Skip to content

Abstract Search

Primary Submission Category: Machine Learning and Causal Inference

Missing Not At Random Data In Federated Learning Systems

Authors: David Goetze, Rohit Bhattacharya, Jeannie Albrecht,

Presenting Author: David Goetze*

Federated learning is a technique used to train machine learning models on multiple datasets contained in local nodes without the need for information exchange between these nodes. Often these nodes correspond to multiple users operating on individual devices. A central server sends a model, e.g., an autocorrect model, to all devices, and intermittently pings them for model usage and accuracy. The user devices can choose to respond, sending back the loss or gradients allowing the central server to train and push an updated model back to all users. Thus federated learning enables data privacy while building a shared model across users.

Our focus is on federated learning for models trained using stochastic gradient descent, e.g., deep neural nets, when not all users choose to participate. Depending on why certain users do not share their data, this can lead to updates that cause the model to perform poorly across the general population (e.g., poor autocorrection for younger users if these users disproportionately choose not to participate). We examine missing data in federated learning through the lens of causal inference. In particular, we propose a reweighted version of stochastic gradient descent using the propensity score of missingness that unbiases the computation of gradients under Missing Not At Random assumptions. Finally, we present FLeWM, a federated learning system built to test our technique. We verify our results empirically on both simulated and real world data.