LEACE: Perfect linear concept erasure in closed form
Nora Belrose, David Schneider-Joseph, Shauli Ravfogel, Ryan Cotterell, Edward Raff, Stella Biderman
Introduction
The ability to prevent a machine learning system from using a specified concept is important for fairness and interpretability. Popular notions of fairness require that protected attributes should not causally affect predictions , and interpretability research often estimates the causal effect of a concept by attempting to remove it from a model’s internal representations .
In this work, we improve upon existing concept erasure techniques using a theory-driven approach. We focus on the case where is the set of linear classifiers, and prove a previously unnoticed equivalence: a classification task is linearly guarded if and only if every class has exactly the same mean feature vector (§ 3). Leveraging this equivalence, we derive a simple necessary and sufficient condition for an affine transformation to produce linearly guarded features. We then identify the unique surgical transformation in this family—the one that minimizes the mean squared distance from the original features with respect to all norms induced by inner products, including the popular Euclidean and Mahalanobis norms. We name it LEAst-squares Concept Erasure (LEACE) (§ 4).
We empirically validate our proposals, demonstrating the superiority of LEACE for erasing gender bias from BERT representations (§ 5.2), and using concept scrubbing to measure the extent to which large language models use part-of-speech information (§ 6).
Preliminaries
We borrow the concept of guardedness from Ravfogel et al. , who define it in terms of -information . We opt for a slightly more general definition here, which is equivalent to theirs in the case of cross-entropy loss (see Appendix G).
Theoretical Results
Our primary theoretical result is that the following conditions are all equivalent:
The equivalence of conditions 1, 2, and 5 is relatively straightforward to show, and the relevant theorems can be found in Appendices B and C. The other equivalences are proven below (cond. 3 cond. 2 in § 3.1 and § 3.2); cond. 3 4 in § 3.3).
The following result establishes the implication from condition 3 to condition 2.
2 Linear Guardedness Implies Equality of Class Centroids
We now prove the implication from condition 2 to condition 3. Condition 2 applies when the trivially attainable loss is optimal for all convex losses, including cross-entropy loss in particular. And if it holds for cross-entropy loss, we now show that condition 3—the class centroids are equal—must follow. First a more general lemma:
The first-order optimality condition on the component of our parameters and yields the equations:
We now show that Lemma 3.2 applies to the widely used cross entropy loss:
3 Linearly Guarded Labels Have Zero Covariance with the Features
The next theorem establishes the equivalence of conditions 3 and 4.
We have thus established the equivalence of the first four conditions stated earlier. See Appendix C for the last one, on statistical parity.
Least-Squares Concept Erasure
Implications for prior work. Notably, the above theorems imply that three previously proposed methods in the literature, Spectral Attribute Removal (SAL) , Mean Projection , and Fair PCA , are guaranteed to achieve linear guardedness given suitable hyperparameters. See Appendix D for further discussion.
See Appendices E.1 and E.2 for two independent proofs of Theorem 4.2. ∎
Putting together Theorems 4.2 and 4.3 and rearranging, we arrive at the LEACE formula:
2 Oblique Projections are Least-Squares Optimal
Evaluation
Following Ravfogel et al. we evaluate the ability of our method to remove gender information from the last hidden layer of a frozen BERT model. We use the biographies dataset of De-Arteaga et al. , composed of short biographies annotated by both binary gender and profession. We embed each biography with the [CLS] representation in the last layer of BERT, enforce the same-conditional-mean constraint to remove gender information from the [CLS] , and then evaluate the performance of the model, after the intervention, on the main task of profession prediction. We compare our intervention with RLACE , which uses gradient-based optimization to solve a linear concept-erasure adversarial game.
First, we evaluate the ability of logistic regression classifiers to recover the removed information. The results, presented in Fig. 2, show that our method is the only to achieve random accuracy (perfect erasure) with a small edit, although RLACE (but not INLP) comes close. At the same time, our method is around 2 orders of magnitude faster, and does not require gradient-based optimization.
2 Downstream Fairness
How does our intervention affect the behavior of the model on the main classification task of profession prediction? We fit a logistic regression profession-prediction classifier over the projected [CLS] representations.
3 Revisiting Amnesic Probing
Elazar et al. have introduced the idea of amnesic probing as a causal intervention that aims to test the importance of a given concept (e.g. part-of-speech tag) to some main task (e.g. language modeling). They applied Iterative Nullspace Projection (INLP) to remove different concepts from the hidden representations of the model, and assessed the degree to which its behavior changed when performing masked language modeling. Since INLP often requires dozens of iterations to completely erase the concept, its usage in this context raises concerns of collateral damage due to magnitude of the intervention and the non-exhaustive nature of INLP removal. Here, we replicate their experiments on the bert-base-uncased model with our interventions.
We use part-of-speech (POS) tags as our concept of interest. We collect sentences and their coarse POS tags (“Noun”, “Verb” etc.; 18 in total) from the English Universal Dependencies dataset . We tokenize the sentences with the BERT tokenizer and map each word-piece to the POS tag of the word to which it belongs. We collect the unmasked BERT representations for each layer, intervene to linearly erase the POS concept from that layer, and continue the forward pass until the last layer, from which we compute the distribution of the MLM over the vocabulary. Note that in each experiment we intervene on a single layer. We quantify the decrease in accuracy following the intervention, as well as the increase in the loss. We compare with a baseline intervention of a random orthogonal projection whose null space has the same rank as the label space (18). For INLP, we perform 20 iterations. This is needed because INLP does not effectively remove the concept; even after 20 iterations, classification accuracy is above majority accuracy. As a result, INLP reduces the rank of the representation by 360. By contrast, our method decreases the rank just by 17.
Results.
The results are shown in Fig. 4b. Our intervention only mildly changes BERT LM accuracy and loss until layer 8, with the highest drop recorded in layer 11. INLP, in contrast, shows maximum effect at layer 6. Since it removes hundreds of dimensions, it is difficult to attribute this effect to the erasure of the concept. These results suggest that the causal effect of the POS concept on the language model is concentrated in layer 11. Interestingly, this stands in contrast with POS linear probing results, which are optimal at earlier layers . As Elazar et al. have noted, probing does not generally correlate with intervention-based analysis techniques.
Concept Scrubbing
Unfortunately, Elazar et al. were forced to limit their interventions to a single layer due to the limitations of INLP. INLP often requires the deletion of several dozen dimensions before linear guarding is achieved—as demonstrated in Figure 2. Kumar et al. show empirically and theoretically that INLP causes needless “collateral damage” to useful parts of the representation that are orthogonal to the concept being erased. Because of this collateral damage, it’s impossible to apply INLP to multiple layers of a transformer without causing its outputs to collapse into gibberish.
Instead, we would like to erase all linear information about a concept in every intermediate representation, which we term concept scrubbing. LEACE makes concept scrubbing possible and eminently practical. It causes minimal collateral damage, induces little computational overhead, and the covariance statistics it relies on can be computed in a streaming fashion, without ever storing all the hidden states in memory or on disk.
Dataset. For each model family, we use a sample from the respective pretraining distribution: the validation split of the Pile for the Pythia models , and the RedPajama replication of the LLaMA pretraining corpus for the LLaMA family . sampling a slice of tokens for fitting the LEACE parameters and another slice of tokens for evaluation. Since neither corpus comes with part-of-speech tags, we use the model from the SpaCy library to automatically generate Universal Dependency tags .
Baseline method. We also run concept scrubbing using full-rank SAL , which is similar to our method but lacks a bias term and does not adjust for correlations between features (Appendix D).
Architecture. We focus on autoregressive language models. We evaluate our method on EleutherAI’s Pythia 160M, 1.4B, 6.9B, and 12B models , and Meta’s LLaMA 7B, 13B, and 30B . We apply concept erasure to the input of each transformer block, immediately after normalization is applied (LayerNorm or RMSNorm).
Randomized erasure. Almost any intervention on a neural network will cause its performance to degrade to some extent. Following Elazar et al. , we isolate the effect of the concept erasure by comparing it to a control condition in which we orthogonally project onto a random linear subspace of the same rank as the cross-covariance matrix. To reduce the variance of our results, we sample a fresh subspace for each minibatch, and erase that subspace at each layer, reporting the cross-entropy loss averaged over subspaces.
2 Results
We find strong evidence that autoregressive language models heavily rely on linearly encoded part-of-speech information. While erasing a randomly selected subspace has little to no effect on language modeling performance, scrubbing away part-of-speech information induces a large increase in perplexity across all models (Table 1).
The specific numbers, however, depend on the erasure method used: SAL induces significantly larger increases in perplexity for all models we tested. We take this to mean that SAL inflicts more collateral damage on other useful features in the representation than LEACE does. In other words, interventions made with LEACE are more surgical than those made with prior work; they more closely approximate the ideal of a perfect intervention which only erases the target concept and keeps everything else fixed . If this experiment were conducted with SAL alone, we would have overestimated the causal effect of part-of-speech.
Limitations and Future Work
Much work remains to be done to validate concept scrubbing. Specifically, we’d like to see experiments that target concepts much narrower than part-of-speech, and use behavioral metrics to determine whether scrubbing changes the network in the ways we’d intuitively expect. If these experiments succeed, an exciting next step would be the incorporation of concept scrubbing into the pretraining and/or finetuning process. This may make it possible to train deep neural networks subject to conceptual constraints. It remains to be seen if gradient-based optimizers will be able to “circumvent” such constraints by learning completely nonlinear representations of protected attributes.
A major motivation of concept erasure is that it promises to prevent models from using a concept in a post hoc, model-agnostic fashion. But if our concept scrubbing procedure turns out to yield unsatisfactory results in practical use cases, the most promising research direction might then be to improve model-specific techniques, such as those that modify the training procedure .
Acknowledgements
We are grateful to CoreWeave for providing the compute resources used in Section 6. Shauli Ravfogel is grateful to be supported by the Bloomberg Data Science PhD Fellowship.
References
Appendix A Additional Related Work
The problem of linear concept erasure is an instance of the general problem of information removal. Information removal methods generally divide into adversarial methods, which are applied during training, and the post-hoc linear methods considered in this paper. Adversarial methods use a gradient-reversal layer during training to induce representations that do not encode the protected attribute. However, Elazar and Goldberg have shown that these methods fail in exhaustively removing all the information associated with the protected attribute: it is often possible to train new adversaries that successfully recover the removed information. Linear methods have been proposed as a tractable alternative, where one identifies a linear subspace that captures the concept of interest, and neutralizes it using algebraic techniques. Different methods have been proposed for the identification of the subspace, e.g. PCA and variants thereof , orthogonal-rotation , classification-based , spectral and adversarial approaches .
Few works theoretically characterize the condition of linear guardedness. Haghighatkhah et al. extensively analyzed the problem of preventing linear classification, with the focus on decreasing accuracy. They provide a constructive proof of an optimal intervention for an SVM classifier. Ravfogel et al. have proposed a formal definition of linear guardedness based on information, and characterized the fairness implications of guardedness; we show the relations with our definition above. Ravfogel et al. provide an adversarial formulation of the problem, derive a closed-formed solution to certain cases, and propose an SGD-based optimization for others. While they seek an orthogonal projection, we empirically showed that their solution is very close to ours. Sadeghi et al. and Sadeghi and Boddeti both study an adversarial formulation of concept erasure for linear regression, and they trade-off with main-task performance. In contrast to Ravfogel et al. , they consider a general linear adversary, i.e. not necessarily a projection matrix. Closest to our work are Kleindessner et al. , Haghighatkhah et al. , Shao et al. . As we showed above (§ 4), those methods do achieve the goal of linear guardedness though they are unable to prove this fact. At the same time, they are not optimal in terms of damage to the original representation space.
Appendix B Equivalence of Guardedness with the Optimality of Constant Predictors
The following two theorems establish the equivalence of conditions 1 and 2 (indeed, they do so in the general setting, with no assumption of convex loss or linear predictors).
Combining equations (4) and (5), and the fact that all constant functions exist in our function class , we arrive at our desired result:
Appendix C Linear Guardedness is Equivalent to Linear Statistical Parity
To measure the effect of linear guardedness on main-task classifiers, we use the following minimal definition of “fairness” with respect to an attribute, adapted from Edwards and Storkey .
We now prove the equivalence of conditions 3 and 5.
This matches the definition of statistical parity provided in Definition C.1.
Appendix D Implications for Prior Work
In this section we discuss the implications of Theorem 4.1, which characterizes the necessary and sufficient conditions for an affine erasure function to yield a perfectly linearly guarded dataset, for methods proposed in prior work.
Appendix E Derivation of LEACE
Below are two independent proofs of Theorem 4.2.
at which point the weights of each subproblem become irrelevant, and our objective may as well be Euclidean, allowing us to view each row as an independent optimization problem not just in this basis, but from any convenient one.
The sub-objective is then:
E.2 Covector Proof
Appendix F The Optimality of Oblique Projections
Appendix G Equivalence of Guardedness Definitions
where \color[rgb]{0,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,0}\pgfsys@color@gray@stroke{0}\pgfsys@color@gray@fill{0}\mathcal{L}(\eta,z)=-\log\frac{\exp(\eta_{z})}{\sum_{i=1}^{k}\exp(\eta_{i})} is the cross-entropy loss function.