Application of explainable ensemble artificial intelligence model to categorization of hemodialysis-patient and treatment using nationwide-real-world data in Japan

Background Although dialysis patients are at a high risk of death, it is difficult for medical practitioners to simultaneously evaluate many inter-related risk factors. In this study, we evaluated the characteristics of hemodialysis patients using machine learning model, and its usefulness for screening hemodialysis patients at a high risk of one-year death using the nation-wide database of the Japanese Society for Dialysis Therapy. Materials and methods The patients were separated into two datasets (n = 39,930, 39,930, respectively). We categorized hemodialysis patients in Japan into new clusters generated by the K-means clustering method using the development dataset. The association between a cluster and the risk of death was evaluated using multivariate Cox proportional hazards models. Then, we developed an ensemble model composed of the clusters and support vector machine models in the model development phase, and compared the accuracy of the prediction of mortality between the machine learning models in the model validation phase. Results Average age of the subjects was 65.7±12.2 years; 32.7% had diabetes mellitus. The five clusters clearly distinguished the groups on the basis of their characteristics: Cluster 1, young male, and chronic glomerulonephritis; Cluster 2, female, and chronic glomerulonephritis; Cluster 3, diabetes mellitus; Cluster 4, elderly and nephrosclerosis; Cluster 5, elderly and protein energy wasting. These clusters were associated with the risk of death; Cluster 5 compared with Cluster 1, hazard ratio 8.86 (95% CI 7.68, 10.21). The accuracy of the ensemble model for the prediction of 1-year death was 0.948 and higher than those of logistic regression model (0.938), support vector machine model (0.937), and deep learning model (0.936). Conclusions The clusters clearly categorized patient on their characteristics, and reflected their prognosis. Our real-world-data-based machine learning system is applicable to identifying high-risk hemodialysis patients in clinical settings, and has a strong potential to guide treatments and improve their prognosis.


Introduction
The mortality rates of dialysis patients are very high and the number of prevalent end-stagekidney disease (ESKD) patients has been increasing in the USA and Japan [1,2]. To improve their prognosis, early identification of patients at a high risk of death, and interventional treatments of their conditions are necessary.
Various risk factors for death in dialysis patients have been identified [3][4][5]. These risk factors are associated with each other forming a complex network which should be simultaneously taken into account and controlled [6][7][8]. The Dialysis Outcomes and Practice Patterns Study (DOPPS) has defined a survival index to predict the hemodialysis patients' risk of death using logistic regression models [6]. We also have developed a nutritional risk index (NRI) for hemodialysis patients using Cox proportional hazards models [8]. However, these indices make some statistical assumptions which limit their application; they also take time to calculate, which is inconvenient when dealing with many patients in clinical settings. The development of a new automatic system is needed to help manage various risk factors simultaneously, and to improve the prognosis of a large number of patients.
Artificial intelligence (AI) methods hold great promise for decision-making in complex systems including those used in medicine for diagnosis and prediction [9,10]. Although AI is useful to accurately diagnose patients at a high risk of death, only a few studies on the prediction of ESKD patients' prognosis have been carried out [11][12][13]. Difficulties constructing AI algorithms for clinical use have been pointed out, such as the scarce availability of reliable and large data sets for AI algorithm construction, the lack of transparency of conventional AI algorithms, the difficult integration of AI algorithms into complex existing clinical work flow, and the cumbersome compliance with regulatory medical frameworks [14]. Overcoming some or all of these difficulties is required to create a new AI-based system for ESKD patients.
Therefore, in this study, we aim to establish an implementable AI system for screening hemodialysis patients at a high risk of death and for predicting their prognosis on the basis of real-world data from the Japanese Society for Dialysis Therapy (JSDT) Renal Data Registry (JRDR). JRDR is a nationwide-data registry and includes 98.8% of ESKD patients in Japan [1]. To provide transparency and accuracy of AI predictions, an ensemble model composed of the K-means method and a support vector machine (SVM) was developed. Then, the performance of the proposed model was compared with that of a SVM-alone model, a deep learning model, and a multivariate logistic regression model. Moreover, considering their usage and applicability to clinical settings, we developed a new total-care system for treating hemodialysis patients at a high risk of death.

Dataset
This is a prospective cohort study of maintenance hemodialysis patients using JRDR data. JSDT has been conducting annual surveys of dialysis facilities in Japan since 1968. The JRDR data from 2008 to 2013 were used in this study. This study was approved by the ethics committee of JSDT and was exempt from the need to obtain informed consent from participants (JSDT No. 33). The data were analyzed anonymously. The study was performed in accordance with the relevant guidelines and the Declaration of Helsinki of 1975 as revised in 1983.
The subjects of this study were the 275,553 patients (Fig 1). The exclusion criteria were as follows: patients younger than twenty years; patients on hemodiafiltration, hemofiltration, or peritoneal dialysis; patients with missing values or outlier values of laboratory data; patients who had a limb amputated; and patients with a hemodialysis vintage of less than one year. Thus, 79,860 patients were included in the analysis. The included subjects were randomly classified into two groups to obtain a dataset for the development of the machine learning algorithms (development dataset, 39,930) and a dataset for the validation of the algorithms (validation dataset, 39,930).

Statistical analyses
Normally distributed variables are presented as mean±standard deviation; otherwise, the median and interquartile ranges are presented.

Development of machine learning models
The variables of the baseline characteristics were Z-score-normalized and used for the following modeling. K-means method model.
Step 1: Patients were grouped into clusters from 2 to 10 on the basis of their baseline characteristics by the K-means method using the basis of the development dataset. Patients with similar characteristics were grouped in one cluster, and the patients in other clusters showed dissimilar characteristics. First, patients were randomly selected as initial cluster centers. Next, each patient was assigned to one cluster on the basis of the closeness of their characteristics to the cluster center. The mean of samples in a cluster was calculated as the new cluster center, μ. These steps were repeated until the final stable clustering results were obtained. The similarity between a patient x and a center μ in a cluster was evaluated using the Euclidean distance in an m-dimensional space, dist(x,μ): where j is the j th variable of the baseline characteristics, m is the number of variables of the baseline characteristics; in this study, m = 20.
Step 2: To evaluate the clustering, the within-cluster sum of squared errors (SSEs), namely, distortion J, was measured: where μ j is the center for cluster j, if x i is in cluster j, r ij = 1, else r ij = 0, k is the number of clusters, and n is the number of patients.
To use a gradient-based optimizer for J, Eq (3) is partially differentiated by μ j to obtain: The elbow method was used to identify the number of clusters where the within-cluster SSE decreased rapidly. Next, to evaluate whether the clusters could discriminate the patients on the basis of their risks of the endpoints, the survival probabilities of the clusters were evaluated using Kaplan-Meier survival curves. The clusters were indicated by numbers on the basis of the risks, and Cox proportional hazards models were evaluated to compare the risk of an endpoint between clusters. The Cox proportional hazards models were developed including only the cluster used as a categorical variable because the K-means method can be considered as a function which was composed of variables of the baseline characteristics: Hazard ratio results (HRs) with 95% confidence interval (CI) are presented here.
Step 3: The patients in the validation dataset were grouped into clusters using the K-means method trained using the development dataset. Then, the relationship between the clusters and the risk of the endpoints were evaluated using Kaplan-Meier survival curves, and Cox proportional hazards models. Considering the results, the optimal number of clusters, k, was determined, and the differences in characteristics between the clusters were statistically evaluated.
Multivariate logistic regression model. To predict the probabilities of the endpoints, multivariate logistic regression models (LRMs) including all variables of the baseline characteristics were developed using the development dataset as follows: where x i is the i th variable of the baseline characteristics, and β i is the parameter estimate for the same variable. When p was estimated to be more than 0.5, a patient's death was predicted. Then, using the validation dataset, we evaluated the accuracy of the prediction using the LRMs. Support vector machine models. SVM models were used to predict the endpoints. SVM models with a Gaussian radial basis function kernel included all of the variables of the baseline characteristics. In the development of each SVM model, classification was examined on the basis of the three-fold cross validation method, and the accuracy of the prediction was estimated by taking the three results. Then, the final SVM models were developed. Using the validation dataset, we evaluated the accuracy of the prediction of the endpoints using the SVM models developed.
Ensemble model. Using the development dataset, we grouped the patients into the k clusters previously determined by the K-means method (Fig 2). Each SVM model including all of the variables of the baseline characteristics for each cluster was trained to predict the risk of the endpoints. And k SVM models, F(Cluster i), were developed. where φ(Cluster i) is a SMV model for Cluster i. Then, the patients in the validation dataset were grouped into k clusters, and the trained SVM models were applied to the corresponding clusters. The results of the prediction of endpoints were unified.
Deep learning models. Deep learning models were developed to predict death at 1-year and 5-years of dialysis (1-year and 5-year deaths, respectively). The numbers of layers and hyperparameters were optimized on the basis of the accuracy to predict the endpoints and to prevent overfitting (Figs 3 and 4). In the development of each deep learning model, two-thirds of the development dataset was used as the training dataset and the remaining one-third was used as the test dataset. Then, using the validation dataset, we evaluated the accuracy of the prediction of the endpoints using the deep learning models.
Evaluation of model performance. The performance of the models developed for the binary diagnosis decision (death or no death) in terms of accuracy, sensitivity, and specificity was evaluated using the validation dataset. Accuracy is calculated as follows: Because of this method chosen to calculate accuracy, the value of accuracy changes depending on the number of endpoints (the risk of death). Here, given that sensitivity and specificity were constant, we simulated accuracy at various risks of death from 0.05 to 0.65, and compared the accuracies of the models.

Baseline characteristics
The baseline characteristics including biochemical data are shown in Table 1. No statistically significant differences in the baseline characteristics between the development and validation datasets were observed. Machine learning models were constructed ( Fig 5). Through the hidden layers, the patient's characteristics were extracted. The dropout rate of each hidden layer was determined appropriately. Adam was used as a learning rate optimization algorithm. ReLUs were used as the activation function of hidden layers, and the logistic activation function was used in the output layer. The performance of a deep learning model was evaluated in terms of accuracy and loss function. The trained model was applied to the validation dataset. https://doi.org/10.1371/journal.pone.0233491.g003

PLOS ONE
Explainable artificial intelligence for hemodialysis patients

K-means method models
The K-means method was conducted, and the models with 2 to 10 clusters were developed. The elbow method showed decreasing in within-cluster SSE with increasing numbers of clusters (Fig 6). Five and six clusters were chosen as candidate numbers of clusters.
The Kaplan-Meier survival curves showed the relationship between the numbers of clusters and the risk of death (Figs 7 and 8). The five-cluster model clearly distinguished the patients on the basis of the risk of 1-year and 5-year deaths both in the development and validation datasets (Figs 7A, 7B, 8A and 8B); (Tables 2A and 3A). Cluster 5 showed the highest risks of 1-year and 5-year deaths.
In contrast, the six-cluster model showed that the rank of the clusters based on the risk of 1-year death in the development dataset was different from the rank in the validation dataset (Table 2B). Although the risk of 1-year death of Cluster 2 (HR, 1.87) was lower than that of Cluster 3 (HR, 2.55) in the development dataset, the risk of Cluster 2 (HR, 1.58) was higher than that of Cluster 3 (HR, 1.51). Moreover, Cluster 6 showed the highest risk of 5-year death in the development dataset (Table 3B). However, in the validation dataset, the risk of Cluster 5 was very close to that of Cluster 6 (Table 3B); (Fig 8C and 8D), which suggests that the six-cluster model might be unreliable in reflecting the patients' prognosis depending on the patient data. Therefore, considering the stability of the accuracy of the five-cluster model in reflecting the patients' prognosis, k = 5 was considered appropriate for the model, and the five-cluster model was hereafter adopted.

Difference in the characteristics among five clusters
The five-cluster model could cluster the patients on the basis of their characteristics ( Table 4). The mean ages of Clusters 4 and 5 were older than those of other groups. A gender difference was observed; most of the patients in Cluster 1 were males (94.2%), and those in Cluster 2 were females (92.3%). There were also significant differences in the causes of ESKD between the groups as follows: Clusters 1 and 2, CGN (74.6%, and 60.3%, respectively); Cluster 3, DM (93.3%); Cluster 4, nephrosclerosis (100%). In Cluster 5, the numbers of patients with DM and CGN were almost the same as the mean numbers in the study population (Tables 1 and 4). Moreover, the numbers of patients who had a history of CVD were larger in Clusters 3 to 5 than in Clusters 1 and 2.
There were significant differences in the laboratory data among the clusters. Serum albumin and potassium levels gradually decreased with increasing in clusters number. The serum phosphorus, and creatinine levels; and nPCR in Cluster 5 were lower than those in the other groups. The number of patients with high and medium risk of NRIs were larger in Clusters 4 and 5 than in the other clusters. The CRP levels in Clusters 4 and 5 were higher than those in other groups.
The risk of all-cause death in Cluster 5 was higher than those in the other groups ( Table 5). The trends similar to all-cause death were observed in the risks of CVD-and infection-caused deaths. The details of 5-year death were as follows. The proportions of CVD-caused death in

Performance of models to predict death
The five-cluster model had four cutoff points. The accuracies of predicting death on the basis of these cutoff points were compared with those of the LRM, SVM model, ensemble model, deep learning model, and the high-risk group of NRI (Fig 9). The accuracies of predicting 1-year and 5-year deaths using LRM (0.938, 0.759), SVM model (0.937, 0.758), ensemble The estimated accuracies of the models decreased with increasing risks of 1-year and 5-year deaths (Fig 10). The lines of the accuracies of predicting the risk of 1-year death crossed at 0.3 of the risk of 1-year death (Fig 10A). The accuracies for the risk of 1-year death using the LRM, SVM model, ensemble model, and deep learning model showed similar patterns, and were more than 0.9 at 0.1 of the risk of 1-year death, which was higher than those using the fivecluster model. Moreover, for the prediction of 5-year death, the accuracies of the LRM, SVM model, ensemble model, and deep learning model were higher than that of the five-cluster model (Fig 10B). The accuracies of the deep learning model, LRM, SVM, and ensemble model were almost the same, more than 0.7, at which the interval of 5-year death was about 0.4.
The sensitivities and specificities of the models showed a negative relationship at different cutoff points for clusters (Fig 11). To predict both 1-and 5-year deaths, the five-cluster model

Total-care system for hemodialysis patients
Considering the characteristics of the machine learning models, for our system, we adopted an ensemble model with the K-means method and SVM for use in clinical settings. Our recommended system is as follows (Fig 12): After clustering, the patients in clusters 1 to 3 are followed up periodically, because some of them may be classified in Cluster 4 or 5 in the future. Then, the patients in Clusters 4 and 5 are examined using the SVM models. If they are diagnosed to be at a high risk, they undergo detailed medical examinations. If diseases or aggravation of comorbid conditions are diagnosed, intervention and therapy are provided. If not, they are followed up as high-risk patients more frequently and thoroughly than the patients in Clusters 1 to 3.

Discussion
There are various types of machine learning, whose mechanisms cannot be fully understood by humans, and are called black boxes. Thus, an explainable machine learning model has been studied. Among the types of machine learning, K-means is based on the least square method, and is more understandable than other models. Moreover, SVM can be used to predict patients' prognosis. In this work, we developed an explainable ensemble model for the prediction of patients' prognosis, which was composed of K-means and SVM. Hemodialysis patients were categorized into five clusters by the K-means method on their basis of baseline characteristics, which reflected the risk of death. Then, we developed machine learning and statistical models, and compared their performances. The ensemble model of the K-means method and SVM showed the highest accuracy of the prediction of death. Although some studies showed a high accuracy of the prediction of dialysis patients' death using machine learning models, the Variables are expressed as mean±standard deviation. Vintage and CRP are also shown as median and interquartile range. Intergroup comparisons of parameters were performed using the chi-square test, t-test, and the Mann-Whitney U test as appropriate.
Abbreviations: DM, diabetes mellitus as a cause of end-stage renal disease; CGN, chronic glomerulonephritis; CVD, cardiovascular disease; BMI, body mass index; CRP, C-reactive protein; nPCR, normalized protein catabolic rate; NRI, nutritional risk index.
https://doi.org/10.1371/journal.pone.0233491.t004 internal structures of the models were difficult to understand [11][12][13]. There is a tradeoff relationship between the accuracy of prediction and the transparency of algorithms [14]. We attempted to achieve a balance by developing a blended system, which we found useful for identifying patients at a high risk of death, and which was easily applicable to clinical settings. The International Society of Renal Nutrition and Metabolism proposed an algorithm for the nutritional management and support of chronic kidney disease patients [15]. In the algorithm, multiple nutritional examinations, such as measurement of dietary nutritional intakes, subjective global assessment, and anthropometrics, are recommended [15]. However, it is difficult for all of these nutritional examination results to be digitized and evaluated by machine learning models. Moreover, a systematic review of the studies of the data-driven population segmentation analysis pointed out that a perfect diagnosis is not always guaranteed; and the review suggested the importance of assessing the segmentation outcome with a combination of statistical reasoning, clinical judgement, and policy implication [16]. Therefore, we did not leave the entire diagnosis to be performed by a machine learning system, and instead developed the ensemble model as part of the medical system. The ensemble model and detailed medical examinations can complement each other, which enhances the robustness of this system.
According to the JSDT annual report in 2015, the mean age of Japanese dialysis patients was 67.86 years, 64.3% were male, and the causes of ESKD were DM (38.4%), CGN (29.8%), nephrosclerosis (9.5%) [1]. Considering these basic statistics, our system could divide the patients into the five clusters reflecting their baseline characteristics (Table 6). These characteristics were risk factors for death in their prognosis [3,4,6]. For example, the risks of all-cause death, CVD-and infection-caused deaths in Cluster 5 were higher than those in other clusters.

PLOS ONE
And Cluster 5 showed lower serum albumin and creatinine levels and lower nPCR, which are nutritional factors, than the other clusters, and included a large number of patients with high and medium risks of NRI of 26.6% and 36.8%, respectively. Moreover, a high serum CRP level, which indicates inflammation, was also observed in Cluster 5. Inflammation is often observed in ESKD patients with malnutrition, and this complex state of malnutrition and inflammation is called protein energy wasting (PEW) [5]. PEW causes CVD which is a risk factor for death [5,8]. The classification of an elderly patient with PEW into Cluster 5 indicates that the treatment of PEW should be of the highest priority.
Our system could clearly distinguish patients with DM (Cluster 3) or nephrosclerosis (Cluster 4) from those with other conditions. The patients in Cluster 4 showed a higher risk of death than those in Cluster 3. What factor made this difference? Both DM and aging are the main causes of CVD in hemodialysis patients [17]. According to a systematic review, they are risk factors for all-cause and CVD-caused deaths [18]. In our study, no clear differences were observed in the other risk factors reported in the systematic review, such as history of CVD, BMI, hemoglobin level, serum albumin, and CRP levels, between Clusters 3 and 4 [18]. The only factors different between these clusters were the causes of ESKD and age; patients in Cluster 4 were about 8 years older than those in Cluster 3. DOPPS showed no statistically significant difference in mortality rate between patients with DM and hypertension [19]. It is possible that age itself might have caused the survival difference. DM has been the leading cause of ESKD in Japan, and the number of dialysis patients with DM has been stable over the past few years [1]. In contrast, nephrosclerosis is caused by aging and hypertension, and the number of patients with nephrosclerosis has been increasing with the aging of the population in Japan [1]. Elderly patients with nephrosclerosis should be paid more attention, because they are at a high risk of death, and will be a majority among dialysis patients in the near future.
Similar to our study, a cohort study of the health care system in Singapore showed a relationship among K-means clusters, healthcare utilization pattern, and mortality [20]. Why do the clusters obtained by the K-means method reflect the patients' prognosis in the Singapore study and our study? The cluster centers were obtained using Eq (5). μ j is a vector equal to the mean of all data of patients in Cluster j. That is, patients in Cluster j are distributed in an mdimensional sphere with the center at μ j . In this study, the number of clusters was determined by the links with the risk of death as an important true endpoint, which showed that μ j was strongly associated with risk of death. On the basis of these theoretical backgrounds, each cluster had specific characteristics of risk factors for death, such as gender, causes of ESKD, and PEW (Table 6). In the risk prediction models using standard statistics, the variables are often arbitrarily selected, whereas in machine learning, patients' features are extracted from their numerical data, even though a human does not provide sufficient information. There is a possibility that this feature extraction can clarify the new pathophysiological characteristics of diseases. For example, the five clusters in this study, which had different numerical features, may have different courses of change in their body condition after dialysis initiation. Thus, new unknown research seeds will be mined by machine learning.
The performance of machine learning is often evaluated by the accuracy of classification. When using the validation data, the ensemble model showed a higher accuracy of the prediction of death than other models. The analysis of machine learning models, e.g., SVM and deep learning models, is a black box [21]. Because our ensemble model was composed of the Kmeans method and SVM model, this combined system of classification and prediction made the results interpretable with high accuracy, and closely matched the clinical decision-making process. The practical applications of this kind of machine learning model have never been reported.
Because accuracy is determined by the incident number of events, it changes with the composition of the sample population. Thus, we evaluated the changes in the accuracies of the models with the changes in the risk of death. The risks of 1-year death in Japan and USA are 9.6% and 13.4%, and those of 5-year death in Japan, Italy, and USA are 39.5%, 44.4%, and 58%, respectively [1,2,22,23]. In the simulation using these populations, the machine learning models could show high accuracies, and effectively predicted the prognosis of ESKD patients. The classification performance of diagnostic tests is commonly evaluated in terms of sensitivity and specificity. The machine learning models in this study showed their high specificity to predict 1-year and 5-year deaths. High specificity means that the models have a small number of false-positive patients. That is, when a patient is diagnosed to be positive for a risk by the models, the possibility of the presence of a disease is high. Therefore, it could be said that the diagnosis obtained using the models with high specificities is useful to confirm the diagnosis. On the other hand, because the sensitivities of SVM and deep learning models were low, they were not appropriate for screening high-risk patients. The sensitivities of the K-means method using clusters were higher than those of the other models. The clusters might be useful for identifying the high-risk patients.
Our system is applicable to clinical settings in the context of its limitations. First, in this study, JRDR data were used. This data were obtained from 98.8% of dialysis patients in Japan, reflecting the real-world of dialysis patients in Japan. Because our system was developed using these data, its accuracy for Japanese or Asian patients is high, but the results using data from other countries might be biased by the sampling of patients. Second, we did not include patients with missing data in this study, which might cause a selection bias. Third, the JRDR data did not include sufficient data for assessing malnutrition, blood pressure, comorbid conditions, and medications. And, we were unable to evaluate the effects of the differences in the baseline characteristics such as dietary intake; comorbid conditions such as DM and hypertension; and medications such as hypoglycemic and antihypertensive medicines on the clustering. Further studies are needed to evaluate the relationship between these factors and clustering. Thus, such data would improve the accuracy of the models.

Conclusions
We developed a novel system using machine learning algorithms that analyzes hemodialysis patients' data, categorizes the patients on the basis of their characteristics, and identifies patients at a high risk of death. The new approach has a strong potential to guide treatments and improve hemodialysis patients' prognosis.