Automatic inference specification on real-time datasets

RxInfer exports the rxinference function to quickly run and test you model with dynamic and potentially real-time datasets. Note, however, that this function does cover almost all capabilities of the reactive inference engine, but for advanced use cases you may want to resort to the manual inference specification.

For running inference on static datasets see the Static Inference section. For manual inference specification see the Manual Inference section.

RxInfer.rxinferenceFunction
rxinference(
    model,
    data = nothing,
    datastream = nothing,
    initmarginals = nothing,
    initmessages = nothing,
    autoupdates = nothing,
    constraints = nothing,
    meta = nothing,
    options = nothing,
    returnvars = nothing,
    historyvars = nothing,
    keephistory = nothing,
    iterations = nothing,
    free_energy = false,
    free_energy_diagnostics = BetheFreeEnergyDefaultChecks,
    autostart = true,
    events = nothing,
    callbacks = nothing,
    addons = nothing,
    postprocess = DefaultPostprocess(),
    uselock = false,
    warn = true
)

This function provides a generic way to perform probabilistic inference in RxInfer.jl. Returns RxInferenceEngine.

Arguments

For more information about some of the arguments, please check below.

  • model: specifies a model generator, required
  • data: NamedTuple or Dict with data, required (or datastream)
  • datastream: A stream of NamedTuple with data, required (or data)
  • initmarginals = nothing: NamedTuple or Dict with initial marginals, optional
  • initmessages = nothing: NamedTuple or Dict with initial messages, optional
  • autoupdates = nothing: auto-updates specification, required for many models, see @autoupdates
  • constraints = nothing: constraints specification object, optional, see @constraints
  • meta = nothing: meta specification object, optional, may be required for some models, see @meta
  • options = nothing: model creation options, optional, see ModelInferenceOptions
  • returnvars = nothing: return structure info, optional, by default creates observables for all random variables that return posteriors at last vmp iteration, see below for more information
  • historyvars = nothing: history structure info, optional, defaults to no history, see below for more information
  • keephistory = nothing: history buffer size, defaults to empty buffer, see below for more information
  • iterations = nothing: number of iterations, optional, defaults to nothing, the inference engine does not distinguish between variational message passing or Loopy belief propagation or expectation propagation iterations, see below for more information
  • free_energy = false: compute the Bethe free energy, optional, defaults to false. Can be passed a floating point type, e.g. Float64, for better efficiency, but disables automatic differentiation packages, such as ForwardDiff.jl
  • free_energy_diagnostics = BetheFreeEnergyDefaultChecks: free energy diagnostic checks, optional, by default checks for possible NaNs and Infs. nothing disables all checks.
  • autostart = true: specifies whether to call RxInfer.start on the created engine automatically or not
  • showprogress = false: show progress module, optional, defaults to false
  • events = nothing: inference cycle events, optional, see below for more info
  • callbacks = nothing: inference cycle callbacks, optional, see below for more info
  • addons = nothing: inject and send extra computation information along messages, see below for more info
  • postprocess = DefaultPostprocess(): inference results postprocessing step, optional, see below for more info
  • uselock = false: specifies either to use the lock structure for the inference or not, if set to true uses Base.Threads.SpinLock. Accepts custom AbstractLock.
  • warn = true: enables/disables warnings

Note on NamedTuples

When passing NamedTuple as a value for some argument, make sure you use a trailing comma for NamedTuples with a single entry. The reason is that Julia treats returnvars = (x = KeepLast()) and returnvars = (x = KeepLast(), ) expressions differently. This first expression creates (or overwrites!) new local/global variable named x with contents KeepLast(). The second expression (note trailing comma) creates NamedTuple with x as a key and KeepLast() as a value assigned for this key.

Extended information about some of the arguments

  • data or datastream

Either data or datastream keyword argument is required, but specifying both is not supported and will result in an error.

  • data

The data keyword argument must be a NamedTuple (or Dict) where keys (of Symbol type) correspond to all datavars defined in the model specification. For example, if a model defines x = datavar(Float64) the data field must have an :x key (of Symbol type) which holds an iterable container with values of type Float64. The elements of such containers in the data must have the exact same shape as the datavar container. In other words, if a model defines x = datavar(Float64, n) then data[:x] must provide an iterable container with elements of type Vector{Float64}.

All entries in the data argument are zipped together with the Base.zip function to form one slice of the data chunck. This means all containers in the data argument must be of the same size (zip iterator finished as soon as one container has no remaining values). In order to use a fixed value for some specific datavar it is not necessary to create a container with that fixed value, but rather more efficient to use Iterators.repeated to create an infinite iterator.

Note: The behavior of the data keyword argument is different from that which is used in the inference function.

  • datastream

The datastream keyword argument must be an observable that supports subscribe! and unsubscribe! functions (streams from the Rocket.jl package are also supported). The elements of the observable must be of type NamedTuple where keys (of Symbol type) correspond to all datavars defined in the model specification, except for those which are listed in the autoupdates specification. For example, if a model defines x = datavar(Float64) (which is not part of the autoupdates specification) the named tuple from the observable must have an :x key (of Symbol type) which holds a value of type Float64. The values in the named tuple must have the exact same shape as the datavar container. In other words, if a model defines x = datavar(Float64, n) then namedtuple[:x] must provide a container with length n and with elements of type Float64.

Note: The behavior of the individual named tuples from the datastream observable is similar to that which is used in the inference function and its data argument. In fact, you can see the rxinference function as an efficient streamed version of the inference function, which automatically updates some datavars with the autoupdates specification and listens to the datastream to update the rest of the datavars.

  • model

The model argument accepts a ModelGenerator as its input. The easiest way to create the ModelGenerator is to use the @model macro. For example:

@model function coin_toss(some_argument, some_keyword_argument = 3)
   ...
end

result = rxinference(
    model = coin_toss(some_argument; some_keyword_argument = 3)
)

Note: The model keyword argument does not accept a FactorGraphModel instance as a value, as it needs to inject constraints and meta during the inference procedure.

  • initmarginals

For specific types of inference algorithms, such as variational message passing, it might be required to initialize (some of) the marginals before running the inference procedure in order to break the dependency loop. If this is not done, the inference algorithm will not be executed due to the lack of information and message and/or marginals will not be updated. In order to specify these initial marginals, you can use the initmarginals argument, such as

rxinference(...
    initmarginals = (
        # initialize the marginal distribution of x as a vague Normal distribution
        # if x is a vector, then it simply uses the same value for all elements
        # However, it is also possible to provide a vector of distributions to set each element individually 
        x = vague(NormalMeanPrecision),  
    ),
)

This argument needs to be a named tuple, i.e. initmarginals = (a = ..., ), or dictionary.

  • initmessages

For specific types of inference algorithms, such as loopy belief propagation or expectation propagation, it might be required to initialize (some of) the messages before running the inference procedure in order to break the dependency loop. If this is not done, the inference algorithm will not be executed due to the lack of information and message and/or marginals will not be updated. In order to specify these initial messages, you can use the initmessages argument, such as

rxinference(...
    initmessages = (
        # initialize the messages distribution of x as a vague Normal distribution
        # if x is a vector, then it simply uses the same value for all elements
        # However, it is also possible to provide a vector of distributions to set each element individually 
        x = vague(NormalMeanPrecision),  
    ),
)

This argument needs to be a named tuple, i.e. initmessages = (a = ..., ), or dictionary.

  • autoupdates

See @autoupdates for more information.

  • options

  • limit_stack_depth: limits the stack depth for computing messages, helps with StackOverflowError for some huge models, but reduces the performance of inference backend. Accepts integer as an argument that specifies the maximum number of recursive depth. Lower is better for stack overflow error, but worse for performance.

  • pipeline: changes the default pipeline for each factor node in the graph

  • global_reactive_scheduler: changes the scheduler of reactive streams, see Rocket.jl for more info, defaults to no scheduler

  • returnvars

returnvars accepts a tuple of symbols and specifies the latent variables of interests. For each symbol in the returnvars specification the rxinference function will prepare an observable stream (see Rocket.jl) of posterior updates. An agent may subscribe on the new posteriors events and perform some actions. For example:

engine = rxinference(
    ...,
    returnvars = (:x, :τ),
    autostart  = false
)

x_subscription = subscribe!(engine.posteriors[:x], (update) -> println("x variable has been updated: ", update))
τ_subscription = subscribe!(engine.posteriors[:τ], (update) -> println("τ variable has been updated: ", update))

RxInfer.start(engine)

...

unsubscribe!(x_subscription)
unsubscribe!(τ_subscription)

RxInfer.stop(engine)
  • historyvars

historyvars specifies the variables of interests and the amount of information to keep in history about the posterior updates. The specification is similar to the returnvars in the inference procedure. The historyvars requires keephistory to be greater than zero.

historyvars accepts a NamedTuple or Dict or return var specification. There are two specifications:

  • KeepLast: saves the last update for a variable, ignoring any intermediate results during iterations
  • KeepEach: saves all updates for a variable for all iterations

Example:

result = rxinference(
    ...,
    historyvars = (
        x = KeepLast(),
        τ = KeepEach()
    ),
    keephistory = 10
)

It is also possible to set either historyvars = KeepLast() or historyvars = KeepEach() that acts as an alias and sets the given option for all random variables in the model.

Example:

result = rxinference(
    ...,
    historyvars = KeepLast(),
    keephistory = 10
)
  • keep_history

Specifies the buffer size for the updates history both for the historyvars and the free_energy buffers.

  • iterations

Specifies the number of variational (or loopy belief propagation) iterations. By default set to nothing, which is equivalent of doing 1 iteration.

  • free_energy

This setting specifies whenever the inference function should create an observable of Bethe Free Energy (BFE) values. The BFE observable returns a new computed value for each VMP iteration. Note, however, that it may be not possible to compute BFE values for every model. If free_energy = true and keephistory > 0 the engine exposes extra fields to access the history of the Bethe free energy updates:

  • engine.free_energy_history: Returns a free energy history averaged over the VMP iterations
  • engine.free_energy_final_only_history: Returns a free energy history of values computed on last VMP iterations for every observation
  • engine.free_energy_raw_history: Returns a raw free energy history

Additionally, the argument may accept a floating point type, instead of a Bool value. Using this option, e.g.Float64, improves performance of Bethe Free Energy computation, but restricts using automatic differentiation packages.

  • free_energy_diagnostics

This settings specifies either a single or a tuple of diagnostic checks for Bethe Free Energy values stream. By default checks for NaNs and Infs. See also BetheFreeEnergyCheckNaNs and BetheFreeEnergyCheckInfs. Pass nothing to disable any checks.

  • events

The engine from the rxinference function has its own lifecycle. The events can be listened by subscribing to the engine.events field. E.g.

engine = rxinference(
    ...,
    autostart = false
)

subscription = subscribe!(engine.events, (event) -> println(event))

RxInfer.start(engine)

By default all events are disabled, in order to enable an event its identifier must be listed in the Val tuple of symbols passed to the events keyword arguments.

engine = rxinference(
    events = Val((:on_new_data, :before_history_save, :after_history_save))
)

The list of all possible events and their event data is present below (see RxInferenceEvent for more information about the type of event data):

  • on_new_data: args: (model::FactorGraphModel, data)

  • before_iteration args: (model::FactorGraphModel, iteration)

  • before_auto_update args: (model::FactorGraphModel, iteration, auto_updates)

  • after_auto_update args: (model::FactorGraphModel, iteration, auto_updates)

  • before_data_update args: (model::FactorGraphModel, iteration, data)

  • after_data_update args: (model::FactorGraphModel, iteration, data)

  • after_iteration args: (model::FactorGraphModel, iteration)

  • before_history_save args: (model::FactorGraphModel, )

  • after_history_save args: (model::FactorGraphModel, )

  • on_tick args: (model::FactorGraphModel, )

  • on_error args: (model::FactorGraphModel, err)

  • on_complete args: (model::FactorGraphModel, )

  • callbacks

The rxinference function has its own lifecycle. The user is free to provide some (or none) of the callbacks to inject some extra logging or other procedures in the preparation of the inference engine. To inject extra procedures during the inference use the events. Here is the example of the callbacks

result = rxinference(
    ...,
    callbacks = (
        after_model_creation = (model, returnval) -> println("The model has been created. Number of nodes: $(length(getnodes(model)))"),
    )
)

The callbacks keyword argument accepts a named-tuple of 'name = callback' pairs. The list of all possible callbacks and their input arguments is present below:

  • before_model_creation: args: ()

  • after_model_creation: args: (model::FactorGraphModel, returnval)

  • before_autostart: args: (engine::RxInferenceEngine)

  • after_autostart: args: (engine::RxInferenceEngine)

  • addons

The addons field extends the default message computation rules with some extra information, e.g. computing log-scaling factors of messages or saving debug-information. Accepts a single addon or a tuple of addons. If set, replaces the corresponding setting in the options. Automatically changes the default value of the postprocess argument to NoopPostprocess.

  • postprocess

The postprocess keyword argument controls whether the inference results must be modified in some way before exiting the inference function. By default, the inference function uses the DefaultPostprocess strategy, which by default removes the Marginal wrapper type from the results. Change this setting to NoopPostprocess if you would like to keep the Marginal wrapper type, which might be useful in the combination with the addons argument. If the addons argument has been used, automatically changes the default strategy value to NoopPostprocess.

See also inference

source
RxInfer.startFunction
start(engine::RxInferenceEngine)

Starts the RxInferenceEngine by subscribing to the data source, instantiating free energy (if enabled) and starting the event loop. Use RxInfer.stop to stop the RxInferenceEngine. Note that it is not always possible to stop/restart the engine and this depends on the data source type.

See also: RxInfer.stop

source
RxInfer.stopFunction
stop(engine::RxInferenceEngine)

Stops the RxInferenceEngine by unsubscribing to the data source, free energy (if enabled) and stopping the event loop. Use RxInfer.start to start the RxInferenceEngine again. Note that it is not always possible to stop/restart the engine and this depends on the data source type.

See also: RxInfer.start

source
RxInfer.@autoupdatesMacro
@autoupdates

Creates the auto-updates specification for the rxinference function. In the online-streaming Bayesian inference procedure it is important to update your priors for the future states based on the new updated posteriors. The @autoupdates structure simplify such a specification. It accepts a single block of code where each line defines how to update the datavar's in the probabilistic model specification.

Each line of code in the auto-update specification defines datavars, which need to be updated, on the left hand side of the equality expression and the update function on the right hand side of the expression. The update function operates on posterior marginals in the form of the q(symbol) expression.

For example:

@autoupdates begin 
    x = f(q(z))
end

This structure specifies to automatically update x = datavar(...) as soon as the inference engine computes new posterior over z variable. It then applies the f function to the new posterior and calls update!(x, ...) automatically.

As an example consider the following model and auto-update specification:

@model function kalman_filter()
    x_current_mean = datavar(Float64)
    x_current_var  = datavar(Float64)

    x_current ~ Normal(mean = x_current_mean, var = x_current_var)

    x_next ~ Normal(mean = x_current, var = 1.0)

    y = datavar(Float64)
    y ~ Normal(mean = x_next, var = 1.0)
end

This model has two datavars that represent our prior knowledge of the x_current state of the system. The x_next random variable represent the next state of the system that is connected to the observed variable y. The auto-update specification could look like:

autoupdates = @autoupdates begin
    x_current_mean, x_current_var = mean_cov(q(x_next))
end

This structure specifies to update our prior as soon as we have a new posterior q(x_next). It then applies the mean_cov function on the updated posteriors and updates datavars x_current_mean and x_current_var automatically.

See also: rxinference

source
RxInfer.RxInferenceEngineType
RxInferenceEngine

The return value of the rxinference function.

Public fields

  • posteriors: Dict or NamedTuple of 'random variable' - 'posterior stream' pairs. See the returnvars argument for the rxinference.
  • free_energy: (optional) A stream of Bethe Free Energy values per VMP iteration. See the free_energy argument for the rxinference.
  • history: (optional) Saves history of previous marginal updates. See the historyvars and keephistory arguments for the rxinference.
  • free_energy_history: (optional) Free energy history, average over variational iterations
  • free_energy_raw_history: (optional) Free energy history, returns returns computed values of all variational iterations for each data event (if available)
  • free_energy_final_only_history: (optional) Free energy history, returns computed values of final variational iteration for each data event (if available)
  • events: (optional) A stream of events send by the inference engine. See the events argument for the rxinference.
  • model: FactorGraphModel object reference.
  • returnval: Return value from executed @model.

Use the RxInfer.start(engine) function to subscribe on the data source and start the inference procedure. Use RxInfer.stop(engine) to unsubscribe from the data source and stop the inference procedure. Note, that it is not always possible to start/stop the inference procedure.

See also: rxinference, RxInferenceEvent, RxInfer.start, RxInfer.stop

source
RxInfer.RxInferenceEventType
RxInferenceEvent{T, D}

The RxInferenceEngine sends events in a form of the RxInferenceEvent structure. T represents the type of an event, D represents the type of a data associated with the event. The type of data depends on the type of an event, but usually represents a tuple, which can be unrolled automatically with the Julia's splitting syntax, e.g. model, iteration = event. See the documentation of the rxinference function for possible event types and their associated data types.

The events system itself uses the Rocket.jl library API. For example, one may create a custom event listener in the following way:

using Rocket

struct MyEventListener <: Rocket.Actor{RxInferenceEvent}
    # ... extra fields
end

function Rocket.on_next!(listener::MyEventListener, event::RxInferenceEvent{ :after_iteration })
    model, iteration = event
    println("Iteration $(iteration) has been finished.")
end

function Rocket.on_error!(listener::MyEventListener, err)
    # ...
end

function Rocket.on_complete!(listener::MyEventListener)
    # ...
end

and later on:

engine = rxinference(events = Val((:after_iteration, )), ...)

subscription = subscribe!(engine.events, MyEventListener(...))

See also: rxinference, RxInferenceEngine

source