diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 433aa6370e..dacae6e8cc 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -184,7 +184,7 @@ represent the corresponding Burn Op. | [STFT][177] | ❌ | ❌ | | [StringNormalizer][178] | ❌ | ❌ | | [Sub][179] | ✅ | ✅ | -| [Sum][180] | ❌ | ✅ | +| [Sum][180] | ✅ | ✅ | | [Tan][181] | ❌ | ❌ | | [Tanh][182] | ✅ | ✅ | | [TfIdfVectorizer][183] | ❌ | ❌ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index df807bd001..391a085a7d 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -69,6 +69,8 @@ fn main() { .input("tests/conv_transpose2d/conv_transpose2d.onnx") .input("tests/pow/pow.onnx") .input("tests/pow/pow_int.onnx") + .input("tests/sum/sum.onnx") + .input("tests/sum/sum_int.onnx") .input("tests/unsqueeze/unsqueeze.onnx") .input("tests/unsqueeze/unsqueeze_opset16.onnx") .input("tests/unsqueeze/unsqueeze_opset11.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index 238671506b..7188e1c396 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -75,6 +75,8 @@ include_models!( sqrt, sub_int, sub, + sum, + sum_int, tanh, transpose, conv_transpose2d, @@ -161,6 +163,36 @@ mod tests { assert_eq!(output.to_data(), expected); } + #[test] + fn sum_tensor_and_tensor() { + let device = Default::default(); + let model: sum::Model = sum::Model::default(); + + let input1 = Tensor::::from_floats([1., 2., 3., 4.], &device); + let input2 = Tensor::::from_floats([1., 2., 3., 4.], &device); + let input3 = Tensor::::from_floats([1., 2., 3., 4.], &device); + + let output = model.forward(input1, input2, input3); + let expected = Data::from([3., 6., 9., 12.]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn sum_int_tensor_and_int_tensor() { + let device = Default::default(); + let model: sum_int::Model = sum_int::Model::default(); + + let input1 = Tensor::::from_ints([1, 2, 3, 4], &device); + let input2 = Tensor::::from_ints([1, 2, 3, 4], &device); + let input3 = Tensor::::from_ints([1, 2, 3, 4], &device); + + let output = model.forward(input1, input2, input3); + let expected = Data::from([3, 6, 9, 12]); + + assert_eq!(output.to_data(), expected); + } + #[test] fn mul_scalar_with_tensor_and_tensor_with_tensor() { // Initialize the model with weights (loaded from the exported file) diff --git a/crates/burn-import/onnx-tests/tests/sum/sum.onnx b/crates/burn-import/onnx-tests/tests/sum/sum.onnx new file mode 100644 index 0000000000..2ce3419a4e --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/sum/sum.onnx @@ -0,0 +1,22 @@ + + sum-model: +% +input1 +input2 +input3output"SumSumGraphZ +input1 + + +Z +input2 + + +Z +input3 + + +b +output + + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/sum/sum.py b/crates/burn-import/onnx-tests/tests/sum/sum.py new file mode 100644 index 0000000000..bfb13dc8f3 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/sum/sum.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/sum/sum.onnx + +import onnx +import onnx.helper +import onnx.checker +import numpy as np + +# Create input tensors +input1 = onnx.helper.make_tensor_value_info('input1', onnx.TensorProto.FLOAT, [3]) +input2 = onnx.helper.make_tensor_value_info('input2', onnx.TensorProto.FLOAT, [3]) +input3 = onnx.helper.make_tensor_value_info('input3', onnx.TensorProto.FLOAT, [3]) + +# Create output tensor +output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [3]) + +# Create the Sum node +sum_node = onnx.helper.make_node( + 'Sum', + inputs=['input1', 'input2', 'input3'], + outputs=['output'] +) + +# Create the graph (GraphProto) +graph_def = onnx.helper.make_graph( + nodes=[sum_node], + name='SumGraph', + inputs=[input1, input2, input3], + outputs=[output] +) + +# Create the model (ModelProto) +model_def = onnx.helper.make_model(graph_def, producer_name='sum-model') +onnx.checker.check_model(model_def) + +# Save the ONNX model +onnx.save(model_def, 'sum.onnx') + +print("ONNX model 'sum.onnx' generated successfully.") + diff --git a/crates/burn-import/onnx-tests/tests/sum/sum_int.onnx b/crates/burn-import/onnx-tests/tests/sum/sum_int.onnx new file mode 100644 index 0000000000..c1ce819a15 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/sum/sum_int.onnx @@ -0,0 +1,22 @@ + + sum-model: +% +input1 +input2 +input3output"SumSumGraphZ +input1 + + +Z +input2 + + +Z +input3 + + +b +output + + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/sum/sum_int.py b/crates/burn-import/onnx-tests/tests/sum/sum_int.py new file mode 100644 index 0000000000..97ca2ba18a --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/sum/sum_int.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/sum/sum.onnx + +import onnx +import onnx.helper +import onnx.checker +import numpy as np + +# Create input tensors +input1 = onnx.helper.make_tensor_value_info('input1', onnx.TensorProto.INT64, [3]) +input2 = onnx.helper.make_tensor_value_info('input2', onnx.TensorProto.INT64, [3]) +input3 = onnx.helper.make_tensor_value_info('input3', onnx.TensorProto.INT64, [3]) + +# Create output tensor +output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.INT64, [3]) + +# Create the Sum node +sum_node = onnx.helper.make_node( + 'Sum', + inputs=['input1', 'input2', 'input3'], + outputs=['output'] +) + +# Create the graph (GraphProto) +graph_def = onnx.helper.make_graph( + nodes=[sum_node], + name='SumGraph', + inputs=[input1, input2, input3], + outputs=[output] +) + +# Create the model (ModelProto) +model_def = onnx.helper.make_model(graph_def, producer_name='sum-model') +onnx.checker.check_model(model_def) + +# Save the ONNX model +onnx.save(model_def, 'sum_int.onnx') + +print("ONNX model 'sum_int.onnx' generated successfully.") diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 15f9396855..027e98ee93 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -1,14 +1,14 @@ -use super::expand::ExpandNode; use super::{ argmax::ArgMaxNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode, - conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode, - gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, + conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, expand::ExpandNode, + gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, - reshape::ReshapeNode, squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, + reshape::ReshapeNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, + unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; @@ -103,6 +103,7 @@ pub enum Node { Range(RangeNode), Reshape(ReshapeNode), Squeeze(SqueezeNode), + Sum(SumNode), Unary(UnaryNode), Unsqueeze(UnsqueezeNode), Where(WhereNode), @@ -139,6 +140,7 @@ macro_rules! match_all { Node::Range(node) => $func(node), Node::Reshape(node) => $func(node), Node::Squeeze(node) => $func(node), + Node::Sum(node) => $func(node), Node::Unary(node) => $func(node), Node::Unsqueeze(node) => $func(node), Node::Where(node) => $func(node), @@ -185,6 +187,7 @@ impl Node { Node::Range(_) => "range", Node::Reshape(_) => "reshape", Node::Squeeze(_) => "squeeze", + Node::Sum(_) => "add", Node::Unary(unary) => unary.kind.as_str(), Node::Unsqueeze(_) => "unsqueeze", Node::Where(_) => "where", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 5d082f8cc8..b8f44750bd 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -28,6 +28,7 @@ pub(crate) mod random_uniform; pub(crate) mod range; pub(crate) mod reshape; pub(crate) mod squeeze; +pub(crate) mod sum; pub(crate) mod unary; pub(crate) mod unsqueeze; pub(crate) use base::*; diff --git a/crates/burn-import/src/burn/node/sum.rs b/crates/burn-import/src/burn/node/sum.rs new file mode 100644 index 0000000000..ad0a2601f8 --- /dev/null +++ b/crates/burn-import/src/burn/node/sum.rs @@ -0,0 +1,108 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, Type}; + +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone, new)] +pub struct SumNode { + pub inputs: Vec, + pub output: TensorType, +} + +impl NodeCodegen for SumNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + self.inputs + .iter() + .map(|t| Type::Tensor(t.clone())) + .collect() + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let inputs = self + .inputs + .iter() + .map(|t| scope.tensor_use_owned(t, node_position)); + + let output = &self.output.name; + + quote! { + let #output = #(#inputs)+*; + } + } + + fn into_node(self) -> Node { + Node::Sum(self) + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{sum::SumNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_sum() { + let mut graph = BurnGraph::::default(); + + graph.register(SumNode::new( + vec![ + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + ], + TensorType::new_float("tensor3", 4), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + tensor2: Tensor + ) -> Tensor { + let tensor3 = tensor1 + tensor2; + + tensor3 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 3ccd52a1f8..a70502f9ca 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -67,6 +67,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::Softmax => same_as_input(node), NodeType::Sqrt => same_as_input(node), NodeType::Sub => same_as_input(node), + NodeType::Sum => same_as_input(node), NodeType::Tanh => same_as_input(node), NodeType::Transpose => same_as_input(node), NodeType::Unsqueeze => unsqueeze_update_output(node), diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index bca9efab08..aceac8f641 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -43,6 +43,7 @@ use crate::{ range::RangeNode, reshape::ReshapeNode, squeeze::SqueezeNode, + sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }, @@ -293,6 +294,7 @@ impl OnnxGraph { NodeType::Shape => graph.register(Self::shape_conversion(node)), NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)), NodeType::Sin => graph.register(Self::sin_conversion(node)), + NodeType::Sum => graph.register(Self::sum_conversion(node)), NodeType::Transpose => graph.register(Self::transpose_conversion(node)), NodeType::Concat => graph.register(Self::concat_conversion(node)), NodeType::Cast => graph.register(Self::cast_conversion(node)), @@ -684,6 +686,17 @@ impl OnnxGraph { UnaryNode::sin(input, output) } + fn sum_conversion(node: Node) -> SumNode { + let inputs = node + .inputs + .iter() + .map(|input| input.to_tensor_type()) + .collect(); + let output = node.outputs.first().unwrap().to_tensor_type(); + + SumNode::new(inputs, output) + } + fn reciprocal_conversion(node: Node) -> UnaryNode { let input = node.inputs.first().unwrap().to_type(); let output = node.outputs.first().unwrap().to_type();