= 1 # Random batch size
batch_size = torch.rand((batch_size, 28, 28))
example_inputs
= torch.onnx.export(torch_model,
onnx_program
example_inputs,=['input'],
input_names=['output'],
output_names= { # variable input/output: first dimension, corresponding to batch size
dynamic_axes 'input' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'}
},="converted_model.onnx",
f=True,
export_params=True # Optimization
do_constant_folding )
At the end of this tutorial, we will have a running deployment of an ONNX model on Hugging Face Spaces:
What is ONNX?
There exist many different deep learning frameworks, across many different programming language. ONNX is a standard that defines a common set of building blocks and file format so that no matter what technology is used to train a model, when it is rendered to ONNX, it can be deployed virtually anywhere.
Converting to ONNX
In this section I assume that you have a trained model called torch_model
. If you don’t have one, or don’t know how to train your own model yet, this post explains how to build and train your own model.
Let’s export our model to ONNX format. Since onnx.export
runs the model, we need to supply an example input. Furthermore If we don’t want the batch size to be stationary, we need to set the dynamic_axes
parameter.
Running The ONNX Model
We can run any ONNX model inside python, using the onnxruntime
package. Let’s run our own model that we just exported by downloading it from converted_model.onnx
.
import onnx
= onnx.load("converted_model.onnx")
model # If this does not raise an error, we can continue onnx.checker.check_model(model)
Since ONNX does not support all data types that PyTorch uses, we need to do a bit of pre-processing before we can actually run the model.
import onnxruntime as ort
import numpy as np
# Define a 'session', which will run the model
= ort.InferenceSession("converted_model.onnx", providers=["CPUExecutionProvider"])
ort_session
# The function that will convert PyTorch inputs to ONNX inputs
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# Sample image
= datasets.FashionMNIST(
training_data ="data",
root=True,
train=True,
download=transforms.ToTensor()
transform
)
= DataLoader(training_data, batch_size=64, shuffle=True)
train_dataloader
= next(iter(train_dataloader))
X, y = X[0]
testing_image = y[0]
testing_image_label
= ort_session.get_inputs()[0].name # We specified this in our onnx_program's input_names parameter
input_name = to_numpy(testing_image)
input_values = {input_name: input_values}
ort_inputs
= ort_session.run(None, ort_inputs) ort_outputs
I specified None
for the output_names parameter of the .run()
method. This computes all outputs. If we had specific outputs defined in onnx_program
’s output_names
, we could pass them here in a list. In this case, running .run(output_names=['output'])
would result in the same output.
Sanity Checking Output
Using ONNX comes with the advantage of a standardized model format which we can run anywhere, but it still needs to give the same output as the model that we trained using PyTorch. Let’s make sure that nothing went wrong during conversion by comparing the PyTorch model output and the ONNX model output on the same input:
= torch_model(testing_image)
torch_outputs = to_numpy(torch_outputs)
torch_outputs
0], torch_outputs, rtol=1e-03, atol=1e-05) # No error: good to go! np.testing.assert_allclose(ort_outputs[
How To Get ONNX predicted labels?
We can get the labels back by using np.argmax
:
= {
idx_to_class 0: 'T-shirt/top',
1: 'Trouser',
2: 'Pullover',
3: 'Dress',
4: 'Coat',
5: 'Sandal',
6: 'Shirt',
7: 'Sneaker',
8: 'Bag',
9: 'Ankle boot'
}= np.argmax(ort_outputs[0], axis=1).item()
label_index = idx_to_class[label_index] class_label
When using a pre-trained model from torchvision.models, we can retrieve the class label through weights.meta["categories"]
. E.g ResNet50_Weights.meta["categories"][label_index]
Deploying ONNX models
In this section we will deploy our model to HuggingFace Spaces
For demonstration purposes I will be using a ResNet-50 with default weights, saved as a ‘.onnx’ file.
Follow these steps to get started:
Create an account
Create an account at Hugging Face if you don’t have one.
Create a new space
Select ‘Gradio’ as the Space SDK. Gradio is a high-level API that generates a UI for machine learning models with very few code.
Generate a password for the Space
Go to Settings > Access tokens, and scroll down to ‘Repositories permissions’. Select your space and click the write permissions.
Push the app
We only need to specify 2 functions to create a UI and do inference. The
predict
function, and a preprocessing function. The last depends on the model that you are using. PyTorch pre-trained models also have their required preprocessing made available through{weights_name.VERSION}.transforms()
:
Show imports
import numpy as np
import onnxruntime as ort
import gradio as gr
from PIL import Image
from torchvision.models import ResNet50_Weights
= ResNet50_Weights.DEFAULT
weights = weights.transforms() # Necessary input transformations
preprocess = ort.InferenceSession("resnet50.onnx", providers=["CPUExecutionProvider"])
ort_session
def preprocess_inputs(img: Image):
= preprocess(img) # Change this line when using a different model
img = np.array(img).astype(np.float32)
img_array = np.expand_dims(img_array, axis=0)
img_array return img_array
def predict(img):
= preprocess_inputs(img)
img = {ort_session.get_inputs()[0].name: img}
ort_inputs = ort_session.run(None, ort_inputs)
ort_outputs
= np.argmax(ort_outputs[0], axis=1).item()
label_index = weights.meta["categories"][label_index]
predicted_label return predicted_label
That’s it! Now we can build the interface:
= gr.Interface(predict, gr.Image(type="pil", image_mode="RGB"), gr.Label())
demo demo.launch()
Your file structure should look like this:
.
├── README.md
├── app.py
├── requirements.txt
└── resnet50.onnx
When all the code is in app.py
and your project dependencies (imports) are listed in a requirements.txt
you are ready to push and deploy. Run git push
. If you encounter an error, you will need to install git-lfs