# Test .tio_transform property
# CustomDictTransform-based wrappers
test_eq(type(RandomAffine(degrees=10).tio_transform), tio.RandomAffine)
test_eq(type(RandomFlip(p=0.5).tio_transform), tio.RandomFlip)
test_eq(type(RandomElasticDeformation(p=0.5).tio_transform), tio.RandomElasticDeformation)
# DisplayedTransform-based wrappers
test_eq(type(PadOrCrop([64, 64, 64]).tio_transform), tio.CropOrPad)
test_eq(type(ZNormalization().tio_transform), tio.ZNormalization)
test_eq(type(RescaleIntensity((-1, 1), (-1000, 1000)).tio_transform), tio.RescaleIntensity)
test_eq(type(RandomGamma(p=0.5).tio_transform), tio.RandomGamma)
test_eq(type(RandomNoise(p=0.5).tio_transform), tio.RandomNoise)
test_eq(type(RandomBiasField(p=0.5).tio_transform), tio.RandomBiasField)
test_eq(type(RandomBlur(p=0.5).tio_transform), tio.RandomBlur)
test_eq(type(RandomGhosting(p=0.5).tio_transform), tio.RandomGhosting)
test_eq(type(RandomSpike(p=0.5).tio_transform), tio.RandomSpike)
test_eq(type(RandomMotion(p=0.5).tio_transform), tio.RandomMotion)
# Custom TorchIO wrappers (isinstance check since these are custom subclasses)
test_eq(isinstance(RandomIntensityScale(p=0.5).tio_transform, tio.IntensityTransform), True)
test_eq(isinstance(NormalizeIntensity().tio_transform, tio.IntensityTransform), True)Data augmentation
Transforms wrapper
CustomDictTransform
def CustomDictTransform(
aug
):
A class that serves as a wrapper to perform an identical transformation on both the image and the target (if it’s a mask).
Vanilla transforms
do_pad_or_crop
def do_pad_or_crop(
o, target_shape, padding_mode, mask_name, dtype:_TensorMeta=Tensor
):
PadOrCrop
def PadOrCrop(
size, padding_mode:int=0, mask_name:NoneType=None
):
Resize image using TorchIO CropOrPad.
ZNormalization
def ZNormalization(
masking_method:NoneType=None, channel_wise:bool=True
):
Apply TorchIO ZNormalization.
RescaleIntensity
def RescaleIntensity(
out_min_max:tuple, in_min_max:tuple
):
Apply TorchIO RescaleIntensity for robust intensity scaling.
Args: out_min_max (tuple[float, float]): Output intensity range (min, max) in_min_max (tuple[float, float]): Input intensity range (min, max)
Example for CT images: # Normalize CT from air (-1000 HU) to bone (1000 HU) into range (-1, 1) transform = RescaleIntensity(out_min_max=(-1, 1), in_min_max=(-1000, 1000))
NormalizeIntensity
def NormalizeIntensity(
nonzero:bool=True, channel_wise:bool=True, subtrahend:float=None, divisor:float=None
):
Apply MONAI NormalizeIntensity.
Args: nonzero (bool): Only normalize non-zero values (default: True) channel_wise (bool): Apply normalization per channel (default: True) subtrahend (float, optional): Value to subtract
divisor (float, optional): Value to divide by
BraTSMaskConverter
def BraTSMaskConverter(
enc:NoneType=None, dec:NoneType=None, split_idx:NoneType=None, order:NoneType=None
):
Convert BraTS masks.
BinaryConverter
def BinaryConverter(
enc:NoneType=None, dec:NoneType=None, split_idx:NoneType=None, order:NoneType=None
):
Convert to binary mask.
RandomGhosting
def RandomGhosting(
intensity:tuple=(0.5, 1), p:float=0.5
):
Apply TorchIO RandomGhosting.
RandomSpike
def RandomSpike(
num_spikes:int=1, intensity:tuple=(1, 3), p:float=0.5
):
Apply TorchIO RandomSpike.
RandomNoise
def RandomNoise(
mean:int=0, std:tuple=(0, 0.25), p:float=0.5
):
Apply TorchIO RandomNoise.
RandomBiasField
def RandomBiasField(
coefficients:float=0.5, order:int=3, p:float=0.5
):
Apply TorchIO RandomBiasField.
RandomBlur
def RandomBlur(
std:tuple=(0, 2), p:float=0.5
):
Apply TorchIO RandomBlur.
RandomGamma
def RandomGamma(
log_gamma:tuple=(-0.3, 0.3), p:float=0.5
):
Apply TorchIO RandomGamma.
RandomIntensityScale
def RandomIntensityScale(
scale_range:tuple=(0.5, 2.0), p:float=0.5
):
Randomly scale image intensities by a multiplicative factor.
Useful for domain generalization across different acquisition protocols with varying intensity ranges.
Args: scale_range (tuple[float, float]): Range of scale factors (min, max). Values > 1 increase intensity, < 1 decrease intensity. p (float): Probability of applying the transform (default: 0.5)
Example: # Scale intensities randomly between 0.5x and 2.0x transform = RandomIntensityScale(scale_range=(0.5, 2.0), p=0.3)
RandomMotion
def RandomMotion(
degrees:int=10, translation:int=10, num_transforms:int=2, image_interpolation:str='linear', p:float=0.5
):
Apply TorchIO RandomMotion.
RandomCutout
def RandomCutout(
holes:int=1, max_holes:int=3, spatial_size:int=8, max_spatial_size:int=16, fill:str='min', mask_only:bool=True,
p:float=0.2
):
Randomly erase spherical regions in 3D medical images with mask-aware placement.
Simulates post-operative surgical cavities by filling random ellipsoid volumes with specified values. When mask_only=True (default), cutouts only affect voxels inside the segmentation mask, ensuring no healthy tissue is modified.
Args: holes: Minimum number of cutout regions. Default: 1. max_holes: Maximum number of regions. Default: 3. spatial_size: Minimum cutout diameter in voxels. Default: 8. max_spatial_size: Maximum cutout diameter. Default: 16. fill: Fill value - ‘min’, ‘mean’, ‘random’, or float. Default: ‘min’. mask_only: If True, cutouts only affect mask-positive voxels (tumor tissue). If False, cutouts can affect any voxel (original behavior). Default: True. p: Probability of applying transform. Default: 0.2.
Example: >>> # Simulate post-op cavities only within tumor regions >>> tfm = RandomCutout(holes=1, max_holes=2, spatial_size=10, … max_spatial_size=25, fill=‘min’, mask_only=True, p=0.2)
>>> # Original behavior - cutouts anywhere in the volume
>>> tfm = RandomCutout(mask_only=False, p=0.2)
Dictionary transforms
RandomElasticDeformation
def RandomElasticDeformation(
num_control_points:int=7, max_displacement:float=7.5, image_interpolation:str='linear', p:float=0.5
):
Apply TorchIO RandomElasticDeformation.
RandomAffine
def RandomAffine(
scales:int=0, degrees:int=10, translation:int=0, isotropic:bool=False, image_interpolation:str='linear',
default_pad_value:float=0.0, p:float=0.5
):
Apply TorchIO RandomAffine.
RandomFlip
def RandomFlip(
axes:str='LR', p:float=0.5
):
Apply TorchIO RandomFlip.
OneOf
def OneOf(
transform_dict, p:int=1
):
Apply only one of the given transforms using TorchIO OneOf.
# Test RandomCutout (ItemTransform - expects tuple input)
import numpy as np
# Create test data
test_img = MedImage(torch.randn(1, 32, 32, 32))
test_mask = MedMask(torch.zeros(1, 32, 32, 32))
test_mask[0, 10:20, 10:20, 10:20] = 1.0 # Tumor region
# Test mask_only=True (default): only tumor voxels affected
cutout = RandomCutout(holes=1, spatial_size=8, fill='min', mask_only=True, p=1.0)
result_img, result_mask = cutout.encodes((test_img, test_mask))
test_eq(type(result_img), MedImage)
test_eq(type(result_mask), MedMask)
test_eq(result_img.shape, test_img.shape)
# Verify: healthy tissue (mask==0) unchanged (use numpy for comparison)
healthy_region = test_mask.numpy()[0] == 0
test_eq(np.array_equal(result_img.numpy()[0, healthy_region], test_img.numpy()[0, healthy_region]), True)
# Verify: mask unchanged
test_eq(torch.equal(result_mask, test_mask), True)
# Test empty mask skips cutout (mask_only=True)
empty_mask = MedMask(torch.zeros(1, 32, 32, 32))
result_img, _ = cutout.encodes((test_img, empty_mask))
test_eq(torch.equal(result_img, test_img), True) # Unchanged
# Test mask_only=False: cutouts can affect any voxel
cutout_any = RandomCutout(mask_only=False, p=1.0)
result_img, _ = cutout_any.encodes((test_img, test_mask))
test_eq(result_img.shape, test_img.shape)
# Test with TensorCategory target (classification task with mask_only=False)
test_label = TensorCategory(1)
cutout_cls = RandomCutout(mask_only=False, p=1.0)
result_img, result_label = cutout_cls.encodes((test_img, test_label))
test_eq(type(result_img), MedImage)
test_eq(result_label, test_label)
# Test TensorCategory with mask_only=True skips cutout (no mask available)
cutout_mask_only = RandomCutout(mask_only=True, p=1.0)
result_img, result_label = cutout_mask_only.encodes((test_img, test_label))
test_eq(torch.equal(result_img, test_img), True) # Unchanged - no mask to intersect
# tio_transform property
test_eq(isinstance(cutout.tio_transform, tio.IntensityTransform), True)
# Test fill modes with mask_only=False
for fill_mode in ['min', 'mean', 'random', 0.0]:
cutout = RandomCutout(fill=fill_mode, mask_only=False, p=1.0)
result_img, _ = cutout.encodes((test_img, test_mask))
test_eq(result_img.shape, test_img.shape)