Artificial Data for Image Classification | by Ayaan Haque | Dec, 2020
A novel method for using GANs and artificial data in classification tasks
In this article, I will review a new method for using GANs, or Generative Adversarial Networks, for semi-supervised classification from the paper “EC-GAN: Low-Sample Classification using Semi-Supervised Algorithms and GANs.” My paper was recently accepted to the 35th AAAI Conference on Artificial Intelligence in February and will appear in the abstract program and the proceedings. This article will include a review of the method, important results, as well as a PyTorch tutorial on how to implement a simplified version of the method.
EC-GAN, which stands for External Classifier GAN, is a semi-supervised algorithm that uses artificial data generated by a GAN to improve image classification. Semi-supervised learning has been gaining interest in recent years because it allows for learning from limited labeled data. GANs have recently been applied to classification tasks, and often share a single architecture for both classification and discrimination. However, this may require the model to converge to a separate data distribution for each task, which may reduce overall performance. Specifically, restricted, fully-supervised learning, where datasets are very small and don’t even have access to unlabeled data, has received much less attention.
EC-GAN uses artificial data from GANs and semi-supervised pseudo-labeling to effectively increase the size of datasets and improve classification. Importantly, EC-GAN attaches a GAN’s generator to a classifier, hence the name, as opposed to sharing a single architecture for discrimination and classification. The promising results of the algorithm could prompt new related research on how to use artificial data for many different machine learning tasks and applications.
What are Semi-Supervised and Fully-Supervised Learning?
Semi-supervised learning is a method for machine learning where a model can learn from both labeled and unlabeled data in order to reduce the need for labeled data. Deep learning models require lots of data to achieve effective performance because of the sheer size of the models. Traditionally, if a data sample lacks a corresponding label, a model cannot learn from it. Therefore, semi-supervised learning has grown as an alternative because of the amount of tasks that have unlabeled data, and many different methods have been developed in recent research.
A more severe case scenario includes tasks where even unlabeled data is unavailable and the dataset only contains a small amount of entirely labeled data. This domain is known as restricted, fully-supervised learning. There have been few recent methods to address these problems as most methods formulate a problem such that there is some amount of unlabeled data to learn from. In these scenarios, any increase in the size of the dataset, even unlabeled or perhaps artificial data, can be beneficial. As such, the EC-GAN method attempts to use a Generative Adversarial Network (Goodfellow et al. 2014) to address this problem.
What are Generative Adversarial Networks?
A GAN’s objective is to train two neural networks where a generative model is attempting to generate images resembling real training samples by replicating the data’s distribution. Simultaneously, a discriminative network predicts the probability that a generated image is from the real training set. The two models compete with one another, such that the generator eventually produces images resembling real training samples. During training, the generator is updated on predictions of the discriminator to create better images, and the discriminator improves at discriminating images as real or fake. The goal is to have the two networks achieve equilibrium, at which point the generator is creating almost perfect images and the discriminator is left with a 50% chance of discriminating correctly. To learn more about the GAN loss objective, refer to this link.
Many existing methods using GANs for semi-supervised learning utilize a single network with separate classification and discrimination branches (Salimans et al. 2016). A traditional classifier attempts to classify data to its respective class, with the output of the classifier being a probability distribution over K such classes. The architecture has individual layers at the end of the network for each task. The discriminator head and classifier head each independently update the base network of shared parameters. This means the network attempts to minimize two separate losses with the same parameters, which could be a concern. While multi-task learning can be beneficial in certain scenarios, for the two specific tasks of classification and discrimination, the learned features for each task may not be similar enough to warrant a shared, multi-tasking architecture. This means that a new method that separates the two tasks into individual networks while still training in a mutually-beneficial relationship is important to improve performance.
EC-GAN addresses restricted, fully-supervised learning by leveraging GANs and artificial data while also separating the tasks of classification and discrimination.
The Method and Implementation
The algorithm consists of three separate models: a generator, a discriminator, and a classifier. At every training iteration, the generator is given random vectors and generates corresponding images. The discriminator is then updated to better distinguish between real and generated samples. These are standard GAN training procedures.
The losses for the discriminator and generator can be defined by the following:
In the following equations, BCE is binary cross-entropy, D is the discriminator, G is the generator, x is real, labeled data, and z is a random vector.
Simultaneously, a classifier is trained in a standard fashion on available real data and their respective labels. All of the available real data have labels in this method.
We then use generated images as inputs for supplementing classification during training. This is the semi-supervised portion of our algorithm, as the generated images do not have associated labels. To create labels, we use a pseudo-labeling scheme that assumes a label based on the most likely class according to the current state of the classifier. The generated images and labels are only retained if the model predicts the class of the sample with high confidence, or a probability above a certain threshold. This loss is multiplied by a hyperparameter λ, which controls the relative importance of generated data compared to true samples.
The combined loss of the classifier can be defined by the following equation:
In the equation above, x is the real data, y is the corresponding labels, z is a random vector, CE is cross-entropy, y is the respective labels, λ is the unsupervised loss weight, C is the classifier, and t is the pseudo-labeling threshold.
The first component of the loss is the standard method of fully-supervised learning, where the cross-entropy is calculated with the supervised data. The second component is the unsupervised loss, where the cross-entropy is computed between the classifier’s predictions on the GAN generated images and the hypothesized pseudo-labels. These pseudo-labels are produced with the “argmax” function. The threshold is a key component, as without this threshold, the model may be negatively impacted by GAN generations that are poor and non-realistic. If GAN generations are poor, the model will not be able to label them with confidence, which means they will not be computed in the loss.
λ is also an important component, as λ controls the importance of the unsupervised loss. We incorporate λ because generated images are only meant to supplement the classifier and should be less significant than real, labeled data when calculating loss.
Code Implementation in PyTorch
Now that the algorithm itself has been described, let’s write some code using PyTorch.
The model architectures for this method are not too important nor are they unique to the method itself. However, in order to achieve the best performance, we will utilize the DC-GAN, or the Deep Convolutional GAN (Radford et al. 2015) architecture, which is a deep, convolutional implementation of a standard GAN. The code for the generator and discriminator is shown below.
To simplify, in the following code snippets, the model architectures are coded according to the DC-GAN paper and implementation. To learn more about these specific models, refer to this link.
The third network required in this algorithm is the classifier, and for this example, we will use a ResNet-18. The code is below.
This is a classic ResNet-18 implementation in PyTorch, and it is resized for 32×32 inputs, just like the GAN models. To learn more about ResNets, refer to this link. However, feel free to use whatever classifier architecture you prefer, as long as the input sizes match those of the GAN network.
Now, let’s move on to the algorithm itself. The implementation of the algorithm can be done quite simply and effectively. The following snippet shows the steps in each minibatch to execute the algorithm in a simplified form.
To begin training, we load the images and labels from the available dataset. We also create labels for the GAN, which are just tensors of 0s and 1s, which are used to train the discriminator. We then create a random vector (torch.randn) of size 100x1x1 and pass it through the generator (netG) to create fake images.
The discriminator (netD) is first trained on the real images and given labels of 1. Then the discriminator is trained on the fake images created by the generator (fakeImageBatch). The loss is calculated each time and the optimizer then takes a step to update itself (optD.step) and cleared each time (optD.zero_grad). The generator is then given the predictions of the discriminator and the loss is calculated (optG.step) using labels of 1. The classifier is then trained on the available real images in a conventional fashion and uses cross-entropy loss. This loss is labeled realClassifierLoss, and the classifier is updated with this loss (optC.step).
Now, the classifier (netC) is given the GAN generated images (fakeImageBatch) and generates classification predictions on these images. These predictions are then passed converted into hard pseudo-labels (torch.argmax), and a tensor of labels are created. Then, the predictions are passed through a softmax activation function to determine the predicted probability of each class for each image (probs). Then, each softmax distribution is examined to determine the indices of the labels with the highest predicted probability. The highest probability is compared to the given threshold and if the probabilities are above the threshold, the predictions are added to the array of indices to keep (toKeep). Finally, the loss is only updated (torch.backward) on the pseudo-labels (predictedLabels) that were above the given threshold (fakeClassifierLoss).
This is the simple implementation of the algorithm, and it is now clear how the classifier works in tandem with the GAN to use the artificial images for classification.
Regarding the most important results, the classification accuracy of EC-GAN was compared to a bare classifier as well as the shared discriminator method, which was discussed earlier. Since EC-GAN focuses on separating classification and discrimination, a direct comparison of the two methods is important. The following table contains the results of both methods at varying labeled dataset sizes. The training and testing were done with the SVHN dataset, a common academic dataset used as a benchmark for classification and GAN algorithms.
The external classifier method performs on par and occasionally better than a shared architecture in small datasets. This could be because each network can learn its own task with its own parameters as opposed to a shared architecture where the network simultaneously updates for two tasks, which can allow both networks to reach their potential. Moreover, the shared architecture does not definitionally increase the size of the dataset, since it is not updating classification with GAN images. Through this empirical analysis, separating classification and discrimination and supplementing classification with generated images may be key factors for strong performance in the algorithm.
There were other ablation results and evaluations performed for this algorithm, which will be available with the rest of the paper after the conference in February.
In this article, we reviewed a new generative model that attaches an external classifier to a GAN to improve classification performance in restricted, fully-supervised datasets. The models allows classifiers to leverage GAN image generations to improve classification, while simultaneously separating the tasks of discrimination and classification. The results show promising potential for real application to image processing problems, and the implementation in code is intuitive and efficient.
This work is exciting because it reveals the ways that artificial data can be used to perform machine learning tasks. With just a small dataset of images, a GAN can significantly increase the effective size of the dataset. With this increase in data, many deep learning tasks can be performed at a higher level because of how much deep learning approaches rely on lots of data. There has been little research on how to most effectively use artificial data and how or why it can be beneficial, and this method and paper shows the potential of the approach. I am excited for feedback on this paper in the near future at AAAI 2021 and be sure to be on the lookout for the conference and the proceedings in February. Thanks for reading.
 Goodfellow, I. J.; Pouget-Abadie, J.; Mirza, M.; Xu, B.; Warde-Farley, D.; Ozair, S.; Courville, A.; and Bengio, Y. 2014. Generative Adversarial Networks.
 Salimans, T.; Goodfellow, I.; Zaremba, W.; Cheung, V.; Radford, A.; and Chen, X. 2016. Improved Techniques For Training GANs. In Advances in neural information processing systems, 2234–2242.
 Radford, A.; Metz, L.; and Chintala, S. 2015. Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.
Read More …