Most basic RNN models are set up to predict the next state directly, or to predict the discrete probability distribution of the next state , typically by applying the softmax function on the output layer. The output is therefore a probability mass function, assigning a probability on every possible outcome.

How should we predict continuous variables? Of course, we can directly output the value of from the NN. But what if we need the probability distribution? The softmax approach would not work on continuous variables, since there are infinite amount of possible outcomes for an continuous random variable. The answer is: Mixture Density Models. Gaussian Mixture Models assume that the output distribution is a mixture of normal distributions, so we only need to predict the parameters of each mixture component. You can read the original paper by Bishop here. One of the most known demos using this model is the handwriting synthesis demo by Graves.

Implementation Details



The mixture distribution is defined as:

where is a multivariate gaussian function defined by :

Assuming that the covariance matrix is a diagonal matrix with the same variance across all dimensions:

Which simplifies the equation to:

So we only need to predict the k means , the variance , and the mixture coefficient for each mixture. They are defined as follows, given is the output of the NN given input :

We use the negative log-likehood for the loss function.

The problem with this is that calculating the log(sum(exp(·))) expression is numerically unstable and may cause NaNs (They spread like wildfire). Luckily there is the LogSumExp trick, which eliminates most of the unstableness:

In my implementation, I added a small epsilon value just before taking the log, just to be sure.

Results

I trained my models to fit this plane:

As you can see, there are multiple possible (y, z) output values for each input x value. Thus the Mean Squared error method would not work here. This is the result of fitting the model after 100000 iterations:

The red dots are randomly sampled points from the output mixture distribution, and the black lines are the mean value for each mixture. The transparency of the black lines indicate their mixture weight. Notice the mean lines are spaced almost evenly from each other, in order to maximize the likelihood. That also helps the model to capture the output plane with no left out parts.

The source code is available here.