Patch-based training

Patch-based training and inference for 3D medical image segmentation using TorchIO’s Queue mechanism.

source

normalize_patch_transforms


def normalize_patch_transforms(
    tfms:list
)->list:

Extract raw TorchIO transforms from a list of fastMONAI wrappers (raw TorchIO transforms pass through).

Lets the same transform syntax work in both standard and patch-based workflows.

Args: tfms: List of fastMONAI wrappers or raw TorchIO transforms.

Returns: List of raw TorchIO transforms suitable for tio.Compose().

Example: >>> patch_tfms = normalize_patch_transforms([RandomAffine(degrees=10), RandomGamma(p=0.5)])

# Test _extract_tio_transform and normalize_patch_transforms
from fastMONAI.vision_augmentation import RandomAffine, RandomGamma, RandomFlip, RandomNoise

wrapped_affine = RandomAffine(degrees=10)
extracted = _extract_tio_transform(wrapped_affine)
test_eq(type(extracted), tio.RandomAffine)

wrapped_gamma = RandomGamma(p=0.5)
extracted = _extract_tio_transform(wrapped_gamma)
test_eq(type(extracted), tio.RandomGamma)

raw_affine = tio.RandomAffine(degrees=10)
extracted = _extract_tio_transform(raw_affine)
test_eq(extracted, raw_affine)  # Should be the exact same object

raw_flip = tio.RandomFlip(p=0.5)
extracted = _extract_tio_transform(raw_flip)
test_eq(extracted, raw_flip)

tfms = [RandomAffine(degrees=5), RandomGamma(p=0.5), RandomNoise(p=0.3)]
normalized = normalize_patch_transforms(tfms)
test_eq(len(normalized), 3)
test_eq(type(normalized[0]), tio.RandomAffine)
test_eq(type(normalized[1]), tio.RandomGamma)
test_eq(type(normalized[2]), tio.RandomNoise)

# Mixed list (wrappers + raw TorchIO)
mixed_tfms = [RandomAffine(degrees=5), tio.RandomGamma(p=0.5)]
normalized = normalize_patch_transforms(mixed_tfms)
test_eq(len(normalized), 2)
test_eq(type(normalized[0]), tio.RandomAffine)
test_eq(type(normalized[1]), tio.RandomGamma)

test_eq(normalize_patch_transforms(None), None)

Configuration


source

PatchConfig


def PatchConfig(
    patch_size:list=<factory>, patch_overlap:int | float | list=0, samples_per_volume:int=8,
    sampler_type:str='uniform', label_probabilities:dict=None, queue_length:int=300, queue_num_workers:int=4,
    aggregation_mode:str='hann', apply_reorder:bool=True, target_spacing:list=None, preprocessed:bool=False,
    padding_mode:int | float | str=0, keep_largest_component:bool=False, binary_threshold:float=0.5,
    normalization:list=None
)->None:

Configuration for patch-based training and inference.

Args: patch_size: Size of patches [x, y, z]. patch_overlap: Overlap for inference GridSampler (int, float 0-1, or list). - Float 0-1: fraction of patch_size (e.g., 0.5 = 50% overlap) - Int >= 1: pixel overlap (e.g., 48 = 48 pixel overlap) - List: per-dimension overlap in pixels samples_per_volume: Number of patches to extract per volume during training. sampler_type: Type of sampler (‘uniform’, ‘label’, ‘weighted’). label_probabilities: For LabelSampler, dict mapping label values to probabilities. queue_length: Maximum number of patches to store in queue. queue_num_workers: Number of workers for parallel patch extraction. aggregation_mode: For inference, how to combine overlapping patches (‘crop’, ‘average’, ‘hann’). apply_reorder: Whether to reorder to RAS+ canonical orientation. Must match between training and inference. Defaults to True (the common case). target_spacing: Target voxel spacing [x, y, z] for resampling. Must match between training and inference. preprocessed: If True, data has been preprocessed externally (e.g., via preprocess_dataset()). Training will skip reorder, resample, AND pre_patch_tfms (e.g., normalization) since they were already applied. Inference is unaffected and always applies pre_inference_tfms to raw images. Defaults to False. padding_mode: Padding mode for CropOrPad when image < patch_size. Default is 0 (zero padding). Can be int, float, or string (e.g., ‘minimum’, ‘mean’). keep_largest_component: If True, keep only the largest connected component in binary segmentation predictions. Only applies during inference when return_probabilities=False. Defaults to False. binary_threshold: Decision boundary for single-channel (sigmoid) masks; a voxel is foreground when probability >= binary_threshold (matches MONAI AsDiscrete). Only applies when return_probabilities=False and n_classes == 1. Defaults to 0.5. normalization: Single source of truth for pre-patch / pre-inference intensity normalization. A list of fastMONAI transforms (e.g. [ZNormalization(masking_method=‘foreground’)]) or JSON spec dicts; coerced to specs and persisted with the config. Read by both MedPatchDataLoaders.from_df (training) and PatchInferenceEngine (inference). The manual pre_patch_tfms / pre_inference_tfms args override this when provided.

Example: >>> config = PatchConfig( … patch_size=[96, 96, 96], … samples_per_volume=16, … sampler_type=‘label’, … label_probabilities={0: 0.1, 1: 0.9}, … target_spacing=[0.5, 0.5, 0.5] … )

config = PatchConfig(patch_size=[96, 96, 96], samples_per_volume=16)
test_eq(config.patch_size, [96, 96, 96])
test_eq(config.samples_per_volume, 16)
test_eq(config.sampler_type, 'uniform')
test_eq(config.apply_reorder, True)
test_eq(config.target_spacing, None)
test_eq(config.preprocessed, False)
test_eq(config.padding_mode, 0)

config2 = PatchConfig(
    patch_size=[64, 64, 64],
    apply_reorder=True,
    target_spacing=[0.5, 0.5, 0.5],
    padding_mode=0
)
test_eq(config2.apply_reorder, True)
test_eq(config2.target_spacing, [0.5, 0.5, 0.5])

# preprocessed=True with preprocessing params: no warning
config3 = PatchConfig(
    patch_size=[96, 96, 96],
    apply_reorder=True,
    target_spacing=[0.5, 0.5, 0.5],
    preprocessed=True
)
test_eq(config3.preprocessed, True)
test_eq(config3.apply_reorder, True)
test_eq(config3.target_spacing, [0.5, 0.5, 0.5])

# preprocessed=True without preprocessing params does NOT warn
# (preprocessed=True still has effect: skips pre_patch_tfms during training)
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    config4 = PatchConfig(
        patch_size=[96, 96, 96],
        apply_reorder=False,
        target_spacing=None,
        preprocessed=True
    )
    preprocessed_warns = [x for x in w if 'preprocessed' in str(x.message).lower()]
    test_eq(len(preprocessed_warns), 0)

Subject conversion


source

med_to_subject


def med_to_subject(
    img:pathlib.Path | str, mask:pathlib.Path | str=None
)->Subject:

Create a TorchIO Subject with LAZY loading (stores paths only, no tensors).

Storing only paths lets TorchIO’s Queue workers load volumes on-demand during training, keeping RAM usage low.

Args: img: Path to image file. mask: Path to mask file (optional).

Returns: TorchIO Subject with ‘image’ and optionally ‘mask’ keys (lazy loaded).

Example: >>> subject = med_to_subject(‘image.nii.gz’, ‘mask.nii.gz’) >>> data = subject[‘image’].data # volume loaded only now


source

create_subjects_dataset


def create_subjects_dataset(
    df:DataFrame, img_col:str, mask_col:str=None, pre_tfms:list=None, ensure_affine_consistency:bool=True
)->SubjectsDataset:

Build a TorchIO SubjectsDataset with LAZY loading from a DataFrame.

Stores only file paths (not tensors); volumes are loaded on-demand by Queue workers, keeping memory usage constant regardless of dataset size.

Args: df: DataFrame with image (and optionally mask) paths. img_col: Column name containing image paths. mask_col: Column name containing mask paths (optional). pre_tfms: List of TorchIO transforms to apply before patch extraction. Use tio.ToCanonical() for reordering and tio.Resample() for resampling. ensure_affine_consistency: If True and mask_col is provided, automatically prepends tio.CopyAffine(target=‘image’) to ensure spatial metadata consistency between image and mask. This prevents “More than one value for direction found” errors. Defaults to True.

Returns: TorchIO SubjectsDataset with lazy-loaded subjects.

Example: >>> pre_tfms = [tio.ToCanonical(), tio.Resample([0.5, 0.5, 0.5]), tio.ZNormalization()] >>> dataset = create_subjects_dataset(df, img_col=‘image’, mask_col=‘label’, pre_tfms=pre_tfms)

Sampler creation


source

create_patch_sampler


def create_patch_sampler(
    config:PatchConfig
)->PatchSampler:

Create appropriate TorchIO sampler based on config.

Args: config: PatchConfig with sampler settings.

Returns: TorchIO PatchSampler instance.

Example: >>> config = PatchConfig(patch_size=[96, 96, 96], sampler_type=‘label’) >>> sampler = create_patch_sampler(config)

config = PatchConfig(patch_size=[64, 64, 64], sampler_type='uniform')
sampler = create_patch_sampler(config)
test_eq(type(sampler), tio.UniformSampler)

# WeightedSampler raises NotImplementedError
from fastcore.test import test_fail
config_weighted = PatchConfig(patch_size=[64, 64, 64], sampler_type='weighted')
test_fail(lambda: create_patch_sampler(config_weighted), contains='WeightedSampler')
device = _get_default_device()
test_eq(type(device), torch.device)

# _warn_config_override: no warning when values match or one is None
import warnings
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    _warn_config_override('test_param', True, True)  # Same values - no warning
    _warn_config_override('test_param', True, None)  # Explicit is None - no warning
    _warn_config_override('test_param', None, True)  # Config is None - no warning
    test_eq(len(w), 0)

# _warn_config_override warns when values differ
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    _warn_config_override('test_param', True, False)  # Different values - warning
    test_eq(len(w), 1)
    assert 'mismatch' in str(w[0].message)

Patch DataLoaders


source

MedPatchDataLoader


def MedPatchDataLoader(
    subjects_dataset:SubjectsDataset, config:PatchConfig, batch_size:int=4, patch_tfms:list=None,
    gpu_augmentation:NoneType=None, shuffle:bool=True, drop_last:bool=False
):

DataLoader wrapper for patch-based training with TorchIO Queue.

This class wraps a TorchIO Queue to provide a fastai-compatible DataLoader interface for patch-based training.

Args: subjects_dataset: TorchIO SubjectsDataset. config: PatchConfig with queue and sampler settings. batch_size: Number of patches per batch. Must be positive. patch_tfms: Transforms to apply to extracted patches (training only). Accepts both fastMONAI wrappers (e.g., RandomAffine, RandomGamma) and raw TorchIO transforms. fastMONAI wrappers are automatically normalized to raw TorchIO for internal use. Mutually exclusive with gpu_augmentation. gpu_augmentation: GpuPatchAugmentation instance for GPU-batched augmentation. Operates on [B,C,D,H,W] tensors already on GPU, avoiding per-sample CPU overhead. Mutually exclusive with patch_tfms. Training only. shuffle: Whether to shuffle subjects and patches. drop_last: Whether to drop last incomplete batch.


source

MedPatchDataLoaders


def MedPatchDataLoaders(
    train_dl:MedPatchDataLoader, valid_dl:MedPatchDataLoader, device:device=None
):

fastai-compatible DataLoaders for patch-based training with LAZY loading.

This class provides train and validation DataLoaders that work with fastai’s Learner for patch-based training on 3D medical images.

Memory-efficient: Volumes are loaded on-demand by Queue workers, keeping memory usage constant (~150 MB) regardless of dataset size.

Note: Validation uses the same sampling as training (pseudo Dice). For true validation metrics, use PatchInferenceEngine with GridSampler for full-volume sliding window inference.

Example: >>> config = PatchConfig(patch_size=[96, 96, 96], apply_reorder=True, target_spacing=[0.5, 0.5, 0.5]) >>> dls = MedPatchDataLoaders.from_df( … df, img_col=‘image’, mask_col=‘label’, valid_pct=0.2, … patch_config=config, pre_patch_tfms=[tio.ZNormalization()], bs=4 … ) >>> learn = Learner(dls, model, loss_func=DiceLoss())

from fastMONAI.vision_augmentation import GpuPatchAugmentation

# Both gpu_augmentation and patch_tfms -> ValueError
test_fail(
    lambda: MedPatchDataLoaders.from_df(
        pd.DataFrame({'img': ['fake.nii'], 'mask': ['fake.nii']}),
        img_col='img', mask_col='mask',
        patch_tfms=[tio.RandomFlip()],
        gpu_augmentation=GpuPatchAugmentation(flip={'axes': (0,), 'p': 0.5}),
    ),
    contains='Cannot use both'
)

# Verify gpu_augmentation is stored on train_dl but not valid_dl
# (We can't fully instantiate from_df without real files, so test MedPatchDataLoader directly)
test_eq(MedPatchDataLoader.__init__.__code__.co_varnames[:8],
        ('self', 'subjects_dataset', 'config', 'batch_size',
         'patch_tfms', 'gpu_augmentation', 'shuffle', 'drop_last'))

Patch-based Inference


source

PatchInferenceEngine


def PatchInferenceEngine(
    learner, config:PatchConfig, apply_reorder:bool=None, target_spacing:list=None, batch_size:int=4,
    pre_inference_tfms:list=None, amp:bool=False
):

Patch-based inference with automatic volume reconstruction.

Uses TorchIO’s GridSampler to extract overlapping patches and GridAggregator to reconstruct the full volume from predictions.

Args: learner: fastai Learner or PyTorch model (nn.Module). When passing a raw PyTorch model, load weights first with model.load_state_dict(). config: PatchConfig with inference settings. Preprocessing params (apply_reorder, target_spacing, padding_mode) can be set here for DRY usage. apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value. target_spacing: Target voxel spacing. If None, uses config value. batch_size: Number of patches to predict at once. Must be positive. pre_inference_tfms: Optional override for config.normalization. If None, normalization is read from config.normalization (the default source of truth, set at training time). Provide this only to override the config (e.g. for non-serializable transforms). Accepts both fastMONAI wrappers and raw TorchIO transforms. amp: If True, use automatic mixed precision (float16) for the forward pass. Only supported on CUDA devices; ignored with a warning on CPU/MPS. Defaults to False.

Example: >>> # normalization is read from config.normalization; amp=True for faster GPU >>> config = PatchConfig(patch_size=[96, 96, 96], normalization=[ZNormalization(masking_method=‘foreground’)]) >>> engine = PatchInferenceEngine(learn, config) >>> pred = engine.predict(‘image.nii.gz’)

# PatchInferenceEngine must store a raw model as-is, not its internal .model
from monai.networks.nets import UNet
from monai.networks.layers import Norm

test_model = UNet(
    spatial_dims=3, in_channels=1, out_channels=2,
    channels=(16, 32), strides=(2,), num_res_units=1,
    norm=Norm.INSTANCE
)
test_config = PatchConfig(patch_size=[32, 32, 32])

# MONAI UNet has a .model attribute (the bug scenario)
test_eq(hasattr(test_model, 'model'), True)
test_eq(type(test_model.model).__name__, 'Sequential')  # It's an internal Sequential

# Raw model -> engine.model is the UNet, NOT unet.model
engine = PatchInferenceEngine(test_model, test_config)
test_eq(type(engine.model), UNet)  # Should be UNet, not Sequential

print("PatchInferenceEngine raw model detection test passed!")
PatchInferenceEngine raw model detection test passed!

source

patch_inference


def patch_inference(
    learner, config:PatchConfig, file_paths:list, apply_reorder:bool=None, target_spacing:list=None,
    batch_size:int=4, return_probabilities:bool=False, progress:bool=True, save_dir:str=None,
    pre_inference_tfms:list=None, tta:bool=False, prefetch:bool=True, amp:bool=False
)->list:

Batch patch-based inference on multiple volumes.

When prefetch=True (default), overlaps I/O with compute: while the current image is being inferred, the next image is loaded and preprocessed in a background thread, and the previous result is saved in the background. This eliminates most I/O idle time, especially on GPU where CPU prep and GPU compute use different hardware.

Args: learner: PyTorch model or fastai Learner. config: PatchConfig with inference settings. Preprocessing params (apply_reorder, target_spacing) can be set here for DRY usage. file_paths: List of image paths. apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value. target_spacing: Target voxel spacing. If None, uses config value. batch_size: Patches per batch. return_probabilities: Return probability maps. progress: Show progress bar. save_dir: Directory to save predictions as NIfTI files. If None, predictions are not saved. pre_inference_tfms: Optional override for config.normalization. If None, normalization is read from config.normalization (the source of truth set at training time). Provide this only to override the config (e.g. for non-serializable transforms). tta: If True, mirror TTA (8 flip combinations). prefetch: If True (default), overlap I/O with compute using a background thread for preparation and saving. Holds two subjects in memory simultaneously (current + next). Set to False for memory-constrained environments processing very large volumes. amp: If True, use automatic mixed precision (float16) for the forward pass. Only supported on CUDA devices; ignored with a warning on CPU/MPS.

Returns: List of predicted tensors.

Example: >>> config = PatchConfig(patch_size=[96, 96, 96], apply_reorder=True, target_spacing=[0.4102, 0.4102, 1.5], … normalization=[ZNormalization(masking_method=‘foreground’)]) >>> predictions = patch_inference( # normalization auto-applied from config … learner=learn, config=config, file_paths=val_paths, … save_dir=‘predictions/patch_based’, amp=True … )

# Test _TTA_FLIP_AXES and _predict_patch_tta
from itertools import combinations

# _TTA_FLIP_AXES has exactly 8 entries (2^3 combinations for 3 axes)
test_eq(len(_TTA_FLIP_AXES), 8)

# All 2^3 combinations present (each axis in {2,3,4} independently on/off)
expected_combos = set()
axes = [2, 3, 4]
for r in range(len(axes) + 1):
    for combo in combinations(axes, r):
        expected_combos.add(combo)
actual_combos = set(tuple(sorted(a)) for a in _TTA_FLIP_AXES)
test_eq(actual_combos, expected_combos)

# _predict_patch_tta output shape and probability range
import torch.nn as nn

class _SimpleConv(nn.Module):
    """Minimal model for TTA testing."""
    def __init__(self, out_channels):
        super().__init__()
        self.conv = nn.Conv3d(1, out_channels, 1)
    def forward(self, x):
        return self.conv(x)

# Binary case (1 output channel -> sigmoid)
model_bin = _SimpleConv(1).eval()
dummy_input = torch.randn(2, 1, 8, 8, 8)  # [B=2, C=1, D, H, W]
with torch.no_grad():
    tta_out = _predict_patch_tta(model_bin, dummy_input)
test_eq(tta_out.shape, torch.Size([2, 1, 8, 8, 8]))
assert tta_out.min() >= 0.0 and tta_out.max() <= 1.0, f"Probabilities out of range: [{tta_out.min()}, {tta_out.max()}]"

# Multi-class case (3 output channels -> softmax)
model_mc = _SimpleConv(3).eval()
with torch.no_grad():
    tta_out_mc = _predict_patch_tta(model_mc, dummy_input)
test_eq(tta_out_mc.shape, torch.Size([2, 3, 8, 8, 8]))
assert tta_out_mc.min() >= 0.0 and tta_out_mc.max() <= 1.0

# TTA on constant input matches single forward pass
# A constant tensor is invariant to flipping, so TTA should equal single pass
const_input = torch.ones(1, 1, 8, 8, 8) * 0.5
with torch.no_grad():
    single_logits = model_bin(const_input)
    single_probs = torch.sigmoid(single_logits).cpu()
    tta_probs = _predict_patch_tta(model_bin, const_input)
assert torch.allclose(single_probs, tta_probs, atol=1e-6), "TTA on constant input should match single forward pass"

print("TTA tests passed!")
TTA tests passed!
# Test _PreparedSubject and decomposed predict path
import tempfile, os, nibabel as nib
from monai.networks.nets import UNet
from monai.networks.layers import Norm

# Create a small synthetic NIfTI file for testing
_test_data = np.random.randn(32, 32, 32).astype(np.float32)
_test_affine = np.eye(4)
_test_nii = nib.Nifti1Image(_test_data, _test_affine)

with tempfile.TemporaryDirectory() as tmpdir:
    img_path = os.path.join(tmpdir, 'test_img.nii.gz')
    nib.save(_test_nii, img_path)

    _model = UNet(
        spatial_dims=3, in_channels=1, out_channels=2,
        channels=(16, 32), strides=(2,), num_res_units=1,
        norm=Norm.INSTANCE
    ).eval()
    _config = PatchConfig(patch_size=[32, 32, 32])
    _engine = PatchInferenceEngine(_model, _config, apply_reorder=False)

    # Test 1: _prepare_subject returns _PreparedSubject with expected attributes
    prepared = _engine._prepare_subject(img_path)
    assert isinstance(prepared, _PreparedSubject), "Should return _PreparedSubject"
    assert isinstance(prepared.subject, tio.Subject)
    assert isinstance(prepared.grid_sampler, tio.GridSampler)
    assert isinstance(prepared.aggregator, tio.GridAggregator)
    assert isinstance(prepared.patch_loader, DataLoader)
    assert prepared.org_size is not None

    # Test 2: Decomposed path equals predict() output
    pred_decomposed, affine_decomposed = _engine._postprocess(
        _engine._run_inference(
            _engine._prepare_subject(img_path)
        ),
        _engine._prepare_subject(img_path)
    )
    pred_predict, affine_predict = _engine.predict(img_path, return_affine=True)

    assert torch.equal(pred_decomposed, pred_predict), "Decomposed path should match predict()"
    assert np.array_equal(affine_decomposed, affine_predict), "Affine should match"

    # Test 3: prefetch=True produces identical results to prefetch=False
    paths = [img_path, img_path]  # Two copies to trigger pipeline path
    preds_prefetch = patch_inference(
        _model, _config, paths, apply_reorder=False, progress=False, prefetch=True
    )
    preds_sequential = patch_inference(
        _model, _config, paths, apply_reorder=False, progress=False, prefetch=False
    )
    assert len(preds_prefetch) == len(preds_sequential) == 2
    for p1, p2 in zip(preds_prefetch, preds_sequential):
        assert torch.equal(p1, p2), "prefetch=True should produce identical results"

    # Test 4: Error propagation -- file-not-found raises (not silently swallowed)
    test_fail(
        lambda: patch_inference(
            _model, _config, ['/nonexistent/file.nii.gz'],
            apply_reorder=False, progress=False, prefetch=True
        )
    )
    test_fail(
        lambda: patch_inference(
            _model, _config, [img_path, '/nonexistent/file.nii.gz'],
            apply_reorder=False, progress=False, prefetch=True
        )
    )

    # Test 5: Save pipeline works correctly with prefetch=True
    save_dir = os.path.join(tmpdir, 'preds')
    preds_saved = patch_inference(
        _model, _config, paths, apply_reorder=False, progress=False,
        save_dir=save_dir, prefetch=True
    )
    assert len(preds_saved) == 2
    saved_files = list(Path(save_dir).glob('*.nii.gz'))
    assert len(saved_files) == 1, f"Expected 1 unique file (same input), got {len(saved_files)}"
    # Verify the saved file is valid NIfTI
    saved_nii = nib.load(str(saved_files[0]))
    assert saved_nii.shape is not None

print("Pipeline inference tests passed!")
Pipeline inference tests passed!
apply_reorder mismatch: explicit=False, config=True. Using explicit argument.
# Safety-net tests (added before refactor to characterize current behavior)
import tempfile as _tempfile, os as _os, nibabel as _nib
from monai.networks.nets import UNet as _UNet
from monai.networks.layers import Norm as _Norm

# --- T1: _normalize_patch_overlap covers zero / fraction / odd / int / list / numpy ---
test_eq(_normalize_patch_overlap(0, [96, 96, 96]), (0, 0, 0))
test_eq(_normalize_patch_overlap(0.5, [96, 96, 96]), (48, 48, 48))
test_eq(_normalize_patch_overlap(0.5, [98, 98, 98]), (48, 48, 48))   # 49 coerced down to even 48
test_eq(_normalize_patch_overlap(47, [64, 64, 64]), (46, 46, 46))    # odd int coerced even
test_eq(_normalize_patch_overlap([10, 11, 12], [64, 64, 64]), (10, 10, 12))
test_eq(_normalize_patch_overlap(np.int64(48), [96, 96, 96]), (48, 48, 48))

# --- T2: PatchConfig.__post_init__ validation + divisibility warning ---
test_fail(lambda: PatchConfig(sampler_type='bogus'), contains='sampler_type must be one of')
test_fail(lambda: PatchConfig(aggregation_mode='bogus'), contains='aggregation_mode must be one of')
test_fail(lambda: PatchConfig(patch_overlap=-1), contains='cannot be negative')
test_fail(lambda: PatchConfig(patch_overlap=[-1, 0, 0]), contains='cannot be negative')
test_fail(lambda: PatchConfig(patch_size=[64, 64, 64], patch_overlap=64), contains='must be less than patch_size')
test_fail(lambda: PatchConfig(patch_size=[64, 64, 64], patch_overlap=[10, 10, 64]), contains='must be less than patch_size')
_ok = PatchConfig(patch_size=[96, 96, 96], patch_overlap=0.5, sampler_type='label', aggregation_mode='hann')
test_eq(_ok.sampler_type, 'label')
with warnings.catch_warnings(record=True) as _w:
    warnings.simplefilter('always')
    PatchConfig(patch_size=[90, 90, 90])
    test_eq(any('divisible by 16' in str(_x.message) for _x in _w), True)

# --- T3: return_probabilities=True returns a float probability map (production path, was untested) ---
with _tempfile.TemporaryDirectory() as _tmp:
    _ip = _os.path.join(_tmp, 'img.nii.gz')
    _nib.save(_nib.Nifti1Image(np.random.randn(32, 32, 32).astype(np.float32), np.eye(4)), _ip)
    _net = _UNet(spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32),
                 strides=(2,), num_res_units=1, norm=_Norm.INSTANCE).eval()
    _eng = PatchInferenceEngine(_net, PatchConfig(patch_size=[32, 32, 32]), apply_reorder=False)
    _prob = _eng.predict(_ip, return_probabilities=True)
    assert _prob.shape[0] == 2, f'expected 2 prob channels, got {tuple(_prob.shape)}'
    assert _prob.dtype.is_floating_point, _prob.dtype
    assert float(_prob.min()) >= 0.0 and float(_prob.max()) <= 1.0
    _mask = _eng.predict(_ip, return_probabilities=False)
    assert _mask.shape[0] == 1, f'expected 1 mask channel, got {tuple(_mask.shape)}'

# --- T4: from_df happy path (split, metadata contract, preprocessed skip) ---
def _tfm_names(ds):
    t = getattr(ds, 'transform', None)
    if t is None: t = getattr(ds, '_transform', None)
    if t is None: return []
    return [type(x).__name__ for x in getattr(t, 'transforms', [t])]

with _tempfile.TemporaryDirectory() as _tmp:
    _rows = []
    for _i in range(4):
        _ipath = _os.path.join(_tmp, f'img_{_i}.nii.gz')
        _mpath = _os.path.join(_tmp, f'msk_{_i}.nii.gz')
        _nib.save(_nib.Nifti1Image(np.random.randn(24, 24, 24).astype(np.float32), np.eye(4)), _ipath)
        _nib.save(_nib.Nifti1Image((np.random.rand(24, 24, 24) > 0.5).astype(np.uint8), np.eye(4)), _mpath)
        _rows.append({'img': _ipath, 'mask': _mpath, 'is_val': _i >= 2})
    _df = pd.DataFrame(_rows)
    _cfg = PatchConfig(patch_size=[16, 16, 16], samples_per_volume=2,
                       sampler_type='label', label_probabilities={0: 0.5, 1: 0.5})

    _dls = MedPatchDataLoaders.from_df(_df, img_col='img', mask_col='mask',
                                       valid_pct=0.5, patch_config=_cfg, seed=0, bs=1)
    test_eq(len(_dls.split_df), 4)
    test_eq(len(_dls._train_source_df), 2)
    test_eq(len(_dls._valid_source_df), 2)
    assert isinstance(_dls.train, MedPatchDataLoader) and isinstance(_dls.valid, MedPatchDataLoader)
    for _a in ['_img_col', '_mask_col', '_pre_patch_tfms', '_apply_reorder', '_target_spacing',
               '_ensure_affine_consistency', '_patch_config', '_train_source_df', '_valid_source_df']:
        assert hasattr(_dls, _a), f'from_df must set {_a}'
    test_eq(_dls.patch_config is _cfg, True)
    test_eq(_dls.apply_reorder, True)

    # valid_col split branch
    _dls2 = MedPatchDataLoaders.from_df(_df, img_col='img', mask_col='mask',
                                        valid_col='is_val', patch_config=_cfg, bs=1)
    test_eq(len(_dls2._train_source_df), 2)
    test_eq(len(_dls2._valid_source_df), 2)

    # preprocessed=True skips reorder/resample
    _cfg_pp = PatchConfig(patch_size=[16, 16, 16], samples_per_volume=2,
                          preprocessed=True, target_spacing=[1, 1, 1])
    _dls3 = MedPatchDataLoaders.from_df(_df, img_col='img', mask_col='mask',
                                        valid_pct=0.5, patch_config=_cfg_pp, seed=0, bs=1)
    _names_pp = _tfm_names(_dls3.train_ds)
    assert 'Resample' not in _names_pp and 'ToCanonical' not in _names_pp, f'preprocessed should skip, got {_names_pp}'

    # contrast: non-preprocessed WITH target_spacing includes Resample
    _cfg_rs = PatchConfig(patch_size=[16, 16, 16], samples_per_volume=2, target_spacing=[1, 1, 1])
    _dls4 = MedPatchDataLoaders.from_df(_df, img_col='img', mask_col='mask',
                                        valid_pct=0.5, patch_config=_cfg_rs, seed=0, bs=1)
    assert 'Resample' in _tfm_names(_dls4.train_ds)

# --- TH: MedPatchDataLoader.__iter__ with patch_tfms yields (MedImage, MedMask) (guards _apply_patch_tfms) ---
with _tempfile.TemporaryDirectory() as _tmp:
    _ip = _os.path.join(_tmp, 'i.nii.gz'); _mp = _os.path.join(_tmp, 'm.nii.gz')
    _nib.save(_nib.Nifti1Image(np.random.randn(20, 20, 20).astype(np.float32), np.eye(4)), _ip)
    _nib.save(_nib.Nifti1Image((np.random.rand(20, 20, 20) > 0.5).astype(np.uint8), np.eye(4)), _mp)
    _ds_h = create_subjects_dataset(pd.DataFrame({'img': [_ip], 'mask': [_mp]}), 'img', 'mask')
    _cfg_h = PatchConfig(patch_size=[16, 16, 16], samples_per_volume=2, queue_length=4, queue_num_workers=0)
    _dl_h = MedPatchDataLoader(_ds_h, _cfg_h, batch_size=2, patch_tfms=[tio.RandomFlip(flip_probability=0.0)])
    _xb, _yb = next(iter(_dl_h))
    assert isinstance(_xb, MedImage) and isinstance(_yb, MedMask), (type(_xb), type(_yb))
    test_eq(_xb.shape[0], 2)
    test_eq(tuple(_xb.shape[-3:]), (16, 16, 16))
    _dl_h.close()

# --- TB: binary_threshold uses >= (MONAI AsDiscrete semantics) and is configurable ---
with _tempfile.TemporaryDirectory() as _tmp:
    _ip = _os.path.join(_tmp, 'img.nii.gz')
    _nib.save(_nib.Nifti1Image(np.zeros((16, 16, 16), dtype=np.float32), np.eye(4)), _ip)
    _net1 = _UNet(spatial_dims=3, in_channels=1, out_channels=1, channels=(8, 16),
                  strides=(2,), num_res_units=1, norm=_Norm.INSTANCE).eval()
    # default threshold 0.5 with '>=': probability exactly 0.5 is foreground (would be background under '>')
    _eng_d = PatchInferenceEngine(_net1, PatchConfig(patch_size=[16, 16, 16]), apply_reorder=False)
    _prep = _eng_d._prepare_subject(_ip)
    _sp = _prep.input_img.spatial_shape
    _res_half = _eng_d._postprocess(torch.full((1, *_sp), 0.5), _prep, return_probabilities=False)[0]
    test_eq(int(_res_half.sum()), _res_half.numel())   # all 0.5 >= 0.5 -> all foreground
    # higher threshold makes 0.5 background, 0.7 foreground
    _eng_h = PatchInferenceEngine(_net1, PatchConfig(patch_size=[16, 16, 16], binary_threshold=0.7), apply_reorder=False)
    _prep_h = _eng_h._prepare_subject(_ip)
    _res_low = _eng_h._postprocess(torch.full((1, *_sp), 0.5), _prep_h, return_probabilities=False)[0]
    test_eq(int(_res_low.sum()), 0)                     # 0.5 < 0.7 -> background
    _res_at = _eng_h._postprocess(torch.full((1, *_sp), 0.7), _prep_h, return_probabilities=False)[0]
    test_eq(int(_res_at.sum()), _res_at.numel())        # 0.7 >= 0.7 -> foreground
test_fail(lambda: PatchConfig(binary_threshold=1.5), contains='binary_threshold')

print('Safety-net tests (T1-T4, TH, TB) passed!')
Safety-net tests (T1-T4, TH, TB) passed!
apply_reorder mismatch: explicit=False, config=True. Using explicit argument.
apply_reorder mismatch: explicit=False, config=True. Using explicit argument.

Export