Patch-Based Binary Segmentation

Patch-based training for 3D medical image segmentation using fastMONAI’s MedPatchDataLoaders with lazy loading - volumes are loaded on-demand, keeping memory usage constant regardless of dataset size.

Google Colab
from fastMONAI.vision_all import *

from monai.apps import DecathlonDataset
from sklearn.model_selection import train_test_split

Download external data

We use the MONAI function DecathlonDataset to download the Heart MRI dataset from the Medical Segmentation Decathlon challenge.

path = Path('../data')
path.mkdir(exist_ok=True)
task = "Task02_Heart"
training_data = DecathlonDataset(root_dir=path, task=task, section="training", 
    download=True, cache_num=0, num_workers=3)
df = pd.DataFrame(training_data.data)
df.shape

Split the labeled data into training and test sets.

train_df, test_df = train_test_split(df, test_size=0.1, random_state=42)
train_df.shape, test_df.shape

Analyze training data

Use MedDataset to analyze the dataset and get preprocessing recommendations.

med_dataset = MedDataset(img_list=train_df.label.tolist(), dtype=MedMask, max_workers=12)
med_dataset.df.head()
data_info_df = med_dataset.summary()
suggestion = med_dataset.get_suggestion()
target_spacing = suggestion['target_spacing']
target_spacing
stats = med_dataset.get_size_statistics(target_spacing=target_spacing)
print(f"Image sizes (after resampling to {target_spacing}):")
print(f"  Min:    {stats['min']}")
print(f"  Median: {stats['median']}")
print(f"  Max:    {stats['max']}")

suggested_size = suggest_patch_size(med_dataset, target_spacing=target_spacing)
print(f"\nSuggested patch size: {suggested_size}")

Configure patch-based training

PatchConfig centralizes all patch-related parameters:

  • patch_size: Size of extracted patches [x, y, z] - should be divisible by 16 for UNet compatibility
  • samples_per_volume: Number of patches extracted per volume per epoch
  • sampler_type: 'uniform' (random) or 'label' (foreground-weighted)
  • label_probabilities: For 'label' sampler, probability of sampling each class
  • queue_length: Number of patches to keep in memory buffer
  • patch_overlap: Overlap for inference (float 0-1 for fraction, or int for pixels)
  • aggregation_mode: How to combine overlapping patches ('hann' for smooth boundaries)
  • padding_mode: Padding mode when image < patch_size (0 = zero padding, nnU-Net standard)
patch_config = PatchConfig(
    patch_size=[128, 128, 64],
    samples_per_volume=8,
    sampler_type='label',
    label_probabilities={0: 0.2, 1: 0.8},
    patch_overlap=0.5,
    keep_largest_component=True,
    target_spacing=target_spacing,
    aggregation_mode='hann'
    # apply_reorder defaults to True (the common case)
)

print(f"Patch config: {patch_config}")

Alternative: Use PatchConfig.from_dataset(med_dataset) to auto-configure patch size based on dataset analysis.

Define transforms

Patch-based training uses two stages of transforms:

  1. pre_patch_tfms: Applied to full volumes before patch extraction (e.g., normalization)
  2. patch_tfms: Applied to extracted patches during training (augmentations)

Automatic padding: Images smaller than patch_size are automatically padded to minimum required size using zero padding (nnU-Net standard). Large dimensions are preserved. This happens AFTER pre_patch_tfms to ensure normalization is applied to original intensities first.

Critical: pre_patch_tfms must match between training and inference for consistent preprocessing.

# Pre-patch transforms (applied to full volumes by Queue workers)
pre_patch_tfms = [ZNormalization()]

# Patch augmentations (applied to training patches only)
patch_tfms = [
    RandomAffine(scales=(0.95, 1.3), degrees=15, translation=5, p=0.5),
    RandomGamma(log_gamma=(-0.3, 0.3), p=0.5),
    RandomBiasField(coefficients=0.5, p=0.3),
    RandomBlur(std=(0.0, 0.8), p=0.2),
    RandomNoise(std=0.1, p=0.2),
    RandomFlip(p=0.5),
]

Create patch-based DataLoaders

MedPatchDataLoaders uses lazy loading: - Only file paths are stored at creation time (~0 MB) - Volumes are loaded on-demand by Queue workers

bs = 4

dls = MedPatchDataLoaders.from_df(
    df=train_df,
    img_col='image',
    mask_col='label',
    valid_pct=0.1,
    patch_config=patch_config,
    pre_patch_tfms=pre_patch_tfms,
    patch_tfms=patch_tfms,
    bs=bs,
    seed=42
)

print(f"Training subjects: {len(dls.train.subjects_dataset)}")
print(f"Validation subjects: {len(dls.valid.subjects_dataset)}")
# Visualize a batch of patches
batch = next(iter(dls.train))
x, y = batch
print(f"Batch shape - Image: {x.shape}, Mask: {y.shape}")

Create and train a 3D model

We use MONAI’s UNet with: - out_channels=2: Softmax output (background + foreground) - Instance normalization: Common in medical imaging - DiceCELoss: Combines Dice loss with Cross-Entropy for stable training

from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceCELoss
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.INSTANCE
)

loss_func = CustomLoss(loss_func=DiceCELoss(
    to_onehot_y=True,
    softmax=True,
    include_background=True
))

We use AccumulatedDice metric which accumulates true positives, false positives, and false negatives across all validation batches before computing Dice.

learn = Learner(dls, model, loss_func=loss_func, metrics=[AccumulatedDice(n_classes=2)])
learn.lr_find()
lr = 1e-3
best_model_fname = "best_heart_patch"
save_best = SaveModelCallback(
    monitor='accumulated_dice',
    comp=np.greater,
    fname=best_model_fname,
    with_opt=False
)
# All params (size, transforms, loss, model name, etc.) are extracted from learn automatically
mlflow_callback = create_mlflow_callback(learn, experiment_name="Task02_Heart_Patch")
learn.fit_one_cycle(100, lr, cbs=[mlflow_callback, save_best])
learn.recorder.plot_loss();

Save configuration for inference

The following parameters MUST be identical between training and inference: - apply_reorder: Whether to reorder to RAS+ orientation - target_spacing: Target voxel spacing for resampling - pre_patch_tfms: Preprocessing transforms (e.g., ZNormalization)

Save these to a pickle file for the inference notebook to load.

store_patch_variables(
    pkl_fn='patch_config.pkl',
    patch_size=patch_config.patch_size,
    patch_overlap=patch_config.patch_overlap,
    aggregation_mode=patch_config.aggregation_mode,
    apply_reorder=patch_config.apply_reorder,
    target_spacing=patch_config.target_spacing,
    sampler_type=patch_config.sampler_type,
    label_probabilities=patch_config.label_probabilities,
    samples_per_volume=patch_config.samples_per_volume,
    queue_length=patch_config.queue_length,
    queue_num_workers=patch_config.queue_num_workers,
    keep_largest_component=patch_config.keep_largest_component
)

print("Configuration saved to patch_config.pkl")

Evaluate on validation set

Evaluate the trained model on the validation set using sliding-window patch inference.

Metrics: - DSC: Dice score - overlap similarity (higher = better) - HD95: 95th percentile Hausdorff distance in voxels (lower = better) - Sens: Sensitivity - true positive rate (higher = better) - LDR: Lesion detection rate (higher = better) - RVE: Relative volume error (0 = optimal, + = over-seg, - = under-seg)

# Extract file paths from validation subjects
val_subjects = dls.valid.subjects_dataset
val_img_paths = [str(s['image'].path) for s in val_subjects]
val_mask_paths = [str(s['mask'].path) for s in val_subjects]

print(f"Validation set: {len(val_img_paths)} images")
# Load the best model (saved locally by SaveModelCallback)
learn.load(best_model_fname)
print(f"Loaded best model: {best_model_fname}")
learn.export(f"models/learner.pkl")

save_dir = Path('predictions/patch_heart_val')

predictions = patch_inference(
    learner=learn,
    config=patch_config,
    file_paths=val_img_paths,
    pre_inference_tfms=pre_patch_tfms,
    save_dir=str(save_dir),
    progress=True
)

print(f"\nSaved {len(predictions)} predictions to {save_dir}/")
from fastMONAI.vision_metrics import (calculate_dsc, calculate_haus,
                                       calculate_confusion_metrics,
                                       calculate_lesion_detection_rate,
                                       calculate_signed_rve)

results = []

for i, pred in enumerate(predictions):
    img_name = Path(val_img_paths[i]).name
    gt_path = val_mask_paths[i]

    # Load ground truth with matching preprocessing
    gt = MedMask.create(
        gt_path,
        apply_reorder=patch_config.apply_reorder,
        target_spacing=patch_config.target_spacing
    )

    # Prepare 5D tensors [B, C, D, H, W]
    pred_5d = pred.unsqueeze(0).float()
    gt_5d = gt.data.unsqueeze(0).float()

    # Compute metrics using calculate_* functions (for binary masks)
    dsc = calculate_dsc(pred_5d, gt_5d).mean().item()
    hd95 = calculate_haus(pred_5d, gt_5d).mean().item()
    sens = calculate_confusion_metrics(pred_5d, gt_5d, "sensitivity").nanmean().item()
    ldr = calculate_lesion_detection_rate(pred_5d, gt_5d).nanmean().item()
    rve = calculate_signed_rve(pred_5d, gt_5d).nanmean().item()

    results.append({
        'Image': img_name, 'DSC': dsc, 'HD95': hd95,
        'Sens': sens, 'LDR': ldr, 'RVE': rve
    })

results_df = pd.DataFrame(results)
results_df
metrics = ['DSC', 'HD95', 'Sens', 'LDR', 'RVE']

print("Validation Set Performance Summary")
print("=" * 45)
for metric in metrics:
    mean = results_df[metric].mean()
    std = results_df[metric].std()
    print(f"{metric:8s}: {mean:.4f} ± {std:.4f}")
from fastMONAI.vision_plot import show_segmentation_comparison

# Visualize first validation case
idx = 0
val_img = MedImage.create(
    val_img_paths[idx],
    apply_reorder=patch_config.apply_reorder,
    target_spacing=patch_config.target_spacing
)
val_gt = MedMask.create(
    val_mask_paths[idx],
    apply_reorder=patch_config.apply_reorder,
    target_spacing=patch_config.target_spacing
)

show_segmentation_comparison(
    image=val_img,
    ground_truth=val_gt,
    prediction=predictions[idx],
    metric_value=results_df.iloc[idx]['DSC'],
    voxel_size=patch_config.target_spacing,
    anatomical_plane=2  # axial view
)

View experiment tracking

mlflow_ui = MLflowUIManager()
mlflow_ui.start_ui()

Summary

In this tutorial, we demonstrated patch-based training and evaluation for 3D medical image segmentation:

Training: 1. PatchConfig: Centralized configuration for patch size, sampling, and inference parameters 2. MedPatchDataLoaders: Memory-efficient lazy loading with TorchIO Queue 3. Two-stage transforms: pre_patch_tfms (full volume) + patch_tfms (training augmentation) 4. AccumulatedDice: nnU-Net-style accumulated metric for reliable validation 5. store_patch_variables(): Save configuration for inference consistency

Evaluation: 6. patch_inference(): Batch sliding-window inference with NIfTI output 7. Comprehensive metrics: DSC, HD95, Sensitivity, LDR, RVE

Key takeaway: Preprocessing parameters (apply_reorder, target_spacing, pre_inference_tfms) MUST match between training and inference for correct predictions.

mlflow_ui.stop()