A multiple instance learning approach for detecting COVID-19 in peripheral blood smears

A wide variety of diseases are commonly diagnosed via the visual examination of cell morphology within a peripheral blood smear. For certain diseases, such as COVID-19, morphological impact across the multitude of blood cell types is still poorly understood. In this paper, we present a multiple instance learning-based approach to aggregate high-resolution morphological information across many blood cells and cell types to automatically diagnose disease at a per-patient level. We integrated image and diagnostic information from across 236 patients to demonstrate not only that there is a significant link between blood and a patient’s COVID-19 infection status, but also that novel machine learning approaches offer a powerful and scalable means to analyze peripheral blood smears. Our results both backup and enhance hematological findings relating blood cell morphology to COVID-19, and offer a high diagnostic efficacy; with a 79% accuracy and a ROC-AUC of 0.90.

1 Model Training and Architecture

Training Infrastructure
To implement the neural network models used throughout this work we used PyTorch [1]. During the training process we applied data augmentation to individual blood cell images across both branches and their respective training regimes. The images were augmented using Torchvision [2] with standard horizontal/vertical flipping, random rotations, and color jittering. We tracked performance on the held-out validation set during model training using Weights & Biases [3], with the final model used for evaluation being the iteration of the model which performed best on the validation set. To test our system across all available data, we used repeated stratified k-fold cross validation from Scikit-Learn [4] on a patient level based their COVID-19 test result. There were a total of 236 patients, 125 which were labelled positive for COVID-19. We first split our data into a 83%-17% training/testing split, then split the training data again with a 80%-20% split for our final training and validation sets. The proportion of COVID-19 positive patients was approximately balanced within each set, and there was no cross-over of patient data from set-to-set. We rotated through our dataset, for a total of six folds (a single fold consisting of a unique training/validation/test set). Unless otherwise indicated, all performance metrics are reported as the average test-set performance across all six folds, where multiple independent models were trained from scratch exclusively on the individual folds. This strategy enabled us to test our system on all available data while isolating the test data during the training process.

Model Details
The single-image model was trained for a total of 800 epochs. To generate validation metrics representative of our task we evaluated our model by taking the average of all outputs across a patient's data (i.e. running every image included within a patient's PBS and taking the average output as the patient prediction). The final model selected was the model which had peak validation performance, which is evaluated at the end of every epoch.
The multi-image branch of our system consists of three steps: feature extraction, attention based aggregation, and classification. For the feature extraction portion of our pipeline we chose the ResNet50 [5] neural network architecture, configured to produce a feature vector, f i ∈ R 32 , for each image i. The attention mechanism used a multilayer perceptron (MLP) and a tanh activation to produce a single importance value per image followed by a softmax function. The produced feature weights are used to integrate the feature vector. Finally this combined feature vector is transformed into a classification with a single linear layer.
Several components of our multi-image branch were critical to successful performance. The most important component was the use of a unique random image subset generation process for model training, which was not used during inference. Specifically, during model training, we randomly selected a subset of 16 images per patient (from all available per-patient images) to form our model input. This random set of 16 training images changed at each iteration step of model optimization. During model inference, however, we used all available images from each patient (see Fig 2 for the distribution of image count across patients). This random sampling strategy first helped address high memory usage encountered during network training, Second, we hypothesize that it also had the beneficial effect of acting as a regularization strategy (akin to image augmentation), preventing the network from utilizing relatively few images within a patient's PBS image set to drive predictions.
We trained our multi-image model for 2000 epochs. Similar to the single-image model validation performance was evaluated by examining the prediction of the patient's infections status using all of their data, and the model weights which achieved peak performance were saved for final evaluation.  [7], which used the same model of imaging system (Cellavision DM9600) used within our work. Within their dataset they classified WBCs into eight distinct classes: Neutrophils, Eosiophils, Basophils, Lymphocytes, Monocytes, Immature Granulocytes (IG), Erythroblasts, and Platelets. However, we found that within our dataset there were a large number of what are canonically known as Smudged Cells, these are WBCs which are broken during the slide making process [8], these cells were not represented within the dataset provided by Acevedo et al. [7]. To compensate for this exclusion we manually labelled 2000 images within our main dataset and included them with the 17,000 images of the other classes.
Using this expanded dataset we trained a convolutional neural network (CNN) to predict the classification of each cell. We followed a similar procedure to the single-image classifier used for detecting COVID-19, using the same set of image augmentations and the ResNet50 [5] network architecture, trained from random initialization. We used the AdamW optimizer [6] with a fixed learning rate of 3×10 −4 , and trained the model for 100 epochs. We used stratified K-Fold cross validation to ensure equal proportions of each cell type across our training, validation, and test sets of 70%, 15%, and 15% respectively. With this training strategy our model was able to achieve an average classification accuracy across our test set of 92%

Scan Data Statistics
To understand the underlying trends within our data we plotted the distribution of WBC types (determined by our automated WBC classifier). Within Fig 1 we can see that the most common cell is the neutrophil, constituting ∼ 35% of the images, while other cell types occur at a medium to low frequency. We also examined the number of images per-patient, as shown in Fig 2. Here we can see that there is no meaningful difference between image counts between condition (COVID-19 positive vs. negative), however there is a great deal of variation across the entire dataset.