# Test _extract_tio_transform and normalize_patch_transforms
from fastMONAI.vision_augmentation import RandomAffine, RandomGamma, RandomFlip, RandomNoise
# Test extraction from fastMONAI wrappers via .tio_transform property
wrapped_affine = RandomAffine(degrees=10)
extracted = _extract_tio_transform(wrapped_affine)
test_eq(type(extracted), tio.RandomAffine)
wrapped_gamma = RandomGamma(p=0.5)
extracted = _extract_tio_transform(wrapped_gamma)
test_eq(type(extracted), tio.RandomGamma)
# Test passthrough for raw TorchIO transforms
raw_affine = tio.RandomAffine(degrees=10)
extracted = _extract_tio_transform(raw_affine)
test_eq(extracted, raw_affine) # Should be the exact same object
raw_flip = tio.RandomFlip(p=0.5)
extracted = _extract_tio_transform(raw_flip)
test_eq(extracted, raw_flip)
# Test normalize_patch_transforms with list of fastMONAI wrappers
tfms = [RandomAffine(degrees=5), RandomGamma(p=0.5), RandomNoise(p=0.3)]
normalized = normalize_patch_transforms(tfms)
test_eq(len(normalized), 3)
test_eq(type(normalized[0]), tio.RandomAffine)
test_eq(type(normalized[1]), tio.RandomGamma)
test_eq(type(normalized[2]), tio.RandomNoise)
# Test normalize_patch_transforms with mixed list (wrappers + raw TorchIO)
mixed_tfms = [RandomAffine(degrees=5), tio.RandomGamma(p=0.5)]
normalized = normalize_patch_transforms(mixed_tfms)
test_eq(len(normalized), 2)
test_eq(type(normalized[0]), tio.RandomAffine)
test_eq(type(normalized[1]), tio.RandomGamma)
# Test normalize_patch_transforms with None
test_eq(normalize_patch_transforms(None), None)Patch-based training
normalize_patch_transforms
def normalize_patch_transforms(
tfms:list
)->list:
Normalize transforms for patch-based workflow.
Extracts underlying TorchIO transforms from fastMONAI wrappers. Also accepts raw TorchIO transforms for backward compatibility.
This enables using the same transform syntax in both standard and patch-based workflows:
>>> from fastMONAI.vision_augmentation import RandomAffine, RandomGamma
>>>
>>> # Same syntax works in both contexts
>>> item_tfms = [RandomAffine(degrees=10), RandomGamma(p=0.5)] # Standard
>>> patch_tfms = [RandomAffine(degrees=10), RandomGamma(p=0.5)] # Patch-based
Args: tfms: List of fastMONAI wrappers or raw TorchIO transforms
Returns: List of raw TorchIO transforms suitable for tio.Compose()
Configuration
PatchConfig
def PatchConfig(
patch_size:list=<factory>, patch_overlap:int | float | list=0, samples_per_volume:int=8,
sampler_type:str='uniform', label_probabilities:dict=None, queue_length:int=300, queue_num_workers:int=4,
aggregation_mode:str='hann', apply_reorder:bool=True, target_spacing:list=None, preprocessed:bool=False,
padding_mode:int | float | str=0, keep_largest_component:bool=False
)->None:
Configuration for patch-based training and inference.
Args: patch_size: Size of patches [x, y, z]. patch_overlap: Overlap for inference GridSampler (int, float 0-1, or list). - Float 0-1: fraction of patch_size (e.g., 0.5 = 50% overlap) - Int >= 1: pixel overlap (e.g., 48 = 48 pixel overlap) - List: per-dimension overlap in pixels samples_per_volume: Number of patches to extract per volume during training. sampler_type: Type of sampler (‘uniform’, ‘label’, ‘weighted’). label_probabilities: For LabelSampler, dict mapping label values to probabilities. queue_length: Maximum number of patches to store in queue. queue_num_workers: Number of workers for parallel patch extraction. aggregation_mode: For inference, how to combine overlapping patches (‘crop’, ‘average’, ‘hann’). apply_reorder: Whether to reorder to RAS+ canonical orientation. Must match between training and inference. Defaults to True (the common case). target_spacing: Target voxel spacing [x, y, z] for resampling. Must match between training and inference. preprocessed: If True, data has been preprocessed externally (e.g., via preprocess_dataset()). Training will skip reorder, resample, AND pre_patch_tfms (e.g., normalization) since they were already applied. Inference is unaffected and always applies pre_inference_tfms to raw images. Defaults to False. padding_mode: Padding mode for CropOrPad when image < patch_size. Default is 0 (zero padding) to align with nnU-Net’s approach. Can be int, float, or string (e.g., ‘minimum’, ‘mean’). keep_largest_component: If True, keep only the largest connected component in binary segmentation predictions. Only applies during inference when return_probabilities=False. Defaults to False.
Example: >>> config = PatchConfig( … patch_size=[96, 96, 96], … samples_per_volume=16, … sampler_type=‘label’, … label_probabilities={0: 0.1, 1: 0.9}, … target_spacing=[0.5, 0.5, 0.5] … )
# Test PatchConfig
config = PatchConfig(patch_size=[96, 96, 96], samples_per_volume=16)
test_eq(config.patch_size, [96, 96, 96])
test_eq(config.samples_per_volume, 16)
test_eq(config.sampler_type, 'uniform')
test_eq(config.apply_reorder, True) # Default is now True (the common case)
test_eq(config.target_spacing, None)
test_eq(config.preprocessed, False) # Default is False
test_eq(config.padding_mode, 0)
# Test with preprocessing params
config2 = PatchConfig(
patch_size=[64, 64, 64],
apply_reorder=True,
target_spacing=[0.5, 0.5, 0.5],
padding_mode=0
)
test_eq(config2.apply_reorder, True)
test_eq(config2.target_spacing, [0.5, 0.5, 0.5])
# Test preprocessed=True with actual preprocessing params (no warning)
config3 = PatchConfig(
patch_size=[96, 96, 96],
apply_reorder=True,
target_spacing=[0.5, 0.5, 0.5],
preprocessed=True
)
test_eq(config3.preprocessed, True)
test_eq(config3.apply_reorder, True)
test_eq(config3.target_spacing, [0.5, 0.5, 0.5])
# Test preprocessed=True without preprocessing params does NOT warn
# (preprocessed=True still has effect: skips pre_patch_tfms during training)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
config4 = PatchConfig(
patch_size=[96, 96, 96],
apply_reorder=False,
target_spacing=None,
preprocessed=True
)
preprocessed_warns = [x for x in w if 'preprocessed' in str(x.message).lower()]
test_eq(len(preprocessed_warns), 0)Subject conversion
med_to_subject
def med_to_subject(
img:pathlib.Path | str, mask:pathlib.Path | str=None
)->Subject:
Create TorchIO Subject with LAZY loading (paths only, no tensor loading).
This function stores file paths in the Subject, allowing TorchIO’s Queue workers to load volumes on-demand during training. This is memory-efficient as volumes are not loaded into RAM until needed.
Args: img: Path to image file. mask: Path to mask file (optional).
Returns: TorchIO Subject with ‘image’ and optionally ‘mask’ keys (lazy loaded).
Example: >>> subject = med_to_subject(‘image.nii.gz’, ‘mask.nii.gz’) >>> # Volume NOT loaded yet - only path stored >>> data = subject[‘image’].data # NOW volume is loaded
create_subjects_dataset
def create_subjects_dataset(
df:DataFrame, img_col:str, mask_col:str=None, pre_tfms:list=None, ensure_affine_consistency:bool=True
)->SubjectsDataset:
Build TorchIO SubjectsDataset with LAZY loading from DataFrame.
This function creates a SubjectsDataset that stores only file paths, not loaded tensors. Volumes are loaded on-demand by Queue workers, keeping memory usage constant regardless of dataset size.
Args: df: DataFrame with image (and optionally mask) paths. img_col: Column name containing image paths. mask_col: Column name containing mask paths (optional). pre_tfms: List of TorchIO transforms to apply before patch extraction. Use tio.ToCanonical() for reordering and tio.Resample() for resampling. ensure_affine_consistency: If True and mask_col is provided, automatically prepends tio.CopyAffine(target=‘image’) to ensure spatial metadata consistency between image and mask. This prevents “More than one value for direction found” errors. Defaults to True.
Returns: TorchIO SubjectsDataset with lazy-loaded subjects.
Example: >>> # Preprocessing via transforms (applied by workers on-demand) >>> pre_tfms = [ … tio.ToCanonical(), # Reorder to RAS+ … tio.Resample([0.5, 0.5, 0.5]), # Resample … tio.ZNormalization(), # Intensity normalization … ] >>> dataset = create_subjects_dataset( … df, img_col=‘image’, mask_col=‘label’, … pre_tfms=pre_tfms … ) >>> # Memory: ~0 MB (only paths stored, not volumes)
Sampler creation
create_patch_sampler
def create_patch_sampler(
config:PatchConfig
)->PatchSampler:
Create appropriate TorchIO sampler based on config.
Args: config: PatchConfig with sampler settings.
Returns: TorchIO PatchSampler instance.
Example: >>> config = PatchConfig(patch_size=[96, 96, 96], sampler_type=‘label’) >>> sampler = create_patch_sampler(config)
# Test sampler creation
config = PatchConfig(patch_size=[64, 64, 64], sampler_type='uniform')
sampler = create_patch_sampler(config)
test_eq(type(sampler), tio.UniformSampler)
# Test WeightedSampler raises NotImplementedError
from fastcore.test import test_fail
config_weighted = PatchConfig(patch_size=[64, 64, 64], sampler_type='weighted')
test_fail(lambda: create_patch_sampler(config_weighted), contains='WeightedSampler')# Test utility functions
# Test _get_default_device
device = _get_default_device()
test_eq(type(device), torch.device)
# Test _warn_config_override doesn't raise when values match or when one is None
import warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_warn_config_override('test_param', True, True) # Same values - no warning
_warn_config_override('test_param', True, None) # Explicit is None - no warning
_warn_config_override('test_param', None, True) # Config is None - no warning
test_eq(len(w), 0)
# Test _warn_config_override warns when values differ
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_warn_config_override('test_param', True, False) # Different values - warning
test_eq(len(w), 1)
assert 'mismatch' in str(w[0].message)Patch DataLoaders
MedPatchDataLoader
def MedPatchDataLoader(
subjects_dataset:SubjectsDataset, config:PatchConfig, batch_size:int=4, patch_tfms:list=None,
gpu_augmentation:NoneType=None, shuffle:bool=True, drop_last:bool=False
):
DataLoader wrapper for patch-based training with TorchIO Queue.
This class wraps a TorchIO Queue to provide a fastai-compatible DataLoader interface for patch-based training.
Args: subjects_dataset: TorchIO SubjectsDataset. config: PatchConfig with queue and sampler settings. batch_size: Number of patches per batch. Must be positive. patch_tfms: Transforms to apply to extracted patches (training only). Accepts both fastMONAI wrappers (e.g., RandomAffine, RandomGamma) and raw TorchIO transforms. fastMONAI wrappers are automatically normalized to raw TorchIO for internal use. Mutually exclusive with gpu_augmentation. gpu_augmentation: GpuPatchAugmentation instance for GPU-batched augmentation. Operates on [B,C,D,H,W] tensors already on GPU, avoiding per-sample CPU overhead. Mutually exclusive with patch_tfms. Training only. shuffle: Whether to shuffle subjects and patches. drop_last: Whether to drop last incomplete batch.
MedPatchDataLoaders
def MedPatchDataLoaders(
train_dl:MedPatchDataLoader, valid_dl:MedPatchDataLoader, device:device=None
):
fastai-compatible DataLoaders for patch-based training with LAZY loading.
This class provides train and validation DataLoaders that work with fastai’s Learner for patch-based training on 3D medical images.
Memory-efficient: Volumes are loaded on-demand by Queue workers, keeping memory usage constant (~150 MB) regardless of dataset size.
Note: Validation uses the same sampling as training (pseudo Dice). For true validation metrics, use PatchInferenceEngine with GridSampler for full-volume sliding window inference.
Example: >>> import torchio as tio >>> >>> # New pattern: preprocessing params in config (DRY) >>> config = PatchConfig( … patch_size=[96, 96, 96], … apply_reorder=True, … target_spacing=[0.5, 0.5, 0.5] … ) >>> dls = MedPatchDataLoaders.from_df( … df, img_col=‘image’, mask_col=‘label’, … valid_pct=0.2, … patch_config=config, … pre_patch_tfms=[tio.ZNormalization()], … bs=4 … ) >>> learn = Learner(dls, model, loss_func=DiceLoss())
# Test mutual exclusivity of gpu_augmentation and patch_tfms
from fastMONAI.vision_augmentation import GpuPatchAugmentation
# Should raise ValueError when both gpu_augmentation and patch_tfms are provided
test_fail(
lambda: MedPatchDataLoaders.from_df(
pd.DataFrame({'img': ['fake.nii'], 'mask': ['fake.nii']}),
img_col='img', mask_col='mask',
patch_tfms=[tio.RandomFlip()],
gpu_augmentation=GpuPatchAugmentation(flip={'axes': (0,), 'p': 0.5}),
),
contains='Cannot use both'
)
# Verify gpu_augmentation is stored on train_dl but not valid_dl
# (We can't fully instantiate from_df without real files, so test MedPatchDataLoader directly)
test_eq(MedPatchDataLoader.__init__.__code__.co_varnames[:8],
('self', 'subjects_dataset', 'config', 'batch_size',
'patch_tfms', 'gpu_augmentation', 'shuffle', 'drop_last'))Patch-based Inference
PatchInferenceEngine
def PatchInferenceEngine(
learner, config:PatchConfig, apply_reorder:bool=None, target_spacing:list=None, batch_size:int=4,
pre_inference_tfms:list=None
):
Patch-based inference with automatic volume reconstruction.
Uses TorchIO’s GridSampler to extract overlapping patches and GridAggregator to reconstruct the full volume from predictions.
Args: learner: fastai Learner or PyTorch model (nn.Module). When passing a raw PyTorch model, load weights first with model.load_state_dict(). config: PatchConfig with inference settings. Preprocessing params (apply_reorder, target_spacing, padding_mode) can be set here for DRY usage. apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value. target_spacing: Target voxel spacing. If None, uses config value. batch_size: Number of patches to predict at once. Must be positive. pre_inference_tfms: List of TorchIO transforms to apply before patch extraction. IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]). This ensures preprocessing consistency between training and inference. Accepts both fastMONAI wrappers and raw TorchIO transforms.
Example: >>> # Option 1: From fastai Learner >>> engine = PatchInferenceEngine(learn, config, pre_inference_tfms=[ZNormalization()]) >>> pred = engine.predict(‘image.nii.gz’)
>>> # Option 2: From raw PyTorch model (recommended for deployment)
>>> model = UNet(spatial_dims=3, in_channels=1, out_channels=2, ...)
>>> model.load_state_dict(torch.load('final_weights.pth'))
>>> model.cuda().eval()
>>> engine = PatchInferenceEngine(model, config, pre_inference_tfms=[ZNormalization()])
>>> pred = engine.predict('image.nii.gz')
# Test PatchInferenceEngine correctly handles raw PyTorch models
# This verifies the isinstance fix - MONAI UNet has a .model attribute (internal Sequential)
# that should NOT be extracted when passing a raw model
from monai.networks.nets import UNet
from monai.networks.layers import Norm
test_model = UNet(
spatial_dims=3, in_channels=1, out_channels=2,
channels=(16, 32), strides=(2,), num_res_units=1,
norm=Norm.INSTANCE
)
test_config = PatchConfig(patch_size=[32, 32, 32])
# Verify MONAI UNet has .model attribute (the bug scenario)
test_eq(hasattr(test_model, 'model'), True)
test_eq(type(test_model.model).__name__, 'Sequential') # It's an internal Sequential
# Create engine with raw model - should store the UNet, NOT unet.model
engine = PatchInferenceEngine(test_model, test_config)
test_eq(type(engine.model), UNet) # Should be UNet, not Sequential
print("PatchInferenceEngine raw model detection test passed!")PatchInferenceEngine raw model detection test passed!
patch_inference
def patch_inference(
learner, config:PatchConfig, file_paths:list, apply_reorder:bool=None, target_spacing:list=None,
batch_size:int=4, return_probabilities:bool=False, progress:bool=True, save_dir:str=None,
pre_inference_tfms:list=None, tta:bool=False, prefetch:bool=True
)->list:
Batch patch-based inference on multiple volumes.
When prefetch=True (default), overlaps I/O with compute: while the current image is being inferred, the next image is loaded and preprocessed in a background thread, and the previous result is saved in the background. This eliminates most I/O idle time, especially on GPU where CPU prep and GPU compute use different hardware.
Args: learner: PyTorch model or fastai Learner. config: PatchConfig with inference settings. Preprocessing params (apply_reorder, target_spacing) can be set here for DRY usage. file_paths: List of image paths. apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value. target_spacing: Target voxel spacing. If None, uses config value. batch_size: Patches per batch. return_probabilities: Return probability maps. progress: Show progress bar. save_dir: Directory to save predictions as NIfTI files. If None, predictions are not saved. pre_inference_tfms: List of TorchIO transforms to apply before patch extraction. IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]). tta: If True, apply nnU-Net-style mirror TTA (8 flip combinations). prefetch: If True (default), overlap I/O with compute using a background thread for preparation and saving. Holds two subjects in memory simultaneously (current + next). Set to False for memory-constrained environments processing very large volumes.
Returns: List of predicted tensors.
Example: >>> config = PatchConfig( … patch_size=[96, 96, 96], … apply_reorder=True, … target_spacing=[0.4102, 0.4102, 1.5] … ) >>> predictions = patch_inference( … learner=learn, … config=config, # apply_reorder and target_spacing from config … file_paths=val_paths, … pre_inference_tfms=[tio.ZNormalization()], … save_dir=‘predictions/patch_based’ … )
# Test _TTA_FLIP_AXES and _predict_patch_tta
from itertools import combinations
# Test 1: _TTA_FLIP_AXES has exactly 8 entries (2^3 combinations for 3 axes)
test_eq(len(_TTA_FLIP_AXES), 8)
# Verify all 2^3 combinations are present (each axis in {2,3,4} independently on/off)
expected_combos = set()
axes = [2, 3, 4]
for r in range(len(axes) + 1):
for combo in combinations(axes, r):
expected_combos.add(combo)
actual_combos = set(tuple(sorted(a)) for a in _TTA_FLIP_AXES)
test_eq(actual_combos, expected_combos)
# Test 2: _predict_patch_tta output shape and probability range
import torch.nn as nn
class _SimpleConv(nn.Module):
"""Minimal model for TTA testing."""
def __init__(self, out_channels):
super().__init__()
self.conv = nn.Conv3d(1, out_channels, 1)
def forward(self, x):
return self.conv(x)
# Binary case (1 output channel -> sigmoid)
model_bin = _SimpleConv(1).eval()
dummy_input = torch.randn(2, 1, 8, 8, 8) # [B=2, C=1, D, H, W]
with torch.no_grad():
tta_out = _predict_patch_tta(model_bin, dummy_input)
test_eq(tta_out.shape, torch.Size([2, 1, 8, 8, 8]))
assert tta_out.min() >= 0.0 and tta_out.max() <= 1.0, f"Probabilities out of range: [{tta_out.min()}, {tta_out.max()}]"
# Multi-class case (3 output channels -> softmax)
model_mc = _SimpleConv(3).eval()
with torch.no_grad():
tta_out_mc = _predict_patch_tta(model_mc, dummy_input)
test_eq(tta_out_mc.shape, torch.Size([2, 3, 8, 8, 8]))
assert tta_out_mc.min() >= 0.0 and tta_out_mc.max() <= 1.0
# Test 3: TTA on constant input matches single forward pass
# A constant tensor is invariant to flipping, so TTA should equal single pass
const_input = torch.ones(1, 1, 8, 8, 8) * 0.5
with torch.no_grad():
single_logits = model_bin(const_input)
single_probs = torch.sigmoid(single_logits).cpu()
tta_probs = _predict_patch_tta(model_bin, const_input)
assert torch.allclose(single_probs, tta_probs, atol=1e-6), "TTA on constant input should match single forward pass"
print("TTA tests passed!")TTA tests passed!
# Test _PreparedSubject and decomposed predict path
import tempfile, os, nibabel as nib
from monai.networks.nets import UNet
from monai.networks.layers import Norm
# Create a small synthetic NIfTI file for testing
_test_data = np.random.randn(32, 32, 32).astype(np.float32)
_test_affine = np.eye(4)
_test_nii = nib.Nifti1Image(_test_data, _test_affine)
with tempfile.TemporaryDirectory() as tmpdir:
img_path = os.path.join(tmpdir, 'test_img.nii.gz')
nib.save(_test_nii, img_path)
_model = UNet(
spatial_dims=3, in_channels=1, out_channels=2,
channels=(16, 32), strides=(2,), num_res_units=1,
norm=Norm.INSTANCE
).eval()
_config = PatchConfig(patch_size=[32, 32, 32])
_engine = PatchInferenceEngine(_model, _config, apply_reorder=False)
# Test 1: _prepare_subject returns _PreparedSubject with expected attributes
prepared = _engine._prepare_subject(img_path)
assert isinstance(prepared, _PreparedSubject), "Should return _PreparedSubject"
assert isinstance(prepared.subject, tio.Subject)
assert isinstance(prepared.grid_sampler, tio.GridSampler)
assert isinstance(prepared.aggregator, tio.GridAggregator)
assert isinstance(prepared.patch_loader, DataLoader)
assert prepared.org_size is not None
# Test 2: Decomposed path equals predict() output
pred_decomposed, affine_decomposed = _engine._postprocess(
_engine._run_inference(
_engine._prepare_subject(img_path)
),
_engine._prepare_subject(img_path)
)
pred_predict, affine_predict = _engine.predict(img_path, return_affine=True)
assert torch.equal(pred_decomposed, pred_predict), "Decomposed path should match predict()"
assert np.array_equal(affine_decomposed, affine_predict), "Affine should match"
# Test 3: prefetch=True produces identical results to prefetch=False
paths = [img_path, img_path] # Two copies to trigger pipeline path
preds_prefetch = patch_inference(
_model, _config, paths, apply_reorder=False, progress=False, prefetch=True
)
preds_sequential = patch_inference(
_model, _config, paths, apply_reorder=False, progress=False, prefetch=False
)
assert len(preds_prefetch) == len(preds_sequential) == 2
for p1, p2 in zip(preds_prefetch, preds_sequential):
assert torch.equal(p1, p2), "prefetch=True should produce identical results"
# Test 4: Error propagation -- file-not-found raises (not silently swallowed)
test_fail(
lambda: patch_inference(
_model, _config, ['/nonexistent/file.nii.gz'],
apply_reorder=False, progress=False, prefetch=True
)
)
test_fail(
lambda: patch_inference(
_model, _config, [img_path, '/nonexistent/file.nii.gz'],
apply_reorder=False, progress=False, prefetch=True
)
)
# Test 5: Save pipeline works correctly with prefetch=True
save_dir = os.path.join(tmpdir, 'preds')
preds_saved = patch_inference(
_model, _config, paths, apply_reorder=False, progress=False,
save_dir=save_dir, prefetch=True
)
assert len(preds_saved) == 2
saved_files = list(Path(save_dir).glob('*.nii.gz'))
assert len(saved_files) == 1, f"Expected 1 unique file (same input), got {len(saved_files)}"
# Verify the saved file is valid NIfTI
saved_nii = nib.load(str(saved_files[0]))
assert saved_nii.shape is not None
print("Pipeline inference tests passed!")Pipeline inference tests passed!
apply_reorder mismatch: explicit=False, config=True. Using explicit argument.