11"""Tests for base abstract classes."""
22
3+ from abc import ABC
4+
35import pytest
46import torch
5- from abc import ABC
67
78from cellmap_data .base_dataset import CellMapBaseDataset
89from cellmap_data .base_image import CellMapImageBase
@@ -18,57 +19,57 @@ def test_cannot_instantiate_abstract_class(self):
1819
1920 def test_incomplete_implementation_raises_error (self ):
2021 """Test that incomplete implementations cannot be instantiated."""
21-
22+
2223 # Missing all abstract methods
2324 class IncompleteDataset (CellMapBaseDataset ):
2425 pass
25-
26+
2627 with pytest .raises (TypeError , match = "Can't instantiate abstract class" ):
2728 IncompleteDataset ()
28-
29+
2930 # Missing some abstract methods
3031 class PartialDataset (CellMapBaseDataset ):
3132 @property
3233 def class_counts (self ):
3334 return {}
34-
35+
3536 @property
3637 def class_weights (self ):
3738 return {}
38-
39+
3940 with pytest .raises (TypeError , match = "Can't instantiate abstract class" ):
4041 PartialDataset ()
4142
4243 def test_complete_implementation_can_be_instantiated (self ):
4344 """Test that complete implementations can be instantiated."""
44-
45+
4546 class CompleteDataset (CellMapBaseDataset ):
4647 def __init__ (self ):
4748 self .classes = ["class1" , "class2" ]
4849 self .input_arrays = {"raw" : {}}
4950 self .target_arrays = {"labels" : {}}
50-
51+
5152 @property
5253 def class_counts (self ):
5354 return {"class1" : 100.0 , "class2" : 200.0 }
54-
55+
5556 @property
5657 def class_weights (self ):
5758 return {"class1" : 0.67 , "class2" : 0.33 }
58-
59+
5960 @property
6061 def validation_indices (self ):
6162 return [0 , 1 , 2 ]
62-
63+
6364 def to (self , device , non_blocking = True ):
6465 return self
65-
66+
6667 def set_raw_value_transforms (self , transforms ):
6768 pass
68-
69+
6970 def set_target_value_transforms (self , transforms ):
7071 pass
71-
72+
7273 # Should not raise
7374 dataset = CompleteDataset ()
7475 assert isinstance (dataset , CellMapBaseDataset )
@@ -83,11 +84,11 @@ def set_target_value_transforms(self, transforms):
8384 def test_attributes_are_defined (self ):
8485 """Test that expected attributes are defined in the base class."""
8586 # Check type annotations exist
86- assert hasattr (CellMapBaseDataset , ' __annotations__' )
87+ assert hasattr (CellMapBaseDataset , " __annotations__" )
8788 annotations = CellMapBaseDataset .__annotations__
88- assert ' classes' in annotations
89- assert ' input_arrays' in annotations
90- assert ' target_arrays' in annotations
89+ assert " classes" in annotations
90+ assert " input_arrays" in annotations
91+ assert " target_arrays" in annotations
9192
9293
9394class TestCellMapImageBase :
@@ -100,52 +101,52 @@ def test_cannot_instantiate_abstract_class(self):
100101
101102 def test_incomplete_implementation_raises_error (self ):
102103 """Test that incomplete implementations cannot be instantiated."""
103-
104+
104105 # Missing all abstract methods
105106 class IncompleteImage (CellMapImageBase ):
106107 pass
107-
108+
108109 with pytest .raises (TypeError , match = "Can't instantiate abstract class" ):
109110 IncompleteImage ()
110-
111+
111112 # Missing some abstract methods
112113 class PartialImage (CellMapImageBase ):
113114 @property
114115 def bounding_box (self ):
115116 return {"x" : (0 , 100 ), "y" : (0 , 100 )}
116-
117+
117118 @property
118119 def sampling_box (self ):
119120 return {"x" : (10 , 90 ), "y" : (10 , 90 )}
120-
121+
121122 with pytest .raises (TypeError , match = "Can't instantiate abstract class" ):
122123 PartialImage ()
123124
124125 def test_complete_implementation_can_be_instantiated (self ):
125126 """Test that complete implementations can be instantiated."""
126-
127+
127128 class CompleteImage (CellMapImageBase ):
128129 def __getitem__ (self , center ):
129130 return torch .zeros ((1 , 64 , 64 ))
130-
131+
131132 @property
132133 def bounding_box (self ):
133134 return {"x" : (0.0 , 100.0 ), "y" : (0.0 , 100.0 )}
134-
135+
135136 @property
136137 def sampling_box (self ):
137138 return {"x" : (10.0 , 90.0 ), "y" : (10.0 , 90.0 )}
138-
139+
139140 @property
140141 def class_counts (self ):
141142 return 1000.0
142-
143+
143144 def to (self , device , non_blocking = True ):
144145 pass
145-
146+
146147 def set_spatial_transforms (self , transforms ):
147148 pass
148-
149+
149150 # Should not raise
150151 image = CompleteImage ()
151152 assert isinstance (image , CellMapImageBase )
@@ -161,59 +162,59 @@ def set_spatial_transforms(self, transforms):
161162
162163 def test_class_counts_supports_dict_return_type (self ):
163164 """Test that class_counts can return a dictionary."""
164-
165+
165166 class MultiClassImage (CellMapImageBase ):
166167 def __getitem__ (self , center ):
167168 return torch .zeros ((1 , 64 , 64 ))
168-
169+
169170 @property
170171 def bounding_box (self ):
171172 return {"x" : (0.0 , 100.0 )}
172-
173+
173174 @property
174175 def sampling_box (self ):
175176 return {"x" : (10.0 , 90.0 )}
176-
177+
177178 @property
178179 def class_counts (self ):
179180 return {"class1" : 500.0 , "class2" : 300.0 , "class3" : 200.0 }
180-
181+
181182 def to (self , device , non_blocking = True ):
182183 pass
183-
184+
184185 def set_spatial_transforms (self , transforms ):
185186 pass
186-
187+
187188 image = MultiClassImage ()
188189 counts = image .class_counts
189190 assert isinstance (counts , dict )
190191 assert counts == {"class1" : 500.0 , "class2" : 300.0 , "class3" : 200.0 }
191192
192193 def test_bounding_box_can_be_none (self ):
193194 """Test that bounding_box property can return None."""
194-
195+
195196 class UnboundedImage (CellMapImageBase ):
196197 def __getitem__ (self , center ):
197198 return torch .zeros ((1 , 64 , 64 ))
198-
199+
199200 @property
200201 def bounding_box (self ):
201202 return None
202-
203+
203204 @property
204205 def sampling_box (self ):
205206 return None
206-
207+
207208 @property
208209 def class_counts (self ):
209210 return 1000.0
210-
211+
211212 def to (self , device , non_blocking = True ):
212213 pass
213-
214+
214215 def set_spatial_transforms (self , transforms ):
215216 pass
216-
217+
217218 image = UnboundedImage ()
218219 assert image .bounding_box is None
219220 assert image .sampling_box is None
0 commit comments