ONNX Import/Export

Table of Contents

  1. Overview
  2. Importing ONNX Models
    1. Basic Import
    2. From Bytes
    3. With Options
  3. Exporting to ONNX
    1. Basic Export
    2. With Dynamic Axes
  4. Supported Operators
    1. Math Operations
    2. Tensor Operations
    3. Reduction Operations
    4. Neural Network Layers
    5. Activations
    6. RNN
  5. Working with PyTorch Models
    1. Export from PyTorch
    2. Import in AxonML
  6. Custom Operators
    1. Register Custom Op
  7. Model Optimization
    1. Constant Folding
    2. Quantization
  8. Validation
  9. Model Information

Overview

AxonML supports ONNX (Open Neural Network Exchange) for interoperability with PyTorch, TensorFlow, and other frameworks.

Supported Features:

Importing ONNX Models

Basic Import

use axonml::onnx::import_onnx;

// Load ONNX model
let model = import_onnx("model.onnx")?;

// Get model information
println!("Inputs: {:?}", model.inputs());
println!("Outputs: {:?}", model.outputs());

// Run inference
let input = Tensor::randn(&[1, 3, 224, 224]);
let output = model.forward(&input)?;

From Bytes

use axonml::onnx::OnnxModel;

let bytes = std::fs::read("model.onnx")?;
let model = OnnxModel::from_bytes(&bytes)?;

With Options

use axonml::onnx::{import_onnx_with_options, ImportOptions};

let options = ImportOptions::new()
    .device(Device::CUDA(0))
    .dtype(DType::F16)
    .optimize(true);

let model = import_onnx_with_options("model.onnx", options)?;

Exporting to ONNX

Basic Export

use axonml::onnx::export_onnx;

let model = Sequential::new()
    .add(Linear::new(784, 256))
    .add(ReLU)
    .add(Linear::new(256, 10));

// Export with example input shape
export_onnx(&model, "my_model.onnx", &[1, 784])?;

With Dynamic Axes

use axonml::onnx::{export_onnx_with_options, ExportOptions};

let options = ExportOptions::new()
    .dynamic_axes(vec![("input", vec![0])])  // Batch dimension is dynamic
    .opset_version(13);

export_onnx_with_options(&model, "model.onnx", &[1, 784], options)?;

Supported Operators

Math Operations

ONNX Op Status Notes
Add With broadcasting
Sub With broadcasting
Mul With broadcasting
Div With broadcasting
MatMul 2D and batched
Gemm General matrix multiply
Pow Element-wise
Sqrt  
Exp  
Log  

Tensor Operations

ONNX Op Status Notes
Reshape  
Transpose  
Concat Any axis
Split  
Slice  
Gather  
Squeeze  
Unsqueeze  
Flatten  

Reduction Operations

ONNX Op Status Notes
ReduceSum  
ReduceMean  
ReduceMax  
ReduceMin  
ReduceProd  

Neural Network Layers

ONNX Op Status Notes
Conv 1D and 2D
ConvTranspose  
MaxPool  
AveragePool  
GlobalAveragePool  
BatchNormalization  
Dropout Inference mode
Softmax  
LogSoftmax  

Activations

ONNX Op Status Notes
Relu  
LeakyRelu  
Sigmoid  
Tanh  
Elu  
Gelu  
Silu  

RNN

ONNX Op Status Notes
LSTM Unidirectional
GRU  

Working with PyTorch Models

Export from PyTorch

import torch

model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    model,
    dummy_input,
    "resnet18.onnx",
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}}
)

Import in AxonML

use axonml::onnx::import_onnx;

let model = import_onnx("resnet18.onnx")?;

// Run inference
let input = Tensor::randn(&[1, 3, 224, 224]);
let output = model.forward(&input)?;

// Get predictions
let predictions = output.argmax(1);

Custom Operators

Register Custom Op

use axonml::onnx::{register_custom_op, CustomOp};

struct MyCustomOp;

impl CustomOp for MyCustomOp {
    fn name(&self) -> &str {
        "MyCustomOp"
    }

    fn forward(&self, inputs: &[&Tensor]) -> Result<Vec<Tensor>> {
        let x = inputs[0];
        let y = x.mul(&Tensor::full(x.shape(), 2.0));
        Ok(vec![y])
    }
}

// Register before importing
register_custom_op(MyCustomOp);

let model = import_onnx("model_with_custom_op.onnx")?;

Model Optimization

Constant Folding

use axonml::onnx::{optimize_model, OptimizationPass};

let model = import_onnx("model.onnx")?;

let optimized = optimize_model(&model, &[
    OptimizationPass::ConstantFolding,
    OptimizationPass::EliminateDeadNodes,
    OptimizationPass::FuseOperations,
])?;

Quantization

use axonml::onnx::quantize_model;

let model = import_onnx("model.onnx")?;
let quantized = quantize_model(&model, DType::I8)?;

// Save quantized model
export_onnx(&quantized, "model_int8.onnx", &[1, 3, 224, 224])?;

Validation

use axonml::onnx::validate_model;

// Check model is valid ONNX
let result = validate_model("model.onnx");
match result {
    Ok(()) => println!("Model is valid"),
    Err(e) => println!("Validation error: {}", e),
}

Model Information

let model = import_onnx("model.onnx")?;

// Input/output info
for input in model.inputs() {
    println!("Input: {} - {:?}", input.name, input.shape);
}

for output in model.outputs() {
    println!("Output: {} - {:?}", output.name, output.shape);
}

// Graph info
println!("Nodes: {}", model.graph().nodes().len());
println!("Opset version: {}", model.opset_version());