Example: Linear Gaussian State Space Model
In this example the goal is to estimate hidden states of a Linear Dynamical process where all hidden states are Gaussians. A simple multivariate Linear Gaussian State Space Model can be described with the following equations:
\[\begin{aligned} p(x_i|x_{i - 1}) & = \mathcal{N}(x_i|A * x_{i - 1}, \mathcal{P}),\\ p(y_i|x_i) & = \mathcal{N}(y_i|B * x_i, \mathcal{Q}), \end{aligned}\]
where $x_i$ are hidden states, $y_i$ are noisy observations, $A$, $B$ are state transition and observational matrices, $\mathcal{P}$ and $\mathcal{Q}$ are state transition noise and observation noise covariance matrices. For a more rigorous introduction to Linear Gaussian Dynamical systems we refer to Simo Sarkka, Bayesian Filtering and Smoothing book.
To model this process in ReactiveMP
, first, we start with importing all needed packages:
using Rocket, ReactiveMP, GraphPPL, Distributions
using BenchmarkTools, Random, LinearAlgebra, Plots
Next step, is to generate some synthetic data:
function generate_data(rng, A, B, Q, P)
x_prev = [ 10.0, -10.0 ]
x = Vector{Vector{Float64}}(undef, n)
y = Vector{Vector{Float64}}(undef, n)
for i in 1:n
x[i] = rand(rng, MvNormal(A * x_prev, Q))
y[i] = rand(rng, MvNormal(B * x[i], P))
x_prev = x[i]
end
return x, y
end
generate_data (generic function with 1 method)
# Seed for reproducibility
seed = 1234
rng = MersenneTwister(seed)
# We will model 2-dimensional observations with rotation matrix `A`
# To avoid clutter we also assume that matrices `A`, `B`, `P` and `Q`
# are known and fixed for all time-steps
θ = π / 35
A = [ cos(θ) -sin(θ); sin(θ) cos(θ) ]
B = diageye(2)
Q = diageye(2)
P = 25.0 .* diageye(2)
# Number of observations
n = 300
For large number of observations you will need yo use limit_stack_depth = 100
option during model creation, e.g. model, (x, y) = create_model(model_options(limit_stack_depth = 100), ...)
x, y = generate_data(rng, A, B, Q, P)
Lets plot our synthetic dataset. Lines represent our hidden states we want to estimate using noisy observations, which are represented as dots.
slicedim(dim) = (a) -> map(e -> e[dim], a)
px = plot()
px = plot!(px, x |> slicedim(1), label = "Hidden Signal (dim-1)", color = :orange)
px = scatter!(px, y |> slicedim(1), label = false, markersize = 2, color = :orange)
px = plot!(px, x |> slicedim(2), label = "Hidden Signal (dim-2)", color = :green)
px = scatter!(px, y |> slicedim(2), label = false, markersize = 2, color = :green)
plot(px)
Model specification
To create a model we use GraphPPL
package and @model
macro:
@model function rotate_ssm(n, x0, A, B, Q, P)
# We create constvar references for better efficiency
cA = constvar(A)
cB = constvar(B)
cQ = constvar(Q)
cP = constvar(P)
# `x` is a sequence of hidden states
x = randomvar(n)
# `y` is a sequence of "clamped" observations
y = datavar(Vector{Float64}, n)
x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))
x_prev = x_prior
for i in 1:n
x[i] ~ MvNormalMeanCovariance(cA * x_prev, cQ)
y[i] ~ MvNormalMeanCovariance(cB * x[i], cP)
x_prev = x[i]
end
return x, y
end
Inference
Here we create a custom_inference
function to infer hidden states of our system:
function custom_inference(data, x0, A, B, Q, P)
# We create a model and get references for
# hidden states and observations
model, (x, y) = rotate_ssm(n, x0, A, B, Q, P);
xbuffer = buffer(Marginal, n)
bfe = nothing
# We subscribe on posterior marginals of `x`
xsubscription = subscribe!(getmarginals(x), xbuffer)
# We are also intereset in BetheFreeEnergy functional,
# which in this case is equal to minus log evidence
fsubcription = subscribe!(score(BetheFreeEnergy(), model), (v) -> bfe = v)
# `update!` updates our clamped datavars
update!(y, data)
# It is important to always unsubscribe
unsubscribe!((xsubscription, fsubcription))
return xbuffer, bfe
end
custom_inference (generic function with 1 method)
Alternatively you can use ReactiveMP inference API:
result = inference(
model = Model(rotate_ssm, length(y), x0, A, B, Q, P),
data = (y = y,)
);
To run inference we also specify prior for out first time-step:
x0 = MvNormalMeanCovariance(zeros(2), 100.0 * diageye(2))
xmarginals, bfe = custom_inference(y, x0, A, B, Q, P)
px = plot()
px = plot!(px, x |> slicedim(1), label = "Hidden Signal (dim-1)", color = :orange)
px = plot!(px, x |> slicedim(2), label = "Hidden Signal (dim-2)", color = :green)
px = plot!(px, mean.(xmarginals) |> slicedim(1), ribbon = var.(xmarginals) |> slicedim(1) .|> sqrt, fillalpha = 0.5, label = "Estimated Signal (dim-1)", color = :teal)
px = plot!(px, mean.(xmarginals) |> slicedim(2), ribbon = var.(xmarginals) |> slicedim(2) .|> sqrt, fillalpha = 0.5, label = "Estimated Signal (dim-1)", color = :violet)
plot(px)
As we can see from our plot, estimated signal resembles closely to the real hidden states with small variance. We maybe also interested in the value for minus log evidence:
bfe
1882.2434870101347
We may be also interested in performance of our resulting Belief Propagation algorithm:
@benchmark custom_inference($y, $x0, $A, $B, $Q, $P)
BenchmarkTools.Trial: 89 samples with 1 evaluation.
Range (min … max): 44.084 ms … 120.368 ms ┊ GC (min … max): 0.00% … 54.56%
Time (median): 50.072 ms ┊ GC (median): 0.00%
Time (mean ± σ): 56.268 ms ± 18.338 ms ┊ GC (mean ± σ): 9.93% ± 15.24%
▁▇█▃▂
▃▃█████▅█▄▃▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▄▃▁▁▁▁▁▁▃▄▃ ▁
44.1 ms Histogram: frequency by time 116 ms <
Memory estimate: 18.97 MiB, allocs estimate: 348147.