This example has been auto-generated from the examples/
folder at GitHub repository.
Probit Model (EP)
# Activate local environment, see `Project.toml`
import Pkg; Pkg.activate(".."); Pkg.instantiate();
RxInfer
comes with support for expectation propagation (EP). In this demo we illustrate EP in the context of state-estimation in a linear state-space model that combines a Gaussian state-evolution model with a discrete observation model. Here, the probit function links continuous variable $x_t$ with the discrete variable $y_t$. The model is defined as:
\[\begin{aligned} u &= 0.1 \\ x_0 &\sim \mathcal{N}(0, 100) \\ x_t &\sim \mathcal{N}(x_{t-1}+ u, 0.01) \\ y_t &\sim \mathrm{Ber}(\Phi(x_t)) \end{aligned}\]
Import packages
using RxInfer, StableRNGs, Random, Plots
using StatsFuns: normcdf
Data generation
function generate_data(nr_samples::Int64; seed = 123)
rng = StableRNG(seed)
# hyper parameters
u = 0.1
# allocate space for data
data_x = zeros(nr_samples + 1)
data_y = zeros(nr_samples)
# initialize data
data_x[1] = -2
# generate data
for k = 2:nr_samples + 1
# calculate new x
data_x[k] = data_x[k-1] + u + sqrt(0.01)*randn(rng)
# calculate y
data_y[k-1] = normcdf(data_x[k]) > rand(rng)
end
# return data
return data_x, data_y
end;
n = 40
40
data_x, data_y = generate_data(n);
p = plot(xlabel = "t", ylabel = "x, y")
p = scatter!(p, data_y, label = "y")
p = plot!(p, data_x[2:end], label = "x")
Model specification
@model function probit_model(nr_samples::Int64)
# allocate space for variables
x = randomvar(nr_samples + 1)
y = datavar(Float64, nr_samples)
# specify uninformative prior
x[1] ~ Normal(mean = 0.0, precision = 0.01)
# create model
for k = 2:nr_samples + 1
x[k] ~ Normal(mean = x[k - 1] + 0.1, precision = 100)
y[k - 1] ~ Probit(x[k]) where {
# Probit node by default uses RequireMessage pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge
# To change initial value use may specify it manually, like. Changes to the initial message may improve stability in some situations
pipeline = RequireMessage(in = NormalMeanPrecision(0, 0.01))
}
end
end;
Inference
result = inference(
model = probit_model(length(data_y)),
data = (y = data_y, ),
iterations = 5,
returnvars = (x = KeepLast(),),
free_energy = true
)
Inference results:
Posteriors | available for (x)
Free Energy: | Real[23.1779, 15.743, 15.6467, 15.6462, 15.6462]
Results
mx = result.posteriors[:x]
p = plot(xlabel = "t", ylabel = "x, y", legend = :bottomright)
p = scatter!(p, data_y, label = "y")
p = plot!(p, data_x[2:end], label = "x", lw = 2)
p = plot!(mean.(mx)[2:end], ribbon = std.(mx)[2:end], fillalpha = 0.2, label="x (inferred mean)")
f = plot(xlabel = "t", ylabel = "BFE")
f = plot!(result.free_energy, label = "Bethe Free Energy")
plot(p, f, size = (800, 400))