1414"""Spatial verification metrics."""
1515
1616import dataclasses
17- from typing import Iterable , Mapping , Union
17+ from typing import Collection , Iterable , Mapping , Union
1818import numpy as np
1919from scipy import ndimage
2020from weatherbenchX .metrics import base
@@ -99,6 +99,41 @@ def neighborhood_averaging(
9999 )
100100
101101
102+ def get_fss_mask (
103+ predictions : xr .DataArray ,
104+ targets : xr .DataArray ,
105+ neighborhood_size : Union [int , Iterable [int ]],
106+ wrap_longitude : bool = False ,
107+ ) -> xr .DataArray :
108+ """Get mask for FSS.
109+
110+ The mask is True for pixels where FSS is valid, based on neighborhood
111+ averaging method used in FSS, which propagates NaNs and applies boundary
112+ zeroing for non-wrap-around case.
113+
114+ If any of predictions or targets is NaN in a neighborhood of a pixel,
115+ the neighborhood averaging for that pixel will result in NaN,
116+ unless it's on a boundary that gets zeroed out when wrap_longitude=False.
117+ This mask is True where neighborhood averaging doesn't produce NaN.
118+ This matches masking logic in SquaredPredictionFraction and
119+ SquaredTargetFraction.
120+
121+ Args:
122+ predictions: Predictions DataArray.
123+ targets: Targets DataArray.
124+ neighborhood_size: Neighborhood size for convolution.
125+ wrap_longitude: Whether to wrap longitude in convolution.
126+
127+ Returns:
128+ Boolean mask DataArray.
129+ """
130+ masked_preds = predictions .where (~ targets .isnull ())
131+ neighborhood_preds = neighborhood_averaging (
132+ masked_preds , neighborhood_size , wrap_longitude
133+ )
134+ return ~ neighborhood_preds .isnull ()
135+
136+
102137def get_suffix (
103138 neighborhood_size : Union [int , Iterable [int ]],
104139 wrap_longitude : bool = False ,
@@ -129,13 +164,21 @@ def _compute_per_variable(
129164 predictions : xr .DataArray ,
130165 targets : xr .DataArray ,
131166 ) -> xr .DataArray :
167+ mask = get_fss_mask (
168+ predictions ,
169+ targets ,
170+ self .neighborhood_size_in_pixels ,
171+ self .wrap_longitude ,
172+ )
132173 predictions = neighborhood_averaging (
133174 predictions , self .neighborhood_size_in_pixels , self .wrap_longitude
134175 )
135176 targets = neighborhood_averaging (
136177 targets , self .neighborhood_size_in_pixels , self .wrap_longitude
137178 )
138- return (predictions - targets ) ** 2
179+ result = (predictions - targets ) ** 2
180+ result = result .assign_coords (mask = mask )
181+ return result
139182
140183
141184@dataclasses .dataclass
@@ -155,10 +198,18 @@ def _compute_per_variable(
155198 predictions : xr .DataArray ,
156199 targets : xr .DataArray ,
157200 ) -> xr .DataArray :
201+ mask = get_fss_mask (
202+ predictions ,
203+ targets ,
204+ self .neighborhood_size_in_pixels ,
205+ self .wrap_longitude ,
206+ )
158207 predictions = neighborhood_averaging (
159208 predictions , self .neighborhood_size_in_pixels , self .wrap_longitude
160209 )
161- return predictions ** 2 + xr .zeros_like (targets )
210+ result = predictions ** 2
211+ result = result .assign_coords (mask = mask )
212+ return result
162213
163214
164215@dataclasses .dataclass
@@ -178,10 +229,18 @@ def _compute_per_variable(
178229 predictions : xr .DataArray ,
179230 targets : xr .DataArray ,
180231 ) -> xr .DataArray :
232+ mask = get_fss_mask (
233+ predictions ,
234+ targets ,
235+ self .neighborhood_size_in_pixels ,
236+ self .wrap_longitude ,
237+ )
181238 targets = neighborhood_averaging (
182239 targets , self .neighborhood_size_in_pixels , self .wrap_longitude
183240 )
184- return targets ** 2 + xr .zeros_like (predictions )
241+ result = targets ** 2
242+ result = result .assign_coords (mask = mask )
243+ return result
185244
186245
187246@dataclasses .dataclass
0 commit comments