from fastMONAI.vision_all import *
Classification
The following line imports all of the functions and classes from the fastMONAI library:
Downloading external data
To demonstrate the use of fastMONAI, we download the NoduleMNIST3D dataset from MedMNIST v2, a dataset containing lung nodules with labels indicating whether the nodules are benign (0) or malignant (1):
= download_medmnist3d_dataset(study='NoduleMNIST3D', max_workers=2) df,_
Inspecting the data
Let’s look at how the processed DataFrame is formatted:
df.head()
img_path | labels | is_val | |
---|---|---|---|
0 | ../data/NoduleMNIST3D/train_images/0_nodule.nii.gz | 0 | False |
1 | ../data/NoduleMNIST3D/train_images/1_nodule.nii.gz | 1 | False |
2 | ../data/NoduleMNIST3D/train_images/2_nodule.nii.gz | 1 | False |
3 | ../data/NoduleMNIST3D/train_images/3_nodule.nii.gz | 0 | False |
4 | ../data/NoduleMNIST3D/train_images/4_nodule.nii.gz | 0 | False |
In fastMONAI, various data augmentation techniques are available for training vision models, and they can also optionally be applied during inference. The following code cell specifies a list of transformations to be applied to the items in the training set. The complete list of available transformations in the library can be found at https://fastmonai.no/vision_augment.
= [ZNormalization(), PadOrCrop(size=28), RandomAffine(degrees=35, isotropic=True)] item_tfms
Before feeding the data into a model, we must create a DataLoaders
object for our dataset. There are several ways to get the data in DataLoaders
. In the following line, we call the ImageDataLoaders.from_df
factory method, which is the most basic way of building a DataLoaders
.
Here, we pass the processed DataFrame, define the columns for the images fn_col
and the labels label_col
, some transforms item_tfms
, voxel spacing resample
, and the batch size bs
.
= MedImageDataLoaders.from_df(df, fn_col='img_path', label_col='labels',
dls =ColSplitter('is_val'), item_tfms=item_tfms,
splitter=1, bs=64) resample
We can now take a look at a batch of images in the training set using show_batch
:
= 2, anatomical_plane = 2) dls.show_batch(max_n
We’re now ready to construct a deep learning classification model.
Create and train a 3D deep learning model
We import a classification network from MONAI and configure it based on our task, including defining the input image size, the number of classes to predict, channels, etc.
from monai.networks.nets import Classifier
= Classifier(in_shape=[1, 28, 28, 28], classes=2,
model =(8, 16, 32, 64), strides=(2, 2, 2)) channels
Then we create a Learner
, which is a fastai object that combines the data and our defined model for training.
= Learner(dls, model, metrics=accuracy) learn
4) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.574831 | 0.409310 | 0.837121 | 00:04 |
1 | 0.477787 | 0.383374 | 0.840909 | 00:03 |
2 | 0.408993 | 0.345690 | 0.859848 | 00:03 |
3 | 0.365673 | 0.357712 | 0.840909 | 00:03 |
Note: Small random variations are involved in training CNN models. Hence, when running the notebook, you may see different results.
With the model trained, let’s look at some predictions on the validation data. The show_results
method plots instances, their target values, and their corresponding predictions from the model.
=2, anatomical_plane=2) learn.show_results(max_n
Model evaluation and interpretation
Let’s look at how often and for what instances our trained model becomes confused while making predictions on the validation data:
= ClassificationInterpretation.from_learner(learn) interp
= interp.plot_confusion_matrix() cm
Class imbalance is a common challenge in medical datasets, and it is something we’re facing in our example dataset. When dealing with a classification task on such imbalanced datasets, specific sampling techniques may be necessary. See the Advanced) section for further details.
=4, anatomical_plane=2) interp.plot_top_losses(k
Here are the instances our model was most confused about (in other words, most confident but wrong)
The following line imports only the needed components for our classification task:
from fastai.vision.learner import Learner
from fastai.losses import CrossEntropyLossFlat
from fastai.vision.data import ImageBlock, CategoryBlock
from fastai.data.transforms import ColReader, ColSplitter
from fastai.metrics import accuracy
from fastai.callback.schedule import lr_find
from monai.networks.nets import Classifier
from fastMONAI.vision_augmentation import PadOrCrop, RandomAffine, ZNormalization
from fastMONAI.external_data import download_medmnist3d_dataset
from fastMONAI.dataset_info import MedDataset, get_class_weights
from fastMONAI.vision_core import MedImage
from fastMONAI.vision_data import MedDataBlock
from fastMONAI.utils import store_variables
= download_medmnist3d_dataset(study='NoduleMNIST3D', max_workers = 8) df_train_val, df_test
1) df_train_val.head(
img_path | labels | is_val | |
---|---|---|---|
0 | ../data/NoduleMNIST3D/train_images/0_nodule.nii.gz | 0 | False |
= MedDataset(img_list=df_train_val.img_path.tolist(), max_workers=12) med_dataset
= med_dataset.summary() data_info_df
data_info_df.head()
dim_0 | dim_1 | dim_2 | voxel_0 | voxel_1 | voxel_2 | orientation | example_path | total | |
---|---|---|---|---|---|---|---|---|---|
0 | 28 | 28 | 28 | 1.0 | 1.0 | 1.0 | RAS+ | ../data/NoduleMNIST3D/train_images/0_nodule.nii.gz | 1323 |
= med_dataset.suggestion()
resample, reorder resample, reorder
([1.0, 1.0, 1.0], False)
= med_dataset.get_largest_img_size(resample=resample)
img_size img_size
[28.0, 28.0, 28.0]
=64
bs= [1, 28, 28, 28] in_shape
= [ZNormalization(), PadOrCrop(size=28), RandomAffine(degrees=35, isotropic=True)] item_tfms
As we mentioned earlier, there are several ways to get the data in DataLoaders
. In this section, let’s rebuild using DataBlock
. Here we need to define what our input and target should be (MedImage
and CategoryBlock
for classification), how to get the images and the labels, how to split the data, item transforms that should be applied during training, reorder voxel orientations, and voxel spacing. Take a look at fastai’s documentation for DataBlock for further information: https://docs.fast.ai/data.block.html#DataBlock.
= MedDataBlock(blocks=(ImageBlock(cls=MedImage), CategoryBlock),
dblock =ColSplitter('is_val'),
splitter=ColReader('img_path'),
get_x=ColReader('labels'),
get_y=item_tfms,
item_tfms=reorder,
reorder=resample) resample
Now we pass our processed DataFrame and the bath size to create a DataLoaders
object.
= dblock.dataloaders(df_train_val, bs=bs)
dls len(dls.train_ds.items), len(dls.valid_ds.items)
(1158, 165)
=6, figsize=(5, 5), anatomical_plane=2) dls.show_batch(max_n
= Classifier(in_shape=in_shape, classes=2,
model =(8, 16, 32, 64), strides=(2, 2, 2)) channels
Choosing a loss function
Class imbalance is a common challenge in medical datasets, and it is something we’re facing in our example dataset:
print(df_train_val.labels.value_counts())
labels
0 986
1 337
Name: count, dtype: int64
There are multiple ways to deal with class imbalance. A straightforward technique is to use balancing weights in the model’s loss function, i.e., penalizing misclassifications for instances belonging to the minority class more heavily than those of the majority class.
= df_train_val.loc[~df_train_val.is_val]['labels']
train_labels = get_class_weights(train_labels)
class_weights print(class_weights)
tensor([0.6709, 1.9627])
= CrossEntropyLossFlat(weight=class_weights) loss_func
= Learner(dls, model, loss_func=loss_func, metrics=accuracy) learn
learn.summary()
Classifier (Input shape: 64 x 1 x 28 x 28 x 28)
============================================================================
Layer (type) Output Shape Param # Trainable
============================================================================
64 x 8 x 14 x 14 x
Conv3d 224 True
InstanceNorm3d 0 False
PReLU 1 True
Conv3d 1736 True
InstanceNorm3d 0 False
PReLU 1 True
Conv3d 224 True
____________________________________________________________________________
64 x 16 x 7 x 7 x 7
Conv3d 3472 True
InstanceNorm3d 0 False
PReLU 1 True
Conv3d 6928 True
InstanceNorm3d 0 False
PReLU 1 True
Conv3d 3472 True
____________________________________________________________________________
64 x 32 x 4 x 4 x 4
Conv3d 13856 True
InstanceNorm3d 0 False
PReLU 1 True
Conv3d 27680 True
InstanceNorm3d 0 False
PReLU 1 True
Conv3d 13856 True
Reshape
____________________________________________________________________________
64 x 2048
Flatten
____________________________________________________________________________
64 x 2
Linear 4098 True
____________________________________________________________________________
Total params: 75,552
Total trainable params: 75,552
Total non-trainable params: 0
Optimizer used: <function Adam at 0x7f368fbefaf0>
Loss function: FlattenedLoss of CrossEntropyLoss()
Callbacks:
- TrainEvalCallback
- CastToTensor
- Recorder
- ProgressCallback
= learn.lr_find() lr
4, lr.valley) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.631502 | 0.476412 | 0.842424 | 00:03 |
1 | 0.552307 | 0.455522 | 0.824242 | 00:03 |
2 | 0.505492 | 0.442922 | 0.860606 | 00:03 |
3 | 0.474527 | 0.442131 | 0.860606 | 00:03 |
=2) learn.show_results(anatomical_plane
'model-2') learn.save(
Path('models/model-2.pth')
Inference on test data
'model-2'); learn.load(
= df_test dls.valid_ds.items
= learn.get_preds(); preds, targs
accuracy(preds, targs)
TensorBase(0.8194)
Test-time augmentation
Test-time augmentation (TTA) is a technique where you apply data augmentation transforms when making predictions to produce average output. In addition to often yielding better performance, the variation in the output of the TTA runs can provide some measure of its robustness and sensitivity to augmentations.
= learn.tta(); preds, targs
accuracy(preds, targs)
TensorBase(0.8387)
Export learner
='vars.pkl', size=in_shape, reorder=reorder, resample=resample) store_variables(pkl_fn
'learner.pkl') learn.export(
Make a simple web app
Make a simple web application with Gradio and host it on Hugging Face Spaces.
= load_learner('learner.pkl', cpu=True)
learn = load_variables(pkl_fn='vars.pkl') _, reorder, resample
import gradio as gr
=lambda fileobj: gradio_image_classifier(fileobj, learn,
gr.Interface(fn
reorder, resample),=['file'],
inputs=gr.Label(num_top_classes=2),
outputs=[df_test.img_path[0], df_test.img_path[200]],
examples='Example app').launch() title
Running on local URL: http://127.0.0.1:7860
To create a public link, set `share=True` in `launch()`.