Motivation
The reparameterization trick allows for the efficient computation of gradients through random variables, enabling the optimization of parametric probability models using stochastic gradient descent
Consider a function of the form:
Info
Our goal is to answer the question: How can we compute ?
Before calculating , let’s discuss how to calculate . We can use Monte Carlo methods to estimate the expectation by repeatedly sampling , calculating , and then averaging all :
The following is the computational graph that describes how to generate a single sample from the parameters .
graph TD; θ{θ}-->P(P); P-->X{X}; X-->F; F-->Y{Y}; P(P)-.->θ{θ}; X{X}-.->P; F-.->X; Y{Y}-.->F;
The shape and style of nodes/edges define their semantics:
- Diamond nodes represent data (), and act as the input/output of functions.
- Rounded boxes represent stochastic functions ()
- Rectangular boxes represent deterministic functions ():
- Full arrows show the direction data flows for generating a sample based on parameters .
- The dashed arrows show the direction that gradients flow when differentiating the output with respect to the parameters .
The computational graph makes it clear how to calculate : We simple accumulate the gradients along the dashed arrows. However, an issue arises since we need to differentiate through a stochastic function , which is not well-defined.
We discuss two options for handling this: the REINFORCE estimator, and the reparameterization trick.
REINFORCE Estimator
The REINFORCE estimator uses the log-derivative trick to rewrite as an expectation that we can calculate directly:
This final expression is called the REINFORCE estimator for . Similar to how we approximated , we can approximate this expectation by sampling and averaging . One issue with this estimator is that it has high variance since it involves multiplying and .
Instead of using the high-variance REINFORCE estimator, we can use the reparameterization trick.