Example: Custom nonlinear node

using ReactiveMP, Distributions, Random, BenchmarkTools, Rocket, GraphPPL, StableRNGs

Here is an example of creating custom node with nonlinear function approximation with samplelist.

Custom node structure

struct NonlinearNode end # Dummy structure just to make Julia happy

struct NonlinearMeta{R, F}
    rng      :: R
    fn       :: F   # Nonlinear function, we assume 1 float input - 1 float ouput
    nsamples :: Int # Number of samples used in approximation
end
@node NonlinearNode Deterministic [ out, in ]

We need to define two Sum-product message computation rules for our new custom node

  • Rule for outbound message on out edge given inbound message on in edge
  • Rule for outbound message on in edge given inbound message on out edge
  • Both rules accept optional meta object
# Rule for outbound message on `out` edge given inbound message on `in` edge
@rule NonlinearNode(:out, Marginalisation) (m_in::NormalMeanVariance, meta::NonlinearMeta) = begin
    samples = rand(meta.rng, m_in, meta.nsamples)
    return SampleList(map(meta.fn, samples))
end

# Rule for outbound message on `in` edge given inbound message on `out` edge
@rule NonlinearNode(:in, Marginalisation) (m_out::Gamma, meta::NonlinearMeta) = begin
    return ContinuousUnivariateLogPdf((x) -> logpdf(m_out, meta.fn(x)))
end

Model specification

After we have defined our custom node with custom rules we may proceed with a model specification:

\[\begin{equation} \begin{aligned} p(\theta) &= \mathcal{N}(\theta|\mu_{\theta}, \sigma_{\theta}),\\ p(m) &= \mathcal{N}(\theta|\mu_{m}, \sigma_{m}),\\ p(w) &= f(\theta),\\ p(y_i|m, w) &= \mathcal{N}(y_i|m, w), \end{aligned} \end{equation}\]

Given this IID model, we aim to estimate the precision of a Gaussian distribution. We pass a random variable $\theta$ through a non-linear transformation $f$ to make it positive and suitable for a precision parameter of a Gaussian distribution. We, later on, will estimate the posterior of $\theta$.

@model function nonlinear_estimation(n)

    θ ~ NormalMeanVariance(0.0, 100.0)
    m ~ NormalMeanVariance(0.0, 1.0)

    w ~ NonlinearNode(θ)

    y = datavar(Float64, n)

    for i in 1:n
        y[i] ~ NormalMeanPrecision(m, w)
    end

    return θ, m, w, y
end
@constraints function nconstsraints(nsamples)
    q(θ) :: SampleList(nsamples, LeftProposal())
    q(w) :: SampleList(nsamples, RightProposal())

    q(θ, w, m) = q(θ)q(m)q(w)
end
nconstsraints (generic function with 1 method)
# TODO: check
@meta function nmeta(fn, nsamples)
    NonlinearNode(θ, w) -> NonlinearMeta(StableRNG(123), fn, nsamples)
end
nmeta (generic function with 1 method)

Here we generate some data with some arbitrary nonlinearity for precision parameter:

nonlinear_fn(x) = abs(exp(x) * sin(x))
nonlinear_fn (generic function with 1 method)
seed = 123
rng  = MersenneTwister(seed)

niters   = 15 # Number of VMP iterations
nsamples = 5_000 # Number of samples in approximation

n = 500 # Number of IID samples
μ = -10.0
θ = -1.0
w = nonlinear_fn(θ)

data = rand(rng, NormalMeanPrecision(μ, w), n)

Inference

result = inference(
    model = Model(nonlinear_estimation, n),
    meta =  nmeta(nonlinear_fn, nsamples),
    constraints = nconstsraints(nsamples),
    data = (y = data, ),
    initmarginals = (m = vague(NormalMeanPrecision), w = vague(Gamma)),
    returnvars = (θ = KeepLast(), ),
    iterations = niters,
    showprogress = true
)
Inference results:
-----------------------------------------
θ = SampleList(Univariate, 5000)
using Plots, StatsPlots

estimated = Normal(mean_std(result.posteriors[:θ])...)

plot(estimated, title="Posterior for θ", label = "Estimated", legend = :bottomright, fill = true, fillopacity = 0.2, xlim = (-3, 3), ylim = (0, 2))
vline!([ θ ], label = "Real value of θ")

Benchmark

@benchmark inference(
    model = $(Model(nonlinear_estimation, n)),
    meta = $(nmeta(nonlinear_fn, nsamples)),
    constraints = $(nconstsraints(nsamples)),
    data = (y = $data, ),
    initmarginals = (m = vague(NormalMeanPrecision), w = vague(Gamma)),
    returnvars = (θ = KeepLast(), ),
    iterations = $niters,
    showprogress = true
)
BenchmarkTools.Trial: 76 samples with 1 evaluation.
 Range (min … max):  54.324 ms … 107.022 ms  ┊ GC (min … max): 0.00% … 31.05%
 Time  (median):     59.822 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   66.289 ms ±  14.851 ms  ┊ GC (mean ± σ):  8.80% ± 12.81%

        █                                                       
  ▅▆▆▆▆▇█▆▆▆▃▄▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▃▁▄▃▄▁▃▃▃▄▁▁▁▁▃ ▁
  54.3 ms         Histogram: frequency by time          104 ms <

 Memory estimate: 24.80 MiB, allocs estimate: 342470.