Patch-based training

Patch-based training and inference for 3D medical image segmentation using TorchIO’s Queue mechanism.

source

normalize_patch_transforms

 normalize_patch_transforms (tfms:list)

Normalize transforms for patch-based workflow.

Extracts underlying TorchIO transforms from fastMONAI wrappers. Also accepts raw TorchIO transforms for backward compatibility.

This enables using the same transform syntax in both standard and patch-based workflows:

>>> from fastMONAI.vision_augmentation import RandomAffine, RandomGamma
>>>
>>> # Same syntax works in both contexts
>>> item_tfms = [RandomAffine(degrees=10), RandomGamma(p=0.5)]   # Standard
>>> patch_tfms = [RandomAffine(degrees=10), RandomGamma(p=0.5)]  # Patch-based

Args: tfms: List of fastMONAI wrappers or raw TorchIO transforms

Returns: List of raw TorchIO transforms suitable for tio.Compose()

# Test _extract_tio_transform and normalize_patch_transforms
from fastMONAI.vision_augmentation import RandomAffine, RandomGamma, RandomFlip, RandomNoise

# Test extraction from fastMONAI wrappers via .tio_transform property
wrapped_affine = RandomAffine(degrees=10)
extracted = _extract_tio_transform(wrapped_affine)
test_eq(type(extracted), tio.RandomAffine)

wrapped_gamma = RandomGamma(p=0.5)
extracted = _extract_tio_transform(wrapped_gamma)
test_eq(type(extracted), tio.RandomGamma)

# Test passthrough for raw TorchIO transforms
raw_affine = tio.RandomAffine(degrees=10)
extracted = _extract_tio_transform(raw_affine)
test_eq(extracted, raw_affine)  # Should be the exact same object

raw_flip = tio.RandomFlip(p=0.5)
extracted = _extract_tio_transform(raw_flip)
test_eq(extracted, raw_flip)

# Test normalize_patch_transforms with list of fastMONAI wrappers
tfms = [RandomAffine(degrees=5), RandomGamma(p=0.5), RandomNoise(p=0.3)]
normalized = normalize_patch_transforms(tfms)
test_eq(len(normalized), 3)
test_eq(type(normalized[0]), tio.RandomAffine)
test_eq(type(normalized[1]), tio.RandomGamma)
test_eq(type(normalized[2]), tio.RandomNoise)

# Test normalize_patch_transforms with mixed list (wrappers + raw TorchIO)
mixed_tfms = [RandomAffine(degrees=5), tio.RandomGamma(p=0.5)]
normalized = normalize_patch_transforms(mixed_tfms)
test_eq(len(normalized), 2)
test_eq(type(normalized[0]), tio.RandomAffine)
test_eq(type(normalized[1]), tio.RandomGamma)

# Test normalize_patch_transforms with None
test_eq(normalize_patch_transforms(None), None)

Configuration


source

PatchConfig

 PatchConfig (patch_size:list=<factory>, patch_overlap:int|float|list=0,
              samples_per_volume:int=8, sampler_type:str='uniform',
              label_probabilities:dict=None, queue_length:int=300,
              queue_num_workers:int=4, aggregation_mode:str='hann',
              apply_reorder:bool=True, target_spacing:list=None,
              padding_mode:int|float|str=0,
              keep_largest_component:bool=False)

Configuration for patch-based training and inference.

Args: patch_size: Size of patches [x, y, z]. patch_overlap: Overlap for inference GridSampler (int, float 0-1, or list). - Float 0-1: fraction of patch_size (e.g., 0.5 = 50% overlap) - Int >= 1: pixel overlap (e.g., 48 = 48 pixel overlap) - List: per-dimension overlap in pixels samples_per_volume: Number of patches to extract per volume during training. sampler_type: Type of sampler (‘uniform’, ‘label’, ‘weighted’). label_probabilities: For LabelSampler, dict mapping label values to probabilities. queue_length: Maximum number of patches to store in queue. queue_num_workers: Number of workers for parallel patch extraction. aggregation_mode: For inference, how to combine overlapping patches (‘crop’, ‘average’, ‘hann’). apply_reorder: Whether to reorder to RAS+ canonical orientation. Must match between training and inference. Defaults to True (the common case). target_spacing: Target voxel spacing [x, y, z] for resampling. Must match between training and inference. padding_mode: Padding mode for CropOrPad when image < patch_size. Default is 0 (zero padding) to align with nnU-Net’s approach. Can be int, float, or string (e.g., ‘minimum’, ‘mean’). keep_largest_component: If True, keep only the largest connected component in binary segmentation predictions. Only applies during inference when return_probabilities=False. Defaults to False.

Example: >>> config = PatchConfig( … patch_size=[96, 96, 96], … samples_per_volume=16, … sampler_type=‘label’, … label_probabilities={0: 0.1, 1: 0.9}, … target_spacing=[0.5, 0.5, 0.5] … )

# Test PatchConfig
config = PatchConfig(patch_size=[96, 96, 96], samples_per_volume=16)
test_eq(config.patch_size, [96, 96, 96])
test_eq(config.samples_per_volume, 16)
test_eq(config.sampler_type, 'uniform')
test_eq(config.apply_reorder, True)  # Default is now True (the common case)
test_eq(config.target_spacing, None)
test_eq(config.padding_mode, 0)

# Test with preprocessing params
config2 = PatchConfig(
    patch_size=[64, 64, 64],
    apply_reorder=True,
    target_spacing=[0.5, 0.5, 0.5],
    padding_mode=0
)
test_eq(config2.apply_reorder, True)
test_eq(config2.target_spacing, [0.5, 0.5, 0.5])

Subject conversion


source

med_to_subject

 med_to_subject (img:pathlib.Path|str, mask:pathlib.Path|str=None)

Create TorchIO Subject with LAZY loading (paths only, no tensor loading).

This function stores file paths in the Subject, allowing TorchIO’s Queue workers to load volumes on-demand during training. This is memory-efficient as volumes are not loaded into RAM until needed.

Args: img: Path to image file. mask: Path to mask file (optional).

Returns: TorchIO Subject with ‘image’ and optionally ‘mask’ keys (lazy loaded).

Example: >>> subject = med_to_subject(‘image.nii.gz’, ‘mask.nii.gz’) >>> # Volume NOT loaded yet - only path stored >>> data = subject[‘image’].data # NOW volume is loaded


source

create_subjects_dataset

 create_subjects_dataset (df:pandas.core.frame.DataFrame, img_col:str,
                          mask_col:str=None, pre_tfms:list=None,
                          ensure_affine_consistency:bool=True)

Build TorchIO SubjectsDataset with LAZY loading from DataFrame.

This function creates a SubjectsDataset that stores only file paths, not loaded tensors. Volumes are loaded on-demand by Queue workers, keeping memory usage constant regardless of dataset size.

Args: df: DataFrame with image (and optionally mask) paths. img_col: Column name containing image paths. mask_col: Column name containing mask paths (optional). pre_tfms: List of TorchIO transforms to apply before patch extraction. Use tio.ToCanonical() for reordering and tio.Resample() for resampling. ensure_affine_consistency: If True and mask_col is provided, automatically prepends tio.CopyAffine(target=‘image’) to ensure spatial metadata consistency between image and mask. This prevents “More than one value for direction found” errors. Defaults to True.

Returns: TorchIO SubjectsDataset with lazy-loaded subjects.

Example: >>> # Preprocessing via transforms (applied by workers on-demand) >>> pre_tfms = [ … tio.ToCanonical(), # Reorder to RAS+ … tio.Resample([0.5, 0.5, 0.5]), # Resample … tio.ZNormalization(), # Intensity normalization … ] >>> dataset = create_subjects_dataset( … df, img_col=‘image’, mask_col=‘label’, … pre_tfms=pre_tfms … ) >>> # Memory: ~0 MB (only paths stored, not volumes)

Sampler creation


source

create_patch_sampler

 create_patch_sampler (config:__main__.PatchConfig)

Create appropriate TorchIO sampler based on config.

Args: config: PatchConfig with sampler settings.

Returns: TorchIO PatchSampler instance.

Example: >>> config = PatchConfig(patch_size=[96, 96, 96], sampler_type=‘label’) >>> sampler = create_patch_sampler(config)

# Test sampler creation
config = PatchConfig(patch_size=[64, 64, 64], sampler_type='uniform')
sampler = create_patch_sampler(config)
test_eq(type(sampler), tio.UniformSampler)

# Test WeightedSampler raises NotImplementedError
from fastcore.test import test_fail
config_weighted = PatchConfig(patch_size=[64, 64, 64], sampler_type='weighted')
test_fail(lambda: create_patch_sampler(config_weighted), contains='WeightedSampler')
# Test utility functions
# Test _get_default_device
device = _get_default_device()
test_eq(type(device), torch.device)

# Test _warn_config_override doesn't raise when values match or when one is None
import warnings
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    _warn_config_override('test_param', True, True)  # Same values - no warning
    _warn_config_override('test_param', True, None)  # Explicit is None - no warning
    _warn_config_override('test_param', None, True)  # Config is None - no warning
    test_eq(len(w), 0)

# Test _warn_config_override warns when values differ
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    _warn_config_override('test_param', True, False)  # Different values - warning
    test_eq(len(w), 1)
    assert 'mismatch' in str(w[0].message)

Patch DataLoaders


source

MedPatchDataLoader

 MedPatchDataLoader
                     (subjects_dataset:torchio.data.dataset.SubjectsDatase
                     t, config:__main__.PatchConfig, batch_size:int=4,
                     patch_tfms:list=None, shuffle:bool=True,
                     drop_last:bool=False)

DataLoader wrapper for patch-based training with TorchIO Queue.

This class wraps a TorchIO Queue to provide a fastai-compatible DataLoader interface for patch-based training.

Args: subjects_dataset: TorchIO SubjectsDataset. config: PatchConfig with queue and sampler settings. batch_size: Number of patches per batch. Must be positive. patch_tfms: Transforms to apply to extracted patches (training only). Accepts both fastMONAI wrappers (e.g., RandomAffine, RandomGamma) and raw TorchIO transforms. fastMONAI wrappers are automatically normalized to raw TorchIO for internal use. shuffle: Whether to shuffle subjects and patches. drop_last: Whether to drop last incomplete batch.


source

MedPatchDataLoaders

 MedPatchDataLoaders (train_dl:__main__.MedPatchDataLoader,
                      valid_dl:__main__.MedPatchDataLoader,
                      device:torch.device=None)

fastai-compatible DataLoaders for patch-based training with LAZY loading.

This class provides train and validation DataLoaders that work with fastai’s Learner for patch-based training on 3D medical images.

Memory-efficient: Volumes are loaded on-demand by Queue workers, keeping memory usage constant (~150 MB) regardless of dataset size.

Automatic padding: Images smaller than patch_size are automatically padded using SpatialPad (zero padding, nnU-Net standard). Dimensions larger than patch_size are preserved. A message is printed at DataLoader creation to inform you that automatic padding is enabled. This ensures training matches inference behavior where both pad small dimensions to minimum patch_size.

Note: Validation uses the same sampling as training (pseudo Dice). For true validation metrics, use PatchInferenceEngine with GridSampler for full-volume sliding window inference.

Example: >>> import torchio as tio >>> >>> # New pattern: preprocessing params in config (DRY) >>> config = PatchConfig( … patch_size=[96, 96, 96], … apply_reorder=True, … target_spacing=[0.5, 0.5, 0.5] … ) >>> dls = MedPatchDataLoaders.from_df( … df, img_col=‘image’, mask_col=‘label’, … valid_pct=0.2, … patch_config=config, … pre_patch_tfms=[tio.ZNormalization()], … bs=4 … ) >>> learn = Learner(dls, model, loss_func=DiceLoss())

Patch-based Inference


source

PatchInferenceEngine

 PatchInferenceEngine (learner, config:__main__.PatchConfig,
                       apply_reorder:bool=None, target_spacing:list=None,
                       batch_size:int=4, pre_inference_tfms:list=None)

Patch-based inference with automatic volume reconstruction.

Uses TorchIO’s GridSampler to extract overlapping patches and GridAggregator to reconstruct the full volume from predictions.

Args: learner: fastai Learner or PyTorch model (nn.Module). When passing a raw PyTorch model, load weights first with model.load_state_dict(). config: PatchConfig with inference settings. Preprocessing params (apply_reorder, target_spacing, padding_mode) can be set here for DRY usage. apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value. target_spacing: Target voxel spacing. If None, uses config value. batch_size: Number of patches to predict at once. Must be positive. pre_inference_tfms: List of TorchIO transforms to apply before patch extraction. IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]). This ensures preprocessing consistency between training and inference. Accepts both fastMONAI wrappers and raw TorchIO transforms.

Example: >>> # Option 1: From fastai Learner >>> engine = PatchInferenceEngine(learn, config, pre_inference_tfms=[ZNormalization()]) >>> pred = engine.predict(‘image.nii.gz’)

>>> # Option 2: From raw PyTorch model (recommended for deployment)
>>> model = UNet(spatial_dims=3, in_channels=1, out_channels=2, ...)
>>> model.load_state_dict(torch.load('weights.pth'))
>>> model.cuda().eval()
>>> engine = PatchInferenceEngine(model, config, pre_inference_tfms=[ZNormalization()])
>>> pred = engine.predict('image.nii.gz')
# Test PatchInferenceEngine correctly handles raw PyTorch models
# This verifies the isinstance fix - MONAI UNet has a .model attribute (internal Sequential)
# that should NOT be extracted when passing a raw model
from monai.networks.nets import UNet
from monai.networks.layers import Norm

test_model = UNet(
    spatial_dims=3, in_channels=1, out_channels=2,
    channels=(16, 32), strides=(2,), num_res_units=1,
    norm=Norm.INSTANCE
)
test_config = PatchConfig(patch_size=[32, 32, 32])

# Verify MONAI UNet has .model attribute (the bug scenario)
test_eq(hasattr(test_model, 'model'), True)
test_eq(type(test_model.model).__name__, 'Sequential')  # It's an internal Sequential

# Create engine with raw model - should store the UNet, NOT unet.model
engine = PatchInferenceEngine(test_model, test_config)
test_eq(type(engine.model), UNet)  # Should be UNet, not Sequential

print("PatchInferenceEngine raw model detection test passed!")
PatchInferenceEngine raw model detection test passed!

source

patch_inference

 patch_inference (learner, config:__main__.PatchConfig, file_paths:list,
                  apply_reorder:bool=None, target_spacing:list=None,
                  batch_size:int=4, return_probabilities:bool=False,
                  progress:bool=True, save_dir:str=None,
                  pre_inference_tfms:list=None)

Batch patch-based inference on multiple volumes.

Args: learner: PyTorch model or fastai Learner. config: PatchConfig with inference settings. Preprocessing params (apply_reorder, target_spacing) can be set here for DRY usage. file_paths: List of image paths. apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value. target_spacing: Target voxel spacing. If None, uses config value. batch_size: Patches per batch. return_probabilities: Return probability maps. progress: Show progress bar. save_dir: Directory to save predictions as NIfTI files. If None, predictions are not saved. pre_inference_tfms: List of TorchIO transforms to apply before patch extraction. IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).

Returns: List of predicted tensors.

Example: >>> # DRY pattern: use same config for training and inference >>> config = PatchConfig( … patch_size=[96, 96, 96], … apply_reorder=True, … target_spacing=[0.4102, 0.4102, 1.5] … ) >>> predictions = patch_inference( … learner=learn, … config=config, # apply_reorder and target_spacing from config … file_paths=val_paths, … pre_inference_tfms=[tio.ZNormalization()], … save_dir=‘predictions/patch_based’ … )

Export