from fastMONAI.vision_all import *
from monai.apps import DecathlonDataset
from sklearn.model_selection import train_test_splitPatch-Based Binary Segmentation
MedPatchDataLoaders with lazy loading - volumes are loaded on-demand, keeping memory usage constant regardless of dataset size.
Download external data
We use the MONAI function DecathlonDataset to download the Heart MRI dataset from the Medical Segmentation Decathlon challenge.
path = Path('../data')
path.mkdir(exist_ok=True)task = "Task02_Heart"
training_data = DecathlonDataset(root_dir=path, task=task, section="training",
download=True, cache_num=0, num_workers=3)df = pd.DataFrame(training_data.data)
df.shapeSplit the labeled data into training and test sets.
train_df, test_df = train_test_split(df, test_size=0.1, random_state=42)
train_df.shape, test_df.shapeAnalyze training data
Use MedDataset to analyze the dataset and get preprocessing recommendations.
med_dataset = MedDataset(img_list=train_df.label.tolist(), dtype=MedMask, max_workers=12)med_dataset.df.head()data_info_df = med_dataset.summary()suggestion = med_dataset.get_suggestion()
target_spacing = suggestion['target_spacing']
target_spacingstats = med_dataset.get_size_statistics(target_spacing=target_spacing)
print(f"Image sizes (after resampling to {target_spacing}):")
print(f" Min: {stats['min']}")
print(f" Median: {stats['median']}")
print(f" Max: {stats['max']}")
suggested_size = suggest_patch_size(med_dataset, target_spacing=target_spacing)
print(f"\nSuggested patch size: {suggested_size}")Configure patch-based training
PatchConfig centralizes all patch-related parameters:
- patch_size: Size of extracted patches
[x, y, z]- should be divisible by 16 for UNet compatibility - samples_per_volume: Number of patches extracted per volume per epoch
- sampler_type:
'uniform'(random) or'label'(foreground-weighted) - label_probabilities: For
'label'sampler, probability of sampling each class - queue_length: Number of patches to keep in memory buffer
- patch_overlap: Overlap for inference (float 0-1 for fraction, or int for pixels)
- aggregation_mode: How to combine overlapping patches (
'hann'for smooth boundaries) - padding_mode: Padding mode when image < patch_size (0 = zero padding, nnU-Net standard)
patch_config = PatchConfig(
patch_size=[128, 128, 64],
samples_per_volume=8,
sampler_type='label',
label_probabilities={0: 0.2, 1: 0.8},
patch_overlap=0.5,
keep_largest_component=True,
target_spacing=target_spacing,
aggregation_mode='hann'
# apply_reorder defaults to True (the common case)
)
print(f"Patch config: {patch_config}")Alternative: Use
PatchConfig.from_dataset(med_dataset)to auto-configure patch size based on dataset analysis.
Define transforms
Patch-based training uses two stages of transforms:
- pre_patch_tfms: Applied to full volumes before patch extraction (e.g., normalization)
- patch_tfms: Applied to extracted patches during training (augmentations)
Automatic padding: Images smaller than patch_size are automatically padded to minimum required size using zero padding (nnU-Net standard). Large dimensions are preserved. This happens AFTER pre_patch_tfms to ensure normalization is applied to original intensities first.
Critical:
pre_patch_tfmsmust match between training and inference for consistent preprocessing.
# Pre-patch transforms (applied to full volumes by Queue workers)
pre_patch_tfms = [ZNormalization()]
# Patch augmentations (applied to training patches only)
patch_tfms = [
RandomAffine(scales=(0.95, 1.3), degrees=15, translation=5, p=0.5),
RandomGamma(log_gamma=(-0.3, 0.3), p=0.5),
RandomBiasField(coefficients=0.5, p=0.3),
RandomBlur(std=(0.0, 0.8), p=0.2),
RandomNoise(std=0.1, p=0.2),
RandomFlip(p=0.5),
]Create patch-based DataLoaders
MedPatchDataLoaders uses lazy loading: - Only file paths are stored at creation time (~0 MB) - Volumes are loaded on-demand by Queue workers
bs = 4
dls = MedPatchDataLoaders.from_df(
df=train_df,
img_col='image',
mask_col='label',
valid_pct=0.1,
patch_config=patch_config,
pre_patch_tfms=pre_patch_tfms,
patch_tfms=patch_tfms,
bs=bs,
seed=42
)
print(f"Training subjects: {len(dls.train.subjects_dataset)}")
print(f"Validation subjects: {len(dls.valid.subjects_dataset)}")# Visualize a batch of patches
batch = next(iter(dls.train))
x, y = batch
print(f"Batch shape - Image: {x.shape}, Mask: {y.shape}")Create and train a 3D model
We use MONAI’s UNet with: - out_channels=2: Softmax output (background + foreground) - Instance normalization: Common in medical imaging - DiceCELoss: Combines Dice loss with Cross-Entropy for stable training
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceCELossmodel = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.INSTANCE
)
loss_func = CustomLoss(loss_func=DiceCELoss(
to_onehot_y=True,
softmax=True,
include_background=True
))We use AccumulatedDice metric which accumulates true positives, false positives, and false negatives across all validation batches before computing Dice.
learn = Learner(dls, model, loss_func=loss_func, metrics=[AccumulatedDice(n_classes=2)])learn.lr_find()lr = 1e-3best_model_fname = "best_heart_patch"
save_best = SaveModelCallback(
monitor='accumulated_dice',
comp=np.greater,
fname=best_model_fname,
with_opt=False
)# All params (size, transforms, loss, model name, etc.) are extracted from learn automatically
mlflow_callback = create_mlflow_callback(learn, experiment_name="Task02_Heart_Patch")
learn.fit_one_cycle(100, lr, cbs=[mlflow_callback, save_best])learn.recorder.plot_loss();Save configuration for inference
The following parameters MUST be identical between training and inference: - apply_reorder: Whether to reorder to RAS+ orientation - target_spacing: Target voxel spacing for resampling - pre_patch_tfms: Preprocessing transforms (e.g., ZNormalization)
Save these to a pickle file for the inference notebook to load.
store_patch_variables(
pkl_fn='patch_config.pkl',
patch_size=patch_config.patch_size,
patch_overlap=patch_config.patch_overlap,
aggregation_mode=patch_config.aggregation_mode,
apply_reorder=patch_config.apply_reorder,
target_spacing=patch_config.target_spacing,
sampler_type=patch_config.sampler_type,
label_probabilities=patch_config.label_probabilities,
samples_per_volume=patch_config.samples_per_volume,
queue_length=patch_config.queue_length,
queue_num_workers=patch_config.queue_num_workers,
keep_largest_component=patch_config.keep_largest_component
)
print("Configuration saved to patch_config.pkl")Evaluate on validation set
Evaluate the trained model on the validation set using sliding-window patch inference.
Metrics: - DSC: Dice score - overlap similarity (higher = better) - HD95: 95th percentile Hausdorff distance in voxels (lower = better) - Sens: Sensitivity - true positive rate (higher = better) - LDR: Lesion detection rate (higher = better) - RVE: Relative volume error (0 = optimal, + = over-seg, - = under-seg)
# Extract file paths from validation subjects
val_subjects = dls.valid.subjects_dataset
val_img_paths = [str(s['image'].path) for s in val_subjects]
val_mask_paths = [str(s['mask'].path) for s in val_subjects]
print(f"Validation set: {len(val_img_paths)} images")# Load the best model (saved locally by SaveModelCallback)
learn.load(best_model_fname)
print(f"Loaded best model: {best_model_fname}")
learn.export(f"models/learner.pkl")
save_dir = Path('predictions/patch_heart_val')
predictions = patch_inference(
learner=learn,
config=patch_config,
file_paths=val_img_paths,
pre_inference_tfms=pre_patch_tfms,
save_dir=str(save_dir),
progress=True
)
print(f"\nSaved {len(predictions)} predictions to {save_dir}/")from fastMONAI.vision_metrics import (calculate_dsc, calculate_haus,
calculate_confusion_metrics,
calculate_lesion_detection_rate,
calculate_signed_rve)
results = []
for i, pred in enumerate(predictions):
img_name = Path(val_img_paths[i]).name
gt_path = val_mask_paths[i]
# Load ground truth with matching preprocessing
gt = MedMask.create(
gt_path,
apply_reorder=patch_config.apply_reorder,
target_spacing=patch_config.target_spacing
)
# Prepare 5D tensors [B, C, D, H, W]
pred_5d = pred.unsqueeze(0).float()
gt_5d = gt.data.unsqueeze(0).float()
# Compute metrics using calculate_* functions (for binary masks)
dsc = calculate_dsc(pred_5d, gt_5d).mean().item()
hd95 = calculate_haus(pred_5d, gt_5d).mean().item()
sens = calculate_confusion_metrics(pred_5d, gt_5d, "sensitivity").nanmean().item()
ldr = calculate_lesion_detection_rate(pred_5d, gt_5d).nanmean().item()
rve = calculate_signed_rve(pred_5d, gt_5d).nanmean().item()
results.append({
'Image': img_name, 'DSC': dsc, 'HD95': hd95,
'Sens': sens, 'LDR': ldr, 'RVE': rve
})
results_df = pd.DataFrame(results)
results_dfmetrics = ['DSC', 'HD95', 'Sens', 'LDR', 'RVE']
print("Validation Set Performance Summary")
print("=" * 45)
for metric in metrics:
mean = results_df[metric].mean()
std = results_df[metric].std()
print(f"{metric:8s}: {mean:.4f} ± {std:.4f}")from fastMONAI.vision_plot import show_segmentation_comparison
# Visualize first validation case
idx = 0
val_img = MedImage.create(
val_img_paths[idx],
apply_reorder=patch_config.apply_reorder,
target_spacing=patch_config.target_spacing
)
val_gt = MedMask.create(
val_mask_paths[idx],
apply_reorder=patch_config.apply_reorder,
target_spacing=patch_config.target_spacing
)
show_segmentation_comparison(
image=val_img,
ground_truth=val_gt,
prediction=predictions[idx],
metric_value=results_df.iloc[idx]['DSC'],
voxel_size=patch_config.target_spacing,
anatomical_plane=2 # axial view
)View experiment tracking
mlflow_ui = MLflowUIManager()
mlflow_ui.start_ui()Summary
In this tutorial, we demonstrated patch-based training and evaluation for 3D medical image segmentation:
Training: 1. PatchConfig: Centralized configuration for patch size, sampling, and inference parameters 2. MedPatchDataLoaders: Memory-efficient lazy loading with TorchIO Queue 3. Two-stage transforms: pre_patch_tfms (full volume) + patch_tfms (training augmentation) 4. AccumulatedDice: nnU-Net-style accumulated metric for reliable validation 5. store_patch_variables(): Save configuration for inference consistency
Evaluation: 6. patch_inference(): Batch sliding-window inference with NIfTI output 7. Comprehensive metrics: DSC, HD95, Sensitivity, LDR, RVE
Key takeaway: Preprocessing parameters (apply_reorder, target_spacing, pre_inference_tfms) MUST match between training and inference for correct predictions.
mlflow_ui.stop()