Automatic inference specification on static datasets

RxInfer exports the inference function to quickly run and test you model with static datasets. Note, however, that this function does cover almost all capabilities of the inference engine, but for advanced use cases you may want to resort to the manual inference specification.

For running inference on real-time datasets see the Reactive Inference section. For manual inference specification see the Manual Inference section.

RxInfer.inferenceFunction
inference(
    model; 
    data,
    initmarginals           = nothing,
    initmessages            = nothing,
    constraints             = nothing,
    meta                    = nothing,
    options                 = nothing,
    returnvars              = nothing, 
    predictvars             = nothing, 
    iterations              = nothing,
    free_energy             = false,
    free_energy_diagnostics = BetheFreeEnergyDefaultChecks,
    showprogress            = false,
    callbacks               = nothing,
    addons                  = nothing,
    postprocess             = DefaultPostprocess()
)

This function provides a generic way to perform probabilistic inference in RxInfer.jl. Returns InferenceResult.

Arguments

For more information about some of the arguments, please check below.

  • model: specifies a model generator, required
  • data: NamedTuple or Dict with data, required
  • initmarginals = nothing: NamedTuple or Dict with initial marginals, optional
  • initmessages = nothing: NamedTuple or Dict with initial messages, optional
  • constraints = nothing: constraints specification object, optional, see @constraints
  • meta = nothing: meta specification object, optional, may be required for some models, see @meta
  • options = nothing: model creation options, optional, see ModelInferenceOptions
  • returnvars = nothing: return structure info, optional, defaults to return everything at each iteration, see below for more information
  • predictvars = nothing: return structure info, optional, see below for more information
  • iterations = nothing: number of iterations, optional, defaults to nothing, the inference engine does not distinguish between variational message passing or Loopy belief propagation or expectation propagation iterations, see below for more information
  • free_energy = false: compute the Bethe free energy, optional, defaults to false. Can be passed a floating point type, e.g. Float64, for better efficiency, but disables automatic differentiation packages, such as ForwardDiff.jl
  • free_energy_diagnostics = BetheFreeEnergyDefaultChecks: free energy diagnostic checks, optional, by default checks for possible NaNs and Infs. nothing disables all checks.
  • showprogress = false: show progress module, optional, defaults to false
  • callbacks = nothing: inference cycle callbacks, optional, see below for more info
  • addons = nothing: inject and send extra computation information along messages, see below for more info
  • postprocess = DefaultPostprocess(): inference results postprocessing step, optional, see below for more info
  • warn = true: enables/disables warnings

Note on NamedTuples

When passing NamedTuple as a value for some argument, make sure you use a trailing comma for NamedTuples with a single entry. The reason is that Julia treats returnvars = (x = KeepLast()) and returnvars = (x = KeepLast(), ) expressions differently. This first expression creates (or overwrites!) new local/global variable named x with contents KeepLast(). The second expression (note trailing comma) creates NamedTuple with x as a key and KeepLast() as a value assigned for this key.

Extended information about some of the arguments

  • model

The model argument accepts a ModelGenerator as its input. The easiest way to create the ModelGenerator is to use the @model macro. For example:

@model function coin_toss(some_argument, some_keyword_argument = 3)
   ...
end

result = inference(
    model = coin_toss(some_argument; some_keyword_argument = 3)
)

Note: The model keyword argument does not accept a FactorGraphModel instance as a value, as it needs to inject constraints and meta during the inference procedure.

  • data

The data keyword argument must be a NamedTuple (or Dict) where keys (of Symbol type) correspond to all datavars defined in the model specification. For example, if a model defines x = datavar(Float64) the data field must have an :x key (of Symbol type) which holds a value of type Float64. The values in the data must have the exact same shape as the datavar container. In other words, if a model defines x = datavar(Float64, n) then data[:x] must provide a container with length n and with elements of type Float64.

Note: The behavior of the data keyword argument is different from that which is used in the rxinference function.

  • initmarginals

For specific types of inference algorithms, such as variational message passing, it might be required to initialize (some of) the marginals before running the inference procedure in order to break the dependency loop. If this is not done, the inference algorithm will not be executed due to the lack of information and message and/or marginals will not be updated. In order to specify these initial marginals, you can use the initmarginals argument, such as

inference(...
    initmarginals = (
        # initialize the marginal distribution of x as a vague Normal distribution
        # if x is a vector, then it simply uses the same value for all elements
        # However, it is also possible to provide a vector of distributions to set each element individually 
        x = vague(NormalMeanPrecision),  
    ),
)

This argument needs to be a named tuple, i.e. initmarginals = (a = ..., ), or dictionary.

  • initmessages

For specific types of inference algorithms, such as loopy belief propagation or expectation propagation, it might be required to initialize (some of) the messages before running the inference procedure in order to break the dependency loop. If this is not done, the inference algorithm will not be executed due to the lack of information and message and/or marginals will not be updated. In order to specify these initial messages, you can use the initmessages argument, such as

inference(...
    initmessages = (
        # initialize the messages distribution of x as a vague Normal distribution
        # if x is a vector, then it simply uses the same value for all elements
        # However, it is also possible to provide a vector of distributions to set each element individually 
        x = vague(NormalMeanPrecision),  
    ),
)

This argument needs to be a named tuple, i.e. initmessages = (a = ..., ), or dictionary.

  • options

  • limit_stack_depth: limits the stack depth for computing messages, helps with StackOverflowError for large models, but reduces the performance of the inference backend. Accepts integer as an argument that specifies the maximum number of recursive depth. Lower is better for stack overflow error, but worse for performance.

  • pipeline: changes the default pipeline for each factor node in the graph

  • global_reactive_scheduler: changes the scheduler of reactive streams, see Rocket.jl for more info, defaults to no scheduler

  • returnvars

returnvars specifies the variables of interests and the amount of information to return about their posterior updates.

returnvars accepts a NamedTuple or Dict or return var specification. There are two specifications:

  • KeepLast: saves the last update for a variable, ignoring any intermediate results during iterations
  • KeepEach: saves all updates for a variable for all iterations

Note: if iterations are specified as a number, the inference function tracks and returns every update for each iteration for every random variable in the model (equivalent to KeepEach()). If number of iterations is set to nothing, the inference function saves the 'last' (and the only one) update for every random variable in the model (equivalent to KeepLast()). Use iterations = 1 to force KeepEach() setting when number of iterations is equal to 1 or set returnvars = KeepEach() manually.

Example:

result = inference(
    ...,
    returnvars = (
        x = KeepLast(),
        τ = KeepEach()
    )
)

It is also possible to set either returnvars = KeepLast() or returnvars = KeepEach() that acts as an alias and sets the given option for all random variables in the model.

Example:

result = inference(
    ...,
    returnvars = KeepLast()
)
  • predictvars

predictvars specifies the variables which should be predicted. In the model definition these variables are specified as datavars, although they should not be passed inside data argument.

Similar to returnvars, predictvars accepts a NamedTuple or Dict. There are two specifications:

  • KeepLast: saves the last update for a variable, ignoring any intermediate results during iterations
  • KeepEach: saves all updates for a variable for all iterations

Example:

result = inference(
    ...,
    predictvars = (
        o = KeepLast(),
        τ = KeepEach()
    )
)
  • iterations

Specifies the number of variational (or loopy belief propagation) iterations. By default set to nothing, which is equivalent of doing 1 iteration.

  • free_energy

This setting specifies whenever the inference function should return Bethe Free Energy (BFE) values. Note, however, that it may be not possible to compute BFE values for every model.

Additionally, the argument may accept a floating point type, instead of a Bool value. Using his option, e.g.Float64, improves performance of Bethe Free Energy computation, but restricts using automatic differentiation packages.

  • free_energy_diagnostics

This settings specifies either a single or a tuple of diagnostic checks for Bethe Free Energy values stream. By default checks for NaNs and Infs. See also BetheFreeEnergyCheckNaNs and BetheFreeEnergyCheckInfs. Pass nothing to disable any checks.

  • callbacks

The inference function has its own lifecycle. The user is free to provide some (or none) of the callbacks to inject some extra logging or other procedures in the inference function, e.g.

result = inference(
    ...,
    callbacks = (
        on_marginal_update = (model, name, update) -> println("$(name) has been updated: $(update)"),
        after_inference    = (args...) -> println("Inference has been completed")
    )
)

The callbacks keyword argument accepts a named-tuple of 'name = callback' pairs. The list of all possible callbacks and their arguments is present below:

  • on_marginal_update: args: (model::FactorGraphModel, name::Symbol, update)
  • before_model_creation: args: ()
  • after_model_creation: args: (model::FactorGraphModel, returnval)
  • before_inference: args: (model::FactorGraphModel)
  • before_iteration: args: (model::FactorGraphModel, iteration::Int)::Bool
  • before_data_update: args: (model::FactorGraphModel, data)
  • after_data_update: args: (model::FactorGraphModel, data)
  • after_iteration: args: (model::FactorGraphModel, iteration::Int)::Bool
  • after_inference: args: (model::FactorGraphModel)

before_iteration and after_iteration callbacks are allowed to return true/false value. true indicates that iterations must be halted and no further inference should be made.

  • addons

The addons field extends the default message computation rules with some extra information, e.g. computing log-scaling factors of messages or saving debug-information. Accepts a single addon or a tuple of addons. If set, replaces the corresponding setting in the options. Automatically changes the default value of the postprocess argument to NoopPostprocess.

  • postprocess

The postprocess keyword argument controls whether the inference results must be modified in some way before exiting the inference function. By default, the inference function uses the DefaultPostprocess strategy, which by default removes the Marginal wrapper type from the results. Change this setting to NoopPostprocess if you would like to keep the Marginal wrapper type, which might be useful in the combination with the addons argument. If the addons argument has been used, automatically changes the default strategy value to NoopPostprocess.

  • catch_exception

The catch_exception keyword argument specifies whether exceptions during the inference procedure should be caught in the error field of the result. By default, if exception occurs during the inference procedure the result will be lost. Set catch_exception = true to obtain partial result for the inference in case if an exception occurs. Use RxInfer.issuccess and RxInfer.iserror function to check if the inference completed successfully or failed. If an error occurs, the error field will store a tuple, where first element is the exception itself and the second element is the caught backtrace. Use the stacktrace function with the backtrace as an argument to recover the stacktrace of the error. Use Base.showerror function to display the error.

See also: InferenceResult, rxinference

source
RxInfer.InferenceResultType
InferenceResult

This structure is used as a return value from the inference function.

Public Fields

  • posteriors: Dict or NamedTuple of 'random variable' - 'posterior' pairs. See the returnvars argument for inference.
  • free_energy: (optional) An array of Bethe Free Energy values per VMP iteration. See the free_energy argument for inference.
  • model: FactorGraphModel object reference.
  • returnval: Return value from executed @model.
  • error: (optional) A reference to an exception, that might have occurred during the inference. See the catch_exception argument for inference.

See also: inference

source