diff --git a/src/torchio/data/image.py b/src/torchio/data/image.py index d8f74c32..6afd748b 100644 --- a/src/torchio/data/image.py +++ b/src/torchio/data/image.py @@ -1,10 +1,13 @@ from __future__ import annotations +import base64 +import io import warnings from collections import Counter from collections.abc import Callable from collections.abc import Sequence from pathlib import Path +from typing import TYPE_CHECKING from typing import Any import humanize @@ -46,6 +49,9 @@ from .io import sitk_to_nib from .io import write_image +if TYPE_CHECKING: + from matplotlib.figure import Figure + PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM TypeBound = tuple[float, float] TypeBounds = tuple[TypeBound, TypeBound, TypeBound] @@ -204,6 +210,28 @@ def __repr__(self): string = f'{self.__class__.__name__}({properties})' return string + def _repr_html_(self): + try: + from matplotlib import pyplot as plt + from matplotlib.figure import Figure + except ImportError: + return self.__repr__() + + buffer = io.BytesIO() + fig = self.plot( + return_fig=True, + output_path=buffer, + show=False, + savefig_kwargs={'bbox_inches': 'tight'}, + ) + assert isinstance(fig, Figure) + plt.close(fig) + buffer.seek(0) + + img_str = base64.b64encode(buffer.read()).decode('utf-8') + html = f'' + return html + def __getitem__(self, item): if isinstance(item, (slice, int, tuple)): return self._crop_from_slices(item) @@ -774,14 +802,17 @@ def get_center(self, lps: bool = False) -> TypeTripletFloat: def set_check_nans(self, check_nans: bool) -> None: self.check_nans = check_nans - def plot(self, **kwargs) -> None: + def plot(self, return_fig: bool = False, **kwargs) -> None | Figure: """Plot image.""" if self.is_2d(): self.as_pil().show() else: from ..visualization import plot_volume # avoid circular import - plot_volume(self, **kwargs) + figure = plot_volume(self, **kwargs) + if return_fig: + assert figure is not None + return figure def show(self, viewer_path: TypePath | None = None) -> None: """Open the image using external software.