On Catastrophic Forgetting and Mode Collapse in Generative Adversarial Networks
Hoang Thanh-Tung, Truyen Tran
I Introduction
GANs are a powerful tool for modeling complex distributions. Training a GAN to approximate a single target distribution is often considered as a single task. In this paper, we introduce a novel view of GAN training as a continual learning problem in which the sequence of changing model distributions are considered as the sequence of tasks. We discover a surprising result that GANs suffer from catastrophic forgetting, a problem often observed in continual learning settings . Catastrophic forgetting (CF) in artificial neural networks is the problem where the knowledge of previously learned tasks is abruptly destroyed by the learning of the current task. When a GAN suffers from CF, it exhibits undesired behaviors such as mode collapse and non-convergence.
In section III, we show that GAN training is actually a continual learning problem and demonstrate the CF problem on a number of datasets. We show that catastrophic forgetting and mode collapse are two different but interrelated problems and together, they can make the training of GANs non-convergent (section III-B, IV-B). To avoid mode collapse and improve convergence, it is important to address the CF problem. We identify 2 factors that causes CF in GANs: 1) Information from previous tasks is not used in the current task, 2) Knowledge from previous tasks is not usable for the current task and vice versa. Our findings shed light on how to avoid catastrophic forgetting to learn the target distribution properly (Section V).
In section IV, we investigate the effect of CF and mode collapse on the landscape of the discriminator’s output. We find that when a GAN converge to a good local equilibrium without mode collapse, real datapoints are wide local maxima of the discriminator. We show that the sharper the local maxima are, the more severe mode collapse is. Section IV-B shows that when CF happen, the discriminator is directionally monotonic. A GAN with a directionally monotonic discriminator does not converge to an equilibrium. The fact confirms that CF is a cause of non-convergence.
Section V explains how state-of-the-art methods for stabilizing GANs such as Wasserstein GAN , zero-centered gradient penalty on training examples (GAN-R1) , zero-centered gradient penalty on interpolated samples (GAN-0GP) , and optimizers with momentum, can prevent CF and mode collapse. Finally, we introduce a new loss function that helps preventing CF while adding zero computational overhead.
We show the relationship between CF, mode collapse, and non-convergence.
We study the relationship between the sharpness of local maxima and mode collapse.
We show that CF tends to make the discriminator directionally monotonic around real datapoints.
We identify the causes of CF and explain the effectiveness of methods for preventing CF in GANs.
II Related works
Convergence. Prior works on the convergence of GANs usually consider the convergence in parameter space . However, convergence in parameter space tells little about the quality of the equilibrium that a GAN converge to. For example, Thanh-Tung et al. demonstrated that TTUR can make GAN converge to collapsed equilibrium. Consensus Optimization can introduce spurious local equilibria with unknown properties to the game.
We directly study the behaviors of GANs in the data space. By analyzing the discriminator’s output landscape, we find that when a GAN converges, real datapoints are local maxima of the discriminator. We discover the relationship between the sharpness of local maxima and mode collapse, generalization.
Catastrophic forgetting. Seff et al. studied the standard continual learning setting in which a GAN is trained to generate samples from a set of distributions introduced sequentially. The problem is solved by the direct application of continual learning algorithms such as Elastic Weight Consolidation (EWC) to GANs. Liang et al. independently came up with a similar intuition that GAN training is a continual learning problem.Liang et al. came up with the idea a few months after us. They agreed that we are the first to consider the catastrophic forgetting problem in a single GAN. Their preprint has not been published at any conferences or journals. The paper, however, did not study the causes and effects of the problem and focused on applying continual learning algorithms to address catastrophic forgetting in GANs. We focus on explaining the causes and effect of the problem and its relationship to mode collapse and non-convergence.
III Catastrophic forgetting problem in GANs
At each iteration of the training process, is updated to better fool . , the model distribution at iteration , is different from the model distribution at the previous iteration and the next iteration . The knowledge required to separate from is different from that for the pair . and are two different classification tasks to the discriminator.In the original theoretical formulation of GAN, at every GAN iteration, the discriminator and the generator are trained until convergence . That means can be arbitrarily different from . In practice, at each iteration, only a limited number of gradient updates are applied to the players. We can consider a chunk of consecutive model distributions as a task to the discriminator. The sequence of changing model distributions and the target distribution form a sequence of tasks to the discriminator. Because the generator at iteration , , can only generate samples from , , the discriminator at iteration , cannot access samples from previous model distributions . That makes the learning process of a continual learning problem. Similarly, the generator has to fool a sequence of changing discriminators . The training process of a GAN poses a different continual learning problem to each of the players. In this paper, we focus on the continual learning problem in the discriminator as many prior works have showed that the quality of a GAN mainly depends on its discriminator .
If the sequence converges to a distribution , then the sequence of tasks converges to a single task of separating 2 distributions and . In practice, however, the sequence of model distributions does not always converge. Nagarajan and Kolter formally proved that the players in Wasserstein GAN do not converge to an equilibrium but oscillate in a small cycle around the equilibrium. Although non-saturating GAN (GAN-NS) was proven to be convergent under strong assumptions , Fedus et al. observed that on many real world datasets, the distance between and (measured in KL-divergence and Jensen-Shannon divergence) does not decrease as increases. The authors suggested that can approach in many different and unpredictable ways. These results imply that in the most common variants of GANs, can be arbitrarily different from for large . If the knowledge used for separating and cannot be used for separating and , a discriminator trained on could forget , i.e. it classifies samples in wrongly (Fig. 11b). When this happens, we say that the discriminator exhibits catastrophic forgetting behaviors.
III-B Catastrophic forgetting in GANs
We begin by analyzing the problem on the 8 Gaussian dataset, a dataset generated by a mixture of 8 Gaussians placed on a circle. In Fig. 1, red datapoints are generated samples, blue datapoints are real samples. The discriminator and generator are 2 hidden layer MLP with 64 hidden neurons. ReLU activation function was used. is a 2-dimensional standard normal distribution. SGD with constant learning rate of was used for both networks. The vector at a datapoint shows the negative gradient . The vector shows the direction in which decreases the fastest. The length of the vector corresponds to the speed of change in . Because the gradient field is conservative, the the difference between the loss of two datapoints and is:
where and is a path from to . For the variants in Table I, only depends on and . Because decreasing in these GANs corresponds to increasing , going in the direction of increases the score . Let be a fake datapoint. Updating with SGD with a small enough learning rate will move in the direction of by a distance proportional to . If the discriminator is fixed, then SGD updates will move along its integral curve, in the direction of increasing . In practice, gradient updates are not applied to but to the generator’s parameters. Because the generator also minimizes , gradient updates to the generator move in a direction that approximates . is a good approximation of the direction that will move in the next iteration.
Fig. 1a - 1d show the evolution of a GAN-NS on 8 Gaussian dataset. In Fig. 1a - 1c, the discriminator assigns higher score to datapoints that are further away from the fake datapoints, regardless of the true labels of these points. This is shown by the gradient vectors pointing away from the fake datapoints. The integral curves do not converge to any real datapoints. If is fixed, updating with gradient descent makes diverges. Because gradients w.r.t. different fake datapoints have the same direction, almost all of fake datapoints move in the same direction and do not spread out over the space. Because of CF, the generator is unable to break out of mode collapse.
Inside the green box (Fig. 1a), gradients at all datapoints have approximately the same direction. The loss decreases (the score increases) monotonically along the direction of the green vector , a random vector that points away from the fake datapoints. Graphically, we see that the angles between the green vector and are less than for all in the box. Thus, the dot product is positive. The line integral in Eqn. 1 is positive for in the box that satisfy . monotonically decreases along the direction of . We say that is monotonic in direction . We have the following observation:
In a large neighborhood around a real datapoint, (and therefore, ) is directionally monotonic.
A theoretical explanation to this phenomenon is given in Sec. IV-B. Because fake samples in Fig. 1a-1d are concentrated in a small region (i.e. mode collapse), can easily separate them from distant real samples and does not learn useful features of the real data. We say that catastrophically forgets real samples that are far away from the current fake samples. Mode collapse and CF are interrelated, one problem makes the other more severe.
In Fig. 1b, fake datapoints on the right of the red box have higher scores than real datapoints on the left, although in Fig. 1a, these real datapoints have higher scores than these fake datapoints. Going from Fig. 1a to 1d, we observe that the vectors’ directions change as soon as fake datapoints move. The phenomenon suggests that information about previous model distributions is not preserved in the discriminator. As tries to separate from , it assigns low scores to regions with fake samples and higher scores to other regions. Because does not ’remember’ , it could assign high scores to regions previously occupied by , i.e. could classify old fake samples as real. Fake samples at iteration 3000 (Fig. 1a) are classified as real by (Fig. 1b). Similar behaviors are observed on MNIST (Fig. 11b). Because of forgetting, could direct to move to a region which has visited before. That could cause and to fall in a learning loop and do not converge to an equilibrium. In Fig. 1a - 1d, the model distribution rotates around the circle indefinitely. CF is a cause of non-convergence.
III-B2 Catastrophic forgetting on image datasets
for . We use the same for all images in Fig. 2. We choose to visualize instead of because explodes if . The quality of the image decreases as increases. A good discriminator should assign lower scores to samples with lower quality. should be higher than , i.e. is a local maximum of . If is a local maximum of , must have a local maximum at (the center of each subplot). The result reported below was observed in all 10 different runs of the experiment.
Fig. 2 demonstrates the problem on MNIST. The generator and discriminator are 3 hidden layer MLPs with 512 hidden neurons. SGD with constant learning rate was ued in training.
As shown in Fig. 2, the generated images keep changing from one shape to another, implying that the game does not converge to an equilibrium. In a large neighborhood around every real image, the discriminator’s output is monotonic in the sampled direction. At iteration 100000, for every image, is a decreasing function (Fig. 2f), while at iteration 200000, is an increasing function (Fig. 2g). More conretely, let be the discriminator’s directional derivative along direction at at iteration . Then Fig. 2f and 2g shows that and for some near the real datapoint , have opposite directions. The knowledge of (what learned on ) is not usable for .
We trained DCGAN on CelebA and CIFAR-10 to study the effect of network architecture and dataset complexity on the level of forgetting. Network architecture and hyper parameters are given in Table II.
On CelebA, Fig. 9a - 9g show that CNN suffers less from CF than MLP. The discriminator in DCGAN-NS is not directional monotonic and it successfully makes many real datapoints its local maxima (see Sec. IV for more). The discriminator can effectively discriminate real images from neighboring noisy images. The generator moves fake datapoints toward these local maxima and produces recognizable faces.
On CIFAR-10 (Fig. 10a - 10g), the discriminator cannot discriminate real images from noisy images. The function in Fig. 10b is almost an increasing function while in Fig. 10d it is almost a decreasing function. The training does not converge as fake images change significantly as the learning progresses.
Conclusion: GAN-NS trained on high dimensional datasets exhibits the same catastrophic forgetting behaviors as on toy datasets: (1) real datapoints are not local maxima of the discriminator or in more extreme cases, the discriminator is directionally monotonic in the neighborhoods of real datapoints; (2) the gradients w.r.t. datapoints in the neighborhood of a real datapoint change their directions significantly as fake datapoints move.
III-B3 The causes of Catastrophic Forgetting
Based on the above experiments, we identified two reasons for CF:
Information from previous tasks is not carried to/used for the current task. SGD does not use information from previous model distributions, . At iteration , SGD update for the discriminator is computed from samples from and only. Because information from is not used in training, the discriminator forgets , i.e. it does not assign low score to samples from .
The current task is significantly different from previous tasks so the knowledge of the current task cannot be used for previous tasks and vice versa. As old knowledge is overwritten by new knowledge, optimizing the discriminator on the current task will degrade its performance on older tasks.
Methods for preventing CF is studied in Section V.
IV The output landscape
We apply the visualization technique in Section III-B2 to other variants of GAN. We reuse the network architecture and learning rate from the experiment in Fig. 2. We replace SGD with Adam with . We run each experiment 10 times with different random seeds and report results that are consistent between different runs. The evolution of the landscape and generated samples of GAN-NS, GAN-0GP with , GAN-R1 with , and WGAN-GP with are shown in Fig. 3, 4, 5, and 6 respectively.
GAN-0GP, GAN-R1, and WGAN-GP have significantly better sample quality and diversity than GAN-NS. GAN-NS does not exhibit good convergence behavior: the digit in a image changes from one digit to another as the training progresses (Fig. 3).Note that this does not contradict the statement in that GAN-NS converge to an equilibrium. Many of the assumptions in that paper is not satisfied in practice, e.g. the learning rate is not decayed toward 0. GAN-0GP, GAN-R1, and WGAN-GP exhibit better convergence behaviors: for many images, the digits stay the same during training.
We observe that throughout the training process of GAN-0GP, GAN-R1, and WGAN-GP, for every real datapoint, the function always has a local maximum at , implying that real datapoints are local maxima of the discriminator. This can also be seen in GAN-R1 trained on the 8 Gaussian dataset (Fig. 1e - 1g): the gradients w.r.t. datapoints in the neighborhood of a real datapoint point toward that real datapoint (GAN-0GP and WGANGP exhibit the same behaviors). If a fake datapoint is in the basin of attraction of a real datapoint and gradient updates are applied directly on the fake datapoint, it will be attracted toward the real datapoint. Different attractors (local maxima) at different regions of the data space attract different fake datapoints toward different directions, spreading fake datapoints over the space, effectively reducing mode collapse.
Fig. 7 shows that GAN-0GP with suffers from mild mode collapse.This is consistent with the analysis by the authors of GAN-0GP. Thanh-Tung et al. claimed that larger leads to better generalization but may slow down the training. The maxima in Fig. 7 are much sharper than those in Fig. 6. The discriminator overfits to the real training datapoints and forces the scores of near by datapoints to be close to 0. That creates many flat regions where the gradients of the discriminator w.r.t. datapoints in these regions are vanishingly small. A fake datapoint located in a flat region cannot move toward the real datapoint because the gradient is vanishingly small. Real datapoints in Fig. 7 have small basin of attraction and cannot effectively spread fake samples over the space. The diversity of generated samples is thus reduced, making mode collapse visible. In order to attract fake datapoints toward different directions, local maxima should be wide, i.e. they should have large basin of attraction.
The landscapes of GAN-NS in Fig. 2 and 3 contain many flat regions where the scores are very close to 1 or 0. The same problem is seen on the 8 Gaussian dataset (datapoints in the orange and blue boxes in Fig. 1a-1d have scores close to 1 and 0, respectively). However, unlike Fig. 7, the real datapoints in Fig. 1a - 1d, 2, and 3 are not local maxima. The discriminator in GAN-NS underfits the data.
CNN based discriminators do not create flat regions in the output landscape (Fig. 9b-9d and 10b-10d). However, when the dataset is more complicated, DCGAN-NS discriminator fails to make real datapoints local maxima and the training does not converge (Fig. 10a-10g). The discriminator underfits the data because it is not powerful enough to learn features that separate real and fake/noisy samples. More powerful discriminators based on ResNet significantly improve the quality of GANs (e.g. ). We make the following observation:
For a GAN to converge to a good local equilibrium, real datapoints should be wide local maxima of the discriminator.
IV-B The effect of catastrophic forgetting on the landscape
We investigate the effect of CF on Dirac GAN , a GAN that learns a 1 dimensional Dirac distribution located at the origin, . In the original Dirac GAN, the discriminator is a linear function with 1 parameter, and the model distribution is a Dirac distribution located at , . is the generator’s parameter. Initially, . At each iteration, the training dataset of Dirac GAN contains two training examples: a real training example , and a fake training example . Gradient updates are applied directly on the fake training example.
The unique equilibrium is . Mescheder et al. showed that the players in Dirac GAN do not converge to an equilibrium (see Fig. 1 in ). To make the game converge to the above equilibrium, the authors proposed R1 gradient penalty which pushes the gradient w.r.t. the real datapoint to (Table I). A high dimensional GAN can be narrowed to a Dirac GAN by considering a pair of real and fake sample and the discriminator’s output along the line connecting these samples (similar to the landscape in Fig. 2-6).
Because the discriminator in the original Dirac GAN is a linear function with a single parameter, the output of Dirac discriminator is always a monotonic function. We consider a generic discriminator which is a 1 hidden layer neural network: where , and is a monotonically increasing activation function such as Leaky ReLU (Fig. 8). At equilibrium, and is any function with a global maximum at . Although can have global maxima (see Fig. 8h), optimizing only on the current task makes a monotonic function (Fig. 8f).
The optimal Dirac discriminator that minimizes in Eqn. 3 is a monotonic function.
Let where be the discriminator and be a non-decreasing activation function such as ReLU, Leaky ReLU, Sigmoid, or Tanh. Let be the real datapoint, be the fake datapoint. The empirically optimal discriminator must maximize the difference .
If is ReLU or Leaky ReLU or Tanh, then , , thus
If is Sigmoid, then and . For both cases, we have
The equality for both Eqn. 1 and 2 is achieved for all cases when and . The optimal discriminator’s parameters are .
Without loss of generality, assume .
Because is monotonic, is monotonic. ∎
Optimizing the performance of pushes it toward , making monotonic (Fig. 8a - 8e). This explains the directional monotonicity of discriminators in Fig. 1a-1d, 2.
Although the discriminator in Fig. 8f minimizes the score of the current fake datapoint, it assigns high scores to (old) fake datapoints on the left of the real datapoint, i.e. it forgets these datapoints. If the discriminator is fixed, then minimizing corresponds to moving to . Dirac GAN with a monotonic discriminator does not converge. When the generator and discriminator are trained with alternating SGD, the two players oscillate around the equilibrium (Fig. 8a - 8e).
The problem can be alleviated if one old fake datapoint is added to the training dataset. Fig. 8g - 8j shows that when old fake example is added, Dirac GAN has better convergence behavior (the small fluctuation is due to the large constant learning rate of 0.1). The discriminator at iteration 10 has a global maximum at the origin. If the discriminator is fixed, then will converge to 0. The experiment suggests that information about previous model distributions helps GANs converge. used a buffer of recent old fake samples to refine reasonably good fake samples. Recent old fake samples reduce the oscillation around the equilibrium, helping GANs to converge faster and produce sharper images. However, because the number of samples needed to capture the statistics of a distribution grows exponentially with it dimensionality, storing old fake datapoints is not efficient for high dimensional data. In the next section, we study more efficient methods for preserving information about old distributions.
V Preventing catastrophic forgetting
Based on the reasons identified in Section III-B, we propose the following ways to address CF problem:
Preserve and use information from previous tasks in the current task.
Introduce prior knowledge to the game in a way such that old knowledge is useful for the new task and is not erased by the new task.
Optimizers with momentum. The update rule of SGD with momentum
The momentum term is a simple form of memory that carries gradient information from previous training iterations to the current iteration. When the discriminator/generator is updated with , the performance of the network on previous tasks is also improved. The effectiveness of momentum in preventing CF is demonstrated in Fig. 1h: the discriminator’s gradient pattern is more stable and similar to those of GAN-0GP and GAN-R1.
Continual learning algorithms such as EWC and online EWC prevent important knowledge of previous tasks from being overwritten by the new task. At the end of a task , online EWC computes the importance of each parameter to the task and adds a regularization term to the loss function of task :
where is the value of at the end of task , balances the importance of the current task and previous tasks, accumulates the importance of throughout the training process. Because consecutive model distributions are similar, we consider a chunk of distributions as a task to the discriminator. The importance is computed every GAN training iteration. The regularizer prevents important weights from deviating too far from the values that are optimal to previous tasks while allowing less important weights to change more freely. It helps the discriminator preserves important information about old distributions. Liang et al. independently proposed a similar way of adapting continual learning methods to GANs. Experiments in the paper showed that continual learning methods improve the quality of GANs.
V-B Introducing prior knowledge to the game
In Dirac GAN, if the discriminator has a local maximum at the real datapoint then it can always classify the real and the fake datapoint correctly, regardless of location of the fake datapoint. Because separating different fake distributions from the target distribution requires the same knowledge, that knowledge will not be erased from the discriminator. We want to introduce to the game the knowledge that real datapoints should be local maxima. R1 and 0GP are two ways to implement that.
R1 regularizer (the third row in Table I) forces the gradients w.r.t. a real datapoint to be , making it a local extremum of the discriminator. As the discriminator maximizes the score of real datapoints, real datapoints become local maxima of the discriminator. Fig. 1e - 1g shows that real datapoints are always local maxima and the gradient pattern of the discriminator stay unchanged as moves toward . Fig. 5 demonstrates the same effect of R1 on MNIST. Note that noisy images that are far away from the real images (e.g. for ) have higher scores than real images. This is because no regularizer is applied to these noisy images.
0GP regularizer (the forth row in Table I) pushes gradients w.r.t. datapoints on the line connecting a real datapoint and a fake datapoint toward . 0GP forces the score to increase gradually as we move from to . During training, is paired with different . Thus, the score is greater than the scores of fake datapoints in a wider neighborhood. That fixes the problem of R1 and creates wider local maxima (Fig. 4, 9). Thanh-Tung et al. showed that GAN-0GP generalizes better than GAN-R1. Although generalization is beyond the scope of this paper, we believe that the sharpness of the discriminator’s landscape is related to its generalization capability. Prior works on generalization of neural networks showed flat (wide) minima of the loss surface generalize better than sharp minima. Creating discriminators with wide local maxima is a good way to improve GANs’ generalizability.
WGAN-GP (the first row in Table I) uses 1-centered gradient penalty (1GP) which pushes gradients w.r.t. datapoints on the line connecting a real datapoint and a fake datapoint toward , forcing the score to increase gradually from to . Fig. 6 shows that real datapoints are local maxima of the discriminator. Wu et al. showed that WGAN-0GP performs slightly better than WGAN-1GP. Our hypothesis is that 0GP creates wider maxima than 1GP as it make the score on the line from to to change more slowly.
Imbalanced weights for real and fake samples. To prevent the discriminator from forgetting distant real datapoints, we propose to increase the weight of the loss for real datapoints:
where is an empirically chosen hyper parameter, are the losses for real and fake samples, respectively. When , the discriminator is penalized more if it assigns a low score to a real datapoint. The situation where real datapoints are local minima like in Fig. 10b or have low scores like in the blue boxes in Fig. 1a - 1b will less likely to happen. Fig. 10k shows that the new loss successfully helps the discriminator to make more real datapoints local maxima and thus improve fake samples’ quality. Table III shows the effectiveness of imbalanced loss on CIFAR-10 dataset: it significantly improves Inception Score and reduces the score’s variance. The imbalanced loss is orthogonal to gradient penalties and can be used to improve gradient penalties (the last two rows in Table III).
VI Conclusion
Catastrophic forgetting is a important problem in GANs. It is directly related to mode collapse and non-convergence. Addressing catastrophic forgetting leads to better convergence and less mode collapse. Methods such as imbalanced loss, zero centered gradient penalties, optimizers with momentum, and continual learning are effective at preventing catastrophic forgetting in GANs. 0GP helps GANs to converge to good local equilibria where real datapoints are wide local maxima of the discriminator. The gradient penalty is a promising method for improving generalizability of GANs.
References
Appendix A Experiments on synthetic datasets
Appendix B Landscapes of different GANs
This section includes figures for different GANs. The general configuration for all experiments are shown in Table IV. Hyper parameters specific to each experiment is shown in the caption of the corresponding figure.
In each figure, the ’Real’ subfloat shows real samples from MNIST dataset. Each cell in a ’Landscape’ subfloat shows a slice of the landscape - the value of , for the corresponding real sample at the specified iteration. Each ’Generated’ subfloat shows the generated samples at that iteration.