Skip to content
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ New Features
- Added ``inherit='all_coords'`` option to :py:meth:`DataTree.to_dataset` to inherit
all parent coordinates, not just indexed ones (:issue:`10812`, :pull:`11230`).
By `Alfonso Ladino <https://github.com/aladinor>`_.
- Added complex dtype support to FillValueCoder for the Zarr backend. (:pull:`11151`)
By `Max Jones <https://github.com/maxrjones>`_.

Breaking Changes
~~~~~~~~~~~~~~~~
Expand Down
17 changes: 17 additions & 0 deletions properties/test_encode_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import hypothesis.extra.numpy as npst
import numpy as np
from hypothesis import given
from hypothesis import strategies as st

import xarray as xr
from xarray.coding.times import _parse_iso8601
Expand Down Expand Up @@ -48,6 +49,22 @@ def test_CFScaleOffset_coder_roundtrip(original) -> None:
xr.testing.assert_identical(original, roundtripped)


@given(
real=st.floats(allow_nan=True, allow_infinity=True),
imag=st.floats(allow_nan=True, allow_infinity=True),
dtype=st.sampled_from([np.complex64, np.complex128]),
)
def test_FillValueCoder_complex_roundtrip(real, imag, dtype) -> None:
from xarray.backends.zarr import FillValueCoder

value = dtype(complex(real, imag))
encoded = FillValueCoder.encode(value, np.dtype(dtype))
decoded = FillValueCoder.decode(encoded, np.dtype(dtype))
np.testing.assert_equal(
np.array(decoded, dtype=dtype), np.array(value, dtype=dtype)
)


@given(dt=datetimes())
def test_iso8601_decode(dt):
iso = dt.isoformat()
Expand Down
61 changes: 55 additions & 6 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,40 +122,89 @@ class FillValueCoder:
"""

@classmethod
def encode(cls, value: int | float | str | bytes, dtype: np.dtype[Any]) -> Any:
def encode(
cls, value: int | float | complex | str | bytes, dtype: np.dtype[Any]
) -> Any:
if dtype.kind == "S":
# byte string, this implies that 'value' must also be `bytes` dtype.
assert isinstance(value, bytes)
if not isinstance(value, bytes):
raise TypeError(
f"Failed to encode fill_value: expected bytes for dtype {dtype}, got {type(value).__name__}"
)
return base64.standard_b64encode(value).decode()
elif dtype.kind == "b":
# boolean
return bool(value)
elif dtype.kind in "iu":
# todo: do we want to check for decimals?
if not isinstance(value, int | float | np.integer | np.floating):
raise TypeError(
f"Failed to encode fill_value: expected int or float for dtype {dtype}, got {type(value).__name__}"
)
return int(value)
elif dtype.kind == "f":
if not isinstance(value, int | float | np.integer | np.floating):
raise TypeError(
f"Failed to encode fill_value: expected int or float for dtype {dtype}, got {type(value).__name__}"
)
return base64.standard_b64encode(struct.pack("<d", float(value))).decode()
elif dtype.kind == "c":
# complex - encode each component as base64, matching float encoding
if not isinstance(value, complex) and not np.issubdtype(
type(value), np.complexfloating
):
raise TypeError(
f"Failed to encode fill_value: expected complex for dtype {dtype}, got {type(value).__name__}"
)
return [
base64.standard_b64encode(
struct.pack("<d", float(value.real)) # type: ignore[union-attr]
).decode(),
base64.standard_b64encode(
struct.pack("<d", float(value.imag)) # type: ignore[union-attr]
).decode(),
]
elif dtype.kind == "U":
return str(value)
else:
raise ValueError(f"Failed to encode fill_value. Unsupported dtype {dtype}")

@classmethod
def decode(cls, value: int | float | str | bytes, dtype: str | np.dtype[Any]):
def decode(
cls, value: int | float | str | bytes | list, dtype: str | np.dtype[Any]
):
if dtype == "string":
# zarr V3 string type
return str(value)
elif dtype == "bytes":
# zarr V3 bytes type
assert isinstance(value, str | bytes)
if not isinstance(value, str | bytes):
raise TypeError(
f"Failed to decode fill_value: expected str or bytes for dtype {dtype}, got {type(value).__name__}"
)
return base64.standard_b64decode(value)
np_dtype = np.dtype(dtype)
if np_dtype.kind == "f":
assert isinstance(value, str | bytes)
if not isinstance(value, str | bytes):
raise TypeError(
f"Failed to decode fill_value: expected str or bytes for dtype {np_dtype}, got {type(value).__name__}"
)
return struct.unpack("<d", base64.standard_b64decode(value))[0]
elif np_dtype.kind == "c":
# complex - decode each component from base64, matching float decoding
if not (isinstance(value, list | tuple) and len(value) == 2):
raise TypeError(
f"Failed to decode fill_value: expected a 2-element list for dtype {np_dtype}, got {type(value).__name__}"
)
real = struct.unpack("<d", base64.standard_b64decode(value[0]))[0]
imag = struct.unpack("<d", base64.standard_b64decode(value[1]))[0]
return complex(real, imag)
elif np_dtype.kind == "b":
return bool(value)
elif np_dtype.kind in "iu":
if not isinstance(value, int | float | np.integer | np.floating):
raise TypeError(
f"Failed to decode fill_value: expected int or float for dtype {np_dtype}, got {type(value).__name__}"
)
return int(value)
else:
raise ValueError(f"Failed to decode fill_value. Unsupported dtype {dtype}")
Expand Down
35 changes: 35 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7241,6 +7241,41 @@ def test_encode_zarr_attr_value() -> None:
assert actual3 == expected3


@requires_zarr
@pytest.mark.parametrize("dtype", [complex, np.complex64, np.complex128])
def test_fill_value_coder_complex(dtype) -> None:
"""Test that FillValueCoder round-trips complex fill values."""
from xarray.backends.zarr import FillValueCoder

for value in [dtype(1 + 2j), dtype(-3.5 + 4.5j), dtype(complex("nan+nanj"))]:
encoded = FillValueCoder.encode(value, np.dtype(dtype))
decoded = FillValueCoder.decode(encoded, np.dtype(dtype))
np.testing.assert_equal(np.array(decoded, dtype=dtype), np.array(value))


@requires_zarr
@pytest.mark.parametrize(
"value,dtype",
[
(np.float32(np.inf), np.float32),
(np.float32(-np.inf), np.float32),
(np.float64(np.inf), np.float64),
(np.float64(-np.inf), np.float64),
(np.float32(np.nan), np.float32),
(np.float64(np.nan), np.float64),
],
)
def test_fill_value_coder_inf_nan(value, dtype) -> None:
"""Test that FillValueCoder round-trips inf and nan fill values."""
from xarray.backends.zarr import FillValueCoder

encoded = FillValueCoder.encode(value, np.dtype(dtype))
decoded = FillValueCoder.decode(encoded, np.dtype(dtype))
np.testing.assert_equal(
np.array(decoded, dtype=dtype), np.array(value, dtype=dtype)
)


@requires_zarr
def test_extract_zarr_variable_encoding() -> None:
var = xr.Variable("x", [1, 2])
Expand Down
Loading