Delta node manual
RxInfer.jl offers a comprehensive set of stochastic nodes, with a primary emphasis on distributions from the exponential family and its associated compositions, such as Gaussian with controlled variance (GCV) or autoregressive (AR) nodes. The DeltaNode
stands out in this package, representing a deterministic transformation of either a single random variable or a group of them. This guide provides insights into the DeltaNode
and its functionalities.
Features and Supported Inference Scenarios
The delta node has several approximation methods for performing probabilistic inference. The desired approximation method depends on the nodes connected to the delta node. We differentiate the following deterministic transformation scenarios:
- Gaussian Nodes: For delta nodes linked to strictly multivariate or univariate Gaussian distributions, the recommended methods are Linearization or Unscented transforms.
- Exponential Family Nodes: For the delta node connected to nodes from the exponential family, the CVI (Conjugate Variational Inference) is the method of choice.
- Stacking Delta Nodes: For scenarios where delta nodes are stacked, either Linearization or Unscented transforms are suitable.
The table below summarizes the features of the delta node in RxInfer.jl, categorized by the approximation method:
Methods | Gaussian Nodes | Exponential Family Nodes | Stacking Delta Nodes |
---|---|---|---|
Linearization | ✓ | ✗ | ✓ |
Unscented | ✓ | ✗ | ✓ |
CVI | ✓ | ✓ | ✗ |
Gaussian Case
In the context of Gaussian distributions, we recommend either the Linearization
or Unscented
method for delta node approximation. The Linearization
method provides a first-order approximation, while the Unscented
method delivers a more precise second-order approximation. It's worth noting that while the Unscented
method is more accurate, it may require hyper-parameters tuning.
For clarity, consider the following example:
using RxInfer
@model function delta_node_example()
z = datavar(Float64)
x ~ Normal(mean=0.0, var=1.0)
y ~ tanh(x)
z ~ Normal(mean=y, var=1.0)
end
To perform inference on this model, designate the approximation method for the delta node (here, the tanh
function) using the @meta
specification:
delta_meta = @meta begin
tanh() -> Linearization()
end
Meta specification:
tanh() -> Linearization()
Options:
warn = true
or
delta_meta = @meta begin
tanh() -> Unscented()
end
Meta specification:
tanh() -> Unscented{Float64, Float64, Float64, Nothing}(0.001, 2.0, 0.0, nothing)
Options:
warn = true
For a deeper understanding of the Unscented
method and its parameters, consult the docstrings.
Given the invertibility of tanh
, indicating its inverse function can optimize the inference procedure:
delta_meta = @meta begin
tanh() -> DeltaMeta(method = Linearization(), inverse = atanh)
end
Meta specification:
tanh() -> DeltaMeta{Linearization, typeof(atanh)}(Linearization(), atanh)
Options:
warn = true
To execute the inference procedure:
inference(model = delta_node_example(), meta=delta_meta, data = (z = 1.0,))
Inference results:
Posteriors | available for (y, x)
This methodology is consistent even when the delta node is associated with multiple nodes. For instance:
f(x, g) = x*tanh(g)
f (generic function with 1 method)
@model function delta_node_example()
z = datavar(Float64)
x ~ Normal(mean=1.0, var=1.0)
g ~ Normal(mean=1.0, var=1.0)
y ~ f(x, g)
z ~ Normal(mean=y, var=0.1)
end
The corresponding meta specification is:
delta_meta = @meta begin
f() -> DeltaMeta(method = Linearization())
end
Meta specification:
f() -> DeltaMeta{Linearization, Nothing}(Linearization(), nothing)
Options:
warn = true
or simply
delta_meta = @meta begin
f() -> Linearization()
end
Meta specification:
f() -> Linearization()
Options:
warn = true
If specific functions outline the backward relation of variables within the f
function, you can provide a tuple of inverse functions in the order of the variables:
delta_meta = @meta begin
f() -> DeltaMeta(method = Linearization(), inverse=(f_back_x, f_back_g))
end
Exponential Family Case
When the delta node is associated with nodes from the exponential family (excluding Gaussians), the Linearization
and Unscented
methods are not applicable. In such cases, the CVI (Conjugate Variational Inference) is available. Here's a modified example:
using RxInfer
@model function delta_node_example1()
z = datavar(Float64)
x ~ Gamma(shape=1.0, rate=1.0)
y ~ tanh(x)
z ~ Bernoulli(y)
end
The corresponding meta specification can be represented as:
using StableRNGs
using Optimisers
delta_meta = @meta begin
tanh() -> DeltaMeta(method = CVI(StableRNG(42), 100, 100, Optimisers.Descent(0.01)))
end
Meta specification:
tanh() -> DeltaMeta{ProdCVI{StableRNGs.LehmerRNG, Optimisers.Descent{Float64}, ForwardDiffGrad{0}, true}, Nothing}(ProdCVI{StableRNGs.LehmerRNG, Optimisers.Descent{Float64}, ForwardDiffGrad{0}, true}(StableRNGs.LehmerRNG(state=0x00000000000000000000000000000055), 100, 100, Optimisers.Descent{Float64}(0.01), ForwardDiffGrad{0}(), 1, Val{true}(), true), nothing)
Options:
warn = true
Consult the ProdCVI
docstrings for a detailed explanation of these parameters.