Skip to content

Commit b59beba

Browse files
author
WeatherBenchX authors
committed
Update NaN mask handling for FSS
PiperOrigin-RevId: 869895126
1 parent 859afd0 commit b59beba

2 files changed

Lines changed: 98 additions & 4 deletions

File tree

weatherbenchX/metrics/metrics_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,41 @@ def test_fss(self):
227227
])
228228
np.testing.assert_allclose(out, correct_result)
229229

230+
# Test SPF and STF with NaNs, n=1
231+
prediction1 = xr.DataArray(
232+
[[1, 0, np.nan, 1]],
233+
dims=['latitude', 'longitude'],
234+
name='precipitation',
235+
).to_dataset()
236+
target1 = xr.DataArray(
237+
[[0, np.nan, 1, 0]],
238+
dims=['latitude', 'longitude'],
239+
name='precipitation',
240+
).to_dataset()
241+
metrics1 = {
242+
'fss': spatial.FSS(neighborhood_size_in_pixels=1),
243+
}
244+
stats1 = metrics_base.compute_unique_statistics_for_all_metrics(
245+
metrics1, prediction1, target1
246+
)
247+
spf_stat = metrics1['fss'].statistics['SquaredPredictionFraction']
248+
stf_stat = metrics1['fss'].statistics['SquaredTargetFraction']
249+
250+
# SPF: predictions are masked where targets are NaN.
251+
# prediction becomes [1, 0, nan, 1].where(~[F,T,F,F]) = [1, nan, nan, 1]
252+
# Neighborhood size 1: SPF = prediction^2 = [1, nan, nan, 1]
253+
np.testing.assert_allclose(
254+
stats1[spf_stat.unique_name]['precipitation'].values,
255+
[[1.0, np.nan, np.nan, 1.0]],
256+
)
257+
# STF: targets are masked where predictions are NaN.
258+
# target becomes [0, nan, 1, 0].where(~[F,F,T,F]) = [0, nan, nan, 0]
259+
# Neighborhood size 1: STF = target^2 = [0, nan, nan, 0]
260+
np.testing.assert_allclose(
261+
stats1[stf_stat.unique_name]['precipitation'].values,
262+
[[0.0, np.nan, np.nan, 0.0]],
263+
)
264+
230265
def test_wrapped_metric(self):
231266
target = (
232267
test_utils.mock_prediction_data(

weatherbenchX/metrics/spatial.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Spatial verification metrics."""
1515

1616
import dataclasses
17-
from typing import Iterable, Mapping, Union
17+
from typing import Collection, Iterable, Mapping, Union
1818
import numpy as np
1919
from scipy import ndimage
2020
from weatherbenchX.metrics import base
@@ -99,6 +99,41 @@ def neighborhood_averaging(
9999
)
100100

101101

102+
def get_fss_mask(
103+
predictions: xr.DataArray,
104+
targets: xr.DataArray,
105+
neighborhood_size: Union[int, Iterable[int]],
106+
wrap_longitude: bool = False,
107+
) -> xr.DataArray:
108+
"""Get mask for FSS.
109+
110+
The mask is True for pixels where FSS is valid, based on neighborhood
111+
averaging method used in FSS, which propagates NaNs and applies boundary
112+
zeroing for non-wrap-around case.
113+
114+
If any of predictions or targets is NaN in a neighborhood of a pixel,
115+
the neighborhood averaging for that pixel will result in NaN,
116+
unless it's on a boundary that gets zeroed out when wrap_longitude=False.
117+
This mask is True where neighborhood averaging doesn't produce NaN.
118+
This matches masking logic in SquaredPredictionFraction and
119+
SquaredTargetFraction.
120+
121+
Args:
122+
predictions: Predictions DataArray.
123+
targets: Targets DataArray.
124+
neighborhood_size: Neighborhood size for convolution.
125+
wrap_longitude: Whether to wrap longitude in convolution.
126+
127+
Returns:
128+
Boolean mask DataArray.
129+
"""
130+
masked_preds = predictions.where(~targets.isnull())
131+
neighborhood_preds = neighborhood_averaging(
132+
masked_preds, neighborhood_size, wrap_longitude
133+
)
134+
return ~neighborhood_preds.isnull()
135+
136+
102137
def get_suffix(
103138
neighborhood_size: Union[int, Iterable[int]],
104139
wrap_longitude: bool = False,
@@ -129,13 +164,21 @@ def _compute_per_variable(
129164
predictions: xr.DataArray,
130165
targets: xr.DataArray,
131166
) -> xr.DataArray:
167+
mask = get_fss_mask(
168+
predictions,
169+
targets,
170+
self.neighborhood_size_in_pixels,
171+
self.wrap_longitude,
172+
)
132173
predictions = neighborhood_averaging(
133174
predictions, self.neighborhood_size_in_pixels, self.wrap_longitude
134175
)
135176
targets = neighborhood_averaging(
136177
targets, self.neighborhood_size_in_pixels, self.wrap_longitude
137178
)
138-
return (predictions - targets) ** 2
179+
result = (predictions - targets) ** 2
180+
result = result.assign_coords(mask=mask)
181+
return result
139182

140183

141184
@dataclasses.dataclass
@@ -155,10 +198,18 @@ def _compute_per_variable(
155198
predictions: xr.DataArray,
156199
targets: xr.DataArray,
157200
) -> xr.DataArray:
201+
mask = get_fss_mask(
202+
predictions,
203+
targets,
204+
self.neighborhood_size_in_pixels,
205+
self.wrap_longitude,
206+
)
158207
predictions = neighborhood_averaging(
159208
predictions, self.neighborhood_size_in_pixels, self.wrap_longitude
160209
)
161-
return predictions**2 + xr.zeros_like(targets)
210+
result = predictions**2
211+
result = result.assign_coords(mask=mask)
212+
return result
162213

163214

164215
@dataclasses.dataclass
@@ -178,10 +229,18 @@ def _compute_per_variable(
178229
predictions: xr.DataArray,
179230
targets: xr.DataArray,
180231
) -> xr.DataArray:
232+
mask = get_fss_mask(
233+
predictions,
234+
targets,
235+
self.neighborhood_size_in_pixels,
236+
self.wrap_longitude,
237+
)
181238
targets = neighborhood_averaging(
182239
targets, self.neighborhood_size_in_pixels, self.wrap_longitude
183240
)
184-
return targets**2 + xr.zeros_like(predictions)
241+
result = targets**2
242+
result = result.assign_coords(mask=mask)
243+
return result
185244

186245

187246
@dataclasses.dataclass

0 commit comments

Comments
 (0)