This lecture is about the overlap between gradient descent and Markov chain Monte Carlo. It consists of five sections, each of which is largely based on the following material:
- Gradient flows and Rethinking SGD noise
- Bayesian Learning via Stochastic Gradient Langevin Dynamics
- A Complete Recipe for Stochastic Gradient MCMC
- Bridging the Gap between SG-MCMC and Stochastic Optimization
This tutorial additionally refers to
- Stochastic quasi-Newton Langevin Monte Carlo paper
- A tutorial introduction to Monte Carlo methods, Markov chain Monte Carlo and particle filtering
- MCMC using Hamiltonian dynamics
- MCMC demo
- hamiltorch library
Each section contains an interactive demo or code snippet. We will have short breaks after each section. You are encouraged to play around with these during breaks.
Let's get started with importing the necessary libraries. You need to install hamiltorch in addition to standard libraries.
%load_ext autoreload
%autoreload 2
import importlib, pickle, os, sys, copy, numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
os.environ['KMP_DUPLICATE_LIB_OK']='True'
# %matplotlib notebook
plt.rcParams.update({
"text.usetex": True,
"font.family": "serif",
"font.serif": ["Palatino"],
})
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import hamiltorch
from anim_utils import plot_opt_animation, plot_mcmc_animation, plot_mcmc_surface
1. Ordinary Differential Equations (ODEs) and Gradient Descent (GD)¶
1.1. Ordinary Differential Equations (ODEs)¶
Ordinary differential equations involve an independent variable, its functions and derivatives of these functions. Formally,
\begin{equation} \dot{x}(t) = \frac{dx(t)}{dt} = \lim_{\Delta t \rightarrow 0} \frac{ x(t + \Delta t) - x(t)}{\Delta t} = f(x(t),t), \end{equation}
where
- $t$ denotes time (or any other independent variable)
- $x(t) \in \mathcal{X} \in \mathbb{R}^d$ is the state vector at time $t$ (thus dependent variable)
- $\dot{x}(t) \in \dot{\mathcal{X}} \in \mathbb{R}^d$ is the first order time derivative of $x(t)$
- $f : \mathcal{X} \times \mathbb{R}_+ \rightarrow \dot{\mathcal{X}}$ is the vector-valued and continuous (time) differential function describing the system's evolution over time with $\mathbb{R}_+$ denoting non-negative real numbers.
We often refer to $f$ as vector field or right hand side. Informally speaking, $f$ tells "how much the state $x(t)$ would change with an infinitisemal change in $t$".
An "ODE state solution" $x(t)$ at time $t\in \mathbb{R}_+$ is given by \begin{equation} x(t) = x_0 + \int_0^t f(x_\tau)~d\tau, \end{equation} where $x_0$ denotes the initial value and $\tau \in \mathbb{R}_+$ is an auxiliary time variable. Given an initial value $x_0$ and a set of time points $\{t_0,t_1,\ldots,t_N\}$, we are often interested in state solutions $x_{0:N}\equiv\{x(t_0),x(t_1),\ldots,x(t_N)\}$
Numerical Integration¶
Above integral has a tractable form only for very trivial differential functions (recall the integration rules from high school). This is why we almost always resort to numerical simulation.
More formally, we consider the rate of change of $x$ around $t$: \begin{equation} x(t+\delta) = x(t) + \delta f(x(t)) + o(||\delta||). \end{equation}
Assuming we can live with the smaller error term, we obtain the following recursive (and discrete) algorithm to compute state solutions: \begin{equation} x(t+\delta) = x(t) + \delta f(x(t)). \end{equation}
1.2. Gradient Descent (GD)¶
Gradient descent (GD) is by far the most commonly used optimization technique in deep learning era. Here, we restrict ourself to maximum a-posteriori estimation problems. Let's first set up the notation:
- $\theta$ denotes a parameter vector
- $p(\theta)$ is our prior belief on the parameters
- $p(x\mid\theta)$ denotes the likelihood
- $X \equiv \{ x_i \}_{i=1}^N$ is an iid dataset
Posterior distribution is proportional to $$ p(\theta \mid X) \propto \prod_i p(x_i \mid \theta) p(\theta).$$
In practice, MAP estimation task translates to minimizing minus log posterior density: $$ \min_\theta~~ g(\theta) \equiv -\log p(\theta \mid X) = -\log p(\theta) - \sum_i \log p(x_i \mid \theta)$$
Gradient descent (GD) aims to minimize a function by taking small steps in the direction of steepest descent, which is given by the gradient vector. More concretely, it iterates the following update equation:
\begin{equation} \theta_{n+1} = \theta_{n} - \delta \nabla g(\theta_{n}), \end{equation}
where $n$ denotes the current step, $\theta_{n}$ is the current iterate, and $\delta>0$ is the learning rate.
From GD to ODEs¶
Assuming that each GD update step takes "$\delta$ time", i.e. $\theta_{n} \equiv \theta(\delta n)$, and defining $t \equiv \delta n$, we obtain the following:
\begin{equation} \theta(t+\delta) = \theta(t) - \delta \nabla g(\theta(t)). \end{equation}
Comparing this with the last equation in "numerical integration" section, we conclude that
GD is an ODE system in which vector field is given by the gradient of the optimized function.
Below we contrast gradient descent and ODE flow.
# the loss function
m = torch.tensor([0.0,0.0])
std = torch.tensor([[1.0,0.81],[0.81,1.0]])
dist = torch.distributions.MultivariateNormal(m,std)
def forward_simulate(f, x0, ts):
X = [x0]
ode_steps = len(ts)-1
for i in range(ode_steps):
h = ts[i+1]-ts[i]
t = ts[i]
x = X[i]
x_next = x + h*f(t,x)
X.append(x_next)
X = torch.stack(X) # T,N,d
return X
def loss_fnc(x):
return -dist.log_prob(x)
# GD and ODE flow simulations
def solve_systems(gd_lr=0.33, num_iter=30, num_euler=50):
x0_ = torch.tensor([0.,5.]).to(torch.float32)
xpar = torch.nn.Parameter(x0_.clone()) # 1,P
# gradient descent solution
opt = torch.optim.SGD([xpar],lr=gd_lr)
gd_loss_trace = []
for i in range(num_iter):
gd_loss_trace.append(xpar.detach().clone())
opt.zero_grad()
loss = loss_fnc(xpar)
loss.backward()
opt.step()
gd_loss_trace = torch.stack(gd_loss_trace).numpy()
# ode solution
dt = gd_lr / num_euler
ts = torch.arange(num_euler*num_iter) * dt
x0 = torch.nn.Parameter(x0_.clone())
def odef(t,x):
f = -loss_fnc(x)
grad = torch.autograd.grad(f,x)[0]
return grad
ode_loss_trace = forward_simulate(odef, x0, ts) # T,P
ode_loss_trace = ode_loss_trace.detach().numpy()
return gd_loss_trace,ode_loss_trace
gd_loss_trace, ode_loss_trace = solve_systems()
anim = plot_opt_animation(loss_fnc, gd_loss_trace, ode_loss_trace)
HTML(anim.to_jshtml())