In this tutorial, we make an introduction to neural ordinary differential equations (NODEs) [chen2018neural]. A one-sentence summary of this model family is

an ODE system in which the differential function is a neural network (NN).

We will start this tutorial with a discussion on ODEs. Instead of presenting techniqual details, we will give a practical introduction to ODEs. Next, we formally describe NODEs and show three standard use cases of ODEs: classification, normalizing flows and latent dynamics learning. The lecture will be closed by works that study different aspects of the vanilla NODEs.

Organization of the Lecture¶

  1. Introduction (10min)

  2. Formal descriptions of ODEs (20min)
    1.1. Computing ODE solutions
    1.2. Example: Van der Pol oscillator
    1.3. Break: VDP & ODE integration parameters

  3. Neural ODEs (20min)
    2.1. Problem formulation
    2.2. Maximum likelihood estimation
    2.3. Example: learning VDP sequences with NODE
    2.4. Break: Adjoints

  4. Latent Bayesian Neural ODEs (20min)
    3.1. Variational Inference
    3.2. Evidence Lower-bound
    3.3. Example: Rotating MNIST
    3.4. Implementation

    Long Break (15min)

  5. ResNets are Discretized ODEs (20min)
    4.1. Classification Objective
    4.2. Implementation
    4.3. Training
    4.4. Break: ODE solver parameters

  6. Continuous-time Normalizing Flows (20min)
    5.1. Normalizing Flows
    5.2. Continuous-time Normalizing Flows
    5.3. Implementation
    5.4. Training
    5.5. Break: Wrap-off

  7. Related Studies (15min)
    6.1. ODE-RNN [rubanova2019latent]
    6.2. ODE$^2$VAE [yildiz2019deep]
    6.3. Augmented NODEs [dupont2019augmented]
    6.4. Regularized NODEs [finlay2020train]
    6.5. ACA [zhuang2020adaptive]
    6.6. ODE-RL [yildiz2021continuous]
    6.7. NSDEs [tzen2019neural], [xu2022infinitely]
    6.8. GP-ODEs [hegde2022variational]

  8. Summary & Q&A (5+25min)

NOTE: Most of the code pieces in this tutorial as well as the figures are from the original neural ODE paper and corresponding github repo.

Practicalities¶

  • Each section ends with a 5-10-min break in which you can read the provided material and/or code snippets, ask questions, or just take a rest. Feel free to arrange your breaks in accordance with your needs.
  • In addition to mathematical descriptions of the techniques, we provide short code snippets for the model definitions, training and visualization. Training could be too time consuming for this session; so make sure to load the pre-trained models if you would like to visualize the fits.
  • Most of the implementation in this notebook depends on the provided utility files, some of which might be too involved to grasp immediately. If you're interested, go ahead and check them out.

The following cell imports all the required libraries.

In [53]:
%load_ext autoreload
%autoreload 2
!pip install torch torchvision torchdiffeq numpy scipy matplotlib pillow sklearn

import numpy as np
from IPython import display
import time
from sklearn.datasets import make_circles

import torch
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from torchdiffeq  import odeint
from bnn          import BNN
from vae_utils    import MNIST_Encoder, MNIST_Decoder
from plot_utils   import plot_vdp_trajectories, plot_ode, plot_vdp_animation, plot_cnf_animation, \
    plot_mnist_sequences, plot_mnist_predictions, plot_cnf_data
from utils       import get_minibatch, mnist_loaders, inf_generator, mnist_accuracy, \
    count_parameters, conv3x3, group_norm, Flatten, load_rotating_mnist
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Requirement already satisfied: torch in c:\users\cagat\anaconda3\lib\site-packages (1.11.0)
Requirement already satisfied: torchvision in c:\users\cagat\anaconda3\lib\site-packages (0.12.0)
Requirement already satisfied: torchdiffeq in c:\users\cagat\anaconda3\lib\site-packages (0.2.2)
Requirement already satisfied: numpy in c:\users\cagat\anaconda3\lib\site-packages (1.20.3)
Requirement already satisfied: scipy in c:\users\cagat\anaconda3\lib\site-packages (1.7.1)
Requirement already satisfied: matplotlib in c:\users\cagat\anaconda3\lib\site-packages (3.4.3)
Requirement already satisfied: pillow in c:\users\cagat\anaconda3\lib\site-packages (8.4.0)
Requirement already satisfied: sklearn in c:\users\cagat\anaconda3\lib\site-packages (0.0)
Requirement already satisfied: typing_extensions in c:\users\cagat\anaconda3\lib\site-packages (from torch) (3.10.0.2)
Requirement already satisfied: requests in c:\users\cagat\anaconda3\lib\site-packages (from torchvision) (2.26.0)
Requirement already satisfied: python-dateutil>=2.7 in c:\users\cagat\anaconda3\lib\site-packages (from matplotlib) (2.8.2)
Requirement already satisfied: cycler>=0.10 in c:\users\cagat\anaconda3\lib\site-packages (from matplotlib) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in c:\users\cagat\anaconda3\lib\site-packages (from matplotlib) (1.3.1)
Requirement already satisfied: pyparsing>=2.2.1 in c:\users\cagat\anaconda3\lib\site-packages (from matplotlib) (3.0.4)
Requirement already satisfied: scikit-learn in c:\users\cagat\anaconda3\lib\site-packages (from sklearn) (0.24.2)
Requirement already satisfied: six in c:\users\cagat\anaconda3\lib\site-packages (from cycler>=0.10->matplotlib) (1.16.0)
Requirement already satisfied: certifi>=2017.4.17 in c:\users\cagat\anaconda3\lib\site-packages (from requests->torchvision) (2021.10.8)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\users\cagat\anaconda3\lib\site-packages (from requests->torchvision) (1.26.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in c:\users\cagat\anaconda3\lib\site-packages (from requests->torchvision) (2.0.4)
Requirement already satisfied: idna<4,>=2.5 in c:\users\cagat\anaconda3\lib\site-packages (from requests->torchvision) (3.2)
Requirement already satisfied: joblib>=0.11 in c:\users\cagat\anaconda3\lib\site-packages (from scikit-learn->sklearn) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in c:\users\cagat\anaconda3\lib\site-packages (from scikit-learn->sklearn) (2.2.0)

1. Ordinary Differential Equations (ODEs)¶

Ordinary differential equations involve an independent variable, its functions and derivatives of these functions. Formally,

\begin{equation} \dot{\mathbf{x}}(t) = \frac{d\mathbf{x}(t)}{dt} = \lim_{\Delta t \rightarrow 0} \frac{ \mathbf{x}(t + \Delta t) - \mathbf{x}(t)}{\Delta t} = \mathbf{f}(\mathbf{x}(t),\mathbf{u}(t),t), \end{equation}

where

  • $t$ denotes time (or any other independent variable)
  • $\mathbf{x}(t) \in \mathcal{X} \in \mathbb{R}^d$ is the state vector at time $t$ (thus dependent variable)
  • $\mathbf{u}(t) \in \mathcal{A} \in \mathbb{R}^m$ is the external control signal
  • $\dot{\mathbf{x}}(t) \in \dot{\mathcal{X}} \in \mathbb{R}^d$ is the first order time derivative of $\mathbf{x}(t)$
  • $\mathbf{f} : \mathcal{X} \times \mathcal{A} \times \mathbb{R}_+ \rightarrow \dot{\mathcal{X}}$ is the vector-valued and continuous (time) differential function describing the system's evolution over time with $\mathbb{R}_+$ denoting non-negative real numbers.

Informally speaking, $\mathbf{f}$ tells "how much the state $\mathbf{x}(t)$ would change with an infinitisemal change in $t$". More formally, below equation holds in the limit $\Delta t \rightarrow 0$: \begin{equation} \mathbf{x}(t+\Delta t) = \mathbf{x}(t) + \Delta t \cdot \mathbf{f}(\mathbf{x}(t),\mathbf{u}(t),t). \end{equation}

Note-1: We often refer to $\mathbf{f}$ as vector field or right hand side.
Note-2: Above problem is also known as initial value problem.
Note-3: Throughout this tutorial, we focus on differential functions $\mathbf{f}(\mathbf{x}(t))$ independent of control signals and not explicitly parameterized by time.

1.1. Computing ODE Solutions¶

An "ODE state solution" $\mathbf{x}(t)$ at time $t\in \mathbb{R}_+$ is given by \begin{equation} \mathbf{x}(t) = \mathbf{x}_0 + \int_0^t \mathbf{f}(\mathbf{x}_\tau)~d\tau, \end{equation} where $\mathbf{x}_0$ denotes the initial value and $\tau \in \mathbb{R}_+$ is an auxiliary time variable.

Note-1: Given an initial value $\mathbf{x}_0$ and a set of time points $\{t_0,t_1,\ldots,t_N\}$, we are often interested in state solutions $\mathbf{x}_{0:N}\equiv\{\mathbf{x}(t_0),\mathbf{x}(t_1),\ldots,\mathbf{x}(t_N)\}$
Note-2: We occassionaly denote $\mathbf{x}_n \equiv \mathbf{x}(t_n)$.
Note-3: Above integral has a tractable form only for very trivial differential functions (recall the integration rules from high school). Therefore, we almost always resort to numerical solvers.

Numerical solvers: TL;DR: A state solution $\mathbf{x}(t)$ can be numerically computed up to a tolerable error.

The celebrated Picard's existence and uniqueness theorem states that an initial value problem has a unique solution if the time differential satisfies the Lipschitz condition. Despite the uniqueness guarantee, there is no general recipe to analytically compute the solution; therefore, we often resort to numerical methods. The simplest and least efficient numerical method is known as Euler's method (above equation). More advanced methods such as Heun's method and Runge-Kutta family of solvers compute average slopes by evaluating $\mathbf{f}(\mathbf{x}(t))$ at multiple locations (speed vs accuracy trade-off). Even more advanced adaptive step solvers set the step size $\Delta t$ dynamically.

In this tutorial, we use torchdiffeq library that implements the adjoint method for gradient estimations.

1.2. Example: Van der Pol Oscillator¶

As an example, we examine Van der Pol (VDP) oscillator, a parametric $2D$ time-invariant ODE system that evolves according to the following: \begin{equation} \label{eq:vdp} \frac{d}{dt} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} = \begin{bmatrix} x_2 \\ \mu(1-x_2^2)x_2-x_1 \end{bmatrix}. \end{equation}

Our VDP implementatation below follows the two requirements of torchdiffeq:

  • Integrated function shall be an instance of nn.Module.
  • The forward() function must take (time,state) pair as input.
In [2]:
# define the differential function
class VDP(nn.Module):
    
    def __init__(self,mu):
        ''' mu is the only parameter in VDP oscillator '''
        super().__init__()
        self.mu = mu
        
    def forward(self, t, x):
        ''' Implements the right hand side
            Inputs
                t - []     time
                x - [N,d]  state(s)
            Output
                \dot{x} - [N,d], time derivative
        '''
        d1 = x[...,1:2]
        d2 = self.mu*(1-x[...,0:1]**2)*x[...,1:2]-x[...,0:1]
        return torch.cat([d1,d2],-1)

Next, we instantiate the three ingredients (differential function $\mathbf{f}$, initial value $\mathbf{x}_0$, integration time points $t$), forward integrate, and visualize how integration proceeds.

In [3]:
# create the differential function, needs to be a nn.Module
vdp = VDP(1.0).to(device)

# initial value, of shape [N,n]
x0 = torch.tensor([[1.0,0.0]]).float().to(device)

# integration time points, of shape [T]
ts = torch.linspace(0., 15., 500).to(device)

# forward integration
with torch.no_grad():
    X = odeint(vdp, x0, ts) # [T,N,n]

# animation
anim = plot_vdp_animation(ts,X,vdp)
display.HTML(anim.to_jshtml())
Out[3]:
No description has been provided for this image

1.3. Break: VDP & ODE Parameters¶

Van der Pol oscillator has a single parameter set to 1 above: $\mu=1$. Below cell implements the same illustration, except that we plot instead of animate. Use this break to play around with the parameter $\mu$ and initial value $\mathbf{x}_0$ to see how the tinyest change affects the whole trajectory. Note that below we visualize two trajectories as $\mathbf{x}_0$ contains two initial values.

In [4]:
# feel free to modify the parameter
vdp = VDP(5.0).to(device)

# feel free to try out different initial values
x0 = torch.tensor(
    [[-2.0,-3.0],[-2.0,3.0]]
).float().to(device)

# integration time points, of shape [T]
ts = torch.linspace(0., 15., 500).to(device)

# forward integration
with torch.no_grad():
    X = odeint(vdp, x0, ts) # [T,N,D]

plot_ode(ts,X,vdp)
No description has been provided for this image

2. Neural ODE (NODE)¶

To motivate the neural ODEs, imagine that we observe a sequence $\mathbf{y}_{0:N}$ that is generated by a continuous-time system (the sequence could be the measurements from a physical system, motion of objects, flow of electric current, substance rates in a chemical reaction, etc). How can we find the time evolution of such observed systems?

  • If we know the underlying ODE (such as the VDP system), we can use statistics and optimization tools to estimate parameters ($\mu$).
  • What if we do not know what functional form and/or the parameters of the ODE system? We estimate the time evolution function by any function approximator (In practice, one could define GP-ODEs, linear-regression-ODEs, kernel-regression-ODEs, etc).

2.1. Problem Formulation¶

In more concrete terms, let's say our dataset contains a noisy observed sequence $\mathbf{y}_{0:N}$

\begin{align} \mathbf{y}_n &= \mathbf{x}_n + \epsilon, \qquad \epsilon\sim\mathcal{N}(0,\sigma^2), \end{align}

where each observation is a perturbation of an unknown state $\mathbf{x}_n$ generated by an unknown underlying vector field $\mathbf{f}_\text{true}$

\begin{align} \mathbf{x}_n &= \mathbf{x}_0 + \int_0^{t_n} \mathbf{f}_\text{true}(\mathbf{x}_\tau)~d\tau. \end{align}

Our goal is to learn a neural network $\mathbf{f}_\mathbf{w}$ with parameters $\mathbf{w}$ that matches the unknown dynamics:

$$\mathbf{f}_\mathbf{w} \approx \mathbf{f}_\text{true}.$$

Let's start by implementing a NODE system. We use a simple multi-layer perceptron with two hidden layers. Since vector fields are smooth, we opt for the smooth ELU activation instead of ReLU.

In [55]:
class NODE(nn.Module):
    def __init__(self, d):
        ''' d - ODE dimensionality '''
        super().__init__()
        self._f = nn.Sequential(nn.Linear(d,200), 
                                nn.ELU(), 
                                nn.Linear(200,200), 
                                nn.ELU(), 
                                nn.Linear(200,d))
    
    def ode_rhs(self, t, x):
        ''' differential function = f(x)'''
        return self._f(x)
    
    def forward(self, ts, x0, method='dopri5'):
        ''' Forward integrates the NODE system and returns state solutions
            Input
                ts - [T]   time points
                x0 - [N,d] initial value
            Returns
                X  - [T,N,d] forward simulated states
        '''
        return odeint(self.ode_rhs, x0, ts, method=method)

Now, let's see what the forward trajectory $\mathbf{x}_{0:N}$ looks like when the differential function $\mathbf{f}_\mathbf{w}$ is a NN with randomly initialized weights $w_i \sim \mathbb{U}(-k,k)$. Here, $\mathbb{U}$ and $k$ denote the uniform distribution and the number of input features. As you will see below, small random weights typically translate into smooth and small functions outputs, making the initial trajectory smooth as well.

In [59]:
node = NODE(2).to(device)

# let's compute the integral of our neural net!
x0 = torch.tensor([[1.0,0.0]]).float().to(device)
ts = torch.linspace(0., 20., 1000).to(device)

X = node(ts,x0)
plot_ode(ts, X, node.ode_rhs)
No description has been provided for this image

2.2. Maximum Likelihood Estimation¶

The simplest approach to approximate the unknown vector field $\mathbf{f}_\text{true}$ is the maximum-likelihood estimation. Since we do not have access to the vector field $\mathbf{f}_\text{true}$, we propose to match the forward simulated states with the observations:

\begin{align} \min_\mathbf{w} ~~ \mathcal{L} = \frac{1}{2} \sum_n ||\mathbf{y}_n-\mathbf{x}_n||_2^2 \qquad \text{s.t.} \qquad \mathbf{x}_n = \mathbf{x}_0 + \int_0^{t_n} \mathbf{f}_\mathbf{w}(\mathbf{x}_\tau)~d\tau. \end{align}

Observe that forward simulated states $\mathbf{x}(t)$ are functions of NN parameters $\mathbf{w}$. In the following, we show the dependency explicitly by using $\mathbf{x}(t_n;\mathbf{w})$ instead of $\mathbf{x}_n$. The gradient of the loss wrt $\mathbf{w}$ can be computed by chain rule:

\begin{align} \frac{d\mathcal{L}}{d\mathbf{w}} = \sum_n (\mathbf{x}(t_n;\mathbf{w})-\mathbf{y}_n) \frac{d\mathbf{x}(t_n;\mathbf{w})}{d\mathbf{w}} \end{align}

The second term is the derivative of the forward simulated state $\frac{d\mathbf{x}(t_n;\mathbf{w})}{d\mathbf{w}}$ against the vector field parameters $\mathbf{w}$. In other words, we need to compute the derivative through the ODE solver, which is not a straightforward task. This can be done by forward sensitivity or adjoints equations. Both techniques compute the gradient by solving a second ODE system. Due to its lower memory footpring, torchdiffeq library implements the latter.

2.3. Example: Learning VDP Sequences with NODE¶

Next, we test our NODE system on noisy VDP sequences. To generate data, we randomly pick 10 initial values and forward integrate all trajectories concurrently. Luckily, this only requires setting the initial values and the rest of the implementation stays the same.

In [57]:
# lets first generate data
vdp = VDP(1.0).to(device)
x0 = 6*torch.rand([10,2]).to(device) - 3 # 10 random initial values in [-3,3]
tvdp = torch.linspace(0., 10., 50).to(device)
with torch.no_grad():
    Xvdp = odeint(vdp, x0, tvdp)
    Yvdp = Xvdp + torch.randn_like(Xvdp)*0.1 # noisy data with observation noise has std 0.1

plot_vdp_trajectories(tvdp, Yvdp, vdp)
Plotting the first 3 data sequences.
No description has been provided for this image

We now train the model on the observed sequences. To speed up the training, we optimize for a subsequence instead of the whole sequence (see get_minibatch function in odevae_utils.py).

In [60]:
# optimization loop
Niter  = 1000 # number of optimization iterations
tsub   = 11   # subsequence length in each minibatch

optimizer = torch.optim.Adam(node.parameters(),1e-3)
for i in range(Niter):
    optimizer.zero_grad()
    t_,Y_ = get_minibatch(tvdp, Yvdp, tsub=tsub)
    Xhat = node(t_, Y_[0]) # forward simulation
    loss = ((Xhat-Y_)**2).mean() # MSE
    loss.backward()
    optimizer.step()
    if i%50==0:
        Xhat = node(tvdp, Yvdp[0]) # forward simulation
        display.clear_output(wait=True)
        plot_ode(tvdp, Yvdp, node.ode_rhs, Xhat.detach())
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_22896/3685077909.py in <module>
     14         Xhat = node(tvdp, Yvdp[0]) # forward simulation
     15         display.clear_output(wait=True)
---> 16         plot_ode(tvdp, Yvdp, node.ode_rhs, Xhat.detach())

~\OneDrive\academic-mixed\conferences\22-probai\plot_utils.py in plot_ode(t, X, ode_rhs, Xhat, L, return_fig)
     81     fig = plt.figure(1,[15,7.5],constrained_layout=True)
     82     gs  = fig.add_gridspec(3, 3)
---> 83     ax1 = fig.add_subplot(gs[:, 0])
     84 
     85     ax1.set_xlabel('State $x_1$',fontsize=17)

~\anaconda3\lib\site-packages\matplotlib\figure.py in add_subplot(self, *args, **kwargs)
    782             projection_class, pkw = self._process_projection_requirements(
    783                 *args, **kwargs)
--> 784             ax = subplot_class_factory(projection_class)(self, *args, **pkw)
    785             key = (projection_class, pkw)
    786         return self._add_axes_internal(ax, key)

~\anaconda3\lib\site-packages\matplotlib\axes\_subplots.py in __init__(self, fig, *args, **kwargs)
     34         """
     35         # _axes_class is set in the subplot_class_factory
---> 36         self._axes_class.__init__(self, fig, [0, 0, 1, 1], **kwargs)
     37         # This will also update the axes position.
     38         self.set_subplotspec(SubplotSpec._from_subplot_args(fig, args))

~\anaconda3\lib\site-packages\matplotlib\_api\deprecation.py in wrapper(*args, **kwargs)
    469                 "parameter will become keyword-only %(removal)s.",
    470                 name=name, obj_type=f"parameter of {func.__name__}()")
--> 471         return func(*args, **kwargs)
    472 
    473     return wrapper

~\anaconda3\lib\site-packages\matplotlib\axes\_base.py in __init__(self, fig, rect, facecolor, frameon, sharex, sharey, label, xscale, yscale, box_aspect, **kwargs)
    632 
    633         self._rasterization_zorder = None
--> 634         self.cla()
    635 
    636         # funcs used to format x and y - fall back on major formatters

~\anaconda3\lib\site-packages\matplotlib\axes\_base.py in cla(self)
   1294         self.set_axis_on()
   1295 
-> 1296         self.xaxis.set_clip_path(self.patch)
   1297         self.yaxis.set_clip_path(self.patch)
   1298 

~\anaconda3\lib\site-packages\matplotlib\axis.py in set_clip_path(self, clippath, transform)
    916 
    917     def set_clip_path(self, clippath, transform=None):
--> 918         super().set_clip_path(clippath, transform)
    919         for child in self.majorTicks + self.minorTicks:
    920             child.set_clip_path(clippath, transform)

~\anaconda3\lib\site-packages\matplotlib\artist.py in set_clip_path(self, path, transform)
    778             if isinstance(path, Rectangle):
    779                 self.clipbox = TransformedBbox(Bbox.unit(),
--> 780                                                path.get_transform())
    781                 self._clippath = None
    782                 success = True

~\anaconda3\lib\site-packages\matplotlib\patches.py in get_transform(self)
    271     def get_transform(self):
    272         """Return the `~.transforms.Transform` applied to the `Patch`."""
--> 273         return self.get_patch_transform() + artist.Artist.get_transform(self)
    274 
    275     def get_data_transform(self):

~\anaconda3\lib\site-packages\matplotlib\patches.py in get_patch_transform(self)
    777         bbox = self.get_bbox()
    778         return (transforms.BboxTransformTo(bbox)
--> 779                 + transforms.Affine2D().rotate_deg_around(
    780                     bbox.x0, bbox.y0, self.angle))
    781 

~\anaconda3\lib\site-packages\matplotlib\transforms.py in rotate_deg_around(self, x, y, degrees)
   2000         # Cast to float to avoid wraparound issues with uint8's
   2001         x, y = float(x), float(y)
-> 2002         return self.translate(-x, -y).rotate_deg(degrees).translate(x, y)
   2003 
   2004     def translate(self, tx, ty):

~\anaconda3\lib\site-packages\matplotlib\transforms.py in rotate_deg(self, degrees)
   1978         and :meth:`scale`.
   1979         """
-> 1980         return self.rotate(math.radians(degrees))
   1981 
   1982     def rotate_around(self, x, y, theta):

~\anaconda3\lib\site-packages\matplotlib\transforms.py in rotate(self, theta)
   1966         rotate_mtx = np.array([[a, -b, 0.0], [b, a, 0.0], [0.0, 0.0, 1.0]],
   1967                               float)
-> 1968         self._mtx = np.dot(rotate_mtx, self._mtx)
   1969         self.invalidate()
   1970         return self

<__array_function__ internals> in dot(*args, **kwargs)

KeyboardInterrupt: 
<Figure size 1080x540 with 0 Axes>

Finally, let's load and visualize a trained model.

In [61]:
state_dict = torch.load('etc/trained_node.pkl')
node.load_state_dict(state_dict)
node.eval()

Xhat = node(tvdp, Yvdp[:,0]) # forward simulation
plot_ode(tvdp, Yvdp, node.ode_rhs, Xhat.detach())
No description has been provided for this image

2.4. Break: NN Differential Function and/or Adjoints¶

For this break, we have two suggestions to look into:

  • If you would like to play around with the differential function, go ahead and try out shallower/deeper nets, other activations, smaller/larger weight initializations, etc.
  • If you are more into theory, take a look at the adjoints, which are the ODEs that give us the gradients of an ODE system. You can read Section 2.1 of this tutorial or Sections 1 and 3 of this techical report for a derivation of adjoints.

3. Latent Bayesian Neural ODEs (ODEVAE)¶

All the ODE systems we investigated so far are defined in data space, i.e., the data and the differential equation system are defined in the same space. As an example that contradicts with this modeling choice, consider the video of a flying ball. The motion of the ball can surely be explained by an ODE; however, observations themselves (pixels) do not follow any ODE at all. To handle such cases, a reasonable modeling choice is to simultaneously learn an embedding of the videos into a latent space and learn a latent ODE system that explains the motion.

A suitable generative model for a given high-dimensional observed sequence $\mathbf{y}_{0:N}$ could be as follows:

\begin{align} \mathbf{z}_0 &\sim p(\mathbf{z}_0) \\ \mathbf{z}_n &= \mathbf{z}_0 + \int_0^{t_n} \mathbf{f}_\text{true}(\mathbf{z}_\tau) d\tau \\ \mathbf{y}_n &\sim p(\mathbf{y}_n | \mathbf{z}_n), \quad \forall n \in [0,N] \end{align}

where $\mathbf{z}_n$ corresponds to latent embedding for $\mathbf{y}_n$. The unknowns are

  • the initial value for each sequence
  • the latent dynamics
  • the observation mapping.

title

3.1. Variational Inference¶

As before, we propose to infer the unknown dynamics $\mathbf{f}_\text{true}$ by a NODE system $\mathbf{f}_\mathbf{w}$. This time, our goal is to maintain uncertainty estimates over both the initial value and ODE dynamics. For this, we resort to VI with the following approximations:

  • amortized inference (encoder) to approximate the initial value distribution $q(\mathbf{z}_0|\mathbf{y}_{0:N})$ for an input sequence $\mathbf{y}_{0:N}$
  • mean-field inference $q(\mathbf{w})$ for the dynamics parameters
  • a decoder $\mathbf{d}(\mathbf{z}_n)$ that gives the parameters of the observation mapping $p(\mathbf{y}_n | \mathbf{z}_n)$.

In turn, the resulting formulation becomes a hybrid ODE-VAE model. Our variational posterior factorizes as follows:

$$ q(\mathbf{z}_0,\mathbf{w}|\mathbf{y}_{0:N}) = q(\mathbf{z}_0|\mathbf{y}_{0:N}) ~ q(\mathbf{w}),$$

where both distributions are assumed to be Gaussian with diagonal covariance.
Remark-1: Extensions to multiple sequences would require variational posteriors for all initial values $\{\mathbf{z}_{0}^{(r)}\} _{r=1}^R$.

Remark-2: Our variational formulation corresponds to having a Bayesian NN differential function, i.e., BNODEs. The stocasticity of BNNs (meaning that each evaluation of a BNN on the same input would give a different output) violates ODE definition (which requires the differential function to be continuous). Therefore, our framework first draw samples from the differential function, and then uses the function draw(s) to solve ODE systems.

3.2. Evidence Lower-bound¶

Following the standard ELBO derivations, we end up at the following bound:

\begin{align} \log p(\mathbf{y}_{0:N}) \geq \sum_n \mathbb{E}_{q(\mathbf{z}_0,\mathbf{w}|\mathbf{y}_{0:N})}[\log p(\mathbf{y}_n|\mathbf{z}_0,\mathbf{w})] - \text{KL}(q(\mathbf{z}_0 | \mathbf{y}_{0:N}) || p(\mathbf{z}_0)) - \text{KL}(q(\mathbf{w}) || p(\mathbf{w})). \end{align}

Thanks to Gaussian posteriors, KL terms are tractable. The intractable expected log-likelihood is approximated by Monte Carlo sampling:

\begin{align} \mathbb{E}_{q(\mathbf{z}_0,\mathbf{w}|\mathbf{y}_{0:N})}[\log p(\mathbf{y}_{0:N} |\mathbf{z}_0,\mathbf{w})] \approx \frac{1}{L} \sum_{l=1}^L \sum_{n=0}^N \log p(\mathbf{y}_n|\mathbf{z}_0^{(l)},\mathbf{w}^{(l)}). \end{align}

The following procedure specifies how to compute the likelihood given the samples $\mathbf{z}_0^{(l)}$ and $\mathbf{w}^{(l)}$:

  1. Drawing an initial value and a vector field sample \begin{align} \mathbf{z}_0^{(l)} &\sim q(\mathbf{z}_0|\mathbf{y}_{0:N}) \\ \mathbf{w}^{(l)} &\sim q(\mathbf{w}) \end{align}

  2. Forward simulating \begin{align} \mathbf{z}_n^{(l)} = \mathbf{z}_0^{(l)} + \int_0^{t_n} \mathbf{f}_{\mathbf{w}^{(l)}}(\mathbf{z}_\tau)~d\tau \end{align}

  3. Decoding \begin{align} \mathbf{x}_n^{(l)} &\equiv \mathbf{d}(\mathbf{z}_n^{(l)}), \quad \forall n \in [0,N]. \end{align}

Remarks:

  1. We consider a mean-field approximation for differential function parameters.
  2. Initial value distribution $q(\mathbf{z}_0|\mathbf{y}_{0:N})$ is also a diagonal Gaussian whose mean and variance parameters are given by the encoder NN.
  3. The ELBO is jointly optimized wrt encoder, bnode and decoder parameters.

3.3. Example Dataset: Rotating MNIST¶

In the following example, our dataset consists of rotating MNIST digit 3. Since each pixel value is restricted to $[0,1]$, we opt for a Bernoulli observation model instead of Gaussian:

$$\log p(\mathbf{y} | \mathbf{x}) = \sum_n y_n\log x_n + (1-y_n)\log(1-x_n), \qquad \mathbf{x}=\mathbf{d}(\mathbf{z}),$$

where index $n$ denotes the observation dimensions (not the time index).

The following cell reads the dataset.

In [62]:
# we read 1042 sequences of length 16, where each observation is a 28x28 grey-scale image
Ymnist_tr, Ymnist_test = load_rotating_mnist(device) # [T,N,1,28,28]
plot_mnist_sequences(Ymnist_tr)

# let's create artificial time points corresponding to rotation angles <===> T=16
tmnist = 0.1*torch.arange(16).to(device) 
Plotting 5 rotating MNIST sequences.
No description has been provided for this image

3.4. Implementation¶

We now implement our ODEVAE class. If you would like to learn more about the encoder and decoder implementation details, check out vae_utils.py.

In [63]:
from torch.distributions import Normal, kl_divergence

class ODEVAE(nn.Module):
    def __init__(self, q, n_filt=16):
        ''' Inputs:
                q      - latent dimensionaliy
                n_filt - number of filters in the first CNN layer
        '''
        super().__init__()
        self.encoder  = MNIST_Encoder(q, n_filt)
        self.bnode    = BNN(n_in=q, n_out=q, n_hid_layers=2, n_hidden=100, act='elu')
        self.decoder  = MNIST_Decoder(q, n_filt)
        self.obs_loss = nn.BCELoss(reduction='sum')
        self.q        = q
        
    def forward(self, ts, Y, method='dopri5'):
        ''' Performs encoding, latent forward integration and decoding.
            Note that we always draw a single sample from the encoder to improve the readibility of our code.
            Inputs:
                ts - [T]           observation time points
                Y  - [T,N,1,28,28] input sequences
            Returns:
                q_z0_mu  - [N,q]           initial value means
                q_z0_sig - [N,q]           initial value std
                zt       - [T,N,q]       latent trajectoy
                Xhat     - [T,N,1,28,28] reconstructions
        '''
        [T,N,nc,d,d] = Y.shape
        # encode mean and variance
        q_z0_mu, q_z0_sig = self.encoder(Y) # N,q & N,q
        # sample differential function
        f = self.bnode.draw_f()
        ode_rhs = lambda t,x: f(x)
        # sample initial values
        z0 = q_z0_mu + q_z0_sig*torch.randn_like(q_z0_sig)
        # forward integrate
        zt = odeint(ode_rhs, z0, ts, method=method) # T,N,q
        # decode
        Xhat = self.decoder(zt) # T,N,nc,d,d
        return q_z0_mu, q_z0_sig, zt, Xhat

odevae = ODEVAE(q=8).to(device)

Now let's implement the ELBO.

In [64]:
def compute_elbo(odevae, ts, Y):
    ''' Computes the ELBO.
        Note that we always draw a single sample from the encoder to improve the readibility of our code.
        Inputs:
            ts - [T] observation time points
            Y  - [T,N,1,28,28] input sequences
        Returns:
            rec    - [] expected log likelihood
            kl_enc - [] the KL term due to z_0 
            kl_bnn - [] the KL term due to bnn weights w
    '''
    q_z0_mu, q_z0_sig, zt, Xhat = odevae(ts, Y)
    # reconstruction
    rec = -odevae.obs_loss(Xhat,Y)
    # KL divergence on z_0
    q_z0_mu, q_z0_sig = q_z0_mu.reshape(-1), q_z0_sig.reshape(-1)
    q = Normal(q_z0_mu,q_z0_sig)
    N = Normal(torch.zeros_like(q_z0_mu),torch.ones_like(q_z0_sig))
    kl_enc = kl_divergence(q,N).sum()
    # KL divergence on bnn weights
    kl_bnn = odevae.bnode.kl()
    return rec, kl_enc, kl_bnn

We finally train the model.

In [65]:
Nsub  = 25  # number of sequences in each minibatch
C     = Ymnist_tr.shape[0] / Nsub # scaling factor
Niter = 2000

optimizer = torch.optim.Adam(odevae.parameters(), 1e-3)

for i in range(Niter):
    optimizer.zero_grad()
    t_,Y_ = get_minibatch(tmnist, Ymnist_tr, Nsub=Nsub)
    rec, kl_enc, kl_bnn = compute_elbo(odevae, t_, Y_)
    rec  = rec*C 
    kl   = kl_enc*C + kl_bnn
    loss = -rec + kl
    loss.backward()
    optimizer.step()
    if i%25==0:
        with torch.no_grad():
            t_,Y_ = get_minibatch(tmnist, Ymnist_tr, Nsub=5)
            q_z0_mu, q_z0_sig, zt, Xhat = odevae(t_,Y_)
            display.clear_output(wait=True)
            plot_mnist_predictions(Y_, zt, Xhat)
No description has been provided for this image
Plotting 5 rotating MNIST sequences (top rows) and corresponding predictions (bottom).
No description has been provided for this image
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_22896/2566184225.py in <module>
     12     kl   = kl_enc*C + kl_bnn
     13     loss = -rec + kl
---> 14     loss.backward()
     15     optimizer.step()
     16     if i%25==0:

~\anaconda3\lib\site-packages\torch\_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    361                 create_graph=create_graph,
    362                 inputs=inputs)
--> 363         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    364 
    365     def register_hook(self, hook):

~\anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    171     # some Python versions print out the first line of a multi-line function
    172     # calls in the traceback and some print out the last line
--> 173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass

~\anaconda3\lib\site-packages\torch\autograd\function.py in apply(self, *args)
    241 
    242 class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
--> 243     def apply(self, *args):
    244         # _forward_cls is defined by derived class
    245         # The user should define either backward or vjp but never both.

KeyboardInterrupt: 

The following three cells import a trained model and then plot the training and test predictions. We first visualize PCA embeddings of the latent trajectories $\mathbf{z}_{0:N}$, where each color corresponds to the embedding of one sequence. Note that a single sample is drawn from the encoder and BNN, i.e., $L=1$. We then visualize five sequences and corresponding predictions.

In [30]:
# load a trained model
state_dict = torch.load('etc/trained_odevae.pkl')
odevae.load_state_dict(state_dict)
odevae.eval();

t_,Y_ = get_minibatch(tmnist, Ymnist_tr, Nsub=5)
q_z0_mu, q_z0_sig, zt, Xhat = odevae(t_,Y_)
plot_mnist_predictions(Y_, zt, Xhat)
No description has been provided for this image
Plotting 5 rotating MNIST sequences (top rows) and corresponding predictions (bottom).
No description has been provided for this image
In [31]:
t_,Y_ = get_minibatch(tmnist, Ymnist_test, Nsub=5)
q_z0_mu, q_z0_sig, zt, Xhat = odevae(t_,Y_)
plot_mnist_predictions(Y_, zt, Xhat)
No description has been provided for this image
Plotting 5 rotating MNIST sequences (top rows) and corresponding predictions (bottom).
No description has been provided for this image

15-MIN BREAK¶

My fav online radio: https://radyobozcaada.com/player/index.html

4. ResNets are Discretized ODEs¶

So far, we examined NODEs from a dynamical system standpoint. We showed that NODE is an instance of ODE models in which the differential function is a neural network. Thanks to their universal approximation guarantees, NODEs can approximate any ODE system.

Our presentation is orthogonal to the original NODE paper, which describes the model starting from Residual Networks (ResNets). ResNet is among the first "very deep" networks to solve classification problems. In a nutshell, ResNets consist of layers with skip connections, leading to following transformation of the hidden state $\mathbf{x}_n$ at layer $n$: $$ \mathbf{x}_{n+1} = \mathbf{x}_n + \mathbf{f}(\mathbf{x}_n;\theta_n),$$ where $\theta_n$ corresponds to the parameters at layer $n$. As we showed previously, this update equation is equivalent to computing ODE solutions with fixed time increments $\Delta t$: $$\mathbf{x}_{n+1} = \mathbf{x}_{n} + \Delta t \cdot \mathbf{f}(\mathbf{x}_{n},t_n;\theta),$$ Therefore, we can interpret ResNet as a rough approximation of NODEs with fixed time increments $\Delta t$. In the following, we show how ResNets can be trivially replaced by its ODE counterpart, dubbed as "ODE Networks". Since we use adaptive step ODE solvers, which can be evaluated at any point in time, ODENets are interpreted as infinitely deep.

title

Note: ResNets typically have different parameters $\theta_n$ at each layer $n$. A similar parameterization for NODEs can be achieved by explicitly parameterizing the differential function by time $n$.

4.1. Classification Objective¶

Now we formulate the classification objective. Given a dataset of images and labels $\{(\mathbf{x}_n,\mathbf{y}_n)\}_{n=1}^N$, we apply the following chain of transformations:

  • downsampling (to extract relevant features)
  • feature transformation (implemented by ResNets/ODENets)
  • fully connected layers (to map transformed features into class labels)

\begin{align} \min_\mathbf{w} ~~ \sum_n \texttt{cross_entropy}(\mathbf{y}_n,\hat{\mathbf{y}}_n) \qquad \text{s.t.} \qquad \hat{\mathbf{y}} = \mathbf{f}_{\text{fc}}(\mathbf{f}_{\text{trans}}(\mathbf{f}_{\text{down}}(\mathbf{x}_n))). \end{align}

4.2. Implementation¶

We start our implementation by residual networks.

In [66]:
class ResNet(nn.Module):
    def __init__(self, num_blocks, inplanes, planes, stride=1):
        super(ResNet, self).__init__()
        self.blocks = nn.Sequential(*[ResNetBlock(inplanes, planes) for _ in range(num_blocks)])
    
    def forward(self,x):
        return self.blocks(x)

class ResNetBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(ResNetBlock, self).__init__()
        self.net = nn.Sequential(group_norm(inplanes), 
                                nn.ReLU(inplace=True), 
                                conv3x3(inplanes, planes, stride), 
                                group_norm(planes), 
                                nn.ReLU(inplace=True),
                                conv3x3(planes, planes))

    def forward(self, x):
        shortcut = x
        net_out  = self.net(x)
        return net_out + shortcut

Next, we implement the neural ODE block. Similar to previous section, we only implement the differential function and forward integrate. Notable differences to the time series fitting example:

  • We are only interested in the final state of the ODE system (intermediate states are not important).
  • Integration time points are completely arbitrary.
  • Inside the differential function, we concatenate the states with the current time stamp. This way, we learn time-dependent and hence much more powerful differential functions (since the differential function evaluated at two different time points are different).
In [67]:
class NODE(nn.Module):
    def __init__(self, dim):
        super(NODE, self).__init__()
        self.norm1 = group_norm(dim)
        self.relu  = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = group_norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = group_norm(dim)
        self.integration_time = torch.tensor([0, 1]).float()

    def ode_rhs(self, t, x):
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out

    def forward(self, x, method='dopri5'):
        ''' Forward integrates the NODE system and returns state solutions
            Input
                x   - [N, num_filt, w, c] initial value
            Returns
                out - [N, num_filt, w, c] the final state of the ODE system
        '''
        self.integration_time = self.integration_time.type_as(x)
        # we solve the ODE system with less tolerance (bigger error) for faster computation
        out = odeint(self.ode_rhs, x, self.integration_time, method=method, rtol=1e-3, atol=1e-6)
        return out[-1]

    
class ConcatConv2d(nn.Module):
    ''' Convolutional layers that use current time stamp information '''

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(ConcatConv2d, self).__init__()
        self._layer = nn.Conv2d(dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, 
                                dilation=dilation, groups=groups, bias=bias)

    def forward(self, t, x):
        tt  = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)  
        return self._layer(ttx)

4.3. Training¶

Next, we create the downsampling, feature transformation and final classification layers.

In [68]:
trans_layer = 'odenet' # can be replaced with 'resnet'
num_filt    = 16

# downsampling
downsampling_layers = [
    nn.Conv2d(1, num_filt, 3, 1),
    group_norm(num_filt),
    nn.ReLU(inplace=True),
    nn.Conv2d(num_filt, num_filt, 4, 2, 1),
    group_norm(num_filt),
    nn.ReLU(inplace=True),
    nn.Conv2d(num_filt, num_filt, 4, 2, 1),
]

# feature transformation
if trans_layer=='odenet':
    feature_layers = NODE(num_filt)
else:
    feature_layers = ResNet(6, num_filt, num_filt)

    
# fully connected layer
fc_layers = [group_norm(num_filt), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(num_filt, 10)]

model = nn.Sequential(*downsampling_layers, feature_layers, *fc_layers).to(device)
print('Number of parameters: {}'.format(count_parameters(model)))
Number of parameters: 13674

We finally load the data and start training.

In [69]:
lr       = 0.1
niters   = 1000
batch_size  = 100
print_every = 10
test_every  = 100

train_loader, test_loader, train_eval_loader = mnist_loaders(batch_size)
data_gen = inf_generator(train_loader)

optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
criterion = nn.CrossEntropyLoss().to(device)

start_time = time.time()

for itr in range(1,niters):
    optimizer.zero_grad()
    x, y = data_gen.__next__()
    x = x.to(device)
    y = y.to(device)
    logits = model(x)
    loss   = criterion(logits, y)
    loss.backward()
    optimizer.step()
    
    # print the train trace
    if itr % print_every == 0:
        end_time = time.time()
        print("Iter {:04d} | Time {:.3f} | loss {:.4f}".format(itr, end_time-start_time, loss.item()))
        start_time = time.time()
        
    # print the test trace
    if itr % test_every == 0:
        with torch.no_grad():
            val_acc   = mnist_accuracy(model, device, test_loader)
            train_acc = mnist_accuracy(model, device, train_eval_loader)
            print("Iter {:04d} | Train Acc {:.4f} | Test Acc {:.4f}".format(itr, train_acc, val_acc))
            start_time = time.time()
Iter 0010 | Time 6.542 | loss 2.3096
Iter 0020 | Time 4.249 | loss 2.2972
Iter 0030 | Time 3.997 | loss 2.3011
Iter 0040 | Time 3.802 | loss 2.2980
Iter 0050 | Time 4.125 | loss 2.2560
Iter 0060 | Time 3.900 | loss 2.2598
Iter 0070 | Time 4.135 | loss 2.2707
Iter 0080 | Time 3.838 | loss 2.2191
Iter 0090 | Time 3.869 | loss 2.2175
Iter 0100 | Time 4.226 | loss 2.1295
Iter 0100 | Train Acc 0.2241 | Test Acc 0.2304
Iter 0110 | Time 4.251 | loss 2.1413
Iter 0120 | Time 4.707 | loss 2.0850
Iter 0130 | Time 4.607 | loss 1.9517
Iter 0140 | Time 4.918 | loss 1.9446
Iter 0150 | Time 4.886 | loss 1.8859
Iter 0160 | Time 5.123 | loss 1.7849
Iter 0170 | Time 4.808 | loss 1.8127
Iter 0180 | Time 3.504 | loss 1.6171
Iter 0190 | Time 5.056 | loss 1.4815
Iter 0200 | Time 4.927 | loss 1.3827
Iter 0200 | Train Acc 0.5300 | Test Acc 0.4915
Iter 0210 | Time 4.652 | loss 1.4271
Iter 0220 | Time 5.089 | loss 1.2138
Iter 0230 | Time 5.435 | loss 1.4058
Iter 0240 | Time 4.788 | loss 1.3113
Iter 0250 | Time 5.520 | loss 1.1210
Iter 0260 | Time 4.495 | loss 1.2944
Iter 0270 | Time 4.125 | loss 1.0757
Iter 0280 | Time 4.546 | loss 0.9276
Iter 0290 | Time 4.931 | loss 0.8121
Iter 0300 | Time 4.986 | loss 0.8438
Iter 0300 | Train Acc 0.7711 | Test Acc 0.7578
Iter 0310 | Time 5.618 | loss 0.6255
Iter 0320 | Time 5.369 | loss 0.6043
Iter 0330 | Time 4.268 | loss 0.9336
Iter 0340 | Time 4.155 | loss 0.6923
Iter 0350 | Time 3.723 | loss 0.5035
Iter 0360 | Time 4.180 | loss 0.5041
Iter 0370 | Time 4.250 | loss 0.5891
Iter 0380 | Time 4.120 | loss 0.5354
Iter 0390 | Time 4.771 | loss 0.4401
Iter 0400 | Time 3.975 | loss 0.3725
Iter 0400 | Train Acc 0.9137 | Test Acc 0.8963
Iter 0410 | Time 5.455 | loss 0.4050
Iter 0420 | Time 5.021 | loss 0.4614
Iter 0430 | Time 5.138 | loss 0.3602
Iter 0440 | Time 5.415 | loss 0.3874
Iter 0450 | Time 4.886 | loss 0.3347
Iter 0460 | Time 5.736 | loss 0.3265
Iter 0470 | Time 5.647 | loss 0.3185
Iter 0480 | Time 5.115 | loss 0.3498
Iter 0490 | Time 5.542 | loss 0.4387
Iter 0500 | Time 5.116 | loss 0.2702
Iter 0500 | Train Acc 0.9478 | Test Acc 0.9344
Iter 0510 | Time 5.081 | loss 0.2277
Iter 0520 | Time 4.992 | loss 0.2838
Iter 0530 | Time 5.295 | loss 0.1582
Iter 0540 | Time 4.998 | loss 0.1799
Iter 0550 | Time 5.261 | loss 0.1326
Iter 0560 | Time 4.771 | loss 0.1627
Iter 0570 | Time 5.035 | loss 0.2268
Iter 0580 | Time 4.707 | loss 0.1679
Iter 0590 | Time 5.040 | loss 0.2313
Iter 0600 | Time 4.861 | loss 0.3063
Iter 0600 | Train Acc 0.9611 | Test Acc 0.9530
Iter 0610 | Time 6.264 | loss 0.1455
Iter 0620 | Time 5.131 | loss 0.0863
Iter 0630 | Time 5.089 | loss 0.3072
Iter 0640 | Time 4.739 | loss 0.2373
Iter 0650 | Time 4.729 | loss 0.2867
Iter 0660 | Time 3.876 | loss 0.1067
Iter 0670 | Time 4.914 | loss 0.2129
Iter 0680 | Time 4.827 | loss 0.1498
Iter 0690 | Time 4.755 | loss 0.1544
Iter 0700 | Time 5.183 | loss 0.1400
Iter 0700 | Train Acc 0.9578 | Test Acc 0.9385
Iter 0710 | Time 4.593 | loss 0.1683
Iter 0720 | Time 4.276 | loss 0.1603
Iter 0730 | Time 4.959 | loss 0.1810
Iter 0740 | Time 5.126 | loss 0.1254
Iter 0750 | Time 5.013 | loss 0.1094
Iter 0760 | Time 4.967 | loss 0.1577
Iter 0770 | Time 4.986 | loss 0.1577
Iter 0780 | Time 4.987 | loss 0.1373
Iter 0790 | Time 5.078 | loss 0.1809
Iter 0800 | Time 4.979 | loss 0.2024
Iter 0800 | Train Acc 0.9681 | Test Acc 0.9556
Iter 0810 | Time 4.789 | loss 0.1155
Iter 0820 | Time 4.343 | loss 0.1994
Iter 0830 | Time 4.870 | loss 0.0731
Iter 0840 | Time 4.715 | loss 0.1313
Iter 0850 | Time 4.789 | loss 0.1689
Iter 0860 | Time 5.046 | loss 0.1357
Iter 0870 | Time 5.221 | loss 0.1015
Iter 0880 | Time 4.790 | loss 0.2174
Iter 0890 | Time 4.182 | loss 0.0859
Iter 0900 | Time 3.657 | loss 0.1786
Iter 0900 | Train Acc 0.9715 | Test Acc 0.9670
Iter 0910 | Time 3.574 | loss 0.1121
Iter 0920 | Time 4.230 | loss 0.1827
Iter 0930 | Time 5.000 | loss 0.2199
Iter 0940 | Time 4.898 | loss 0.0507
Iter 0950 | Time 4.963 | loss 0.0881
Iter 0960 | Time 4.880 | loss 0.1290
Iter 0970 | Time 5.087 | loss 0.1222
Iter 0980 | Time 4.660 | loss 0.1613
Iter 0990 | Time 5.193 | loss 0.0820

4.4. Break: ODE Solver Parameters¶

Our continuous-time classification algorithm relies on solving an intermediate ODE system. To solve the ODE system, we use an adaptive step ODE solver named dopri5 (RK45). Just like the Euler method, dopri5 takes a finite number of steps to compute the state solutions but this time the step size $\Delta t$ adaptively changes at every step. Notice that taking small steps (= small $\Delta t$) leads to more accurate solutions, at the expense of taking more steps (= higher execution time).

Adaptive step solvers control $\Delta t$ based on the local error made due to discretization. As you can see in the NODE.forward(), we input the ODE solver with tolerance values (rtol and atol). Roughly speaking, these values control "how much error we can live with". In this break, you can study how changing these values affect the execution time and overall performance (please do not forget to run all three cells above for testing). You can also check out scipy RK45 function to learn more about the tolerances.

5. Continuous-time Normalizing Flows¶

5.1. Normalizing Flows¶

Next, we turn to second application of NODEs, namely continuous-time normalizing flows (CNFs). NFs is a method to build complex distributions by transforming a simple probability distribution through a fixed number of invertible mappings. In the context of NFs, the change of variables theorem describes the change in the probability of a random variable $z_0$ upon a deterministic transformation $f(z_0)$: $$ z_1 = f(z_0), \qquad \log p(z_1) = \log p(z_0) - \log \left| \text{det} \frac{\partial f}{\partial z_0} \right|. $$

To increase the expressiveness of the transformations, we chain several transformations: $$ z_K = f_{K-1} \circ f_{K-2} \circ \ldots \circ f_0 (z_0), \qquad \log p(z_K) = \log p(z_0) - \sum_{k=0}^{K-1} \log \left| \text{det} \frac{\partial f_k}{\partial z_{k}} \right|. $$

Here, we transform samples $z_0 \sim p(z_0)$ from a simple base distribution (such as standard Gaussian) into a more complex distribution $p(z_K)$. Once the transformations are known, we can compute any expectation $\mathbb{E}_{p(z_K)}[h(z_K)]$ as follows:

$$\mathbb{E}_{p(z_K)}[h(z_K)] = \mathbb{E}_{p(z_0)}[h(f_{K-1} \circ f_{K-2} \circ \ldots \circ f_0 (z_0))]$$

5.2. Continuous-time Normalizing Flows¶

Similar to ODENets, CNFs is a continuous-time counterpart of standard NFs. Meaning, a simple base distribution $p(z(t_0))$ is transformed into a more complex $p(z(t_1))$ via an ODE flow:

title

Since we replaced a finite set of transformations with an ODE, the above formula for the change in density no longer applies. Here, so-called instantaneous change of variables theorem kicks in. Given a random variable $z(t)$ whose probability depends on time $p(z(t))$, the change in log probability due to a continuous-time transformation $\frac{dz}{dt} = f(z(t),t)$ has the following expression: $$ \frac{\partial \log p(z(t))}{\partial t} = -\text{tr} \left( \frac{df}{dz(t)} \right), $$ where $\text{tr}$ refers to the trace operator. Then we have:

$$ \log p(z(t_1)) = \log p(z(t_0)) - \int_{t_0}^{t_1} \text{tr} \left( \frac{df}{dz(\tau)} \right) d\tau $$

for the following ODE system:

$$ z(t_1) = z(t_0) + \int_{t_0}^{t_1} f(t,z(\tau)) d\tau$$

Below is the implementation of the trace operator:

In [70]:
def trace_df_dz(f, z):
    """Calculates the trace of the Jacobian df/dz.
    Stolen from: https://github.com/rtqichen/ffjord/blob/master/lib/layers/odefunc.py#L13
    Input:
        f - function output [N,d]
        z - current state [N,d]
    Returns:
        tr(df/dz) - [N]
    """
    sum_diag = 0.
    for i in range(z.shape[1]):
        sum_diag += torch.autograd.grad(f[:, i].sum(), z, create_graph=True)[0].contiguous()[:, i].contiguous()
    return sum_diag.contiguous()

5.3. Implementation¶

Now we implement CNF. Our implementation heavily resembles previous NODE implementation. However, this time we concurrently compute the ODE state solutions and the log density change, both needed to optimize the flow. Consquently, the differential function takes current state and density as input and computes the time derivatives (differential function + the trace operator).

In our implementation, we consider the following time-dependent flow: $$ \frac{dz(t)}{dt} = f(t,z(t)) = U_t ~\text{h}(W_t z(t) + b_t),$$ where h is a non-linear function and the parameters $(U_t, W_t, b_t)$ are given by a neural network. Please see hyper_net.py for the implementation.

Remark-1: Any parameterized function can replace self.f below. We choose to use HyperNetwork as in the original NODE github repo.
Remark-2: Unlike NFs, CNFs do not necessitate $f$ to be bijective since we can backward integrate the ODE system.

In [71]:
from hyper_net import HyperNetwork

class CNF(nn.Module):
    """Adapted from the NumPy implementation at:
    https://gist.github.com/rtqichen/91924063aa4cc95e7ef30b3a5491cc52
    """
    def __init__(self, in_out_dim, hidden_dim, width):
        super().__init__()
        self.f = HyperNetwork(in_out_dim, hidden_dim, width)

    def ode_rhs(self, t, states):
        ''' Differential function implementation. states is (x1,logp_diff_t1) where
                x1 - [N,d] initial values for ODE states
                logp_diff_t1 - [N,1] initial values for density changes
        '''
        z,logp_z = states # [N,d], [N,1]
        N = z.shape[0]
        with torch.set_grad_enabled(True):
            z.requires_grad_(True)
            dz_dt      = self.f(t,z) # [N,d] 
            dlogp_z_dt = -trace_df_dz(dz_dt, z).view(N, 1)
        return (dz_dt, dlogp_z_dt)
    
    def forward(self, ts, z0, logp_diff_t0, method='dopri5'):
        ''' Forward integrates the CNF system. Returns state and density change solutions.
            Input
                ts - [T]   time points
                z0 - [N,d] initial values for ODE states
                logp_diff_t0 - [N,1] initial values for density changes
            Retuns:
                zt -     [T,N,...]  state trajectory computed at t
                logp_t - [T,N,1]    density change computed over time
        '''
        zt, logp_t = odeint(self.ode_rhs, (z0, logp_diff_t0), ts, method=method)
        return zt, logp_t 

5.4. Training¶

Next, we visualize the dataset (samples from the target density) and train the model.

In [72]:
# data generation
Ntrain = 10000

def get_batch(num_samples):
    points, _ = make_circles(n_samples=num_samples, noise=0.06, factor=0.5)
    return torch.tensor(points).type(torch.float32).to(device) # N,2
tr_data = get_batch(Ntrain)

plot_cnf_data(tr_data)
No description has been provided for this image
In [73]:
# model and flow parameters
hidden_dim = 32
width      = 64
t0 = 0  # flow start time
t1 = 1  # flow end time

# optimization parameters
lr     = 3e-3
niters = 1000
Nsamp  = 100
print_every = 25

# model
cnf  = CNF(in_out_dim=2, hidden_dim=hidden_dim, width=width).to(device)
ts   = torch.tensor([t1, t0]).type(torch.float32).to(device) # for training, we flow the samples backward (in time) 
p_z0 = torch.distributions.MultivariateNormal(
    loc=torch.tensor([0.0, 0.0]).to(device),
    covariance_matrix=torch.tensor([[0.1, 0.0], [0.0, 0.1]]).to(device)
)

optimizer = torch.optim.Adam(cnf.parameters(), lr=lr)
for itr in range(1, niters+1):
    optimizer.zero_grad()

    # get a random sample minibatch
    idx = torch.randperm(Ntrain)[:Nsamp]
    x1  = tr_data[idx] # Nsamp,2
    
    # initialize initial densities
    logp_diff_t1 = torch.zeros(Nsamp, 1).type(torch.float32).to(device)
    
    # compute the backward solutions
    z_t,  logp_diff_t  = cnf(ts, x1, logp_diff_t1) # outputs time first
    z_t0, logp_diff_t0 = z_t[-1], logp_diff_t[-1]
    
    # compute the density of each sample
    logp_x = p_z0.log_prob(z_t0).to(device) - logp_diff_t0.view(-1)
    loss   = -logp_x.mean(0)
    loss.backward()
    optimizer.step()
    
    if itr%print_every==0:
        print('Iter: {}, loss: {:.4f}'.format(itr, loss.item()))

print('Training complete after {} iters.'.format(itr))
Iter: 25, loss: 1.6009
Iter: 50, loss: 1.5427
Iter: 75, loss: 1.5508
Iter: 100, loss: 1.5496
Iter: 125, loss: 1.5721
Iter: 150, loss: 1.5380
Iter: 175, loss: 1.6878
Iter: 200, loss: 1.5414
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_22896/3529732314.py in <module>
     31 
     32     # compute the backward solutions
---> 33     z_t,  logp_diff_t  = cnf(ts, x1, logp_diff_t1) # outputs time first
     34     z_t0, logp_diff_t0 = z_t[-1], logp_diff_t[-1]
     35 

~\anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

~\AppData\Local\Temp/ipykernel_22896/769506971.py in forward(self, ts, z0, logp_diff_t0, method)
     32                 logp_t - [T,N,1]    density change computed over time
     33         '''
---> 34         zt, logp_t = odeint(self.ode_rhs, (z0, logp_diff_t0), ts, method=method)
     35         return zt, logp_t

~\anaconda3\lib\site-packages\torchdiffeq\_impl\odeint.py in odeint(func, y0, t, rtol, atol, method, options, event_fn)
     75 
     76     if event_fn is None:
---> 77         solution = solver.integrate(t)
     78     else:
     79         event_t, solution = solver.integrate_until_event(t[0], event_fn)

~\anaconda3\lib\site-packages\torchdiffeq\_impl\solvers.py in integrate(self, t)
     28         self._before_integrate(t)
     29         for i in range(1, len(t)):
---> 30             solution[i] = self._advance(t[i])
     31         return solution
     32 

~\anaconda3\lib\site-packages\torchdiffeq\_impl\rk_common.py in _advance(self, next_t)
    192         while next_t > self.rk_state.t1:
    193             assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps)
--> 194             self.rk_state = self._adaptive_step(self.rk_state)
    195             n_steps += 1
    196         return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t)

~\anaconda3\lib\site-packages\torchdiffeq\_impl\rk_common.py in _adaptive_step(self, rk_state)
    253         # trigger both. (i.e. interleaving them would be wrong.)
    254 
--> 255         y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, t1, tableau=self.tableau)
    256         # dtypes:
    257         # y1.dtype == self.y0.dtype

~\anaconda3\lib\site-packages\torchdiffeq\_impl\rk_common.py in _runge_kutta_step(func, y0, f0, t0, dt, t1, tableau)
     66     k = _UncheckedAssign.apply(k, f0, (..., 0))
     67     for i, (alpha_i, beta_i) in enumerate(zip(tableau.alpha, tableau.beta)):
---> 68         if alpha_i == 1.:
     69             # Always step to perturbing just before the end time, in case of discontinuities.
     70             ti = t1

KeyboardInterrupt: 

Let's visualize how the flow behaves over time:

In [41]:
# load the pre-trained model
state_dict = torch.load('etc/trained_cnf.pkl')
cnf.load_state_dict(state_dict)
cnf.eval()

# samples
viz_samples   = 30000
viz_timesteps = 41
target_sample = get_batch(viz_samples)

# simulate the flow
with torch.no_grad():
    # Generate evolution of samples
    z_t0 = p_z0.sample([viz_samples]).to(device)
    logp_diff_t0 = torch.zeros(viz_samples, 1).type(torch.float32).to(device)

    ts = torch.tensor(np.linspace(t0, t1, viz_timesteps)).to(device)
    z_t_samples, _  = cnf(ts, z_t0, logp_diff_t0)

    # Generate evolution of density
    x = np.linspace(-1.5, 1.5, 100)
    y = np.linspace(-1.5, 1.5, 100)
    points = np.vstack(np.meshgrid(x, y)).reshape([2, -1]).T
    
    z_t1 = torch.tensor(points).type(torch.float32).to(device)
    logp_diff_t1 = torch.zeros(z_t1.shape[0], 1).type(torch.float32).to(device)
    ts = torch.tensor(np.linspace(t1, t0, viz_timesteps)).to(device)
    z_t_density, logp_diff_t = cnf(ts, z_t1, logp_diff_t1)

anim = plot_cnf_animation(target_sample, t0, t1, viz_timesteps, p_z0, z_t1, z_t_samples, z_t_density, logp_diff_t)
display.HTML(anim.to_jshtml())
Out[41]:
No description has been provided for this image

5.5. Break: Wrap-off¶

This is our last break in this tutorial. Use this time to go through the entire lecture material to get ready for the upcoming Q&A session. Alternatively, you can dive deeper into CNFs, e.g., by changing the differential function, integration length, etc.

6. Related Studies¶

6.1. ODE-RNN [rubanova2019latent]¶

Vertical lines show observation times. Standard RNNs have constant or undefined hidden states between observations. States of Neural ODE follow a complex trajectory but are determined by the initial state. The ODE-RNN model has states which obey an ODE between observations, and are also updated at observations.

title

6.2. ODE$^2$VAE [yildiz2019deep]¶

The model explicitly decomposes the latent space into momentum and position components and solves a second order ODE system. Latent ODE dynamics parameterized by deep Bayesian neural networks.

title

6.3. Augmented NODEs [dupont2019augmented]¶

NODEs learn homeomorphisms so that the features of Neural ODEs preserve the topology of the input space. This implies that NODEs can only continuously deform the input space and cannot for example tear a connected region apart. The paper introduces Augmented Neural ODEs (ANODEs), which augment the space on which the ODE is solved, allowing the model to use the additional dimensions to learn more complex functions using simpler flows.

alt alt

6.4. Regularized NODEs [finlay2020train]¶

Adaptive numerical ODE solvers could take very small steps, which in practice leads to dynamics equivalent to many hundreds of layers. Levaraging the connections with optimal transport, the paper shows that regularizing the Jacobian of the differential function leads to simpler vector fields that can be trained much faster while still achieving the same accuracy.

title

6.5. ACA [zhuang2020adaptive]¶

The adjoint method has numerical errors in reverse-mode integration (for gradient computation) as can be seen from the below images. The paper presents the Adaptive Checkpoint Adjoint (ACA) method: in automatic differentiation, ACA applies a trajectory checkpoint strategy which records the forward-mode trajectory as the reverse-mode trajectory to guarantee accuracy.

title

6.6. ODE-RL [yildiz2021continuous]¶

Model-based reinforcement learning (MBRL) approaches rely on discrete-time state transition models whereas physical systems and the vast majority of control tasks operate in continuous-time. This work presents a new perspective for RL in which the dynamics are approximated by an ensemble of NODEs. The authors also introduce a novel actor-critic algorithm for policy learning to address the fact that Q-functions vanish in continuous time.

title

6.7. NSDEs [tzen2019neural], [xu2022infinitely]¶

Deep latent Gaussian models combine deterministic transformation of random variables with small independent Gaussian perturbation. This paper shows that if the number of layers tends to infinity, the limiting latent object is an Itô diffusion process that solves a stochastic differential equation (SDE). [tzen2019neural] develops a variational inference framework for these neural SDEs via stochastic automatic differentiation in Wiener space. Later [xu2022infinitely] defines a stochastic process (NSDE) on the weights of a BNN, leading to so-called "infinitely deep BNNs".

title

6.8. GP-ODEs [hegde2022variational]¶

A novel Bayesian nonparametric model that uses Gaussian processes to infer posteriors of unknown ODE systems. The method uses sparse variational inference with decoupled functional sampling to represent vector field posteriors.

title

References¶

URLs in blue.

[chen2018neural] Chen, Ricky TQ, et al. "Neural ordinary differential equations." NeurIPS (2018).

[rubanova2019latent] Rubanova, Y., Chen, R. T., and Duvenaud, D. "Latent odes for irregularly-sampled time series (2019)." arXiv preprint arXiv:1907.03907 (1907).

[yildiz2019deep] Yildiz, C., Heinonen, M., and Lähdesmäki, H. "ODE2VAE: Deep generative second order ODEs with Bayesian neural networks." NeurIPS (2019).

[tzen2019neural] Tzen, B., & Raginsky, M. (2019). Neural stochastic differential equations: Deep latent gaussian models in the diffusion limit. arXiv preprint arXiv:1905.09883.

[dupont2019augmented] Dupont, E., Doucet, A., & Teh, Y. W. (2019). Augmented neural odes. Advances in Neural Information Processing Systems, 32.

[finlay2020train] Finlay, C., Jacobsen, J. H., Nurbekyan, L., & Oberman, A. (2020, November). How to train your neural ODE: the world of Jacobian and kinetic regularization. In International conference on machine learning (pp. 3154-3164). PMLR.

[zhuang2020adaptive] Zhuang, J., Dvornek, N., Li, X., Tatikonda, S., Papademetris, X., & Duncan, J. (2020, November). Adaptive checkpoint adjoint method for gradient estimation in neural ode. In International Conference on Machine Learning (pp. 11639-11649). PMLR.

[yildiz2021continuous] Yildiz, C., Heinonen, M., & Lähdesmäki, H. (2021, July). Continuous-time Model-based Reinforcement Learning. In International Conference on Machine Learning (pp. 12009-12018). PMLR.

[xu2022infinitely] Xu, W., Chen, R. T., Li, X., & Duvenaud, D. (2022, May). Infinitely deep bayesian neural networks with stochastic differential equations. In International Conference on Artificial Intelligence and Statistics (pp. 721-738). PMLR.

[hegde2022variational] Hegde P., Yildiz, C., Lähdesmäki, H., Kaski, S., & Heinonen, M. (2022, February). Variational multiple shooting for Bayesian ODEs with Gaussian processes. In The 38th Conference on Uncertainty in Artificial Intelligence.