Gaussian Process-based Amortization of Variational Message Passing Update Rules

Abstract

Variational Message Passing facilitates automated variational inference in factorized probabilistic models where connected factors are conjugate pairs. Conjugate-computation Variational Inference (CVI) extends the applicability of VMP to models comprising both conjugate and non-conjugate factors. CVI makes use of a gradient that is estimated by Monte Carlo (MC) sampling, which potentially leads to substantial computational load. As a result, for models that feature a large number of non-conjugate pairs, CVI-based inference may not scale well to larger model sizes. In this paper, we propose a Gaussian Process-enhanced CVI approach, called GP-CVI, to amortize the computational costs caused by the MC sampling procedures in CVI. Specifically, we train a Gaussian process regression (GPR) model based on a set of incoming outgoing message pairs that were generated by CVI. In operation, we use the “cheaper” GPR model to produce outgoing messages and resort to the more accurate but expensive CVI message only if the variance of the outgoing message exceeds a threshold. By experimental validation, we show that GP-CVI gradually uses more fast memory-based update rule computations, and less sampling-based update rule computations. As a result, GP-CVI speeds up CVI with a controllable effect on the accuracy of the inference procedure.

Publication
2022 European Signal Processing Conference
Date