Example: Flow tutorial

Normalizing flows: a tutorial

Table of contents

  1. Introduction
  2. Model specification
  3. Model compilation
  4. Probabilistic inference
  5. Parameter estimation

Introduction

Normalizing flows are parameterized mappings of random variables, which map simple base distributions to more complex distributions. These mappings are constrained to be invertible and differentiable and can be composed of multiple simpler mappings for improved expressivity.

Load required packages

Before we can start, we need to import some packages:

using ReactiveMP
using Rocket
using GraphPPL
using Random

Random.seed!(123)

using LinearAlgebra     # only used for some matrix specifics
using PyPlot            # only used for visualisation
using Distributions     # only used for sampling from multivariate distributions
using Optim             # only used for parameter optimisation
[ Info: Installing matplotlib via the Conda matplotlib package...
[ Info: Running `conda install -q -y matplotlib` in root environment
Collecting package metadata (current_repodata.json): ...working... done
Solving environment: ...working... done

## Package Plan ##

  environment location: /home/runner/.julia/conda/3

  added / updated specs:
    - matplotlib


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    alsa-lib-1.2.3             |       h516909a_0         560 KB  conda-forge
    brotli-1.0.9               |       h7f98852_6          18 KB  conda-forge
    brotli-bin-1.0.9           |       h7f98852_6          19 KB  conda-forge
    cycler-0.11.0              |     pyhd8ed1ab_0          10 KB  conda-forge
    dbus-1.13.6                |       h5008d03_3         604 KB  conda-forge
    expat-2.4.2                |       h9c3ff4c_0         182 KB  conda-forge
    fontconfig-2.13.1          |    hba837de_1005         357 KB  conda-forge
    fonttools-4.28.5           |   py39h3811e60_0         1.6 MB  conda-forge
    freetype-2.10.4            |       h0708190_1         890 KB  conda-forge
    gettext-0.19.8.1           |    h73d1719_1008         3.6 MB  conda-forge
    gst-plugins-base-1.18.5    |       hf529b03_3         2.6 MB  conda-forge
    gstreamer-1.18.5           |       h9f60fe5_3         2.0 MB  conda-forge
    icu-68.2                   |       h9c3ff4c_0        13.1 MB  conda-forge
    jbig-2.1                   |    h7f98852_2003          43 KB  conda-forge
    jpeg-9d                    |       h36c2ea0_0         264 KB  conda-forge
    kiwisolver-1.3.2           |   py39h1a9c180_1          80 KB  conda-forge
    krb5-1.19.2                |       hcc1bbae_3         1.4 MB  conda-forge
    lcms2-2.12                 |       hddcbb42_0         443 KB  conda-forge
    lerc-3.0                   |       h9c3ff4c_0         216 KB  conda-forge
    libbrotlicommon-1.0.9      |       h7f98852_6          65 KB  conda-forge
    libbrotlidec-1.0.9         |       h7f98852_6          33 KB  conda-forge
    libbrotlienc-1.0.9         |       h7f98852_6         286 KB  conda-forge
    libclang-11.1.0            |default_ha53f305_1        19.2 MB  conda-forge
    libdeflate-1.8             |       h7f98852_0          67 KB  conda-forge
    libedit-3.1.20191231       |       he28a2e2_2         121 KB  conda-forge
    libevent-2.1.10            |       h9b69904_4         1.1 MB  conda-forge
    libglib-2.70.2             |       h174f98d_1         3.1 MB  conda-forge
    libiconv-1.16              |       h516909a_0         1.4 MB  conda-forge
    libllvm11-11.1.0           |       hf817b99_2        29.1 MB  conda-forge
    libogg-1.3.4               |       h7f98852_1         206 KB  conda-forge
    libopus-1.3.1              |       h7f98852_1         255 KB  conda-forge
    libpng-1.6.37              |       h21135ba_2         306 KB  conda-forge
    libpq-13.5                 |       hd57d9b9_1         2.8 MB  conda-forge
    libtiff-4.3.0              |       h6f004c6_2         614 KB  conda-forge
    libuuid-2.32.1             |    h7f98852_1000          28 KB  conda-forge
    libvorbis-1.3.7            |       h9c3ff4c_0         280 KB  conda-forge
    libwebp-base-1.2.1         |       h7f98852_0         845 KB  conda-forge
    libxcb-1.13                |    h7f98852_1004         391 KB  conda-forge
    libxkbcommon-1.0.3         |       he3ba5ed_0         581 KB  conda-forge
    libxml2-2.9.12             |       h72842e0_0         772 KB  conda-forge
    lz4-c-1.9.3                |       h9c3ff4c_1         179 KB  conda-forge
    matplotlib-3.5.1           |   py39hf3d152e_0           6 KB  conda-forge
    matplotlib-base-3.5.1      |   py39h2fa2bec_0         7.4 MB  conda-forge
    munkres-1.1.4              |     pyh9f0ad1d_0          12 KB  conda-forge
    mysql-common-8.0.27        |       ha770c72_3         1.8 MB  conda-forge
    mysql-libs-8.0.27          |       hfa10184_3         1.9 MB  conda-forge
    nspr-4.32                  |       h9c3ff4c_1         233 KB  conda-forge
    nss-3.74                   |       hb5efdd6_0         2.1 MB  conda-forge
    olefile-0.46               |     pyh9f0ad1d_1          32 KB  conda-forge
    openjpeg-2.4.0             |       hb52868f_1         444 KB  conda-forge
    packaging-21.3             |     pyhd8ed1ab_0          36 KB  conda-forge
    pcre-8.45                  |       h9c3ff4c_0         253 KB  conda-forge
    pillow-8.4.0               |   py39ha612740_0         706 KB  conda-forge
    pthread-stubs-0.4          |    h36c2ea0_1001           5 KB  conda-forge
    pyparsing-3.0.6            |     pyhd8ed1ab_0          79 KB  conda-forge
    pyqt-5.12.3                |   py39hf3d152e_8          22 KB  conda-forge
    pyqt-impl-5.12.3           |   py39hde8b62d_8         5.9 MB  conda-forge
    pyqt5-sip-4.19.18          |   py39he80948d_8         311 KB  conda-forge
    pyqtchart-5.12             |   py39h0fcd23e_8         257 KB  conda-forge
    pyqtwebengine-5.12.1       |   py39h0fcd23e_8         174 KB  conda-forge
    python-dateutil-2.8.2      |     pyhd8ed1ab_0         240 KB  conda-forge
    qt-5.12.9                  |       hda022c4_4        99.5 MB  conda-forge
    tornado-6.1                |   py39h3811e60_2         646 KB  conda-forge
    xorg-libxau-1.0.9          |       h7f98852_0          13 KB  conda-forge
    xorg-libxdmcp-1.1.3        |       h7f98852_0          19 KB  conda-forge
    zstd-1.5.1                 |       ha95c52a_0         463 KB  conda-forge
    ------------------------------------------------------------
                                           Total:       212.0 MB

The following NEW packages will be INSTALLED:

  alsa-lib           conda-forge/linux-64::alsa-lib-1.2.3-h516909a_0
  brotli             conda-forge/linux-64::brotli-1.0.9-h7f98852_6
  brotli-bin         conda-forge/linux-64::brotli-bin-1.0.9-h7f98852_6
  cycler             conda-forge/noarch::cycler-0.11.0-pyhd8ed1ab_0
  dbus               conda-forge/linux-64::dbus-1.13.6-h5008d03_3
  expat              conda-forge/linux-64::expat-2.4.2-h9c3ff4c_0
  fontconfig         conda-forge/linux-64::fontconfig-2.13.1-hba837de_1005
  fonttools          conda-forge/linux-64::fonttools-4.28.5-py39h3811e60_0
  freetype           conda-forge/linux-64::freetype-2.10.4-h0708190_1
  gettext            conda-forge/linux-64::gettext-0.19.8.1-h73d1719_1008
  gst-plugins-base   conda-forge/linux-64::gst-plugins-base-1.18.5-hf529b03_3
  gstreamer          conda-forge/linux-64::gstreamer-1.18.5-h9f60fe5_3
  icu                conda-forge/linux-64::icu-68.2-h9c3ff4c_0
  jbig               conda-forge/linux-64::jbig-2.1-h7f98852_2003
  jpeg               conda-forge/linux-64::jpeg-9d-h36c2ea0_0
  kiwisolver         conda-forge/linux-64::kiwisolver-1.3.2-py39h1a9c180_1
  krb5               conda-forge/linux-64::krb5-1.19.2-hcc1bbae_3
  lcms2              conda-forge/linux-64::lcms2-2.12-hddcbb42_0
  lerc               conda-forge/linux-64::lerc-3.0-h9c3ff4c_0
  libbrotlicommon    conda-forge/linux-64::libbrotlicommon-1.0.9-h7f98852_6
  libbrotlidec       conda-forge/linux-64::libbrotlidec-1.0.9-h7f98852_6
  libbrotlienc       conda-forge/linux-64::libbrotlienc-1.0.9-h7f98852_6
  libclang           conda-forge/linux-64::libclang-11.1.0-default_ha53f305_1
  libdeflate         conda-forge/linux-64::libdeflate-1.8-h7f98852_0
  libedit            conda-forge/linux-64::libedit-3.1.20191231-he28a2e2_2
  libevent           conda-forge/linux-64::libevent-2.1.10-h9b69904_4
  libglib            conda-forge/linux-64::libglib-2.70.2-h174f98d_1
  libiconv           conda-forge/linux-64::libiconv-1.16-h516909a_0
  libllvm11          conda-forge/linux-64::libllvm11-11.1.0-hf817b99_2
  libogg             conda-forge/linux-64::libogg-1.3.4-h7f98852_1
  libopus            conda-forge/linux-64::libopus-1.3.1-h7f98852_1
  libpng             conda-forge/linux-64::libpng-1.6.37-h21135ba_2
  libpq              conda-forge/linux-64::libpq-13.5-hd57d9b9_1
  libtiff            conda-forge/linux-64::libtiff-4.3.0-h6f004c6_2
  libuuid            conda-forge/linux-64::libuuid-2.32.1-h7f98852_1000
  libvorbis          conda-forge/linux-64::libvorbis-1.3.7-h9c3ff4c_0
  libwebp-base       conda-forge/linux-64::libwebp-base-1.2.1-h7f98852_0
  libxcb             conda-forge/linux-64::libxcb-1.13-h7f98852_1004
  libxkbcommon       conda-forge/linux-64::libxkbcommon-1.0.3-he3ba5ed_0
  libxml2            conda-forge/linux-64::libxml2-2.9.12-h72842e0_0
  lz4-c              conda-forge/linux-64::lz4-c-1.9.3-h9c3ff4c_1
  matplotlib         conda-forge/linux-64::matplotlib-3.5.1-py39hf3d152e_0
  matplotlib-base    conda-forge/linux-64::matplotlib-base-3.5.1-py39h2fa2bec_0
  munkres            conda-forge/noarch::munkres-1.1.4-pyh9f0ad1d_0
  mysql-common       conda-forge/linux-64::mysql-common-8.0.27-ha770c72_3
  mysql-libs         conda-forge/linux-64::mysql-libs-8.0.27-hfa10184_3
  nspr               conda-forge/linux-64::nspr-4.32-h9c3ff4c_1
  nss                conda-forge/linux-64::nss-3.74-hb5efdd6_0
  olefile            conda-forge/noarch::olefile-0.46-pyh9f0ad1d_1
  openjpeg           conda-forge/linux-64::openjpeg-2.4.0-hb52868f_1
  packaging          conda-forge/noarch::packaging-21.3-pyhd8ed1ab_0
  pcre               conda-forge/linux-64::pcre-8.45-h9c3ff4c_0
  pillow             conda-forge/linux-64::pillow-8.4.0-py39ha612740_0
  pthread-stubs      conda-forge/linux-64::pthread-stubs-0.4-h36c2ea0_1001
  pyparsing          conda-forge/noarch::pyparsing-3.0.6-pyhd8ed1ab_0
  pyqt               conda-forge/linux-64::pyqt-5.12.3-py39hf3d152e_8
  pyqt-impl          conda-forge/linux-64::pyqt-impl-5.12.3-py39hde8b62d_8
  pyqt5-sip          conda-forge/linux-64::pyqt5-sip-4.19.18-py39he80948d_8
  pyqtchart          conda-forge/linux-64::pyqtchart-5.12-py39h0fcd23e_8
  pyqtwebengine      conda-forge/linux-64::pyqtwebengine-5.12.1-py39h0fcd23e_8
  python-dateutil    conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0
  qt                 conda-forge/linux-64::qt-5.12.9-hda022c4_4
  tornado            conda-forge/linux-64::tornado-6.1-py39h3811e60_2
  xorg-libxau        conda-forge/linux-64::xorg-libxau-1.0.9-h7f98852_0
  xorg-libxdmcp      conda-forge/linux-64::xorg-libxdmcp-1.1.3-h7f98852_0
  zstd               conda-forge/linux-64::zstd-1.5.1-ha95c52a_0


Preparing transaction: ...working... done
Verifying transaction: ...working... done
Executing transaction: ...working... done

Model specification

Specifying a flow model is easy. The general recipe looks like follows: model = FlowModel(input_dim, (layer1(options), layer2(options), ...)). Here the first argument corresponds to the input dimension of the model and the second argument is a tuple of layers. An example flow model can be defined as

model = FlowModel(2,
    (
        AdditiveCouplingLayer(PlanarFlow()),
        AdditiveCouplingLayer(PlanarFlow(); permute=false)
    )
)

Alternatively, the input_dim can also be passed as an InputLayer layer as

model = FlowModel(
    (
        InputLayer(2),
        AdditiveCouplingLayer(PlanarFlow()),
        AdditiveCouplingLayer(PlanarFlow(); permute=false)
    )
)
nothing #hide

In the above AdditiveCouplingLayer layers the input ${\bf{x}} = [x_1, x_2, \ldots, x_N]$ is partitioned into chunks of unit length. These partitions are additively coupled to an output ${\bf{y}} = [y_1, y_2, \ldots, y_N]$ as

\[\begin{align*} y_1 &= x_1 \\ y_2 &= x_2 + f_1(x_1) \\ \vdots \\ y_N &= x_N + f_{N-1}(x_{N-1}) \end{align*}\]

math

Importantly, this structure can easily be converted as

\[\begin{align*} x_1 &= y_1 \\ x_2 &= y_2 - f_1(x_1) \\ \vdots \\ x_N &= y_N - f_{N-1}(x_{N-1}) \end{align*}\]

\[f_n\]

is an arbitrarily complex function, here chosen to be a PlanarFlow, but this can be interchanged for any function or neural network. The permute keyword argument (which defaults to true) specifies whether the output of this layer should be randomly permuted or shuffled. This makes sure that the first element is also transformed in consecutive layers.

A permutation layer can also be added by itself as a PermutationLayer layer with a custom permutation matrix if desired.

model = FlowModel(
    (
        InputLayer(2),
        AdditiveCouplingLayer(PlanarFlow(); permute=false),
        PermutationLayer(PermutationMatrix(2)),
        AdditiveCouplingLayer(PlanarFlow(); permute=false)
    )
)

Model compilation

In the current models, the layers are setup to work with the passed input dimension. This means that the function $f_n$ is repeated input_dim-1 times for each of the partitions. Furthermore the permutation layers are set up with proper permutation matrices. If we print the model we get

model
FlowModel{3, Tuple{ReactiveMP.AdditiveCouplingLayerEmpty{Tuple{ReactiveMP.PlanarFlowEmpty{1}}}, PermutationLayer{Int64}, ReactiveMP.AdditiveCouplingLayerEmpty{Tuple{ReactiveMP.PlanarFlowEmpty{1}}}}}(2, (ReactiveMP.AdditiveCouplingLayerEmpty{Tuple{ReactiveMP.PlanarFlowEmpty{1}}}(2, (ReactiveMP.PlanarFlowEmpty{1}(),), 1), PermutationLayer{Int64}(2, [0 1; 1 0]), ReactiveMP.AdditiveCouplingLayerEmpty{Tuple{ReactiveMP.PlanarFlowEmpty{1}}}(2, (ReactiveMP.PlanarFlowEmpty{1}(),), 1)))

The text below describes the terms above. Please note the distinction in typing and elements, i.e. FlowModel{types}(elements):

  • FlowModel - specifies that we are dealing with a flow model.
  • 3 - Number of layers.
  • Tuple{AdditiveCouplingLayerEmpty{...},PermutationLayer{Int64},AdditiveCouplingLayerEmpty{...}} - tuple of layer types.
  • Tuple{ReactiveMP.PlanarFlowEmpty{1},ReactiveMP.PlanarFlowEmpty{1}} - tuple of functions $f_n$.
  • PermutationLayer{Int64}(2, [0 1; 1 0]) - permutation layer with input dimension 2 and permutation matrix [0 1; 1 0].

From inspection we can see that the AdditiveCouplingLayerEmpty and PlanarFlowEmpty objects are different than before. They are initialized for the correct dimension, but they do not have any parameters registered to them. This is by design to allow for separating the model specification from potential optimization procedures. Before we perform inference in this model, the parameters should be initialized. We can randomly initialize the parameters as

compiled_model = compile(model)
CompiledFlowModel{3, Tuple{AdditiveCouplingLayer{Tuple{PlanarFlow{Float64, Float64}}}, PermutationLayer{Int64}, AdditiveCouplingLayer{Tuple{PlanarFlow{Float64, Float64}}}}}(2, (AdditiveCouplingLayer{Tuple{PlanarFlow{Float64, Float64}}}(2, (PlanarFlow{Float64, Float64}(-1.6236037455860806, -0.21766510678354617, 0.4922456865251828),), 1), PermutationLayer{Int64}(2, [0 1; 1 0]), AdditiveCouplingLayer{Tuple{PlanarFlow{Float64, Float64}}}(2, (PlanarFlow{Float64, Float64}(0.9809798121241488, 0.0799568295050599, 1.5491245530427917),), 1)))

Probabilistic inference

We can perform inference in our compiled model through standard usage of ReactiveMP. Let's first generate some random 2D data which has been sampled from a standard normal distribution and is consecutively passed through a normalizing flow. Using the forward(model, data) function we can propagate data in the forward direction through the flow.

function generate_data(nr_samples::Int64, model::CompiledFlowModel)

    # specify latent sampling distribution
    dist = MvNormal([1.5, 0.5], I)

    # sample from the latent distribution
    x = rand(dist, nr_samples)

    # transform data
    y = zeros(Float64, size(x))
    for k = 1:nr_samples
        y[:,k] .= ReactiveMP.forward(model, x[:,k])
    end

    # return data
    return y, x

end
generate_data (generic function with 1 method)
# generate data
y, x = generate_data(1000, compiled_model)

# plot generated data
_, ax = plt.subplots(ncols=2, figsize=(15,5))
ax[1].scatter(x[1,:], x[2,:], alpha=0.3)
ax[2].scatter(y[1,:], y[2,:], alpha=0.3)
ax[1].set_title("Original data")
ax[2].set_title("Transformed data")
ax[1].grid(), ax[2].grid()
plt.gcf()

The probabilistic model for doing inference can be described as

@model function normalizing_flow(nr_samples::Int64, compiled_model::CompiledFlowModel)

    # initialize variables
    z_μ   = randomvar()
    z_Λ   = randomvar()
    x     = randomvar(nr_samples)
    y_lat = randomvar(nr_samples)
    y     = datavar(Vector{Float64}, nr_samples)

    # specify prior
    z_μ ~ MvNormalMeanCovariance(zeros(2), huge*diagm(ones(2)))
    z_Λ ~ Wishart(2.0, tiny*diagm(ones(2)))

    # specify model
    meta = FlowMeta(compiled_model) # defaults to FlowMeta(compiled_model; approximation=Linearization()).
                                    # other approximation methods can be e.g. FlowMeta(compiled_model; approximation=Unscented(input_dim))

    # specify observations
    for k = 1:nr_samples

        # specify latent state
        x[k] ~ MvNormalMeanPrecision(z_μ, z_Λ) where { q = MeanField() }

        # specify transformed latent value
        y_lat[k] ~ Flow(x[k]) where { meta = meta }

        # specify observations
        y[k] ~ MvNormalMeanCovariance(y_lat[k], tiny*diagm(ones(2)))

    end

    # return variables
    return z_μ, z_Λ, x, y_lat, y

end
normalizing_flow (generic function with 1 method)

Here the flow model is passed inside a meta data object of the flow node. Inference then resorts to

function inference_flow(data_y::Array{Array{Float64,1},1}, compiled_model::CompiledFlowModel; nr_iterations::Int64=10)

    # fetch number of samples
    nr_samples = length(data_y)

    # define model
    model, (z_μ, z_Λ, x, y_lat, y) = normalizing_flow(nr_samples, compiled_model)

    # initialize buffer for latent states
    mzμ = keep(Marginal)
    mzΛ = keep(Marginal)
    mx  = buffer(Marginal, nr_samples)
    my  = buffer(Marginal, nr_samples)

    # initialize free energy
    fe_values = Vector{Float64}()

    # subscribe to z
    zμ_sub = subscribe!(getmarginal(z_μ), mzμ)
    zΛ_sub = subscribe!(getmarginal(z_Λ), mzΛ)
    x_sub  = subscribe!(getmarginals(x), mx)
    y_sub  = subscribe!(getmarginals(y_lat), my)
    fe_sub = subscribe!(score(BetheFreeEnergy(), model), (fe) -> push!(fe_values, fe))

    # set initial marginals
    setmarginal!(z_μ, MvNormalMeanCovariance(zeros(2), huge*diagm(ones(2))))
    setmarginal!(z_Λ, Wishart(2.0, tiny*diagm(ones(2))))

    # update y according to observations (i.e. perform inference)
    for it = 1:nr_iterations
        ReactiveMP.update!(y, data_y)
    end

    # unsubscribe
    unsubscribe!([zμ_sub, zΛ_sub, x_sub, y_sub, fe_sub])

    # return the marginal values
    return getvalues(mzμ)[end], getvalues(mzΛ)[end], getvalues(mx), getvalues(my), fe_values

end;
inference_flow (generic function with 1 method)

The following line of code then executes the inference algorithm.

zμ_flow, zΛ_flow, x_flow, y_flow, fe_flow = inference_flow([y[:,k] for k=1:size(y,2)], compiled_model)

As we can see, the variational free energy decreases inside of our model.

plt.figure()
plt.plot(1:10, fe_flow/size(y,2))
plt.grid()
plt.xlim(1,10)
plt.xlabel("iteration")
plt.ylabel("normalized variational free energy [nats/sample]")
plt.gcf()

If we plot a random noisy observation and its approximated transformed uncertainty we obtain:

# pick a random observation
id = rand(1:size(y,2))
rand_observation = MvNormal(y[:,id], 5e-1*diagm(ones(2)))
warped_observation = MvNormal(ReactiveMP.backward(compiled_model, y[:,id]), ReactiveMP.inv_jacobian(compiled_model, y[:,id])*5e-1*diagm(ones(2))*ReactiveMP.inv_jacobian(compiled_model, y[:,id])');

# plot inferred means and transformed point
fig, ax = plt.subplots(ncols = 2, figsize=(15,5))
ax[1].scatter(x[1,:], x[2,:], alpha=0.1, label="generated data")
ax[1].contour(repeat(-5:0.1:5, 1, 101), repeat(-5:0.1:5, 1, 101)', map( (x) -> pdf(MvNormal([1.5, 0.5], I), [x...]), collect(Iterators.product(-5:0.1:5, -5:0.1:5))), label="true distribution")
ax[1].scatter(mean(zμ_flow)[1], mean(zμ_flow)[2], color="red", marker="x", label="inferred mean")
ax[1].contour(repeat(-10:0.01:10, 1, 2001), repeat(-10:0.01:10, 1, 2001)', map( (x) -> pdf(warped_observation, [x...]), collect(Iterators.product(-10:0.01:10, -10:0.01:10))), colors="red", levels=1)
ax[1].scatter(mean(warped_observation)..., color="red", s=10, label="transformed noisy observation")
ax[2].scatter(y[1,:], y[2,:], alpha=0.1, label="generated data")
ax[2].scatter(ReactiveMP.forward(compiled_model, mean(zμ_flow))..., color="red", marker="x", label="inferred mean")
ax[2].contour(repeat(-10:0.1:10, 1, 201), repeat(-10:0.1:10, 1, 201)', map( (x) -> pdf(MvNormal([1.5, 0.5], I), ReactiveMP.backward(compiled_model, [x...])), collect(Iterators.product(-10:0.1:10, -10:0.1:10))))
ax[2].contour(repeat(-10:0.1:10, 1, 201), repeat(-10:0.1:10, 1, 201)', map( (x) -> pdf(rand_observation, [x...]), collect(Iterators.product(-10:0.1:10, -10:0.1:10))), colors="red", levels=1, label="random noisy observation")
ax[2].scatter(mean(rand_observation)..., color="red", s=10, label="random noisy observation")
ax[1].grid(), ax[2].grid()
ax[1].set_xlim(-4,4), ax[1].set_ylim(-4,4), ax[2].set_xlim(-10,10), ax[2].set_ylim(-10,10)
ax[1].legend(), ax[2].legend()
fig.suptitle("Generated data")
ax[1].set_title("Latent distribution"), ax[2].set_title("Observed distribution")
plt.gcf()

Parameter estimation

The flow model is often used to learn unknown probabilistic mappings. Here we will demonstrate it as follows for a binary classification task with the following data:

function generate_data(nr_samples::Int64)

    # sample weights
    w = rand(nr_samples,2)

    # sample appraisal
    y = zeros(Float64, nr_samples)
    for k = 1:nr_samples
        y[k] = 1.0*(w[k,1] > 0.5)*(w[k,2] < 0.5)
    end

    # return data
    return y, w

end
generate_data (generic function with 2 methods)
data_y, data_x = generate_data(200);
plt.figure()
plt.scatter(data_x[:,1], data_x[:,2], c=data_y)
plt.grid()
plt.xlabel("w1")
plt.ylabel("w2")
plt.gcf()

We will then specify a possible flow model as

# specify flow model
model = FlowModel(2,
    (
        AdditiveCouplingLayer(PlanarFlow()), # defaults to AdditiveCouplingLayer(PlanarFlow(); permute=true)
        AdditiveCouplingLayer(PlanarFlow()),
        AdditiveCouplingLayer(PlanarFlow()),
        AdditiveCouplingLayer(PlanarFlow(); permute=false)
    )
);
FlowModel{7, Tuple{ReactiveMP.AdditiveCouplingLayerEmpty{Tuple{ReactiveMP.PlanarFlowEmpty{1}}}, PermutationLayer{Int64}, ReactiveMP.AdditiveCouplingLayerEmpty{Tuple{ReactiveMP.PlanarFlowEmpty{1}}}, PermutationLayer{Int64}, ReactiveMP.AdditiveCouplingLayerEmpty{Tuple{ReactiveMP.PlanarFlowEmpty{1}}}, PermutationLayer{Int64}, ReactiveMP.AdditiveCouplingLayerEmpty{Tuple{ReactiveMP.PlanarFlowEmpty{1}}}}}(2, (ReactiveMP.AdditiveCouplingLayerEmpty{Tuple{ReactiveMP.PlanarFlowEmpty{1}}}(2, (ReactiveMP.PlanarFlowEmpty{1}(),), 1), PermutationLayer{Int64}(2, [0 1; 1 0]), ReactiveMP.AdditiveCouplingLayerEmpty{Tuple{ReactiveMP.PlanarFlowEmpty{1}}}(2, (ReactiveMP.PlanarFlowEmpty{1}(),), 1), PermutationLayer{Int64}(2, [0 1; 1 0]), ReactiveMP.AdditiveCouplingLayerEmpty{Tuple{ReactiveMP.PlanarFlowEmpty{1}}}(2, (ReactiveMP.PlanarFlowEmpty{1}(),), 1), PermutationLayer{Int64}(2, [0 1; 1 0]), ReactiveMP.AdditiveCouplingLayerEmpty{Tuple{ReactiveMP.PlanarFlowEmpty{1}}}(2, (ReactiveMP.PlanarFlowEmpty{1}(),), 1)))

The corresponding probabilistic model for the binary classification task can be created as

@model function flow_classifier(nr_samples::Int64, model::FlowModel, params)

    # initialize variables
    x_lat  = randomvar(nr_samples)
    y_lat1 = randomvar(nr_samples)
    y_lat2 = randomvar(nr_samples)
    y      = datavar(Float64, nr_samples)
    x      = datavar(Vector{Float64}, nr_samples)

    # compile flow model
    meta  = FlowMeta(compile(model, params)) # default: FlowMeta(model, Linearization())

    # specify observations
    for k = 1:nr_samples

        # specify latent state
        x_lat[k] ~ MvNormalMeanPrecision(x[k], 1e3*diagm(ones(2)))

        # specify transformed latent value
        y_lat1[k] ~ Flow(x_lat[k]) where { meta = meta }
        y_lat2[k] ~ dot(y_lat1[k], [1, 1])

        # specify observations
        y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireInbound(in = NormalMeanPrecision(0, 1.0)) }

    end

    # return variables
    return x_lat, x, y_lat1, y_lat2, y

end
flow_classifier (generic function with 1 method)

Here we see that the compilation occurs inside of our probabilistic model. As a result we can pass parameters (and a model) to this function which we wish to opmize for some criterium, such as the variational free energy. Inference can be described as

function inference_flow_classifier(data_y::Array{Float64,1}, data_x::Array{Array{Float64,1},1}, model::FlowModel, params)

    # fetch number of samples
    nr_samples = length(data_y)

    # define model
    model, (x_lat, x, y_lat1, y_lat2, y) = flow_classifier(nr_samples, model, params)

    # initialize free energy
    fe_buffer = nothing

    # subscribe
    fe_sub = subscribe!(score(BetheFreeEnergy(), model), (fe) -> fe_buffer = fe)

    # update y and x according to observations (i.e. perform inference)
    ReactiveMP.update!(y, data_y)
    ReactiveMP.update!(x, data_x)

    # unsubscribe
    unsubscribe!(fe_sub)

    # return the marginal values
    return fe_buffer

end
inference_flow_classifier (generic function with 1 method)

For the optimization procedure, we will simplify our inference loop, such that it only accepts parameters as an argument (which is wishes to optimize) and outputs a performance metric.

function f(params)
    fe = inference_flow_classifier(data_y, [data_x[k,:] for k=1:size(data_x,1)], model, params)
    return fe
end
f (generic function with 1 method)

Optimization can be performed using the Optim package. Alternatively, other (custom) optimizers can be implemented, such as:

res = optimize(f, randn(nr_params(model)), LBFGS(), Optim.Options(g_tol = 1e-3, iterations = 100, store_trace = true, show_trace = true)) - uses finitediff and is slower/less accurate.

or

# create gradient function
g = (x) -> ForwardDiff.gradient(f, x);

# specify initial params
params = randn(nr_params(model))

# create custom optimizer (here Adam)
optimizer = Adam(params; λ=1e-1)

# allocate space for gradient
∇ = zeros(nr_params(model))

# perform optimization
for it = 1:10000

    # backward pass
    ∇ .= ForwardDiff.gradient(f, optimizer.x)

    # gradient update
    ReactiveMP.update!(optimizer, ∇)

end
res = optimize(f, randn(nr_params(model)), LBFGS(), Optim.Options(store_trace = true, show_trace = true), autodiff=:forward)
nothing #hide

optimization results are then given as

params = Optim.minimizer(res)
inferred_model = compile(model, params)
trans_data_x_1 = hcat(map((x) -> ReactiveMP.forward(inferred_model, x), [data_x[k,:] for k=1:size(data_x,1)])...)'
trans_data_x_2 = map((x) -> dot([1, 1], x), [trans_data_x_1[k,:] for k=1:size(data_x,1)])
trans_data_x_2_split = [trans_data_x_2[data_y .== 1.0], trans_data_x_2[data_y .== 0.0]]
fig, ax = plt.subplots(ncols = 3, figsize=(15,5))
ax[1].scatter(data_x[:,1], data_x[:,2], c = data_y)
ax[2].scatter(trans_data_x_1[:,1], trans_data_x_1[:,2], c = data_y)
ax[3].hist(trans_data_x_2_split; stacked=true, bins=50, color = ["gold", "purple"])
ax[1].grid(), ax[2].grid(), ax[3].grid()
ax[1].set_xlim(-0.25,1.25), ax[1].set_ylim(-0.25,1.25)
ax[1].set_title("original data"), ax[2].set_title("|> warp"), ax[3].set_title("|> dot")
plt.gcf()
using StatsFuns: normcdf
classification_map = map((x) -> normcdf(dot([1,1],x)), map((x) -> ReactiveMP.forward(inferred_model, [x...]), collect(Iterators.product(0:0.01:1, 0:0.01:1))))
fig, ax = plt.subplots(ncols = 3, figsize=(20,5))
im1 = ax[1].scatter(data_x[:,1], data_x[:,2], c = data_y)
im2 = ax[2].scatter(data_x[:,1], data_x[:,2], c = normcdf.(trans_data_x_2))
ax[3].contour(repeat(0:0.01:1, 1, 101), repeat(0:0.01:1, 1, 101)', classification_map)
plt.colorbar(im1, ax=ax[1])
plt.colorbar(im2, ax=ax[2])
ax[1].grid(), ax[2].grid(), ax[3].grid()
ax[1].set_xlabel("weight 1"), ax[1].set_ylabel("weight 2"), ax[2].set_xlabel("weight 1"), ax[2].set_ylabel("weight 2"), ax[3].set_xlabel("weight 1"), ax[3].set_ylabel("weight 2")
ax[1].set_title("original labels"), ax[2].set_title("predicted labels"), ax[3].set_title("Classification map")
plt.gcf()