Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 64 additions & 11 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import base64
import functools
import importlib.util
import json
import os
import struct
Expand Down Expand Up @@ -46,6 +48,19 @@
from xarray.core.types import ZarrArray, ZarrGroup


@functools.cache
def _has_unified_chunk_grid() -> bool:
"""Check if zarr has the unified ChunkGrid with is_regular support.

Defers the actual import so zarr stays lazy at module load time.
"""
if importlib.util.find_spec("zarr.core.chunk_grids") is None:
return False
from zarr.core.chunk_grids import ChunkGrid

return hasattr(ChunkGrid, "is_regular")


def _get_mappers(*, storage_options, store, chunk_store):
# expand str and path-like arguments
store = _normalize_path(store)
Expand Down Expand Up @@ -284,7 +299,7 @@ async def async_getitem(self, key):
)


def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name):
def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, zarr_format):
"""
Given encoding chunks (possibly None or []) and variable chunks
(possibly None or []).
Expand All @@ -306,18 +321,24 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name):
# while dask chunks can be variable sized
# https://dask.pydata.org/en/latest/array-design.html#chunks
if var_chunks and not enc_chunks:
if zarr_format == 3 and _has_unified_chunk_grid():
return tuple(var_chunks)

if any(len(set(chunks[:-1])) > 1 for chunks in var_chunks):
raise ValueError(
"Zarr requires uniform chunk sizes except for final chunk. "
"Zarr v2 requires uniform chunk sizes except for the final chunk. "
f"Variable named {name!r} has incompatible dask chunks: {var_chunks!r}. "
"Consider rechunking using `chunk()`."
"Consider rechunking using `chunk()`, or switching to the "
"zarr v3 format with zarr-python>=3.2."
)
if any((chunks[0] < chunks[-1]) for chunks in var_chunks):
raise ValueError(
"Final chunk of Zarr array must be the same size or smaller "
f"than the first. Variable named {name!r} has incompatible Dask chunks {var_chunks!r}."
"Consider either rechunking using `chunk()` or instead deleting "
"or modifying `encoding['chunks']`."
"The final chunk of a Zarr v2 array or a Zarr v3 array without the "
"rectilinear chunks extension must be the same size or smaller "
f"than the first. Variable named {name!r} has incompatible Dask "
f"chunks {var_chunks!r}. "
"Consider switching to Zarr v3 with the rectilinear chunks extension, "
"rechunking using `chunk()` or deleting or modifying `encoding['chunks']`."
)
# return the first chunk for each dimension
return tuple(chunk[0] for chunk in var_chunks)
Expand All @@ -340,8 +361,17 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name):
var_chunks,
ndim,
name,
zarr_format,
)

# Rectilinear chunks: each element is a sequence of per-chunk edge lengths
if (
zarr_format == 3
and _has_unified_chunk_grid()
and any(not isinstance(x, int) for x in enc_chunks_tuple)
):
return enc_chunks_tuple

for x in enc_chunks_tuple:
if not isinstance(x, int):
raise TypeError(
Expand Down Expand Up @@ -483,6 +513,7 @@ def extract_zarr_variable_encoding(
var_chunks=variable.chunks,
ndim=variable.ndim,
name=name,
zarr_format=zarr_format,
)
if _zarr_v3() and chunks is None:
chunks = "auto"
Expand Down Expand Up @@ -861,9 +892,25 @@ def open_store_variable(self, name):
)
attributes = dict(attributes)

if _has_unified_chunk_grid():
from zarr.core.metadata.v3 import RectilinearChunkGrid, RegularChunkGrid

chunk_grid = zarr_array.metadata.chunk_grid
if isinstance(chunk_grid, RegularChunkGrid):
chunks = chunk_grid.chunk_shape
elif isinstance(chunk_grid, RectilinearChunkGrid):
chunks = chunk_grid.chunk_shapes
else:
chunks = tuple(zarr_array.chunks)
preferred_chunks = dict(zip(dimensions, chunks, strict=True))
else:
# Fallback for older zarr-python without unified chunk grid
chunks = tuple(zarr_array.chunks)
preferred_chunks = dict(zip(dimensions, chunks, strict=True))

encoding = {
"chunks": zarr_array.chunks,
"preferred_chunks": dict(zip(dimensions, zarr_array.chunks, strict=True)),
"chunks": chunks,
"preferred_chunks": preferred_chunks,
}

if _zarr_v3():
Expand Down Expand Up @@ -1248,14 +1295,20 @@ def set_variables(
# parallel writes. See https://github.com/pydata/xarray/issues/10831
effective_write_chunks = encoding.get("shards") or encoding["chunks"]

if self._align_chunks and isinstance(effective_write_chunks, tuple):
# Rectilinear chunks are tuples-of-tuples — align_chunks and
# safe_chunks validation only apply to regular (flat tuple) chunks.
_is_regular_chunks = isinstance(effective_write_chunks, tuple) and all(
isinstance(c, int) for c in effective_write_chunks
)

if self._align_chunks and _is_regular_chunks:
v = grid_rechunk(
v=v,
enc_chunks=effective_write_chunks,
region=region,
)

if self._safe_chunks and isinstance(effective_write_chunks, tuple):
if self._safe_chunks and _is_regular_chunks:
# the hard case
# DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk
# (or shard, when sharding is enabled)
Expand Down
48 changes: 48 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7264,6 +7264,54 @@ def test_extract_zarr_variable_encoding() -> None:
)


@requires_zarr_v3
@requires_dask
def test_rectilinear_chunks_encoding_roundtrip(tmp_path: Path) -> None:
"""Rectilinear chunk sizes in encoding are passed through to zarr v3."""

import zarr

if not backends.zarr._has_unified_chunk_grid():
pytest.skip("zarr does not have unified ChunkGrid support")

chunk_sizes = [10, 20, 30]
data = np.arange(60, dtype="float32")
ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": tuple(chunk_sizes)})

store_path = tmp_path / "rectilinear.zarr"
encoding = {"var": {"chunks": [chunk_sizes]}}

with zarr.config.set({"array.rectilinear_chunks": True}):
ds.to_zarr(store_path, zarr_format=3, mode="w", encoding=encoding)

roundtrip = xr.open_zarr(store_path, zarr_format=3)
assert roundtrip.chunks["x"] == tuple(chunk_sizes)
np.testing.assert_array_equal(roundtrip["var"].values, data)


@requires_zarr_v3
@requires_dask
def test_rectilinear_chunks_no_encoding(tmp_path: Path) -> None:
"""Variable dask chunks are written as rectilinear when no encoding is given."""
import zarr

if not backends.zarr._has_unified_chunk_grid():
pytest.skip("zarr does not have unified ChunkGrid support")

chunk_sizes = [15, 25, 20]
data = np.arange(60, dtype="float32")
ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": tuple(chunk_sizes)})

store_path = tmp_path / "rectilinear_no_enc.zarr"

with zarr.config.set({"array.rectilinear_chunks": True}):
ds.to_zarr(store_path, zarr_format=3, mode="w")

roundtrip = xr.open_zarr(store_path, zarr_format=3)
assert roundtrip.chunks["x"] == tuple(chunk_sizes)
np.testing.assert_array_equal(roundtrip["var"].values, data)


@requires_zarr
@requires_fsspec
@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager")
Expand Down
Loading