Learning Disentangled Joint Continuous and Discrete Representations
Emilien Dupont
Introduction
Disentangled representations are defined as ones where a change in a single unit of the representation corresponds to a change in single factor of variation of the data while being invariant to others (Bengio et al. (2013)). For example, a disentangled representation of 3D objects could contain a set of units each corresponding to a distinct generative factor such as position, color or scale. Most recent work on learning disentangled representations has focused on modeling continuous factors of variation (Higgins et al. (2016); Kim & Mnih (2018); Chen et al. (2018)). However, a large number of datasets contain inherently discrete generative factors which can be difficult to capture with these methods. In image data for example, distinct objects or entities would most naturally be represented by discrete variables, while their position or scale might be represented by continuous variables.
Several machine learning tasks, including transfer learning and zero-shot learning, can benefit from disentangled representations (Lake et al. (2017)). Disentangled representations have also been applied to reinforcement learning (Higgins et al. (2017a)) and for learning visual concepts (Higgins et al. (2017b)). Further, in contrast to most representation learning algorithms, disentangled representations are often interpretable since they align with factors of variation of the data. Different approaches have been explored for semi-supervised or supervised learning of factored representations (Kulkarni et al. (2015); Whitney et al. (2016); Yang et al. (2015); Reed et al. (2014)). These approaches achieve impressive results but either require knowledge of the underlying generative factors or other forms of weak supervision. Several methods also exist for unsupervised disentanglement with the two most prominent being InfoGAN and -VAE (Chen et al. (2016); Higgins et al. (2016)). These frameworks have shown promise in disentangling factors of variation in an unsupervised manner on a number of datasets.
InfoGAN (Chen et al. (2016)) is a framework based on Generative Adversarial Networks (Goodfellow et al. (2014)) which disentangles generative factors by maximizing the mutual information between a subset of latent variables and the generated samples. While this approach is able to model both discrete and continuous factors, it suffers from some of the shortcomings of Generative Adversarial Networks (GAN), such as unstable training and reduced sample diversity. Recent improvements in the training of GANs (Arjovsky et al. (2017); Gulrajani et al. (2017)) have mitigated some of these issues, but stable GAN training still remains a challenge (and this is particularly challenging for InfoGAN as shown in Kim & Mnih (2018)). -VAE (Higgins et al. (2016)), in contrast, is based on Variational Autoencoders (Kingma & Welling (2013); Rezende et al. (2014)) and is stable to train. -VAE, however, can only model continuous latent variables.
In this paper we propose a framework, based on Variational Autoencoders (VAE), that learns disentangled continuous and discrete representations in an unsupervised manner. It comes with the advantages of VAEs, such as stable training, large sample diversity and a principled inference network, while having the flexibility to model a combination of continuous and discrete generative factors. We show how our framework, which we term JointVAE, discovers independent factors of variation on MNIST, FashionMNIST (Xiao et al. (2017)), CelebA (Liu et al. (2015)) and Chairs (Aubry et al. (2014)). For example, on MNIST, JointVAE disentangles digit type (discrete) from slant, width and stroke thickness (continuous). In addition, the model’s learned inference network can infer various properties of data, such as the azimuth of a chair, in an unsupervised manner. The model can also be used for simple image editing, such as rotating a face in an image.
Analysis of β𝛽\beta-VAE
We derive our approach by modifying the -VAE framework and augmenting it with a joint latent distribution. -VAEs model a joint distribution of the data and a set of latent variables and learn continuous disentangled representations by maximizing the objective
We can derive further insights by analyzing the role of the KL divergence term in the objective (1). During training, the objective will be optimized in expectation over the data . The KL term then becomes (Makhzani & Frey (2017); Kim & Mnih (2018))
i.e., when taken in expectation over the data, the KL divergence term is an upper bound on the mutual information between the latents and the data (see appendix for proof and details). Thus, a mini batch estimate of the mean KL divergence is an estimate of the upper bound on the information can transmit about .
Penalizing the mutual information term improves disentanglement but comes at the cost of increased reconstruction error. Recently, several methods have been explored to improve the reconstruction quality without decreasing disentanglement (Burgess et al. (2017); Kim & Mnih (2018); Chen et al. (2018); Gao et al. (2018)). Burgess et al. (2017) in particular propose an objective where the upper bound on the mutual information is controlled and gradually increased during training. Denoting the controlled information capacity by , the objective is defined as
where is a constant which forces the KL divergence term to match the capacity . Gradually increasing during training allows for control of the amount of information the model can encode. This objective has been shown to improve reconstruction quality as compared to (1) without reducing disentanglement (Burgess et al. (2017)).
JointVAE Model
We propose a modification to the -VAE framework which allows us to model a joint distribution of continuous and discrete latent variables. Letting denote a set of continuous latent variables and denote a set of categorical or discrete latent variables, we define a joint posterior , prior and likelihood . The -VAE objective then becomes
where the latent distribution is now jointly continuous and discrete. Assuming the continuous and discrete latent variables are conditionally independent-VAE assumes the data is generated by a fixed number of independent factors of variation, so all latent variables are in fact conditionally independent. However, for the sake of deriving the JointVAE objective we only require conditional independence between the continuous and discrete latents., i.e. and similarly for the prior we can rewrite the KL divergence as
i.e. we can separate the discrete and continuous KL divergence terms (see appendix for proof). Under this assumption, the loss becomes
In our initial experiments, we found that directly optimizing this loss led to the model ignoring the discrete latent variables. Similarly, gradually increasing the channel capacity as in equation (3) leads to the model assigning all capacity to the continuous channels. To overcome this, we split the capacity increase: the capacities of the discrete and continuous latent channels are controlled separately forcing the model to encode information both in the discrete and continuous channels. The final loss is then given by
where and are gradually increased during training.
As in the original VAE framework, we parametrize by a factorised Gaussian, i.e. where and let the prior be a unit Gaussian . and are both parametrized by neural networks.
2 Parametrization of discrete latent variables
Parametrizing is more difficult. Since needs to be differentiable with respect to its parameters, we cannot parametrize by a set of categorical distributions. Recently, Maddison et al. (2016) and Jang et al. (2016) proposed a differentiable relaxation of discrete random variables based on the Gumbel Max trick (Gumbel (1954)). If is a categorical variable with class probabilities , then we can sample from a continuous approximation of the categorical distribution, by sampling a set of i.i.d. and applying the following transformation
where is a temperature parameter which controls the relaxation. The sample is a continuous approximation of the one hot representation of . The relaxed discrete distribution is called a Concrete or Gumbel Softmax distribution and is denoted by where is a vector of class probabilities.
We can parametrize by a product of independent Gumbel Softmax distributions, where is a Gumbel Softmax distribution with class probabilities . We let the prior be equal to a product of uniform Gumbel Softmax distributions. This approach allows us to use the reparametrization trick (Kingma & Welling (2013); Rezende et al. (2014)) and efficiently train the discrete model.
3 Architecture
The final architecture of the JointVAE model is shown in Fig. 1. We build the encoder to output the parameters of the continuous distribution and and of each of the discrete distributions . We then sample and using the reparametrization trick and concatenate and into one latent vector which is passed as input to the decoder.
4 Choice and sensitivity hyperparameters
The JointVAE loss in equation 7 depends on the hyperparameters , and . While the choice of these is ultimately empirical, there are various heuristics we can use to narrow the search. The value of , for example, is chosen so that it is large enough to maintain the capacity at the desired level (e.g. large improvements in reconstruction error should not come at the cost of breaking the capacity constraint). We found the model to be quite robust to changes in . As the capacity of a discrete channel is bounded, is chosen to be the maximum capacity of the channel, encouraging the model to use all categories of the discrete distribution. is more difficult to choose and is often chosen by experiment to be the largest value where the representation is still disentangled (in a similar way that is chosen as the lowest value where the representation is still disentangled in -VAE).
Experiments
We perform experiments on several datasets including MNIST, FashionMNIST, CelebA and Chairs. We parametrize the encoder by a convolutional neural network and the decoder by the same network, transposed (for the full architecture and training details see appendix). The code, along with all experiments and trained models presented in this paper, is available at https://github.com/Schlumberger/joint-vae.
Disentanglement results and latent traversals for MNIST are shown in Fig. 2. The model was trained with 10 continuous latent variables and one discrete 10-dimensional latent variable. The model discovers several factors of variation in the data, such as digit type (discrete), stroke thickness, angle and width (continuous) in an unsupervised manner. As can be seen from the latent traversals in Fig. 2, the trained model is able to generate realistic samples for a large variety of latent settings. Fig. 4(a) shows digits generated by fixing the discrete latent and sampling the continuous latents from the prior , which can be interpreted as sampling from a distribution conditioned on digit type. As can be seen, the samples are diverse, realistic and honor the conditioning.
For a large range of hyperparameters we were not able to achieve disentanglement using the purely continuous -VAE framework (see Fig. 3). This is likely because MNIST has an inherently discrete generative factor (digit type), which -VAE is unable to map onto a continuous latent variable. In contrast, the JointVAE approach allows us to disentangle the discrete factors while maintaining disentanglement of continuous factors. To the best of our knowledge, JointVAE is, apart from InfoGAN, the only framework which disentangles MNIST in a completely unsupervised manner and it does so in a more stable way than InfoGAN.
FashionMNIST
Latent traversals for FashionMNIST are shown in Fig. 4(c). We also used 10 continuous and 1 discrete latent variable for this dataset. FashionMNIST is harder to disentangle as the generative factors for creating clothes are not as clear as the ones for drawing digits. However, JointVAE performs well and discovers interesting dimensions, such as sleeve length, heel size and shirt color. As some of the classes of FashionMNIST are very similar (e.g. shirt and t-shirt are two different classes), not all classes are discovered. However, a significant amount of them are disentangled including dress, t-shirt, trousers, sneakers, bag, ankle boot and so on (see Fig. 4(b)).
CelebA
For CelebA we used a model with 32 continuous latent variables and one 10 dimensional discrete latent variable. As shown in Fig. 5, the JointVAE model discovers various factors of variation including azimuth, age and background color, while being able to generate realistic samples. Different settings of the discrete variable correspond to different facial identities. While the samples are not as sharp as those produced by entangled models, we can still see details in the images such as distinct facial features and skin tones (the trade-off between disentanglement and reconstruction quality is a well known problem which is discussed in Higgins et al. (2016); Burgess et al. (2017); Kim & Mnih (2018); Chen et al. (2018)).
Chairs
For the chairs dataset we used a model with 32 continuous latent variables and 3 binary discrete latent variables. JointVAE discovers several factors of variation such as chair rotation, width and leg style. Furthermore, different settings of the discrete variables correspond to different chair types and colors.
While there is a well defined discrete generative factor for datasets like MNIST and FashionMNIST, it is less clear what exactly would constitute a discrete factor of variation in datasets like CelebA and Chairs. For example, for CelebA, JointVAE maps various facial identities onto the discrete latent variable. However, facial identity is not necessarily discrete and it is possible that such a factor of variation could also be mapped to a continuous latent variable. JointVAE has a clear advantage in disentangling datasets where discrete factors are prominent (as shown in Fig. 3) but when this is not the case using frameworks that only disentangle continuous factors may be sufficient.
1 Quantitative evaluation
We quantitatively evaluate our model on the dSprites dataset using the metric recently proposed by Kim & Mnih (2018). Since the dataset is generated from 1 discrete factor (with 3 categories) and 4 continuous factors, we used a model with 6 continuous latent variables and one 3 dimensional discrete latent variable. The results are shown in table 6. Even though the discrete factor in this dataset is not prominent (in the sense that the different categories have very small differences in pixel space) our model is able to achieve scores close to the current best models. Further, as shown in Fig. 6, our model learns meaningful latent representations. In particular, for the discrete factor of variation, JointVAE is able to better separate the classes than other models.
2 Detecting disentanglement in latent distributions
As noted in Section 2, taken in expectation over data, the KL divergence between the inferred latents and the priors, upper bounds the mutual information between the latent units and the data. Motivated by this, we can plot the KL divergence values for each latent unit averaged over a mini batch of data during training. As various factors of variation are discovered in the data, we would expect the KL divergence of the corresponding latent units to increase. This is shown in Fig. 7(a). As the capacities and are increased, the model is able to encode more and more factors of variation. For MNIST, the first factor to be discovered is digit type, followed by angle and width. This is likely because encoding digit type results in the largest reconstruction error reduction, followed by encoding angle and width and so on.
After training, we can also measure the KL divergence of each latent unit on test data and rank the latent units by their average KL values. This corresponds to ranking the latent units by how much information they are transmitting about . Fig. 7(b) shows the ranked latent units for MNIST and Chairs along with a latent traversal of each unit. As can be seen, the latent units with large information content encode various aspects of the data while latent units with approximately zero KL divergence do not affect the output.
3 The inference network
One of the advantages of JointVAE is that it comes with an inference network . For example, on MNIST we can infer the digit type on test data with 88.7% accuracy by simply looking at the value of the discrete latent variable . Of course, this is completely unsupervised and the accuracy could likely be increased dramatically by using some label information.
Since we are learning several generative factors, the inference network can also be used to infer properties which we do not have labels for. For example, the latent unit corresponding to azimuth on the chairs dataset correlates well with the actual azimuth of unseen chairs. After training a model on the chairs dataset and identifying the latent unit corresponding to azimuth, we can test the inference network on images that were not used during training. This is shown in Fig. 8(a). As can be seen, the latent unit corresponding to rotation infers the angle of the chair even though no labeled data was given (or available) for this task.
The framework can also be used to perform image editing or manipulation. If we wish to rotate the image of a face, we can encode the face with , modify the latent corresponding to azimuth and decode the resulting vector with . Examples of this are shown in Fig. 8(b).
4 Robustness and sensitivity to hyperparameters
While our framework is robust with respect to different architectures and optimizers, it is, like most frameworks for unsupervised disentanglement, fairly sensitive to the choice of hyperparameters (all hyperparameters needed to reproduce the results in this paper are given in the appendix). Even with a good choice of hyperparameters, the quality of disentanglement may vary based on the random seed. In general, it is easy to achieve some degree of disentanglement for a large set of hyperparameters, but achieving complete clean disentanglement (e.g. perfectly separate digit type and other generative factors) can be difficult. It would be interesting to explore more principled approaches for choosing the latent capacities and how to increase them, but we leave this for future work. Further, as mentioned in Section 4, when a discrete generative factor is not present or important, the framework may fail to learn meaningful discrete representations. We have included some failure examples in Fig. 9.
Conclusion
We have proposed JointVAE, a framework for learning disentangled continuous and discrete representations in an unsupervised manner. The framework comes with the advantages of VAEs such as stable training and large sample diversity while being able to model complex jointly continuous and discrete generative factors. We have shown that JointVAE disentangles factors of variation on several datasets while producing realistic samples. In addition, the inference network can be used to infer unlabeled quantities on test data and to edit and manipulate images.
In future work, it would be interesting to combine our approach with recent improvements of the -VAE framework, such as FactorVAE (Kim & Mnih (2018)) or -TCVAE (Chen et al. (2018)). Gaining a deeper understanding of how disentanglement depends on the latent channel capacities and how they are increased will likely provide insights to build more stable models. Finally, it would also be interesting to explore the use of other latent distributions since the framework allows the use of any joint distribution of reparametrizable random variables.
The author would like to thank Erik Burton, José Celaya, Suhas Suresha, Vishakh Hegde and the anonymous reviewers for helpful suggestions and comments that helped improve the paper.
References
Supplementary material
Appendix A Proofs
where the mutual information is defined under the joint distribution of the data and the encoding distribution.
A.2 Splitting the discrete and continuous KL divergence terms
Assuming the continuous and discrete latent variables are conditionally independent, i.e. and similarly for the prior we can rewrite the joint KL divergence as
Appendix B Model architecture
The architecture of the model is shown in the table. The non linearities in both the encoder and decoder are ReLU except for the output layer of the decoder which is a sigmoid.
For 64 by 64 images (Chairs, CelebA and dSprites) the architecture shown in the table was used. For 32 by 32 images (MNIST and FashionMNIST which were resized from 28 by 28) we used the same architecture with the last conv layer in the encoder and first in the decoder removed.
Appendix C Training details
Parameters and training details for each model.
Latent distribution: 10 continuous, 1 10-dimensional discrete
: Increased linearly from 0 to 5 in 25000 iterations
: Increased linearly from 0 to 5 in 25000 iterations
C.2 FashionMNIST
Latent distribution: 10 continuous, 1 10-dimensional discrete
: Increased linearly from 0 to 5 in 50000 iterations
: Increased linearly from 0 to 10 in 50000 iterations
C.3 Chairs
Latent distribution: 32 continuous, 3 binary discrete
: Increased linearly from 0 to 30 in 100000 iterations
: Increased linearly from 0 to 5 in 100000 iterations
C.4 CelebA
Latent distribution: 32 continuous, 1 10-dimensional discrete
: Increased linearly from 0 to 50 in 100000 iterations
: Increased linearly from 0 to 10 in 100000 iterations
C.5 dSprites
Latent distribution: 6 continuous, 1 3-dimensional discrete
: Increased linearly from 0 to 40 in 300000 iterations
: Increased linearly from 0 to 1.1 in 300000 iterations
Note that since the KL divergence between a categorical variable and a uniform categorical variable is bounded, the discrete capacity is clipped during training if exceeds the maximum capacity. Let denote a categorical random variable and let be a uniform categorical variable, then
During training is then clipped as .
Appendix D Things that didn’t work
We experimented with several things which we found did not improve disentanglement of joint continuous and discrete representations.
Modifying the latent distribution in -VAE to include a joint Gaussian and Gumbel-Softmax distribution without changing the loss. This generally resulted in the model ignoring the discrete codes.
Changing the loss function to have a higher on the continuous KL term and a lower on the discrete KL term. For a large combination of , we either found the model to ignore the discrete latent codes or to produce representations where continuous factors were encoded in the discrete latent variables.
In -VAE, there is a larger weight on the KL term than in a traditional VAE model. In most VAE models with a Gumbel-Softmax latent variable, the KL divergence between the Gumbel-Softmax variables is approximated by the KL divergence between the corresponding categorical variables. We hypothesized that the approximation error might be worse in -VAE, since there is a larger weight on the KL term. Unfortunately, there is no closed form expression of the KL divergence between two Gumbel-Softmax variables. We used various approximations of this, but most estimates had very high variance and impeded learning in the model.
Appendix E Choice of discrete dimensions
As discussed in the main section of the paper, it is not clear what exactly would constitute a discrete factor of variation for a dataset like CelebA for example. As such, the choice is somewhat arbitrary: when using a 10 dimensional discrete latent variable, the model encodes 10 facial identities and when using more than 10 dimensions it encodes more identities. This was generally found to be quite robust, except when the number of discrete dimensions was exceedingly large (>100), when the model would start to encode e.g. facial angles in the discrete dimensions. Similarly, the reason we choose a 10 dimensional discrete latent variable for MNIST is because we know a priori that there are 10 types of digits. When choosing less than 10 discrete dimensions on MNIST, the model tends to fuse digit types which look similar into one discrete dimension. For example, 4 and 9 or 5 and 8 may correspond to one discrete dimension instead of being separated. When using more than 10 dimensions, the model tends to separate different writing styles of digits into separate dimensions, e.g. 2’s with and without a curl at the bottom or 7’s with and without a middle stroke may be encoded into different categories.
Appendix F Comparison with InfoGAN
We include comparisons with InfoGAN which can also disentangle joint continuous and discrete factors. InfoGAN successfully disentangles digit type, from angle and width. However, width and stroke thickness remain entangled. Further, InfoGAN models are typically less stable to train.
Appendix G Latent traversals
In all figures latent traversals of continuous variables are from to where is the inverse cdf of a unit normal. Latent traversals of discrete variables are from 1 to the number of dimensions of the variable.