From 4649d87d9670888c89eeee03d18fafd1848b9e39 Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Fri, 6 Sep 2024 21:50:02 +0200 Subject: [PATCH 1/8] Parallel load datasets --- esmvalcore/dataset.py | 76 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 67 insertions(+), 9 deletions(-) diff --git a/esmvalcore/dataset.py b/esmvalcore/dataset.py index d4bd665aa6..d3efdb4e96 100644 --- a/esmvalcore/dataset.py +++ b/esmvalcore/dataset.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any, Iterator, Sequence, Union +import dask from iris.cube import Cube from esmvalcore import esgf, local @@ -79,6 +80,10 @@ def _ismatch(facet_value: FacetValue, pattern: FacetValue) -> bool: and fnmatchcase(facet_value, pattern)) +def _first(elems): + return elems[0] + + class Dataset: """Define datasets, find the related files, and load them. @@ -664,9 +669,19 @@ def files(self) -> Sequence[File]: def files(self, value): self._files = value - def load(self) -> Cube: + def load(self, compute=True) -> Cube: """Load dataset. + Parameters + ---------- + compute: + If :obj:`True`, return the cube immediately. If :obj:`False`, + return a :class:`~dask.delayed.Delayed` object that can be used + to load the cube by calling its + :func:`~dask.delayed.Delayed.compute` method. Multiple datasets + can be loaded in parallel by passing a list of such delayeds + to :func:`dask.compute`. + Raises ------ InputFilesNotFound @@ -689,7 +704,7 @@ def load(self) -> Cube: supplementary_cubes.append(supplementary_cube) output_file = _get_output_file(self.facets, self.session.preproc_dir) - cubes = preprocess( + cubes = dask.delayed(preprocess)( [cube], 'add_supplementary_variables', input_files=input_files, @@ -698,7 +713,10 @@ def load(self) -> Cube: supplementary_cubes=supplementary_cubes, ) - return cubes[0] + cube = dask.delayed(_first)(cubes) + if compute: + return cube.compute() + return cube def _load(self) -> Cube: """Load self.files into an iris cube and return it.""" @@ -763,21 +781,61 @@ def _load(self) -> Cube: 'short_name': self.facets['short_name'], } - result = [ + input_files = [ file.local_file(self.session['download_dir']) if isinstance( file, esgf.ESGFFile) else file for file in self.files ] - for step, kwargs in settings.items(): - result = preprocess( + + debug = self.session['save_intermediary_cubes'] + + result = [] + for input_file in input_files: + files = dask.delayed(preprocess)( + [input_file], + 'fix_file', + input_files=[input_file], + output_file=output_file, + debug=debug, + **settings['fix_file'], + ) + cubes = dask.delayed(preprocess)( + files, + 'load', + input_files=[input_file], + output_file=output_file, + debug=debug, + **settings['load'], + ) + cubes = dask.delayed(preprocess)( + cubes, + 'fix_metadata', + input_files=[input_file], + output_file=output_file, + debug=debug, + **settings['fix_metadata'], + ) + cube = dask.delayed(_first)(cubes) + result.append(cube) + + result = dask.delayed(preprocess)( + result, + 'concatenate', + input_files=input_files, + output_file=output_file, + debug=debug, + **settings['concatenate'], + ) + for step, kwargs in dict(tuple(settings.items())[4:]).items(): + result = dask.delayed(preprocess)( result, step, - input_files=self.files, + input_files=input_files, output_file=output_file, - debug=self.session['save_intermediary_cubes'], + debug=debug, **kwargs, ) - cube = result[0] + cube = dask.delayed(_first)(result) return cube def from_ranges(self) -> list['Dataset']: From bc889ba50a3176b9b840cc9f8259483528cbd45c Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Mon, 9 Sep 2024 10:49:47 +0200 Subject: [PATCH 2/8] Add test and improve documentation --- esmvalcore/dataset.py | 91 +++++++++++++---------- tests/integration/dataset/test_dataset.py | 12 ++- 2 files changed, 61 insertions(+), 42 deletions(-) diff --git a/esmvalcore/dataset.py b/esmvalcore/dataset.py index d3efdb4e96..60a1b5cb29 100644 --- a/esmvalcore/dataset.py +++ b/esmvalcore/dataset.py @@ -6,13 +6,15 @@ import re import textwrap import uuid +from collections.abc import Iterable from copy import deepcopy from fnmatch import fnmatchcase from itertools import groupby from pathlib import Path -from typing import Any, Iterator, Sequence, Union +from typing import Any, Iterator, Sequence, TypeVar, Union import dask +from dask.delayed import Delayed from iris.cube import Cube from esmvalcore import esgf, local @@ -80,8 +82,12 @@ def _ismatch(facet_value: FacetValue, pattern: FacetValue) -> bool: and fnmatchcase(facet_value, pattern)) -def _first(elems): - return elems[0] +T = TypeVar('T') + + +def _first(elems: Iterable[T]) -> T: + """Return the first element.""" + return next(iter(elems)) class Dataset: @@ -669,16 +675,16 @@ def files(self) -> Sequence[File]: def files(self, value): self._files = value - def load(self, compute=True) -> Cube: + def load(self, compute=True) -> Cube | Delayed: """Load dataset. Parameters ---------- compute: - If :obj:`True`, return the cube immediately. If :obj:`False`, - return a :class:`~dask.delayed.Delayed` object that can be used - to load the cube by calling its - :func:`~dask.delayed.Delayed.compute` method. Multiple datasets + If :obj:`True`, return the :class:`~iris.cube.Cube` immediately. + If :obj:`False`, return a :class:`~dask.delayed.Delayed` object + that can be used to load the cube by calling its + :meth:`~dask.delayed.Delayed.compute` method. Multiple datasets can be loaded in parallel by passing a list of such delayeds to :func:`dask.compute`. @@ -731,7 +737,14 @@ def _load(self) -> Cube: msg = "\n".join(lines) raise InputFilesNotFound(msg) + input_files = [ + file.local_file(self.session['download_dir']) if isinstance( + file, esgf.ESGFFile) else file for file in self.files + ] output_file = _get_output_file(self.facets, self.session.preproc_dir) + debug = self.session['save_intermediary_cubes'] + + # Load all input files and concatenate them. fix_dir_prefix = Path( self.session._fixed_file_dir, self._get_joined_summary_facets('_', join_lists=True) + '_', @@ -757,36 +770,6 @@ def _load(self) -> Cube: settings['concatenate'] = { 'check_level': self.session['check_level'] } - settings['cmor_check_metadata'] = { - 'check_level': self.session['check_level'], - 'cmor_table': self.facets['project'], - 'mip': self.facets['mip'], - 'frequency': self.facets['frequency'], - 'short_name': self.facets['short_name'], - } - if 'timerange' in self.facets: - settings['clip_timerange'] = { - 'timerange': self.facets['timerange'], - } - settings['fix_data'] = { - 'check_level': self.session['check_level'], - 'session': self.session, - **self.facets, - } - settings['cmor_check_data'] = { - 'check_level': self.session['check_level'], - 'cmor_table': self.facets['project'], - 'mip': self.facets['mip'], - 'frequency': self.facets['frequency'], - 'short_name': self.facets['short_name'], - } - - input_files = [ - file.local_file(self.session['download_dir']) if isinstance( - file, esgf.ESGFFile) else file for file in self.files - ] - - debug = self.session['save_intermediary_cubes'] result = [] for input_file in input_files: @@ -798,6 +781,7 @@ def _load(self) -> Cube: debug=debug, **settings['fix_file'], ) + # Multiple cubes may be present in a file. cubes = dask.delayed(preprocess)( files, 'load', @@ -806,6 +790,7 @@ def _load(self) -> Cube: debug=debug, **settings['load'], ) + # Combine the cubes into a single cube per file. cubes = dask.delayed(preprocess)( cubes, 'fix_metadata', @@ -817,6 +802,7 @@ def _load(self) -> Cube: cube = dask.delayed(_first)(cubes) result.append(cube) + # Concatenate the cubes from all files. result = dask.delayed(preprocess)( result, 'concatenate', @@ -825,7 +811,34 @@ def _load(self) -> Cube: debug=debug, **settings['concatenate'], ) - for step, kwargs in dict(tuple(settings.items())[4:]).items(): + + # At this point `result` is a list containing a single cube. Apply the + # remaining preprocessor functions to this cube. + settings.clear() + settings['cmor_check_metadata'] = { + 'check_level': self.session['check_level'], + 'cmor_table': self.facets['project'], + 'mip': self.facets['mip'], + 'frequency': self.facets['frequency'], + 'short_name': self.facets['short_name'], + } + if 'timerange' in self.facets: + settings['clip_timerange'] = { + 'timerange': self.facets['timerange'], + } + settings['fix_data'] = { + 'check_level': self.session['check_level'], + 'session': self.session, + **self.facets, + } + settings['cmor_check_data'] = { + 'check_level': self.session['check_level'], + 'cmor_table': self.facets['project'], + 'mip': self.facets['mip'], + 'frequency': self.facets['frequency'], + 'short_name': self.facets['short_name'], + } + for step, kwargs in settings.items(): result = dask.delayed(preprocess)( result, step, diff --git a/tests/integration/dataset/test_dataset.py b/tests/integration/dataset/test_dataset.py index 0c94dc8c48..19d5aa1f4a 100644 --- a/tests/integration/dataset/test_dataset.py +++ b/tests/integration/dataset/test_dataset.py @@ -3,6 +3,7 @@ import iris.coords import iris.cube import pytest +from dask.delayed import Delayed from esmvalcore.config import CFG from esmvalcore.dataset import Dataset @@ -34,7 +35,8 @@ def example_data(tmp_path, monkeypatch): monkeypatch.setitem(CFG, 'output_dir', tmp_path / 'output_dir') -def test_load(example_data): +@pytest.mark.parametrize('lazy', [True, False]) +def test_load(example_data, lazy): tas = Dataset( short_name='tas', mip='Amon', @@ -51,7 +53,11 @@ def test_load(example_data): tas.find_files() print(tas.files) - cube = tas.load() - + if lazy: + result = tas.load(compute=False) + assert isinstance(result, Delayed) + cube = result.compute() + else: + cube = tas.load() assert isinstance(cube, iris.cube.Cube) assert cube.cell_measures() From 0bd3da1ce80971e211921cccf90bae199eac5e57 Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Thu, 26 Sep 2024 21:39:32 +0200 Subject: [PATCH 3/8] Use ruff formatting --- esmvalcore/dataset.py | 370 ++++++++++++---------- tests/integration/dataset/test_dataset.py | 67 ++-- 2 files changed, 245 insertions(+), 192 deletions(-) diff --git a/esmvalcore/dataset.py b/esmvalcore/dataset.py index 60a1b5cb29..49c62bcb0d 100644 --- a/esmvalcore/dataset.py +++ b/esmvalcore/dataset.py @@ -1,4 +1,5 @@ """Classes and functions for defining, finding, and loading data.""" + from __future__ import annotations import logging @@ -38,9 +39,9 @@ from esmvalcore.typing import Facets, FacetValue __all__ = [ - 'Dataset', - 'INHERITED_FACETS', - 'datasets_to_recipe', + "Dataset", + "INHERITED_FACETS", + "datasets_to_recipe", ] logger = logging.getLogger(__name__) @@ -48,12 +49,12 @@ File = Union[esgf.ESGFFile, local.LocalFile] INHERITED_FACETS: list[str] = [ - 'dataset', - 'domain', - 'driver', - 'grid', - 'project', - 'timerange', + "dataset", + "domain", + "driver", + "grid", + "project", + "timerange", ] """Inherited facets. @@ -72,17 +73,21 @@ def _augment(base: dict, update: dict): def _isglob(facet_value: FacetValue | None) -> bool: """Check if a facet value is a glob pattern.""" - return (isinstance(facet_value, str) - and bool(re.match(r'.*[\*\?]+.*|.*\[.*\].*', facet_value))) + return isinstance(facet_value, str) and bool( + re.match(r".*[\*\?]+.*|.*\[.*\].*", facet_value) + ) def _ismatch(facet_value: FacetValue, pattern: FacetValue) -> bool: """Check if a facet value matches a glob pattern.""" - return (isinstance(pattern, str) and isinstance(facet_value, str) - and fnmatchcase(facet_value, pattern)) + return ( + isinstance(pattern, str) + and isinstance(facet_value, str) + and fnmatchcase(facet_value, pattern) + ) -T = TypeVar('T') +T = TypeVar("T") def _first(elems: Iterable[T]) -> T: @@ -109,25 +114,24 @@ class Dataset: """ _SUMMARY_FACETS = ( - 'short_name', - 'mip', - 'project', - 'dataset', - 'rcm_version', - 'driver', - 'domain', - 'activity', - 'exp', - 'ensemble', - 'grid', - 'version', + "short_name", + "mip", + "project", + "dataset", + "rcm_version", + "driver", + "domain", + "activity", + "exp", + "ensemble", + "grid", + "version", ) """Facets used to create a summary of a Dataset instance.""" def __init__(self, **facets: FacetValue): - self.facets: Facets = {} - self.supplementaries: list['Dataset'] = [] + self.supplementaries: list["Dataset"] = [] self._persist: set[str] = set() self._session: Session | None = None @@ -141,7 +145,7 @@ def __init__(self, **facets: FacetValue): def from_recipe( recipe: Path | str | dict, session: Session, - ) -> list['Dataset']: + ) -> list["Dataset"]: """Read datasets from a recipe. Parameters @@ -160,6 +164,7 @@ def from_recipe( A list of datasets. """ from esmvalcore._recipe.to_datasets import datasets_from_recipe + return datasets_from_recipe(recipe, session) def _file_to_dataset( @@ -168,14 +173,16 @@ def _file_to_dataset( ) -> Dataset: """Create a dataset from a file with a `facets` attribute.""" facets = dict(file.facets) - if 'version' not in self.facets: + if "version" not in self.facets: # Remove version facet if no specific version requested - facets.pop('version', None) + facets.pop("version", None) updated_facets = { f: v - for f, v in facets.items() if f in self.facets - and _isglob(self.facets[f]) and _ismatch(v, self.facets[f]) + for f, v in facets.items() + if f in self.facets + and _isglob(self.facets[f]) + and _ismatch(v, self.facets[f]) } dataset = self.copy() dataset.facets.update(updated_facets) @@ -183,7 +190,7 @@ def _file_to_dataset( # If possible, remove unexpanded facets that can be automatically # populated. unexpanded = {f for f, v in dataset.facets.items() if _isglob(v)} - required_for_augment = {'project', 'mip', 'short_name', 'dataset'} + required_for_augment = {"project", "mip", "short_name", "dataset"} if unexpanded and not unexpanded & required_for_augment: copy = dataset.copy() copy.supplementaries = [] @@ -203,10 +210,10 @@ def _get_available_datasets(self) -> Iterator[Dataset]: """ dataset_template = self.copy() dataset_template.supplementaries = [] - if _isglob(dataset_template.facets.get('timerange')): + if _isglob(dataset_template.facets.get("timerange")): # Remove wildcard `timerange` facet, because data finding cannot # handle it - dataset_template.facets.pop('timerange') + dataset_template.facets.pop("timerange") seen = set() partially_defined = [] @@ -217,11 +224,15 @@ def _get_available_datasets(self) -> Iterator[Dataset]: # Filter out identical datasets facetset = frozenset( (f, frozenset(v) if isinstance(v, list) else v) - for f, v in dataset.facets.items()) + for f, v in dataset.facets.items() + ) if facetset not in seen: seen.add(facetset) - if any(_isglob(v) for f, v in dataset.facets.items() - if f != 'timerange'): + if any( + _isglob(v) + for f, v in dataset.facets.items() + if f != "timerange" + ): partially_defined.append((dataset, file)) else: dataset._update_timerange() @@ -231,19 +242,24 @@ def _get_available_datasets(self) -> Iterator[Dataset]: # Only yield datasets with globs if there is no better alternative for dataset, file in partially_defined: - msg = (f"{dataset} with unexpanded wildcards, created from file " - f"{file} with facets {file.facets}. Are the missing facets " - "in the path to the file?" if isinstance( - file, local.LocalFile) else "available on ESGF?") + msg = ( + f"{dataset} with unexpanded wildcards, created from file " + f"{file} with facets {file.facets}. Are the missing facets " + "in the path to the file?" + if isinstance(file, local.LocalFile) + else "available on ESGF?" + ) if expanded: logger.info("Ignoring %s", msg) else: logger.debug( "Not updating timerange and supplementaries for %s " - "because it still contains wildcards.", msg) + "because it still contains wildcards.", + msg, + ) yield dataset - def from_files(self) -> Iterator['Dataset']: + def from_files(self) -> Iterator["Dataset"]: """Create datasets based on the available files. The facet values for local files are retrieved from the directory tree @@ -277,17 +293,18 @@ def from_files(self) -> Iterator['Dataset']: """ expanded = False if any(_isglob(v) for v in self.facets.values()): - if _isglob(self.facets['mip']): + if _isglob(self.facets["mip"]): available_mips = _get_mips( - self.facets['project'], # type: ignore - self.facets['short_name'], # type: ignore + self.facets["project"], # type: ignore + self.facets["short_name"], # type: ignore ) mips = [ - mip for mip in available_mips - if _ismatch(mip, self.facets['mip']) + mip + for mip in available_mips + if _ismatch(mip, self.facets["mip"]) ] else: - mips = [self.facets['mip']] # type: ignore + mips = [self.facets["mip"]] # type: ignore for mip in mips: dataset_template = self.copy(mip=mip) @@ -327,7 +344,7 @@ def _remove_unexpanded_supplementaries(self) -> None: "For %s: ignoring supplementary variable '%s', " "unable to expand wildcards %s.", self.summary(shorten=True), - supplementary_ds.facets['short_name'], + supplementary_ds.facets["short_name"], ", ".join(f"'{f}'" for f in unexpanded), ) else: @@ -357,8 +374,9 @@ def _remove_duplicate_supplementaries(self) -> None: not_used = [] supplementaries = list(self.supplementaries) self.supplementaries.clear() - for _, duplicates in groupby(supplementaries, - key=lambda ds: ds['short_name']): + for _, duplicates in groupby( + supplementaries, key=lambda ds: ds["short_name"] + ): group = sorted(duplicates, key=self._match, reverse=True) self.supplementaries.append(group[0]) not_used.extend(group[1:]) @@ -368,27 +386,30 @@ def _remove_duplicate_supplementaries(self) -> None: "List of all supplementary datasets found for %s:\n%s", self.summary(shorten=True), "\n".join( - sorted(ds.summary(shorten=True) - for ds in supplementaries)), + sorted(ds.summary(shorten=True) for ds in supplementaries) + ), ) def _fix_fx_exp(self) -> None: for supplementary_ds in self.supplementaries: - exps = supplementary_ds.facets.get('exp') - frequency = supplementary_ds.facets.get('frequency') - if isinstance(exps, list) and len(exps) > 1 and frequency == 'fx': + exps = supplementary_ds.facets.get("exp") + frequency = supplementary_ds.facets.get("frequency") + if isinstance(exps, list) and len(exps) > 1 and frequency == "fx": for exp in exps: dataset = supplementary_ds.copy(exp=exp) if dataset.files: - supplementary_ds.facets['exp'] = exp + supplementary_ds.facets["exp"] = exp logger.info( "Corrected wrong 'exp' from '%s' to '%s' for " - "supplementary variable '%s' of %s", exps, exp, - supplementary_ds.facets['short_name'], - self.summary(shorten=True)) + "supplementary variable '%s' of %s", + exps, + exp, + supplementary_ds.facets["short_name"], + self.summary(shorten=True), + ) break - def copy(self, **facets: FacetValue) -> 'Dataset': + def copy(self, **facets: FacetValue) -> "Dataset": """Create a copy. Parameters @@ -412,10 +433,9 @@ def copy(self, **facets: FacetValue) -> 'Dataset': for supplementary in self.supplementaries: # The short_name and mip of the supplementary variable are probably # different from the main variable, so don't copy those facets. - skip = ('short_name', 'mip') + skip = ("short_name", "mip") supplementary_facets = { - k: v - for k, v in facets.items() if k not in skip + k: v for k, v in facets.items() if k not in skip } new_supplementary = supplementary.copy(**supplementary_facets) new.supplementaries.append(new_supplementary) @@ -423,24 +443,25 @@ def copy(self, **facets: FacetValue) -> 'Dataset': def __eq__(self, other) -> bool: """Compare with another dataset.""" - return (isinstance(other, self.__class__) - and self._session == other._session - and self.facets == other.facets - and self.supplementaries == other.supplementaries) + return ( + isinstance(other, self.__class__) + and self._session == other._session + and self.facets == other.facets + and self.supplementaries == other.supplementaries + ) def __repr__(self) -> str: """Create a string representation.""" first_keys = ( - 'diagnostic', - 'variable_group', - 'dataset', - 'project', - 'mip', - 'short_name', + "diagnostic", + "variable_group", + "dataset", + "project", + "mip", + "short_name", ) def facets2str(facets): - view = {k: facets[k] for k in first_keys if k in facets} for key, value in sorted(facets.items()): if key not in first_keys: @@ -456,7 +477,8 @@ def facets2str(facets): txt.append("supplementaries:") txt.extend( textwrap.indent(facets2str(a.facets), " ") - for a in self.supplementaries) + for a in self.supplementaries + ) if self._session: txt.append(f"session: '{self.session.session_name}'") return "\n".join(txt) @@ -473,7 +495,7 @@ def _get_joined_summary_facets( continue val = self.facets[key] if join_lists and isinstance(val, (tuple, list)): - val = '-'.join(str(elem) for elem in val) + val = "-".join(str(elem) for elem in val) else: val = str(val) summary_facets_vals.append(val) @@ -496,16 +518,23 @@ def summary(self, shorten: bool = False) -> str: return repr(self) title = self.__class__.__name__ - txt = f"{title}: " + self._get_joined_summary_facets(', ') + txt = f"{title}: " + self._get_joined_summary_facets(", ") def supplementary_summary(dataset): return ", ".join( - str(dataset.facets[k]) for k in self._SUMMARY_FACETS - if k in dataset.facets and dataset[k] != self.facets.get(k)) + str(dataset.facets[k]) + for k in self._SUMMARY_FACETS + if k in dataset.facets and dataset[k] != self.facets.get(k) + ) if self.supplementaries: - txt += (", supplementaries: " + "; ".join( - supplementary_summary(a) for a in self.supplementaries) + "") + txt += ( + ", supplementaries: " + + "; ".join( + supplementary_summary(a) for a in self.supplementaries + ) + + "" + ) return txt def __getitem__(self, key): @@ -542,11 +571,11 @@ def set_version(self) -> None: """Set the ``'version'`` facet based on the available data.""" versions: set[str] = set() for file in self.files: - if 'version' in file.facets: - versions.add(file.facets['version']) # type: ignore + if "version" in file.facets: + versions.add(file.facets["version"]) # type: ignore version = versions.pop() if len(versions) == 1 else sorted(versions) if version: - self.set_facet('version', version) + self.set_facet("version", version) for supplementary_ds in self.supplementaries: supplementary_ds.set_version() @@ -594,19 +623,19 @@ def augment_facets(self) -> None: supplementary._augment_facets() def _augment_facets(self): - extra_facets = get_extra_facets(self, self.session['extra_facets_dir']) + extra_facets = get_extra_facets(self, self.session["extra_facets_dir"]) _augment(self.facets, extra_facets) - if 'institute' not in self.facets: + if "institute" not in self.facets: institute = get_institutes(self.facets) if institute: - self.facets['institute'] = institute - if 'activity' not in self.facets: + self.facets["institute"] = institute + if "activity" not in self.facets: activity = get_activity(self.facets) if activity: - self.facets['activity'] = activity + self.facets["activity"] = activity _update_cmor_facets(self.facets) - if self.facets.get('frequency') == 'fx': - self.facets.pop('timerange', None) + if self.facets.get("frequency") == "fx": + self.facets.pop("timerange", None) def find_files(self) -> None: """Find files. @@ -616,7 +645,7 @@ def find_files(self) -> None: """ self.augment_facets() - if _isglob(self.facets.get('timerange')): + if _isglob(self.facets.get("timerange")): self._update_timerange() self._find_files() @@ -630,16 +659,16 @@ def _find_files(self) -> None: ) # If project does not support automatic downloads from ESGF, stop here - if self.facets['project'] not in esgf.facets.FACETS: + if self.facets["project"] not in esgf.facets.FACETS: return # 'never' mode: never download files from ESGF and stop here - if self.session['search_esgf'] == 'never': + if self.session["search_esgf"] == "never": return # 'when_missing' mode: if files are available locally, do not check # ESGF - if self.session['search_esgf'] == 'when_missing': + if self.session["search_esgf"] == "when_missing": try: check.data_availability(self, log=False) except InputFilesNotFound: @@ -659,8 +688,8 @@ def _find_files(self) -> None: # Use ESGF files that are newer than the locally available # files. local_file = local_files[file.name] - if 'version' in local_file.facets: - if file.facets['version'] > local_file.facets['version']: + if "version" in local_file.facets: + if file.facets["version"] > local_file.facets["version"]: idx = self.files.index(local_file) self.files[idx] = file @@ -701,7 +730,7 @@ def load(self, compute=True) -> Cube | Delayed: input_files = list(self.files) for supplementary_dataset in self.supplementaries: input_files.extend(supplementary_dataset.files) - esgf.download(input_files, self.session['download_dir']) + esgf.download(input_files, self.session["download_dir"]) cube = self._load() supplementary_cubes = [] @@ -712,10 +741,10 @@ def load(self, compute=True) -> Cube | Delayed: output_file = _get_output_file(self.facets, self.session.preproc_dir) cubes = dask.delayed(preprocess)( [cube], - 'add_supplementary_variables', + "add_supplementary_variables", input_files=input_files, output_file=output_file, - debug=self.session['save_intermediary_cubes'], + debug=self.session["save_intermediary_cubes"], supplementary_cubes=supplementary_cubes, ) @@ -732,72 +761,72 @@ def _load(self) -> Cube: "locally using glob patterns:", "\n".join(str(f) for f in self._file_globs or []), ] - if self.session['search_esgf'] != 'never': - lines.append('or on ESGF.') + if self.session["search_esgf"] != "never": + lines.append("or on ESGF.") msg = "\n".join(lines) raise InputFilesNotFound(msg) input_files = [ - file.local_file(self.session['download_dir']) if isinstance( - file, esgf.ESGFFile) else file for file in self.files + file.local_file(self.session["download_dir"]) + if isinstance(file, esgf.ESGFFile) + else file + for file in self.files ] output_file = _get_output_file(self.facets, self.session.preproc_dir) - debug = self.session['save_intermediary_cubes'] + debug = self.session["save_intermediary_cubes"] # Load all input files and concatenate them. fix_dir_prefix = Path( self.session._fixed_file_dir, - self._get_joined_summary_facets('_', join_lists=True) + '_', + self._get_joined_summary_facets("_", join_lists=True) + "_", ) settings: dict[str, dict[str, Any]] = {} - settings['fix_file'] = { - 'output_dir': fix_dir_prefix, - 'add_unique_suffix': True, - 'session': self.session, + settings["fix_file"] = { + "output_dir": fix_dir_prefix, + "add_unique_suffix": True, + "session": self.session, **self.facets, } - settings['load'] = { - 'ignore_warnings': get_ignored_warnings( - self.facets['project'], 'load' + settings["load"] = { + "ignore_warnings": get_ignored_warnings( + self.facets["project"], "load" ), } - settings['fix_metadata'] = { - 'check_level': self.session['check_level'], - 'session': self.session, + settings["fix_metadata"] = { + "check_level": self.session["check_level"], + "session": self.session, **self.facets, } - settings['concatenate'] = { - 'check_level': self.session['check_level'] - } + settings["concatenate"] = {"check_level": self.session["check_level"]} result = [] for input_file in input_files: files = dask.delayed(preprocess)( [input_file], - 'fix_file', + "fix_file", input_files=[input_file], output_file=output_file, debug=debug, - **settings['fix_file'], + **settings["fix_file"], ) # Multiple cubes may be present in a file. cubes = dask.delayed(preprocess)( files, - 'load', + "load", input_files=[input_file], output_file=output_file, debug=debug, - **settings['load'], + **settings["load"], ) # Combine the cubes into a single cube per file. cubes = dask.delayed(preprocess)( cubes, - 'fix_metadata', + "fix_metadata", input_files=[input_file], output_file=output_file, debug=debug, - **settings['fix_metadata'], + **settings["fix_metadata"], ) cube = dask.delayed(_first)(cubes) result.append(cube) @@ -805,38 +834,38 @@ def _load(self) -> Cube: # Concatenate the cubes from all files. result = dask.delayed(preprocess)( result, - 'concatenate', + "concatenate", input_files=input_files, output_file=output_file, debug=debug, - **settings['concatenate'], + **settings["concatenate"], ) # At this point `result` is a list containing a single cube. Apply the # remaining preprocessor functions to this cube. settings.clear() - settings['cmor_check_metadata'] = { - 'check_level': self.session['check_level'], - 'cmor_table': self.facets['project'], - 'mip': self.facets['mip'], - 'frequency': self.facets['frequency'], - 'short_name': self.facets['short_name'], + settings["cmor_check_metadata"] = { + "check_level": self.session["check_level"], + "cmor_table": self.facets["project"], + "mip": self.facets["mip"], + "frequency": self.facets["frequency"], + "short_name": self.facets["short_name"], } - if 'timerange' in self.facets: - settings['clip_timerange'] = { - 'timerange': self.facets['timerange'], + if "timerange" in self.facets: + settings["clip_timerange"] = { + "timerange": self.facets["timerange"], } - settings['fix_data'] = { - 'check_level': self.session['check_level'], - 'session': self.session, + settings["fix_data"] = { + "check_level": self.session["check_level"], + "session": self.session, **self.facets, } - settings['cmor_check_data'] = { - 'check_level': self.session['check_level'], - 'cmor_table': self.facets['project'], - 'mip': self.facets['mip'], - 'frequency': self.facets['frequency'], - 'short_name': self.facets['short_name'], + settings["cmor_check_data"] = { + "check_level": self.session["check_level"], + "cmor_table": self.facets["project"], + "mip": self.facets["mip"], + "frequency": self.facets["frequency"], + "short_name": self.facets["short_name"], } for step, kwargs in settings.items(): result = dask.delayed(preprocess)( @@ -851,7 +880,7 @@ def _load(self) -> Cube: cube = dask.delayed(_first)(result) return cube - def from_ranges(self) -> list['Dataset']: + def from_ranges(self) -> list["Dataset"]: """Create a list of datasets from short notations. This expands the ``'ensemble'`` and ``'sub_experiment'`` facets in the @@ -867,10 +896,11 @@ def from_ranges(self) -> list['Dataset']: The datasets. """ datasets = [self] - for key in 'ensemble', 'sub_experiment': + for key in "ensemble", "sub_experiment": if key in self.facets: datasets = [ - ds.copy(**{key: value}) for ds in datasets + ds.copy(**{key: value}) + for ds in datasets for value in ds._expand_range(key) ] return datasets @@ -881,12 +911,12 @@ def _expand_range(self, input_tag): Expansion only supports ensembles defined as strings, not lists. """ expanded = [] - regex = re.compile(r'\(\d+:\d+\)') + regex = re.compile(r"\(\d+:\d+\)") def expand_range(input_range): match = regex.search(input_range) if match: - start, end = match.group(0)[1:-1].split(':') + start, end = match.group(0)[1:-1].split(":") for i in range(int(start), int(end) + 1): range_ = regex.sub(str(i), input_range, 1) expand_range(range_) @@ -899,7 +929,8 @@ def expand_range(input_range): if regex.search(elem): raise RecipeError( f"In {self}: {input_tag} expansion " - f"cannot be combined with {input_tag} lists") + f"cannot be combined with {input_tag} lists" + ) expanded.append(tag) else: expand_range(tag) @@ -915,19 +946,20 @@ def _update_timerange(self): dataset = self.copy() dataset.supplementaries = [] dataset.augment_facets() - if 'timerange' not in dataset.facets: - self.facets.pop('timerange', None) + if "timerange" not in dataset.facets: + self.facets.pop("timerange", None) return - timerange = self.facets['timerange'] + timerange = self.facets["timerange"] if not isinstance(timerange, str): raise TypeError( - f"timerange should be a string, got '{timerange!r}'") + f"timerange should be a string, got '{timerange!r}'" + ) check.valid_time_selection(timerange) - if '*' in timerange: + if "*" in timerange: dataset = self.copy() - dataset.facets.pop('timerange') + dataset.facets.pop("timerange") dataset.supplementaries = [] check.data_availability(dataset) intervals = [_get_start_end_date(f) for f in dataset.files] @@ -935,16 +967,16 @@ def _update_timerange(self): min_date = min(interval[0] for interval in intervals) max_date = max(interval[1] for interval in intervals) - if timerange == '*': - timerange = f'{min_date}/{max_date}' - if '*' in timerange.split('/')[0]: - timerange = timerange.replace('*', min_date) - if '*' in timerange.split('/')[1]: - timerange = timerange.replace('*', max_date) + if timerange == "*": + timerange = f"{min_date}/{max_date}" + if "*" in timerange.split("/")[0]: + timerange = timerange.replace("*", min_date) + if "*" in timerange.split("/")[1]: + timerange = timerange.replace("*", max_date) # Make sure that years are in format YYYY - start_date, end_date = timerange.split('/') + start_date, end_date = timerange.split("/") timerange = _dates_to_timerange(start_date, end_date) check.valid_time_selection(timerange) - self.set_facet('timerange', timerange) + self.set_facet("timerange", timerange) diff --git a/tests/integration/dataset/test_dataset.py b/tests/integration/dataset/test_dataset.py index 19d5aa1f4a..b95222506e 100644 --- a/tests/integration/dataset/test_dataset.py +++ b/tests/integration/dataset/test_dataset.py @@ -12,17 +12,38 @@ @pytest.fixture def example_data(tmp_path, monkeypatch): cwd = Path(__file__).parent - tas_src = cwd / 'tas.nc' - areacella_src = cwd / 'areacella.nc' - - rootpath = tmp_path / 'climate_data' - tas_tgt = (rootpath / 'cmip5' / 'output1' / 'CCCma' / 'CanESM2' / - 'historical' / 'mon' / 'atmos' / 'Amon' / 'r1i1p1' / - 'v20120718' / - 'tas_Amon_CanESM2_historical_r1i1p1_185001-200512.nc') - areacella_tgt = (rootpath / 'cmip5' / 'output1' / 'CCCma' / 'CanESM2' / - 'historical' / 'fx' / 'atmos' / 'fx' / 'r0i0p0' / - 'v20120410' / 'areacella_fx_CanESM2_historical_r0i0p0.nc') + tas_src = cwd / "tas.nc" + areacella_src = cwd / "areacella.nc" + + rootpath = tmp_path / "climate_data" + tas_tgt = ( + rootpath + / "cmip5" + / "output1" + / "CCCma" + / "CanESM2" + / "historical" + / "mon" + / "atmos" + / "Amon" + / "r1i1p1" + / "v20120718" + / "tas_Amon_CanESM2_historical_r1i1p1_185001-200512.nc" + ) + areacella_tgt = ( + rootpath + / "cmip5" + / "output1" + / "CCCma" + / "CanESM2" + / "historical" + / "fx" + / "atmos" + / "fx" + / "r0i0p0" + / "v20120410" + / "areacella_fx_CanESM2_historical_r0i0p0.nc" + ) tas_tgt.parent.mkdir(parents=True, exist_ok=True) tas_tgt.symlink_to(tas_src) @@ -30,23 +51,23 @@ def example_data(tmp_path, monkeypatch): areacella_tgt.parent.mkdir(parents=True, exist_ok=True) areacella_tgt.symlink_to(areacella_src) - monkeypatch.setitem(CFG, 'rootpath', {'CMIP5': str(rootpath)}) - monkeypatch.setitem(CFG, 'drs', {'CMIP5': 'ESGF'}) - monkeypatch.setitem(CFG, 'output_dir', tmp_path / 'output_dir') + monkeypatch.setitem(CFG, "rootpath", {"CMIP5": str(rootpath)}) + monkeypatch.setitem(CFG, "drs", {"CMIP5": "ESGF"}) + monkeypatch.setitem(CFG, "output_dir", tmp_path / "output_dir") -@pytest.mark.parametrize('lazy', [True, False]) +@pytest.mark.parametrize("lazy", [True, False]) def test_load(example_data, lazy): tas = Dataset( - short_name='tas', - mip='Amon', - project='CMIP5', - dataset='CanESM2', - ensemble='r1i1p1', - exp='historical', - timerange='1850/185002', + short_name="tas", + mip="Amon", + project="CMIP5", + dataset="CanESM2", + ensemble="r1i1p1", + exp="historical", + timerange="1850/185002", ) - tas.add_supplementary(short_name='areacella', mip='fx', ensemble='r0i0p0') + tas.add_supplementary(short_name="areacella", mip="fx", ensemble="r0i0p0") tas.augment_facets() From 4057950891f7c4a403b08f9c93f534732cc8fa86 Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Thu, 26 Sep 2024 21:41:34 +0200 Subject: [PATCH 4/8] Add type hint --- esmvalcore/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/esmvalcore/dataset.py b/esmvalcore/dataset.py index 49c62bcb0d..024a4d08f0 100644 --- a/esmvalcore/dataset.py +++ b/esmvalcore/dataset.py @@ -704,7 +704,7 @@ def files(self) -> Sequence[File]: def files(self, value): self._files = value - def load(self, compute=True) -> Cube | Delayed: + def load(self, compute: bool = True) -> Cube | Delayed: """Load dataset. Parameters From 81b41a09099b33b5ace76022b4a5fc81930d85fb Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Thu, 24 Oct 2024 11:11:35 +0200 Subject: [PATCH 5/8] Mark preprocessor functions that modify input as not pure --- esmvalcore/dataset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/esmvalcore/dataset.py b/esmvalcore/dataset.py index 3d7d067173..33dcdbf9cd 100644 --- a/esmvalcore/dataset.py +++ b/esmvalcore/dataset.py @@ -819,7 +819,7 @@ def _load(self) -> Cube: **settings["load"], ) # Combine the cubes into a single cube per file. - cubes = dask.delayed(preprocess)( + cubes = dask.delayed(preprocess, pure=False)( cubes, "fix_metadata", input_files=[input_file], @@ -831,7 +831,7 @@ def _load(self) -> Cube: result.append(cube) # Concatenate the cubes from all files. - result = dask.delayed(preprocess)( + result = dask.delayed(preprocess, pure=False)( result, "concatenate", input_files=input_files, @@ -855,6 +855,7 @@ def _load(self) -> Cube: "timerange": self.facets["timerange"], } settings["fix_data"] = { + "pure": False, "session": self.session, **self.facets, } @@ -866,7 +867,8 @@ def _load(self) -> Cube: "short_name": self.facets["short_name"], } for step, kwargs in settings.items(): - result = dask.delayed(preprocess)( + pure = settings.pop("pure", True) + result = dask.delayed(preprocess, pure=pure)( result, step, input_files=input_files, From d5a39af44a176884c2012c49b451eb304438c0b2 Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Fri, 15 Nov 2024 09:50:38 +0100 Subject: [PATCH 6/8] Remove grouping by file inside fix_metadata --- esmvalcore/cmor/fix.py | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/esmvalcore/cmor/fix.py b/esmvalcore/cmor/fix.py index ab81353cfb..e75167a7bb 100644 --- a/esmvalcore/cmor/fix.py +++ b/esmvalcore/cmor/fix.py @@ -8,7 +8,6 @@ from __future__ import annotations import logging -from collections import defaultdict from collections.abc import Sequence from pathlib import Path from typing import TYPE_CHECKING, Optional @@ -137,7 +136,7 @@ def fix_metadata( Returns ------- iris.cube.CubeList - Fixed cubes. + A list containing a single fixed cube. """ # Update extra_facets with variable information given as regular arguments @@ -161,27 +160,13 @@ def fix_metadata( session=session, frequency=frequency, ) - fixed_cubes = CubeList() - # Group cubes by input file and apply all fixes to each group element - # (i.e., each file) individually - by_file = defaultdict(list) - for cube in cubes: - by_file[cube.attributes.get("source_file", "")].append(cube) + for fix in fixes: + cubes = fix.fix_metadata(cubes) - for cube_list in by_file.values(): - cube_list = CubeList(cube_list) - for fix in fixes: - cube_list = fix.fix_metadata(cube_list) - - # The final fix is always GenericFix, whose fix_metadata method always - # returns a single cube - cube = cube_list[0] - - cube.attributes.pop("source_file", None) - fixed_cubes.append(cube) - - return fixed_cubes + # The final fix is always GenericFix, whose fix_metadata method always + # returns a single cube + return CubeList(cubes[:1]) def fix_data( From 99f875d91da6b15dff5611d3ae1201e49b4ae120 Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Sat, 30 Nov 2024 21:31:23 +0100 Subject: [PATCH 7/8] Ensure cubes are in a CubeList Co-authored-by: Manuel Schlund <32543114+schlunma@users.noreply.github.com> --- esmvalcore/cmor/fix.py | 1 + 1 file changed, 1 insertion(+) diff --git a/esmvalcore/cmor/fix.py b/esmvalcore/cmor/fix.py index e75167a7bb..45d281ef51 100644 --- a/esmvalcore/cmor/fix.py +++ b/esmvalcore/cmor/fix.py @@ -161,6 +161,7 @@ def fix_metadata( frequency=frequency, ) + cubes = CubeList(cubes) for fix in fixes: cubes = fix.fix_metadata(cubes) From 2f9c39c93793820dfb06c7ce6e9a357f391a59ab Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Sat, 30 Nov 2024 22:32:01 +0100 Subject: [PATCH 8/8] Remove source_file cube attribute after CMOR fixes and checks --- esmvalcore/cmor/check.py | 6 ++++++ tests/unit/test_cmor_api.py | 14 +++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/esmvalcore/cmor/check.py b/esmvalcore/cmor/check.py index a75dcdaab4..de10d74424 100644 --- a/esmvalcore/cmor/check.py +++ b/esmvalcore/cmor/check.py @@ -985,6 +985,12 @@ def cmor_check_data( check_level=check_level, ) cube = checker(cube).check_data() + # Remove the "source_file" attribute that `esmvalcore.preprocessor.load` + # adds for CMOR fix and check function logging purposes. This is a bit + # ugly and it would be nice to stop using the "source_file" attribute and + # pass the data source as an argument to those functions that need it + # instead. + cube.attributes.pop("source_file", None) return cube diff --git a/tests/unit/test_cmor_api.py b/tests/unit/test_cmor_api.py index cce1fab9d8..66d69215ae 100644 --- a/tests/unit/test_cmor_api.py +++ b/tests/unit/test_cmor_api.py @@ -41,9 +41,7 @@ def test_cmor_check_metadata(mocker): check_level=sentinel.check_level, ) mock_get_cmor_checker.return_value.assert_called_once_with(sentinel.cube) - ( - mock_get_cmor_checker.return_value.return_value.check_metadata.assert_called_once_with() - ) + mock_get_cmor_checker.return_value.return_value.check_metadata.assert_called_once_with() assert cube == sentinel.checked_cube @@ -52,9 +50,6 @@ def test_cmor_check_data(mocker): mock_get_cmor_checker = mocker.patch.object( esmvalcore.cmor.check, "_get_cmor_checker", autospec=True ) - ( - mock_get_cmor_checker.return_value.return_value.check_data.return_value - ) = sentinel.checked_cube cube = cmor_check_data( sentinel.cube, @@ -73,10 +68,11 @@ def test_cmor_check_data(mocker): check_level=sentinel.check_level, ) mock_get_cmor_checker.return_value.assert_called_once_with(sentinel.cube) - ( - mock_get_cmor_checker.return_value.return_value.check_data.assert_called_once_with() + mock_get_cmor_checker.return_value.return_value.check_data.assert_called_once_with() + checked_cube = ( + mock_get_cmor_checker.return_value.return_value.check_data.return_value ) - assert cube == sentinel.checked_cube + assert cube == checked_cube def test_cmor_check(mocker):