Skip to content

Commit

Permalink
feat: added expand to import
Browse files Browse the repository at this point in the history
  • Loading branch information
JachymPutta committed May 31, 2024
1 parent da8a522 commit c282d08
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 16 deletions.
12 changes: 6 additions & 6 deletions crates/burn-import/onnx-tests/tests/expand/expand.onnx
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@

expand:�
expand:�
>shapeshape_constant"Constant*
value*:Bshape�
.
input_tensor
shapeoutput/Expand"Expand ExpandGraphZ
input_tensor


b
shapeoutput/Expand"Expand ExpandGraphZ
input_tensor


b
output


Expand Down
2 changes: 1 addition & 1 deletion crates/burn-import/onnx-tests/tests/expand/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def main() -> None:
nodes=[shape_node, expand_node],
name='ExpandGraph',
inputs=[
helper.make_tensor_value_info('input_tensor', TensorProto.FLOAT, [2]),
helper.make_tensor_value_info('input_tensor', TensorProto.FLOAT, [2, 1]),
],
outputs=[
helper.make_tensor_value_info('output', TensorProto.FLOAT, [2, 2])
Expand Down
11 changes: 5 additions & 6 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1124,13 +1124,12 @@ mod tests {
let device = Default::default();
let model: expand::Model<Backend> = expand::Model::new(&device);

let input1 = Tensor::<Backend, 4>::from_floats([[[[-1.0, 1.0, 42.0, 3.0]]]], &device);
let input2 = Tensor::<Backend, 1, Int>::from_ints([3, 2], &device);
let input1 = Tensor::<Backend, 2>::from_floats([[-1.0], [1.0]], &device);

// let output = model.forward(input1, input2);
// let expected_shape = Shape::from([3, 2]);
//
// assert_eq!(output.shape(), expected_shape);
let output = model.forward(input1);
let expected_shape = Shape::from([2, 2]);

assert_eq!(output.shape(), expected_shape);
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ fn expand_update_outputs(node: &mut Node) {
None => None,
}
} else {
node.attrs.get("shape").cloned().map(|v| v.into_i64s())
panic!("Expand: invalid number of inputs");
};

let output = match &node.outputs[0].ty {
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-import/src/onnx/from_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ use super::ir::{ArgType, Argument, Node, NodeType};

use protobuf::Message;

const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 9] = [
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
NodeType::BatchNormalization,
NodeType::Clip,
NodeType::Conv1d,
NodeType::Conv2d,
NodeType::Dropout,
NodeType::Expand,
NodeType::Reshape,
NodeType::Unsqueeze,
NodeType::ReduceSum,
Expand Down
1 change: 0 additions & 1 deletion crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,6 @@ impl OnnxGraph {
fn expand_conversion(node: Node) -> ExpandNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
println!("{:?}", node);
let shape = expand_config(&node);

ExpandNode::new(input, output, shape)
Expand Down

0 comments on commit c282d08

Please sign in to comment.