from fastMONAI.vision_all import *
from monai.apps import DecathlonDataset
from sklearn.model_selection import train_test_splitPatch-Based Inference
PatchInferenceEngine.
Load data
Load the same dataset and recreate the train/test split from the training notebook.
path = Path('../data')
path.mkdir(exist_ok=True)task = "Task02_Heart"
training_data = DecathlonDataset(root_dir=path, task=task, section="training",
download=True, cache_num=0, num_workers=3)2026-02-03 15:10:15,182 - INFO - Verified 'Task02_Heart.tar', md5: 06ee59366e1e5124267b774dbd654057.
2026-02-03 15:10:15,182 - INFO - File exists: ../data/Task02_Heart.tar, skipped downloading.
2026-02-03 15:10:15,183 - INFO - Non-empty folder exists in ../data/Task02_Heart, skipped extracting.
df = pd.DataFrame(training_data.data)
train_df, test_df = train_test_split(df, test_size=0.1, random_state=42)
print(f"Test samples: {len(test_df)}")
test_dfTest samples: 2
| image | label | |
|---|---|---|
| 0 | ../data/Task02_Heart/imagesTr/la_030.nii.gz | ../data/Task02_Heart/labelsTr/la_030.nii.gz |
| 1 | ../data/Task02_Heart/imagesTr/la_024.nii.gz | ../data/Task02_Heart/labelsTr/la_024.nii.gz |
Load configuration
Load the patch configuration saved during training. This ensures preprocessing consistency:
- apply_reorder: Whether to reorder to RAS+ orientation
- target_spacing: Target voxel spacing for resampling
- patch_size, patch_overlap, aggregation_mode: Inference parameters
Critical: Mismatched preprocessing parameters will produce incorrect predictions (e.g., mirrored or rotated outputs).
config_dict = load_patch_variables('patch_config.pkl')
print("Loaded configuration:")
for k, v in config_dict.items():
print(f" {k}: {v}")Loaded configuration:
patch_size: [128, 128, 64]
patch_overlap: 0.5
aggregation_mode: hann
apply_reorder: True
target_spacing: [1.25, 1.25, 1.37]
sampler_type: label
label_probabilities: {0: 0.2, 1: 0.8}
samples_per_volume: 8
queue_length: 300
queue_num_workers: 4
keep_largest_component: True
patch_config = PatchConfig(**config_dict)
apply_reorder = config_dict['apply_reorder']
target_spacing = config_dict['target_spacing']Load trained model
Load the exported learner which contains both the model architecture and best weights.
Two options: 1. Local file: Load best_learner.pkl exported during training 2. MLflow: Download from MLflow artifacts (for experiment tracking workflows)
from fastai.learner import load_learner
import torch
# Option 1: Load from local file
learn = load_learner('models/best_learner.pkl')
# Option 2: Load from MLflow (uncomment to use)
# import mlflow
# run_id = "your_run_id" # Get from MLflow UI
# mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path="model/best_learner.pkl", dst_path="./")
# learn = load_learner('best_learner.pkl')
model = learn.model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
print(f"Loaded model: {model.__class__.__name__}")
print(f"Device: {device}")Normalization (loaded from the config)
Normalization was saved in the PatchConfig during training, so it is already present in the loaded patch_config and is applied automatically – nothing to re-specify here.
Troubleshooting: If predictions appear mirrored, rotated, or wrong, check that
apply_reorder,target_spacing, andnormalizationin the loaded config match training. To override, passpre_inference_tfms=explicitly (e.g. for non-serializable transforms).
# Normalization is already in the loaded config and applied automatically:
print("Normalization from config:", patch_config.normalization)Create PatchInferenceEngine
The engine handles the complete sliding-window inference pipeline: 1. Load and preprocess the image (reorder, resample, normalize) 2. Pad if image is smaller than patch size 3. Extract overlapping patches with GridSampler 4. Predict on batches of patches 5. Reconstruct full volume with GridAggregator using Hann windowing
engine = PatchInferenceEngine(
learner=model, # Pass model directly (no Learner needed)
config=patch_config,
batch_size=4
)Single image inference
Use engine.predict() to predict on a single image.
test_path = test_df.iloc[0]['image']
pred, affine = engine.predict(test_path, return_affine=True)
print(f"Input: {test_path}")
print(f"Prediction shape: {pred.shape}")
print(f"Unique values: {torch.unique(pred).tolist()}")Input: ../data/Task02_Heart/imagesTr/la_030.nii.gz
Prediction shape: torch.Size([1, 320, 320, 110])
Unique values: [0, 1]
Batch inference
Use patch_inference() to predict on multiple images with optional NIfTI output.
test_paths = test_df['image'].tolist()
predictions = patch_inference(
learner=model, # Pass model directly (no Learner needed)
config=patch_config,
file_paths=test_paths,
save_dir='predictions/patch_heart',
progress=True
)
print(f"\nGenerated {len(predictions)} predictions")
print(f"Saved to: predictions/patch_heart/")
Generated 2 predictions
Saved to: predictions/patch_heart/
Evaluate predictions
Evaluate predictions against ground truth using comprehensive metrics.
Metrics: - DSC: Dice score - overlap similarity (higher = better) - HD95: 95th percentile Hausdorff distance in voxels (lower = better) - Sens: Sensitivity - true positive rate (higher = better) - LDR: Lesion detection rate (higher = better) - RVE: Relative volume error (0 = optimal, + = over-seg, - = under-seg)
from fastMONAI.vision_metrics import (calculate_dsc, calculate_haus,
calculate_confusion_metrics,
calculate_lesion_detection_rate,
calculate_signed_rve)
results = []
for i, pred in enumerate(predictions):
img_name = Path(test_df.iloc[i]['image']).name
gt_path = test_df.iloc[i]['label']
# Load ground truth with matching preprocessing
gt = MedMask.create(gt_path, apply_reorder=apply_reorder, target_spacing=target_spacing)
# Prepare 5D tensors [B, C, D, H, W]
pred_5d = pred.unsqueeze(0).float()
gt_5d = gt.data.unsqueeze(0).float()
# Compute metrics using calculate_* functions (for binary masks)
dsc = calculate_dsc(pred_5d, gt_5d).mean().item()
hd95 = calculate_haus(pred_5d, gt_5d).mean().item()
sens = calculate_confusion_metrics(pred_5d, gt_5d, "sensitivity").nanmean().item()
ldr = calculate_lesion_detection_rate(pred_5d, gt_5d).nanmean().item()
rve = calculate_signed_rve(pred_5d, gt_5d).nanmean().item()
results.append({
'Image': img_name, 'DSC': dsc, 'HD95': hd95,
'Sens': sens, 'LDR': ldr, 'RVE': rve
})
results_df = pd.DataFrame(results)
results_dfmonai.metrics.utils get_mask_edges:always_return_as_numpy: Argument `always_return_as_numpy` has been deprecated since version 1.5.0. It will be removed in version 1.7.0. The option is removed and the return type will always be equal to the input type.
| Image | DSC | HD95 | Sens | LDR | RVE | |
|---|---|---|---|---|---|---|
| 0 | la_030.nii.gz | 0.938002 | 3.000000 | 0.918056 | 1.0 | -0.042528 |
| 1 | la_024.nii.gz | 0.922835 | 3.316625 | 0.907937 | 1.0 | -0.032287 |
metrics = ['DSC', 'HD95', 'Sens', 'LDR', 'RVE']
print("Test Set Performance Summary")
print("=" * 45)
for metric in metrics:
mean = results_df[metric].mean()
std = results_df[metric].std()
print(f"{metric:8s}: {mean:.4f} ± {std:.4f}")Test Set Performance Summary
=============================================
DSC : 0.9304 ± 0.0107
HD95 : 3.1583 ± 0.2239
Sens : 0.9130 ± 0.0072
LDR : 1.0000 ± 0.0000
RVE : -0.0374 ± 0.0072
Visualize predictions
from fastMONAI.vision_plot import show_segmentation_comparison
# Visualize first test case
idx = 0
test_img = MedImage.create(
test_df.iloc[idx]['image'],
apply_reorder=apply_reorder,
target_spacing=target_spacing
)
test_gt = MedMask.create(
test_df.iloc[idx]['label'],
apply_reorder=apply_reorder,
target_spacing=target_spacing
)
show_segmentation_comparison(
image=test_img,
ground_truth=test_gt,
prediction=predictions[idx],
metric_value=results_df.iloc[idx]['DSC'],
voxel_size=target_spacing,
anatomical_plane=2 # axial view
)
Summary
In this tutorial, we demonstrated patch-based inference for 3D medical image segmentation:
load_patch_variables(): Load configuration saved during training- Direct weight loading: Use
torch.load()to load weights directly (no Learner/DataLoaders needed) PatchInferenceEngine: Sliding-window inference with GridSampler and GridAggregatorpatch_inference(): Batch inference with optional NIfTI output
Key points: - Preprocessing parameters (normalization, apply_reorder, target_spacing) are carried in the saved config and applied automatically, so they match training by construction - Pass model directly to PatchInferenceEngine - no need for fastai Learner at inference time - Hann windowing (aggregation_mode='hann') produces smooth boundaries between patches - keep_largest_component=True can clean up small spurious predictions
When to use patch-based inference: - Large images that don’t fit in GPU memory - Variable-sized inputs (no resizing needed) - Memory-constrained environments
Tradeoffs: - Slower than full-image inference (multiple forward passes) - Overlap and aggregation add computational overhead - But enables processing of arbitrarily large volumes