Skip to content

Commit

Permalink
feat: expand onnx import (tracel-ai#1813)
Browse files Browse the repository at this point in the history
* feat: added expand to import
  • Loading branch information
JachymPutta authored and LilDojd committed Jun 5, 2024
1 parent 4651eda commit 37eec0f
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 2 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ represent the corresponding Burn Op.
| [Equal][51] |||
| [Erf][52] |||
| [Exp][53] |||
| [Expand][54] | ||
| [Expand][54] | ||
| [EyeLike][55] |||
| [Flatten][56] |||
| [Floor][57] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ fn main() {
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
.input("tests/not/not.onnx")
.input("tests/expand/expand.onnx")
.input("tests/greater/greater.onnx")
.input("tests/greater_or_equal/greater_or_equal.onnx")
.input("tests/less/less.onnx")
Expand Down
15 changes: 15 additions & 0 deletions crates/burn-import/onnx-tests/tests/expand/expand.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

expand:�
>shapeshape_constant"Constant*
value*:Bshape�
.
input_tensor
shapeoutput/Expand"Expand ExpandGraphZ
input_tensor


b
output


B
53 changes: 53 additions & 0 deletions crates/burn-import/onnx-tests/tests/expand/expand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/expand/expand.onnx

import onnx
from onnx import helper, TensorProto

def main() -> None:
# Define the shape tensor as a constant node
shape_value = [2, 2] # Example shape value
shape_tensor = helper.make_tensor(
name='shape',
data_type=TensorProto.INT64,
dims=[len(shape_value)],
vals=shape_value,
)

shape_node = helper.make_node(
'Constant',
name='shape_constant',
inputs=[],
outputs=['shape'],
value=shape_tensor,
)

# Define the Expand node that uses the outputs from the constant nodes
expand_node = helper.make_node(
'Expand',
name='/Expand',
inputs=['input_tensor', 'shape'],
outputs=['output']
)

# Create the graph
graph_def = helper.make_graph(
nodes=[shape_node, expand_node],
name='ExpandGraph',
inputs=[
helper.make_tensor_value_info('input_tensor', TensorProto.FLOAT, [2, 1]),
],
outputs=[
helper.make_tensor_value_info('output', TensorProto.FLOAT, [2, 2])
],
)

# Create the model
model_def = helper.make_model(graph_def, producer_name='expand')

# Save the model to a file
onnx.save(model_def, 'expand.onnx')

if __name__ == '__main__':
main()
14 changes: 14 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ include_models!(
equal,
erf,
exp,
expand,
flatten,
gather,
gelu,
Expand Down Expand Up @@ -1149,6 +1150,19 @@ mod tests {
output.to_data().assert_approx_eq(&expected, 2);
}

#[test]
fn expand() {
let device = Default::default();
let model: expand::Model<Backend> = expand::Model::new(&device);

let input1 = Tensor::<Backend, 2>::from_floats([[-1.0], [1.0]], &device);

let output = model.forward(input1);
let expected_shape = Shape::from([2, 2]);

assert_eq!(output.shape(), expected_shape);
}

#[test]
fn gelu() {
let device = Default::default();
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::expand::ExpandNode;
use super::{
argmax::ArgMaxNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode,
batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode,
Expand Down Expand Up @@ -90,6 +91,7 @@ pub enum Node<PS: PrecisionSettings> {
ConvTranspose2d(ConvTranspose2dNode<PS>),
PRelu(PReluNode<PS>),
Dropout(DropoutNode),
Expand(ExpandNode),
Gather(GatherNode),
GlobalAvgPool(GlobalAvgPoolNode),
LayerNorm(LayerNormNode<PS>),
Expand Down Expand Up @@ -124,6 +126,7 @@ macro_rules! match_all {
Node::ConvTranspose2d(node) => $func(node),
Node::PRelu(node) => $func(node),
Node::Dropout(node) => $func(node),
Node::Expand(node) => $func(node),
Node::Gather(node) => $func(node),
Node::GlobalAvgPool(node) => $func(node),
Node::LayerNorm(node) => $func(node),
Expand Down Expand Up @@ -168,6 +171,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::ConvTranspose2d(_) => "conv_transpose2d",
Node::PRelu(_) => "prelu",
Node::Dropout(_) => "dropout",
Node::Expand(_) => "expand",
Node::Gather(_) => "gather",
Node::GlobalAvgPool(_) => "global_avg_pool",
Node::LayerNorm(_) => "layer_norm",
Expand Down
92 changes: 92 additions & 0 deletions crates/burn-import/src/burn/node/expand.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct ExpandNode {
pub input: TensorType,
pub output: TensorType,
pub shape: Vec<i64>,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for ExpandNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let shape = &self.shape.to_tokens();
let output = &self.output.name;

quote! {
let #output = #input.expand(#shape);
}
}

fn into_node(self) -> Node<PS> {
Node::Expand(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{expand::ExpandNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_nodes() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(ExpandNode::new(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
[4, 4, 4, 4].into(),
));

graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.expand([4,4,4,4]);

tensor2
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub(crate) mod conv1d;
pub(crate) mod conv2d;
pub(crate) mod conv_transpose_2d;
pub(crate) mod dropout;
pub(crate) mod expand;
pub(crate) mod gather;
pub(crate) mod global_avg_pool;
pub(crate) mod layer_norm;
Expand Down
28 changes: 28 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Equal => equal_update_outputs(node),
NodeType::Erf => same_as_input(node),
NodeType::Exp => same_as_input(node),
NodeType::Expand => expand_update_outputs(node),
NodeType::Flatten => flatten_update_outputs(node),
NodeType::Gelu => same_as_input(node),
NodeType::GatherElements => same_as_input(node),
Expand Down Expand Up @@ -491,6 +492,33 @@ fn equal_update_outputs(node: &mut Node) {
}
}

fn expand_update_outputs(node: &mut Node) {
let shape = if node.inputs.len() == 2 {
match &node.inputs[1].value {
Some(value) => match value {
Data::Int64s(shape) => Some(shape.clone()),
_ => panic!("Expand: invalid input types"),
},
None => None,
}
} else {
panic!("Expand: invalid number of inputs");
};

let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Expand: invalid output types"),
};

if let Some(shape) = shape {
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: shape.len(),
shape: None, // shape is calculated at runtime
..output
});
}
}

fn shape_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("Shape: multiple inputs are not supported: {:?}", node);
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-import/src/onnx/from_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ use super::ir::{ArgType, Argument, Node, NodeType};

use protobuf::Message;

const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 9] = [
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
NodeType::BatchNormalization,
NodeType::Clip,
NodeType::Conv1d,
NodeType::Conv2d,
NodeType::Dropout,
NodeType::Expand,
NodeType::Reshape,
NodeType::Unsqueeze,
NodeType::ReduceSum,
Expand Down
15 changes: 15 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,21 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig {
.with_count_include_pad(count_include_pad == 1)
}

pub fn expand_config(node: &Node) -> Vec<i64> {
let input_value = &node.inputs[1].value;
match &node.inputs[1].ty {
ArgType::Tensor(tensor) => {
assert_eq!(tensor.dim, 1, "Expand: shape tensor must be 1D");
if let Some(Data::Int64s(shape)) = input_value.as_ref() {
shape.clone()
} else {
panic!("Tensor data type must be int64")
}
}
_ => panic!("Only tensor input is valid for shape"),
}
}

/// Create a FlattenConfig from the attributes of the node
pub fn flatten_config(curr: &Node) -> (usize, usize) {
// the begin dimension is the first dimension (Default: 1 per ONNX spec)
Expand Down
10 changes: 10 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use crate::{
conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode,
dropout::DropoutNode,
expand::ExpandNode,
gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode,
Expand Down Expand Up @@ -244,6 +245,7 @@ impl OnnxGraph {
NodeType::Equal => graph.register(Self::equal_conversion(node)),
NodeType::Erf => graph.register(Self::erf_conversion(node)),
NodeType::Exp => graph.register(Self::exp_conversion(node)),
NodeType::Expand => graph.register(Self::expand_conversion(node)),
NodeType::Clip => graph.register(Self::clip_conversion(node)),
NodeType::Cos => graph.register(Self::cos_conversion(node)),
NodeType::Conv1d => graph.register(Self::conv1d_conversion::<PS>(node)),
Expand Down Expand Up @@ -908,6 +910,14 @@ impl OnnxGraph {
UnaryNode::exp(input, output)
}

fn expand_conversion(node: Node) -> ExpandNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let shape = expand_config(&node);

ExpandNode::new(input, output, shape)
}

fn neg_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
Expand Down

0 comments on commit 37eec0f

Please sign in to comment.