Automatic inference specification on static datasets
RxInfer
exports the inference
function to quickly run and test you model with static datasets. Note, however, that this function does cover almost all capabilities of the inference engine, but for advanced use cases you may want to resort to the manual inference specification.
For running inference on real-time datasets see the Reactive Inference section. For manual inference specification see the Manual Inference section.
RxInfer.inference
— Functioninference(
model;
data,
initmarginals = nothing,
initmessages = nothing,
constraints = nothing,
meta = nothing,
options = nothing,
returnvars = nothing,
predictvars = nothing,
iterations = nothing,
free_energy = false,
free_energy_diagnostics = BetheFreeEnergyDefaultChecks,
showprogress = false,
callbacks = nothing,
addons = nothing,
postprocess = DefaultPostprocess()
)
This function provides a generic way to perform probabilistic inference in RxInfer.jl. Returns InferenceResult
.
Arguments
For more information about some of the arguments, please check below.
model
: specifies a model generator, requireddata
:NamedTuple
orDict
with data, requiredinitmarginals = nothing
:NamedTuple
orDict
with initial marginals, optionalinitmessages = nothing
:NamedTuple
orDict
with initial messages, optionalconstraints = nothing
: constraints specification object, optional, see@constraints
meta = nothing
: meta specification object, optional, may be required for some models, see@meta
options = nothing
: model creation options, optional, seeModelInferenceOptions
returnvars = nothing
: return structure info, optional, defaults to return everything at each iteration, see below for more informationpredictvars = nothing
: return structure info, optional, see below for more informationiterations = nothing
: number of iterations, optional, defaults tonothing
, the inference engine does not distinguish between variational message passing or Loopy belief propagation or expectation propagation iterations, see below for more informationfree_energy = false
: compute the Bethe free energy, optional, defaults to false. Can be passed a floating point type, e.g.Float64
, for better efficiency, but disables automatic differentiation packages, such as ForwardDiff.jlfree_energy_diagnostics = BetheFreeEnergyDefaultChecks
: free energy diagnostic checks, optional, by default checks for possibleNaN
s andInf
s.nothing
disables all checks.showprogress = false
: show progress module, optional, defaults to falsecallbacks = nothing
: inference cycle callbacks, optional, see below for more infoaddons = nothing
: inject and send extra computation information along messages, see below for more infopostprocess = DefaultPostprocess()
: inference results postprocessing step, optional, see below for more infowarn = true
: enables/disables warnings
Note on NamedTuples
When passing NamedTuple
as a value for some argument, make sure you use a trailing comma for NamedTuple
s with a single entry. The reason is that Julia treats returnvars = (x = KeepLast())
and returnvars = (x = KeepLast(), )
expressions differently. This first expression creates (or overwrites!) new local/global variable named x
with contents KeepLast()
. The second expression (note trailing comma) creates NamedTuple
with x
as a key and KeepLast()
as a value assigned for this key.
Extended information about some of the arguments
model
The model
argument accepts a ModelGenerator
as its input. The easiest way to create the ModelGenerator
is to use the @model
macro. For example:
@model function coin_toss(some_argument, some_keyword_argument = 3)
...
end
result = inference(
model = coin_toss(some_argument; some_keyword_argument = 3)
)
Note: The model
keyword argument does not accept a FactorGraphModel
instance as a value, as it needs to inject constraints
and meta
during the inference procedure.
data
The data
keyword argument must be a NamedTuple
(or Dict
) where keys (of Symbol
type) correspond to all datavar
s defined in the model specification. For example, if a model defines x = datavar(Float64)
the data
field must have an :x
key (of Symbol
type) which holds a value of type Float64
. The values in the data
must have the exact same shape as the datavar
container. In other words, if a model defines x = datavar(Float64, n)
then data[:x]
must provide a container with length n
and with elements of type Float64
.
Note: The behavior of the data
keyword argument is different from that which is used in the rxinference
function.
initmarginals
For specific types of inference algorithms, such as variational message passing, it might be required to initialize (some of) the marginals before running the inference procedure in order to break the dependency loop. If this is not done, the inference algorithm will not be executed due to the lack of information and message and/or marginals will not be updated. In order to specify these initial marginals, you can use the initmarginals
argument, such as
inference(...
initmarginals = (
# initialize the marginal distribution of x as a vague Normal distribution
# if x is a vector, then it simply uses the same value for all elements
# However, it is also possible to provide a vector of distributions to set each element individually
x = vague(NormalMeanPrecision),
),
)
This argument needs to be a named tuple, i.e. initmarginals = (a = ..., )
, or dictionary.
initmessages
For specific types of inference algorithms, such as loopy belief propagation or expectation propagation, it might be required to initialize (some of) the messages before running the inference procedure in order to break the dependency loop. If this is not done, the inference algorithm will not be executed due to the lack of information and message and/or marginals will not be updated. In order to specify these initial messages, you can use the initmessages
argument, such as
inference(...
initmessages = (
# initialize the messages distribution of x as a vague Normal distribution
# if x is a vector, then it simply uses the same value for all elements
# However, it is also possible to provide a vector of distributions to set each element individually
x = vague(NormalMeanPrecision),
),
)
This argument needs to be a named tuple, i.e. initmessages = (a = ..., )
, or dictionary.
options
limit_stack_depth
: limits the stack depth for computing messages, helps withStackOverflowError
for large models, but reduces the performance of the inference backend. Accepts integer as an argument that specifies the maximum number of recursive depth. Lower is better for stack overflow error, but worse for performance.pipeline
: changes the default pipeline for each factor node in the graphglobal_reactive_scheduler
: changes the scheduler of reactive streams, see Rocket.jl for more info, defaults to no schedulerreturnvars
returnvars
specifies the variables of interests and the amount of information to return about their posterior updates.
returnvars
accepts a NamedTuple
or Dict
or return var specification. There are two specifications:
KeepLast
: saves the last update for a variable, ignoring any intermediate results during iterationsKeepEach
: saves all updates for a variable for all iterations
Note: if iterations
are specified as a number, the inference
function tracks and returns every update for each iteration for every random variable in the model (equivalent to KeepEach()
). If number of iterations is set to nothing
, the inference
function saves the 'last' (and the only one) update for every random variable in the model (equivalent to KeepLast()
). Use iterations = 1
to force KeepEach()
setting when number of iterations is equal to 1
or set returnvars = KeepEach()
manually.
Example:
result = inference(
...,
returnvars = (
x = KeepLast(),
τ = KeepEach()
)
)
It is also possible to set either returnvars = KeepLast()
or returnvars = KeepEach()
that acts as an alias and sets the given option for all random variables in the model.
Example:
result = inference(
...,
returnvars = KeepLast()
)
predictvars
predictvars
specifies the variables which should be predicted. In the model definition these variables are specified as datavars, although they should not be passed inside data argument.
Similar to returnvars
, predictvars
accepts a NamedTuple
or Dict
. There are two specifications:
KeepLast
: saves the last update for a variable, ignoring any intermediate results during iterationsKeepEach
: saves all updates for a variable for all iterations
Example:
result = inference(
...,
predictvars = (
o = KeepLast(),
τ = KeepEach()
)
)
iterations
Specifies the number of variational (or loopy belief propagation) iterations. By default set to nothing
, which is equivalent of doing 1 iteration.
free_energy
This setting specifies whenever the inference
function should return Bethe Free Energy (BFE) values. Note, however, that it may be not possible to compute BFE values for every model.
Additionally, the argument may accept a floating point type, instead of a Bool
value. Using his option, e.g.Float64
, improves performance of Bethe Free Energy computation, but restricts using automatic differentiation packages.
free_energy_diagnostics
This settings specifies either a single or a tuple of diagnostic checks for Bethe Free Energy values stream. By default checks for NaN
s and Inf
s. See also BetheFreeEnergyCheckNaNs
and BetheFreeEnergyCheckInfs
. Pass nothing
to disable any checks.
callbacks
The inference function has its own lifecycle. The user is free to provide some (or none) of the callbacks to inject some extra logging or other procedures in the inference function, e.g.
result = inference(
...,
callbacks = (
on_marginal_update = (model, name, update) -> println("$(name) has been updated: $(update)"),
after_inference = (args...) -> println("Inference has been completed")
)
)
The callbacks
keyword argument accepts a named-tuple of 'name = callback' pairs. The list of all possible callbacks and their arguments is present below:
on_marginal_update
: args: (model::FactorGraphModel, name::Symbol, update)before_model_creation
: args: ()after_model_creation
: args: (model::FactorGraphModel, returnval)before_inference
: args: (model::FactorGraphModel)before_iteration
: args: (model::FactorGraphModel, iteration::Int)::Boolbefore_data_update
: args: (model::FactorGraphModel, data)after_data_update
: args: (model::FactorGraphModel, data)after_iteration
: args: (model::FactorGraphModel, iteration::Int)::Boolafter_inference
: args: (model::FactorGraphModel)
before_iteration
and after_iteration
callbacks are allowed to return true/false
value. true
indicates that iterations must be halted and no further inference should be made.
addons
The addons
field extends the default message computation rules with some extra information, e.g. computing log-scaling factors of messages or saving debug-information. Accepts a single addon or a tuple of addons. If set, replaces the corresponding setting in the options
. Automatically changes the default value of the postprocess
argument to NoopPostprocess
.
postprocess
The postprocess
keyword argument controls whether the inference results must be modified in some way before exiting the inference
function. By default, the inference function uses the DefaultPostprocess
strategy, which by default removes the Marginal
wrapper type from the results. Change this setting to NoopPostprocess
if you would like to keep the Marginal
wrapper type, which might be useful in the combination with the addons
argument. If the addons
argument has been used, automatically changes the default strategy value to NoopPostprocess
.
catch_exception
The catch_exception
keyword argument specifies whether exceptions during the inference procedure should be caught in the error
field of the result. By default, if exception occurs during the inference procedure the result will be lost. Set catch_exception = true
to obtain partial result for the inference in case if an exception occurs. Use RxInfer.issuccess
and RxInfer.iserror
function to check if the inference completed successfully or failed. If an error occurs, the error
field will store a tuple, where first element is the exception itself and the second element is the caught backtrace
. Use the stacktrace
function with the backtrace
as an argument to recover the stacktrace of the error. Use Base.showerror
function to display the error.
See also: InferenceResult
, rxinference
RxInfer.InferenceResult
— TypeInferenceResult
This structure is used as a return value from the inference
function.
Public Fields
posteriors
:Dict
orNamedTuple
of 'random variable' - 'posterior' pairs. See thereturnvars
argument forinference
.free_energy
: (optional) An array of Bethe Free Energy values per VMP iteration. See thefree_energy
argument forinference
.model
:FactorGraphModel
object reference.returnval
: Return value from executed@model
.error
: (optional) A reference to an exception, that might have occurred during the inference. See thecatch_exception
argument forinference
.
See also: inference