@@ -112,8 +112,16 @@ def __init__(
112112 # get target channel weights from stream config
113113 self .target_channel_weights = self .parse_target_channel_weights ()
114114
115- self .geoinfo_channels = []
116- self .geoinfo_idx = []
115+ # select/filter requested geoinfo channels (can be any variable, not just constant-in-time)
116+ self .geoinfo_idx = self .select_geoinfo_channels (ds0 )
117+ self .geoinfo_channels = [ds .variables [i ] for i in self .geoinfo_idx ]
118+ # set geoinfo normalization statistics
119+ if len (self .geoinfo_idx ) > 0 :
120+ self .mean_geoinfo = ds .statistics ["mean" ][self .geoinfo_idx ]
121+ self .stdev_geoinfo = ds .statistics ["stdev" ][self .geoinfo_idx ]
122+ else :
123+ self .mean_geoinfo = np .zeros (0 )
124+ self .stdev_geoinfo = np .ones (0 )
117125
118126 ds_name = stream_info ["name" ]
119127 _logger .info (f"{ ds_name } : source channels: { self .source_channels } " )
@@ -177,12 +185,13 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData:
177185 num_data_fields = len (channels_idx ), num_geo_fields = len (self .geoinfo_idx )
178186 )
179187
188+ # coords-first representation and collapse multiple steps
189+ data = data .transpose ([0 , 2 , 1 ]).reshape ((data .shape [0 ] * data .shape [2 ], - 1 ))
190+
191+ # extract geoinfo channels (can be time-varying, so read from dataset)
192+ geoinfos = data [:, list (self .geoinfo_idx )]
180193 # extract channels
181- data = (
182- data [:, list (channels_idx )]
183- .transpose ([0 , 2 , 1 ])
184- .reshape ((data .shape [0 ] * data .shape [2 ], - 1 ))
185- )
194+ data = data [:, list (channels_idx )]
186195
187196 # construct lat/lon coords
188197 latlon = np .concatenate (
@@ -195,9 +204,6 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData:
195204 # repeat latlon len(t_idxs) times
196205 coords = np .vstack ((latlon ,) * len (t_idxs ))
197206
198- # empty geoinfos for anemoi
199- geoinfos = np .zeros ((len (data ), 0 ), dtype = data .dtype )
200-
201207 # date time matching #data points of data
202208 # Assuming a fixed frequency for the dataset
203209 datetimes = np .repeat (self .ds .dates [didx_start :didx_end ], len (data ) // len (t_idxs ))
@@ -254,6 +260,40 @@ def select_channels(self, ds0: anemoi_datasets, ch_type: str) -> NDArray[np.int6
254260
255261 return np .array (chs_idx , dtype = np .int64 )
256262
263+ def select_geoinfo_channels (self , ds0 : anemoi_datasets ) -> NDArray [np .int64 ]:
264+ """
265+ Select geoinfo channels (can be any variable, not just constant-in-time)
266+
267+ Parameters
268+ ----------
269+ ds0 :
270+ raw anemoi dataset with available channels
271+
272+ Returns
273+ -------
274+ NDArray of channel indices for geoinfo variables
275+
276+ """
277+
278+ geoinfo_channels = self .stream_info .get ("geoinfo_channels" , [])
279+
280+ if len (geoinfo_channels ) == 0 :
281+ return np .array ([], dtype = np .int64 )
282+
283+ # Select channels that match the geoinfo list (exact match required)
284+ chs_idx = np .sort (
285+ [ds0 .name_to_index [k ] for k in ds0 .typed_variables .keys () if k in geoinfo_channels ]
286+ )
287+
288+ if len (chs_idx ) == 0 and len (geoinfo_channels ) > 0 :
289+ stream_name = self .stream_info ["name" ]
290+ _logger .warning (
291+ f"No matching geoinfo channels found for { stream_name } . "
292+ f"Requested: { geoinfo_channels } "
293+ )
294+
295+ return np .array (chs_idx , dtype = np .int64 )
296+
257297
258298def _clip_lat (lats : NDArray ) -> NDArray [np .float32 ]:
259299 """
0 commit comments