From a4b7b0d034b5c9b1e9b9aa424eb62419211e183f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Wed, 4 Feb 2026 22:34:59 +0000 Subject: [PATCH 1/2] Add methods to export transform as dict for Hydra --- src/torchio/transforms/augmentation/composition.py | 10 ++++++++++ src/torchio/transforms/transform.py | 11 +++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/torchio/transforms/augmentation/composition.py b/src/torchio/transforms/augmentation/composition.py index 2e8b7b3cf..2a98b9bc0 100644 --- a/src/torchio/transforms/augmentation/composition.py +++ b/src/torchio/transforms/augmentation/composition.py @@ -81,6 +81,16 @@ def inverse(self, warn: bool = True) -> Compose: ) return result + def to_hydra_config(self) -> dict: + """Return a dictionary representation of the transform for Hydra instantiation.""" + target = self._get_name_with_module() + transform_dict = {'_target_': target} + transform_dict['transforms'] = [] + transform_dict.update(self._get_reproducing_arguments()) + for transform in self.transforms: + transform_dict['transforms'].append(transform.to_hydra_config()) + return transform_dict + class OneOf(RandomTransform): """Apply only one of the given transforms. diff --git a/src/torchio/transforms/transform.py b/src/torchio/transforms/transform.py index 2375e7ce3..55c85941a 100644 --- a/src/torchio/transforms/transform.py +++ b/src/torchio/transforms/transform.py @@ -596,3 +596,14 @@ def get_mask_from_bounds( mask = torch.zeros_like(tensor, dtype=torch.bool) mask[:, i0:i1, j0:j1, k0:k1] = True return mask + + def _get_name_with_module(self) -> str: + """Return the name of the transform including its module.""" + return f'{self.__class__.__module__}.{self.__class__.__name__}' + + def to_hydra_config(self) -> dict: + """Return a dictionary representation of the transform for Hydra instantiation.""" + target = self._get_name_with_module() + transform_dict = {'_target_': target} + transform_dict.update(self._get_reproducing_arguments()) + return transform_dict From f64b3f793d0e4c80df46d094d295021de6146460 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Tue, 10 Feb 2026 10:24:05 +0000 Subject: [PATCH 2/2] Fix tuples in config and typing errors --- src/torchio/transforms/augmentation/composition.py | 13 ++++++++----- src/torchio/transforms/transform.py | 10 +++++++++- src/torchio/visualization.py | 8 ++++++++ 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/torchio/transforms/augmentation/composition.py b/src/torchio/transforms/augmentation/composition.py index 2a98b9bc0..c09aa856a 100644 --- a/src/torchio/transforms/augmentation/composition.py +++ b/src/torchio/transforms/augmentation/composition.py @@ -2,6 +2,8 @@ import warnings from collections.abc import Sequence +from typing import Any +from typing import TypeAlias from typing import Union import numpy as np @@ -11,7 +13,9 @@ from ..transform import Transform from . import RandomTransform -TypeTransformsDict = Union[dict[Transform, float], Sequence[Transform]] +TypeTransformsDict: TypeAlias = Union[dict[Transform, float], Sequence[Transform]] +HydraConfig: TypeAlias = dict[str, Any] +HydraConfigDict: TypeAlias = dict[str, HydraConfig] class Compose(Transform): @@ -81,15 +85,14 @@ def inverse(self, warn: bool = True) -> Compose: ) return result - def to_hydra_config(self) -> dict: + def to_hydra_config(self) -> HydraConfig: """Return a dictionary representation of the transform for Hydra instantiation.""" - target = self._get_name_with_module() - transform_dict = {'_target_': target} + transform_dict: HydraConfig = {'_target_': self._get_name_with_module()} transform_dict['transforms'] = [] transform_dict.update(self._get_reproducing_arguments()) for transform in self.transforms: transform_dict['transforms'].append(transform.to_hydra_config()) - return transform_dict + return self._tuples_to_lists(transform_dict) class OneOf(RandomTransform): diff --git a/src/torchio/transforms/transform.py b/src/torchio/transforms/transform.py index 55c85941a..26735d098 100644 --- a/src/torchio/transforms/transform.py +++ b/src/torchio/transforms/transform.py @@ -601,9 +601,17 @@ def _get_name_with_module(self) -> str: """Return the name of the transform including its module.""" return f'{self.__class__.__module__}.{self.__class__.__name__}' + @staticmethod + def _tuples_to_lists(obj): + if isinstance(obj, (tuple, list)): + return [Transform._tuples_to_lists(x) for x in obj] + if isinstance(obj, dict): + return {k: Transform._tuples_to_lists(v) for k, v in obj.items()} + return obj + def to_hydra_config(self) -> dict: """Return a dictionary representation of the transform for Hydra instantiation.""" target = self._get_name_with_module() transform_dict = {'_target_': target} transform_dict.update(self._get_reproducing_arguments()) - return transform_dict + return self._tuples_to_lists(transform_dict) diff --git a/src/torchio/visualization.py b/src/torchio/visualization.py index f177862a4..dc260c8e5 100644 --- a/src/torchio/visualization.py +++ b/src/torchio/visualization.py @@ -103,6 +103,14 @@ def plot_volume( elif rgb and image.num_channels == 3: data = image.data # keep image as it is elif channel is None: + if image.num_channels > 1: + message = ( + 'Multiple channels found in the image. ' + 'Plotting the first channel (0). ' + 'To plot a different channel, please specify the channel ' + 'index using the "channel" argument.' + ) + warnings.warn(message, RuntimeWarning, stacklevel=2) data = image.data[0:1] # just use the first channel else: data = image.data[np.newaxis, channel]