diff --git a/extra_data/components.py b/extra_data/components.py index de37663f..49ddb2aa 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -1,5 +1,6 @@ """Interfaces to data from specific instruments """ +import inspect import logging import re from copy import copy @@ -98,6 +99,7 @@ class MultimodDetectorBase: # Override in subclass _main_data_key = '' # Key to use for checking data counts match _frames_per_entry = 1 # Override if separate pulse dimension in files + _modnos_start_at = 0 # Override if module numbers start at 1 (JUNGFRAU) module_shape = (0, 0) n_modules = 0 @@ -282,6 +284,10 @@ def frames_per_train(self): raise ValueError(f"Varying number of frames per train: {counts}") return counts.pop() * self._frames_per_entry + @property + def n_frames(self): + return self.frame_counts.sum() * self._frames_per_entry + def __repr__(self): return "<{}: Data interface for detector {!r} with {} modules>".format( type(self).__name__, self.detector_name, len(self.source_to_modno), @@ -407,7 +413,7 @@ def get_array(self, key, *, fill_value=None, roi=(), astype=None): Specify e.g. ``np.s_[10:60, 100:200]`` to select pixels within each module when reading data. The selection is applied to each individual module, so it may only be useful when working with a single module. - astype: Type + astype: dtype Data type of the output array. If None (default) the dtype matches the input array dtype """ @@ -466,6 +472,123 @@ def get_dask_array(self, key, fill_value=None, astype=None): return self._concat(arrays, modnos, fill_value, astype) + def _get_data(self, key, *, fill_value=None, roi=(), astype=None): + """Get data as a plain NumPy array with no labels""" + train_ids = self.train_ids_perframe + + eg_src = min(self.source_to_modno) + eg_keydata = self.data[eg_src, key] + + # Find the shape of 1 frame for 1 module with the ROI applied + out_shape = ((self.n_modules, len(train_ids)) + + roi_shape(eg_keydata.entry_shape, roi)) + + dtype = eg_keydata.dtype if astype is None else np.dtype(astype) + out = self._out_array(out_shape, dtype, fill_value=fill_value) + + for modno, source in sorted(self.modno_to_source.items()): + mod_ix = modno - self._modnos_start_at + for chunk in self.data._find_data_chunks(source, key): + for tgt_slice, chunk_slice in self._split_align_chunk(chunk, train_ids): + chunk.dataset.read_direct( + out[mod_ix, tgt_slice], source_sel=(chunk_slice,) + roi + ) + + return out + + def _apply_framewise(self, f, out, data_params={}): + arr = self._get_data(self._main_data_key) + # Array should be (modules, frames, *pixel_dims) + ndim_px = len(self.module_shape) + ndim_iter = arr.ndim - 1 - ndim_px + arr = arr.reshape((arr.shape[0], -1, *arr.shape[-ndim_px:])) + + # Prepare arrays for data to be passed as kwargs (mask, pulseId, etc.) + kw_arrs = {} + for param, key in data_params.items(): + a = self._get_data(key) + ndim_inner = a.ndim - 1 - ndim_iter + kw_arrs[param] = a.reshape((a.shape[0], -1, *arr.shape[-ndim_inner:])) + + for i in range(arr.shape[1]): + kw = {p: a[:, i] for (p, a) in kw_arrs.items()} + out[i] = f(arr[:, i], **kw) + + def _frame_func_to_chunk_func(self, f, out_shape=None, out_dtype=None): + eg_srcdata = self.data[min(self.source_to_modno)] + main_group = self._main_data_key.rpartition('.')[0] + '.' + data_keys = {k.rpartition('.')[2]: k for k in eg_srcdata.keys() + if k.startswith(main_group)} + data_params = {} + for param_name in list(inspect.signature(f).parameters)[1:]: + if param_name in data_keys: + data_params[param_name] = data_keys[param_name] + else: + raise KeyError(f"No {param_name} data available; " + f"possible names are {', '.join(data_keys)}") + + def chunk_func(chunk): + if out_shape is not None: + out = chunk._out_array((chunk.n_frames, *out_shape), dtype=out_dtype) + else: + out = [None] * chunk.n_frames + chunk._apply_framewise(f, out, data_params) + return out + + return chunk_func + + def map_frames( + self, f, mapper=None, *, + out=None, out_shape=None, out_dtype=None, + parts=None, trains_per_part=None, frames_per_part=None + ): + if mapper is None: + # Default to using multiprocessing with up to 16 cores. + # We're likely to spend a fair bit of + import multiprocessing + with multiprocessing.Pool(min(multiprocessing.cpu_count(), 16)) as p: + return self.map_frames( + f, p.imap, out=out, out_shape=out_shape, out_dtype=out_dtype, + parts=parts, trains_per_part=trains_per_part, + frames_per_part=frames_per_part, + ) + + if parts is None and trains_per_part is None and frames_per_part is None: + # Default ~4 GiB chunks for 1 MPx detectors. This is probably too + # big for all cores in parallel on one node, but in many cases the + # limiting step will be loading data, so you want fewer workers + # (or split it over multiple nodes) + frames_per_part = 1000 + chunks = list(self.split_trains(parts, trains_per_part, frames_per_part)) + + map_kwargs = {} + if 'key' in inspect.signature(mapper).parameters: + # Dask workaround: avoid pickling & clumsily md5-ing function to + # produce task keys + from secrets import token_hex + map_kwargs['key'] = [f"map-frames-{token_hex(16)}" for _ in chunks] + + chunk_func = self._frame_func_to_chunk_func(f) + results_iter = mapper(chunk_func, chunks, **map_kwargs) + + if out is None: + if out_shape is not None: + out = self._out_array((self.n_frames, *out_shape), dtype=out_dtype) + else: + out = [None] * self.n_frames + + # Assemble per-chunk results into output list/array + out_cursor = 0 + for chunk_res in results_iter: + if hasattr(chunk_res, 'result'): + # Dask returns futures rather than direct results + chunk_res = chunk_res.result() + to = out_cursor + len(chunk_res) + out[out_cursor : to] = chunk_res + out_cursor = to + + return out + def trains(self, require_all=True): """Iterate over trains for detector data. @@ -1429,6 +1552,7 @@ class JUNGFRAU(MultimodDetectorBase): r'(MODULE_|RECEIVER-|JNGFR)(?P\d+)' ) _main_data_key = 'data.adc' + _modnos_start_at = 1 module_shape = (512, 1024) def __init__(self, data: DataCollection, detector_name=None, modules=None,