Adaptive Communication Strategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD

Jianyu Wang, Gauri Joshi

Introduction

Stochastic gradient descent (SGD) is the backbone of state-of-the-art supervised learning, which is revolutionizing inference and decision-making in many diverse applications. Classical SGD was designed to be run on a single computing node, and its error-convergence with respect to the number of iterations has been extensively analyzed and improved via accelerated SGD methods. Due to the massive training data-sets and neural network architectures used today, it has became imperative to design distributed SGD implementations, where gradient computation and aggregation is parallelized across multiple worker nodes. Although parallelism boosts the amount of data processed per iteration, it exposes SGD to unpredictable node slowdown and communication delays stemming from variability in the computing infrastructure. Thus, there is a critical need to make distributed SGD fast, yet robust to system variability.

Need to Optimize Convergence in terms of Error versus Wall-clock Time. The convergence speed of distributed SGD is a product of two factors: 1) the error in the trained model versus the number of iterations, and 2) the number of iterations completed per second. Traditional single-node SGD analysis focuses on optimizing the first factor, because the second factor is generally a constant when SGD is run on a single dedicated server. In distributed SGD, which is often run on shared cloud infrastructure, the second factor depends on several aspects such as the number of worker nodes, their local computation and communication delays, and the protocol (synchronous, asynchronous or periodic) used to aggregate their gradients. Hence, in order to achieve the fastest convergence speed we need: 1) optimization techniques (eg. variable learning rate) to maximize the error-convergence rate with respect to iterations, and 2) scheduling techniques (eg. straggler mitigation, infrequent communication) to maximize the number of iterations completed per second. These directions are inter-dependent and need to be explored together rather than in isolation. While many works have advanced the first direction, the second is less explored from a theoretical point of view, and the juxtaposition of both is an unexplored problem.

Local-Update SGD to Reduce Communication Delays. A popular distributed SGD implementation is the parameter server framework Dean et al. (2012); Cui et al. (2014); Li et al. (2014); Gupta et al. (2016); Mitliagkas et al. (2016) where in each iteration, worker nodes compute gradients on one mini-batch of data and a central parameter server aggregates these gradients (synchronously or asynchronously) and updates the parameter vector x\mathbf{x}. The constant communication between the parameter server and worker nodes in each iteration can be expensive and slow in bandwidth-limited computed environments. Recently proposed distributed SGD frameworks such as Elastic-averaging Zhang et al. (2015); Chaudhari et al. (2017), Federated Learning McMahan et al. (2016); Smith et al. (2017b) and decentralized SGD Lian et al. (2017); Jiang et al. (2017) save this communication cost by allowing worker nodes to perform local updates to the parameter x\mathbf{x} instead of just computing gradients. The resulting locally trained models (which are different due to variability in training data across nodes) are periodically averaged through a central server, or via direct inter-worker communication. This local-update strategy has been shown to offer significant speedup in deep neural network training Lian et al. (2017); McMahan et al. (2016).

Error-Runtime Trade-offs in Local-Update SGD. While local updates reduce the communication-delay incurred per iteration, discrepancies between local models can result in an inferior error-convergence. For example, consider the case of periodic-averaging SGD (PASGD) where each of mm worker nodes makes τ\tau local updates, and the resulting models are averaged after every τ\tau iterations Moritz et al. (2015); Su & Chen (2015); Chen & Huo (2016); Seide & Agarwal (2016); Zhang et al. (2016); Zhou & Cong (2017); Lin et al. (2018). A larger value of τ\tau leads to slower convergence with respect to the number of iterations as illustrated in Figure 1. However, if we look at the true convergence with respect to the wall-clock time, then a larger τ\tau, that is, less frequent averaging, saves communication delay and reduces the runtime per iteration. While some recent theoretical works Zhou & Cong (2017); Yu et al. (2018); Wang & Joshi (2018); Stich (2018) study this dependence of the error-convergence with respect to the number of iterations as τ\tau varies, achieving a provably-optimal speed-up in the true convergence with respect to wall-clock time is an open problem that we aim to address in this work.

Need for Adaptive Communication Strategies. In the error-runtime in Figure 1, we observe a trade-off between the convergence speed and the error floor when the number of local updates τ\tau is varied. A larger τ\tau gives a faster initial drop in the training loss but results in a higher error floor. This calls for adaptive communication strategies that start with a larger τ\tau and gradually decrease it as the model reaches closer to convergence. Such an adaptive strategy will offer a win-win in the error-runtime trade-off by achieving fast convergence as well as low error floor. To the best of our knowledge, this is the first work to propose an adaptive communication frequency strategy.

Main Contributions. This paper focuses on periodic-averaging local-update SGD (PASGD) and makes the following main contributions:

We first analyze the runtime per iteration of periodic averaging SGD (PASGD) by modeling local computing time and communication delays as random variables, and quantify its runtime speed-up over fully synchronous SGD. A novel insight from this analysis is that periodic-averaging strategy not only reduces the communication delay but also mitigates synchronization delays in waiting for slow or straggling nodes.

By combining the runtime analysis error-convergence analysis of PASGD Wang & Joshi (2018), we can obtain the error-runtime trade-off for different values of τ\tau. Using this combined error-runtime trade-off, we derive an expression of the optimal communication period, which can serve as a useful guideline in practice.

Based on the observations in runtime and convergence analysis, we develop an adaptive communication scheme: AdaComm. Experiments on training VGG-16 and ResNet-50 deep neural networks and different settings (with/without momentum, fixed/decaying learning rate) show that AdaComm can give a 3×3\times runtime speed-up and still reach the same low training loss as fully synchronous SGD.

We present a convergence analysis for PASGD with variable communication period τ\tau and variable learning rate η\eta, generalizing previous work Wang & Joshi (2018). This analysis shows that decaying τ\tau provides similar convergence benefits as decaying learning rate, the difference being that varying τ\tau improves the true convergence with respect to the wall-clock time. Adaptive communication can also be used in conjunction with existing learning rate schedules.

Although we focus on periodic simple-averaging of local models, the insights on error-runtime trade-offs and adaptive communication strategies are directly extendable to other communication-efficient SGD algorithms including Federated Learning McMahan et al. (2016), Elastic-Averaging Zhang et al. (2015) and Decentralized averaging Jiang et al. (2017); Lian et al. (2017), as well as synchronous/asynchronous distributed SGD with a central parameter server Dean et al. (2012); Cui et al. (2014); Dutta et al. (2018).

Problem Framework

where f(x;si)f(\mathbf{x};s_{i}) is the composite loss function at the ithi^{th} data point. In classic mini-batch stochastic gradient descent (SGD) Dekel et al. (2012), updates to the parameter vector x\mathbf{x} are performed as follows. If ξkS\xi_{k}\subset S represents a randomly sampled mini-batch, then the update rule is

where η\eta denotes the learning rate and the stochastic gradient is defined as: g(x;ξ)=1ξsiξf(x;si)g(\mathbf{x};\xi)=\frac{1}{|\xi|}\sum_{s_{i}\in\xi}\nabla f(\mathbf{x};s_{i}). For simplicity, we will use g(xk)g(\mathbf{x}_{k}) instead of g(xk;ξk)g(\mathbf{x}_{k};\xi_{k}) in the rest of the paper. A complete review of convergence properties of serial SGD can be found in Bottou et al. (2018).

Periodic-Averaging SGD (PASGD). We consider a distributed SGD framework with mm worker nodes where all workers can communicate with others via a central server or via direct inter-worker communication. In periodic-averaging SGD, all workers start at the same initial point x1\mathbf{x}_{1}. Each worker performs τ\tau local mini-batch SGD updates according to (2), and the local models are averaged by a fusion node or by performing an all-node broadcast. The workers then update their local models with the averaged model, as illustrated in Figure 2. Thus, the overall update rule at the ithi^{th} worker is given by

where xk(i)\mathbf{x}_{k}^{(i)} denote the model parameters in the ii-th worker after kk iterations and τ\tau is defined as the communication period. Note that the iteration index kk corresponds to the local iterations, and not the number of averaging steps.

Special Case (τ=1\tau=1): Fully Synchronous SGD. When τ=1\tau=1, that is, the local models are synchronized after every iteration, periodic-averaging SGD is equivalent to fully synchronous SGD which has the update rule

The analysis of fully synchronous SGD is identical to serial SGD with mm-fold large mini-batch size.

Local Computation Times and Communication Delay. In order to analyze the effect of τ\tau on the expected runtime per iteration, we consider the following delay model. The time taken by the ithi^{th} worker to compute a mini-batch gradient at the kthk^{th} local-step is modeled a random variable Yi,kFYY_{i,k}\sim F_{Y}, assumed to be i.i.d. across workers and mini-batches. The communication delay is a random variable DD for each all-node broadcast, as illustrated in Figure 3. The value of random variable DD can depend on the number of workers as follows.

where D0D_{0} represents the time taken for each inter-node communication, and s(m)s(m) describes how the delay scales with the number of workers, which depends on the implementation and system characteristics. For example, in the parameter server framework, the communication delay can be proportional to 2log2(m)2\log_{2}(m) by exploiting a reduction tree structure Iandola et al. (2016). We assume that s(m)s(m) is known beforehand for the communication-efficient distributed SGD framework under consideration.

Convergence Criteria. In the error-convergence analysis, since the objective function is non-convex, we use the expected gradient norm as a an indicator of convergence following Ghadimi & Lan (2013); Bottou et al. (2018). We say the algorithm achieves an ϵ\epsilon-suboptimal solution if:

When ϵ\epsilon is arbitrarily small, this condition can guarantee the algorithm converges to a stationary point.

Jointly Analyzing Runtime and Error-Convergence

We now present a comparison of the runtime per iteration of periodic-averaging SGD with fully synchronous SGD to illustrate how increasing τ\tau can lead to a large runtime speed-up. Another interesting effect of performing more local update τ\tau is that it mitigates the slowdown due to straggling worker nodes.

Runtime Per Iteration of Fully Synchronous SGD. Fully synchronous SGD is equivalent to periodic-averaging SGD with τ=1\tau=1. Each of the mm workers computes the gradient of one mini-batch and updates the parameter vector x\mathbf{x}, which takes time Yi,1Y_{i,1} at the ithi^{th} workerInstead of local updates, typical implementations of fully synchronous SGD have a central server that performs the update. Here we compare PASGD with fully synchronous SGD without a central parameter server.. After all workers finish their local updates, an all-node broadcast is performed to synchronize and average the models. Thus, the total time to complete each iteration is given by

where Yi,1Y_{i,1} are i.i.d. random variables with probability distribution FYF_{Y} and DD is the communication delay. The term Ym:mY_{m:m} denotes the highest order statistic of mm i.i.d. random variables David & Nagaraja (2003).

Runtime Per Iteration of Periodic-Averaging SGD (PASGD). In periodic-averaging SGD, each worker performs τ\tau local updates before communicating with other workers. Let us denote the average local computation time at the ithi^{th} worker by

Since the communication delay DD is amortized over τ\tau iterations, the average computation time per iteration is

The value of the first term Ym:m\overline{Y}_{m:m} and how it compares with Ym:mY_{m:m} depends on the probability distribution FYF_{Y} of YY.

2 Runtime Benefits of Periodic Averaging Strategy

Figure 4 shows the speed-up for different values of α\alpha and τ\tau. When DD is comparable with YY (α=0.9\alpha=0.9), periodic-averaging SGD (PASGD) can be almost twice as fast as fully synchronous SGD.

3 Joint Analysis with Error-convergence

In this subsection, we combine the runtime analysis with previous error-convergence analysis for PASGD Wang & Joshi (2018). Due to space limitations, we state the necessary theoretical assumptions in the Appendix; the assumptions are similar to previous works Zhou & Cong (2017); Wang & Joshi (2018) on the convergence of local-update SGD algorithms.

For PASGD, under certain assumptions (stated in the Appendix), if the learning rate satisfies ηL+η2L2τ(τ1)1\eta L+\eta^{2}L^{2}\tau(\tau-1)\leq 1, YY and DD are constants, and all workers are initialized at the same point x1\mathbf{x}_{1}, then after total TT wall-clock time, the minimal expected squared gradient norm within TT time interval will be bounded by:

where LL is the Lipschitz constant of the objective function and σ2\sigma^{2} is the variance bound of mini-batch stochastic gradients.

The proof of Theorem 1 is presented in the Appendix. From the optimization error upper bound 13, one can easily observe the error-runtime trade-off for different communication periods. While a larger τ\tau reduces the runtime per iteration and let the first term in 13 become smaller, it also adds additional noise and increases the last term. In Figure 6, we plot theoretical bounds for both fully synchronous SGD (τ=1\tau=1) and PASGD. It is shown that although PASGD with τ=10\tau=10 starts with a rapid drop, it will eventually converge to a high error floor. This theoretical result is also corroborated by experiments in Section 5. Another direct outcome of Theorem 1 is the determination of the best communication period that balances the first and last terms in 13. We will discuss the selection of communication period later in Section 4.1.

AdaComm: Proposed Adaptive Communication Strategy

Inspired by the clear trade-off in the learning curve in Figure 6, it would be better to have an adaptive communication strategy that starts with infrequent communication to improve convergence speed, and then increases the frequency to achieve a low error floor. In this section, we are going to develop the proposed adaptive communication scheme.

The basic idea to adapt the communication is to choose the communication period that minimizes the optimization error at each wall-clock time. One way to achieve the idea is switching between the learning curves at their intersections. However, without prior knowledge of various curves, it would be difficult to determine the switch points.

Instead, we divide the whole training procedure into uniform wall-clock time intervals with the same length T0T_{0}. At the beginning of each time interval, we select the best value of τ\tau that has the fastest decay rate in the next T0T_{0} wall-clock time. If the interval length T0T_{0} is small enough and the best choice of communication period for each interval can be precisely estimated, then this adaptive scheme should achieve a win-win in the error-runtime trade-off as illustrated in Figure 7.

After setting the interval length, the next question is how to estimate the best communication period for each time interval. In Section 4.1 we use the error-runtime analysis in Section 3.3 to find the best τ\tau at each time.

From Theorem 1, it can be observed that there is an optimal value τ\tau^{*} that minimizes the optimization error bound at given wall-clock time. In particular, consider the simplest setting where YY and DD are constants. Then, by minimizing the upper bound 13 over τ\tau, we obtain the following.

For PASGD, under the same assumptions as Theorem 1, the optimization error upper bound in (13) at time TT is minimized when the communication period is

Similarly, for the ll-th time interval, workers can be viewed as restarting training at a new initial point xt=lT0\mathbf{x}_{t=lT_{0}}. Applying Theorem 2 again, we have

Comparing 15 and 16, it is easy to see the generated communication period sequence decreases along with the objective value F(xt)F(\mathbf{x}_{t}) when the learning rate is fixed. This result is consistent with the intuition that the trade-off between error-convergence and communication-efficiency varies over time. Compared to the initial phase of training, the benefit of using a large communication period diminishes as the model reaches close to convergence. At this later stage, a lower error floor is more preferable to speeding up the runtime.

Using a fixed learning rate in SGD leads to an error floor at convergence. To further reduce the error, practical SGD implementations generally decay the learning rate or increase the mini-batch size Smith et al. (2017a); Goyal et al. (2017). As we saw from the convergence analysis Theorem 1, performing local updates adds additional noise in stochastic gradients, resulting in a higher error floor convergence. Decaying the communication period can gradually reduce the variance of gradients and yield a similar improvement in convergence. Thus, adaptive communication strategies are similar in spirit to decaying learning rate or increasing mini-batch size. The key difference is that here we are optimizing the true error convergence with respect to wall-clock time rather than the number iterations.

2 Practical Considerations

Although 15 and 16 provide useful insights about how to adapt τ\tau over time, it is still difficult to directly use them in practice due to the Lipschitz constant LL and the gradient variance bound σ2\sigma^{2} being unknown. For deep neural networks, estimating these constants can be difficult and unreliable due to the highly non-convex and high-dimensional loss surface. As an alternative, we propose a simpler rule where we approximate FinfF_{\text{inf}} by , and divide 16 by 15 to obtain the basic communication period update rule:

where a\lceil a\rceil is the ceil function to round aa to the nearest integer a\geq a. Since the objective function values (i.e., training loss) F(xt=lT0)F(\mathbf{x}_{t=lT_{0}}) and F(xt=0)F(\mathbf{x}_{t=0}) can be easily obtained in the training, the only remaining thing now is to determine the initial communication period τ0\tau_{0}. We obtain a heuristic estimate of τ0\tau_{0} by a simple grid search over different τ\tau run for one or two epochs each.

3 Refinements to the Proposed Adaptive Strategy

The communication period update rule 17 tends to give a decreasing sequence {τl}\{\tau_{l}\}. Nonetheless, it is possible that the best value of τl\tau_{l} for next time interval is larger than the current one due to random noise in the training process. Besides, when the training loss gets stuck on plateaus and decreases very slowly, 17 will result in τl\tau_{l} saturating at the same value for a long time. To address this issue, we borrow a idea used in classic SGD where the learning rate is decayed by a factor γ\gamma when the training loss saturates for several epochs Goyal et al. (2017). Similarly, in the our scheme, the communication period will be multiplied by γ<1\gamma<1 when the τl\tau_{l} given by 17 is not strictly less than τl1\tau_{l-1}. To be specific, the communication period for the lthl^{th} time interval will be determined as follows:

In the experiments, γ=1/2\gamma=1/2 turns out to be a good choice. One can obtain a more aggressive decay in τl\tau_{l} by either reducing the value of γ\gamma or introducing a slack variable ss in the condition, such as F(xt=lT0)F(xt=0)τ0+s<τl1\lceil\sqrt{\frac{F(\mathbf{x}_{t=lT_{0}})}{F(\mathbf{x}_{t=0})}}\tau_{0}\rceil+s<\tau_{l-1}.

3.2 Incorporating Adaptive Learning Rate

So far we consider a fixed learning rate η\eta for the local SGD updates at the workers. We now present an adaptive communication strategy that adjusts τl\tau_{l} for a given variable learning rate schedule, in order to obtain the best error-runtime trade-off. Suppose ηl\eta_{l} denotes the learning rate for the lthl^{th} time interval. Then, combining 15 and 16 again, we have

Observe that when the learning rate becomes smaller, the communication period τl\tau_{l} increases. This result corresponds the intuition that a small learning rate reduces the discrepancy between the local models, and hence is more tolerant to large communication periods.

Equation 19 states that the communication period should be proportional to (η0/ηl)3/2(\eta_{0}/\eta_{l})^{3/2}. However, in practice, it is common to decay the learning rate 1010 times after some given number of epochs. The dramatic change of learning rate may push the communication period to an unreasonably large value. In the experiments, we observe that when applying 19, the communication period can increase to τ=1000\tau=1000 which causes the training loss to diverge.

To avoid this issue, we propose the adaptive strategy given by (20) below. This strategy can also be justified by theoretical analysis. Suppose that in lthl^{th} time interval, the objective function has a local Lipschitz smoothness LlL_{l}. Then, by using the approximation ηlLl1\eta_{l}L_{l}\approx 1, which is common in SGD literature Balles et al. (2016), we derive the following adaptive strategy:

Apart from coupling the communication period with learning rate, when to decay the learning rate is another key design factor. In order to eliminate the noise introduced by local updates, we choose to first gradually decay the communication period to 11 and then decay the learning rate as usual. For example, if the learning rate is scheduled to be decayed at the 80th80^{th} epoch but at that time the communication period τ\tau is still larger than 11, then we will continue use the current learning rate until τ=1\tau=1.

4 Theoretical Guarantees for the Convergence of AdaComm

For PASGD with adaptive communication period and adaptive learning rate, suppose the learning rate remains same in each local update period. If the following conditions are satisfied as RR\to\infty,

then the averaged model x\overline{\mathbf{x}} is guaranteed to converge to a stationary point:

The proof details and a non-asymptotic result (similar to Theorem 1 but with variable τ\tau) are provided in Appendix. In order to understand the meaning of condition 21, let us first consider the case when τ0==τR\tau_{0}=\cdots=\tau_{R} is a constant. In this case, the convergence condition is identical to mini-batch SGD Bottou et al. (2018):

As long as the communication period sequence is bounded, it is trivial to adapt the learning rate scheme in mini-batch SGD 23 to satisfy 21. In particular, when the communication period sequence is decreasing, the last two terms in 21 will become easier to be satisfied and put less constraints on the learning rate sequence.

Experimental Results

Platform. The proposed adaptive communication scheme was implemented in Pytorch Paszke et al. (2017) with Mpi4Py Dalcín et al. (2005). All experiments were conducted on a local cluster with 44 worker nodes, each of which has an NVIDIA TitanX GPU and a 1616-core Intel Xeon CPU. Worker nodes are connected via a 4040 Gbps (50005000 Mb/s) Ethernet interface. Due to space limitations, additional results with 88 nodes are listed in Appendix A.

Dataset. We evaluate our method for image classification tasks on CIFAR10 and CIFAR100 dataset Krizhevsky (2009), which consists of 50,000 training images and 10,000 validation images in 10 and 100 classes respectively. Each worker machine is assigned with a partition which will be randomly shuffled after every epoch.

Model. We choose to train deep neural networks VGG-16 Simonyan & Zisserman (2014) and ResNet-50 He et al. (2016) from scratch The implementations of VGG-16 and ResNet-50 follow this GitHub repository: https://github.com/meliketoy/wide-resnet.pytorch. These two neural networks have different architectures and parameter sizes, thus resulting in different performance of periodic-averaging. As shown in Figure 8, for VGG-16, the communication time is about 44 times higher than the computation time. Thus, compared to ResNet-50, it requires a larger τ\tau in order to reduce the runtime-per-iteration and achieve fast convergence. Similar high communication/computation ratio is common in literature, see Lin et al. (2018); Harlap et al. (2018).

Hyperparameter Choice. Mini-batch size on each worker is 128128. Therefore, the total mini-batch size per iteration is 512512. The initial learning rates for VGG-16 and ResNet-50 are 0.20.2 and 0.40.4 respectively. The weight decay for both networks is 0.00050.0005. In the variable learning rate setting, we decay the learning rate by 1010 after 80th/120th/160th/200th80^{\text{th}}/120^{\text{th}}/160^{\text{th}}/200^{\text{th}} epochs. We set the time interval length T0T_{0} as 6060 seconds (about 1010 epochs for the initial communication period).

Metrics. We compare the performance of proposed adaptive communication scheme with following methods with a fixed communication period: (1) Baseline: fully synchronous SGD (τ=1\tau=1); (2) Extreme high throughput case where τ=100\tau=100; (3) Manually tuned case where a moderate value of τ\tau is selected after trial runs with different communication periods. Instead of training for a fixed number of epochs, we train all methods for sufficiently long time to convergence and compare the training loss and test accuracy, both of which are recorded after every 100 iterations.

2 Adaptive Communication in PASGD

We first validate the effectiveness of AdaComm which uses the communication period update rule 18 combined with 20 on original PASGD without momentum.

Figure 9 presents the results for VGG-16 for both fixed and variable learning rates. A large communication period τ\tau initially results in a rapid drop in the error, but the error finally converges to higher floor. By adapting τ\tau, the proposed AdaComm scheme strikes the best error-runtime trade-off in all settings. In Figure 9(a), while fully synchronous SGD takes 33.533.5 minutes to reach 4×1034\times 10^{-3} training loss, AdaComm costs 15.515.5 minutes achieving more than 2×2\times speedup. Similarly, in Figure 9(b), AdaComm takes 11.511.5 minutes to reach 4.5×1024.5\times 10^{-2} training loss achieving 3.3×3.3\times speedup over fully synchronous SGD (38.038.0 minutes).

However, for ResNet-50, the communication overhead is no longer the bottleneck. For fixed communication period, the negative effect of performing local updates becomes more obvious and cancels the benefit of low communication delay (see Figures 10(b) and 10(c)). It is not surprising to see fully synchronous SGD is nearly the best one in the error-runtime plot among all fixed-τ\tau methods. Even in this extreme case, adaptive communication can still have a competitive performance. When combined with learning rate decay, the adaptive scheme is about 1.3 times faster than fully synchronous SGD (see Figure 10(a), 15.015.0 versus 21.521.5 minutes to achieve 3×1023\times 10^{-2} training loss).

Table 1 lists the test accuracies in different settings; we report the best accuracy within a time budget for each setting. The results show that adaptive communication method have better generalization than fully synchronous SGD. In the variable learning rate case, the adaptive method even gives the better test accuracy than PASGD with the best fixed τ\tau.

3 Adaptive Communication in Momentum SGD

The adaptive communication scheme is proposed based on the joint error-runtime analysis for PASGD without momentum. However, it can also be extended to other SGD variants, and in this subsection, we show that the proposed method works well for SGD with momentum.

Before presenting the empirical results, we describe how to introduce momentum in PASGD. A naive way is to apply the momentum independently to each local model, where each worker maintains an independent momentum buffer, which is the latest change in the parameter vector x\mathbf{x}. However, this does not account for the potential dramatic change in x\mathbf{x} at each averaging step. When local models are synchronized, the local momentum buffer will contain the update steps before averaging, resulting in a large momentum term in the first SGD step of the each local update period. When τ\tau is large, this large momentum term can side-track the SGD descent direction resulting in slower convergence.

To address this issue, a block momentum scheme was proposed in Chen & Huo (2016) and applied to speech recognition tasks. The basic idea is to treat the local updates in each communication period as one big gradient step between two synchronized models, and to introduce a global momentum for this big accumulated step. The update rule can be written as follows in terms of the momentum uj\mathbf{u}_{j}:

where Gj=1mi=1mk=1τg(xjτ+k(i))\mathcal{G}_{j}=\frac{1}{m}\sum_{i=1}^{m}\sum_{k=1}^{\tau}g(\mathbf{x}_{j\tau+k}^{(i)}) represents the accumulated gradients in the jthj^{th} local update period and βglob\beta_{\text{glob}} denotes the global momentum factor. Moreover, workers can also conduct momentum SGD on local models, but their local momentum buffer will be cleared at the beginning of each local update period. That is, we restart momentum SGD on local models after every averaging step. The same strategy was also suggested in Microsoft’s CNTK framework Seide & Agarwal (2016). In our experiments, we set the global momentum factor as 0.30.3 and local momentum factor as 0.90.9 following Lin et al. (2018). In the fully synchronous case, there is no need to introduce the block momentum and we simply follow the common practice setting the momentum factor as 0.90.9.

3.2 AdaComm plus Block Momentum

In Figure 11, we apply our adaptive communication strategy in PASGD with block momentum and observe significant performance gain on CIFAR10/100. In particular, the adaptive communication scheme has the fastest convergence rate with respect to wall-clock time in the whole training process. While fully synchronous SGD gets stuck with a plateau before the first learning rate decay, the training loss of adaptive method continuously decreases until converging. For VGG-16 in Figure 11(b), AdaComm is 3.5×3.5\times faster (in terms of wall-clock time) than fully synchronous SGD in reaching a 3×1033\times 10^{-3} training loss. For ResNet-50 in Figure 11(a), AdaComm takes 15.815.8 minutes to get 2×1022\times 10^{-2} training loss which is 22 times faster than fully synchronous SGD (32.632.6 minutes).

Concluding Remarks

The design of communication-efficient SGD algorithms that are robust to system variability is vital to scaling machine learning training to resource-limited computing nodes. This paper is one of the first to analyze the convergence of error with respect to wall-clock time instead of number of iterations by accounting for the effect of computation and communication delays on the runtime per iteration. We present a theoretical analysis of the error-runtime trade-off for periodic-averaging SGD (PASGD), where each node performs local updates and their models are averaged after every τ\tau iterations. Based on the joint error-runtime analysis, we design the first (to the best of our knowledge) adaptive communication strategy called AdaComm for distributed deep learning. Experimental results using VGGNet and ResNet show that the proposed method can achieve up to a 3×3\times improvement in runtime, while achieving the same error floor as fully synchronous SGD.

Going beyond periodic-averaging SGD, our idea of adapting frequency of averaging distributed SGD updates can be easily extended to other SGD frameworks including elastic-averaging Zhang et al. (2015), decentralized SGD (e.g., adapting network sparsity) Lian et al. (2017) and parameter server-based training (e.g., adapting asynchrony).

Acknowledgments

The authors thank Prof. Greg Ganger for helpful discussions. This work was partially supported by NSF CCF-1850029 and an IBM Faculty Award. Experiments were conducted on clusters provided by the Parallel Data Lab at CMU.

References

Appendix A Additional Experimental Results

In the 88 worker case, the communication among nodes is accomplished via Nvidia Collective Communication Library (NCCL). The mini-batch size on each node is 6464. The initial learning rate is set as 0.20.2 for both VGG-16 and ResNet-50. In Figure 12(a), while fully synchronous SGD takes 17.517.5 minutes to reach 10210^{-2} training loss, AdaComm only costs 6.06.0 minutes achieving about 2.9×2.9\times speedup.

Appendix B Inefficient Local Updates

It is worth noting there is an interesting phenomenon about the convergence of periodic averaging SGD (PASGD). When the learning rate is fixed, PASGD with fine-tuned communication period has better test accuracy than both fully synchronous SGD and the adaptive method, while its training loss remains higher than the latter two methods (see Figure 9, Figure 10). In particular, on CIFAR100 dataset, we observe about 5%5\% improvement in test accuracy when τ=5\tau=5. To investigate this phenomenon, we evaluate the test accuracy for PASGD (τ=15\tau=15) in two frequencies: 1) every 135135 iterations; 2) every 100100 iterations. In the former case, the test accuracy is reported just after the averaging step. However, in the latter case, the test accuracy can come from either the synchronized/averaged model or local models, since 100100 cannot be divided by 1515.

From Figure 14, it is clear that local model’s accuracy is much lower than the synchronized model, even when the algorithm has converged. Thus, we conjecture that the improvement of test accuracy only happens on the synchronized model. That is, after averaging, the test accuracy will undergo a rapid increase but it decreases again in the following local steps due to noise in stochastic gradients. Such behavior may depend on the geometric structure of the loss surface of specific neural networks. The observation also reveals that the local updates are inefficient as they reduces the accuracy and makes no progress. In this sense, it is necessary for PASGD to reduce the gradient variance by either decaying learning rate or decaying communication period.

Appendix C Assumptions for Convergence Analysis

The convergence analysis is conducted under the following assumptions, which are similar to the assumptions made in previous work on the analysis of PASGD Zhou & Cong (2017); Yu et al. (2018); Wang & Joshi (2018); Stich (2018). In particular, we make no assumptions on the convexity of the objective function. We also remove the uniform bound assumption for the norm of stochastic gradients.

The objective function F(x)F(\mathbf{x}) is differentiable and LL-Lipschitz smooth, i.e., F(x)F(y)Lxy\left\|\nabla F(\mathbf{x})-\nabla F(\mathbf{y})\right\|\leq L\left\|\mathbf{x}-\mathbf{y}\right\|. The function value is bounded below by a scalar FinfF_{\text{inf}}.

The variance of stochastic gradient evaluated on a mini-batch ξ\xi is bounded as

where β\beta and σ2\sigma^{2} are non-negative constants and in inverse proportion to the mini-batch size.

Appendix D Proof of Theorem 2: Error-runtime Convergence of PASGD

Firstly, let us recall the error-analysis of PASGD. We adapt the theorem from Wang & Joshi (2018).

For PASGD, under Assumptions 1, 2 and 3, if the learning rate satisfies ηL+η2L2τ(τ1)1\eta L+\eta^{2}L^{2}\tau(\tau-1)\leq 1 and all workers are initialized at the same point x1\mathbf{x}_{1}, then after KK iterations, we have

where LL is the Lipschtiz constant of the objective function, σ2\sigma^{2} is the variance bound of mini-batch stochastic gradients and xk\overline{\mathbf{x}}_{k} denotes the averaged model at the kthk^{th} iteration.

From the runtime analysis in Section 2, we know that the expected runtime per iteration of PASGD is

Accordingly, the total wall-clock time of training KK iteration is

Appendix E Proof of Theorem 3: the Best Communication Period

Taking the derivative of the upper bound (14) with respect to the communication period, we obtain

When the derivative equals to zero, the communication period is

then the optimal value obtained in 30 must be a global minimum.

Appendix F Proof of Theorem 4: Error-Convergence of Adaptive Communication Scheme

Besides, define matrix J=11/(11)\mathbf{J}=\mathbf{1}\mathbf{1}^{\top}/(\mathbf{1}^{\top}\mathbf{1}) where 1\mathbf{1} denotes the column vector [1,1,,1][1,1,\dots,1]^{\top}. Unless otherwise stated, 1\mathbf{1} is a size mm column vector, and the matrix J\mathbf{J} and identity matrix I\mathbf{I} are of size m×mm\times m, where mm is the number of workers.

F.2 Proof

Let us first focus on the jj-th local update period, where j{0,1,,R}j\in\{0,1,\dots,R\}. Without loss of generality, suppose the local index of the jthj^{th} local update period starts from 11 and ends with τj\tau_{j}. Then, for the kk-th local step in the interested period, we have the following lemma.

For PASGD, under Assumptions 1, 2 and 3, at the kk-th iteration, we have the following bound for the objective value:

where xk\overline{\mathbf{x}}_{k} denotes the averaged model at the kthk^{th} iteration.

Taking the total expectation and summing over all iterates in the jj-th local update period, we can obtain

Next, we are going to provide an upper bound for the last term in (35). Note that

where (39) follows the fact that all workers start from the same point at the beginning of each local update period, i.e., X1(IJ)=0\mathbf{X}_{1}(\mathbf{I}-\mathbf{J})=0. Accordingly, we have

where the inequality (41) is due to the operator norm of (IJ)(\mathbf{I}-\mathbf{J}) is less than 1. Furthermore, using the fact (a+b)22a2+2b2(a+b)^{2}\leq 2a^{2}+2b^{2}, one can get

For the first term T1T_{1}, since the stochastic gradients are unbiased, all cross terms are zero. Thus, combining with Assumption 3, we have

For the second term in (43), directly applying Jensen’s inequality, we get

Substituting the bounds of T1T_{1} and T2T_{2} into (43),

Recall the upper bound (35), we further derive the following bound:

Note that when the learning rate satisfies:

Suppose lj=r=0j1τr+1l_{j}=\sum_{r=0}^{j-1}\tau_{r}+1 is the first index in the jj-th local update period. Without loss of generality, we substitute the local index by global index:

Summing over all local periods from j=0j=0 to j=Rj=R, one can obtain

After minor rearranging, it is easy to see

F.3 Asymptotic Result (Theorem 3)

In order to let the upper bound 62 converges to zero as RR\to\infty, a sufficient condition is

Here, we complete the proof of Theorem 3.

F.4 Simplified Result

We can obtain a simplified result when the learning rate is fixed. To be specific, we have

If we choose the total iterations K=j=0RτjK=\sum_{j=0}^{R}\tau_{j}, then