diff --git a/src/torchio/data/image.py b/src/torchio/data/image.py index d8f74c32..438234c5 100644 --- a/src/torchio/data/image.py +++ b/src/torchio/data/image.py @@ -774,6 +774,73 @@ def get_center(self, lps: bool = False) -> TypeTripletFloat: def set_check_nans(self, check_nans: bool) -> None: self.check_nans = check_nans + def new_like(self, tensor: TypeData, affine: TypeData | None = None) -> Image: + """Create a new image of the same type with new tensor data. + + This method creates a new image instance of the same class as the current + image, preserving essential attributes like type, check_nans, and reader. + This is particularly useful for transforms that need to create new images + while maintaining compatibility with custom Image subclasses. + + Args: + tensor: 4D tensor with dimensions :math:`(C, W, H, D)` for the new image. + affine: :math:`4 \\times 4` matrix to convert voxel coordinates to world + coordinates. If ``None``, uses the current image's affine matrix. + + Returns: + A new image instance of the same type as the current image. + + Example: + >>> import torch + >>> import torchio as tio + >>> # Standard usage + >>> image = tio.ScalarImage('path/to/image.nii.gz') + >>> new_tensor = torch.rand(1, 64, 64, 64) + >>> new_image = image.new_like(tensor=new_tensor) + >>> isinstance(new_image, tio.ScalarImage) + True + + >>> # Custom subclass usage + >>> class CustomImage(tio.ScalarImage): + ... def __init__(self, tensor, affine, metadata, **kwargs): + ... super().__init__(tensor=tensor, affine=affine, **kwargs) + ... self.metadata = metadata + ... + ... def new_like(self, tensor, affine=None): + ... return type(self)( + ... tensor=tensor, + ... affine=affine if affine is not None else self.affine, + ... metadata=self.metadata, # Preserve custom attribute + ... check_nans=self.check_nans, + ... reader=self.reader, + ... ) + >>> custom = CustomImage(torch.rand(1, 32, 32, 32), torch.eye(4), {'id': 123}) + >>> new_custom = custom.new_like(torch.rand(1, 16, 16, 16)) + >>> new_custom.metadata['id'] + 123 + """ + if affine is None: + affine = self.affine + + # First, try the standard constructor approach + try: + return type(self)( + tensor=tensor, + affine=affine, + type=self.type, + check_nans=self.check_nans, + reader=self.reader, + ) + except TypeError: + # If the standard constructor fails (e.g., custom subclass with additional required args), + # fall back to a copy-based approach + import copy + + new_image = copy.deepcopy(self) + new_image.set_data(tensor) + new_image.affine = affine + return new_image + def plot(self, **kwargs) -> None: """Plot image.""" if self.is_2d(): diff --git a/src/torchio/transforms/preprocessing/spatial/crop.py b/src/torchio/transforms/preprocessing/spatial/crop.py index 3142972f..36ef12e0 100644 --- a/src/torchio/transforms/preprocessing/spatial/crop.py +++ b/src/torchio/transforms/preprocessing/spatial/crop.py @@ -111,12 +111,12 @@ def _crop_image( if copy_patch: # Create a new image with the cropped data cropped_data = image.data[:, i0:i1, j0:j1, k0:k1].clone() - new_image = type(image)( + new_image = image.new_like( tensor=cropped_data, affine=new_affine, - type=image.type, - path=image.path, ) + # Preserve path for the new image + new_image.path = image.path return new_image else: image.set_data(image.data[:, i0:i1, j0:j1, k0:k1].clone()) diff --git a/src/torchio/transforms/preprocessing/spatial/pad.py b/src/torchio/transforms/preprocessing/spatial/pad.py index 70cf35d6..071ffb87 100644 --- a/src/torchio/transforms/preprocessing/spatial/pad.py +++ b/src/torchio/transforms/preprocessing/spatial/pad.py @@ -96,7 +96,12 @@ def _check_truncation(self, image: Image, mode: str | float) -> None: def apply_transform(self, subject: Subject) -> Subject: assert self.bounds_parameters is not None low = self.bounds_parameters[::2] - for image in self.get_images(subject): + images_dict = subject.get_images_dict( + intensity_only=False, + include=self.include, + exclude=self.exclude, + ) + for image_name, image in images_dict.items(): self._check_truncation(image, self.padding_mode) new_origin = apply_affine(image.affine, -np.array(low)) new_affine = image.affine.copy() @@ -126,8 +131,14 @@ def apply_transform(self, subject: Subject) -> Subject: pad_params = self.bounds_parameters paddings = (0, 0), pad_params[:2], pad_params[2:4], pad_params[4:] padded = np.pad(image.data, paddings, **kwargs) # type: ignore[call-overload] - image.set_data(torch.as_tensor(padded)) - image.affine = new_affine + new_image = image.new_like( + tensor=torch.as_tensor(padded), affine=new_affine + ) + # Replace the image in the subject with the new padded image + subject[image_name] = new_image + + # Update attributes to sync dictionary changes with attribute access + subject.update_attributes() return subject def inverse(self): diff --git a/src/torchio/transforms/preprocessing/spatial/resample.py b/src/torchio/transforms/preprocessing/spatial/resample.py index bb8ef63c..85771049 100644 --- a/src/torchio/transforms/preprocessing/spatial/resample.py +++ b/src/torchio/transforms/preprocessing/spatial/resample.py @@ -180,7 +180,12 @@ def apply_transform(self, subject: Subject) -> Subject: assert self.pre_affine_name is not None # for mypy self.check_affine_key_presence(self.pre_affine_name, subject) - for image in self.get_images(subject): + images_dict = subject.get_images_dict( + intensity_only=False, + include=self.include, + exclude=self.exclude, + ) + for image_name, image in images_dict.items(): # If the current image is the reference, don't resample it if self.target is image: continue @@ -233,8 +238,11 @@ def apply_transform(self, subject: Subject) -> Subject: resampled = resampler.Execute(floating_sitk) array, affine = sitk_to_nib(resampled) - image.set_data(torch.as_tensor(array)) - image.affine = affine + new_image = image.new_like(tensor=torch.as_tensor(array), affine=affine) + subject[image_name] = new_image + + # Update attributes to sync dictionary changes with attribute access + subject.update_attributes() return subject @staticmethod diff --git a/src/torchio/transforms/preprocessing/spatial/to_orientation.py b/src/torchio/transforms/preprocessing/spatial/to_orientation.py index d33e24df..fb198f5e 100644 --- a/src/torchio/transforms/preprocessing/spatial/to_orientation.py +++ b/src/torchio/transforms/preprocessing/spatial/to_orientation.py @@ -71,7 +71,8 @@ def __init__(self, orientation: str = 'RAS', **kwargs): self.args_names = ['orientation'] def apply_transform(self, subject: Subject) -> Subject: - for image in subject.get_images(intensity_only=False): + images_dict = subject.get_images_dict(intensity_only=False) + for image_name, image in images_dict.items(): current_orientation = ''.join(nib.orientations.aff2axcodes(image.affine)) # If the image is already in the target orientation, skip it @@ -104,7 +105,9 @@ def apply_transform(self, subject: Subject) -> Subject: # Update the image data and affine reoriented_array = np.ascontiguousarray(reoriented_array) tensor = torch.from_numpy(reoriented_array) - image.set_data(tensor) - image.affine = reoriented_affine + new_image = image.new_like(tensor=tensor, affine=reoriented_affine) + subject[image_name] = new_image + # Update attributes to sync dictionary changes with attribute access + subject.update_attributes() return subject diff --git a/src/torchio/transforms/preprocessing/spatial/to_reference_space.py b/src/torchio/transforms/preprocessing/spatial/to_reference_space.py index 3b5af1c3..1671e652 100644 --- a/src/torchio/transforms/preprocessing/spatial/to_reference_space.py +++ b/src/torchio/transforms/preprocessing/spatial/to_reference_space.py @@ -29,10 +29,17 @@ def __init__(self, reference: Image, **kwargs): self.reference = reference def apply_transform(self, subject: Subject) -> Subject: - for image in self.get_images(subject): + images_dict = subject.get_images_dict( + intensity_only=False, + include=self.include, + exclude=self.exclude, + ) + for image_name, image in images_dict.items(): new_image = build_image_from_reference(image.data, self.reference) - image.set_data(new_image.data) - image.affine = new_image.affine + subject[image_name] = new_image + + # Update attributes to sync dictionary changes with attribute access + subject.update_attributes() return subject @staticmethod @@ -49,6 +56,5 @@ def build_image_from_reference(tensor: torch.Tensor, reference: Image) -> Image: output_spacing = input_spacing * downsampling_factor downsample = Resample(output_spacing, image_interpolation='nearest') reference = downsample(reference) - class_ = reference.__class__ - result = class_(tensor=tensor, affine=reference.affine) + result = reference.new_like(tensor=tensor, affine=reference.affine) return result diff --git a/src/torchio/transforms/preprocessing/spatial/transpose.py b/src/torchio/transforms/preprocessing/spatial/transpose.py index a215ad22..e8d07984 100644 --- a/src/torchio/transforms/preprocessing/spatial/transpose.py +++ b/src/torchio/transforms/preprocessing/spatial/transpose.py @@ -22,13 +22,21 @@ class Transpose(SpatialTransform): """ def apply_transform(self, subject: Subject) -> Subject: - for image in self.get_images(subject): + images_dict = subject.get_images_dict( + intensity_only=False, + include=self.include, + exclude=self.exclude, + ) + for image_name, image in images_dict.items(): old_orientation = image.orientation_str new_orientation = old_orientation[::-1] transform = ToOrientation(new_orientation) transposed = transform(image) - image.set_data(transposed.data) - image.affine = transposed.affine + new_image = image.new_like(tensor=transposed.data, affine=transposed.affine) + subject[image_name] = new_image + + # Update attributes to sync dictionary changes with attribute access + subject.update_attributes() return subject def is_invertible(self): diff --git a/tests/transforms/test_custom_image_subclass.py b/tests/transforms/test_custom_image_subclass.py new file mode 100644 index 00000000..ac7198a1 --- /dev/null +++ b/tests/transforms/test_custom_image_subclass.py @@ -0,0 +1,220 @@ +"""Tests for custom Image subclasses with transforms.""" + +import pytest +import torch + +import torchio as tio + + +class HistoryImage(tio.ScalarImage): + """Test custom Image with required parameter.""" + + def __init__(self, tensor, affine, history, **kwargs): + super().__init__(tensor=tensor, affine=affine, **kwargs) + self.history = history + + def new_like(self, tensor, affine=None): + return type(self)( + tensor=tensor, + affine=affine if affine is not None else self.affine, + history=self.history, + check_nans=self.check_nans, + reader=self.reader, + ) + + +class MetadataImage(tio.ScalarImage): + """Test custom Image with optional parameter.""" + + def __init__(self, tensor, affine, metadata=None, **kwargs): + super().__init__(tensor=tensor, affine=affine, **kwargs) + self.metadata = metadata or {} + + def new_like(self, tensor, affine=None): + return type(self)( + tensor=tensor, + affine=affine if affine is not None else self.affine, + metadata=self.metadata, + check_nans=self.check_nans, + reader=self.reader, + ) + + +class TestCustomImageSubclass: + """Test suite for custom Image subclasses with transforms.""" + + @pytest.fixture + def history_image(self): + """Create a HistoryImage for testing.""" + tensor = torch.rand(1, 10, 10, 10) + affine = torch.eye(4) + return HistoryImage(tensor=tensor, affine=affine, history=['created']) + + @pytest.fixture + def metadata_image(self): + """Create a MetadataImage for testing.""" + tensor = torch.rand(1, 12, 12, 12) + affine = torch.eye(4) + return MetadataImage( + tensor=tensor, affine=affine, metadata={'id': 123, 'source': 'test'} + ) + + @pytest.fixture + def history_subject(self, history_image): + """Create a Subject with HistoryImage.""" + return tio.Subject(image=history_image) + + @pytest.fixture + def metadata_subject(self, metadata_image): + """Create a Subject with MetadataImage.""" + return tio.Subject(image=metadata_image) + + def test_crop_with_history_image(self, history_subject): + """Test that Crop transform works with custom Image requiring history parameter.""" + transform = tio.Crop(cropping=2) + result = transform(history_subject) + + # Check that the result is still a HistoryImage + assert isinstance(result.image, HistoryImage) + + # Check that custom attribute is preserved + assert result.image.history == ['created'] + + # Check that cropping worked correctly + assert result.image.shape == (1, 6, 6, 6) + + def test_crop_with_metadata_image(self, metadata_subject): + """Test that Crop transform works with custom Image with optional parameters.""" + transform = tio.Crop(cropping=1) + result = transform(metadata_subject) + + # Check that the result is still a MetadataImage + assert isinstance(result.image, MetadataImage) + + # Check that custom attribute is preserved + assert result.image.metadata == {'id': 123, 'source': 'test'} + + # Check that cropping worked correctly + assert result.image.shape == (1, 10, 10, 10) + + def test_chained_transforms_preserve_attributes(self, history_subject): + """Test that chained transforms preserve custom attributes.""" + # Chain multiple transforms + transform = tio.Compose( + [ + tio.Crop(cropping=1), + tio.Crop(cropping=1), + ] + ) + + result = transform(history_subject) + + # Check that the result is still a HistoryImage after multiple transforms + assert isinstance(result.image, HistoryImage) + + # Check that custom attribute is preserved through the chain + assert result.image.history == ['created'] + + # Check that both crops were applied + assert result.image.shape == (1, 6, 6, 6) + + def test_backward_compatibility_standard_images(self): + """Test that standard Images still work with transforms.""" + # Create a standard ScalarImage + tensor = torch.rand(1, 10, 10, 10) + affine = torch.eye(4) + image = tio.ScalarImage(tensor=tensor, affine=affine) + subject = tio.Subject(image=image) + + # Apply transform + transform = tio.Crop(cropping=2) + result = transform(subject) + + # Check that it still works + assert isinstance(result.image, tio.ScalarImage) + assert result.image.shape == (1, 6, 6, 6) + + def test_to_reference_space_with_custom_image(self, history_image): + """Test that ToReferenceSpace works with custom images.""" + # Create embedding tensor (smaller than reference) + embedding_tensor = torch.rand(1, 10, 10, 10) + + # Use ToReferenceSpace.from_tensor + result = tio.ToReferenceSpace.from_tensor(embedding_tensor, history_image) + + # Check that the result preserves the custom class type + assert isinstance(result, HistoryImage) + + # Check that custom attribute is preserved + assert result.history == ['created'] + + def test_new_like_method_directly(self, history_image): + """Test the new_like method directly.""" + new_tensor = torch.rand(1, 5, 5, 5) + new_affine = torch.eye(4) * 2 + + # Create new image using new_like + new_image = history_image.new_like(tensor=new_tensor, affine=new_affine) + + # Check type preservation + assert isinstance(new_image, HistoryImage) + + # Check attribute preservation + assert new_image.history == ['created'] + + # Check new data + assert torch.equal(new_image.data, new_tensor) + assert torch.allclose( + torch.tensor(new_image.affine).float(), new_affine.float() + ) + + def test_new_like_with_default_affine(self, metadata_image): + """Test new_like method with default affine (None).""" + new_tensor = torch.rand(1, 8, 8, 8) + + # Create new image using new_like with default affine + new_image = metadata_image.new_like(tensor=new_tensor) + + # Check that original affine is used + assert torch.allclose( + torch.tensor(new_image.affine), torch.tensor(metadata_image.affine) + ) + + # Check attribute preservation + assert new_image.metadata == {'id': 123, 'source': 'test'} + + def test_label_map_subclass(self): + """Test that custom LabelMap subclasses also work.""" + + class CustomLabelMap(tio.LabelMap): + def __init__(self, tensor, affine, labels_info, **kwargs): + super().__init__(tensor=tensor, affine=affine, **kwargs) + self.labels_info = labels_info + + def new_like(self, tensor, affine=None): + return type(self)( + tensor=tensor, + affine=affine if affine is not None else self.affine, + labels_info=self.labels_info, + check_nans=self.check_nans, + reader=self.reader, + ) + + # Create custom label map + tensor = torch.randint(0, 3, (1, 8, 8, 8)) + affine = torch.eye(4) + labels_info = {0: 'background', 1: 'tissue1', 2: 'tissue2'} + + custom_label = CustomLabelMap( + tensor=tensor, affine=affine, labels_info=labels_info + ) + subject = tio.Subject(labels=custom_label) + + # Apply transform + transform = tio.Crop(cropping=1) + result = transform(subject) + + # Check preservation + assert isinstance(result.labels, CustomLabelMap) + assert result.labels.labels_info == labels_info + assert result.labels.shape == (1, 6, 6, 6)