Building a Custom Semantic Segmentation Model | by Sam Watts | Nov, 2020
Using your own data to create a robust computer vision model
Following on from my previous post here, I wanted to see how feasible it would be to reliably detect and segment a Futoshiki puzzle grid from an image without using a clunky capture grid. It works surprisingly well even when trained on a tiny dataset!
Semantic Segmentation is a step up in complexity versus the more common computer vision tasks such as classification and object detection. The goal is to produce a pixel-level prediction for one or more classes. This prediction is referred to as an image ‘mask’. The example here shows 3 overlaid masks for person, sheep, and dog represented by the different foreground colours.
For my task, the setup is somewhat simpler as there is only one class to predict – the puzzle grid. To train the model, we need pairs of images and masks. The images we are using are full colour, so as an array will have the shape (H, W, 3). The masks on the other hand only have a single value per pixel (1 or 0), so will have shape (H, W, 1).
How do we get the image masks I’ve just talked about? VIA is a great tool for image labelling — it’s open source and runs in a browser from a standalone HTML file.
VIA lets you export labels for multiple images as a csv, with the coordinates of each polygon in json format:
I then wrote a custom pytorch dataloader, which converts the polygon json into a single channel image mask. The training image and the target mask are then passed on to the model.
In total I labelled 43 images, which I split 75:25 into training and validation sets. I later added an extra 7 images to serve as a test set. This might not seem like much data to be training a large neural network on – but fortunately there are some techniques we can use to get the most out of this small set of images!
As this is a prototype, I wanted to see if the approach would achieve decent results without building the whole thing myself from scratch and potentially wasting a lot of effort. With that in mind, I used the awesome
segmentation-models-pytorch library. The power of this library hinges on transfer learning, which means we can avoid having to train the entire network from a standing start.
U-Net consists of a coupled encoder and decoder structure, which builds high level abstractions of input images before expanding out these abstractions to provide a pixel-level prediction.
The grey arrows signify residual connections between the encoder and decoder pathways. This means that at every upwards step of the decoder, the encoder matrices of the same dimensions are concatenated together with the decoder matrices. The benefits of this are twofold:
- At each level of the decoder – which would otherwise only contain high level abstraction information of the image – the network is able to combine it’s learning about high and low level features, increasing the fidelity of predictions.
- Residual connections allow backpropagation during training to skip past layers, making optimisation easier. This is also crucial when training deeper models to avoid issues with vanishing gradients.
The beauty of this architecture is also that we can use a pre-trained model that has been used for a classification task – on a dataset such as ImageNet – as our encoder. Once we remove the final classification layer from this model, this can be connected to a decoder with untrained weights, and skip-connections are added to reflect the U-Net structure. This saves a lot of compute time, as our pre-trained encoder already has good parameters for building high levels abstractions of images.
segmentation-models-pytorch provides pre-trained weights for a number of different encoder architectures.
Google AI published their EfficientNet paper in 2019 with new thinking behind how to scale up convolutional neural networks. Alongside this, the paper proposed a range of models of increasing complexity that achieve state of the art performance.
As a trade off between size and performance, I chose the B3 variant to use in my model.
Specifying these architecture choices with
segmentation-models-pytorch is a breeze:
As the training dataset only contains 36 images, overfitting is a serious concern. If we train for multiple epochs over this small dataset, we might worry that our model will start fitting to the noise in this small dataset, leading to poor performance on out of sample examples. This problem can be somewhat mitigated by data augmentation. As each training image and mask pair is read into memory to pass to the model, we apply several layers of non-deterministic image processing, as shown below.
It’s useful to look at example image to see the individual effects of each of these layers. This is the first image from our training set:
As you can see below, most of the augmentations by themselves only provide a subtle change – however when stacked up, they add enough novelty to our training data to stop the model fitting to the noise of the base dataset
The most commonly used loss function is pixel wise Cross-Entropy Loss – similar to what is used in general classification tasks. Here, we instead use Dice Loss, which was introduced to address the issue of class imbalance in semantic segmentation:
Dice Loss = 2|A ∩ B| / |A| + |B|
In practice, the intersection term of this equation is approximated by calculating the element-wise product of the prediction and target mask matrices:
We also use Intersection-over-Union (IoU) as a scoring metric. This essentially looks at the overlapping over total area of both predicted and ground truth masks, which is a similar concept to Dice Loss.
- Trained for 40 epochs, initial learning rate = 5x10e-4
- After the 30th epoch, learning rate = 5x10e-5
I tested the trained model on 7 held out images from my labelled dataset, and the model achieved a IOU Score = 0.94 for these images, including some with puzzles at odd angles and as a smaller part of the image.
I also ran the model over a short video to see the results more visually, which was also pretty good – it also deals well with an object covering the puzzle!
The enhanced version of the code base I discussed in my prior post can be found here. This version of the model shows some slight activation on background features, which is perhaps the sign of some overfitting.
To conclude, this approach showed some pretty impressive results, especially given the tiny amount of training data that was used!
I found two of the recent DeepMind x UCL Deep Learning Lectures to be a great introduction to computer vision concepts: