Skip to content

Commit

Permalink
Enable negative starts and ends for slice op (#1981)
Browse files Browse the repository at this point in the history
* Enable negative starts and ends for slice op

* Refactor slice_config

---------

Co-authored-by: JC <you@example.com>
  • Loading branch information
johnhuichen and JC authored Jul 7, 2024
1 parent 3f9e979 commit c9e9054
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 43 deletions.
9 changes: 8 additions & 1 deletion crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,18 @@ mod tests {
[
[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15., 16., 17., 18., 19., 20.],
[21., 22., 23., 24., 25., 26., 27., 28., 29., 30.],
[31., 32., 33., 34., 35., 36., 37., 38., 39., 40.],
[41., 42., 43., 44., 45., 46., 47., 48., 49., 50.],
],
&device,
);
let output = model.forward(input);
let expected = TensorData::from([[1f32, 2., 3., 4., 5.]]);
let expected = TensorData::from([
[1f32, 2., 3., 4., 5.],
[11f32, 12., 13., 14., 15.],
[21., 22., 23., 24., 25.],
]);

output.to_data().assert_eq(&expected, true);
}
Expand Down
Binary file modified crates/burn-import/onnx-tests/tests/slice/slice.onnx
Binary file not shown.
13 changes: 6 additions & 7 deletions crates/burn-import/onnx-tests/tests/slice/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import onnx
from onnx import helper, TensorProto


def main() -> None:
# Starts
starts_val = [0,0] # Example shape value
starts_val = [-5, 0] # Equivalently [0, 0]
starts_tensor = helper.make_tensor(
name="starts",
data_type=TensorProto.INT64,
Expand All @@ -23,7 +24,7 @@ def main() -> None:
)

# Ends
ends_val = [1,5] # Example shape value
ends_val = [3, -5] # Equivalently [3, 5]
ends_tensor = helper.make_tensor(
name="ends",
data_type=TensorProto.INT64,
Expand All @@ -39,7 +40,7 @@ def main() -> None:
)

# Axes
axes_val = [0,1] # Example shape value
axes_val = [0, 1] # Example shape value
axes_tensor = helper.make_tensor(
name="axes",
data_type=TensorProto.INT64,
Expand Down Expand Up @@ -83,11 +84,9 @@ def main() -> None:
nodes=[starts_node, ends_node, axes_node, steps_node, slice_node],
name="SliceGraph",
inputs=[
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [2, 10]),
],
outputs=[
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 5])
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [5, 10]),
],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 5])],
)

# Create the model
Expand Down
65 changes: 30 additions & 35 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1284,44 +1284,39 @@ pub fn shape_config(curr: &Node) -> (usize, usize) {
}

pub fn slice_config(node: &Node) -> (Vec<usize>, Vec<usize>) {
let start_value = &node.inputs[1].value;
let end_value = &node.inputs[2].value;

let starts = match &node.inputs[1].ty {
ArgType::Tensor(tensor) => {
assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D");
if let Some(Data::Int64s(shape)) = start_value.as_ref() {
shape
.iter()
.map(|x| {
assert!(*x >= 0, "Slice: start must be positive");
fn ensure_1d_tensor(node: &Node, index: usize) {
match &node.inputs[index].ty {
ArgType::Tensor(tensor) => assert_eq!(tensor.dim, 1, "Slice: tensor must be 1D"),
_ => panic!("Only tensor input is valid"),
};
}

fn get_input_values(node: &Node, index: usize) -> Vec<usize> {
let tensor_shape = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.shape.as_ref().unwrap(),
_ => panic!("Only tensor input is valid"),
};
match &node.inputs[index].value {
Some(Data::Int64s(shape)) => shape
.iter()
.enumerate()
.map(|(i, x)| {
if x.is_negative() {
tensor_shape[i] - x.wrapping_abs() as usize
} else {
*x as usize
})
.collect()
} else {
panic!("Tensor data type must be int64")
}
}
})
.collect(),
_ => panic!("Tensor data type must be int64"),
}
_ => panic!("Only tensor input is valid for shape"),
};
}

let ends = match &node.inputs[2].ty {
ArgType::Tensor(tensor) => {
assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D");
if let Some(Data::Int64s(shape)) = end_value.as_ref() {
shape
.iter()
.map(|x| {
assert!(*x >= 0, "Slice: end must be positive");
*x as usize
})
.collect()
} else {
panic!("Tensor data type must be int64")
}
}
_ => panic!("Only tensor input is valid for shape"),
};
ensure_1d_tensor(node, 1);
ensure_1d_tensor(node, 2);

let starts = get_input_values(node, 1);
let ends = get_input_values(node, 2);

for (key, value) in node.attrs.iter() {
match key.as_str() {
Expand Down

0 comments on commit c9e9054

Please sign in to comment.