Patch-Based Inference

Sliding-window inference for 3D medical image segmentation using fastMONAI’s PatchInferenceEngine.

Google Colab
from fastMONAI.vision_all import *

from monai.apps import DecathlonDataset
from sklearn.model_selection import train_test_split

Load data

Load the same dataset and recreate the train/test split from the training notebook.

path = Path('../data')
path.mkdir(exist_ok=True)
task = "Task02_Heart"
training_data = DecathlonDataset(root_dir=path, task=task, section="training", 
    download=True, cache_num=0, num_workers=3)
2026-02-03 15:10:15,182 - INFO - Verified 'Task02_Heart.tar', md5: 06ee59366e1e5124267b774dbd654057.
2026-02-03 15:10:15,182 - INFO - File exists: ../data/Task02_Heart.tar, skipped downloading.
2026-02-03 15:10:15,183 - INFO - Non-empty folder exists in ../data/Task02_Heart, skipped extracting.
df = pd.DataFrame(training_data.data)
train_df, test_df = train_test_split(df, test_size=0.1, random_state=42)
print(f"Test samples: {len(test_df)}")
test_df
Test samples: 2
image label
0 ../data/Task02_Heart/imagesTr/la_030.nii.gz ../data/Task02_Heart/labelsTr/la_030.nii.gz
1 ../data/Task02_Heart/imagesTr/la_024.nii.gz ../data/Task02_Heart/labelsTr/la_024.nii.gz

Load configuration

Load the patch configuration saved during training. This ensures preprocessing consistency:

  • apply_reorder: Whether to reorder to RAS+ orientation
  • target_spacing: Target voxel spacing for resampling
  • patch_size, patch_overlap, aggregation_mode: Inference parameters

Critical: Mismatched preprocessing parameters will produce incorrect predictions (e.g., mirrored or rotated outputs).

config_dict = load_patch_variables('patch_config.pkl')
print("Loaded configuration:")
for k, v in config_dict.items():
    print(f"  {k}: {v}")
Loaded configuration:
  patch_size: [128, 128, 64]
  patch_overlap: 0.5
  aggregation_mode: hann
  apply_reorder: True
  target_spacing: [1.25, 1.25, 1.37]
  sampler_type: label
  label_probabilities: {0: 0.2, 1: 0.8}
  samples_per_volume: 8
  queue_length: 300
  queue_num_workers: 4
  keep_largest_component: True
patch_config = PatchConfig(**config_dict)
apply_reorder = config_dict['apply_reorder']
target_spacing = config_dict['target_spacing']

Load trained model

Load the exported learner which contains both the model architecture and best weights.

Two options: 1. Local file: Load best_learner.pkl exported during training 2. MLflow: Download from MLflow artifacts (for experiment tracking workflows)

from fastai.learner import load_learner
import torch

# Option 1: Load from local file
learn = load_learner('models/best_learner.pkl')

# Option 2: Load from MLflow (uncomment to use)
# import mlflow
# run_id = "your_run_id"  # Get from MLflow UI
# mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path="model/best_learner.pkl", dst_path="./")
# learn = load_learner('best_learner.pkl')

model = learn.model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

print(f"Loaded model: {model.__class__.__name__}")
print(f"Device: {device}")

Normalization (loaded from the config)

Normalization was saved in the PatchConfig during training, so it is already present in the loaded patch_config and is applied automatically – nothing to re-specify here.

Troubleshooting: If predictions appear mirrored, rotated, or wrong, check that apply_reorder, target_spacing, and normalization in the loaded config match training. To override, pass pre_inference_tfms= explicitly (e.g. for non-serializable transforms).

# Normalization is already in the loaded config and applied automatically:
print("Normalization from config:", patch_config.normalization)

Create PatchInferenceEngine

The engine handles the complete sliding-window inference pipeline: 1. Load and preprocess the image (reorder, resample, normalize) 2. Pad if image is smaller than patch size 3. Extract overlapping patches with GridSampler 4. Predict on batches of patches 5. Reconstruct full volume with GridAggregator using Hann windowing

engine = PatchInferenceEngine(
    learner=model,  # Pass model directly (no Learner needed)
    config=patch_config,
    batch_size=4
)

Single image inference

Use engine.predict() to predict on a single image.

test_path = test_df.iloc[0]['image']
pred, affine = engine.predict(test_path, return_affine=True)
print(f"Input: {test_path}")
print(f"Prediction shape: {pred.shape}")
print(f"Unique values: {torch.unique(pred).tolist()}")
Input: ../data/Task02_Heart/imagesTr/la_030.nii.gz
Prediction shape: torch.Size([1, 320, 320, 110])
Unique values: [0, 1]

Batch inference

Use patch_inference() to predict on multiple images with optional NIfTI output.

test_paths = test_df['image'].tolist()

predictions = patch_inference(
    learner=model,  # Pass model directly (no Learner needed)
    config=patch_config,
    file_paths=test_paths,
    save_dir='predictions/patch_heart',
    progress=True
)

print(f"\nGenerated {len(predictions)} predictions")
print(f"Saved to: predictions/patch_heart/")

Generated 2 predictions
Saved to: predictions/patch_heart/

Evaluate predictions

Evaluate predictions against ground truth using comprehensive metrics.

Metrics: - DSC: Dice score - overlap similarity (higher = better) - HD95: 95th percentile Hausdorff distance in voxels (lower = better) - Sens: Sensitivity - true positive rate (higher = better) - LDR: Lesion detection rate (higher = better) - RVE: Relative volume error (0 = optimal, + = over-seg, - = under-seg)

from fastMONAI.vision_metrics import (calculate_dsc, calculate_haus,
                                       calculate_confusion_metrics,
                                       calculate_lesion_detection_rate,
                                       calculate_signed_rve)

results = []

for i, pred in enumerate(predictions):
    img_name = Path(test_df.iloc[i]['image']).name
    gt_path = test_df.iloc[i]['label']

    # Load ground truth with matching preprocessing
    gt = MedMask.create(gt_path, apply_reorder=apply_reorder, target_spacing=target_spacing)

    # Prepare 5D tensors [B, C, D, H, W]
    pred_5d = pred.unsqueeze(0).float()
    gt_5d = gt.data.unsqueeze(0).float()

    # Compute metrics using calculate_* functions (for binary masks)
    dsc = calculate_dsc(pred_5d, gt_5d).mean().item()
    hd95 = calculate_haus(pred_5d, gt_5d).mean().item()
    sens = calculate_confusion_metrics(pred_5d, gt_5d, "sensitivity").nanmean().item()
    ldr = calculate_lesion_detection_rate(pred_5d, gt_5d).nanmean().item()
    rve = calculate_signed_rve(pred_5d, gt_5d).nanmean().item()

    results.append({
        'Image': img_name, 'DSC': dsc, 'HD95': hd95,
        'Sens': sens, 'LDR': ldr, 'RVE': rve
    })

results_df = pd.DataFrame(results)
results_df
monai.metrics.utils get_mask_edges:always_return_as_numpy: Argument `always_return_as_numpy` has been deprecated since version 1.5.0. It will be removed in version 1.7.0. The option is removed and the return type will always be equal to the input type.
Image DSC HD95 Sens LDR RVE
0 la_030.nii.gz 0.938002 3.000000 0.918056 1.0 -0.042528
1 la_024.nii.gz 0.922835 3.316625 0.907937 1.0 -0.032287
metrics = ['DSC', 'HD95', 'Sens', 'LDR', 'RVE']

print("Test Set Performance Summary")
print("=" * 45)
for metric in metrics:
    mean = results_df[metric].mean()
    std = results_df[metric].std()
    print(f"{metric:8s}: {mean:.4f} ± {std:.4f}")
Test Set Performance Summary
=============================================
DSC     : 0.9304 ± 0.0107
HD95    : 3.1583 ± 0.2239
Sens    : 0.9130 ± 0.0072
LDR     : 1.0000 ± 0.0000
RVE     : -0.0374 ± 0.0072

Visualize predictions

from fastMONAI.vision_plot import show_segmentation_comparison

# Visualize first test case
idx = 0
test_img = MedImage.create(
    test_df.iloc[idx]['image'],
    apply_reorder=apply_reorder,
    target_spacing=target_spacing
)
test_gt = MedMask.create(
    test_df.iloc[idx]['label'],
    apply_reorder=apply_reorder,
    target_spacing=target_spacing
)

show_segmentation_comparison(
    image=test_img,
    ground_truth=test_gt,
    prediction=predictions[idx],
    metric_value=results_df.iloc[idx]['DSC'],
    voxel_size=target_spacing,
    anatomical_plane=2  # axial view
)

Summary

In this tutorial, we demonstrated patch-based inference for 3D medical image segmentation:

  1. load_patch_variables(): Load configuration saved during training
  2. Direct weight loading: Use torch.load() to load weights directly (no Learner/DataLoaders needed)
  3. PatchInferenceEngine: Sliding-window inference with GridSampler and GridAggregator
  4. patch_inference(): Batch inference with optional NIfTI output

Key points: - Preprocessing parameters (normalization, apply_reorder, target_spacing) are carried in the saved config and applied automatically, so they match training by construction - Pass model directly to PatchInferenceEngine - no need for fastai Learner at inference time - Hann windowing (aggregation_mode='hann') produces smooth boundaries between patches - keep_largest_component=True can clean up small spurious predictions

When to use patch-based inference: - Large images that don’t fit in GPU memory - Variable-sized inputs (no resizing needed) - Memory-constrained environments

Tradeoffs: - Slower than full-image inference (multiple forward passes) - Overlap and aggregation add computational overhead - But enables processing of arbitrarily large volumes