Mixture-of-Experts Variational Autoencoder for clustering and generating from similarity-based representations on single cell data

Clustering high-dimensional data, such as images or biological measurements, is a long-standing problem and has been studied extensively. Recently, Deep Clustering has gained popularity due to its flexibility in fitting the specific peculiarities of complex data. Here we introduce the Mixture-of-Experts Similarity Variational Autoencoder (MoE-Sim-VAE), a novel generative clustering model. The model can learn multi-modal distributions of high-dimensional data and use these to generate realistic data with high efficacy and efficiency. MoE-Sim-VAE is based on a Variational Autoencoder (VAE), where the decoder consists of a Mixture-of-Experts (MoE) architecture. This specific architecture allows for various modes of the data to be automatically learned by means of the experts. Additionally, we encourage the lower dimensional latent representation of our model to follow a Gaussian mixture distribution and to accurately represent the similarities between the data points. We assess the performance of our model on the MNIST benchmark data set and challenging real-world tasks of clustering mouse organs from single-cell RNA-sequencing measurements and defining cell subpopulations from mass cytometry (CyTOF) measurements on hundreds of different datasets. MoE-Sim-VAE exhibits superior clustering performance on all these tasks in comparison to the baselines as well as competitor methods.


Author summary
Clustering single cell measurements into relevant biological phenotypes, such as cell types or tissue types, is an important task in computational biology. We developed a computational approach which allows incorporating prior knowledge about the single cell similarity into the training process, and ultimately achieve significant better clustering performance compared to baseline methods. This single cell similarity can be defined to benefit specific needs of the modeling goal, for example to either cluster cell type or tissue type, respectively.
In addition, we are able to generate new realistic single cell data from a respective mode of the phenotype due to the architecture of the model, which consists of smaller

Introduction
Clustering has been studied extensively [1,2] in machine learning and has found wide application in identifying grouping structure in high dimensional biological data such as various omics data modalities. Recently, many Deep Clustering approaches were proposed, which modified (Variational) Autoencoder ((V)AE) architectures [2,3] or by varying regularization of the latent representation [4][5][6][7].
The reconstruction error usually drives the definition of the latent representation learned from an AE or VAE. The representation for AE models is unconstrained and typically places data objects close to each other according to an implicit similarity measure that also yields favorable reconstruction error. In contrast, VAE models regularize the latent representation such that the represented inputs follow a certain variational distribution. This construction enables sampling from the latent representation and data generation via the decoder of a VAE. Typically, the variational distribution is assumed standard Gaussian, but for example Jiang et al. [7] introduced a mixture-of-Gaussians variational distribution for clustering purposes.
A key component of clustering approaches is the choice of similarity metric for the considered data objects which we try to group [8]. Such similarity metrics are either defined a priori or learned from the data to specifically solve classification tasks via a Siamese network architecture [9]. Dimensionality reduction approaches, such as UMAP [10] or t-SNE [11], allow to specify a similarity metric for projection and thereby define the data separation in the inferred latent representation.
In this work, we introduce the Mixture-of-Experts Similarity Variational Autoencoder (MoE--Sim-VAE), a new deep architecture that performs similarity-based representation learning, clustering of the data and generation of data from each specific data mode. Due to a combined loss function, it can be jointly optimized. We empirically assess the scope of the model and present superior clustering performance on the canonical benchmark MNIST. Moreover, in an ablation study, we show the efficiency and precision of MoE-Sim-VAE for data generation purposes in comparison to the most related state-of-the-art method [7]. We achieve superior results on the identification of tissue-or cell type groupings via MoE-Sim-VAE on a murine single-cell RNAsequencing atlas and mass cytometry measurements of Peripheral Blood Mononuclear Cells.

MoE-Sim-VAE
Here we introduce the Mixture-of-Experts Similarity Variational Autoencoder (MoE-Sim-VAE, Fig 1). The model is based on the Variational Autoencoder [12]. While the encoder network is shared across all data points, the decoder of the MoE-Sim-VAE consists of a number of K different subnetworks, forming a Mixture-of-Experts architecture [13]. Each subnetwork constitutes a generator for a specific data mode and is learned from the data.
The variational distribution over the latent representation is defined to be a mixture of multivariate Gaussians, first introduced by Jiang et al. [7]. In our model, we aim to learn the mixture components in the latent representation to be standard Gaussians where ω k are mixture coefficients, μ k are the means for each mixture component, I is the identity matrix and K is the number of mixture components. The dimension of the latent representation z needs to be defined to suit the demands of Gaussian mixtures which have limitations in higher dimensions [14]. Similar to optimizing an Evidence Lower Bound (ELBO), we penalize the latent representation via the reconstruction loss of the data L reconst and by using the Kullback-Leibler (KL) divergence for multivariate Gaussians [7] on the latent representation where k is a constant, N 0 � N ðμ 0 ; Σ 0 ¼ IÞ, and I is the identity matrix. Further, PLOS COMPUTATIONAL BIOLOGY estimated from the samples of the latent representation. Finally, we assume μ 0 = μ 1 resulting in the following simplified objective to penalize exclusively the covariance of each cluster. It remains to define the reconstruction loss L reconst , where we choose a Binary Cross-Entropy (BCE) between the original data x (scaled between 0 and 1) and the reconstructed data x reconst , where i iterates the batch size N and d the dimensions of the data D. We motivate the BCE loss due to better convergence properties using artificial neural networks in comparison to mean squared error [15]. Finally the loss for the VAE part is defined by with a weighting coefficient π 1 which can be optimized as a hyperparameter. Similarity clustering and gating of latent representation. Training of a data mode-specific generator expert requires samples from the same data mode. This necessitates to solve a clustering problem, that is, mapping the data via the latent representation into K clusters, each corresponding to one of the K generator experts. We solve this clustering problem via a clustering network, also referred to as gating network for MoE models. It takes as input the latent representation z i of sample i and outputs probabilities p ik 2 [0, 1] for clustering sample i into cluster k. According to this cluster assignment, sample i is then gated to expert k = argmax k p ik for each sample i. We further define the cluster centers μ k for k 2 {1, . . ., K} similar as in the Expectation Maximization (EM) algorithm for Gaussian Mixture models [16] as where N k is the absolute number of data points assigned to cluster k based on highest probability p ik for each sample i = 1, . . ., N. The Gaussian mixture distributed latent representation (via KL loss in Eq 3) is motivation for the empirical computation of the cluster means and further, similar as in the EM algorithm, it allows iterative optimization of the means of the Gaussians. We train the clustering network to reconstruct a data-driven similarity matrix S, using the Binary Cross-Entropy to minimize the error in PP T � S, with P ≔ {p ik } i2{1,. . .,N},k2{1,. . .,K} where N is the number of samples (e.g., batch size). Intuitively, PP T approximates the similarity matrix S since values in PP T are only close to 1 when similar data objects are assigned to the same cluster, similar to the entries in the adjacency similarity matrix S. This similarity matrix is derived in an unsupervised way in our experiments (e.g. UMAP projection of the data and k-nearest-neighbors or distance thresholding to define the adjacency matrix for the batch), but can also be used to include weakly-supervised information (e.g., knowledge about diseased vs. non-diseased patients). If labels are available, the model could even be used to derive a latent representation with supervision. The similarity feature in MoE-Sim-VAE thus allows to include prior knowledge about the best similarity measure on the data. Moreover, we apply the DEPICT loss from Dizaji et al. [4], to improve the robustness of the clustering. For the DEPICT loss, we additionally propagate a noisy probabilityp ik through the clustering network using dropout after each layer. The goal is to predict the same cluster for both, the noisyp ik and the clean probability p ik (without applying dropout). Dizaji et al. [4] derived as objective function a standard cross-entropy loss whereby q ik is computed via the auxiliary function We refer to Dizaji et al. [4] for the exact derivation. The DEPICT loss encourages the model to learn invariant features from the latent representation for clustering with respect to noise [4]. Looking at it from a different perspective, the loss helps to define a latent representation which has those invariant features to be able to reconstruct the similarity and therefore the clustering correctly. The complete clustering loss function L Clustering is then defined by with a mixture coefficient π 2 which can be optimized as a hyperparameter.

MoE-Sim-VAE loss function. Finally, the MoE-Sim-VAE model loss is defined by
which consists of the two main loss functions L VAE , acting as a regularization for the latent representation, and L Clustering , which helps to learn the mixture components based on an a priori defined data similarity. The model objective function L MoEÀ SimÀ VAE can then be optimized endto-end to train all parts of the model.

Related work
(V)AEs have been extensively used for clustering [1,[4][5][6][17][18][19][20]. The most related approaches to MoE-Sim-VAE are Jiang et al. [7] and Zhang et al. [3]. Jiang et al. [7] introduced the VaDE model, comprising a mixture of Gaussians as underlying distribution in the latent representation of a Variational Autoencoder. Optimizing the Evidence Lower Bound (ELBO) of the log-likelihood of the data can be rewritten to optimize the reconstruction loss of the data and KL divergence between the variational posterior and the mixture of Gaussians prior. Jiang et al. [7] use two separate networks for reconstruction and the generation process. Further, to effectively generate images from a specific data mode and to increase image quality, the sampled points have to surpass a certain posterior threshold and are otherwise rejected. This leads to an increased computational effort. The MoE Decoder of our model, which is used for both reconstruction and generation, does not need such a threshold.
Zhang et al. [3] have introduced a mixture of autoencoders (MIXAE) model. The latent representation of the MIXAE is defined as the concatenation of the latent representation vectors of each single autoencoder in the model. Based on this concatenated latent representation, a Mixture Assignment Network predicts probabilities which are used in the Mixture Aggregation to form the output of the generator network. Each AE model learns the manifold of a specific cluster, similarly to our MoE Decoder. However, MIXAE does not optimize a variational distribution, such that generation of data from a distribution over the latent representation is not possible, in contrast to the MoE-Sim-VAE (Fig 2).

Results
In the following we report superior clustering and generating results of MoE-Sim-VAE on real world problems. First, we evaluate MoE-Sim-VAE on images from MNIST and show why a MoE decoder is beneficial. Second, we present significantly better clustering results on mouse organ single-cell RNA sequencing data. Third, we apply MoE-Sim-VAE to cluster cell types in Peripheral Blood Mononuclear Cells using CyTOF measurements on 272 distinct data sets significantly better than competitors. (Exact model and optimization details as well as preprocessing steps for all experiments can be found in S1 Text)

Unsupervised clustering, representation learning and data generation on MNIST
We trained a MoE-Sim-VAE model on images from MNIST. We compared our model against multiple models which were recently reviewed in Aljalbout et al. [1], and specifically against VaDE [7] which shares similar properties with MoE-Sim-VAE. The VaDE model is comprising a mixture of Gaussians as underlying distribution in the latent representation of a Variational Autoencoder (more detailed comparison in Section Related work).

PLOS COMPUTATIONAL BIOLOGY
We compare the models with the Normalized Mutual Information (NMI) criterion but also classification accuracy (ACC) ( Table 1). The MoE-Sim-VAE outperforms the other methods w.r.t. clustering performance when comparing NMI and achieves the second-best result when comparing ACC. Note that for comparability reasons we used the number of experts k = 10 in our model to fit the existing number of digits in MNIST. To prove that MoE-Sim-VAE is able to learn the correct number of experts, we report a study on synthetic data in supporting information (S1 Text and S1 Fig).
We use a UMAP projection [10] of MNIST as our similarity measure and then apply knearest-neighbors of each sample in a batch. In an ablation study, we show the importance of the similarity matrix to create a clear separation of the different digits in the latent representation. Therefore, we computed a test statistic based on the Maximum Mean Discrepancy (MMD) [28,29] which can be used to test if two samples are drawn from the same distribution (see Section 1.2 in S1 Text). In this work we use MMD to test if samples of different clusters of the latent representation are similar. When sampling twice from the same cluster we get an average MMD test statistic of t sim = −0.05 with, and t = −0.11 without similarity matrix, whereas the average distance between samples from two different clusters is significantly larger when training with similarity matrix t sim = 221.66 compared to when training without t = 49.29. This clearly suggests better separation on the latent representations between the clusters when being able to define a respective similarity (S2 Fig).
In addition to the clustering network, we can make use of the latent representation for image generation purposes. The latent representation is trained as a mixture of standard Gaussians. The means of these Gaussians are the centers of the clusters trained via the clustering network. Therefore, the variational distribution can be sampled from and gated to the clusterspecific expert in the MoE-decoder. The expert then generates new data points for the specific data mode. Results and the schematic are displayed in

PLOS COMPUTATIONAL BIOLOGY
In an ablation study, we compare the two models MoE-Sim-VAE and VaDE [7] on generating MNIST images with the request for a specific digit. The goal is to show that a MoE decoder, as proposed in our model, is beneficial. We focus our comparison to VaDE since this model, as the MoE-Sim-VAE, resorts to a mixture of Gaussian latent representation but differs in generating images by means of a single decoder network instead of a Mixture-of-Expert decoder network. The rationale for our design choice is to ensure that smaller sub-networks learn to reproduce and generate specific modes of the data, in this case of specific MNIST digits.
To show that both models' latent representations are separating the different clusters well, we computed the Maximum Mean Discrepancy (MMD) [28], similar as introduced above. An MMD statistic of t MoE-Sim-VAE = 256.31 and t VaDE = 355.14 suggests separation of the clusters when sampling in the latent representations of both models. Therefore, both latent representations can separate the clusters of respective digits well, such that the decoder gets well-defined samples to generate the requested digit. Hence, the main difference of generating specific digits arises in the decoder/generator networks (S3 Fig).
We evaluated the importance of the MoE-Decoder to (1) accurately generate requested digits and (2) be efficient in generating requested digits. Specifically, we sampled 10, 000 points from each mixture component in the latent representation, generated images, and used the model's internal clustering to assign a probability to which digits were generated. To generate correct and high-quality images with VaDE, the posterior of the latent representation needs to be evaluated for each sample. This was done for the different thresholds ϕ 2 [0.0, 0.1, 0.2, � � �, 0.9, 0.999]. The default threshold [7] used was ϕ = 0.999. To compare the separation of the clusters in the latent representation above using MMD, we used a threshold of only ϕ = 0.8, which already is enough to have higher separation based on MMD. Instead of thresholding the latent representation, we ran the generation process for MoE-Sim-VAE for each threshold with the same settings. To generate images from VaDE we used the Python implementation (https://github.com/slim1017/VaDE) and model weights publicly available from Jiang et al. [7].
As a result the MoE-Sim-VAE generates digits more accurately with fewer resources required, especially when comparing the number of iterations required to fulfill the default posterior threshold of 0.999. VaDE needs nearly 2 million iterations to find samples that fulfill the aforementioned threshold criterion whereas the MoE-Sim-VAE only requires 10, 000 for a comparable sample accuracy. In comparison, the mean accuracy over all thresholds for MoE--Sim-VAE is 0.970, whereas VaDE reaches on average only 0.944 (S4, S5 and S6 Figs).

Clustering organ-specific single cell RNA-seq data
Single-cell RNA-sequencing (scRNA-seq) measurements allow measuring transcriptomes of tens of thousands of single cells. Clustering of the resulting data into groups representing biological phenotypes, such as cell type or tissue type, constitutes a major analysis task in scRNAseq studies. In the following, we present how MoE-Sim-VAE outperforms the methods Gaussian Mixture Models (GMM), k-means, hierarchical clustering, HDBSCAN, fuzzy-c-means (FCM), Louvain and scVI. [30][31][32][33] for clustering the scRNA-seq data of the Tabula Muris study covering seven different mouse organs [34]. The method scVI is a well established deep generative modeling framework designed for single cell transcriptomic data modeling the count data with a Poisson distribution, and allows to perform several downstream analysis tasks, such as clustering. Also in this example we used MoE-Sim-VAE with a BCE loss instead of a mean squared error loss for MoE-Sim-VAE, motivated due to better convergence properties in combination with artificial neural networks [15], and due to literature where BCE was used as reconstruction loss on RNA-seq data for visualization purposes [35].
MoE-Sim-VAE allows for incorporation of a user defined similarity and therefore also prior knowledge about the data. From literature we identified for each organ a representing signature gene and use each to encode prior expectation of organ assignment for each cell measurement. Namely, we take advantage of the high expressions of Lpl in the heart, Miox in kidney, Hpx in liver, Tspan1 in large intestine, Prx in lung, Cd79a in spleen and Dntt in thymus [36][37][38][39]. Single cells which show above average expression of a respective signature gene are considered to be similar. For the training of MoE-Sim-VAE we only considered cells which show above average expression in exactly one of the respective organ-specific signature genes (does not apply for the test data). To highlight the influence of the similarity prior in this example, we report the average accuracy of 0.92 of having a correct similarity assignment for each organ based on above average expression (S1 Table).
MoE-Sim-VAE outperforms all above mentioned baseline approaches in clustering the single cells with respect to the organ of origin. Our model reaches a F-measure of 0.748 and is therefore close to 0.1 better compared to the second best competitor. We performed a hyperparameter screening for the competitor methods (more details in Chapter 4 in S1 Text) and chose the best results achieved on the test dataset based on the F-measure as well. In Table 2 we present the exact results in detailed comparison. In Fig 3A and 3B) we show a Principal Component Analysis (PCA) of the original data as well as of the latent representation of MoE-Sim-VAE. It can be seen that the organs are better separated in the latent representation inferred from our model which enables for better clustering results of MoE-Sim-VAE (Fig 3C)). In Fig 3D-3I) we present the results of the competitor methods in the latent representation of MoE-Sim-VAE and can clearly see that Louvain performs second best, but poorly separates cells from organs which are close to each other or overlap in the original PCA representation. Fig 3J) visualizes the Leiden clustering results on the latent representation inferred from scVI. This latent space shows a more detailed separated latent representation but with overlapping or separated true labels. For example, samples from the heart are separated in up to seven different groups. This leads to less precise clustering results concerning the task of identifying tissue types, but might be beneficial when clustering cell types. This also highlights

PLOS COMPUTATIONAL BIOLOGY
the importance of being able to incorporate prior knowledge when inferring latent representations for specific clustering tasks, such as grouping tissue types.

Learning cell type composition in peripheral blood mononuclear cells using CyTOF measurements
In the following, we want to assess representation learning performance on the real-world problem of cell type definition from single-cell measurements. Cytometry by time-of-flight mass spectrometry (CyTOF) is a state-of-the-art technique allowing measurements of up to 1, 000 cells per second and in parallel over 40 different protein markers of the cells [40]. Defining biologically relevant cell subpopulations by clustering this data is a common learning task [41,42]. Many methods have been developed to tackle the problem introduced above and were compared on four publicly available datasets in Weber and Robinson [42]. The best out of 18 methods were FlowSOM [43], PhenoGraph [44] and X-shift [45]. These are based on k-nearestneighbors heuristics, either defined from a spanning graph or from estimating the data density. In contrast to these methods, MoE-Sim-VAE can map new cells into the latent representation, assign probabilities for cell types, and infer an interpretable latent representation, allowing intuitive downstream analysis by domain experts.
We applied MoE-Sim-VAE to the same datasets as in Weber and Robinson [42] and achieve superior results in classification using the F-measure [41] in three out of four datasets. Similarly as in Weber and Robinson [42], we trained MoE-Sim-VAE 30 times and report in Table 3 (adopted from Weber and Robinson [42]) the means and standard deviation across all runs (S7 Fig). As a MoE-Sim-VAE similarity measure we used a UMAP projection with Canberra distance [46] as metric and computed similarly to the MNIST experiments the k-nearest-neighbors of each sample in the batch. This applies for all CyTOF experiments.

PLOS COMPUTATIONAL BIOLOGY
Further, we trained a MoE-Sim-VAE model with a fixed number of experts k = 15 (thereby slightly overestimating the true number of subpopulations) on 268 datasets from Bodenmiller et al. [47] and achieve superior clustering results of cell subpopulations in the data when comparing to state-of-the-art methods in this field (PhenoGraph, X-Shift, FlowSOM). Results are summarized in Fig 4 as well as exactly listed in S2 Table. Furthermore, we visualize in detail the reconstruction of the original data per expert using a Principal Component Analysis on the original data space. This visualization also shows that many experts were silenced during the training, since only seven out of possible 15 experts where selected (S8 Fig).

Discussion
Our MoE-Sim-VAE model can infer similarity-based representations, perform clustering tasks, and efficiently as well as accurately generate high-dimensional data. The training of the model is performed by optimizing a joint objective function consisting of data reconstruction, clustering, and KL loss, where the latter regularizes the latent representation. On the benchmark dataset of MNIST, we present superior clustering performance and the efficiency and accuracy of MoE-Sim-VAE in generating high-dimensional data. On the biological real-world tasks of clustering mouse organs and defining cell subpopulations in complex single-cell data, we show superior performances compared to state-of-the-art methods on a vast range of over 270 datasets and therefore demonstrate the MoE-Sim-VAE's real-world usefulness.
To achieve outstanding clustering performances the choice of the similarity measure as well as the hyperparameter tuning, such as for the loss coefficients, play a crucial role. As shown in the ablation study for clustering MNIST, setting the similarity clustering loss to zero has a tremendous effect on the learned latent representation and the clustering performance. In general, we could observe that the loss coefficients of the reconstruction loss and clustering loss need to be selected close to one, whereas the loss coefficient for the KL loss is closer to zero. A less crucial role plays the selection of number of experts, as shown on clustering synthetic data, on the example of clustering mouse organs based on single cell RNA-sequencing data, or when clustering cell types from mass cytometry measurements. Even when defining more experts than the number of expected clusters, MoE-Sim-VAE did not target each single expert. A minimum number of required experts to distribute the different modes of the data with respect to the defined similarity where selected by the model.
Future work might include to add adversarial training to the MoE decoder, which could improve image generation to create even more realistic images. Also, specific applications might benefit from replacing the Gaussian with a different mixture model. Especially biological data is not always generated from Gaussian distributions. So far the MoE-Sim-VAE's similarity measure has to be defined by the user. Relaxing this requirement and allowing for learning a useful similarity measure automatically for inferring latent representations will be an interesting extension to explore. This could be useful in a weakly-supervised setting, which often occurs for example in clinical data consisting of healthy and diseased patients. Minor details between a healthy and diseased patient might make a huge difference and could be learned from the data using neural networks.
In summary, we expect the MoE-VAE model, as well its future extensions, to be a valuable contribution to the computational biology toolbox to identify biological group structure in high-dimensional molecular data modalities under consideration of weak prior knowledge, in particular including single-cell omics data. In comparison, B) shows results of the MoE-Sim-VAE model ignoring the similarity matrix setting the loss coefficient to zero. One can observe that the MMD statistic, which can be seen as a similarity measure of two distributions, is way lower compared to the model including the similarity matrix. Further, also the UMAP projection confirms less separation in the latent representation between the different digits. which can be seen as the distance between pairs of distributions. The figures on the right side show the separation of the cluster in the latent representation based on a dimensionality reduction via UMAP [10]. A) shows the results for the clusters of VaDE at a posterior threshold of 0.8 which is the first threshold which shows total separation of all clusters. B) shows the separation of the clusters in latent space learned from MoE-Sim-VAE. For both methods, all distributions belonging to clusters of different respective digits show a larger distance compared to the diagonal of matching distributions, such that we generate images from a well-separated latent representation for both methods and therefore the main difference comes from the decoders. (EPS)

S4 Fig. Comparison of data generation process between Moe-Sim-VAE and VaDE. A)
shows the accuracy of how certain a specific digit can be generated from the respective cluster in the latent representation whereas B) compares the number of runs until a sample from the latent representation satisfied the posterior criterion from VaDE. It needs to be mentioned that MoE-Sim-VAE does not require any thresholding such that we ran the data generation process multiple times with the same settings to compare with VaDE. In total 10000 samples are generated for each digit.  Table. Signature gene accuracy. Accuracy of assigning a organ similarity based on high gene expression of prior selected organ specific signature genes for the split training and test data set. We computed the balanced accuracy for each single organ vs. the rest, respectively. (XLS) S2 Table. Exact results on 268 mass cytometry experiments. CyTOF measurements from peripheral blood mononuclear cells (PBMCs) were taken and the goal is to define the different cell types present in the data. The ground truth was defined using the SPADE algorithm [48], which can visualize the high dimensional data in such a way to be able to manual gate the cells. We compare to other fully unsupervised methods as FlowSOM, X-shift and PhenoGraph and achieve in most cases the best F-measure. (XLS)