Markov chains, Metropolis-Hastings, Gibbs sampling, and how it relates to Bayesian inference
This post is an introduction to Markov chain Monte Carlo (MCMC) sampling methods. We will consider two methods in particular, namely the Metropolis-Hastings algorithm and Gibbs sampling. We will introduce them and prove why they work, implement practical examples in Python, and eventually explain how sampling is applied for Bayesian inference and why it is so important.
MCMC methods are a family of sampling methods which make use of Markov chains to generate dependent data samples. Their basic idea is to build such Markov chains, which are easy to sample from, and whose stationary distribution is our target distribution — such that when following them, in the limit, we obtain samples from the target distribution.
Why do we need this? In a previous post I introduced basic sampling methods, among others covering rejection and importance sampling for complex distributions. These generate independent data samples, whereas here we generate dependent ones, as mentioned — which does not answer the previous question, but is an important distinction. However, in the previous posts we saw that the presented methods suffer from severe limitations: it is hard to find suited proposal distributions, in particular in high-dimensional spaces, yielding to high variance and wasteful computations.
MCMC methods, i.e. following a (simple) Markov chain fare better in these circumstances, in particular due to less needed information about the distribution we want to sample from, and the fact that we only need to be able to evaluate it up to a fixed factor. That is: we do not need to be able to evaluate the full pdf for a given x
, p(x)
, but it suffices to be able to compute zp(x)
. At the end of this article we will see why this is so powerful, by applying it to solve a Bayesian inference problem. In many tutorials and explanations this last bit is just given quite briefly as a side note — but I believe this deserves — especially for beginners to Bayesian inference — more spotlight.
Naturally, there are disadvantages of MCMC methods too though: due to the samples being correlated, the effective sample size shrinks, and occasionally the methods might not converge or be very slow at it.
Since, as the name suggests (and we stated it multiple times so far), MCMC methods are based on Marko Chains, we introduce these first.
They are a way of modelling stochastic processes as a sequence of events. In this, the Markovian property states that the next state only depends on the current, and not any historic information.
(Small excursus: many practical ML methods require this property, such as RL. Requiring this one-step dependency might seem very limiting and impractical — however note that we can simply expand the state space to arbitrary dimension, in particular including past events in the current state — and thus totally circumventing this “limitation”.)
Formally, let us consider a random variable X
, and denote its per-timestep realisations with X₀
, X₁
, … How X
develops over time is given by a transition function P
, where
denotes that the chance of X
transitioning from state i
to state j
is p
.
To fully specify a Markov chain, in addition we need to define an initial distribution for X
, denoted by π₀
. With this, we can follow the Markov chain, from π₀
iteratively applying P
, yielding the per-timestamp distributions π₁
, π₂
, …
Let us visualise this with an example. We chose the following transition matrix:
Note that in our notation index ij
denotes the transition probabilities from state j
to i
, for convenience.
We now take a random initial distribution, and follow the Markov chain for 30 steps. This can be implemented as follows in Python:
import numpy as npP = np.asarray([[0.3, 0.5, 0.75], [0.1, 0.1, 0.1], [0.6, 0.4, 0.15]])
print(f"Transition matrix P: {P}")
# Generate random initial distribution (normalize to obtain valid distribution).
pi = np.random.rand(3)
pi /= np.sum(pi)
print(f"Initial distribution pi_0: {pi}")
for i in range(30):
pi = np.matmul(P, pi)
if i % 5 == 0:
print(f"Distribution after i steps: {pi}")
When executing this program, we will get some output similar to this:
Distribution after i steps: [0.51555326 0.1 0.38444674]
Distribution after i steps: [0.499713 0.1 0.400287]
Distribution after i steps: [0.5000053 0.1 0.3999947]
Distribution after i steps: [0.4999999 0.1 0.4000001]
Distribution after i steps: [0.5 0.1 0.4]
Distribution after i steps: [0.5 0.1 0.4]
As we can see, this Markov chain converges — for any initial distribution — to the distribution [0.5, 0.1, 0.4]
— which we call the stationary distribution of this Markov chain.
Before moving on, we will introduce a criterion, needed in the following sections, to determine whether a Markov chain converges: detailed balance. We say a Markov chain satisfies the detailed balance criterion, if there exists a distribution π
satisfying:
I.e., the probability of transitioning from state j
to state i
is the same as the reverse transition, considering the distribution π
. Intuitively this should also make sense, as to why this yields a stationary distribution. Feel free to convince yourself that this criterion is satisfied for above defined Markov chain, and that indeed [0.5, 0.1, 0.4]
is the distribution satisfying it.
Equipped with this knowledge, we now describe and introduce one of the most common and frequently used MCMC algorithms, namely the Metropolis-Hastings algorithm. To recap, what we are trying to do is sample values from a difficult probability distribution f(x)
, the target distribution.
Let’s begin with an overview over the algorithm. Essentially, it is made up of the following steps:
- Select an arbitrary initial value
x₀
in the target distribution’s support - Draw
y₁
using a proposal distributionq
- Compute
p₁
(see below) - Draw
u₁
from the uniform distribution over [0, 1] - Set
x₁ = y₁
ifu₁ ≤ p₁
, else setx₁ = x₀
- Repeat steps 2–5
p₁
is given by:
Example
Let’s demonstrate this using a concrete example, implemented in Python. The setup: the target distribution we want to sample is a Gaussian distribution. Our proposal distribution is another Gaussian. This naturally is no real-world practical example. However, I believe and hope, that this simplified settings helps understanding, instead of confusing the reader. Note that in this example, all values of interest are 1D.
The corresponding Python code looks as follows:
import matplotlib.pyplot as plt
import numpy as np
import scipy.statsNUM_SAMPLES = 10000
# Target distribution
f = scipy.stats.norm(5, 2)
# Plot target distribution
x = np.linspace(-5, 15, 5000)
plt.plot(x, f.pdf(x))
# Step 1
x = np.random.uniform(-2, 2)
# Proposal distribution
q = scipy.stats.norm(0, 1)
samples = []
for i in range(NUM_SAMPLES):
# Step 2
y = x + q.rvs()
# Step 3
p = min(f.pdf(y) / f.pdf(x) * q.pdf(x - y) / q.pdf(y - x), 1)
# Step 4
u = np.random.uniform(0, 1)
# Step 5
x = y if u <= p else x
samples.append(x)
plt.hist(samples, density=True, bins=30)
plt.show()
Let’s go over this with some more details. In the beginning, we’re using scipy’s stats module to represent our target distribution f
— then plot its pdf. We then define an initial value x
to begin sampling with — simply generating one value from a uniform distribution. We then enter the sampling loop, iteratively generating NUM_SAMPLES
value according to the algorithm introduced above. As proposal distribution we use another Gaussian q
— which yields a new value y
obtained by “jumping” away from x
according to this Gaussian. It is probably worth noting, that the conditional evaluation of q
equals q
’s pdf with the given jump range — intuitively the further we jump, the less likely it becomes.
Executing this program should yield a result similar to this:
We see that we correctly sampled from the “unknown” distribution f
.
Proof of Correctness
To prove the correctness of the Metropolis-Hastings algorithm, we need to show that the used Markov chain’s stationary distribution is indeed the target distribution. For this, we use above introduced notation of detailed balance. Remember, this involves showing that
i.e. it does not matter whether we first visit state (t-1)