This example has been auto-generated from the examples/ folder at GitHub repository.

# Hierarchical Gaussian Filter

# Activate local environment, see Project.toml
import Pkg; Pkg.activate(".."); Pkg.instantiate();

In this demo the goal is to perform approximate variational Bayesian Inference for Univariate Hierarchical Gaussian Filter (HGF).

Simple HGF model can be defined as:

\begin{aligned} x^{(j)}_k & \sim \, \mathcal{N}(x^{(j)}_{k - 1}, f_k(x^{(j - 1)}_k)) \\ y_k & \sim \, \mathcal{N}(x^{(j)}_k, \tau_k) \end{aligned}

where $j$ is an index of layer in hierarchy, $k$ is a time step and $f_k$ is a variance activation function. RxInfer.jl export Gaussian Controlled Variance (GCV) node with $f_k = \exp(\kappa x + \omega)$ variance activation function. By default the node uses Gauss-Hermite cubature with a prespecified number of approximation points in the cubature. In this demo we also show how we can change the hyperparameters in different approximation methods (iin this case Gauss-Hermite cubature) with the help of metadata structures. Here how our model will look like with the GCV node:

\begin{aligned} z_k & \sim \, \mathcal{N}(z_{k - 1}, \mathcal{\tau_z}) \\ x_k & \sim \, \mathcal{N}(x_{k - 1}, \exp(\kappa z_k + \omega)) \\ y_k & \sim \, \mathcal{N}(x_k, \mathcal{\tau_y}) \end{aligned}

In this experiment we will create a single time step of the graph and perform variational message passing filtering alrogithm to estimate hidden states of the system. For a more rigorous introduction to Hierarchical Gaussian Filter we refer to Ismail Senoz, Online Message Passing-based Inference in the Hierarchical Gaussian Filter paper.

For simplicity we will consider $\tau_z$, $\tau_y$, $\kappa$ and $\omega$ known and fixed, but there are no principled limitations to make them random variables too.

To model this process in RxInfer, first, we start with importing all needed packages:

using RxInfer, BenchmarkTools, Random, Plots

Next step, is to generate some synthetic data:

function generate_data(rng, k, w, zv, yv)
z_prev = 0.0
x_prev = 0.0

z = Vector{Float64}(undef, n)
v = Vector{Float64}(undef, n)
x = Vector{Float64}(undef, n)
y = Vector{Float64}(undef, n)

for i in 1:n
z[i] = rand(rng, Normal(z_prev, sqrt(zv)))
v[i] = exp(k * z[i] + w)
x[i] = rand(rng, Normal(x_prev, sqrt(v[i])))
y[i] = rand(rng, Normal(x[i], sqrt(yv)))

z_prev = z[i]
x_prev = x[i]
end

return z, x, y
end
generate_data (generic function with 1 method)
# Seed for reproducibility
seed = 42

rng = MersenneTwister(seed)

# Parameters of HGF process
real_k = 1.0
real_w = 0.0
z_variance = abs2(0.2)
y_variance = abs2(0.1)

# Number of observations
n = 300

z, x, y = generate_data(rng, real_k, real_w, z_variance, y_variance);

Let's plot our synthetic dataset. Lines represent our hidden states we want to estimate using noisy observations.

let
pz = plot(title = "Hidden States Z")
px = plot(title = "Hidden States X")

plot!(pz, 1:n, z, label = "z_i", color = :orange)
plot!(px, 1:n, x, label = "x_i", color = :green)
scatter!(px, 1:n, y, label = "y_i", color = :red, ms = 2, alpha = 0.2)

plot(pz, px, layout = @layout([ a; b ]))
end

To create a model we use the @model macro:

# We create a single-time step of corresponding state-space process to
# perform online learning (filtering)
@model function hgf(real_k, real_w, z_variance, y_variance)

# Priors from previous time step for z
zt_min_mean = datavar(Float64)
zt_min_var  = datavar(Float64)

# Priors from previous time step for x
xt_min_mean = datavar(Float64)
xt_min_var  = datavar(Float64)

zt_min ~ NormalMeanVariance(zt_min_mean, zt_min_var)
xt_min ~ NormalMeanVariance(xt_min_mean, xt_min_var)

# Higher layer is modelled as a random walk
zt ~ NormalMeanVariance(zt_min, z_variance)

# Lower layer is modelled with GCV node
gcvnode, xt ~ GCV(xt_min, zt, real_k, real_w)

# Noisy observations
y = datavar(Float64)
y ~ NormalMeanVariance(xt, y_variance)

return gcvnode
end

@constraints function hgfconstraints()
q(xt, zt, xt_min) = q(xt, xt_min)q(zt)
end

@meta function hgfmeta()
# Lets use 31 approximation points in the Gauss Hermite cubature approximation method
end
hgfmeta (generic function with 1 method)
function run_inference(data, real_k, real_w, z_variance, y_variance)

zt_min_mean, zt_min_var = mean_var(q(zt))
xt_min_mean, xt_min_var = mean_var(q(xt))
end

return rxinference(
model         = hgf(real_k, real_w, z_variance, y_variance),
constraints   = hgfconstraints(),
meta          = hgfmeta(),
data          = (y = data, ),
keephistory   = length(data),
historyvars    = (
xt = KeepLast(),
zt = KeepLast()
),
initmarginals = (
zt = NormalMeanVariance(0.0, 5.0),
xt = NormalMeanVariance(0.0, 5.0),
),
iterations    = 5,
free_energy   = true,
autostart     = true,
callbacks     = (
after_model_creation = (model, returnval) -> begin
gcvnode = returnval
setmarginal!(gcvnode, :y_x, MvNormalMeanCovariance([ 0.0, 0.0 ], [ 5.0, 5.0 ]))
end,
)
)
end
run_inference (generic function with 1 method)
result = run_inference(y, real_k, real_w, z_variance, y_variance);

mz = result.history[:zt];
mx = result.history[:xt];
let
pz = plot(title = "Hidden States Z")
px = plot(title = "Hidden States X")

plot!(pz, 1:n, z, label = "z_i", color = :orange)
plot!(pz, 1:n, mean.(mz), ribbon = std.(mz), label = "estimated z_i", color = :teal)

plot!(px, 1:n, x, label = "x_i", color = :green)
plot!(px, 1:n, mean.(mx), ribbon = std.(mx), label = "estimated x_i", color = :violet)

plot(pz, px, layout = @layout([ a; b ]))
end

As we can see from our plot, estimated signal resembles closely to the real hidden states with small variance. We maybe also interested in the values for Bethe Free Energy functional:

plot(result.free_energy_history, label = "Bethe Free Energy")

As we can see BetheFreeEnergy converges nicely to a stable point.

At final, lets check the overall performance of our resulting Variational Message Passing algorithm:

@benchmark run_inference($y,$real_k, $real_w,$z_variance, \$y_variance)
BenchmarkTools.Trial: 64 samples with 1 evaluation.
Range (min … max):  59.119 ms … 128.093 ms  ┊ GC (min … max): 0.00% … 0.00
%
Time  (median):     77.214 ms               ┊ GC (median):    0.00%
Time  (mean ± σ):   79.066 ms ±  13.156 ms  ┊ GC (mean ± σ):  6.22% ± 6.90
%

▂  ▂▂   ▂▂ ▅█▅       ▂▂▂▂
▅█▅▅▅▅█▅▅██▁▁███▅███▁▅▅▅▅▁▅████▁▅▁▁▁▁▁▅▅▁▁▁▁▅▅▁▁▁▁▁▁▁▁▁▁▁▁▁▅ ▁
59.1 ms         Histogram: frequency by time          119 ms <

Memory estimate: 18.64 MiB, allocs estimate: 447879.