This example has been auto-generated from the examples/ folder at GitHub repository.

# Probit Model (EP)

# Activate local environment, see Project.toml
import Pkg; Pkg.activate(".."); Pkg.instantiate();

RxInfer comes with support for expectation propagation (EP). In this demo we illustrate EP in the context of state-estimation in a linear state-space model that combines a Gaussian state-evolution model with a discrete observation model. Here, the probit function links continuous variable $x_t$ with the discrete variable $y_t$. The model is defined as:

\begin{aligned} u &= 0.1 \\ x_0 &\sim \mathcal{N}(0, 100) \\ x_t &\sim \mathcal{N}(x_{t-1}+ u, 0.01) \\ y_t &\sim \mathrm{Ber}(\Phi(x_t)) \end{aligned}

## Import packages

using RxInfer, StableRNGs, Random, Plots
using StatsFuns: normcdf

## Data generation

function generate_data(nr_samples::Int64; seed = 123)

rng = StableRNG(seed)

# hyper parameters
u = 0.1

# allocate space for data
data_x = zeros(nr_samples + 1)
data_y = zeros(nr_samples)

# initialize data
data_x[1] = -2

# generate data
for k = 2:nr_samples + 1

# calculate new x
data_x[k] = data_x[k-1] + u + sqrt(0.01)*randn(rng)

# calculate y
data_y[k-1] = normcdf(data_x[k]) > rand(rng)

end

# return data
return data_x, data_y

end;
n = 40
40
data_x, data_y = generate_data(n);
p = plot(xlabel = "t", ylabel = "x, y")
p = scatter!(p, data_y, label = "y")
p = plot!(p, data_x[2:end], label = "x")

## Model specification

@model function probit_model(nr_samples::Int64)

# allocate space for variables
x = randomvar(nr_samples + 1)
y = datavar(Float64, nr_samples)

# specify uninformative prior
x[1] ~ Normal(mean = 0.0, precision = 0.01)

# create model
for k = 2:nr_samples + 1
x[k] ~ Normal(mean = x[k - 1] + 0.1, precision = 100)
y[k - 1] ~ Probit(x[k]) where {
# Probit node by default uses RequireMessage pipeline with vague(NormalMeanPrecision) message as initial value for in edge
# To change initial value use may specify it manually, like. Changes to the initial message may improve stability in some situations
pipeline = RequireMessage(in = NormalMeanPrecision(0, 0.01))
}
end

end;

## Inference

result = inference(
model = probit_model(length(data_y)),
data  = (y = data_y, ),
iterations = 5,
returnvars = (x = KeepLast(),),
free_energy  = true
)
Inference results:
Posteriors       | available for (x)
Free Energy:     | Real[23.1779, 15.743, 15.6467, 15.6462, 15.6462]

## Results

mx = result.posteriors[:x]

p = plot(xlabel = "t", ylabel = "x, y", legend = :bottomright)
p = scatter!(p, data_y, label = "y")
p = plot!(p, data_x[2:end], label = "x", lw = 2)
p = plot!(mean.(mx)[2:end], ribbon = std.(mx)[2:end], fillalpha = 0.2, label="x (inferred mean)")

f = plot(xlabel = "t", ylabel = "BFE")
f = plot!(result.free_energy, label = "Bethe Free Energy")

plot(p, f, size = (800, 400))