Example: Autoregressive model

In this example we are going to perform an automated Variational Bayesian Inference for autoregressive model that can be represented as following:

\[\begin{aligned} p(\gamma) &= \mathrm{Gamma}(\gamma|a, b),\\ p(\mathbf{\theta}) &= \mathcal{N}(\mathbf{\theta}|\mathbf{\mu}, \Sigma),\\ p(x_t|\mathbf{x}_{t-1:t-k}) &= \mathcal{N}(x_t|\mathbf{\theta}^{T}\mathbf{x}_{t-1:t-k}, \gamma^{-1}),\\ p(y_t|x_t) &= \mathcal{N}(y_t|x_t, \tau^{-1}), \end{aligned}\]

where $x_t$ is a current state of our system, $\mathbf{x}_{t-1:t-k}$ is a sequence of $k$ previous states, $k$ is an order of autoregression process, $\mathbf{\theta}$ is a vector of transition coefficients, $\gamma$ is a precision of state transition process, $y_k$ is a noisy observation of $x_k$ with precision $\tau$.

For a more rigorous introduction to Bayesian inference in Autoregressive models we refer to Albert Podusenko, Message Passing-Based Inference for Time-Varying Autoregressive Models.

We start with importing all needed packages:

using Rocket, ReactiveMP, GraphPPL
using Distributions, LinearAlgebra, Parameters, Random, Plots, BenchmarkTools

Let's generate some synthetic dataset, we use a predefined set of coefficients for $k$ = 5:

# The following coefficients correspond to stable poles
coefs_ar_5 = [0.10699399235785655, -0.5237303489793305, 0.3068897071844715, -0.17232255282458891, 0.13323964347539288]
function generate_ar_data(rng, n, θ, γ, τ)
    order        = length(θ)
    states       = Vector{Vector{Float64}}(undef, n + 3order)
    observations = Vector{Float64}(undef, n + 3order)

    γ_std = sqrt(inv(γ))
    τ_std = sqrt(inv(γ))

    states[1] = randn(rng, order)

    for i in 2:(n + 3order)
        states[i]       = vcat(rand(rng, Normal(dot(θ, states[i - 1]), γ_std)), states[i-1][1:end-1])
        observations[i] = rand(rng, Normal(states[i][1], τ_std))
    end

    return states[1+3order:end], observations[1+3order:end]
end
generate_ar_data (generic function with 1 method)
# Seed for reproducibility
seed = 123
rng  = MersenneTwister(seed)

# Number of observations in synthetic dataset
n = 500

# AR process parameters
real_γ = 5.0
real_τ = 5.0
real_θ = coefs_ar_5

states, observations = generate_ar_data(rng, n, real_θ, real_γ, real_τ)

Let's plot our synthetic dataset:

plot(first.(states), label = "Hidden states")
scatter!(observations, label = "Observations")

Next step is to specify probabilistic model and run inference procedure with ReactiveMP. We use GraphPPL.jl package to specify probabilistic model and additional constraints for variational Bayesian Inference. We also specify two different models for Multivariate AR with order $k$ > 1 and for Univariate AR (reduces to simple State-Space-Model) with order $k$ = 1.

@model function lar_model(T::Type, n, order, c, τ)


    # We create a sequence of random variables for hidden states
    x = randomvar(n)
    # As well a sequence of observartions
    y = datavar(Float64, n)

    ct = constvar(c)
    # We assume observation noise to be known
    cτ = constvar(τ)

    γ  = randomvar()
    θ  = randomvar()
    x0 = randomvar()

    # Prior for first state
    if T === Multivariate
        γ  ~ GammaShapeRate(1.0, 1.0)
        θ  ~ MvNormalMeanPrecision(zeros(order), diageye(order))
        x0 ~ MvNormalMeanPrecision(zeros(order), diageye(order))
    else
        γ  ~ GammaShapeRate(1.0, 1.0)
        θ  ~ NormalMeanPrecision(0.0, 1.0)
        x0 ~ NormalMeanPrecision(0.0, 1.0)
    end

    x_prev = x0

    for i in 1:n

        x[i] ~ AR(x_prev, θ, γ)

        if T === Multivariate
            y[i] ~ NormalMeanPrecision(dot(ct, x[i]), cτ)
        else
            y[i] ~ NormalMeanPrecision(ct * x[i], cτ)
        end

        x_prev = x[i]
    end
    return x, y, θ, γ
end
constraints = @constraints begin
    q(x0, x, θ, γ) = q(x0, x)q(θ)q(γ)
end
Constraints:
  marginals form:
  messages form:
  factorisation:
    q(x0, x, θ, γ) = q(x0, x)q(θ)q(γ)
Options:
  warn = true
@meta function ar_meta(artype, order, stype)
    AR(x, θ, γ) -> ARMeta(artype, order, stype)
end
ar_meta (generic function with 1 method)
morder  = 5
martype = Multivariate
mc      = ReactiveMP.ar_unit(martype, morder)
mmeta   = ar_meta(martype, morder, ARsafe())

moptions = (limit_stack_depth = 100, )

mmodel         = Model(lar_model, martype, length(observations), morder, mc, real_τ)
mdata          = (y = observations, )
minitmarginals = (γ = GammaShapeRate(1.0, 1.0), θ = MvNormalMeanPrecision(zeros(morder), diageye(morder)))
mreturnvars    = (x = KeepLast(), γ = KeepEach(), θ = KeepEach())

# First execution is slow due to Julia's initial compilation
mresult = inference(
    model = mmodel,
    data  = mdata,
    constraints   = constraints,
    meta          = mmeta,
    options       = moptions,
    initmarginals = minitmarginals,
    returnvars    = mreturnvars,
    free_energy   = true,
    iterations    = 100,
    showprogress  = true
)
Inference results:
-----------------------------------------
Free Energy: Real[559.425, 537.172, 529.743, 526.949, 525.878, 525.325, 524.984, 524.737, 524.549, 524.398  …  523.938, 523.938, 523.938, 523.938, 523.938, 523.938, 523.941, 523.94, 523.94, 523.94]
-----------------------------------------
γ = GammaShapeRate{Float64}[GammaShapeRate{Float64}(a=251.0, b=50.296627506130086), ...
θ = MvNormalWeightedMeanPrecision{Float64, Vector{Float64}, Matrix{Float64}}[MvNorma...
x = MvNormalWeightedMeanPrecision{Float64, Vector{Float64}, Matrix{Float64}}[MvNorma...
@unpack x, γ, θ = mresult.posteriors

fe = mresult.free_energy;
100-element Vector{Real}:
 559.4249714142759
 537.1715570174406
 529.7428309410466
 526.9485827792428
 525.8775165043139
 525.3245565147481
 524.9843919374539
 524.7365477686021
 524.5488062997024
 524.3975257289626
   ⋮
 523.9380699757653
 523.9379579262677
 523.937853673126
 523.9377694967447
 523.9376856257172
 523.9406581214121
 523.9403854553439
 523.9401192923065
 523.9398879407577
p1 = plot(first.(states), label="Hidden state")
p1 = scatter!(p1, observations, label="Observations")
p1 = plot!(p1, first.(mean.(x)), ribbon = first.(std.(x)), label="Inferred states", legend = :bottomright)

p2 = plot(mean.(γ), ribbon = std.(γ), label = "Inferred transition precision", legend = :topright)
p2 = plot!([ real_γ ], seriestype = :hline, label = "Real transition precision")

p3 = plot(fe, label = "Bethe Free Energy")

plot(p1, p2, p3, layout = @layout([ a; b c ]))

Let's also plot a subrange of our results:

subrange = div(n,5):(div(n, 5) + div(n, 5))

plot(subrange, first.(states)[subrange], label="Hidden state")
scatter!(subrange, observations[subrange], label="Observations")
plot!(subrange, first.(mean.(x))[subrange], ribbon = sqrt.(first.(var.(x)))[subrange], label="Inferred states", legend = :bottomright)

It is also interesting to see where our AR coefficients converge to:

let
    pθ = plot()

    θms = mean.(θ)
    θvs = var.(θ)

    l = length(θms)

    edim(e) = (a) -> map(r -> r[e], a)

    for i in 1:length(first(θms))
        pθ = plot!(pθ, θms |> edim(i), ribbon = θvs |> edim(i) .|> sqrt, label = "Estimated θ[$i]")
    end

    for i in 1:length(real_θ)
        pθ = plot!(pθ, [ real_θ[i] ], seriestype = :hline, label = "Real θ[$i]")
    end

    plot(pθ, legend = :outertopright, size = (800, 300))
end
println("$(length(real_θ))-order AR inference Bethe Free Energy: ", last(fe))
5-order AR inference Bethe Free Energy: 523.9398879407577

We can also run a 1-order AR inference on 5-order AR data:

uorder  = 1
uartype = Univariate
uc      = ReactiveMP.ar_unit(uartype, uorder)
umeta   = ar_meta(uartype, uorder, ARsafe())

uoptions = (limit_stack_depth = 100, )

umodel         = Model(lar_model, uartype, length(observations), uorder, uc, real_τ)
udata          = (y = observations, )
uinitmarginals = (γ = GammaShapeRate(1.0, 1.0), θ = NormalMeanPrecision(0.0, 1.0))
ureturnvars    = (x = KeepLast(), γ = KeepEach(), θ = KeepEach())

uresult = inference(
    model = umodel,
    data  = udata,
    meta  = umeta,
    constraints   = constraints,
    initmarginals = uinitmarginals,
    returnvars    = ureturnvars,
    free_energy   = true,
    iterations    = 15,
    showprogress  = false
)
Inference results:
-----------------------------------------
Free Energy: Real[543.153, 537.758, 536.225, 535.696, 535.501, 535.426, 535.397, 535.385, 535.381, 535.379, 535.378, 535.378, 535.378, 535.378, 535.378]
-----------------------------------------
γ = GammaShapeRate{Float64}[GammaShapeRate{Float64}(a=251.0, b=99.00892750388891), G...
θ = NormalWeightedMeanPrecision{Float64}[NormalWeightedMeanPrecision{Float64}(xi=-2....
x = NormalWeightedMeanPrecision{Float64}[NormalWeightedMeanPrecision{Float64}(xi=0.2...
println("1-order AR inference Bethe Free Energy: ", last(fe))
1-order AR inference Bethe Free Energy: 523.9398879407577

We can see that, according to final Bethe Free Energy value, in this example 5-order AR process can describe data better than 1-order AR.

We may be also interested in benchmarking our algorithm:

@benchmark inference(model = $umodel, constraints = $constraints, meta = $umeta, data = $udata, initmarginals = $uinitmarginals, free_energy = true, iterations = 15, showprogress = false)
BenchmarkTools.Trial: 16 samples with 1 evaluation.
 Range (min … max):  266.691 ms … 375.633 ms  ┊ GC (min … max):  0.00% … 16.21%
 Time  (median):     340.973 ms               ┊ GC (median):    13.81%
 Time  (mean ± σ):   331.512 ms ±  34.591 ms  ┊ GC (mean ± σ):  11.42% ±  8.00%

  ▁    ▁    ▁        ▁       ▁  ▁  ▁    ▁     ▁▁█     ▁    ▁  █  
  █▁▁▁▁█▁▁▁▁█▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁█▁▁█▁▁█▁▁▁▁█▁▁▁▁▁███▁▁▁▁▁█▁▁▁▁█▁▁█ ▁
  267 ms           Histogram: frequency by time          376 ms <

 Memory estimate: 86.02 MiB, allocs estimate: 1681202.
@benchmark inference(model = $mmodel, constraints = $constraints, meta = $mmeta, data = $mdata, initmarginals = $minitmarginals, free_energy = true, iterations = 15, showprogress = false)
BenchmarkTools.Trial: 7 samples with 1 evaluation.
 Range (min … max):  791.184 ms … 835.933 ms  ┊ GC (min … max): 16.32% … 17.48%
 Time  (median):     809.171 ms               ┊ GC (median):    15.63%
 Time  (mean ± σ):   813.371 ms ±  17.715 ms  ┊ GC (mean ± σ):  15.43% ±  1.86%

  █     █          █      █                         █ █       █  
  █▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁█▁▁▁▁▁▁▁█ ▁
  791 ms           Histogram: frequency by time          836 ms <

 Memory estimate: 291.41 MiB, allocs estimate: 2434230.