In this project, we want to perform image classification using the fastai deep learning framework based on PyTorch. I will present a step-by-step solution to transfer learning in the case where we have multi-channel images.
Most commonly available pre-trained convolutional neural networks (CNNs) have been trained on RGB (3-channel) images, posing a big challenge when dealing with multi-channel images and low sample size. Training a CNN from scratch is not recommended in this particular scenario due to overfitting.
One approach to transfer learning consists in duplicating RGB weights of the first convolution layer on top of the missing channels, cycling at every 3 additional channels (i.e., RGBX image will use RGBR weights, RGBXYZ will use RGBRGB weights and so forth).
Let’s start from scratch by first implementing a custom dataloader that can handle multi-channel images, then normalise the whole dataset and finally perform transfer learning.
1. Data preparation
import os
from pathlib import Path
from PIL import Image
path_to_images = Path(os.getcwd()) / "data"
img = Image.open(path_to_images / "1.jpg")
display(img)
Let’s check some properties of this figure: dimension and type
props = f"shape: {img.size} - type: {img.mode}"
props
'shape: (614, 816) - type: RGB'
In tensor notation, this is a stack of 3 MxN matrices (corresponding to the channels R, G, and B) with values given by pixel intensities in [0, 1]
from torchvision.transforms.functional import pil_to_tensor
x = pil_to_tensor(img)
props = f"shape: {x.shape}"
props
'shape: torch.Size([3, 816, 614])'
The dataset we are using for this tutorial is composed of 12 RGB images. Let’s augment the number of channels for each of them by repeating the 3 channels n times (e.g., n=3) so to generate multi-channel images
import numpy as np
import torch
images = [Image.open(path_to_images / f"{i+1}.jpg") for i in range(12)]
x_list = [pil_to_tensor(img) for img in images]
def repeat(x, n_times=3):
return torch.cat(n_times*[x], dim=0)
n_times = 3
x_aug_list = [repeat(x, n_times) for x in x_list]
props = f"shape: {x_aug_list[0].shape}"
props
'shape: torch.Size([9, 816, 614])'
Let’s save these augmented images as multi-channel .tif images. Also, we construct a pandas dataframe describing a (mock) traning dataset based on the multi-channel images and we save it as a .csv file
import random
import pandas as pd
from torchvision.transforms.functional import to_pil_image
path_to_multi = path_to_images / "multi"
path_to_train = path_to_multi / "train"
path_to_train.mkdir(parents=True, exist_ok=True)
random.seed(8) # enforcing training dataset labels reproducibility
df = pd.DataFrame()
for i, x in enumerate(x_aug_list): # loop over all multi-channel images
img = to_pil_image(x[0])
img_list = [to_pil_image(x[i+1]) for i in range(3*n_times-1)] # for each image, list each of its channels separately
fname_i = f"{i+1}_multi"
label_i = random.randint(0, 1) # assign random label supposing this is a binary classification problem
img.save(path_to_train / f"{fname_i}.tif", save_all=True, append_images=img_list)
dct_i = {"fname": fname_i, "label": label_i, "is_valid": False, "is_test": False} # build a dataset with no validation nor testing data points
df = pd.concat([df, pd.DataFrame.from_dict(dct_i, orient="index").T])
df.to_csv(path_to_multi / "train.csv", index=False)
2. Data loaders
Now we need to define some tools to be able to deal with multi-channel images in the context of fastai
from PIL import ImageSequence
from functools import partial
from fastai.data.block import TransformBlock
from fastai.torch_core import TensorImage
from torchvision.transforms import ToTensor
def open_image(fn:Path) -> torch.Tensor:
transform = ToTensor()
img = Image.open(fn)
x = torch.cat([transform(channel) for channel in ImageSequence.Iterator(img)], dim=0)
return x
class MultiChannelTensor(TensorImage):
@classmethod
def create(cls, fn: Path) -> None:
return cls(open_image(fn))
def show(self):
pass # can implement custom plotting functionality here for multi-channel (n > 3) images
def __repr__(self):
return (f"MultiChannelTensor: {self.shape}")
def MultiChannelTensorBlock():
return TransformBlock(type_tfms=partial(MultiChannelTensor.create), batch_tfms=None)
We can use the defined MultiChannelTensorBlock to create a data block to be fed into an image dataloader
from fastai.data.transforms import ColSplitter
from fastai.vision.data import CategoryBlock, ColReader, DataBlock, ImageDataLoaders
db = DataBlock(
blocks=(MultiChannelTensorBlock, CategoryBlock),
get_x=ColReader("fname", pref=path_to_train, suff=".tif"),
get_y=ColReader("label"),
splitter=ColSplitter("is_valid"),
)
db.summary(df)
Setting-up type transforms pipelines
Collecting items from fname label is_valid is_test
0 1_multi 0 False False
0 2_multi 1 False False
0 3_multi 1 False False
0 4_multi 0 False False
0 5_multi 0 False False
0 6_multi 0 False False
0 7_multi 0 False False
0 8_multi 0 False False
0 9_multi 0 False False
0 10_multi 0 False False
0 11_multi 1 False False
0 12_multi 0 False False
Found 12 items
2 datasets of sizes 12,0
Setting up Pipeline: ColReader -- {'cols': 'fname', 'pref': Path('/Users/slongobardi/Projects/multi-channel-images/data/multi/train'), 'suff': '.tif', 'label_delim': None} -> partial
Setting up Pipeline: ColReader -- {'cols': 'label', 'pref': '', 'suff': '', 'label_delim': None} -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}
Building one sample
Pipeline: ColReader -- {'cols': 'fname', 'pref': Path('/Users/slongobardi/Projects/multi-channel-images/data/multi/train'), 'suff': '.tif', 'label_delim': None} -> partial
starting from
fname 1_multi
label 0
is_valid False
is_test False
Name: 0, dtype: object
applying ColReader -- {'cols': 'fname', 'pref': Path('/Users/slongobardi/Projects/multi-channel-images/data/multi/train'), 'suff': '.tif', 'label_delim': None} gives
/Users/slongobardi/Projects/multi-channel-images/data/multi/train/1_multi.tif
applying partial gives
MultiChannelTensor of size 9x816x614
Pipeline: ColReader -- {'cols': 'label', 'pref': '', 'suff': '', 'label_delim': None} -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}
starting from
fname 1_multi
label 0
is_valid False
is_test False
Name: 0, dtype: object
applying ColReader -- {'cols': 'label', 'pref': '', 'suff': '', 'label_delim': None} gives
0
applying Categorize -- {'vocab': None, 'sort': True, 'add_na': False} gives
TensorCategory(0)
Final sample: (MultiChannelTensor: torch.Size([9, 816, 614]), TensorCategory(0))
Collecting items from fname label is_valid is_test
0 1_multi 0 False False
0 2_multi 1 False False
0 3_multi 1 False False
0 4_multi 0 False False
0 5_multi 0 False False
0 6_multi 0 False False
0 7_multi 0 False False
0 8_multi 0 False False
0 9_multi 0 False False
0 10_multi 0 False False
0 11_multi 1 False False
0 12_multi 0 False False
Found 12 items
2 datasets of sizes 12,0
Setting up Pipeline: ColReader -- {'cols': 'fname', 'pref': Path('/Users/slongobardi/Projects/multi-channel-images/data/multi/train'), 'suff': '.tif', 'label_delim': None} -> partial
Setting up Pipeline: ColReader -- {'cols': 'label', 'pref': '', 'suff': '', 'label_delim': None} -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}
Setting up after_item: Pipeline: ToTensor
Setting up before_batch: Pipeline:
Setting up after_batch: Pipeline:
Building one batch
Applying item_tfms to the first sample:
Pipeline: ToTensor
starting from
(MultiChannelTensor of size 9x816x614, TensorCategory(0))
applying ToTensor gives
(MultiChannelTensor of size 9x816x614, TensorCategory(0))
Adding the next 3 samples
No before_batch transform to apply
Collating items in a batch
No batch_tfms to apply
Let’s create an image dataloader to iterate over the full dataset described by the dataframe df. Since we only have 12 images in the dataset, we will use a batch size of 4
dls = ImageDataLoaders.from_dblock(
db,
df.loc[(df["is_valid"]==False) & (df["is_test"]==False)],
bs=4,
num_workers=0,
device="cpu",
)
The issues we encounter when attempting transfer learning in fastai are related to inner methods that cannot be called because they will raise an error given they only work with images with up to 3 channels. For example, we cannot make use of the ‘normalize=True’ flag when constructing a vision learner because the associated internal method expects the dataset to be a set of RGB images.
To overcome this issue, we need to manually normalize the dataset by performing an ‘offline’ item trasformation. Let’s compute first some summary statistics of the full dataset
def get_data_stats(dls):
x = []
for xi in next(iter(dls)):
x.append(torch.Tensor(xi[0]))
x = torch.cat(x, dim=0)
mean = x.sum(dim=[0, 2, 3])
std = (x**2).sum(dim=[0, 2, 3])
count = x.shape[0] * x.shape[2] * x.shape[3]
total_mean = mean / count
total_var = std / count - total_mean**2
total_std = torch.sqrt(total_var)
return total_mean, total_std
mean, std = get_data_stats(dls)
Now we can normalise the dataset using an ad hoc item transformation based on the computed statistics
from fastcore.transform import Transform
from torchvision.transforms import Normalize
class MultiChannelNormalize(Transform):
def __init__(self, mean, std, device="cpu"):
self.device = device
self.mean = mean
self.std = std
def encodes(self, x: MultiChannelTensor):
t = Normalize(mean=self.mean, std=self.std)
return t(x)
db = DataBlock(
blocks=(MultiChannelTensorBlock, CategoryBlock),
get_x=ColReader("fname", pref=path_to_train, suff=".tif"),
get_y=ColReader("label"),
splitter=ColSplitter("is_valid"),
item_tfms=MultiChannelNormalize(mean, std), # now we can normalise items
)
dls = ImageDataLoaders.from_dblock(
db,
df.loc[(df["is_valid"]==False) & (df["is_test"]==False)],
bs=4,
num_workers=0,
device="cpu",
)
mean_check, std_check = get_data_stats(dls)
print(mean_check) # close to a tensor of all 0s
print(std_check) # close to a tensor of all 1s
tensor([-1.0915e-07, 2.5988e-08, -2.0791e-08, -1.0915e-07, 2.5988e-08,
-2.0791e-08, -1.0915e-07, 2.5988e-08, -2.0791e-08])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
3. Vision learner
We are ready to define a fastai vision learner based on an example architecture which we know has pretrained weights available (e.g., resnet18). Make sure to switch ‘normalize=False’
from fastai.metrics import error_rate
from fastai.vision.all import vision_learner
from fastai.vision.models import resnet18
from fastai.optimizer import ranger
learn = vision_learner(
dls,
resnet18,
metrics=error_rate,
pretrained=True, # freeze the body (non-trainable)
normalize=False,
opt_func=ranger
)
We were able to initialise our learner using a multi-channel image dataset, yay! Also, we manually normalised the dataset so that we can do transfer learing!!! We are still missing something though: we need to adapt the first layer of the network to accommodate all the additional channels after the first 3. We will cycle the weights of the first 3 channels as described in the introduction
def make_batches(x, bs):
if x <= bs:
return [np.min([x, bs])]
else:
return [bs] + make_batches(x - bs, bs)
def create_new_weights(original_weights, n_channels):
dst = torch.zeros(64, n_channels, 7, 7) # resnet18 specific, hardcoded
start = 0
for i in make_batches(n_channels, 3):
dst[:, start:start+i, :, :] = original_weights[:, :i, :, :]
start += i
return dst
def adapt_first_layer(src_model, n_channels, device):
original_weights = src_model[0][0].weight.clone()
new_weights = create_new_weights(original_weights, n_channels)
new_layer = torch.nn.Conv2d(
n_channels,
64, # resnet18 specific, hardcoded
kernel_size=(7, 7), # same
stride=(2, 2), # same
padding=(3, 3), # same
bias=False,
)
new_layer.weight = torch.nn.Parameter(new_weights)
src_model[0][0] = new_layer
src_model.to(device)
Ok, we are ready to adapt the first layer to accommodate more than 3 channels! Notice that there are some hardcoded values you need to change in case you would want to use a different architecture
n_channels = 9
adapt_first_layer(learn.model, n_channels, "cpu")
learn.summary()
Sequential (Input shape: 4 x 9 x 816 x 614)
============================================================================
Layer (type) Output Shape Param # Trainable
============================================================================
4 x 64 x 408 x 307
Conv2d 28224 True
BatchNorm2d 128 True
ReLU
____________________________________________________________________________
4 x 64 x 204 x 154
MaxPool2d
Conv2d 36864 False
BatchNorm2d 128 True
ReLU
Conv2d 36864 False
BatchNorm2d 128 True
Conv2d 36864 False
BatchNorm2d 128 True
ReLU
Conv2d 36864 False
BatchNorm2d 128 True
____________________________________________________________________________
4 x 128 x 102 x 77
Conv2d 73728 False
BatchNorm2d 256 True
ReLU
Conv2d 147456 False
BatchNorm2d 256 True
Conv2d 8192 False
BatchNorm2d 256 True
Conv2d 147456 False
BatchNorm2d 256 True
ReLU
Conv2d 147456 False
BatchNorm2d 256 True
____________________________________________________________________________
4 x 256 x 51 x 39
Conv2d 294912 False
BatchNorm2d 512 True
ReLU
Conv2d 589824 False
BatchNorm2d 512 True
Conv2d 32768 False
BatchNorm2d 512 True
Conv2d 589824 False
BatchNorm2d 512 True
ReLU
Conv2d 589824 False
BatchNorm2d 512 True
____________________________________________________________________________
4 x 512 x 26 x 20
Conv2d 1179648 False
BatchNorm2d 1024 True
ReLU
Conv2d 2359296 False
BatchNorm2d 1024 True
Conv2d 131072 False
BatchNorm2d 1024 True
Conv2d 2359296 False
BatchNorm2d 1024 True
ReLU
Conv2d 2359296 False
BatchNorm2d 1024 True
____________________________________________________________________________
4 x 512 x 1 x 1
AdaptiveAvgPool2d
AdaptiveMaxPool2d
____________________________________________________________________________
4 x 1024
Flatten
BatchNorm1d 2048 True
Dropout
____________________________________________________________________________
4 x 512
Linear 524288 True
ReLU
BatchNorm1d 1024 True
Dropout
____________________________________________________________________________
4 x 2
Linear 1024 True
____________________________________________________________________________
Total params: 11,723,712
Total trainable params: 566,208
Total non-trainable params: 11,157,504
Optimizer used: <function ranger at 0x2db4a3e20>
Loss function: FlattenedLoss of CrossEntropyLoss()
Model frozen up to parameter group #2
Callbacks:
- TrainEvalCallback
- CastToTensor
- Recorder
- ProgressCallback
Notice that most of the parameters (body) are not targeted for training (Trainable: False) since we are doing transfer learning (‘pretrained=True’).
4. Transfer learning
We are ready to train our model! Of course ‘train’ is a big word, we are solving a mock / fake classification problem here just to showcase we can run a training and transfer learning on multi-channel images. First, we find a suitable learning rate using fastai’s lr_find utility. Then, we fine-tune trainable parameters
from fastai.torch_core import set_seed
set_seed(8, reproducible=True) # enforcing reproducibility
learn.lr_find()
SuggestedLRs(valley=0.0020892962347716093)
n_epochs = 5
lr = 2e-3
learn.fine_tune(n_epochs, lr)
| epoch | train_loss | valid_loss | error_rate | time |
|---|---|---|---|---|
| 0 | 0.733440 | None | None | 00:06 |
| 1 | 0.578637 | None | None | 00:06 |
| 2 | 0.498724 | None | None | 00:06 |
| 3 | 0.576430 | None | None | 00:06 |
| 4 | 0.672077 | None | None | 00:06 |
