Skip to content

Commit

Permalink
feat: added slice onnx import (#1856)
Browse files Browse the repository at this point in the history
* feat: added slice onnx import

* fix: axes, steps handling
  • Loading branch information
JachymPutta authored Jun 11, 2024
1 parent dd60446 commit 671ec8c
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 3 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ represent the corresponding Burn Op.
| [Sin][164] |||
| [Sinh][165] |||
| [Size][166] |||
| [Slice][167] | ||
| [Slice][167] | ||
| [Softmax][168] |||
| [SoftmaxCrossEntropyLoss][169] |||
| [Softplus][170] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ fn main() {
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
.input("tests/pow/pow.onnx")
.input("tests/pow/pow_int.onnx")
.input("tests/slice/slice.onnx")
.input("tests/sum/sum.onnx")
.input("tests/sum/sum_int.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
Expand Down
19 changes: 19 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ include_models!(
sigmoid,
sign,
sin,
slice,
softmax,
sqrt,
sub_int,
Expand Down Expand Up @@ -459,6 +460,24 @@ mod tests {
assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2)));
}

#[test]
fn slice() {
let model: slice::Model<Backend> = slice::Model::default();
let device = Default::default();

let input = Tensor::<Backend, 2>::from_floats(
[
[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15., 16., 17., 18., 19., 20.],
],
&device,
);
let output = model.forward(input);
let expected = Data::from([[1., 2., 3., 4., 5.]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn softmax() {
// Initialize the model without weights (because the exported file does not contain them)
Expand Down
Binary file added crates/burn-import/onnx-tests/tests/slice/slice.onnx
Binary file not shown.
101 changes: 101 additions & 0 deletions crates/burn-import/onnx-tests/tests/slice/slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/slice/slice.onnx

import onnx
from onnx import helper, TensorProto

def main() -> None:
# Starts
starts_val = [0,0] # Example shape value
starts_tensor = helper.make_tensor(
name="starts",
data_type=TensorProto.INT64,
dims=[len(starts_val)],
vals=starts_val,
)
starts_node = helper.make_node(
"Constant",
name="starts_constant",
inputs=[],
outputs=["starts"],
value=starts_tensor,
)

# Ends
ends_val = [1,5] # Example shape value
ends_tensor = helper.make_tensor(
name="ends",
data_type=TensorProto.INT64,
dims=[len(ends_val)],
vals=ends_val,
)
ends_node = helper.make_node(
"Constant",
name="ends_constant",
inputs=[],
outputs=["ends"],
value=ends_tensor,
)

# Axes
axes_val = [0,1] # Example shape value
axes_tensor = helper.make_tensor(
name="axes",
data_type=TensorProto.INT64,
dims=[len(axes_val)],
vals=axes_val,
)
axes_node = helper.make_node(
"Constant",
name="axes_constant",
inputs=[],
outputs=["axes"],
value=axes_tensor,
)

# Steps
steps_val = [1, 1] # Example shape value
steps_tensor = helper.make_tensor(
name="steps",
data_type=TensorProto.INT64,
dims=[len(steps_val)],
vals=steps_val,
)
steps_node = helper.make_node(
"Constant",
name="steps_constant",
inputs=[],
outputs=["steps"],
value=steps_tensor,
)

# Define the Slice node that uses the outputs from the constant nodes
slice_node = helper.make_node(
"Slice",
name="slice_node",
inputs=["input_tensor", "starts", "ends", "axes", "steps"],
outputs=["output"],
)

# Create the graph
graph_def = helper.make_graph(
nodes=[starts_node, ends_node, axes_node, steps_node, slice_node],
name="SliceGraph",
inputs=[
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [2, 10]),
],
outputs=[
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 5])
],
)

# Create the model
model_def = helper.make_model(graph_def, producer_name="slice")

# Save the model to a file
onnx.save(model_def, "slice.onnx")


if __name__ == "__main__":
main()
5 changes: 4 additions & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::{
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
reshape::ReshapeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
Expand Down Expand Up @@ -102,6 +102,7 @@ pub enum Node<PS: PrecisionSettings> {
MaxPool2d(MaxPool2dNode),
Range(RangeNode),
Reshape(ReshapeNode),
Slice(SliceNode),
Squeeze(SqueezeNode),
Sum(SumNode),
Unary(UnaryNode),
Expand Down Expand Up @@ -139,6 +140,7 @@ macro_rules! match_all {
Node::MaxPool2d(node) => $func(node),
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Slice(node) => $func(node),
Node::Squeeze(node) => $func(node),
Node::Sum(node) => $func(node),
Node::Unary(node) => $func(node),
Expand Down Expand Up @@ -186,6 +188,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::MaxPool2d(_) => "max_pool2d",
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",
Node::Slice(_) => "slice",
Node::Squeeze(_) => "squeeze",
Node::Sum(_) => "add",
Node::Unary(unary) => unary.kind.as_str(),
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub(crate) mod random_normal;
pub(crate) mod random_uniform;
pub(crate) mod range;
pub(crate) mod reshape;
pub(crate) mod slice;
pub(crate) mod squeeze;
pub(crate) mod sum;
pub(crate) mod unary;
Expand Down
90 changes: 90 additions & 0 deletions crates/burn-import/src/burn/node/slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct SliceNode {
pub input: TensorType,
pub output: TensorType,
pub starts: Vec<usize>,
pub ends: Vec<usize>,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for SliceNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let starts = &self.starts;
let ends = &self.ends;

quote! {
let #output = #input.slice([#(#starts..#ends),*]);
}
}
fn into_node(self) -> Node<PS> {
Node::Slice(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{slice::SliceNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_slice() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(SliceNode::new(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
vec![0, 0, 0, 0],
vec![1, 1, 1, 1],
));
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.slice([0usize..1usize,0usize..1usize,0usize..1usize,0usize..1usize]);

tensor2
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
28 changes: 28 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Sigmoid => same_as_input(node),
NodeType::Sign => same_as_input(node),
NodeType::Sin => same_as_input(node),
NodeType::Slice => slice_update_outputs(node),
NodeType::Softmax => same_as_input(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Sub => same_as_input(node),
Expand Down Expand Up @@ -423,6 +424,33 @@ fn squeeze_update_output(node: &mut Node) {
});
}

fn slice_update_outputs(node: &mut Node) {
let shape = match &node.inputs[1].value {
Some(value) => match value {
Data::Int64s(shape) => Some(shape.clone()),
_ => panic!("Slice: invalid input types"),
},
None => None,
};

if shape.is_none() {
panic!("Slice: invalid shape");
}

let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Slice: invalid output types"),
};

if let Some(shape) = shape {
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: shape.len(),
shape: None, // shape is calculated at runtime
..output
});
}
}

/// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-import/src/onnx/from_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use super::ir::{ArgType, Argument, Node, NodeType};

use protobuf::Message;

const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 11] = [
NodeType::BatchNormalization,
NodeType::Clip,
NodeType::Conv1d,
Expand All @@ -28,6 +28,7 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
NodeType::Reshape,
NodeType::Unsqueeze,
NodeType::ReduceSum,
NodeType::Slice,
NodeType::Squeeze,
];

Expand Down
Loading

0 comments on commit 671ec8c

Please sign in to comment.