Scalable Optimization in the Modular Norm
Tim Large, Yang Liu, Minyoung Huh, Hyojin Bahng, Phillip Isola, Jeremy Bernstein
Introduction
Given the practical impact of deep learning systems trained at the largest scale, there is a need for training algorithms that scale gracefully: without instability and—if possible—without manual tuning. However, current best practices for training have developed somewhat organically and do not live on a bedrock of sound numerical analysis. For example, while the Adam optimizer is ubiquitous in the field, errors have been found in its proof of convergence , and empirically Adam has been found to scale poorly as either the width or the depth of the network is ramped up.
To remedy this situation, a patchwork of learning rate correction factors have recently been proposed . The general idea is to retrofit a base optimizer such as Adam or SGD with special correction factors intended to render the optimizer’s optimal learning rate invariant to scale. But this situation is not ideal: the correction factors are reportedly difficult to use. Lingle suggests that this may be due to their “higher implementation complexity, many variations, or complex theoretical background”. What’s more, the correction factors are optimizer-specific, meaning that if one switches to a different optimizer one must either look up or recalculate a separate set of correction factors.
The goal of this paper is to simplify matters. We show that both Adam and SGD can be made to scale gracefully with width and depth by simply normalizing their updates in a special norm associated with the network architecture—see Figure 1. We call this norm the modular norm, and provide a Python package called Modula that constructs this norm automatically and in tandem with the architecture.
The modular norm is constructed recursively, leveraging the module tree perspective on neural architectures. It is enough to define how the modular norm propagates through only two elementary operations: composition and concatenation. We show how other basic operations on modules, such as addition and scalar-multiplication, can be implemented through composition and concatenation. And then higher-order structures, such as residual networks, can be built using these basic operations.
Beyond its practical relevance, the modular norm may also prove useful to theoreticians. Various optimization-theoretic quantities are accessible and efficiently calculable in the modular norm. For instance, we show that the gradient of any neural network built from “well-behaved” atomic modules is Lipschitz-continuous in the modular norm of the architecture. This opens the door to porting several more-or-less textbook optimization theory analyses over to the world of deep learning.
It is by now well-known that deep networks do not easily or naturally admit Lipschitz-continuity or smoothness guarantees in the Euclidean norm . Researchers have attempted to address this problem: for instance, Bernstein et al. propose a distance function called deep relative trust, which combines Frobenius norms across network layers. However, deep relative trust is only constructed for the multilayer perceptron and, when used to normalize updates, its employment of the Frobenius norm precludes good width scaling. In contrast, Yang et al. equip individual layers with the RMS–RMS operator norm, finding this to enable good width scaling. Researchers have also looked at building neural net distance functions outside the context of scalability .
Asymptotics
The metrization-based approach to scaling developed in this paper contrasts with the tradition of asymptotic scaling analyses—the study of infinite width and depth limits—more common in the deep learning theory literature . These asymptotic analyses follow an old observation of Neal that interesting properties of the neural network function space are exactly calculable in the infinite width limit and at initialization. This tradition has continued with asymptotic studies of the neural tangent kernel as well as infinite depth limits . However, there is increasing recognition of the limits of these limits, with researchers now often trying to relax limiting results . And ultimately, from a practitioner’s perspective, these results can be difficult to make sense of . In contrast, our framework eschews any kind of limiting or probabilistic analysis. As a consequence, we believe our framework is simpler, more easily relatable to basic mathematical concepts, and ultimately more relevant to what one may encounter in, say, a PyTorch program.
Majorization
In recent work, Streeter and Dillon propose a universal majorize-minimize algorithm : a method that automatically computes and minimizes a majorizer for any computational graph. Despite its generality, current downsides to the method include its overhead, which can be 2 per step , as well as the risk that use of a full majorization may be overly pessimistic. Indeed, Cho and Shin find that an optimization approach leveraging second-order information converges significantly faster than a majorization-inspired approach. Related ideas appear in .
Descent in Normed Spaces
We define the modular norm in Section 3. This section is intended to prime the reader for what is to come. In this section, and the rest of the document, the diamond operator denotes tensor contraction.
the loss function is differentiable, meaning that the gradient map exists;
the loss is Lipschitz smooth in the norm , with sharpness constant , meaning that:
Under these conditions, the weight update given by is guaranteed to reduce the loss. The particular norm influences the direction of this weight update, while the sharpness constant influences the size of the update.
In deep learning, we would ideally like the optimal step-size to remain invariant as we scale, say, the width and the depth of the network. Thus, a fundamental problem is to design a norm such that, first, Equation 2.1 actually holds (and is not hopelessly lax), and second, the corresponding sharpness constant is invariant to the relevant architectural dimensions. If the norm is chosen poorly, the practitioner may end up having to re-tune the step size as the network is scaled up. In this paper, we design a norm for neural networks that meets these requirements: the modular norm.
2 Preview of the modular norm
turns out to hold quite tightly when is a gradient update and is a corresponding layer input. This is because gradient updates to a layer are (sums of) outer products that align with layer inputs.
Once we know how to metrize individual layers, a natural question is: can we combine layer-wise norms to produce a norm on the full weight space of the network? Naïvely, there are many ways to do this: one could take any positive linear combination of the layer-wise norms ( combination), the square root of any combination of the squared layer-wise norms ( combination), and so on. But we want the norm to be useful by the criteria of Section 2.1. To this end, we propose the modular norm , which ends up as a max ( combination) of scaled layer-wise norms :
The positive scalar constants are determined by both the architecture of the network and a set of user-specified “mass” parameters. The precise construction of the modular norm, working recursively over the module tree of the network, is given in Section 3; there, we also explain how the modular norm satisfies the criteria of Section 2.1, and the role played by the mass parameters. For now, let us explain what good the modular norm yields in practice.
3 Normed optimization
The main practical use of the modular norm is to normalize weight updates. With reference to Equation 2.3, we define the following operation on weight updates :
Provided none of the are zero, then is a unit vector in the modular norm. We propose using as a wrapper, along with an explicit learning rate schedule, for any base optimizer such as Adam or SGD. The resulting normed optimizer is thus made architecture-aware via the normalize function. In pseudo-code—and actual Modula code—this amounts to:
delta_w = optim(w.grad()) # get update from base optimizer net.normalize(delta_w) # normalize update in the modular norm w -= eta(step) * delta_w # apply update with learning rate eta
We find this wrapper to significantly improve the scalability of the base optimizer. It renders the optimal learning rate roughly invariant to width and depth, with seemingly no cost to accuracy. In some instances, it enables training with a simpler optimizer—for example, training GPT with SGD rather than Adam—thus incurring a smaller memory footprint.
Normalization in the modular norm essentially forces individual layers to learn at specified, regulated rates. We view this as balancing learning across the network; no individual layer can learn too fast and destabilize training. This balance is determined by the architecture, along with user-specified mass parameters that provide precise control over the relative learning speed in different submodules.
For a variety of experiments with normed optimization, see Sections 4 and Appendix D. But first, we detail the construction of the modular norm along with its core properties.
Constructing the Modular Norm
Our strategy is to first define the abstract notion of a module, which includes a norm as an attribute. We depict this concept in Figure 2. Then, by providing rules for composing and concatenating modules, we recursively define a norm for any module built via an arbitrary sequence of compositions and concatenations: the modular norm!
A module is a re-usable, composable object useful for building complicated neural networks. Our definition of a module augments the PyTorch module with two real numbers and a norm:
Given input vector space , output vector space and weight vector space , a module is an object with the following four attributes:
a function, , which maps an input and a weight vector to an output—we often abbreviate this attribute to just ;
a number, , which will turn out to set the proportion of feature learning that this module contributes to any supermodule;
a number, , which estimates the module’s sensitivity to input perturbations;
Before we say more about the intended roles of these attributes, let us mention the three kinds of modules that we will care about in practice:
atomic modules, whose attributes are hand-declared, and have weights. Examples include linear modules, embedding modules, and convolution modules.
bond modules, whose attributes are hand-declared, but have no weights. Formally, their weight space is the zero vector space . An example is the non-linearity module.
compound modules, built out of other modules, with automatically inferred attributes.
Note that the space of objects that type-check as a module by Definition 1 is vast. Since we need to hand-declare atomic and bond modules in order to build interesting compound modules, we should have an idea of what makes for a “good” module. Simply put, a module is good when its attributes are predictive of its behaviour. To formalize this idea, we say that a module is well-normed if its forward function, sensitivity, and norm satisfy the following two relationships:
Let be a module on , where the input and output spaces have respective norms and . is well-normed if for all inputs and weights :
Well-normed-ness means that the norm function and sensitivity are a good match for the forward function. The first inequality says that a well-normed module is Lipschitz-continuous over its weight space with a constant one. The second inequality says that a well-normed module is Lipschitz-continuous over its input space with constant . In practice, we will be interested in well-normed modules where these inequalities hold fairly tightly, since then and will let us estimate the sensitivity of the module to input and weight perturbations. Appendix B provides many examples of well-normed atomic and bond modules.
The remaining attribute will turn out to control the proportion of feature learning that a module contributes to any compound module in which it participates. We formalize this concept in Section 3.3. But before that, we need to understand how to build compound modules.
2 Compound modules: Building new modules from old
We consider building new modules from old ones via the binary operations of composition and concatenation, illustrated in Figure 2. Composition is denoted via the serial combination , and concatenation via the parallel combination , alternatively referred to as a module tuple. These simple binary combinations will let us build basic algebraic operations on modules (Table 1) as well as complex neural network architectures. We start by defining module composition:
Consider module with input, output and weight space and module with input, output and weight space . and are composable if . Their composite lives on with attributes:
;
;
;
given by:
where if or is zero, the corresponding term in the is set to zero.
At this stage, we make two comments about this definition. First, in the definition of the composite norm, notice that the norm of the first module couples with the sensitivity of the second module. This reflects the fact that the output of the first module is fed into the second module and not vice versa. Second, observe that the masses of the submodules are involved in setting the balance of the composite norm. Before we further motivate this definition, let us first define module concatenation:
Consider module with input, output and weight space and module with input, output and weight space . We say that and are concatenatable if their input spaces match: . The tuple has input, output and weight space and attributes:
;
;
;
given by:
where if or is zero, the corresponding term in the is set to zero.
Concatenation is simpler than composition in the sense that neither module is fed through the other, and therefore, sensitivity does not appear in the concatenated norm. To further motivate these definitions, observe that two basic and desirable properties follow as immediate consequences:
If modules are successively composable, then equals in all attributes. If modules are mutually concatenatable, then equals in all attributes.
If modules and are well-normed and composable, then their composite is also well-normed. If modules and are well-normed and concatenatable, then their tuple is also well-normed with respect to the combination norm on the output space:
The proofs follow directly from the definitions and the chain rule. Proposition 1 implies that one may build complicated compound modules without worrying in which order successive combinations are taken. Proposition 2 implies that complicated compounds automatically inherit Lipschitz guarantees.
Taken together, Definitions 3 and 4 define the modular norm of any compound module .
3 Mass allocation in compound modules
Suppose we wish to train a network with an input layer, an output layer, and blocks between:
Then how much learning should happen in the output layer, compared to the blocks, compared to the input layer? And what if we scale the number of blocks —do we want relatively less learning to occur in the network’s extremities? Or do we want the input and output layers to learn non-trivially even in the limit? Since answering these questions is difficult a priori, we introduced the mass parameter to allow a user to set the proportional contribution each module has toward learning:
Consider a compound module derived in any fashion from well-normed modules . Given weight setting , where denote the weights of module , let us perturb by . If we decompose the linearized change in the output of module into one contribution per sub-module:
then the th term in this decomposition satisfies:
In words: module mass provides the flexibility needed to build complicated compound modules involving many sub-modules, while maintaining precise control over how much learning any sub-module can contribute to the overall compound. Proposition 3 is proved in Appendix E.
In practice, we obtained the best training performance by maintaining a constant amount of learning in the input and output layers even as the number of blocks is scaled (Figure 6). In other words, it seems to be a good idea to assign in proportion , where is independent of the number of blocks . The exact mass of the hidden layers needs to be tuned on a new architecture—just as one needs to tune separate learning rates in the input and output layers in P ; this tuning can be done on a small model prior to scaling (Figure 3). We further discuss mass allocation in Section D.6.
4 Smoothness in the modular norm
In this section, we study the second derivatives of a module using the modular norm as a measuring stick. Let us start by defining the notion of sharpness that we will consider:
Let be a module on , where the input and output spaces have respective norms and . We say that is -sharp for constants if, at all inputs and weights , the second derivatives of are bounded as:
While one may ultimately be interested in the sharpness of a module with respect to weight perturbations, Definition 5 also tracks sharpness with respect to input perturbations. In fact, tracking this extra information is essential for propagating sharpness bounds up the module tree. Appendix C details the procedure for automatically calculating the sharpness constants of a compound module starting from the sharpness constants of all its submodules; see Propositions 8 and 9 for the specific formulae. Here we highlight one major corollary of these formulae, proved in Appendix E: for a specific choice of block multipliers, the sharpness constant of a residual network is independent of depth:
Suppose is a well-normed, -sharp module on with unit sensitivity. Define the depth residual module via the module arithmetic of Table 1 as:
Then this residual module is in fact -sharp, independent of depth .
For optimization purposes, one may be more interested in the sharpness of the loss function rather than the sharpness of the neural network. Fortunately, it is possible to convert sharpness bounds on modules into sharpness bounds on loss functions, provided a little is known about the error measure:
If the module is well-normed and -sharp, then the loss function satisfies the following three inequalities at all weight settings and for all weight perturbations :
,
where is the dual norm of ;
The proof is given in Appendix E, and we present estimates for and for common error measures in Section C.4. Notice that inequalities (i), (ii) and (iii) are the standard inequalities of smooth optimization , albeit expressed in the modular norm. In fact, (i) implies (ii) implies (iii). In words, inequality (ii) says that the gradient of the loss is Lipschitz-continuous in the modular norm. The Lipschitz constant depends on the module only through the module’s first sharpness coefficient .
Experiments
Our experiments aimed to test the scalability of training with normed versions of Adam and SGD: whether one can tune the learning rate on a small model, and expect the learning rate to remain close to optimal on models of much larger width and depth. In addition to the learning rate, normed optimization in Modula requires a mass parameter to apportion feature learning between the input, output and hidden layers; we also tested the sensitivity of this parameter, whether it affects learning rate transfer, and to what extent the optimal mass itself transfers across width and depth.
All SGD experiments were done with momentum , and all Adam experiments used and . No weight decay was used in any experiment. Every experiment was done with a linear decay learning rate schedule. As for initialization, we used orthogonal initialization for and modules, and Gaussian weights projected to a unit norm ball for our module. This was to ensure all modules were well-normed at initialization. Precise versions of our architectures are described in Appendices B.5 and B.7. We compare with nanoGPT using standard initialization in Section D.4 to make sure our changes recover standard performance. We actually found unnormed Adam using our GPT architecture transferred learning rate better than in nanoGPT.
We found that normed optimization, with both Adam and SGD as the base optimizer, allows for successful learning rate transfer across width and depth for GPT training on OpenWebText (Figure 1), as well as ResMLP and ResNet training on CIFAR-10 (Figure 4). We present expanded results in Section D.5, including results on test loss. We reproduce the standard finding that train and test loss are remarkably simillar in large language model pretraining. As for mass allocation, Figure 3 shows that optimal mass transfers with depth for training a ResMLP on CIFAR-10 with normed Adam, and also that both mass and learning rate transfer quite well from a smaller GPT on OpenWebText to a larger one. We detail more experiments on mass allocation in Section D.6.
Discussion: Limitations and Future Work
This paper was influenced by four main streams of work: first, the Tensor Programs series, starting at TP-IV ; second, the papers on universal majorize-minimize algorithms ; third, work on deep network metrization ; and fourth, the open source deep learning ecosystem including the PyTorch module tree and Karpathy’s YouTube video on autograd . We have distilled and synthesized key ideas from these sources, creating a framework that we believe to be simpler than Tensor Programs, computationally lighter than universal majorization-minimization, more general than prior work on metrization and more scalable than the PyTorch module tree. We have packaged these ideas into a (soon-to-be) open-source library called Modula. Inevitably, Modula has limitations. We highlight some of them here, along with associated avenues for future work.
Loss of well-normed-ness. We have emphasized well-normed-ness (Definition 2) as an important criterion in module design. We show in Section B.1 that, for example, the module is well-normed when its weights lie within a spectral norm ball. In our experiments, we initialize all weights so that all modules are well-normed, but we do not enforce this property throughout training. Future work could explore regularization as a means to enforce well-normed-ness throughout training, with the hope of attaining better scalability or improved generalization.
Overhead of normalization. As discussed in Section A.3, we implement normalization for and modules using two steps of online power iteration. While online power iteration is an established and fast primitive in deep learning—in fact, coming from the GAN literature —it does add a modest overhead to training time, as discussed in Section A.4. We think it may be possible to mitigate this overhead by constructing atomic modules with more exotic operator norms. For example, if one equips feature vectors with the norm rather than the RMS norm, then the induced – matrix norm is cheaper to compute than the RMS–RMS operator norm. In fact, – operator normalization has the convenient feature that it decouples over matrix rows, making it more local than spectral normalization and, dare-we-say, more biologically plausible.
Automatic step-size selection. Beyond scalability, recent work has explored the question of automatic learning rate selection , with the Prodigy optimizer serving as a popular example. We tested the Adam version of Prodigy and found it performs well at small scales, essentially working by an implicit form of line search. However, Prodigy will always break at large enough widths, since it requires a lower bound () on Adam’s initial learning rate; Yang et al. showed that no such lower bound exists. We believe this issue could be fixed by rebuilding Prodigy on top of Modula. More broadly, we think that designing line search methods in a properly-normed space is a good idea.
Acknowledgements
We are grateful to Chris Mingard, Virgile Richard and Evan Kiely for useful discussions early in the project. Tongzhou Wang and Jyo Pari provided helpful feedback on the writing and figures.
The work was supported by a Packard Fellowship and a Sloan Research Fellowship to PI, by the MIT-IBM Watson AI Lab, by ONR MURI grant N00014-22-1-2740 and the MIT Quest for Intelligence. TL was supported by a Simons Junior Fellowship.
Contribution Statement
All authors were involved in project conception and discussions, which were initiated by JB. TL developed the theory with input from JB. MH and YL made core experimental observations. YL, MH, JB, and HB ran experiments. TL and JB did most of the writing, while JB, MH and YL made the figures. PI contributed guidance and helpful feedback throughout the course of the project. JB wrote the Modula package with help from MH.
References
Appendix Appendix A The Modula Package
We created a Python package called Modula that realizes our module framework in code. Modula supplements PyTorch’s Tensor class with two new classes: Vector and Module.
The Vector class is used to store the weights of a module. It allows for basic algebraic operations to be performed on module weights without needing to write for loops over lists of tensors. For example, if v_1 and v_2 are vectors with the same sub-structure, then one may write expressions such as v_1 + v_2 for the vector sum, or v_1 * v_2 for the elementwise product. Internally, a Vector stores a list of tensors and implements operations using efficient PyTorch foreach primitives.
A.2 The Module class
The most significant aspect of the Modula package is the Module class. A Module must have six attributes: two float attributes, namely mass and sensitivity. And four methods:
forward(w: Vector, x: Tensor) -> Tensor # returns an output tensor
initialize() -> Vector # randomly samples a weight vector
normalize(w: Vector) # normalizes w to have unit modular norm
regularize(w: Vector, strength: float) # regularizes w in-place
The norm of a module is not specifically implemented, instead we use the normalize method which is how the norm is directly used in optimization.
We refer to modules with hand-specified attributes as bonds if they have no weights and atoms if they have weights. Modules formed by combining existing modules are called compounds. Modula automatically constructs the attributes of compound modules. We provide reference implementations for many common modules—see Appendix B. We equip atoms with their natural operator norm, and compute spectral norms via online power iteration. Reference modules may be imported as follows:
from modula.bond import Identity, ReLU, Abs, FunctionalAttention from modula.atom import Linear, Embed, Conv2D from modula.compound import ResMLP, ResCNN, Attention, GPT
To make building new compounds easier, Modula overloads the following operations on modules:
M_2 @ M_1 # composes module M_2 with module M_1
(M_1, M_2) # acts as a tuple module in any further composition
M ** L # returns the Lth iterate of module M
builds an L-layer residual network from base module M. Comparing with Equation 3.10, we see that Modula expressions closely resemble their mathematical counterparts.
Finally, all modules come with a convenience method tare(m: float), which resets the module mass to m, with default m=1.
A.3 Normalization in Modula
We can normalize any base optimizer in the modular norm using the following pattern:
delta_w = optim(w.grad()) # get update from base optimizer net.normalize(delta_w) # normalize update in the modular norm w -= lr * delta_w # apply update to weights
Computation of net.normalize(delta_w) requires an efficient estimation of the spectral matrix norm, in the last two dimensions, of the constituent tensors of delta_w; this can be done very quickly to reasonable accuracy using power iteration. We implement this by storing a running estimate of the top singular vector u for each constituent tensor of delta_w. At initialization, u is sampled Gaussian, and each time we normalize a weight update, the previous update’s estimated singular vector is used as the starting value for the power iteration. This enables us to use just two steps of power iteration per weight update. Indeed, for any base optimizer with momentum, successive weight updates should be fairly close; for training without momentum more steps of power iteration may be required.
A.4 Overhead
To test the overhead of normalization in the modular norm, we trained a width 64 ResMLP with 8 blocks and block-depth 2 for 10k steps on the CIFAR-10 dataset. We repeated the experiment with and without normalization, and in each case with three different random seeds. Without normalization, the training took seconds, and with normalization the training took seconds. So in this experiment, the overhead of modular normalization was around 23%.
We note that the user of the Modula package is free to write new atomic modules with cheaper or more efficient normalize functions. For instance, the Frobenius norm can be used as a proxy for the spectral norm whenever the weight updates have low stable rank . And we note in Section 5 that one could explore more exotic norms such as the – operator norm, which is cheaper to compute than the standard spectral norm. Beyond these suggestions, one could explore CUDA-level optimizations to spectral norm computation, which is something that we have not explored.
Appendix Appendix B Module and Network Design
In this appendix, we list the basic, hand-declared modules that serve as building blocks for more complicated neural networks. Then we go on to show how these modules may be combined to yield interesting neural networks. This includes discussion of module broadcasting (Section B.3) and mass taring (Section B.4). The appendix culminates with case studies on attention (Section B.6) and transformers (Section B.7).
An atomic module or atom for short is a module with nonzero mass and nonzero parameter space, whose attributes are specifically declared rather than derived. Setting an atom’s mass to zero has the effect of freezing its weights under normed optimization.
These conditions will be automatically satisfied for many neural networks under orthogonal initialization of the weights, and especially if a linear module is immediately preceded by something like a module. Moreover, orthogonal initialization guarantees that the well-normed inequality
holds tightly in nearly-square matrices at initialization, which is important for getting good signal propagation through the whole network.
Moreover, inspection of the second derivative formula above shows it is always -sharp with respect to the RMS norms on the input and output spaces.
This is at first sight similar to the linear module, the key difference being that in applications we expect the inputs of to be one-hot vectors; as such we consider its input space to carry the -norm.
B.2 Bond modules
A bond module or bond is a module with zero mass and zero parameter space. They are the “glue” between the atomic modules, needed to construct complex neural networks.
Note that we need not specify a weight space, or mass or norm arguments for a bond module. Moreover, when discussing whether a bond module is -sharp, the inequalities for and are vacuous; thus for bond modules we will abbreviate this notion to -sharp.
To begin, we need two bond modules that are essentially “utility”, as they are crucial for defining basic secondary module operations. These modules are also “type polymorphic” in the sense that they work with any underlying vector space.
For any vector space , the adder module has inputs and outputs . It has forward function
and sensitivity 1. Its significance is that it allows for concatenable modules to be added:
For any norm on the vector space , is well-normed with respect to the combination norm on its input space. Furthermore, is -sharp.
For any normed vector space and real number the scalar multiplier module has inputs and outputs . Its forward function is:
and its sensitivity is . Its significance is that it allows for scalar multiplication of modules:
It is well-normed with respect to any norm on , and -sharp. When , we call this the identity module . Note that for any .
The remaining bond modules are used explicitly as non-linearities in neural networks.
and RMS norm on inputs and ouputs. For more on this design decision, see . We also define to be the unit sensitivity counterpart to .
where is the cumulative distribution function of the standard Gaussian.
is well-normed in the same sense as . We similarly set .
and has sensitivity 1. It is well-normed, and since it is a linear mapping, it is 0-sharp.
and has sensitivity 1. While it is not automatically well-normed, as long as its inputs have , the required inequality is not very far off. Similarly, it is approximately -sharp.
As with , it is approximately well-normed and approximately -sharp.
B.3 Module broadcasting
Let us briefly discuss a supplementary module operation, which we refer to as module broadcasting.
Suppose is a module with inputs , outputs and weights . Then for any , the -times-broadcast of is the module with the same weight space , mass, sensitivity and norm as , but inputs the Cartesian power and outputs , and forward function
Since this is not defining a module with a new set of weights, we will usually just refer to the broadcast module by the same name , and consider this as just an extension of its forward function.
If is well-normed, then so is any broadcast of taking to , as long as the norms on and are taken to be either the “mean ” norms
for ; when this is just the max norm. In the case that is a bond module (so , any scalar multiple of the mean norm can be used (including the standard norm).
The situation for sharpness is a bit more complicated; we discuss this in Section C.3.
B.4 Mass taring
In order to make working with the mass parameter of modules a bit easier, let us introduce an auxiliary operation:
This way, one can build complex modules starting from atomic modules with unit masses, and then using later to reset their masses to desired quantities for better feature learning with normed descent as in Proposition 3.
B.5 Compound modules and neural networks
Composition, concatenation and the secondary operations of addition, scalar multiplication and iterated concatenation allow us to build a wide variety of neural networks which thus come automatically endowed with the modular norm.
Deep neural networks are typically built as long series of compositions. Let us introduce some terminology:
A deep neural network is a module formed by a composition
where are modules; the number of blocks is the depth of the network.
Typically, each of will be copies of the same module (allowing them to take different weight values, of course), so that the network can be written as an iterated composition
can be principle be any module one likes, but usually is often some form of embedding module, and is usually a linear module.
As for the form of , we found the following design principle to be quite useful in practice:
Arrange so that each has unit sensitivity.
This ensures that the sensitivity of the whole network stays bounded as (this will also be the case if we ensure that , but unit sensitivity has the advantage that the modular norm becomes very explicit). With this in mind:
Suppose that is a module of unit sensitivity whose inputs and outputs are the same space . For any , consider the residual block
and write . This is of unit sensitivity, well-normed if is, and moreover by Proposition 4 is sharp with O(1) sharpness if is.
A general residual network with residue is any neural network of the form
In practice, we will want to apply one more operation: we will want to tare the mass of the residual blocks. To this end, the residual network with residue , depth and total block mass is
Let us give two basic example of residual networks.
This is a simple residual variation on the multi-layer perceptron. For a width , consider the unit sensitivity module
This particular order of operations is inspired by a reecent paper of Yang et al. .
We invite the reader to compare this to something like : three core operations are being performed (but in a different order in both cases): the inputs are being normalized; the inputs are being centered; and the inputs are passed through a nonlinearity that mutates just the negative coordinates.
The network has as its residue an iterated composition of , where the number of copies of in each residue is called the block depth and denoted . It also has just linear initial and final modules. Thus the network of depth , width , block depth and total block mass is
Usually we suggest taking or , and .
This is a version of ResNet for image classification tasks. For a width and kernel size , consider similarly to above the unit sensitivity module
As in the , the network is a residual network with as its residue an iterated composition of copies of where is the block depth. Its initial and final modules are given by
As defaults, we suggest taking and .
B.6 Case study I: Attention
The core of the attention module is a bond module which we call functional attention.
Moreover, we set .
In theory, one could try break up attention further into constituent more basic modules (such as scaled dot product, softmax, etc), but keeping as the basic unit one to leverage efficient implementations of attention such as FlashAttention .
In fact, a perhaps surprising result is that with the above scaling of the dot product, we can estimate the sensitivity and sharpness of . This relies on giving norms for the input and output spaces; these norms are chosen to be
Over the space of inputs with each , the functional attention module is well-normed, and moreover is sharp with sharpness constant .
The proof is given in Appendix E. We thus choose to adopt a -dot-product scaling in our implementation of attention– a rigorous bound as above is not possible for -scaling, for instance.
We can then immediately define a single head of attention.
The scalar multiplication factor of ensures that has unit sensitivity.
For multiple heads of attention, we simply take advantage of module broadcasting (Definition 6):
where is broadcast over the heads dimension. Note that in Modula, we do this by creating dummy bond modules called AddHeads and RemoveHeads to reshape the tensors and create/remove the explicit head dimension.
As in the single-headed case, the scalar multiplication factor of ensures unit sensitivity.
B.7 Case study II: GPT
and form the mass one, sensitivity one module
The depth , width , total block mass module is thus
We suggest, as a default value, a total block mass of .
Appendix Appendix C More on Smoothness and Sharpness
All our estimates of sharpness for compound modules, as well as the smoothness estimate Proposition 5 for loss functions, depend on an application of the chain rule to compute second derivatives which in the optimization context is sometimes called the Gauss-Newton decomposition.
Indeed, this amounts to simply the following expression for partial derivatives:
C.2 Sharpness under composition and concatenation
Here, we state the two formulae for computing the sharpness of a composition and a concatenation of two modules. The proofs are given in Appendix E.
Suppose that and are well-normed, composable modules that are respectively -sharp and -sharp. Under the shorthand that and , the composite is -sharp for:
Suppose that and are well-normed, concatenatable modules that are respectively -sharp and -sharp. Under the shorthand that and , the tuple is -sharp for:
Taken together, Propositions 8 and 9 specify a recursive procedure for computing the sharpness of any compound module that is built from a set of well-normed modules of known sharpness.
These two sets of formulas are actually associative, as the reader may verify using their favorite computer algebra package. This means, for instance, that if are successively composable, well-normed and each -sharp, then the two sets of sharpness estimates coming from applying the above formulas for and actually coincide.
C.3 Sharpness under module broadcasting
Suppose is a well-normed module with inputs , outputs and weights , and suppose moreover that it is -sharp. The broadcast module has the same weights, mass, sensitivity and norm, but takes to .
By Proposition 6, is well-normed, as long as the norms on and are taken to be
for ; unless is a bond module (and thus weight-less), we must take , otherwise can be any positive scalar.
A natural question is whether is also sharp, and if so what its sharpness constants are, with respect to these norms. More or less the same proof as for Proposition 6 shows that the and bounds for sharpness are always true, with the same . The bound is trickier however, and depends subtly on the chosen . We highlight three cases where one can say something interesting.
Case 1. . For the norm, we have that is -sharp with the same by a more or less immediate proof.
Case 2. . For the “standard” -norms, we have that is -sharp with the same . The proof is direct, using the inequality
Case 3. . This is the “RMS norm” case. As in Case 2, one could use a very weak inequality to obtain the pessimistic result that is -sharp. However, one could also make the following observation: if is large, and are sampled from any normal distribution , then
In particular, this justifies the statement that “for large , the broadcast module is approximately -sharp”. While in actual deep learning contexts, the assumption that are sampled from a normal distribution may not be valid, one should still expect the ratio between the two sides of Equation C.13 to stay O(1) as , and so even if the “ rule” is insufficient, the effective sharpness of the broadcast module should not blow up as .
C.4 Smoothness estimates for common error measures
We now present estimates on and for square and cross-entropy error. Both estimates will be in terms of the value of the average loss function itself, rather than being truly global over the entire output space . Thus, to apply them to real learning problems, one should measure the average loss at initialization, and use this for estimates for and ; we are implicitly making the assumption that under gradient descent the loss decreases.
Square error
The desired constants can then be computed as maxima:
which from the above formulas amounts exactly to
To translate this into a bound for the average loss function , note that square root is a concave function. Thus if we have outputs with true classes , Jensen’s inequality yields
allowing us to use as our estimate for Proposition 5.
Cross-entropy error
Consider again the RMS norm on . An estimate on can thus be computed as
using the basic fact that if are non-negative numbers that sum to 1, then
(Indeed, for fixed , the left hand side is maximized at and all other = 0; one then easily checked that for all .)
A similar concavity argument to the square error case then enables us to use as the first derivative bound for average cross-entropy loss.
The second derivative bound depends on more subtle information geometry. Indeed, can be computed to be
where is the largest eigenvalue of the matrix . It is possible for this eigenvalue to be quite large (for instance, if and all other , then ). However, the average eigenvalue is
If we presumed that, in the course of a gradient descent optimizing the weights of a module , the output perturbations are only generically aligned with the eigenvectors of , then we could use the “effective” smoothness bound .
Perhaps this is a dubious assumption however. A more conservative, but perhaps still dubious, assumption comes from assuming that the logits have roughly entries—at least this could be more or less true at initialization. In this case, the largest eigenvalue is with high probability bounded as
justifying “approximate” smoothness bound of .
Appendix Appendix D Experimental Details
All experiments with ResMLP and ResNet are done with the CIFAR-10 image dataset with standard train and test splits. For data augmentation on the training set, we use random crop, random horizontal flip and PyTorch AutoAugment.
For the GPT transformer experiments, we compared three different datasets:
The Shakespeare corpus, using character-level tokens ;
The TinyStories database using sub-word level tokenization;
OpenWebText using sub-word level tokenization .
No data augmentation was used on the language data. We used data splitting code from .
D.2 Architectures
Full details of the ResMLP, ResNet and GPT architectures we used are detailed in Appendices Section B.5 and Section B.7. In every experiment, we used:
heads for GPT, with query and value dimensions where is the embedding dimension (width);
context length for GPT, except for the experiment in Section D.7.
D.3 Hardware
All experiments were run on NVIDIA GPUs using float32-precision. We used a combination of TITAN-RTX, RTX-3090, V100, Ada6000, and H100 devices. Each data point in the experiments takes up to hours, depending on the computing device used. We ran over 1000 training runs in total.
D.4 Comparing to standard nanoGPT architecture
Our implementation of GPT in Modula has certain differences from off-the-shelf architectures such as nanoGPT . We would summarize the overall changes to transformer architecture and training the following three points:
the mathematical architecture has slightly different coefficients;
we initialize weight matrices to be orthogonal rather than Gaussian;
we train using normalized weight updates.
The architectural choices we made were entirely informed by the desire for the network to be well-normed and have unit sensitivity: in particular this means that the network enjoys favorable signal propagation properties. In the language of modules, these architectural changes can be summarized as:
Each residual block in our architecture is of the form
where or , compared to suggested for nanoGPT;
We use a scaled dot product attention with scaling, rather than ;
We use several additional scalar multiplications to keep the network of unit sensitivity:
Each Attention module (B.44) has a scalar factor of ;
Each MLP module (B.45) has a scalar factor of ;
The token and position embeddings (B.49) have a scalar factor of .
In Figure 5, we ran a comparison of the performance of the standard (unnormed) Adam optimizer trained on OpenWebText with:
the nanoGPT architecture with Gaussian initialization;
our implementation of GPT with orthogonal initialization.
We found that even without using the normed optimizer, our implementation with orthogonal initialization transferred learning rate better. We suggest that even the base Adam optimizer benefits from the above architectural changes.
D.5 Full sweeps
In Figures 9, 10, 11 and 12, at the end of this Appendix, we report on full learning rate sweep experiments, across width and depth, for GPT on OpenWebText and TinyStories, and ResMLP, ResNet on CIFAR-10.
We consistently find that the normed Adam optimizer matches or outperforms unnormed Adam in both test and training loss, all the while exhibiting significantly better transfer across width. The difference in depth transfer is less stark, however we posit that, in part, unnormed Adam is already benefiting from architectural changes we made to improve depth scaling.
Notice too that normed SGD consistently significantly outperforms ordinary SGD, often coming close to or matching the performance of Adam. We would like to highlight this, since SGD has a significantly lower memory requirement than Adam, and does not require any tuning of .
D.6 Mass allocation
A novel feature of our normed optimization framework is the need to choose a mass parameter for each atomic module. In the context of networks of the form
where . We typically do this by assuming that have mass 1, and by hand resetting the mass of to be a fixed total mass , by calling .
In this Appendix, we explore some different aspects the choice of .
First, we tested whether or not calling is necessary in the first place. Not using tare would leave the “free mass” of ; accordingly as grows large, the feature learning allotment (see Proposition 3) for and would grow smaller. Indeed, as the reader can see in Figure 6, this “free mass” arrangement for a ResMLP network on CIFAR-10, allowing the mass of to grow with is very undesirable, and for good learning rate transfer with depth we must fix a mass.
The mass is thus left as a tunable parameter. We then tested the transferability of mass tuning. Specifically, we wanted to know:
whether one can tune on a network of small width/depth, and expect that same to be close to optimal on a larger network;
whether learning rate transfer across width/depth is itself dependent on selecting a good mass ;
how sensitive the tuning for is: if there is a broad range of acceptable masses, or certain precise values lead to big improvements in train or test loss.
Figures Figure 3 and Figure 6 answer Question 1 above in the affirmative, in the context of ResMLP on CIFAR-10 and GPT on OpenWebText. Moreover, in the context of ResMLP on CIFAR-10, they give an answer of Question 2 and Question 3: learning rate transfer occurs at a range of values of .
Figure 7 address Question 3 in the context of transformers, on three different datasets. Across all three datasets, a mass in the region to is reasonable.
D.7 Context length
Additionally, we also tested the dependence of the optimal learning rate for GPT training on OpenWebText on the context length; the results are in Figure 8 Interestingly, we report good transfer of the optimal learning rate from small contexts to long contexts.
D.8 Full sweep results
The next four pages of the Appendix list results of our full learning rate sweeps over width/depth for GPT on OpenWebText and TinyStories, and ResMLP, ResNet on CIFAR-10.
Appendix Appendix E Proofs
To prove Proposition 3, it suffices to induct on the construction of a compound module by composition and concatenation, with the atomic modules (where the inequality is just part of well-normed-ness) as the base case.
Indeed, suppose either or . Suppose is a weight for one of the atomic modules of , and write for the mass of this atomic module. Then is must be a weight of either or ; the inductive assumption is that
Case 1. and is a weight of . From the chain rule we then must have:
where the last line is by the definition of the norm of module composition.
Case 2. and is a weight of . The chain rule is not needed in this case, and we proceed straight from the inductive assumption:
Case 3. . Given the symmetric roles of , without loss of generality assume is a weight of . Then,
Proposition 4: Sharpness of residual networks
Suppose is a well-normed module of unit sensitivity on and is -sharp. Then, by Proposition 8 for any , the module is well-normed, sensitivity , and -sharp.
The module is also well-normed, sensitivity , and -sharp. In particular, the sum
is well-normed, unit sensitivity, and -sharp; it has the same mass as the original module.
We induct on the statement for that is -sharp where
The base case is clearly true, and given the statement for , which has exactly times the mass as , we see that is -sharp by applying Proposition 8 with and , where
Setting , observe that and , giving
Proposition 5: Smoothness in the modular norm
The second inequality follows from the first via the fundamental theorem of calculus:
The third inequality follows from the second by again applying the fundamental theorem of calculus, followed by the Cauchy-Schwarz inequality:
Proposition 6: Broadcast modules are well-normed
Suppose is a module with inputs , outputs and weights , broadcast to take to . We take norms on these spaces to be
where unless is a bond module. Write . Then, for perturbations in the weight direction, which only occur if is not a bond module:
For perturbations in the input direction, we have:
Proposition 7: Sensitivity of attention
We prove that the functional attention module of 9 is well-normed and of unit sensitivity.
Writing for short, recall that
where is the mask (our proof will apply equally for the standard causal mask and also the non-causal ).
We will prove that at any satisfying , for any we have
For short, write for the attention matrix and its derivative as
Now, the derivative of splits into two terms
The calculation of the norm of follows by definition from its construction by softmax. For the calculation of the norm of , a direct calculation yields that
where we are writing and so on.
Taking absolute values, applying the Cauchy-Schwarz inequality and summing over we have
Applying to the matrix , we thus have
Taking the max over , this shows the -operator-norm of is at most
which, since , completes the proof.
Proposition 7: Sharpness of functional attention
In this section, we estimate the second derivative of the forward function of functional attention at in perturbation directions and :
We will prove that functional attention is -sharp where in fact ; this amounts to proving that
With these conventions, the expression for is thus
In these terms, the second derivative is just
From the estimates of the previous section, we have
so our task is to estimate the -operator-norm of . Thus, we calculate :
We estimate the -operator-norm of these six terms one by one. The first (E.72), (E.73) are the simplest, using inequality (E.62):
Take absolute values, sum over , and apply Cauchy-Schwarz and inequalities (E.63),(E.64):
Taking the max over and applying :
and so by Cauchy-Schwarz and the fact that the rows of sum to 1:
By a similar argument, for term (E.77) we have:
Thus, we have an estimate on the -operator-norm of :
where all the norms on the right hand side are .
Adding this together with (E.70) and (E.71), we obtain (all norms being :
Proposition 8: Sharpness under composition
Suppose where are well-normed modules on respectively and moreover -sharp for . If for , note that by the definition of the modular norm on the composite , we have for any :
We must prove that is sharp where:
Turning to the second derivative of , we prove the first Inequality (E.97). The Gauss-Newton decomposition (C.1) for any and yields
Applying the well-normed and sharpness inequalities, the norm of the first (E.100) of these terms is bounded by
The second term (E.101) breaks into four separate terms:
In particular, applying the well-normed and sharpness inequalities, this is bounded by
which completes the proof of Inequality (C.4).
Inequalities (E.98) and (E.99) are simpler. For the first of these, note we have
Term (E.114) breaks into two separate terms
which completes the proof of Inequality (E.98).
Proposition 9: Sharpness under concatenation
Suppose where are well-normed modules on respectively and moreover -sharp for . If for , as in the previous proof we have for any :
We must prove that is -sharp where
Now, for the first of these identities, we have for and :
which shows . The expressions for follow similarly.