Example: Gamma Mixture

This example implements one of the experiments outlined in https://biaslab.github.io/publication/mp-based-inference-in-gmm/.

To model this process in ReactiveMP, first, we start with importing all needed packages:

using Rocket, ReactiveMP, GraphPPL
using Distributions, Random, StableRNGs
using StatsPlots
# create custom structure for model parameters for simplicity
struct GammaMixtureModelParameters
    nmixtures   # number of mixtures
    priors_as   # tuple of priors for variable a
    priors_bs   # tuple of priors for variable b
    prior_s     # prior of variable s
end

Model specification

@model function gamma_mixture_model(nobservations, parameters::GammaMixtureModelParameters)

    # fetch information from struct
    nmixtures = parameters.nmixtures
    priors_as = parameters.priors_as
    priors_bs = parameters.priors_bs
    prior_s   = parameters.prior_s

    # set prior on global selection variable
    s ~ Dirichlet(probvec(prior_s))

    # allocate vectors of random variables
    as = randomvar(nmixtures)
    bs = randomvar(nmixtures)

    # set priors on variables of mixtures
    for i in 1:nmixtures
        as[i] ~ GammaShapeRate(shape(priors_as[i]), rate(priors_as[i]))
        bs[i] ~ GammaShapeRate(shape(priors_bs[i]), rate(priors_bs[i]))
    end

    # introduce random variables for local selection variables and data
    z = randomvar(nobservations)
    y = datavar(Float64, nobservations)

    # convert vector to tuples for proper functioning of GammaMixture node
    tas = tuple(as...)
    tbs = tuple(bs...)

    # specify local selection variable and data generating process
    for i in 1:nobservations
        z[i] ~ Categorical(s)
        y[i] ~ GammaMixture(z[i], tas, tbs)
    end

    # return random variables
    return s, as, bs, z, y

end
constraints = @constraints begin
    q(as) :: PointMass(starting_point = (args...) -> [ 1.0 ])
end
Constraints:
  marginals form:
    q(as) :: PointMassFormConstraint() [ prod_constraint = ProdGeneric(fallback = ProdAnalytical()) ]
  messages form:
  factorisation:
Options:
  warn = true

Generate test dataset for verification

# specify seed and number of data points
rng = StableRNG(43)
n_samples = 2500

# specify parameters of mixture model that generates the data
# Note that mixture components have exactly the same means
mixtures  = [ Gamma(9.0, inv(27.0)), Gamma(90.0, inv(270.0)) ]
nmixtures = length(mixtures)
mixing    = rand(rng, nmixtures)
mixing    = mixing ./ sum(mixing)
mixture   = MixtureModel(mixtures, mixing)

# generate data set
dataset = rand(rng, mixture, n_samples)

Inference

# specify priors of probabilistic model
# NOTE: As the means of the mixtures "collide", we specify informative prior for selector variable
nmixtures = 2
gpriors = GammaMixtureModelParameters(
    nmixtures,                                                    # number of mixtures
    [ Gamma(1.0, 0.1), Gamma(1.0, 1.0) ],                         # priors on variables a
    [ GammaShapeRate(10.0, 2.0), GammaShapeRate(1.0, 3.0) ],      # priors on variables b
    Dirichlet(1e3*mixing)                                         # prior on variable s
)

gmodel         = Model(gamma_mixture_model, length(dataset), gpriors)
gdata          = (y = dataset, )
ginitmarginals = (s = gpriors.prior_s, z = vague(Categorical, gpriors.nmixtures), bs = GammaShapeRate(1.0, 1.0))
greturnvars    = (s = KeepLast(), z = KeepLast(), as = KeepEach(), bs = KeepEach())

goptions = (
    limit_stack_depth     = 100,
    default_factorisation = MeanField() # Mixture models require Mean-Field assumption currently
)

gresult = inference(
    model       = gmodel,
    data        = gdata,
    constraints = constraints,
    options     = goptions,
    initmarginals = ginitmarginals,
    returnvars    = greturnvars,
    free_energy   = true,
    iterations    = 250,
    showprogress  = true
)
Inference results:
-----------------------------------------
Free Energy: Real[757.147, -62.2868, -544.017, -866.001, -1094.6, -1264.11, -1393.83, -1495.54, -1576.83, -1642.81  …  -2035.5, -2035.5, -2035.5, -2035.5, -2035.5, -2035.5, -2035.5, -2035.5, -2035.5, -2035.5]
-----------------------------------------
as = Vector{PointMass{Float64}}[[PointMass{Float64}(0.5509517542657298), PointMass{Fl...
s  = Dirichlet{Float64, Vector{Float64}, Float64}(alpha=[2893.4287750101134, 606.5712...
bs = Vector{GammaShapeRate{Float64}}[[GammaShapeRate{Float64}(a=687.4706291245399, b=...
z  = Distributions.Categorical{Float64, Vector{Float64}}[Distributions.Categorical{Fl...

Verification

# extract inferred parameters
_as, _bs = mean.(gresult.posteriors[:as][end]), mean.(gresult.posteriors[:bs][end])
_dists   = map(g -> Gamma(g[1], inv(g[2])), zip(_as, _bs))
_mixing = mean(gresult.posteriors[:s])

# create model from inferred parameters
_mixture   = MixtureModel(_dists, _mixing);

# report on outcome of inference
println("Generated means: $(mean(mixtures[1])) and $(mean(mixtures[2]))")
println("Inferred means: $(mean(_dists[1])) and $(mean(_dists[2]))")
println("========")
println("Generated mixing: $(mixing)")
println("Inferred mixing: $(_mixing)")
Generated means: 0.3333333333333333 and 0.33333333333333337
Inferred means: 0.3314217819856029 and 0.35365172726821925
========
Generated mixing: [0.799908395356677, 0.20009160464332298]
Inferred mixing: [0.8266939357171744, 0.17330606428282555]

Results

# plot results
p1 = histogram(dataset, ylim = (0, 13), xlim = (0, 1), normalize=:pdf, label="observations")
p1 = plot!(mixture, label=false, title="Generated mixtures")

p2 = histogram(dataset, ylim = (0, 13), xlim = (0, 1), normalize=:pdf, label="data", opacity=0.3)
p2 = plot!(_mixture, label=false, title="Inferred mixtures", linewidth=3.0)

# evaluate the convergence of the algorithm by monitoring the BFE
p3 = plot(gresult.free_energy, label=false, xlabel="iterations", title="Bethe FE")

plot(p1, p2, layout = @layout([ a; b ]))
plot(p3)