Example: Hidden Markov Model

In this demo we are interested in Bayesian inference of parameters of a hidden Markov model (HMM)., Specifically, we consider a first-order HMM with hidden states $s_0, s_1, \dots, s_T$ and observations $x_1, \dots, x_T$ governed by a state transition probability matrix $A$ and an observation probability matrix $B$:,

\[\begin{align*} s_t & \sim \mathcal{C}at(A s_{t-1}),\\ x_t & \sim \mathcal{C}at(B s_t).\\, \end{align*}\]

We assume three possible states (\"red\", \"green\" and \"blue\"), and the goal is to estimate matrices $A$ and $B$ from a simulated data set. To have a full Bayesian treatment of the problem, both $A$ and $B$ are endowed with priors (Dirichlet distributions on the columns)."

using Rocket, ReactiveMP, GraphPPL
using Random, BenchmarkTools, Distributions, LinearAlgebra
using Plots
function rand_vec(rng, distribution::Categorical)
    k = ncategories(distribution)
    s = zeros(k)
    s[ rand(rng, distribution) ] = 1.0
    s
end

function generate_data(n_samples; seed = 124)

    rng = MersenneTwister(seed)

    # Transition probabilities (some transitions are impossible)
    A = [0.9 0.0 0.1; 0.1 0.9 0.0; 0.0 0.1 0.9]
    # Observation noise
    B = [0.9 0.05 0.05; 0.05 0.9 0.05; 0.05 0.05 0.9]
    # Initial state
    s_0 = [1.0, 0.0, 0.0]
    # Generate some data
    s = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoding of the states
    x = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoding of the observations

    s_prev = s_0

    for t = 1:n_samples
        a = A * s_prev
        s[t] = rand_vec(rng, Categorical(a ./ sum(a)))
        b = B * s[t]
        x[t] = rand_vec(rng, Categorical(b ./ sum(b)))
        s_prev = s[t]
    end

    return x, s
end
generate_data (generic function with 1 method)
# Test data
N = 100

x_data, s_data = generate_data(N)

scatter(argmax.(s_data))

Model specification

# Model specification
@model function hidden_markov_model(n)

    A ~ MatrixDirichlet(ones(3, 3))
    B ~ MatrixDirichlet([ 10.0 1.0 1.0; 1.0 10.0 1.0; 1.0 1.0 10.0 ])

    s_0 ~ Categorical(fill(1.0 / 3.0, 3))

    s = randomvar(n)
    x = datavar(Vector{Float64}, n)

    s_prev = s_0

    for t in 1:n
        s[t] ~ Transition(s_prev, A)
        x[t] ~ Transition(s[t], B)
        s_prev = s[t]
    end

end

@constraints function hidden_markov_model_constraints()
    q(s_0, s, A, B) = q(s_0, s)q(A)q(B)
end
hidden_markov_model_constraints (generic function with 1 method)

Inference

idata = (x = x_data, )

imodel = Model(hidden_markov_model, N)

imarginals = (
    A = vague(MatrixDirichlet, 3, 3),
    B = vague(MatrixDirichlet, 3, 3),
    s = vague(Categorical, 3)
)

ireturnvars = (
    A = KeepLast(),
    B = KeepLast(),
    s = KeepLast()
)

result = inference(
    model         = imodel,
    data          = idata,
    constraints   = hidden_markov_model_constraints(),
    initmarginals = imarginals,
    returnvars    = ireturnvars,
    iterations    = 20,
    free_energy   = true
)
Inference results:
-----------------------------------------
Free Energy: Real[132.198, 125.35, 120.523, 113.37, 99.2503, 84.8829, 80.6348, 79.6174, 79.1632, 78.9335, 78.8291, 78.7876, 78.7729, 78.768, 78.7666, 78.7661, 78.766, 78.766, 78.766, 78.766]
-----------------------------------------
A = MatrixDirichlet{Float64, Matrix{Float64}}(
a: [50.77712567849792 1.5084854704866...
s = Categorical{Float64, Vector{Float64}}[Categorical{Float64, Vector{Float64}}(supp...
B = MatrixDirichlet{Float64, Matrix{Float64}}(
a: [59.78624669132068 1.1776901428905...

Results

mean(result.posteriors[:A])
3×3 Matrix{Float64}:
 0.892911   0.0830398  0.133822
 0.0858821  0.691112   0.042798
 0.0212073  0.225849   0.82338
mean(result.posteriors[:B])
3×3 Matrix{Float64}:
 0.908176   0.0422229  0.0718142
 0.0694807  0.882211   0.0430307
 0.0223437  0.0755656  0.885155
p1 = scatter(argmax.(s_data), title="Inference results", label = "real", ms = 6)
p1 = scatter!(p1, argmax.(ReactiveMP.probvec.(result.posteriors[:s])), label = "inferred", ms = 2)
p2 = plot(result.free_energy, label="Free energy")

plot(p1, p2, layout = @layout([ a; b ]))

Custom inference

@model [ default_factorisation = MeanField() ] function custom_optimised_hidden_markov_model(n)

    A ~ MatrixDirichlet(ones(3, 3))
    B ~ MatrixDirichlet([ 10.0 1.0 1.0; 1.0 10.0 1.0; 1.0 1.0 10.0 ])

    s_0 ~ Categorical(fill(1.0 / 3.0, 3))

    s = randomvar(n)
    x = datavar(Vector{Float64}, n)

    s_prev = s_0

    for t in 1:n
        s[t] ~ Transition(s_prev, A) where { q = q(out, in)q(a) }
        x[t] ~ Transition(s[t], B)
        s_prev = s[t]
    end

    return s, x, A, B
end
function custom_optimised_inference(data, vmp_iters)
    n = length(data)

    model, (s, x, A, B) = custom_optimised_hidden_markov_model(model_options(limit_stack_depth = 500), n)

    sbuffer = keep(Vector{Marginal})
    Abuffer = keep(Marginal)
    Bbuffer = keep(Marginal)
    fe      = ScoreActor(Float64)

    ssub  = subscribe!(getmarginals(s), sbuffer)
    Asub  = subscribe!(getmarginal(A), Abuffer)
    Bsub  = subscribe!(getmarginal(B), Bbuffer)
    fesub = subscribe!(score(Float64, BetheFreeEnergy(), model), fe)

    setmarginal!(A, vague(MatrixDirichlet, 3, 3))
    setmarginal!(B, vague(MatrixDirichlet, 3, 3))

    foreach(s) do svar
        setmarginal!(svar, vague(Categorical, 3))
    end

    for i in 1:vmp_iters
        update!(x, data)
    end

    unsubscribe!(ssub)
    unsubscribe!(Asub)
    unsubscribe!(Bsub)
    unsubscribe!(fesub)

    return map(getvalues, (sbuffer, Abuffer, Bbuffer, fe))
end
custom_optimised_inference (generic function with 1 method)
sbuffer, Abuffer, Bbuffer, fe = custom_optimised_inference(x_data, 20)

@assert mean.(last(sbuffer)) ≈ mean.(result.posteriors[:s])
p1 = scatter(argmax.(s_data), title="Inference results", label = "real", ms = 6)
p1 = scatter!(p1, argmax.(ReactiveMP.probvec.(last(sbuffer))), label = "inferred", ms = 2)
p2 = plot(result.free_energy, label="Free energy")

plot(p1, p2, layout = @layout([ a; b ]))

Benchmark timings

@benchmark inference(
    model         = $imodel,
    data          = $idata,
    constraints   = hidden_markov_model_constraints(),
    initmarginals = $imarginals,
    returnvars    = $ireturnvars,
    iterations    = 20,
    free_energy   = true
)
BenchmarkTools.Trial: 77 samples with 1 evaluation.
 Range (min … max):  53.801 ms … 124.627 ms  ┊ GC (min … max): 0.00% … 47.37%
 Time  (median):     58.137 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   65.653 ms ±  18.536 ms  ┊ GC (mean ± σ):  9.34% ± 14.26%

  ▄█▇▇▂▅                                                        
  ██████▆▅█▅▁▃▆▃▁▁▁▁▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▁▃▁▁▅▃▆ ▁
  53.8 ms         Histogram: frequency by time          114 ms <

 Memory estimate: 23.52 MiB, allocs estimate: 398675.
@benchmark custom_optimised_inference($x_data, 20)
BenchmarkTools.Trial: 68 samples with 1 evaluation.
 Range (min … max):  55.950 ms … 168.484 ms  ┊ GC (min … max):  0.00% … 38.05%
 Time  (median):     61.984 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   74.093 ms ±  26.332 ms  ┊ GC (mean ± σ):  10.35% ± 15.04%

  ▁█▂                                                           
  ███▅▇▄▆▄▁▄▁▁▃▃▃▁▃▁▁▁▃▁▁▁▁▁▃▃▁▃▁▁▁▃▁▄▁▃▁▁▁▃▁▃▁▁▁▁▁▁▃▁▁▁▁▁▁▁▃▃ ▁
  56 ms           Histogram: frequency by time          151 ms <

 Memory estimate: 22.74 MiB, allocs estimate: 374792.