Computer Vision Problems : Building a multi-label classifier using PASCAL dataset
Book notes: Deep Learning for Coders
Multi-Label Classification
The field of computer vision includes a set of main problems such as image classification, localisation, image segmentation and object detection. Among these, image classification can be considered as the fundamental problem. Image classification is one of the most fundamental tasks in computer vision. And for a good reason. It has revolutionised and propelled technological advances in the most important fields, including the automotive industry, healthcare, manufacturing, etc.
for this tutorial, we will be using the PASCAL dataset, this one contains images with multiple labels, it will serve as a perfect example for our multi-classification task.
!pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
from fastai.vision.all import *
#importing the pascal dataset
path = untar_data(URLs.PASCAL_2007)
#list first elements
df = pd.read_csv(path/'train.csv')
df.head()
We can see that a dataset example (instance) is represented by a row containing the file name, labels(chair, car....) those are separated by a space and a is_valid boolean property.
Constructing a DataBlock
To convert our DataFrame object into a DataLoader, first we need to go through dataset object conversion, one thing we need to know is that Dataloader is build on top of Datasets, thus we'll try to create a Datasets object then convert it into a Dataloaders one.
One other thing we need to do is specify to our DataBlock constructor our input and the target, otherwise it will directly assume that we have both and end up with duplicating the data examples.
dblock = DataBlock(get_x = lambda r: r["fname"], get_y = lambda r: r["labels"])
dsets = dblock.datasets(df)
dsets.train[0]
In practise we'll need the full path of the image and a list of labels thus we'll use a more verbose version of the code above.
def get_x(r): return path/"train"/r["fname"]
def get_y(r): return r['labels'].split(" ")
dblock = DataBlock(get_x = get_x, get_y = get_y)
dsets = dblock.datasets(df)
dsets.train[0]
next we'll need to use a set of transforms to use images and convert those to tensors
dblock = DataBlock(blocks= (ImageBlock, MultiCategoryBlock) , get_x=get_x,
get_y=get_y)
dsets = dblock.datasets(df)
show_image(dsets.train[0][0])
notice that we havent used the is_valid column yet, we can assume that fastai is using a random splitter to get the validation set, to use the is_valid column we have to write a fucntion that we'll feed to our DataBlock constructor.
def splitter(df):
train = df.index[~df['is_valid']].tolist()
valid = df.index[df['is_valid']].tolist()
return train,valid
#Notice that we've made sure to apply Random Crop to ensure
#that every image is in the same size
dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock), splitter=splitter,
get_x=get_x, get_y=get_y,
item_tfms=RandomResizedCrop(128,min_scale=0.35))
dls = dblock.dataloaders(df)
Binary Cross Entropy
Now we'll need to create a Learner, it's constructed with four main things:
- model
- DataLoaders object
- Optimizer
- Loss function
here we can't use the nll_loss or softmax as we don't relly want our probabilities to sum up to 1 (we may feel unconfident about any of the labels), and we don't want a loss function that returns the value for a single label.
In other hand Binary Cross Entropy which is the mnist loss along with log will be useful for this task.
for the accuracy, we want to apply sigmoid first to map values between 0 and 1 alongside a threshold to decide wether we feel confident on that specific label.
loss_func = nn.BCEWithLogitsLoss()
def accuracy_multi(inp, targ, thresh=0.5, sigmoid=True):
if sigmoid: inp = inp.sigmoid()
return ((inp>thresh)==targ.bool()).float().mean()
#train the model
learn = cnn_learner(dls, resnet50, metrics=partial(accuracy_multi, thresh=0.2))
learn.fine_tune(3, base_lr=3e-3, freeze_epochs=4)
to help us decide which threshold value we use, we could grab the predictions and try few values of threshold
preds, targs = learn.get_preds()
#trying values for TH
xs = torch.linspace(0.05,0.95,29)
accs = [accuracy_multi(preds, targs, thresh=i, sigmoid=False) for i in xs]
plt.plot(xs,accs);