Skip to content

Commit f1719d8

Browse files
Introduce geoinfo_channels for anemoi datasets (#1760)
* implemented * remove eval in interface * lint * incoporate requested changes * fix imports * Introduced geoinfo_channel parameter for anemoi dataset --------- Co-authored-by: moritzhauschulz <[email protected]>
1 parent 40aa725 commit f1719d8

File tree

2 files changed

+53
-11
lines changed

2 files changed

+53
-11
lines changed

src/weathergen/datasets/data_reader_anemoi.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,16 @@ def __init__(
112112
# get target channel weights from stream config
113113
self.target_channel_weights = self.parse_target_channel_weights()
114114

115-
self.geoinfo_channels = []
116-
self.geoinfo_idx = []
115+
# select/filter requested geoinfo channels (can be any variable, not just constant-in-time)
116+
self.geoinfo_idx = self.select_geoinfo_channels(ds0)
117+
self.geoinfo_channels = [ds.variables[i] for i in self.geoinfo_idx]
118+
# set geoinfo normalization statistics
119+
if len(self.geoinfo_idx) > 0:
120+
self.mean_geoinfo = ds.statistics["mean"][self.geoinfo_idx]
121+
self.stdev_geoinfo = ds.statistics["stdev"][self.geoinfo_idx]
122+
else:
123+
self.mean_geoinfo = np.zeros(0)
124+
self.stdev_geoinfo = np.ones(0)
117125

118126
ds_name = stream_info["name"]
119127
_logger.info(f"{ds_name}: source channels: {self.source_channels}")
@@ -177,12 +185,13 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData:
177185
num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx)
178186
)
179187

188+
# coords-first representation and collapse multiple steps
189+
data = data.transpose([0, 2, 1]).reshape((data.shape[0] * data.shape[2], -1))
190+
191+
# extract geoinfo channels (can be time-varying, so read from dataset)
192+
geoinfos = data[:, list(self.geoinfo_idx)]
180193
# extract channels
181-
data = (
182-
data[:, list(channels_idx)]
183-
.transpose([0, 2, 1])
184-
.reshape((data.shape[0] * data.shape[2], -1))
185-
)
194+
data = data[:, list(channels_idx)]
186195

187196
# construct lat/lon coords
188197
latlon = np.concatenate(
@@ -195,9 +204,6 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData:
195204
# repeat latlon len(t_idxs) times
196205
coords = np.vstack((latlon,) * len(t_idxs))
197206

198-
# empty geoinfos for anemoi
199-
geoinfos = np.zeros((len(data), 0), dtype=data.dtype)
200-
201207
# date time matching #data points of data
202208
# Assuming a fixed frequency for the dataset
203209
datetimes = np.repeat(self.ds.dates[didx_start:didx_end], len(data) // len(t_idxs))
@@ -254,6 +260,40 @@ def select_channels(self, ds0: anemoi_datasets, ch_type: str) -> NDArray[np.int6
254260

255261
return np.array(chs_idx, dtype=np.int64)
256262

263+
def select_geoinfo_channels(self, ds0: anemoi_datasets) -> NDArray[np.int64]:
264+
"""
265+
Select geoinfo channels (can be any variable, not just constant-in-time)
266+
267+
Parameters
268+
----------
269+
ds0 :
270+
raw anemoi dataset with available channels
271+
272+
Returns
273+
-------
274+
NDArray of channel indices for geoinfo variables
275+
276+
"""
277+
278+
geoinfo_channels = self.stream_info.get("geoinfo_channels", [])
279+
280+
if len(geoinfo_channels) == 0:
281+
return np.array([], dtype=np.int64)
282+
283+
# Select channels that match the geoinfo list (exact match required)
284+
chs_idx = np.sort(
285+
[ds0.name_to_index[k] for k in ds0.typed_variables.keys() if k in geoinfo_channels]
286+
)
287+
288+
if len(chs_idx) == 0 and len(geoinfo_channels) > 0:
289+
stream_name = self.stream_info["name"]
290+
_logger.warning(
291+
f"No matching geoinfo channels found for {stream_name}. "
292+
f"Requested: {geoinfo_channels}"
293+
)
294+
295+
return np.array(chs_idx, dtype=np.int64)
296+
257297

258298
def _clip_lat(lats: NDArray) -> NDArray[np.float32]:
259299
"""

src/weathergen/datasets/data_reader_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,9 @@ def normalize_geoinfos(self, geoinfos: NDArray[DType]) -> NDArray[DType]:
610610

611611
assert geoinfos.shape[-1] == len(self.geoinfo_idx), "incorrect number of geoinfo channels"
612612
for i, _ in enumerate(self.geoinfo_idx):
613-
geoinfos[..., i] = (geoinfos[..., i] - self.mean_geoinfo[i]) / self.stdev_geoinfo[i]
613+
# for constant fields, just center the data (resulting in 0s after subtracting mean)
614+
stdev = 1.0 if np.isclose(self.stdev_geoinfo[i], 0) else self.stdev_geoinfo[i]
615+
geoinfos[..., i] = (geoinfos[..., i] - self.mean_geoinfo[i]) / stdev
614616

615617
return geoinfos
616618

0 commit comments

Comments
 (0)