12
loading...
This website collects cookies to deliver better user experience
A short guide on FastAI vision model conversion to ONNX. Code included. 👀
ONNX is an open specification that consists of a definition of an extensible computation graph model, definition of standard data types, and definition of built-in operators. Extensible computation graph and definition of standard data types make up the Intermediate Representation (IR).
hot_dog_model_resnet18_256_256.pkl
. With load_learner()
I am loading the previously exported FastAI model on line 7. If you trained your own model you can skip the load step. Your model is already stored in learn
.model
attribute on learn
- see line 12. I don't want to train the model in subsequent steps thus I am also setting it to evaluation mode with eval()
. For more details on eval()
and torch.no_grad()
see the discussion [link].('not_hot_dog', array([[-3.0275817, 1.2424631]], dtype=float32))
into ('not_hot_dog', array([[0.01378838, 0.98621166]], dtype=float32))
. Notice the range of inference results - with the added softmax layer the results are scaled between 0-1.final_model
for export. On line 5 I am creating a dummy tensor that is used to define the input dimensions of my ONNX model. These dimensions are defined as batch x channels x height x width - BCHW
format. My FastAI model was trained on images with 256 x 256 dimension which was defined in our FastAI DataBlock API. The same dimensions must be used for the ONNX export - torch.randn(1, 3, 256, 256)
.torch.randn(1, 3, 320, 320)
while training image dimensions were 3 x 224 x 224
. It took me a while to figure out why I got poor results from my ONNX models.export_param
argument, if set to True
, includes the parameters of the trained model in the export. It's important to use True
in this case. We want our model with parameters. As you might have guessed, export_params=False
exports a model without parameters. Full torch.onnx
documentation [link].get_inputs()[0].shape
on the inference session instance to get the expected inputs. If you prefer a GUI, Netron [link] can help you to visualize the architecture of the neural networks.run()
method which returns a numpy array with softmaxed probabilities. See line 21.12