Skip to main content
Advertisement
Browse Subject Areas
?

Click through the PLOS taxonomy to find articles in your field.

For more information about PLOS Subject Areas, click here.

  • Loading metrics

Efficient federated learning via aggregation of base models

  • Pan Wang ,

    Contributed equally to this work with: Pan Wang, Zhengyi Zhong

    Roles Conceptualization, Formal analysis, Methodology, Project administration, Validation, Writing – original draft

    Affiliation Laboratory for Big Data and Decision, National University of Defense Technology, Changsha, Hunan, China

  • Zhengyi Zhong ,

    Contributed equally to this work with: Pan Wang, Zhengyi Zhong

    Roles Conceptualization, Formal analysis, Methodology, Project administration, Visualization, Writing – original draft

    Affiliation Laboratory for Big Data and Decision, National University of Defense Technology, Changsha, Hunan, China

  • Ji Wang

    Roles Conceptualization, Resources, Supervision

    wangji@nudt.edu.cn

    Affiliation Laboratory for Big Data and Decision, National University of Defense Technology, Changsha, Hunan, China

Abstract

Federated Learning (FL), as a distributed computing framework for training machine learning (ML) models, has garnered significant attention for its superior privacy protection. In typical FL, a subset of client models is randomly selected for aggregation in each iteration, which performs well when the data is independent and identically distributed (IID). However, in real-world scenarios, data is often non-independent and identically distributed (Non-IID). Random selection cannot capture knowledge from different data distributions, resulting in a global model with lower accuracy and slower convergence. To address this challenge, we propose base models, which are models with diverse data distributions on clients. By combining the parameters of these base models, we can approximate all client models. Meanwhile, we sufficiently demonstrate the existence of base models. Then we employ the evolutionary algorithm (EA) to identify base models on distributed clients by encoding client IDs and optimizing client selection through crossover, mutation, and other evolutionary operations. Our method addresses the issue of low efficiency in random selection. We conduct experiments on the FashionMNIST, MNIST, and TodayNews datasets, applying the proposed method to FL frameworks such as FedAvg, FedProx, and SCAFFOLD, all of which show superior performance and faster convergence.

1 Introduction

With the rapid advancement of technologies such as the Internet of Things (IoT) and artificial intelligence (AI), mobile devices like smartphones and tablets are increasingly capable of storing, processing, and transmitting vast amounts of data. However, due to privacy concerns and the enforcement of strict data protection regulations [1], internet companies are unable to gather clients data for model training. To tackle this real-world challenge, McMahan et al. [2] proposed Federated Learning (FL), which aims to facilitate client model training among multiple participants while safeguarding data privacy. As a distributed computing framework, FL [2] can aggregate model knowledge from different clients by only uploading model parameters. Ultimately obtains a global model with diverse data features. FL involves four processes. Initially, the server distributes the global model to the clients. Secondly, the clients utilize their local data to train and obtain client-side models. Subsequently, the clients update their local models and upload them to the server. Finally, the server aggregates the client models to obtain the global model. This iterative process continues until the model converges or achieves the desired performance. The specific process is illustrated in Fig 1.

thumbnail
Fig 1. The process of FL.

Step 1: Server distributes the global model; Step 2: Clients train locally; Step 3: Clients upload the local model; Step 4: Server aggregates models.

https://doi.org/10.1371/journal.pone.0327883.g001

Classical FL perform well under IID settings but exhibit poor performance under Non-IID settings. This is because the non-independent and identically distributed (non-i.i.d.) data on the clients may lead to similar data distributions among different clients. Traditional methods, which select clients for aggregation randomly, are likely to include clients with similar distributions. As a result, the global model needs to undergo more rounds of aggregation to capture the full knowledge from the clients, ultimately leading to a decrease in both the convergence speed and accuracy of the global model. Researchers combined FL with other sophisticated methods, such as regularization [3,4], knowledge distillation [57], and client selection [8,9], to tackle the challenges posed by data heterogeneity in FL. For example, Chai et al. [10] categorized clients into different layers for stratified training, while Zhu et al. [5] applied data-free distillation methods to mitigate the effects of data heterogeneity. While these approaches are adept at addressing the issue of data heterogeneity, several critical issues remain unresolved. In typical FL algorithms, server randomly selects models for aggregation, then using the aggregation of these “lucky” clients as the global model. However, models selected through this method are often biased, since the characteristics of the model directly reflect the characteristics of the data, and data distributions of randomly selected clients are likely to exhibit high similarity. Hence, it is difficult to incorporate the diverse data characteristics into the training process.

Marfoq et al. [11] proposed and demonstrated that that the data distribution of each client is a mixture of M latent distributions. This finding underscores the inherent diversity of client data and suggests that in the model aggregation process of federated learning, randomly selecting clients may lead to the aggregation of models with similar data distributions, potentially failing to capture the full spectrum of data diversity. This motivates us to propose the concept of base models, which refer to a set of client models with the most distinct data distribution differences, correspond to local data distributions that exhibit orthogonality. Via linear parameter combinations, these models can cover all clients’ feature spaces and approximate any client model. On distributed clients, there always exist N base models that can represent any model through the combination of parameters. This approach enables the global model to integrate features from base models as much as possible. In other words, we can select client models that are as orthogonal as possible to the global model for aggregation. This helps the global model incorporate knowledge from client models with significantly different data distributions, thereby accelerating the convergence speed and accuracy of the model. This approach is orthogonal to previous methods. In large-scale client scenarios, the identification of base models is challenging. The genetic algorithm is particularly suitable for global search through crossover, mutation, and other operations. Therefore, this paper using the genetic algorithm to solve for the base models on clients. Assuming a fixed number of clients selected in each round, the best-performing base models are chosen from all clients for aggregation, thereby accelerating the FL process and enhancing performance.

The key contributions of this paper are as follows:

  • Introducing the concept of base models under the FL architecture, and theoretically proves the existence of base models, which can enhance the aggregation effect of models, and accelerate the learning progress.
  • Proposing an efficient federated learning method via aggregation of base models, utilizing the genetic algorithm to optimize the selection of base models, thereby globally optimizing the iterative process of FL.
  • We conduct experiments on image and text datasets, including FashionMNIST, MNIST and TodayNews, using various classic FL frameworks such as FedAvg [2], FedProx [3], and SCAFFOLD [12]. Our method shows superior performance compared to related approaches.

The rest of the paper is organized as follows. Sect 2 reviews the advancements in federated learning, heterogeneous federated learning, and client selection. Sect 3 analyzes the existence of base models and proves relevant lemmas. Sect 4 details the process for solving base models and provides algorithm. Sect 5 validates the proposed method, with experimental results demonstrating its superiority. Sect 6 discusses our research and highlights the limitations of the study. Sect 7 summarizes the paper.

2 Related work

2.1 Federated learning

In 2016, Google proposed FL [2] as a distributed learning framework, which aggregates models from distributed clients without the need to upload user data, thereby safeguarding user privacy. With the rapid development of the Internet of Things, FL has garnered significant attention in various fields such as next-word prediction on keyboards [13], financial fraud detection [14], and medical impact analysis [15]. Current research on FL primarily focus on several challenges: communication bottlenecks, client system heterogeneity, privacy protection and data heterogeneity [16].

Firstly, the need for communication between clients and the server in FL gives rise to communication bottlenecks. Wireless communication, which typically has lower bandwidth and speed compared to wired communication, becomes a limiting factor for FL. To address this issue, researchers have explored a number of techniques, These include reducing communication frequency by communicating only once or at specific intervals [1719], employing quantization methods to decrease the size of communication vectors [2022], and utilizing sparsification techniques to randomly sparsify the gradients of local training models, thereby reducing communication costs [2325].

Secondly, in the distributed computing framework of FL, different clients likely exhibit significant differences in memory and computation, leading to the exclusion of resource-constrained clients from training and posing challenges of system heterogeneity. Research [9] explored an optimized method for client selection to address heterogeneity issues. Furthermore, Li et al. [3] proposed the FedProx, which mitigates heterogeneous issues by improving client loss through the addition of regularization terms in the loss function. This work has garnered significant attention.

Thirdly, preserving privacy [26,27] presents a significant challenge in the field of federated learning. Jagarlamudi et al. [28] conduct a comprehensive survey on privacy measurement within federated learning, identifying current gaps and suggesting future directions, including the integration of quantum computing and other cutting-edge technologies. Rabieinejad et al. [29] propose a two-level privacy-preserving framework that combines federated learning with partially homomorphic encryption to enhance security and reduce attack prediction errors. Meanwhile, Yazdinejad et al. [30] concentrate on reinforcing privacy and security in CIoT devices through federated learning and innovative encryption techniques. Collectively, these methods have propelled the advancement of privacy protection within federated learning. Furthermore, the privacy protection of federated learning in practical applications [31,32] remains to be further explored.

Furthermore, apart from hardware disparities, distributed clients may also face significant differences in local data, resulting in data heterogeneity challenges. A study by Zhao et al. [33] explored methods to improve accuracy under non-IID settings and demonstrated that weight discrepancies during training are a key factor contributing to accuracy degradation. Systems like Astraea [34] has been employed to address global imbalanced data problems by selectively combining biased local data to create more balanced datasets, and has implemented z-score based data augmentation methods to alleviate global data imbalances. Additionally, algorithms such as Tucker [35] decomposition-based algorithms have been proposed to fuse heterogeneous data in FL. We systematically introduce the challenges faced by FL and summarize related work. Next, we will conduct further research on data heterogeneity.

2.2 Data heterogeneity in federated learning

Data heterogeneity in FL refers to the inconsistencies in data distribution among clients, where samples are non-independent and identically distributed (Non-IID). In FL, federated aggregation methods play a crucial role in updating the global model, which are utilized to aggregate parameters from various participants (such as tablets and smartphones) and update the global model, ultimately determining the success of model training [3640].

McMahan et al. [2] proposed the federated averaging algorithm, which computes the average of received parameters to update the global model, then returned to clients for further training. Wang et al. [41] introduced the trimmed mean aggregation method, which clips model updates within a predefined range. This approach helps reduce the impact of outliers and potential malicious updates from clients on the results. Reyes et al. [42] presented the federated weighted aggregation method, where the server weights the contribution of each client in the global model based on client performance. Liu et al. [43] introduced a hierarchical aggregation method that involves conducting local aggregation at lower levels of the hierarchy before transmitting the results to higher levels, improving the convergence efficiency of the global model.

The aforementioned FL algorithms have shown preferable results in the client heterogeneity. However, in scenarios such as network congestion and limited communication resources, clients often face upload delays, packet loss, and other challenges. Literature [4449] has applied exponential moving average algorithms to FL, achieving effective parameter transmission and model updates even under constrained communication resources. To address issues faced with an increasing number of clients, some researchers have investigated resource allocation method [50] in wireless networks for FL. While these methods effectively address challenges related to an increasing number of clients, randomly selected models are often trained on data from similar distributions, resulting in slower convergence and reduced accuracy.

2.3 Client selection

Due to the Non-IID local data on clients, there is a significant deviation between the local optimization objectives and the global optimization objective. Randomly selected clients for aggregation during the FL process will exacerbate the negative impact of data heterogeneity. Fu et al. [51] highlighted that client selection in FL is an emerging topic, and an effective client selection scheme can significantly improve model efficiency. Chai et al. [10] confirmed the detrimental effects of random client selection on federated learning performance through theoretical analysis and experimental validation. Literature [8] introduces the Oort to establish data selection criteria to obtain clients with more informative and fast-executing training capabilities. Literature [9] suggested that in heterogeneous environments, prioritizing clients with higher local loss values can accelerate the convergence speed of the global model, thereby improving communication efficiency and providing the proof of convergence for biased client selection in FL.

Additionally, Jin et al. [52] selected appropriate clients and excluded unnecessary model updates to save resources, designing an online learning algorithm that jointly controls participant selection in an online manner. Ribero et al. [53] proposed a selection strategy based on client availability, progressively minimizing the impact of client sampling variance on the convergence of the global model and thereby enhancing federated learning performance. Luo et al. [54] proposed an adaptive client sampling algorithm to address system and statistical heterogeneity, minimizing convergence time. Marnissi et al. [55] designed a device selection strategy based on the importance of gradient norms. AdaFL [56] dynamically adjusts the number of selected clients using a piecewise function, starting with a smaller selection size to reduce communication overhead, and gradually increasing the selection size to enhance the model’s generalization capabilities. TiFLCS-MAR [57] is integrated into the federated learning framework, allowing for a comprehensive evaluation of client attributes and employing a tiered strategy to mitigate issues arising from client heterogeneity. These methods prioritize aggregating clients with large amounts of data, potentially excluding clients with small amounts of data from participating in the aggregation process. However, these underrepresented clients may have distinct data distributions that are orthogonal to those of other client models. This paper further optimizes client selection by the evolutionary algorithm (EA), aiming to aggregate base models with lricher data distributions, accelerating the convergence speed and improving the accuracy of the global model.

3 Theoretical analysis

In contrast to the random selection of clients in traditional FL, we propose an aggregation method for FL via the aggregation of base models to enhance model training efficiency. Building on the concept of base distribution introduced by Marfoq et al. [11], this section theoretically proves the existence of base models on distributed clients.

Assuming there exists a model within the client set C that performs task classification, each model can be represented as a weighted average combination of N base models , this meaning the local data distribution Dc corresponding to the client model can be expressed as a weighted average combination, where . The data on client is generated according to the data distribution within the range of , where represents the Cartesian product of the input space X and the output space Y. The data distribution on the clients C is generally different, leading to different models sc trained by the clients. Based on the above conditions, the following optimization problem is considered:

(1)

thus, the objective is to minimize the loss during training the model on the local data distribution, where is the loss function, and is the client model mapping from the input space X to the output space Y (where denotes the unitary simplex of dimension . The evaluation of the loss for the model sc during training is defined by . For any , we denote the joint distribution density related to Dc as pc(x,y), and the marginal densities as pc(x) and pc(y).

During the training process, an initial screening of base models is conducted, and the corresponding clients are added to the set for aggregation. Subsequently, through multiple rounds of iteration, more suitable base models are gradually selected. represents the dataset extracted independent and identically from the data distribution Dc on client , with denoting the total dataset size. The purpose of selecting and aggregating base models is to find client models with strong diversity and low similarity, aiming to reduce the loss and enhance the model’s generalization ability. Marfoq et al. [11] point out that without additional assumptions on the local distribution pt(x,y), the improvements from collaboration among clients via federated learning (FL) algorithms are limited. This collaboration may only result in an increase of a constant factor in sample complexity, rather than a significant reduction in the required sample size. Assuming that there exists some structural relationship (such as similarity or correlation) between the output distributions pt(y|x) across different clients would enable FL algorithms to share information and learn more effectively, thereby significantly reducing sample complexity. Based on these considerations, the following assumptions are made.

Assumption 1. There exist n base models with corresponding data distributions . When , the client model sc is a combination of the base models , with corresponding weights . It can be expressed as:

(2)

is a polynomial distribution parameterized by γ, and the probability density distributions related to are denoted as pn(x,y), pn(x), and pn(y).

Assumption 2. For any , it holds that

(3)

Assumption 2 implies that we can proceed with the inference in discriminative and classification models (such as neural networks). More specifically, we consider a set of trained classification models for the next steps.

Assumption 3. is a set of base models parameterized by after local training, with the boundary of this set contained within S. For any base model in when , its corresponding data distribution is , such that:

(4)

where is a normalization constant, and the function is the log loss function of pn(y|x). The models in are base models, denoted by representing a matrix where the n-th row is , and representing a matrix where the c-th row is . Similarly, and Γ can represent any parameters in the matrix.

Under the aforementioned Assumptions, it is known that pc(x,y) is determined by and . We can prove that the optimal local model on client can be represented as a weighted average of the base models in the set .

Proposition. Using to represent mean squared error loss, regression loss, or cross-entropy loss, let and be the solutions to the following optimization problem:

(5)

For any client model under the distribution Dc, Under Assumptions 1, 2, and 3, client model can be represented as:

(6)

by minimizing , problem 1 can be solved.

Proposition 1 presents the solution to problem (1). Firstly, estimate the parameters and () by minimizing problem (5) on the training data, i.e., minimizing:

(7)

The above equation represents the negative log-likelihood function of model (2). Next, utilize (5) to obtain models for the C clients involved in training during training time. Lastly, to handle unseen client during training, maintain the base models unchanged, select the weights that maximize the likelihood of client data, and predict the local model for the client using (5).

Lemma 1. Under Assumptions 1 and 2, let and be the solution to problem (8), then we have:

(8)

Lemma 2. Represent the N probability distributions on Y as qn for , and let . For any probability distribution q on Y, if and only if , it holds that:

(9)

Lemma 3. Since and are solutions to problem (5), under Assumptions 1, 2, and 3, if rs does not depend on , it is possible to minimize by the model . When , it can be proven:

(10)

According to Lemma 3, pc(x,y)depends on and .

For and , let ps(y|x) denote the conditional probability distribution of y given x under model s, defined as:

(11)

where

(12)

The entropy of a probability distribution q on Y can be expressed as:

(13)

the Kullback-Leibler (KL) divergence between two probability distributions q1 and q2 on Y (a measure of the asymmetry of difference between two distributions) can be represented as:

(14)

When using mean squared error, regression loss, and cross-entropy loss functions, we verified that in these three cases, rs is independent of s, and then concluded using Lemma 3.

Mean squared error. This is a regression problem where g > 0, and there exists . For , , we have:

(15)

and

(16)

Regression loss. This is a binary classification problem where L > 1, and Y = [L]. For , , we have:

(17)

and

(18)

Cross-entropy loss. This is a classification problem where L > 1, and Y = [L]. For and , we have:

(19)

and

(20)

Theorem. For , consider a model that minimizes . Using Lemma 3, for , we have:

(21)

Multiplying both sides of the equation by y and integrating over , in all three cases, we have:

(22)

therefore,

(23)

Meaning the base model can represent any model in the client set C. As long as Lemma 1, Lemma 2, and Lemma 3 hold, the existence of the base model can be proven. We have included the proofs of Lemma 1, Lemma 2, and Lemma 3 in the S1 Appendix.

4 Solving base model using the evolutionary algorithm

Building on the existence of base models, this section introduces the method for solving base models using evolutionary algorithms and the corresponding algorithmic procedure.

4.1 Base model solving

Before delving into the methodology of this paper, it is essential to provide the process of FL and the symbols used. FL is a distributed machine learning framework, which consists of a central server for model aggregation and multiple distributed clients for executing intelligent computing tasks. Assuming there are N clients, each client has a data volume of Di(i = 1,2,...,N), and the model is denoted as . The server does not store data. In each round of iteration, the server first distributes the global model to all clients. Subsequently, each client conducts local training based on the received model. The training process is as follows:

(24)

where represents the model parameters after the e-th local training in the t-th global iteration. In each round of iteration, the client conducts E local trainings. is the loss function of the i-th client after local training.

After local training by all clients is completed, the server randomly selects clients from N clients for aggregation. Taking FedAvg as an example, the server aggregates based on the weights of the data volume uploaded by selected clients. The aggregation process is as follows.

(25)

where represents the global model aggregated in the t-th iteration process, T is the total number of iterations, and the calculation method of D is as follows:

(26)

Finally, the server redistributes the global model to the clients for training, and the above process is iterated repeatedly until convergence or the preset number of iterations is reached.

After the description of the classic FL method, we introduce a base model solving approach based on evolutionary algorithms. The primary aim of this method is to reform the client selection process, which traditional approach is to randomly select clients. To enhance aggregation efficiency, this paper proposes using a classic evolutionary algorithm, the genetic algorithm to solve for the base models to be aggregated. This approach involves several key operations, including genetic encoding, crossover, and mutation.

In terms of genetic encoding, the length of the chromosome is set to , where each gene on the chromosome represents the ID of a selected client. There are a total of N clients, with client IDs ranging from 0 to (N–1). After a certain number of federated iterations, a genetic algorithm is applied to optimize the selected clients, aiming to identify the base models in N client models as much as possible. The population size is denoted as popsize.

In each iteration, 50% of the superior individuals are retained from the parent population. Each chromosome can be represented as , where j is the chromosome number (). The chromosomes are not affected by the order of genes, and there are no duplicate genes within a chromosome. Therefore, two chromosomes with different gene orders are considered as the same chromosome.

Crossover Operator. The crossover probability is typically set to 1. For every adjacent pair of chromosomes Chmj and Chmj + 1, the intersection set Common[j,j + 1] is obtained, which consists of the client IDs that are present in both chromosomes. Thus, each chromosome Chmj is composed of two parts: the common IDs and the unique IDs.

(27)(28)

the crossover operator exchanges the unique IDs, excluding Common[j,j + 1], from two chromosomes Chmj and Chmj + 1 to create new individuals (as shown in Fig 2).

(29)(30)

Mutation Operator. The mutation probability is defined as muta prob, and the mutation length is represented by a random variable lj. A random gene block of length lj is selected from chromosome Chmj, and the IDs of this gene block are randomly replaced with client IDs that are not present in Chmj (as shown in Fig 3).

(31)

where Crossentropy () = .

Finally, the fitness function is obtained as follows:

(32)

Loss1 aims to minimize the similarity loss of the models in order to select client models that are as orthogonal as possible to the global model, while Loss2 focuses on enhancing classification performance by minimizing cross-entropy. By integrating Loss1 and Loss2, we obtain a fitness function that allows us to simultaneously address feature integration and classification performance during the optimization process. By introducing a hyperparameter μ, we can adjust the relative importance of the two losses in the fitness calculation, thereby flexibly adapting to different application scenarios and requirements. Ultimately, we select individuals with high fitness values to meet the selection criteria for client models.

4.2 Algorithm

The detailed computation process of the proposed method is described in Algorithm 1. The server initializes the global model and distributes it to individual clients. Each client conducts local training for E epochs based on the received global model. After completing the training, the server selects N clients for federated aggregation. Every GA iterations, a genetic algorithm is applied to optimize the selected client ID list, aiming to accelerate the convergence speed and effectiveness of the model. This iterative process continues until convergence is reached or the predefined iteration count T is achieved.

Algorithm 1. Efficient Federated Learning via Aggregation of Base Models.

Require: learning rate η, federated iteration count T, local

  training epochs E, total number of clients N, number of

  clients per round , client i data and data size Di,

  optimization interval

Ensure: global model

1: Initialize global model

2: for do

3:   Server distributes the initial global model to each client

4:   for do

5:   

6:   end for

7:   if then

8:    Genetic Algorithm optimizes

  client IDs

9:    //Use optimized client IDs

10:   else

11:    Randomly select client models for aggregation

12:   

13:   end if

14: end forreturn global model

5 Experiment

In order to validate the effectiveness of the proposed method, this section aims to address the following questions through experiments:

  • Can using evolutionary algorithms to solve the base model in FL improve the performance of the aggregated model?
  • Is the distribution of client data influence the aggregation effect of the proposed method?
  • During the process of solving the base model, how do different optimization intervals affect the effectiveness of the aggregated model?

5.1 Experiment setting

Datasets and models. We extensively evaluate our method on MNIST, FashionMNIST, and TodayNews datasets, and conduct experiments using the LeNet and TextCNN. The MNIST dataset contains 70,000 handwritten digit images, with 60,000 images used for training and 10,000 images for testing. Each image has a size of 28x28 pixels. The dataset consists of 10 classes (digits 0-9), with each class having a relatively uniform sample size of approximately 7,000 images. Similarly, the FashionMNIST dataset also contains 70,000 images, with 60,000 for training and 10,000 for testing, and the images are of the same size (28x28 pixels). This dataset includes 10 categories (such as T-shirts, trousers, shoes, etc.), with a relatively uniform sample size across each category. In both datasets, the sample sizes for each class are relatively uniform. To simulate data heterogeneity among clients in a non-iid scenario, we can partition the datasets using a Dirichlet distribution, resulting in a distribution that more closely resembles real-world applications. The TodayNews dataset contains approximately 30,000 news articles covering various topics. The distribution of categories in this dataset is relatively uneven, with some topics having a large number of articles while others have relatively few. This imbalance makes the TodayNews dataset more suitable for simulating real-world federated learning scenarios.

Baselines. We incorporated components of genetic algorithms into the classic federated learning algorithms FedAvg, FedProx, and SCAFFOLD to validate the effectiveness of the components proposed in this paper. We employed the same training settings and dataset allocations across all experiments. The partitioning of the dataset and the allocation of client data were kept consistent to ensure the comparability of experimental conditions. Parameters were adjusted to ensure that all baseline methods achieved optimal performance.

Hyper-parameters. The specific parameter settings are shown in Table 1. During the FL process, assuming there are 50 clients, with 10 clients participating in training each round. Based on the description in the literature [58] and considerations regarding computational power, the following points can be made. The population size of the genetic algorithm is set to 100, with a selection probability of 50% and a mutation probability of 10%. Every 10 federated iterations, the genetic algorithm is utilized to optimize the aggregated client IDs.

5.2 Independent and identically distributed data

For question 1, conduct experiments under the condition that the client data is independent and identically distributed. This section aims to verify the proposed method in this paper, to investigate whether it can improve the accuracy and convergence speed of the aggregated model compared to not using genetic algorithm optimization. The specific experimental results are shown in Figs 4, 5, and 6. In the figures, the horizontal axis represents the number of training epochs, while the vertical axis indicates the accuracy. The blue line represents the results obtained when using genetic algorithm in the classic FL framework, while the orange line corresponds to the results obtained when training in the classic FL framework without genetic algorithm.

thumbnail
Fig 4. The accuracy of MNIST dataset under FedAvg, FedProx, and SCAFFOLD.

https://doi.org/10.1371/journal.pone.0327883.g004

thumbnail
Fig 5. The accuracy of FashionMNIST dataset under FedAvg, FedProx, and SCAFFOLD.

https://doi.org/10.1371/journal.pone.0327883.g005

thumbnail
Fig 6. The accuracy of FashionMNIST dataset under FedAvg, FedProx, and SCAFFOLD.

https://doi.org/10.1371/journal.pone.0327883.g006

From Figs 4, 5, and 6, it can be observed that when training on the MNIST dataset, FashionMNIST dataset, and TodayNews dataset, under the FedAvg, FedProx, and SCAFFOLD frameworks, incorporating genetic algorithm optimization every 10 iterations leads to faster convergence speed and higher final convergence values (accuracy) compared to not using genetic algorithm. Specifically, our proposed method consistently converges and achieves higher accuracy values in the range of 20-40 iterations. This is attributed to the ability of our method to help the server select base models with larger discrepancies, thereby enhancing the diversity of features fused in the aggregation model. In the case of independent and identically distributed (IID) data, the data distribution is the same across all clients. However, there are still some discrepancies in training performance. The client models selected using the genetic algorithm, which are as orthogonal as possible to the global model, have certain advantages over those obtained through random selection. Therefore, this method demonstrates improved performance compared to random aggregation in federated learning.

5.3 Non-independent and identically distributed data

However, in real-world scenarios, due to the varying environments each client faces, the data from clients often exhibit non-independent and identically distributed characteristics.

Therefore, to answer question 2, this section extends the experiments from Sect 5.2 to verify the performance under Non-IID environment, while keeping other settings unchanged. Using the Dirichlet method to partitioning data, with an imbalance factor set to 0.1.

From Figs 7, 8, and 9, it can be observed that when the data distribution is imbalanced, the advantages of our method become more pronounced compared to the IID scenario in Sect 5.2. In terms of convergence speed, our proposed method consistently tends to converge when the number of iterations is between 20-40, while achieving higher accuracy. This is because the primary issue addressed in this paper is that, under Non-IID settings, the client models selected for aggregation by the server are likely to come from the same (or similar) data distributions, resulting in a biased global model. In the case of non-independent and identically distributed (non-IID) data, the method proposed in this paper can filter client models that are orthogonal to the global model parameters for aggregation, facilitating earlier convergence of the global model. In contrast, the random selection method requires multiple rounds of filtering and aggregation to ensure that the global model covers the data distribution of all client models, resulting in slower convergence of the global model.

thumbnail
Fig 7. The accuracy of MNIST dataset under FedAvg, FedProx, and SCAFFOLD.

https://doi.org/10.1371/journal.pone.0327883.g007

thumbnail
Fig 8. The accuracy of FashionMNIST dataset under FedAvg, FedProx, and SCAFFOLD.

https://doi.org/10.1371/journal.pone.0327883.g008

thumbnail
Fig 9. The accuracy of TodayNews dataset under FedAvg, FedProx, and SCAFFOLD.

https://doi.org/10.1371/journal.pone.0327883.g009

To further validate the effectiveness of the proposed method under different degrees of data heterogeneity, we present a comparative performance analysis in the table for algorithms with and without evolutionary components, show the performance comparison when the α value is set to 0.5.

Based on the data in the Table 2, the federated learning methods with EA components consistently demonstrate superior performance. They achieve higher accuracy compared to methods without EA, and they also converge more quickly in terms of the number of rounds and the time required. The genetic algorithm takes approximately 46.96 seconds to perform model selection on the server side, we can observe that the use of the EA algorithm reduces the time required for client models to achieve local convergence. Although the EA algorithm increases resource consumption on the server side, it also cuts down on resource use on the client side, thereby optimizing the usability of federated learning methods on resource-constrained end devices.

thumbnail
Table 2. Comparison between methods with EA and methods without EA.

https://doi.org/10.1371/journal.pone.0327883.t002

For the case when the α is set to 0.5, we conducted a statistical significance test to demonstrate the effectiveness of the proposed method.

Null Hypothesis (H0): There is no significant difference in accuracy between the models with and without the components ().

Alternative Hypothesis (H1): The accuracy of the models with the components is significantly higher than that of the models without the components ().

The calculated pvalue is approximately 0.0015, which is less than 0.05. Therefore, we conclude that, at the 0.05 significance level, there is sufficient statistical evidence to indicate that the inclusion of the components significantly improves accuracy.

5.4 Supplementary experiments

5.4.1 Impact of the hyperparameter .

In the equation (32), μ denotes the balance between diversity and accuracy in the loss function. We’ve designed a series of experiments to identify the optimal value of μ. We performed experiments on the MNIST dataset to explore the optimal μ value. As shown in the Table 3, we set μ to 0.05, 0.2, 0.4, 0.5, and 1. The global model converged fastest and achieved the best performance when μ value was in the range of 0.2 to 0.4.

5.4.2 Impact of different optimization intervals.

After a certain number of federated iterations, the method in this paper utilizes a genetic algorithm to optimize the selected client IDs.

To answer question 3, this section will primarily investigate the impact of the interval size used in genetic algorithm on learning outcomes. In Sects 5.2 and 5.3, Set the interval hyperparameter GA_gap to 10. In this section, considering the Non-IID data distribution scenario, different interval sizes are explored to observe the convergence behavior of the algorithm, and investigate the influence of the interval hyperparameter on the federated aggregation model.

In the experiments, the MNIST, TodayNews, and FashionMNSIT datasets are trained under the genetic algorithm-based FL framework. Furthermore, the interval is increased to 12, 15, and 17, and the results are shown in Figs 10, 11, and 12.

thumbnail
Fig 10. Comparison of MNIST with different intervals under FedAvg, FedProx and SCAFFOLD.

https://doi.org/10.1371/journal.pone.0327883.g010

thumbnail
Fig 11. Comparison of TodayNew with different intervals under FedAvg, FedProx and SCAFFOLD.

https://doi.org/10.1371/journal.pone.0327883.g011

thumbnail
Fig 12. Comparison of FashionMNIST with different intervals under FedAvg, FedProx and SCAFFOLD.

https://doi.org/10.1371/journal.pone.0327883.g012

It can be observed from the figures that with a larger GA_gap, the optimization rounds for aggregating client models using genetic algorithm are reduced within a limited number of iterations, leading to a slightly decreasing trend in the performance of the final aggregated model. Considering the three methods FedAvg, FedProx, and SCAFFOLD, overall performance is best when GA_gap=10.

5.4.3 Impact of the dirichlet distribution.

MNIST experiments with varying Dirichlet distributions to verify our method’s effectiveness under different data heterogeneity levels. As shown in the Table 4, when different data heterogeneity levels are set, the genetic algorithm consistently converges faster than without it, demonstrating the adaptability of our approach across varying heterogeneity levels.

5.4.4 The analysis of base models.

We use PCA-based dimensionality reduction to visualize the base model’s embedding distribution. The visualizations clearly illustrate the base model’s feature space distribution, further confirming their ability to represent broader model diversity. As shown in the Figs 13 and 14, the client models selected through genetic algorithm have a greater distance between them. The point size represents the distance from the global model, and horizontal and vertical coordinates represent the offset on the first and second principal components after PCA dimensionality reduction. Compared with those selected at random, aggregating base models with larger differences can make the global model converge more rapidly.

thumbnail
Fig 13. Visualization Comparison Between Client Model with NO_EA.

https://doi.org/10.1371/journal.pone.0327883.g013

thumbnail
Fig 14. Visualization Comparison Between Client Model with EA.

https://doi.org/10.1371/journal.pone.0327883.g014

The cosine distance is used to calculate the distance between a base model and a global model. The cosine distance is calculated as follows:

(33)

Among them, A and B are the feature vectors processed by the global model and the client model. The value range of cosine distance is between 0 and 2. The larger the cosine distance, the greater the difference in vector direction. As shown in the Table 5, the base models selected by GA show a greater distance from the globally initialized model. This indicates that our screening approach, while preserving model accuracy, favors the selection of client models with more substantial updates for aggregation. Thereby, it incorporates a wider diversity of client data characteristics, which enhances the convergence of the global model.

6 Discussion

We validated our approach using convolutional neural networks on image and text datasets, including MNIST, FashionMNIST, and TodayNews, achieving encouraging experimental results. Due to limitations in computational resources and time, we did not conduct further validation on larger-scale datasets or networks. To enhance the generalizability of our research, we plan to extend our work to larger datasets in the future to comprehensively evaluate the model’s performance and stability. Additionally, testing the adaptability of different network architectures will be a key focus of our subsequent research to ensure that the model maintains consistent performance across a wider range of applications. Through these expansions, we aim to further validate our conclusions and lay the groundwork for future studies.

Limitation. Using the evolutionary algorithm to select clients that are as orthogonal as possible reduces the number of iterations, but it increases the computational consumption on the server during the server’s preference process, even though we typically assume that the server’s computing resources are unlimited. Frequent combination and exchange of client models on the server side can intrinsically pose a risk of model privacy leakage. In future work, we will explore the use of differential privacy and other techniques to better protect client - side data security during model transmission.

7 Conclusion

This paper theoretically proves the existence of base models in the FL framework, and proposes an efficient federated learning method via aggregation of base models. This method utilizes the evolutionary algorithm to derive base models from clients with diverse data features, aiming to integrate data characteristics from different distributions in the global model as much as possible. It effectively addresses the issue of biased aggregation results from randomly selected clients in FL, thus accelerating the convergence speed and improving the effectiveness of federated aggregation. Using datasets such as MNIST, FashionMNIST, and TodayNews, and frameworks like FedAvg, FedProx, and SCAFFOLD, it is validated that under IID or Non-IID environment, the global model obtained through the aggregation of base models proposed in this paper, achieves faster convergence and better performance in FL.

Supporting information

References

  1. 1. Voigt P, Von dem Bussche A. The EU general data protection regulation (GDPR). 1st ed. Cham: Springer; 2017.
  2. 2. McMahan B, Moore E, Ramage D, Hampson S, y Arcas BA. Communication-efficient learning of deep networks from decentralized data. In: Artificial Intelligence and Statistics, 2017. p. 1273–82.
  3. 3. Li T, Sahu AK, Zaheer M, Sanjabi M, Talwalkar A, Smith V. Federated optimization in heterogeneous networks. Proc Mach Learn Syst. 2020;2:429–50.
  4. 4. Acar DAE, Zhao Y, Navarro RM, Mattina M, Whatmough PN, Saligrama V. Federated learning based on dynamic regularization. arXiv preprint 2021. https://arxiv.org/abs/2111.04263
  5. 5. Zhu Z, Hong J, Zhou J. Data-free knowledge distillation for heterogeneous federated learning. In: International Conference on Machine Learning, 2021. 12878–89.
  6. 6. Jiang D, Shan C, Zhang Z. Federated learning algorithm based on knowledge distillation. In: 2020 International Conference on Artificial Intelligence and Computer Engineering (ICAICE). 2020. p. 163–7.
  7. 7. Li D, Wang J. FEDMD: heterogenous federated learning via model distillation. arXiv preprint 2019. https://arxiv.org/abs/1910.03581
  8. 8. Lai F, Zhu X, Madhyastha HV, Chowdhury M. Oort: Informed participant selection for scalable federated learning. arXiv preprint 2020. https://arxiv.org/abs/2010.06081
  9. 9. Nishio T, Yonetani R. Client selection for federated learning with heterogeneous resources in mobile edge. In: ICC 2019 -2019 IEEE International Conference on Communications (ICC). 2019. p. 1–7.
  10. 10. Chai Z, Ali A, Zawad S, Truex S, Anwar A, Baracaldo N, et al. In: Proceedings of the 29th International Symposium on High-Performance Parallel and Distributed Computing, 2020. p. 125–36.
  11. 11. Marfoq O, Neglia G, Bellet A, Kameni L, Vidal R. Federated multi-task learning under a mixture of distributions. Adv Neural Inf Process Syst. 2021;34:15434–47.
  12. 12. Karimireddy SP, Kale S, Mohri M, Reddi S, Stich S, Suresh AT. In: International conference on machine learning, 2020. 5132–43.
  13. 13. Hard A, Rao K, Mathews R, Ramaswamy S, Beaufays F, Augenstein S, Eichner H, Kiddon C, Ramage D. Federated learning for mobile keyboard prediction. arXiv preprint arXiv:1811.03604; 2018.
  14. 14. Liu T, Wang Z, He H, Shi W, Lin L, An R, et al. Efficient and secure federated learning for financial applications. Appl Sci. 2023;13(10):5877.
  15. 15. Chaddad A, Wu Y, Desrosiers C. Federated learning for healthcare applications. IEEE Internet Things J. 2023.
  16. 16. Horváth S. Better methods and theory for federated learning: Compression, client selection and heterogeneity. arXiv preprint 2022.
  17. 17. McDonald R, Mohri M, Silberman N, Walker D, Mann G. Efficient large-scale distributed training of conditional maximum entropy models. Adv Neural Inf Process Syst. 2009;22.
  18. 18. Zinkevich M, Weimer M, Li L, Smola A. Parallelized stochastic gradient descent. Adv Neural Inf Process Syst. 2010;23.
  19. 19. Stich SU. Local SGD converges fast and communicates little. arXiv preprint 2018.
  20. 20. Wen W, Xu C, Yan F, Wu C, Wang Y, Chen Y, et al. Terngrad: Ternary gradients to reduce communication in distributed deep learning. Adv Neural Inf Process Syst. 2017;30.
  21. 21. Wangni J, Wang J, Liu J, Zhang T. Gradient sparsification for communication-efficient distributed optimization. Adv Neural Inf Process Syst. 2018;31.
  22. 22. Hubara I, Courbariaux M, Soudry D, El-Yaniv R, Bengio Y. Quantized neural networks: Training neural networks with low precision weights and activations. J Mach Learn Res. 2018;18(187):1–30.
  23. 23. Suresh AT, Felix XY, Kumar S, McMahan HB. Distributed mean estimation with limited communication. In: International conference on machine learning. PMLR; 2017. p. 3329–37.
  24. 24. Konečná"³ J, Richtárik P. Randomized distributed mean estimation: accuracy vs. communication. Front Appl Math Stat. 2018;4:62.
  25. 25. Alistarh D, Hoefler T, Johansson M, Konstantinov N, Khirirat S, Renggli C. The convergence of sparsified gradient methods. Adv Neural Inf Process Syst. 2018;31.
  26. 26. Li Q, Wen Z, Wu Z, Hu S, Wang N, Li Y, et al. A survey on federated learning systems: vision, hype and reality for data privacy and protection. IEEE Transactions on Knowledge and Data Engineering. 2021;35(4):3347–66.
  27. 27. Yazdinejad A, Dehghantanha A, Karimipour H, Srivastava G, Parizi RM. A robust privacy-preserving federated learning model against model poisoning attacks. IEEE Transactions on Information Forensics and Security. 2024.
  28. 28. Jagarlamudli GK, Yazdinejad A, Parizi RM, Pouriyeh S. Exploring privacy measurement in federated learning. J Supercomput. 2024;80(8):10511–51.
  29. 29. Rabieinejad E, Yazdinejad A, Dehghantanha A, Srivastava G. Two-level privacy-preserving framework: Federated learning for attack detection in the consumer internet of things. IEEE Trans Consum Electron. 2024;70(1):4258–65.
  30. 30. Yazdinejad A, Dehghantanha A, Srivastava G, Karimipour H, Parizi RM. Hybrid privacy preserving federated learning against irregular users in next-generation Internet of Things. J Syst Archit. 2024;148:103088.
  31. 31. Pati S, Kumar S, Varma A, Edwards B, Lu C, Qu L, et al. Privacy preservation for federated learning in health care. Patterns. 2024;5(7).
  32. 32. Yang M, Huang D, Wan W, Jin M. Federated learning for privacy-preserving medical data sharing in drug development. ACE. 2024;108(1):7–13.
  33. 33. Zhao Y, Li M, Lai L, Suda N, Civin D, Chandra V. Federated learning with non-iid data. arXiv preprint 2018. https://arxiv.org/abs/1806.00582
  34. 34. Duan M, Liu D, Chen X, Tan Y, Ren J, Qiao L, et al. Astraea: self-balancing federated learning for improving classification accuracy of mobile deep learning applications. In: 2019 IEEE 37th International Conference on Computer Design (ICCD). 2019. p. 246–54.
  35. 35. Mo H, Zheng H, Gao M. Multi-source heterogeneous data fusion algorithm based on federated learning. J Comput Res Dev. 2022;59(2):10.
  36. 36. Zhong Z, Bao W, Wang J, Zhu X, Zhang X. Flee: a hierarchical federated learning framework for distributed deep neural network over cloud, edge, and end device. ACM Trans Intell Syst Technol. 2022;13(5):1–24.
  37. 37. Moshawrab M, Adda M, Bouzouane A, Ibrahim H, Raad A. Reviewing federated machine learning and its use in diseases prediction. Sensors (Basel). 2023;23(4):2112. pmid:36850717
  38. 38. Malekijoo A, Fadaeieslam MJ, Malekijou H, Homayounfar M, Alizadeh-Shabdiz F, Rawassizadeh R. Fedzip: a compression framework for communication-efficient federated learning. arXiv preprint 2021. https://arxiv.org/abs/2102.01593
  39. 39. Yang Q, Liu Y, Chen T, Tong Y. Federated machine learning. ACM Trans Intell Syst Technol. 2019;10(2):1–19.
  40. 40. Zhang C, Xie Y, Bai H, Yu B, Li W, Gao Y. A survey on federated learning. Knowl-Based Syst. 2021;216:106775.
  41. 41. Wang T, Zheng Z, Lin F. Federated learning framework based on trimmed mean aggregation rules. 2022. https://ssrn.com/abstract=4181353
  42. 42. Reyes J, Di Jorio L, Low-Kam C, Kersten-Oertel M. Precision-weighted federated learning. arXiv preprint. 2021. https://arxiv.org/abs/2107.09627
  43. 43. Liu L, Zhang J, Song S, Letaief KB. Hierarchical quantized federated learning: convergence analysis and system design. arXiv preprint 2021. https://arxiv.org/abs/2103.14272
  44. 44. Xie C, Koyejo S, Gupta I. Asynchronous federated optimization. arXiv preprint 2019. https://arxiv.org/abs/1903.03934
  45. 45. Chen Y, Ning Y, Slawski M, Rangwala H. Asynchronous online federated learning for edge devices with non-IID data. In: 2020 IEEE International Conference on Big Data (Big Data). 2020. p. 15–24. https://doi.org/10.1109/bigdata50022.2020.9378161
  46. 46. Damaskinos G, Guerraoui R, Kermarrec A-M, Nitu V, Patra R, Taiani F. FLeet: online federated learning via staleness awareness and performance prediction. ACM Trans Intell Syst Technol. 2022;13(5):1–30.
  47. 47. Sprague MR, Jalalirad A, Scavuzzo M, Capota C, Neun M, Do L, Kopp M. Asynchronous federated learning for geospatial applications. In: Joint European conference on machine learning and knowledge discovery in databases. Springer; 2018. p. 21–8.
  48. 48. Wu W, He L, Lin W, Mao R, Maple C, Jarvis S. SAFA: a semi-asynchronous protocol for fast federated learning with low overhead. IEEE Trans Comput. 2021;70(5):655–68.
  49. 49. Chai Z, Chen Y, Zhao L, Cheng Y, Rangwala H. Fedat: a communication-efficient federated learning method with asynchronous tiers under non-iid data. arXivorg. 2020.
  50. 50. Dinh CT, Tran NH, Nguyen MNH, Hong CS, Bao W, Zomaya AY, et al. Federated learning over wireless networks: convergence analysis and resource allocation. IEEE/ACM Trans Networking. 2021;29(1):398–409.
  51. 51. Fu L, Zhang H, Gao G, Zhang M, Liu X. Client selection in federated learning: principles, challenges, and opportunities. IEEE Internet of Things J. 2023.
  52. 52. Jin Y, Jiao L, Qian Z, Zhang S, Lu S, Wang X. Resource-efficient and convergence-preserving online participant selection in federated learning. In: 2020 IEEE 40th International Conference on Distributed Computing Systems (ICDCS), 2020. p. 606–16.
  53. 53. Ribero M, Vikalo H, de Veciana G. Federated learning under intermittent client availability and time-varying communication constraints. IEEE J Sel Top Signal Process. 2023;17(1):98–111.
  54. 54. Luo B, Xiao W, Wang S, Huang J, Tassiulas L. Tackling system and statistical heterogeneity for federated learning with adaptive client sampling. In: IEEE INFOCOM 2022 -IEEE Conference on Computer Communications, 2022. p. 1739–48.
  55. 55. Marnissi O, Hammouti HE, Bergou EH. Client selection in federated learning based on gradients importance. In: AIP Conference Proceedings, 2024.
  56. 56. Li Q, Li X, Zhou L, Yan X. Adafl: Adaptive client selection and dynamic contribution evaluation for efficient federated learning. In: ICASSP 2024 -2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). 2024. p. 6645–9.
  57. 57. Sun Y, Li B, Yang K, Bi X, Zhao X. TiFLCS-MARP: client selection and model pricing for federated learning in data markets. Exp Syst Appl. 2024;245:123071.
  58. 58. Eiben AE, Smith JE. Introduction to evolutionary computing. Springer; 2015.