Neural ODEs with PyTorch Lightning and TorchDyn | by Michael Poli | Dec, 2020
Neural Differential Equations inference is typically slower than comparable discrete neural networks, since these continuous models come with the additional overhead of solving a differential equation. Various approaches have been proposed to alleviate these limitations e.g regularizing the vector field such that is easier to solve. In practice, an easier to solve NDE translates into a smaller number of function evaluations — with adaptive-step solvers — or in other words less calls to the neural network parametrizing the vector field.
However, regularizing the vector field may not always be an option, particularly in situations where the differential equation is partially known and specified a prior — as is the case for example in control applications.
The framework of hypersolvers considers instead the interplay between model and solver, analyzing ways in which the solver and the training strategies can be adapted to maximize NDE speedups.
In their simplest form, hypersolvers take a base solver formulation and enhance its performance on an NDE with an additional learning component, trained by approximating local errors of the solver. Other techniques are available, for example, adversarial training to exploit base solver weaknesses and aid in its generalization across dynamics.
We will use this HyperEuler implementation to speed up inference of the Neural ODE of the previous section. Here, we will use as hypersolver network
g a tiny neural network made up of a single linear layer:
Do not be alarmed by the high number of epochs. These will be simple training iterations, since we are doing full-batch training.
After training, we can visualize the flows of the Neural ODE solved with our hypersolver variant of the Euler method,
HyperEuler , and verify whether they are the same as those solved with the adaptive — step method:
At the cost of only 12 additional parameters, we are able to successfully reduce the number of function evaluations to less than 20, starting from 40+. Not bad!
Read More …