Battling label distribution shift in a dynamic world | by Av Shrikumar | Nov, 2020
Maximum likelihood with appropriate calibration goes a long way.
by Amr M. Alexandari & Avanti Shrikumar
In this tutorial, we will see how we can use a combination of model calibration and a simple iterative procedure to make our model predictions robust to shifts in the class proportions between when the model is trained and when it is deployed.
Label shift in a real world setting
Say we build a classifier using data gathered in June to predict the probability that a patient has COVID19 based on the severity of their symptoms. At the time we train our classifier, covid positivity was at a comparatively lower rate in the community. Now it’s November, and the rate of covid positivity has increased considerably. Is it still a good idea to use our classifier from June to predict who has covid?
To understand why the classifier from June might underestimate the prevalence of covid, let’s imagine our classifier is built using a single variable capturing symptoms severity that we will call the “disease score” (this argument generalizes to classifiers built using many variables, as well as classifiers that have more than 2 classes as the output, but the intuition is easiest in the single variable & two-class case). Here is a toy visualization of what the ground-truth distribution of symptom severity might look like for positive and negative cases, in a situation where 10% of all tested cases in June were positive:
The red line shows the predicted fraction of positives for an ideal classifier that matches the ground-truth. This ground-truth probability can be calculated by taking the height of the orange bar at a given disease score and dividing by the total height of the blue and orange bars at that disease score added together.
Now let us visualize the case where the proportion of positives among the tested population has risen to 70% in November (note: this an extreme shift, and we show it solely for ease of visualization). The symptoms of covid have not changed between June and November, which means the overall shape of the blue distribution and orange distribution would stay the same. However, the height of the blue bars relative to orange bars would increase to reflect the greater proportion of positives, giving:
The solid red line shows the predictions from the classifier trained in June, while the dashed red line shows the true fraction of positives for data gathered in November. As we can see, the ideal classifier for data gathered in June underestimates the probability that a patient has COVID19 when the classifier is deployed in November. This phenomenon is called label shift or prior probability shift. If we knew the labels for the testing data, we could train a new classifier to output the class probabilities — but in practice, we don’t know the labels for the testing data — all we observe is the overall distribution of symptom severity for both the covid-positive and covid-negative patients combined, which looks something like this:
So, how can we get an updated classifier?
Adapting via Maximum Likelihood
It turns out there is a simple way to adapt our classifier to account for the shift in class proportions. To see how this method works, let us introduce some terminology. The dataset that we train on is called the “source domain”, while the dataset we deploy the model on is called the “target domain”. We will use the following notation:
- y denotes the class, which can be one of “positive for covid” or “negative for covid”
- p(y) denotes proportion of patients belonging to class y in the source domain (i.e. in June)
- q(y) denotes the proportion of patients belonging to class y in the target domain (i.e. in November)
- p(y|x) denotes the conditional probability in the source domain (i.e. in June) that the patient belongs to class y given symptoms that look like x
- q(y|x) denotes the conditional probability in the target domain (i.e. in November) that the patient belongs to class y given symptoms that look like x
- p(x|y) denotes the conditional probability in the source domain (i.e. in June) of observing certain symptoms, given that you know the patient belongs to class y
- q(x|y) denotes the conditional probability in the target domain (i.e. in November) of observing certain symptoms, given that you know the patient belongs to class y
Let’s first begin by taking stock of which of these quantities we do have information about. We can estimate p(y) by simply calculating the average class proportions in our data from the source domain. We also have an estimate of p(y|x) from building our classifier on the source domain (remember, the classifier was trained to estimate the probability that a person belongs to a particular class — in this case covid vs. no covid — given their symptoms). What about p(x|y)? While it is technically possible to build an estimate of p(x|y) using the source-domain data, in practice this can be very hard when x is high-dimensional. Thus, we will not assume that we have access to p(x|y). However, if we assume that the symptoms of covid do not change between June and November, we can assume that p(x|y) = q(x|y). To see why this is useful, let us consider how we can update our classifier if we were given a guess for the value of q(y). From Bayes’ rule, we have:
If we substitute our assumption that q(x|y) = p(x|y), we get:
As mentioned before, we don’t always have an estimate of p(x|y) — however, we can get around this by applying Bayes’ rule again to p(x|y), which gives:
We’re very close! We have an estimate of p(y|x) and p(y) from the training data, and we assumed we were given a guess for q(y) — the only term we don’t know is p(x). Fortunately, we can see that p(x) cancels out in the numerator and the denominator, giving us:
Let’s develop some intuition for this formula. In the numerator, we see that the predictions p(y|x) are being re-weighted by the class ratios q(y)/p(y) to produce q(y|x), and the denominator is simply normalizing the re-weighted predictions to sum to 1 across all classes (in order to get a valid probability distribution). If q(y) were 0 for all except one class, then after applying this re-weighting and normalization, q(y|x) would again be 0 for all except that one class (for which it would be 1).
We have thus seen that if we are given a guess for q(y), we can get an updated estimate of q(y|x). But how can we obtain a good guess for q(y)? Well, one thing that you (like the authors of this work) might intuitively try to do is to estimate q(y) by (a) making predictions on all the target-domain examples using your original source-domain classifier p(y|x), and (b) averaging the predictions to obtain an initial guess for q(y). When we apply this to our toy covid example from before, we get an estimate of q(y)=34%.
Using this guess for q(y), we can get an estimate of q(y|x) using Bayes’ rule as shown above. But wait! What if we go back and apply our updated q(y|x) to all the target domain examples to re-estimate q(y)? When we do this, we get an estimate of q(y)=53%, which is even closer to the ground truth. We can now repeat this process: average the value of q(y|x) over all the testing set examples to re-estimate q(y), and then use our updated q(y) to re-estimate q(y|x). If we do this for multiple iterations, we eventually converge to the true value q(y)=70%, as shown below:
If this iterative procedure reminds you of Expectation Maximization, you’d be correct! The algorithm we have described above is indeed a valid expectation maximization algorithm, as was first shown in 2002 by Saerens et al.¹ Incidentally, that 2002 paper was incorrectly described in several recent works as requiring access to an estimate of p(x|y) — however, as we have shown, p(x|y) is not used in the EM algorithm.
Calibration with Bias Correction
One issue that can arise when applying the EM updates is that, sometimes, the predictions p(y|x) that are output by a model might not be calibrated. By calibrated, we mean simply that for all examples where the model outputs a probability of x% for a particular class, there is actually an x% chance that those examples belong to that class. Modern neural networks are often notoriously miscalibrated, as was shown by Guo et al.² , and this miscalibration is believed to stem from overfitting to the training set. To visualize this miscalibration, we can bin the predicted probabilities for a given class into intervals of size 0.1 on the x-axis, and for each bin we can plot the actual probability that the example belonged to the class on the y-axis. Any divergence from the x=y line indicates miscalibration. If we do this for a model trained on the CIFAR10 object recognition dataset, here is what we observe for the “cat” class when we make predictions on the testing set:
As you can see, there is a marked deviation from the x=y line, indicating miscalibration. This suggests that the predicted probabilities p(y|x) are not very reliable, and if we were to apply them as-is in the EM algorithm, we might get poor results. The approach proposed in Guo et al. to correct for miscalibration is called Temperature Scaling (TS). In TS, the predicted probabilities are adjusted by rescaling the softmax logits according to a “temperature” parameter T. The parameter T is optimized to achieve the best negative-log-likelihood on a held-out validation set. Formally, with TS, the new predicted probabilities become:
Where z(x) is a function that returns the original logit vector from the classifier. If we use TS to adjust the probabilities in the CIFAR10 case above, here is what we observe for the “cat” class:
We still notice some issues. There is consistent bias with a higher “true” fraction of cat labels in almost every bin; the network seems to be systematically under-estimating the true probability that an observation belongs to the “cat” class. To fix this, we propose introducing explicit class-specific bias correction terms to the TS equation, which we call Bias-Corrected Temperature Scaling (BCTS). BCTS is defined as follows:
Applying BCTS, we get:
As we can see, the systematic bias in the predicted probabilities has been greatly reduced. In our paper, we found that BCTS consistently tended to give superior results for label shift adaptation compared to TS.
Putting it all together
Let’s summarize our procedure for adapting to label shift. Our algorithm proceeds as follows:
- N observations from the target domain (i.e. the testing set) with unknown labels
- a trained model that outputs p(y|x), constructed using data from the source domain
- a held-out validation set from the source domain with known labels
- Calibrate the predictions p(y|x) on the validation set. We recommend using bias-corrected temperature scaling for this.
- Compute p(y) by averaging the calibrated p(y|x) on the validation set. Set q(y) = p(y).
- For the observations from the target domain:
on each observation from the target domain
- If |q’(y) -q(y)| exceeds some tolerance parameter, set q(y) = q’(y) and repeat step 3.
4. Return q(y|x) as the predictions for the target domain.
This simple algorithm, described in Alexandari et al. (ICML 2020)³, turns out to achieve state-of-the-art results compared to alternative methods such as BBSL⁴ and RLLS⁵. This was independently verified in a paper by Garg et al.⁶, which states (referring to the maximum likelihood approach as MLLS):
“Across all shifts, MLLS (with BCTS-calibrated classifiers) uniformly dominates BBSE, RLLS, …”
Garg et al. also provided a theoretical analysis of the maximum likelihood approach that confirms the critical importance of good calibration (but also shows that there is more work to be done beyond using BCTS calibration).
To see this method in action with code, we can use the python abstention package, which implements all of these methods and makes battling label shift as easy as:
from abstention.calibration import TempScaling
from abstention.label_shift import EMImbalanceAdapter#Instantiate the BCTS calibrator factory
bcts_calibrator_factory = TempScaling(verbose=False,
bias_positions='all')#Specify that we would like to use Maximum Likelihood (EM) for the
# label shift adaptation, with BCTS for calibration
imbalance_adapter = EMImbalanceAdapter(calibrator_factory=
bcts_calibrator_factory)#Get the function that will do the label shift adaptation
# (creating this function requires supplying the validation set
# labels/predictions as well as the test-set predictions)
imbalance_adapter_func = imbalance_adapter(
valid_posterior_probs=valid_preds)#Get the adapted test-set predictions
adapted_shifted_test_preds = imbalance_adapter_func(
A colab notebook demonstrating this process is available here