Skip to main content
Advertisement
Browse Subject Areas
?

Click through the PLOS taxonomy to find articles in your field.

For more information about PLOS Subject Areas, click here.

  • Loading metrics

An ensemble-based 3D residual network for the classification of Alzheimer’s disease

  • Xiaoli Yang ,

    Roles Conceptualization, Funding acquisition, Methodology, Resources, Supervision, Writing – review & editing

    yxl@haust.edu.cn

    Affiliation School of Medical Technology and Engineering, Henan University of Science and Technology, Luoyang, China

  • Jiayi Zhou,

    Roles Conceptualization, Formal analysis, Methodology, Software, Writing – original draft, Writing – review & editing

    Affiliation School of Medical Technology and Engineering, Henan University of Science and Technology, Luoyang, China

  • Chenchen Wang,

    Roles Data curation, Formal analysis

    Affiliation School of Medical Technology and Engineering, Henan University of Science and Technology, Luoyang, China

  • Xiao Li,

    Roles Data curation, Formal analysis

    Affiliation School of Medical Technology and Engineering, Henan University of Science and Technology, Luoyang, China

  • Jiawen Wang,

    Roles Investigation, Resources

    Affiliation School of Medical Technology and Engineering, Henan University of Science and Technology, Luoyang, China

  • Angchao Duan,

    Roles Data curation, Investigation

    Affiliation School of Medical Technology and Engineering, Henan University of Science and Technology, Luoyang, China

  • Nuan Du

    Roles Investigation

    Affiliation School of Basic Medicine and Forensic Medicine, Henan University of Science and Technology, Luoyang, China

Abstract

Alzheimer’s disease (AD) is a common type of dementia, with mild cognitive impairment (MCI) being a key precursor. Early MCI diagnosis is crucial for slowing AD progression, but distinguishing MCI from normal controls (NC) is challenging due to subtle imaging differences. Furthermore, differentiating early MCI (EMCI) from late MCI (LMCI) is also important for interventions. This study proposes a deep learning-based approach using a weighted probability-based ensemble method to integrate results from three-dimensional residual networks (3D ResNet). (1) This study employs 3D ResNet-18, 3D ResNet-34, and 3D ResNet-50 architectures with the Convolutional Block Attention Module (CBAM). The attention mechanism enhances performance by helping the model focus on pertinent information. Data augmentation techniques are applied to address limited data and improve accuracy. (2) To overcome the limitation of the individual convolutional neural network (CNN), an ensemble learning method is adopted. The method assigns weights to each 3D CNN model based on prediction accuracy and integrates them to obtain the final result. Our method achieves accuracy of 94.87%, 92.31%, 95.49%, and 95.97% for MCI vs. NC, MCI vs. AD, EMCI vs. LMCI, and NC vs. EMCI vs. LMCI vs. AD, respectively. The results demonstrate the effectiveness of our method for AD diagnosis.

Introduction

Alzheimer’s disease (AD) is an irreversible, progressive neurodegenerative disease and one of the most common causes of dementia in the elderly. As the disease progresses, patients experience a decline in cognitive abilities and daily functioning, which has a profound impact on their quality of life and that of their family members. Early diagnosis is crucial as it can help slow down the progression of the disease and improve the quality of life [1]. Mild cognitive impairment (MCI) represents a condition that exists between AD and the normal age-related cognitive decline in normal controls (NC). MCI is further classified into early MCI (EMCI) and late MCI (LMCI), with EMCI representing a milder form of cognitive impairment compared to LMCI. During the MCI stage, patients experience a slower decline in cognitive functioning, and in the majority of cases, the ability to perform activities of daily living remains largely unimpaired. However, MCI is considered a high-risk state for progression to AD, with approximately 10% to 15% of patients with MCI progressing to AD each year [2]. Therefore, it is essential to closely monitor individuals who are already in the MCI stage for changes in their cognitive functional status. Early intervention and treatment for them may help to delay or prevent the progression of MCI to AD [3].

In recent years, studies related to AD have demonstrated that neuroimaging techniques, such as magnetic resonance imaging (MRI) and positron emission tomography (PET), are more effective than traditional clinical assessments and psychological tests in diagnosing AD [4]. In particular, structural magnetic resonance imaging (sMRI), a widely used neuroimaging analysis method [5], plays an important role in the diagnosis of AD. As technology advances, the integration of neuroimaging techniques with computer-aided diagnosis has gained prominence in the classification and prediction of AD. For example, machine learning algorithms can analyze neuroimaging data to extract and interpret valuable information, thereby enhancing the accuracy of diagnoses and assessments [6].

Traditional machine learning algorithms usually require manual design and selection of features, which can be both time-consuming and challenging when dealing with complex neuroimaging data. In contrast, deep learning methods can automatically learn and extract high-level features from the data, capturing its intricate relationships more effectively. Consequently, deep learning techniques are becoming a more mainstream and effective option for medical image analysis. Convolutional neural networks (CNN) in deep learning have significant advantages in image processing tasks and are becoming more widely adopted [7]. Deep learning has been demonstrated exceptional performance in medical image analysis [8]. It also shows considerable potential in the classification and prediction of AD. In 2013, Suk et al. [9] explored the potential of deep learning in AD classification, proposing a deep learning-based feature representation using a stacked auto-encoder. This method revealed nonlinear relationships among features and improved classification accuracy. With the development of deep learning algorithms, more complex CNN architectures have been introduced into the research on AD, such as GoogleNet [10], VGGNet [11], and ResNet [12]. By introducing the residual learning mechanism, ResNet effectively solves the problems of gradient vanishing and gradient explosion of traditional CNNs during the training process. This advancement enables the network to learn complex feature representations more deeply and efficiently, enhancing the performance and training efficiency of deep learning models.

In the field of deep learning, two-dimensional (2D) images are widely used, especially in various applications related to computer vision and image processing. Some researchers have used 2D slices extracted from 3D MRI images to classify AD. For example, Xu et al. [13] introduced a selective kernel network and channel shuffle in ResNet and proposed an enhanced ResNet to classify AD. Prakash et al. [14] used whole slide 2D images to perform the classification tasks. They employed transfer learning with three models, ResNet-101, ResNet-50, and ResNet-18, and evaluated their performance in detecting AD. However, 2D images have certain limitations, as they are unable to directly convey 3D information or rotational transformations of objects. This may restrict the capabilities of deep learning models in certain scenarios. In the classification of AD, neuroimaging data is crucial, yet 2D images fail to fully utilize the volumetric information of brain imaging. To more accurately analyze and recognize brain information, researchers have introduced 3D CNNs. For instance, Frimpong et al. [15] proposed an AD classification model based on 3D CNN multilayer perceptron (MLP). The model uses an attention mechanism to automatically extract relevant features in the images and generate probability maps, which are then input to the MLP classifier. Zhang et al. [16] proposed a computer-aided method for early classification prediction of AD by introducing an explainable 3D residual attention deep neural network for end-to-end learning from sMRI scans. Wen et al. [17] compared numerous CNN-based research methods for AD classification and concluded that the performance of the different 3D methods was similar, while the performance of the 2D slicing methods was inferior. Unlike 2D CNNs, 3D CNNs process volumetric data, effectively capturing the spatial structure and relationships of objects from different angles. 3D CNN not only enhances the model’s ability to interpret complex data but also improves the accuracy and reliability of AD classification [18].

Input images for 3D CNNs can be categorized into three types: 3D whole-brain images, 3D image patches, and 3D regions of interest (ROI). Methods based on image patches may be inadequate for capturing the global features and structural context of the entire brain, potentially resulting in the overlooking of critical information dispersed across various regions. The performance of the ROI-based methods is relatively improved, but it is significantly influenced by the extraction and segmentation techniques. Whole-brain images require more computational resources and sophisticated data processing techniques, but they offer the most comprehensive and integrated information about brain structure and function [19].

The availability of medical image data is limited, which can pose certain challenges to classification tasks. In addition to advancements in network architectures, ensemble learning has been introduced to address these challenges. Tanveer et al. [20] proposed a model that combines ensemble learning with deep learning and employs transfer learning for AD classification, ultimately achieving better classification results. Zhang et al. [21] proposed a method combining 3D CNN with ensemble learning to improve the accuracy of AD classification. Furthermore, a data denoising module was proposed to reduce the boundary noise. Experimental results show that the model effectively improves the training speed of the neural network. Grover et al. [22] proposed an ensemble-based transfer learning method. It uses simple averaging ensemble and weighted averaging ensemble methods to extract superior sparse patterns and features from MRI images. An et al. [23] proposed a three-layer deep ensemble learning framework, including a voting layer, a stacking layer, and an optimizing layer. The proposed architecture demonstrated superior performance in AD classification compared to six other representative ensemble learning methods referenced in their paper.

To address the complexities and challenges in the classification of MCI, this study proposes an ensemble learning framework that integrates three distinct 3D ResNet models. Each base model is augmented with the Convolutional Block Attention Module (CBAM) to prioritize discriminative brain regions in 3D sMRI data, while data augmentation strategies are applied to mitigate the impact of limited data. The proposed method integrates predictions from these models using a weighted probability-based fusion strategy, which dynamically assigns weights based on the performance of each individual model. This study implements three binary classification tasks: AD vs. MCI, NC vs. MCI, EMCI vs. LMCI, and a multiclass classification task: AD vs. EMCI vs. LMCI vs. NC. The primary contributions of this study are as follows:

  1. (1). Improved model design through 3D ResNets with CBAM, which enhances feature extraction by incorporating attention mechanisms alongside deep learning, enabling precise identification of brain changes in 3D MRI scans.
  2. (2). The introduction of a weighted probability-based ensemble method to combine the classification results from three distinct 3D ResNet models. This method shows superior performance compared to using a single model, overcoming the limitations of individual CNNs. The proposed ensemble approach improves the accuracy of AD diagnosis and enhances the model's ability to distinguish between subtle cognitive stages such as early and late MCI.

The remainder of this paper is structured as follows:

The Materials and methods section explains the research data and methods in this study. The Results section presents the results of this study. The Discussion section discusses the experimental findings. Finally, the conclusion of this paper is provided.

Materials and methods

This section provides a detailed description of the data as well as the methods used in this study. The overall flow of the study is given in Fig 1. Initially, the 3D sMRI data undergo preprocessing. Subsequently, the data-augmented images are input into three classification networks for training. Finally, the classification results from the three networks are aggregated to produce final results.

thumbnail
Fig 1. Workflow.

Pi denotes predicted probabilities of the i-th model, Wi denotes the weight of the i-th model. n denotes the number of classes.

https://doi.org/10.1371/journal.pone.0324520.g001

Materials

Data acquisition.

The data used in this study are sourced from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) database (http://adni.loni.usc.edu/). The ADNI database is a widely utilized resource on a global scale, containing various forms of data, including sMRI, functional MRI, PET, cerebrospinal fluid analysis, and many other forms of imaging and biomarker data. These data are publicly accessible to the global scientific community to promote research on AD and related neurodegenerative disorders. The data are stored and shared through a public database.

In the initial phase of constructing the dataset, the focus was primarily on the overall characteristics of all patients with cognitive impairment within the MCI category, without specifically considering the sample ratio between EMCI and LMCI. This approach was adopted to provide a comprehensive understanding of MCI as a whole category of cognitive impairment. To further investigate the subtle differences between EMCI and LMCI and to improve the model's discriminative ability, the decision was made to resample the data within these two subtypes. This adjustment aims to create a more balanced and representative dataset for subsequent analysis, model training, and prediction tasks. In this study, T1-weighted sMRI image data were used, including 350 AD data, 629 MCI data, 350 NC data, 350 EMCI data, and 318 LMCI data. A detailed account of the selected data is provided in Table 1.

thumbnail
Table 1. Demographic characteristics and scale scores of the study participants.

https://doi.org/10.1371/journal.pone.0324520.t001

Data preprocessing.

The images downloaded from ADNI have already undergone gradient distortion correction and gradient non-uniformity correction [24]. Building upon this, further preprocessing steps were performed. The first step involved performing anterior commissure-posterior commissure correction on the downloaded DICOM format images. Then, skull stripping was applied to remove non-brain tissue. Finally, the skull-stripped sMRI images were registered to the standard brain atlas of the Montreal Neurological Institute space [25]. The size of the final processed images is 121 × 145 × 121. The dataset was randomly divided into a training set and a test set in the ratio of 7:3, where the training set constitutes 70 percent of the set and the test set constitutes 30 percent of the set.

Data augmentation.

To mitigate the risk of overfitting due to the relatively limited dataset, data augmentation techniques were performed to improve classification accuracy [26]. A series of augmentation operations were applied to the preprocessed 3D images using TorchIO in Python. The augmentation operations on the training set included: intensity rescaling, which adjusts the range of pixel values; random affine transformation, which performs operations such as translating, rotating, and scaling; random flipping, which flips the image horizontally, vertically, or along a specific dimension; adding random noise or blurring; random masking, which generates masked regions in the image; and Z-normalization, which normalizes the image data to a distribution with a mean of 0 and a standard deviation of 1. This latter operation helps to speed up training and improves model convergence. The test set underwent only Z-score normalization.

Methods

CNN architecture.

ResNet is a convolutional neural network proposed by He et al. [12]. The main feature of the ResNet architecture is the introduction of a residual learning mechanism, which effectively solves the problems in deep network training by inter-layer connections and downsampling strategies. The ResNet structure is designed with blocks of different depths, including the basic block and the bottleneck block. ResNet-18 and ResNet-34 are composed of different numbers of basic blocks, while ResNet-50 is constituted of bottleneck blocks. The ResNet architecture uses an average pooling layer and a fully connected layer at the end of the network to transform the convolutional feature maps into the final classification results. In this study, 3D versions of ResNet-18, ResNet-34, and ResNet-50 are used, and they are combined with CBAM for classification. Fig 2 shows the structure of the two blocks integrated with CBAM at the end of the ResNet.

thumbnail
Fig 2. The structure of the bottleneck block and basic block connected to CBAM.

(a) illustrates the bottleneck block, while (b) depicts the basic block. x×x×xConv3d, F denotes the kernel size and the number of feature maps, respectively, and BN refers to batch normalization.

https://doi.org/10.1371/journal.pone.0324520.g002

Attention mechanisms have an important role in human perception. Nowadays, numerous researchers have sought to enhance the performance of CNNs by integrating attention mechanisms into their architectures [18,27]. The attention mechanism enables the network to focus on important features while suppressing irrelevant ones, thus enhancing the model's focus and learning ability related to the input data.

The CBAM is a lightweight and generalized module proposed by Woo et al. [28] that can be integrated into any CNN architecture. The CBAM combines both the channel attention mechanism and the spatial attention mechanism, which can improve the performance of the model while maintaining a relatively low computational overhead. In this study, the CBAM has been modified to be compatible with 3D CNN and incorporated at the end of the ResNet. The specific network structure is given in Table 2.

thumbnail
Table 2. The structure of the 3D ResNet models.

https://doi.org/10.1371/journal.pone.0324520.t002

Ensemble learning.

The core idea of the ensemble method is to utilize the classification results of multiple models and combine them in some way to achieve superior classification results than any individual model. This study uses a weighted probability-based fusion method to get the final results [29,30]. Specifically, different weights of 0.4, 0.3, and 0.3 are assigned to the classification results of the three models. The model with the optimal performance is allocated a greater weight, reflecting its superior predictive capabilities, while the other two models are assigned relatively lower weights. This weighting strategy serves to balance the contributions of the best-performing model, ensuring that its strengths are effectively utilized without disproportionately influencing the final output. Simultaneously, it mitigates the risk of excessive reliance on the model with the least favorable performance, thereby preserving the integrity of the ensemble. By taking into account the predictive power of all participating models in the overall integration process, this method enhances the robustness and reliability of the ensemble results.

The weighting strategy (0.4, 0.3, 0.3) employed in this study was determined through systematic pre-experimentation and parameter optimization. Multiple weight combinations were evaluated, such as (0.5, 0.3, 0.2, accuracy: 0.9560), (0.6, 0.2, 0.2, accuracy: 0.9524), and (0.7, 0.2, 0.1, accuracy: 0.9451). The (0.4, 0.3, 0.3) combination achieved the accuracy of 0.9597, demonstrating superior classification performance. Notably, assigning an excessively dominant weight to the best-performing model did not improve overall performance, emphasizing the necessity of a balanced weighting scheme.

Let Wi denote the weight of the i-th model. The model's predicted probabilities are represented as

(1)

where n denotes the number of classes, n is 2 or 4. Then normalize Pi as

(2)

The final class label is determined as

(3)

Data partition strategy.

In this study, all preprocessing steps are conducted on the raw data before dataset division. The dataset is randomly split into a training set (70%) and a test set (30%), ensuring each sample is exclusively assigned to one subset to prevent overlap. Data augmentation is applied only to the training set to artificially expand its size and enhance model generalization, while the test set undergoes only necessary normalization procedures and remains untouched during training. The test set is strictly reserved for final evaluation and is never used for model optimization or hyperparameter tuning. This design ensures the test set serves as an independent benchmark, enabling unbiased evaluation of model performance. By maintaining clear separation between training and testing phases, the methodology ensures the reliability and validity of the experimental results.

Evaluation metrics.

In the classification tasks, the prediction results are categorized into the following four cases: TP denotes positive samples that have been correctly predicted as positive, TN indicates negative samples that have been correctly predicted as negative, FP refers to negative samples that have been incorrectly predicted as positive, and FN signifies positive samples that have been incorrectly predicted as negative. The evaluation metrics used in this study include accuracy, recall, precision, F1-score, and area under the curve (AUC). In the case of multiclass classification, only the first four evaluation metrics are used.

Accuracy is the ratio of the number of correctly predicted samples to the total number of samples.

(4)

Recall is the ratio of the number of correctly predicted positive samples to the total number of actual positive samples.

(5)

Precision is the ratio of the number of correctly predicted positive samples to the total number of samples that are predicted as positive.

(6)

F1-score is the harmonic mean of precision and recall, which provides a balance between the two metrics by considering both precision and recall.

(7)

The AUC is calculated by summing the area under the receiver operating characteristic (ROC) curve. Although the ROC curve may not always lead to a straightforward comparison of different models, the AUC serves as an effective measure of the model's overall performance, with AUC values approaching 1 indicating superior performance.

Results

In this study, classification tasks focus on the detection of MCI. Initially, binary classification is performed between MCI and NC, as well as between MCI and AD. The state of MCI can be further subdivided into EMCI and LMCI. The distinction between these two states is important for individualized treatment, scientific research, and the optimal allocation of social health resources. Therefore, EMCI and LMCI are classified, and classification is performed across the four states: NC, EMCI, LMCI, and AD.

The experiments were conducted on a computing platform equipped with the Python 3.10.12 and PyTorch 2.0.1 software, utilizing an NVIDIA GeForce RTX 4090 graphics processing unit (GPU) with 64 gigabytes (GB) of memory. The network is trained with a batch size of 4, and a learning rate of 0.0001, and the training consists of 100 epochs.

Ablation experiments

This section presents results comparing the performance of models with and without CBAM, as well as with and without data augmentation. Taking the classification tasks of MCI and NC as an example, this study trained three 3D ResNet models. First, without data augmentation, the impact of CBAM on the performance of the 3D ResNet models is examined. Table 3 presents the performance differences between models with and without CBAM. The results indicate that incorporating CBAM improves classification performance and accuracy across the three ResNet models. Subsequently, the model with CBAM is employed to investigate the impact of data augmentation on model performance. The term “data augmentation” is represented as “DA” in Table 3. The results in Table 3 indicate that the classification performance of the three ResNet models has been further improved.

In this study, visualization techniques were integrated to illustrate the role of the CBAM mechanism within the neural network. Fig 3 presents MRI cross-sectional slices from four different categories, with each sample processed through this mechanism. In Fig 3, regions highlighted in shades closer to red indicate areas that received higher attention from the model during the classification decision process. The more intense the red hue, the greater the attention the model paid to these regions during classification. The CBAM mechanism enhances the model's focus on these key regions, enabling it to more effectively extract and emphasize pathological features in AD classification. This attention mechanism aids the model in identifying critical abnormalities or features, thereby significantly improving classification accuracy. The visualization provides a more intuitive explanation of the role of the CBAM attention mechanism within the model, offering valuable insights for model optimization and interpretability.

thumbnail
Fig 3. Visualization of CBAM.

(a) represents the AD category, (b) represents the NC category, (c) represents the EMCI category, and (d) represents the LMCI category.

https://doi.org/10.1371/journal.pone.0324520.g003

Based on the above results, this study ultimately uses models with CBAM to classify images after performing data augmentation. Taking the classification tasks of MCI and NC as an example, the differences in ensemble performance are discussed. There are four combination methods for the three ResNet models. This study compares pairwise ensemble combinations of the three models with the ensemble of all three models. In Table 4. “Ensemble 1” denotes the ensemble of 3D ResNet-50 and 3D ResNet-18, “Ensemble 2” denotes the ensemble of 3D ResNet-50 and 3D ResNet-34, “Ensemble 3” denotes the ensemble of 3D ResNet-34 and 3D ResNet-18, and “Ensemble” denotes the ensemble of all three networks. Compared to individual models, both “Ensemble 1”, “Ensemble 3” and “Ensemble” showed improved accuracy, while Ensemble 2 did not exhibit an accuracy increase. However, the AUC values for all three methods were enhanced, indicating an overall improvement in performance. The results indicate that the best performance is achieved when using the ensemble of three models.

thumbnail
Table 4. Ensemble results of different models.

https://doi.org/10.1371/journal.pone.0324520.t004

All classification results

Based on the aforementioned results, this study employed the proposed method for binary classification between MCI and AD, as well as between EMCI and LMCI. Additionally, a four-class classification involving NC, EMCI, LMCI, and MCI was performed. The results are presented in Tables 5 and 6.

Furthermore, Figs 4 and 5 present the ROC curve for binary classification and the confusion matrix for four-class classification, respectively. The data presented in the tables demonstrate that, across all classification tasks, ensemble learning achieves a notable improvement compared to individual models. The ROC curve for the binary classification task clearly illustrates the superior performance of ensemble learning relative to a single model.

thumbnail
Fig 4. The ROC curve of binary classification.

(a) shows the ROC curve of MCI vs. NC, (b) shows the ROC curve of MCI vs. AD, and (c) shows the ROC curve of EMCI vs. LMCI.

https://doi.org/10.1371/journal.pone.0324520.g004

thumbnail
Fig 5. Confusion matrix of multiclass classification results.

(a) is the confusion matrix of the final ensemble results, (b) is the confusion matrix of the 3D ResNet-18 classification results, (c) is the confusion matrix of the 3D ResNet-34 classification results, and (d) is the confusion matrix of the 3D ResNet-50 classification results.

https://doi.org/10.1371/journal.pone.0324520.g005

Discussion

This section compares the results of this study with those of previous studies, then points out the limitations of the proposed method and discusses potential approaches to address these limitations in future research.

The results obtained in this study indicate that the proposed method is a viable approach for AD classification, yielding superior results. To better illustrate the effectiveness of the method, a comparison with prior work was conducted. The results are presented in Tables 7 and 8. Data from Tables 7 and 8 demonstrate that the method outperforms others in both binary and four-class classification tasks.

thumbnail
Table 7. Comparison of binary classification results.

https://doi.org/10.1371/journal.pone.0324520.t007

thumbnail
Table 8. Comparison of multiclass classification results.

https://doi.org/10.1371/journal.pone.0324520.t008

Several studies have used 3D CNN architectures for classification tasks [17,25,31,34,39,41]. The observed performance differences among these architectures are influenced by the datasets employed and the specific training methodologies adopted. Notably, some studies have indicated that 3D CNN architectures outperform 2D CNN architectures in terms of classification performance [17,31]. Liu et al. [32] proposed a multi-model deep learning framework. Their results show that multi-model approaches yield better performance compared to single models. Furthermore, several studies implemented ensemble learning techniques, which demonstrate superior performance [36,37,39]. Variations in datasets, model architectures, and the specific ensemble strategies employed may contribute to the discrepancies in results. Although the outcomes of other studies may fall short of our method, they provide valuable considerations for future research avenues.

In this study, 3D images are used as inputs to the network. The approach was based on 3D ResNet-18, 3D ResNet-34, and 3D ResNet-50, which are combined with CBAM for AD classification. The classification results of the three networks are integrated through a weighted probability-based ensemble method to get the final results. It is well known that 3D whole-brain images contain rich information, but the rich image information contains not only critical information for classification but also some unimportant information. Accordingly, this study combines the network with the attention mechanism to focus the CNN attention on the key information related to the classification tasks. The results of our study show that the attention mechanism helps improve the model's performance. Meanwhile, the limited amount of data can affect classification performance. Therefore, this study employs data augmentation to address this issue, and the previous results have demonstrated the effectiveness of this approach.

A single CNN may be limited by factors such as insufficient data and the choice of model structure, potentially resulting in suboptimal classification performance. To overcome these limitations, ensemble learning becomes an effective strategy. It can combine the results of multiple CNNs, thus reducing the bias of a single model, significantly improving the performance of classification tasks, and enhancing the stability of the model for superior performance in practical applications. Unlike simple majority voting methods, the ensemble method this study adopts relies on the probability distributions of each model's output. Specifically, different weights are assigned to each model based on its classification accuracy, and then the weighted probabilities are summed to obtain the final integrated probability for prediction. This method effectively improves the accuracy and reliability of classification by considering the prediction confidence of each model. The experimental results clearly show that this weighted probability-based ensemble method performs better than single CNN models.

Although the research method achieved good results in AD classification, there are still some limitations that, if addressed, could further optimize the network's performance and generalization ability. The first is that this study only used a single modality, sMRI images, for training. Future work could consider combining other modality images, such as PET, and applying multimodal data for AD classification. Secondly, the ensemble method this study proposed uses different ResNet models. Future research could investigate the effects of integrating other network models, and explore different weighting strategies for the models in the ensemble method.

Conclusion

The accurate diagnosis of MCI in the clinic is particularly important, as it not only influences the treatment plans for patients but also relates to the prevention and management of future cognitive decline. To this end, this study proposes an innovative method aimed at classifying different cognitive states using 3D sMRI data. These states include classifying MCI and NC, MCI and AD, EMCI, and LMCI, as well as NC, EMCI, LMCI, and AD. Methodologically, 3D whole-brain images are used as inputs to three 3D ResNet models, incorporating the CBAM attention module to enhance feature representation. This attention module plays a crucial role in improving the model's ability to focus on relevant features. To further improve the model's generalization ability and robustness, data augmentation techniques are applied. Ultimately, the classification results of the three models are integrated using a weighted probability ensemble method to achieve more accurate final classifications. The experimental results demonstrate that the proposed method exhibits high effectiveness and accuracy in predicting different cognitive states, thereby validating the model's performance and providing a solid foundation for future clinical applications. This study provides potential support for interventions and treatments for AD and its precursor stages.

Supporting information

Acknowledgments

The authors would like to thank the ADNI (http://adni.loni.usc.edu/) investigators for publicly sharing their valuable neuroimaging data.

References

  1. 1. Zhang Y, Wang S, Xia K, Jiang Y, Qian P. Alzheimer’s disease multiclass diagnosis via multimodal neuroimaging embedding feature selection and fusion. Information Fusion. 2021;66:170–83.
  2. 2. Song X, Zhou F, Frangi AF, Cao J, Xiao X, Lei Y, et al. Graph convolution network with similarity awareness and adaptive calibration for disease-induced deterioration prediction. Med Image Anal. 2021;69:101947. pmid:33388456
  3. 3. Zhou Y, Si X, Chao Y-P, Chen Y, Lin C-P, Li S, et al. Automated classification of mild cognitive impairment by machine learning with hippocampus-related white matter network. Front Aging Neurosci. 2022;14:866230. pmid:35774112
  4. 4. Feng C, Elazab A, Yang P, Wang T, Zhou F, Hu H, et al. Deep learning framework for Alzheimer’s disease diagnosis via 3D-CNN and FSBi-LSTM. IEEE Access. 2019;7:63605–18.
  5. 5. Li Q, Yang MQ. Comparison of machine learning approaches for enhancing Alzheimer’s disease classification. PeerJ. 2021;9:e10549. pmid:33665002
  6. 6. Çelebi SB, Emiroğlu BG. A novel deep dense block-based model for detecting Alzheimer’s disease. Applied Sci. 2023;13(15):8686.
  7. 7. Thayumanasamy I, Ramamurthy K. Performance analysis of machine learning and deep learning models for classification of Alzheimer’s disease from brain MRI. Traitement Sig. 2022;39(6):1961–70.
  8. 8. Basaia S, Agosta F, Wagner L, Canu E, Magnani G, Santangelo R, et al. Automated classification of Alzheimer’s disease and mild cognitive impairment using a single MRI and deep neural networks. Neuroimage Clin. 2019;21:101645. pmid:30584016
  9. 9. Suk H-I, Shen D, editors. Deep learning-based feature representation for AD/MCI classification. In: Medical Image Computing and Computer-Assisted Intervention – MICCAI 2013. Berlin, Heidelberg: Springer; 2013. https://doi.org/10.1007/978-3-642-40763-5_72
  10. 10. Szegedy C, Wei L, Yangqing J, Sermanet P, Reed S, Anguelov D, et al., editors. Going deeper with convolutions. 2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR); 2015 7-12 June 2015. https://doi.org/10.1109/cvpr.2015.7298594
  11. 11. Simonyan K, Zisserman A. Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:14091556. 2014.
  12. 12. He K, Zhang X, Ren S, Sun J, editors. Deep residual learning for image recognition. Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. https://doi.org/10.1109/cvpr.2016.90
  13. 13. Xu M, Liu Z, Wang Z, Sun L, Liang Z. The diagnosis of Alzheimer’s disease based on enhanced residual neutral network. 2019 International Conference on Cyber-Enabled Distributed Computing and Knowledge Discovery (CyberC). 2019. p. 405–11. https://dx.doi.org/10.1109/CyberC.2019.00076
  14. 14. Prakash D, Madusanka N, Bhattacharjee S, Kim C-H, Park H-G, Choi H-K. Diagnosing Alzheimer’s disease based on multiclass MRI scans using transfer learning techniques. Curr Med Imaging. 2021;17(12):1460–72. pmid:33504310
  15. 15. Frimpong EA, Qin ZG, Turkson RE, Cobbinah BM, Baagyere EY, Tenagyei EK. Enhancing Alzheimer’s disease classification using 3D convolutional neural network and multilayer perceptron model with attention network. KSii Transactions on Internet and Information Systems. 2023;17(11):2924–44. https://dx.doi.org/10.3837/tiis.2023.11.002
  16. 16. Zhang X, Han L, Zhu W, Sun L, Zhang D. An explainable 3D residual self-attention deep neural network for joint atrophy localization and Alzheimer’s disease diagnosis using structural MRI. IEEE J Biomed Health Inform. 2022;26(11):5289–97. pmid:33735087
  17. 17. Wen J, Thibeau-Sutre E, Diaz-Melo M, Samper-González J, Routier A, Bottani S, et al. Convolutional neural networks for classification of Alzheimer’s disease: Overview and reproducible evaluation. Med Image Anal. 2020;63:101694. pmid:32417716
  18. 18. Shen X, Lin L, Xu X, Wu S. Effects of patchwise sampling strategy to three-dimensional convolutional neural network-based Alzheimer’s disease classification. Brain Sci. 2023;13(2):254. pmid:36831797
  19. 19. Chen Y, Xia Y. Iterative sparse and deep learning for accurate diagnosis of Alzheimer’s disease. Pattern Recognition. 2021;116:107944.
  20. 20. Tanveer M, Rashid AH, Ganaie MA, Reza M, Razzak I, Hua K-L. Classification of Alzheimer’s disease using ensemble of deep neural networks trained through transfer learning. IEEE J Biomed Health Inform. 2022;26(4):1453–63. pmid:34033550
  21. 21. Zhang P, Lin S, Qiao J, Tu Y. Diagnosis of Alzheimer’s disease with ensemble learning classifier and 3D convolutional neural network. Sensors (Basel). 2021;21(22):7634. pmid:34833710
  22. 22. Grover P, Chaturvedi K, Zi X, Saxena A, Prakash S, Jan T, et al. Ensemble transfer learning for distinguishing cognitively normal and mild cognitive impairment patients using MRI. Algorithms. 2023;16(8):377.
  23. 23. An N, Ding H, Yang J, Au R, Ang TFA. Deep ensemble learning for Alzheimer’s disease classification. J Biomed Inform. 2020;105:103411. pmid:32234546
  24. 24. Gamal A, Elattar M, Selim S. Automatic early diagnosis of Alzheimer’s disease using 3D deep ensemble approach. IEEE Access. 2022;10:115974–87.
  25. 25. Lin L, Xiong M, Zhang G, Kang W, Sun S, Wu S, et al. A convolutional neural network and graph convolutional network based framework for AD classification. Sensors (Basel). 2023;23(4):1914. pmid:36850510
  26. 26. Valliani A, Soni A. Deep residual nets for improved Alzheimer’s diagnosis. Proceedings of the 8th ACM International Conference on Bioinformatics, Computational Biology,and Health Informatics. 2017. p. 615. https://dx.doi.org/10.1145/3107411.3108224
  27. 27. Zhu W, Sun L, Huang J, Han L, Zhang D. Dual attention multi-instance deep learning for Alzheimer’s disease diagnosis with structural MRI. IEEE Trans Med Imaging. 2021;40(9):2354–66. pmid:33939609
  28. 28. Woo S, Park J, Lee J-Y, Kweon IS. CBAM: Convolutional Block Attention Module. Computer Vision – ECCV 2018. Lecture Notes Comput Sci. 2018. p. 3–19.
  29. 29. Wen G, Hou Z, Li H, Li D, Jiang L, Xun E. Ensemble of Deep Neural Networks with Probability-Based Fusion for Facial Expression Recognition. Cogn Comput. 2017;9(5):597–610.
  30. 30. Ruiz J, Mahmud M, Modasshir M, Shamim Kaiser M, Alzheimer’s Disease Neuroimaging Initiative ft, editors. 3D DenseNet Ensemble in 4-Way Classification of Alzheimer’s Disease. Brain Informatics; 2020. Cham: Springer International Publishing; 2020. https://doi.org/10.1007/978-3-030-59277-6_8
  31. 31. Tufail AB, Anwar N, Othman MTB, Ullah I, Khan RA, Ma Y-K, et al. Early-Stage Alzheimer’s disease categorization using PET neuroimaging modality and convolutional neural networks in the 2D and 3D domains. Sensors (Basel). 2022;22(12):4609. pmid:35746389
  32. 32. Liu M, Li F, Yan H, Wang K, Ma Y, Alzheimer’s Disease Neuroimaging Initiative, et al. A multi-model deep convolutional neural network for automatic hippocampus segmentation and classification in Alzheimer’s disease. Neuroimage. 2020;208:116459. pmid:31837471
  33. 33. Ji B, Wang H, Zhang M, Mao B, Li X. An efficient lightweight network based on magnetic resonance images for predicting Alzheimer’s disease. Int J Semantic Web Info Syst. 2022;18(1):1–18.
  34. 34. Lao H, Zhang X. Diagnose Alzheimer’s disease by combining 3D discrete wavelet transform and 3D moment invariants. IET Image Process. 2022;16(14):3948–64.
  35. 35. Cui W, Yan C, Yan Z, Peng Y, Leng Y, Liu C, et al. BMNet: A new region-based metric learning method for early Alzheimer’s disease identification with FDG-PET Images. Front Neurosci. 2022;16:831533. pmid:35281501
  36. 36. Jitsuishi T, Yamaguchi A. Searching for optimal machine learning model to classify mild cognitive impairment (MCI) subtypes using multimodal MRI data. Sci Rep. 2022;12(1):4284. pmid:35277565
  37. 37. Fathi S, Ahmadi A, Dehnad A, Almasi-Dooghaee M, Sadegh M, Alzheimer’s Disease Neuroimaging Initiative. A deep learning-based ensemble method for early diagnosis of Alzheimer’s disease using MRI images. Neuroinformatics. 2024;22(1):89–105. pmid:38042764
  38. 38. Yang P, Zhou F, Ni D, Xu Y, Chen S, Wang T, et al. Fused sparse network learning for longitudinal analysis of mild cognitive impairment. IEEE Trans Cybern. 2021;51(1):233–46. pmid:31567112
  39. 39. Dharwada S, Tembhurne J, Diwan T. An optimal weighted ensemble of 3D CNNs for early diagnosis of Alzheimer’s disease. SN Comput Sci. 2024;5(2).
  40. 40. Hajamohideen F, Shaffi N, Mahmud M, Subramanian K, Al Sariri A, Vimbi V, et al. Four-way classification of Alzheimer’s disease using deep Siamese convolutional neural network with triplet-loss function. Brain Inform. 2023;10(1):5. pmid:36806042
  41. 41. Parmar H, Nutter B, Long R, Antani S, Mitra S. Spatiotemporal feature extraction and classification of Alzheimer’s disease using deep learning 3D-CNN for fMRI data. J Med Imaging (Bellingham). 2020;7(5):056001. pmid:37476352