Practical Secure Aggregation for Federated Learning on User-Held Data

Keith Bonawitz, Vladimir Ivanov, Ben Kreuter, Antonio Marcedone, H. Brendan McMahan, Sarvar Patel, Daniel Ramage, Aaron Segal, Karn Seth

Introduction

Secure Aggregation is a class of Secure Multi-Party Computation algorithms wherein a group of mutually distrustful parties uUu\in{\cal U} each hold a private value xux_{u} and collaborate to compute an aggregate value, such as the sum uUxu\sum_{u\in{\cal U}}x_{u}, without revealing to one another any information about their private value except what is learnable from the aggregate value itself. In this work, we consider training a deep neural network in the Federated Learning model, using distributed gradient descent across user-held training data on mobile devices, using Secure Aggregation to protect the privacy of each user’s model gradient. We identify a combination of efficiency and robustness requirements which, to the best of our knowledge, are unmet by existing algorithms in the literature. We proceed to design a novel, communication-efficient Secure Aggregation protocol for high-dimensional data that tolerates up to \nicefrac13\nicefrac{{1}}{{3}} of users failing to complete the protocol. For 16-bit input values, our protocol offers 1.73×1.73\times communication expansion for 2102^{10} users and 2202^{20}-dimensional vectors, and 1.98×1.98\times expansion for 2142^{14} users and 2242^{24}-dimensional vectors.

Secure Aggregation for Federated Learning

Consider training a deep neural network to predict the next word that a user will type as she composes a text message to improve typing accuracy for a phone’s on-screen keyboard . A modeler may wish to train such a model on all text messages across a large population of users. However, text messages frequently contain sensitive information; users may be reluctant to upload a copy of them to the modeler’s servers. Instead, we consider training such a model in a Federated Learning setting, wherein each user maintains a private database of her text messages securely on her own mobile device, and a shared global model is trained under the coordination of a central server based upon highly processed, minimally scoped, ephemeral updates from users .

In the Federated Learning setting, each user uU{u}\in{\cal U} holds a private set DuD_{u} of training examples with D=uUDuD=\bigcup_{{u}\in{\cal U}}D_{u}. To run stochastic gradient descent, for each update we select data from a random subset UU{\cal U}^{\prime}\subset{\cal U} and form a (virtual) minibatch B=uUDuB=\bigcup_{{u}\in{\cal U}^{\prime}}D_{u} (in practice we might have say U=104|{\cal U}^{\prime}|=10^{4} while U=107|{\cal U}|=10^{7}; we might only consider a subset of each user’s local dataset). The minibatch loss gradient Lf(B,Θ)\nabla{\cal L}_{f}(B,\Theta) can be rewritten as a weighted average across users: Lf(B,Θ)=1BuUδut\nabla{\cal L}_{f}(B,\Theta)=\frac{1}{|B|}\sum_{{u}\in{\cal U}^{\prime}}\delta_{u}^{t} where δut=DuLf(Du,Θt)\delta_{u}^{t}=|D_{u}|\nabla{\cal L}_{f}(D_{u},\Theta^{t}). A user can thus share just Du,δut\langle|D_{u}|,\delta_{u}^{t}\rangle with the server, from which a gradient descent step Θt+1ΘtηuUδutuUDu\Theta^{t+1}\leftarrow\Theta^{t}-\eta\frac{\sum_{{u}\in{\cal U}^{\prime}}\delta_{u}^{t}}{\sum_{{u}\in{\cal U}^{\prime}}|D_{u}|} may be taken.

Although each update Du,δut\langle|D_{u}|,\delta_{u}^{t}\rangle is ephemeral and contains less information then the raw DuD_{u}, a user might still wonder what information remains. There is evidence that a trained neural network’s parameters sometimes allow reconstruction of training examples ; might the parameter updates be subject to similar attacks? For example, if the input xx is a one-hot vocabulary-length vector encoding the most recently typed word, common neural network architectures will contain at least one parameter θw\theta_{w} in Θ\Theta for each word ww such that Lfθw\frac{\partial{\cal L}_{f}}{\partial\theta_{w}} is non-zero only when xx encodes ww. Thus, the set of recently typed words in DuD_{u} would be revealed by inspecting the non-zero entries of δut\delta_{u}^{t}. The server does not need to inspect any individual user’s update, however; it requires only the sums uUDu\sum_{{u}\in{\cal U}}|D_{u}| and uUδut\sum_{{u}\in{\cal U}}\delta_{u}^{t}. Using a Secure Aggregation protocol would ensure that the server learns only that one or more users in U{\cal U} wrote the word ww, but not which users.

Federated Learning systems face several practical challenges. Mobile devices have only sporadic access to power and network connectivity, so the set U{\cal U} participating in each update step is unpredictable and the system must be robust to users dropping out. Because Θ\Theta may contain millions of parameters, updates δut\delta_{u}^{t} may be large, representing a direct cost to users on metered network plans. Mobile devices also generally cannot establish direct communications channels with other mobile devices (relying on a server or service provider to mediate such communication) nor can they natively authenticate other mobile devices. Thus, Federated Learning motivates a need for a Secure Aggregation protocol that: (1) operates on high-dimensional vectors, (2) is communication efficient, even with a novel set of users on each instantiation, (3) is robust to users dropping out, and (4) provides the strongest possible security under the constraints of a server-mediated, unauthenticated network model.

A Practical Secure Aggregation Protocol

The server is honest-but-curious, that is it follows the protocol honestly, but tries to learn as much as possible from messages it receives from users.

The server can lie to users about which other users have dropped out, including reporting dropouts inconsistently among different users.

The server can lie about who dropped out (as in T2) and also access the private memory of some limited number of users (who are following the protocol honestly themselves). (In this, the privacy requirement applies only to the inputs of the remaining users.)

We develop our protocol in a series of refinements. We begin by assuming that all parties complete the protocol and possess pair-wise secure communication channels with ample bandwidth. Each pair of users first agree on a matched pair of input perturbations. That is, user u{u} samples a vector su,vs_{{u},{v}} uniformly from [0,R)k[0,R)^{k} for each other user v{v}. Users u{u} and v{v} exchange su,vs_{{u},{v}} and sv,us_{{v},{u}} over their secure channel and compute perturbations pu,v=su,vsv,u(modR)p_{{u},{v}}=s_{{u},{v}}-s_{{v},{u}}\pmod{R}, noting that pu,v=pv,u(modR)p_{{u},{v}}=-p_{{v},{u}}\pmod{R} and taking pu,v=0p_{{u},{v}}=0 when u=v{u}={v}. Each user sends to the server: yu=xu+vUpu,v(modR)y_{{u}}=x_{{u}}+\sum_{{v}\in{\cal U}}p_{{u},{v}}\pmod{R}. The server simply sums the perturbed values: xˉ=uUyu(modR)\bar{x}=\sum_{{u}\in{\cal U}}y_{{u}}\pmod{R}. Correctness is guaranteed because the paired perturbations in yuy_{{u}} cancel:

Protocol 0 guarantees perfect privacy for the users; because the su,vs_{{u},{v}} factors that users add are uniformly sampled, the yuy_{u} values appear uniformly random to the server, subject to the constraint that xˉ=uUyu(modR)\bar{x}=\sum_{{u}\in{\cal U}}y_{{u}}\pmod{R}. In fact, even if the server can access the memory of some users, privacy holds for those remaining. A more complete and formal argument is deferred to the full version of this paper.

Protocol 1: Dropped User Recovery using Secret Sharing

Unfortunately, Protocol 0 fails several of our design criteria, including robustness: if any user u{u} fails to complete the protocol by sending her yuy_{{u}} to the server, the resulting sum will be masked by the perturbations that yuy_{u} would have cancelled. To achieve robustness, we first add an initial round to the protocol in which user u{u} generates a public/private keypair, and broadcasts the public key over the pairwise channels. All future messages from u{u} to v{v} will be intermediated by the server but encrypted with v{v}’s public key, and signed by u{u}, simulating a secure authenticated channel. This allows the server to maintain a consistent view of which users have successfully passed each round of the protocol. (We assume here, temporarily, that the server faithfully delivers all messages between users.)

We also add a secret-sharing round between users after su,vs_{{u},{v}} values have been selected. In this round, each user computes nn shares of each perturbation pu,vp_{{u},{v}} using a (t,n)(t,n)-threshold scheme A (t,n)(t,n) secret-sharing scheme allows splitting a secret into nn shares, such that any subset of tt shares is sufficient to recover the secret, but given any subset of fewer than tt shares the secret remains completely hidden., such as Shamir’s Secret Sharing , for some t>n2t>\frac{n}{2}. For each secret user u{u} holds, she encrypts one share with each user v{v}’s public key, then delivers all of these shares to the server. The server gathers shares from a subset of the users U1U{\cal U}_{1}\subseteq{\cal U} of size at least tt (e.g. by waiting a for a fixed period), then considers all other users dropped. The server delivers to each user vU1{v}\in{\cal U}_{1} the secret shares that were encrypted for that user; all the users in U1{\cal U}_{1} now infer a consistent view of the surviving user set U1{\cal U}_{1} from the set of received shares. When a user computes yuy_{{u}}, she only includes those perturbations related to surviving users; that is, yu=xu+vU1pu,v(modR)y_{{u}}=x_{{u}}+\sum_{{v}\in{\cal U}_{1}}p_{{u},{v}}\pmod{R}.

After the server has received yuy_{{u}} from at least tt users U2U1{\cal U}_{2}\subseteq{\cal U}_{1}, it proceeds to a new unmasking round, considering all other users to be dropped. From the remaining users in U2{\cal U}_{2}, the server requests all shares of secrets generated by the dropped users in U1U2{\cal U}_{1}\setminus{\cal U}_{2}. As long as U2>t|{\cal U}_{2}|>t, each user will respond with those shares. Once the server receives shares from at least tt users, it reconstructs the perturbations for U1U2{\cal U}_{1}\setminus{\cal U}_{2} and computes the aggregate value: xˉ=uU2yuuU2vU1U2pu,v(modR)\bar{x}=\sum_{{u}\in{\cal U}_{2}}y_{{u}}-\sum_{{u}\in{\cal U}_{2}}\sum_{{v}\in{\cal U}_{1}\setminus{\cal U}_{2}}p_{{u},{v}}\pmod{R}. Correctness is guaranteed for Uˉ=U2\bar{{\cal U}}={\cal U}_{2} as long as at least tt users complete the protocol. In this case, the sum xˉ\bar{x} includes the values of at least t>n2t>\frac{n}{2} users, and all perturbations cancel out:

However, security has been lost: if a server incorrectly omits u{u} from U2{\cal U}_{2}, either inadvertently (e.g. yuy_{{u}} arrives slightly too late) or by malicious intent, the honest users in U2{\cal U}_{2} will supply the server with all the secret shares needed to remove all the perturbations that masked xux_{{u}} in yuy_{{u}}. This means we cannot guarantee security even against honest-but-curious servers (Threat Model T1).

Protocol 2: Double-Masking to Thwart a Malicious Server

To guarantee security, we introduce a double-masking structure that protects xux_{{u}} even when the server can reconstruct u{u}’s perturbations. First, each user u{u} samples an additional random value bub_{{u}} uniformly from [0,R)k[0,R)^{k} during the same round as the generation of the su,vs_{{u},{v}} values. During the secret sharing round, the user also generates and distributes shares of bub_{{u}} to each of the other users. When generating yuy_{{u}}, users also add this secondary mask: yu=xu+bu+vU1pu,v(modR)y_{{u}}=x_{{u}}+b_{{u}}+\sum_{{v}\in{\cal U}_{1}}p_{{u},{v}}\pmod{R}. During the unmasking round, the server must make an explicit choice with respect to each user uU1{u}\in{\cal U}_{1}: from each surviving member vU2{v}\in{\cal U}_{2}, the server can request either a share of the pu,vp_{{u},{v}} perturbations associated with u{u} or a share of the bub_{{u}} for u{u}; an honest user v{v} will only respond if U2>t|{\cal U}_{2}|>t, and will never reveal both kinds of shares for the same user. After gathering at least tt shares of pu,vp_{{u},{v}} for all uU1U2{u}\in{\cal U}_{1}\setminus{\cal U}_{2} and tt shares of bub_{u} for all uU2{u}\in{\cal U}_{2}, the server reconstructs the secrets and computes the aggregate value: xˉ=uU2yuuU2buuU2vU1U2pu,v(modR)\bar{x}=\sum_{{u}\in{\cal U}_{2}}y_{{u}}-\sum_{{u}\in{\cal U}_{2}}b_{{u}}-\sum_{{u}\in{\cal U}_{2}}\sum_{{v}\in{\cal U}_{1}\setminus{\cal U}_{2}}p_{{u},{v}}\pmod{R}.

We can now guarantee security in Threat Model T1 for t>n2t>\frac{n}{2}, since xux_{u} always remains masked by either pu,vp_{{u},{v}}s or by bub_{u}s. It can be shown that in Threat Models T2 and T3 the thresholds must be raised to 2n3\frac{2n}{3} and 4n5\frac{4n}{5} correspondingly. We defer the detailed analysis, as well as the case of arbitrarily malicious and colluding servers and users, to the full versionThe security argument involves bounding the number of shares the server can recover by forging dropouts..

Protocol 3: Exchanging Secrets Efficiently

While Protocol 2 is robust and secure with the right choice of tt, it requires O(kn2)O(kn^{2}) communication, which we address in this refinement of the protocol. Observe that a single secret value may be expanded to a vector of pseudorandom values by using it to seed a cryptographically secure pseudorandom generator (PRG) . Thus we can generate just scalar seeds su,vs_{{u},{v}} and bub_{u} and expand them to kk-element vectors. Still, each user has (n1)(n-1) secrets su,vs_{{u},{v}} with other users and must publish shares of all these secrets. We use key agreement to establish these secrets more efficiently. Each user generates a Diffie-Hellman secret key sSKs^{SK} and public key sPKs^{PK}. Users send their public keys to the server (authenticated as per Protocol 1); the server then broadcasts all public keys to all users, retaining a copy for itself. Each pair of users u,v{u},{v} can now agree on a secret su,v=sv,u=\textscAgree(suSK,svPK)=\textscAgree(svSK,suPK)s_{{u},{v}}=s_{{v},{u}}=\operatorname{\textsc{Agree}}(s^{SK}_{u},s^{PK}_{v})=\operatorname{\textsc{Agree}}(s^{SK}_{v},s^{PK}_{u}). To construct perturbations, we assume a total ordering on U{\cal U} and take pu,v=\textscPRG(su,v)p_{{u},{v}}=\operatorname{\textsc{PRG}}(s_{{u},{v}}) for u<v{u}<{v}, pu,v=\textscPRG(su,v)p_{{u},{v}}=-\operatorname{\textsc{PRG}}(s_{{u},{v}}) for u>v{u}>{v}, and pu,v=0p_{{u},{v}}=0 for u=v{u}={v} (as before). The server now only needs to learn suSKs^{SK}_{u} to reconstruct all of u{u}’s perturbations; therefore u{u} need only distribute shares of suSKs^{SK}_{u} and bub_{u} during the secret sharing round. The security of Protocol 3 can be shown to be essentially identical to that of Protocol 2 in each of the different threat models.

Protocol 4: Minimizing Trust in Practice

Protocol 3 is not practically deployable for mobile devices because they lack pairwise secure communication and authentication. We propose to bootstrap the communication protocol by replacing the exchange of public/private keys described in Protocol 1 with a server-mediated key agreement, where each user generates a Diffie-Hellman secret key cSKc^{SK} and public key cPKc^{PK} and advertises the latter together with sPKs^{PK}This can be viewed as bootstrapping a SSL/TLS connection between each pair of users. We note immediately that the server may now conduct man-in-the-middle attacks, but argue that this is tolerable for several reasons. First, it is essentially inevitable for users that lack authentication mechanisms or a pre-existing public-key infrastructure. Relying only on the non-maliciousness of the bootstrapping round also constitutes minimization of trust: the code implementing this stage is small and could be publicly audited, outsourced to a trusted third party, or implemented via a trusted compute platform offering a remote attestation capability . Moreover, the protocol meaningfully increases security (by protecting against anything less than an actively malicious attack by the server) and provides forward secrecy (compromising the server at any time after the key exchange provides no benefit to the attacker, even if all data and communications had been fully logged).

We summarize the protocol’s performance in Table 3. Taking that key agreement public keys and encrypted secret shares are 256 bits and that users’ inputs are all on the same rangeTaking R=n(RU1)+1R=n(R_{U}-1)+1 to ensure no overflow [0,RU1][0,R_{U}-1], each user transfers 256(7n4)+klog2(n(RU1)+1)+nklog2RU\frac{256(7n-4)+k\left\lceil\log_{2}\left(n(R_{U}-1)+1\right)\right\rceil+n}{k\left\lceil\log_{2}R_{U}\right\rceil} more data than if she sent a raw vector.

Related work

The restricted case of secure aggregation in which all users but one have an input 0 can be expressed as a dining cryptographers network (DC-net), which provide anonymity by using pairwise blinding of inputs , allowing to untraceably learn each user’s input. Recent research has examined the communication efficiencly and operation in the presence of malicious users . However, if even one user aborts too early, existing protocols must restart from scratch, which can be very expensive . Pairwise blinding in a modulo addition-based encryption scheme has been explored, but existing schemes are neither efficient for vectors nor robust to even single failure . Other schemes (e.g. based on Paillier cryptosystem ) are very computationally expensive.

References