Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions src/torchio/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,73 @@ def get_center(self, lps: bool = False) -> TypeTripletFloat:
def set_check_nans(self, check_nans: bool) -> None:
self.check_nans = check_nans

def new_like(self, tensor: TypeData, affine: TypeData | None = None) -> Image:
"""Create a new image of the same type with new tensor data.

This method creates a new image instance of the same class as the current
image, preserving essential attributes like type, check_nans, and reader.
This is particularly useful for transforms that need to create new images
while maintaining compatibility with custom Image subclasses.

Args:
tensor: 4D tensor with dimensions :math:`(C, W, H, D)` for the new image.
affine: :math:`4 \\times 4` matrix to convert voxel coordinates to world
coordinates. If ``None``, uses the current image's affine matrix.

Returns:
A new image instance of the same type as the current image.

Example:
>>> import torch
>>> import torchio as tio
>>> # Standard usage
>>> image = tio.ScalarImage('path/to/image.nii.gz')
>>> new_tensor = torch.rand(1, 64, 64, 64)
>>> new_image = image.new_like(tensor=new_tensor)
>>> isinstance(new_image, tio.ScalarImage)
True

>>> # Custom subclass usage
>>> class CustomImage(tio.ScalarImage):
... def __init__(self, tensor, affine, metadata, **kwargs):
... super().__init__(tensor=tensor, affine=affine, **kwargs)
... self.metadata = metadata
...
... def new_like(self, tensor, affine=None):
... return type(self)(
... tensor=tensor,
... affine=affine if affine is not None else self.affine,
... metadata=self.metadata, # Preserve custom attribute
... check_nans=self.check_nans,
... reader=self.reader,
... )
>>> custom = CustomImage(torch.rand(1, 32, 32, 32), torch.eye(4), {'id': 123})
>>> new_custom = custom.new_like(torch.rand(1, 16, 16, 16))
>>> new_custom.metadata['id']
123
"""
if affine is None:
affine = self.affine

# First, try the standard constructor approach
try:
return type(self)(
tensor=tensor,
affine=affine,
type=self.type,
check_nans=self.check_nans,
reader=self.reader,
)
except TypeError:
# If the standard constructor fails (e.g., custom subclass with additional required args),
# fall back to a copy-based approach
import copy

new_image = copy.deepcopy(self)
new_image.set_data(tensor)
new_image.affine = affine
return new_image

def plot(self, **kwargs) -> None:
"""Plot image."""
if self.is_2d():
Expand Down
6 changes: 3 additions & 3 deletions src/torchio/transforms/preprocessing/spatial/crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ def _crop_image(
if copy_patch:
# Create a new image with the cropped data
cropped_data = image.data[:, i0:i1, j0:j1, k0:k1].clone()
new_image = type(image)(
new_image = image.new_like(
tensor=cropped_data,
affine=new_affine,
type=image.type,
path=image.path,
)
# Preserve path for the new image
new_image.path = image.path
return new_image
else:
image.set_data(image.data[:, i0:i1, j0:j1, k0:k1].clone())
Expand Down
17 changes: 14 additions & 3 deletions src/torchio/transforms/preprocessing/spatial/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ def _check_truncation(self, image: Image, mode: str | float) -> None:
def apply_transform(self, subject: Subject) -> Subject:
assert self.bounds_parameters is not None
low = self.bounds_parameters[::2]
for image in self.get_images(subject):
images_dict = subject.get_images_dict(
intensity_only=False,
include=self.include,
exclude=self.exclude,
)
for image_name, image in images_dict.items():
self._check_truncation(image, self.padding_mode)
new_origin = apply_affine(image.affine, -np.array(low))
new_affine = image.affine.copy()
Expand Down Expand Up @@ -126,8 +131,14 @@ def apply_transform(self, subject: Subject) -> Subject:
pad_params = self.bounds_parameters
paddings = (0, 0), pad_params[:2], pad_params[2:4], pad_params[4:]
padded = np.pad(image.data, paddings, **kwargs) # type: ignore[call-overload]
image.set_data(torch.as_tensor(padded))
image.affine = new_affine
new_image = image.new_like(
tensor=torch.as_tensor(padded), affine=new_affine
)
# Replace the image in the subject with the new padded image
subject[image_name] = new_image

# Update attributes to sync dictionary changes with attribute access
subject.update_attributes()
return subject

def inverse(self):
Expand Down
14 changes: 11 additions & 3 deletions src/torchio/transforms/preprocessing/spatial/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,12 @@ def apply_transform(self, subject: Subject) -> Subject:
assert self.pre_affine_name is not None # for mypy
self.check_affine_key_presence(self.pre_affine_name, subject)

for image in self.get_images(subject):
images_dict = subject.get_images_dict(
intensity_only=False,
include=self.include,
exclude=self.exclude,
)
for image_name, image in images_dict.items():
# If the current image is the reference, don't resample it
if self.target is image:
continue
Expand Down Expand Up @@ -233,8 +238,11 @@ def apply_transform(self, subject: Subject) -> Subject:
resampled = resampler.Execute(floating_sitk)

array, affine = sitk_to_nib(resampled)
image.set_data(torch.as_tensor(array))
image.affine = affine
new_image = image.new_like(tensor=torch.as_tensor(array), affine=affine)
subject[image_name] = new_image

# Update attributes to sync dictionary changes with attribute access
subject.update_attributes()
return subject

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def __init__(self, orientation: str = 'RAS', **kwargs):
self.args_names = ['orientation']

def apply_transform(self, subject: Subject) -> Subject:
for image in subject.get_images(intensity_only=False):
images_dict = subject.get_images_dict(intensity_only=False)
for image_name, image in images_dict.items():
current_orientation = ''.join(nib.orientations.aff2axcodes(image.affine))

# If the image is already in the target orientation, skip it
Expand Down Expand Up @@ -104,7 +105,9 @@ def apply_transform(self, subject: Subject) -> Subject:
# Update the image data and affine
reoriented_array = np.ascontiguousarray(reoriented_array)
tensor = torch.from_numpy(reoriented_array)
image.set_data(tensor)
image.affine = reoriented_affine
new_image = image.new_like(tensor=tensor, affine=reoriented_affine)
subject[image_name] = new_image

# Update attributes to sync dictionary changes with attribute access
subject.update_attributes()
return subject
16 changes: 11 additions & 5 deletions src/torchio/transforms/preprocessing/spatial/to_reference_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,17 @@ def __init__(self, reference: Image, **kwargs):
self.reference = reference

def apply_transform(self, subject: Subject) -> Subject:
for image in self.get_images(subject):
images_dict = subject.get_images_dict(
intensity_only=False,
include=self.include,
exclude=self.exclude,
)
for image_name, image in images_dict.items():
new_image = build_image_from_reference(image.data, self.reference)
image.set_data(new_image.data)
image.affine = new_image.affine
subject[image_name] = new_image

# Update attributes to sync dictionary changes with attribute access
subject.update_attributes()
return subject

@staticmethod
Expand All @@ -49,6 +56,5 @@ def build_image_from_reference(tensor: torch.Tensor, reference: Image) -> Image:
output_spacing = input_spacing * downsampling_factor
downsample = Resample(output_spacing, image_interpolation='nearest')
reference = downsample(reference)
class_ = reference.__class__
result = class_(tensor=tensor, affine=reference.affine)
result = reference.new_like(tensor=tensor, affine=reference.affine)
return result
14 changes: 11 additions & 3 deletions src/torchio/transforms/preprocessing/spatial/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,21 @@ class Transpose(SpatialTransform):
"""

def apply_transform(self, subject: Subject) -> Subject:
for image in self.get_images(subject):
images_dict = subject.get_images_dict(
intensity_only=False,
include=self.include,
exclude=self.exclude,
)
for image_name, image in images_dict.items():
old_orientation = image.orientation_str
new_orientation = old_orientation[::-1]
transform = ToOrientation(new_orientation)
transposed = transform(image)
image.set_data(transposed.data)
image.affine = transposed.affine
new_image = image.new_like(tensor=transposed.data, affine=transposed.affine)
subject[image_name] = new_image

# Update attributes to sync dictionary changes with attribute access
subject.update_attributes()
return subject

def is_invertible(self):
Expand Down
Loading