How to deploy PyTorch Lightning models to production
By Caleb Kaiser, Cortex Labs
Looking on the machine learning panorama, one of many main tendencies is the proliferation of initiatives centered on making use of software program engineering rules to machine learning. Cortex, for instance, recreates the expertise of deploying serverless capabilities, however with inference pipelines. DVC, equally, implements trendy model management and CI/CD pipelines, however for ML.
PyTorch Lightning has an analogous philosophy, solely utilized to coaching. The frameworks supplies a Python wrapper for PyTorch that lets data scientists and engineers write clear, manageable, and performant coaching code.
As individuals who constructed an total deployment platform partly as a result of we hated writing boilerplate, we’re big followers of PyTorch Lightning. In that spirit, I’ve put collectively this information to deploying PyTorch Lightning models to production. In the method, we’re going to take a look at just a few totally different choices for exporting PyTorch Lightning models for inclusion in your inference pipelines.
Every method to deploy a PyTorch Lightning mannequin for inference
There are 3 ways to export a PyTorch Lightning mannequin for serving:
- Saving the mannequin as a PyTorch checkpoint
- Converting the mannequin to ONNX
- Exporting the mannequin to Torchscript
We can serve all three with Cortex.
1. Package and deploy PyTorch Lightning modules instantly
Starting with the only method, let’s deploy a PyTorch Lightning mannequin with none conversion steps.
The PyTorch Lightning Trainer, a category which abstracts boilerplate coaching code (suppose coaching and validation steps), has a builtin save_checkpoint() operate which is able to save your mannequin as a .ckpt file. To save your mannequin as a checkpoint, merely add this code to your coaching script:
Now, earlier than we get into serving this checkpoint, it’s necessary to be aware that whereas I preserve saying “PyTorch Lightning model,” PyTorch Lightning is a wrapper round PyTorch — the challenge’s README actually says “PyTorch Lightning is just organized PyTorch.” The exported mannequin, due to this fact, is a standard PyTorch mannequin, and could be served accordingly.
With a saved checkpoint, we will serve the mannequin fairly simply in Cortex. If you’re unfamiliar with Cortex, you may familiarize your self shortly right here, however the easy overview of the deployment course of with Cortex is:
- We write a prediction API for our mannequin in Python
- We outline our APIs infrastructure and habits in YAML
- We deploy the API with a command from the CLI
Our prediction API will use Cortex’s Python Predictor class to outline an init() operate to initialize our API and cargo the mannequin, and a predict() operate to serve predictions when queried:
Pretty easy. We repurpose some code from our coaching code, add slightly inference logic, and that’s it. One factor to be aware is that in the event you add your mannequin to S3 (beneficial), you’ll want to add some logic for accessing it.
Next, we configure our infrastructure in YAML:
Again, easy. We give our API a reputation, inform Cortex the place our prediction API is, and allocate some CPU.
Next, we deploy it:
Note that we will additionally deploy to a cluster, spun up and managed by Cortex:
With all deployments, Cortex containerizes our API and exposes it as an online service. With cloud deployments, Cortex configures load balancing, autoscaling, monitoring, updating, and plenty of different infrastructure options.
And that’s it! We now have a dwell net API serving predictions from our mannequin on request.
2. Export to ONNX and serve through ONNX Runtime
Now that we’ve deployed a vanilla PyTorch checkpoint, lets complicate issues a bit.
PyTorch Lightning not too long ago added a handy abstraction for exporting models to ONNX (beforehand, you could possibly use PyTorch’s built-in conversion capabilities, although they required a bit extra boilerplate). To export your mannequin to ONNX, simply add this little bit of code to your coaching script:
Note that your enter pattern ought to mimic the form of your precise mannequin enter.
Once you’ve exported an ONNX mannequin, you may serve it utilizing Cortex’s ONNX Predictor. The code will principally look the identical, and the method is similar. For instance, that is an ONNX prediction API:
Basically the identical. The solely distinction is that as a substitute of initializing the mannequin instantly, we entry it by means of the onnx_client, which is an ONNX Runtime container Cortex spins up for serving our mannequin.
Our YAML additionally seems fairly related:
I added a monitoring flag right here simply to present how simple it’s to configure, and there are some ONNX particular fields, however in any other case it’s the identical YAML.
Finally, we deploy through the use of the identical $ cortex deploy command as earlier than, and our ONNX API is dwell.
3. Serialize with Torchscript’s JIT compiler
For a closing deployment, we’re going to export our PyTorch Lightning mannequin to Torchscript and serve it utilizing PyTorch’s JIT compiler. To export the mannequin, merely add this to your coaching script:
The Python API for that is practically similar to the vanilla PyTorch instance:
The YAML stays the identical as earlier than, and the CLI command after all is constant. If we wish, we will truly replace our earlier PyTorch API to use the brand new mannequin by merely changing our outdated predictor.py script with the brand new one, and operating$ cortex deploy once more:
Cortex routinely performs a rolling replace right here, during which a brand new API is spun up after which swapped with the outdated API, stopping any downtime between mannequin updates.
And that’s all there may be to it. Now you’ve a completely operational prediction API for realtime inference, serving predictions from a Torchscript mannequin.
So, which methodology do you have to use?
The apparent query right here is which methodology performs finest. The fact is that there isn’t an easy reply right here, because it is dependent upon your mannequin.
For Transformer models like BERT and GPT-2, ONNX can provide unbelievable optimizations (we measured a 40x enchancment in throughput on CPUs). For different models, Torchscript seemingly performs higher than vanilla PyTorch — although that too comes with some caveats, as not all models export to Torchscript cleanly.
Fortunately, with how simple it’s to deploy utilizing any choice, you may check all three in parallel and see which performs finest to your specific API.
Original. Reposted with permission.