Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3705.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a performance bug in morton curve generation.
19 changes: 12 additions & 7 deletions src/zarr/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Iterator, Sequence
from dataclasses import dataclass
from enum import Enum
from functools import reduce
from functools import lru_cache, reduce
from types import EllipsisType
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -1467,16 +1467,21 @@ def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]:
return tuple(out)


def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
i = 0
@lru_cache
def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
n_total = product(chunk_shape)
order: list[tuple[int, ...]] = []
while len(order) < product(chunk_shape):
i = 0
while len(order) < n_total:
m = decode_morton(i, chunk_shape)
if m not in order and all(x < y for x, y in zip(m, chunk_shape, strict=False)):
if all(x < y for x, y in zip(m, chunk_shape, strict=False)):
order.append(m)
i += 1
for j in range(product(chunk_shape)):
yield order[j]
return tuple(order)


def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
return iter(_morton_order(tuple(chunk_shape)))


def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
Expand Down
66 changes: 52 additions & 14 deletions tests/test_codecs/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TransposeCodec,
)
from zarr.core.buffer import default_buffer_prototype
from zarr.core.indexing import BasicSelection, morton_order_iter
from zarr.core.indexing import BasicSelection, decode_morton, morton_order_iter
from zarr.core.metadata.v3 import ArrayV3Metadata
from zarr.dtype import UInt8
from zarr.errors import ZarrUserWarning
Expand Down Expand Up @@ -171,7 +171,8 @@ def test_open(store: Store) -> None:
assert a.metadata == b.metadata


def test_morton() -> None:
def test_morton_exact_order() -> None:
"""Test exact morton ordering for power-of-2 shapes."""
assert list(morton_order_iter((2, 2))) == [(0, 0), (1, 0), (0, 1), (1, 1)]
assert list(morton_order_iter((2, 2, 2))) == [
(0, 0, 0),
Expand Down Expand Up @@ -206,21 +207,58 @@ def test_morton() -> None:
@pytest.mark.parametrize(
"shape",
[
[2, 2, 2],
[5, 2],
[2, 5],
[2, 9, 2],
[3, 2, 12],
[2, 5, 1],
[4, 3, 6, 2, 7],
[3, 2, 1, 6, 4, 5, 2],
(2, 2, 2),
(5, 2),
(2, 5),
(2, 9, 2),
(3, 2, 12),
(2, 5, 1),
(4, 3, 6, 2, 7),
(3, 2, 1, 6, 4, 5, 2),
(1,),
(1, 1),
(5, 1, 3),
(1, 4, 1, 2),
],
)
def test_morton2(shape: tuple[int, ...]) -> None:
def test_morton_is_permutation(shape: tuple[int, ...]) -> None:
"""Test that morton_order_iter produces every valid coordinate exactly once."""
import itertools

from zarr.core.common import product

order = list(morton_order_iter(shape))
expected_len = product(shape)
# completeness: every valid coordinate is present
assert len(order) == expected_len
# no duplicates
assert len(set(order)) == expected_len
# all coordinates are within bounds
assert all(all(c < s for c, s in zip(coord, shape, strict=True)) for coord in order)
# the set of coordinates equals the full cartesian product
assert set(order) == set(itertools.product(*(range(s) for s in shape)))


@pytest.mark.parametrize(
"shape",
[
(2, 2),
(4, 4),
(2, 2, 2),
(4, 4, 4),
(2, 2, 2, 2),
],
)
def test_morton_ordering(shape: tuple[int, ...]) -> None:
"""Test that the iteration order matches consecutive decode_morton outputs.

For power-of-2 shapes, every decode_morton output is in-bounds,
so the ordering should be exactly decode_morton(0), decode_morton(1), ...
"""

order = list(morton_order_iter(shape))
for i, x in enumerate(order):
assert x not in order[:i] # no duplicates
assert all(x[j] < shape[j] for j in range(len(shape))) # all indices are within bounds
for i, coord in enumerate(order):
assert coord == decode_morton(i, shape)


@pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"])
Expand Down