diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 09d4918e..853c8f8e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,7 +31,7 @@ jobs: python-version: 3.13 - name: "Install dependencies" - run: pip install ".[test]" + run: pip install ".[test,era5,lineage]" - name: Run tests run: pytest --cov=. --cov-report html --cov-report term diff --git a/docs/era5.md b/docs/era5.md new file mode 100644 index 00000000..f4b4b27f --- /dev/null +++ b/docs/era5.md @@ -0,0 +1,410 @@ +# OpenHEXA Toolbox ERA5 + +Download and process ERA5-Land climate reanalysis data from the [Copernicus Climate Data +Store](https://www.google.com/url?sa=t&source=web&rct=j&opi=89978449&url=https://cds.climate.copernicus.eu/&ved=2ahUKEwi0x-Pl4aqQAxUnRKQEHftaGdAQFnoECBEQAQ&usg=AOvVaw1BwvwpB-Kja5hnXP6DTcbl) +(CDS). + +- [Overview](#overview) +- [Installation](#installation) +- [Supported variables](#supported-variables) +- [Usage](#usage) + - [Prepare and retrieve data requests](#prepare-and-retrieve-data-requests) + - [Move GRIB files into a Zarr store](#move-grib-files-into-a-zarr-store) + - [Read climate data from a Zarr store](#read-climate-data-from-a-zarr-store) + - [Aggregate climate data stored in a Zarr store](#aggregate-climate-data-stored-in-a-zarr-store) +- [Calculate derived variables](#calculate-derived-variables) + - [Relative humidity](#relative-humidity) + - [Wind speed](#wind-speed) +- [Tests](#tests) + +## Overview + +The package provides tools to: +- Download ERA5-Land hourly data from ECMWF's Climate Data Store +- Convert GRIB files to analysis-ready Zarr format +- Perform spatial aggregation using geographic boundaries +- Aggregate data temporally across various periods (daily, weekly, monthly, yearly) +- Support DHIS2-compatible weekly periods (standard, Wednesday, Thursday, Saturday, Sunday weeks) + +## Installation + +With pip: + +```bash +pip install openhexa.toolbox[all] +# Or +pip install openhexa.toolbox[era5] +``` + +With uv: + +```bash +uv add openhexa.toolbox --extra all +# Or +uv add openhexa.toolbox --extra era5 +``` + +## Supported variables + +The module supports a subset of ERA5-Land variables commonly used in health: + +- 10m u-component of wind (`u10`) +- 10m v-component of wind (`v10`) +- 2m dewpoint temperature (`d2m`) +- 2m temperature (`t2m`) +- Runoff (`ro`) +- Soil temperature level 1 (`stl1`) +- Volumetric soil water layer 1 (`swvl1`) +- Volumetric soil water layer 2 (`swvl2`) +- Total precipitation (`tp`) +- Total evaporation (`e`) + +When fetching hourly data, we sample instantaneous variable at 4 daily steps: 01:00, +07:00, 13:00 and 19:00. For accumulated variables (e.g. total precipitation), we only +retrieve totals at the end of each day. + +See [variables.toml](/openhexa/toolbox/era5/data/variables.toml) for more details on +supported variables. + +## Usage + +### Prepare and retrieve data requests + +Download ERA5-Land data from the CDS API. You'll need to set up your CDS API credentials +first (see [CDS API setup](https://cds.climate.copernicus.eu/how-to-api)) and accept the +license of the dataset you want to download. + +```python +from datetime import date +from pathlib import Path +from ecmwf.datastores import Client +from openhexa.toolbox.era5.extract import prepare_requests, retrieve_requests +import os + +client = Client(url=os.getenv("CDS_API_URL"), key=os.getenv("CDS_API_KEY")) + +# Prepare the data requests that need to be submitted to the CDS +# If data already exists in the destination zarr store, it will not be requested again +# NB: At this point, no data is moved to the Zarr store - it is used to avoid +# downloading data we already have +requests = prepare_requests( + client=client, + dataset_id="reanalysis-era5-land", + start_date=date(2025, 3, 28), + end_date=date(2025, 4, 5), + variable="2m_temperature", + area=[10, -1, 8, 1], # [north, west, south, east] in degrees + zarr_store=Path("data/2m_temperature.zarr"), +) + +# Submit data requests and retrieve data in GRIB format as they are ready +# Depending on request size and server load, this may take a while +retrieve_requests( + client=client, + dataset_id="reanalysis-era5-land", + requests=requests, + dst_dir=Path("data/raw"), + wait=30, # Check every 30 seconds for completed requests +) +``` + +### Move GRIB files into a Zarr store + +Convert downloaded GRIB files into an analysis-ready Zarr store for efficient access. + +```python +from pathlib import Path +from openhexa.toolbox.era5.extract import grib_to_zarr + +grib_to_zarr( + src_dir=Path("data/raw"), + zarr_store=Path("data/2m_temperature.zarr"), + data_var="t2m", # Short name for 2m temperature +) +``` + +### Read climate data from a Zarr store + +Data is stored in [Zarr](https://zarr.dev/) stores for efficient storage and access of +climate variables as N-dimensional arrays. You can read data in Zarr stores using +[xarray](https://xarray.dev/). + +When opening a Zarr store, no data is loaded into memory yet. You can check the dataset +structure without loading the data. + +```python +import xarray as xr + +ds = xr.open_zarr("data/2m_temperature.zarr", consolidated=True) +print(ds) +``` +``` + Size: 7MB +Dimensions: (latitude: 71, longitude: 91, time: 284) +Coordinates: + * latitude (latitude) float64 568B 16.0 15.9 15.8 15.7 ... 9.3 9.2 9.1 9.0 + * longitude (longitude) float64 728B -6.0 -5.9 -5.8 -5.7 ... 2.7 2.8 2.9 3.0 + * time (time) datetime64[ns] 2kB 2024-10-01T01:00:00 ... 2024-12-10T1... +Data variables: + t2m (latitude, longitude, time) float32 7MB ... +Attributes: + Conventions: CF-1.7 + GRIB_centre: ecmf + GRIB_centreDescription: European Centre for Medium-Range Weather Forecasts + GRIB_edition: 1 + GRIB_subCentre: 0 + history: 2025-10-14T09:02 GRIB to CDM+CF via cfgrib-0.9.1... + institution: European Centre for Medium-Range Weather Forecasts +``` + +You can use real dates and coordinates to index the data. + +```python +import xarray as xr + +t2m = xr.open_zarr("data/2m_temperature.zarr", consolidated=True) +t2m_daily_mean = t2m.resample(time="1D").mean() +t2m_daily_mean.mean(dim=["latitude", "longitude"]).t2m.plot.line() +``` + +![ERA5 2m Temperature Daily Mean](/docs/images/era5_t2m_lineplot.png) + +### Aggregate climate data stored in a Zarr store + +Aggregate hourly climate data by administrative boundaries and time periods. + +```python +from pathlib import Path +import geopandas as gpd +import xarray as xr +from openhexa.toolbox.era5.transform import ( + create_masks, + aggregate_in_space, + aggregate_in_time, + Period, +) + +t2m = xr.open_zarr("./2m_temperature.zarr", consolidated=True, decode_timedelta=False) +``` + +For instantaneous variables (e.g. 2m temperature, soil moisture...), hourly data should +be aggregated to daily 1st. In ERA5-Land data, data is structured along 2 temporal +dimensions: `time` and `step`. To aggregate hourly data to daily, you need to average over +the `step` dimension: + +```python +t2m_daily = t2m.mean(dim="step") + +# or to compute daily extremes +t2m_daily_max = t2m.max(dim="step") +t2m_daily_min = t2m.min(dim="step") +``` + +```python +import matplotlib.pyplot as plt + +plt.imshow( + t2m_daily.sel(time="2024-10-04").t2m, + cmap="coolwarm", +) +plt.colorbar(label="Temperature (°C)", shrink=0.8) +plt.axis("off") +``` +![2m temperature raster](/docs/images/era5_t2m_raster.png) + +The module provides helper functions to help you perform spatial aggregation on gridded +ERA5 data. Use the `create_masks()` function to create raster masks from vector +boundaries. Raster masks uses the same grid as the ERA5 dataset. + +```python +import geopandas as gpd +from openhexa.toolbox.era5.transform import create_masks + +# Boundaries geographic file should use EPSG:4326 coordinate reference system (lat/lon) +boundaries = gpd.read_file("boundaries.gpkg") + +masks = create_masks( + gdf=boundaries, + id_column="district_id", # Column in the GeoDataFrame with unique boundary IDs + ds=t2m_daily, +) +``` + +Example of raster mask for 1 vector boundary: + +![Boundary vector](/docs/images/era5_boundary_vector.png) +![Boundary raster mask](/docs/images/era5_boundary_raster.png) + +You can now aggregate daily gridded ERA5 data in space and time: + +```python +from openhexa.toolbox.era5.transform import aggregate_in_space, aggregate_in_time, Period + +# convert from Kelvin to Celsius +t2m_daily = t2m_daily - 273.15 + +t2m_agg = aggregate_in_space( + ds=t2m_daily, + masks=masks, + variable="t2m", + agg="mean", +) +print(t2m_agg) +``` + +``` +shape: (4_970, 3) +┌─────────────┬────────────┬───────────┐ +│ boundary ┆ time ┆ value │ +│ --- ┆ --- ┆ --- │ +│ str ┆ date ┆ f64 │ +╞═════════════╪════════════╪═══════════╡ +│ mPenE8ZIBFC ┆ 2024-10-01 ┆ 26.534632 │ +│ mPenE8ZIBFC ┆ 2024-10-02 ┆ 25.860088 │ +│ mPenE8ZIBFC ┆ 2024-10-03 ┆ 26.068018 │ +│ mPenE8ZIBFC ┆ 2024-10-04 ┆ 26.103462 │ +│ mPenE8ZIBFC ┆ 2024-10-05 ┆ 24.362678 │ +│ … ┆ … ┆ … │ +│ eKYyXbBdvmB ┆ 2024-12-06 ┆ 25.130324 │ +│ eKYyXbBdvmB ┆ 2024-12-07 ┆ 24.946449 │ +│ eKYyXbBdvmB ┆ 2024-12-08 ┆ 24.840832 │ +│ eKYyXbBdvmB ┆ 2024-12-09 ┆ 25.242334 │ +│ eKYyXbBdvmB ┆ 2024-12-10 ┆ 26.697817 │ +└─────────────┴────────────┴───────────┘ +``` + +Likewise, to aggregate in time (e.g. weekly averages): + +```python +t2m_weekly = aggregate_in_time( + dataframe=t2m_agg, + period=Period.WEEK, + agg="mean", +) +print(t2m_weekly) +``` + +``` +shape: (770, 3) +┌─────────────┬─────────┬───────────┐ +│ boundary ┆ period ┆ value │ +│ --- ┆ --- ┆ --- │ +│ str ┆ str ┆ f64 │ +╞═════════════╪═════════╪═══════════╡ +│ AKVCJJ2TKSi ┆ 2024W40 ┆ 27.33611 │ +│ AKVCJJ2TKSi ┆ 2024W41 ┆ 27.011093 │ +│ AKVCJJ2TKSi ┆ 2024W42 ┆ 27.905081 │ +│ AKVCJJ2TKSi ┆ 2024W43 ┆ 28.239824 │ +│ AKVCJJ2TKSi ┆ 2024W44 ┆ 27.34595 │ +│ … ┆ … ┆ … │ +│ yhs1ecKsLOc ┆ 2024W46 ┆ 27.711391 │ +│ yhs1ecKsLOc ┆ 2024W47 ┆ 26.394333 │ +│ yhs1ecKsLOc ┆ 2024W48 ┆ 24.863514 │ +│ yhs1ecKsLOc ┆ 2024W49 ┆ 24.714464 │ +│ yhs1ecKsLOc ┆ 2024W50 ┆ 24.923738 │ +└─────────────┴─────────┴───────────┘ +``` + +Or per week starting on Sundays: + +``` python +t2m_sunday_week = aggregate_in_time( + dataframe=t2m_agg, + period=Period.WEEK_SUNDAY, + agg="mean", +) +print(t2m_sunday_week) +``` + +``` +shape: (770, 3) +┌─────────────┬────────────┬───────────┐ +│ boundary ┆ period ┆ value │ +│ --- ┆ --- ┆ --- │ +│ str ┆ str ┆ f64 │ +╞═════════════╪════════════╪═══════════╡ +│ AKVCJJ2TKSi ┆ 2024SunW40 ┆ 27.898345 │ +│ AKVCJJ2TKSi ┆ 2024SunW41 ┆ 26.483939 │ +│ AKVCJJ2TKSi ┆ 2024SunW42 ┆ 27.9347 │ +│ AKVCJJ2TKSi ┆ 2024SunW43 ┆ 28.291441 │ +│ AKVCJJ2TKSi ┆ 2024SunW44 ┆ 27.510819 │ +│ … ┆ … ┆ … │ +│ yhs1ecKsLOc ┆ 2024SunW46 ┆ 27.691862 │ +│ yhs1ecKsLOc ┆ 2024SunW47 ┆ 26.316256 │ +│ yhs1ecKsLOc ┆ 2024SunW48 ┆ 25.249807 │ +│ yhs1ecKsLOc ┆ 2024SunW49 ┆ 24.751227 │ +│ yhs1ecKsLOc ┆ 2024SunW50 ┆ 24.542277 │ +└─────────────┴────────────┴───────────┘ +``` + +Or per month: + +``` python +t2m_monthly = aggregate_in_time( + dataframe=t2m_agg, + period=Period.MONTH, + agg="mean", +) +print(t2m_monthly) +``` + +``` +shape: (210, 3) +┌─────────────┬────────┬───────────┐ +│ boundary ┆ period ┆ value │ +│ --- ┆ --- ┆ --- │ +│ str ┆ str ┆ f64 │ +╞═════════════╪════════╪═══════════╡ +│ AKVCJJ2TKSi ┆ 202410 ┆ 27.615368 │ +│ AKVCJJ2TKSi ┆ 202411 ┆ 26.527692 │ +│ AKVCJJ2TKSi ┆ 202412 ┆ 25.080745 │ +│ AVb6wBstPAo ┆ 202410 ┆ 29.747595 │ +│ AVb6wBstPAo ┆ 202411 ┆ 26.137431 │ +│ … ┆ … ┆ … │ +│ vQ6AJUeqBpc ┆ 202411 ┆ 25.915338 │ +│ vQ6AJUeqBpc ┆ 202412 ┆ 23.130632 │ +│ yhs1ecKsLOc ┆ 202410 ┆ 29.050539 │ +│ yhs1ecKsLOc ┆ 202411 ┆ 26.628291 │ +│ yhs1ecKsLOc ┆ 202412 ┆ 24.688542 │ +└─────────────┴────────┴───────────┘ +``` + +Note that the period column uses DHIS2 format (e.g. `2024W40` for week 40 of 2024). + +## Calculate derived variables + +### Relative humidity + +You can compute relative humidity from 2m temperature and 2m dewpoint temperature. + +```python +from openhexa.toolbox.era5.transform import calculate_relative_humidity + +rh = calculate_relative_humidity( + t2m=t2m_daily, + d2m=d2m_daily, +) +``` + +### Wind speed + +You can compute wind speed from the 10m u-component and v-component of wind. + +```python +from openhexa.toolbox.era5.transform import calculate_wind_speed + +ws = calculate_wind_speed( + u10=u10_daily, + v10=v10_daily, +) +``` + +## Tests + +The module uses Pytest. To run tests, install development dependencies and execute +Pytest in the virtual environment. + +```bash +uv sync --dev +uv run pytest tests/era5/* +``` \ No newline at end of file diff --git a/docs/images/era5_boundary_raster.png b/docs/images/era5_boundary_raster.png new file mode 100644 index 00000000..1c22784c Binary files /dev/null and b/docs/images/era5_boundary_raster.png differ diff --git a/docs/images/era5_boundary_vector.png b/docs/images/era5_boundary_vector.png new file mode 100644 index 00000000..7d401620 Binary files /dev/null and b/docs/images/era5_boundary_vector.png differ diff --git a/docs/images/era5_t2m_lineplot.png b/docs/images/era5_t2m_lineplot.png new file mode 100644 index 00000000..3f8fc37b Binary files /dev/null and b/docs/images/era5_t2m_lineplot.png differ diff --git a/docs/images/era5_t2m_raster.png b/docs/images/era5_t2m_raster.png new file mode 100644 index 00000000..cb81daca Binary files /dev/null and b/docs/images/era5_t2m_raster.png differ diff --git a/openhexa/toolbox/era5/README.md b/openhexa/toolbox/era5/README.md index 7e6ec0cd..f4b4b27f 100644 --- a/openhexa/toolbox/era5/README.md +++ b/openhexa/toolbox/era5/README.md @@ -1,182 +1,410 @@ # OpenHEXA Toolbox ERA5 -The package contains ETL classes and functions to acquire and process ERA5-Land data. ERA5-Land -provides hourly information of surface variables from 1950 to 5 days before the current date, with -a ~9 km spatial resolution. See [ERA5-Land: data -documentation](https://confluence.ecmwf.int/display/CKB/ERA5-Land%3A+data+documentation) for more -information. +Download and process ERA5-Land climate reanalysis data from the [Copernicus Climate Data +Store](https://www.google.com/url?sa=t&source=web&rct=j&opi=89978449&url=https://cds.climate.copernicus.eu/&ved=2ahUKEwi0x-Pl4aqQAxUnRKQEHftaGdAQFnoECBEQAQ&usg=AOvVaw1BwvwpB-Kja5hnXP6DTcbl) +(CDS). + +- [Overview](#overview) +- [Installation](#installation) +- [Supported variables](#supported-variables) +- [Usage](#usage) + - [Prepare and retrieve data requests](#prepare-and-retrieve-data-requests) + - [Move GRIB files into a Zarr store](#move-grib-files-into-a-zarr-store) + - [Read climate data from a Zarr store](#read-climate-data-from-a-zarr-store) + - [Aggregate climate data stored in a Zarr store](#aggregate-climate-data-stored-in-a-zarr-store) +- [Calculate derived variables](#calculate-derived-variables) + - [Relative humidity](#relative-humidity) + - [Wind speed](#wind-speed) +- [Tests](#tests) + +## Overview + +The package provides tools to: +- Download ERA5-Land hourly data from ECMWF's Climate Data Store +- Convert GRIB files to analysis-ready Zarr format +- Perform spatial aggregation using geographic boundaries +- Aggregate data temporally across various periods (daily, weekly, monthly, yearly) +- Support DHIS2-compatible weekly periods (standard, Wednesday, Thursday, Saturday, Sunday weeks) + +## Installation + +With pip: + +```bash +pip install openhexa.toolbox[all] +# Or +pip install openhexa.toolbox[era5] +``` -Available variables include: -* 2 metre temperature -* Wind components -* Leaf area index -* Volumetric soil water layer -* Total precipitation +With uv: -See [ERA5-Land data -documentation](https://confluence.ecmwf.int/display/CKB/ERA5-Land%3A+data+documentation#ERA5Land:datadocumentation-parameterlistingParameterlistings) -for a full list of available parameters. +```bash +uv add openhexa.toolbox --extra all +# Or +uv add openhexa.toolbox --extra era5 +``` -In addition to download clients for the Copernicus [Climate Data Store](https://cds.climate.copernicus.eu/datasets/reanalysis-era5-land?tab=overview) and [Google Public Datasets](https://cloud.google.com/storage/docs/public-datasets/era5), the package includes an `aggregate` module to aggregate ERA5 measurements in space (geographic boundaries) and time (hourly to daily). +## Supported variables -## Usage +The module supports a subset of ERA5-Land variables commonly used in health: -The package contains 3 modules: -* `openhexa.toolbox.era5.cds`: download ERA5-land products from the Copernicus [Climate Data Store](https://cds.climate.copernicus.eu/datasets/reanalysis-era5-land?tab=overview) -* `openhexa.toolbox.era5.google`: download ERA5 products from Google Cloud [Public Datasets](https://cloud.google.com/storage/docs/public-datasets/era5) -* `openhexa.toolbox.era5.aggregate`: aggregate ERA5 data in space and time +- 10m u-component of wind (`u10`) +- 10m v-component of wind (`v10`) +- 2m dewpoint temperature (`d2m`) +- 2m temperature (`t2m`) +- Runoff (`ro`) +- Soil temperature level 1 (`stl1`) +- Volumetric soil water layer 1 (`swvl1`) +- Volumetric soil water layer 2 (`swvl2`) +- Total precipitation (`tp`) +- Total evaporation (`e`) -### Download from CDS +When fetching hourly data, we sample instantaneous variable at 4 daily steps: 01:00, +07:00, 13:00 and 19:00. For accumulated variables (e.g. total precipitation), we only +retrieve totals at the end of each day. -To download products from the Climate Data Store, you will need to create an account and generate an API key in ECMWF (see [CDS](https://cds.climate.copernicus.eu/)). +See [variables.toml](/openhexa/toolbox/era5/data/variables.toml) for more details on +supported variables. -```python -from openhexa.toolbox.era5.cds import CDS, build_request, bounds_from_file +## Usage -cds = CDS(key="") +### Prepare and retrieve data requests -request = build_request( +Download ERA5-Land data from the CDS API. You'll need to set up your CDS API credentials +first (see [CDS API setup](https://cds.climate.copernicus.eu/how-to-api)) and accept the +license of the dataset you want to download. + +```python +from datetime import date +from pathlib import Path +from ecmwf.datastores import Client +from openhexa.toolbox.era5.extract import prepare_requests, retrieve_requests +import os + +client = Client(url=os.getenv("CDS_API_URL"), key=os.getenv("CDS_API_KEY")) + +# Prepare the data requests that need to be submitted to the CDS +# If data already exists in the destination zarr store, it will not be requested again +# NB: At this point, no data is moved to the Zarr store - it is used to avoid +# downloading data we already have +requests = prepare_requests( + client=client, + dataset_id="reanalysis-era5-land", + start_date=date(2025, 3, 28), + end_date=date(2025, 4, 5), variable="2m_temperature", - year=2024, - month=4, - day=[1, 2, 3], - time=[1, 6, 12, 18] + area=[10, -1, 8, 1], # [north, west, south, east] in degrees + zarr_store=Path("data/2m_temperature.zarr"), ) -cds.retrieve( - request=request, - dst_file="data/t2m.grib" +# Submit data requests and retrieve data in GRIB format as they are ready +# Depending on request size and server load, this may take a while +retrieve_requests( + client=client, + dataset_id="reanalysis-era5-land", + requests=requests, + dst_dir=Path("data/raw"), + wait=30, # Check every 30 seconds for completed requests ) ``` -The module also contains helper functions to use bounds from a geoparquet file as an area of -interest. Source bounds are buffered and rounded by default to make sure the required data is -downloaded. +### Move GRIB files into a Zarr store + +Convert downloaded GRIB files into an analysis-ready Zarr store for efficient access. ```python -bounds = bounds_from_file(fp=Path("data/districts.parquet"), buffer=0.5) - -request = build_request( - variable="total_precipitation", - year=2023, - month=10, - days=[1, 2, 3, 4, 5], - area=bounds -) +from pathlib import Path +from openhexa.toolbox.era5.extract import grib_to_zarr -cds.retrieve( - request=request, - dst_file="data/product.grib" +grib_to_zarr( + src_dir=Path("data/raw"), + zarr_store=Path("data/2m_temperature.zarr"), + data_var="t2m", # Short name for 2m temperature ) ``` -To download multiple products for a given period, use `Client.download_between()`: +### Read climate data from a Zarr store + +Data is stored in [Zarr](https://zarr.dev/) stores for efficient storage and access of +climate variables as N-dimensional arrays. You can read data in Zarr stores using +[xarray](https://xarray.dev/). + +When opening a Zarr store, no data is loaded into memory yet. You can check the dataset +structure without loading the data. ```python -cds.download_between( - variable="2m_temperature", - start=datetime(2020, 1, 1, tzinfo=timezone.utc), - end=datetime(2021, 6, 1, tzinfo=timezone.utc), - dst_dir="data/raw/2m_temperature", - area=bounds +import xarray as xr + +ds = xr.open_zarr("data/2m_temperature.zarr", consolidated=True) +print(ds) +``` +``` + Size: 7MB +Dimensions: (latitude: 71, longitude: 91, time: 284) +Coordinates: + * latitude (latitude) float64 568B 16.0 15.9 15.8 15.7 ... 9.3 9.2 9.1 9.0 + * longitude (longitude) float64 728B -6.0 -5.9 -5.8 -5.7 ... 2.7 2.8 2.9 3.0 + * time (time) datetime64[ns] 2kB 2024-10-01T01:00:00 ... 2024-12-10T1... +Data variables: + t2m (latitude, longitude, time) float32 7MB ... +Attributes: + Conventions: CF-1.7 + GRIB_centre: ecmf + GRIB_centreDescription: European Centre for Medium-Range Weather Forecasts + GRIB_edition: 1 + GRIB_subCentre: 0 + history: 2025-10-14T09:02 GRIB to CDM+CF via cfgrib-0.9.1... + institution: European Centre for Medium-Range Weather Forecasts +``` + +You can use real dates and coordinates to index the data. + +```python +import xarray as xr + +t2m = xr.open_zarr("data/2m_temperature.zarr", consolidated=True) +t2m_daily_mean = t2m.resample(time="1D").mean() +t2m_daily_mean.mean(dim=["latitude", "longitude"]).t2m.plot.line() +``` + +![ERA5 2m Temperature Daily Mean](/docs/images/era5_t2m_lineplot.png) + +### Aggregate climate data stored in a Zarr store + +Aggregate hourly climate data by administrative boundaries and time periods. + +```python +from pathlib import Path +import geopandas as gpd +import xarray as xr +from openhexa.toolbox.era5.transform import ( + create_masks, + aggregate_in_space, + aggregate_in_time, + Period, ) + +t2m = xr.open_zarr("./2m_temperature.zarr", consolidated=True, decode_timedelta=False) ``` -Checking latest available date in the ERA5-Land dataset: +For instantaneous variables (e.g. 2m temperature, soil moisture...), hourly data should +be aggregated to daily 1st. In ERA5-Land data, data is structured along 2 temporal +dimensions: `time` and `step`. To aggregate hourly data to daily, you need to average over +the `step` dimension: ```python -cds = CDS("") +t2m_daily = t2m.mean(dim="step") -cds.latest +# or to compute daily extremes +t2m_daily_max = t2m.max(dim="step") +t2m_daily_min = t2m.min(dim="step") ``` + +```python +import matplotlib.pyplot as plt + +plt.imshow( + t2m_daily.sel(time="2024-10-04").t2m, + cmap="coolwarm", +) +plt.colorbar(label="Temperature (°C)", shrink=0.8) +plt.axis("off") ``` ->>> datetime(2024, 10, 8) +![2m temperature raster](/docs/images/era5_t2m_raster.png) + +The module provides helper functions to help you perform spatial aggregation on gridded +ERA5 data. Use the `create_masks()` function to create raster masks from vector +boundaries. Raster masks uses the same grid as the ERA5 dataset. + +```python +import geopandas as gpd +from openhexa.toolbox.era5.transform import create_masks + +# Boundaries geographic file should use EPSG:4326 coordinate reference system (lat/lon) +boundaries = gpd.read_file("boundaries.gpkg") + +masks = create_masks( + gdf=boundaries, + id_column="district_id", # Column in the GeoDataFrame with unique boundary IDs + ds=t2m_daily, +) ``` -NB: End dates in product requests will be automatically replaced by latest available date if they are greater. +Example of raster mask for 1 vector boundary: -### Download from Google Cloud +![Boundary vector](/docs/images/era5_boundary_vector.png) +![Boundary raster mask](/docs/images/era5_boundary_raster.png) + +You can now aggregate daily gridded ERA5 data in space and time: ```python -from openhexa.toolbox.era5.google import Client +from openhexa.toolbox.era5.transform import aggregate_in_space, aggregate_in_time, Period -google = Client() +# convert from Kelvin to Celsius +t2m_daily = t2m_daily - 273.15 -google.download( - variable="2m_temperature", - date=datetime(2024, 6, 15), - dst_file="data/product.nc" +t2m_agg = aggregate_in_space( + ds=t2m_daily, + masks=masks, + variable="t2m", + agg="mean", ) +print(t2m_agg) ``` -Or to download all products for a given period: +``` +shape: (4_970, 3) +┌─────────────┬────────────┬───────────┐ +│ boundary ┆ time ┆ value │ +│ --- ┆ --- ┆ --- │ +│ str ┆ date ┆ f64 │ +╞═════════════╪════════════╪═══════════╡ +│ mPenE8ZIBFC ┆ 2024-10-01 ┆ 26.534632 │ +│ mPenE8ZIBFC ┆ 2024-10-02 ┆ 25.860088 │ +│ mPenE8ZIBFC ┆ 2024-10-03 ┆ 26.068018 │ +│ mPenE8ZIBFC ┆ 2024-10-04 ┆ 26.103462 │ +│ mPenE8ZIBFC ┆ 2024-10-05 ┆ 24.362678 │ +│ … ┆ … ┆ … │ +│ eKYyXbBdvmB ┆ 2024-12-06 ┆ 25.130324 │ +│ eKYyXbBdvmB ┆ 2024-12-07 ┆ 24.946449 │ +│ eKYyXbBdvmB ┆ 2024-12-08 ┆ 24.840832 │ +│ eKYyXbBdvmB ┆ 2024-12-09 ┆ 25.242334 │ +│ eKYyXbBdvmB ┆ 2024-12-10 ┆ 26.697817 │ +└─────────────┴────────────┴───────────┘ +``` + +Likewise, to aggregate in time (e.g. weekly averages): ```python -# if products are already presents in dst_dir, they will be skipped -google.sync( - variable="2m_temperature", - start_date=datetime(2022, 1, 1), - end_date=datetime(2022, 6, 1), - dst_dir="data" +t2m_weekly = aggregate_in_time( + dataframe=t2m_agg, + period=Period.WEEK, + agg="mean", ) +print(t2m_weekly) +``` + +``` +shape: (770, 3) +┌─────────────┬─────────┬───────────┐ +│ boundary ┆ period ┆ value │ +│ --- ┆ --- ┆ --- │ +│ str ┆ str ┆ f64 │ +╞═════════════╪═════════╪═══════════╡ +│ AKVCJJ2TKSi ┆ 2024W40 ┆ 27.33611 │ +│ AKVCJJ2TKSi ┆ 2024W41 ┆ 27.011093 │ +│ AKVCJJ2TKSi ┆ 2024W42 ┆ 27.905081 │ +│ AKVCJJ2TKSi ┆ 2024W43 ┆ 28.239824 │ +│ AKVCJJ2TKSi ┆ 2024W44 ┆ 27.34595 │ +│ … ┆ … ┆ … │ +│ yhs1ecKsLOc ┆ 2024W46 ┆ 27.711391 │ +│ yhs1ecKsLOc ┆ 2024W47 ┆ 26.394333 │ +│ yhs1ecKsLOc ┆ 2024W48 ┆ 24.863514 │ +│ yhs1ecKsLOc ┆ 2024W49 ┆ 24.714464 │ +│ yhs1ecKsLOc ┆ 2024W50 ┆ 24.923738 │ +└─────────────┴─────────┴───────────┘ +``` + +Or per week starting on Sundays: + +``` python +t2m_sunday_week = aggregate_in_time( + dataframe=t2m_agg, + period=Period.WEEK_SUNDAY, + agg="mean", +) +print(t2m_sunday_week) +``` + +``` +shape: (770, 3) +┌─────────────┬────────────┬───────────┐ +│ boundary ┆ period ┆ value │ +│ --- ┆ --- ┆ --- │ +│ str ┆ str ┆ f64 │ +╞═════════════╪════════════╪═══════════╡ +│ AKVCJJ2TKSi ┆ 2024SunW40 ┆ 27.898345 │ +│ AKVCJJ2TKSi ┆ 2024SunW41 ┆ 26.483939 │ +│ AKVCJJ2TKSi ┆ 2024SunW42 ┆ 27.9347 │ +│ AKVCJJ2TKSi ┆ 2024SunW43 ┆ 28.291441 │ +│ AKVCJJ2TKSi ┆ 2024SunW44 ┆ 27.510819 │ +│ … ┆ … ┆ … │ +│ yhs1ecKsLOc ┆ 2024SunW46 ┆ 27.691862 │ +│ yhs1ecKsLOc ┆ 2024SunW47 ┆ 26.316256 │ +│ yhs1ecKsLOc ┆ 2024SunW48 ┆ 25.249807 │ +│ yhs1ecKsLOc ┆ 2024SunW49 ┆ 24.751227 │ +│ yhs1ecKsLOc ┆ 2024SunW50 ┆ 24.542277 │ +└─────────────┴────────────┴───────────┘ ``` -### Aggregation +Or per month: + +``` python +t2m_monthly = aggregate_in_time( + dataframe=t2m_agg, + period=Period.MONTH, + agg="mean", +) +print(t2m_monthly) +``` + +``` +shape: (210, 3) +┌─────────────┬────────┬───────────┐ +│ boundary ┆ period ┆ value │ +│ --- ┆ --- ┆ --- │ +│ str ┆ str ┆ f64 │ +╞═════════════╪════════╪═══════════╡ +│ AKVCJJ2TKSi ┆ 202410 ┆ 27.615368 │ +│ AKVCJJ2TKSi ┆ 202411 ┆ 26.527692 │ +│ AKVCJJ2TKSi ┆ 202412 ┆ 25.080745 │ +│ AVb6wBstPAo ┆ 202410 ┆ 29.747595 │ +│ AVb6wBstPAo ┆ 202411 ┆ 26.137431 │ +│ … ┆ … ┆ … │ +│ vQ6AJUeqBpc ┆ 202411 ┆ 25.915338 │ +│ vQ6AJUeqBpc ┆ 202412 ┆ 23.130632 │ +│ yhs1ecKsLOc ┆ 202410 ┆ 29.050539 │ +│ yhs1ecKsLOc ┆ 202411 ┆ 26.628291 │ +│ yhs1ecKsLOc ┆ 202412 ┆ 24.688542 │ +└─────────────┴────────┴───────────┘ +``` + +Note that the period column uses DHIS2 format (e.g. `2024W40` for week 40 of 2024). + +## Calculate derived variables + +### Relative humidity + +You can compute relative humidity from 2m temperature and 2m dewpoint temperature. ```python -from pathlib import Path +from openhexa.toolbox.era5.transform import calculate_relative_humidity -import geopandas as gpd -from openhexa.toolbox.era5.aggregate import build_masks, merge, aggregate, get_transform +rh = calculate_relative_humidity( + t2m=t2m_daily, + d2m=d2m_daily, +) +``` -boundaries = gpd.read_parquet("districts.parquet") -data_dir = Path("data/era5/total_precipitation") +### Wind speed -ds = merge(data_dir) +You can compute wind speed from the 10m u-component and v-component of wind. -ncols = len(ds.longitude) -nrows = len(ds.latitude) -transform = get_transform(ds) -masks = build_masks(boundaries, nrows, ncols, transform) +```python +from openhexa.toolbox.era5.transform import calculate_wind_speed -df = aggregate( - ds=ds, - var="tp", - masks=masks, - boundaries_id=[uid for uid in boundaries["district_id"]] +ws = calculate_wind_speed( + u10=u10_daily, + v10=v10_daily, ) +``` + +## Tests + +The module uses Pytest. To run tests, install development dependencies and execute +Pytest in the virtual environment. -print(df) -``` -``` -shape: (18_410, 5) -┌─────────────┬────────────┬───────────┬──────────┬───────────┐ -│ boundary_id ┆ date ┆ mean ┆ min ┆ max │ -│ --- ┆ --- ┆ --- ┆ --- ┆ --- │ -│ str ┆ date ┆ f64 ┆ f64 ┆ f64 │ -╞═════════════╪════════════╪═══════════╪══════════╪═══════════╡ -│ mPenE8ZIBFC ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ TPgpGxUBU9y ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ AhST5ZpuCDJ ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ Lp2BjBVT63s ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ EdfRX9b9vEb ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ yhs1ecKsLOc ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ iHSJypSwlo5 ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ CTtB0TPRvWc ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ eVFAuZOzogt ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ WVEJjdJ2S15 ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ rbYGKFgupK9 ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ Nml6rVDElLh ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ E0hd8TD1M0q ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ PCg4pLGmKSM ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ C6EBhE8OnfW ┆ 2024-01-01 ┆ 0.000462 ┆ 0.0 ┆ 0.00086 │ -│ … ┆ … ┆ … ┆ … ┆ … │ -│ CkpfOFkMyrd ┆ 2024-10-07 ┆ 1.883121 ┆ 0.001785 ┆ 2.700447 │ -│ tMXsltjzzmR ┆ 2024-10-07 ┆ 3.579136 ┆ 0.105436 ┆ 4.702504 │ -│ F0ytkh0RExg ┆ 2024-10-07 ┆ 8.415455 ┆ 0.838535 ┆ 17.08884 │ -... -│ TTSmaRnHa82 ┆ 2024-10-07 ┆ 1.724243 ┆ 0.007809 ┆ 5.692989 │ -│ jbmw2gdrrTV ┆ 2024-10-07 ┆ 1.176629 ┆ 0.110173 ┆ 1.582995 │ -│ eKYyXbBdvmB ┆ 2024-10-07 ┆ 0.599976 ┆ 0.037771 ┆ 1.189411 │ -└─────────────┴────────────┴───────────┴──────────┴───────────┘ +```bash +uv sync --dev +uv run pytest tests/era5/* ``` \ No newline at end of file diff --git a/openhexa/toolbox/era5/aggregate.py b/openhexa/toolbox/era5/aggregate.py deleted file mode 100644 index ad0075b4..00000000 --- a/openhexa/toolbox/era5/aggregate.py +++ /dev/null @@ -1,370 +0,0 @@ -"""Module for spatial and temporal aggregation of ERA5 data.""" - -from datetime import datetime -from pathlib import Path - -import geopandas as gpd -import numpy as np -import polars as pl -import rasterio -import xarray as xr -from epiweeks import Week -from rasterio.features import rasterize -from rasterio.transform import Affine, from_bounds - - -def clip_dataset(ds: xr.Dataset, xmin: float, ymin: float, xmax: float, ymax: float) -> xr.Dataset: - """Clip input xarray dataset according to the provided bounding box. - - Assumes lat & lon dimensions are named "latitude" and "longitude". Longitude in the - source dataset is expected to be in the range [0, 360], and will be converted to - [-180, 180]. - - Parameters - ---------- - ds : xr.Dataset - Input xarray dataset. - xmin : float - Minimum longitude. - ymin : float - Minimum latitude. - xmax : float - Maximum longitude. - ymax : float - Maximum latitude. - - Returns - ------- - xr.Dataset - Clipped xarray dataset. - """ - ds = ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude") - ds = ds.where((ds.longitude >= xmin) & (ds.longitude <= xmax), drop=True) - ds = ds.where((ds.latitude >= ymin) & (ds.latitude <= ymax), drop=True) - return ds - - -def get_transform(ds: xr.Dataset) -> Affine: - """Get rasterio affine transform from xarray dataset. - - Parameters - ---------- - ds : xr.Dataset - Input xarray dataset. - - Returns - ------- - Affine - Rasterio affine transform. - """ - transform = from_bounds( - ds.longitude.values.min(), - ds.latitude.values.min(), - ds.longitude.values.max(), - ds.latitude.values.max(), - len(ds.longitude), - len(ds.latitude), - ) - return transform - - -def build_masks( - boundaries: gpd.GeoDataFrame, height: int, width: int, transform: rasterio.Affine -) -> tuple[np.ndarray, rasterio.Affine]: - """Build binary masks for all geometries in a dataframe. - - We build a raster of shape (n_boundaries, n_height, n_width) in order to store one binary mask - per boundary. Boundaries shapes cannot be stored in a single array as we want masks to overlap - if needed. - - Parameters - ---------- - boundaries : gpd.GeoDataFrame - Input GeoDataFrame containing the boundaries. - height : int - Height of the raster (number of pixels) - width : int - Width of the raster (number of pixels) - transform : rasterio.Affine - Raster affine transform - - Returns - ------- - np.ndarray - Binary masks as a numpy ndarray of shape (n_boundaries, height, width) - """ - masks = np.ndarray(shape=(len(boundaries), height, width), dtype=np.bool_) - for i, geom in enumerate(boundaries.geometry): - mask = rasterize( - shapes=[geom.__geo_interface__], - out_shape=(height, width), - fill=0, - default_value=1, - all_touched=True, - transform=transform, - ) - masks[i, :, :] = mask == 1 - return masks - - -def merge(data_dir: Path | str) -> xr.Dataset: - """Merge all .grib files in a directory into a single xarray dataset. - - If multiple values are available for a given time, step, longitude & latitude dimensions, the - maximum value is kept. - - Parameters - ---------- - data_dir : Path | str - Directory containing the .grib files. - - Returns - ------- - xr.Dataset - Merged xarray dataset with time, step, longitude and latitude dimensions. - """ - if isinstance(data_dir, str): - data_dir = Path(data_dir) - - files = data_dir.glob("*.grib") - ds = xr.open_dataset(next(files), engine="cfgrib", decode_timedelta=True) - if "time" not in ds.dims and "time" in ds.coords: - # xarray drop the time dimension if it has only one value - ds = ds.expand_dims("time") - - for f in files: - ds2 = xr.open_dataset(f, engine="cfgrib", decode_timedelta=True) - if "time" not in ds2.dims and "time" in ds2.coords: - ds2 = ds2.expand_dims("time") - ds = xr.concat([ds, ds2], dim="tmp_dim").max(dim="tmp_dim") - - return ds - - -def _np_to_datetime(dt64: np.datetime64) -> datetime: - epoch = np.datetime64(0, "s") - one_second = np.timedelta64(1, "s") - seconds_since_epoch = (dt64 - epoch) / one_second - return datetime.fromtimestamp(seconds_since_epoch) - - -def _has_missing_data(da: xr.DataArray) -> bool: - """A DataArray is considered to have missing data if not all hours have measurements.""" - missing = False - - # if da.step.size == 1, da.step is just an int so we cannot iterate over it - # if da.step size > 1, da.step is an array of int (one per step) - if da.step.size > 1: - for step in da.step: - if da.sel(step=step).isnull().all(): - missing = True - else: - missing = da.isnull().all() - - return missing - - -def _week(date: datetime) -> str: - year = date.isocalendar()[0] - week = date.isocalendar()[1] - return f"{year}W{week}" - - -def _epi_week(date: datetime) -> str: - epiweek = Week.fromdate(date) - year = epiweek.year - week = epiweek.week - return f"{year}W{week}" - - -def _month(date: datetime) -> str: - return date.strftime("%Y%m") - - -def aggregate(ds: xr.Dataset, var: str, masks: np.ndarray, boundaries_id: list[str]) -> pl.DataFrame: - """Aggregate hourly measurements in space and time. - - Parameters - ---------- - ds : xr.Dataset - Input xarray dataset with time, step, longitude and latitude dimensions - var : str - Variable to aggregate (ex: "t2m" or "tp") - masks : np.ndarray - Binary masks as a numpy ndarray of shape (n_boundaries, height, width) - boundaries_id : list[str] - List of boundary IDs (same order as n_boundaries dimension in masks) - - Notes - ----- - The function aggregates hourly measurements to daily values for each boundary. - - Temporal aggregation is applied first. 3 statistics are computed for each day: daily mean, - daily min, and daily max. - - Spatial aggregation is then applied. For each boundary, 3 statistics are computed: average of - daily means, average of daily min, and average of daily max. These 3 statistics are stored in - the "mean", "min", and "max" columns of the output dataframe. - """ - rows = [] - - for day in ds.time.values: - da = ds[var].sel(time=day) - - if _has_missing_data(da): - continue - - # if there is a step dimension (= hourly measurements), aggregate to daily - # if not, data is already daily - if "step" in da.dims: - da_mean = da.mean(dim="step").values - da_min = da.min(dim="step").values - da_max = da.max(dim="step").values - else: - da_mean = da.values - da_min = da.values - da_max = da.values - - for i, uid in enumerate(boundaries_id): - v_mean = np.nanmean(da_mean[masks[i, :, :]]) - v_min = np.nanmin(da_min[masks[i, :, :]]) - v_max = np.nanmax(da_max[masks[i, :, :]]) - - rows.append( - { - "boundary_id": uid, - "date": _np_to_datetime(day).date(), - "mean": v_mean, - "min": v_min, - "max": v_max, - } - ) - - SCHEMA = { - "boundary_id": pl.String, - "date": pl.Date, - "mean": pl.Float64, - "min": pl.Float64, - "max": pl.Float64, - } - - df = pl.DataFrame(data=rows, schema=SCHEMA) - - # add week, month, and epi_week period columns - df = df.with_columns( - pl.col("date").map_elements(_week, str).alias("week"), - pl.col("date").map_elements(_month, str).alias("month"), - pl.col("date").map_elements(_epi_week, str).alias("epi_week"), - ) - - return df - - -def aggregate_per_week( - daily: pl.DataFrame, - column_uid: str, - use_epidemiological_weeks: bool = False, - sum_aggregation: bool = False, -) -> pl.DataFrame: - """Aggregate daily data per week. - - Parameters - ---------- - daily : pl.DataFrame - Daily data with a "week" or "epi_week", "mean", "min", and "max" columns - Length of the dataframe should be (n_boundaries * n_days). - column_uid : str - Column containing the boundary ID. - use_epidemiological_weeks : bool, optional - Use epidemiological weeks instead of iso weeks. - sum_aggregation : bool, optional - If True, sum values instead of computing the mean, for example for total precipitation data. - - Returns - ------- - pl.DataFrame - Weekly aggregated data of length (n_boundaries * n_weeks). - """ - if use_epidemiological_weeks: - week_column = "epi_week" - else: - week_column = "week" - - df = daily.select([column_uid, pl.col(week_column).alias("week"), "mean", "min", "max"]) - - if sum_aggregation: - df = df.group_by([column_uid, "week"]).agg( - [ - pl.col("mean").sum().alias("mean"), - pl.col("min").sum().alias("min"), - pl.col("max").sum().alias("max"), - ] - ) - else: - df = df.group_by([column_uid, "week"]).agg( - [ - pl.col("mean").mean().alias("mean"), - pl.col("min").min().alias("min"), - pl.col("max").max().alias("max"), - ] - ) - - # sort per date since dhis2 period format is "2012W9", we need to extract year and week number - # from the period string and cast them to int before sorting, else "2012W9" will be superior to - # "2012W32" - df = df.sort( - by=[ - pl.col("week").str.split("W").list.get(0).cast(int), - pl.col("week").str.split("W").list.get(1).cast(int), - pl.col(column_uid), - ] - ) - - return df - - -def aggregate_per_month(daily: pl.DataFrame, column_uid: str, sum_aggregation: bool = False) -> pl.DataFrame: - """Aggregate daily data per month. - - Parameters - ---------- - daily : pl.DataFrame - Daily data with a "month", "mean", "min", and "max" columns - Length of the dataframe should be (n_boundaries * n_days). - column_uid : str - Column containing the boundary ID. - sum_aggregation : bool, optional - If True, sum values instead of computing the mean, for example for total precipitation data. - - Returns - ------- - pl.DataFrame - Monthly aggregated data of length (n_boundaries * n_months). - """ - df = daily.select([column_uid, "month", "mean", "min", "max"]) - - if sum_aggregation: - df = df.group_by([column_uid, "month"]).agg( - [ - pl.col("mean").sum().alias("mean"), - pl.col("min").sum().alias("min"), - pl.col("max").sum().alias("max"), - ] - ) - else: - df = df.group_by([column_uid, "month"]).agg( - [ - pl.col("mean").mean().alias("mean"), - pl.col("min").min().alias("min"), - pl.col("max").max().alias("max"), - ] - ) - - df = df.sort( - by=[ - pl.col("month").cast(int), - pl.col(column_uid), - ] - ) - - return df diff --git a/openhexa/toolbox/era5/cache.py b/openhexa/toolbox/era5/cache.py new file mode 100644 index 00000000..796e09ab --- /dev/null +++ b/openhexa/toolbox/era5/cache.py @@ -0,0 +1,195 @@ +import gzip +import hashlib +import json +import shutil +from dataclasses import dataclass +from pathlib import Path + +import psycopg +from psycopg.rows import class_row + +from openhexa.toolbox.era5.models import Request + + +@dataclass +class CacheEntry: + job_id: str + file_name: str | None + + +def _hash_data_request(request: Request) -> str: + """Convert data request dict into MD5 hash.""" + json_str = json.dumps(request, sort_keys=True) + return hashlib.md5(json_str.encode()).hexdigest() + + +class Cache: + """Cache data requests using PostgreSQL.""" + + def __init__(self, database_uri: str, cache_dir: Path): + """Initialize cache in database. + + Args: + database_uri: URI of the PostgreSQL database, e.g. + "postgresql://user:password@host:port/dbname". + cache_dir: Directory to store downloaded GRIB files. + """ + self.database_uri = database_uri + self.cache_dir = cache_dir + self._init_db() + self._init_cache_dir() + + def _init_db(self) -> None: + """Create schema and table if they do not exist.""" + with psycopg.connect(self.database_uri) as conn: + with conn.cursor() as cur: + cur.execute("create schema if not exists era5") + cur.execute( + """ + create table if not exists era5.cds_cache ( + cache_key varchar(32) primary key, + request json not null, + job_id varchar(64) not null, + file_name text, + created_at timestamp not null default now(), + updated_at timestamp not null default now(), + expires_at timestamp + ) + """ + ) + cur.execute( + """ + create index if not exists idx_cds_cache_expires + on era5.cds_cache(expires_at) + """ + ) + conn.commit() + + def _init_cache_dir(self) -> None: + """Create cache directory if it does not exist.""" + self.cache_dir.mkdir(parents=True, exist_ok=True) + + def _archive(self, src_fp: Path) -> None: + """Archive a GRIB file using gzip. + + Args: + src_fp: Path to the source GRIB file. + """ + dst_fp = self.cache_dir / f"{src_fp.name}.gz" + with open(src_fp, "rb") as src_f: + with gzip.open(dst_fp, "wb", compresslevel=9) as dst_f: + shutil.copyfileobj(src_f, dst_f) + + def retrieve(self, file_name: str, dst_fp: Path) -> None: + """Retrieve a GRIB file from a gzip archive. + + Args: + file_name: The name of the cached GRIB file. + dst_fp: Path to the destination GRIB file. + """ + src_fp = self.cache_dir / file_name + if not src_fp.exists(): + raise FileNotFoundError(f"Cached file not found: {src_fp}") + with gzip.open(src_fp, "rb") as src_f: + with open(dst_fp, "wb") as dst_f: + shutil.copyfileobj(src_f, dst_f) + + def set(self, request: Request, job_id: str, file_path: Path | None = None) -> None: + """Store a data request in the cache. + + Data request info and metadata are stored in the database. If available, the + downloaded GRIB file is archived in the cache directory. + + Args: + request: The data request parameters. + job_id: The ID of the corresponding CDS job. + file_path: Optional path to the downloaded file to be cached. + """ + if file_path: + self._archive(file_path) + file_name = f"{file_path.name}.gz" + else: + file_name = None + + cache_key = _hash_data_request(request) + with psycopg.connect(self.database_uri) as conn: + with conn.cursor() as cur: + cur.execute( + """ + insert into era5.cds_cache ( + cache_key, request, job_id, file_name + ) values (%s, %s, %s, %s) + on conflict (cache_key) do update set + job_id = excluded.job_id, + file_name = excluded.file_name, + updated_at = now() + """, + (cache_key, json.dumps(request), job_id, file_name), + ) + conn.commit() + + def get(self, request: Request) -> CacheEntry | None: + """Retrieve a data request from the cache. + + Args: + request: The data request parameters. + """ + cache_key = _hash_data_request(request) + with psycopg.connect(self.database_uri) as conn: + with conn.cursor(row_factory=class_row(CacheEntry)) as cur: + cur.execute( + """ + select job_id, file_name from era5.cds_cache + where cache_key = %s + """, + (cache_key,), + ) + return cur.fetchone() + + def clean_expired_jobs(self, job_ids: list[str]) -> None: + """Remove cache entries associated with expired jobs. + + NB: The entry is only removed if the associated GRIB file has not been archived + yet. + + Args: + job_ids: The IDs of the expired CDS jobs. + """ + with psycopg.connect(self.database_uri) as conn: + with conn.cursor() as cur: + cur.execute( + """ + delete from era5.cds_cache + where job_id = any(%s) and file_name is null + """, + (job_ids,), + ) + conn.commit() + + def clean_missing_files(self) -> None: + """Remove cache entries with missing archived files.""" + with psycopg.connect(self.database_uri) as conn: + with conn.cursor(row_factory=class_row(CacheEntry)) as cur: + cur.execute( + """ + select job_id, file_name from era5.cds_cache + where file_name is not null + """ + ) + entries = cur.fetchall() + + missing_job_ids: list[str] = [] + for entry in entries: + if entry.file_name and not (self.cache_dir / entry.file_name).exists(): + missing_job_ids.append(entry.job_id) + + if missing_job_ids: + with conn.cursor() as cur: + cur.execute( + """ + delete from era5.cds_cache + where job_id = any(%s) + """, + (missing_job_ids,), + ) + conn.commit() diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py deleted file mode 100644 index 92bbcd4f..00000000 --- a/openhexa/toolbox/era5/cds.py +++ /dev/null @@ -1,463 +0,0 @@ -"""Client to download ERA5-Land data products from the climate data store. - -See . -""" - -from __future__ import annotations - -import importlib.resources -import json -import logging -import shutil -import tempfile -import zipfile -from calendar import monthrange -from dataclasses import dataclass -from datetime import datetime, timedelta, timezone -from functools import cached_property -from math import ceil -from pathlib import Path -from time import sleep -from typing import Iterator - -import geopandas as gpd -import xarray as xr -from datapi import ApiClient, Remote -from requests.exceptions import HTTPError - -with importlib.resources.open_text("openhexa.toolbox.era5", "variables.json") as f: - VARIABLES = json.load(f) - -DATASET = "reanalysis-era5-land" - -log = logging.getLogger(__name__) - - -@dataclass -class DataRequest: - """CDS data request as expected by the API.""" - - variable: list[str] - year: str - month: str - day: list[str] - time: list[str] - data_format: str = "grib" - area: list[float] | None = None - - -def bounds_from_file(fp: Path, buffer: float = 0.5) -> list[float]: - """Get bounds from file. - - Parameters - ---------- - fp : Path - File path. - buffer : float, optional - Buffer to add to the bounds (default=0.5). - - Returns - ------- - list[float] - Bounds (north, west, south, east). - """ - boundaries = gpd.read_parquet(fp) - xmin, ymin, xmax, ymax = boundaries.total_bounds - xmin = ceil(xmin - buffer) - ymin = ceil(ymin - buffer) - xmax = ceil(xmax + buffer) - ymax = ceil(ymax + buffer) - return ymax, xmin, ymin, xmax - - -def get_period_chunk(dtimes: list[datetime]) -> dict: - """Get the period chunk for a list of datetimes. - - The period chunk is a dictionary with the "year", "month", "day" and "time" keys as expected by - the CDS API. A period chunk cannot contain more than 1 year and 1 month. However, it can - contain any number of days and times. - - This is the temporal part of a CDS data request. - - Parameters - ---------- - dtimes : list[datetime] - A list of datetimes for which we want data - - Returns - ------- - dict - The period chunk, in other words the temporal part of the request payload - - Raises - ------ - ValueError - If the list of datetimes contains more than 1 year or more than 1 month - """ - years = {dtime.year for dtime in dtimes} - if len(years) > 1: - msg = "Cannot create a period chunk for multiple years" - raise ValueError(msg) - months = {dtime.month for dtime in dtimes} - if len(months) > 1: - msg = "Cannot create a period chunk for multiple months" - raise ValueError(msg) - - year = next(iter(years)) - month = next(iter(months)) - days = [] - - for dtime in sorted(dtimes): - if dtime.day not in days: - days.append(dtime.day) - - return { - "year": str(year), - "month": f"{month:02}", - "day": [f"{day:02}" for day in days], - } - - -def iter_chunks(dtimes: list[datetime]) -> Iterator[dict]: - """Get the period chunks for a list of datetimes. - - The period chunks are a list of dictionaries with the "year", "month", "day" and "time" keys as - expected by the CDS API. A period chunk cannot contain more than 1 year and 1 month. However, - it can contain any number of days and times. - - The function tries its best to generate the minimum amount of chunks to minimize the amount of requests. - - Parameters - ---------- - dtimes : list[datetime] - A list of datetimes for which we want data - - Returns - ------- - Iterator[dict] - The period chunks (one per month max) - """ - for year in range(min(dtimes).year, max(dtimes).year + 1): - for month in range(12): - dtimes_month = [dtime for dtime in dtimes if dtime.year == year and dtime.month == month + 1] - if dtimes_month: - yield get_period_chunk(dtimes_month) - - -def list_datetimes_in_dataset(ds: xr.Dataset) -> list[datetime]: - """List datetimes in input dataset for which data is available. - - It is assumed that the dataset has a `time` dimension, in addition to `latitude` and `longitude` - dimensions. We consider that a datetime is available in a dataset if non-null data values are - present for more than 1 step. - """ - dtimes = [] - data_vars = list(ds.data_vars) - var = data_vars[0] - - for time in ds.time.values: - dtime = datetime.fromtimestamp(time.astype(int) / 1e9, tz=timezone.utc) - if dtime in dtimes: - continue - non_null = ds.sel(time=time)[var].notnull().sum().values.item() - if non_null > 0: - dtimes.append(dtime) - - return dtimes - - -def list_datetimes_in_dir(data_dir: Path) -> list[datetime]: - """List datetimes in datasets that can be found in an input directory.""" - # make sure all grib files are decompressed and index files are removed - decompress_grib_files(data_dir) - remove_index_files(data_dir) - - dtimes = [] - - for f in data_dir.glob("*.grib"): - ds = xr.open_dataset(f, engine="cfgrib", decode_timedelta=True) - # xarray drop the time dimension if it has only one value, so we expand it - # to make sure structure is consistent with other datasets - if "time" not in ds.dims and "time" in ds.coords: - ds = ds.expand_dims("time") - dtimes += list_datetimes_in_dataset(ds) - - dtimes = sorted(set(dtimes)) - - msg = f"Scanned {data_dir.as_posix()}, found data for {len(dtimes)} dates" - log.info(msg) - - return dtimes - - -def date_range(start: datetime, end: datetime) -> list[datetime]: - """Get a range of dates with a 1-day step.""" - drange = [] - dt = start - while dt <= end: - drange.append(dt) - dt += timedelta(days=1) - return drange - - -def build_request( - variable: str, - year: int, - month: int, - day: list[int] | list[str] | None = None, - time: list[int] | list[str] | None = None, - data_format: str = "grib", - area: list[float] | None = None, -) -> DataRequest: - """Build request payload. - - Parameters - ---------- - variable : str - Climate data store variable name (ex: "2m_temperature"). - year : int - Year of interest. - month : int - Month of interest. - day : list[int] | list[str] | None, optional - Days of interest. Defauls to None (all days). - time : list[int] | list[str] | None, optional - Hours of interest (ex: [1, 6, 18]). Defaults to None (all hours). - data_format : str, optional - Output data format ("grib" or "netcdf"). Defaults to "grib". - area : list[float] | None, optional - Area of interest (north, west, south, east). Defaults to None (world). - - Returns - ------- - DataRequest - CDS data equest payload. - - Raises - ------ - ValueError - Request parameters are not valid. - """ - if variable not in VARIABLES: - msg = f"Variable {variable} not supported" - raise ValueError(msg) - - if data_format not in ["grib", "netcdf"]: - msg = f"Data format {data_format} not supported" - raise ValueError(msg) - - # in the CDS data request, area is an array of float or int in the following order: - # [north, west, south, east] - if area: - n, w, s, e = area - msg = "Invalid area of interest" - max_lat = 90 - max_lon = 180 - if ((abs(n) > max_lat) or (abs(s) > max_lat)) or ((abs(w) > max_lon) or (abs(e) > max_lon)): - raise ValueError(msg) - if (n < s) or (e < w): - raise ValueError(msg) - - # in the CDS data request, days must be an array of strings (one string per day) - # ex: ["01", "02", "03"] - if not day: - dmax = monthrange(year, month)[1] - day = list(range(1, dmax + 1)) - - if isinstance(day[0], int): - day = [f"{d:02}" for d in day] - - # in the CDS data request, time must be an array of strings (one string per hour) - # only hours between 00:00 and 23:00 are valid - # ex: ["00:00", "03:00", "06:00"] - if not time: - time = range(24) - - if isinstance(time[0], int): - time = [f"{hour:02}:00" for hour in time] - - return DataRequest( - variable=[variable], - year=str(year), - month=f"{month:02}", - day=day, - time=time, - data_format="grib", - area=list(area) if area else None, - ) - - -class CDS: - """Climate data store API client based on datapi.""" - - def __init__(self, key: str, url: str = "https://cds.climate.copernicus.eu/api") -> None: - """Initialize CDS client.""" - self.client = ApiClient(key=key, url=url) - self.client.check_authentication() - msg = f"Sucessfully authenticated to {url}" - log.info(msg) - - @cached_property - def latest(self) -> datetime: - """Get date of latest available product.""" - collection = self.client.get_collection(DATASET) - return collection.end_datetime - - def get_remote_requests(self) -> list[dict]: - """Fetch list of the last 100 data requests in the CDS account.""" - requests = [] - jobs = self.client.get_jobs(limit=100) - for request_id in jobs.request_uids: - try: - remote = self.client.get_remote(request_id) - if remote.status in ["failed", "dismissed", "deleted"]: - continue - requests.append({"request_id": request_id, "request": remote.request}) - except HTTPError: - continue - return requests - - def get_remote_from_request(self, request: DataRequest, existing_requests: list[dict]) -> Remote | None: - """Look for a remote object that matches the provided request payload. - - Parameters - ---------- - request : DataRequest - Data request payload to look for. - existing_requests : list[dict] - List of existing data requests (as returned by self.get_remote_requests()). - - Returns - ------- - Remote | None - Remote object if found, None otherwise. - """ - if not existing_requests: - return None - - for remote_request in existing_requests: - if remote_request["request"] == request.__dict__: - return self.client.get_remote(remote_request["request_id"]) - - return None - - def submit(self, request: DataRequest) -> Remote: - """Submit an async data request to the CDS API. - - If an identical data request has already been submitted, the Remote object corresponding to - the existing data request is returned instead of submitting a new one. - """ - return self.client.submit(DATASET, request=request.__dict__) - - def retrieve(self, request: DataRequest, dst_file: Path | str) -> None: - """Submit and download a data request to the CDS API.""" - dst_file = Path(dst_file) - dst_file.parent.mkdir(parents=True, exist_ok=True) - self.client.retrieve(collection_id=DATASET, target=dst_file, request=request.__dict__) - - def download_between( - self, - start: datetime, - end: datetime, - variable: str, - area: list[float], - dst_dir: str | Path, - time: list[int] | None = None, - ) -> None: - """Download all ERA5 data files needed to cover the period. - - Data requests are sent asynchronously (max one per month) to the CDS API and fetched when - they are completed. - - Parameters - ---------- - start : datetime - Start date. - end : datetime - End date. - variable : str - Climate data store variable name (ex: "2m_temperature"). - area : list[float] - Area of interest (north, west, south, east). - dst_dir : str | Path - Output directory. - time : list[int] | None, optional - Hours of interest (ex: [1, 6, 18]). Defaults to None (all hours). - """ - dst_dir = Path(dst_dir) - dst_dir.mkdir(parents=True, exist_ok=True) - - if not start.tzinfo: - start = start.astimezone(tz=timezone.utc) - if not end.tzinfo: - end = end.astimezone(tz=timezone.utc) - - if end > self.latest: - end = self.latest - msg = "End date is after latest available product, setting end date to {}".format(end.strftime("%Y-%m-%d")) - log.info(msg) - - # get the list of dates for which we will want to download data, which is the difference - # between the available (already downloaded) and the requested dates - drange = date_range(start, end) - available = [dtime.date() for dtime in list_datetimes_in_dir(dst_dir)] - dates = [d for d in drange if d.date() not in available] - msg = f"Will request data for {len(dates)} dates" - log.info(msg) - - existing_requests = self.get_remote_requests() - remotes: list[Remote] = [] - - for chunk in iter_chunks(dates): - request = build_request(variable=variable, data_format="grib", area=area, time=time, **chunk) - - # has a similar request been submitted recently? if yes, use it instead of submitting - # a new one - remote = self.get_remote_from_request(request, existing_requests) - if remote: - remotes.append(remote) - msg = f"Found existing request for date {request.year}-{request.month}" - log.info(msg) - else: - remote = self.submit(request) - remotes.append(remote) - msg = f"Submitted new data request {remote.request_id} for {request.year}-{request.month}" - - while remotes: - for remote in remotes: - if remote.results_ready: - request = remote.request - fname = f"{request['year']}{request['month']}_{remote.request_id}.grib" - dst_file = Path(dst_dir, fname) - remote.download(dst_file.as_posix()) - msg = f"Downloaded {dst_file.name}" - log.info(msg) - remotes.remove(remote) - remote.delete() - - if remotes: - msg = f"Still {len(remotes)} files to download. Waiting 30s before retrying..." - log.info(msg) - sleep(30) - - -def decompress_grib_files(data_dir: Path) -> None: - """Decompress all grib files in a directory.""" - for fp in data_dir.glob("*.grib"): - if zipfile.is_zipfile(fp): - with ( - tempfile.NamedTemporaryFile(suffix=".grib") as tmp, - zipfile.ZipFile(fp, "r") as zip_file, - ): - tmp.write(zip_file.read("data.grib")) - shutil.copy(tmp.name, fp) - msg = f"Decompressed {fp.name}" - log.info(msg) - - -def remove_index_files(data_dir: Path) -> None: - """Remove all GRIB index files in a directory.""" - for fp in data_dir.glob("*.idx"): - fp.unlink() - msg = f"Removed index file {fp.name}" - log.debug(msg) diff --git a/openhexa/toolbox/era5/data/variables.toml b/openhexa/toolbox/era5/data/variables.toml new file mode 100644 index 00000000..865aac43 --- /dev/null +++ b/openhexa/toolbox/era5/data/variables.toml @@ -0,0 +1,79 @@ +# Information about supported ERA5-Land variables. +# +# - name: Name/identifier of the variable, used for requests. +# - short_name: Short name used in data files, used for processing. +# - unit: Scientific unit of the variable. +# - time: A list of hours (HH:MM) to fetch for daily aggregation. +# - Accumulated variables (e.g., precipitation) are fetched at "00:00". +# - Instantaneous variables are sampled at four hours. +# - accumulated: Whether the variable is accumulated (True) or instantaneous (False). + +[10m_u_component_of_wind] +name = "10m_u_component_of_wind" +short_name = "u10" +unit = "m s**-1" +time = ["01:00", "07:00", "13:00", "19:00"] +accumulated = false + +[10m_v_component_of_wind] +name = "10m_v_component_of_wind" +short_name = "v10" +unit = "m s**-1" +time = ["01:00", "07:00", "13:00", "19:00"] +accumulated = false + +[2m_dewpoint_temperature] +name = "2m_dewpoint_temperature" +short_name = "d2m" +unit = "K" +time = ["01:00", "07:00", "13:00", "19:00"] +accumulated = false + +[2m_temperature] +name = "2m_temperature" +short_name = "t2m" +unit = "K" +time = ["01:00", "07:00", "13:00", "19:00"] +accumulated = false + +[runoff] +name = "runoff" +short_name = "ro" +unit = "m" +time = ["00:00"] +accumulated = true + +[soil_temperature_level_1] +name = "soil_temperature_level_1" +short_name = "stl1" +unit = "K" +time = ["01:00", "07:00", "13:00", "19:00"] +accumulated = false + +[volumetric_soil_water_layer_1] +name = "volumetric_soil_water_layer_1" +short_name = "swvl1" +unit = "m**3 m**-3" +time = ["01:00", "07:00", "13:00", "19:00"] +accumulated = false + +[volumetric_soil_water_layer_2] +name = "volumetric_soil_water_layer_2" +short_name = "swvl2" +unit = "m**3 m**-3" +time = ["01:00", "07:00", "13:00", "19:00"] +accumulated = false + +[total_precipitation] +name = "total_precipitation" +short_name = "tp" +unit = "m" +time = ["00:00"] +accumulated = true + +[total_evaporation] +name = "total_evaporation" +short_name = "e" +unit = "m" +time = ["00:00"] +accumulated = true \ No newline at end of file diff --git a/openhexa/toolbox/era5/dhis2weeks.py b/openhexa/toolbox/era5/dhis2weeks.py new file mode 100644 index 00000000..4a2dca2c --- /dev/null +++ b/openhexa/toolbox/era5/dhis2weeks.py @@ -0,0 +1,105 @@ +"""A set of functions to convert dates to all types of DHIS2 periods.""" + +from datetime import date, timedelta +from enum import StrEnum + + +class WeekType(StrEnum): + """DHIS2 weekly period types.""" + + WEEK = "WEEK" + WEEK_WEDNESDAY = "WEEK_WEDNESDAY" + WEEK_THURSDAY = "WEEK_THURSDAY" + WEEK_SATURDAY = "WEEK_SATURDAY" + WEEK_SUNDAY = "WEEK_SUNDAY" + + +start_days = { + WeekType.WEEK_WEDNESDAY: 3, + WeekType.WEEK_THURSDAY: 4, + WeekType.WEEK_SATURDAY: 6, + WeekType.WEEK_SUNDAY: 7, +} + + +def get_calendar_week(dt: date, week_type: WeekType) -> tuple[int, int]: + """Get week number and year for a given date and week type. + + Args: + dt: The date to convert. + week_type: The type of week period. One of 'WEEK', 'WEEK_WEDNESDAY', + 'WEEK_THURSDAY', 'WEEK_SATURDAY', 'WEEK_SUNDAY'. + + Returns: + A tuple (year, week number). + + """ + # We can use the ISO calendar for standard Monday weeks + if week_type == WeekType.WEEK: + iso_year, iso_week, _ = dt.isocalendar() + return (iso_year, iso_week) + + # 1st week of the year always contain Jan 4th + week_start = adjust_to_week_start(dt, start_days[week_type]) + jan4 = date(week_start.year, 1, 4) + first_week_start = adjust_to_week_start(jan4, start_days[week_type]) + + # Week start is before the 1st week of the year, so it belongs to the last week of + # the previous year + if week_start < first_week_start: + jan4_prev = date(week_start.year - 1, 1, 4) + first_week_start_prev = adjust_to_week_start(jan4_prev, start_days[week_type]) + weeks_from_start = (week_start - first_week_start_prev).days // 7 + return week_start.year - 1, weeks_from_start + 1 + + # If we are in late December, we might belong to next year's first week + if week_start.month == 12: + week_end = week_start + timedelta(days=6) + if week_end.month == 1: + jan4_next = date(week_start.year + 1, 1, 4) + if week_start <= jan4_next <= week_end: + return week_start.year + 1, 1 + + # Happy path: we are in the current year's weeks + weeks_from_start = (week_start - first_week_start).days // 7 + return week_start.year, weeks_from_start + 1 + + +def to_dhis2_week(dt: date, week_type: WeekType) -> str: + """Convert a date to a DHIS2 period string. + + Args: + dt: The date to convert. + week_type: The type of week period. One of 'WEEK', 'WEEK_WEDNESDAY', + 'WEEK_THURSDAY', 'WEEK_SATURDAY', 'WEEK_SUNDAY'. + + Returns: + The DHIS2 period string. + + """ + year, week = get_calendar_week(dt, week_type) + + prefix = { + WeekType.WEEK: "W", + WeekType.WEEK_WEDNESDAY: "WedW", + WeekType.WEEK_THURSDAY: "ThuW", + WeekType.WEEK_SATURDAY: "SatW", + WeekType.WEEK_SUNDAY: "SunW", + }[week_type] + + return f"{year}{prefix}{week}" + + +def adjust_to_week_start(dt: date, start_day: int) -> date: + """Adjust date to the start of the week. + + Args: + dt: The date to adjust. + start_day: The day of the week the week starts on (1=Monday, 7=Sunday). + + Returns: + The adjusted date. + + """ + days_to_adjust = (dt.weekday() - (start_day - 1)) % 7 + return dt - timedelta(days=days_to_adjust) diff --git a/openhexa/toolbox/era5/extract.py b/openhexa/toolbox/era5/extract.py new file mode 100644 index 00000000..01cb5004 --- /dev/null +++ b/openhexa/toolbox/era5/extract.py @@ -0,0 +1,737 @@ +"""Download ERA5-Land hourly data from the ECMWF Climate Data Store (CDS). + +Provides functions to build requests, submit them to the CDS API, retrieve results, and +move GRIB data to an analysis-ready Zarr store for further processing. +""" + +import logging +import shutil +from collections import defaultdict +from dataclasses import dataclass +from datetime import date +from pathlib import Path +from time import sleep +from typing import Literal + +import numpy as np +import numpy.typing as npt +import xarray as xr +import zarr +from dateutil.relativedelta import relativedelta +from ecmwf.datastores import Remote +from ecmwf.datastores.client import Client + +from openhexa.toolbox.era5.cache import Cache +from openhexa.toolbox.era5.models import Job, Request, RequestTemporal +from openhexa.toolbox.era5.utils import get_name, get_variables + +logger = logging.getLogger(__name__) + + +def get_date_range( + start_date: date, + end_date: date, +) -> list[date]: + """Get inclusive date range from start and end dates. + + Returns: + A list of dates from start to end, inclusive. + + """ + if start_date > end_date: + msg = "Start date must be before end date" + logger.error(msg) + raise ValueError(msg) + + date_range: list[date] = [] + current_date = start_date + while current_date <= end_date: + date_range.append(current_date) + current_date += relativedelta(days=1) + return date_range + + +def _bound_date_range( + start_date: date, + end_date: date, + collection_start_date: date, + collection_end_date: date, +) -> tuple[date, date]: + """Bound input date range to the collection's start and end dates. + + Args: + start_date: Requested start date. + end_date: Requested end date. + collection_start_date: Earliest date in the collection. + collection_end_date: Latest date in the collection. + + Returns: + A new date range tuple (start, end) within the collection's date limits. + + """ + start = max(start_date, collection_start_date) + end = min(end_date, collection_end_date) + return start, end + + +def _get_temporal_chunks(dates: list[date]) -> list[RequestTemporal]: + """Get monthly temporal request chunks for the given list of dates. + + Args: + dates: A list of dates to chunk. + + Returns: + A list of RequestTemporal objects, one per month. + + """ + by_month: dict[tuple[int, int], list[int]] = defaultdict(list) + for d in dates: + by_month[(d.year, d.month)].append(d.day) + + chunks: list[RequestTemporal] = [] + for (year, month), days in by_month.items(): + chunks.append( + RequestTemporal( + year=f"{year:04d}", + month=f"{month:02d}", + day=[f"{day:02d}" for day in sorted(set(days))], + ), + ) + return chunks + + +def _build_requests( + dates: list[date], + variable: str, + time: list[str], + area: list[int], + data_format: Literal["grib", "netcdf"] = "grib", + download_format: Literal["unarchived", "zip"] = "unarchived", +) -> list[Request]: + """Build requests for the reanalysis-era5-land dataset. + + Args: + dates: Requested dates. + variable: Requested variable (ex: "2m_temperature"). + time: List of times to request (ex: ["00:00", "01:00", ..., "23:00"]). + area: Geographical area to request (north, west, south, east). + data_format: Data format to request ("grib" or "netcdf"). + download_format: Download format ("unarchived" or "zip"). + + Returns: + A list of Request objects to be submitted to the CDS API. + + """ + requests: list[Request] = [] + temporal_chunks = _get_temporal_chunks(dates) + for chunk in temporal_chunks: + request = Request( + variable=[variable], + year=chunk["year"], + month=chunk["month"], + day=chunk["day"], + time=time, + data_format=data_format, + download_format=download_format, + area=area, + ) + requests.append(request) + return requests + + +def prepare_requests( + client: Client, + dataset_id: str, + start_date: date, + end_date: date, + variable: str, + area: list[int], + zarr_store: Path, +) -> list[Request]: + """Prepare requests for data retrieval from the CDS API. + + This function checks the available dates in the Zarr store and prepares + requests for the missing dates. + + Args: + client: The CDS API client. + dataset_id: ID of the CDS dataset (e.g. "reanalysis-era5-land"). + start_date: Start date for data synchronization. + end_date: End date for data synchronization. + variable: The variable to synchronize (e.g. "2m_temperature"). + area: The geographical area to synchronize (north, west, south, east). + zarr_store: The Zarr store to update or create. + + Returns: + A list of requests to be submitted to the CDS API. + + """ + variables = get_variables() + if variable not in variables: + msg = f"Variable '{variable}' not supported" + raise ValueError(msg) + + dates = get_missing_dates( + client=client, + dataset_id=dataset_id, + start_date=start_date, + end_date=end_date, + zarr_store=zarr_store, + data_var=variables[variable]["short_name"], + ) + + requests = _build_requests( + dates=dates, + variable=variable, + time=variables[variable]["time"], + area=area, + data_format="grib", + download_format="unarchived", + ) + + max_requests = 100 + if len(requests) > max_requests: + msg = f"Too many data requests ({len(requests)}), max is {max_requests}" + logger.error(msg) + raise ValueError(msg) + + return requests + + +def find_jobs(client: Client) -> list[Job]: + """Get the list of current jobs from the CDS API. + + NB: Jobs with expired results are filtered out and we only search for the latest 100 + jobs. + + Args: + client: CDS API client. + + Returns: + A list of submitted jobs. + + """ + r = client.get_jobs(limit=100, sortby="-created", status=["accepted", "running", "successful"]) + return [Job(**job) for job in r.json["jobs"]] + + +def _submit_requests( + client: Client, + collection_id: str, + requests: list[Request], +) -> list[Remote]: + """Submit a list of requests to the CDS API. + + Args: + client: CDS API client. + collection_id: ID of the CDS dataset (e.g. "reanalysis-era5-land"). + requests: List of request parameters. + + Returns: + List of Remote objects representing the submitted requests. + + """ + remotes: list[Remote] = [] + for request in requests: + r = client.submit( + collection_id=collection_id, + request=dict(request), + ) + logger.info("Submitted request %s", r.request_id) + remotes.append(r) + return remotes + + +def _retrieve_remotes(queue: list[Remote], output_dir: Path, cache: Cache | None = None) -> list[Remote]: + """Retrieve the results of the submitted remotes. + + Args: + queue: List of Remote objects to check and download if ready. + output_dir: Directory to save downloaded files. + cache: Cache to use for caching downloaded files (optional). + + Returns: + List of Remote objects that are still pending (not ready). + + """ + output_dir.mkdir(parents=True, exist_ok=True) + pending: list[Remote] = [] + + for remote in queue: + if remote.results_ready: + name = get_name(remote) + fp = output_dir / name + remote.download(target=fp.as_posix()) + if cache: + cache.set(request=Request(**remote.request), job_id=remote.request_id, file_path=fp) + logger.info("Downloaded %s", name) + else: + pending.append(remote) + return pending + + +def retrieve_requests( + client: Client, + dataset_id: str, + requests: list[Request], + dst_dir: Path, + cache: Cache | None = None, + wait: int = 30, +) -> None: + """Submit and retrieve the results of data requests. + + Args: + client: The CDS API client. + dataset_id: The ID of the dataset to retrieve. + requests: The list of requests to retrieve. + dst_dir: The directory containing the source data files. + cache: Cache to use for caching downloaded files (optional). + wait: Seconds to wait between checks for completed requests (default=30). + + """ + logger.debug("Retrieving %s data requests", len(requests)) + + # If using cache, check for already downloaded files and already submitted + # data requests before submitting new requests + if cache: + triage = _triage_requests(client, cache, requests) + for file_name in triage.downloaded: + dst_fp = dst_dir / file_name.replace(".gz", "") + cache.retrieve(file_name, dst_fp) + logger.info("Retrieved file %s from cache", file_name) + remotes = triage.submitted + remotes += _submit_requests( + client=client, + collection_id=dataset_id, + requests=triage.to_submit, + ) + + # If not using cache, submit all requests directly + else: + remotes = _submit_requests( + client=client, + collection_id=dataset_id, + requests=requests, + ) + + while remotes: + remotes = _retrieve_remotes(remotes, dst_dir, cache=cache) + if remotes: + sleep(wait) + + +@dataclass +class TriageResult: + """Result of triaging data requests after checking the cache. + + Attributes: + downloaded: Job IDs of already downloaded requests. + submitted: Remote objects of already submitted requests. + to_submit: Data requests that still need to be submitted. + """ + + downloaded: list[str] + submitted: list[Remote] + to_submit: list[Request] + + +def _triage_requests(client: Client, cache: Cache, requests: list[Request]) -> TriageResult: + """Triage the requests into downloaded, submitted, and to_submit categories. + + Args: + client: The CDS API client. + cache: The cache to use for checking existing downloads. + requests: The list of requests to triage. + + Returns: + A TriageResult object containing the triaged requests. + """ + result = TriageResult( + downloaded=[], + submitted=[], + to_submit=[], + ) + + jobs = find_jobs(client) + cache.clean_expired_jobs(job_ids=[job.jobID for job in jobs if job.expired]) + cache.clean_missing_files() + + for request in requests: + entry = cache.get(request) + if entry and entry.file_name: + result.downloaded.append(entry.file_name) + elif entry: + remote = client.get_remote(entry.job_id) + result.submitted.append(remote) + else: + result.to_submit.append(request) + + logger.debug( + "Triage result: %s downloaded, %s submitted, %s to submit", + len(result.downloaded), + len(result.submitted), + len(result.to_submit), + ) + + return result + + +def _variable_is_in_zarr(zarr_store: Path, data_var: str) -> bool: + """Check if a variable exists in a zarr store. + + Args: + zarr_store: Path to the zarr store. + data_var: Name of the variable to check. + + Returns: + True if the variable exists in the zarr store, False otherwise. + + """ + if not zarr_store.exists(): + raise ValueError(f"Zarr store {zarr_store} does not exist") + ds = xr.open_zarr(zarr_store, consolidated=True, decode_timedelta=False) + return data_var in ds.data_vars + + +def _list_times_in_zarr(store: Path, data_var: str) -> npt.NDArray[np.datetime64]: + """List time dimensions for a specific variable in the zarr store. + + Args: + store: Path to the zarr store. + data_var: Name of the variable to check. + + Returns: + Numpy array of datetime64 values in the time dimension of the specified variable. + + """ + if not store.exists(): + raise ValueError(f"Zarr store {store} does not exist") + ds = xr.open_zarr(store, consolidated=True, decode_timedelta=False) + if data_var not in ds.data_vars: + raise ValueError(f"Variable {data_var} not found in Zarr store {store}") + return ds[data_var].time.values + + +def _clean_dims_and_coords(ds: xr.Dataset) -> xr.Dataset: + """Expand time and step dimensions if needed. + + When data is downloaded for a single day (or a single step per day), + time and step dimensions can be squeezed into a coordinate instead. + In that case, we expand the coordinate into a dimension to ensure + compatibility with the subsequent processes. + + Args: + ds: The xarray dataset to process (loaded from GRIB file) + + Returns: + The xarray dataset with expanded dimensions if needed. + + Raises: + ValueError: If the dataset does not have a time or step dimension. + """ + if "time" in ds.coords and "time" not in ds.dims: + ds = ds.expand_dims("time") + if "step" in ds.coords and "step" not in ds.dims: + ds = ds.expand_dims("step") + if "time" not in ds.dims: + raise ValueError("Dataset does not have a time dimension") + if "step" not in ds.dims: + raise ValueError("Dataset does not have a step dimension") + + # Drop unused dimensions if they exist + ds = ds.drop_vars(["number", "surface"], errors="ignore") + + # Ensure latitude and longitude are rounded to 0.1 degree + ds = ds.assign_coords( + { + "latitude": np.round(ds.latitude.values, 1), + "longitude": np.round(ds.longitude.values, 1), + }, + ) + + return ds + + +def _drop_incomplete_days(ds: xr.Dataset, data_var: str) -> xr.Dataset: + """Drop days with incomplete temporal steps from the dataset. + + A day is incomplete if any temporal step (hour) has completely missing data + across all spatial points. Days with spatial nulls (e.g., water bodies) are + kept as long as each temporal step has at least some valid data. + + Args: + ds: The xarray dataset to process. + data_var: The name of the data variable to check for completeness. + + Returns: + The xarray dataset with incomplete days removed. + """ + has_valid_data = ~ds[data_var].isnull().all(dim=["latitude", "longitude"]) + all_steps_complete = has_valid_data.all(dim="step") + return ds.sel(time=all_steps_complete) + + +def _flatten_time_dimension(ds: xr.Dataset) -> xr.Dataset: + """Flatten the time dimension of the dataset. + + Flatten step dimension into time. Meaning, instead of having time (n=n_days) and + step (n=n_hours) dimensions, we only have one (n=n_days*n_hours). This makes + analysis easier. + + Args: + ds: The xarray dataset to flatten. + + Returns: + The flattened xarray dataset. + + """ + valid_times = ds.valid_time.values.flatten() + ds = ds.stack(new_time=("time", "step")) + ds = ds.reset_index("new_time", drop=True) + ds = ds.assign_coords(new_time=valid_times) + ds = ds.drop_vars(["valid_time"]) + ds = ds.rename({"new_time": "time"}) + + return ds + + +def _prepare_for_zarr(ds: xr.Dataset) -> xr.Dataset: + """Prepare dataset for zarr storage by setting optimal chunks. + + Args: + ds: The xarray Dataset to prepare (after all transformations) + + Returns: + Dataset with optimal chunking for zarr storage. + """ + # Clear previous encoding for all data vars + for var in ds.data_vars: + ds[var].encoding.pop("chunks", None) + + chunks = {"time": 30, "latitude": -1, "longitude": -1} + + return ds.chunk(chunks) + + +def _create_zarr(ds: xr.Dataset, zarr_store: Path) -> None: + """Create a new zarr store from the dataset. + + Args: + ds: The xarray Dataset to store. + zarr_store: Path to the zarr store to create. + + """ + if zarr_store.exists(): + raise ValueError(f"Zarr store {zarr_store} already exists") + ds = _prepare_for_zarr(ds) + ds.to_zarr(zarr_store, mode="w", consolidated=True, zarr_format=2) + logger.debug("Created Zarr store at %s", zarr_store) + + +def _append_zarr(ds: xr.Dataset, zarr_store: Path, data_var: str) -> None: + """Append new data to an existing zarr store. + + The function checks for overlapping time values and only appends new data. + + Args: + ds: The xarray Dataset to append. + zarr_store: Path to the existing zarr store. + data_var: Name of the variable to append. + + """ + existing_ds = xr.open_zarr(zarr_store, consolidated=True, decode_timedelta=False) + + # Validate that lat/lon coordinates match + if not np.array_equal(existing_ds.latitude.values, ds.latitude.values): + msg = f"Latitude coordinates don't match with zarr store {zarr_store.name}" + logger.error(msg) + raise ValueError(msg) + if not np.array_equal(existing_ds.longitude.values, ds.longitude.values): + msg = f"Longitude coordinates don't match with zarr store {zarr_store.name}" + logger.error(msg) + raise ValueError(msg) + + # Check for overlapping times and only append non-overlapping data + if data_var in existing_ds.data_vars: + existing_times = _list_times_in_zarr(zarr_store, data_var) + new_times = ds.time.values + overlap = np.isin(new_times, existing_times) + if overlap.any(): + logger.debug("Time dimension of GRIB file overlaps with existing Zarr store") + ds = ds.isel(time=~overlap) + if len(ds.time) == 0: + logger.debug("No new data to add to Zarr store") + return + + ds = ds.load() + if data_var in existing_ds.data_vars: + ds.to_zarr(zarr_store, mode="a", append_dim="time", zarr_format=2) + else: + ds.to_zarr(zarr_store, mode="a", zarr_format=2) + logger.debug("Added data to Zarr store for variable %s", data_var) + + +def _consolidate_zarr(zarr_store: Path) -> None: + """Consolidate metadata and ensure dimensions are properly sorted. + + The function consolidates the metadata of the zarr store and checks if the time + dimension is sorted. If not, it sorts the time dimension and rewrites the zarr + store. + + Args: + zarr_store: Path to the zarr store to consolidate. + + """ + zarr.consolidate_metadata(zarr_store) + ds = xr.open_zarr(zarr_store, consolidated=True, decode_timedelta=False) + + # Validate input dataset (duplicate time, inconsistent steps per day) + for data_var in ds.data_vars: + times = ds.time.values + if len(times) == 0: + msg = f"Zarr store {zarr_store.name} has not time values for variable {data_var}" + raise RuntimeError(msg) + if len(times) != len(np.unique(times)): + msg = f"Duplicate time values found in Zarr store {zarr_store.name} for variable {data_var}" + raise RuntimeError(msg) + dates = times.astype("datetime64[D]") + _, counts = np.unique(dates, return_counts=True) + if not np.all(counts == counts[0]): + unique_counts = np.unique(counts) + msg = f"Inconsistent steps per day found: {unique_counts}\nExpected all days to have {counts[0]} steps" + raise RuntimeError(msg) + + # Make sure time dimension is sorted + ds_sorted = ds.sortby("time") + if not np.array_equal(ds.time.values, ds_sorted.time.values): + logger.warning("Time dimension is unsorted, rewriting zarr store") + ds_sorted = _prepare_for_zarr(ds_sorted) + _safe_rewrite_zarr(ds_sorted, zarr_store) + + +def _safe_rewrite_zarr(ds: xr.Dataset, zarr_store: Path) -> None: + """Safely rewrite a zarr store with backup.""" + backup = zarr_store.parent / f"{zarr_store.name}.backup" + + for var in ds.data_vars: + ds[var].encoding.pop("chunks", None) + + try: + shutil.move(zarr_store, backup) + ds.to_zarr(zarr_store, mode="w", consolidated=True, zarr_format=2) + shutil.rmtree(backup) + except Exception as e: + if backup.exists(): + if zarr_store.exists(): + shutil.rmtree(zarr_store) + shutil.move(backup, zarr_store) + raise e + + +def _diff_zarr( + start_date: date, + end_date: date, + zarr_store: Path, + data_var: str, +) -> list[date]: + """Get dates between start and end dates that are not in the zarr store. + + Args: + start_date: Start date for data retrieval. + end_date: End date for data retrieval. + zarr_store: The Zarr store to check for existing data. + data_var: Name of the variable to check in the Zarr store. + + Returns: + The list of dates that are not in the Zarr store. + + """ + if not zarr_store.exists(): + return get_date_range(start_date, end_date) + + if not _variable_is_in_zarr(zarr_store, data_var): + return get_date_range(start_date, end_date) + + zarr_dtimes = _list_times_in_zarr(zarr_store, data_var) + zarr_dates = zarr_dtimes.astype("datetime64[D]").astype(date).tolist() + + date_range = get_date_range(start_date, end_date) + return [d for d in date_range if d not in zarr_dates] + + +def get_missing_dates( + client: Client, + dataset_id: str, + start_date: date, + end_date: date, + zarr_store: Path, + data_var: str, +) -> list[date]: + """Get the list of dates between start_date and end_date that are not in the Zarr store. + + Args: + client: The CDS API client. + dataset_id: The ID of the dataset to check. + start_date: Start date for data retrieval. + end_date: End date for data retrieval. + zarr_store: The Zarr store to check for existing data. + data_var: Name of the variable to check in the Zarr store. + + Returns: + A list of dates that are not in the Zarr store. + + """ + collection = client.get_collection(dataset_id) + if not collection.begin_datetime or not collection.end_datetime: + msg = f"Dataset {dataset_id} does not have a defined date range" + raise ValueError(msg) + start_date, end_date = _bound_date_range( + start_date, + end_date, + collection.begin_datetime.date(), + collection.end_datetime.date(), + ) + logger.debug("Checking existing data for variable '%s' from %s to %s", data_var, start_date, end_date) + dates = _diff_zarr(start_date, end_date, zarr_store, data_var) + logger.debug("Missing dates for variable '%s': %s", data_var, dates) + return dates + + +def grib_to_zarr( + src_dir: Path, + zarr_store: Path, + data_var: str, +) -> None: + """Move data in multiple GRIB files to a zarr store. + + The function processes all GRIB files in the source directory and moves the data + to the specified Zarr store (creating or appending as necessary). + + Args: + src_dir: Directory containing the GRIB files. + zarr_store: Path to the zarr store to create or update. + data_var: Short name of the variable to process (e.g. "t2m", "tp", "swvl1"). + + """ + for fp in sorted(src_dir.glob("*.grib")): + logger.info("Processing GRIB file %s", fp.name) + ds = xr.open_dataset(fp, engine="cfgrib", decode_timedelta=False) + ds = _clean_dims_and_coords(ds) + if ds[data_var].isnull().all(): + logger.warning("GRIB file %s is completely empty, skipping", fp.name) + continue + ds = _drop_incomplete_days(ds, data_var=data_var) + if len(ds.time) == 0: + logger.warning("All days dropped from %s after filtering, skipping", fp.name) + continue + ds = _flatten_time_dimension(ds) + + if not zarr_store.exists(): + logger.debug("Creating new Zarr store '%s'", zarr_store.name) + _create_zarr(ds, zarr_store) + else: + logger.debug("Appending data to existing Zarr store '%s'", zarr_store.name) + _append_zarr(ds, zarr_store, data_var) + logger.debug("Consolidating Zarr store '%s'", zarr_store.name) + _consolidate_zarr(zarr_store) + logger.debug("Validating Zarr store '%s'", zarr_store.name) diff --git a/openhexa/toolbox/era5/google.py b/openhexa/toolbox/era5/google.py deleted file mode 100644 index 884dbbab..00000000 --- a/openhexa/toolbox/era5/google.py +++ /dev/null @@ -1,158 +0,0 @@ -"""Download raw historical Era5 products from Google Cloud: -https://console.cloud.google.com/storage/browser/gcp-public-data-arco-era5 - -Products are provided as raw NetCDF files and are usually available with a ~3 month lag. -""" - -from __future__ import annotations - -import importlib.resources -import json -import logging -import shutil -import tempfile -from datetime import datetime, timedelta -from functools import cached_property -from pathlib import Path - -import requests -from google.cloud import storage - -with importlib.resources.open_text("openhexa.toolbox.era5", "variables.json") as f: - VARIABLES = json.load(f) - -log = logging.getLogger(__name__) - - -class NotFoundError(Exception): - pass - - -class ParameterError(ValueError): - pass - - -class Client: - def __init__(self): - self.client = storage.Client.create_anonymous_client() - self.bucket = self.client.bucket("gcp-public-data-arco-era5") - - @staticmethod - def prefix(variable: str, date: datetime) -> str: - """Build key prefix for a given product.""" - return f"raw/date-variable-single_level/{date.year}/{date.month:02}/{date.day:02}/{variable}/surface.nc" - - def _subdirs(self, prefix: str) -> list[str]: - """List subdirs.""" - blobs = self.client.list_blobs(self.bucket, prefix=prefix, delimiter="/") - prefixes = [] - for page in blobs.pages: - prefixes += page.prefixes - return prefixes - - @cached_property - def latest(self) -> datetime: - """Get date of latest available product.""" - root = "raw/date-variable-single_level/" - subdirs = self._subdirs(root) # years - subdirs = self._subdirs(max(subdirs)) # months - subdirs = self._subdirs(max(subdirs)) # days - subdir = max(subdirs).split("/") - year = int(subdir[-4]) - month = int(subdir[-3]) - day = int(subdir[-2]) - return datetime(year, month, day) - - def find(self, variable: str, date: datetime) -> str | None: - """Find public URL of product. Return None if not found.""" - prefix = self.prefix(variable, date) - blobs = self.client.list_blobs(self.bucket, prefix=prefix, max_results=1) - blobs = list(blobs) - if blobs: - return blobs[0].public_url - else: - return None - - def download(self, variable: str, date: datetime, dst_file: str | Path, overwrite=False): - """Download an Era5 NetCDF product for a given day. - - Parameters - ---------- - variable : str - Climate data store variable name (ex: "2m_temperature"). - date : datetime - Product date (year, month, day). - dst_file : str | Path - Output file. - overwrite : bool, optional - Overwrite existing file (default=False). - - Raises - ------ - ParameterError - Product request parameters are invalid. - NotFoundError - Product not found in bucket. - """ - dst_file = Path(dst_file) - dst_file.parent.mkdir(parents=True, exist_ok=True) - - if dst_file.exists() and not overwrite: - log.debug("Skipping download of %s because file already exists", str(dst_file.absolute())) - return - - if variable not in VARIABLES: - raise ParameterError("%s is not a valid climate data store variable name", variable) - - url = self.find(variable, date) - if not url: - raise NotFoundError("%s product not found for date %s", variable, date.strftime("%Y-%m-%d")) - - with tempfile.NamedTemporaryFile() as tmp: - with open(tmp.name, "wb") as f: - with requests.get(url, stream=True) as r: - for chunk in r.iter_content(chunk_size=1024**2): - if chunk: - f.write(chunk) - - shutil.copy(tmp.name, dst_file) - - log.debug("Downloaded %s", str(dst_file.absolute())) - - def sync(self, variable: str, start_date: datetime, end_date: datetime, dst_dir: str | Path): - """Download all products for a given variable and date range. - - If products are already present in the destination directory, they will be skipped. - Expects file names to be formatted as "YYYY-MM-DD_VARIABLE.nc". - - Parameters - ---------- - variable : str - Climate data store variable name (ex: "2m_temperature"). - start_date : datetime - Start date (year, month, day). - end_date : datetime - End date (year, month, day). - dst_dir : str | Path - Output directory. - """ - dst_dir = Path(dst_dir) - dst_dir.mkdir(parents=True, exist_ok=True) - - if start_date > end_date: - raise ParameterError("`start_date` must be before `end_date`") - - date = start_date - if end_date > self.latest: - log.info("Setting `end_date` to the latest available date: %s" % date.strftime("%Y-%m-%d")) - end_date = self.latest - - while date <= end_date: - expected_filename = f"{date.strftime('%Y-%m-%d')}_{variable}.nc" - fpath = Path(dst_dir, expected_filename) - fpath_grib = Path(dst_dir, expected_filename.replace(".nc", ".grib")) - if fpath.exists() or fpath_grib.exists(): - log.debug("%s already exists, skipping download" % expected_filename) - else: - self.download(variable=variable, date=date, dst_file=fpath, overwrite=False) - date += timedelta(days=1) diff --git a/openhexa/toolbox/era5/models.py b/openhexa/toolbox/era5/models.py new file mode 100644 index 00000000..f28d4e70 --- /dev/null +++ b/openhexa/toolbox/era5/models.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass +from typing import Literal, TypedDict + + +class Variable(TypedDict): + """Metadata for a single variable in the ERA5-Land dataset.""" + + name: str + short_name: str + unit: str + time: list[str] + accumulated: bool + + +class Request(TypedDict): + """Request parameters for the 'reanalysis-era5-land' dataset.""" + + variable: list[str] + year: str + month: str + day: list[str] + time: list[str] + data_format: Literal["grib", "netcdf"] + download_format: Literal["unarchived", "zip"] + area: list[int] + + +class RequestTemporal(TypedDict): + """Temporal request parameters.""" + + year: str + month: str + day: list[str] + + +class JobLink(TypedDict): + """A link related to a data request job.""" + + href: str + rel: str + type: str | None + title: str | None + + +class JobMetadataResults(TypedDict): + """Metadata about the results of a data request job.""" + + type: str + title: str + status: int + detail: str + trace_id: str + + +class JobMetadata(TypedDict): + """Metadata about a data request job.""" + + results: JobMetadataResults + datasetMetadata: dict[str, str] + qos: dict[str, dict] + origin: str + + +@dataclass +class Job: + """A data request job in the CDS.""" + + processID: str + type: str + jobID: str + status: str + created: str + started: str + finished: str + updated: str + links: list[JobLink] + metadata: JobMetadata + + @property + def expired(self) -> bool: + """Whether the job results have expired. + + Means that a data request has been successfully processed by the server, + but the results expired and cannot be downloaded anymore. This doesn't change + the status, we have to dig into job metadata for this info. + """ + if "results" in self.metadata: + if "type" in self.metadata["results"]: + if self.metadata["results"]["type"] == "results expired": + return True + return False diff --git a/openhexa/toolbox/era5/transform.py b/openhexa/toolbox/era5/transform.py new file mode 100644 index 00000000..23bc3a36 --- /dev/null +++ b/openhexa/toolbox/era5/transform.py @@ -0,0 +1,274 @@ +"""Spatial aggregation of ERA5-Land data.""" + +import logging +from enum import StrEnum +from typing import Literal + +import geopandas as gpd +import numpy as np +import polars as pl +import rasterio.features +import rasterio.transform +import xarray as xr + +from openhexa.toolbox.era5.dhis2weeks import WeekType, to_dhis2_week + +logger = logging.getLogger(__name__) + + +def create_masks(gdf: gpd.GeoDataFrame, id_column: str, ds: xr.Dataset) -> xr.DataArray: + """Create masks for each boundary in the GeoDataFrame. + + Input polygons are rasterized into a grid matching the spatial dimensions of the + dataset. + We use the `all_touched=True` option, so that any pixel touched by a polygon is included in the + mask. This is because we don't want small geometries ending up with zero pixel. As a result, + each polygon has its own mask because some pixels may belong to multiple polygons. + + Args: + gdf: A GeoDataFrame containing the boundaries, with a 'geometry' column + id_column: Column in the GeoDataFrame that contains unique identifiers for each + boundary + ds: An xarray Dataset containing the spatial dimensions (latitude and longitude) + + Returns: + An xarray DataArray with dimensions ['boundary', 'latitude', 'longitude'] + containing the masks. Each mask corresponds to a boundary in the GeoDataFrame. + + """ + logger.debug("Creating masks for %s boundaries", len(gdf)) + lat = ds.latitude.values + lon = ds.longitude.values + lat_res = abs(lat[1] - lat[0]) + lon_res = abs(lon[1] - lon[0]) + transform = rasterio.transform.from_bounds( # type: ignore + west=lon.min() - lon_res / 2, + east=lon.max() + lon_res / 2, + north=lat.max() + lat_res / 2, + south=lat.min() - lat_res / 2, + width=len(lon), + height=len(lat), + ) + + masks: list[np.ndarray] = [] + names: list[str] = [] + + for _, row in gdf.iterrows(): + mask = rasterio.features.rasterize( # type: ignore + [row.geometry], + out_shape=(len(lat), len(lon)), + transform=transform, # type: ignore + fill=0, + all_touched=True, + dtype=np.uint8, + ) + masks.append(mask) # type: ignore + names.append(row[id_column]) # type: ignore + + logger.debug("Created masks with shape %s", (len(masks), len(lat), len(lon))) + + return xr.DataArray( + np.stack(masks), + dims=["boundary", "latitude", "longitude"], + coords={ + "boundary": names, + "latitude": lat, + "longitude": lon, + }, + ) + + +def aggregate_in_space( + ds: xr.Dataset, + masks: xr.DataArray, + data_var: str, + agg: Literal["mean", "sum", "min", "max"], +) -> pl.DataFrame: + """Perform spatial aggregation on the dataset using the provided masks. + + Args: + ds: The data containing the variable to aggregate. Dataset is expected to have + 'latitude' and 'longitude' coordinates, and daily data. + masks: An xarray DataArray containing the masks for spatial aggregation, as returned by the + `create_masks()` function. + data_var: Name of the variable to aggregate in input dataset (e.g. "t2m") + agg: Spatial aggregation method (one of "mean", "sum", "min", "max"). + + Returns: + A Polars DataFrame of shape (n_boundaries, n_days) with columns: "boundary", "time", and + "value". + + Raises: + ValueError: If the specified variable is not found in the dataset. + ValueError: If the dataset still contains the 'step' dimension (i.e. data is not daily). + ValueError: If an unsupported aggregation method is specified. + + """ + logger.debug("Aggregating data for variable '%s' using masks", data_var) + if data_var not in ds.data_vars: + msg = f"Variable '{data_var}' not found in dataset" + raise ValueError(msg) + if "step" in ds.dims: + msg = "Dataset still contains 'step' dimension. Please aggregate to daily data first." + raise ValueError(msg) + da = ds[data_var].compute() + area_weights = np.cos(np.deg2rad(ds.latitude)) + results: list[xr.DataArray] = [] + for boundary in masks.boundary: + logger.debug("Aggregating for boundary '%s'", boundary.item()) + mask = masks.sel(boundary=boundary) + if agg == "mean": + weights = area_weights * mask + result = da.weighted(weights).mean(["latitude", "longitude"]) + elif agg == "sum": + result = da.where(mask > 0).sum(["latitude", "longitude"]) + elif agg == "min": + result = da.where(mask > 0).min(["latitude", "longitude"]) + elif agg == "max": + result = da.where(mask > 0).max(["latitude", "longitude"]) + else: + msg = f"Unsupported aggregation method: {agg}" + raise ValueError(msg) + results.append(result) + result = xr.concat(results, dim="boundary").assign_coords(boundary=masks.boundary, time=ds.time) + + n_boundaries = len(result.boundary) + n_times = len(result.time) + + schema = { + "boundary": pl.String, + "time": pl.Date, + "value": pl.Float64, + } + data = { + "boundary": np.repeat(result.boundary.values, n_times), + "time": np.tile(result.time.values, n_boundaries), + "value": result.values.flatten(order="C"), + } + return pl.DataFrame(data, schema=schema) + + +class Period(StrEnum): + """Temporal aggregation periods.""" + + DAY = "DAY" + WEEK = "WEEK" + MONTH = "MONTH" + YEAR = "YEAR" + WEEK_WEDNESDAY = "WEEK_WEDNESDAY" + WEEK_THURSDAY = "WEEK_THURSDAY" + WEEK_SATURDAY = "WEEK_SATURDAY" + WEEK_SUNDAY = "WEEK_SUNDAY" + + +def aggregate_in_time( + dataframe: pl.DataFrame, + period: Period, + agg: Literal["mean", "sum", "min", "max"] = "mean", +) -> pl.DataFrame: + """Aggregate the dataframe over the specified temporal period. + + Args: + dataframe: The dataframe to aggregate. + period: The temporal period to aggregate over. + agg: Temporal aggregation method (one of "mean", "sum", "min", "max"). + + Returns: + The aggregated dataframe. + + """ + logger.debug("Aggregating dataframe over period '%s' with method '%s'", period, agg) + # We 1st create a "period" column to be able to group by it + if period == Period.DAY: + df = dataframe.with_columns( + pl.col("time").dt.strftime("%Y%m%d").alias("period"), + ) + elif period == Period.MONTH: + df = dataframe.with_columns( + pl.col("time").dt.strftime("%Y%m").alias("period"), + ) + elif period == Period.YEAR: + df = dataframe.with_columns( + pl.col("time").dt.strftime("%Y").alias("period"), + ) + elif period in ( + Period.WEEK, + Period.WEEK_WEDNESDAY, + Period.WEEK_THURSDAY, + Period.WEEK_SATURDAY, + Period.WEEK_SUNDAY, + ): + df = dataframe.with_columns( + pl.col("time") + .map_elements(lambda dt: to_dhis2_week(dt, WeekType(period)), return_dtype=pl.String) + .alias("period"), + ) + else: + msg = f"Unsupported period: {period}" + raise NotImplementedError(msg) + + if agg == "mean": + df = df.group_by(["boundary", "period"]).agg(pl.col("value").mean().alias("value")) + elif agg == "sum": + df = df.group_by(["boundary", "period"]).agg(pl.col("value").sum().alias("value")) + elif agg == "min": + df = df.group_by(["boundary", "period"]).agg(pl.col("value").min().alias("value")) + elif agg == "max": + df = df.group_by(["boundary", "period"]).agg(pl.col("value").max().alias("value")) + else: + msg = f"Unsupported aggregation method: {agg}" + raise ValueError(msg) + + return df.select(["boundary", "period", "value"]).sort(["boundary", "period"]) + + +def calculate_relative_humidity(t2m: xr.DataArray, d2m: xr.DataArray) -> xr.Dataset: + """Calculate relative humidity from 2m temperature and 2m dewpoint temperature. + + Uses Magnus formula to calculate RH from t2m and d2m. + + Args: + t2m: 2m temperature in Kelvin. + d2m: 2m dewpoint temperature in Kelvin. + + Returns: + Relative humidity in percentage. + """ + t2m_c = t2m - 273.15 + d2m_c = d2m - 273.15 + + a = 17.1 # temperature coefficient + b = 235.0 # temperature offset (°C) + base_pressure = 6.1078 + vapor_pressure = base_pressure * np.exp(a * d2m_c / (b + d2m_c)) + sat_vapor_pressure = base_pressure * np.exp(a * t2m_c / (b + t2m_c)) + rh = vapor_pressure / sat_vapor_pressure + rh = rh.clip(0, 1) + rh_da = xr.DataArray( + rh * 100, + dims=t2m.dims, + coords=t2m.coords, + attrs={"units": "%"}, + ) + return xr.Dataset({"rh": rh_da}) + + +def calculate_wind_speed(u10: xr.DataArray, v10: xr.DataArray) -> xr.Dataset: + """Calculate wind speed from u10 and v10 components. + + Args: + u10: U component of wind at 10m in m/s. + v10: V component of wind at 10m in m/s. + + Returns: + Wind speed in m/s. + """ + wind_speed = np.sqrt(u10**2 + v10**2) + wind_speed_da = xr.DataArray( + wind_speed, + dims=u10.dims, + coords=u10.coords, + name="ws", + attrs={"units": "m/s"}, + ) + return xr.Dataset({"ws": wind_speed_da}) diff --git a/openhexa/toolbox/era5/utils.py b/openhexa/toolbox/era5/utils.py new file mode 100644 index 00000000..f5191c0d --- /dev/null +++ b/openhexa/toolbox/era5/utils.py @@ -0,0 +1,33 @@ +import importlib.resources +import tomllib + +from ecmwf.datastores import Remote + +from openhexa.toolbox.era5.models import Variable + + +def get_name(remote: Remote) -> str: + """Create file name from remote request. + + Returns: + File name with format: {year}{month}_{request_id}.{ext} + + """ + request = remote.request + data_format = request["data_format"] + download_format = request["download_format"] + year = request["year"] + month = request["month"] + ext = "zip" if download_format == "zip" else data_format + return f"{year}{month}_{remote.request_id}.{ext}" + + +def get_variables() -> dict[str, Variable]: + """Load ERA5-Land variables metadata. + + Returns: + A dictionary mapping variable names to their metadata. + + """ + with importlib.resources.files("openhexa.toolbox.era5").joinpath("data/variables.toml").open("rb") as f: + return tomllib.load(f) diff --git a/openhexa/toolbox/era5/variables.json b/openhexa/toolbox/era5/variables.json deleted file mode 100644 index e659ef8f..00000000 --- a/openhexa/toolbox/era5/variables.json +++ /dev/null @@ -1,352 +0,0 @@ -{ - "lake_mix_layer_temperature": { - "name": "Lake mix-layer temperature", - "shortname": "lmlt", - "units": "K", - "grib1": true, - "grib2": false - }, - "lake_mix_layer_depth": { - "name": "Lake mix-layer depth", - "shortname": "lmld", - "units": "m", - "grib1": true, - "grib2": false - }, - "lake_bottom_temperature": { - "name": "Lake bottom temperature", - "shortname": "lblt", - "units": "K", - "grib1": true, - "grib2": false - }, - "lake_total_layer_temperature": { - "name": "Lake total layer temperature", - "shortname": "ltlt", - "units": "K", - "grib1": true, - "grib2": false - }, - "lake_shape_factor": { - "name": "Lake shape factor", - "shortname": "lshf", - "units": "dimensionless", - "grib1": true, - "grib2": false - }, - "lake_ice_temperature": { - "name": "Lake ice temperature", - "shortname": "lict", - "units": "K", - "grib1": true, - "grib2": false - }, - "lake_ice_depth": { - "name": "Lake ice depth", - "shortname": "licd", - "units": "m", - "grib1": true, - "grib2": false - }, - "snow_cover": { - "name": "Snow cover", - "shortname": "snowc", - "units": "%", - "grib1": false, - "grib2": true - }, - "snow_depth": { - "name": "Snow depth", - "shortname": "sde", - "units": "m", - "grib1": false, - "grib2": true - }, - "snow_albedo": { - "name": "Snow albedo", - "shortname": "asn", - "units": "(0 - 1)", - "grib1": true, - "grib2": false - }, - "snow_density": { - "name": "Snow density", - "shortname": "rsn", - "units": "kg m**-3", - "grib1": true, - "grib2": false - }, - "volumetric_soil_water_layer_1": { - "name": "Volumetric soil water layer 11", - "shortname": "swvl1", - "units": "m**3 m**-3", - "grib1": true, - "grib2": false - }, - "volumetric_soil_water_layer_2": { - "name": "Volumetric soil water layer 21", - "shortname": "swvl2", - "units": "m**3 m**-3", - "grib1": true, - "grib2": false - }, - "volumetric_soil_water_layer_3": { - "name": "Volumetric soil water layer 31", - "shortname": "swvl3", - "units": "m**3 m**-3", - "grib1": true, - "grib2": false - }, - "volumetric_soil_water_layer_4": { - "name": "Volumetric soil water layer 41", - "shortname": "swvl4", - "units": "m**3 m**-3", - "grib1": true, - "grib2": false - }, - "leaf_area_index_low_vegetation": { - "name": "Leaf area index, low vegetation2", - "shortname": "lai_lv", - "units": "m**2 m**-2", - "grib1": true, - "grib2": false - }, - "leaf_area_index_high_vegetation": { - "name": "Leaf area index, high vegetation2", - "shortname": "lai_hv", - "units": "m**2 m**-2", - "grib1": true, - "grib2": false - }, - "surface_pressure": { - "name": "Surface pressure", - "shortname": "sp", - "units": "Pa", - "grib1": true, - "grib2": false - }, - "soil_temperature_level_1": { - "name": "Soil temperature level 11", - "shortname": "stl1", - "units": "K", - "grib1": true, - "grib2": false - }, - "snow_depth_water_equivalent": { - "name": "Snow depth water equivalent", - "shortname": "sd", - "units": "m of water equivalent", - "grib1": true, - "grib2": false - }, - "10m_u_component_of_wind": { - "name": "10 metre U wind component", - "shortname": "u10", - "units": "m s**-1", - "grib1": true, - "grib2": false - }, - "10m_v_component_of_wind": { - "name": "10 metre V wind component", - "shortname": "v10", - "units": "m s**-1", - "grib1": true, - "grib2": false - }, - "2m_temperature": { - "name": "2 metre temperature", - "shortname": "t2m", - "units": "K", - "grib1": true, - "grib2": false - }, - "2m_dewpoint_temperature": { - "name": "2 metre dewpoint temperature", - "shortname": "2d", - "units": "K", - "grib1": true, - "grib2": false - }, - "soil_temperature_level_2": { - "name": "Soil temperature level 21", - "shortname": "stl2", - "units": "K", - "grib1": true, - "grib2": false - }, - "soil_temperature_level_3": { - "name": "Soil temperature level 31", - "shortname": "stl3", - "units": "K", - "grib1": true, - "grib2": false - }, - "skin_reservoir_content": { - "name": "Skin reservoir content", - "shortname": "src", - "units": "m of water equivalent", - "grib1": false, - "grib2": false - }, - "skin_temperature": { - "name": "Skin temperature", - "shortname": "skt", - "units": "K", - "grib1": true, - "grib2": false - }, - "soil_temperature_level_4": { - "name": "Soil temperature level 41", - "shortname": "stl4", - "units": "K", - "grib1": true, - "grib2": false - }, - "temperature_of_snow_layer": { - "name": "Temperature of snow layer", - "shortname": "tsn", - "units": "K", - "grib1": true, - "grib2": false - }, - "forecast_albedo": { - "name": "Forecast albedo", - "shortname": "fal", - "units": "(0 - 1)", - "grib1": true, - "grib2": false - }, - "surface_runoff": { - "name": "Surface runoff", - "shortname": "sro", - "units": "m", - "grib1": true, - "grib2": false - }, - "sub_surface_runoff": { - "name": "Sub-surface runoff", - "shortname": "ssro", - "units": "m", - "grib1": true, - "grib2": false - }, - "\u00a0snow_evaporation": { - "name": "Snow evaporation", - "shortname": "es", - "units": "m of water equivalent", - "grib1": true, - "grib2": false - }, - "snowmelt": { - "name": "Snowmelt", - "shortname": "smlt", - "units": "m of water equivalent", - "grib1": true, - "grib2": false - }, - "snowfall": { - "name": "Snowfall", - "shortname": "sf", - "units": "m of water equivalent", - "grib1": true, - "grib2": false - }, - "surface_sensible_heat_flux": { - "name": "Surface sensible heat flux", - "shortname": "sshf", - "units": "J m**-2", - "grib1": true, - "grib2": false - }, - "surface_latent_heat_flux": { - "name": "Surface latent heat flux", - "shortname": "slhf", - "units": "J m**-2", - "grib1": true, - "grib2": false - }, - "surface_solar_radiation_downwards": { - "name": "Surface solar radiation downwards", - "shortname": "ssrd", - "units": "J m**-2", - "grib1": true, - "grib2": false - }, - "surface_thermal_radiation_downwards": { - "name": "Surface thermal radiation downwards", - "shortname": "strd", - "units": "J m**-2", - "grib1": true, - "grib2": false - }, - "surface_net_solar_radiation": { - "name": "Surface net solar radiation", - "shortname": "ssr", - "units": "J m**-2", - "grib1": true, - "grib2": false - }, - "surface_net_thermal_radiation": { - "name": "Surface net thermal radiation", - "shortname": "str", - "units": "J m**-2", - "grib1": true, - "grib2": false - }, - "total_evaporation": { - "name": "Total Evaporation", - "shortname": "e", - "units": "m of water equivalent", - "grib1": true, - "grib2": false - }, - "runoff": { - "name": "Runoff", - "shortname": "ro", - "units": "m", - "grib1": true, - "grib2": false - }, - "total_precipitation": { - "name": "Total precipitation", - "shortname": "tp", - "units": "m", - "grib1": true, - "grib2": false - }, - "evaporation_from_the_top_of_canopy": { - "name": "Evaporation from the top of canopy", - "shortname": "evatc", - "units": "m of water equivalent", - "grib1": false, - "grib2": true - }, - "evaporation_from_bare_soil": { - "name": "Evaporation from bare soil", - "shortname": "evabs", - "units": "m of water equivalent", - "grib1": false, - "grib2": true - }, - "evaporation_from_open_water_surfaces_excluding_oceans": { - "name": "Evaporation from open water surfaces excluding oceans", - "shortname": "evaow", - "units": "m of water equivalent", - "grib1": false, - "grib2": true - }, - "evaporation_from_vegetation_transpiration": { - "name": "Evaporation from vegetation transpiration", - "shortname": "evavt", - "units": "m of water equivalent", - "grib1": false, - "grib2": true - }, - "potential_evaporation": { - "name": "Potential evaporation", - "shortname": "pev", - "units": "m", - "grib1": true, - "grib2": false - } -} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a87a9204..4546d7bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "openhexa.toolbox" -version = "2.9.0" +version = "2.10.0" description = "A set of tools to acquire & process data from various sources" authors = [{ name = "Bluesquare", email = "dev@bluesquarehub.com" }] maintainers = [{ name = "Bluesquare", email = "dev@bluesquarehub.com" }] @@ -15,7 +15,7 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] -requires-python = ">=3.8" +requires-python = ">=3.12" dependencies = [ "requests", "python-dateutil", @@ -31,6 +31,7 @@ dependencies = [ "openhexa.sdk", "humanize", "rich", + "psycopg>=3.2.11", "regex" ] @@ -50,49 +51,41 @@ test = [ "pytest~=8.4.0", "pytest-cov~=7.0.0", "responses", - "cdsapi >=0.7.3", - "cads-api-client >=1.4.0", - "cfgrib", "xarray", - "datapi >=0.3.0", "rasterio", "epiweeks", - "openlineage-python >=1.33.0" + "openlineage-python >=1.33.0", + "ecmwf-datastores-client>=0.4.0", + "pyarrow>=18.1.0", + "xarray>=2025.9.1", + "zarr>=3.1.3", ] era5 = [ - "cdsapi >=0.7.3", - "cads-api-client >=1.4.0", - "cfgrib", - "xarray", - "datapi >=0.3.0", - "rasterio", - "epiweeks" + "cfgrib>=0.9.15.1", + "ecmwf-datastores-client>=0.4.0", + "pyarrow>=21.0.0", + "xarray>=2025.9.1", + "zarr>=3.1.3", + "dask", ] lineage = [ "openlineage-python >=1.33.0" ] -all = ["openhexa.toolbox[era5,lineage]"] +all = [] [tool.setuptools] include-package-data = true [tool.setuptools.packages.find] where = ["."] -include = [ - "openhexa.toolbox.dhis2", - "openhexa.toolbox.era5", - "openhexa.toolbox.hexa", - "openhexa.toolbox.iaso", - "openhexa.toolbox.kobo", - "openhexa.toolbox.lineage", -] +include = ["openhexa.toolbox.*"] namespaces = true [tool.setuptools.package-data] -"openhexa.toolbox.era5" = ["*.json"] +"openhexa.toolbox.era5" = ["data/*.json", "data/*.toml"] [project.urls] "Homepage" = "https://github.com/blsq/openhexa-toolbox" diff --git a/recipe/meta.yaml b/recipe/meta.yaml index 90401b7c..5da52ceb 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -42,6 +42,11 @@ requirements: - openhexa.sdk >=2.2.0 - rich - openlineage-python >=1.33.0 + - psycopg + - ecmwf-datastores-client + - zarr + - dask + - regex test: imports: diff --git a/tests/era5/data/geoms.parquet b/tests/era5/data/geoms.parquet new file mode 100644 index 00000000..4f8d6b52 Binary files /dev/null and b/tests/era5/data/geoms.parquet differ diff --git a/tests/era5/data/sample_202503.grib b/tests/era5/data/sample_202503.grib new file mode 100644 index 00000000..65b58d44 Binary files /dev/null and b/tests/era5/data/sample_202503.grib differ diff --git a/tests/era5/data/sample_202504.grib b/tests/era5/data/sample_202504.grib new file mode 100644 index 00000000..76beca86 Binary files /dev/null and b/tests/era5/data/sample_202504.grib differ diff --git a/tests/era5/data/sample_2m_temperature.zarr.tar.gz b/tests/era5/data/sample_2m_temperature.zarr.tar.gz new file mode 100644 index 00000000..417b09eb Binary files /dev/null and b/tests/era5/data/sample_2m_temperature.zarr.tar.gz differ diff --git a/tests/era5/test_cds.py b/tests/era5/test_cds.py deleted file mode 100644 index c8bb2f51..00000000 --- a/tests/era5/test_cds.py +++ /dev/null @@ -1,158 +0,0 @@ -"""Unit tests for the ERA5 Climate Data Store client.""" - -from __future__ import annotations - -from datetime import datetime -from unittest.mock import Mock, patch - -import datapi -import pytest - -from openhexa.toolbox.era5.cds import ( - CDS, - DataRequest, -) - - -class TestCollection(datapi.catalogue.Collection): - """Datapi Collection object with mocked end_datetime property.""" - - __test__ = False - - def __init__(self, end_datetime: datetime) -> None: - self._end_datetime = end_datetime - - @property - def end_datetime(self): - return self._end_datetime - - -@patch("datapi.ApiClient.check_authentication") -def test_cds_init(mock_check_authentication: Mock): - """Test CDS class initialization.""" - mock_check_authentication.return_value = True - CDS(key="xxx") - - -@pytest.fixture -@patch("datapi.ApiClient.check_authentication") -def fake_cds(mock_check_authentication: Mock): - mock_check_authentication.return_value = True - return CDS(key="xxx") - - -@patch("datapi.ApiClient.get_collection") -def test_latest(mock_get_collection: Mock, fake_cds: CDS): - mock_get_collection.return_value = TestCollection(end_datetime=datetime(2023, 1, 1).astimezone()) - assert fake_cds.latest == datetime(2023, 1, 1).astimezone() - - -class TestJobs(datapi.processing.Jobs): - """Datapi Jobs class with mocked request_ids property.""" - - __test__ = False - - def __init__(self, request_ids: list[str]) -> None: - self._request_ids = request_ids - - @property - def request_ids(self): - return self._request_ids - - -class TestRemote(datapi.processing.Remote): - """Datapi Remote class with mocked properties.""" - - __test__ = False - - def __init__(self, request_id: str, status: str, results_ready: bool, request: dict) -> None: - self._request_id = request_id - self._status = status - self._results_ready = results_ready - self._request = request - self.cleanup = False - - @property - def status(self): - return self._status - - @property - def request_id(self): - return self._request_id - - @property - def results_ready(self): - return self._results_ready - - @property - def request(self): - return self._request - - -@patch("datapi.ApiClient.get_jobs") -@patch("datapi.ApiClient.get_remote") -def test_cds_get_remote_requests(mock_get_remote: Mock, mock_get_jobs: Mock, fake_cds: CDS): - mock_get_jobs.return_value = TestJobs( - request_ids=[ - "73dc0d2d-8288-4041-a84d-87e70772d5a8", - "3973ec55-4b38-449b-b7f1-5edd1034f663", - "a5c7093d-56d9-40a4-a363-c60cd242ce66", - ] - ) - - mock_get_remote.return_value = TestRemote( - request_id="73dc0d2d-8288-4041-a84d-87e70772d5a8", status="successful", results_ready=True, request={} - ) - - remote_requests = fake_cds.get_remote_requests() - - assert len(remote_requests) == 3 - assert remote_requests[0]["request_id"] == "73dc0d2d-8288-4041-a84d-87e70772d5a8" - assert isinstance(remote_requests[0]["request"], dict) - - -@pytest.fixture -def tp_request() -> DataRequest: - return DataRequest( - variable=["total_precipitation"], - year="2024", - month="12", - day=["01", "02", "03", "04", "05"], - time=["01:00", "06:00", "18:00"], - data_format="grib", - area=[16, -6, 9, 3], - ) - - -@pytest.fixture -def tp_request_remote() -> dict: - return { - "request_id": "73dc0d2d-8288-4041-a84d-87e70772d5a8", - "request": { - "day": ["01", "02", "03", "04", "05"], - "area": [16, -6, 9, 3], - "time": ["01:00", "06:00", "18:00"], - "year": "2024", - "month": "12", - "variable": ["total_precipitation"], - "data_format": "grib", - }, - } - - -@patch("datapi.ApiClient.get_remote") -def test_cds_get_remote_from_request( - mock_get_remote: Mock, fake_cds: CDS, tp_request: DataRequest, tp_request_remote: dict -): - mock_get_remote.return_value = TestRemote( - request_id="73dc0d2d-8288-4041-a84d-87e70772d5a8", - status="successful", - results_ready=True, - request=tp_request_remote, - ) - - existing_requests = [tp_request_remote] - remote = fake_cds.get_remote_from_request(tp_request, existing_requests=existing_requests) - assert remote - assert remote.request_id == "73dc0d2d-8288-4041-a84d-87e70772d5a8" - assert remote.request["request"] == tp_request.__dict__ diff --git a/tests/era5/test_dhis2weeks.py b/tests/era5/test_dhis2weeks.py new file mode 100644 index 00000000..64b7be0f --- /dev/null +++ b/tests/era5/test_dhis2weeks.py @@ -0,0 +1,52 @@ +"""Test dhis2weeks module.""" + +from datetime import date + +from openhexa.toolbox.era5.dhis2weeks import WeekType, get_calendar_week + + +def test_standard_iso_week(): + # 2024 Jan 1 is a Monday, so we expect it to be week 1 of 2024 + dt = date(2024, 1, 1) + assert get_calendar_week(dt, WeekType.WEEK) == (2024, 1) + + # 2023 Dec 31 is a Sunday, so it belongs to week 52 of 2023 + dt = date(2023, 12, 31) + assert get_calendar_week(dt, WeekType.WEEK) == (2023, 52) + + +def test_sunday_week_year_boundary(): + """Test Sunday weeks crossing year boundaries.""" + # 2023 Dec 31 is a Sunday starting a week containing Jan 4th, so it should belong + # to week 1 of 2024 for Sunday weeks + dt = date(2023, 12, 31) + assert get_calendar_week(dt, WeekType.WEEK_SUNDAY) == (2024, 1) + + # Next Sunday should be in Week 2 + dt = date(2024, 1, 7) + assert get_calendar_week(dt, WeekType.WEEK_SUNDAY) == (2024, 2) + + +def test_saturday_week_year_start(): + """Test Saturday weeks when year starts on Saturday.""" + # Jan 1 2022 is a Saturday so it should be week 1 of 2022 for Saturday weeks + # However, for Sunday weeks it should belong to the last week of 2021 + dt = date(2022, 1, 1) + assert get_calendar_week(dt, WeekType.WEEK_SATURDAY) == (2022, 1) + assert get_calendar_week(dt, WeekType.WEEK_SUNDAY) == (2021, 52) + + +def test_different_week_types_same_date(): + """Test that the same date can belong to different year/weeks.""" + # 2022 Jan 1 is a Saturday and is expected to be: + # - Week 52 of 2021 for standard ISO weeks (Monday start) + # - Week 1 of 2022 for Wednesday weeks + # - Week 1 of 2022 for Thursday weeks + # - Week 1 of 2022 for Saturday weeks + # - Week 52 of 2021 for Sunday weeks + dt = date(2022, 1, 1) + assert get_calendar_week(dt, WeekType.WEEK) == (2021, 52) + assert get_calendar_week(dt, WeekType.WEEK_WEDNESDAY) == (2022, 1) + assert get_calendar_week(dt, WeekType.WEEK_THURSDAY) == (2022, 1) + assert get_calendar_week(dt, WeekType.WEEK_SATURDAY) == (2022, 1) + assert get_calendar_week(dt, WeekType.WEEK_SUNDAY) == (2021, 52) diff --git a/tests/era5/test_extract.py b/tests/era5/test_extract.py new file mode 100644 index 00000000..26438d5f --- /dev/null +++ b/tests/era5/test_extract.py @@ -0,0 +1,261 @@ +"""Test requests to the CDS API and handling of responses.""" + +import shutil +import tarfile +import tempfile +from datetime import date, datetime +from pathlib import Path +from unittest.mock import Mock + +import numpy as np +import pytest +import xarray as xr + +from openhexa.toolbox.era5.extract import ( + Client, + Remote, + Request, + _bound_date_range, + _get_temporal_chunks, + _submit_requests, + get_date_range, + grib_to_zarr, + prepare_requests, + retrieve_requests, +) + + +@pytest.fixture +def mock_client() -> Mock: + client = Mock(spec=Client) + collection = Mock() + collection.begin_datetime = datetime(2020, 1, 1) + collection.end_datetime = datetime(2025, 4, 4) + client.get_collection.return_value = collection + return client + + +@pytest.fixture +def mock_request() -> Request: + return { + "variable": ["2m_temperature"], + "year": "2025", + "month": "03", + "day": ["28", "29", "30", "31"], + "time": ["01:00", "07:00", "13:00", "19:00"], + "data_format": "grib", + "download_format": "unarchived", + "area": [10, -1, 8, 1], + } + + +@pytest.fixture +def sample_grib_file_march() -> Path: + """Small sample GRIB file with 2m_temperature data for March.""" + return Path(__file__).parent / "data" / "sample_202503.grib" + + +@pytest.fixture +def sample_grib_file_april() -> Path: + """Small sample GRIB file with 2m_temperature data for April.""" + return Path(__file__).parent / "data" / "sample_202504.grib" + + +@pytest.fixture +def sample_zarr_store() -> Path: + """Path to a sample Zarr store with data from sample GRIB files.""" + return Path(__file__).parent / "data" / "sample_2m_temperature.zarr.tar.gz" + + +def test_prepare_requests(mock_client): + requests = prepare_requests( + client=mock_client, + dataset_id="reanalysis-era5-land", + start_date=date(2025, 3, 28), + end_date=date(2025, 4, 5), + variable="2m_temperature", + area=[10, -1, 8, 1], + zarr_store=Path("/tmp/do-not-exist.zarr"), + ) + + # The mock client has collection end date of 2025-04-04 + # So we expect requests only up to 2025-04-04 despire the requested end date + # We also expect 2 prepared requests: one for March and one for April + assert len(requests) == 2 + assert requests[0] == { + "variable": ["2m_temperature"], + "year": "2025", + "month": "03", + "day": ["28", "29", "30", "31"], + "time": ["01:00", "07:00", "13:00", "19:00"], + "data_format": "grib", + "download_format": "unarchived", + "area": [10, -1, 8, 1], + } + assert requests[1] == { + "variable": ["2m_temperature"], + "year": "2025", + "month": "04", + "day": ["01", "02", "03", "04"], + "time": ["01:00", "07:00", "13:00", "19:00"], + "data_format": "grib", + "download_format": "unarchived", + "area": [10, -1, 8, 1], + } + + +def test_prepare_requests_with_existing_data(sample_zarr_store, mock_client): + # Sample zarr store has data from 2025-03-28 to 2025-04-05 + with tempfile.TemporaryDirectory() as tmpdir, tarfile.open(sample_zarr_store, "r:gz") as tar: + tar.extractall(path=tmpdir, filter="data") + zarr_store = Path(tmpdir) / "2m_temperature.zarr" + requests = prepare_requests( + client=mock_client, + dataset_id="reanalysis-era5-land", + start_date=date(2025, 3, 27), + end_date=date(2025, 4, 6), + variable="2m_temperature", + area=[10, -1, 8, 1], + zarr_store=zarr_store, + ) + + # In the sample zarr store, we already have data between 2025-03-28 and 2025-04-05 + # In the mock client, the end date of the collection is 2025-04-04 + # As a result, we expect only 1 request to be prepared: for 2025-03-27 + assert len(requests) == 1 + assert requests[0] == { + "variable": ["2m_temperature"], + "year": "2025", + "month": "03", + "day": ["27"], + "time": ["01:00", "07:00", "13:00", "19:00"], + "data_format": "grib", + "download_format": "unarchived", + "area": [10, -1, 8, 1], + } + + +def test_submit_requests(mock_client, mock_request): + remote = Mock(spec=Remote) + mock_client.submit.return_value = remote + remotes = _submit_requests( + client=mock_client, + collection_id="reanalysis-era5-land", + requests=[mock_request, mock_request], + ) + # We expect 1 remote per request here + assert len(remotes) == 2 + + +def test_retrieve_requests(mock_client, mock_request): + remote1 = Mock(spec=Remote) + remote1.request_id = "remote1" + remote1.request = mock_request + remote1.status = "successful" + remote1.results_ready = True + remote1.download = Mock(side_effect=lambda target: Path(target).touch()) + + remote2 = Mock(spec=Remote) + remote2.request_id = "remote2" + remote2.request = mock_request + remote2.status = "successful" + remote2.results_ready = True + remote2.download = Mock(side_effect=lambda target: Path(target).touch()) + + mock_client.submit.side_effect = [remote1, remote2] + + with tempfile.TemporaryDirectory() as tmpdir: + retrieve_requests( + client=mock_client, + dataset_id="reanalysis-era5-land", + requests=[mock_request, mock_request], + dst_dir=Path(tmpdir), + wait=0, + ) + # We expect 2 grib files to have been downloaded + assert len(list(Path(tmpdir).glob("*.grib"))) == 2 + + +def test_get_date_range(): + start = date(2024, 12, 27) + end = date(2025, 1, 3) + result = get_date_range(start, end) + assert result == [ + date(2024, 12, 27), + date(2024, 12, 28), + date(2024, 12, 29), + date(2024, 12, 30), + date(2024, 12, 31), + date(2025, 1, 1), + date(2025, 1, 2), + date(2025, 1, 3), + ] + + +def test_get_date_range_single_day(): + start = date(2025, 3, 15) + end = date(2025, 3, 15) + result = get_date_range(start, end) + assert result == [date(2025, 3, 15)] + + +def test_get_date_range_invalid(): + start = date(2025, 3, 15) + end = date(2025, 3, 14) + with pytest.raises(ValueError, match="Start date must be before end date"): + get_date_range(start, end) + + +def test_bound_date_range(): + start = date(2024, 12, 27) + end = date(2025, 1, 3) + collection_start = date(2024, 1, 1) + collection_end = date(2024, 12, 31) + bounded_start, bounded_end = _bound_date_range(start, end, collection_start, collection_end) + assert bounded_start == date(2024, 12, 27) + assert bounded_end == date(2024, 12, 31) + + +def test_get_temporal_chunks(): + dates = [ + date(2024, 1, 31), + date(2024, 2, 1), + date(2024, 2, 15), + date(2024, 3, 1), + ] + result = _get_temporal_chunks(dates) + + # We expect 3 chunks: one per month + assert len(result) == 3 + assert result[0]["year"] == "2024" + assert result[0]["month"] == "01" + assert result[0]["day"] == ["31"] + assert result[1]["year"] == "2024" + assert result[1]["month"] == "02" + assert result[1]["day"] == ["01", "15"] + assert result[2]["year"] == "2024" + assert result[2]["month"] == "03" + assert result[2]["day"] == ["01"] + + +def test_grib_to_zarr(sample_grib_file_march, sample_grib_file_april): + def _move_grib_to_tmp_dir(grib_file: Path, dst_dir: Path): + dst_file = dst_dir / grib_file.name + shutil.copy(grib_file, dst_file) + + with tempfile.TemporaryDirectory() as tmpdir: + grib_dir = Path(tmpdir) / "grib_files" + grib_dir.mkdir() + _move_grib_to_tmp_dir(sample_grib_file_march, grib_dir) + _move_grib_to_tmp_dir(sample_grib_file_april, grib_dir) + zarr_store = Path(tmpdir) / "store.zarr" + grib_to_zarr(grib_dir, zarr_store, "t2m") + # We expect the Zarr store have been created and contains data from both GRIB files + # Sample GRIB files contains data from 2025-03-28 to 2025-04-05 (9 days) + assert zarr_store.exists() + ds = xr.open_zarr(zarr_store, decode_timedelta=True) + assert "t2m" in ds + times = np.array(ds["time"]) + assert np.datetime_as_string(times[0], unit="D") == "2025-03-28" + assert np.datetime_as_string(times[-1], unit="D") == "2025-04-05" + assert len(times) == 9 * 4 # 9 days, 4 time steps per day diff --git a/tests/era5/test_transform.py b/tests/era5/test_transform.py new file mode 100644 index 00000000..fd1f6862 --- /dev/null +++ b/tests/era5/test_transform.py @@ -0,0 +1,167 @@ +"""Test transform module.""" + +import tarfile +import tempfile +from pathlib import Path + +import geopandas as gpd +import numpy as np +import polars as pl +import pytest +import xarray as xr + +from openhexa.toolbox.era5.transform import ( + Period, + aggregate_in_space, + aggregate_in_time, + calculate_relative_humidity, + calculate_wind_speed, + create_masks, +) + + +@pytest.fixture +def sample_boundaries() -> gpd.GeoDataFrame: + fp = Path(__file__).parent / "data" / "geoms.parquet" + return gpd.read_parquet(fp) + + +@pytest.fixture +def sample_dataset() -> xr.Dataset: + archive = Path(__file__).parent / "data" / "sample_2m_temperature.zarr.tar.gz" + with tempfile.TemporaryDirectory() as tmp_dir: + with tarfile.open(archive, "r:gz") as tar: + tar.extractall(path=tmp_dir, filter="data") + ds = xr.open_zarr(Path(tmp_dir) / "2m_temperature.zarr", decode_timedelta=False) + ds.load() + return ds + + +def test_create_masks(sample_boundaries, sample_dataset): + masks = create_masks(gdf=sample_boundaries, id_column="boundary_id", ds=sample_dataset) + # We have 4 boundaries in the sample data and the dataset has 21x21 lat/lon points + assert masks.shape == (4, 21, 21) + assert masks.dims == ("boundary", "latitude", "longitude") + # Each boolean mask should contain only 0s and 1s, and at least some 1s + assert sorted(np.unique(masks.data).tolist()) == [0, 1] + assert np.count_nonzero(masks.data) > 100 + + +@pytest.fixture +def sample_masks(sample_boundaries, sample_dataset) -> xr.DataArray: + return create_masks(gdf=sample_boundaries, id_column="boundary_id", ds=sample_dataset) + + +def test_aggregate_in_space(sample_dataset, sample_masks): + ds = sample_dataset.mean(dim="step") + df = aggregate_in_space(ds=ds, masks=sample_masks, data_var="t2m", agg="mean") + + # We have 4 boundaries and 9 days in the sample data, so shape should be 9*4=36 rows + # and 3 columns (boundary, time, value) + assert df.shape == (36, 3) + + expected = pl.Schema({"boundary": pl.String, "time": pl.Date, "value": pl.Float64}) + assert df.schema == expected + + assert df["boundary"].n_unique() == 4 + assert df["time"].n_unique() == 9 + + assert pytest.approx(df["value"].min(), 0.1) == 302.38 + assert pytest.approx(df["value"].max(), 0.1) == 307.44 + + # The following aggregation methods do not make sense for 2m_temperature, + # but values should match expected results nonetheless + df = aggregate_in_space(ds=ds, masks=sample_masks, data_var="t2m", agg="sum") + assert df.shape == (36, 3) + assert pytest.approx(df["value"].min(), 0.1) == 11589.94 + assert pytest.approx(df["value"].max(), 0.1) == 68461.84 + df = aggregate_in_space(ds=ds, masks=sample_masks, data_var="t2m", agg="max") + assert df.shape == (36, 3) + assert pytest.approx(df["value"].min(), 0.1) == 305.08 + assert pytest.approx(df["value"].max(), 0.1) == 308.07 + + +def test_aggregate_in_time(sample_masks, sample_dataset): + ds = sample_dataset.mean(dim="step") + df = aggregate_in_space(ds=ds, masks=sample_masks, data_var="t2m", agg="mean") + + weekly = aggregate_in_time(df, Period.WEEK, agg="mean") + # 4 boundaries * 2 weeks = 8 rows + assert weekly.shape[0] == 8 + assert weekly.schema == pl.Schema( + {"boundary": pl.String, "period": pl.String, "value": pl.Float64}, + ) + assert set(weekly.columns) == {"boundary", "period", "value"} + assert weekly["period"].str.starts_with("2025W").all() + + sunday_weekly = aggregate_in_time(df, Period.WEEK_SUNDAY, agg="mean") + assert sunday_weekly["period"].str.starts_with("2025SunW").all() + + monthly = aggregate_in_time(df, Period.MONTH, agg="mean") + assert monthly.shape[0] == 8 # 4 boundaries * 2 months + assert "202503" in monthly["period"].to_list() + + +def test_calculate_relative_humidty(): + t2m = xr.DataArray( + np.array([[[300.0, 305.0]], [[310.0, 290.0]]]), + dims=["time", "latitude", "longitude"], + coords={ + "time": ["2025-01-01", "2025-01-02"], + "latitude": [45.0], + "longitude": [10.0, 11.0], + }, + ) + + # when t2m == d2m, RH should be 100% + d2m = t2m.copy() + result = calculate_relative_humidity(t2m, d2m) + + assert "rh" in result.data_vars + assert result["rh"].dims == ("time", "latitude", "longitude") + assert result["rh"].attrs["units"] == "%" + assert np.allclose(result["rh"].values, 100.0, rtol=0.01) + + # RH should be between 0 and 100% with lower dewpoint + d2m_lower = t2m - 10.0 + result2 = calculate_relative_humidity(t2m, d2m_lower) + assert (result2["rh"] < 100.0).all() + assert (result2["rh"] > 0.0).all() + + # check that clipping works (dewpoint higher than temperature = invalid, should clip to 100%) + d2m_higher = t2m + 5.0 + result_clipped = calculate_relative_humidity(t2m, d2m_higher) + assert (result_clipped["rh"] <= 100.0).all() + + +def test_calculate_wind_speed(): + u10 = xr.DataArray( + np.array([[[0.0, 3.0]], [[4.0, 5.0]]]), + dims=["time", "latitude", "longitude"], + coords={ + "time": ["2025-01-01", "2025-01-02"], + "latitude": [45.0], + "longitude": [10.0, 11.0], + }, + ) + v10 = xr.DataArray( + np.array([[[0.0, 4.0]], [[0.0, 12.0]]]), + dims=["time", "latitude", "longitude"], + coords={ + "time": ["2025-01-01", "2025-01-02"], + "latitude": [45.0], + "longitude": [10.0, 11.0], + }, + ) + + result = calculate_wind_speed(u10, v10) + + assert "ws" in result.data_vars + assert result["ws"].dims == ("time", "latitude", "longitude") + assert result["ws"].attrs["units"] == "m/s" + + expected = np.array([[[0.0, 5.0]], [[4.0, 13.0]]]) + assert np.allclose(result["ws"].values, expected, rtol=1e-10) + + # wind speed should always be non-negative + assert (result["ws"] >= 0).all()