Inference execution

This section explains how to use ReactiveMP reactive API for running inference on probabilistic models that were created with GraphPPL package as explained in Model Specification section.

The ReactiveMP inference API supports different types of message-passing algorithms (including hybrid algorithms combining several different types):

Whereas belief propagation computes exact inference for the random variables of interest, the variational message passing (VMP) in an approximation method that can be applied to a larger range of models.

The ReactiveMP engine itself isn't aware of different algorithm types and simply does message passing between nodes, however during model specification stage user may specify different factorisation constraints around factor nodes by using where { q = ... } syntax or with the help of the @constraints macro. Different factorisation constraints lead to a different message passing update rules.

Inference with ReactiveMP usually consists of the same simple building blocks and designed in such a way to support both static and real-time infinite datasets:

  1. Create a model with @model macro and get a references to random variables and data inputs
  2. Subscribe to random variable posterior marginal updates
  3. Subscribe to Bethe Free Energy updates (optional)
  4. Feed model with observations
  5. Unsubscribe from posterior marginal updates (optional)

It is worth to note that Step 5 is optional and in case where observations come from an infinite real-time data stream (e.g. from the internet) it may be justified to never unsubscribe and perform real-time Bayesian inference in a reactive manner as soon as data arrives.

ReactiveMP.jl provides generic inference function to simplify these steps and test models faster. However, this function does not support the full range of ReactiveMP.jl's package feature. Read about both automatic and manual approaches below.

Automatic inference specification

ReactiveMP.jl exports user friendly inference function to quickly run and test you model with static datasets. Note, however, that this function does not use all capabilities of ReactiveMP.jl library and for advanced use cases you may want to resort to the manual inference specification section and Advanced Tutorial section.

ReactiveMP.inferenceFunction
inference(
    model::ModelGenerator; 
    data,
    initmarginals           = nothing,
    initmessages            = nothing,
    constraints             = nothing,
    meta                    = nothing,
    options                 = (;),
    returnvars              = nothing, 
    iterations              = nothing,
    free_energy             = false,
    free_energy_diagnostics = BetheFreeEnergyDefaultChecks,
    showprogress            = false,
    callbacks               = nothing,
)

This function provides a generic (but somewhat limited) way to perform probabilistic inference in ReactiveMP.jl. Returns InferenceResult.

Arguments

For more information about some of the arguments, please check below.

  • model::ModelGenerator: specifies a model generator, with the help of the Model function
  • data: NamedTuple or Dict with data, required
  • initmarginals = nothing: NamedTuple or Dict with initial marginals, optional, defaults to nothing
  • initmessages = nothing: NamedTuple or Dict with initial messages, optional, defaults to nothing
  • constraints = nothing: constraints specification object, optional, see @constraints
  • meta = nothing: meta specification object, optional, may be required for some models, see @meta
  • options = (;): model creation options, optional, see model_options
  • returnvars = nothing: return structure info, optional, defaults to return everything at each iteration, see below for more information
  • iterations = nothing: number of iterations, optional, defaults to nothing, we do not distinguish between variational message passing or Loopy belief propagation or expectation propagation iterations, see below for more information
  • free_energy = false: compute the Bethe free energy, optional, defaults to false. Can be passed a floating point type, e.g. Float64, for better efficiency, but disables automatic differentiation packages, such as ForwardDiff.jl
  • free_energy_diagnostics = BetheFreeEnergyDefaultChecks: free energy diagnostic checks, optional, by default checks for possible NaNs and Infs. nothing disables all checks.
  • showprogress = false: show progress module, optional, defaults to false
  • callbacks = nothing: inference cycle callbacks, optional, see below for more info
  • warn = true: enables/disables warnings

Note on NamedTuples

When passing NamedTuple as a value for some argument, make sure you use a trailing comma for NamedTuples with a single entry. The reason is that Julia treats returnvars = (x = KeepLast()) and returnvars = (x = KeepLast(), ) expressions differently. First expression creates (or overwrites!) new local/global variable named x with contents KeepLast(). The second expression (note traling comma) creates NamedTuple with x as a key and KeepLast() as a value assigned for this key.

Extended information about some of the arguments

  • model

The model argument accepts a ModelGenerator as its input. The easiest way to create the ModelGenerator is to use the Model function. The Model function accepts a model name as its first argument and the rest is passed directly to the model constructor. For example:

result = inference(
    # Creates `coin_toss(some_argument, some_keyword_argument = 3)`
    model = Model(coin_toss, some_argument; some_keyword_argument = 3)
)

Note: The model keyword argument does not accept a FactorGraphModel instance as a value, as it needs to inject constraints and meta during the inference procedure.

  • initmarginals

In general for variational message passing every marginal distribution in a model needs to be pre-initialised. In practice, however, for many models it is sufficient enough to initialise only a small subset of variables in the model.

  • initmessages

Loopy belief propagation may need some messages in a model to be pre-initialised.

  • options

See ?model_options.

  • returnvars

returnvars specifies the variables of interests and the amount of information to return about their posterior updates.

returnvars accepts a NamedTuple or Dict or return var specification. There are two specifications:

  • KeepLast: saves the last update for a variable, ignoring any intermediate results during iterations
  • KeepEach: saves all updates for a variable for all iterations

Note: if iterations are specified as a number, the inference function tracks and returns every update for each iteration for every random variable in the model (equivalent to KeepEach()). If number of iterations is set to nothing, the inference function saves the 'last' (and the only one) update for every random variable in the model (equivalent to KeepLast()). Use iterations = 1 to force KeepEach() setting when number of iterations is equal to 1 or set returnvars = KeepEach() manually.

Example:

result = inference(
    ...,
    returnvars = (
        x = KeepLast(),
        τ = KeepEach()
    )
)

It is also possible to set iether returnvars = KeepLast() or returnvars = KeepEach() that acts as an alias and sets the given option for all random variables in the model.

Example:

result = inference(
    ...,
    returnvars = KeepLast()
)
  • iterations

Specifies the number of variational (or loopy BP) iterations. By default set to nothing, which is equivalent of doing 1 iteration.

  • free_energy

This setting specifies whenever the inference function should return Bethe Free Energy (BFE) values. Note, however, that it may be not possible to compute BFE values for every model.

Additionally, the argument may accept a floating point type, instead of a Bool value. Using his option, e.g.Float64, improves performance of Bethe Free Energy computation, but restricts using automatic differentiation packages.

  • free_energy_diagnostics

This settings specifies either a single or a tuple of diagnostic checks for Bethe Free Energy values stream. By default checks for NaNs and Infs. See also BetheFreeEnergyCheckNaNs and BetheFreeEnergyCheckInfs. Pass nothing to disable any checks.

  • callbacks

The inference function has its own lifecycle. The user is free to provide some (or none) of the callbacks to inject some extra logging or other procedures in the inference function, e.g.

result = inference(
    ...,
    callbacks = (
        on_marginal_update = (model, name, update) -> println("$(name) has been updated: $(update)"),
        after_inference    = (args...) -> println("Inference has been completed")
    )
)

The list of all possible callbacks is present below:

  • :on_marginal_update: args: (model::FactorGraphModel, name::Symbol, update)
  • :before_model_creation: args: ()
  • :after_model_creation: args: (model::FactorGraphModel)
  • :before_inference: args: (model::FactorGraphModel)
  • :before_iteration: args: (model::FactorGraphModel, iteration::Int)
  • :before_data_update: args: (model::FactorGraphModel, data)
  • :after_data_update: args: (model::FactorGraphModel, data)
  • :after_iteration: args: (model::FactorGraphModel, iteration::Int)
  • :after_inference: args: (model::FactorGraphModel)

See also: InferenceResult

source
ReactiveMP.InferenceResultType
InferenceResult

This structure is used as a return value from the inference function.

Fields

  • posteriors: Dict or NamedTuple of 'random variable' - 'posterior' pairs. See the returnvars argument for inference.
  • free_energy: (optional) An array of Bethe Free Energy values per VMP iteration. See the free_energy argument for inference.
  • model: FactorGraphModel object reference.
  • returnval: Return value from executed @model.

See also: inference

source

Manual inference specification

For advanced use cases such as online real-time Bayesian inference it is advised to use manual inference specification.

Model creation

During model specification stage user decides on variables of interest in a model and returns (optionally) them using a return ... statement. As an example consider that we have a simple hierarchical model in which the mean of a Normal distribution is represented by another Normal distribution whose mean is modelled by another Normal distribution.

using Rocket, GraphPPL, ReactiveMP, Distributions, Random

@model function my_model()
    m2 ~ NormalMeanVariance(0.0, 1.0)
    m1 ~ NormalMeanVariance(m2, 1.0)

    y = datavar(Float64)
    y ~ NormalMeanVariance(m1, 1.0)

    # Return variables of interests, optional
    return m1, y
end

And later on we may create our model and obtain references for variables of interests:

model, (m1, y) = my_model()

Alternatively, it is possible to query variables using squared brackets on model object:

model, _ = my_model()

model[:m1] # m1
model[:y]  # y

@model macro also return a reference for a factor graph as its first return value. Factor graph object (named model in previous example) contains all information about all factor nodes in a model as well as random variables and data inputs. See Advanced Tutorial section.

Posterior marginal updates

The ReactiveMP package has a reactive API and operates in terms of Observables and Actors. For detailed information about these concepts we refer to Rocket.jl documentation.

We use getmarginal function to get a posterior marginal updates observable:

m1_posterior_updates = getmarginal(m1)

After that we can subscribe on new updates and perform some actions based on new values:

m1_posterior_subscription = subscribe!(m1_posterior_updates, (new_posterior) -> begin
    println("New posterior for m1: ", new_posterior)
end)

Sometimes it is usefull to return an array of random variables from model specification, in this case we may use getmarginals() function that transform an array of observables to an observable of arrays.

@model function my_model()
    ...
    m_n = randomvar(n)
    ...
    return m_n, ...
end

model, (m_n, ...) = my_model()

m_n_updates = getmarginals(m_n)

Feeding observations

By default (without any extra factorisation constraints) model specification implies Belief Propagation message passing update rules. In case of BP algorithm ReactiveMP package computes an exact Bayesian posteriors with a single message passing iteration. To enforce Belief Propagation message passing update rule for some specific factor node user may use where { q = FullFactorisation() } option. Read more in Model Specification section. To perform a message passing iteration we need to pass some data to all our data inputs that were created with datavar function during model specification.

To feed an observation for a specific data input we use update! function:

update!(y, 0.0)
New posterior for m1: Marginal(NormalWeightedMeanPrecision{Float64}(xi=0.0, w=1.5))

As you can see after we passed a single value to our data input we got a posterior marginal update from our subscription and printed it with println function. In case of BP if observations do not change it should not affect posterior marginal results:

update!(y, 0.0) # Observation didn't change, should result in the same posterior
New posterior for m1: Marginal(NormalWeightedMeanPrecision{Float64}(xi=0.0, w=1.5))

If y is an array of data inputs it is possible to pass an array of observation to update! function:

for i in 1:length(data)
    update!(y[i], data[i])
end
# is an equivalent of
update!(y, data)

Variational Message Passing

Variational message passing (VMP) algorithms are generated much in the same way as the belief propagation algorithm we saw in the previous section. There is a major difference though: for VMP algorithm generation we need to define the factorization properties of our approximate distribution. A common approach is to assume that all random variables of the model factorize with respect to each other. This is known as the mean field assumption. In ReactiveMP, the specification of such factorization properties is defined during model specification stage using the where { q = ... } syntax or with the @constraints macro (see Constraints specification section for more info about the @constraints macro). Let's take a look at a simple example to see how it is used. In this model we want to learn the mean and precision of a Normal distribution, where the former is modelled with a Normal distribution and the latter with a Gamma.

using Rocket, GraphPPL, ReactiveMP, Distributions, Random
real_mean      = -4.0
real_precision = 0.2
rng            = MersenneTwister(1234)

n    = 100
data = rand(rng, Normal(real_mean, sqrt(inv(real_precision))), n)
@model function normal_estimation(n)
    m ~ NormalMeanVariance(0.0, 10.0)
    w ~ Gamma(0.1, 10.0)

    y = datavar(Float64, n)

    for i in 1:n
        y[i] ~ NormalMeanPrecision(m, w) where { q = MeanField() }
    end

    return m, w, y
end

We create our model as usual, however in order to start VMP inference procedure we need to set initial posterior marginals for all random variables in the model:

model, (m, w, y) = normal_estimation(n)

# We use vague initial marginals
setmarginal!(m, vague(NormalMeanVariance))
setmarginal!(w, vague(Gamma))

To perform a single VMP iteration it is enough to feed all data inputs with some values. To perform multiple VMP iterations we should feed our all data inputs with the same values multiple times:

m_marginals = []
w_marginals = []

subscriptions = subscribe!([
    (getmarginal(m), (marginal) -> push!(m_marginals, marginal)),
    (getmarginal(w), (marginal) -> push!(w_marginals, marginal)),
])

vmp_iterations = 10

for _ in 1:vmp_iterations
    update!(y, data)
end

unsubscribe!(subscriptions)

As we process more VMP iterations, our beliefs about the possible values of m and w converge and become more confident.

using Plots

p1    = plot(title = "'Mean' posterior marginals")
grid1 = -6.0:0.01:4.0

for iter in [ 1, 2, 10 ]

    estimated = Normal(mean(m_marginals[iter]), std(m_marginals[iter]))
    e_pdf     = (x) -> pdf(estimated, x)

    plot!(p1, grid1, e_pdf, fill = true, opacity = 0.3, label = "Estimated mean after $iter VMP iterations")

end

plot!(p1, [ real_mean ], seriestype = :vline, label = "Real mean", color = :red4, opacity = 0.7)
p2    = plot(title = "'Precision' posterior marginals")
grid2 = 0.01:0.001:0.35

for iter in [ 2, 3, 10 ]

    estimated = Gamma(shape(w_marginals[iter]), scale(w_marginals[iter]))
    e_pdf     = (x) -> pdf(estimated, x)

    plot!(p2, grid2, e_pdf, fill = true, opacity = 0.3, label = "Estimated precision after $iter VMP iterations")

end

plot!(p2, [ real_precision ], seriestype = :vline, label = "Real precision", color = :red4, opacity = 0.7)

Computing Bethe Free Energy

VMP inference boils down to finding the member of a family of tractable probability distributions that is closest in KL divergence to an intractable posterior distribution. This is achieved by minimizing a quantity known as Variational Free Energy. ReactiveMP uses Bethe Free Energy approximation to the real Variational Free Energy. Free energy is particularly useful to test for convergence of the VMP iterative procedure.

The ReactiveMP package exports score function for an observable of free energy values:

fe_observable = score(BetheFreeEnergy(), model)
# Reset posterior marginals for `m` and `w`
setmarginal!(m, vague(NormalMeanVariance))
setmarginal!(w, vague(Gamma))

fe_values = []

fe_subscription = subscribe!(fe_observable, (v) -> push!(fe_values, v))

vmp_iterations = 10

for _ in 1:vmp_iterations
    update!(y, data)
end

unsubscribe!(fe_subscription)
plot(fe_values, label = "Bethe Free Energy", xlabel = "Iteration #")