Skip to content

Commit

Permalink
Feature/onnx argmax (#1814)
Browse files Browse the repository at this point in the history
* pre-test

* implementing argmax for burn-import from onnx

* tidying

* fixing return types and tests

* addressing feedback

* only warn when select_last_index!=0
  • Loading branch information
will-maclean authored May 31, 2024
1 parent de0b49e commit 13a6f84
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 9 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ represent the corresponding Burn Op.
| [Acosh][3] |||
| [Add][4] |||
| [And][5] |||
| [ArgMax][6] | ||
| [ArgMax][6] | ||
| [ArgMin][7] |||
| [Asin][8] |||
| [Asinh][9] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ fn main() {
ModelGen::new()
.input("tests/add/add_int.onnx")
.input("tests/add/add.onnx")
.input("tests/argmax/argmax.onnx")
.input("tests/avg_pool1d/avg_pool1d.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/batch_norm/batch_norm.onnx")
Expand Down
Binary file not shown.
41 changes: 41 additions & 0 deletions crates/burn-import/onnx-tests/tests/argmax/argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/argmax/argmax.onnx

import torch
import torch.nn as nn

class Model(nn.Module):
def __init__(self, argmax_dim: int = 0):
super(Model, self).__init__()
self._argmax_dim = argmax_dim

def forward(self, x):
# Note: only keepdim=True is supported in burn
y = torch.argmax(input=x, dim=self._argmax_dim, keepdim=True)
return y

def main():

# Export to onnx
model = Model(1)
model.eval()
device = torch.device("cpu")
onnx_name = "argmax.onnx"
dummy_input = torch.randn((3, 4), device=device)
torch.onnx.export(model, dummy_input, onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.randn((2, 3), device=device)
print("Test input data shape: {}".format(test_input.shape))
output = model.forward(test_input)

print("Test output data shape: {}".format(output.shape))



if __name__ == '__main__':
main()
15 changes: 15 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ macro_rules! include_models {
include_models!(
add_int,
add,
argmax,
avg_pool2d,
avg_pool1d,
batch_norm,
Expand Down Expand Up @@ -368,6 +369,20 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn argmax() {
// Initialize the model with weights (loaded from the exported file)
let model: argmax::Model<Backend> = argmax::Model::default();

let device = Default::default();
// Run the model
let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
let output = model.forward(input);
let expected = Data::from([[2], [2]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn globalavrpool_1d_2d() {
// The model contains 1d and 2d global average pooling nodes
Expand Down
102 changes: 102 additions & 0 deletions crates/burn-import/src/burn/node/argmax.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use super::{Node, NodeCodegen};
use crate::burn::{TensorKind, TensorType, ToTokens, Type};

use burn::record::PrecisionSettings;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct ArgMaxNode {
pub input: TensorType,
pub output: TensorType,
pub axis: usize,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for ArgMaxNode {
fn output_types(&self) -> Vec<Type> {
let mut output = self.output.clone();
output.kind = TensorKind::Int;
vec![Type::Tensor(output)]
}

fn input_types(&self) -> Vec<crate::burn::Type> {
vec![Type::Tensor(self.input.clone())]
}

fn forward(
&self,
scope: &mut crate::burn::Scope,
node_position: usize,
) -> proc_macro2::TokenStream {
//NOTE: select_last_index and keep_dims are not supported
let axis = self.axis.to_tokens();

let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;

quote! {
let #output = #input.argmax(#axis);
}
}

fn into_node(self) -> super::Node<PS> {
Node::ArgMax(self)
}
}

#[cfg(test)]
mod tests {

use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};

#[test]
fn test_codegen_argmax() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(ArgMaxNode::new(
TensorType::new_float("tensor1", 2),
TensorType::new_int("tensor2", 2),
1,
));

graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);

let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2>
) -> Tensor<B, 2, Int> {
let tensor2 = tensor1.argmax(1);

tensor2
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
20 changes: 12 additions & 8 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use super::{
avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode,
binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode,
conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode,
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, reshape::ReshapeNode,
squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
argmax::ArgMaxNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode,
batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode,
constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode, prelu::PReluNode, random_normal::RandomNormalNode,
random_uniform::RandomUniformNode, reshape::ReshapeNode, squeeze::SqueezeNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -76,6 +77,7 @@ pub trait NodeCodegen<PS: PrecisionSettings>: std::fmt::Debug {

#[derive(Debug, Clone)]
pub enum Node<PS: PrecisionSettings> {
ArgMax(ArgMaxNode),
AvgPool1d(AvgPool1dNode),
AvgPool2d(AvgPool2dNode),
BatchNorm(BatchNormNode<PS>),
Expand Down Expand Up @@ -108,6 +110,7 @@ macro_rules! match_all {
($self:expr, $func:expr) => {{
#[allow(clippy::redundant_closure_call)]
match $self {
Node::ArgMax(node) => $func(node),
Node::AvgPool1d(node) => $func(node),
Node::AvgPool2d(node) => $func(node),
Node::BatchNorm(node) => $func(node),
Expand Down Expand Up @@ -150,6 +153,7 @@ impl<PS: PrecisionSettings> Serialize for Node<PS> {
impl<PS: PrecisionSettings> Node<PS> {
pub fn name(&self) -> &str {
match self {
Node::ArgMax(_) => "argmax",
Node::AvgPool1d(_) => "avg_pool1d",
Node::AvgPool2d(_) => "avg_pool2d",
Node::BatchNorm(_) => "batch_norm",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod base;

pub(crate) mod argmax;
pub(crate) mod avg_pool1d;
pub(crate) mod avg_pool2d;
pub(crate) mod batch_norm;
Expand Down
20 changes: 20 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use super::{
pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
match node.node_type {
NodeType::Add => same_as_input(node),
NodeType::ArgMax => argmax_update_outputs(node),
NodeType::AveragePool1d => same_as_input(node),
NodeType::AveragePool2d => same_as_input(node),
NodeType::BatchNormalization => same_as_input(node),
Expand Down Expand Up @@ -362,6 +363,25 @@ fn reduce_mean_update_outputs(node: &mut Node) {
}
}

fn argmax_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("Mean: multiple inputs are not supported");
}

let node_input = &mut node.inputs[0];
let tensor = match node_input.clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// Note: argmax in burn does not support keepdims=false
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: tensor.dim,
shape: tensor.shape.clone(),
elem_type: ElementType::Int64,
});
}

/// Update the output tensor dimension
fn squeeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {
Expand Down
52 changes: 52 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,58 @@ pub fn softmax_config(node: &Node) -> usize {
axis as usize
}

/// Create argmax config from the attributes of the node
pub fn argmax_config(node: &Node) -> usize {
let mut axis: i64 = 0;

// check if the node has only one input
if node.inputs.len() != 1 {
panic!(
"Argmax: multiple inputs are not supported (got {:?})",
node.inputs.len()
);
}

// extract the shape of the input tensor
let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// extract the attributes
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axis" => axis = value.clone().into_i64(),
"select_last_index" => {
// not all params are supported in burn
if value.clone().into_i64() != 0 {
log::warn!(
"only select_last_index=0 is supported for argmax in burn. Ignoring supplied value (got {:?})",
value
);
}
}
"keepdims" => {
// not all params are supported in burn
if value.clone().into_i64() != 1 {
panic!(
"Only keepdims=1 is supported for argmax in burn (got {:?})",
value
);
}
}
_ => {}
}
}

// if axis is negative, it is counted from the end
if axis < 0 {
axis += tensor.dim as i64;
}

axis as usize
}

/// Create concat config from the attributes of the node
pub fn concat_config(node: &Node) -> usize {
// the axis is the last dimension (Default: 1 per ONNX spec)
Expand Down
10 changes: 10 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
burn::{
graph::BurnGraph,
node::{
argmax::ArgMaxNode,
avg_pool1d::AvgPool1dNode,
avg_pool2d::AvgPool2dNode,
batch_norm::BatchNormNode,
Expand Down Expand Up @@ -235,6 +236,7 @@ impl OnnxGraph {
for node in self.nodes {
match node.node_type {
NodeType::Add => graph.register(Self::add_conversion(node)),
NodeType::ArgMax => graph.register(Self::argmax_conversion(node)),
NodeType::Sub => graph.register(Self::sub_conversion(node)),
NodeType::Mul => graph.register(Self::mul_conversion(node)),
NodeType::Div => graph.register(Self::div_conversion(node)),
Expand Down Expand Up @@ -681,6 +683,14 @@ impl OnnxGraph {
UnaryNode::tanh(input, output)
}

fn argmax_conversion(node: Node) -> ArgMaxNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let axis = argmax_config(&node);

ArgMaxNode::new(input, output, axis)
}

fn concat_conversion(node: Node) -> ConcatNode {
let inputs = node
.inputs
Expand Down

0 comments on commit 13a6f84

Please sign in to comment.