Linking SDEs and Sampling

In bayesian inference, when we want to find samples from an intractable posterior distribution, we often resort to sampling schemes like Metropolis-Hastings or Hamiltonian Monte Carlo. It can seem like a bit of black box as to how this actually works - but there is a powerful intuition to be gained by considering stochastic processes.

The aim of this post (or perhaps post series) is to provide some background and intution behind the excellent work: A Complete Recipe for Stochastic Gradient MCMC.

This post involves some stochastic processes and ideas from stochastic calculus. If you want a good introductory resource, in the past I have found:

Brownian Motion Calculus - Wiersema
Stochastic Calculus for Finance (Parts 1 & 2)- Shreve

to be really good introductions.

Ito Process and SDE's - A Recap

An Ito process is one of the form:

$dX_{t} = \mu_{t}dt + \theta_{t}dW_{t}$

Where $W_{t}$ is a Brownian motion - i.e. $W_{t} - W_{s} \sim N(0, t-s)$

This is actually shorthand - we won't go into technicalities but the process $X_{t}$ is differentiable almost nowhere.

What is important here is intuition - essentially we can read the differential form as at any time point, the change in $X_{t}$ is made up of a deterministic part changing through time, and a stochasitic part that is related to a Brownian motion.

This is a stochastic process, and we can take analogue to the Euler discretisation from ODE's, called Euler-Maruyama discretisation as a practical way to simulate a stochasic process from a given ODE.

Fokker-Planck Equation

For a stochastic process the Fokker-Planck equation is a PDE which quantifies the evolution of the probability the process is in state x at time t. If the process is as given above;

$\frac{dp(x,t)}{dt} = -\frac{d(\mu(x,t) p(x,t))}{dx} +\frac{1}{2} \frac{d^{2}(\sigma^{2}(x,t) p(x,t))}{dx^{2}}$

The stationary distribution of the SDE is given when $\frac{dp(x,t)}{dt}=0$. Hence the name, stationary (also, steady state) - as time evolves, the probabilites do not change.

An example

Let's consider $\mu = -x_{t}$, $\sigma = \sqrt{2}$.

We get the ODE:

$ 0= -\frac{d(-x_{t}p(x,t))}{dx} + \frac{d^{2} p(x,t))}{dx^{2}}$

$\frac{d p(x,t))}{dx} = -x_{t}p(x,t)) +c $

This is a simple ODE, and actually if you have ever integrated the normal distribution, you have probably used this in reverse. If assume a form for p as :

$p = ke^{\frac{-x^{2}}{2}}$

Simple substitutions shows that this is indeed a solution of the ODE, and c=0. k is the normalisation constant.

This is an step important because we have actually linked an SDE and sampling! What we have found is that if I wish to sample from a normal distribution, I can simulate from the SDE:

$dX_{t} = -X_{t}dt + \sqrt{2} dW_{t}$

and the values of X are our samples!

Obviously in practice this is not neccesary - we can generate samples directly in this case. However, the concept is sound - we can sample from a distribution if only we know the correct stochastic process!

Solving it the other way round!

What we did before was to solve the ODE to find the stationary distribution. We can however, solve for the coefficient given a known probability distribution. We can tell that p(x,t) = p(x) because it's partial derivative with respect to $t$ is 0, so we drop it from our notation.

Let's set $D(x) = \frac{\sigma(x)^{2}}{2}$

If we follow the same line of reasoning as before, we get to the general formula;

$\frac{d D(x) p(x))}{dx} = \mu(x) p(x)$

$p(x)\frac{d D(x) )}{dx} + D(x)\frac{d p(x))}{dx} = \mu(x) p(x)$

$\frac{d D(x) )}{dx} + D(x) [\frac{1}{p(x)}\frac{d p(x))}{dx}] = \mu(x)$

We can notice that the term in the square is actually the derivative of the log probability, so we can write this as:

$\frac{d D(x) )}{dx} + D(x) \frac{d \ (ln \ p(x))}{dx} = \mu(x)$

So, we can pick a form for $D(x, t)$ and find the appropriate drift coefficient so that the stochastic process has a stationary distribution of the desired distribution!

We can repeat our previous example, but by using $ln \ p(x) = \frac{-x^{2}}{2}$ and $D(x)=1$. As such we find the coefficient we need is (as expected) $-x$.

Use of this approach

Firstly, simulating an SDE is quite straight-forward! Secondly, we don't have to limit ourselves to simple closed form distribution sampling. Recalling the log posterior is:

$ln \ P(\theta \mid \mathcal{D}) = ln \ P(\theta, \mathcal{D}) + C$

This means the gradients are the same, and so we can compute the gradient of log joint, and use a simple discretisation regime on the resulting stochastic process to sample from the posterior distribution!

It should be noted however, that just because the stationary distribution is correct, it doesn't mean that in practice our simulation will converge in any sensible time. Effective simulation of SDEs is a big topic!

Most samplers we might care to use are actually specialisations of the above procedure, but in higher dimensions. It also can be interesting to look at what we can add from other areas - for example, stochastic gradient Hamiltonian Monte Carlo is where we use an estimate of the gradient rather than its full form. See the paper for more details, but you may also want to check out Monte Carlo Gradient Estimation in Machine Learning for some ways we could compute estimates of this gradient.

Summary

What we have seen in this post is that the there is a reasonably simple way that we can determine (in this case 1D) a stochastic process which has the stationary distribution we wish to sample from. We can go through an analogue procedure, as listed in the paper, to do the same thing in n dimensions.