Motivation

 The reparameterization trick allows for the efficient computation of gradients through random variables, enabling the optimization of parametric probability models using stochastic gradient descent

Consider a function of the form:

Info

Our goal is to answer the question: How can we compute ?

Before calculating , let’s discuss how to calculate . We can use Monte Carlo methods to estimate the expectation by repeatedly sampling , calculating , and then averaging all :

The following is the computational graph that describes how to generate a single sample from the parameters .

graph TD;  
θ{θ}-->P(P);
P-->X{X};
X-->F;
F-->Y{Y};
P(P)-.->θ{θ};
X{X}-.->P;
F-.->X;
Y{Y}-.->F;

The shape and style of nodes/edges define their semantics:

  • Diamond nodes represent data (), and act as the input/output of functions.
  • Rounded boxes represent stochastic functions ()
  • Rectangular boxes represent deterministic functions ():
  • Full arrows show the direction data flows for generating a sample based on parameters .
  • The dashed arrows show the direction that gradients flow when differentiating the output with respect to the parameters .

The computational graph makes it clear how to calculate : We simple accumulate the gradients along the dashed arrows. However, an issue arises since we need to differentiate through a stochastic function , which is not well-defined.

We discuss two options for handling this: the REINFORCE estimator, and the reparameterization trick.

REINFORCE Estimator

The REINFORCE estimator uses the log-derivative trick to rewrite as an expectation that we can calculate directly:

This final expression is called the REINFORCE estimator for . Similar to how we approximated , we can approximate this expectation by sampling and averaging . One issue with this estimator is that it has high variance since it involves multiplying and .

Instead of using the high-variance REINFORCE estimator, we can use the reparameterization trick.

Reparameterization trick

References