Patch-Based Cross-Validation and Final Model Training

Cross-validation gives a more reliable estimate of how well a model generalizes than a single

Google Colab
from fastMONAI.vision_all import *

from monai.apps import DecathlonDataset
from sklearn.model_selection import KFold
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceCELoss
import torchio as tio
import gc

Configuration

We use reduced settings so the notebook runs quickly as a demo. Production runs use many more epochs. The reference script research/vs_seg/patch_based_dev/train_cv.py trains each fold for 500 epochs.

EPOCHS = 20
BS = 4
LR = 1e-3
EXPERIMENT = "Task02_Heart_5Fold_CV"

Download the dataset

We use MONAI’s DecathlonDataset to download the Heart MRI dataset from the Medical Segmentation Decathlon. Cross-validation folds over all the cases we load, so we keep them in one DataFrame.

path = Path('../data')
path.mkdir(exist_ok=True)
task = "Task02_Heart"
training_data = DecathlonDataset(root_dir=path, task=task, section="training",
    download=True, cache_num=0, num_workers=3)
2026-07-01 11:28:15,976 - INFO - Verified 'Task02_Heart.tar', md5: 06ee59366e1e5124267b774dbd654057.
2026-07-01 11:28:15,977 - INFO - File exists: ../data/Task02_Heart.tar, skipped downloading.
2026-07-01 11:28:15,978 - INFO - Non-empty folder exists in ../data/Task02_Heart, skipped extracting.
df = pd.DataFrame(training_data.data)
df.shape
(16, 2)

Analyze the dataset

MedDataset analyzes the masks and recommends a target voxel spacing for resampling. We also keep med_dataset.fingerprint (a content hash of the dataset) to tag every MLflow run for reproducibility.

med_dataset = MedDataset(img_list=df.label.tolist(), dtype=MedMask, max_workers=12)
MedDataset cache: 16 cached, 0 processed
data_info_df = med_dataset.summary()
suggestion = med_dataset.get_suggestion()
target_spacing = suggestion['target_spacing']
target_spacing
[1.25, 1.25, 1.37]
stats = med_dataset.get_size_statistics(target_spacing=target_spacing)
print(f"Image sizes (after resampling to {target_spacing}):")
print(f"  Min:    {stats['min']}")
print(f"  Median: {stats['median']}")
print(f"  Max:    {stats['max']}")

suggested_size = suggest_patch_size(med_dataset, target_spacing=target_spacing)
print(f"\nSuggested patch size: {suggested_size}")
Image sizes (after resampling to [1.25, 1.25, 1.37]):
  Min:    [320.0, 320.0, 90.0]
  Median: [320.0, 320.0, 110.0]
  Max:    [320.0, 320.0, 130.0]

Suggested patch size: [256, 256, 80]
Dim 2: patch_size reduced from 96 to 80 to fit smallest volume (min=90, median=110).

Configure patch-based training

PatchConfig centralizes every patch-related parameter so that training and inference stay consistent. We use the same configuration for all five folds and for the final model, so the cross-validation estimate reflects the deployed model.

  • patch_size: size of extracted patches [x, y, z] (divisible by 16 for UNet)
  • samples_per_volume: patches extracted per volume per epoch
  • sampler_type / label_probabilities: 'label' sampling foreground-weights patches
  • patch_overlap / aggregation_mode: sliding-window inference settings ('hann' = smooth)
  • target_spacing: resampling grid (from the dataset analysis above)
  • normalization: set once here and applied at both training and inference
patch_config = PatchConfig(
    patch_size=[160, 160, 80],
    samples_per_volume=8,
    sampler_type='label',
    label_probabilities={0: 0.5, 1: 0.5},
    patch_overlap=0.5,
    keep_largest_component=True,
    target_spacing=target_spacing,
    aggregation_mode='hann',
    # normalization is set once on the config and applied at both training and inference
    # (no need to re-specify it at inference time).
    normalization=[ZNormalization(masking_method='foreground')]
)

print(f"Patch config: {patch_config}")
Patch config: PatchConfig(patch_size=[160, 160, 80], patch_overlap=0.5, samples_per_volume=8, sampler_type='label', label_probabilities={0: 0.5, 1: 0.5}, queue_length=300, queue_num_workers=4, aggregation_mode='hann', apply_reorder=True, target_spacing=[1.25, 1.25, 1.37], preprocessed=False, padding_mode=0, keep_largest_component=True, binary_threshold=0.5, normalization=[{'name': 'ZNormalization', 'masking_method': 'foreground', 'channel_wise': True}])

Define augmentations

Normalization is already set on the PatchConfig (applied at both training and inference). For the remaining augmentations we use gpu_patch_augmentations, which runs batched on the GPU (roughly 25x faster than the per-sample CPU path) and is applied to each training batch. It is mutually exclusive with patch_tfms. We reuse the same augmentation across all folds and the final model.

gpu_aug = gpu_patch_augmentations(patch_config.patch_size, patch_config.target_spacing)

Part 1: Cross-validation (estimate)

We now estimate generalization performance with 5-fold cross-validation:

  1. Split the subjects into 5 folds.
  2. For each fold, train a fresh model on the other 4 folds and evaluate on the held-out fold.
  3. Aggregate the per-case metrics into a mean +/- std summary across all held-out cases.

Each fold gets its own MLflow run, and every model is trained from scratch so no information leaks between folds.

Create cross-validation folds

KFold assigns each subject to exactly one of 5 folds. We store the fold number in a fold column (1 to 5); a subject’s fold is its validation fold (it trains in the other four).

kf = KFold(n_splits=5, shuffle=True, random_state=42)
df = df.reset_index(drop=True)
df['fold'] = -1

for fold_num, (_, val_idx) in enumerate(kf.split(df), start=1):
    df.loc[val_idx, 'fold'] = fold_num

df['fold'].value_counts().sort_index()
fold
1    4
2    3
3    3
4    3
5    3
Name: count, dtype: int64

Choosing a fold strategy. Plain KFold is fine when subjects are independent and roughly homogeneous. Real datasets often need:

  • StratifiedKFold to balance a property across folds, for example binning tumour volume into quartiles so each fold has a similar tumour-size distribution (as the VS dataset does).
  • GroupKFold to keep all scans from the same patient in the same fold, preventing leakage when a subject contributes multiple volumes.

Define the per-fold train-and-evaluate function

train_one_fold mirrors research/vs_seg/patch_based_dev/train_cv.py. For one fold it:

  1. Splits the data with valid_col='is_val' (the held-out fold is validation).
  2. Builds a fresh UNet, loss, and Learner (no weights carried over between folds).
  3. Trains with EMACheckpoint saving the best smoothed-Dice weights, logging to a per-fold MLflow run.
  4. Runs full-volume patch_inference (with TTA) on the held-out fold and computes the metric suite.
  5. Frees GPU memory before the next fold.

patch_inference returns each prediction in its volume’s original voxel space (it resizes back from the resampled grid), and the ground truth is loaded natively too, so surface metrics use the per-case voxel spacing read from the GT file (tio.LabelMap(path).spacing), mirroring train_cv.py. The cohort spacing is non-uniform, so a single fixed spacing would give wrong HD95/ASSD millimetres.

from fastMONAI.vision_metrics import (calculate_dsc, calculate_surface_metrics,
                                       calculate_confusion_metrics,
                                       calculate_lesion_detection_rate,
                                       calculate_signed_rve)
def train_one_fold(fold_num, df, patch_config, gpu_aug, fingerprint,
                   epochs=EPOCHS, bs=BS, lr=LR, experiment=EXPERIMENT):
    # Split: the held-out fold is validation, the other four folds are training.
    fold_df = df.copy()
    fold_df['is_val'] = fold_df['fold'] == fold_num

    dls = MedPatchDataLoaders.from_df(
        df=fold_df, img_col='image', mask_col='label', valid_col='is_val',
        patch_config=patch_config, gpu_augmentation=gpu_aug, bs=bs)

    # A fresh model + loss + Learner every fold (no weight leakage between folds).
    model = UNet(spatial_dims=3, in_channels=1, out_channels=2,
                 channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2),
                 num_res_units=2, norm=Norm.INSTANCE)
    loss_func = CustomLoss(loss_func=DiceCELoss(
        to_onehot_y=True, softmax=True, include_background=False, batch=True))
    learn = Learner(dls, model, loss_func=loss_func, metrics=[AccumulatedDice(n_classes=2)])

    best_fname = f'best_fold_{fold_num}'
    save_best = EMACheckpoint(monitor='accumulated_dice', momentum=0.9,
                              comp=np.greater, fname=best_fname, with_opt=False)
    mlflow_cb = create_mlflow_callback(
        learn, experiment_name=experiment, run_name=f'fold_{fold_num}',
        extra_tags={'fold': str(fold_num)}, dataset_version=fingerprint)

    learn.fit_one_cycle(epochs, lr, cbs=[mlflow_cb, save_best])
    learn.load(best_fname)

    # Full-volume sliding-window inference on the held-out fold (the real evaluation).
    val_df = fold_df[fold_df['fold'] == fold_num].reset_index(drop=True)
    val_img_paths = val_df['image'].tolist()
    val_mask_paths = val_df['label'].tolist()

    predictions = patch_inference(learner=learn, config=patch_config,
                                  file_paths=val_img_paths,
                                  save_dir=f'predictions/cv_fold_{fold_num}',
                                  progress=True, tta=True)

    results = []
    for i, pred in enumerate(predictions):
        # patch_inference returns predictions in native voxel space, and MedMask.create loads
        # the GT natively too, so both share the original grid. Surface metrics need the
        # per-case voxel spacing read from the GT file (the cohort spacing is non-uniform).
        gt = MedMask.create(val_mask_paths[i])
        pred_5d = pred.unsqueeze(0).float()
        gt_5d = gt.data.unsqueeze(0).float()
        spacing_mm = tio.LabelMap(val_mask_paths[i]).spacing
        sm = calculate_surface_metrics(pred_5d, gt_5d, spacing_mm=spacing_mm)
        results.append({
            'fold': fold_num,
            'image': Path(val_img_paths[i]).name,
            'dsc': calculate_dsc(pred_5d, gt_5d).mean().item(),
            'sensitivity': calculate_confusion_metrics(pred_5d, gt_5d, 'sensitivity').nanmean().item(),
            'precision': calculate_confusion_metrics(pred_5d, gt_5d, 'precision').nanmean().item(),
            'ldr': calculate_lesion_detection_rate(pred_5d, gt_5d).nanmean().item(),
            'rve': calculate_signed_rve(pred_5d, gt_5d).nanmean().item(),
            'hd95_mm': sm['hd95_mm'],
            'assd_mm': sm['assd_mm'],
        })

    results_df = pd.DataFrame(results)
    mlflow_cb.log_metrics_table(results_df.drop(columns=['fold']), display=False)
    mlflow_cb.log_dataframe(results_df)
    print(f"[Fold {fold_num}] DSC: {results_df['dsc'].mean():.4f} +/- {results_df['dsc'].std():.4f}")

    del learn, model, dls
    torch.cuda.empty_cache()
    gc.collect()
    return results_df

Run the 5-fold loop

This trains and evaluates all five folds sequentially. Each call returns a per-fold results DataFrame; we collect them for aggregation. (On a single GPU this is the slow part: five full trainings.)

This notebook uses 16 cases. MONAI’s DecathlonDataset(section="training") reserves 20% of Task02_Heart’s 20 labeled cases as a separate validation section we do not use, so KFold runs over the remaining 16, giving just 3-4 per fold. Per-fold scores and their spread will be noisy here, so treat the aggregate as illustrative rather than a precise benchmark. Real datasets give many more cases per fold.

all_results = [
    train_one_fold(f, df, patch_config, gpu_aug, med_dataset.fingerprint)
    for f in range(1, 6)
]
Logged train/val split (16 rows) to MLflow artifacts
epoch train_loss valid_loss accumulated_dice time
0 1.701012 1.647424 0.069987 00:06
1 1.639387 1.535919 0.133839 00:08
2 1.553601 1.374621 0.272732 00:08
3 1.434225 1.185381 0.377758 00:08
4 1.294279 1.030571 0.325836 00:08
5 1.151020 0.860453 0.423965 00:08
6 1.019822 0.738847 0.492733 00:08
7 0.890736 0.551925 0.643766 00:08
8 0.799126 0.590274 0.572869 00:09
9 0.719957 0.459817 0.680719 00:09
10 0.645591 0.452898 0.670767 00:08
11 0.565130 0.346835 0.753482 00:08
12 0.492632 0.300070 0.788466 00:09
13 0.452296 0.260065 0.824737 00:09
14 0.415433 0.245314 0.835402 00:09
15 0.391586 0.253193 0.826399 00:09
16 0.367455 0.231394 0.839577 00:09
17 0.342011 0.208056 0.863537 00:09
18 0.328732 0.217046 0.853384 00:09
19 0.323015 0.216044 0.852964 00:09

Training finished. Logging model artifacts to MLflow...
Logged best model weights: best_weights.pth
Logged final epoch learner: final_learner.pkl
Loaded best model weights (best_fold_1) for best learner export
Logged best epoch learner: best_learner.pkl
Saved file doesn't contain an optimizer state.
2026/07/01 11:31:16 WARNING mlflow.pytorch: Saving pytorch model by Pickle or CloudPickle format requires exercising caution because these formats rely on Python's object serialization mechanism, which can execute arbitrary code during deserialization.The recommended safe alternative is to set 'export_model' to True to save the pytorch model using the safe graph model format.
2026/07/01 11:31:16 WARNING mlflow.utils.requirements_utils: Found torch version (2.6.0+cu124) contains a local version label (+cu124). MLflow logged a pip requirement for this package as 'torch==2.6.0' without the local version label to make it installable from PyPI. To specify pip requirements containing local version labels, please use `conda_env` or `pip_requirements`.
MLflow run completed. Run ID: df66b6e83362430da80f254e74e40be1
Registered model 'UNet' already exists. Creating a new version of this model...
Created version '53' of model 'UNet'.
Saved file doesn't contain an optimizer state.
Logged metrics table to MLflow run df66b6e83362430da80f254e74e40be1
Logged DataFrame (4 rows) to MLflow run df66b6e83362430da80f254e74e40be1
[Fold 1] DSC: 0.8814 +/- 0.0239
Logged train/val split (16 rows) to MLflow artifacts
epoch train_loss valid_loss accumulated_dice time
0 1.612648 1.548852 0.089702 00:07
1 1.531349 1.409334 0.204033 00:09
2 1.425320 1.234746 0.341356 00:09
3 1.302725 1.077958 0.318564 00:09
4 1.163399 0.875728 0.478166 00:09
5 1.027537 0.749024 0.508049 00:09
6 0.897268 0.608724 0.594825 00:08
7 0.772583 0.437791 0.739318 00:08
8 0.650887 0.372989 0.748023 00:08
9 0.570302 0.284802 0.832174 00:08
10 0.484928 0.315056 0.783244 00:09
11 0.429151 0.265872 0.812273 00:08
12 0.394999 0.232472 0.840002 00:08
13 0.358419 0.226217 0.841759 00:08
14 0.329204 0.217780 0.843075 00:08
15 0.299172 0.223935 0.843498 00:08
16 0.278816 0.231911 0.841079 00:08
17 0.271072 0.196384 0.860136 00:08
18 0.262768 0.195990 0.863490 00:08
19 0.256372 0.205288 0.854965 00:08

Training finished. Logging model artifacts to MLflow...
Logged best model weights: best_weights.pth
Logged final epoch learner: final_learner.pkl
Loaded best model weights (best_fold_2) for best learner export
Logged best epoch learner: best_learner.pkl
Saved file doesn't contain an optimizer state.
2026/07/01 11:35:12 WARNING mlflow.pytorch: Saving pytorch model by Pickle or CloudPickle format requires exercising caution because these formats rely on Python's object serialization mechanism, which can execute arbitrary code during deserialization.The recommended safe alternative is to set 'export_model' to True to save the pytorch model using the safe graph model format.
2026/07/01 11:35:12 WARNING mlflow.utils.requirements_utils: Found torch version (2.6.0+cu124) contains a local version label (+cu124). MLflow logged a pip requirement for this package as 'torch==2.6.0' without the local version label to make it installable from PyPI. To specify pip requirements containing local version labels, please use `conda_env` or `pip_requirements`.
MLflow run completed. Run ID: edffa7dbe1bb4ace929fb3fe6c575db9
Registered model 'UNet' already exists. Creating a new version of this model...
Created version '54' of model 'UNet'.
Saved file doesn't contain an optimizer state.
Logged metrics table to MLflow run edffa7dbe1bb4ace929fb3fe6c575db9
Logged DataFrame (3 rows) to MLflow run edffa7dbe1bb4ace929fb3fe6c575db9
[Fold 2] DSC: 0.8548 +/- 0.0110
Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/connection.py", line 178, in close
    self._close()
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/connection.py", line 377, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 772, in run_closure
    _threading_Thread_run(self)
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/queues.py", line 271, in _feed
    queue_sem.release()
ValueError: semaphore or lock released too many times
Logged train/val split (16 rows) to MLflow artifacts
epoch train_loss valid_loss accumulated_dice time
0 1.627658 1.569205 0.072238 00:06
1 1.562548 1.462914 0.143529 00:08
2 1.467380 1.278806 0.303647 00:09
3 1.331324 1.074320 0.363793 00:08
4 1.179066 0.871011 0.454890 00:08
5 1.030407 0.782917 0.446536 00:08
6 0.907171 0.691649 0.506360 00:08
7 0.795361 0.517195 0.641513 00:08
8 0.674236 0.391943 0.738520 00:09
9 0.587636 0.373087 0.737233 00:08
10 0.536868 0.469595 0.632056 00:08
11 0.465647 0.276018 0.808859 00:08
12 0.402727 0.326512 0.751293 00:08
13 0.358276 0.315360 0.750810 00:08
14 0.323637 0.251212 0.815298 00:08
15 0.298074 0.272793 0.796635 00:08
16 0.277327 0.256681 0.807365 00:08
17 0.261703 0.240993 0.814648 00:09
18 0.249138 0.223318 0.826831 00:08
19 0.243259 0.236280 0.819128 00:08

Training finished. Logging model artifacts to MLflow...
Logged best model weights: best_weights.pth
Logged final epoch learner: final_learner.pkl
Loaded best model weights (best_fold_3) for best learner export
Logged best epoch learner: best_learner.pkl
Saved file doesn't contain an optimizer state.
2026/07/01 11:38:57 WARNING mlflow.pytorch: Saving pytorch model by Pickle or CloudPickle format requires exercising caution because these formats rely on Python's object serialization mechanism, which can execute arbitrary code during deserialization.The recommended safe alternative is to set 'export_model' to True to save the pytorch model using the safe graph model format.
2026/07/01 11:38:57 WARNING mlflow.utils.requirements_utils: Found torch version (2.6.0+cu124) contains a local version label (+cu124). MLflow logged a pip requirement for this package as 'torch==2.6.0' without the local version label to make it installable from PyPI. To specify pip requirements containing local version labels, please use `conda_env` or `pip_requirements`.
MLflow run completed. Run ID: 8deae63e58324828870cdd92ce2cbe0c
Registered model 'UNet' already exists. Creating a new version of this model...
Created version '55' of model 'UNet'.
Saved file doesn't contain an optimizer state.
Logged metrics table to MLflow run 8deae63e58324828870cdd92ce2cbe0c
Logged DataFrame (3 rows) to MLflow run 8deae63e58324828870cdd92ce2cbe0c
[Fold 3] DSC: 0.8491 +/- 0.0419
Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/connection.py", line 178, in close
    self._close()
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/connection.py", line 377, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 772, in run_closure
    _threading_Thread_run(self)
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/queues.py", line 271, in _feed
    queue_sem.release()
ValueError: semaphore or lock released too many times
Logged train/val split (16 rows) to MLflow artifacts
epoch train_loss valid_loss accumulated_dice time
0 1.407753 1.365662 0.148875 00:06
1 1.338845 1.256114 0.275352 00:08
2 1.242933 1.130846 0.312414 00:08
3 1.118479 0.887514 0.495784 00:09
4 1.025925 0.882973 0.348771 00:09
5 0.919653 0.676784 0.543092 00:09
6 0.784236 0.545908 0.637265 00:09
7 0.688505 0.545615 0.584027 00:09
8 0.593948 0.463441 0.646460 00:09
9 0.519339 0.341473 0.748775 00:09
10 0.456259 0.327666 0.754947 00:09
11 0.396849 0.373677 0.718520 00:09
12 0.354889 0.315748 0.753925 00:09
13 0.323949 0.296111 0.775431 00:09
14 0.290599 0.293732 0.782449 00:09
15 0.276306 0.244641 0.814663 00:09
16 0.248714 0.230736 0.828741 00:09
17 0.237172 0.273702 0.796766 00:09
18 0.234188 0.246298 0.816564 00:09
19 0.223202 0.272461 0.802942 00:09

Training finished. Logging model artifacts to MLflow...
Logged best model weights: best_weights.pth
Logged final epoch learner: final_learner.pkl
Loaded best model weights (best_fold_4) for best learner export
Logged best epoch learner: best_learner.pkl
Saved file doesn't contain an optimizer state.
2026/07/01 11:42:50 WARNING mlflow.pytorch: Saving pytorch model by Pickle or CloudPickle format requires exercising caution because these formats rely on Python's object serialization mechanism, which can execute arbitrary code during deserialization.The recommended safe alternative is to set 'export_model' to True to save the pytorch model using the safe graph model format.
2026/07/01 11:42:50 WARNING mlflow.utils.requirements_utils: Found torch version (2.6.0+cu124) contains a local version label (+cu124). MLflow logged a pip requirement for this package as 'torch==2.6.0' without the local version label to make it installable from PyPI. To specify pip requirements containing local version labels, please use `conda_env` or `pip_requirements`.
MLflow run completed. Run ID: 918e027516604d5e9e86e9f27ecb85ed
Registered model 'UNet' already exists. Creating a new version of this model...
Created version '56' of model 'UNet'.
Saved file doesn't contain an optimizer state.
Logged metrics table to MLflow run 918e027516604d5e9e86e9f27ecb85ed
Logged DataFrame (3 rows) to MLflow run 918e027516604d5e9e86e9f27ecb85ed
[Fold 4] DSC: 0.8226 +/- 0.0703
Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/connection.py", line 178, in close
    self._close()
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/connection.py", line 377, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 772, in run_closure
    _threading_Thread_run(self)
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/queues.py", line 271, in _feed
    queue_sem.release()
ValueError: semaphore or lock released too many times
Logged train/val split (16 rows) to MLflow artifacts
epoch train_loss valid_loss accumulated_dice time
0 1.633976 1.574071 0.064475 00:06
1 1.557520 1.455876 0.155453 00:09
2 1.464961 1.290834 0.272214 00:09
3 1.329806 1.076808 0.338858 00:09
4 1.179034 0.942880 0.357032 00:09
5 1.050753 0.922587 0.291638 00:09
6 0.952064 0.710765 0.537526 00:09
7 0.850799 0.565738 0.647201 00:09
8 0.733711 0.555793 0.585369 00:10
9 0.619929 0.505191 0.623079 00:09
10 0.548552 0.351475 0.749176 00:09
11 0.461237 0.380307 0.706646 00:09
12 0.413232 0.362005 0.723976 00:09
13 0.368844 0.338266 0.734550 00:09
14 0.338809 0.280755 0.791673 00:09
15 0.322456 0.313366 0.763290 00:09
16 0.297837 0.276132 0.791697 00:09
17 0.292470 0.272496 0.796069 00:09
18 0.287070 0.264539 0.802578 00:10
19 0.292141 0.265273 0.806502 00:09

Training finished. Logging model artifacts to MLflow...
Logged best model weights: best_weights.pth
Logged final epoch learner: final_learner.pkl
Loaded best model weights (best_fold_5) for best learner export
Logged best epoch learner: best_learner.pkl
Saved file doesn't contain an optimizer state.
2026/07/01 11:46:53 WARNING mlflow.pytorch: Saving pytorch model by Pickle or CloudPickle format requires exercising caution because these formats rely on Python's object serialization mechanism, which can execute arbitrary code during deserialization.The recommended safe alternative is to set 'export_model' to True to save the pytorch model using the safe graph model format.
2026/07/01 11:46:53 WARNING mlflow.utils.requirements_utils: Found torch version (2.6.0+cu124) contains a local version label (+cu124). MLflow logged a pip requirement for this package as 'torch==2.6.0' without the local version label to make it installable from PyPI. To specify pip requirements containing local version labels, please use `conda_env` or `pip_requirements`.
MLflow run completed. Run ID: f0f69e9d48db48ac8102ace2fcb85457
Registered model 'UNet' already exists. Creating a new version of this model...
Created version '57' of model 'UNet'.
Saved file doesn't contain an optimizer state.
Logged metrics table to MLflow run f0f69e9d48db48ac8102ace2fcb85457
Logged DataFrame (3 rows) to MLflow run f0f69e9d48db48ac8102ace2fcb85457
[Fold 5] DSC: 0.7897 +/- 0.0933
Exception ignored in: <function _ConnectionBase.__del__ at 0x7f55ed021c60>
Traceback (most recent call last):
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/connection.py", line 133, in __del__
    self._close()
  File "/home/sathiesh/miniconda3/envs/fastmonai-dev/lib/python3.11/multiprocessing/connection.py", line 377, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor

Aggregate cross-validation results

The cross-validation estimate is the mean +/- std across all held-out cases (pooled over the five folds) for each metric; the per-fold DSC block below is where you see fold-to-fold variation. We exclude non-finite hd95_mm / assd_mm values (which occur when exactly one of the prediction or ground truth is empty for a case) from those means, mirroring aggregate_results in train_cv.py.

cv_summary = pd.concat(all_results, ignore_index=True)
cv_summary.to_csv('cv_summary.csv', index=False)

metrics = ['dsc', 'sensitivity', 'precision', 'ldr', 'rve', 'hd95_mm', 'assd_mm']
inf_excluded = {'hd95_mm', 'assd_mm'}  # empty-case inf must not poison the mean

print("=" * 46)
print("  CROSS-VALIDATION SUMMARY")
print("=" * 46)
print(f"  Folds completed: {sorted(int(x) for x in cv_summary['fold'].unique())}")
print(f"  Total subjects:  {len(cv_summary)}\n")
print(f"  {'Metric':<14}{'Mean':>10}{'Std':>10}")
print(f"  {'-' * 34}")
for m in metrics:
    col = cv_summary[m]
    note = ""
    if m in inf_excluded:
        col = col[np.isfinite(col)]
        n_skip = len(cv_summary) - len(col)
        if n_skip:
            note = f"  ({n_skip} non-finite excl.)"
    mean = col.mean() if len(col) else float('nan')
    std = col.std() if len(col) else float('nan')
    print(f"  {m:<14}{mean:>10.4f}{std:>10.4f}{note}")

print("\n  Per-fold DSC:")
for fold_num, group in cv_summary.groupby('fold'):
    print(f"    Fold {fold_num}: {group['dsc'].mean():.4f} +/- {group['dsc'].std():.4f}")

print("\n  Saved to cv_summary.csv")
cv_summary
==============================================
  CROSS-VALIDATION SUMMARY
==============================================
  Folds completed: [1, 2, 3, 4, 5]
  Total subjects:  16

  Metric              Mean       Std
  ----------------------------------
  dsc               0.8421    0.0571
  sensitivity       0.8349    0.0593
  precision         0.8566    0.0920
  ldr               1.0000    0.0000
  rve              -0.0134    0.1328
  hd95_mm          12.6655    5.6159
  assd_mm           2.6557    1.1787

  Per-fold DSC:
    Fold 1: 0.8814 +/- 0.0239
    Fold 2: 0.8548 +/- 0.0110
    Fold 3: 0.8491 +/- 0.0419
    Fold 4: 0.8226 +/- 0.0703
    Fold 5: 0.7897 +/- 0.0933

  Saved to cv_summary.csv
fold image dsc sensitivity precision ldr rve hd95_mm assd_mm
0 1 la_030.nii.gz 0.893506 0.883842 0.903383 1.0 -0.021631 10.986054 1.785649
1 1 la_024.nii.gz 0.868088 0.903248 0.835562 1.0 0.081007 10.905057 2.101182
2 1 la_023.nii.gz 0.908257 0.866054 0.954783 1.0 -0.092931 5.332860 1.423619
3 1 la_004.nii.gz 0.855608 0.828063 0.885048 1.0 -0.064386 18.868773 2.949089
4 2 la_011.nii.gz 0.855397 0.869438 0.841803 1.0 0.032829 12.456079 2.651618
5 2 la_022.nii.gz 0.865549 0.928974 0.810232 1.0 0.146553 8.220000 1.950285
6 2 la_007.nii.gz 0.843481 0.831524 0.855788 1.0 -0.028353 13.512859 2.807123
7 3 la_021.nii.gz 0.860764 0.820345 0.905371 1.0 -0.093913 10.821344 2.176539
8 3 la_029.nii.gz 0.883969 0.872420 0.895827 1.0 -0.026129 6.065236 1.582677
9 3 la_014.nii.gz 0.802560 0.679051 0.980985 1.0 -0.307786 15.377025 3.351596
10 4 la_009.nii.gz 0.819264 0.788457 0.852575 1.0 -0.075205 13.438181 2.660054
11 4 la_020.nii.gz 0.754044 0.795818 0.716437 1.0 0.110799 16.732290 3.771815
12 4 la_016.nii.gz 0.894491 0.858836 0.933234 1.0 -0.079721 8.646152 1.781026
13 5 la_018.nii.gz 0.834325 0.851660 0.817682 1.0 0.041554 12.574852 2.530277
14 5 la_017.nii.gz 0.682516 0.778898 0.607360 1.0 0.282432 28.770000 6.364560
15 5 la_005.nii.gz 0.852396 0.801687 0.909954 1.0 -0.118981 9.941730 2.603577

A note on surface metrics: grid-based vs mesh-based

The ASSD, HD95 and NSD above come from calculate_surface_metrics, which uses MONAI’s grid-based implementation. Grid-based boundary metrics have a known voxel-discretization (staircasing) bias and can disagree between implementations. For publication-grade surface distances, a mesh-based library such as MeshMetrics extracts a surface mesh from the mask and measures continuous, sub-voxel point-to-mesh distances, removing the voxel-size rounding that grid methods have. The mesh is still an approximation of the true surface, just a much finer one than the voxel staircase. It installs from GitHub (pip install git+https://github.com/gasperpodobnik/MeshMetrics.git) and can be swapped in for final benchmark reporting. See MeshMetrics (arXiv:2509.05670) and the implementation-pitfalls study (arXiv:2410.02630).

Part 2: Final model on all data (deploy)

Cross-validation in Part 1 gave us a performance estimate, but it produced five models, each trained on only 4/5 of the data, with one fold held out of every model. None of them is what we deploy.

For deployment we train one model on all the available data, with no fold held out, so it sees every example. We still need something to watch during training, so (following research/vs_seg/final_training/train_final.py) we duplicate a small random subset of the training data into a nominal validation split. It is used only to plot the loss and pseudo-Dice curve; it is not a held-out test set, since those subjects are also in training. The reported performance number is the cross-validation estimate from Part 1.

Build the all-data DataLoaders

Every subject is tagged is_val=False (training). We then append a duplicated random subset tagged is_val=True for monitoring. MedPatchDataLoaders.from_df splits on is_val, wiring reorder / resample / normalization from the same PatchConfig used in Part 1, so preprocessing stays identical.

SEED = 42
NOMINAL_VAL_PCT = 0.15

# All subjects go into training, and a small
# random subset is duplicated into a nominal validation split used only to watch the loss
# and pseudo-Dice during training (it is not a held-out test set).
n_val = max(2, int(len(df) * NOMINAL_VAL_PCT))
nominal_val_df = df.sample(n=n_val, random_state=SEED).copy()

all_train_df = df.copy()
all_train_df['is_val'] = False
nominal_val_df['is_val'] = True
final_df = pd.concat([all_train_df, nominal_val_df], ignore_index=True)

dls = MedPatchDataLoaders.from_df(
    df=final_df, img_col='image', mask_col='label', valid_col='is_val',
    patch_config=patch_config, gpu_augmentation=gpu_aug, bs=BS)

print(f"Training subjects (all data):                  {len(dls.train.subjects_dataset)}")
print(f"Nominal validation subjects (monitoring only): {len(dls.valid.subjects_dataset)}")
Training subjects (all data):                  16
Nominal validation subjects (monitoring only): 2

Train the final model

A fresh UNet, loss, and Learner (identical architecture and configuration to the fold models), now trained on all data. The MLflow run is tagged training_type=final to distinguish it from the fold runs.

model = UNet(spatial_dims=3, in_channels=1, out_channels=2,
             channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2),
             num_res_units=2, norm=Norm.INSTANCE)

loss_func = CustomLoss(loss_func=DiceCELoss(
    to_onehot_y=True, softmax=True, include_background=False, batch=True))

learn = Learner(dls, model, loss_func=loss_func, metrics=[AccumulatedDice(n_classes=2)])

mlflow_cb = create_mlflow_callback(
    learn, experiment_name=EXPERIMENT, run_name='final_all_data',
    extra_tags={'training_type': 'final'}, dataset_version=med_dataset.fingerprint)

learn.fit_one_cycle(EPOCHS, LR, cbs=[mlflow_cb])
Logged train/val split (18 rows) to MLflow artifacts
epoch train_loss valid_loss accumulated_dice time
0 1.832774 1.780650 0.027345 00:08
1 1.752023 1.634611 0.065704 00:10
2 1.611906 1.381001 0.237814 00:10
3 1.400263 1.094604 0.297296 00:11
4 1.170122 0.825253 0.469552 00:11
5 0.965562 0.639839 0.570667 00:10
6 0.792902 0.466827 0.691346 00:09
7 0.637594 0.419168 0.707288 00:09
8 0.525203 0.255464 0.839762 00:10
9 0.477103 0.254502 0.832425 00:10
10 0.411660 0.210097 0.856949 00:09
11 0.369581 0.255328 0.819710 00:09
12 0.329591 0.185264 0.869347 00:09
13 0.309124 0.201647 0.850481 00:09
14 0.281584 0.167785 0.879528 00:09
15 0.264362 0.151247 0.892409 00:10
16 0.261826 0.151399 0.888759 00:09
17 0.242328 0.133658 0.905754 00:10
18 0.232283 0.135610 0.904789 00:10
19 0.223154 0.128345 0.909170 00:10
2026/07/01 11:51:06 WARNING mlflow.pytorch: Saving pytorch model by Pickle or CloudPickle format requires exercising caution because these formats rely on Python's object serialization mechanism, which can execute arbitrary code during deserialization.The recommended safe alternative is to set 'export_model' to True to save the pytorch model using the safe graph model format.
2026/07/01 11:51:06 WARNING mlflow.utils.requirements_utils: Found torch version (2.6.0+cu124) contains a local version label (+cu124). MLflow logged a pip requirement for this package as 'torch==2.6.0' without the local version label to make it installable from PyPI. To specify pip requirements containing local version labels, please use `conda_env` or `pip_requirements`.

Training finished. Logging model artifacts to MLflow...
Logged final epoch learner: final_learner.pkl
MLflow run completed. Run ID: b53f9384c29341109aa32b353781370d
Registered model 'UNet' already exists. Creating a new version of this model...
Created version '58' of model 'UNet'.

Export for deployment

We export the trained Learner and persist the PatchConfig with store_patch_variables. Together they are everything the patch inference tutorial needs to load the model and reproduce the exact preprocessing.

Path('models').mkdir(exist_ok=True)
learn.export('models/final_learner.pkl')

store_patch_variables(
    config_fn='patch_config.json',
    patch_size=patch_config.patch_size,
    patch_overlap=patch_config.patch_overlap,
    aggregation_mode=patch_config.aggregation_mode,
    apply_reorder=patch_config.apply_reorder,
    target_spacing=patch_config.target_spacing,
    sampler_type=patch_config.sampler_type,
    label_probabilities=patch_config.label_probabilities,
    samples_per_volume=patch_config.samples_per_volume,
    queue_length=patch_config.queue_length,
    queue_num_workers=patch_config.queue_num_workers,
    keep_largest_component=patch_config.keep_largest_component,
    normalization=patch_config.normalization
)

print("Saved models/final_learner.pkl and patch_config.json")
Saved models/final_learner.pkl and patch_config.json

View experiment tracking

All five fold runs (tagged fold=1..5) and the final run (tagged training_type=final) are logged under the Task02_Heart_5Fold_CV experiment.

mlflow_ui = MLflowUIManager()
mlflow_ui.start_ui()
Reusing existing MLflow UI on port 5001
Open MLflow UI
URL: http://localhost:5001
True

Summary

This tutorial taught the full train-and-ship methodology for patch-based 3D segmentation:

Part 1, cross-validation (estimate)

  1. Fold generation: KFold assigns each subject to one validation fold.
  2. Per-fold loop: a fresh Learner per fold (no weight leakage), trained on 4 folds and evaluated on the 5th.
  3. Per-fold metrics: full-volume patch_inference + the metric suite (DSC, sensitivity, precision, LDR, RVE, HD95, ASSD), each fold its own MLflow run.
  4. Aggregate: the mean and std across all held-out cases is the generalization estimate.

Part 2, final model (deploy)

  1. Train on all data: one model on every subject (a small duplicated nominal-val subset for monitoring only).
  2. Export: learn.export(...) + store_patch_variables(...) for inference (see 12c).

In short: cross-validate to estimate, retrain on all data to deploy. The CV mean and std is your reported performance; the all-data model is what you ship.

mlflow_ui.stop()
MLflow UI was started externally — not stopping it