Ideas and figures from this summary are taken from Model-Based Reinforcement Learning for Atari(SimPLe).

The traditional trade off between Model-Based Reinforcement Learning and Model-Free Reinforcement Learning is sample complexity for final performance. Model-Based Reinforcement Learning algorithms generally have a lower sample complexity but come at the cost of worse final performance. In this work an approach is introduced which shows how to significantly reduce sample complexity on Atari at the cost of final performance.

Architecture

The architecture is split into two primary models: (1) World Model (2) Policy.

World Model

Illustration of the World Model architecture

The world model is further broken down into 2 components

A skip connected encoder-decoder where the decoding section attempts to generate the next frame

an inference network approximating the posterior q(z|x)

The encoding z and an embedding of the action are passed in the center of the encoder-decoder. The encoding z is predicted auto-regressively as an auxiliary objective. Since we don’t have access to future observations, at inference time the auto regressive model is used to predict z.

Policy

The policy is a PPO(summary) agent implemented as a standard actor-critic approach. The parameters of the PPO agent are completely separate from that of the World Model. Given that this work only explores environments with discrete action spaces the policy generates logits for a multinomial distribution used to parameterize a distribution over actions.

Training

The training scheme introduces a few tricks: loss clipping, early environment resets, and scheduled sampling.

Loss Clipping

One of the main problems with explicit model based RL is the possibility of spending much of the networks capacity modeling irrelevant information such as the background. Loss clipping is a hack that is hypothesized as a way around it as gradients are 0 when the loss is less than a constant C.

Early Environment Resets

Given that the World Model is not perfect, the world model diverges from realistic observations. To get around this the simulated environment(env’) is reset to a ground truth observation(uniformly sampled from D) every N steps(generally N=50).

Scheduled Sampling

While training the world model scheduled sampling is linearly annealed to 100% . This helps avoid issues stemming from the model drifting to unrealistic observations when consuming its own predictions. This is an especially pronounced issue when the model is weak at the beginning of training.

In the future it would be interesting to see Policy have access to the hidden state of the autoregressive model, this seemed to be quite helpful in World Models(summary).