Skip to content

Commit 0d6aa09

Browse files
Replace T.Normalize with T.ToDtype for improved data type handling in transformations
1 parent 223c131 commit 0d6aa09

File tree

8 files changed

+11
-24
lines changed

8 files changed

+11
-24
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ spatial_transforms = {
8989

9090
# Define value transformations
9191
raw_value_transforms = T.Compose([
92-
T.Normalize(mean=[0.0], std=[255.0]), # Normalize to [0,1]
92+
T.ToDtype(torch.float, scale=True), # Normalize to [0,1] and convert to float
9393
GaussianNoise(std=0.05), # Add noise for augmentation
9494
RandomContrast((0.8, 1.2)), # Vary contrast
9595
])
@@ -293,7 +293,7 @@ from cellmap_data.transforms import (
293293

294294
# Input preprocessing
295295
raw_transforms = T.Compose([
296-
T.Normalize(mean=[0.0], std=[255.0]), # Normalize to [0,1]
296+
T.ToDtype(torch.float, scale=True), # Normalize to [0,1]
297297
GaussianNoise(std=0.1), # Add noise
298298
RandomContrast((0.8, 1.2)), # Vary contrast
299299
NaNtoNum({"nan": 0}) # Handle NaN values

docs/source/cellmap_data.transforms.augment.normalize.rst

Lines changed: 0 additions & 12 deletions
This file was deleted.

src/cellmap_data/datasplit.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,12 @@ def __init__(
108108
spatial_transforms: Optional[Mapping[str, Any]] = None,
109109
train_raw_value_transforms: Optional[T.Transform] = T.Compose(
110110
[
111-
T.Normalize(mean=[0.0], std=[255.0]),
112111
T.ToDtype(torch.float, scale=True),
113112
NaNtoNum({"nan": 0, "posinf": None, "neginf": None}),
114113
],
115114
),
116115
val_raw_value_transforms: Optional[T.Transform] = T.Compose(
117116
[
118-
T.Normalize(mean=[0.0], std=[255.0]),
119117
T.ToDtype(torch.float, scale=True),
120118
NaNtoNum({"nan": 0, "posinf": None, "neginf": None}),
121119
],

tests/test_cellmap_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def test_value_transforms_configuration(self, minimal_dataset_config):
192192

193193
raw_transforms = T.Compose(
194194
[
195-
T.Normalize(mean=[0.0], std=[255.0]),
195+
T.ToDtype(torch.float, scale=True),
196196
]
197197
)
198198

tests/test_dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def test_loader_with_transforms(self, tmp_path):
286286
input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}
287287
target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}
288288

289-
raw_transforms = T.Compose([T.Normalize(mean=[0.0], std=[255.0])])
289+
raw_transforms = T.Compose([T.ToDtype(torch.float, scale=True)])
290290
target_transforms = T.Compose([Binarize(threshold=0.5)])
291291

292292
dataset = CellMapDataset(

tests/test_dataset_writer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88
import torchvision.transforms.v2 as T
9+
import torch
910

1011
from cellmap_data import CellMapDatasetWriter
1112

@@ -290,7 +291,7 @@ def test_writer_with_value_transforms(self, tmp_path):
290291

291292
output_path = tmp_path / "output.zarr"
292293

293-
raw_transform = T.Normalize(mean=[0.0], std=[255.0])
294+
raw_transform = T.ToDtype(torch.float, scale=True)
294295

295296
target_bounds = {
296297
"pred": {

tests/test_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_basic_training_setup(self, tmp_path):
4343

4444
raw_transforms = T.Compose(
4545
[
46-
T.Normalize(mean=[0.0], std=[255.0]),
46+
T.ToDtype(torch.float, scale=True),
4747
GaussianNoise(std=0.05),
4848
]
4949
)
@@ -220,7 +220,7 @@ def test_complete_augmentation_pipeline(self, tmp_path):
220220
raw_transforms = T.Compose(
221221
[
222222
NaNtoNum({"nan": 0.0}),
223-
T.Normalize(mean=[0.0], std=[255.0]),
223+
T.ToDtype(torch.float, scale=True),
224224
GaussianNoise(std=0.05),
225225
RandomContrast(contrast_range=(0.8, 1.2)),
226226
RandomGamma(gamma_range=(0.8, 1.2)),
@@ -271,7 +271,7 @@ def test_per_target_transforms(self, tmp_path):
271271
# Different transforms for different targets
272272
target_transforms = {
273273
"labels": T.Compose([Binarize(threshold=0.5)]),
274-
"distances": T.Compose([T.Normalize(mean=[0.0], std=[100.0])]),
274+
"distances": T.Compose([T.ToDtype(torch.float, scale=True)]),
275275
}
276276

277277
target_arrays = {

tests/test_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def test_sequential_transforms(self):
336336

337337
transforms = T.Compose(
338338
[
339-
T.Normalize(mean=[0.0], std=[255.0]),
339+
T.ToDtype(torch.float32, scale=True),
340340
GaussianNoise(std=0.01),
341341
RandomContrast(contrast_range=(0.9, 1.1)),
342342
]
@@ -356,7 +356,7 @@ def test_transform_pipeline(self):
356356
# Realistic preprocessing pipeline
357357
raw_transforms = T.Compose(
358358
[
359-
T.Normalize(mean=[-128.0], std=[128.0]), # Normalize around 0
359+
T.ToDtype(torch.float32, scale=True),
360360
GaussianNoise(std=0.05),
361361
RandomContrast(contrast_range=(0.8, 1.2)),
362362
]

0 commit comments

Comments
 (0)