Data augmentation

Transforms wrapper


source

CustomDictTransform


def CustomDictTransform(
    aug
):

A class that serves as a wrapper to perform an identical transformation on both the image and the target (if it’s a mask).

Vanilla transforms


source

do_pad_or_crop


def do_pad_or_crop(
    o, target_shape, padding_mode, mask_name, dtype:_TensorMeta=Tensor
):

source

PadOrCrop


def PadOrCrop(
    size, padding_mode:int=0, mask_name:NoneType=None
):

Resize image using TorchIO CropOrPad.


source

ZNormalization


def ZNormalization(
    masking_method:NoneType=None, channel_wise:bool=True
):

Apply TorchIO ZNormalization.


source

RescaleIntensity


def RescaleIntensity(
    out_min_max:tuple, in_min_max:tuple
):

Apply TorchIO RescaleIntensity for robust intensity scaling.

Args: out_min_max (tuple[float, float]): Output intensity range (min, max) in_min_max (tuple[float, float]): Input intensity range (min, max)

Example for CT images: # Normalize CT from air (-1000 HU) to bone (1000 HU) into range (-1, 1) transform = RescaleIntensity(out_min_max=(-1, 1), in_min_max=(-1000, 1000))


source

NormalizeIntensity


def NormalizeIntensity(
    nonzero:bool=True, channel_wise:bool=True, subtrahend:float=None, divisor:float=None
):

Apply MONAI NormalizeIntensity.

Args: nonzero (bool): Only normalize non-zero values (default: True) channel_wise (bool): Apply normalization per channel (default: True) subtrahend (float, optional): Value to subtract
divisor (float, optional): Value to divide by


source

BraTSMaskConverter


def BraTSMaskConverter(
    enc:NoneType=None, dec:NoneType=None, split_idx:NoneType=None, order:NoneType=None
):

Convert BraTS masks.


source

BinaryConverter


def BinaryConverter(
    enc:NoneType=None, dec:NoneType=None, split_idx:NoneType=None, order:NoneType=None
):

Convert to binary mask.


source

RandomGhosting


def RandomGhosting(
    intensity:tuple=(0.5, 1), p:float=0.5
):

Apply TorchIO RandomGhosting.


source

RandomSpike


def RandomSpike(
    num_spikes:int=1, intensity:tuple=(1, 3), p:float=0.5
):

Apply TorchIO RandomSpike.


source

RandomNoise


def RandomNoise(
    mean:int=0, std:tuple=(0, 0.25), p:float=0.5
):

Apply TorchIO RandomNoise.


source

RandomBiasField


def RandomBiasField(
    coefficients:float=0.5, order:int=3, p:float=0.5
):

Apply TorchIO RandomBiasField.


source

RandomBlur


def RandomBlur(
    std:tuple=(0, 2), p:float=0.5
):

Apply TorchIO RandomBlur.


source

RandomGamma


def RandomGamma(
    log_gamma:tuple=(-0.3, 0.3), p:float=0.5
):

Apply TorchIO RandomGamma.


source

RandomIntensityScale


def RandomIntensityScale(
    scale_range:tuple=(0.5, 2.0), p:float=0.5
):

Randomly scale image intensities by a multiplicative factor.

Useful for domain generalization across different acquisition protocols with varying intensity ranges.

Args: scale_range (tuple[float, float]): Range of scale factors (min, max). Values > 1 increase intensity, < 1 decrease intensity. p (float): Probability of applying the transform (default: 0.5)

Example: # Scale intensities randomly between 0.5x and 2.0x transform = RandomIntensityScale(scale_range=(0.5, 2.0), p=0.3)


source

RandomMotion


def RandomMotion(
    degrees:int=10, translation:int=10, num_transforms:int=2, image_interpolation:str='linear', p:float=0.5
):

Apply TorchIO RandomMotion.


source

RandomAnisotropy


def RandomAnisotropy(
    axes:tuple=(0, 1, 2), downsampling:tuple=(1.5, 5), image_interpolation:str='linear', scalars_only:bool=True,
    p:float=0.5
):

Apply TorchIO RandomAnisotropy.


source

RandomCutout


def RandomCutout(
    holes:int=1, max_holes:int=3, spatial_size:int=8, max_spatial_size:int=16, fill:str='min', mask_only:bool=True,
    p:float=0.2
):

Randomly erase spherical regions in 3D medical images with mask-aware placement.

Simulates post-operative surgical cavities by filling random ellipsoid volumes with specified values. When mask_only=True (default), cutouts only affect voxels inside the segmentation mask, ensuring no healthy tissue is modified.

Args: holes: Minimum number of cutout regions. Default: 1. max_holes: Maximum number of regions. Default: 3. spatial_size: Minimum cutout diameter in voxels. Default: 8. max_spatial_size: Maximum cutout diameter. Default: 16. fill: Fill value - ‘min’, ‘mean’, ‘random’, or float. Default: ‘min’. mask_only: If True, cutouts only affect mask-positive voxels (tumor tissue). If False, cutouts can affect any voxel (original behavior). Default: True. p: Probability of applying transform. Default: 0.2.

Example: >>> # Simulate post-op cavities only within tumor regions >>> tfm = RandomCutout(holes=1, max_holes=2, spatial_size=10, … max_spatial_size=25, fill=‘min’, mask_only=True, p=0.2)

>>> # Original behavior - cutouts anywhere in the volume
>>> tfm = RandomCutout(mask_only=False, p=0.2)

Dictionary transforms


source

RandomElasticDeformation


def RandomElasticDeformation(
    num_control_points:int=7, max_displacement:float=7.5, image_interpolation:str='linear', p:float=0.5
):

Apply TorchIO RandomElasticDeformation.


source

RandomAffine


def RandomAffine(
    scales:int=0, degrees:int=10, translation:int=0, isotropic:bool=False, image_interpolation:str='linear',
    default_pad_value:float=0.0, p:float=0.5
):

Apply TorchIO RandomAffine.


source

RandomFlip


def RandomFlip(
    axes:str='LRAPIS', p:float=0.5
):

Apply TorchIO RandomFlip.


source

OneOf


def OneOf(
    transform_dict, p:int=1
):

Apply only one of the given transforms using TorchIO OneOf.

Augmentation suggestion

GPU patch augmentation


source

GpuPatchAugmentation


def GpuPatchAugmentation(
    affine:NoneType=None, anisotropy:NoneType=None, flip:NoneType=None, gamma:NoneType=None,
    intensity_scale:NoneType=None, noise:NoneType=None, blur:NoneType=None
):

GPU-batched augmentation for patch-based training.

Operates on [B, C, D, H, W] tensors already on GPU. All operations run under torch.no_grad() since augmentation does not need gradient tracking.

Transform order: spatial (affine, anisotropy, flip) then intensity (gamma, intensity_scale, noise, blur). Spatial transforms apply the same parameters to both image and mask. Intensity transforms skip the mask.

Each transform is controlled by a parameter dict with at minimum a ‘p’ key for per-sample probability. Pass None to disable a transform.

Args: affine: dict with keys ‘scales’, ‘degrees’, ‘translation’, ‘default_pad_value’, ‘p’. None to disable. anisotropy: dict with keys ‘axes’, ‘downsampling’, ‘p’. None to disable. flip: dict with keys ‘axes’, ‘p’. None to disable. gamma: dict with keys ‘log_gamma’, ‘p’. None to disable. intensity_scale: dict with keys ‘scale_range’, ‘p’. None to disable. noise: dict with keys ‘std’, ‘p’. None to disable. blur: dict with keys ‘std’, ‘p’. None to disable.

Example::

>>> gpu_aug = GpuPatchAugmentation(
...     affine={'scales': (0.7, 1.4), 'degrees': (30, 30, 30),
...             'translation': (25, 25, 10), 'default_pad_value': 0., 'p': 0.2},
...     gamma={'log_gamma': (-0.3, 0.3), 'p': 0.3},
...     flip={'axes': (0, 1, 2), 'p': 0.5},
... )
>>> img_aug, mask_aug = gpu_aug(img_gpu, mask_gpu)

source

gpu_patch_augmentations


def gpu_patch_augmentations(
    patch_size, target_spacing, anisotropy_threshold:float=3.0, translation_fraction:float=0.15, affine_p:float=0.2,
    anisotropy_p:float=0.25, gamma_p:float=0.3, intensity_scale_p:float=0.1, noise_p:float=0.1, blur_p:float=0.2,
    flip_p:float=0.5
):

Create GpuPatchAugmentation with nnU-Net-inspired defaults.

Factory function that mirrors suggest_patch_augmentations but returns a GpuPatchAugmentation for GPU-batched operation. Uses the same shared parameter logic via _compute_patch_aug_params.

Args: patch_size: List/tuple of 3 ints – patch dimensions. target_spacing: List/tuple of 3 floats – voxel spacing. anisotropy_threshold: Ratio threshold for anisotropy detection (default 3.0). translation_fraction: Fraction of patch_size for translation (default 0.15). affine_p: Probability for RandomAffine (default 0.2). anisotropy_p: Probability for RandomAnisotropy (default 0.25). gamma_p: Probability for RandomGamma (default 0.3). intensity_scale_p: Probability for RandomIntensityScale (default 0.1). noise_p: Probability for RandomNoise (default 0.1). blur_p: Probability for RandomBlur (default 0.2). flip_p: Probability for RandomFlip per axis (default 0.5).

Returns: GpuPatchAugmentation instance.

Example::

>>> gpu_aug = gpu_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5])
>>> dls = MedPatchDataLoaders.from_df(..., gpu_augmentation=gpu_aug)

source

suggest_patch_augmentations


def suggest_patch_augmentations(
    patch_size, target_spacing, anisotropy_threshold:float=3.0, translation_fraction:float=0.15
):

Suggest patch-based augmentations with nnU-Net-inspired defaults.

Derives rotation degrees, translation, and RandomAnisotropy axes from patch geometry and voxel spacing. Returns a list of fastMONAI transform instances ready for the patch_tfms parameter in MedPatchDataLoaders.

Anisotropy detection: if max(spacing)/min(spacing) >= threshold, rotation is restricted to 5 deg out-of-plane and 30 deg in-plane. Otherwise 30 deg symmetric. Translation is patch_size * fraction per axis.

Args: patch_size: List/tuple of 3 ints – patch dimensions. target_spacing: List/tuple of 3 floats – voxel spacing. anisotropy_threshold: Ratio threshold for anisotropy detection (default 3.0). translation_fraction: Fraction of patch_size for translation (default 0.15).

Returns: list: fastMONAI transform instances (7 normally, 6 if RandomAnisotropy omitted).

Example::

>>> patch_tfms = suggest_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5])
>>> dls = MedPatchDataLoaders.from_df(..., patch_tfms=patch_tfms)
from fastcore.test import test_eq, test_fail

# Isotropic case
tfms = suggest_patch_augmentations([128, 128, 128], [1.0, 1.0, 1.0])
test_eq(len(tfms), 7)
test_eq(type(tfms[0]), RandomAffine)
test_eq(type(tfms[-1]), RandomFlip)

# Anisotropic case (axis 2 thick): degrees=(5, 5, 30) -> (-5, 5, -5, 5, -30, 30)
tfms = suggest_patch_augmentations([128, 128, 32], target_spacing=[0.5, 0.5, 1.5])
test_eq(len(tfms), 7)
aff = tfms[0].tio_transform
test_eq(aff.degrees, (-5, 5, -5, 5, -30, 30))

# Anisotropic case (axis 0 thick): degrees=(30, 5, 5) -> (-30, 30, -5, 5, -5, 5)
tfms = suggest_patch_augmentations([32, 128, 128], target_spacing=[3.0, 0.5, 0.5])
aff = tfms[0].tio_transform
test_eq(aff.degrees, (-30, 30, -5, 5, -5, 5))

# Isotropic spacing -> symmetric degrees: 30 -> (-30, 30, -30, 30, -30, 30)
tfms = suggest_patch_augmentations([64, 64, 64], [1.0, 1.0, 1.0])
aff = tfms[0].tio_transform
test_eq(aff.degrees, (-30, 30, -30, 30, -30, 30))

# 2D-like patch [128, 128, 1]
tfms = suggest_patch_augmentations([128, 128, 1], [1.0, 1.0, 1.0])
aniso_tfm = tfms[1]
test_eq(type(aniso_tfm), RandomAnisotropy)
test_eq(aniso_tfm.add_anisotropy.axes, (0, 1))

# All dims 1 -> RandomAnisotropy omitted
tfms = suggest_patch_augmentations([1, 1, 1], [1.0, 1.0, 1.0])
test_eq(len(tfms), 6)
test_eq(all(not isinstance(t, RandomAnisotropy) for t in tfms), True)

# Wrong input lengths
test_fail(lambda: suggest_patch_augmentations([128, 128], [1.0, 1.0, 1.0]))
test_fail(lambda: suggest_patch_augmentations([128, 128, 128], [1.0, 1.0]))

# All returned transforms have .tio_transform
tfms = suggest_patch_augmentations([128, 128, 64], [1.0, 1.0, 1.0])
for t in tfms:
    assert hasattr(t, 'tio_transform'), f"{type(t).__name__} missing .tio_transform"
# Tests for _compute_patch_aug_params
params = _compute_patch_aug_params([128, 128, 128], [1.0, 1.0, 1.0])
test_eq(params['degrees'], (30, 30, 30))
test_eq(params['is_aniso'], False)
test_eq(params['aniso_axes'], (0, 1, 2))

params = _compute_patch_aug_params([128, 128, 32], [0.5, 0.5, 1.5])
test_eq(params['degrees'], (5, 5, 30))
test_eq(params['is_aniso'], True)

params = _compute_patch_aug_params([128, 128, 1], [1.0, 1.0, 1.0])
test_eq(params['aniso_axes'], (0, 1))

test_fail(lambda: _compute_patch_aug_params([128, 128], [1.0, 1.0, 1.0]))
test_fail(lambda: _compute_patch_aug_params([128, 128, 128], [1.0, 1.0]))

# Tests for _build_rotation_matrix_3d
zero_angles = torch.zeros(3, 3)
R = _build_rotation_matrix_3d(zero_angles)
test_eq(R.shape, (3, 3, 3))
for i in range(3):
    assert torch.allclose(R[i], torch.eye(3), atol=1e-6), "Zero angles should give identity"

# Tests for GpuPatchAugmentation: p=0 -> identity
gpu_aug_noop = GpuPatchAugmentation(
    affine={'scales': (0.7, 1.4), 'degrees': (30, 30, 30),
            'translation': (10, 10, 10), 'default_pad_value': 0., 'p': 0.0},
    anisotropy={'axes': (0, 1, 2), 'downsampling': (1.5, 4), 'p': 0.0},
    flip={'axes': (0, 1, 2), 'p': 0.0},
    gamma={'log_gamma': (-0.3, 0.3), 'p': 0.0},
    intensity_scale={'scale_range': (0.75, 1.25), 'p': 0.0},
    noise={'std': 0.1, 'p': 0.0},
    blur={'std': (0.5, 1.0), 'p': 0.0},
)
test_img = torch.randn(2, 1, 16, 16, 16)
test_mask = torch.zeros(2, 1, 16, 16, 16)
test_mask[:, :, 4:12, 4:12, 4:12] = 1.0
out_img, out_mask = gpu_aug_noop(test_img, test_mask)
test_eq(torch.equal(out_img, test_img), True)
test_eq(torch.equal(out_mask, test_mask), True)

# Tests for GpuPatchAugmentation: p=1 -> shapes preserved
gpu_aug_all = GpuPatchAugmentation(
    affine={'scales': (0.7, 1.4), 'degrees': (30, 30, 30),
            'translation': (2, 2, 2), 'default_pad_value': 0., 'p': 1.0},
    anisotropy={'axes': (0, 1, 2), 'downsampling': (1.5, 4), 'p': 1.0},
    flip={'axes': (0, 1, 2), 'p': 1.0},
    gamma={'log_gamma': (-0.3, 0.3), 'p': 1.0},
    intensity_scale={'scale_range': (0.75, 1.25), 'p': 1.0},
    noise={'std': 0.1, 'p': 1.0},
    blur={'std': (0.5, 1.0), 'p': 1.0},
)
out_img, out_mask = gpu_aug_all(test_img.clone(), test_mask.clone())
test_eq(out_img.shape, test_img.shape)
test_eq(out_mask.shape, test_mask.shape)

# Intensity-only augmentation: mask unchanged
gpu_aug_intensity = GpuPatchAugmentation(
    gamma={'log_gamma': (-0.3, 0.3), 'p': 1.0},
    intensity_scale={'scale_range': (0.75, 1.25), 'p': 1.0},
    noise={'std': 0.1, 'p': 1.0},
    blur={'std': (0.5, 1.0), 'p': 1.0},
)
mask_copy = test_mask.clone()
_, out_mask = gpu_aug_intensity(test_img.clone(), mask_copy)
test_eq(torch.equal(out_mask, test_mask), True)

# None mask handling
out_img, out_mask = gpu_aug_all(test_img.clone(), None)
test_eq(out_mask, None)
test_eq(out_img.shape, test_img.shape)
# Tests for gpu_patch_augmentations factory
gpu_aug = gpu_patch_augmentations([128, 128, 128], [1.0, 1.0, 1.0])
test_eq(type(gpu_aug), GpuPatchAugmentation)
test_eq(gpu_aug.affine['degrees'], (30, 30, 30))
test_eq(gpu_aug.anisotropy['axes'], (0, 1, 2))

# Anisotropic case
gpu_aug_aniso = gpu_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5])
test_eq(gpu_aug_aniso.affine['degrees'], (5, 5, 30))

# All dims 1 -> no anisotropy
gpu_aug_nodim = gpu_patch_augmentations([1, 1, 1], [1.0, 1.0, 1.0])
test_eq(gpu_aug_nodim.anisotropy, None)

# repr contains expected parts
repr_str = repr(gpu_aug)
assert 'GpuPatchAugmentation' in repr_str
assert 'affine' in repr_str
assert 'flip' in repr_str

# Flip with p=1 modifies both img and mask consistently
gpu_aug_flip = GpuPatchAugmentation(flip={'axes': (0,), 'p': 1.0})
# Create asymmetric tensor to verify flip happened
asym_img = torch.zeros(1, 1, 8, 8, 8)
asym_img[0, 0, 0, :, :] = 1.0  # First slice lit up
asym_mask = asym_img.clone()
out_img, out_mask = gpu_aug_flip(asym_img.clone(), asym_mask.clone())
test_eq(out_img.shape, asym_img.shape)
# After flip on axis 0 (D-axis), first slice -> last slice
test_eq(out_img[0, 0, -1, :, :].sum() > 0, True)
test_eq(out_mask[0, 0, -1, :, :].sum() > 0, True)
# Test .tio_transform property
# CustomDictTransform-based wrappers
test_eq(type(RandomAffine(degrees=10).tio_transform), tio.RandomAffine)
test_eq(type(RandomFlip(p=0.5).tio_transform), tio.RandomFlip)
test_eq(type(RandomElasticDeformation(p=0.5).tio_transform), tio.RandomElasticDeformation)

# DisplayedTransform-based wrappers
test_eq(type(PadOrCrop([64, 64, 64]).tio_transform), tio.CropOrPad)
test_eq(type(ZNormalization().tio_transform), tio.ZNormalization)
test_eq(type(RescaleIntensity((-1, 1), (-1000, 1000)).tio_transform), tio.RescaleIntensity)
test_eq(type(RandomGamma(p=0.5).tio_transform), tio.RandomGamma)
test_eq(type(RandomNoise(p=0.5).tio_transform), tio.RandomNoise)
test_eq(type(RandomBiasField(p=0.5).tio_transform), tio.RandomBiasField)
test_eq(type(RandomBlur(p=0.5).tio_transform), tio.RandomBlur)
test_eq(type(RandomGhosting(p=0.5).tio_transform), tio.RandomGhosting)
test_eq(type(RandomSpike(p=0.5).tio_transform), tio.RandomSpike)
test_eq(type(RandomMotion(p=0.5).tio_transform), tio.RandomMotion)
test_eq(type(RandomAnisotropy(p=0.5).tio_transform), tio.RandomAnisotropy)

# Custom TorchIO wrappers (isinstance check since these are custom subclasses)
test_eq(isinstance(RandomIntensityScale(p=0.5).tio_transform, tio.IntensityTransform), True)
test_eq(isinstance(NormalizeIntensity().tio_transform, tio.IntensityTransform), True)
# Test RandomCutout (ItemTransform - expects tuple input)
import numpy as np

# Create test data
test_img = MedImage(torch.randn(1, 32, 32, 32))
test_mask = MedMask(torch.zeros(1, 32, 32, 32))
test_mask[0, 10:20, 10:20, 10:20] = 1.0  # Tumor region

# Test mask_only=True (default): only tumor voxels affected
cutout = RandomCutout(holes=1, spatial_size=8, fill='min', mask_only=True, p=1.0)
result_img, result_mask = cutout.encodes((test_img, test_mask))
test_eq(type(result_img), MedImage)
test_eq(type(result_mask), MedMask)
test_eq(result_img.shape, test_img.shape)
# Verify: healthy tissue (mask==0) unchanged (use numpy for comparison)
healthy_region = test_mask.numpy()[0] == 0
test_eq(np.array_equal(result_img.numpy()[0, healthy_region], test_img.numpy()[0, healthy_region]), True)
# Verify: mask unchanged
test_eq(torch.equal(result_mask, test_mask), True)

# Test empty mask skips cutout (mask_only=True)
empty_mask = MedMask(torch.zeros(1, 32, 32, 32))
result_img, _ = cutout.encodes((test_img, empty_mask))
test_eq(torch.equal(result_img, test_img), True)  # Unchanged

# Test mask_only=False: cutouts can affect any voxel
cutout_any = RandomCutout(mask_only=False, p=1.0)
result_img, _ = cutout_any.encodes((test_img, test_mask))
test_eq(result_img.shape, test_img.shape)

# Test with TensorCategory target (classification task with mask_only=False)
test_label = TensorCategory(1)
cutout_cls = RandomCutout(mask_only=False, p=1.0)
result_img, result_label = cutout_cls.encodes((test_img, test_label))
test_eq(type(result_img), MedImage)
test_eq(result_label, test_label)

# Test TensorCategory with mask_only=True skips cutout (no mask available)
cutout_mask_only = RandomCutout(mask_only=True, p=1.0)
result_img, result_label = cutout_mask_only.encodes((test_img, test_label))
test_eq(torch.equal(result_img, test_img), True)  # Unchanged - no mask to intersect

# tio_transform property
test_eq(isinstance(cutout.tio_transform, tio.IntensityTransform), True)

# Test fill modes with mask_only=False
for fill_mode in ['min', 'mean', 'random', 0.0]:
    cutout = RandomCutout(fill=fill_mode, mask_only=False, p=1.0)
    result_img, _ = cutout.encodes((test_img, test_mask))
    test_eq(result_img.shape, test_img.shape)