Medical Image Captioning on Chest X-Rays | by Ashish Thomas Chempolil | Jan, 2021
These are the some of the topics that the reader has to familiarize to fully understand the modelling part:
For the modelling part I have created 3 models. Each one will have a similar encoder but entirely different decoder architectures.
Encoder part will take the two images and convert it into backbone features to provide to the decoder. Here for encoder part I will be using CheXNET model.
CheXNET Model is a Denset121 layered model which is trained on millions of chest x-ray images for the classification of 14 diseases.We can load the weights of that model and pass the images through that model. The top layer will be ignored.
Now I will create an Image Encoder layer which contains chexnet model which is same as outlined as before. Here I will set the chexnet model trainable as false.
Simple Encoder Decoder Model
This model will be our baseline. Here I will built a simple implementation of an image captioning model. The architecture will be as shown below:
Here I will pass the two images through the Image_encoder layer and concatenate the two outputs and then pass it through the Dense layer. The padded tokenized captions will be passed through an embedding layer where we will be using pretrained Glove vectors (300 dimensions) as the initial weights for the layer. This will be set as trainable and then it is passed through LSTM where the initial state of the LSTM is taken from the output of Image_dense layer. These are then added and then passed through output_dense layer where the numbe of output will be the vocabulary size with softmax activation applied on top.
For training, I have chosen “adam” optimizer and created a loss function based on Sparse categorical loss where I have only considered the losses of the words in the true captions.
Here for this model I will be using Global Attention and using the concat equation.
For the decoder, I have created a one_step_decoder layer which takes in decoder_input, the encoder_output and state value. The decoder_input will be any character token number. This will be passed through the embedding layer and then embedding output and the encoder_output will be passed through the attention layer which will produce the context vector. The context vector will then be passed through the RNN (here GRU will be used) with initial state being that of previous decoder.
The decoder model will store all the outputs in an tf.TensorArray and return it. Here I will be using teacher forcing method to train the RNN where instead of passing the output of the last rnn’s output I will pass the next original token.
s1 will be the zeros here. All i’s will be the original token after passing it through attention layer along with encoder output.
Final Custom Model
Here I will be using a custom encoder along with decoder which is same as that of the Attention Model.
For the encoder part for this model I will be extracting the backbone features from the chexnet model specifically 3rd last layer’s output. This will be the output from the Image_encoder layer. Then I will pass it through global flow and context flow which is actually inspired from another model which was used for image segmentation purposes. This will be explained below.
Global Flow and Context Flow: This architecture implementation is taken from Attention guided chained context aggregation for image segmentation (Specifically Chained Context Aggregation Module (CAM)) which was used for image segmentation but I will use it to extract image information. Here what I will do is that the outputs from Image encoder (ie chexnet) will be sent to global flow. Then outputs from both the chexnet and global flow will be concatenated and sent to context flow. Here only one context flow will be used since our dataset is small and can lead to underfitting if a more complicated architecture is used. Here global flow extracts the global information of image while context flow will get the local features of the images. Then the output from the global flow and context flow will be summed and then sent to decoder after reshaping, and applying batch norm and dropout.
- Here Global flow will take the information from Image encoder layer and then I will apply global average pooling which will result in (batch_size,1,1,no. of filters).
- Then we will apply batch normalization, relu, 1*1 Convolution and upsample to the same shape as the input.
- We will get the data from both global flow and the image encoder layer and concat it on the last axis.
- This is then applied average pooling which reduces the size of feature map by N* times.
- Then a 3*3 single convolutions will be applied twice. (No CShuffle will be applied)
- After that we will apply 1*1 convolutions, followed by relu activation, then again 1*1 convolutions followed by sigmoid activation. This is multipled with output from context fusion module and then added to output from the context refinement model which will then be upsampled with the same size as the input (here conv2d transpose will be used to get the same number of filters as of that of the input).
The summed output of both context flow and global flow will be concatenated (here add_1 and add_2 are results of image_1 and image_2 respectively) then will be sent to Dense layer which converts the no. of filters to 512 which after a series of reshaping, batch normalization and dropout will be sent to decoder. The decoder architecture is the same as that of the Attention model.
Read More …