Skip to content

Commit 3096379

Browse files
Add edge case tests for CellMapImage properties and methods
Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com>
1 parent bba7431 commit 3096379

File tree

1 file changed

+378
-0
lines changed

1 file changed

+378
-0
lines changed

tests/test_image_edge_cases.py

Lines changed: 378 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,378 @@
1+
"""Tests for CellMapImage edge cases and special methods."""
2+
3+
import pytest
4+
import torch
5+
import numpy as np
6+
7+
from cellmap_data import CellMapImage
8+
9+
from .test_helpers import create_test_image_data, create_test_zarr_array
10+
11+
12+
class TestCellMapImageEdgeCases:
13+
"""Test edge cases and special methods in CellMapImage."""
14+
15+
@pytest.fixture
16+
def test_zarr_image(self, tmp_path):
17+
"""Create a test Zarr image."""
18+
data = create_test_image_data((32, 32, 32), pattern="gradient")
19+
path = tmp_path / "test_image.zarr"
20+
create_test_zarr_array(path, data, scale=(4.0, 4.0, 4.0))
21+
return str(path), data
22+
23+
def test_axis_order_longer_than_scale(self, test_zarr_image):
24+
"""Test handling when axis_order has more axes than target_scale."""
25+
path, _ = test_zarr_image
26+
27+
# Provide fewer scale values than axes
28+
image = CellMapImage(
29+
path=path,
30+
target_class="test_class",
31+
target_scale=(4.0, 4.0), # Only 2 values for 3 axes
32+
target_voxel_shape=(16, 16, 16),
33+
axis_order="zyx", # 3 axes
34+
)
35+
36+
# Should pad scale with first value
37+
assert len(image.scale) == 3
38+
assert image.scale["z"] == 4.0 # Padded value
39+
assert image.scale["y"] == 4.0
40+
assert image.scale["x"] == 4.0
41+
42+
def test_axis_order_longer_than_shape(self, test_zarr_image):
43+
"""Test handling when axis_order has more axes than target_voxel_shape."""
44+
path, _ = test_zarr_image
45+
46+
# Provide fewer shape values than axes
47+
image = CellMapImage(
48+
path=path,
49+
target_class="test_class",
50+
target_scale=(4.0, 4.0, 4.0),
51+
target_voxel_shape=(16, 16), # Only 2 values for 3 axes
52+
axis_order="zyx", # 3 axes
53+
)
54+
55+
# Should pad shape with 1s
56+
assert len(image.output_shape) == 3
57+
assert image.output_shape["z"] == 1 # Padded value
58+
assert image.output_shape["y"] == 16
59+
assert image.output_shape["x"] == 16
60+
61+
def test_device_auto_selection_cuda(self, test_zarr_image):
62+
"""Test device auto-selection when no device specified."""
63+
path, _ = test_zarr_image
64+
65+
image = CellMapImage(
66+
path=path,
67+
target_class="test_class",
68+
target_scale=(4.0, 4.0, 4.0),
69+
target_voxel_shape=(16, 16, 16),
70+
)
71+
72+
# Should select an appropriate device
73+
assert image.device in ["cuda", "mps", "cpu"]
74+
75+
def test_explicit_device_selection(self, test_zarr_image):
76+
"""Test explicit device selection."""
77+
path, _ = test_zarr_image
78+
79+
image = CellMapImage(
80+
path=path,
81+
target_class="test_class",
82+
target_scale=(4.0, 4.0, 4.0),
83+
target_voxel_shape=(16, 16, 16),
84+
device="cpu",
85+
)
86+
87+
assert image.device == "cpu"
88+
89+
def test_to_device_method(self, test_zarr_image):
90+
"""Test moving image to different device."""
91+
path, _ = test_zarr_image
92+
93+
image = CellMapImage(
94+
path=path,
95+
target_class="test_class",
96+
target_scale=(4.0, 4.0, 4.0),
97+
target_voxel_shape=(16, 16, 16),
98+
)
99+
100+
# Move to CPU
101+
image.to("cpu")
102+
assert image.device == "cpu"
103+
104+
def test_set_spatial_transforms_none(self, test_zarr_image):
105+
"""Test setting spatial transforms to None."""
106+
path, _ = test_zarr_image
107+
108+
image = CellMapImage(
109+
path=path,
110+
target_class="test_class",
111+
target_scale=(4.0, 4.0, 4.0),
112+
target_voxel_shape=(16, 16, 16),
113+
)
114+
115+
# Set to None
116+
image.set_spatial_transforms(None)
117+
assert image._current_spatial_transforms is None
118+
119+
def test_set_spatial_transforms_with_values(self, test_zarr_image):
120+
"""Test setting spatial transforms with actual transform dict."""
121+
path, _ = test_zarr_image
122+
123+
image = CellMapImage(
124+
path=path,
125+
target_class="test_class",
126+
target_scale=(4.0, 4.0, 4.0),
127+
target_voxel_shape=(16, 16, 16),
128+
)
129+
130+
# Set transforms
131+
transforms = {"mirror": {"axes": {"x": 0.5}}}
132+
image.set_spatial_transforms(transforms)
133+
assert image._current_spatial_transforms == transforms
134+
135+
def test_bounding_box_property(self, test_zarr_image):
136+
"""Test the bounding_box property."""
137+
path, _ = test_zarr_image
138+
139+
image = CellMapImage(
140+
path=path,
141+
target_class="test_class",
142+
target_scale=(4.0, 4.0, 4.0),
143+
target_voxel_shape=(16, 16, 16),
144+
)
145+
146+
bbox = image.bounding_box
147+
148+
# Should be a dict with axis keys
149+
assert isinstance(bbox, dict)
150+
for axis in "zyx":
151+
assert axis in bbox
152+
assert len(bbox[axis]) == 2
153+
assert bbox[axis][0] <= bbox[axis][1]
154+
155+
def test_sampling_box_property(self, test_zarr_image):
156+
"""Test the sampling_box property."""
157+
path, _ = test_zarr_image
158+
159+
image = CellMapImage(
160+
path=path,
161+
target_class="test_class",
162+
target_scale=(4.0, 4.0, 4.0),
163+
target_voxel_shape=(16, 16, 16),
164+
)
165+
166+
sbox = image.sampling_box
167+
168+
# Should be a dict with axis keys
169+
assert isinstance(sbox, dict)
170+
for axis in "zyx":
171+
assert axis in sbox
172+
assert len(sbox[axis]) == 2
173+
174+
def test_class_counts_property(self, test_zarr_image):
175+
"""Test the class_counts property."""
176+
path, _ = test_zarr_image
177+
178+
image = CellMapImage(
179+
path=path,
180+
target_class="test_class",
181+
target_scale=(4.0, 4.0, 4.0),
182+
target_voxel_shape=(16, 16, 16),
183+
)
184+
185+
counts = image.class_counts
186+
187+
# Should be a numeric value or dict
188+
assert isinstance(counts, (int, float, dict))
189+
190+
def test_pad_parameter_true(self, test_zarr_image):
191+
"""Test padding when pad=True."""
192+
path, _ = test_zarr_image
193+
194+
image = CellMapImage(
195+
path=path,
196+
target_class="test_class",
197+
target_scale=(4.0, 4.0, 4.0),
198+
target_voxel_shape=(16, 16, 16),
199+
pad=True,
200+
pad_value=0,
201+
)
202+
203+
assert image.pad is True
204+
assert image.pad_value == 0
205+
206+
def test_pad_parameter_false(self, test_zarr_image):
207+
"""Test when pad=False."""
208+
path, _ = test_zarr_image
209+
210+
image = CellMapImage(
211+
path=path,
212+
target_class="test_class",
213+
target_scale=(4.0, 4.0, 4.0),
214+
target_voxel_shape=(16, 16, 16),
215+
pad=False,
216+
)
217+
218+
assert image.pad is False
219+
220+
def test_interpolation_nearest(self, test_zarr_image):
221+
"""Test interpolation mode nearest."""
222+
path, _ = test_zarr_image
223+
224+
image = CellMapImage(
225+
path=path,
226+
target_class="test_class",
227+
target_scale=(4.0, 4.0, 4.0),
228+
target_voxel_shape=(16, 16, 16),
229+
interpolation="nearest",
230+
)
231+
232+
assert image.interpolation == "nearest"
233+
234+
def test_interpolation_linear(self, test_zarr_image):
235+
"""Test interpolation mode linear."""
236+
path, _ = test_zarr_image
237+
238+
image = CellMapImage(
239+
path=path,
240+
target_class="test_class",
241+
target_scale=(4.0, 4.0, 4.0),
242+
target_voxel_shape=(16, 16, 16),
243+
interpolation="linear",
244+
)
245+
246+
assert image.interpolation == "linear"
247+
248+
def test_value_transform_none(self, test_zarr_image):
249+
"""Test when no value transform is provided."""
250+
path, _ = test_zarr_image
251+
252+
image = CellMapImage(
253+
path=path,
254+
target_class="test_class",
255+
target_scale=(4.0, 4.0, 4.0),
256+
target_voxel_shape=(16, 16, 16),
257+
value_transform=None,
258+
)
259+
260+
assert image.value_transform is None
261+
262+
def test_value_transform_provided(self, test_zarr_image):
263+
"""Test when value transform is provided."""
264+
path, _ = test_zarr_image
265+
266+
transform = lambda x: x * 2
267+
image = CellMapImage(
268+
path=path,
269+
target_class="test_class",
270+
target_scale=(4.0, 4.0, 4.0),
271+
target_voxel_shape=(16, 16, 16),
272+
value_transform=transform,
273+
)
274+
275+
assert image.value_transform is transform
276+
277+
def test_output_size_calculation(self, test_zarr_image):
278+
"""Test that output_size is correctly calculated."""
279+
path, _ = test_zarr_image
280+
281+
image = CellMapImage(
282+
path=path,
283+
target_class="test_class",
284+
target_scale=(4.0, 8.0, 2.0),
285+
target_voxel_shape=(10, 20, 30),
286+
axis_order="zyx",
287+
)
288+
289+
# output_size = voxel_shape * scale
290+
assert image.output_size["z"] == 10 * 4.0
291+
assert image.output_size["y"] == 20 * 8.0
292+
assert image.output_size["x"] == 30 * 2.0
293+
294+
def test_axes_property(self, test_zarr_image):
295+
"""Test that axes property is correctly set."""
296+
path, _ = test_zarr_image
297+
298+
image = CellMapImage(
299+
path=path,
300+
target_class="test_class",
301+
target_scale=(4.0, 4.0, 4.0),
302+
target_voxel_shape=(16, 16, 16),
303+
axis_order="zyx",
304+
)
305+
306+
assert image.axes == "zyx"
307+
308+
def test_context_parameter_none(self, test_zarr_image):
309+
"""Test when no context is provided."""
310+
path, _ = test_zarr_image
311+
312+
image = CellMapImage(
313+
path=path,
314+
target_class="test_class",
315+
target_scale=(4.0, 4.0, 4.0),
316+
target_voxel_shape=(16, 16, 16),
317+
context=None,
318+
)
319+
320+
assert image.context is None
321+
322+
def test_path_attribute(self, test_zarr_image):
323+
"""Test that path attribute is correctly set."""
324+
path, _ = test_zarr_image
325+
326+
image = CellMapImage(
327+
path=path,
328+
target_class="test_class",
329+
target_scale=(4.0, 4.0, 4.0),
330+
target_voxel_shape=(16, 16, 16),
331+
)
332+
333+
assert image.path == path
334+
335+
def test_label_class_attribute(self, test_zarr_image):
336+
"""Test that label_class attribute is correctly set."""
337+
path, _ = test_zarr_image
338+
339+
image = CellMapImage(
340+
path=path,
341+
target_class="my_class",
342+
target_scale=(4.0, 4.0, 4.0),
343+
target_voxel_shape=(16, 16, 16),
344+
)
345+
346+
assert image.label_class == "my_class"
347+
348+
def test_getitem_returns_tensor(self, test_zarr_image):
349+
"""Test that __getitem__ returns a PyTorch tensor."""
350+
path, _ = test_zarr_image
351+
352+
image = CellMapImage(
353+
path=path,
354+
target_class="test_class",
355+
target_scale=(4.0, 4.0, 4.0),
356+
target_voxel_shape=(8, 8, 8),
357+
)
358+
359+
center = {"z": 64.0, "y": 64.0, "x": 64.0}
360+
result = image[center]
361+
362+
assert isinstance(result, torch.Tensor)
363+
assert result.ndim >= 3
364+
365+
def test_nan_pad_value(self, test_zarr_image):
366+
"""Test using NaN as pad value."""
367+
path, _ = test_zarr_image
368+
369+
image = CellMapImage(
370+
path=path,
371+
target_class="test_class",
372+
target_scale=(4.0, 4.0, 4.0),
373+
target_voxel_shape=(16, 16, 16),
374+
pad=True,
375+
pad_value=np.nan,
376+
)
377+
378+
assert np.isnan(image.pad_value)

0 commit comments

Comments
 (0)