Skip to content

Commit ab15ad7

Browse files
Merge pull request #52 from janelia-cellmap/copilot/sub-pr-51
Add comprehensive test coverage for base classes and refactored components
2 parents c79601c + 3096379 commit ab15ad7

File tree

4 files changed

+1193
-0
lines changed

4 files changed

+1193
-0
lines changed

tests/test_base_classes.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""Tests for base abstract classes."""
2+
3+
import pytest
4+
import torch
5+
from abc import ABC
6+
7+
from cellmap_data.base_dataset import CellMapBaseDataset
8+
from cellmap_data.base_image import CellMapImageBase
9+
10+
11+
class TestCellMapBaseDataset:
12+
"""Test the CellMapBaseDataset abstract base class."""
13+
14+
def test_cannot_instantiate_abstract_class(self):
15+
"""Test that CellMapBaseDataset cannot be instantiated directly."""
16+
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
17+
CellMapBaseDataset()
18+
19+
def test_incomplete_implementation_raises_error(self):
20+
"""Test that incomplete implementations cannot be instantiated."""
21+
22+
# Missing all abstract methods
23+
class IncompleteDataset(CellMapBaseDataset):
24+
pass
25+
26+
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
27+
IncompleteDataset()
28+
29+
# Missing some abstract methods
30+
class PartialDataset(CellMapBaseDataset):
31+
@property
32+
def class_counts(self):
33+
return {}
34+
35+
@property
36+
def class_weights(self):
37+
return {}
38+
39+
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
40+
PartialDataset()
41+
42+
def test_complete_implementation_can_be_instantiated(self):
43+
"""Test that complete implementations can be instantiated."""
44+
45+
class CompleteDataset(CellMapBaseDataset):
46+
def __init__(self):
47+
self.classes = ["class1", "class2"]
48+
self.input_arrays = {"raw": {}}
49+
self.target_arrays = {"labels": {}}
50+
51+
@property
52+
def class_counts(self):
53+
return {"class1": 100.0, "class2": 200.0}
54+
55+
@property
56+
def class_weights(self):
57+
return {"class1": 0.67, "class2": 0.33}
58+
59+
@property
60+
def validation_indices(self):
61+
return [0, 1, 2]
62+
63+
def to(self, device, non_blocking=True):
64+
return self
65+
66+
def set_raw_value_transforms(self, transforms):
67+
pass
68+
69+
def set_target_value_transforms(self, transforms):
70+
pass
71+
72+
# Should not raise
73+
dataset = CompleteDataset()
74+
assert isinstance(dataset, CellMapBaseDataset)
75+
assert dataset.classes == ["class1", "class2"]
76+
assert dataset.class_counts == {"class1": 100.0, "class2": 200.0}
77+
assert dataset.class_weights == {"class1": 0.67, "class2": 0.33}
78+
assert dataset.validation_indices == [0, 1, 2]
79+
assert dataset.to("cpu") is dataset
80+
dataset.set_raw_value_transforms(lambda x: x)
81+
dataset.set_target_value_transforms(lambda x: x)
82+
83+
def test_attributes_are_defined(self):
84+
"""Test that expected attributes are defined in the base class."""
85+
# Check type annotations exist
86+
assert hasattr(CellMapBaseDataset, '__annotations__')
87+
annotations = CellMapBaseDataset.__annotations__
88+
assert 'classes' in annotations
89+
assert 'input_arrays' in annotations
90+
assert 'target_arrays' in annotations
91+
92+
93+
class TestCellMapImageBase:
94+
"""Test the CellMapImageBase abstract base class."""
95+
96+
def test_cannot_instantiate_abstract_class(self):
97+
"""Test that CellMapImageBase cannot be instantiated directly."""
98+
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
99+
CellMapImageBase()
100+
101+
def test_incomplete_implementation_raises_error(self):
102+
"""Test that incomplete implementations cannot be instantiated."""
103+
104+
# Missing all abstract methods
105+
class IncompleteImage(CellMapImageBase):
106+
pass
107+
108+
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
109+
IncompleteImage()
110+
111+
# Missing some abstract methods
112+
class PartialImage(CellMapImageBase):
113+
@property
114+
def bounding_box(self):
115+
return {"x": (0, 100), "y": (0, 100)}
116+
117+
@property
118+
def sampling_box(self):
119+
return {"x": (10, 90), "y": (10, 90)}
120+
121+
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
122+
PartialImage()
123+
124+
def test_complete_implementation_can_be_instantiated(self):
125+
"""Test that complete implementations can be instantiated."""
126+
127+
class CompleteImage(CellMapImageBase):
128+
def __getitem__(self, center):
129+
return torch.zeros((1, 64, 64))
130+
131+
@property
132+
def bounding_box(self):
133+
return {"x": (0.0, 100.0), "y": (0.0, 100.0)}
134+
135+
@property
136+
def sampling_box(self):
137+
return {"x": (10.0, 90.0), "y": (10.0, 90.0)}
138+
139+
@property
140+
def class_counts(self):
141+
return 1000.0
142+
143+
def to(self, device, non_blocking=True):
144+
pass
145+
146+
def set_spatial_transforms(self, transforms):
147+
pass
148+
149+
# Should not raise
150+
image = CompleteImage()
151+
assert isinstance(image, CellMapImageBase)
152+
center = {"x": 50.0, "y": 50.0}
153+
result = image[center]
154+
assert isinstance(result, torch.Tensor)
155+
assert result.shape == (1, 64, 64)
156+
assert image.bounding_box == {"x": (0.0, 100.0), "y": (0.0, 100.0)}
157+
assert image.sampling_box == {"x": (10.0, 90.0), "y": (10.0, 90.0)}
158+
assert image.class_counts == 1000.0
159+
image.to("cpu")
160+
image.set_spatial_transforms(None)
161+
162+
def test_class_counts_supports_dict_return_type(self):
163+
"""Test that class_counts can return a dictionary."""
164+
165+
class MultiClassImage(CellMapImageBase):
166+
def __getitem__(self, center):
167+
return torch.zeros((1, 64, 64))
168+
169+
@property
170+
def bounding_box(self):
171+
return {"x": (0.0, 100.0)}
172+
173+
@property
174+
def sampling_box(self):
175+
return {"x": (10.0, 90.0)}
176+
177+
@property
178+
def class_counts(self):
179+
return {"class1": 500.0, "class2": 300.0, "class3": 200.0}
180+
181+
def to(self, device, non_blocking=True):
182+
pass
183+
184+
def set_spatial_transforms(self, transforms):
185+
pass
186+
187+
image = MultiClassImage()
188+
counts = image.class_counts
189+
assert isinstance(counts, dict)
190+
assert counts == {"class1": 500.0, "class2": 300.0, "class3": 200.0}
191+
192+
def test_bounding_box_can_be_none(self):
193+
"""Test that bounding_box property can return None."""
194+
195+
class UnboundedImage(CellMapImageBase):
196+
def __getitem__(self, center):
197+
return torch.zeros((1, 64, 64))
198+
199+
@property
200+
def bounding_box(self):
201+
return None
202+
203+
@property
204+
def sampling_box(self):
205+
return None
206+
207+
@property
208+
def class_counts(self):
209+
return 1000.0
210+
211+
def to(self, device, non_blocking=True):
212+
pass
213+
214+
def set_spatial_transforms(self, transforms):
215+
pass
216+
217+
image = UnboundedImage()
218+
assert image.bounding_box is None
219+
assert image.sampling_box is None

0 commit comments

Comments
 (0)