Skip to content

Commit

Permalink
Enable negative starts and ends for slice op
Browse files Browse the repository at this point in the history
  • Loading branch information
JC committed Jul 6, 2024
1 parent d6efb3c commit c2ba5a9
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 15 deletions.
9 changes: 8 additions & 1 deletion crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,18 @@ mod tests {
[
[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15., 16., 17., 18., 19., 20.],
[21., 22., 23., 24., 25., 26., 27., 28., 29., 30.],
[31., 32., 33., 34., 35., 36., 37., 38., 39., 40.],
[41., 42., 43., 44., 45., 46., 47., 48., 49., 50.],
],
&device,
);
let output = model.forward(input);
let expected = TensorData::from([[1f32, 2., 3., 4., 5.]]);
let expected = TensorData::from([
[1f32, 2., 3., 4., 5.],
[11f32, 12., 13., 14., 15.],
[21., 22., 23., 24., 25.],
]);

output.to_data().assert_eq(&expected, true);
}
Expand Down
Binary file modified crates/burn-import/onnx-tests/tests/slice/slice.onnx
Binary file not shown.
13 changes: 6 additions & 7 deletions crates/burn-import/onnx-tests/tests/slice/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import onnx
from onnx import helper, TensorProto


def main() -> None:
# Starts
starts_val = [0,0] # Example shape value
starts_val = [-5, 0] # Equivalently [0, 0]
starts_tensor = helper.make_tensor(
name="starts",
data_type=TensorProto.INT64,
Expand All @@ -23,7 +24,7 @@ def main() -> None:
)

# Ends
ends_val = [1,5] # Example shape value
ends_val = [3, -5] # Equivalently [3, 5]
ends_tensor = helper.make_tensor(
name="ends",
data_type=TensorProto.INT64,
Expand All @@ -39,7 +40,7 @@ def main() -> None:
)

# Axes
axes_val = [0,1] # Example shape value
axes_val = [0, 1] # Example shape value
axes_tensor = helper.make_tensor(
name="axes",
data_type=TensorProto.INT64,
Expand Down Expand Up @@ -83,11 +84,9 @@ def main() -> None:
nodes=[starts_node, ends_node, axes_node, steps_node, slice_node],
name="SliceGraph",
inputs=[
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [2, 10]),
],
outputs=[
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 5])
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [5, 10]),
],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 5])],
)

# Create the model
Expand Down
27 changes: 20 additions & 7 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1287,15 +1287,24 @@ pub fn slice_config(node: &Node) -> (Vec<usize>, Vec<usize>) {
let start_value = &node.inputs[1].value;
let end_value = &node.inputs[2].value;

let tensor_shape: &Vec<usize> = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.shape.as_ref().unwrap(),
_ => panic!("Only tensor input is valid"),
};

let starts = match &node.inputs[1].ty {
ArgType::Tensor(tensor) => {
assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D");
assert_eq!(tensor.dim, 1, "Slice: starts tensor must be 1D");
if let Some(Data::Int64s(shape)) = start_value.as_ref() {
shape
.iter()
.map(|x| {
assert!(*x >= 0, "Slice: start must be positive");
*x as usize
.enumerate()
.map(|(i, x)| {
if x.is_negative() {
tensor_shape[i] - x.wrapping_abs() as usize
} else {
*x as usize
}
})
.collect()
} else {
Expand All @@ -1311,9 +1320,13 @@ pub fn slice_config(node: &Node) -> (Vec<usize>, Vec<usize>) {
if let Some(Data::Int64s(shape)) = end_value.as_ref() {
shape
.iter()
.map(|x| {
assert!(*x >= 0, "Slice: end must be positive");
*x as usize
.enumerate()
.map(|(i, x)| {
if x.is_negative() {
tensor_shape[i] - x.wrapping_abs() as usize
} else {
*x as usize
}
})
.collect()
} else {
Expand Down

0 comments on commit c2ba5a9

Please sign in to comment.