In this tutorial, we make an introduction to neural ordinary differential equations (NODEs) [chen2018neural]. A one-sentence summary of this model family is
We will start this tutorial with a discussion on ODEs. Instead of presenting techniqual details, we will give a practical introduction to ODEs. Next, we formally describe NODEs and show three standard use cases of ODEs: classification, normalizing flows and latent dynamics learning. The lecture will be closed by works that study different aspects of the vanilla NODEs.
Organization of the Lecture¶
Introduction (10min)
Formal descriptions of ODEs (20min)
1.1. Computing ODE solutions
1.2. Example: Van der Pol oscillator
1.3. Break: VDP & ODE integration parametersNeural ODEs (20min)
2.1. Problem formulation
2.2. Maximum likelihood estimation
2.3. Example: learning VDP sequences with NODE
2.4. Break: AdjointsLatent Bayesian Neural ODEs (20min)
3.1. Variational Inference
3.2. Evidence Lower-bound
3.3. Example: Rotating MNIST
3.4. Implementation
Long Break (15min)ResNets are Discretized ODEs (20min)
4.1. Classification Objective
4.2. Implementation
4.3. Training
4.4. Break: ODE solver parametersContinuous-time Normalizing Flows (20min)
5.1. Normalizing Flows
5.2. Continuous-time Normalizing Flows
5.3. Implementation
5.4. Training
5.5. Break: Wrap-offRelated Studies (15min)
6.1. ODE-RNN [rubanova2019latent]
6.2. ODE$^2$VAE [yildiz2019deep]
6.3. Augmented NODEs [dupont2019augmented]
6.4. Regularized NODEs [finlay2020train]
6.5. ACA [zhuang2020adaptive]
6.6. ODE-RL [yildiz2021continuous]
6.7. NSDEs [tzen2019neural], [xu2022infinitely]
6.8. GP-ODEs [hegde2022variational]Summary & Q&A (5+25min)
NOTE: Most of the code pieces in this tutorial as well as the figures are from the original neural ODE paper and corresponding github repo.
Practicalities¶
- Each section ends with a 5-10-min break in which you can read the provided material and/or code snippets, ask questions, or just take a rest. Feel free to arrange your breaks in accordance with your needs.
- In addition to mathematical descriptions of the techniques, we provide short code snippets for the model definitions, training and visualization. Training could be too time consuming for this session; so make sure to load the pre-trained models if you would like to visualize the fits.
- Most of the implementation in this notebook depends on the provided utility files, some of which might be too involved to grasp immediately. If you're interested, go ahead and check them out.
The following cell imports all the required libraries.
%load_ext autoreload
%autoreload 2
!pip install torch torchvision torchdiffeq numpy scipy matplotlib pillow sklearn
import numpy as np
from IPython import display
import time
from sklearn.datasets import make_circles
import torch
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from torchdiffeq import odeint
from bnn import BNN
from vae_utils import MNIST_Encoder, MNIST_Decoder
from plot_utils import plot_vdp_trajectories, plot_ode, plot_vdp_animation, plot_cnf_animation, \
plot_mnist_sequences, plot_mnist_predictions, plot_cnf_data
from utils import get_minibatch, mnist_loaders, inf_generator, mnist_accuracy, \
count_parameters, conv3x3, group_norm, Flatten, load_rotating_mnist
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload Requirement already satisfied: torch in c:\users\cagat\anaconda3\lib\site-packages (1.11.0) Requirement already satisfied: torchvision in c:\users\cagat\anaconda3\lib\site-packages (0.12.0) Requirement already satisfied: torchdiffeq in c:\users\cagat\anaconda3\lib\site-packages (0.2.2) Requirement already satisfied: numpy in c:\users\cagat\anaconda3\lib\site-packages (1.20.3) Requirement already satisfied: scipy in c:\users\cagat\anaconda3\lib\site-packages (1.7.1) Requirement already satisfied: matplotlib in c:\users\cagat\anaconda3\lib\site-packages (3.4.3) Requirement already satisfied: pillow in c:\users\cagat\anaconda3\lib\site-packages (8.4.0) Requirement already satisfied: sklearn in c:\users\cagat\anaconda3\lib\site-packages (0.0) Requirement already satisfied: typing_extensions in c:\users\cagat\anaconda3\lib\site-packages (from torch) (3.10.0.2) Requirement already satisfied: requests in c:\users\cagat\anaconda3\lib\site-packages (from torchvision) (2.26.0) Requirement already satisfied: python-dateutil>=2.7 in c:\users\cagat\anaconda3\lib\site-packages (from matplotlib) (2.8.2) Requirement already satisfied: cycler>=0.10 in c:\users\cagat\anaconda3\lib\site-packages (from matplotlib) (0.10.0) Requirement already satisfied: kiwisolver>=1.0.1 in c:\users\cagat\anaconda3\lib\site-packages (from matplotlib) (1.3.1) Requirement already satisfied: pyparsing>=2.2.1 in c:\users\cagat\anaconda3\lib\site-packages (from matplotlib) (3.0.4) Requirement already satisfied: scikit-learn in c:\users\cagat\anaconda3\lib\site-packages (from sklearn) (0.24.2) Requirement already satisfied: six in c:\users\cagat\anaconda3\lib\site-packages (from cycler>=0.10->matplotlib) (1.16.0) Requirement already satisfied: certifi>=2017.4.17 in c:\users\cagat\anaconda3\lib\site-packages (from requests->torchvision) (2021.10.8) Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\users\cagat\anaconda3\lib\site-packages (from requests->torchvision) (1.26.7) Requirement already satisfied: charset-normalizer~=2.0.0 in c:\users\cagat\anaconda3\lib\site-packages (from requests->torchvision) (2.0.4) Requirement already satisfied: idna<4,>=2.5 in c:\users\cagat\anaconda3\lib\site-packages (from requests->torchvision) (3.2) Requirement already satisfied: joblib>=0.11 in c:\users\cagat\anaconda3\lib\site-packages (from scikit-learn->sklearn) (1.1.0) Requirement already satisfied: threadpoolctl>=2.0.0 in c:\users\cagat\anaconda3\lib\site-packages (from scikit-learn->sklearn) (2.2.0)
1. Ordinary Differential Equations (ODEs)¶
Ordinary differential equations involve an independent variable, its functions and derivatives of these functions. Formally,
\begin{equation} \dot{\mathbf{x}}(t) = \frac{d\mathbf{x}(t)}{dt} = \lim_{\Delta t \rightarrow 0} \frac{ \mathbf{x}(t + \Delta t) - \mathbf{x}(t)}{\Delta t} = \mathbf{f}(\mathbf{x}(t),\mathbf{u}(t),t), \end{equation}
where
- $t$ denotes time (or any other independent variable)
- $\mathbf{x}(t) \in \mathcal{X} \in \mathbb{R}^d$ is the state vector at time $t$ (thus dependent variable)
- $\mathbf{u}(t) \in \mathcal{A} \in \mathbb{R}^m$ is the external control signal
- $\dot{\mathbf{x}}(t) \in \dot{\mathcal{X}} \in \mathbb{R}^d$ is the first order time derivative of $\mathbf{x}(t)$
- $\mathbf{f} : \mathcal{X} \times \mathcal{A} \times \mathbb{R}_+ \rightarrow \dot{\mathcal{X}}$ is the vector-valued and continuous (time) differential function describing the system's evolution over time with $\mathbb{R}_+$ denoting non-negative real numbers.
Informally speaking, $\mathbf{f}$ tells "how much the state $\mathbf{x}(t)$ would change with an infinitisemal change in $t$". More formally, below equation holds in the limit $\Delta t \rightarrow 0$: \begin{equation} \mathbf{x}(t+\Delta t) = \mathbf{x}(t) + \Delta t \cdot \mathbf{f}(\mathbf{x}(t),\mathbf{u}(t),t). \end{equation}
Note-1: We often refer to $\mathbf{f}$ as vector field or right hand side.
Note-2: Above problem is also known as initial value problem.
Note-3: Throughout this tutorial, we focus on differential functions $\mathbf{f}(\mathbf{x}(t))$ independent of control signals and not explicitly parameterized by time.
1.1. Computing ODE Solutions¶
An "ODE state solution" $\mathbf{x}(t)$ at time $t\in \mathbb{R}_+$ is given by \begin{equation} \mathbf{x}(t) = \mathbf{x}_0 + \int_0^t \mathbf{f}(\mathbf{x}_\tau)~d\tau, \end{equation} where $\mathbf{x}_0$ denotes the initial value and $\tau \in \mathbb{R}_+$ is an auxiliary time variable.
Note-1: Given an initial value $\mathbf{x}_0$ and a set of time points $\{t_0,t_1,\ldots,t_N\}$, we are often interested in state solutions $\mathbf{x}_{0:N}\equiv\{\mathbf{x}(t_0),\mathbf{x}(t_1),\ldots,\mathbf{x}(t_N)\}$
Note-2: We occassionaly denote $\mathbf{x}_n \equiv \mathbf{x}(t_n)$.
Note-3: Above integral has a tractable form only for very trivial differential functions (recall the integration rules from high school). Therefore, we almost always resort to numerical solvers.
Numerical solvers: TL;DR: A state solution $\mathbf{x}(t)$ can be numerically computed up to a tolerable error.
The celebrated Picard's existence and uniqueness theorem states that an initial value problem has a unique solution if the time differential satisfies the Lipschitz condition. Despite the uniqueness guarantee, there is no general recipe to analytically compute the solution; therefore, we often resort to numerical methods. The simplest and least efficient numerical method is known as Euler's method (above equation). More advanced methods such as Heun's method and Runge-Kutta family of solvers compute average slopes by evaluating $\mathbf{f}(\mathbf{x}(t))$ at multiple locations (speed vs accuracy trade-off). Even more advanced adaptive step solvers set the step size $\Delta t$ dynamically.
In this tutorial, we use torchdiffeq library that implements the adjoint method for gradient estimations.
1.2. Example: Van der Pol Oscillator¶
As an example, we examine Van der Pol (VDP) oscillator, a parametric $2D$ time-invariant ODE system that evolves according to the following: \begin{equation} \label{eq:vdp} \frac{d}{dt} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} = \begin{bmatrix} x_2 \\ \mu(1-x_2^2)x_2-x_1 \end{bmatrix}. \end{equation}
Our VDP implementatation below follows the two requirements of torchdiffeq
:
- Integrated function shall be an instance of
nn.Module
. - The
forward()
function must take (time,state) pair as input.
# define the differential function
class VDP(nn.Module):
def __init__(self,mu):
''' mu is the only parameter in VDP oscillator '''
super().__init__()
self.mu = mu
def forward(self, t, x):
''' Implements the right hand side
Inputs
t - [] time
x - [N,d] state(s)
Output
\dot{x} - [N,d], time derivative
'''
d1 = x[...,1:2]
d2 = self.mu*(1-x[...,0:1]**2)*x[...,1:2]-x[...,0:1]
return torch.cat([d1,d2],-1)
Next, we instantiate the three ingredients (differential function $\mathbf{f}$, initial value $\mathbf{x}_0$, integration time points $t$), forward integrate, and visualize how integration proceeds.
# create the differential function, needs to be a nn.Module
vdp = VDP(1.0).to(device)
# initial value, of shape [N,n]
x0 = torch.tensor([[1.0,0.0]]).float().to(device)
# integration time points, of shape [T]
ts = torch.linspace(0., 15., 500).to(device)
# forward integration
with torch.no_grad():
X = odeint(vdp, x0, ts) # [T,N,n]
# animation
anim = plot_vdp_animation(ts,X,vdp)
display.HTML(anim.to_jshtml())