Example: Hidden Markov Model
In this demo we are interested in Bayesian inference of parameters of a hidden Markov model (HMM)., Specifically, we consider a first-order HMM with hidden states $s_0, s_1, \dots, s_T$ and observations $x_1, \dots, x_T$ governed by a state transition probability matrix $A$ and an observation probability matrix $B$:,
\[\begin{align*} s_t & \sim \mathcal{C}at(A s_{t-1}),\\ x_t & \sim \mathcal{C}at(B s_t).\\, \end{align*}\]
We assume three possible states (\"red\", \"green\" and \"blue\"), and the goal is to estimate matrices $A$ and $B$ from a simulated data set. To have a full Bayesian treatment of the problem, both $A$ and $B$ are endowed with priors (Dirichlet distributions on the columns)."
using Rocket, ReactiveMP, GraphPPL
using Random, BenchmarkTools, Distributions, LinearAlgebra
using Plots
function rand_vec(rng, distribution::Categorical)
k = ncategories(distribution)
s = zeros(k)
s[ rand(rng, distribution) ] = 1.0
s
end
function generate_data(n_samples; seed = 124)
rng = MersenneTwister(seed)
# Transition probabilities (some transitions are impossible)
A = [0.9 0.0 0.1; 0.1 0.9 0.0; 0.0 0.1 0.9]
# Observation noise
B = [0.9 0.05 0.05; 0.05 0.9 0.05; 0.05 0.05 0.9]
# Initial state
s_0 = [1.0, 0.0, 0.0]
# Generate some data
s = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoding of the states
x = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoding of the observations
s_prev = s_0
for t = 1:n_samples
a = A * s_prev
s[t] = rand_vec(rng, Categorical(a ./ sum(a)))
b = B * s[t]
x[t] = rand_vec(rng, Categorical(b ./ sum(b)))
s_prev = s[t]
end
return x, s
end
generate_data (generic function with 1 method)
# Test data
N = 100
x_data, s_data = generate_data(N)
scatter(argmax.(s_data))
Model specification
# Model specification
@model function hidden_markov_model(n)
A ~ MatrixDirichlet(ones(3, 3))
B ~ MatrixDirichlet([ 10.0 1.0 1.0; 1.0 10.0 1.0; 1.0 1.0 10.0 ])
s_0 ~ Categorical(fill(1.0 / 3.0, 3))
s = randomvar(n)
x = datavar(Vector{Float64}, n)
s_prev = s_0
for t in 1:n
s[t] ~ Transition(s_prev, A)
x[t] ~ Transition(s[t], B)
s_prev = s[t]
end
end
@constraints function hidden_markov_model_constraints()
q(s_0, s, A, B) = q(s_0, s)q(A)q(B)
end
hidden_markov_model_constraints (generic function with 1 method)
Inference
idata = (x = x_data, )
imodel = Model(hidden_markov_model, N)
imarginals = (
A = vague(MatrixDirichlet, 3, 3),
B = vague(MatrixDirichlet, 3, 3),
s = vague(Categorical, 3)
)
ireturnvars = (
A = KeepLast(),
B = KeepLast(),
s = KeepLast()
)
result = inference(
model = imodel,
data = idata,
constraints = hidden_markov_model_constraints(),
initmarginals = imarginals,
returnvars = ireturnvars,
iterations = 20,
free_energy = true
)
Inference results:
-----------------------------------------
Free Energy: Real[132.198, 125.35, 120.523, 113.37, 99.2503, 84.8829, 80.6348, 79.6174, 79.1632, 78.9335, 78.8291, 78.7876, 78.7729, 78.768, 78.7666, 78.7661, 78.766, 78.766, 78.766, 78.766]
-----------------------------------------
A = MatrixDirichlet{Float64, Matrix{Float64}}(
a: [50.77712567849792 1.5084854704866...
s = Distributions.Categorical{Float64, Vector{Float64}}[Distributions.Categorical{Fl...
B = MatrixDirichlet{Float64, Matrix{Float64}}(
a: [59.78624669132068 1.1776901428905...
Results
mean(result.posteriors[:A])
3×3 Matrix{Float64}:
0.892911 0.0830398 0.133822
0.0858821 0.691112 0.042798
0.0212073 0.225849 0.82338
mean(result.posteriors[:B])
3×3 Matrix{Float64}:
0.908176 0.0422229 0.0718142
0.0694807 0.882211 0.0430307
0.0223437 0.0755656 0.885155
p1 = scatter(argmax.(s_data), title="Inference results", label = "real", ms = 6)
p1 = scatter!(p1, argmax.(ReactiveMP.probvec.(result.posteriors[:s])), label = "inferred", ms = 2)
p2 = plot(result.free_energy, label="Free energy")
plot(p1, p2, layout = @layout([ a; b ]))
Custom inference
@model [ default_factorisation = MeanField() ] function custom_optimised_hidden_markov_model(n)
A ~ MatrixDirichlet(ones(3, 3))
B ~ MatrixDirichlet([ 10.0 1.0 1.0; 1.0 10.0 1.0; 1.0 1.0 10.0 ])
s_0 ~ Categorical(fill(1.0 / 3.0, 3))
s = randomvar(n)
x = datavar(Vector{Float64}, n)
s_prev = s_0
for t in 1:n
s[t] ~ Transition(s_prev, A) where { q = q(out, in)q(a) }
x[t] ~ Transition(s[t], B)
s_prev = s[t]
end
return s, x, A, B
end
function custom_optimised_inference(data, vmp_iters)
n = length(data)
model, (s, x, A, B) = custom_optimised_hidden_markov_model(model_options(limit_stack_depth = 500), n)
sbuffer = keep(Vector{Marginal})
Abuffer = keep(Marginal)
Bbuffer = keep(Marginal)
fe = ScoreActor(Float64)
ssub = subscribe!(getmarginals(s), sbuffer)
Asub = subscribe!(getmarginal(A), Abuffer)
Bsub = subscribe!(getmarginal(B), Bbuffer)
fesub = subscribe!(score(Float64, BetheFreeEnergy(), model), fe)
setmarginal!(A, vague(MatrixDirichlet, 3, 3))
setmarginal!(B, vague(MatrixDirichlet, 3, 3))
foreach(s) do svar
setmarginal!(svar, vague(Categorical, 3))
end
for i in 1:vmp_iters
update!(x, data)
end
unsubscribe!(ssub)
unsubscribe!(Asub)
unsubscribe!(Bsub)
unsubscribe!(fesub)
return map(getvalues, (sbuffer, Abuffer, Bbuffer, fe))
end
custom_optimised_inference (generic function with 1 method)
sbuffer, Abuffer, Bbuffer, fe = custom_optimised_inference(x_data, 20)
@assert mean.(last(sbuffer)) ≈ mean.(result.posteriors[:s])
p1 = scatter(argmax.(s_data), title="Inference results", label = "real", ms = 6)
p1 = scatter!(p1, argmax.(ReactiveMP.probvec.(last(sbuffer))), label = "inferred", ms = 2)
p2 = plot(result.free_energy, label="Free energy")
plot(p1, p2, layout = @layout([ a; b ]))
Benchmark timings
@benchmark inference(
model = $imodel,
data = $idata,
constraints = hidden_markov_model_constraints(),
initmarginals = $imarginals,
returnvars = $ireturnvars,
iterations = 20,
free_energy = true
)
BenchmarkTools.Trial: 79 samples with 1 evaluation.
Range (min … max): 52.128 ms … 112.747 ms ┊ GC (min … max): 0.00% … 44.84%
Time (median): 57.208 ms ┊ GC (median): 0.00%
Time (mean ± σ): 64.325 ms ± 18.648 ms ┊ GC (mean ± σ): 10.98% ± 15.61%
▂▇█▄ ▃▃▂
████▇███▅▇▃▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▃▁▁▁▇▃▃▅▃ ▁
52.1 ms Histogram: frequency by time 111 ms <
Memory estimate: 25.25 MiB, allocs estimate: 414972.
@benchmark custom_optimised_inference($x_data, 20)
BenchmarkTools.Trial: 83 samples with 1 evaluation.
Range (min … max): 51.265 ms … 106.978 ms ┊ GC (min … max): 0.00% … 42.32%
Time (median): 54.392 ms ┊ GC (median): 0.00%
Time (mean ± σ): 61.385 ms ± 17.403 ms ┊ GC (mean ± σ): 10.50% ± 15.32%
█▂
██▄▆▇▆▅▄▅▃▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▄▃▄▃▃ ▁
51.3 ms Histogram: frequency by time 105 ms <
Memory estimate: 24.02 MiB, allocs estimate: 388135.