Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable negative starts and ends for slice op #1981

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,18 @@ mod tests {
[
[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15., 16., 17., 18., 19., 20.],
[21., 22., 23., 24., 25., 26., 27., 28., 29., 30.],
[31., 32., 33., 34., 35., 36., 37., 38., 39., 40.],
[41., 42., 43., 44., 45., 46., 47., 48., 49., 50.],
],
&device,
);
let output = model.forward(input);
let expected = TensorData::from([[1f32, 2., 3., 4., 5.]]);
let expected = TensorData::from([
[1f32, 2., 3., 4., 5.],
[11f32, 12., 13., 14., 15.],
[21., 22., 23., 24., 25.],
]);

output.to_data().assert_eq(&expected, true);
}
Expand Down
Binary file modified crates/burn-import/onnx-tests/tests/slice/slice.onnx
Binary file not shown.
13 changes: 6 additions & 7 deletions crates/burn-import/onnx-tests/tests/slice/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import onnx
from onnx import helper, TensorProto


def main() -> None:
# Starts
starts_val = [0,0] # Example shape value
starts_val = [-5, 0] # Equivalently [0, 0]
starts_tensor = helper.make_tensor(
name="starts",
data_type=TensorProto.INT64,
Expand All @@ -23,7 +24,7 @@ def main() -> None:
)

# Ends
ends_val = [1,5] # Example shape value
ends_val = [3, -5] # Equivalently [3, 5]
ends_tensor = helper.make_tensor(
name="ends",
data_type=TensorProto.INT64,
Expand All @@ -39,7 +40,7 @@ def main() -> None:
)

# Axes
axes_val = [0,1] # Example shape value
axes_val = [0, 1] # Example shape value
axes_tensor = helper.make_tensor(
name="axes",
data_type=TensorProto.INT64,
Expand Down Expand Up @@ -83,11 +84,9 @@ def main() -> None:
nodes=[starts_node, ends_node, axes_node, steps_node, slice_node],
name="SliceGraph",
inputs=[
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [2, 10]),
],
outputs=[
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 5])
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [5, 10]),
],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 5])],
)

# Create the model
Expand Down
65 changes: 30 additions & 35 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1284,44 +1284,39 @@ pub fn shape_config(curr: &Node) -> (usize, usize) {
}

pub fn slice_config(node: &Node) -> (Vec<usize>, Vec<usize>) {
let start_value = &node.inputs[1].value;
let end_value = &node.inputs[2].value;

let starts = match &node.inputs[1].ty {
ArgType::Tensor(tensor) => {
assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D");
if let Some(Data::Int64s(shape)) = start_value.as_ref() {
shape
.iter()
.map(|x| {
assert!(*x >= 0, "Slice: start must be positive");
fn ensure_1d_tensor(node: &Node, index: usize) {
match &node.inputs[index].ty {
ArgType::Tensor(tensor) => assert_eq!(tensor.dim, 1, "Slice: tensor must be 1D"),
_ => panic!("Only tensor input is valid"),
};
}

fn get_input_values(node: &Node, index: usize) -> Vec<usize> {
let tensor_shape = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.shape.as_ref().unwrap(),
_ => panic!("Only tensor input is valid"),
};
match &node.inputs[index].value {
Some(Data::Int64s(shape)) => shape
.iter()
.enumerate()
.map(|(i, x)| {
if x.is_negative() {
tensor_shape[i] - x.wrapping_abs() as usize
} else {
*x as usize
})
.collect()
} else {
panic!("Tensor data type must be int64")
}
}
})
.collect(),
_ => panic!("Tensor data type must be int64"),
}
_ => panic!("Only tensor input is valid for shape"),
};
}

let ends = match &node.inputs[2].ty {
ArgType::Tensor(tensor) => {
assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D");
if let Some(Data::Int64s(shape)) = end_value.as_ref() {
shape
.iter()
.map(|x| {
assert!(*x >= 0, "Slice: end must be positive");
*x as usize
})
.collect()
} else {
panic!("Tensor data type must be int64")
}
}
_ => panic!("Only tensor input is valid for shape"),
};
ensure_1d_tensor(node, 1);
ensure_1d_tensor(node, 2);

let starts = get_input_values(node, 1);
let ends = get_input_values(node, 2);

for (key, value) in node.attrs.iter() {
match key.as_str() {
Expand Down
Loading