Example: Linear Regression

In this example we are going to perform a simple linear regression problem, but in the Bayesian setting. We specify the model:

\[y_i = a * x_i + b\]

where $a$ and $b$ are random variables with some vague priors.

First step is to import all needed packages and define the model:

using ReactiveMP, GraphPPL, Rocket, Random, Plots, StableRNGs, BenchmarkTools

Model specification

@model function linear_regression(n)
    a ~ NormalMeanVariance(0.0, 1.0)
    b ~ NormalMeanVariance(0.0, 100.0)

    x = datavar(Float64, n)
    y = datavar(Float64, n)

    for i in 1:n
        y[i] ~ NormalMeanVariance(a * x[i] + b, 1.0)
    end

    return a, b, x, y
end

Dataset

In order to test our inference procedure we create a test dataset where observations are corrupted with gaussian white noise (with known variance).

reala = 0.5
realb = 25

N = 250

rng = StableRNG(1234)

xorig = collect(1:N)

xdata = xorig .+ randn(rng, N)
ydata = realb .+ reala .* xorig .+ randn(rng, N);

plot(xdata, label = "X", title = "Linear regression dataset")
plot!(ydata, label = "Y")

Inference

results = inference(
    model = Model(linear_regression, length(xdata)),
    data  = (y = ydata, x = xdata),
    initmessages = (b = NormalMeanVariance(0.0, 100.0), ),
    returnvars   = (a = KeepLast(), b = KeepLast()),
    iterations   = 20
);
Inference results:
-----------------------------------------
a = NormalWeightedMeanPrecision{Float64}(xi=2.613625160215102e6, w=5.221277087207961...
b = NormalWeightedMeanPrecision{Float64}(xi=6198.9695185245455, w=249.00599724152573...
pra = plot(range(-3, 3, length = 1000), (x) -> pdf(NormalMeanVariance(0.0, 1.0), x), title="Prior for a parameter", fillalpha=0.3, fillrange = 0, label="Prior P(a)", c=1,)
pra = vline!(pra, [ reala ], label="Real a", c = 3)
psa = plot(range(0.45, 0.55, length = 1000), (x) -> pdf(results.posteriors[:a], x), title="Posterior for a parameter", fillalpha=0.3, fillrange = 0, label="Posterior P(a)", c=2,)
psa = vline!(psa, [ reala ], label="Real a", c = 3)

plot(pra, psa, size = (1000, 200))
prb = plot(range(-40, 40, length = 1000), (x) -> pdf(NormalMeanVariance(0.0, 100.0), x), title="Prior for b parameter", fillalpha=0.3, fillrange = 0, label="Prior P(b)", c=1, legend = :topleft)
prb = vline!(prb, [ realb ], label="Real b", c = 3)
psb = plot(range(23, 28, length = 1000), (x) -> pdf(results.posteriors[:b], x), title="Posterior for b parameter", fillalpha=0.3, fillrange = 0, label="Posterior P(b)", c=2, legend = :topleft)
psb = vline!(psb, [ realb ], label="Real b", c = 3)

plot(prb, psb, size = (1000, 200))
a = results.posteriors[:a]
b = results.posteriors[:b]

println("Real a: ", reala, " | Estimated a: ", mean(a), " | Error: ", abs(mean(a) - reala))
println("Real b: ", realb, " | Estimated b: ", mean(b), " | Error: ", abs(mean(b) - realb))
Real a: 0.5 | Estimated a: 0.5005720088325591 | Error: 0.0005720088325591455
Real b: 25 | Estimated b: 24.894860313391554 | Error: 0.10513968660844597

We can see that ReactiveMP.jl estimated real values of linear regression coefficients with high precision. Lets also test the benchmark of the resulting inference procedure.

@benchmark inference(
    model = Model($linear_regression, length($xdata)),
    data  = (y = $ydata, x = $xdata),
    initmessages = (b = NormalMeanVariance(0.0, 100.0), ),
    returnvars   = (a = KeepLast(), b = KeepLast()),
    iterations = 20
)
BenchmarkTools.Trial: 62 samples with 1 evaluation.
 Range (min … max):  67.509 ms … 153.554 ms  ┊ GC (min … max): 0.00% … 40.23%
 Time  (median):     73.461 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   81.513 ms ±  21.032 ms  ┊ GC (mean ± σ):  8.73% ± 13.88%

    ▁▄█▁                                                        
  ▃▇██████▄▄▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▃▁▄▁▁▁▁▃▁▁▁▁▁▁▄ ▁
  67.5 ms         Histogram: frequency by time          144 ms <

 Memory estimate: 23.73 MiB, allocs estimate: 500957.