# Inference execution

This section explains how to use `ReactiveMP`

reactive API for running inference on probabilistic models that were created with `GraphPPL`

package as explained in Model Specification section.

The `ReactiveMP`

inference API supports different types of message-passing algorithms (including hybrid algorithms combining several different types):

Whereas belief propagation computes exact inference for the random variables of interest, the variational message passing (VMP) in an approximation method that can be applied to a larger range of models.

The `ReactiveMP`

engine itself isn't aware of different algorithm types and simply does message passing between nodes, however during model specification stage user may specifiy different factorisation constraints around factor nodes by using `where { q = ... }`

syntax or with the help of the `@constraints`

macro. Different factorisation constraints lead to a different message passing update rules.

Inference with `ReactiveMP`

usually consists of the same simple building blocks and designed in such a way to support both static and real-time infinite datasets:

- Create a model with
`@model`

macro and get a references to random variables and data inputs - Subscribe to random variable posterior marginal updates
- Subscribe to Bethe Free Energy updates (optional)
- Feed model with observations
- Unsubscribe from posterior marginal updates (optional)

It is worth to note that Step 5 is optional and in case where observations come from an infinite real-time data stream (e.g. from the internet) it may be justified to never unsubscribe and perform real-time Bayesian inference in a reactive manner as soon as data arrives.

ReactiveMP.jl provides generic `inference`

function to simplify these steps and test models faster. However, this function does not suppor the full range of ReactiveMP.jl's package feature. Read about both automatic and manual approaches below.

## Automatic inference specification

ReactiveMP.jl exports user friendly `inference`

function to quickly run and test you model with static datasets. Note, however, that this function does not use all capabilities of ReactiveMP.jl library and for advanced use cases you may want to resort to the manual inference specification section and Advanced Tutorial section.

`ReactiveMP.inference`

— Function```
inference(
model::ModelGenerator;
data,
initmarginals = nothing,
initmessages = nothing,
constraints = nothing,
meta = nothing,
options = (;),
returnvars = nothing,
iterations = 1,
free_energy = false,
showprogress = false,
callbacks = nothing,
)
```

This function provides a generic (but somewhat limited) way to perform probabilistic inference in ReactiveMP.jl. Returns `InferenceResult`

.

**Arguments**

For more information about some of the arguments, please check below.

`model::ModelGenerator`

: specifies a model generator, with the help of the`Model`

function`data`

:`NamedTuple`

or`Dict`

with data, required`initmarginals = nothing`

:`NamedTuple`

or`Dict`

with initial marginals, optional, defaults to nothing`initmessages = nothing`

:`NamedTuple`

or`Dict`

with initial messages, optional, defaults to nothing`constraints = nothing`

: constraints specification object, optional, see`@constraints`

`meta = nothing`

: meta specification object, optional, may be required for some models, see`@meta`

`options = (;)`

: model creation options, optional, see`model_options`

`returnvars = nothing`

: return structure info, optional, defaults to return everything at each iteration, see below for more information`iterations = 1`

: number of iterations, optional, defaults to 1, we do not distinguish between variational message passing or Loopy belief propagation or expectation propagation iterations`free_energy = false`

: compute the Bethe free energy, optional, defaults to false. Can be passed a floating point type, e.g.`Float64`

, for better efficiency, but disables automatic differentiation packages, such as ForwardDiff.jl`showprogress = false`

: show progress module, optional, defaults to false`callbacks = nothing`

: inference cycle callbacks, optional, see below for more info`warn = true`

: enables/disables warnings

**Note on NamedTuples**

When passing `NamedTuple`

as a value for some argument, make sure you use a trailing comma for `NamedTuple`

s with a single entry. The reason is that Julia treats `returnvars = (x = KeepLast())`

and `returnvars = (x = KeepLast(), )`

expressions differently. First expression creates (or **overwrites!**) new local/global variable named `x`

with contents `KeepLast()`

. The second expression (note traling comma) creates `NamedTuple`

with `x`

as a key and `KeepLast()`

as a value assigned for this key.

**Extended information about some of the arguments**

`model`

The `model`

argument accepts a `ModelGenerator`

as its input. The easiest way to create the `ModelGenerator`

is to use the `Model`

function. The `Model`

function accepts a model name as its first argument and the rest is passed directly to the model constructor. For example:

```
result = inference(
# Creates `coin_toss(some_argument, some_keyword_argument = 3)`
model = Model(coin_toss, some_argument; some_keyword_argument = 3)
)
```

**Note**: The `model`

keyword argument does not accept a `FactorGraphModel`

instance as a value, as it needs to inject `constraints`

and `meta`

during the inference procedure.

`initmarginals`

In general for variational message passing every marginal distribution in a model needs to be pre-initialised. In practice, however, for many models it is sufficient enough to initialise only a small subset of variables in the model.

`initmessages`

Loopy belief propagation may need some messages in a model to be pre-initialised.

`options`

See `?model_options`

.

`returnvars`

`returnvars`

specifies the variables of interests and the amount of information to return about their posterior updates. By default the `inference`

function tracks and returns every update for each iteration for every random variable in the model. `returnvars`

accepts a `NamedTuple`

or `Dict`

or return var specification. There are two specifications:

`KeepLast`

: saves the last update for a variable, ignoring any intermediate results during iterations`KeepEach`

: saves all updates for a variable for all iterations

Example:

```
result = inference(
...,
returnvars = (
x = KeepLast(),
τ = KeepEach()
)
)
```

`free_energy`

This setting specifies whenever the `inference`

function should return Bethe Free Energy (BFE) values. Note, however, that it may be not possible to compute BFE values for every model.

Additionally, the argument may accept a floating point type, instead of a `Bool`

value. Using his option, e.g.`Float64`

, improves performance of Bethe Free Energy computation, but restricts using automatic differentiation packages.

`callbacks`

The inference function has its own lifecycle. The user is free to provide some (or none) of the callbacks to inject some extra logging or other procedures in the inference function, e.g.

```
result = inference(
...,
callbacks = (
on_marginal_update = (model, name, update) -> println("$(name) has been updated: $(update)"),
after_inference = (args...) -> println("Inference has been completed")
)
)
```

The list of all possible callbacks is present below:

`:on_marginal_update`

: args: (model::FactorGraphModel, name::Symbol, update)`:before_model_creation`

: args: ()`:after_model_creation`

: args: (model::FactorGraphModel)`:before_inference`

: args: (model::FactorGraphModel)`:before_iteration`

: args: (model::FactorGraphModel, iteration::Int)`:before_data_update`

: args: (model::FactorGraphModel, data)`:after_data_update`

: args: (model::FactorGraphModel, data)`:after_iteration`

: args: (model::FactorGraphModel, iteration::Int)`:after_inference`

: args: (model::FactorGraphModel)

See also: `InferenceResult`

`ReactiveMP.InferenceResult`

— Type`InferenceResult`

This structure is used as a return value from the `inference`

function.

**Fields**

`posteriors`

:`Dict`

or`NamedTuple`

of 'random variable' - 'posterior' pairs. See the`returnvars`

argument for`inference`

.`free_energy`

: (optional) An array of Bethe Free Energy values per VMP iteration. See the`free_energy`

argument for`inference`

.`model`

:`FactorGraphModel`

object reference.`returnval`

: Return value from executed`@model`

.

See also: `inference`

## Manual inference specification

For advanced use cases such as online real-time Bayesian inference it is advised to use manual inference specification.

### Model creation

During model specification stage user decides on variables of interesets in a model and returns (optionally) them using a `return ...`

statement. As an example consider that we have a simple hierarchical model in which the mean of a Normal distribution is represented by another Normal distribution whose mean is modelled by another Normal distribution.

```
using Rocket, GraphPPL, ReactiveMP, Distributions, Random
@model function my_model()
m2 ~ NormalMeanVariance(0.0, 1.0)
m1 ~ NormalMeanVariance(m2, 1.0)
y = datavar(Float64)
y ~ NormalMeanVariance(m1, 1.0)
# Return variables of interests, optional
return m1, y
end
```

And later on we may create our model and obtain references for variables of interests:

`model, (m1, y) = my_model()`

Alternatively, it is possible to query variables using squared brackets on `model`

object:

```
model, _ = my_model()
model[:m1] # m1
model[:y] # y
```

`@model`

macro also return a reference for a factor graph as its first return value. Factor graph object (named `model`

in previous example) contains all information about all factor nodes in a model as well as random variables and data inputs. See Advanced Tutorial section.

### Posterior marginal updates

The `ReactiveMP`

package has a reactive API and operates in terms of Observables and Actors. For detailed information about these concepts we refer to Rocket.jl documentation.

We use `getmarginal`

function to get a posterior marginal updates observable:

`m1_posterior_updates = getmarginal(m1)`

After that we can subscribe on new updates and perform some actions based on new values:

```
m1_posterior_subscription = subscribe!(m1_posterior_updates, (new_posterior) -> begin
println("New posterior for m1: ", new_posterior)
end)
```

Sometimes it is usefull to return an array of random variables from model specification, in this case we may use `getmarginals()`

function that transform an array of observables to an observable of arrays.

```
@model function my_model()
...
m_n = randomvar(n)
...
return m_n, ...
end
model, (m_n, ...) = my_model()
m_n_updates = getmarginals(m_n)
```

### Feeding observations

By default (without any extra factorisation constraints) model specification implies Belief Propagation message passing update rules. In case of BP algorithm `ReactiveMP`

package computes an exact Bayesian posteriors with a single message passing iteration. To enforce Belief Propagation message passing update rule for some specific factor node user may use `where { q = FullFactorisation() }`

option. Read more in Model Specification section. To perform a message passing iteration we need to pass some data to all our data inputs that were created with `datavar`

function during model specification.

To feed an observation for a specific data input we use `update!`

function:

`update!(y, 0.0)`

`New posterior for m1: Marginal(NormalWeightedMeanPrecision{Float64}(xi=0.0, w=1.5))`

As you can see after we passed a single value to our data input we got a posterior marginal update from our subscription and printed it with `println`

function. In case of BP if observations do not change it should not affect posterior marginal results:

`update!(y, 0.0) # Observation didn't change, should result in the same posterior`

`New posterior for m1: Marginal(NormalWeightedMeanPrecision{Float64}(xi=0.0, w=1.5))`

If `y`

is an array of data inputs it is possible to pass an array of observation to `update!`

function:

```
for i in 1:length(data)
update!(y[i], data[i])
end
# is an equivalent of
update!(y, data)
```

### Variational Message Passing

Variational message passing (VMP) algorithms are generated much in the same way as the belief propagation algorithm we saw in the previous section. There is a major difference though: for VMP algorithm generation we need to define the factorization properties of our approximate distribution. A common approach is to assume that all random variables of the model factorize with respect to each other. This is known as the mean field assumption. In `ReactiveMP`

, the specification of such factorization properties is defined during model specification stage using the `where { q = ... }`

syntax or with the `@constraints`

macro (see Constraints specification section for more info about the `@constraints`

macro). Let's take a look at a simple example to see how it is used. In this model we want to learn the mean and precision of a Normal distribution, where the former is modelled with a Normal distribution and the latter with a Gamma.

`using Rocket, GraphPPL, ReactiveMP, Distributions, Random`

```
real_mean = -4.0
real_precision = 0.2
rng = MersenneTwister(1234)
n = 100
data = rand(rng, Normal(real_mean, sqrt(inv(real_precision))), n)
```

```
@model function normal_estimation(n)
m ~ NormalMeanVariance(0.0, 10.0)
w ~ Gamma(0.1, 10.0)
y = datavar(Float64, n)
for i in 1:n
y[i] ~ NormalMeanPrecision(m, w) where { q = MeanField() }
end
return m, w, y
end
```

We create our model as usual, however in order to start VMP inference procedure we need to set initial posterior marginals for all random variables in the model:

```
model, (m, w, y) = normal_estimation(n)
# We use vague initial marginals
setmarginal!(m, vague(NormalMeanVariance))
setmarginal!(w, vague(Gamma))
```

To perform a single VMP iteration it is enough to feed all data inputs with some values. To perform multiple VMP iterations we should feed our all data inputs with the same values multiple times:

```
m_marginals = []
w_marginals = []
subscriptions = subscribe!([
(getmarginal(m), (marginal) -> push!(m_marginals, marginal)),
(getmarginal(w), (marginal) -> push!(w_marginals, marginal)),
])
vmp_iterations = 10
for _ in 1:vmp_iterations
update!(y, data)
end
unsubscribe!(subscriptions)
```

As we process more VMP iterations, our beliefs about the possible values of `m`

and `w`

converge and become more confident.

```
using Plots
p1 = plot(title = "'Mean' posterior marginals")
grid1 = -6.0:0.01:4.0
for iter in [ 1, 2, 10 ]
estimated = Normal(mean(m_marginals[iter]), std(m_marginals[iter]))
e_pdf = (x) -> pdf(estimated, x)
plot!(p1, grid1, e_pdf, fill = true, opacity = 0.3, label = "Estimated mean after $iter VMP iterations")
end
plot!(p1, [ real_mean ], seriestype = :vline, label = "Real mean", color = :red4, opacity = 0.7)
```

```
p2 = plot(title = "'Precision' posterior marginals")
grid2 = 0.01:0.001:0.35
for iter in [ 2, 3, 10 ]
estimated = Gamma(shape(w_marginals[iter]), scale(w_marginals[iter]))
e_pdf = (x) -> pdf(estimated, x)
plot!(p2, grid2, e_pdf, fill = true, opacity = 0.3, label = "Estimated precision after $iter VMP iterations")
end
plot!(p2, [ real_precision ], seriestype = :vline, label = "Real precision", color = :red4, opacity = 0.7)
```

### Computing Bethe Free Energy

VMP inference boils down to finding the member of a family of tractable probability distributions that is closest in KL divergence to an intractable posterior distribution. This is achieved by minimizing a quantity known as Variational Free Energy. `ReactiveMP`

uses Bethe Free Energy approximation to the real Variational Free Energy. Free energy is particularly useful to test for convergence of the VMP iterative procedure.

The `ReactiveMP`

package exports `score`

function for an observable of free energy values:

`fe_observable = score(BetheFreeEnergy(), model)`

```
# Reset posterior marginals for `m` and `w`
setmarginal!(m, vague(NormalMeanVariance))
setmarginal!(w, vague(Gamma))
fe_values = []
fe_subscription = subscribe!(fe_observable, (v) -> push!(fe_values, v))
vmp_iterations = 10
for _ in 1:vmp_iterations
update!(y, data)
end
unsubscribe!(fe_subscription)
```

`plot(fe_values, label = "Bethe Free Energy", xlabel = "Iteration #")`