PyTorch: Load and Predict; Towards Simple and Standard Inference API
PyTorch is an open source Deep Learning framework that accelerates the path from research prototyping to production deployment.
After training a model, and in many situations, it is required to have a simple and standard way to load the model and make predictions (inference). For example, we need to load a model for image classifications, and then conduct several predictions on images that may come from live camera, or offline files on a storage and so on.
Naturally, the inference API should be simple and standard: 1) Load the model once; 2) Make several predictions
Inference API
In this article, we present a simple inference API for predicting images based on the well known CIFAR10 data set. In just few lines of code, here is how to loads the trained model, load some test data, and then make predictions.
from CIFAR10_clf import *
# Load the CIFAR10 classifier for inference
clf = CIFAR10_clf()
clf.load("cifar10_model")
# load test data
images = torch.load("test_images")
labels = torch.load("test_labels")
# Make predictions
predicted = clf.predict(images)
Test Samples
Prediction Results
- GroundTruth: cat ship ship plane frog frog car frog
- Predicted: cat ship ship plane (deer) frog car frog
It seems that the loaded model is performing relatively well in the test data and made one mistake (predicted deer instead of frog).
Training and saving the model
The model can be trained according to this well documented tutorial. Simply, the trained model is saved:
torch.save(net.state_dict(), "cifar10_model")
Model Inference Class
Here is the code to wrap the model to be easily used for inference as shown above.
import torch
import torch.nn as nn
import torch.nn.functional as F
class CIFAR10_clf(nn.Module):def __init__(self):
super(CIFAR10_clf, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def load(self, model_path):
self.load_state_dict(torch.load(model_path))
self.eval()
def predict(self, images):
outputs = self(images)
_, predicted = torch.max(outputs, 1)
return predicted
This looks very familiar PyTorch code. Two functions are added: Load and Predict. Now it is possible to use the model for production similar to scikit.
# Load the CIFAR10 classifier for inference
clf = CIFAR10_clf()
clf.load("cifar10_model")
# Make predictions
predicted = clf.predict(images)
In this way, model usage for inference is simplified and prediction can be applied in many different scenarios (online predictions, offline predictions, normal python code, notebook tutorials, client/sever, ...). Moreover, this simple approach isolates the model design and training details from the user.
Making predictions should be simple and standard for developers and researchers.
Regards