Scalable Optimization in the Modular Norm

Tim Large, Yang Liu, Minyoung Huh, Hyojin Bahng, Phillip Isola, Jeremy Bernstein

Introduction

Given the practical impact of deep learning systems trained at the largest scale, there is a need for training algorithms that scale gracefully: without instability and—if possible—without manual tuning. However, current best practices for training have developed somewhat organically and do not live on a bedrock of sound numerical analysis. For example, while the Adam optimizer is ubiquitous in the field, errors have been found in its proof of convergence , and empirically Adam has been found to scale poorly as either the width or the depth of the network is ramped up.

To remedy this situation, a patchwork of learning rate correction factors have recently been proposed . The general idea is to retrofit a base optimizer such as Adam or SGD with special correction factors intended to render the optimizer’s optimal learning rate invariant to scale. But this situation is not ideal: the correction factors are reportedly difficult to use. Lingle suggests that this may be due to their “higher implementation complexity, many variations, or complex theoretical background”. What’s more, the correction factors are optimizer-specific, meaning that if one switches to a different optimizer one must either look up or recalculate a separate set of correction factors.

The goal of this paper is to simplify matters. We show that both Adam and SGD can be made to scale gracefully with width and depth by simply normalizing their updates in a special norm associated with the network architecture—see Figure 1. We call this norm the modular norm, and provide a Python package called Modula that constructs this norm automatically and in tandem with the architecture.

The modular norm is constructed recursively, leveraging the module tree perspective on neural architectures. It is enough to define how the modular norm propagates through only two elementary operations: composition and concatenation. We show how other basic operations on modules, such as addition and scalar-multiplication, can be implemented through composition and concatenation. And then higher-order structures, such as residual networks, can be built using these basic operations.

Beyond its practical relevance, the modular norm may also prove useful to theoreticians. Various optimization-theoretic quantities are accessible and efficiently calculable in the modular norm. For instance, we show that the gradient of any neural network built from “well-behaved” atomic modules is Lipschitz-continuous in the modular norm of the architecture. This opens the door to porting several more-or-less textbook optimization theory analyses over to the world of deep learning.

It is by now well-known that deep networks do not easily or naturally admit Lipschitz-continuity or smoothness guarantees in the Euclidean norm . Researchers have attempted to address this problem: for instance, Bernstein et al. propose a distance function called deep relative trust, which combines Frobenius norms across network layers. However, deep relative trust is only constructed for the multilayer perceptron and, when used to normalize updates, its employment of the Frobenius norm precludes good width scaling. In contrast, Yang et al. equip individual layers with the RMS–RMS operator norm, finding this to enable good width scaling. Researchers have also looked at building neural net distance functions outside the context of scalability .

Asymptotics

The metrization-based approach to scaling developed in this paper contrasts with the tradition of asymptotic scaling analyses—the study of infinite width and depth limits—more common in the deep learning theory literature . These asymptotic analyses follow an old observation of Neal that interesting properties of the neural network function space are exactly calculable in the infinite width limit and at initialization. This tradition has continued with asymptotic studies of the neural tangent kernel as well as infinite depth limits . However, there is increasing recognition of the limits of these limits, with researchers now often trying to relax limiting results . And ultimately, from a practitioner’s perspective, these results can be difficult to make sense of . In contrast, our framework eschews any kind of limiting or probabilistic analysis. As a consequence, we believe our framework is simpler, more easily relatable to basic mathematical concepts, and ultimately more relevant to what one may encounter in, say, a PyTorch program.

Majorization

In recent work, Streeter and Dillon propose a universal majorize-minimize algorithm : a method that automatically computes and minimizes a majorizer for any computational graph. Despite its generality, current downsides to the method include its overhead, which can be 2×\times per step , as well as the risk that use of a full majorization may be overly pessimistic. Indeed, Cho and Shin find that an optimization approach leveraging second-order information converges significantly faster than a majorization-inspired approach. Related ideas appear in .

Descent in Normed Spaces

We define the modular norm in Section 3. This section is intended to prime the reader for what is to come. In this section, and the rest of the document, the diamond operator \diamond denotes tensor contraction.

the loss function is differentiable, meaning that the gradient map wL:WW\nabla_{\bm{w}}\mathcal{L}:\mathcal{W}\to\mathcal{W} exists;

the loss is Lipschitz smooth in the norm \|{\cdot}\|, with sharpness constant λ>0\lambda>0, meaning that:

Under these conditions, the weight update given by Δw=arg min[wL(w)Δw+λ2Δw2]\smash{\Delta{\bm{w}}=\operatorname*{arg\,min}\left[\nabla_{\bm{w}}\mathcal{L}({\bm{w}})\diamond\Delta{\bm{w}}+\frac{\lambda}{2}\|{\Delta{\bm{w}}}\|^{2}\right]} is guaranteed to reduce the loss. The particular norm \|{\cdot}\| influences the direction of this weight update, while the sharpness constant λ\lambda influences the size of the update.

In deep learning, we would ideally like the optimal step-size to remain invariant as we scale, say, the width and the depth of the network. Thus, a fundamental problem is to design a norm such that, first, Equation 2.1 actually holds (and is not hopelessly lax), and second, the corresponding sharpness constant λ\lambda is invariant to the relevant architectural dimensions. If the norm is chosen poorly, the practitioner may end up having to re-tune the step size as the network is scaled up. In this paper, we design a norm for neural networks that meets these requirements: the modular norm.

2 Preview of the modular norm

turns out to hold quite tightly when ΔW\Delta{\bm{W}} is a gradient update and x{\bm{x}} is a corresponding layer input. This is because gradient updates to a layer are (sums of) outer products that align with layer inputs.

Once we know how to metrize individual layers, a natural question is: can we combine layer-wise norms to produce a norm on the full weight space W=kWk\smash{\mathcal{W}=\prod_{k}\mathcal{W}_{k}} of the network? Naïvely, there are many ways to do this: one could take any positive linear combination of the layer-wise norms (L1L^{1} combination), the square root of any combination of the squared layer-wise norms (L2L^{2} combination), and so on. But we want the norm to be useful by the criteria of Section 2.1. To this end, we propose the modular norm W\|{\cdot}\|_{\mathcal{W}}, which ends up as a max (LL^{\infty} combination) of scaled layer-wise norms Wk\|{\cdot}\|_{\mathcal{W}_{k}}:

The positive scalar constants s1,,sLs_{1},\ldots,s_{L} are determined by both the architecture of the network and a set of user-specified “mass” parameters. The precise construction of the modular norm, working recursively over the module tree of the network, is given in Section 3; there, we also explain how the modular norm satisfies the criteria of Section 2.1, and the role played by the mass parameters. For now, let us explain what good the modular norm yields in practice.

3 Normed optimization

The main practical use of the modular norm is to normalize weight updates. With reference to Equation 2.3, we define the following operation on weight updates Δw=(Δw1,,ΔwL)W\Delta{\bm{w}}=(\Delta{\bm{w}}_{1},\ldots,\Delta{\bm{w}}_{L})\in\mathcal{W}:

Provided none of the Δwk\Delta{\bm{w}}_{k} are zero, then normalize(Δw)\mathsf{normalize}(\Delta{\bm{w}}) is a unit vector in the modular norm. We propose using normalize\mathsf{normalize} as a wrapper, along with an explicit learning rate schedule, for any base optimizer such as Adam or SGD. The resulting normed optimizer is thus made architecture-aware via the normalize function. In pseudo-code—and actual Modula code—this amounts to:

delta_w = optim(w.grad()) # get update from base optimizer net.normalize(delta_w) # normalize update in the modular norm w -= eta(step) * delta_w # apply update with learning rate eta

We find this wrapper to significantly improve the scalability of the base optimizer. It renders the optimal learning rate roughly invariant to width and depth, with seemingly no cost to accuracy. In some instances, it enables training with a simpler optimizer—for example, training GPT with SGD rather than Adam—thus incurring a smaller memory footprint.

Normalization in the modular norm essentially forces individual layers to learn at specified, regulated rates. We view this as balancing learning across the network; no individual layer can learn too fast and destabilize training. This balance is determined by the architecture, along with user-specified mass parameters that provide precise control over the relative learning speed in different submodules.

For a variety of experiments with normed optimization, see Sections 4 and Appendix D. But first, we detail the construction of the modular norm along with its core properties.

Constructing the Modular Norm

Our strategy is to first define the abstract notion of a module, which includes a norm as an attribute. We depict this concept in Figure 2. Then, by providing rules for composing and concatenating modules, we recursively define a norm for any module built via an arbitrary sequence of compositions and concatenations: the modular norm!

A module is a re-usable, composable object useful for building complicated neural networks. Our definition of a module augments the PyTorch module with two real numbers and a norm:

Given input vector space X\mathcal{X}, output vector space Y\mathcal{Y} and weight vector space W\mathcal{W}, a module M\mathsf{M} is an object with the following four attributes:

a function, M.forward:W×XY\mathsf{M}\mathsf{.forward}:\mathcal{W}\times\mathcal{X}\to\mathcal{Y}, which maps an input and a weight vector to an output—we often abbreviate this attribute to just MM.forward\mathsf{M}\equiv\mathsf{M}\mathsf{.forward};

a number, M.mass0\mathsf{M}\mathsf{.mass}\geq 0, which will turn out to set the proportion of feature learning that this module contributes to any supermodule;

a number, M.sensitivity0\mathsf{M}\mathsf{.sensitivity}\geq 0, which estimates the module’s sensitivity to input perturbations;

Before we say more about the intended roles of these attributes, let us mention the three kinds of modules that we will care about in practice:

atomic modules, whose attributes are hand-declared, and have weights. Examples include linear modules, embedding modules, and convolution modules.

bond modules, whose attributes are hand-declared, but have no weights. Formally, their weight space is the zero vector space W=0\mathcal{W}=0. An example is the ReLU\mathsf{ReLU} non-linearity module.

compound modules, built out of other modules, with automatically inferred attributes.

Note that the space of objects that type-check as a module by Definition 1 is vast. Since we need to hand-declare atomic and bond modules in order to build interesting compound modules, we should have an idea of what makes for a “good” module. Simply put, a module is good when its attributes are predictive of its behaviour. To formalize this idea, we say that a module is well-normed if its forward function, sensitivity, and norm satisfy the following two relationships:

Let M\mathsf{M} be a module on (X,Y,W)(\mathcal{X},\mathcal{Y},\mathcal{W}), where the input and output spaces have respective norms X\|{\cdot}\|_{\mathcal{X}} and Y\|{\cdot}\|_{\mathcal{Y}}. M\mathsf{M} is well-normed if for all inputs xX{\bm{x}}\in\mathcal{X} and weights wW{\bm{w}}\in\mathcal{W}:

Well-normed-ness means that the norm function and sensitivity are a good match for the forward function. The first inequality says that a well-normed module is Lipschitz-continuous over its weight space with a constant one. The second inequality says that a well-normed module is Lipschitz-continuous over its input space with constant M.sensitivity\mathsf{M}\mathsf{.sensitivity}. In practice, we will be interested in well-normed modules where these inequalities hold fairly tightly, since then M.sensitivity\mathsf{M}\mathsf{.sensitivity} and M.norm\mathsf{M}\mathsf{.norm} will let us estimate the sensitivity of the module to input and weight perturbations. Appendix B provides many examples of well-normed atomic and bond modules.

The remaining attribute M.mass\mathsf{M}\mathsf{.mass} will turn out to control the proportion of feature learning that a module contributes to any compound module in which it participates. We formalize this concept in Section 3.3. But before that, we need to understand how to build compound modules.

2 Compound modules: Building new modules from old

We consider building new modules from old ones via the binary operations of composition and concatenation, illustrated in Figure 2. Composition is denoted via the serial combination M2M1\mathsf{M}_{2}\circ\mathsf{M}_{1}, and concatenation via the parallel combination (M1,M2)(\mathsf{M}_{1},\mathsf{M}_{2}), alternatively referred to as a module tuple. These simple binary combinations will let us build basic algebraic operations on modules (Table 1) as well as complex neural network architectures. We start by defining module composition:

Consider module M1\mathsf{M}_{1} with input, output and weight space (X1,Y1,W1)(\mathcal{X}_{1},\mathcal{Y}_{1},\mathcal{W}_{1}) and module M2\mathsf{M}_{2} with input, output and weight space (X2,Y2,W2)(\mathcal{X}_{2},\mathcal{Y}_{2},\mathcal{W}_{2}). M1\mathsf{M}_{1} and M2\mathsf{M}_{2} are composable if X2=Y1\mathcal{X}_{2}=\mathcal{Y}_{1}. Their composite M=M2M1\mathsf{M}=\mathsf{M}_{2}\circ\mathsf{M}_{1} lives on (X1,Y2,W1×W2)(\mathcal{X}_{1},\mathcal{Y}_{2},\mathcal{W}_{1}\times\mathcal{W}_{2}) with attributes:

M.forward((w1,w2),x))=M2.forward(w2,M1.forward(w1,x))\mathsf{M}\mathsf{.forward}(({\bm{w}}_{1},{\bm{w}}_{2}),{\bm{x}}))=\mathsf{M}_{2}\mathsf{.forward}({\bm{w}}_{2},\mathsf{M}_{1}\mathsf{.forward}({\bm{w}}_{1},{\bm{x}}));

M.mass=M1.mass+M2.mass\mathsf{M}\mathsf{.mass}=\mathsf{M}_{1}\mathsf{.mass}+\mathsf{M}_{2}\mathsf{.mass};

M.sensitivity=M1.sensitivityM2.sensitivity\mathsf{M}\mathsf{.sensitivity}=\mathsf{M}_{1}\mathsf{.sensitivity}*\mathsf{M}_{2}\mathsf{.sensitivity};

M.norm((w1,w2))\mathsf{M}\mathsf{.norm}(({\bm{w}}_{1},{\bm{w}}_{2})) given by:

where if M1.mass\mathsf{M}_{1}\mathsf{.mass} or M2.mass\mathsf{M}_{2}\mathsf{.mass} is zero, the corresponding term in the max\max is set to zero.

At this stage, we make two comments about this definition. First, in the definition of the composite norm, notice that the norm of the first module couples with the sensitivity of the second module. This reflects the fact that the output of the first module is fed into the second module and not vice versa. Second, observe that the masses of the submodules are involved in setting the balance of the composite norm. Before we further motivate this definition, let us first define module concatenation:

Consider module M1\mathsf{M}_{1} with input, output and weight space (X1,Y1,W1)(\mathcal{X}_{1},\mathcal{Y}_{1},\mathcal{W}_{1}) and module M2\mathsf{M}_{2} with input, output and weight space (X2,Y2,W2)(\mathcal{X}_{2},\mathcal{Y}_{2},\mathcal{W}_{2}). We say that M1\mathsf{M}_{1} and M2\mathsf{M}_{2} are concatenatable if their input spaces match: X1=X2\mathcal{X}_{1}=\mathcal{X}_{2}. The tuple M=(M1,M2)\mathsf{M}=(\mathsf{M}_{1},\mathsf{M}_{2}) has input, output and weight space (X1,Y1×Y2,W1×W2)(\mathcal{X}_{1},\mathcal{Y}_{1}\times\mathcal{Y}_{2},\mathcal{W}_{1}\times\mathcal{W}_{2}) and attributes:

M.forward((w1,w2),x))=(M1.forward(w1,x),M2.forward(w2,x))\mathsf{M}\mathsf{.forward}(({\bm{w}}_{1},{\bm{w}}_{2}),{\bm{x}}))=(\mathsf{M}_{1}\mathsf{.forward}({\bm{w}}_{1},{\bm{x}}),\mathsf{M}_{2}\mathsf{.forward}({\bm{w}}_{2},{\bm{x}}));

M.mass=M1.mass+M2.mass\mathsf{M}\mathsf{.mass}=\mathsf{M}_{1}\mathsf{.mass}+\mathsf{M}_{2}\mathsf{.mass};

M.sensitivity=M1.sensitivity+M2.sensitivity\mathsf{M}\mathsf{.sensitivity}=\mathsf{M}_{1}\mathsf{.sensitivity}+\mathsf{M}_{2}\mathsf{.sensitivity};

M.norm(w1,w2)\mathsf{M}\mathsf{.norm}({\bm{w}}_{1},{\bm{w}}_{2}) given by:

where if M1.mass\mathsf{M}_{1}\mathsf{.mass} or M2.mass\mathsf{M}_{2}\mathsf{.mass} is zero, the corresponding term in the max\max is set to zero.

Concatenation is simpler than composition in the sense that neither module is fed through the other, and therefore, sensitivity does not appear in the concatenated norm. To further motivate these definitions, observe that two basic and desirable properties follow as immediate consequences:

If modules M1,M2,M3\mathsf{M}_{1},\mathsf{M}_{2},\mathsf{M}_{3} are successively composable, then M3(M2M1)\mathsf{M}_{3}\circ(\mathsf{M}_{2}\circ\mathsf{M}_{1}) equals (M3M2)M1(\mathsf{M}_{3}\circ\mathsf{M}_{2})\circ\mathsf{M}_{1} in all attributes. If modules M1,M2,M3\mathsf{M}_{1},\mathsf{M}_{2},\mathsf{M}_{3} are mutually concatenatable, then ((M1,M2),M3)((\mathsf{M}_{1},\mathsf{M}_{2}),\mathsf{M}_{3}) equals (M1,(M2,M3))(\mathsf{M}_{1},(\mathsf{M}_{2},\mathsf{M}_{3})) in all attributes.

If modules M1\mathsf{M}_{1} and M2\mathsf{M}_{2} are well-normed and composable, then their composite M2M1\mathsf{M}_{2}\circ\mathsf{M}_{1} is also well-normed. If modules M1\mathsf{M}_{1} and M2\mathsf{M}_{2} are well-normed and concatenatable, then their tuple (M1,M2)(\mathsf{M}_{1},\mathsf{M}_{2}) is also well-normed with respect to the L1L^{1} combination norm on the output space: (,)Y1×Y2=Y1+Y2.\|{(\cdot,\cdot)}\|_{\mathcal{Y}_{1}\times\mathcal{Y}_{2}}=\|{\cdot}\|_{\mathcal{Y}_{1}}+\|{\cdot}\|_{\mathcal{Y}_{2}}.

The proofs follow directly from the definitions and the chain rule. Proposition 1 implies that one may build complicated compound modules without worrying in which order successive combinations are taken. Proposition 2 implies that complicated compounds automatically inherit Lipschitz guarantees.

Taken together, Definitions 3 and 4 define the modular norm M.norm\mathsf{M}\mathsf{.norm} of any compound module M\mathsf{M}.

3 Mass allocation in compound modules

Suppose we wish to train a network with an input layer, an output layer, and LL blocks between:

Then how much learning should happen in the output layer, compared to the blocks, compared to the input layer? And what if we scale the number of blocks LL—do we want relatively less learning to occur in the network’s extremities? Or do we want the input and output layers to learn non-trivially even in the LL\to\infty limit? Since answering these questions is difficult a priori, we introduced the mass parameter to allow a user to set the proportional contribution each module has toward learning:

Consider a compound module M\mathsf{M} derived in any fashion from LL well-normed modules M1,,ML\mathsf{M}_{1},\ldots,\mathsf{M}_{L}. Given weight setting w=(w1,,wL){\bm{w}}=({\bm{w}}_{1},\ldots,{\bm{w}}_{L}), where wk{\bm{w}}_{k} denote the weights of module Mk\mathsf{M}_{k}, let us perturb w{\bm{w}} by Δw=(Δw1,,ΔwL)\Delta{\bm{w}}=(\Delta{\bm{w}}_{1},\ldots,\Delta{\bm{w}}_{L}). If we decompose the linearized change in the output of module M\mathsf{M} into one contribution per sub-module:

then the kkth term in this decomposition satisfies:

In words: module mass provides the flexibility needed to build complicated compound modules involving many sub-modules, while maintaining precise control over how much learning any sub-module can contribute to the overall compound. Proposition 3 is proved in Appendix E.

In practice, we obtained the best training performance by maintaining a constant amount of learning in the input and output layers even as the number of blocks is scaled (Figure 6). In other words, it seems to be a good idea to assign OutputLayer.mass:HiddenLayers.mass:InputLayer.mass\mathsf{OutputLayer}\mathsf{.mass}:\mathsf{HiddenLayers}\mathsf{.mass}:\mathsf{InputLayer}\mathsf{.mass} in proportion 1:m:11:m:1, where mm is independent of the number of blocks LL. The exact mass of the hidden layers mm needs to be tuned on a new architecture—just as one needs to tune separate learning rates in the input and output layers in μ\muP ; this tuning can be done on a small model prior to scaling (Figure 3). We further discuss mass allocation in Section D.6.

4 Smoothness in the modular norm

In this section, we study the second derivatives of a module using the modular norm as a measuring stick. Let us start by defining the notion of sharpness that we will consider:

Let M\mathsf{M} be a module on (X,Y,W)(\mathcal{X},\mathcal{Y},\mathcal{W}), where the input and output spaces have respective norms X\|{\cdot}\|_{\mathcal{X}} and Y\|{\cdot}\|_{\mathcal{Y}}. We say that M\mathsf{M} is (α,β,γ)(\alpha,\beta,\gamma)-sharp for constants α,β,γ0\alpha,\beta,\gamma\geq 0 if, at all inputs xX{\bm{x}}\in\mathcal{X} and weights wW{\bm{w}}\in\mathcal{W}, the second derivatives of M\mathsf{M} are bounded as:

While one may ultimately be interested in the sharpness of a module with respect to weight perturbations, Definition 5 also tracks sharpness with respect to input perturbations. In fact, tracking this extra information is essential for propagating sharpness bounds up the module tree. Appendix C details the procedure for automatically calculating the sharpness constants of a compound module starting from the sharpness constants of all its submodules; see Propositions 8 and 9 for the specific formulae. Here we highlight one major corollary of these formulae, proved in Appendix E: for a specific choice of block multipliers, the sharpness constant of a residual network is independent of depth:

Suppose M\mathsf{M} is a well-normed, (α,β,γ)(\alpha,\beta,\gamma)-sharp module on (X,X,W)(\mathcal{X},\mathcal{X},\mathcal{W}) with unit sensitivity. Define the depth LL residual module ResL(M)\mathsf{Res}_{L}(\mathsf{M}) via the module arithmetic of Table 1 as:

Then this residual module ResL(M)\mathsf{Res}_{L}(\mathsf{M}) is in fact (α+β+γ3,β+γ2,γ)(\alpha+\beta+\tfrac{\gamma}{3},\beta+\tfrac{\gamma}{2},\gamma)-sharp, independent of depth LL.

For optimization purposes, one may be more interested in the sharpness of the loss function rather than the sharpness of the neural network. Fortunately, it is possible to convert sharpness bounds on modules into sharpness bounds on loss functions, provided a little is known about the error measure:

If the module M\mathsf{M} is well-normed and (α,β,γ)(\alpha,\beta,\gamma)-sharp, then the loss function L\mathcal{L} satisfies the following three inequalities at all weight settings wW{\bm{w}}\in\mathcal{W} and for all weight perturbations Δw,Δw~W\Delta{\bm{w}},\Delta\widetilde{{\bm{w}}}\in\mathcal{W}:

Δwww2LΔw~(σα+τ)ΔwMΔw~M;|{\Delta{\bm{w}}\diamond\nabla^{2}_{{\bm{w}}{\bm{w}}}\mathcal{L}\diamond\Delta\widetilde{{\bm{w}}}}|\leq(\sigma\alpha+\tau)\,\|{\Delta{\bm{w}}}\|_{\mathsf{M}}\,\|{\Delta\widetilde{{\bm{w}}}}\|_{\mathsf{M}};

wL(w+Δw)wL(w)M(σα+τ)ΔwM\|{\nabla_{{\bm{w}}}\mathcal{L}({\bm{w}}+\Delta{\bm{w}})-\nabla_{{\bm{w}}}\mathcal{L}({\bm{w}})}\|_{\mathsf{M}}^{*}\leq(\sigma\alpha+\tau)\,\|{\Delta{\bm{w}}}\|_{\mathsf{M}},

where M\|{\cdot}\|_{\mathsf{M}}^{*} is the dual norm of M\|{\cdot}\|_{\mathsf{M}};

L(w+Δw)[L(w)+wLΔw]12(σα+τ)ΔwM2.\left|\mathcal{L}({\bm{w}}+\Delta{\bm{w}})-\left[\mathcal{L}({\bm{w}})+\nabla_{\bm{w}}\mathcal{L}\diamond\Delta{\bm{w}}\right]\right|\leq\tfrac{1}{2}(\sigma\alpha+\tau)\,\|{\Delta{\bm{w}}}\|_{\mathsf{M}}^{2}.

The proof is given in Appendix E, and we present estimates for σ\sigma and τ\tau for common error measures in Section C.4. Notice that inequalities (i), (ii) and (iii) are the standard inequalities of smooth optimization , albeit expressed in the modular norm. In fact, (i) implies (ii) implies (iii). In words, inequality (ii) says that the gradient of the loss is Lipschitz-continuous in the modular norm. The Lipschitz constant depends on the module only through the module’s first sharpness coefficient α\alpha.

Experiments

Our experiments aimed to test the scalability of training with normed versions of Adam and SGD: whether one can tune the learning rate on a small model, and expect the learning rate to remain close to optimal on models of much larger width and depth. In addition to the learning rate, normed optimization in Modula requires a mass parameter to apportion feature learning between the input, output and hidden layers; we also tested the sensitivity of this parameter, whether it affects learning rate transfer, and to what extent the optimal mass itself transfers across width and depth.

All SGD experiments were done with momentum β=0.9\beta=0.9, and all Adam experiments used β1=0.9\beta_{1}=0.9 and β2=0.99\beta_{2}=0.99. No weight decay was used in any experiment. Every experiment was done with a linear decay learning rate schedule. As for initialization, we used orthogonal initialization for Linear\mathsf{Linear} and Conv2D\mathsf{Conv2D} modules, and Gaussian weights projected to a unit norm ball for our Embed\mathsf{Embed} module. This was to ensure all modules were well-normed at initialization. Precise versions of our architectures are described in Appendices B.5 and B.7. We compare with nanoGPT using standard initialization in Section D.4 to make sure our changes recover standard performance. We actually found unnormed Adam using our GPT architecture transferred learning rate better than in nanoGPT.

We found that normed optimization, with both Adam and SGD as the base optimizer, allows for successful learning rate transfer across width and depth for GPT training on OpenWebText (Figure 1), as well as ResMLP and ResNet training on CIFAR-10 (Figure 4). We present expanded results in Section D.5, including results on test loss. We reproduce the standard finding that train and test loss are remarkably simillar in large language model pretraining. As for mass allocation, Figure 3 shows that optimal mass transfers with depth for training a ResMLP on CIFAR-10 with normed Adam, and also that both mass and learning rate transfer quite well from a smaller GPT on OpenWebText to a larger one. We detail more experiments on mass allocation in Section D.6.

Discussion: Limitations and Future Work

This paper was influenced by four main streams of work: first, the Tensor Programs series, starting at TP-IV ; second, the papers on universal majorize-minimize algorithms ; third, work on deep network metrization ; and fourth, the open source deep learning ecosystem including the PyTorch module tree and Karpathy’s YouTube video on autograd . We have distilled and synthesized key ideas from these sources, creating a framework that we believe to be simpler than Tensor Programs, computationally lighter than universal majorization-minimization, more general than prior work on metrization and more scalable than the PyTorch module tree. We have packaged these ideas into a (soon-to-be) open-source library called Modula. Inevitably, Modula has limitations. We highlight some of them here, along with associated avenues for future work.

Loss of well-normed-ness. We have emphasized well-normed-ness (Definition 2) as an important criterion in module design. We show in Section B.1 that, for example, the Linear\mathsf{Linear} module is well-normed when its weights lie within a spectral norm ball. In our experiments, we initialize all weights so that all modules are well-normed, but we do not enforce this property throughout training. Future work could explore regularization as a means to enforce well-normed-ness throughout training, with the hope of attaining better scalability or improved generalization.

Overhead of normalization. As discussed in Section A.3, we implement normalization for Linear\mathsf{Linear} and Conv2D\mathsf{Conv2D} modules using two steps of online power iteration. While online power iteration is an established and fast primitive in deep learning—in fact, coming from the GAN literature —it does add a modest overhead to training time, as discussed in Section A.4. We think it may be possible to mitigate this overhead by constructing atomic modules with more exotic operator norms. For example, if one equips feature vectors with the LL^{\infty} norm rather than the RMS norm, then the induced LL^{\infty}LL^{\infty} matrix norm is cheaper to compute than the RMS–RMS operator norm. In fact, LL^{\infty}LL^{\infty} operator normalization has the convenient feature that it decouples over matrix rows, making it more local than spectral normalization and, dare-we-say, more biologically plausible.

Automatic step-size selection. Beyond scalability, recent work has explored the question of automatic learning rate selection , with the Prodigy optimizer serving as a popular example. We tested the Adam version of Prodigy and found it performs well at small scales, essentially working by an implicit form of line search. However, Prodigy will always break at large enough widths, since it requires a lower bound (d0d_{0}) on Adam’s initial learning rate; Yang et al. showed that no such lower bound exists. We believe this issue could be fixed by rebuilding Prodigy on top of Modula. More broadly, we think that designing line search methods in a properly-normed space is a good idea.

Acknowledgements

We are grateful to Chris Mingard, Virgile Richard and Evan Kiely for useful discussions early in the project. Tongzhou Wang and Jyo Pari provided helpful feedback on the writing and figures.

The work was supported by a Packard Fellowship and a Sloan Research Fellowship to PI, by the MIT-IBM Watson AI Lab, by ONR MURI grant N00014-22-1-2740 and the MIT Quest for Intelligence. TL was supported by a Simons Junior Fellowship.

Contribution Statement

All authors were involved in project conception and discussions, which were initiated by JB. TL developed the theory with input from JB. MH and YL made core experimental observations. YL, MH, JB, and HB ran experiments. TL and JB did most of the writing, while JB, MH and YL made the figures. PI contributed guidance and helpful feedback throughout the course of the project. JB wrote the Modula package with help from MH.

References

Appendix Appendix A The Modula Package

We created a Python package called Modula that realizes our module framework in code. Modula supplements PyTorch’s Tensor class with two new classes: Vector and Module.

The Vector class is used to store the weights of a module. It allows for basic algebraic operations to be performed on module weights without needing to write for loops over lists of tensors. For example, if v_1 and v_2 are vectors with the same sub-structure, then one may write expressions such as v_1 + v_2 for the vector sum, or v_1 * v_2 for the elementwise product. Internally, a Vector stores a list of tensors and implements operations using efficient PyTorch foreach primitives.

A.2 The Module class

The most significant aspect of the Modula package is the Module class. A Module must have six attributes: two float attributes, namely mass and sensitivity. And four methods:

forward(w: Vector, x: Tensor) -> Tensor # returns an output tensor

initialize() -> Vector # randomly samples a weight vector

normalize(w: Vector) # normalizes w to have unit modular norm

regularize(w: Vector, strength: float) # regularizes w in-place

The norm of a module is not specifically implemented, instead we use the normalize method which is how the norm is directly used in optimization.

We refer to modules with hand-specified attributes as bonds if they have no weights and atoms if they have weights. Modules formed by combining existing modules are called compounds. Modula automatically constructs the attributes of compound modules. We provide reference implementations for many common modules—see Appendix B. We equip atoms with their natural operator norm, and compute spectral norms via online power iteration. Reference modules may be imported as follows:

from modula.bond import Identity, ReLU, Abs, FunctionalAttention from modula.atom import Linear, Embed, Conv2D from modula.compound import ResMLP, ResCNN, Attention, GPT

To make building new compounds easier, Modula overloads the following operations on modules:

M_2 @ M_1 # composes module M_2 with module M_1

(M_1, M_2) # acts as a tuple module in any further composition

M ** L # returns the Lth iterate of module M

builds an L-layer residual network from base module M. Comparing with Equation 3.10, we see that Modula expressions closely resemble their mathematical counterparts.

Finally, all modules come with a convenience method tare(m: float), which resets the module mass to m, with default m=1.

A.3 Normalization in Modula

We can normalize any base optimizer in the modular norm using the following pattern:

delta_w = optim(w.grad()) # get update from base optimizer net.normalize(delta_w) # normalize update in the modular norm w -= lr * delta_w # apply update to weights

Computation of net.normalize(delta_w) requires an efficient estimation of the spectral matrix norm, in the last two dimensions, of the constituent tensors of delta_w; this can be done very quickly to reasonable accuracy using power iteration. We implement this by storing a running estimate of the top singular vector u for each constituent tensor of delta_w. At initialization, u is sampled Gaussian, and each time we normalize a weight update, the previous update’s estimated singular vector is used as the starting value for the power iteration. This enables us to use just two steps of power iteration per weight update. Indeed, for any base optimizer with momentum, successive weight updates should be fairly close; for training without momentum more steps of power iteration may be required.

A.4 Overhead

To test the overhead of normalization in the modular norm, we trained a width 64 ResMLP with 8 blocks and block-depth 2 for 10k steps on the CIFAR-10 dataset. We repeated the experiment with and without normalization, and in each case with three different random seeds. Without normalization, the training took 101±1101\pm 1 seconds, and with normalization the training took 124±1124\pm 1 seconds. So in this experiment, the overhead of modular normalization was around 23%.

We note that the user of the Modula package is free to write new atomic modules with cheaper or more efficient normalize functions. For instance, the Frobenius norm can be used as a proxy for the spectral norm whenever the weight updates have low stable rank . And we note in Section 5 that one could explore more exotic norms such as the LL^{\infty}LL^{\infty} operator norm, which is cheaper to compute than the standard spectral norm. Beyond these suggestions, one could explore CUDA-level optimizations to spectral norm computation, which is something that we have not explored.

Appendix Appendix B Module and Network Design

In this appendix, we list the basic, hand-declared modules that serve as building blocks for more complicated neural networks. Then we go on to show how these modules may be combined to yield interesting neural networks. This includes discussion of module broadcasting (Section B.3) and mass taring (Section B.4). The appendix culminates with case studies on attention (Section B.6) and transformers (Section B.7).

An atomic module or atom for short is a module with nonzero mass and nonzero parameter space, whose attributes are specifically declared rather than derived. Setting an atom’s mass to zero has the effect of freezing its weights under normed optimization.

These conditions will be automatically satisfied for many neural networks under orthogonal initialization of the weights, and especially if a linear module is immediately preceded by something like a LayerNorm\mathsf{LayerNorm} module. Moreover, orthogonal initialization guarantees that the well-normed inequality

holds tightly in nearly-square matrices at initialization, which is important for getting good signal propagation through the whole network.

Moreover, inspection of the second derivative formula above shows it is always (0,1,0)(0,1,0)-sharp with respect to the RMS norms on the input and output spaces.

This is at first sight similar to the linear module, the key difference being that in applications we expect the inputs of Embed(n,d)\mathsf{Embed}(n,d) to be one-hot vectors; as such we consider its input space to carry the L1L^{1}-norm.

B.2 Bond modules

A bond module or bond is a module with zero mass and zero parameter space. They are the “glue” between the atomic modules, needed to construct complex neural networks.

Note that we need not specify a weight space, or mass or norm arguments for a bond module. Moreover, when discussing whether a bond module is (α,β,γ)(\alpha,\beta,\gamma)-sharp, the inequalities for α\alpha and β\beta are vacuous; thus for bond modules we will abbreviate this notion to γ\gamma-sharp.

To begin, we need two bond modules that are essentially “utility”, as they are crucial for defining basic secondary module operations. These modules are also “type polymorphic” in the sense that they work with any underlying vector space.

For any vector space Y\mathcal{Y}, the adder module Add\mathsf{Add} has inputs Y×Y\mathcal{Y}\times\mathcal{Y} and outputs Y\mathcal{Y}. It has forward function

and sensitivity 1. Its significance is that it allows for concatenable modules to be added:

For any norm Y\|{\cdot}\|_{\mathcal{Y}} on the vector space Y\mathcal{Y}, Add\mathsf{Add} is well-normed with respect to the L1L^{1} combination norm (y1,y2)Y×Y:=y1Y+y2Y\|{({\bm{y}}_{1},{\bm{y}}_{2})}\|_{\mathcal{Y}\times\mathcal{Y}}:=\|{{\bm{y}}_{1}}\|_{\mathcal{Y}}+\|{{\bm{y}}_{2}}\|_{\mathcal{Y}} on its input space. Furthermore, Add\mathsf{Add} is -sharp.

For any normed vector space Y\mathcal{Y} and real number λ\lambda the scalar multiplier module Mulλ\mathsf{Mul}_{\lambda} has inputs Y\mathcal{Y} and outputs Y\mathcal{Y}. Its forward function is:

and its sensitivity is λ|{\lambda}|. Its significance is that it allows for scalar multiplication of modules:

It is well-normed with respect to any norm on Y\mathcal{Y}, and -sharp. When λ=1\lambda=1, we call this the identity module Identity=Mul1\mathsf{Identity}=\mathsf{Mul}_{1}. Note that λIdentity=Mulλ\lambda*\mathsf{Identity}=\mathsf{Mul}_{\lambda} for any λ\lambda.

The remaining bond modules are used explicitly as non-linearities in neural networks.

and RMS norm on inputs and ouputs. For more on this design decision, see . We also define ScaledReLU:=2ReLU\mathsf{ScaledReLU}\vcentcolon=\sqrt{2}*\mathsf{ReLU} to be the unit sensitivity counterpart to ReLU\mathsf{ReLU}.

where Φ(x)=x12πet2/2dt\Phi(x)=\int_{-\infty}^{x}\tfrac{1}{\sqrt{2\pi}}e^{-t^{2}/2}dt is the cumulative distribution function of the standard Gaussian.

GELU\mathsf{GELU} is well-normed in the same sense as ReLU\mathsf{ReLU}. We similarly set ScaledGeLU=2GELU\mathsf{ScaledGeLU}=\sqrt{2}*\mathsf{GELU}.

and has sensitivity 1. It is well-normed, and since it is a linear mapping, it is 0-sharp.

and has sensitivity 1. While it is not automatically well-normed, as long as its inputs have xRMS1\|{{\bm{x}}}\|_{\mathsf{RMS}}\approx 1, the required inequality is not very far off. Similarly, it is approximately 11-sharp.

As with RMSDivide\mathsf{RMSDivide}, it is approximately well-normed and approximately 11-sharp.

B.3 Module broadcasting

Let us briefly discuss a supplementary module operation, which we refer to as module broadcasting.

Suppose M\mathsf{M} is a module with inputs X\mathcal{X}, outputs Y\mathcal{Y} and weights W\mathcal{W}. Then for any h1h\geq 1, the hh-times-broadcast of M\mathsf{M} is the module M(h)\mathsf{M}^{(h)} with the same weight space W\mathcal{W}, mass, sensitivity and norm as M\mathsf{M}, but inputs the Cartesian power Xh=X××X\mathcal{X}^{h}=\mathcal{X}\times\ldots\times\mathcal{X} and outputs Yh=Y××Y\mathcal{Y}^{h}=\mathcal{Y}\times\ldots\times\mathcal{Y}, and forward function

Since this is not defining a module with a new set of weights, we will usually just refer to the broadcast module by the same name M\mathsf{M}, and consider this as just an extension of its forward function.

If M\mathsf{M} is well-normed, then so is any broadcast of M\mathsf{M} taking Xh\mathcal{X}^{h} to Yh\mathcal{Y}^{h}, as long as the norms on Xh\mathcal{X}^{h} and Yh\mathcal{Y}^{h} are taken to be either the “mean LpL^{p}” norms

for 1p1\leq p\leq\infty; when p=p=\infty this is just the max norm. In the case that M\mathsf{M} is a bond module (so W=0\mathcal{W}=0, any scalar multiple of the mean LpL^{p} norm can be used (including the standard LpL^{p} norm).

The situation for sharpness is a bit more complicated; we discuss this in Section C.3.

B.4 Mass taring

In order to make working with the mass parameter of modules a bit easier, let us introduce an auxiliary operation:

This way, one can build complex modules starting from atomic modules with unit masses, and then using tare\mathsf{tare} later to reset their masses to desired quantities for better feature learning with normed descent as in Proposition 3.

B.5 Compound modules and neural networks

Composition, concatenation and the secondary operations of addition, scalar multiplication and iterated concatenation allow us to build a wide variety of neural networks which thus come automatically endowed with the modular norm.

Deep neural networks are typically built as long series of compositions. Let us introduce some terminology:

A deep neural network is a module M\mathsf{M} formed by a composition

where InputLayer,Block1,,BlockL,OutputLayer\mathsf{InputLayer},\mathsf{Block}_{1},\ldots,\mathsf{Block}_{L},\mathsf{OutputLayer} are modules; the number of blocks L1L\geq 1 is the depth of the network.

Typically, each of Block1,,BlockL\mathsf{Block}_{1},\ldots,\mathsf{Block}_{L} will be copies of the same module (allowing them to take different weight values, of course), so that the network can be written as an iterated composition

InputLayer,Block,OutputLayer\mathsf{InputLayer},\mathsf{Block},\mathsf{OutputLayer} can be principle be any module one likes, but usually InputLayer\mathsf{InputLayer} is often some form of embedding module, and OutputLayer\mathsf{OutputLayer} is usually a linear module.

As for the form of Block\mathsf{Block}, we found the following design principle to be quite useful in practice:

Arrange so that each Block\mathsf{Block} has unit sensitivity.

This ensures that the sensitivity of the whole network stays bounded as LL\to\infty (this will also be the case if we ensure that Block.sensitivity=1+O(1/L)\mathsf{Block}\mathsf{.sensitivity}=1+O(1/L), but unit sensitivity has the advantage that the modular norm becomes very explicit). With this in mind:

Suppose that M\mathsf{M} is a module of unit sensitivity whose inputs and outputs are the same space X\mathcal{X}. For any L1L\geq 1, consider the residual block

and write ResL(M)=BlockL\mathsf{Res}_{L}(\mathsf{M})=\mathsf{Block}^{L}. This is of unit sensitivity, well-normed if M\mathsf{M} is, and moreover by Proposition 4 is sharp with O(1) sharpness if M\mathsf{M} is.

A general residual network with residue M\mathsf{M} is any neural network of the form

In practice, we will want to apply one more operation: we will want to tare the mass of the residual blocks. To this end, the residual network with residue M\mathsf{M}, depth LL and total block mass m>0m>0 is

Let us give two basic example of residual networks.

This is a simple residual variation on the multi-layer perceptron. For a width d1d\geq 1, consider the unit sensitivity module

This particular order of operations is inspired by a reecent paper of Yang et al. .

We invite the reader to compare this to something like ReLULinear(d,d)LayerNorm\mathsf{ReLU}\circ\mathsf{Linear}(d,d)\circ\mathsf{LayerNorm}: three core operations are being performed (but in a different order in both cases): the inputs are being normalized; the inputs are being centered; and the inputs are passed through a nonlinearity that mutates just the negative coordinates.

The ResMLP\mathsf{ResMLP} network has as its residue an iterated composition of M(d)\mathsf{M}(d), where the number of copies of M(d)\mathsf{M}(d) in each residue is called the block depth and denoted BB. It also has just linear initial and final modules. Thus the ResMLP\mathsf{ResMLP} network of depth LL, width dd, block depth BB and total block mass m>0m>0 is

Usually we suggest taking B=1B=1 or 22, and m1m\sim 1.

This is a version of ResNet for image classification tasks. For a width d1d\geq 1 and kernel size KK, consider similarly to above the unit sensitivity module

As in the ResMLP\mathsf{ResMLP}, the ResNet\mathsf{ResNet} network is a residual network with as its residue an iterated composition of BB copies of M(d,K)\mathsf{M}(d,K) where BB is the block depth. Its initial and final modules are given by

As defaults, we suggest taking B=2,K=3B=2,K=3 and m20m\sim 20.

B.6 Case study I: Attention

The core of the attention module is a bond module which we call functional attention.

Moreover, we set FuncAttention.sensitivity=1\mathsf{FuncAttention}\mathsf{.sensitivity}=1.

In theory, one could try break up attention further into constituent more basic modules (such as scaled dot product, softmax, etc), but keeping FuncAttention\mathsf{FuncAttention} as the basic unit one to leverage efficient implementations of attention such as FlashAttention .

In fact, a perhaps surprising result is that with the above 1dQ\frac{1}{d_{Q}} scaling of the dot product, we can estimate the sensitivity and sharpness of FuncAttention\mathsf{FuncAttention}. This relies on giving norms for the input and output spaces; these norms are chosen to be

Over the space of inputs q,k,v{\bm{q}},{\bm{k}},{\bm{v}} with each qRMS,kRMS,vRMS1\|{{\bm{q}}}\|_{\infty\mathsf{RMS}},\|{{\bm{k}}}\|_{\infty\mathsf{RMS}},\|{{\bm{v}}}\|_{\infty\mathsf{RMS}}\leq 1, the functional attention module FuncAttention\mathsf{FuncAttention} is well-normed, and moreover is sharp with sharpness constant γ=3\gamma=3.

The proof is given in Appendix E. We thus choose to adopt a 1dQ\frac{1}{d_{Q}}-dot-product scaling in our implementation of attention– a rigorous bound as above is not possible for 1dQ\frac{1}{\sqrt{d_{Q}}}-scaling, for instance.

We can then immediately define a single head of attention.

The scalar multiplication factor of 13\frac{1}{3} ensures that Attention\mathsf{Attention} has unit sensitivity.

For multiple heads of attention, we simply take advantage of module broadcasting (Definition 6):

where FuncAttention\mathsf{FuncAttention} is broadcast over the heads dimension. Note that in Modula, we do this by creating dummy bond modules called AddHeads and RemoveHeads to reshape the tensors and create/remove the explicit head dimension.

As in the single-headed case, the scalar multiplication factor of 13\frac{1}{3} ensures unit sensitivity.

B.7 Case study II: GPT

and form the mass one, sensitivity one module

The depth L1L\geq 1, width dd, total block mass m>0m>0 GPT\mathsf{GPT} module is thus

We suggest, as a default value, a total block mass of m5m\sim 5.

Appendix Appendix C More on Smoothness and Sharpness

All our estimates of sharpness for compound modules, as well as the smoothness estimate Proposition 5 for loss functions, depend on an application of the chain rule to compute second derivatives which in the optimization context is sometimes called the Gauss-Newton decomposition.

Indeed, this amounts to simply the following expression for partial derivatives:

C.2 Sharpness under composition and concatenation

Here, we state the two formulae for computing the sharpness of a composition and a concatenation of two modules. The proofs are given in Appendix E.

Suppose that M2\mathsf{M}_{2} and M1\mathsf{M}_{1} are well-normed, composable modules that are respectively (α2,β2,γ2)(\alpha_{2},\beta_{2},\gamma_{2})-sharp and (α1,β1,γ1)(\alpha_{1},\beta_{1},\gamma_{1})-sharp. Under the shorthand that pkMk.massM1.mass+M2.mass\smash{p_{k}\equiv\frac{\mathsf{M}_{k}\mathsf{.mass}}{\mathsf{M}_{1}\mathsf{.mass}+\mathsf{M}_{2}\mathsf{.mass}}} and μkMk.sensitivity\mu_{k}\equiv\mathsf{M}_{k}\mathsf{.sensitivity}, the composite M2M1\mathsf{M}_{2}\circ\mathsf{M}_{1} is (α,β,γ)(\alpha,\beta,\gamma)-sharp for:

Suppose that M1\mathsf{M}_{1} and M2\mathsf{M}_{2} are well-normed, concatenatable modules that are respectively (α1,β1,γ1)(\alpha_{1},\beta_{1},\gamma_{1})-sharp and (α2,β2,γ2)(\alpha_{2},\beta_{2},\gamma_{2})-sharp. Under the shorthand that pkMk.massM1.mass+M2.mass\smash{p_{k}\equiv\frac{\mathsf{M}_{k}\mathsf{.mass}}{\mathsf{M}_{1}\mathsf{.mass}+\mathsf{M}_{2}\mathsf{.mass}}} and μkMk.sensitivity\mu_{k}\equiv\mathsf{M}_{k}\mathsf{.sensitivity}, the tuple (M1,M2)(\mathsf{M}_{1},\mathsf{M}_{2}) is (α,β,γ)(\alpha,\beta,\gamma)-sharp for:

Taken together, Propositions 8 and 9 specify a recursive procedure for computing the sharpness of any compound module that is built from a set of well-normed modules of known sharpness.

These two sets of formulas are actually associative, as the reader may verify using their favorite computer algebra package. This means, for instance, that if M1,M2,M3\mathsf{M}_{1},\mathsf{M}_{2},\mathsf{M}_{3} are successively composable, well-normed and each (αk,βk,γk)(\alpha_{k},\beta_{k},\gamma_{k})-sharp, then the two sets of sharpness estimates coming from applying the above formulas for M3(M2M1)\mathsf{M}_{3}\circ(\mathsf{M}_{2}\circ\mathsf{M}_{1}) and (M3M2)M1(\mathsf{M}_{3}\circ\mathsf{M}_{2})\circ\mathsf{M}_{1} actually coincide.

C.3 Sharpness under module broadcasting

Suppose M\mathsf{M} is a well-normed module with inputs X\mathcal{X}, outputs Y\mathcal{Y} and weights W\mathcal{W}, and suppose moreover that it is (α,β,γ)(\alpha,\beta,\gamma)-sharp. The broadcast module M(h)\mathsf{M}^{(h)} has the same weights, mass, sensitivity and norm, but takes Xh\mathcal{X}^{h} to Yh\mathcal{Y}^{h}.

By Proposition 6, M(h)\mathsf{M}^{(h)} is well-normed, as long as the norms on Xh\mathcal{X}^{h} and Yh\mathcal{Y}^{h} are taken to be

for 1p1\leq p\leq\infty; unless M\mathsf{M} is a bond module (and thus weight-less), we must take S=h1/pS=h^{-1/p}, otherwise SS can be any positive scalar.

A natural question is whether M(h)\mathsf{M}^{(h)} is also sharp, and if so what its sharpness constants are, with respect to these norms. More or less the same proof as for Proposition 6 shows that the α\alpha and β\beta bounds for sharpness are always true, with the same α,β\alpha,\beta. The γ\gamma bound is trickier however, and depends subtly on the chosen S,pS,p. We highlight three cases where one can say something interesting.

Case 1. p=,S=1p=\infty,S=1. For the LL^{\infty} norm, we have that M(h)\mathsf{M}^{(h)} is (α,β,γ)(\alpha,\beta,\gamma)-sharp with the same α,β,γ\alpha,\beta,\gamma by a more or less immediate proof.

Case 2. p<,S=1p<\infty,S=1. For the “standard” LpL^{p}-norms, we have that M(h)\mathsf{M}^{(h)} is (α,β,γ)(\alpha,\beta,\gamma)-sharp with the same α,β,γ\alpha,\beta,\gamma. The proof is direct, using the inequality

Case 3. p=2,S=1/hp=2,S=1/\sqrt{h}. This is the “RMS norm” case. As in Case 2, one could use a very weak inequality to obtain the pessimistic result that M(h)\mathsf{M}^{(h)} is (α,β,hγ)(\alpha,\beta,\sqrt{h}*\gamma)-sharp. However, one could also make the following observation: if hh is large, and x1,,xhx_{1},\ldots,x_{h} are sampled from any normal distribution N(μ,σ2)N(\mu,\sigma^{2}), then

In particular, this justifies the statement that “for large hh, the broadcast module M(h)\mathsf{M}^{(h)} is approximately (α,β,3γ)(\alpha,\beta,\sqrt{3}*\gamma)-sharp”. While in actual deep learning contexts, the assumption that x1,,xhx_{1},\ldots,x_{h} are sampled from a normal distribution may not be valid, one should still expect the ratio between the two sides of Equation C.13 to stay O(1) as hh\to\infty, and so even if the “3\sqrt{3} rule” is insufficient, the effective sharpness of the broadcast module should not blow up as hh\to\infty.

C.4 Smoothness estimates for common error measures

We now present estimates on σ\sigma and τ\tau for square and cross-entropy error. Both estimates will be in terms of the value of the average loss function L\mathcal{L} itself, rather than being truly global over the entire output space Y\mathcal{Y}. Thus, to apply them to real learning problems, one should measure the average loss L\mathcal{L} at initialization, and use this for estimates for σ\sigma and τ\tau; we are implicitly making the assumption that under gradient descent the loss decreases.

Square error

The desired constants σ,τ\sigma,\tau can then be computed as maxima:

which from the above formulas amounts exactly to

To translate this into a bound for the average loss function L\mathcal{L}, note that square root is a concave function. Thus if we have outputs y1,,yB{\bm{y}}_{1},\ldots,{\bm{y}}_{B} with true classes t1,,tBt_{1},\ldots,t_{B}, Jensen’s inequality yields

allowing us to use σ=L\sigma=\sqrt{\mathcal{L}} as our estimate for Proposition 5.

Cross-entropy error

Consider again the RMS norm on Y\mathcal{Y}. An estimate on σ\sigma can thus be computed as

using the basic fact that if p1,,pdp_{1},\ldots,p_{d} are non-negative numbers that sum to 1, then

(Indeed, for fixed ptp_{t}, the left hand side is maximized at p1=1ptp_{1}=1-p_{t} and all other pip_{i} = 0; one then easily checked that 2(p1)2log(p)2(p-1)^{2}\leq-\log(p) for all 0<p10<p\leq 1.)

A similar concavity argument to the square error case then enables us to use σ=dL\sigma=\sqrt{d}*\sqrt{\mathcal{L}} as the first derivative bound for average cross-entropy loss.

The second derivative bound depends on more subtle information geometry. Indeed, τ\tau can be computed to be

where λ\lambda is the largest eigenvalue of the matrix diag(p)ppT\operatorname{diag}({\bm{p}})-{\bm{p}}{\bm{p}}^{T}. It is possible for this eigenvalue to be quite large (for instance, if p1=p2=1/2p_{1}=p_{2}=1/2 and all other pi=0p_{i}=0, then λ=1/2\lambda=1/2). However, the average eigenvalue is

If we presumed that, in the course of a gradient descent optimizing the weights of a module M\mathsf{M}, the output perturbations MΔw\nabla\mathsf{M}\diamond\Delta{\bm{w}} are only generically aligned with the eigenvectors of diag(p)ppT\operatorname{diag}({\bm{p}})-{\bm{p}}{\bm{p}}^{T}, then we could use the “effective” smoothness bound τ=1\tau=1.

Perhaps this is a dubious assumption however. A more conservative, but perhaps still dubious, assumption comes from assuming that the logits y{\bm{y}} have roughly N(0,1)N(0,1) entries—at least this could be more or less true at initialization. In this case, the largest eigenvalue λ\lambda is with high probability bounded as

justifying “approximate” smoothness bound of τ=d\tau=\sqrt{d}.

Appendix Appendix D Experimental Details

All experiments with ResMLP and ResNet are done with the CIFAR-10 image dataset with standard train and test splits. For data augmentation on the training set, we use random crop, random horizontal flip and PyTorch AutoAugment.

For the GPT transformer experiments, we compared three different datasets:

The Shakespeare corpus, using character-level tokens ;

The TinyStories database using sub-word level tokenization;

OpenWebText using sub-word level tokenization .

No data augmentation was used on the language data. We used data splitting code from .

D.2 Architectures

Full details of the ResMLP, ResNet and GPT architectures we used are detailed in Appendices Section B.5 and Section B.7. In every experiment, we used:

h=8h=8 heads for GPT, with query and value dimensions dQ=dV=d/hd_{Q}=d_{V}=d/h where dd is the embedding dimension (width);

context length 128128 for GPT, except for the experiment in Section D.7.

D.3 Hardware

All experiments were run on NVIDIA GPUs using float32-precision. We used a combination of TITAN-RTX, RTX-3090, V100, Ada6000, and H100 devices. Each data point in the experiments takes up to 55 hours, depending on the computing device used. We ran over 1000 training runs in total.

D.4 Comparing to standard nanoGPT architecture

Our implementation of GPT in Modula has certain differences from off-the-shelf architectures such as nanoGPT . We would summarize the overall changes to transformer architecture and training the following three points:

the mathematical architecture has slightly different coefficients;

we initialize weight matrices to be orthogonal rather than Gaussian;

we train using normalized weight updates.

The architectural choices we made were entirely informed by the desire for the network to be well-normed and have unit sensitivity: in particular this means that the network enjoys favorable signal propagation properties. In the language of modules, these architectural changes can be summarized as:

Each residual block in our architecture is of the form

where Block=BlockMLP\mathsf{Block}=\mathsf{Block}_{\mathsf{MLP}} or BlockAttn\mathsf{Block}_{\mathsf{Attn}}, compared to Identity+1LBlock\mathsf{Identity}+\frac{1}{\sqrt{L}}\mathsf{Block} suggested for nanoGPT;

We use a scaled dot product attention with 1dQ\frac{1}{d_{Q}} scaling, rather than 1dQ\frac{1}{\sqrt{d_{Q}}};

We use several additional scalar multiplications to keep the network of unit sensitivity:

Each Attention module (B.44) has a scalar factor of 13\frac{1}{3};

Each MLP module (B.45) has a scalar factor of 2\sqrt{2};

The token and position embeddings (B.49) have a scalar factor of 12\frac{1}{2}.

In Figure 5, we ran a comparison of the performance of the standard (unnormed) Adam optimizer trained on OpenWebText with:

the nanoGPT architecture with Gaussian initialization;

our implementation of GPT with orthogonal initialization.

We found that even without using the normed optimizer, our implementation with orthogonal initialization transferred learning rate better. We suggest that even the base Adam optimizer benefits from the above architectural changes.

D.5 Full sweeps

In Figures 9, 10, 11 and 12, at the end of this Appendix, we report on full learning rate sweep experiments, across width and depth, for GPT on OpenWebText and TinyStories, and ResMLP, ResNet on CIFAR-10.

We consistently find that the normed Adam optimizer matches or outperforms unnormed Adam in both test and training loss, all the while exhibiting significantly better transfer across width. The difference in depth transfer is less stark, however we posit that, in part, unnormed Adam is already benefiting from architectural changes we made to improve depth scaling.

Notice too that normed SGD consistently significantly outperforms ordinary SGD, often coming close to or matching the performance of Adam. We would like to highlight this, since SGD has a significantly lower memory requirement than Adam, and does not require any tuning of β2\beta_{2}.

D.6 Mass allocation

A novel feature of our normed optimization framework is the need to choose a mass parameter for each atomic module. In the context of networks of the form

where HiddenLayers=BlockL\mathsf{HiddenLayers}=\mathsf{Block}^{L}. We typically do this by assuming that InputLayer,OutputLayer\mathsf{InputLayer},\mathsf{OutputLayer} have mass 1, and by hand resetting the mass of HiddenLayers\mathsf{HiddenLayers} to be a fixed total mass m>0m>0, by calling tare(HiddenLayers,m)\mathsf{tare}(\mathsf{HiddenLayers},m).

In this Appendix, we explore some different aspects the choice of mm.

First, we tested whether or not calling tare\mathsf{tare} is necessary in the first place. Not using tare would leave the “free mass” of HiddenLayers.mass=LBlock.mass\mathsf{HiddenLayers}\mathsf{.mass}=L*\mathsf{Block}\mathsf{.mass}; accordingly as LL grows large, the feature learning allotment (see Proposition 3) for InputLayer\mathsf{InputLayer} and OutputLayer\mathsf{OutputLayer} would grow smaller. Indeed, as the reader can see in Figure 6, this “free mass” arrangement for a ResMLP network on CIFAR-10, allowing the mass of HiddenLayers\mathsf{HiddenLayers} to grow with LL is very undesirable, and for good learning rate transfer with depth we must fix a mass.

The mass mm is thus left as a tunable parameter. We then tested the transferability of mass tuning. Specifically, we wanted to know:

whether one can tune mm on a network of small width/depth, and expect that same mm to be close to optimal on a larger network;

whether learning rate transfer across width/depth is itself dependent on selecting a good mass mm;

how sensitive the tuning for mm is: if there is a broad range of acceptable masses, or certain precise values lead to big improvements in train or test loss.

Figures Figure 3 and Figure 6 answer Question 1 above in the affirmative, in the context of ResMLP on CIFAR-10 and GPT on OpenWebText. Moreover, in the context of ResMLP on CIFAR-10, they give an answer of Question 2 and Question 3: learning rate transfer occurs at a range of values of mm.

Figure 7 address Question 3 in the context of transformers, on three different datasets. Across all three datasets, a mass in the region m5m\sim 5 to 1010 is reasonable.

D.7 Context length

Additionally, we also tested the dependence of the optimal learning rate for GPT training on OpenWebText on the context length; the results are in Figure 8 Interestingly, we report good transfer of the optimal learning rate from small contexts to long contexts.

D.8 Full sweep results

The next four pages of the Appendix list results of our full learning rate sweeps over width/depth for GPT on OpenWebText and TinyStories, and ResMLP, ResNet on CIFAR-10.

Appendix Appendix E Proofs

To prove Proposition 3, it suffices to induct on the construction of a compound module M\mathsf{M} by composition and concatenation, with the atomic modules (where the inequality is just part of well-normed-ness) as the base case.

Indeed, suppose either M=M2M1\mathsf{M}=\mathsf{M}_{2}\circ\mathsf{M}_{1} or M=(M1,M2)\mathsf{M}=(\mathsf{M}_{1},\mathsf{M}_{2}). Suppose wk{\bm{w}}_{k} is a weight for one of the atomic modules of M\mathsf{M}, and write mm for the mass of this atomic module. Then wk{\bm{w}}_{k} is must be a weight of either M1\mathsf{M}_{1} or M2\mathsf{M}_{2}; the inductive assumption is that

Case 1. M=M2M1\mathsf{M}=\mathsf{M}_{2}\circ\mathsf{M}_{1} and wk{\bm{w}}_{k} is a weight of M1\mathsf{M}_{1}. From the chain rule we then must have:

where the last line is by the definition of the norm of module composition.

Case 2. M2M1\mathsf{M}_{2}\circ\mathsf{M}_{1} and wk{\bm{w}}_{k} is a weight of M2\mathsf{M}_{2}. The chain rule is not needed in this case, and we proceed straight from the inductive assumption:

Case 3. M=(M1,M2)\mathsf{M}=(\mathsf{M}_{1},\mathsf{M}_{2}). Given the symmetric roles of M1,M2\mathsf{M}_{1},\mathsf{M}_{2}, without loss of generality assume wk{\bm{w}}_{k} is a weight of M1\mathsf{M}_{1}. Then,

Proposition 4: Sharpness of residual networks

Suppose M\mathsf{M} is a well-normed module of unit sensitivity on (X,X,W)(\mathcal{X},\mathcal{X},\mathcal{W}) and is (α,β,γ)(\alpha,\beta,\gamma)-sharp. Then, by Proposition 8 for any L1L\geq 1, the module 1LM\tfrac{1}{L}*\mathsf{M} is well-normed, sensitivity 1L\tfrac{1}{L}, and (Lα,β,1Lγ)(L\alpha,\beta,\tfrac{1}{L}\gamma)-sharp.

The module L1LIdentity\tfrac{L-1}{L}*\mathsf{Identity} is also well-normed, sensitivity L1L\tfrac{L-1}{L}, and (0,0,0)(0,0,0)-sharp. In particular, the sum

is well-normed, unit sensitivity, and (Lα,β,1Lγ)(L\alpha,\beta,\tfrac{1}{L}\gamma)-sharp; it has the same mass as the original module.

We induct on the statement for k=1,2,k=1,2,\ldots that Mresk\mathsf{M}_{res}^{k} is (αk,βk,γk)(\alpha_{k},\beta_{k},\gamma_{k})-sharp where

The base case is clearly true, and given the statement for Mresk\mathsf{M}_{res}^{k}, which has exactly kk times the mass as Mres\mathsf{M}_{res}, we see that Mresk+1=MresMresk\mathsf{M}_{res}^{k+1}=\mathsf{M}_{res}\circ\mathsf{M}_{res}^{k} is (αk+1,βk+1,γk+1)(\alpha_{k+1},\beta_{k+1},\gamma_{k+1})-sharp by applying Proposition 8 with p1=kk+1p_{1}=\tfrac{k}{k+1} and p2=1k+1p_{2}=\tfrac{1}{k+1}, where

Setting k=Lk=L, observe that 1+2++(L1)=12L(L1)1+2+\ldots+(L-1)=\tfrac{1}{2}L(L-1) and 12+22++(L1)2=16L(L1)(2L1)1^{2}+2^{2}+\ldots+(L-1)^{2}=\tfrac{1}{6}L(L-1)(2L-1), giving

Proposition 5: Smoothness in the modular norm

The second inequality follows from the first via the fundamental theorem of calculus:

The third inequality follows from the second by again applying the fundamental theorem of calculus, followed by the Cauchy-Schwarz inequality:

Proposition 6: Broadcast modules are well-normed

Suppose M\mathsf{M} is a module with inputs X\mathcal{X}, outputs Y\mathcal{Y} and weights W\mathcal{W}, broadcast to take Xh\mathcal{X}^{h} to Yh\mathcal{Y}^{h}. We take norms on these spaces to be

where S=h1/pS=h^{-1/p} unless M\mathsf{M} is a bond module. Write μ=M.sensitivity\mu=\mathsf{M}\mathsf{.sensitivity}. Then, for perturbations in the weight direction, which only occur if M\mathsf{M} is not a bond module:

For perturbations in the input direction, we have:

Proposition 7: Sensitivity of attention

We prove that the functional attention module FuncAttention\mathsf{FuncAttention} of 9 is well-normed and of unit sensitivity.

Writing F=FuncAttention.forwardF=\mathsf{FuncAttention}\mathsf{.forward} for short, recall that

where M{\bm{M}} is the mask (our proof will apply equally for the standard causal mask and also the non-causal M0{\bm{M}}\equiv 0).

We will prove that at any (q,k,v)({\bm{q}},{\bm{k}},{\bm{v}}) satisfying qRMS,kRMS,vRMS1\|{{\bm{q}}}\|_{\infty\mathsf{RMS}},\|{{\bm{k}}}\|_{\infty\mathsf{RMS}},\|{{\bm{v}}}\|_{\infty\mathsf{RMS}}\leq 1, for any (Δq,Δk,Δv)(\Delta{\bm{q}},\Delta{\bm{k}},\Delta{\bm{v}}) we have

For short, write A=softmax(1dQqkT+M){\bm{A}}=\operatorname{softmax}(\tfrac{1}{d_{Q}}{\bm{q}}{\bm{k}}^{T}+{\bm{M}}) for the attention matrix and its derivative as

Now, the derivative of FF splits into two terms

The calculation of the norm of A{\bm{A}} follows by definition from its construction by softmax. For the calculation of the norm of ΔA\Delta{\bm{A}}, a direct calculation yields that

where we are writing qi=qi{\bm{q}}_{i}={\bm{q}}_{i*} and so on.

Taking absolute values, applying the Cauchy-Schwarz inequality and summing over jj we have

Applying to the matrix ΔA\Delta{\bm{A}}, we thus have

Taking the max over ii, this shows the LL^{\infty}-operator-norm of ΔA\Delta{\bm{A}} is at most

which, since qRMS,kRMS1\|{{\bm{q}}}\|_{\infty\mathsf{RMS}},\|{{\bm{k}}}\|_{\infty\mathsf{RMS}}\leq 1, completes the proof.

Proposition 7: Sharpness of functional attention

In this section, we estimate the second derivative of the forward function FF of functional attention at (q,k,v)({\bm{q}},{\bm{k}},{\bm{v}}) in perturbation directions (Δq,Δk,Δv)(\Delta{\bm{q}},\Delta{\bm{k}},\Delta{\bm{v}}) and (Δq~,Δk~,Δv~)(\Delta\widetilde{{\bm{q}}},\Delta\widetilde{{\bm{k}}},\Delta\widetilde{{\bm{v}}}):

We will prove that functional attention is γ\gamma-sharp where in fact γ=3\gamma=3; this amounts to proving that

With these conventions, the expression for ΔA\Delta{\bm{A}} is thus

In these terms, the second derivative Δ2F\Delta^{2}F is just

From the estimates of the previous section, we have

so our task is to estimate the LL^{\infty}-operator-norm of Δ2A\Delta^{2}{\bm{A}}. Thus, we calculate Δ2A\Delta^{2}{\bm{A}}:

We estimate the LL^{\infty}-operator-norm of these six terms one by one. The first (E.72), (E.73) are the simplest, using inequality (E.62):

Take absolute values, sum over jj, and apply Cauchy-Schwarz and inequalities (E.63),(E.64):

Taking the max over ii and applying qRMS,kRMS,vRMS1\|{{\bm{q}}}\|_{\infty\mathsf{RMS}},\|{{\bm{k}}}\|_{\infty\mathsf{RMS}},\|{{\bm{v}}}\|_{\infty\mathsf{RMS}}\leq 1:

and so by Cauchy-Schwarz and the fact that the rows of A{\bm{A}} sum to 1:

By a similar argument, for term (E.77) we have:

Thus, we have an estimate on the LL^{\infty}-operator-norm of Δ2A\Delta^{2}{\bm{A}}:

where all the norms on the right hand side are RMS\|{\cdot}\|_{\infty\mathsf{RMS}}.

Adding this together with (E.70) and (E.71), we obtain (all norms being RMS\|{\cdot}\|_{\infty\mathsf{RMS}}:

Proposition 8: Sharpness under composition

Suppose M=M2M1\mathsf{M}=\mathsf{M}_{2}\circ\mathsf{M}_{1} where M1,M2\mathsf{M}_{1},\mathsf{M}_{2} are well-normed modules on respectively (Xk,Yk,Wk)(\mathcal{X}_{k},\mathcal{Y}_{k},\mathcal{W}_{k}) and moreover (αk,βk,γk)(\alpha_{k},\beta_{k},\gamma_{k})-sharp for k=1,2k=1,2. If pk=Mk.massM.massp_{k}=\tfrac{\mathsf{M}_{k}\mathsf{.mass}}{\mathsf{M}\mathsf{.mass}} for k=1,2k=1,2, note that by the definition of the modular norm on the composite M\mathsf{M}, we have for any Δw=(Δw1,Δw2)W1×W2\Delta{\bm{w}}=(\Delta{\bm{w}}_{1},\Delta{\bm{w}}_{2})\in\mathcal{W}_{1}\times\mathcal{W}_{2}:

We must prove that M\mathsf{M} is (α,β,γ)(\alpha,\beta,\gamma) sharp where:

Turning to the second derivative of M(,)\mathsf{M}(\cdot,\cdot), we prove the first Inequality (E.97). The Gauss-Newton decomposition (C.1) for any Δw=(Δw1,Δw2)\Delta{\bm{w}}=(\Delta{\bm{w}}_{1},\Delta{\bm{w}}_{2}) and Δw~=Δw~1,Δw~2\Delta\widetilde{{\bm{w}}}=\Delta\widetilde{{\bm{w}}}_{1},\Delta\widetilde{{\bm{w}}}_{2} yields

Applying the well-normed and sharpness inequalities, the norm of the first (E.100) of these terms is bounded by

The second term (E.101) breaks into four separate terms:

In particular, applying the well-normed and sharpness inequalities, this is bounded by

which completes the proof of Inequality (C.4).

Inequalities (E.98) and (E.99) are simpler. For the first of these, note we have

Term (E.114) breaks into two separate terms

which completes the proof of Inequality (E.98).

Proposition 9: Sharpness under concatenation

Suppose M=(M1,M2)\mathsf{M}=(\mathsf{M}_{1},\mathsf{M}_{2}) where M1,M2\mathsf{M}_{1},\mathsf{M}_{2} are well-normed modules on respectively (Xk,Yk,Wk)(\mathcal{X}_{k},\mathcal{Y}_{k},\mathcal{W}_{k}) and moreover (αk,βk,γk)(\alpha_{k},\beta_{k},\gamma_{k})-sharp for k=1,2k=1,2. If pk=Mk.massM.massp_{k}=\tfrac{\mathsf{M}_{k}\mathsf{.mass}}{\mathsf{M}\mathsf{.mass}} for k=1,2k=1,2, as in the previous proof we have for any Δw=(Δw1,Δw2)W1×W2\Delta{\bm{w}}=(\Delta{\bm{w}}_{1},\Delta{\bm{w}}_{2})\in\mathcal{W}_{1}\times\mathcal{W}_{2}:

We must prove that M\mathsf{M} is (α,β,γ)(\alpha,\beta,\gamma)-sharp where

Now, for the first of these identities, we have for Δw=(Δw1,Δw2)\Delta{\bm{w}}=(\Delta{\bm{w}}_{1},\Delta{\bm{w}}_{2}) and Δw~=(Δw~1,Δw~2)\Delta\widetilde{{\bm{w}}}=(\Delta\widetilde{{\bm{w}}}_{1},\Delta\widetilde{{\bm{w}}}_{2}):

which shows α=p12α1+p22α2\alpha=p_{1}^{2}\alpha_{1}+p_{2}^{2}\alpha_{2}. The expressions for β,γ\beta,\gamma follow similarly.