#all_skip
from fastMONAI.vision_all import *
from monai.apps import DecathlonDataset
Inference with exported learner
= {
tasks "brain": "Task01_BrainTumour",
"heart": "Task02_Heart",
"spleen": "Task09_Spleen"
}
= tasks["heart"]
task
= Path(f"model_artifacts/{task}") model_artifact_path
= Path('../data')
path =True) path.mkdir(exist_ok
= DecathlonDataset(root_dir=path, task=task, section="test", download=True,
test_data =0, num_workers=3) cache_num
2025-08-29 14:42:58,879 - INFO - Verified 'Task02_Heart.tar', md5: 06ee59366e1e5124267b774dbd654057.
2025-08-29 14:42:58,879 - INFO - File exists: ../data/Task02_Heart.tar, skipped downloading.
2025-08-29 14:42:58,880 - INFO - Non-empty folder exists in ../data/Task02_Heart, skipped extracting.
= [data['image'] for data in test_data.data] test_imgs
import mlflow
from mlflow.tracking import MlflowClient
= MlflowClient()
client = mlflow.get_experiment_by_name(task)
experiment
# Get latest run
= mlflow.search_runs(
runs =[experiment.experiment_id],
experiment_ids=["start_time DESC"],
order_by=1
max_results
)
= runs.iloc[0].run_id
latest_run_id print(f"Loading artifacts from run: {latest_run_id}")
# Download artifacts
= client.download_artifacts(latest_run_id, "model/learner.pkl")
learner_path = client.download_artifacts(latest_run_id, "config/inference_settings.pkl") config_path
= load_learner(learner_path, cpu=False); learn_inf
load_learner` uses Python's insecure pickle module, which can execute malicious arbitrary code when loading. Only load files you trust.
If you only need to load model weights and optimizer state, use the safe `Learner.load` instead.
= load_variables(pkl_fn=config_path)
_, reorder, resample reorder, resample
(False, [1.25, 1.25, 1.37])
= Path(f'../data/{task}/pred_masks')
save_path =True, exist_ok=True) save_path.mkdir(parents
= 3
idx = test_imgs[idx]
img_fn img_fn
'../data/Task02_Heart/imagesTs/la_001.nii.gz'
= inference(learn_inf, reorder=reorder, resample=resample, fn=img_fn, save_path=save_path) pred_fn
from torchio import Subject, ScalarImage, LabelMap
= Subject(image=ScalarImage(img_fn), mask=LabelMap(pred_fn))
subject =(10,5)) subject.plot(figsize