diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 6622eb7..24f8c75 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -14,6 +14,9 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +# Special keys that should not be written to disk +_METADATA_KEYS = {"idx"} + # %% class CellMapDatasetWriter(Dataset): @@ -352,13 +355,34 @@ def __setitem__( if isinstance(idx, (torch.Tensor, np.ndarray, Sequence)): if isinstance(idx, torch.Tensor): idx = idx.cpu().numpy() - for i in idx: - self.__setitem__(i, arrays) + for batch_idx, i in enumerate(idx): + # Extract the data for this specific item in the batch + item_arrays = {} + for array_name, array in arrays.items(): + # Skip special metadata keys + if array_name in _METADATA_KEYS: + continue + if isinstance(array, (int, float)): + # Scalar values are the same for all items + item_arrays[array_name] = array + elif isinstance(array, dict): + # Dictionary of arrays - extract batch item from each + item_arrays[array_name] = { + label: label_array[batch_idx] + for label, label_array in array.items() + } + else: + # Regular array - extract the batch item + item_arrays[array_name] = array[batch_idx] + self.__setitem__(i, item_arrays) return self._current_idx = idx self._current_center = self.get_center(self._current_idx) for array_name, array in arrays.items(): + # Skip special metadata keys + if array_name in _METADATA_KEYS: + continue if isinstance(array, (int, float)): for label in self.classes: self.target_array_writers[array_name][label][ @@ -373,7 +397,7 @@ def __setitem__( for c, label in enumerate(self.classes): self.target_array_writers[array_name][label][ self._current_center - ] = array[:, c, ...] + ] = array[c, ...] def __repr__(self) -> str: """Returns a string representation of the dataset.""" diff --git a/src/cellmap_data/subdataset.py b/src/cellmap_data/subdataset.py index dba384d..c2eaf80 100644 --- a/src/cellmap_data/subdataset.py +++ b/src/cellmap_data/subdataset.py @@ -4,9 +4,9 @@ import torch from torch.utils.data import Subset -from .dataset_writer import CellMapDatasetWriter from .base_dataset import CellMapBaseDataset from .dataset import CellMapDataset +from .dataset_writer import CellMapDatasetWriter from .multidataset import CellMapMultiDataset from .mutable_sampler import MutableSubsetRandomSampler from .utils.sampling import min_redundant_inds diff --git a/tests/test_dataset_writer_batch.py b/tests/test_dataset_writer_batch.py new file mode 100644 index 0000000..fc36e44 --- /dev/null +++ b/tests/test_dataset_writer_batch.py @@ -0,0 +1,209 @@ +""" +Tests for CellMapDatasetWriter batch operations. + +Tests that the writer correctly handles batched write operations. +""" + +import numpy as np +import pytest +import torch + +from cellmap_data import CellMapDatasetWriter + +from .test_helpers import create_test_dataset + + +class TestDatasetWriterBatchOperations: + """Test suite for batch write operations in DatasetWriter.""" + + @pytest.fixture + def writer_setup(self, tmp_path): + """Create writer and config for batch write tests. + + Returns a tuple of (writer, config) where writer is a CellMapDatasetWriter + configured for testing batch operations. + """ + # Create input data + config = create_test_dataset( + tmp_path / "input", + raw_shape=(64, 64, 64), + num_classes=2, + raw_scale=(8.0, 8.0, 8.0), + ) + + # Output path + output_path = tmp_path / "output" / "predictions.zarr" + + target_bounds = { + "pred": { + "x": [0, 512], + "y": [0, 512], + "z": [0, 512], + } + } + + writer = CellMapDatasetWriter( + raw_path=config["raw_path"], + target_path=str(output_path), + classes=["class_0", "class_1"], + input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, + target_arrays={"pred": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, + target_bounds=target_bounds, + ) + + return writer, config + + def test_batch_write_with_tensor_indices(self, writer_setup): + """Test writing with a batch of tensor indices.""" + writer, config = writer_setup + + # Simulate batch predictions + batch_size = 8 + indices = torch.tensor(list(range(batch_size))) + + # Create predictions with shape (batch_size, num_classes, *spatial_dims) + predictions = torch.randn(batch_size, 2, 32, 32, 32) + + # This should not raise an error + writer[indices] = {"pred": predictions} + + def test_batch_write_with_numpy_indices(self, writer_setup): + """Test writing with a batch of numpy indices.""" + writer, config = writer_setup + + # Simulate batch predictions + batch_size = 4 + indices = np.array(list(range(batch_size))) + + # Create predictions + predictions = np.random.randn(batch_size, 2, 32, 32, 32).astype(np.float32) + + # This should not raise an error + writer[indices] = {"pred": predictions} + + def test_batch_write_with_list_indices(self, writer_setup): + """Test writing with a batch of list indices.""" + writer, config = writer_setup + + # Simulate batch predictions + batch_size = 4 + indices = [0, 1, 2, 3] + + # Create predictions + predictions = torch.randn(batch_size, 2, 32, 32, 32) + + # This should not raise an error + writer[indices] = {"pred": predictions} + + def test_batch_write_large_batch(self, writer_setup): + """Test writing with a large batch size (as in the error case).""" + writer, config = writer_setup + + # Simulate the error case: batch_size=32 + batch_size = 32 + indices = torch.tensor(list(range(batch_size))) + + # Create predictions with shape (32, 2, 32, 32, 32) + predictions = torch.randn(batch_size, 2, 32, 32, 32) + + # This should not raise ValueError about shape mismatch + writer[indices] = {"pred": predictions} + + def test_batch_write_with_dict_arrays(self, writer_setup): + """Test writing with dictionary of arrays per class.""" + writer, config = writer_setup + + batch_size = 4 + indices = torch.tensor(list(range(batch_size))) + + # Create predictions as dictionary + predictions_dict = { + "class_0": torch.randn(batch_size, 32, 32, 32), + "class_1": torch.randn(batch_size, 32, 32, 32), + } + + # This should not raise an error + writer[indices] = {"pred": predictions_dict} + + def test_batch_write_2d_data(self, tmp_path): + """Test batch writing for 2D data (3D with singleton z dimension).""" + # Import kept at module level; reuse create_test_dataset here + + # Create test dataset with thin Z dimension to simulate 2D + config = create_test_dataset( + tmp_path / "input", + raw_shape=(1, 128, 128), # Thin z dimension + num_classes=1, + raw_scale=(8.0, 4.0, 4.0), + ) + + output_path = tmp_path / "output_2d.zarr" + + target_bounds = { + "pred": { + "z": [0, 8], + "y": [0, 512], + "x": [0, 512], + } + } + + writer = CellMapDatasetWriter( + raw_path=config["raw_path"], + target_path=str(output_path), + classes=["class_0"], + input_arrays={"raw": {"shape": (1, 64, 64), "scale": (8.0, 4.0, 4.0)}}, + target_arrays={"pred": {"shape": (1, 64, 64), "scale": (8.0, 4.0, 4.0)}}, + axis_order="zyx", + target_bounds=target_bounds, + ) + + # Test batch write with thin-z 3D data + batch_size = 4 + indices = torch.tensor(list(range(batch_size))) + predictions = torch.randn(batch_size, 1, 1, 64, 64) + + # This should not raise an error + writer[indices] = {"pred": predictions} + + def test_single_item_write_still_works(self, writer_setup): + """Test that single item writes still work correctly.""" + writer, config = writer_setup + + # Single item write + idx = 0 + predictions = torch.randn(2, 32, 32, 32) + + # This should work as before + writer[idx] = {"pred": predictions} + + def test_batch_write_with_scalar_values(self, writer_setup): + """Test batch writing with scalar values fills all spatial dims.""" + writer, config = writer_setup + + batch_size = 4 + indices = torch.tensor(list(range(batch_size))) + + # Scalar values should be broadcast to full arrays + # Create proper shaped arrays filled with the scalar value + scalar_val = 0.5 + predictions = torch.full((batch_size, 2, 32, 32, 32), scalar_val) + writer[indices] = {"pred": predictions} + + def test_batch_write_mixed_data_types(self, writer_setup): + """Test batch writing preserves data types.""" + writer, config = writer_setup + + batch_size = 4 + indices = torch.tensor(list(range(batch_size))) + + # Test with different dtypes + predictions_float32 = torch.randn( + batch_size, 2, 32, 32, 32, dtype=torch.float32 + ) + writer[indices] = {"pred": predictions_float32} + + predictions_float64 = torch.randn( + batch_size, 2, 32, 32, 32, dtype=torch.float64 + ) + indices2 = torch.tensor(list(range(batch_size, batch_size * 2))) + writer[indices2] = {"pred": predictions_float64}