Skip to content

Commit 0f6489a

Browse files
authored
Record attributes in provenance via the LocalFile interface instead of directly accessing the file (#2854)
1 parent 4b3e8b0 commit 0f6489a

File tree

8 files changed

+188
-100
lines changed

8 files changed

+188
-100
lines changed

esmvalcore/_provenance.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
import os
66
from functools import total_ordering
7-
from pathlib import Path
87

98
from netCDF4 import Dataset
109
from PIL import Image
@@ -214,11 +213,9 @@ def _initialize_activity(self, activity):
214213
def _initialize_entity(self):
215214
"""Initialize the entity representing the file."""
216215
if self.attributes is None:
217-
self.attributes = {}
218-
if "nc" in Path(self.filename).suffix:
219-
with Dataset(self.filename, "r") as dataset:
220-
for attr in dataset.ncattrs():
221-
self.attributes[attr] = dataset.getncattr(attr)
216+
# This happens for ancestor files of preprocessor files as created
217+
# in esmvalcore.preprocessor.Processorfile.__init__.
218+
self.attributes = copy.deepcopy(self.filename.attributes)
222219

223220
attributes = {
224221
"attribute:" + str(k).replace(" ", "_"): str(v)

esmvalcore/local.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import copy
56
import itertools
67
import logging
78
import os
@@ -15,15 +16,18 @@
1516
from cf_units import Unit
1617
from netCDF4 import Dataset, Variable
1718

18-
from .config import CFG
19-
from .config._config import get_project_config
20-
from .exceptions import RecipeError
19+
from esmvalcore.config import CFG
20+
from esmvalcore.config._config import get_project_config
21+
from esmvalcore.exceptions import RecipeError
22+
from esmvalcore.preprocessor._io import _load_from_file
2123

2224
if TYPE_CHECKING:
2325
from collections.abc import Iterable
2426

25-
from .esgf import ESGFFile
26-
from .typing import Facets, FacetValue
27+
import iris.cube
28+
29+
from esmvalcore.esgf import ESGFFile
30+
from esmvalcore.typing import Facets, FacetValue
2731

2832
logger = logging.getLogger(__name__)
2933

@@ -854,3 +858,34 @@ def facets(self) -> Facets:
854858
@facets.setter
855859
def facets(self, value: Facets) -> None:
856860
self._facets = value
861+
862+
@property
863+
def attributes(self) -> dict[str, Any]:
864+
"""Attributes read from the file."""
865+
if not hasattr(self, "_attributes"):
866+
msg = (
867+
"Attributes have not been read yet. Call the `to_iris` method "
868+
"first to read the attributes from the file."
869+
)
870+
raise ValueError(msg)
871+
return self._attributes
872+
873+
@attributes.setter
874+
def attributes(self, value: dict[str, Any]) -> None:
875+
self._attributes = value
876+
877+
def to_iris(
878+
self,
879+
ignore_warnings: list[dict[str, Any]] | None = None,
880+
) -> iris.cube.CubeList:
881+
"""Load the data as Iris cubes.
882+
883+
Returns
884+
-------
885+
iris.cube.CubeList
886+
The loaded data.
887+
"""
888+
cubes = _load_from_file(self, ignore_warnings=ignore_warnings)
889+
# Cache the attributes.
890+
self.attributes = copy.deepcopy(dict(cubes[0].attributes.globals))
891+
return cubes

esmvalcore/preprocessor/__init__.py

Lines changed: 18 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
if TYPE_CHECKING:
105105
from collections.abc import Callable, Iterable, Sequence
106106

107+
import prov.model
107108
from dask.delayed import Delayed
108109

109110
from esmvalcore.dataset import Dataset, File
@@ -528,8 +529,9 @@ def __init__(
528529
input_files.extend(supplementary.files)
529530
ancestors = [TrackedFile(f) for f in input_files]
530531
else:
531-
# Multimodel preprocessor functions set ancestors at runtime
532-
# instead of here.
532+
# Multimodel preprocessor functions set ancestors at runtime,
533+
# in `esmvalcore.preprocessor.multi_model_statistics` and
534+
# `esmvalcore.preprocessor.ensemble_statistics` instead of here.
533535
input_files = []
534536
ancestors = []
535537

@@ -556,6 +558,8 @@ def __init__(
556558
ancestors=ancestors,
557559
)
558560

561+
self.activity = None
562+
559563
def check(self) -> None:
560564
"""Check preprocessor settings."""
561565
check_preprocessor_settings(self.settings)
@@ -579,6 +583,10 @@ def cubes(self) -> list[Cube]:
579583
"""Cubes."""
580584
if self._cubes is None:
581585
self._cubes = [ds.load() for ds in self.datasets] # type: ignore
586+
# Initialize provenance after loading the data, so that we can reuse
587+
# the global attributes that have been read from the input files.
588+
self.initialize_provenance(self.activity)
589+
582590
return self._cubes
583591

584592
@cubes.setter
@@ -669,6 +677,7 @@ def group(self, keys: list) -> str:
669677
def _apply_multimodel(
670678
products: set[PreprocessorFile],
671679
step: str,
680+
activity: prov.model.ProvActivity,
672681
debug: bool | None,
673682
) -> set[PreprocessorFile]:
674683
"""Apply multi model step to products."""
@@ -679,6 +688,10 @@ def _apply_multimodel(
679688
step,
680689
"\n".join(str(p) for p in products - exclude),
681690
)
691+
for output_product_group in settings.get("output_products", {}).values():
692+
for output_product in output_product_group.values():
693+
output_product.initialize_provenance(activity)
694+
682695
result: list[PreprocessorFile] = preprocess( # type: ignore
683696
products - exclude, # type: ignore
684697
step,
@@ -714,50 +727,10 @@ def __init__(
714727
self.debug = debug
715728
self.write_ncl_interface = write_ncl_interface
716729

717-
def _initialize_product_provenance(self) -> None:
718-
"""Initialize product provenance."""
719-
self._initialize_products(self.products)
720-
self._initialize_multimodel_provenance()
721-
self._initialize_ensemble_provenance()
722-
723-
def _initialize_multiproduct_provenance(self, step: str) -> None:
724-
input_products = self._get_input_products(step)
725-
if input_products:
726-
statistic_products = set()
727-
728-
for input_product in input_products:
729-
step_settings = input_product.settings[step]
730-
output_products = step_settings.get("output_products", {})
731-
732-
for product in output_products.values():
733-
statistic_products.update(product.values())
734-
735-
self._initialize_products(statistic_products)
736-
737-
def _initialize_multimodel_provenance(self) -> None:
738-
"""Initialize provenance for multi-model statistics."""
739-
step = "multi_model_statistics"
740-
self._initialize_multiproduct_provenance(step)
741-
742-
def _initialize_ensemble_provenance(self) -> None:
743-
"""Initialize provenance for ensemble statistics."""
744-
step = "ensemble_statistics"
745-
self._initialize_multiproduct_provenance(step)
746-
747-
def _get_input_products(self, step: str) -> list[PreprocessorFile]:
748-
"""Get input products."""
749-
return [
750-
product for product in self.products if step in product.settings
751-
]
752-
753-
def _initialize_products(self, products: set[PreprocessorFile]) -> None:
754-
"""Initialize products."""
755-
for product in products:
756-
product.initialize_provenance(self.activity)
757-
758730
def _run(self, _) -> list[str]: # noqa: C901,PLR0912
759731
"""Run the preprocessor."""
760-
self._initialize_product_provenance()
732+
for product in self.products:
733+
product.activity = self.activity
761734

762735
steps = {
763736
step for product in self.products for step in product.settings
@@ -773,6 +746,7 @@ def _run(self, _) -> list[str]: # noqa: C901,PLR0912
773746
self.products = _apply_multimodel(
774747
self.products,
775748
step,
749+
self.activity,
776750
self.debug,
777751
)
778752
else:

esmvalcore/preprocessor/_io.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ def load(
113113
Invalid type for ``file``.
114114
115115
"""
116-
if isinstance(file, (str, Path)):
116+
if hasattr(file, "to_iris"):
117+
cubes = file.to_iris(ignore_warnings=ignore_warnings)
118+
elif isinstance(file, (str, Path)):
117119
extension = (
118120
file.suffix
119121
if isinstance(file, Path)

tests/integration/preprocessor/test_preprocessing_task.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
"""Tests for `esmvalcore.preprocessor.PreprocessingTask`."""
22

3+
from pathlib import Path
4+
35
import iris
46
import iris.cube
57
import pytest
68
from prov.model import ProvDocument
79

810
import esmvalcore.preprocessor
911
from esmvalcore.dataset import Dataset
12+
from esmvalcore.local import LocalFile
1013
from esmvalcore.preprocessor import PreprocessingTask, PreprocessorFile
1114

1215

@@ -15,11 +18,11 @@ def test_load_save_task(tmp_path, mocker, scheduler_lock):
1518
"""Test that a task that just loads and saves a file."""
1619
# Prepare a test dataset
1720
cube = iris.cube.Cube(data=[273.0], var_name="tas", units="K")
18-
in_file = tmp_path / "tas_in.nc"
21+
in_file = LocalFile(tmp_path / "tas_in.nc")
1922
iris.save(cube, in_file)
2023
dataset = Dataset(short_name="tas")
2124
dataset.files = [in_file]
22-
dataset.load = lambda: cube.copy()
25+
dataset.load = lambda: in_file.to_iris()[0]
2326

2427
# Create task
2528
task = PreprocessingTask(
@@ -62,33 +65,39 @@ def test_load_save_and_other_task(tmp_path, monkeypatch):
6265
# Prepare test datasets
6366
in_cube = iris.cube.Cube(data=[0.0], var_name="tas", units="degrees_C")
6467
(tmp_path / "climate_data").mkdir()
65-
file1 = tmp_path / "climate_data" / "tas_dataset1.nc"
66-
file2 = tmp_path / "climate_data" / "tas_dataset2.nc"
68+
file1 = LocalFile(tmp_path / "climate_data" / "tas_dataset1.nc")
69+
file2 = LocalFile(tmp_path / "climate_data" / "tas_dataset2.nc")
6770

6871
# Save cubes for reading global attributes into provenance
6972
iris.save(in_cube, target=file1)
7073
iris.save(in_cube, target=file2)
7174

7275
dataset1 = Dataset(short_name="tas", dataset="dataset1")
7376
dataset1.files = [file1]
74-
dataset1.load = lambda: in_cube.copy()
77+
dataset1.load = lambda: file1.to_iris()[0]
7578

7679
dataset2 = Dataset(short_name="tas", dataset="dataset1")
7780
dataset2.files = [file2]
78-
dataset2.load = lambda: in_cube.copy()
81+
dataset2.load = lambda: file2.to_iris()[0]
7982

8083
# Create some mock preprocessor functions and patch
8184
# `esmvalcore.preprocessor` so it uses them.
8285
def single_preproc_func(cube):
8386
cube.data = cube.core_data() + 1.0
8487
return cube
8588

86-
def multi_preproc_func(products):
89+
def multi_preproc_func(products, output_products):
90+
# Preprocessor function that mimics the behaviour of e.g.
91+
# `esmvalcore.preprocessor.multi_model_statistics`.`
8792
for product in products:
8893
cube = product.cubes[0]
8994
cube.data = cube.core_data() + 1.0
9095
product.cubes = [cube]
91-
return products
96+
output_product = output_products[""]["mean"]
97+
output_product.cubes = [
98+
iris.cube.Cube([5.0], var_name="tas", units="degrees_C"),
99+
]
100+
return products | {output_product}
92101

93102
monkeypatch.setattr(
94103
esmvalcore.preprocessor,
@@ -132,7 +141,17 @@ def multi_preproc_func(products):
132141
filename=tmp_path / "tas_dataset2.nc",
133142
settings={
134143
"single_preproc_func": {},
135-
"multi_preproc_func": {},
144+
"multi_preproc_func": {
145+
"output_products": {
146+
"": {
147+
"mean": PreprocessorFile(
148+
filename=tmp_path / "tas_dataset2_mean.nc",
149+
attributes={"dataset": "dataset2_mean"},
150+
settings={},
151+
),
152+
},
153+
},
154+
},
136155
},
137156
datasets=[dataset2],
138157
attributes={"dataset": "dataset2"},
@@ -149,9 +168,9 @@ def multi_preproc_func(products):
149168

150169
task.run()
151170

152-
# Check that two files were saved and the preprocessor functions were
171+
# Check that three files were saved and the preprocessor functions were
153172
# only applied to the second one.
154-
assert len(task.products) == 2
173+
assert len(task.products) == 3
155174
for product in task.products:
156175
print(product.filename)
157176
assert product.filename.exists()
@@ -161,6 +180,11 @@ def multi_preproc_func(products):
161180
assert out_cube.data.tolist() == [0.0]
162181
elif product.attributes["dataset"] == "dataset2":
163182
assert out_cube.data.tolist() == [2.0]
183+
elif product.attributes["dataset"] == "dataset2_mean":
184+
assert out_cube.data.tolist() == [5.0]
164185
else:
165186
msg = "unexpected product"
166187
raise AssertionError(msg)
188+
provenance_file = Path(product.provenance_file)
189+
assert provenance_file.exists()
190+
assert provenance_file.read_text(encoding="utf-8")

tests/integration/recipe/test_recipe.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,8 +1411,14 @@ def get_diagnostic_filename(basename, cfg, extension="nc"):
14111411

14121412
def simulate_preprocessor_run(task):
14131413
"""Simulate preprocessor run."""
1414-
task._initialize_product_provenance()
14151414
for product in task.products:
1415+
# Populate the LocalFile.attributes attribute and initialize
1416+
# provenance as done in `PreprocessingTask.cubes`.
1417+
for dataset in product.datasets:
1418+
for file in dataset.files:
1419+
file.to_iris()
1420+
product.initialize_provenance(task.activity)
1421+
14161422
create_test_file(product.filename)
14171423
product.save_provenance()
14181424

@@ -1871,9 +1877,6 @@ def test_ensemble_statistics(tmp_path, patched_datafinder, session):
18711877

18721878
assert len(product_out) == len(datasets) * len(statistics)
18731879

1874-
task._initialize_product_provenance()
1875-
assert next(iter(products)).provenance is not None
1876-
18771880

18781881
def test_multi_model_statistics(tmp_path, patched_datafinder, session):
18791882
statistics = ["mean", "max"]
@@ -1920,9 +1923,6 @@ def test_multi_model_statistics(tmp_path, patched_datafinder, session):
19201923

19211924
assert len(product_out) == len(statistics)
19221925

1923-
task._initialize_product_provenance()
1924-
assert next(iter(products)).provenance is not None
1925-
19261926

19271927
def test_multi_model_statistics_exclude(tmp_path, patched_datafinder, session):
19281928
statistics = ["mean", "max"]
@@ -1976,8 +1976,6 @@ def test_multi_model_statistics_exclude(tmp_path, patched_datafinder, session):
19761976
for id_, _ in product_out:
19771977
assert id_ != "OBS"
19781978
assert id_ == "CMIP5"
1979-
task._initialize_product_provenance()
1980-
assert next(iter(products)).provenance is not None
19811979

19821980

19831981
def test_groupby_combined_statistics(tmp_path, patched_datafinder, session):

0 commit comments

Comments
 (0)