This example has been auto-generated from the examples/
folder at GitHub repository.
Simple Nonlinear Node
# Activate local environment, see `Project.toml`
import Pkg; Pkg.activate(".."); Pkg.instantiate();
using RxInfer, Random, StableRNGs
Here is an example of creating custom node with nonlinear function approximation with samplelist.
Custom node creation
struct NonlinearNode end # Dummy structure just to make Julia happy
struct NonlinearMeta{R, F}
rng :: R
fn :: F # Nonlinear function, we assume 1 float input - 1 float output
nsamples :: Int # Number of samples used in approximation
end
@node NonlinearNode Deterministic [ out, in ]
We need to define two Sum-product message computation rules for our new custom node
- Rule for outbound message on
out
edge given inbound message onin
edge - Rule for outbound message on
in
edge given inbound message onout
edge - Both rules accept optional meta object
# Rule for outbound message on `out` edge given inbound message on `in` edge
@rule NonlinearNode(:out, Marginalisation) (m_in::NormalMeanVariance, meta::NonlinearMeta) = begin
samples = rand(meta.rng, m_in, meta.nsamples)
return SampleList(map(meta.fn, samples))
end
# Rule for outbound message on `in` edge given inbound message on `out` edge
@rule NonlinearNode(:in, Marginalisation) (m_out::Gamma, meta::NonlinearMeta) = begin
return ContinuousUnivariateLogPdf((x) -> logpdf(m_out, meta.fn(x)))
end
Model specification
After we have defined our custom node with custom rules we may proceed with a model specification:
\[\begin{aligned} p(\theta) &= \mathcal{N}(\theta|\mu_{\theta}, \sigma_{\theta}),\\ p(m) &= \mathcal{N}(\theta|\mu_{m}, \sigma_{m}),\\ p(w) &= f(\theta),\\ p(y_i|m, w) &= \mathcal{N}(y_i|m, w), \end{aligned}\]
Given this IID model, we aim to estimate the precision of a Gaussian distribution. We pass a random variable $\theta$ through a non-linear transformation $f$ to make it positive and suitable for a precision parameter of a Gaussian distribution. We, later on, will estimate the posterior of $\theta$.
@model function nonlinear_estimation(n)
θ ~ Normal(mean = 0.0, variance = 100.0)
m ~ Normal(mean = 0.0, variance = 1.0)
w ~ NonlinearNode(θ)
y = datavar(Float64, n)
for i in 1:n
y[i] ~ Normal(mean = m, precision = w)
end
end
@constraints function nconstsraints(nsamples)
q(θ) :: SampleList(nsamples, LeftProposal())
q(w) :: SampleList(nsamples, RightProposal())
q(θ, w, m) = q(θ)q(m)q(w)
end
nconstsraints (generic function with 1 method)
@meta function nmeta(fn, nsamples)
NonlinearNode(θ, w) -> NonlinearMeta(StableRNG(123), fn, nsamples)
end
nmeta (generic function with 1 method)
Here we generate some data
nonlinear_fn(x) = abs(exp(x) * sin(x))
nonlinear_fn (generic function with 1 method)
seed = 123
rng = StableRNG(seed)
niters = 15 # Number of VMP iterations
nsamples = 5_000 # Number of samples in approximation
n = 500 # Number of IID samples
μ = -10.0
θ = -1.0
w = nonlinear_fn(θ)
data = rand(rng, NormalMeanPrecision(μ, w), n);
result = inference(
model = nonlinear_estimation(n),
meta = nmeta(nonlinear_fn, nsamples),
constraints = nconstsraints(nsamples),
data = (y = data, ),
initmarginals = (m = vague(NormalMeanPrecision), w = vague(Gamma)),
returnvars = (θ = KeepLast(), ),
iterations = niters,
showprogress = true
)
Inference results:
Posteriors | available for (θ)
θposterior = result.posteriors[:θ]
SampleList(Univariate, 5000)
using Plots, StatsPlots
estimated = Normal(mean_std(θposterior)...)
plot(estimated, title="Posterior for θ", label = "Estimated", legend = :bottomright, fill = true, fillopacity = 0.2, xlim = (-3, 3), ylim = (0, 2))
vline!([ θ ], label = "Real value of θ")