Deep Variational Reinforcement Learning for POMDPs
Maximilian Igl, Luisa Zintgraf, Tuan Anh Le, Frank Wood, Shimon Whiteson
Introduction
Most deep reinforcement learning (rl) methods assume that the state of the environment is fully observable at every time step. However, this assumption often does not hold in reality, as occlusions and noisy sensors may limit the agent’s perceptual abilities. Such problems can be formalised as partially observable Markov decision processes (pomdps) (Astrom, 1965; Kaelbling et al., 1998). Because we usually do not have access to the true generative model of our environment, there is a need for reinforcement learning methods that can tackle pomdps when only a stream of observations is given, without any prior knowledge of the latent state space or the transition and observation functions.
Pomdps are notoriously hard to solve: since the current observation does in general not carry all relevant information for choosing an action, information must be aggregated over time and in general the entire history must be taken into account.
This history can be encoded either by remembering features of the past (McCallum, 1993) or by performing inference to determine the distribution over possible latent states (Kaelbling et al., 1998). However, the computation of this belief state requires knowledge of the model.
Most previous work in deep learning relies on training a recurrent neural network (rnn) to summarize the past. Examples are the deep recurrent Q-network (drqn) (Hausknecht & Stone, 2015) and the action-specific deep recurrent Q-network (adrqn) (Zhu et al., 2017). Because these approaches are completely model-free, they place a heavy burden on the rnn. Since performing inference implicitly requires a known or learned model, they are likely to summarise the history either by only remembering features of the past or by computing simple heuristics instead of actual belief states. This is often suboptimal in complex tasks. Generalisation is also often easier over beliefs than over trajectories since distinct histories can lead to similar or identical beliefs.
The premise of this work is that deep policy learning for pomdps can be improved by taking less of a black box approach than drqn and adrqn. While we do not want to assume prior knowledge of the transition and observation functions or the latent state representation, we want to allow the agent to learn models of them and infer the belief state using this learned model.
To this end, we propose dvrl, which implements this approach by providing a helpful inductive bias to the agent. In particular, we develop an algorithm that can learn an internal generative model and use it to perform approximate inference to update the belief state. Crucially, the generative model is not only learned based on an elbo objective, but also by how well it enables maximisation of the expected return. This ensures that, unlike in an unsupervised application of variational autoencoders (vaes), the latent state representation and the inference performed on it are suitable for the ultimate control task. Specifically, we develop an approximation to the elbo based on autoencoding sequential Monte Carlo (aesmc) (Le et al., 2018), allowing joint optimisation with the -step policy gradient update. Uncertainty in the belief state is captured by a particle ensemble. A high-level overview of our approach in comparison to previous rnn-based methods is shown in Figure 1.
We evaluate our approach on Mountain Hike and several flickering Atari games. On Mountain Hike, a low dimensional, continuous environment, we can show that dvrl is better than an rnn based approach at inferring the required information from past observations for optimal action selection in a simple setting. Our results on flickering Atari show that this advantage extends to complex environments with high dimensional observation spaces. Here, partial observability is introduced by (1) using only a single frame as input at each time step and (2) returning a blank screen instead of the true frame with probability 0.5.
Background
In this section, we formalise pomdps and provide background on recent advances in vaes that we use. Lastly, we describe the policy gradient loss based on -step learning and A2C.
An agent acts according to its policy which returns the probability of taking action at time , and where and are the observation and action histories, respectively. The agent’s goal is to learn a policy that maximises the expected future return
over trajectories induced by its policyThe trajectory length is stochastic and depends on the time at which the agent-environment interaction ends., where is the discount factor. We follow the convention of setting to no-op (Zhu et al., 2017).
In general, a pomdp agent must condition its actions on the entire history . The exponential growth in of can be addressed, e.g., with suffix trees (McCallum & Ballard, 1996; Shani et al., 2005; Bellemare et al., 2014; Bellemare, 2015; Messias & Whiteson, 2017). However, those approaches suffer from large memory requirements and are only suitable for small discrete observation spaces.
Alternatively, it is possible to infer the filtering distribution , called the belief state. This is a sufficient statistic of the history that can be used as input to an optimal policy . The belief space does not grow exponentially, but the belief update step requires knowledge of the model:
2 Variational Autoencoder
for a family of encoders parameterised by . This objective also forces to approximate the posterior under the learned model. Gradients of (3) are estimated by Monte Carlo sampling with the reparameterisation trick (Kingma & Welling, 2014; Rezende et al., 2014).
3 VAE for Time Series
For sequential data, we assume that a series of latent states gives rise to a series of observations . We consider a family of generative models parameterised by that consists of the initial distribution , transition distribution and observation distribution . Given a family of encoder distributions , we can also estimate the gradient of the elbo term in the same manner as in (3), noting that:
where we slightly abuse notation for by ignoring the fact that we sample from the model for . Le et al. (2018), Maddison et al. (2017) and Naesseth et al. (2018) introduce a new elbo objective based on sequential Monte Carlo (smc) (Doucet & Johansen, 2009) that allows faster learning in time series:
where is the number of particles and is the weight of particle at time . Each particle is a tuple containing a weight and a value which is obtained as follows. Let be samples from for . For , the weights are obtained by resampling the particle set proportionally to the previous weights and computing
4 A2C
One way to learn the parameters of an agent’s policy is to use -step learning with A2C (Dhariwal et al., 2017; Wu et al., 2017), the synchronous simplification of asynchronous advantage actor-critic (A3C) (Mnih et al., 2016). An actor-critic approach can cope with continuous actions and avoids the need to draw state-action sequences from a replay buffer. The method proposed in this paper is however equally applicable to other deep rl algorithms.
For -step learning, starting at time , the current policy performs consecutive steps in parallel environments. The gradient update is based on this mini-batch of size . The target for the value-function , parameterised by , is the appropriately discounted sum of on-policy rewards up until time and the off-policy bootstrapped value . The minus sign denotes that no gradients are propagated through this value. Defining the advantage function as
the A2C loss for the policy parameters at time is
and the value function loss to learn can be written as
Lastly, an entropy loss is added to encourage exploration:
where is the entropy of a distribution.
Deep Variational Reinforcement Learning
Fundamentally, there are two approaches to aggregating the history in the presence of partial observability: remembering features of the past or maintaining beliefs.
In most previous work, including adrqn (Zhu et al., 2017), the current history is encoded by an rnn, which leads to the recurrent update equation for the latent state :
Since this approach is model-free, it is unlikely to approximate belief update steps, instead relying on memory or simple heuristics.
Inspired by the premise that a good way to solve many pomdps involves (1) estimating the transition and observation model of the environment, (2) performing inference under this model, and (3) choosing an action based on the inferred belief state, we propose deep variational reinforcement learning (dvrl). It extends the rnn-based approach to explicitly support belief inference. Training everything end-to-end shapes the learned model to be useful for the rl task at hand, and not only for predicting observations.
We first explain our baseline architecture and training method in Section 3.1. For a fair comparison, we modify the original architecture of Zhu et al. (2017) in several ways. We find that our new baseline outperforms their reported results in the majority of cases.
In Sections 3.2 and 3.3, we explain our new latent belief state and the recurrent update function
which replaces (12). Lastly, in Section 3.4, we describe our modified loss function, which allows learning the model jointly with the policy.
While previous work often used -learning to train the policy (Hausknecht & Stone, 2015; Zhu et al., 2017; Foerster et al., 2016; Narasimhan et al., 2015), we use -step A2C. This avoids drawing entire trajectories from a replay buffer and allows continuous actions.
Furthermore, since A2C interleaves unrolled trajectories and performs a parameter update only every steps, it makes it feasible to maintain an approximately correct latent state. A small bias is introduced by not recomputing the latent state after each gradient update step.
We also modify the implementation of backpropagation-throught-time (bptt) for -step A2C in the case of policies with latent states. Instead of backpropagating gradients only through the computation graph of the current update involving steps, we set the size of the computation graph independently to involve steps. This leads to an average bptt-length of .This is implemented in PyTorch using the retain_graph=True flag in the backward() function. This decouples the bias-variance tradeoff of choosing from the bias-runtime tradeoff of choosing . Our experiments show that choosing greatly improves the agent’s performance.
2 Extending the Latent State
For dvrl, we extend the latent state to be a set of particles, capturing the uncertainty in the belief state (Thrun, 2000; Silver & Veness, 2010). Each particle consists of the triplet (Chung et al., 2015). The value of particle is the latent state of an rnn; is an additional stochastic latent state that allows us to learn stochastic transition models; and assigns each particle an importance weight.
Our latent state is thus an approximation of the belief state in our learned model
with stochastic transition model , decoder , and deterministic transition function which is denoted using the delta-distribution and for which we use an rnn. The model is trained to jointly optimise the elbo and the expected return.
3 Recurrent Latent State Update
To update the latent state, we proceed as follows:
First, we resample particles based on their weight by drawing ancestor indices . This improves model learning (Le et al., 2018; Maddison et al., 2017) and allows us to train the model jointly with the -step loss (see Section 3.4).
For , new values for are sampled from the encoder which conditions on the resampled ancestor values as well as the last actions and current observation . Latent variables are sampled using the reparameterisation trick. The values , together with and , are then passed to the transition function to compute .
The weights measure how likely each new latent state value is under the model and how well it explains the current observation.
To condition the policy on the belief , we need to encode the set of particles into a vector representation . We use a second rnn that sequentially takes in each tuple and its last latent state is .
Additional encoders are used for , and ; see Appendix A for details. Figure 2 summarises the entire update step.
4 Loss Function
To encourage learning a model, we include the term
in each gradient update every steps. This leads to the overall loss:
Compared to (9), (10) and (11), the losses now also depend on the encoder parameters and, for dvrl, model parameters , since the policy and value function now condition on the latent states instead of . By introducing the -step approximation , we can learn and to jointly optimise the elbo and the rl loss .
If we assume that observations and actions are drawn from the stationary state distribution induced by the policy , then is a stochastic approximation to the action-conditioned elbo:
which is a conditional extension of (6) similar to the extension of vaes by Sohn et al. (2015). The expectation over is approximated by sampling trajectories and the sum over the entire trajectory is approximated by a sum over only a part of it.
The importance of the resampling step (15) in allowing this approximation becomes clear if we compare (21) with the elbo for the importance weighted autoencoder (iwae) that does not include resampling (Doucet & Johansen, 2009; Burda et al., 2016):
Because this loss is not additive over time, we cannot approximate it with shorter parts of the trajectory.
Related Work
Most existing pomdp literature focusses on planning algorithms, where the transition and observation functions, as well as a representation of the latent state space, are known (Barto et al., 1995; McAllester & Singh, 1999; Pineau et al., 2003; Ross et al., 2008; Oliehoek et al., 2008; Roijers et al., 2015). In most realistic domains however, these are not known a priori.
There are several approaches that utilise rnns in pomdps (Bakker, 2002; Wierstra et al., 2007; Zhang et al., 2015; Heess et al., 2015), including multi-agent settings (Foerster et al., 2016), learning text-based fantasy games (Narasimhan et al., 2015) or, most recently, applied to Atari (Hausknecht & Stone, 2015; Zhu et al., 2017). As discussed in Section 3, our algorithm extends those approaches by enabling the policy to explicitly reason about a model and the belief state.
Another more specialised approach called QMDP-Net (Karkus et al., 2017) learns a value iteration network (vin) (Tamar et al., 2016) end-to-end and uses it as a transition model for planning. However, a vin makes strong assumptions about the transition function and in QMDP-Net the belief update must be performed analytically.
The idea to learn a particle filter based policy that is trained using policy gradients was previously proposed by Coquelin et al. (2009). However, they assume a known model and rely on finite differences for gradient estimation.
Instead of optimising an elbo to learn a maximum-likelihood approximation for the latent representation and corresponding transition and observation model, previous work also tried to learn those dynamics using spectral methods (Azizzadenesheli et al., 2016), a Bayesian approach (Ross et al., 2011; Katt et al., 2017), or nonparametrically (Doshi-Velez et al., 2015). However, these approaches do not scale to large or continuous state and observation spaces. For continuous states, actions, and observations with Gaussian noise, a Gaussian process model can be learned (Deisenroth & Peters, 2012). An alternative to learning an (approximate) transition and observation model is to learn a model over trajectories (Willems et al., 1995). However, this is again only possible for small, discrete observation spaces.
Due to the complexity of the learning in pomdps, previous work already found benefits to using auxiliary losses. Unlike the losses proposed by Lample & Chaplot (2017), we do not require additional information from the environment. The UNREAL agent (Jaderberg et al., 2016) is, similarly to our work, motivated by the idea to improve the latent representation by utilising all the information already obtained from the environment. While their work focuses on finding unsupervised auxiliary losses that provide good training signals, our goal is to use the auxiliary loss to better align the network computations with the task at hand by incorporating prior knowledge as an inductive bias.
There is some evidence from recent experiments on the dopamine system in mice (Babayan et al., 2018) showing that their response to ambiguous information is consistent with a theory operating on belief states.
Experiments
We evaluate dvrl on Mountain Hike and on flickering Atari. We show that dvrl deals better with noisy or partially occluded observations and that this scales to high dimensional and continuous observation spaces like images and complex tasks. We also perform a series of ablation studies, showing the importance of using many particles, including the elbo training objective in the loss function, and jointly optimising the elbo and RL losses.
More details about the environments and model architectures can be found in Appendix A together with additional results and visualisations. All plots and reported results are smoothed over time and parallel executed environments. We average over five random seeds, with shaded areas indicating the standard deviation.
dvrl used 30 particles and we set for both rnn and dvrl. The latent state for the rnn-encoder architecture was of dimension 256 and 128 for both and for dvrl. Lastly, and were used, together with RMSProp with a learning rate of for both approaches.
The main difficulty in Mountain Hike is to correctly estimate the current position. Consequently, the achieved return reflects the capability of the network to do so. dvrl outperforms rnn based policies, especially for higher levels of observation noise (Figure 4). In Figure 3 we compare the different trajectories for rnn and dvrl encoders for the same noise, i.e. and for all and . dvrl is better able to follow the mountain ridge, indicating that its inference based history aggregation is superior to a largely memory/heuristics based one.
The example in Figure 3 is representative but selected for clarity: The shown trajectories have compared to an average value of (see Figure 4).
2 Atari
We chose flickering Atari as evaluation benchmark, since it was previously used to evaluate the performance of adrqn (Zhu et al., 2017) and drqn (Hausknecht & Stone, 2015). Atari environments (Bellemare et al., 2013) provide a wide set of challenging tasks with high dimensional observation spaces. We test our algorithm on the same subset of games on which drqn and adrqn were evaluated.
Partial observability is introduced by flickering, i.e., by a probability of of returning a blank screen instead of the actual observation. Furthermore, only one frame is used as the observation. This is in line with previous work (Hausknecht & Stone, 2015). We use a frameskip of fourA frameskip of one is used for Asteroids due to known rendering issues with this environment and for the stochastic Atari environments there is a chance of repeating the current action for a second time at each transition.
dvrl used 15 particles and we set for both agents. The dimension of was 256 for both architectures, as was the dimension of . Larger latent states decreased the performance for the rnn encoder. Lastly, and was used with a learning rate of for rnn and for dvrl, selected out of a set of 6 different rates based on the results on ChopperCommand.
Table 1 shows the results for the more challenging stochastic, flickering environments. Results for the deterministic environments, including returns reported for drqn and adrqn, can be found in Appendix A. dvrl significantly outperforms the rnn-based policy on five out of ten games and narrowly underperforms significantly on only one. This shows that dvrl is viable for high dimensional observation spaces with complex environmental models.
3 Ablation Studies
Using more than one particle is important to accurately approximate the belief distribution over the latent state . Consequently, we expect that higher particle numbers provide better information to the policy, leading to higher returns. Figure 5a shows that this is indeed the case. This is an important result for our architecture, as it also implies that the resampling step is necessary, as detailed in Section 3.4. Without resampling, we cannot approximate the elbo on only observations.
Secondly, Figure 5b shows that the inclusion of to encourage model learning is required for good performance. Furthermore, not backpropagating the policy gradients through the encoder and only learning it based on the elbo (“No joint optim”) also deteriorates performance.
Lastly, we investigate the influence of the backpropagation length on both the rnn and dvrl based policies. While increasing universally helps, the key insight here is that a short length (for an average bptt-length of 2 timesteps) has a stronger negative impact on rnn than on dvrl. This is consistent with our notion that rnn is mainly performing memory based reasoning, for which longer backpropagation-through-time is required: The belief update (2) in dvrl is a one-step update from to , without the need to condition on past actions and observations. The proposal distribution can benefit from extended backpropagation lengths, but this is not necessary. Consequently, this result supports our notion that dvrl relies more on inference computations to update the latent state.
Conclusion
In this paper we proposed dvrl, a method for solving pomdps given only a stream of observations, without knowledge of the latent state space or the transition and observation functions operating in that space. Our method leverages a new elbo-based auxiliary loss and incorporates an inductive bias into the structure of the policy network, taking advantage of our prior knowledge that an inference step is required for an optimal solution.
We compared dvrl to an rnn-based architecture and found that we consistently outperform it on a diverse set of tasks, including a number of Atari games modified to have partial observability and stochastic transitions.
We also performed several ablation studies showing the necessity of using an ensemble of particles and of joint optimisation of the elbo and RL objective. Furthermore, the results support our claim that the latent state in dvrl approximates a belief distribution in a learned model.
Access to a belief distribution opens up several interesting research directions. Investigating the role of better generalisation capabilities and the more powerful latent state representation on the policy performance of dvrl can give rise to further improvements. dvrl is also likely to benefit from more powerful model architectures and a disentangled latent state. Furthermore, uncertainty in the belief state and access to a learned model can be used for curiosity driven exploration in environments with sparse rewards.
Acknowledgements
We would like to thank Wendelin Boehmer and Greg Farquar for useful discussions and feedback. The NVIDIA DGX-1 used for this research was donated by the NVIDIA corporation. M. Igl is supported by the UK EPSRC CDT in Autonomous Intelligent Machines and Systems. L. Zintgraf is supported by the Microsoft Research PhD Scholarship Program. T. A. Le is supported by EPSRC DTA and Google (project code DF6700) studentships. F. Wood is supported by DARPA PPAML through the U.S. AFRL under Cooperative Agreement FA8750-14-2-0006; Intel and DARPA D3M, under Cooperative Agreement FA8750-17-2-0093. S. Whiteson is supported by the European Research Council (ERC) under the European Union’s Horizon 2020 research and innovation programme (grant agreement number 637713).
References
Appendix A Experiments
In our implementation, the transition and proposal distributions and are multivariate normal distributions over whose mean and diagonal variance are determined by neural networks. For image data, the decoder is a multivariate independent Bernoulli distribution whose parameters are again determined by a neural network. For real-valued vectors we use a normal distribution.
When several inputs are passed to a neural network, they are concatenated to one vector. ReLUs are used as nonlinearities between all layers. Hidden layers are, if not otherwise stated, all of the same dimension as . Batch normalization was used between layers for experiments on Atari but not on Mountain Hike as they significantly hurt performance. All rnns are GRUs.
Encoding functions , and are used to encode single observations, actions and latent states before they are passed into other networks.
To encode visual observations, we use the the same convolutional network as proposed by Mnih et al. (2015), but with only instead of channels in the final layer. The transposed convolutional network of the decoder has the reversed structure. The decoder is preceeded by an additional fully connected layer which outputs the required dimension (1568 for Atari’s observations).
Actions are encoded using one fully connected layer of size 128 for Atari and size 64 for Mountain Hike. Lastly, is encoded before being passed into networks by one fully connected layer of the same size as .
The policy is one fully connected layer whose size is determined by the actions space, i.e. up to 18 outputs with softmax for Atari and only 2 outputs for the learned mean for Mountain Hike, together with a learned variance. The value function is one fully connected layer of size 1.
A2C used parallel environments and -step learning for a total batch size of 80. Hyperparameters were tuned on Chopper Command. The learning rate of both dvrl and action-specific deep recurrent AC network (adr-a2c) was independently tuned on the set of values with being chosen for dvrl on Atari and for dvrl on MountainHike and rnn on both environments. Without further tuning, we set and as is commonly used.
As optimizer we use RMSProp with . We clip gradients at a value of . The discount factor of the control problem is set to and lastly, we use ’orthogonal’ initialization for the network weights.
The source code will be release in the future.
A.2 Additional Experiments and Visualisations
Table 2 shows the on deterministic and flickering Atari, averaged over 5 random seeds. The values for drqn and adrqn are taken from the respective papers. Note that drqn and adrqn rely on Q-learning instead of A2C, so the results are not directly comparable.
Figure 6 and 7 show individual learning curves for all 10 Atari games, either for the deterministic or the stochastic version of the games.
A.3 Computational Speed
The approximate training speed in frames per second (FPS) is on one GPU on a dgx1 for Atari:
A.4 Model Predictions
In Figure 8 we show reconstructed and predicted images from the dvrl model for several Atari games. The current observation is in the leftmost column. The second column (’dt0’) shows the reconstruction after encoding and decoding the current observation. For the further columns, we make use of the learned generative model to predict future observations. For simplicity we repeat the last action. Columns 2 to 7 show predicted observations for unrolled timesteps. The model was trained as explained in Section 5.2. The reconstructed and predicted images are a weighted average over all 16 particles.
Note that the model is able to correctly predict features of future observations, for example the movement of the cars in ChopperCommand, the (approximate) ball position in Pong or the missing pins in Bowling. Furthermore, it is able to do so, even if the current observation is blank like in Bowling. The model has also correctly learned to randomly predict blank observations.
It can remember feature of the current state fairly well, like the positions of barriers (white dots) in Centipede. On the other hand, it clearly struggles with the amount of information present in MsPacman like the positions of all previously eaten fruits or the location of the ghosts.
Appendix B Algorithms
Algorithm 1 details the recurrent (belief) state computation (i.e. history encoder) for dvrl. Algorithm 2 details the recurrent state computation for rnn. Algorithm 3 describes the overall training algorithm that either uses one or the other to aggregate the history. Despite looking complicated, it is just a very detailed implementation of -step A2C with the additional changes: Inclusion of and inclusing of the option to not delete the computation graph to allow longer backprop in -step A2C.
Results for also using the reconstruction loss for the rnn based encoder aren’t shown in the paper as they reliably performed worse than rnn without reconstruction loss.