Deep learning models for COVID-19 chest x-ray classification: Preventing shortcut learning using feature disentanglement

In response to the COVID-19 global pandemic, recent research has proposed creating deep learning based models that use chest radiographs (CXRs) in a variety of clinical tasks to help manage the crisis. However, the size of existing datasets of CXRs from COVID-19+ patients are relatively small, and researchers often pool CXR data from multiple sources, for example, using different x-ray machines in various patient populations under different clinical scenarios. Deep learning models trained on such datasets have been shown to overfit to erroneous features instead of learning pulmonary characteristics in a phenomenon known as shortcut learning. We propose adding feature disentanglement to the training process. This technique forces the models to identify pulmonary features from the images and penalizes them for learning features that can discriminate between the original datasets that the images come from. We find that models trained in this way indeed have better generalization performance on unseen data; in the best case we found that it improved AUC by 0.13 on held out data. We further find that this outperforms masking out non-lung parts of the CXRs and performing histogram equalization, both of which are recently proposed methods for removing biases in CXR datasets.


Introduction
The Coronavirus Disease (COVID)-19 pandemic has exposed many vulnerabilities in public health systems around the world.Artificial intelligence (AI) is poised to address some of these challenges with potential applications including early disease detection, clinical management tools, disease modeling, and vaccine research [1].Screening patients for COVID-19 based on chest radiograph (CXR) imaging is one potential AI application, and deep learning models have been developed to distinguish COVID-19 pneumonia from normal findings and even rule out other types of pneumonia with high accuracy [2][3][4].Due to the increasing practice of reporting results prior to peer review in the setting of a novel pandemic, many additional models have been reported as preprint papers [5][6][7][8][9][10][11].
The main challenge with training these deep learning models has been the shortage of COVID-19 CXR imaging data.Several public datasets are available, but most are small (100-200 patients).A recent review assessed many AI diagnostic and prognostic models for COVID-19 using the prediction model risk of bias assessment tool (PROBAST) and found that all but one of the 51 studies were at high risk for analysis bias, mainly overfitting due to small sample size [12].Larger databases have typically pooled images from different patient populations at multiple research sites and hospitals, or even from different countries [7,13,14].These unbalanced datasets often contain images with source-specific identifiers (anteroposterior versus posteroanterior positioning, imaging device type, image size, etc) that the models can misidentify as relevant features.
The most common approach for developing COVID-19 CXR diagnostic models is to use transfer and representational learning techniques to adapt classification models that were trained on labeled CXR datasets (to classify pre-COVID-19 pulmonary findings), or on unrelated image datasets, to the new task [2,4,15].However, due to the problems mentioned above-small datasets or larger datasets with a latent imbalance-leads to learned models that suffer from a phenomenon summarized as "shortcut learning" [16].Here, the models will identify dataset-specific features rather than pulmonary specific features in order to distinguish between classes, and the reported performance of these models may not be generalizable to other patient populations.Indeed, recent research proves exactly this [17] and further shows that some shortcuts can persist in external datasets.The problem now becomes, "how to train highly-parameterized deep learning models on publicly available data, without learning shortcuts that are present in such data?"A recent study proposed two preprocessing techniques to solve this problem: histogram equalization, to correct for issues such as contrast differences between images, and lung masking to remove potential source-specific information located outside of the lung area [18].
In this paper, we sought to evaluate the use of feature disentanglement [19], a multi-task training approach for deep neural networks, to prevent shortcut learning in the context of classification of COVID-19 CXR images.Models trained with this method are forced to learn features that can identify COVID-19+ CXRs and are penalized for learning features that can identify what sub-population a CXR is from.We further sought to compare feature disentanglement to the previously proposed histogram equalization and and lung masking methods.Finally, we sought to test all three methods on the ultimate goal of improving model generalization performance on unseen data.We have released source code to reproduce the models and results found in this paper: https://github.com/microsoft/xray-feature-disentanglement.

Datasets
We used two CXR datasets: the open-source COVIDx dataset [20] and a dataset received from the China Consortium of Chest CT CC-CCII [21].These are both image classification datasets that share the same three class labels for CXR images: "normal" (Normal)-collected from patients without pneumonia, "common pneumonia" (Pneumonia) CXRs-collected from patients with pneumonias other than from COVID-19, and "novel COVID-19 pneumonia" (COVID-19+).We treated the source dataset of each image as its domain label.Specifically, there are five domain labels: one for each of the sub-datasets in the COVIDx dataset.See Table 1 for a breakdown of the number of class labels of each type per dataset.Below we describe each dataset in more detail.
The COVIDx dataset is a conglomeration of CXR samples from five other datasets: the COVID-CHESTXRAY dataset (also referred to as the COHEN dataset in related work) [22], the Fig 1 COVID-19 CXR dataset [23], the ACTUALMED COVID-19 CXR dataset [24], the Kaggle RSNA Pneumonia Detection Challenge dataset [25], and the SIRM samples from the Kaggle COVID-19 radiography database [26].The samples are divided among three classes: "normal" control images, non COVID-19 pneumonia images, and COVID-19+ images.We used the scripts provided on the COVIDx GitHub repository to create the dataset.We then merge the training and test sets and filter out all but the first sample for each patient.The label counts in Table 1 for this dataset therefore equal the number of patients.
The CC-CCII dataset is a large CT dataset encompassing CT images from retrospective cohorts from the China Consortium of Chest CT Image Investigation (CFC-CCII) [21].This datasets has samples from Sun Yat-sen Memorial Hospital and Third Affiliated Hospital of Sun Yat-sen University, The First Affiliated Hospital of Anhui Medical University, West China Hospital, Guangzhou Medical University First Affiliated Hospital, Nanjing Renmin Hospital, Yichang Central People's Hospital, and Renmin Hospital of Wuhan University.The dataset we use in this paper consists of CXRs taken from a subset of the patients included in the CT dataset.These images have been labeled with the same classes as in COVIDx.
Ethics statement.The COVIDx dataset is publicly available.The CC-CCII xray images were collected by Dr. Jun Shen from Sun Yat-sen Memorial Hospital.All patient information for the CC-CCII dataset are de-identified and anonymized before the authors received the data, and consent was obtained from all participants.The usage of the x-ray images in the CC-CCII dataset for our validation studies have been approved by the Sun Yat-sen Memorial University IRB.The CT images in the CC-CCII are publicly available.IRB approval was obtained before the start of the study.
The feature-extractor is responsible for creating an embedding, z, for a given CXR, and the classifier is responsible for predicting the label from a given embedding.This representation is helpful in transfer learning settings where frozen feature-extractor models (i.e.models with fixed θ) that have been pre-trained on large CXR datasets can be used to reduce overfitting when relatively few labeled samples are available.We fit f(z, ϕ) over a dataset, , where each image contains a class label and a domain label.The class labels are clinically relevant; for example, they describe whether a CXR is from a normal (Normal) patient, a patient with common pneumonia (Pneumonia), or from a patient that may exhibit novel COVID-19 pneumonia (COVID-19+).The domain labels are not clinically relevant.They encode information about how the sample was collected, such as the clinical site and patient cohorts.In our experiments the domain labels encode which dataset a sample originally came from.An ideal classifier would not be sensitive to how a sample was collected but would rely on pulmonary features present in the different types of class labels.However, for the reasons we outline in the Introduction, classifiers can overfit to spurious signals in CXRs leading to poor generalization when applied to new imagery.
We fit f(z; ϕ) in two optimization settings: a straightforward baseline setting and a feature disentanglement setting in which we force the classifier to learn data representations that are not useful in predicting the domain labels, while remaining useful for predicting the class labels: Baseline.In this setting, we learned ϕ � by minimizing a negative log-likelihood loss, Lðŷ class ; y class Þ, between the class predictions, ŷclass , and the class labels, y class , over D in a standard setup: Our hypothesis was that, in this setting, even simple classifiers will overfit to spurious features and exhibit poor generalization performance.Feature disentanglement.In this setting we assumed that we can further decompose f into a feature extractor, f e (z; ϕ e ) = z 0 , and two classification heads, 1 for an overview of this setup.We then follow the methodology proposed in [19] for training f in a multi-task setting to transform the initial embedding z into a compressed form z 0 that is useful for predicting the class label and not useful for predicting the domain label for a given CXR.Formally, we defined the empirical error with a class loss as above, and an additional domain loss: Note that the error is defined as the difference between the summed class loss and summed domain loss.Larger loss values in the second term will contribute to the overall minimization of the total error which influence the model to learn a compressed representation that is not predictive of the domain labels.λ is a hyperparameter that controls the relative influence of the domain loss on the error.We optimized for parameters, ϕ, according to: As described in [19], this can be solved by iterating using gradient-based optimization with an additional "gradient reversal layer" inserted between f e and f d in the structure of the model (see Fig 1).We considered a modification wherein after each training epoch we optimize for � � d , over the whole training set, for the current, fixed values of ϕ e .Put simply, we trained f d (z 0 ; ϕ d ) to convergence at the end of each epoch.As a consequence, the updates to ϕ e in the next epoch of training are done with respect to the domain loss computed at a local minimum for the domain classifier.We found empirically that this helped reduce variance in results between training runs.This modification is cheap when the parameters of the original feature extractor remain frozen and the intermediate representations, z, are cached-which was always the case in our experiments.

Lung masking and histogram equalization preprocessing
Motivated by the same points that we are [18], proposed two preprocessing methods for eliminating signal in CXR imagery that a classifier might overfit to: lung masking and histogram equalization.Lung masking uses a model to remove the non-lung parts of CXR images with the assumption that features from the imagery surrounding the lungs should not be used in predictive models of lung diseases.Histogram equalization is a common technique in computer vision used to improve the contrast of grayscale images.If histogram equalization is applied to each image in a dataset of CXR image, then a classification model will not be able to exploit relative differences in the contrast of images in its decisions.
These methods can be used in combination with our proposed method for feature disentanglement.In our experiments we used the recent SOTA lung VAE model [27] to create lung masks and implementation of histogram equalization from OpenCV [28].

Experimental setup
Our method depends on a feature extractor model, g(x; θ), to create an initial set of embeddings.We considered three different pre-trained CNNs for this role: Torchxrayvision pretrained DenseNet121 We used the DenseNet121 model [29] weights released in the torchxrayvision package [30,31].This model has been trained on CXR imagery, collected before the COVID-19 pandemic, with multi-task pulmonary disease labels from different datasets.We extracted a 1,024 dimensional feature vector for a given CXR by applying global average pooling after the last convolutional layer in the DenseNet.

ImageNet pretrained DenseNet121
We also used the DenseNet121 model [29] weights from the PyTorch torchvision library [32].This model has been trained on ImageNet, and the only difference in implementation from the torchxrayvision version is that the first convolutional layer operates over 3 channel images instead of 1 channel images.We duplicated input values across the 3 input channels in order to run this model on CXR imagery.We extract features from this DenseNet in the same way as above.

COVID-Net
We used the COVID-Net model and the "COVIDNet-CXR Large" pretrained weights from the same repository that proposed the COVIDx dataset [20].We extract a 2,048 dimensional feature vector for an input CXR by applying global average pooling after the last convolutional layer in the COVID-Net architecture.This is in contrast with the methodology from [20], who flatten the representation after the last convolutional layer to get a 460,800 dimensional representation.
In all experiments we reported average and standard deviation results from 5×10-fold cross validation.In the experiments testing whether we can identify domain labels from pretrained model representations, we stratified by the domain label in order to ensure that there are samples from each dataset in training and testing splits.In the experiments where we measure domain generalization, we report results on the out-of-sample CC-CCII data, we used the out-of-sample data as an additional test set in the cross-validation folds used to test in-sample performance.Specifically, we split the training set into 10-folds, trained a model on data from 9 folds, then tested on the held-out fold, as well as the entire out-of-sample dataset.
Our classification problems are all multi-class; we reported an unweighted average of the area under the ROC curve (AUC) calculated individually for each class in a one-versus-rest manner and average per class accuracy (ACC).Both methods are not sensitive to the class imbalance which we observe in both our class and domain tasks.
In the experiments where we measure domain generalization of the Baseline models, i.e. the models trained without feature disentanglement, we fit an additional logistic regression model on the domain labels, from the representations generated by g(x; θ).This additional step happens after fully training the f(g(x); ϕ) and will allow us to measure how much information about the domain labels the encoder model uses.
When using feature disentanglement, we set f e (z) as a fully connected model with the following structure: Dense(256) !BatchNorm !Dense(64) !BatchNorm, where both dense layers are followed by a ReLU nonlinearity.We set both f c (z 0 ) and f d (z 0 ) as logistic regression layers.We trained with the AdamW optimizer using AMSGrad and an initial learning rate of 0.001.We divided the learning rate by a factor of 10 if the validation class loss has stagnated for over 10 epochs and stopped training either the third time this happens, or after 200 total epochs.
We also decayed λ throughout training; we set λ 0 = 10 and use the following update rule evaluated each epoch (where t is the epoch) to calculate λ: This moves λ from λ 0 to 1 over the course of the maximum 200 epochs of training.We did not thoroughly test this schedule against other choices.We simply aimed for the domain loss to guide training in the initial epochs of training, and for the class loss to converge to a local minimum in the later epochs of training.Finally, all of our experiments are fully reproducible by setting the seed of the random number generators in all component software packages.We varied this seed over the 5x restarts in our experiments.

Identifying domain labels from pre-trained model representations
In Table 2 we show that all of our existing models extract representations that contain enough information, even when run on masked/equalized imagery, for discriminating between domain labels throughout the COVIDx dataset (i.e.determining which sub-dataset from COVIDx a sample belongs to), and domain labels throughout only COVID-19+ samples (i.e.determining where a COVID-19+ sample comes from out of the sub-datasets in COVIDx or the CC-CCII dataset).For example, we find that a DenseNet121 pre-trained on ImageNet extracts features from unmasked CXR images that are sufficient to train a logistic regression model that can identify the source dataset among samples from the COVIDx dataset with a held-out average AUC of 0.95 ± 0.02.The same embeddings can be used to train a logistic regression model that can identify the source dataset given COVID-19+ sample, among the constituent datasets of COVIDx and CC-CCII with a held out AUC of 0.97 ± 0.01.While the preprocessing steps help to reduce this performance, the linear models still perform much better than random guessing, suggesting that overfitting to domain signals in the embedded representations is trivial.

Class and domain performance
In Table 3 we show the within-dataset and generalization performance of models, f(z; ϕ), trained on top of the different feature extractors.Similar to Table 2, the models that are trained in a standard way have high performance on the domain task, and correspondingly high performance on the actual task labels.At the same time, the poor performance of these models on the out-of-sample dataset, CC-CCII, shows that they are overfitting to the training dataset.For example, a model trained on top of torchxrayvision unmasked image embeddings performs similar to random guessing-random guessing performance is 0.33 average accuracy, while this model gets 0.34 average accuracy.Lung masking and histogram equalization improves generalization performance in all cases, however we hypothesize that this may throw out signals that are relevant to the task.
On the other hand, we observe that training f(z; ϕ) with feature disentanglement results in better generalization performance with both unmasked or masked/equalized images in all cases.ImageNet embeddings of masked/equalized images and feature disentanglement training results in the best generalization performance on CC-CCII.Interestingly, training with

Table 2. Results showing how well logistic regression classifiers can identify which sub-dataset a CXR is from within the COVIDx dataset, and how well classifiers can identify which dataset a "COVID-19+" CXR is from across both the COVIDx and CC-CCII datasets. We report AUC values as averages of the one-vs-all binary
AUCs between all classes, and accuracy (ACC) as the average accuracy over all classes.We observe that the representations generated by the classifiers, even from masked/ equalized inputs, contain enough information to accurately identify the sources of the imagery in both cases.

COVIDx datasets
All feature disentanglement on masked/equalized images is not always better than training with feature disentanglement on unmasked images.This supports the idea that lung masking and histogram equalization might throw out relevant task signals.In all cases, training with feature disentanglement dramatically reduces the domain signal in learned representations.For example, the domain performance of the model tuned with torchxrayvision embeddings of unmasked images is 0.94 AUC, while using feature disentanglement reduces this to 0.56 AUC.
The increase in performance from training with feature disentanglement is variable across the pre-trained model used to generate the initial representation.We observe the smallest difference in generalization performance with the COVID-Net pre-trained model which was initially trained on unmasked images from the COVIDx dataset-a smaller dataset than either ImageNet or the CXR dataset used by torchxrayvision-and may not function as an effective feature extractor.Further experiments are needed to determine if unfreezing the parameters of the feature extractor model during training with feature disentanglement is beneficial.We specifically avoid this as we have already found that it is possible to overfit to the training set with a small fully connected model on top of pre-trained embeddings, as well as learn representations that are not predictive of the within-dataset domain labels.Performance gains of fine-tuning through the feature extractor would only be potentially visible through increases in generalization performance, i.e. on the CC-CCII dataset.By correlating such design decisions with CC-CCII performance.We would risk manually overfitting to CC-CCII, and want to avoid doing so in this work.

Feature visualization
In Fig 2 we show the UMAP embeddings of the learned feature representations from the models trained on the torchxrayvision embeddings for images from the COVIDx dataset.This Table 3. Results showing within-dataset class performance, within-dataset domain performance, and out-of-sample class performance from training models with the COVIDx dataset.The task performance (task AUC and task accuracy) shows how well classifiers are able to distinguish between "Normal", "Pneumonia", and "COVID-19+" disease labels, while the domain performance (domain AUC) shows how well classifiers are able to distinguish which sub-dataset an image belongs to.We report AUC values as averages of the one-vs-all binary AUCs between all classes, and accuracy (ACC) as the average accuracy over all classes.In all cases class performance (both within-dataset and out-of-sample) is reported from the classifier trained on samples within-dataset, while domain performance is reported from an additional classifier trained to predict domain labels on top of the learned representations, z 0 , as a measure of how much domain information the representation contains.We observe that using feature disentanglement decreases within-dataset domain performance as expected, and increases out-of-sample class performance-i.e.improves generalization performance.shows how the representations learned without feature disentanglement clearly separate the images from the RSNA dataset from the images from the other datasets, despite not being trained to do this.This is because the RSNA dataset contains most of the samples with "normal" and "pneumonia" disease labels, versus the other datasets, which contain "COVID-19+" samples.Therefore, when a model is trained to separate these classes, it finds features that are correlated with which component dataset an image is from.Such biases can be trivial, such as a different in contrast between the dark and light portions of the image, the position/size of the lungs, markings on the CXR, etc.The same distinction between datasets is not clear when feature disentanglement is used because the learned representations are forced to not be predictive of dataset.As a result, the distinction between "COVID-19+" and "pneumonia" labels becomes less clear when feature disentanglement is used, however, as we show in Table 3, this can improve the generalization ability of the model.

Discussion
Accurate automated diagnosis of COVID-19 pneumonia based on CXRs has been a research focus since the start of the pandemic, given its potential use in emergency departments, urgent care, and resource-limited settings.The development of such devices has been limited by the availability and quality of COVID-19+ datasets, which are typically aggregated from various sources to increase the number of COVID-19+ examples.In this study we show that previously proposed techniques, such as isolating the pulmonary region and harmonizing images taken from different x-ray machines, fail to prevent deep learning models from relying on features that are specific to a particular dataset or image source (e.g.scanner machine type, age of patient, external artifacts in the image).These models are likely to underperform in real-world clinical settings when analyzing images that lack the identifying data the models have learned to rely on.Since the COVID-19 outbreak, various researchers have developed automated COVID-19 CXR diagnosis models.Most previous studies have used transfer learning approaches and compared classification performances obtained between several popular CNN architectures, but all have relied on datasets made up of COVID-19+ CXRs sourced from around the web [12].For COVID-19 negative cases, data are typically sampled from other open CXR datasets.

Fig 2 .
Fig 2. UMAP projections of features learned from models trained with and without feature disentanglement on unmasked imagery.Each point represents a CXR from the COVIDx dataset.The top row colors points by their domain label-which subdataset of the COVIDx dataset they are in-while the bottom row colors points by their disease label.We observe that without feature disentanglement, the learned representations easily separate datasetsdespite not being trained for this task-however, with feature disentanglement, the learned representations do not clearly separate datasets.https://doi.org/10.1371/journal.pone.0274098.g002

Table 1 . Dataset overview. Counts
of disease label type per dataset.The COVIDx dataset is made up of 5 sub-datasets and the CC-CCII dataset is used as a held-out test set.

Fig 1. Overview of the feature disentanglement modeling approach. We
propose to learn a model that simultaneously predicts the class label and domain label for a given CXR image.The parameters of the model are updated to extract representations that contain information about the class label but not about domain label.https://doi.org/10.1371/journal.pone.0274098.g001