This example has been auto-generated from the examples/ folder at GitHub repository.

# Kalman filtering and smoothing

# Activate local environment, see Project.toml
import Pkg; Pkg.activate(".."); Pkg.instantiate();

In the following set of examples the goal is to estimate hidden states of a Dynamical process where all hidden states are Gaussians.

We start our journey with a simple multivariate Linear Gaussian State Space Model (LGSSM), which can be solved analytically.

We then solve an identification problem which does not have an analytical solution.

Utimately, we show how RxInfer.jl can deal with missing observations.

## Gaussian Linear Dynamical System

LGSSM can be described with the following equations:

\begin{aligned} p(x_i|x_{i - 1}) & = \mathcal{N}(x_i|A * x_{i - 1}, \mathcal{P}),\\ p(y_i|x_i) & = \mathcal{N}(y_i|B * x_i, \mathcal{Q}), \end{aligned}

where $x_i$ are hidden states, $y_i$ are noisy observations, $A$, $B$ are state transition and observational matrices, $\mathcal{P}$ and $\mathcal{Q}$ are state transition noise and observation noise covariance matrices. For a more rigorous introduction to Linear Gaussian Dynamical systems we refer to Simo Sarkka, Bayesian Filtering and Smoothing book.

To model this process in RxInfer, first, we start with importing all needed packages:

using RxInfer, BenchmarkTools, Random, LinearAlgebra, Plots

Next step, is to generate some synthetic data:

function generate_data(rng, A, B, Q, P)
x_prev = [ 10.0, -10.0 ]

x = Vector{Vector{Float64}}(undef, n)
y = Vector{Vector{Float64}}(undef, n)

for i in 1:n
x[i] = rand(rng, MvNormal(A * x_prev, Q))
y[i] = rand(rng, MvNormal(B * x[i], P))
x_prev = x[i]
end

return x, y
end
generate_data (generic function with 1 method)
# Seed for reproducibility
seed = 1234

rng = MersenneTwister(1234)

# We will model 2-dimensional observations with rotation matrix A
# To avoid clutter we also assume that matrices A, B, P and Q
# are known and fixed for all time-steps
θ = π / 35
A = [ cos(θ) -sin(θ); sin(θ) cos(θ) ]
B = diageye(2)
Q = diageye(2)
P = 25.0 .* diageye(2)

# Number of observations
n = 300;
x, y = generate_data(rng, A, B, Q, P);

Let's plot our synthetic dataset. Lines represent our hidden states we want to estimate using noisy observations, which are represented as dots.

px = plot()

px = plot!(px, getindex.(x, 1), label = "Hidden Signal (dim-1)", color = :orange)
px = scatter!(px, getindex.(y, 1), label = false, markersize = 2, color = :orange)
px = plot!(px, getindex.(x, 2), label = "Hidden Signal (dim-2)", color = :green)
px = scatter!(px, getindex.(y, 2), label = false, markersize = 2, color = :green)

plot(px) To create a model we use GraphPPL package and @model macro:

@model function rotate_ssm(n, x0, A, B, Q, P)

# We create constvar references for better efficiency
cA = constvar(A)
cB = constvar(B)
cQ = constvar(Q)
cP = constvar(P)

# x is a sequence of hidden states
x = randomvar(n)
# y is a sequence of "clamped" observations
y = datavar(Vector{Float64}, n)

x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))
x_prev = x_prior

for i in 1:n
x[i] ~ MvNormalMeanCovariance(cA * x_prev, cQ)
y[i] ~ MvNormalMeanCovariance(cB * x[i], cP)
x_prev = x[i]
end

end

To run inference we also specify prior for out first hidden state:

x0 = MvNormalMeanCovariance(zeros(2), 100.0 * diageye(2));
# For large number of observations you need to use limit_stack_depth = 100 option during model creation, e.g.
# inference(..., options = (limit_stack_depth = 500, ))
result = inference(
model = rotate_ssm(length(y), x0, A, B, Q, P),
data = (y = y,),
free_energy = true
);

xmarginals  = result.posteriors[:x]
logevidence = -result.free_energy; # given the analytical solution, free energy will be equal to the negative log evidence
px = plot()

px = plot!(px, getindex.(x, 1), label = "Hidden Signal (dim-1)", color = :orange)
px = plot!(px, getindex.(x, 2), label = "Hidden Signal (dim-2)", color = :green)

px = plot!(px, getindex.(mean.(xmarginals), 1), ribbon = getindex.(var.(xmarginals), 1) .|> sqrt, fillalpha = 0.5, label = "Estimated Signal (dim-1)", color = :teal)
px = plot!(px, getindex.(mean.(xmarginals), 2), ribbon = getindex.(var.(xmarginals), 2) .|> sqrt, fillalpha = 0.5, label = "Estimated Signal (dim-1)", color = :violet)

plot(px) As we can see from our plot, estimated signal resembles closely to the real hidden states with small variance. We maybe also interested in the value for minus log evidence:

logevidence
1-element Vector{Float64}:
-1882.2434870099778

We may be also interested in performance of our resulting Belief Propagation algorithm:

@benchmark inference(
model = rotate_ssm(length($y),$x0, $A,$B, $Q,$P),
data = (y = y,) ) BenchmarkTools.Trial: 145 samples with 1 evaluation. Range (min … max): 25.289 ms … 105.283 ms ┊ GC (min … max): 0.00% … 17.2 5% Time (median): 30.949 ms ┊ GC (median): 0.00% Time (mean ± σ): 34.474 ms ± 9.607 ms ┊ GC (mean ± σ): 8.58% ± 12.6 6% ▆▅█▁ ▃▁▄▇████▇▃▃▁▃▃▃▁▃▃▄▃▄▃▄▃▄▄▃▃▃▃▃▁▁▃▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▃ 25.3 ms Histogram: frequency by time 75.6 ms < Memory estimate: 9.65 MiB, allocs estimate: 223850. ## System Identification Problem In this example we are going to attempt to run Bayesian inference and decouple two random-walk signals, which were combined into a single single through some deterministic function f. We do not have access to the real values of these signals, but only to their combination. First, we create the generate_data function that accepts f as an argument: using RxInfer, Distributions, StableRNGs, Plots function generate_data(f, n; seed = 123, x_i_min = -20.0, w_i_min = 20.0, noise = 20.0, real_x_τ = 0.1, real_w_τ = 1.0) rng = StableRNG(seed) real_x = Vector{Float64}(undef, n) real_w = Vector{Float64}(undef, n) real_y = Vector{Float64}(undef, n) for i in 1:n real_x[i] = rand(rng, Normal(x_i_min, sqrt(1.0 / real_x_τ))) real_w[i] = rand(rng, Normal(w_i_min, sqrt(1.0 / real_w_τ))) real_y[i] = rand(rng, Normal(f(real_x[i], real_w[i]), sqrt(noise))) x_i_min = real_x[i] w_i_min = real_w[i] end return real_x, real_w, real_y end generate_data (generic function with 2 methods) The function returns the real signals real_x and real_w for later comparison (we are not going to use them during inference) and their combined version real_y (we are going to use it as our observations during the inference). We also assume that real_y is corrupted with some measurement noise. ### Combination 1: y = x + w In our first example, we are going to use a simple addition (+) as the function f. In general, it is impossible to decouple the signals x and w without strong priors, but we can try and see how good an inference can be. The + operation on two random variables also has a special meaning in the probabilistic inference, namely the convolution of pdf's of the two random variables, and RxInfer treats it specially with many precomputed analytical rules, which may make the inference task easier. First, let us create a test dataset: n = 250 real_x, real_w, real_y = generate_data(+, n); pl = plot(title = "Underlying signals") pl = plot!(pl, real_x, label = "x") pl = plot!(pl, real_w, label = "w") pr = plot(title = "Combined y = x + w") pr = scatter!(pr, real_y, ms = 3, color = :red, label = "y") plot(pl, pr, size = (800, 300)) To run inference, we need to create a probabilistic model: our beliefs about how our data could have been generated. For this we can use the @model macro from RxInfer.jl: @model function identification_problem(f, n, m_x_0, τ_x_0, a_x, b_x, m_w_0, τ_w_0, a_w, b_w, a_y, b_y) x0 ~ Normal(mean = m_x_0, precision = τ_x_0) τ_x ~ Gamma(shape = a_x, rate = b_x) w0 ~ Normal(mean = m_w_0, precision = τ_w_0) τ_w ~ Gamma(shape = a_w, rate = b_w) τ_y ~ Gamma(shape = a_y, rate = b_y) x = randomvar(n) w = randomvar(n) s = randomvar(n) y = datavar(Float64, n) x_i_min = x0 w_i_min = w0 for i in 1:n x[i] ~ Normal(mean = x_i_min, precision = τ_x) w[i] ~ Normal(mean = w_i_min, precision = τ_w) s[i] ~ f(x[i], w[i]) y[i] ~ Normal(mean = s[i], precision = τ_y) x_i_min = x[i] w_i_min = w[i] end end RxInfer runs Bayesian inference as a variational optimisation procedure between the real solution and its variational proxy q. In our model specification we assumed noise components to be unknown, thus, we need to enforce a structured mean-field assumption for the variational family of distributions q. This inevitably reduces the accuracy of the result, but makes the task easier and allows for fast and analytical message passing-based variational inference: constraints = @constraints begin q(x0, w0, x, w, τ_x, τ_w, τ_y, s) = q(x, x0, w, w0, s)q(τ_w)q(τ_x)q(τ_y) end Constraints: marginals form: messages form: factorisation: q(x0, w0, x, w, τ_x, τ_w, τ_y, s) = q(x, x0, w, w0, s)q(τ_w)q(τ_x)q(τ_y ) Options: warn = true The next step is to assign priors, initialise needed messages and marginals and call the inference function: m_x_0, τ_x_0 = -20.0, 1.0 m_w_0, τ_w_0 = 20.0, 1.0 # We set relatively strong priors for random walk noise components # and sort of vague prior for the noise of the observations a_x, b_x = 0.01, 0.01var(real_x) a_w, b_w = 0.01, 0.01var(real_w) a_y, b_y = 1.0, 1.0 # We set relatively strong priors for messages xinit = map(r -> NormalMeanPrecision(r, τ_x_0), reverse(range(-60, -20, length = n))) winit = map(r -> NormalMeanPrecision(r, τ_w_0), range(20, 60, length = n)) imessages = (x = xinit, w = winit) imarginals = (τ_x = GammaShapeRate(a_x, b_x), τ_w = GammaShapeRate(a_w, b_w), τ_y = GammaShapeRate(a_y, b_y)) result = inference( model = identification_problem(+, n, m_x_0, τ_x_0, a_x, b_x, m_w_0, τ_w_0, a_w, b_w, a_y, b_y), data = (y = real_y,), options = (limit_stack_depth = 500, ), constraints = constraints, initmessages = imessages, initmarginals = imarginals, iterations = 50 ) Inference results: Posteriors | available for (w0, w, x0, s, τ_x, τ_w, τ_y, x) Let's examine our inference results: τ_x_marginals = result.posteriors[:τ_x] τ_w_marginals = result.posteriors[:τ_w] τ_y_marginals = result.posteriors[:τ_y] smarginals = result.posteriors[:s] xmarginals = result.posteriors[:x] wmarginals = result.posteriors[:w]; px1 = plot(legend = :bottomleft, title = "Estimated hidden signals") px2 = plot(legend = :bottomright, title = "Estimated combined signals") px1 = plot!(px1, real_x, label = "Real hidden X") px1 = plot!(px1, mean.(xmarginals[end]), ribbon = var.(xmarginals[end]), label = "Estimated X") px1 = plot!(px1, real_w, label = "Real hidden W") px1 = plot!(px1, mean.(wmarginals[end]), ribbon = var.(wmarginals[end]), label = "Estimated W") px2 = scatter!(px2, real_y, label = "Observations", ms = 2, alpha = 0.5, color = :red) px2 = plot!(px2, mean.(smarginals[end]), ribbon = std.(smarginals[end]), label = "Combined estimated signal", color = :green) plot(px1, px2, size = (800, 300)) The inference results are not so bad, even though RxInfer missed the correct values of the signals between 100 and 150. ### Combination 2: y = min(x, w) In this example we use a slightly more complex function, for which RxInfer does not have precomputed analytical message update rules. We are going to attempt to run Bayesian inference with min as a combination function. Note, however, that directly using min may cause problems for the built-in approximation methods as it has zero partial derviates with respect to all but one of the variables. We generate data with the min function directly however we model it with a somewhat smoothed version: # Smoothed version of min without zero-ed derivatives function smooth_min(x, y) if x < y return x + 1e-4 * y else return y + 1e-4 * x end end smooth_min (generic function with 1 method) RxInfer supports arbitrary nonlinear functions, but it requires an explicit approximation method specification. That can be achieved with the built-in @meta macro: min_meta = @meta begin # In this example we are going to use a simple Linearization method smooth_min() -> Linearization() end Meta specification: smooth_min() -> Linearization() Options: warn = true n = 200 min_real_x, min_real_w, min_real_y = generate_data(min, n, seed = 1, x_i_min = 0.0, w_i_min = 0.0, noise = 1.0, real_x_τ = 1.0, real_w_τ = 1.0); pl = plot(title = "Underlying signals") pl = plot!(pl, min_real_x, label = "x") pl = plot!(pl, min_real_w, label = "w") pr = plot(title = "Combined y = min(x, w)") pr = scatter!(pr, min_real_y, ms = 3, color = :red, label = "y") plot(pl, pr, size = (800, 300)) min_m_x_0, min_τ_x_0 = -1.0, 1.0 min_m_w_0, min_τ_w_0 = 1.0, 1.0 min_a_x, min_b_x = 1.0, 1.0 min_a_w, min_b_w = 1.0, 1.0 min_a_y, min_b_y = 1.0, 1.0 min_imessages = (x = NormalMeanPrecision(min_m_x_0, min_τ_x_0), w = NormalMeanPrecision(min_m_w_0, min_τ_w_0)) min_imarginals = (τ_x = GammaShapeRate(min_a_x, min_b_x), τ_w = GammaShapeRate(min_a_w, min_b_w), τ_y = GammaShapeRate(min_a_y, min_b_y)) min_result = inference( model = identification_problem(smooth_min, n, min_m_x_0, min_τ_x_0, min_a_x, min_b_x, min_m_w_0, min_τ_w_0, min_a_w, min_b_w, min_a_y, min_b_y), data = (y = min_real_y,), meta = min_meta, options = (limit_stack_depth = 500, ), constraints = constraints, initmessages = min_imessages, initmarginals = min_imarginals, iterations = 100 ) Inference results: Posteriors | available for (w0, w, x0, s, τ_x, τ_w, τ_y, x) min_τ_x_marginals = min_result.posteriors[:τ_x] min_τ_w_marginals = min_result.posteriors[:τ_w] min_τ_y_marginals = min_result.posteriors[:τ_y] min_smarginals = min_result.posteriors[:s] min_xmarginals = min_result.posteriors[:x] min_wmarginals = min_result.posteriors[:w] px1 = plot(legend = :bottomleft, title = "Estimated hidden signals") px2 = plot(legend = :bottomright, title = "Estimated combined signals") px1 = plot!(px1, min_real_x, label = "Real hidden X") px1 = plot!(px1, mean.(min_xmarginals[end]), ribbon = var.(min_xmarginals[end]), label = "Estimated X") px1 = plot!(px1, min_real_w, label = "Real hidden W") px1 = plot!(px1, mean.(min_wmarginals[end]), ribbon = var.(min_wmarginals[end]), label = "Estimated W") px2 = scatter!(px2, min_real_y, label = "Observations", ms = 2, alpha = 0.5, color = :red) px2 = plot!(px2, mean.(min_smarginals[end]), ribbon = std.(min_smarginals[end]), label = "Combined estimated signal", color = :green) plot(px1, px2, size = (800, 300)) As we can see inference with the min function is significantly harder. Even though the combined signal has been inferred with high precision the underlying x and w signals are barely inferred. This may be expected, since the min function essentially destroy the information about one of the signals, thus, making it impossible to decouple two seemingly identical random walk signals. The only one inferred signal is the one which is lower and we have no inference information about the signal which is above. It might be possible to infer the states, however, with more informative priors and structural information about two different signals (e.g. if these are not random walks). ### Online (filtering) identification: y = min(x, w) Another way to approach to this problem is to use online (filtering) inference procedure from RxInfer, but for that we also need to modify our model specification a bit: @model function rx_identification(f) # We are going to continuosly update our priors # based on new posteriors m_x_0 = datavar(Float64) τ_x_0 = datavar(Float64) m_w_0 = datavar(Float64) τ_w_0 = datavar(Float64) a_x = datavar(Float64) b_x = datavar(Float64) a_y = datavar(Float64) b_y = datavar(Float64) a_w = datavar(Float64) b_w = datavar(Float64) s = randomvar() y = datavar(Float64) x0 ~ Normal(mean = m_x_0, precision = τ_x_0) τ_x ~ Gamma(shape = a_x, rate = b_x) w0 ~ Normal(mean = m_w_0, precision = τ_w_0) τ_w ~ Gamma(shape = a_w, rate = b_w) τ_y ~ Gamma(shape = a_y, rate = b_y) x ~ Normal(mean = x0, precision = τ_x) w ~ Normal(mean = w0, precision = τ_w) s ~ f(x, w) y ~ Normal(mean = s, precision = τ_y) end We impose structured mean-field assumption for this model as well: rx_constraints = @constraints begin q(x0, x, w0, w, τ_x, τ_w, τ_y, s) = q(x0, x)q(w, w0)q(τ_w)q(τ_x)q(s)q(τ_y) end Constraints: marginals form: messages form: factorisation: q(x0, x, w0, w, τ_x, τ_w, τ_y, s) = q(x0, x)q(w, w0)q(τ_w)q(τ_x)q(s)q(τ _y) Options: warn = true Online inference in the RxInfer supports the @autoupdates specification, which tells inference procedure how to update priors based on new computed posteriors: autoupdates = @autoupdates begin m_x_0, τ_x_0 = mean_precision(q(x)) m_w_0, τ_w_0 = mean_precision(q(w)) a_x = shape(q(τ_x)) b_x = rate(q(τ_x)) a_y = shape(q(τ_y)) b_y = rate(q(τ_y)) a_w = shape(q(τ_w)) b_w = rate(q(τ_w)) end (m_x_0,τ_x_0 = mean_precision(q(x)), m_w_0,τ_w_0 = mean_precision(q(w)), a_ x = shape(q(τ_x)), b_x = rate(q(τ_x)), a_y = shape(q(τ_y)), b_y = rate(q(τ_ y)), a_w = shape(q(τ_w)), b_w = rate(q(τ_w))) As previously we need to define the @meta structure that specifies the approximation method for the nonlinear function smooth_min (f in the model specification): rx_meta = @meta begin smooth_min() -> Linearization() end Meta specification: smooth_min() -> Linearization() Options: warn = true Next step is to generate our dataset and to run the actual inference procedure! For that we use the rxinference function, which has a similar API as the inference function: n = 300 rx_real_x, rx_real_w, rx_real_y = generate_data(min, n, seed = 1, x_i_min = 1.0, w_i_min = -1.0, noise = 1.0, real_x_τ = 1.0, real_w_τ = 1.0); pl = plot(title = "Underlying signals") pl = plot!(pl, rx_real_x, label = "x") pl = plot!(pl, rx_real_w, label = "w") pr = plot(title = "Combined y = min(x, w)") pr = scatter!(pr, rx_real_y, ms = 3, color = :red, label = "y") plot(pl, pr, size = (800, 300)) engine = rxinference( model = rx_identification(smooth_min), constraints = rx_constraints, data = (y = rx_real_y,), autoupdates = autoupdates, meta = rx_meta, returnvars = (:x, :w, :τ_x, :τ_w, :τ_y, :s), keephistory = 1000, historyvars = KeepLast(), initmarginals = (w = NormalMeanVariance(-2.0, 1.0), x = NormalMeanVariance(2.0, 1.0), τ_x = GammaShapeRate(1.0, 1.0), τ_w = GammaShapeRate(1.0, 1.0), τ_y = GammaShapeRate(1.0, 20.0)), iterations = 10, free_energy = true, free_energy_diagnostics = nothing, autostart = true, ) RxInferenceEngine: Posteriors stream | enabled for (w, s, τ_x, τ_w, τ_y, x) Free Energy stream | enabled Posteriors history | available for (x, w, x0, s, τ_x, τ_w, τ_y, w0) Free Energy history | available Enabled events | [ ] rx_smarginals = engine.history[:s] rx_xmarginals = engine.history[:x] rx_wmarginals = engine.history[:w]; px1 = plot(legend = :bottomleft, title = "Estimated hidden signals") px2 = plot(legend = :bottomright, title = "Estimated combined signals") px1 = plot!(px1, rx_real_x, label = "Real hidden X") px1 = plot!(px1, mean.(rx_xmarginals), ribbon = var.(rx_xmarginals), label = "Estimated X") px1 = plot!(px1, rx_real_w, label = "Real hidden W") px1 = plot!(px1, mean.(rx_wmarginals), ribbon = var.(rx_wmarginals), label = "Estimated W") px2 = scatter!(px2, rx_real_y, label = "Observations", ms = 2, alpha = 0.5, color = :red) px2 = plot!(px2, mean.(rx_smarginals), ribbon = std.(rx_smarginals), label = "Combined estimated signal", color = :green) plot(px1, px2, size = (800, 300)) The results are quite similar to the smoothing case and, as we can see, one of the random walk is again in the "disabled" state, does not infer anything and simply increases its variance (which is expected for the random walk). ## Handling Missing Data An interesting case in filtering and smoothing problems is the processing of missing data. It can happen that sometimes your reading devices failt to acquire the data leading to missing observation. Let us assume that the following model generates the data \begin{aligned} {x}_t &\sim \mathcal{N}\left({x}_{t-1}, 1.0\right) \\ {y}_t &\sim \mathcal{N}\left({x}_{t}, P \right) \end{aligned} with prior{x}_0 \sim \mathcal{N}({m_{{x}_0}}, {v_{{x}_0}})$. Suppose that our measurement device fails to acquire data from time to time. In this case, instead of scalar observation$\hat{y}_t \in \mathrm{R}\$ we sometimes will catch missing observations.

using RxInfer, Plots
@model function smoothing(n, x0)

P ~ Gamma(shape = 0.001, scale = 0.001)
x_prior ~ Normal(mean = mean(x0), var = var(x0))

x = randomvar(n)
y = datavar(Float64, n) where { allow_missing = true }

x_prev = x_prior

for i in 1:n
x[i] ~ Normal(mean = x_prev, precision = 1.0)
y[i] ~ Normal(mean = x[i], precision = P)

x_prev = x[i]
end

end

We need to manually extend the set of rules to support ::Missing values

P = 1.0
n = 250

real_signal     = map(e -> sin(0.05 * e), collect(1:n))
noisy_data      = real_signal + rand(Normal(0.0, sqrt(P)), n);
missing_indices = 100:125
missing_data    = similar(noisy_data, Union{Float64, Missing}, )

copyto!(missing_data, noisy_data)

for index in missing_indices
missing_data[index] = missing
end
constraints = @constraints begin
q(x_prior, x, y, P) = q(x_prior, x)q(P)q(y)
end
Constraints:
marginals form:
messages form:
factorisation:
q(x_prior, x, y, P) = q(x_prior, x)q(P)q(y)
Options:
warn = true
x0_prior = NormalMeanVariance(0.0, 1000.0)

result = inference(
model = smoothing(n, x0_prior),
data  = (y = missing_data,),
constraints = constraints,
initmarginals = (P = Gamma(0.001, 0.001), ),
returnvars = (x = KeepLast(),),
iterations = 20
);
plot(real_signal, label = "Noisy signal", legend = :bottomright)
scatter!(missing_indices, real_signal[missing_indices], ms = 2, opacity = 0.75, label = "Missing region")
plot!(mean.(result.posteriors[:x]), ribbon = var.(result.posteriors[:x]), label = "Estimated hidden state")` 