Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnx op topk #2305

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ represent the corresponding Burn Op.
| [TfIdfVectorizer][183] | ❌ | ❌ |
| [ThresholdedRelu][184] | ❌ | ❌ |
| [Tile][185] | ✅ | ✅ |
| [TopK][186] | | ✅ |
| [TopK][186] | | ✅ |
| [Transpose][187] | ✅ | ✅ |
| [Trilu][188] | ❌ | ✅ |
| [Unique][189] | ❌ | ❌ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ fn main() {
.input("tests/sum/sum_int.onnx")
.input("tests/tanh/tanh.onnx")
.input("tests/tile/tile.onnx")
.input("tests/top_k/top_k.onnx")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks the CI caught something I missed! This file doesn't exist anymore with your changes :)

.input("tests/transpose/transpose.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
Expand Down
29 changes: 28 additions & 1 deletion crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ include_models!(
sum_int,
tanh,
tile,
top_k_opset_1,
transpose,
unsqueeze,
unsqueeze_opset11,
Expand All @@ -128,7 +129,7 @@ mod tests {

use super::*;

use burn::tensor::{Bool, Int, Shape, Tensor, TensorData};
use burn::tensor::{cast::ToElement, Bool, Int, Shape, Tensor, TensorData};

use float_cmp::ApproxEq;

Expand Down Expand Up @@ -2125,4 +2126,30 @@ mod tests {
assert!(i_output.equal(i_expected).all().into_scalar());
assert!(b_output.equal(b_expected).all().into_scalar());
}

#[test]
fn top_k_opset_1() {
// Initialize the model
let device = Default::default();
let model = top_k_opset1::Model::<Backend>::new(&device);

// Run the model
let input = Tensor::<Backend, 2>::from_floats(
[[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]],
&device,
);
let (values_tensor, indices_tensor) = model.forward(input);

// expected results
let expected_values_tensor =
TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]);
let expected_indices_tensor = TensorData::from([[3, 2, 1], [3, 2, 1]]);

values_tensor
.to_data()
.assert_eq(&expected_values_tensor, true);
indices_tensor
.to_data()
.assert_eq(&expected_indices_tensor, true);
}
}
Binary file added crates/burn-import/onnx-tests/tests/top_k/top_k.onnx
Binary file not shown.
75 changes: 75 additions & 0 deletions crates/burn-import/onnx-tests/tests/top_k/top_k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import numpy as np
import onnx
from onnx import helper, TensorProto

# Define the input tensor
X = np.array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=np.float32)

# Define the value of K
k = 3
K = np.array([k], dtype=np.int64)
axis = 1
new_dims = [X.shape[0], k]

def create_model(op_set_version: int):
input_tensors = [helper.make_tensor_value_info('X', TensorProto.FLOAT, X.shape)]

output_tensors = [
helper.make_tensor_value_info('Values', TensorProto.FLOAT, new_dims),
helper.make_tensor_value_info('Indices', TensorProto.INT32, new_dims)
]

# Create the TopK node
if op_set_version > 1:
node = helper.make_node(
'TopK',
inputs=['X', 'K'],
outputs=['Values', 'Indices'],
axis=axis, # Axis along which to find the top K elements
)
input_tensors.append(helper.make_tensor_value_info('K', TensorProto.INT32, K.shape))
else:
node = helper.make_node(
'TopK',
inputs=['X'],
outputs=['Values', 'Indices'],
axis=axis, # Axis along which to find the top K elements
k=k
)

# Create the graph
graph = helper.make_graph(
nodes = [node],
name = 'TopKGraph',
inputs = input_tensors,
outputs = output_tensors,
# Unconmment when initializers are supported. Currently we can't test opset 10/11 since the code will require a k value to be initialized for testing.
#initializer = [
# helper.make_tensor('X', TensorProto.FLOAT, X.shape, X),
# helper.make_tensor('K', TensorProto.INT64, [1], [k]),
#]
)

# Create the model
model = helper.make_model(
graph,
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", op_set_version)]
)
# Check the model
onnx.checker.check_model(model)

# Save the model to a file
onnx.save(model, f'top_k_opset_{op_set_version}.onnx')
print(f"Model saved to top_k_opset_{op_set_version}.onnx")

def main():
# Unconmment when initializers are supported.
for op_set_version in [1, 10, 11]:
create_model(op_set_version)


if __name__ == "__main__":
main()
Binary file not shown.
6 changes: 5 additions & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use super::{
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, top_k::TopKNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -114,6 +115,7 @@ pub enum Node<PS: PrecisionSettings> {
Squeeze(SqueezeNode),
Sum(SumNode),
Tile(TileNode),
TopK(TopKNode),
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Where(WhereNode),
Expand Down Expand Up @@ -162,6 +164,7 @@ macro_rules! match_all {
Node::Squeeze(node) => $func(node),
Node::Sum(node) => $func(node),
Node::Tile(node) => $func(node),
Node::TopK(node) => $func(node),
Node::Unary(node) => $func(node),
Node::Unsqueeze(node) => $func(node),
Node::Where(node) => $func(node),
Expand Down Expand Up @@ -218,6 +221,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Squeeze(_) => "squeeze",
Node::Sum(_) => "add",
Node::Tile(_) => "tile",
Node::TopK(_) => "top_k",
Node::Unary(unary) => unary.kind.as_str(),
Node::Unsqueeze(_) => "unsqueeze",
Node::Where(_) => "where",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub(crate) mod slice;
pub(crate) mod squeeze;
pub(crate) mod sum;
pub(crate) mod tile;
pub(crate) mod top_k;
pub(crate) mod unary;
pub(crate) mod unsqueeze;
pub(crate) use base::*;
Expand Down
114 changes: 114 additions & 0 deletions crates/burn-import/src/burn/node/top_k.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use burn::config::Config;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};

#[derive(Config, Debug)]
pub struct TopKConfig {
pub axis: usize,
pub k: usize,
pub largest: usize,
}

#[derive(Debug, Clone, new)]
pub struct TopKNode {
pub input: TensorType,
pub outputs: Vec<TensorType>,
pub config: TopKConfig,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for TopKNode {
fn output_types(&self) -> Vec<Type> {
self.outputs
.iter()
.map(|t| Type::Tensor(t.clone()))
.collect()
}

fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let axis = self.config.axis.to_token_stream();
let k = self.config.k.to_token_stream();
let largest = self.config.largest.to_token_stream();

let input = scope.tensor_use_owned(&self.input, node_position);
let values_output = &self.outputs[0].name;
let indices_output = &self.outputs[1].name;

quote! {
let (#values_output, #indices_output) = #input.topk_with_indices(#k, #axis, #largest);
}
}

fn into_node(self) -> Node<PS> {
Node::TopK(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{test::assert_tokens, top_k::TopKNode},
TensorType,
};

#[test]
fn test_codegen_nodes() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
let config = TopKConfig::new(1, 3, 1);

graph.register(TopKNode::new(
TensorType::new_float("input_tensor", 4),
vec![
TensorType::new_float("values_tensor", 4),
TensorType::new_int("indices_tensor", 4),
],
config,
));

graph.register_input_output(
vec!["input_tensor".to_string()],
vec!["values_tensor".to_string(), "indices_tensor".to_string()],
);

let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input_tensor: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 4, Int>) {
let (values_tensor, indices_tensor) = input_tensor.topk_with_indices(3usize, 1usize, 1usize);
(values_tensor, indices_tensor)
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
41 changes: 40 additions & 1 deletion crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@ use burn::nn::{
PaddingConfig2d, PaddingConfig3d,
};

use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig};
use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig, top_k::TopKConfig};
use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node};

/// Extract and convert a given attribute to i64
fn extract_attr_value_i64(node: &Node, key: &str) -> i64 {
let error_msg = format!("Expected the following attribute key: {:?}", key);
let value = node.attrs.get(key).expect(&error_msg).clone().into_i64();
value
}

/// Create a Conv1dConfig from the attributes of the node
pub fn conv1d_config(curr: &Node) -> Conv1dConfig {
let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec
Expand Down Expand Up @@ -795,6 +802,38 @@ pub fn tile_config(node: &Node) -> TileConfig {
TileConfig::new(repeat)
}

/// Create a TopKConfig from the attributes of the node.
pub fn top_k_config(node: &Node) -> TopKConfig {
// extract the shape of the input data tensor
let data_tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

let k = match node.inputs.get(1) {
Some(k_tensor) => k_tensor
.clone()
.value
.expect("Expecting K tensor to have a value.")
.into_i64s()[0],
_ => extract_attr_value_i64(node, "k"),
};

let mut axis: i64 = extract_attr_value_i64(node, "axis");

// if axis is negative, it is counted from the end
if axis < 0 {
axis += data_tensor.dim as i64;
}

let largest = match node.attrs.get("largest") {
Some(val) => val.clone().into_i64(),
_ => 1,
};
Comment on lines +829 to +832
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're not checking for "k" as the second input of the node (for opsets 10, 11) and just adding support for opset 1, then we don't need to check for the "largest" attribute here. It's only present in the later version 11 of the op.

So we can remove this from the config and node.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure thing will remove and resubmit tonight


TopKConfig::new(axis as usize, k as usize, largest as usize)
}

/// Create a PadConfig from the attributes of the node
pub fn pad_config(node: &Node) -> PadConfig {
fn get_pads_input(node: &Node) -> Vec<i64> {
Expand Down
Loading
Loading