Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/onnx argmax #1814

Merged
merged 8 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ represent the corresponding Burn Op.
| [Acosh][3] | ❌ | ❌ |
| [Add][4] | ✅ | ✅ |
| [And][5] | ❌ | ❌ |
| [ArgMax][6] | | ✅ |
| [ArgMax][6] | | ✅ |
| [ArgMin][7] | ❌ | ❌ |
| [Asin][8] | ❌ | ❌ |
| [Asinh][9] | ❌ | ❌ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ fn main() {
ModelGen::new()
.input("tests/add/add_int.onnx")
.input("tests/add/add.onnx")
.input("tests/argmax/argmax.onnx")
.input("tests/avg_pool1d/avg_pool1d.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/batch_norm/batch_norm.onnx")
Expand Down
Binary file not shown.
41 changes: 41 additions & 0 deletions crates/burn-import/onnx-tests/tests/argmax/argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/argmax/argmax.onnx

import torch
import torch.nn as nn

class Model(nn.Module):
def __init__(self, argmax_dim: int = 0):
super(Model, self).__init__()
self._argmax_dim = argmax_dim

def forward(self, x):
# Note: only keepdim=True is supported in burn
y = torch.argmax(input=x, dim=self._argmax_dim, keepdim=True)
return y

def main():

# Export to onnx
model = Model(1)
model.eval()
device = torch.device("cpu")
onnx_name = "argmax.onnx"
dummy_input = torch.randn((3, 4), device=device)
torch.onnx.export(model, dummy_input, onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.randn((2, 3), device=device)
print("Test input data shape: {}".format(test_input.shape))
output = model.forward(test_input)

print("Test output data shape: {}".format(output.shape))



if __name__ == '__main__':
main()
15 changes: 15 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ macro_rules! include_models {
include_models!(
add_int,
add,
argmax,
avg_pool2d,
avg_pool1d,
batch_norm,
Expand Down Expand Up @@ -368,6 +369,20 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn argmax() {
// Initialize the model with weights (loaded from the exported file)
let model: argmax::Model<Backend> = argmax::Model::default();

let device = Default::default();
// Run the model
let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
let output = model.forward(input);
let expected = Data::from([[2], [2]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn globalavrpool_1d_2d() {
// The model contains 1d and 2d global average pooling nodes
Expand Down
102 changes: 102 additions & 0 deletions crates/burn-import/src/burn/node/argmax.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use super::{Node, NodeCodegen};
use crate::burn::{TensorKind, TensorType, ToTokens, Type};

use burn::record::PrecisionSettings;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct ArgMaxNode {
pub input: TensorType,
pub output: TensorType,
pub axis: usize,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for ArgMaxNode {
fn output_types(&self) -> Vec<Type> {
let mut output = self.output.clone();
output.kind = TensorKind::Int;
vec![Type::Tensor(output)]
}

fn input_types(&self) -> Vec<crate::burn::Type> {
vec![Type::Tensor(self.input.clone())]
}

fn forward(
&self,
scope: &mut crate::burn::Scope,
node_position: usize,
) -> proc_macro2::TokenStream {
//NOTE: select_last_index and keep_dims are not supported
let axis = self.axis.to_tokens();

let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;

quote! {
let #output = #input.argmax(#axis);
}
}

fn into_node(self) -> super::Node<PS> {
Node::ArgMax(self)
}
}

#[cfg(test)]
mod tests {

use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};

#[test]
fn test_codegen_argmax() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(ArgMaxNode::new(
TensorType::new_float("tensor1", 2),
TensorType::new_int("tensor2", 2),
1,
));

graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);

let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2>
) -> Tensor<B, 2, Int> {
let tensor2 = tensor1.argmax(1);

tensor2
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
20 changes: 12 additions & 8 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use super::{
avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode,
binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode,
conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode,
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, reshape::ReshapeNode,
squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
argmax::ArgMaxNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode,
batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode,
constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode, prelu::PReluNode, random_normal::RandomNormalNode,
random_uniform::RandomUniformNode, reshape::ReshapeNode, squeeze::SqueezeNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -76,6 +77,7 @@ pub trait NodeCodegen<PS: PrecisionSettings>: std::fmt::Debug {

#[derive(Debug, Clone)]
pub enum Node<PS: PrecisionSettings> {
ArgMax(ArgMaxNode),
AvgPool1d(AvgPool1dNode),
AvgPool2d(AvgPool2dNode),
BatchNorm(BatchNormNode<PS>),
Expand Down Expand Up @@ -108,6 +110,7 @@ macro_rules! match_all {
($self:expr, $func:expr) => {{
#[allow(clippy::redundant_closure_call)]
match $self {
Node::ArgMax(node) => $func(node),
Node::AvgPool1d(node) => $func(node),
Node::AvgPool2d(node) => $func(node),
Node::BatchNorm(node) => $func(node),
Expand Down Expand Up @@ -150,6 +153,7 @@ impl<PS: PrecisionSettings> Serialize for Node<PS> {
impl<PS: PrecisionSettings> Node<PS> {
pub fn name(&self) -> &str {
match self {
Node::ArgMax(_) => "argmax",
Node::AvgPool1d(_) => "avg_pool1d",
Node::AvgPool2d(_) => "avg_pool2d",
Node::BatchNorm(_) => "batch_norm",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod base;

pub(crate) mod argmax;
pub(crate) mod avg_pool1d;
pub(crate) mod avg_pool2d;
pub(crate) mod batch_norm;
Expand Down
20 changes: 20 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use super::{
pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
match node.node_type {
NodeType::Add => same_as_input(node),
NodeType::ArgMax => argmax_update_outputs(node),
NodeType::AveragePool1d => same_as_input(node),
NodeType::AveragePool2d => same_as_input(node),
NodeType::BatchNormalization => same_as_input(node),
Expand Down Expand Up @@ -362,6 +363,25 @@ fn reduce_mean_update_outputs(node: &mut Node) {
}
}

fn argmax_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("Mean: multiple inputs are not supported");
}

let node_input = &mut node.inputs[0];
let tensor = match node_input.clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// Note: argmax in burn does not support keepdims=false
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: tensor.dim,
shape: tensor.shape.clone(),
elem_type: ElementType::Int64,
});
}

/// Update the output tensor dimension
fn squeeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {
Expand Down
52 changes: 52 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,58 @@ pub fn softmax_config(node: &Node) -> usize {
axis as usize
}

/// Create argmax config from the attributes of the node
pub fn argmax_config(node: &Node) -> usize {
let mut axis: i64 = 0;

// check if the node has only one input
if node.inputs.len() != 1 {
panic!(
"Argmax: multiple inputs are not supported (got {:?})",
node.inputs.len()
);
}

// extract the shape of the input tensor
let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// extract the attributes
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axis" => axis = value.clone().into_i64(),
"select_last_index" => {
// not all params are supported in burn
if value.clone().into_i64() != 0 {
log::warn!(
"only select_last_index=0 is supported for argmax in burn. Ignoring supplied value (got {:?})",
value
);
}
}
"keepdims" => {
// not all params are supported in burn
if value.clone().into_i64() != 1 {
panic!(
"Only keepdims=1 is supported for argmax in burn (got {:?})",
value
);
}
}
_ => {}
}
}

// if axis is negative, it is counted from the end
if axis < 0 {
axis += tensor.dim as i64;
}

axis as usize
}

/// Create concat config from the attributes of the node
pub fn concat_config(node: &Node) -> usize {
// the axis is the last dimension (Default: 1 per ONNX spec)
Expand Down
10 changes: 10 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
burn::{
graph::BurnGraph,
node::{
argmax::ArgMaxNode,
avg_pool1d::AvgPool1dNode,
avg_pool2d::AvgPool2dNode,
batch_norm::BatchNormNode,
Expand Down Expand Up @@ -235,6 +236,7 @@ impl OnnxGraph {
for node in self.nodes {
match node.node_type {
NodeType::Add => graph.register(Self::add_conversion(node)),
NodeType::ArgMax => graph.register(Self::argmax_conversion(node)),
NodeType::Sub => graph.register(Self::sub_conversion(node)),
NodeType::Mul => graph.register(Self::mul_conversion(node)),
NodeType::Div => graph.register(Self::div_conversion(node)),
Expand Down Expand Up @@ -681,6 +683,14 @@ impl OnnxGraph {
UnaryNode::tanh(input, output)
}

fn argmax_conversion(node: Node) -> ArgMaxNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let axis = argmax_config(&node);

ArgMaxNode::new(input, output, axis)
}

fn concat_conversion(node: Node) -> ConcatNode {
let inputs = node
.inputs
Expand Down
Loading