Example: Expectation Propagation / Probit model

ReactiveMP comes with support for expectation propagation (EP). In this demo we illustrate EP in the context of state-estimation in a linear state-space model that combines a Gaussian state-evolution model with a discrete observation model. Here, the probit function links continuous variable $x_t$ with the discrete variable $y_t$. The model is defined as:

\[\begin{align} u &= 0.1 \\ x_0 &\sim \mathcal{N}(0, 100) \\ x_t &\sim \mathcal{N}(x_{t-1}+ u, 0.01) \\ y_t &\sim \mathrm{Ber}(\Phi(x_t)) \end{align}\]

Import packages

using StatsFuns: normcdf
using Random, Plots

using ReactiveMP, GraphPPL, Rocket, StableRNGs

Data generation process

function generate_data(nr_samples::Int64; seed = 123)

    rng = StableRNG(seed)

    # hyper parameters
    u = 0.1

    # allocate space for data
    data_x = zeros(nr_samples + 1)
    data_y = zeros(nr_samples)

    # initialize data
    data_x[1] = -2

    # generate data
    for k = 2:nr_samples + 1

        # calculate new x
        data_x[k] = data_x[k-1] + u + sqrt(0.01)*randn(rng)

        # calculate y
        data_y[k-1] = normcdf(data_x[k]) > rand(rng)

    end

    # return data
    return data_x, data_y

end
generate_data (generic function with 1 method)
n = 40

data_x, data_y = generate_data(n)

p = plot(xlabel = "t", ylabel = "x, y")
p = scatter!(p, data_y, label = "y")
p = plot!(p, data_x[2:end], label = "x")

Model specification

@model function probit_model(nr_samples::Int64)

    # allocate space for variables
    x = randomvar(nr_samples + 1)
    y = datavar(Float64, nr_samples)

    # specify uninformative prior
    x[1] ~ NormalMeanPrecision(0.0, 0.01)

    # create model
    for k = 2:nr_samples + 1
        x[k] ~ NormalMeanPrecision(x[k - 1] + 0.1, 100)
        y[k - 1] ~ Probit(x[k]) where {
            # Probit node by default uses RequireInbound pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge
            # To change initial value use may specify it manually, like
            # pipeline = RequireInbound(in = NormalMeanPrecision(0, 1.0))
        }
    end

    # return parameters
    return x, y

end

Inference

result = inference(
    model = Model(probit_model, length(data_y)),
    data  = (
        y = data_y,
    ),
    iterations = 10,
    returnvars = (
        x = KeepLast(),
    ),
    showprogress = true,
    free_energy  = true
)
Inference results:
-----------------------------------------
Free Energy: Real[NaN, 75.647, 17.3856, 15.6541, 15.6463, 15.6462, 15.6462, 15.6462, 15.6462, 15.6462]
-----------------------------------------
x = NormalWeightedMeanPrecision{Float64}[NormalWeightedMeanPrecision{Float64}(xi=-11...
mx = result.posteriors[:x]

p = plot(xlabel = "t", ylabel = "x, y", legend = :bottomright)
p = scatter!(p, data_y, label = "y")
p = plot!(p, data_x[2:end], label = "x", lw = 2)
p = plot!(mean.(mx)[2:end], ribbon = std.(mx)[2:end], fillalpha = 0.2, label="x (inferred mean)")

f = plot(xlabel = "t", ylabel = "BFE")
f = plot!(result.free_energy[2:end], label = "Bethe Free Energy")

plot(p, f, size = (800, 400))