Federated learning is a method of training AI models across multiple separate devices or servers, each of which retains its data locally. 

It allows models to learn from a wide range of sources without moving the underlying data, thereby helping to protect privacy and reduce security risks, a reason it is often considered part of a broader class of privacy-enhancing technologies and is essential for Private AI. Tools that support this approach include TensorFlow Federated, PySyft, and Flower.

It is part of the broader field of machine learning, which trains neural networks and other algorithms to recognize patterns in data and make predictions. Traditional methods gather all data in one place for training. Federated learning improves on this by keeping data in place, making it well-suited to sectors such as finance or healthcare, where data sharing is restricted.

What are the different types of federated learning? 

Federated learning can be categorized into five types, based on how data is distributed and the identities of the participants. Each type supports different collaboration scenarios while keeping data private.

  • Horizontal federated learning (HFL): A method used when participants have the same type of data but different users, for example, banks operating in different regions.
  • Vertical federated learning (VFL): Applied when organizations share users but possess different types of data about them,  such as a hospital and an insurance provider collaborating.
  • Federated transfer learning (FTL): Used when organizations have minimal overlap in both users and data. It enables knowledge transfer from one domain to another to train a shared model.
  • Cross-device federated learning: A technique involving many user devices, such as phones or sensors, each contributing small, local datasets to train a model collectively.
  • Cross-silo federated learning: Used by a limited number of trusted institutions, such as hospitals or retailers, that securely collaborate to update a shared model over time.

How does federated learning work?

Federated learning works by coordinating a cycle of local training and central aggregation. Each round brings the model closer to understanding patterns across data sources, without ever pooling the raw data.

Here is an overview:

Initializing the global model

The process begins with a central team, such as one inside a hospital or financial institution, creating an initial version of the model. That model is then sent to participating systems, along with configuration instructions for training.

Training on local data

Each participant trains the model using only its own data. For example, a retailer may use customer purchase records, while a healthcare provider uses patient data. The data never leaves the local system, which helps protect privacy.

Creating model updates

After local training, each participant generates a model update. This update consists of numerical adjustments to the model’s internal settings, also known as parameters. The update captures learning outcomes, not the raw data.

Aggregating updates centrally

The central system gathers all updates and combines them into a new version of the model. One common method is weighted averaging, where larger or more active participants influence the outcome more. Privacy-preserving techniques may be added at this stage, including trusted execution environments (TEE) to secure the aggregation process.

Iterating to improve performance

Once the updated model has been created, it is shared again with each participant. They repeat the same process: training the model locally, generating an update, and sending it back. Each cycle builds on the last, fine-tuning the model’s ability to make accurate predictions using patterns from many different data sources.

Federated learning use cases 

As industries scale AI adoption, federated learning is gaining traction. IDC predicts 80% of CIOs will prioritize AI and automation by 2028, with federated approaches playing a key role in unlocking insights across silos. 

Below are three examples of how it supports collaboration across sectors that cannot pool data directly.

Collaborative medical model training

Several hospitals can work together to train a diagnostic model using their own medical images. For example, clinics focused on cancer care may contribute training results from local scans to help the model learn to identify tumours. Each hospital trains the model in its own environment and returns updates, allowing the shared system to handle a broader range of patient examples.

Privacy-preserving product recommendations

Retailers can improve online experiences by learning from user interactions on apps or websites without centralizing the data. Each device updates the model based on local browsing or purchase activity. The updates are then used to refine the recommendation system, without recording or transmitting personal histories.

Cross-institutional fraud detection

Fraud tends to follow common patterns across banks, even when the transactions take place in separate systems. Federated learning allows each institution to train the model on its own transaction records. When updates are sent back and combined, the model improves its ability to detect suspicious behavior across multiple organizations, without requiring any bank to share customer data.

FAQs