Graphical Models 2 - Factor Graphs and Sum Product

Factor Graphs

Factor graphs are a type of undirected graphical model. I'm not going to talk about undirected graphical models much here. You can find plenty of good introductions around, but its not really neccesary to have any particular understanding for this post.

A factor graph specifies a factorisation of a joint distribution. As such it is not neccesarily unique - one joint distribution can have many valid factor graphs.

As an example, consider $p(x_1, x_2, x_3)$ this has factorisations:

$p(x_1 | x_2, x_3) p(x_2) p(x_3)$
$p(x_2 | x_1, x_3) p(x_1) p(x_3)$

etc...

we can write these factors:
$f_a(x_1, x_2, x_3)f_b(x_2)f_c(x_3)$
etc...

The factors do not have to be valid distributions - they may be unnormalised. In fact it is common to discard the normaliser entirely, and reinstate it at the end of any computation.

In [1]:
from IPython.display import Image
Image('images/factorgraph.jpg')
Out[1]:

This factor graph implies $p(x_1, x_2, x_3, x_4, x_5) \propto f_a(x_2, x_3, x_4, x_5)f_b(x_1, x_2)f_c(x_5)$

An example model that would correspond to this factor graph would be:

$p(x_3, x_4| x_2, x_5)p(x_1, x_2)p(x_5)$

Hopefully it quite clear that a factor graph tells us about how groups of variables relate to each other. In the above example, we can see that $x_1$ is only related to the rest of the variable via $f_b$.

Summation vs. Integration

From here on out I use the integral in a measure theoretic sense - it denotes integral for continuous variables and sum for discrete!

Marginalising a factor graph

Taking the above graph, let's say we want to compute $p(x_1)$.

$p(x_1) \propto \int_{x_2} \int_{x_3} \int_{x_4} \int_{x_5} f_a(x_2, x_3, x_4, x_5)f_b(x_1, x_2)f_c(x_5)$

$= \int_{x_2} f_b(x_1, x_2) \int_{x_3} \int_{x_4} \int_{x_5} f_a(x_2, x_3, x_4, x_5)f_c(x_5)$

We can use a different terminology, that is based on the graph rather than summations.

We view marginalisation as a process of propogating information through our distribution.

If we want to compute the marginal for $x_1$, we start at the other nodes, and propogate information from the other nodes up the links until we get to $x_1$.

$= \int_{x_2} f_b(x_1, x_2) \int_{x_3} \int_{x_4} \int_{x_5} f_a(x_2, x_3, x_4, x_5)f_c(x_5)$

$= \int_{x_2} f_b(x_1, x_2) \mu_{x_2 \to f_b } = \mu_{f_b \to x_1 } $

Here we used $\mu$ to denote information (or a message). So this says that the information we get when we integrate over all but $x_1$ is propogated from $f_b$ only. Information flows into $f_b$ from $x_2$.

So we see that a message from a factor to the next node is given by the integrate over all variables of the factor multiplied by all incoming messages.

Expanding a bit more:

$\mu_{x_2 \to f_b } = \int_{x_3} \int_{x_4} \int_{x_5} f_a(x_2, x_3, x_4, x_5)f_c(x_5)$

From our earlier discussion, this must have the form:

$\mu_{x_2 \to f_b } = \int_{x_3} \int_{x_4} \int_{x_5} f_a(x_2, x_3, x_4, x_5)\mu_{x_3 \to f_a }\mu_{x_4 \to f_a }\mu_{x_5 \to f_a }$

So this means $\mu_{x_3 \to f_a } =1, \ \mu_{x_4 \to f_a }=1 , \ \mu_{x_5 \to f_a }=f_c(x_5)$

General Sum product algorithm

We can now state some more general rules we can follow to get the marginal we desire.

1) A message from a node to a factor is the product of all the messages it recieved. If it is a leaf node (i.e. recieves nothing), it sends 1.
2) A message from a factor to a node is the product of the factor and the incoming messages, integrated over all the nodes that sent a message. If it is a leaf node, it simply sends the factor.

The marginal of any variable is the product of the messages from all its neighbours into it.

Previously we picked a single node to marginalise, but we can easily extend the procedure to get all of them for little extra work.

We pick a root as before, and propogate from the leaves to the root as before. We then reverse the procedure.

We start at the root and propogate the other way using exactly the same procedure.

It's easy to see that this will give us messages in both directions for each node-factor pair. So we can just read off all the messages that go into a node of interest, take their product and we have the marginal of that node.

Another example

Let us consider "inference on a chain". This is the factor graph of a markov model.

In [2]:
Image('images/chain.jpg')
Out[2]:

$\mu_{f_a \to X_1} = f_a(X_1)$
$\mu_{X_1 \to f_b} = \mu_{f_a \to X_1}$
$\mu_{f_b \to X_2} = \int \mu_{ X_1 \to f_b}f_b(X_1, X_2) dX_1$

and going the other way along the chain:

$\mu_{X_4 \to f_d} = 1$
$\mu_{f_d \to X_3} = \int f_d(X_3, X_4)\mu_{X_4 \to f_d} dX_4 $
$\mu_{X_3 \to f_c} = \mu_{f_d \to X_3}$
$\mu_{f_c \to X_2} = \int f_c(X_2, X_3)\mu_{X_3 \to f_c} dX_3 $

We can show this is gives the result we expect by substitution and rearranging the integrals.

$\mu_{f_c \to X_2} = \int f_c(X_2, X_3)(\int f_d(X_3, X_4)\mu_{X_4 \to f_d} dX_4) dX_3 $

$\mu_{f_c \to X_2} = \int f_c(X_2, X_3)\int f_d(X_3, X_4) dX_4 dX_3 $

$\mu_{f_b \to X_2} = \int f_a(X_1)f_b(X_1, X_2) dX_1$

and so:

$\mu_{f_c \to X_2}\mu_{f_b \to X_2} = (\int f_c(X_2, X_3)\int f_d(X_3, X_4) dX_4 dX_3)(\int f_a(X_1)f_b(X_1, X_2) dX_1)$

$ = \int \int \int f_c(X_2, X_3)f_d(X_3, X_4) f_a(X_1)f_b(X_1, X_2) dX_1 dX_4 dX_3 $

which is the traditional marginalisation formula!

Summary

The sum-product algorithm can seem somewhat obvious, but it is actually incredibly useful! Firstly, it is more efficient - we can evaluate messages only once, and combine them optimally, in much the same way in which using backwards inference is much more efficient when solving shortest path problems.

It also helps exploit factorisations in the underlying distribution, and helps transform a "global" summation into "local" terms. This also has uses in inference with HMMs - we can derive a recursive formula for performing inference and marginalisation.