Learning Disentangled Joint Continuous and Discrete Representations

Emilien Dupont

Introduction

Disentangled representations are defined as ones where a change in a single unit of the representation corresponds to a change in single factor of variation of the data while being invariant to others (Bengio et al. (2013)). For example, a disentangled representation of 3D objects could contain a set of units each corresponding to a distinct generative factor such as position, color or scale. Most recent work on learning disentangled representations has focused on modeling continuous factors of variation (Higgins et al. (2016); Kim & Mnih (2018); Chen et al. (2018)). However, a large number of datasets contain inherently discrete generative factors which can be difficult to capture with these methods. In image data for example, distinct objects or entities would most naturally be represented by discrete variables, while their position or scale might be represented by continuous variables.

Several machine learning tasks, including transfer learning and zero-shot learning, can benefit from disentangled representations (Lake et al. (2017)). Disentangled representations have also been applied to reinforcement learning (Higgins et al. (2017a)) and for learning visual concepts (Higgins et al. (2017b)). Further, in contrast to most representation learning algorithms, disentangled representations are often interpretable since they align with factors of variation of the data. Different approaches have been explored for semi-supervised or supervised learning of factored representations (Kulkarni et al. (2015); Whitney et al. (2016); Yang et al. (2015); Reed et al. (2014)). These approaches achieve impressive results but either require knowledge of the underlying generative factors or other forms of weak supervision. Several methods also exist for unsupervised disentanglement with the two most prominent being InfoGAN and β\beta-VAE (Chen et al. (2016); Higgins et al. (2016)). These frameworks have shown promise in disentangling factors of variation in an unsupervised manner on a number of datasets.

InfoGAN (Chen et al. (2016)) is a framework based on Generative Adversarial Networks (Goodfellow et al. (2014)) which disentangles generative factors by maximizing the mutual information between a subset of latent variables and the generated samples. While this approach is able to model both discrete and continuous factors, it suffers from some of the shortcomings of Generative Adversarial Networks (GAN), such as unstable training and reduced sample diversity. Recent improvements in the training of GANs (Arjovsky et al. (2017); Gulrajani et al. (2017)) have mitigated some of these issues, but stable GAN training still remains a challenge (and this is particularly challenging for InfoGAN as shown in Kim & Mnih (2018)). β\beta-VAE (Higgins et al. (2016)), in contrast, is based on Variational Autoencoders (Kingma & Welling (2013); Rezende et al. (2014)) and is stable to train. β\beta-VAE, however, can only model continuous latent variables.

In this paper we propose a framework, based on Variational Autoencoders (VAE), that learns disentangled continuous and discrete representations in an unsupervised manner. It comes with the advantages of VAEs, such as stable training, large sample diversity and a principled inference network, while having the flexibility to model a combination of continuous and discrete generative factors. We show how our framework, which we term JointVAE, discovers independent factors of variation on MNIST, FashionMNIST (Xiao et al. (2017)), CelebA (Liu et al. (2015)) and Chairs (Aubry et al. (2014)). For example, on MNIST, JointVAE disentangles digit type (discrete) from slant, width and stroke thickness (continuous). In addition, the model’s learned inference network can infer various properties of data, such as the azimuth of a chair, in an unsupervised manner. The model can also be used for simple image editing, such as rotating a face in an image.

Analysis of β𝛽\beta-VAE

We derive our approach by modifying the β\beta-VAE framework and augmenting it with a joint latent distribution. β\beta-VAEs model a joint distribution of the data x\mathbf{x} and a set of latent variables z\mathbf{z} and learn continuous disentangled representations by maximizing the objective

We can derive further insights by analyzing the role of the KL divergence term in the objective (1). During training, the objective will be optimized in expectation over the data x\mathbf{x}. The KL term then becomes (Makhzani & Frey (2017); Kim & Mnih (2018))

i.e., when taken in expectation over the data, the KL divergence term is an upper bound on the mutual information between the latents and the data (see appendix for proof and details). Thus, a mini batch estimate of the mean KL divergence is an estimate of the upper bound on the information z\mathbf{z} can transmit about x\mathbf{x}.

Penalizing the mutual information term improves disentanglement but comes at the cost of increased reconstruction error. Recently, several methods have been explored to improve the reconstruction quality without decreasing disentanglement (Burgess et al. (2017); Kim & Mnih (2018); Chen et al. (2018); Gao et al. (2018)). Burgess et al. (2017) in particular propose an objective where the upper bound on the mutual information is controlled and gradually increased during training. Denoting the controlled information capacity by CC, the objective is defined as

where γ\gamma is a constant which forces the KL divergence term to match the capacity CC. Gradually increasing CC during training allows for control of the amount of information the model can encode. This objective has been shown to improve reconstruction quality as compared to (1) without reducing disentanglement (Burgess et al. (2017)).

JointVAE Model

We propose a modification to the β\beta-VAE framework which allows us to model a joint distribution of continuous and discrete latent variables. Letting z\mathbf{z} denote a set of continuous latent variables and c\mathbf{c} denote a set of categorical or discrete latent variables, we define a joint posterior qϕ(z,cx)q_{\phi}(\mathbf{z},\mathbf{c}|\mathbf{x}), prior p(z,c)p(\mathbf{z},\mathbf{c}) and likelihood pθ(xz,c)p_{\theta}(\mathbf{x}|\mathbf{z},\mathbf{c}). The β\beta-VAE objective then becomes

where the latent distribution is now jointly continuous and discrete. Assuming the continuous and discrete latent variables are conditionally independentβ\beta-VAE assumes the data is generated by a fixed number of independent factors of variation, so all latent variables are in fact conditionally independent. However, for the sake of deriving the JointVAE objective we only require conditional independence between the continuous and discrete latents., i.e. qϕ(z,cx)=qϕ(zx)qϕ(cx)q_{\phi}(\mathbf{z},\mathbf{c}|\mathbf{x})=q_{\phi}(\mathbf{z}|\mathbf{x})q_{\phi}(\mathbf{c}|\mathbf{x}) and similarly for the prior p(z,c)=p(z)p(c)p(\mathbf{z},\mathbf{c})=p(\mathbf{z})p(\mathbf{c}) we can rewrite the KL divergence as

i.e. we can separate the discrete and continuous KL divergence terms (see appendix for proof). Under this assumption, the loss becomes

In our initial experiments, we found that directly optimizing this loss led to the model ignoring the discrete latent variables. Similarly, gradually increasing the channel capacity as in equation (3) leads to the model assigning all capacity to the continuous channels. To overcome this, we split the capacity increase: the capacities of the discrete and continuous latent channels are controlled separately forcing the model to encode information both in the discrete and continuous channels. The final loss is then given by

where CzC_{z} and CcC_{c} are gradually increased during training.

As in the original VAE framework, we parametrize qϕ(zx)q_{\phi}(\mathbf{z}|\mathbf{x}) by a factorised Gaussian, i.e. qϕ(zx)=iqϕ(zix)q_{\phi}(\mathbf{z}|\mathbf{x})=\prod_{i}q_{\phi}(z_{i}|\mathbf{x}) where qϕ(zix)=N(μi,σi2)q_{\phi}(z_{i}|\mathbf{x})=\mathcal{N}(\mu_{i},\sigma_{i}^{2}) and let the prior be a unit Gaussian p(z)=N(0,I)p(\mathbf{z})=\mathcal{N}(0,I). μ\boldsymbol{\mu} and σ2\boldsymbol{\sigma^{2}} are both parametrized by neural networks.

2 Parametrization of discrete latent variables

Parametrizing qϕ(cx)q_{\phi}(\mathbf{c}|\mathbf{x}) is more difficult. Since qϕ(cx)q_{\phi}(\mathbf{c}|\mathbf{x}) needs to be differentiable with respect to its parameters, we cannot parametrize qϕ(cx)q_{\phi}(\mathbf{c}|\mathbf{x}) by a set of categorical distributions. Recently, Maddison et al. (2016) and Jang et al. (2016) proposed a differentiable relaxation of discrete random variables based on the Gumbel Max trick (Gumbel (1954)). If cc is a categorical variable with class probabilities α1,α2,...,αn\alpha_{1},\alpha_{2},...,\alpha_{n}, then we can sample from a continuous approximation of the categorical distribution, by sampling a set of gkGumbel(0,1)g_{k}\sim\text{Gumbel}(0,1) i.i.d. and applying the following transformation

where τ\tau is a temperature parameter which controls the relaxation. The sample y\mathbf{y} is a continuous approximation of the one hot representation of c\mathbf{c}. The relaxed discrete distribution is called a Concrete or Gumbel Softmax distribution and is denoted by g(α)g(\boldsymbol{\alpha}) where α\boldsymbol{\alpha} is a vector of class probabilities.

We can parametrize qϕ(cx)q_{\phi}(\mathbf{c}|\mathbf{x}) by a product of independent Gumbel Softmax distributions, qϕ(cx)=iqϕ(cix)q_{\phi}(\mathbf{c}|\mathbf{x})=\prod_{i}q_{\phi}(c_{i}|\mathbf{x}) where qϕ(cix)=g(α(i))q_{\phi}(c_{i}|\mathbf{x})=g(\boldsymbol{\alpha}^{(i)}) is a Gumbel Softmax distribution with class probabilities α(i)\boldsymbol{\alpha}^{(i)}. We let the prior p(c)p(\mathbf{c}) be equal to a product of uniform Gumbel Softmax distributions. This approach allows us to use the reparametrization trick (Kingma & Welling (2013); Rezende et al. (2014)) and efficiently train the discrete model.

3 Architecture

The final architecture of the JointVAE model is shown in Fig. 1. We build the encoder to output the parameters of the continuous distribution μ\boldsymbol{\mu} and σ2\boldsymbol{\sigma^{2}} and of each of the discrete distributions α(i)\boldsymbol{\alpha^{(i)}}. We then sample ziN(μi,σi2)z_{i}\sim\mathcal{N}(\mu_{i},\sigma_{i}^{2}) and cig(α(i))c_{i}\sim g(\boldsymbol{\alpha}^{(i)}) using the reparametrization trick and concatenate z\mathbf{z} and c\mathbf{c} into one latent vector which is passed as input to the decoder.

4 Choice and sensitivity hyperparameters

The JointVAE loss in equation 7 depends on the hyperparameters γ\gamma, CcC_{c} and CzC_{z}. While the choice of these is ultimately empirical, there are various heuristics we can use to narrow the search. The value of γ\gamma, for example, is chosen so that it is large enough to maintain the capacity at the desired level (e.g. large improvements in reconstruction error should not come at the cost of breaking the capacity constraint). We found the model to be quite robust to changes in γ\gamma. As the capacity of a discrete channel is bounded, CcC_{c} is chosen to be the maximum capacity of the channel, encouraging the model to use all categories of the discrete distribution. CzC_{z} is more difficult to choose and is often chosen by experiment to be the largest value where the representation is still disentangled (in a similar way that β\beta is chosen as the lowest value where the representation is still disentangled in β\beta-VAE).

Experiments

We perform experiments on several datasets including MNIST, FashionMNIST, CelebA and Chairs. We parametrize the encoder by a convolutional neural network and the decoder by the same network, transposed (for the full architecture and training details see appendix). The code, along with all experiments and trained models presented in this paper, is available at https://github.com/Schlumberger/joint-vae.

Disentanglement results and latent traversals for MNIST are shown in Fig. 2. The model was trained with 10 continuous latent variables and one discrete 10-dimensional latent variable. The model discovers several factors of variation in the data, such as digit type (discrete), stroke thickness, angle and width (continuous) in an unsupervised manner. As can be seen from the latent traversals in Fig. 2, the trained model is able to generate realistic samples for a large variety of latent settings. Fig. 4(a) shows digits generated by fixing the discrete latent and sampling the continuous latents from the prior p(z)=N(0,1)p(\mathbf{z})=\mathcal{N}(0,1), which can be interpreted as sampling from a distribution conditioned on digit type. As can be seen, the samples are diverse, realistic and honor the conditioning.

For a large range of hyperparameters we were not able to achieve disentanglement using the purely continuous β\beta-VAE framework (see Fig. 3). This is likely because MNIST has an inherently discrete generative factor (digit type), which β\beta-VAE is unable to map onto a continuous latent variable. In contrast, the JointVAE approach allows us to disentangle the discrete factors while maintaining disentanglement of continuous factors. To the best of our knowledge, JointVAE is, apart from InfoGAN, the only framework which disentangles MNIST in a completely unsupervised manner and it does so in a more stable way than InfoGAN.

FashionMNIST

Latent traversals for FashionMNIST are shown in Fig. 4(c). We also used 10 continuous and 1 discrete latent variable for this dataset. FashionMNIST is harder to disentangle as the generative factors for creating clothes are not as clear as the ones for drawing digits. However, JointVAE performs well and discovers interesting dimensions, such as sleeve length, heel size and shirt color. As some of the classes of FashionMNIST are very similar (e.g. shirt and t-shirt are two different classes), not all classes are discovered. However, a significant amount of them are disentangled including dress, t-shirt, trousers, sneakers, bag, ankle boot and so on (see Fig. 4(b)).

CelebA

For CelebA we used a model with 32 continuous latent variables and one 10 dimensional discrete latent variable. As shown in Fig. 5, the JointVAE model discovers various factors of variation including azimuth, age and background color, while being able to generate realistic samples. Different settings of the discrete variable correspond to different facial identities. While the samples are not as sharp as those produced by entangled models, we can still see details in the images such as distinct facial features and skin tones (the trade-off between disentanglement and reconstruction quality is a well known problem which is discussed in Higgins et al. (2016); Burgess et al. (2017); Kim & Mnih (2018); Chen et al. (2018)).

Chairs

For the chairs dataset we used a model with 32 continuous latent variables and 3 binary discrete latent variables. JointVAE discovers several factors of variation such as chair rotation, width and leg style. Furthermore, different settings of the discrete variables correspond to different chair types and colors.

While there is a well defined discrete generative factor for datasets like MNIST and FashionMNIST, it is less clear what exactly would constitute a discrete factor of variation in datasets like CelebA and Chairs. For example, for CelebA, JointVAE maps various facial identities onto the discrete latent variable. However, facial identity is not necessarily discrete and it is possible that such a factor of variation could also be mapped to a continuous latent variable. JointVAE has a clear advantage in disentangling datasets where discrete factors are prominent (as shown in Fig. 3) but when this is not the case using frameworks that only disentangle continuous factors may be sufficient.

1 Quantitative evaluation

We quantitatively evaluate our model on the dSprites dataset using the metric recently proposed by Kim & Mnih (2018). Since the dataset is generated from 1 discrete factor (with 3 categories) and 4 continuous factors, we used a model with 6 continuous latent variables and one 3 dimensional discrete latent variable. The results are shown in table 6. Even though the discrete factor in this dataset is not prominent (in the sense that the different categories have very small differences in pixel space) our model is able to achieve scores close to the current best models. Further, as shown in Fig. 6, our model learns meaningful latent representations. In particular, for the discrete factor of variation, JointVAE is able to better separate the classes than other models.

2 Detecting disentanglement in latent distributions

As noted in Section 2, taken in expectation over data, the KL divergence between the inferred latents qϕ(z,cx)q_{\phi}(\mathbf{z},\mathbf{c}|\mathbf{x}) and the priors, upper bounds the mutual information between the latent units and the data. Motivated by this, we can plot the KL divergence values for each latent unit averaged over a mini batch of data during training. As various factors of variation are discovered in the data, we would expect the KL divergence of the corresponding latent units to increase. This is shown in Fig. 7(a). As the capacities CzC_{z} and CcC_{c} are increased, the model is able to encode more and more factors of variation. For MNIST, the first factor to be discovered is digit type, followed by angle and width. This is likely because encoding digit type results in the largest reconstruction error reduction, followed by encoding angle and width and so on.

After training, we can also measure the KL divergence of each latent unit on test data and rank the latent units by their average KL values. This corresponds to ranking the latent units by how much information they are transmitting about x\mathbf{x}. Fig. 7(b) shows the ranked latent units for MNIST and Chairs along with a latent traversal of each unit. As can be seen, the latent units with large information content encode various aspects of the data while latent units with approximately zero KL divergence do not affect the output.

3 The inference network

One of the advantages of JointVAE is that it comes with an inference network qϕ(z,cx)q_{\phi}(\mathbf{z},\mathbf{c}|\mathbf{x}). For example, on MNIST we can infer the digit type on test data with 88.7% accuracy by simply looking at the value of the discrete latent variable qϕ(cx)q_{\phi}(\mathbf{c}|\mathbf{x}). Of course, this is completely unsupervised and the accuracy could likely be increased dramatically by using some label information.

Since we are learning several generative factors, the inference network can also be used to infer properties which we do not have labels for. For example, the latent unit corresponding to azimuth on the chairs dataset correlates well with the actual azimuth of unseen chairs. After training a model on the chairs dataset and identifying the latent unit corresponding to azimuth, we can test the inference network on images that were not used during training. This is shown in Fig. 8(a). As can be seen, the latent unit corresponding to rotation infers the angle of the chair even though no labeled data was given (or available) for this task.

The framework can also be used to perform image editing or manipulation. If we wish to rotate the image of a face, we can encode the face with qϕq_{\phi}, modify the latent corresponding to azimuth and decode the resulting vector with pθp_{\theta}. Examples of this are shown in Fig. 8(b).

4 Robustness and sensitivity to hyperparameters

While our framework is robust with respect to different architectures and optimizers, it is, like most frameworks for unsupervised disentanglement, fairly sensitive to the choice of hyperparameters (all hyperparameters needed to reproduce the results in this paper are given in the appendix). Even with a good choice of hyperparameters, the quality of disentanglement may vary based on the random seed. In general, it is easy to achieve some degree of disentanglement for a large set of hyperparameters, but achieving complete clean disentanglement (e.g. perfectly separate digit type and other generative factors) can be difficult. It would be interesting to explore more principled approaches for choosing the latent capacities and how to increase them, but we leave this for future work. Further, as mentioned in Section 4, when a discrete generative factor is not present or important, the framework may fail to learn meaningful discrete representations. We have included some failure examples in Fig. 9.

Conclusion

We have proposed JointVAE, a framework for learning disentangled continuous and discrete representations in an unsupervised manner. The framework comes with the advantages of VAEs such as stable training and large sample diversity while being able to model complex jointly continuous and discrete generative factors. We have shown that JointVAE disentangles factors of variation on several datasets while producing realistic samples. In addition, the inference network can be used to infer unlabeled quantities on test data and to edit and manipulate images.

In future work, it would be interesting to combine our approach with recent improvements of the β\beta-VAE framework, such as FactorVAE (Kim & Mnih (2018)) or β\beta-TCVAE (Chen et al. (2018)). Gaining a deeper understanding of how disentanglement depends on the latent channel capacities and how they are increased will likely provide insights to build more stable models. Finally, it would also be interesting to explore the use of other latent distributions since the framework allows the use of any joint distribution of reparametrizable random variables.

The author would like to thank Erik Burton, José Celaya, Suhas Suresha, Vishakh Hegde and the anonymous reviewers for helpful suggestions and comments that helped improve the paper.

References

Supplementary material

Appendix A Proofs

where the mutual information is defined under the joint distribution of the data and the encoding distribution.

A.2 Splitting the discrete and continuous KL divergence terms

Assuming the continuous and discrete latent variables are conditionally independent, i.e. qϕ(z,cx)=qϕ(zx)qϕ(cx)q_{\phi}(\mathbf{z},\mathbf{c}|\mathbf{x})=q_{\phi}(\mathbf{z}|\mathbf{x})q_{\phi}(\mathbf{c}|\mathbf{x}) and similarly for the prior p(z,c)=p(z)p(c)p(\mathbf{z},\mathbf{c})=p(\mathbf{z})p(\mathbf{c}) we can rewrite the joint KL divergence as

Appendix B Model architecture

The architecture of the model is shown in the table. The non linearities in both the encoder and decoder are ReLU except for the output layer of the decoder which is a sigmoid.

For 64 by 64 images (Chairs, CelebA and dSprites) the architecture shown in the table was used. For 32 by 32 images (MNIST and FashionMNIST which were resized from 28 by 28) we used the same architecture with the last conv layer in the encoder and first in the decoder removed.

Appendix C Training details

Parameters and training details for each model.

Latent distribution: 10 continuous, 1 10-dimensional discrete

CzC_{z}: Increased linearly from 0 to 5 in 25000 iterations

CcC_{c}: Increased linearly from 0 to 5 in 25000 iterations

C.2 FashionMNIST

Latent distribution: 10 continuous, 1 10-dimensional discrete

CzC_{z}: Increased linearly from 0 to 5 in 50000 iterations

CcC_{c}: Increased linearly from 0 to 10 in 50000 iterations

C.3 Chairs

Latent distribution: 32 continuous, 3 binary discrete

CzC_{z}: Increased linearly from 0 to 30 in 100000 iterations

CcC_{c}: Increased linearly from 0 to 5 in 100000 iterations

C.4 CelebA

Latent distribution: 32 continuous, 1 10-dimensional discrete

CzC_{z}: Increased linearly from 0 to 50 in 100000 iterations

CcC_{c}: Increased linearly from 0 to 10 in 100000 iterations

C.5 dSprites

Latent distribution: 6 continuous, 1 3-dimensional discrete

CzC_{z}: Increased linearly from 0 to 40 in 300000 iterations

CcC_{c}: Increased linearly from 0 to 1.1 in 300000 iterations

Note that since the KL divergence between a categorical variable and a uniform categorical variable is bounded, the discrete capacity is clipped during training if CcC_{c} exceeds the maximum capacity. Let PP denote a categorical random variable and let QQ be a uniform categorical variable, then

During training CcC_{c} is then clipped as Cc=min(Cc,logn)C_{c}=\min(C_{c},\log n).

Appendix D Things that didn’t work

We experimented with several things which we found did not improve disentanglement of joint continuous and discrete representations.

Modifying the latent distribution in β\beta-VAE to include a joint Gaussian and Gumbel-Softmax distribution without changing the loss. This generally resulted in the model ignoring the discrete codes.

Changing the loss function to have a higher β\beta on the continuous KL term and a lower β\beta on the discrete KL term. For a large combination of β\beta, we either found the model to ignore the discrete latent codes or to produce representations where continuous factors were encoded in the discrete latent variables.

In β\beta-VAE, there is a larger weight on the KL term than in a traditional VAE model. In most VAE models with a Gumbel-Softmax latent variable, the KL divergence between the Gumbel-Softmax variables is approximated by the KL divergence between the corresponding categorical variables. We hypothesized that the approximation error might be worse in β\beta-VAE, since there is a larger weight on the KL term. Unfortunately, there is no closed form expression of the KL divergence between two Gumbel-Softmax variables. We used various approximations of this, but most estimates had very high variance and impeded learning in the model.

Appendix E Choice of discrete dimensions

As discussed in the main section of the paper, it is not clear what exactly would constitute a discrete factor of variation for a dataset like CelebA for example. As such, the choice is somewhat arbitrary: when using a 10 dimensional discrete latent variable, the model encodes 10 facial identities and when using more than 10 dimensions it encodes more identities. This was generally found to be quite robust, except when the number of discrete dimensions was exceedingly large (>100), when the model would start to encode e.g. facial angles in the discrete dimensions. Similarly, the reason we choose a 10 dimensional discrete latent variable for MNIST is because we know a priori that there are 10 types of digits. When choosing less than 10 discrete dimensions on MNIST, the model tends to fuse digit types which look similar into one discrete dimension. For example, 4 and 9 or 5 and 8 may correspond to one discrete dimension instead of being separated. When using more than 10 dimensions, the model tends to separate different writing styles of digits into separate dimensions, e.g. 2’s with and without a curl at the bottom or 7’s with and without a middle stroke may be encoded into different categories.

Appendix F Comparison with InfoGAN

We include comparisons with InfoGAN which can also disentangle joint continuous and discrete factors. InfoGAN successfully disentangles digit type, from angle and width. However, width and stroke thickness remain entangled. Further, InfoGAN models are typically less stable to train.

Appendix G Latent traversals

In all figures latent traversals of continuous variables are from Φ1(0.05)\Phi^{-1}(0.05) to Φ1(0.95)\Phi^{-1}(0.95) where Φ1\Phi^{-1} is the inverse cdf of a unit normal. Latent traversals of discrete variables are from 1 to the number of dimensions of the variable.