Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add onnx mean #2119

Merged
merged 8 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ represent the corresponding Burn Op.
| [MaxPool2d][98] | ✅ | ✅ |
| [MaxRoiPool][99] | ❌ | ❌ |
| [MaxUnpool][100] | ❌ | ❌ |
| [Mean][101] | ❌ | ✅ |
| [Mean][101] | ✅ | ✅ |
| [MeanVarianceNormalization][102] | ❌ | ❌ |
| [MelWeightMatrix][103] | ❌ | ❌ |
| [Min][104] | ✅ | ✅ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ fn main() {
.input("tests/maxpool1d/maxpool1d.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/min/min.onnx")
.input("tests/mean/mean.onnx")
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
.input("tests/not/not.onnx")
Expand Down
23 changes: 23 additions & 0 deletions crates/burn-import/onnx-tests/tests/mean/mean.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@


mean-model:‹
&
input1
input2
input3output"Mean MeanGraphZ
input1


Z
input2


Z
input3


b
output


B
41 changes: 41 additions & 0 deletions crates/burn-import/onnx-tests/tests/mean/mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/mean/mean.onnx

import onnx
import onnx.helper
import onnx.checker
import numpy as np

# Create input tensors
input1 = onnx.helper.make_tensor_value_info('input1', onnx.TensorProto.FLOAT, [3])
input2 = onnx.helper.make_tensor_value_info('input2', onnx.TensorProto.FLOAT, [3])
input3 = onnx.helper.make_tensor_value_info('input3', onnx.TensorProto.FLOAT, [3])

# Create output tensor
output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [3])

# Create the Mean node
mean_node = onnx.helper.make_node(
'Mean',
inputs=['input1', 'input2', 'input3'],
outputs=['output']
)

# Create the graph (GraphProto)
graph_def = onnx.helper.make_graph(
nodes=[mean_node],
name='MeanGraph',
inputs=[input1, input2, input3],
outputs=[output]
)

# Create the model (ModelProto)
model_def = onnx.helper.make_model(graph_def, producer_name='mean-model')
onnx.checker.check_model(model_def)

# Save the ONNX model
onnx.save(model_def, 'mean.onnx')

print("ONNX model 'mean.onnx' generated successfully.")

16 changes: 16 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ include_models!(
maxpool1d,
maxpool2d,
min,
mean,
mul,
neg,
not,
Expand Down Expand Up @@ -208,6 +209,21 @@ mod tests {
output.to_data().assert_eq(&expected, true);
}

#[test]
fn mean_tensor_and_tensor() {
let device = Default::default();
let model: mean::Model<Backend> = mean::Model::default();

let input1 = Tensor::<Backend, 1>::from_floats([1., 2., 3., 4.], &device);
let input2 = Tensor::<Backend, 1>::from_floats([2., 2., 4., 0.], &device);
let input3 = Tensor::<Backend, 1>::from_floats([3., 2., 5., -4.], &device);

let output = model.forward(input1, input2, input3);
let expected = TensorData::from([2.0f32, 2., 4., 0.]);

output.to_data().assert_eq(&expected, true);
}

#[test]
fn mul_scalar_with_tensor_and_tensor_with_tensor() {
// Initialize the model with weights (loaded from the exported file)
Expand Down
11 changes: 7 additions & 4 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use super::{
conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode,
gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -105,6 +105,7 @@ pub enum Node<PS: PrecisionSettings> {
Matmul(MatmulNode),
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Mean(MeanNode),
Pad(PadNode),
Range(RangeNode),
Reshape(ReshapeNode),
Expand Down Expand Up @@ -151,6 +152,7 @@ macro_rules! match_all {
Node::Matmul(node) => $func(node),
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Mean(node) => $func(node),
Node::Pad(node) => $func(node),
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
Expand Down Expand Up @@ -205,6 +207,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Matmul(_) => "matmul",
Node::MaxPool1d(_) => "max_pool1d",
Node::MaxPool2d(_) => "max_pool2d",
Node::Mean(_) => "mean",
Node::Pad(_) => "pad",
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",
Expand Down
109 changes: 109 additions & 0 deletions crates/burn-import/src/burn/node/mean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};

use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct MeanNode {
pub inputs: Vec<TensorType>,
pub output: TensorType,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for MeanNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<Type> {
self.inputs
.iter()
.map(|t| Type::Tensor(t.clone()))
.collect()
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let inputs = self
.inputs
.iter()
.map(|t| scope.tensor_use_owned(t, node_position));

let output = &self.output.name;
let inputs_len = self.inputs.len() as u32;

quote! {
let #output = (#(#inputs)+*) / #inputs_len;
}
}

fn into_node(self) -> Node<PS> {
Node::Mean(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{mean::MeanNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_mean() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(MeanNode::new(
vec![
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
],
TensorType::new_float("tensor3", 4),
));

graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
tensor2: Tensor<B, 4>
) -> Tensor<B, 4> {
let tensor3 = (tensor1 + tensor2) / 2u32;

tensor3
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool1d;
pub(crate) mod max_pool2d;
pub(crate) mod mean;
pub(crate) mod pad;
pub(crate) mod prelu;
pub(crate) mod random_normal;
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ use onnx_ir::{
};

pub use crate::burn::graph::RecordType;
use crate::burn::node::mean::MeanNode;

/// Generate code and states from `.onnx` files and save them to the `out_dir`.
#[derive(Debug, Default)]
Expand Down Expand Up @@ -268,6 +269,7 @@ impl ParsedOnnxGraph {
NodeType::Max => graph.register(Self::max_conversion(node)),
NodeType::MaxPool1d => graph.register(Self::max_pool1d_conversion(node)),
NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)),
NodeType::Mean => graph.register(Self::mean_conversion(node)),
NodeType::PRelu => graph.register(Self::prelu_conversion::<PS>(node)),
NodeType::AveragePool1d => graph.register(Self::avg_pool_1d_conversion(node)),
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
Expand Down Expand Up @@ -972,6 +974,13 @@ impl ParsedOnnxGraph {
MaxPool2dNode::new(name, input, output, config)
}

fn mean_conversion(node: Node) -> MeanNode {
let inputs = node.inputs.iter().map(TensorType::from).collect();
let output = TensorType::from(node.outputs.first().unwrap());

MeanNode::new(inputs, output)
}

fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
Expand Down
Loading