Semi-Supervised Learning with Context-Conditional Generative Adversarial Networks

Remi Denton, Sam Gross, Rob Fergus

Introduction

Deep neural networks have yielded dramatic performance gains in recent years on tasks such as object classification (Krizhevsky et al., 2012), text classification (Zhang et al., 2015) and machine translation (Sutskever et al., 2014; Bahdanau et al., 2015). These successes are heavily dependent on large training sets of manually annotated data. In many settings however, such large collections of labels may not be readily available, motivating the need for methods that can learn from data where labels are rare.

We propose a method for harnessing unlabeled image data based on image in-painting. A generative model is trained to generate pixels within a missing hole, based on the context provided by surrounding parts of the image. These in-painted images are then used in an adversarial setting (Goodfellow et al., 2014) to train a large discriminator model whose task is to determine if the image was real (from the unlabeled training set) or fake (an in-painted image). The realistic looking fake examples provided by the generative model cause the discriminator to learn features that generalize to the related task of classifying objects. Thus adversarial training for the in-painting task can be used to regularize large discriminative models during supervised training on a handful of labeled images.

Other forms of spatial context within images have recently been utilized for representation learning. Doersch et al. (2015) propose training a CNN to predict the spatial location of one image patch relative to another. Noroozi & Favaro (2016) propose a model that learns by unscrambling image patches, essentially solving a jigsaw puzzle to learn visual representations. In the text domain, context has been successfully leveraged as an unsupervised criterion for training useful word and sentence level representations (Collobert et al., 2011; Mikolov et al., 2015; Kiros et al., 2015).

Deep unsupervised and semi-supervised learning: A popular method of utilizing unlabeled data is to layer-wise train a deep autoencoder or restricted Botlzmann machine (Hinton et al., 2006) and then fine tune with labels on a discriminative task. More recently, several autoencoding variants have been proposed for unsupervised and semi-supervised learning, such as the ladder network (Rasmus et al., 2015), stacked what-where autoencoders (Zhao et al., 2016) and variational autoencoders (Kingma & Welling, 2014; Kingma et al., 2014).

Dosovitskiy et al. (2014) achieved state-of-the-art results by training a CNN with a different class for each training example and introducing a set of transformations to provide multiple examples per class. The pseudo-label approach (Lee, 2013) is a simple semi-supervised method that trains using the maximumly predicted class as a label when labels are unavailable. Springenberg (2015) propose a categorical generative adversarial network (CatGAN) which can be used for unsupervised and semi-supervised learning. The discriminator in a CatGAN outputs a distribution over classes and is trained to minimize the predicted entropy for real data and maximize the predicted entropy for fake data. Similar to our model, CatGANs use the feature space learned by the discriminator for the final supervised learning task. Salimans et al. (2016) recently proposed a semi-supervised GAN model in which the discriminator outputs a softmax over classes rather than a probability of real vs. fake. An additional ‘generated’ class is used as the target for generated samples. This method differs from our work in that it does not utilize context information and has only been applied to datasets of small resolution. However, the discriminator loss is similar to the one we propose and could be combined with our context-conditional approach.

More traditional semi-supervised methods include graph-based approaches (Zhou et al., 2004; Zhu, 2006) that show impressive performance when good image representations are available. However, the focus of our work is on learning such representations.

Generative models of images: Restricted Boltzmann machines (Salakhutdinov, 2015), de-noising autoencoders (Vincent et al., 2008) and variational autoencoders (Kingma & Welling, 2014) optimize a maximum likelihood criterion and thus learn decoders that map from latent space to image space. More recently, generative adversarial networks (Goodfellow et al., 2014) and generative moment matching networks (Li et al., 2015; Dziugaite et al., 2015) have been proposed. These methods ignore data likelihoods and instead directly train a generative model to produce realistic samples. Several extensions to the generative adversarial network framework have been proposed to scale the approach to larger images (Denton et al., 2015; Radford et al., 2016; Salimans et al., 2016). Our work draws on the insights of Radford et al. (2016) regarding adversarial training practices and architecture for the generator network, as well as the notion that the discriminator can produce useful features for classification tasks.

Other models used recurrent approaches to generate images (Gregor et al., 2015; Theis & Bethge, 2015; Mansimov et al., 2016; van den Oord et al., 2016). Dosovitskiy et al. (2015) trained a CNN to generate objects with different shapes, viewpoints and color. Sohl-Dickstein et al. (2015) propose a generative model based on a reverse diffusion process. While our model does involve image generation, it differs from these approaches in that the main focus is on learning a good representation for classification tasks.

Approach

We present a semi-supervised learning framework built on generative adversarial networks (GANs) of Goodfellow et al. (2014). We first review the generative adversarial network framework and then introduce context conditional generative adversarial networks (CC-GANs). Finally, we show how combining a classification objective and a CC-GAN objective provides a unified framework for semi-supervised learning.

The generative adversarial network approach (Goodfellow et al., 2014) is a framework for training generative models, which we briefly review. It consists of two networks pitted against one another in a two player game: A generative model, GG, is trained to synthesize images resembling the data distribution and a discriminative model, DD, is trained to distinguish between samples drawn from GG and images drawn from the training data.

The conditional generative adversarial network (Mirza & Osindero, 2014) is an extension of the GAN in which both DD and GG receive an additional vector of information y\mathbf{y} as input. The conditional GAN objective is given by:

2 Context-Conditional Generative Adversarial Networks

We propose context-conditional generative adversarial networks (CC-GANs) which are conditional GANs where the generator is trained to fill in a missing image patch and the generator and discriminator are conditioned on the surrounding pixels.

In particular, the generator GG receives as input an image with a randomly masked out patch. The generator outputs an entire image. We fill in the missing patch from the generated output and then pass the completed image into DD. We pass the completed image into DD rather than the context and the patch as two separate inputs so as to prevent DD from simply learning to identify discontinuities along the edge of the missing patch.

3 Combined GAN and CC-GAN

While the generator of the CC-GAN outputs a full image, only a portion of it (corresponding to the missing hole) is seen by the discriminator. In the combined model, which we denote by CC-GAN2, the fake examples for the discriminator include both the in-painted image xI\mathbf{x_{I}} and the full image xG\mathbf{x_{G}} produced by the generator (i.e. not just the missing patch). By combining the GAN and CC-GAN approaches, we introduce a wider array of negative examples to the discriminator. The CC-GAN2 objective given by:

4 Semi-supervised learning with CC-GANs

A common approach to semi-supervised learning is to combine a supervised and unsupervised objective during training. As a result unlabeled data can be leveraged to aid the supervised task.

Intuitively, a GAN discriminator must learn something about the structure of natural images in order to effectively distinguish real from generated images. Recently, Radford et al. (2016) showed that a GAN discriminator learns a hierarchical image representation that is useful for object classification. Such results suggest that combining an unsupervised GAN objective with a supervised classification objective would produce a simple and effective semi-supervised learning method. This approach, denoted by SSL-GAN, is illustrated in Fig. 1(b). The discriminator network receives a gradient from the real/fake loss for every real and generated image. The discriminator also receives a gradient from the classification loss on the subset of (real) images for which labels are available.

Generative adversarial networks have shown impressive performance on many diverse datasets. However, samples are most coherent when the set of images the network is trained on comes from a limited domain (eg. churches or faces). Additionally, it is difficult to train GANs on very large images. Both these issues suggest semi-supervised learning with vanilla GANs may not scale well to datasets of large diverse images. Rather than determining if a full image is real or fake, context conditional GANs address a different task: determining if a part of an image is real or fake given the surrounding context.

Formally, let XL={(x1,y1),...,(xn,yn)}\mathcal{X_{L}}=\{(\mathbf{x}^{1},y^{1}),...,(\mathbf{x}^{n},y^{n})\} denote a dataset of labeled images. Let Dc(x)D_{c}(x) denote the output of the classifier head on the discriminator (see Fig. 1(c) for details). Then the semi-supervised CC-GAN objective is:

The hyperparameter λc\lambda_{c} balances the classification and adversarial losses. We only consider the CC-GAN in the semi-supervised setting and thus drop the SSL notation when referring to this model.

5 Model Architecture and Training Details

The architecture of our generative model, GG, is inspired by the generator architecture of the DCGAN (Radford et al., 2016). The model consists of a sequence of convolutional layers with subsampling (but no pooling) followed by a sequence of fractionally-strided convolutional layers. For the discriminator, DD, we used the VGG-A network (Simonyan & Zisserman, 2015) without the fully connected layers (which we call the VGG-A’ architecture). Details of the generator and discriminator are given in Fig. 2. The input to the generator is an image with a patch zeroed out. In preliminary experiments we also tried passing in a separate mask to the generator to make the missing area more explicit but found this did not effect performance.

Even with the context conditioning it is difficult to generate large image patches that look realistic, making it problematic to scale our approach to high resolution images. To address this, we propose conditioning the generator on both the high resolution image with a missing patch and a low resolution version of the whole image (with no missing region). In this setting, the generator’s task becomes one of super-resolution on a portion of an image. However, the discriminator does not receive the low resolution image and thus is still faced with the same problem of determining if a given in-painting is viable or not. Where indicated, we used this approach in our PASCAL VOC 2007 experiments, with the original image being downsampled by a factor of 4. This provided enough information for the generator to fill in larger holes but not so much that it made the task trivial. This optional low resolution image is illustrated in Fig. 2(left) with the dotted line.

We followed the training procedures of Radford et al. (2016). We used the Adam optimizer (Kingma & Ba, 2015) in all our experiments with learning rate of 0.0002, momentum term β1\beta_{1} of 0.5, and the remaining Adam hyperparameters set to their default values. We set λc=1\lambda_{c}=1 for all experiments.

Experiments

STL-10 is a dataset of 96×\times96 color images with a 1:100 ratio of labeled to unlabeled examples, making it an ideal fit for our semi-supervised learning framework. The training set consists of 5000 labeled images, mapped to 10 pre-defined folds of 1000 images each, and 100,000 unlabeled images. The labeled images belong to 10 classes and were extracted from the ImageNet dataset and the unlabeled images come from a broader distribution of classes. We follow the standard testing protocol and train 10 different models on each of the 10 predefined folds of data. We then evaluate classification accuracy of each model on the test set and report the mean and standard deviation.

We trained our CC-GAN and CC-GAN2 models on 64×\times64 crops of the 96×\times96 image. The hole was 32×\times32 pixels and the location of the hole varied randomly (see Fig. 3(top)). We trained for 100 epochs and then fine-tuned the discriminator on the 96x96 labeled images, stopping when training accuracy reached 100%. As shown in Table 1, the CC-GAN model performs comparably to current state of the art (Dosovitskiy et al., 2014) and the CC-GAN2 model improves upon it.

We also trained two baseline models in an attempt to tease apart the contributions of adversarial training and context conditional adversarial training. The first is a purely supervised training of the VGG-A’ model (the same architecture as the discriminator in the CC-GAN framework). This was trained using a dropout of 0.5 on the final layer and weight decay of 0.001. The performance of this model is significantly worse than the CC-GAN model.

We also trained a semi-supervised GAN (SSL-GAN, see Fig. 1(b)) on STL-10. This consisted of the same discriminator as the CC-GAN (VGG-A’ architecture) and generator from the DCGAN model (Radford et al., 2016). The training setup in this case is identical to the CC-GAN model. The SSL-GAN performs almost as well as the CC-GAN, confirming our hypothesis that the GAN objective is a useful unsupervised criterion.

2 PASCAL VOC classification

In order to compare against other methods that utilize spatial context we ran the CC-GAN model on PASCAL VOC 2007 dataset. This dataset consists of natural images coming from 20 classes. The dataset contains a large amount of variability with objects varying in size, pose, and position. The training and validation sets combined contain 5,011 images, and the test set contains 4,952 images. The standard measure of performance is mean average precision (mAP).

We trained each model on the combined training and validation set for \sim5000 epochs and evaluated on the test set onceHyperparameters were determined by initially training on the training set alone and measuring performance on the validation set.. Following Pathak et al. (2016), we train using random cropping, and then evaluate using the average prediction from 10 random crops.

Our best performing model was trained on images of resolution 128×\times128 with a hole size of 64×\times64 and a low resolution input of size 32×\times32. Table 2 compares our CC-GAN method to other feature learning approaches on the PASCAL test set. It outperforms them, beating the current state of the art (Wang & Gupta, 2015) by 3.8%. It is important to note that our feature extractor is the VGG-A’ model which is larger than the AlexNet architecture (Krizhevsky et al., 2012) used by other approaches in Table 2. However, purely supervised training of the two models reveals that VGG-A’ is less than 2% better than AlexNet. Furthermore, our model outperforms the supervised VGG-A’ baseline by a 7% margin (62.2% vs. 55.2%). This suggests that our gains stem from the CC-GAN method rather than the use of a better architecture.

Table 3 shows the effect of training on different resolutions. The CC-GAN improves over the baseline CNN consistently regardless of image size. We found that conditioning on the low resolution image began to help when the hole size was largest (64×\times64). We hypothesize that the low resolution conditioning would be more important for larger images, potentially allowing the method to scale to larger image sizes than we explored in this work.

3 Inpainting

We now show some sample in-paintings produced by our CC-GAN generators. In our semi-supervised learning experiments on STL-10 we remove a single fixed size hole from the image. The top row of Fig. 3 shows in-paintings produced by this model. We can also explored different masking schemes as illustrated in the remaining rows of Fig. 3 (however these did not improve classification results). In all cases we see that training the generator with the adversarial loss produces sharp semantically plausible in-painting results.

Fig. 4 shows generated images and in-painted images from a model trained with the CC-GAN2 criterion. The output of a CC-GAN generator tends to be corrupted outside the patch used to in-paint the image (since gradients only flow back to the missing patch). However, in the CC-GAN2 model, we see that both the in-painted image and the generated image are coherent and semantically consistent with the masked input image.

Fig. 5 shows in-painted images from a generator trained on 128×\times128 PASCAL images. Fig. 6 shows the effect of adding a low resolution (32×\times32) image as input to the generator. For comparison we also show the result of in-painting by filling in with a bi-linearly upsampled image. Here we see the generator produces high-frequency structure rather than simply learning to copy the low resolution patch.

Discussion

We have presented a simple semi-supervised learning framework based on in-painting with an adversarial loss. The generator in our CC-GAN model is capable of producing semantically meaningful in-paintings and the discriminator performs comparable to or better than existing semi-supervised methods on two classification benchmarks.

Since discrimination of real/fake in-paintings is more closely related to the target task of object classification than extracting a feature representation suitable for in-filling, it is not surprising that we are able to exceed the performance of Pathak et al. (2016) on PASCAL classification. Furthermore, since our model operates on images half the resolution as those used by other approaches (128×\times128 vs. 224×\times244), there is potential for further gains if improvements in the generator resolution can be made. Our models and code are available at https://github.com/edenton/cc-gan.

Acknowledgements: Emily Denton is supported by a Google Fellowship. Rob Fergus is grateful for the support of CIFAR.

References