MentorNet: Learning Data-Driven Curriculum for Very Deep Neural Networks on Corrupted Labels
Lu Jiang, Zhengyuan Zhou, Thomas Leung, Li-Jia Li, Li Fei-Fei
Introduction
Zhang et al. (2017a) found that deep convolutional neural networks (CNNs) are capable of memorizing the entire data even with corrupted labels, where some or all true labels are replaced with random labels. It is a consensus that deeper CNNs usually lead to better performance. However, the ability of deep CNNs to overfit or memorize the corrupted labels can lead to very poor generalization performance (Zhang et al., 2017a). Recently, Neyshabur et al. (2017) and Arpit et al. (2017) proposed deep learning generalization theories to explain this interesting phenomenon.
This paper studies how to overcome the corrupted label for deep CNNs, so as to improve generalization performance on the clean test data. Although learning models on weakly labeled data might not be novel, improving deep CNNs on corrupted labels is clearly an under-studied problem and worthy of exploration, as deep CNNs are more prone to overfitting and memorizing corrupted labels (Zhang et al., 2017a). To address this issue, we focus on training very deep CNNs from scratch, such as resnet-101 (He et al., 2016) or inception-resnet (Szegedy et al., 2017) which has a few hundred layers and orders-of-magnitude more parameters than the number of training samples. These networks can achieve the state-of-the-art result but perform poorly when trained on corrupted labels.
Inspired by the recent success of Curriculum Learning (CL), this paper tackles this problem using CL (Bengio et al., 2009), a learning paradigm inspired by the cognitive process of human and animals, in which a model is learned gradually using samples ordered in a meaningful sequence. A curriculum specifies a scheme under which training samples will be gradually learned. CL has successfully improved the performance on a variety of problems. In our problem, our intuition is that a curriculum, similar to its role in education, may provide meaningful supervision to help a student overcome corrupted labels. A reasonable curriculum can help the student focus on the samples whose labels have a high chance of being correct.
However, for the deep CNNs, we need to address two limitations of the existing CL methodology. First, existing curriculums are usually predefined and remain fixed during training, ignoring the feedback from the student. The learning procedure of deep CNNs is quite complicated, and may not be accurately modeled by the predefined curriculum. Second, the alternating minimization, commonly used in CL and self-paced learning (Kumar et al., 2010) requires alternative variable updates, which is difficult for training very deep CNNs via mini-batch stochastic gradient descent.
To this end, we propose a method to learn the curriculum from data by a network called MentorNet. MentorNet learns a data-driven curriculum to supervise the base deep CNN, namely StudentNet. MentorNet can be learned to approximate an existing predefined curriculum or discover new data-driven curriculums from data. The learned data-driven curriculum can be updated a few times taking into account of the StudentNet’s feedback. Whenever MentorNet is learned or updated, we fix its parameter and use it together with StudentNet to minimize the learning objective, where MentorNet controls the timing and attention to learn each sample. At the test time, StudentNet makes predictions alone without MentorNet.
The proposed method improves existing curriculum learning in two aspects. First, our curriculum is learned from data rather than predefined by human experts. It takes into account of the feedback from StudentNet and can be dynamically adjusted during training. Intuitively, this resembles a “collaborative” learning paradigm, where the curriculum is determined by the teacher and student together. Second, in our algorithm, the learning objective is jointly minimized using MentorNet and StudentNet via mini-batch stochastic gradient descent. Therefore, the algorithm can be conveniently parallelized to train deep CNNs on big data. We show the convergence and empirically verify it on large-scale benchmarks.
We verify our method on four benchmarks. Results show that it can significantly improve the performance of deep CNNs trained on both controlled and real-world corrupted training data. Notably, to the best of our knowledge, it achieves the best-published result on WebVision (Li et al., 2017a), a large benchmark containing 2.2 million images of real-world noisy labels. To summarize, the contribution of this paper is threefold:
We propose a novel method to learn data-driven curriculums for deep CNNs trained on corrupted labels.
We discuss an algorithm to perform curriculum learning for deep networks via mini-batch stochastic gradient descent.
We verify our method on 4 benchmarks and achieve the best-published result on the WebVision benchmark.
Preliminary on Curriculum Learning
where is the indicator function. Eq. (2) intuitively explains the predefined curriculum in (Kumar et al., 2010), known as self-paced learning. First, when updating with a fixed , a sample of smaller loss than the threshold is treated as an “easy” sample, and will be selected in training (). Otherwise, it will not be selected (). Second, when updating with a fixed , the classifier is trained only on the selected “easy” samples. The hyperparameter controls the learning pace and corresponds to the “age” of the model. When is small, only samples of small loss will be considered. As grows, more samples of larger loss will be gradually added to train a more “mature” model.
As shown, the function specifies a curriculum, i.e., a sequence of samples with their corresponding weights to be used in training. When is fixed, its optimal solution, e.g. Eq. (2), computes the time-varying weight that controls the timing and attention to learn every sample. Recent studies discovered multiple predefined curriculums and verified them in many real-world applications, e.g., in (Fan et al., 2017; Ma et al., 2017a; Sangineto et al., 2016; Fan et al., 2017; Chang et al., 2017).
This paper studies learning curriculum from data. In the rest of this paper, Section 3 presents an approach to learn data-driven curriculum by MentorNet. Section 4 discusses an algorithm to optimize Eq. (2) using MentorNet and StudentNet together via mini-batch training.
Learning Curriculum from Data
Existing curriculums are either predetermined as an analytic expression of or a function to compute sample weights. Such predefined curriculums cannot be adjusted accordingly, taking into account of the feedback from the student. This section discusses a new way to learn data-driven curriculum by a neural network, called MentorNet. The MentorNet is learned to compute time-varying weights for each training sample. Let denote the parameters in . Given a fixed , our goal is to learn an to compute the weight:
where indicates the input feature to MentorNet about the -th sample.
MentorNet can be learned to 1) approximate existing curriculums or 2) discover new curriculums from data.
Learning to approximate predefined curriculums. Our first task is to learn a MentorNet to approximate a predefined curriculum. To do so, we minimize the objective in Eq. (2):
Eq. (4) applies for both convex and non-convex . This paper employs the following predefined curriculum. It is derived from (Jiang et al., 2015) and works well in our experiments. As will be discussed later, it is also related to robust non-convex penalties.
The information on the correct label may not always be available on the target dataset . In this case, we learn the curriculum on a different small dataset where the correct labels are available. Intuitively, it resembles first learning a teaching strategy with the student on one topic and transfer the strategy on a similar topic. Empirically, Section 5.1 substantiates that the learned curriculum on a small subset of CIFAR-10 can be applied to the target CIFAR-100 dataset.
A burn-in period is introduced before learning . In the first 20% training epoch of the StudentNet, MentorNet is initialized and fixed as , where is the Bernoulli random variable. This is equivalent to randomly dropping out % training samples. We found that the burn-in process helps StudentNet stabilize the prediction and focus on learning simple and common patterns.
The label and the training epoch percentage are encoded by two separate embedding layers. The epoch percentage is represented as an integer between 0 and 99. It is used to indicate the StudentNet’s training progress, where 0 represents the first and 99 represents the last training epoch. The concatenated outputs from the LSTM and the embedding layers are fed into two fully-connected layers , where uses the sigmoid activation to ensure the output weights bounded between 0 and 1. The last layer in Fig. 1 is a probabilistic sampling layer, and is used to implement the sample dropout in the burn-in process on the already learned MentorNet.
2 Discussions
MentorNet is a general framework for both predefined and data-driven curriculum learning, where various curriculums can be learned by the same MentorNet structure with different parameters. This framework is conceptually general and practically flexible as we can switch curriculums by attaching different MentorNets without modifying the pipeline. Therefore, we also learn MentorNets for predefined curriculums. For predefined curriculums where is unknown, we directly minimize the error between the MentorNet’s outputs and desired weights. For example, the desired weight for focal loss (Lin et al., 2017b) is computed by:
where is a hyperparameter for smoothing the distribution.
This paper tackles the problem of overcoming corrupted labels. It is interesting to analyze why the learned curriculum can improve the generalization performance. It turns out that StudentNet, when jointly learned with MentorNet, may optimize an underlying robust objective and the objective is also related to the robust M-estimator (Huber, 2011).
To show this, let represent the optimal weight function for a loss variable , and we define:
As is an approximator to Eq. (9), its property can then be analyzed by the function . Meng et al.(2015) investigated the insights of self-paced objective function, and proved that the optimization of SPL algorithm is intrinsically equivalent to minimizing a robust loss function. They showed that given a fixed and a decreasing with respect to , the underlying objective of Eq. (2) can be obtained by:
Based on it, the underlying learning objective of the curriculum in Eq. (5) can then be derived.
When are fixed and , the underlying objective function of the curriculum in Eq. (5) is calculated from:
where . When it is equivalent to the minimax concave penalty (Zhang, 2010).
For the data-driven curriculum, if the learned MentorNet satisfies certain conditions, we have:
The Algorithm
The alternating minimization algorithm (Csiszar, 1984) used in related work is intractable for deep CNNs, especially on big datasets, for two important reasons. First, in the subroutine of minimizing when fixing , stochastic gradient descent often takes many steps before converging. This means that it can take a long time before moving past this single sub-step. However, such computation is often wasteful, particularly in the initial part of training, because, when is far away from the optimal point, there is not much gain in finding the exact optimal corresponding to this . Second, the subroutine of minimizing when fixing is often difficult, because the fixed vector may not only consume a considerable amount of the memory but also hinder the parallel training on multiple machines. Therefore, optimizing the objective with deep CNNs requires some thought on the algorithmic level.
To minimize Eq. (2), we propose an algorithm called SPADE (Scholastic gradient PArtial DEscent). The algorithm optimizes the StudentNet model parameter jointly with a given MentorNet. It provides a simple and elegant way to minimize and stochastically over mini-batches. As a general approach, it can also take an input of . Let denotes a mini-batch of samples, fetched uniformly at random and represent the sample weights in . The MentorNet computes:
where is the feature extraction function defined in Eq. (3). denotes the learned MentorNet discussed in Section 3.1.
As shown in Algorithm 1, for , a stochastic gradient is computed (via a mini-batch) and applied (Step 12), where is the learning rate. For the latent weight variables , gradient descent is only applied to a small subset thereof parameters corresponding only to the mini-batch (Step 9 or 11). The partial gradient update on weight parameters is performed when is used (Step 9). Otherwise, we directly apply the weights computed by the learned MentorNet (Step 11). In both cases, the weights are computed on-the-fly within a mini-batch and thus do not need to be fixed. As a result, the algorithm can be conveniently parallelized across multiple machines.
The curriculum can change during training. MentorNet is updated a few times in Algorithm 1. In Step 6, the MentorNet parameter is updated to adapt to the most recent model parameters of StudentNet. In experiments, we update twice after the learning rate is changed. Each time, a data-driven curriculum is learned from the data generated by the most recent using the method discussed in Section 3.1. The update is consistent with existing curriculum learning methodology (Bengio et al., 2009; Kumar et al., 2010) and the difference here is that for each update, the curriculum is learned rather than specified by human experts.
Under standard assumptions, Theorem 1 shows that the algorithm stabilizes and converges to a stationary point (convergence to global/local minima cannot be guaranteed unless in specially structured non-convex objectives (Chen et al., 2018; Zhou et al., 2017b, a)). The proof is in Appendix B. The theorem is a characterization of stability of the model parameters . For the weight parameters , as it is restricted in a compact set, convergence to a stationary point is not always guaranteed. As the model parameters are more important, we only provide a detailed characterization of the model parameter.
For the manually designed curriculums, it may be unclear where or even whether such predefined curriculum would converge via mini-batch training. Theorem 1 shows that the learned curriculum can converge and produce a stable StudentNet model. The algorithm can be used to replace the alternating minimization method in related work.
Experiments
This section empirically verifies the proposed method on four benchmarks of controlled corrupted labels in Section 5.1 and real-world noisy labels in Section 5.2. The code can be found at https://github.com/google/mentornet.
This section validates MentorNet on the controlled corrupted label. We follow a common setting in (Zhang et al., 2017a) to train deep CNNs, where the label of each image is independently changed to a uniform random class with probability , where is noise fraction and is set to 0.2, 0.4 and 0.8. The labels of validation data remain clean for evaluation.
Dataset and StudentNet: We use the same benchmarks in (Zhang et al., 2017a): CIFAR-10, CIFAR-100 and ImageNet. CIFAR-10 and CIFAR-100 (Krizhevsky & Hinton, 2009) consist of 32 32 color images arranged in 10 and 100 classes. Both datasets contain 50,000 training and 10,000 validation images. ImageNet ILSVRC2012 (Deng et al., 2009) contain about 1.2 million training and 50k validation images, split into 1,000 classes. Each image is resized to 299x299 with 3 color channels.
We employ 3 recent deep CNNs as our StudentNets: inception (Szegedy et al., 2016), resnet-101 (He et al., 2016) with wide filters (Zagoruyko & Komodakis, 2016) and inception-resnet v2 (Szegedy et al., 2017). Table 1 shows their #model parameters, training, and validation accuracy when we train them on the clean training data (noise). As shown, they achieve reasonable accuracy on each task.
Baselines: MentorNet is compared against the following baselines: FullMode is the standard StudentNet trained using weight decay, dropout (Srivastava et al., 2014) and data augmentation (Krizhevsky et al., 2012). The hyper-parameters are set to the best ones found on the clean training data. Unless specified otherwise, for a fair comparison, the StudentNet with the same hyperparameters is used in all baseline and our model. Forgetting was introduced in (Arpit et al., 2017), in which the dropout parameter is searched in the range of (0.2-0.9). Self-paced (Kumar et al., 2010) and Focal Loss (Lin et al., 2017b) represent well-known predefined curriculums in the literature. We implemented Reed (2014) and Goldberger (Goldberger & Ben-Reuven, 2017) as the recent weakly-supervised learning methods. The above baseline methods are a mixture of the curriculum learning and the recent methods dealing with corrupted labels.
Our Model: MentorNet PD is the network learned using our predefined curriculum in Eq. (5) using no additional clean labels. MentorNet DD is the learned data-driven curriculum. It is trained on 5,000 images of true labels, randomly sampled from the CIFAR-10 training set. The same data are used to learn MentorNet DD on CIFAR-100. Note CIFAR-10 and CIFAR-100 are two different datasets that have not only different classes but also the different number of classes. Therefore, it is fair to compare MentorNet DD with other methods using no true labels on CIFAR-100. Algorithm 1 is used to optimize the StudentNet. The decay factor in computing the loss moving average is set to 0.95. The loss percentile in the moving average is set by the cross-validation. As mentioned, a burn-in process is used in the first 20% training epoch for both MentorNet DD and MentorNet PD. More details are discussed in Appendix E.
We first show the comparison to the baseline method on CIFAR-10 and CIFAR-100 in Table 2. On both datasets, each method is verified with two StudentNets (resnet-101 and inception) under the noise fraction of 0.2, 0.4, and 0.8. As we see on both datasets, MentorNet improves FullModel across different noise fractions, and the learned data-driven curriculum (MentorNet DD) achieves the best results. The improvement is more significant for the deeper CNN model resnet-101. For example, on the CIFAR-10 of 40% noise, MentorNet DD (with resnet-101) yields an absolute 20% gain over FullModel. After inspecting the result, we found that it may be because Mentor DD learns a more appropriate curriculum to give high weights to samples of correct labels. As a result, it helps the StudentNet focus on samples of correct labels. The results indicate that the learned MentorNet can improve the generalization performance of recent deep CNNs, and outperform the predefined curriculums (Self-paced and Focal Loss).
Fig. 4 illustrates the best learned data-driven curriculum in our experiments, where the -axis denotes the weights computed by ; the and axes denote the sample loss and the loss difference to the moving average, where is the loss moving average. Two observations can be found in Fig. 4. First, the learned curriculum changes during the training of the StudentNet. Fig. 4 (a) and (b) are MentorNet learned at different epochs. As shown, (a) assigns greater weights to samples of big loss more aggressively. Second, the learned curriculums in Fig. 4 generally satisfy the condition in Proposition 1, i.e., the weight generally decreases with the loss. It suggests that joint learning of StudentNet and MentorNet optimizes an underlying robust objective.
Table 3 compares to recent published results under the setting: CIFAR of 40% noise fraction. We cite the number in (Azadi et al., 2016), and implement other methods using the same resnet-101 StudentNet. The results show that our result is comparable and even better than the state-of-the-art.
To verify MentorNet for large-scale training, we apply our method on the ImageNet ILSVRC12 (Deng et al., 2009) benchmark to improve the inception-resnet v2 (Szegedy et al., 2017) model. We train the model on the ImageNet of 40% noise. Inspired by (Zhang et al., 2017a), we start with an inception-resnet (NoReg) with no regularization (NoReg) and add weight decay, dropout, and data augmentation to the model. Table 4 shows the comparison. As shown, MentorNet improves the performance of both the inception-resnet without regularization (NoReg) and with full regularization (FullModel). It also outperforms the forgetting baseline (dropout keep probability = 0.2). The results suggest that MentorNet can improve deep CNNs on the large-scale training on corrupted labels.
2 Experiments on real-world noisy labels
To verify MentorNet on real-world noisy labels, we conduct experiments on the large WebVision benchmark (Li et al., 2017a). It contains 2.4 million images of real-world noisy labels, crawled from the web using the 1,000 concepts in ImageNet ILSVRC12. We download the resized images from the official websitehttps://www.vision.ee.ethz.ch/webvision/download.html. The inception-resenet v2 (Szegedy et al., 2017) is used as our StudentNet, trained using a distributed asynchronized momentum optimizer on 50 GPUs. Since the dataset is very big, for quick experiments, we compare baseline methods using the Google image subset on the first 50 classes. We use Mini to denote this subset and Entire for the entire WebVision. All the models are evaluated on the clean ILSVRC12 and WebVision validation set.
Table 5 lists the comparison result. As we see, the proposed MentorNet significantly improves baseline methods on real-world noisy labels. The method marked by the start indicates it uses a pre-trained ImageNet model to obtain additional 30k labels for 118 classes. Following the same protocol, MentorNet* is trained using the additional labels. The results show that our method outperforms the baseline methods on real-world noisy labels. To the best of our knowledge, it achieves the best-published result on the WebVision (Li et al., 2017a) benchmark.
Related Work
Curriculum learning (CL), proposed by Bengio et al. (2009), is a learning paradigm in which a model is learned by gradually including from easy to complex samples in training so as to increase the learning entropy (Bengio et al., 2009). From the human behavioral perspective, Khan et al. (2011) have shown that CL is consistent with the principle of human teaching. CL has been empirically verified in a variety of problems, such as computer vision (Supancic & Ramanan, 2013; Chen & Gupta, 2015), natural language processing (Turian et al., 2010), multitask learning (Graves et al., 2017). A common CL approach is to predefine a curriculum. For example, Kumar et al. (2010) proposed a curriculum called self-paced learning which favors training samples of smaller loss. After that, many predefined curriculums were proposed, e.g., in (Supancic & Ramanan, 2013; Jiang et al., 2014, 2015; Sangineto et al., 2016; Chang et al., 2017; Ma et al., 2017a, b). For example, Jiang et al. (2014) introduced a curriculum of using easy and diverse samples. Fan et al. (2017) proposed to use predefined sample weighting schemes as an implicit way to define a curriculum. Previous work has shown that predefined curriculums are useful in overcoming noisy labels (Chen & Gupta, 2015; Liang et al., 2016; Lin et al., 2017a). In parallel to CL, the sample weighting schemes were also studied in (Lin et al., 2017a; Wang et al., 2017; Fan et al., 2018; Dehghani et al., 2018). Compared to the existing work, our paper presents a new way of learning data-driven curriculums for deep networks trained on corrupted labels.
Our work is related to the weakly-supervised learning methods. Among recent contributions, Reed et al. (2014) developed a robust loss to model “prediction consistency”. Menon et al. (2015) used class-probability estimation to study the corruption process. Sukhbaatar et al. (2014) proposed a noise transformation to estimate the noise distribution. The transformation matrix needs to be periodically updated and is non-trivial to learn. To address the issue, Goldberger et al. (2017) proposed to add an additional softmax layer end-to-end with the base model. Azadi et al. (2016) tackled this problem by a regularizer called AIR. This method was shown to be effective but it relied on additional clean labels to train the representation. More recently, methods utilized additional labels for label cleaning (Veit et al., 2017), knowledge distillation (Li et al., 2017b) or semi-supervised learning (Vahdat, 2017; Dehghani et al., 2017). Different from previous work, we focus on learning curriculum to train very deep CNNs on corrupted labels from scratch. In addition, clean labels are not always needed for our method. In Section 5.1, the MentorNet is learned on a small subset of CIFAR-10 and applied to CIFAR-100
Conclusions
In this paper, we presented a novel method for training deep CNNs on corrupted labels. Our work was built on curriculum learning and advanced the methodology by proposing to learn data-driven curriculum via a neural network called MentorNet. We proposed an algorithm for jointly optimizing deep CNNs with MentorNet on large-scale data. We conducted comprehensive experiments on datasets of controlled and real-world noise. Our empirical results showed that generalization performance of deep CNNs trained on corrupted labels can be effectively improved by the learned data-driven curriculum.
Acknowledgements
The authors would like to thank anonymous reviewers for helpful comments and Deyu Meng, Sergey Ioffe, and Chong Wang for meaningful discussions and kind support.