From da34ba5e1a54301ad4d3271de4db7bd71a3e0d31 Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Tue, 24 Sep 2024 23:22:34 -0400 Subject: [PATCH 01/14] feat: add topk operation --- crates/burn-import/onnx-tests/build.rs | 1 + .../burn-import/onnx-tests/tests/test_onnx.rs | 15 +++ .../onnx-tests/tests/top_k/top_k.onnx | Bin 0 -> 156 bytes .../onnx-tests/tests/top_k/top_k.py | 57 +++++++++ crates/burn-import/src/burn/graph.rs | 1 + crates/burn-import/src/burn/node/base.rs | 6 +- crates/burn-import/src/burn/node/mod.rs | 1 + crates/burn-import/src/burn/node/top_k.rs | 111 ++++++++++++++++++ .../burn-import/src/onnx/op_configuration.rs | 18 ++- crates/burn-import/src/onnx/to_burn.rs | 17 ++- crates/onnx-ir/src/dim_inference.rs | 30 +++++ crates/onnx-ir/src/from_onnx.rs | 19 ++- 12 files changed, 267 insertions(+), 9 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/top_k/top_k.onnx create mode 100644 crates/burn-import/onnx-tests/tests/top_k/top_k.py create mode 100644 crates/burn-import/src/burn/node/top_k.rs diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index a7360e012d..ff4a034cbe 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -107,6 +107,7 @@ fn main() { .input("tests/sum/sum_int.onnx") .input("tests/tanh/tanh.onnx") .input("tests/tile/tile.onnx") + .input("tests/top_k/top_k.onnx") .input("tests/transpose/transpose.onnx") .input("tests/unsqueeze/unsqueeze.onnx") .input("tests/unsqueeze/unsqueeze_opset11.onnx") diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index a4cd32b485..ad5a011915 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -116,6 +116,7 @@ include_models!( sum_int, tanh, tile, + top_k, transpose, unsqueeze, unsqueeze_opset11, @@ -2125,4 +2126,18 @@ mod tests { assert!(i_output.equal(i_expected).all().into_scalar()); assert!(b_output.equal(b_expected).all().into_scalar()); } + + #[test] + fn top_k() { + // Initialize the model + let device = Default::default(); + let model = top_k::Model::::new(&device); + + // Run the model + let input = Tensor::::from_floats([[1., 2., 3., 4.]], &device); + let output = model.forward(input); + // data from pyTorch + let expected = TensorData::from([[1., 2., 3., 4.]]); + assert!(&expected, output); + } } diff --git a/crates/burn-import/onnx-tests/tests/top_k/top_k.onnx b/crates/burn-import/onnx-tests/tests/top_k/top_k.onnx new file mode 100644 index 0000000000000000000000000000000000000000..814fbeaefd8d9e275b09d07dfbeb30643240d750 GIT binary patch literal 156 zcmd BurnGraph { } } + println!("The node input types are {:#?}", &inputs); // Get the input and output types of the graph using passed in names input_names.iter().for_each(|input| { self.graph_input_types.push( diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index a1c9103b41..6c955033cd 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -11,7 +11,8 @@ use super::{ max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, - squeeze::SqueezeNode, sum::SumNode, tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, + squeeze::SqueezeNode, sum::SumNode, tile::TileNode, top_k::TopKNode, unary::UnaryNode, + unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; @@ -114,6 +115,7 @@ pub enum Node { Squeeze(SqueezeNode), Sum(SumNode), Tile(TileNode), + TopK(TopKNode), Unary(UnaryNode), Unsqueeze(UnsqueezeNode), Where(WhereNode), @@ -162,6 +164,7 @@ macro_rules! match_all { Node::Squeeze(node) => $func(node), Node::Sum(node) => $func(node), Node::Tile(node) => $func(node), + Node::TopK(node) => $func(node), Node::Unary(node) => $func(node), Node::Unsqueeze(node) => $func(node), Node::Where(node) => $func(node), @@ -218,6 +221,7 @@ impl Node { Node::Squeeze(_) => "squeeze", Node::Sum(_) => "add", Node::Tile(_) => "tile", + Node::TopK(_) => "top_k", Node::Unary(unary) => unary.kind.as_str(), Node::Unsqueeze(_) => "unsqueeze", Node::Where(_) => "where", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index ee294ddfd7..605cd9d8f2 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -37,6 +37,7 @@ pub(crate) mod slice; pub(crate) mod squeeze; pub(crate) mod sum; pub(crate) mod tile; +pub(crate) mod top_k; pub(crate) mod unary; pub(crate) mod unsqueeze; pub(crate) use base::*; diff --git a/crates/burn-import/src/burn/node/top_k.rs b/crates/burn-import/src/burn/node/top_k.rs new file mode 100644 index 0000000000..ebf2a22290 --- /dev/null +++ b/crates/burn-import/src/burn/node/top_k.rs @@ -0,0 +1,111 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, Type}; +use burn::config::Config; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; + +// We omit the sorted option because the burn topk impls already come with the topk sorted +#[derive(Config, Debug)] +pub struct TopKConfig { + pub axis: i64, + pub k: i64, +} + + +#[derive(Debug, Clone, new)] +pub struct TopKNode { + pub input: TensorType, + pub outputs: Vec, + pub config: TopKConfig, +} + +impl NodeCodegen for TopKNode { + fn output_types(&self) -> Vec { + self.outputs + .iter() + .map(|t| Type::Tensor(t.clone())) + .collect() + } + + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let axis = self.config.axis.to_token_stream(); + let k = self.config.k.to_token_stream(); + + let input = scope.tensor_use_owned(&self.input, node_position); + let values_output = &self.outputs[0].name; + let indices_output = &self.outputs[1].name; + + quote! { + let (#values_output, #indices_output) = #input.topk_with_indices(#k as usize, #axis as usize); + } + } + + fn into_node(self) -> Node { + Node::TopK(self) + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{top_k::TopKNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_nodes() { + let mut graph = BurnGraph::::default(); + let config = TopKConfig::new(-1, 3); + + graph.register(TopKNode::new( + TensorType::new_float("input_tensor", 4), + vec![ + TensorType::new_float("values_tensor", 4), + TensorType::new_int("indices_tensor", 4), + ], + config, + )); + + graph.register_input_output(vec!["input_tensor".to_string()], vec!["values_tensor".to_string(), "indices_tensor".to_string()]); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input_tensor: Tensor) -> (Tensor, Tensor) { + let (values_tensor, indices_tensor) = input_tensor.topk_with_indices(3i64 as usize, -1i64 as usize); + (values_tensor, indices_tensor) + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} \ No newline at end of file diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 4621b129d6..b8daad5660 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -7,7 +7,7 @@ use burn::nn::{ PaddingConfig2d, PaddingConfig3d, }; -use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig}; +use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig, top_k::TopKConfig}; use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node}; /// Create a Conv1dConfig from the attributes of the node @@ -795,6 +795,22 @@ pub fn tile_config(node: &Node) -> TileConfig { TileConfig::new(repeat) } +fn extract_attr_value_i64(node: &Node, key: &str) -> i64 { + let value = node.attrs + .get(key) + .unwrap() + .clone() + .into_i64(); + value +} +/// Create a TopKConfig from the attributes of the node. We don't extract sorted from the TopK node as our topk impl already returns in sorted order. +pub fn top_k_config(node: &Node) -> TopKConfig { + let axis:i64 = extract_attr_value_i64(node, "axis"); + let k: i64 = extract_attr_value_i64(node, "k"); + + TopKConfig::new(axis, k) +} + /// Create a PadConfig from the attributes of the node pub fn pad_config(node: &Node) -> PadConfig { fn get_pads_input(node: &Node) -> Vec { diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 5c4f34078f..cf7ab7637e 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -51,6 +51,7 @@ use crate::{ squeeze::SqueezeNode, sum::SumNode, tile::TileNode, + top_k::TopKNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }, @@ -67,8 +68,8 @@ use super::op_configuration::{ hard_sigmoid_config, layer_norm_config, leaky_relu_config, linear_config, log_softmax_config, max_pool1d_config, max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config, - shape_config, slice_config, softmax_config, squeeze_config, tile_config, transpose_config, - unsqueeze_config, + shape_config, slice_config, softmax_config, squeeze_config, tile_config, top_k_config, + transpose_config, unsqueeze_config, }; use onnx_ir::{ convert_constant_value, @@ -338,6 +339,7 @@ impl ParsedOnnxGraph { NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), NodeType::RandomUniform => graph.register(Self::random_uniform_conversion(node)), NodeType::Tile => graph.register(Self::tile_conversion(node)), + NodeType::TopK => graph.register(Self::top_k_conversion(node)), NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)), NodeType::ConstantOfShape => { graph.register(Self::constant_of_shape_conversion(node)) @@ -1184,6 +1186,17 @@ impl ParsedOnnxGraph { TileNode::new(input, output, config) } + + fn top_k_conversion(node: Node) -> TopKNode { + // Inputs + let input = TensorType::from(node.inputs.first().unwrap()); + + // Outputs + let outputs = node.outputs.iter().map(TensorType::from).collect(); + let config = top_k_config(&node); + + TopKNode::new(input, outputs, config) + } } /// Extract data from node states and convert it to `TensorData`. diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index ff580b37aa..737f36d964 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -82,6 +82,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::Sub => same_as_input_broadcast(node), NodeType::Sum => same_as_input_broadcast(node), NodeType::Tanh => same_as_input(node), + NodeType::TopK => top_k_update_output(node), NodeType::Transpose => same_as_input(node), NodeType::Unsqueeze => unsqueeze_update_output(node), NodeType::Where => where_update_outputs(node), @@ -477,6 +478,35 @@ fn same_as_input(node: &mut Node) { node.outputs[0].ty = node.inputs[0].ty.clone(); } +fn top_k_update_output(node: &mut Node) { + let dim = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.dim, + _ => panic!("TopK: invalid input type"), + }; + + let output_values_elem = match &node.outputs[0].ty { + ArgType::Tensor(tensor) => tensor.elem_type.clone(), + _ => panic!("TopK: invalid output type"), + }; + + let output_indices_elem = match &node.outputs[1].ty { + ArgType::Tensor(_) => ElementType::Int64, + _ => panic!("TopK: invalid output type"), + }; + + node.outputs[0].ty = ArgType::Tensor(TensorType { + dim: dim, + shape: None, // shape is tracked and calculated at runtime + elem_type: output_values_elem, + }); + + node.outputs[1].ty = ArgType::Tensor(TensorType { + dim: dim, + shape: None, // shape is tracked and calculated at runtime + elem_type: output_indices_elem, + }); +} + /// Temporary pass-through stub for dimension inference so that we can export the IR model. fn temporary_pass_through_stub(node: &mut Node) { log::warn!("Must implement dimension inference for {:?}", node); diff --git a/crates/onnx-ir/src/from_onnx.rs b/crates/onnx-ir/src/from_onnx.rs index fa30bcf83c..d77eb11e01 100644 --- a/crates/onnx-ir/src/from_onnx.rs +++ b/crates/onnx-ir/src/from_onnx.rs @@ -236,14 +236,22 @@ impl OnnxGraphBuilder { keep }); - // TODO Update graph inputs and outputs to match the processed nodes inputs and outputs - // This is necessary for the graph to be valid - // ConstantOfShape updates input to be Shape argument and output Tensor dim is updated + let mut processed_nodes_inputs : Vec = Vec::new(); + let mut processed_nodes_outputs : Vec = Vec::new(); + for processed_node in processed_nodes.iter() { + processed_nodes_inputs.extend(processed_node.inputs.clone()); + processed_nodes_outputs.extend(processed_node.outputs.clone()); + } + + println!("The from onnx processed in/outputs are: {:#?}", (&processed_nodes_inputs, &processed_nodes_outputs)); + println!("The from onnx regular in/outputs are: {:#?}", (&inputs, &outputs)); + + // TODO: ConstantOfShape updates input to be Shape argument and output Tensor dim is updated OnnxGraph { nodes: processed_nodes, - inputs, - outputs, + inputs: processed_nodes_inputs, + outputs: processed_nodes_outputs } } @@ -370,6 +378,7 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { let builder = OnnxGraphBuilder::default(); let graph = builder.build(&onnx_model); + log::info!("Onnx graph: {:#?}", graph); log::info!("Finished parsing ONNX file: {}", onnx_path.display()); graph From 297394510efd1c234d682887068c19590fae3bb4 Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:00:07 -0400 Subject: [PATCH 02/14] fix test --- .../burn-import/onnx-tests/tests/test_onnx.rs | 10 +++++----- .../onnx-tests/tests/top_k/top_k.onnx | Bin 156 -> 147 bytes .../burn-import/onnx-tests/tests/top_k/top_k.py | 2 +- crates/burn-import/src/burn/graph.rs | 1 - 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index ad5a011915..e09f03b6e5 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -129,7 +129,7 @@ mod tests { use super::*; - use burn::tensor::{Bool, Int, Shape, Tensor, TensorData}; + use burn::tensor::{cast::ToElement, Bool, Int, Shape, Tensor, TensorData}; use float_cmp::ApproxEq; @@ -2134,10 +2134,10 @@ mod tests { let model = top_k::Model::::new(&device); // Run the model - let input = Tensor::::from_floats([[1., 2., 3., 4.]], &device); - let output = model.forward(input); + let input = Tensor::::from_floats([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], &device); + let (values_tensor, _indices_tensor) = model.forward(input); // data from pyTorch - let expected = TensorData::from([[1., 2., 3., 4.]]); - assert!(&expected, output); + let expected = TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]); + values_tensor.to_data().assert_eq(&expected, true); } } diff --git a/crates/burn-import/onnx-tests/tests/top_k/top_k.onnx b/crates/burn-import/onnx-tests/tests/top_k/top_k.onnx index 814fbeaefd8d9e275b09d07dfbeb30643240d750..4eb08a05c9df6d563d67962e00a26ab83cfec0f2 100644 GIT binary patch delta 50 zcmbQkIGIt4gTtzWk;{;aF+zwfEHS4vwOEMVGcP4GIki}cB_zMVTZ@~EC9xv2SYo1a FCIB&34GsVR delta 59 zcmbQtIET@ggTrbfBbOByV}uY}SYl3TYOxTzXI@HXa%!;>OGti!x0VPOOJYT4vBZBE In5dNr0O|1-AOHXW diff --git a/crates/burn-import/onnx-tests/tests/top_k/top_k.py b/crates/burn-import/onnx-tests/tests/top_k/top_k.py index 0e807121f6..ca58d8d213 100644 --- a/crates/burn-import/onnx-tests/tests/top_k/top_k.py +++ b/crates/burn-import/onnx-tests/tests/top_k/top_k.py @@ -10,7 +10,7 @@ # Define the value of K k = 3 K = np.array([k], dtype=np.int64) -axis = -1 +axis = 1 new_dims = [X.shape[0], k] input_tensors = [ diff --git a/crates/burn-import/src/burn/graph.rs b/crates/burn-import/src/burn/graph.rs index 35ef372ea7..3ad633fafa 100644 --- a/crates/burn-import/src/burn/graph.rs +++ b/crates/burn-import/src/burn/graph.rs @@ -549,7 +549,6 @@ impl BurnGraph { } } - println!("The node input types are {:#?}", &inputs); // Get the input and output types of the graph using passed in names input_names.iter().for_each(|input| { self.graph_input_types.push( From 578c3467854a1fbab3eddbaff53a3bf199f199f1 Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:01:17 -0400 Subject: [PATCH 03/14] fix: update IOEntry::Node j in the input_name map --- crates/onnx-ir/src/from_onnx.rs | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/crates/onnx-ir/src/from_onnx.rs b/crates/onnx-ir/src/from_onnx.rs index d77eb11e01..faf666c7a1 100644 --- a/crates/onnx-ir/src/from_onnx.rs +++ b/crates/onnx-ir/src/from_onnx.rs @@ -151,7 +151,7 @@ impl GraphData { for output in node.outputs.iter_mut() { self.input_name_map.insert( output.name.clone(), - IOEntry::Node(self.processed_nodes.len(), 0), + IOEntry::Node(self.processed_nodes.len(), out_count - 1), ); output.name = format!("{}_out{}", node.name, out_count); out_count += 1; @@ -202,11 +202,14 @@ impl OnnxGraphBuilder { pub(crate) fn build(mut self, model_proto: &ModelProto) -> OnnxGraph { self.constants_types = LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); + println!("here are the proto graph inputs: {:#?}", &model_proto.graph.input); + println!("here are the proto graph outputs: {:#?}", &model_proto.graph.output); let mut graph_data = GraphData::new( &model_proto.graph.input, &model_proto.graph.output, &model_proto.graph.initializer, ); + println!("here are the proto graph outputs created: {:#?}", &graph_data.outputs); let mut node_iter = model_proto.graph.node.iter().peekable(); @@ -236,22 +239,13 @@ impl OnnxGraphBuilder { keep }); - let mut processed_nodes_inputs : Vec = Vec::new(); - let mut processed_nodes_outputs : Vec = Vec::new(); - for processed_node in processed_nodes.iter() { - processed_nodes_inputs.extend(processed_node.inputs.clone()); - processed_nodes_outputs.extend(processed_node.outputs.clone()); - } - - println!("The from onnx processed in/outputs are: {:#?}", (&processed_nodes_inputs, &processed_nodes_outputs)); - println!("The from onnx regular in/outputs are: {:#?}", (&inputs, &outputs)); - - // TODO: ConstantOfShape updates input to be Shape argument and output Tensor dim is updated - + // TODO Update graph inputs and outputs to match the processed nodes inputs and outputs + // This is necessary for the graph to be valid + // ConstantOfShape updates input to be Shape argument and output Tensor dim is updated OnnxGraph { nodes: processed_nodes, - inputs: processed_nodes_inputs, - outputs: processed_nodes_outputs + inputs, + outputs } } From 17dce603a7d61771229d55206212a99a39229448 Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:32:29 -0400 Subject: [PATCH 04/14] fix: only run on macos --- crates/burn-jit/src/tests/unary.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/burn-jit/src/tests/unary.rs b/crates/burn-jit/src/tests/unary.rs index 692133ae68..47e3933c95 100644 --- a/crates/burn-jit/src/tests/unary.rs +++ b/crates/burn-jit/src/tests/unary.rs @@ -4,6 +4,7 @@ mod tests { use burn_tensor::{Distribution, Tensor}; #[test] + #[cfg(target_os = "macos")] fn tanh_should_not_have_numerical_bugs_on_macos() { fn tanh_one_value(input: f32) -> f32 { let tensor = Tensor::::ones([1], &Default::default()) * input; From ce1aa89ab0f8f85310fce8760febabaed1b1c3f0 Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:32:48 -0400 Subject: [PATCH 05/14] chore: clean up --- crates/burn-import/onnx-tests/tests/test_onnx.rs | 5 ++++- crates/burn-import/src/burn/node/base.rs | 2 +- crates/burn-import/src/burn/node/top_k.rs | 10 ++++++---- crates/burn-import/src/onnx/op_configuration.rs | 10 +++------- crates/onnx-ir/src/dim_inference.rs | 4 ++-- crates/onnx-ir/src/from_onnx.rs | 5 +---- 6 files changed, 17 insertions(+), 19 deletions(-) diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index e09f03b6e5..a7987bb704 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -2134,7 +2134,10 @@ mod tests { let model = top_k::Model::::new(&device); // Run the model - let input = Tensor::::from_floats([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], &device); + let input = Tensor::::from_floats( + [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], + &device, + ); let (values_tensor, _indices_tensor) = model.forward(input); // data from pyTorch let expected = TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]); diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 6c955033cd..69fc8f9712 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -11,7 +11,7 @@ use super::{ max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, - squeeze::SqueezeNode, sum::SumNode, tile::TileNode, top_k::TopKNode, unary::UnaryNode, + squeeze::SqueezeNode, sum::SumNode, tile::TileNode, top_k::TopKNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; diff --git a/crates/burn-import/src/burn/node/top_k.rs b/crates/burn-import/src/burn/node/top_k.rs index ebf2a22290..df487d6b4d 100644 --- a/crates/burn-import/src/burn/node/top_k.rs +++ b/crates/burn-import/src/burn/node/top_k.rs @@ -12,7 +12,6 @@ pub struct TopKConfig { pub k: i64, } - #[derive(Debug, Clone, new)] pub struct TopKNode { pub input: TensorType, @@ -57,7 +56,7 @@ mod tests { use super::*; use crate::burn::{ graph::BurnGraph, - node::{top_k::TopKNode, test::assert_tokens}, + node::{test::assert_tokens, top_k::TopKNode}, TensorType, }; @@ -75,7 +74,10 @@ mod tests { config, )); - graph.register_input_output(vec!["input_tensor".to_string()], vec!["values_tensor".to_string(), "indices_tensor".to_string()]); + graph.register_input_output( + vec!["input_tensor".to_string()], + vec!["values_tensor".to_string(), "indices_tensor".to_string()], + ); let expected = quote! { use burn::tensor::Int; @@ -108,4 +110,4 @@ mod tests { assert_tokens(graph.codegen(), expected); } -} \ No newline at end of file +} diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index b8daad5660..8c0e961c81 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -796,18 +796,14 @@ pub fn tile_config(node: &Node) -> TileConfig { } fn extract_attr_value_i64(node: &Node, key: &str) -> i64 { - let value = node.attrs - .get(key) - .unwrap() - .clone() - .into_i64(); + let value = node.attrs.get(key).unwrap().clone().into_i64(); value } + /// Create a TopKConfig from the attributes of the node. We don't extract sorted from the TopK node as our topk impl already returns in sorted order. pub fn top_k_config(node: &Node) -> TopKConfig { - let axis:i64 = extract_attr_value_i64(node, "axis"); + let axis: i64 = extract_attr_value_i64(node, "axis"); let k: i64 = extract_attr_value_i64(node, "k"); - TopKConfig::new(axis, k) } diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index 737f36d964..f093421abc 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -495,13 +495,13 @@ fn top_k_update_output(node: &mut Node) { }; node.outputs[0].ty = ArgType::Tensor(TensorType { - dim: dim, + dim, shape: None, // shape is tracked and calculated at runtime elem_type: output_values_elem, }); node.outputs[1].ty = ArgType::Tensor(TensorType { - dim: dim, + dim, shape: None, // shape is tracked and calculated at runtime elem_type: output_indices_elem, }); diff --git a/crates/onnx-ir/src/from_onnx.rs b/crates/onnx-ir/src/from_onnx.rs index faf666c7a1..013916e088 100644 --- a/crates/onnx-ir/src/from_onnx.rs +++ b/crates/onnx-ir/src/from_onnx.rs @@ -202,14 +202,11 @@ impl OnnxGraphBuilder { pub(crate) fn build(mut self, model_proto: &ModelProto) -> OnnxGraph { self.constants_types = LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); - println!("here are the proto graph inputs: {:#?}", &model_proto.graph.input); - println!("here are the proto graph outputs: {:#?}", &model_proto.graph.output); let mut graph_data = GraphData::new( &model_proto.graph.input, &model_proto.graph.output, &model_proto.graph.initializer, ); - println!("here are the proto graph outputs created: {:#?}", &graph_data.outputs); let mut node_iter = model_proto.graph.node.iter().peekable(); @@ -245,7 +242,7 @@ impl OnnxGraphBuilder { OnnxGraph { nodes: processed_nodes, inputs, - outputs + outputs, } } From 5783f76380c1a8d0bc6224e1b48970bc6e1102bc Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:38:55 -0400 Subject: [PATCH 06/14] chore: updated supported ops --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 6319fedbe7..668a4da610 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -192,7 +192,7 @@ represent the corresponding Burn Op. | [TfIdfVectorizer][183] | ❌ | ❌ | | [ThresholdedRelu][184] | ❌ | ❌ | | [Tile][185] | ✅ | ✅ | -| [TopK][186] | ❌ | ✅ | +| [TopK][186] | ✅ | ✅ | | [Transpose][187] | ✅ | ✅ | | [Trilu][188] | ❌ | ✅ | | [Unique][189] | ❌ | ❌ | From 16ce366af35aa5534591257de08a37a43cc23a11 Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:55:31 -0400 Subject: [PATCH 07/14] cleanup --- crates/burn-import/src/onnx/op_configuration.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 8c0e961c81..dcfdd30d43 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -800,7 +800,7 @@ fn extract_attr_value_i64(node: &Node, key: &str) -> i64 { value } -/// Create a TopKConfig from the attributes of the node. We don't extract sorted from the TopK node as our topk impl already returns in sorted order. +/// Create a TopKConfig from the attributes of the node. pub fn top_k_config(node: &Node) -> TopKConfig { let axis: i64 = extract_attr_value_i64(node, "axis"); let k: i64 = extract_attr_value_i64(node, "k"); From 4ca56c6cae0a085c67604a7f9ea7a10e65efa064 Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:01:14 -0400 Subject: [PATCH 08/14] cleanup --- crates/burn-import/src/burn/node/top_k.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/burn-import/src/burn/node/top_k.rs b/crates/burn-import/src/burn/node/top_k.rs index df487d6b4d..47b8f57323 100644 --- a/crates/burn-import/src/burn/node/top_k.rs +++ b/crates/burn-import/src/burn/node/top_k.rs @@ -5,7 +5,6 @@ use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; -// We omit the sorted option because the burn topk impls already come with the topk sorted #[derive(Config, Debug)] pub struct TopKConfig { pub axis: i64, From 49a6a73540328623c096e3433177116a0cea5629 Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Fri, 27 Sep 2024 18:42:21 -0400 Subject: [PATCH 09/14] address feedback --- .../burn-import/onnx-tests/tests/test_onnx.rs | 18 ++-- .../onnx-tests/tests/top_k/top_k.py | 100 +++++++++++------- .../onnx-tests/tests/top_k/top_k_opset_1.onnx | Bin 0 -> 147 bytes crates/burn-import/src/burn/node/top_k.rs | 14 ++- .../burn-import/src/onnx/op_configuration.rs | 40 +++++-- crates/burn-import/src/onnx/to_burn.rs | 4 + crates/burn-jit/src/tests/unary.rs | 1 - crates/onnx-ir/src/from_onnx.rs | 1 - 8 files changed, 117 insertions(+), 61 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/top_k/top_k_opset_1.onnx diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index a7987bb704..147080a31e 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -116,7 +116,7 @@ include_models!( sum_int, tanh, tile, - top_k, + top_k_opset_1, transpose, unsqueeze, unsqueeze_opset11, @@ -2128,19 +2128,23 @@ mod tests { } #[test] - fn top_k() { + fn top_k_opset_1() { // Initialize the model let device = Default::default(); - let model = top_k::Model::::new(&device); + let model = top_k_opset1::Model::::new(&device); // Run the model let input = Tensor::::from_floats( [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], &device, ); - let (values_tensor, _indices_tensor) = model.forward(input); - // data from pyTorch - let expected = TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]); - values_tensor.to_data().assert_eq(&expected, true); + let (values_tensor, indices_tensor) = model.forward(input); + + // expected results + let expected_values_tensor = TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]); + let expected_indices_tensor = TensorData::from([[3, 2, 1], [3, 2, 1]]); + + values_tensor.to_data().assert_eq(&expected_values_tensor, true); + indices_tensor.to_data().assert_eq(&expected_indices_tensor, true); } } diff --git a/crates/burn-import/onnx-tests/tests/top_k/top_k.py b/crates/burn-import/onnx-tests/tests/top_k/top_k.py index ca58d8d213..e74c99a2f6 100644 --- a/crates/burn-import/onnx-tests/tests/top_k/top_k.py +++ b/crates/burn-import/onnx-tests/tests/top_k/top_k.py @@ -13,45 +13,63 @@ axis = 1 new_dims = [X.shape[0], k] -input_tensors = [ - helper.make_tensor_value_info('X', TensorProto.FLOAT, X.shape), - #helper.make_tensor_value_info('K', TensorProto.INT32, K.shape) -] - -output_tensors = [ - helper.make_tensor_value_info('Values', TensorProto.FLOAT, new_dims), - helper.make_tensor_value_info('Indices', TensorProto.INT32, new_dims) -] +def create_model(op_set_version: int): + input_tensors = [helper.make_tensor_value_info('X', TensorProto.FLOAT, X.shape)] + + output_tensors = [ + helper.make_tensor_value_info('Values', TensorProto.FLOAT, new_dims), + helper.make_tensor_value_info('Indices', TensorProto.INT32, new_dims) + ] + + # Create the TopK node + if op_set_version > 1: + node = helper.make_node( + 'TopK', + inputs=['X', 'K'], + outputs=['Values', 'Indices'], + axis=axis, # Axis along which to find the top K elements + ) + input_tensors.append(helper.make_tensor_value_info('K', TensorProto.INT32, K.shape)) + else: + node = helper.make_node( + 'TopK', + inputs=['X'], + outputs=['Values', 'Indices'], + axis=axis, # Axis along which to find the top K elements + k=k + ) + + # Create the graph + graph = helper.make_graph( + nodes = [node], + name = 'TopKGraph', + inputs = input_tensors, + outputs = output_tensors, + # Unconmment when initializers are supported. Currently we can't test opset 10/11 since the code will require a k value to be initialized for testing. + #initializer = [ + # helper.make_tensor('X', TensorProto.FLOAT, X.shape, X), + # helper.make_tensor('K', TensorProto.INT64, [1], [k]), + #] + ) + + # Create the model + model = helper.make_model( + graph, + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", op_set_version)] + ) + # Check the model + onnx.checker.check_model(model) + + # Save the model to a file + onnx.save(model, f'top_k_opset_{op_set_version}.onnx') + print(f"Model saved to top_k_opset_{op_set_version}.onnx") -# Create the TopK node -node = helper.make_node( - 'TopK', - inputs=['X'],# 'K'], - outputs=['Values', 'Indices'], - axis=axis, # Axis along which to find the top K elements - #largest=-1, - k=k -) - -# Create the graph -graph = helper.make_graph( - nodes = [node], - name = 'TopKGraph', - inputs = input_tensors, - outputs = output_tensors -) - -# Create the model -model = helper.make_model( - graph, - ir_version=8, - opset_imports=[onnx.helper.make_operatorsetid("", 1)] -) - -# Check the model -onnx.checker.check_model(model) - -# Save the model to a file -onnx.save(model, 'top_k.onnx') - -print("Model saved to topk_model.onnx") +def main(): + # Unconmment when initializers are supported. + for op_set_version in [1]: #, 10, 11]: + create_model(op_set_version) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/top_k/top_k_opset_1.onnx b/crates/burn-import/onnx-tests/tests/top_k/top_k_opset_1.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4eb08a05c9df6d563d67962e00a26ab83cfec0f2 GIT binary patch literal 147 zcmdMgWq`87u$* literal 0 HcmV?d00001 diff --git a/crates/burn-import/src/burn/node/top_k.rs b/crates/burn-import/src/burn/node/top_k.rs index 47b8f57323..0d8dd45498 100644 --- a/crates/burn-import/src/burn/node/top_k.rs +++ b/crates/burn-import/src/burn/node/top_k.rs @@ -9,6 +9,7 @@ use quote::{quote, ToTokens}; pub struct TopKConfig { pub axis: i64, pub k: i64, + pub largest: i64, } #[derive(Debug, Clone, new)] @@ -33,13 +34,17 @@ impl NodeCodegen for TopKNode { fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { let axis = self.config.axis.to_token_stream(); let k = self.config.k.to_token_stream(); + let largest = self.config.largest.to_token_stream(); let input = scope.tensor_use_owned(&self.input, node_position); let values_output = &self.outputs[0].name; let indices_output = &self.outputs[1].name; quote! { - let (#values_output, #indices_output) = #input.topk_with_indices(#k as usize, #axis as usize); + let (#values_output, #indices_output) = match #largest { + Some(0) => #input.topk_smallest_with_indices(#k, #axis), + _ => #input.topk_with_indices(#k, #axis) + }; } } @@ -62,7 +67,7 @@ mod tests { #[test] fn test_codegen_nodes() { let mut graph = BurnGraph::::default(); - let config = TopKConfig::new(-1, 3); + let config = TopKConfig::new(1, 3, 1); graph.register(TopKNode::new( TensorType::new_float("input_tensor", 4), @@ -101,7 +106,10 @@ mod tests { } #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input_tensor: Tensor) -> (Tensor, Tensor) { - let (values_tensor, indices_tensor) = input_tensor.topk_with_indices(3i64 as usize, -1i64 as usize); + let (values_tensor, indices_tensor) = match 1i64 { + Some(0) => input_tensor.topk_smallest_with_indices(3i64, 1i64), + _ => input_tensor.topk_with_indices(3i64, 1i64) + }; (values_tensor, indices_tensor) } } diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index dcfdd30d43..9d132497c0 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -10,6 +10,12 @@ use burn::nn::{ use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig, top_k::TopKConfig}; use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node}; +/// Extract and convert a given attribute to i64 +fn extract_attr_value_i64(node: &Node, key: &str) -> i64 { + let value = node.attrs.get(key).expect("Expected the following attribute key: {:?}").clone().into_i64(); + value +} + /// Create a Conv1dConfig from the attributes of the node pub fn conv1d_config(curr: &Node) -> Conv1dConfig { let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec @@ -795,16 +801,34 @@ pub fn tile_config(node: &Node) -> TileConfig { TileConfig::new(repeat) } -fn extract_attr_value_i64(node: &Node, key: &str) -> i64 { - let value = node.attrs.get(key).unwrap().clone().into_i64(); - value -} - /// Create a TopKConfig from the attributes of the node. pub fn top_k_config(node: &Node) -> TopKConfig { - let axis: i64 = extract_attr_value_i64(node, "axis"); - let k: i64 = extract_attr_value_i64(node, "k"); - TopKConfig::new(axis, k) + // extract the shape of the input data tensor + let data_tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + let k = match node.inputs.get(1) { + Some(k_tensor) => { + k_tensor.clone().value.expect("Expecting K tensor to have a value.").into_i64s()[0] + } + _ => extract_attr_value_i64(node, "k") + }; + + let mut axis: i64 = extract_attr_value_i64(node, "axis"); + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += data_tensor.dim as i64; + } + + let largest = match node.attrs.get("largest") { + Some(val) => val.clone().into_i64(), + _ => 1 + }; + + TopKConfig::new(axis, k, largest) } /// Create a PadConfig from the attributes of the node diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index cf7ab7637e..4a0da898c3 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -359,6 +359,10 @@ impl ParsedOnnxGraph { .iter() .map(|input| input.name.clone()) .collect::>(); + + println!("The registered graph input types are: {:#?}", self + .0 + .inputs.clone()); let output_names = self .0 .outputs diff --git a/crates/burn-jit/src/tests/unary.rs b/crates/burn-jit/src/tests/unary.rs index 47e3933c95..692133ae68 100644 --- a/crates/burn-jit/src/tests/unary.rs +++ b/crates/burn-jit/src/tests/unary.rs @@ -4,7 +4,6 @@ mod tests { use burn_tensor::{Distribution, Tensor}; #[test] - #[cfg(target_os = "macos")] fn tanh_should_not_have_numerical_bugs_on_macos() { fn tanh_one_value(input: f32) -> f32 { let tensor = Tensor::::ones([1], &Default::default()) * input; diff --git a/crates/onnx-ir/src/from_onnx.rs b/crates/onnx-ir/src/from_onnx.rs index 013916e088..470ab2ddca 100644 --- a/crates/onnx-ir/src/from_onnx.rs +++ b/crates/onnx-ir/src/from_onnx.rs @@ -369,7 +369,6 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { let builder = OnnxGraphBuilder::default(); let graph = builder.build(&onnx_model); - log::info!("Onnx graph: {:#?}", graph); log::info!("Finished parsing ONNX file: {}", onnx_path.display()); graph From f293ad3552ba0399a90e2caeb41c65fe512150db Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Fri, 27 Sep 2024 18:49:47 -0400 Subject: [PATCH 10/14] usize cast in config and add topk_smallest --- crates/burn-import/src/onnx/op_configuration.rs | 2 +- crates/burn-tensor/src/tensor/api/numeric.rs | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 9d132497c0..691c145d6f 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -828,7 +828,7 @@ pub fn top_k_config(node: &Node) -> TopKConfig { _ => 1 }; - TopKConfig::new(axis, k, largest) + TopKConfig::new(axis as usize, k as usize, largest) } /// Create a PadConfig from the attributes of the node diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index b67293715b..76f1b08bd7 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -731,6 +731,12 @@ where self.sort_descending(dim).select(dim, k_indices) } + /// Returns the `k` smallest elements of the given input tensor along a given dimension. + pub fn topk_smallest(self, k: usize, dim: usize) -> Tensor { + let k_indices = Tensor::arange(0..k as i64, &self.device()); + self.sort(dim).select(dim, k_indices) + } + /// Returns the `k` largest elements of the given input tensor along a given dimension. /// Also returns the indices. pub fn topk_with_indices(self, k: usize, dim: usize) -> (Tensor, Tensor) { @@ -742,6 +748,17 @@ where ) } + /// Returns the `k` smallest elements of the given input tensor along a given dimension. + /// Also returns the indices. + pub fn topk_smallest_with_indices(self, k: usize, dim: usize) -> (Tensor, Tensor) { + let k_indices = Tensor::arange(0..k as i64, &self.device()); + let (values, indices) = self.sort_with_indices(dim); + ( + values.select(dim, k_indices.clone()), + indices.select(dim, k_indices), + ) + } + /// Pad the tensor of rank two or higher with the given value on the last two dimensions. /// /// # Arguments From 449f519a2ffbeee4d1dec967ee8c710fc2bc58f1 Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Fri, 27 Sep 2024 19:27:45 -0400 Subject: [PATCH 11/14] EOD --- crates/burn-import/src/burn/node/top_k.rs | 16 ++--- crates/burn-import/src/onnx/to_burn.rs | 3 - crates/burn-tensor/src/tensor/api/numeric.rs | 52 ++++++++++---- crates/burn-tensor/src/tests/ops/topk.rs | 76 ++++++++++++++++++-- 4 files changed, 113 insertions(+), 34 deletions(-) diff --git a/crates/burn-import/src/burn/node/top_k.rs b/crates/burn-import/src/burn/node/top_k.rs index 0d8dd45498..e425c2fc5e 100644 --- a/crates/burn-import/src/burn/node/top_k.rs +++ b/crates/burn-import/src/burn/node/top_k.rs @@ -7,8 +7,8 @@ use quote::{quote, ToTokens}; #[derive(Config, Debug)] pub struct TopKConfig { - pub axis: i64, - pub k: i64, + pub axis: usize, + pub k: usize, pub largest: i64, } @@ -41,11 +41,8 @@ impl NodeCodegen for TopKNode { let indices_output = &self.outputs[1].name; quote! { - let (#values_output, #indices_output) = match #largest { - Some(0) => #input.topk_smallest_with_indices(#k, #axis), - _ => #input.topk_with_indices(#k, #axis) - }; - } + #input.topk_with_indices(#k, #axis, #largest) + }; } fn into_node(self) -> Node { @@ -106,10 +103,7 @@ mod tests { } #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input_tensor: Tensor) -> (Tensor, Tensor) { - let (values_tensor, indices_tensor) = match 1i64 { - Some(0) => input_tensor.topk_smallest_with_indices(3i64, 1i64), - _ => input_tensor.topk_with_indices(3i64, 1i64) - }; + (values_tensor, indices_tensor) = input_tensor.topk_with_indices(3i64, 1i64, 1i64) (values_tensor, indices_tensor) } } diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 4a0da898c3..70ef970337 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -360,9 +360,6 @@ impl ParsedOnnxGraph { .map(|input| input.name.clone()) .collect::>(); - println!("The registered graph input types are: {:#?}", self - .0 - .inputs.clone()); let output_names = self .0 .outputs diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 76f1b08bd7..875e494f37 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -726,26 +726,50 @@ where } /// Returns the `k` largest elements of the given input tensor along a given dimension. - pub fn topk(self, k: usize, dim: usize) -> Tensor { + pub fn topk(self, k: usize, dim: usize, largest: Option) -> Tensor { let k_indices = Tensor::arange(0..k as i64, &self.device()); - self.sort_descending(dim).select(dim, k_indices) - } - - /// Returns the `k` smallest elements of the given input tensor along a given dimension. - pub fn topk_smallest(self, k: usize, dim: usize) -> Tensor { - let k_indices = Tensor::arange(0..k as i64, &self.device()); - self.sort(dim).select(dim, k_indices) + match largest { + Some(largest) => { + if largest == 1 { + self.sort_descending(dim).select(dim, k_indices) + } else { + self.sort(dim).select(dim, k_indices) + } + }, + _ => { + self.sort_descending(dim).select(dim, k_indices) + } + } } /// Returns the `k` largest elements of the given input tensor along a given dimension. /// Also returns the indices. - pub fn topk_with_indices(self, k: usize, dim: usize) -> (Tensor, Tensor) { + pub fn topk_with_indices(self, k: usize, dim: usize, largest: Option) -> (Tensor, Tensor) { let k_indices = Tensor::arange(0..k as i64, &self.device()); - let (values, indices) = self.sort_descending_with_indices(dim); - ( - values.select(dim, k_indices.clone()), - indices.select(dim, k_indices), - ) + match largest { + Some(largest) => { + if largest == 1 { + let (values, indices) = self.sort_with_indices(dim); + ( + values.select(dim, k_indices.clone()), + indices.select(dim, k_indices), + ) + } else { + let (values, indices) = self.sort_descending_with_indices(dim); + ( + values.select(dim, k_indices.clone()), + indices.select(dim, k_indices), + ) + } + }, + _ => { + let (values, indices) = self.sort_descending_with_indices(dim); + ( + values.select(dim, k_indices.clone()), + indices.select(dim, k_indices), + ) + } + } } /// Returns the `k` smallest elements of the given input tensor along a given dimension. diff --git a/crates/burn-tensor/src/tests/ops/topk.rs b/crates/burn-tensor/src/tests/ops/topk.rs index 9d98926655..6d8f3a0784 100644 --- a/crates/burn-tensor/src/tests/ops/topk.rs +++ b/crates/burn-tensor/src/tests/ops/topk.rs @@ -6,48 +6,85 @@ mod tests { #[test] fn test_topk_1d() { // Int + // largest let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); - let values = tensor.topk(3, /*dim*/ 0); + let values = tensor.topk(3, /*dim*/ 0, /*largest*/ Some(1)); let expected = TensorData::from([5, 4, 3]); values.into_data().assert_eq(&expected, false); + // smallest + let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); + + let values = tensor.topk(3, /*dim*/ 0, /*largest*/ Some(0)); + let expected = TensorData::from([1, 2, 3]); + + values.into_data().assert_eq(&expected, false); + // Float + // largest let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]); - let values = tensor.topk(3, /*dim*/ 0); + let values = tensor.topk(3, /*dim*/ 0, /*largest*/ Some(1)); let expected = TensorData::from([5., 4., 3.]); values.into_data().assert_approx_eq(&expected, 5); + + // Float + // smallest + let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]); + + let values = tensor.topk(3, /*dim*/ 0, /*largest*/ Some(0)); + let expected = TensorData::from([1., 2., 3.]); + + values.into_data().assert_approx_eq(&expected, 1); } #[test] fn test_topk() { // 3D Int + // largest let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]); - let values = tensor.topk(2, /*dim*/ 2); + let values = tensor.topk(2, /*dim*/ 2, /*largest*/ Some(1)); let expected = TensorData::from([[[7, 4], [6, 5]], [[9, 3], [8, 8]]]); values.into_data().assert_eq(&expected, false); + // smallest + let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]); + + let values = tensor.topk(2, /*dim*/ 2, /*largest*/ Some(0)); + let expected = TensorData::from([[[1, 4], [2, 5]], [[0, 3], [2, 8]]]); + + values.into_data().assert_eq(&expected, false); + // 3D Float + // largest let tensor = TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]); - let values = tensor.topk(2, /*dim*/ 2); + let values = tensor.topk(2, /*dim*/ 2, /*largest*/ Some(1)); let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 8.]]]); values.into_data().assert_approx_eq(&expected, 5); + + // smallest + let tensor = + TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]); + + let values = tensor.topk(2, /*dim*/ 2, /*largest*/ Some(0)); + let expected = TensorData::from([[[1, 4], [2, 5]], [[0, 3], [2, 8]]]); } #[test] fn test_topk_with_indices() { // 1D + // largest let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); - let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0); + let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(1)); let values_expected = TensorData::from([5, 4, 3]); values.into_data().assert_eq(&values_expected, false); @@ -55,11 +92,23 @@ mod tests { let indices_expected = TensorData::from([4, 3, 2]); indices.into_data().assert_eq(&indices_expected, false); + // smallest + let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); + + let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(0)); + + let values_expected = TensorData::from([1, 2, 3]); + values.into_data().assert_eq(&values_expected, false); + + let indices_expected = TensorData::from([0, 1, 2]); + indices.into_data().assert_eq(&indices_expected, false); + // 3D + // largest let tensor = TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]); - let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2); + let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(1)); let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); @@ -68,5 +117,20 @@ mod tests { let indices_expected = TensorData::from([[[2, 1], [2, 1]], [[2, 0], [0, 2]]]); indices.into_data().assert_eq(&indices_expected, false); + + // smallest + let tensor = + TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]); + + let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(0)); + + let values_expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]); + + values.into_data().assert_approx_eq(&values_expected, 5); + + let indices_expected = TensorData::from([[[0, 1], [0, 1]], [[1, 0], [1, 2]]]); + + indices.into_data().assert_eq(&indices_expected, false); + } } From 249d60c25de11aefed7d0df8cd64752baac0ac39 Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Fri, 27 Sep 2024 19:43:31 -0400 Subject: [PATCH 12/14] EOD cleanup --- crates/burn-import/src/burn/node/top_k.rs | 8 ++++---- crates/burn-import/src/onnx/op_configuration.rs | 2 +- crates/burn-tensor/src/tensor/api/numeric.rs | 11 ----------- 3 files changed, 5 insertions(+), 16 deletions(-) diff --git a/crates/burn-import/src/burn/node/top_k.rs b/crates/burn-import/src/burn/node/top_k.rs index e425c2fc5e..4fc899dcf7 100644 --- a/crates/burn-import/src/burn/node/top_k.rs +++ b/crates/burn-import/src/burn/node/top_k.rs @@ -9,7 +9,7 @@ use quote::{quote, ToTokens}; pub struct TopKConfig { pub axis: usize, pub k: usize, - pub largest: i64, + pub largest: usize, } #[derive(Debug, Clone, new)] @@ -41,8 +41,8 @@ impl NodeCodegen for TopKNode { let indices_output = &self.outputs[1].name; quote! { - #input.topk_with_indices(#k, #axis, #largest) - }; + let (#values_output, #indices_output) = #input.topk_with_indices(#k, #axis, #largest); + } } fn into_node(self) -> Node { @@ -103,7 +103,7 @@ mod tests { } #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input_tensor: Tensor) -> (Tensor, Tensor) { - (values_tensor, indices_tensor) = input_tensor.topk_with_indices(3i64, 1i64, 1i64) + let (values_tensor, indices_tensor) = input_tensor.topk_with_indices(3usize, 1usize, 1usize); (values_tensor, indices_tensor) } } diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 691c145d6f..8fd1215e83 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -828,7 +828,7 @@ pub fn top_k_config(node: &Node) -> TopKConfig { _ => 1 }; - TopKConfig::new(axis as usize, k as usize, largest) + TopKConfig::new(axis as usize, k as usize, largest as usize) } /// Create a PadConfig from the attributes of the node diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 875e494f37..785caf1c27 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -772,17 +772,6 @@ where } } - /// Returns the `k` smallest elements of the given input tensor along a given dimension. - /// Also returns the indices. - pub fn topk_smallest_with_indices(self, k: usize, dim: usize) -> (Tensor, Tensor) { - let k_indices = Tensor::arange(0..k as i64, &self.device()); - let (values, indices) = self.sort_with_indices(dim); - ( - values.select(dim, k_indices.clone()), - indices.select(dim, k_indices), - ) - } - /// Pad the tensor of rank two or higher with the given value on the last two dimensions. /// /// # Arguments From 07f79ab3c771e0a014e1e57672a805894186df92 Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Fri, 27 Sep 2024 21:59:53 -0400 Subject: [PATCH 13/14] minor update --- crates/burn-import/onnx-tests/tests/top_k/top_k.py | 2 +- crates/burn-import/src/onnx/op_configuration.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/burn-import/onnx-tests/tests/top_k/top_k.py b/crates/burn-import/onnx-tests/tests/top_k/top_k.py index e74c99a2f6..176c743875 100644 --- a/crates/burn-import/onnx-tests/tests/top_k/top_k.py +++ b/crates/burn-import/onnx-tests/tests/top_k/top_k.py @@ -67,7 +67,7 @@ def create_model(op_set_version: int): def main(): # Unconmment when initializers are supported. - for op_set_version in [1]: #, 10, 11]: + for op_set_version in [1, 10, 11]: create_model(op_set_version) diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 8fd1215e83..1290effdd7 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -12,7 +12,8 @@ use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node}; /// Extract and convert a given attribute to i64 fn extract_attr_value_i64(node: &Node, key: &str) -> i64 { - let value = node.attrs.get(key).expect("Expected the following attribute key: {:?}").clone().into_i64(); + let error_msg = format!("Expected the following attribute key: {:?}", key); + let value = node.attrs.get(key).expect(&error_msg).clone().into_i64(); value } From 7e221c92e6bd852fbef920067b443aa9bbaad3e9 Mon Sep 17 00:00:00 2001 From: Femi Oho <47154698+oojo12@users.noreply.github.com> Date: Sat, 28 Sep 2024 08:29:01 -0400 Subject: [PATCH 14/14] run checks and other fixes --- .../burn-import/onnx-tests/tests/test_onnx.rs | 11 ++- .../burn-import/src/onnx/op_configuration.rs | 12 +-- crates/burn-tensor/src/tensor/api/numeric.rs | 17 ++-- crates/burn-tensor/src/tests/ops/topk.rs | 18 +++-- .../src/tests/quantization/ops/topk.rs | 79 +++++++++++++++++-- 5 files changed, 106 insertions(+), 31 deletions(-) diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index 147080a31e..cfbce8830d 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -2141,10 +2141,15 @@ mod tests { let (values_tensor, indices_tensor) = model.forward(input); // expected results - let expected_values_tensor = TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]); + let expected_values_tensor = + TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]); let expected_indices_tensor = TensorData::from([[3, 2, 1], [3, 2, 1]]); - values_tensor.to_data().assert_eq(&expected_values_tensor, true); - indices_tensor.to_data().assert_eq(&expected_indices_tensor, true); + values_tensor + .to_data() + .assert_eq(&expected_values_tensor, true); + indices_tensor + .to_data() + .assert_eq(&expected_indices_tensor, true); } } diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 1290effdd7..bda0d745bf 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -811,10 +811,12 @@ pub fn top_k_config(node: &Node) -> TopKConfig { }; let k = match node.inputs.get(1) { - Some(k_tensor) => { - k_tensor.clone().value.expect("Expecting K tensor to have a value.").into_i64s()[0] - } - _ => extract_attr_value_i64(node, "k") + Some(k_tensor) => k_tensor + .clone() + .value + .expect("Expecting K tensor to have a value.") + .into_i64s()[0], + _ => extract_attr_value_i64(node, "k"), }; let mut axis: i64 = extract_attr_value_i64(node, "axis"); @@ -826,7 +828,7 @@ pub fn top_k_config(node: &Node) -> TopKConfig { let largest = match node.attrs.get("largest") { Some(val) => val.clone().into_i64(), - _ => 1 + _ => 1, }; TopKConfig::new(axis as usize, k as usize, largest as usize) diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 785caf1c27..77057e2f6d 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -735,33 +735,36 @@ where } else { self.sort(dim).select(dim, k_indices) } - }, - _ => { - self.sort_descending(dim).select(dim, k_indices) } + _ => self.sort_descending(dim).select(dim, k_indices), } } /// Returns the `k` largest elements of the given input tensor along a given dimension. /// Also returns the indices. - pub fn topk_with_indices(self, k: usize, dim: usize, largest: Option) -> (Tensor, Tensor) { + pub fn topk_with_indices( + self, + k: usize, + dim: usize, + largest: Option, + ) -> (Tensor, Tensor) { let k_indices = Tensor::arange(0..k as i64, &self.device()); match largest { Some(largest) => { if largest == 1 { - let (values, indices) = self.sort_with_indices(dim); + let (values, indices) = self.sort_descending_with_indices(dim); ( values.select(dim, k_indices.clone()), indices.select(dim, k_indices), ) } else { - let (values, indices) = self.sort_descending_with_indices(dim); + let (values, indices) = self.sort_with_indices(dim); ( values.select(dim, k_indices.clone()), indices.select(dim, k_indices), ) } - }, + } _ => { let (values, indices) = self.sort_descending_with_indices(dim); ( diff --git a/crates/burn-tensor/src/tests/ops/topk.rs b/crates/burn-tensor/src/tests/ops/topk.rs index 6d8f3a0784..59411f7bd8 100644 --- a/crates/burn-tensor/src/tests/ops/topk.rs +++ b/crates/burn-tensor/src/tests/ops/topk.rs @@ -30,7 +30,6 @@ mod tests { let expected = TensorData::from([5., 4., 3.]); values.into_data().assert_approx_eq(&expected, 5); - // Float // smallest let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]); @@ -72,7 +71,7 @@ mod tests { // smallest let tensor = - TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]); + TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]); let values = tensor.topk(2, /*dim*/ 2, /*largest*/ Some(0)); let expected = TensorData::from([[[1, 4], [2, 5]], [[0, 3], [2, 8]]]); @@ -84,7 +83,8 @@ mod tests { // largest let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); - let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(1)); + let (values, indices) = + tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(1)); let values_expected = TensorData::from([5, 4, 3]); values.into_data().assert_eq(&values_expected, false); @@ -95,7 +95,8 @@ mod tests { // smallest let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); - let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(0)); + let (values, indices) = + tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(0)); let values_expected = TensorData::from([1, 2, 3]); values.into_data().assert_eq(&values_expected, false); @@ -108,7 +109,8 @@ mod tests { let tensor = TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]); - let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(1)); + let (values, indices) = + tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(1)); let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); @@ -120,9 +122,10 @@ mod tests { // smallest let tensor = - TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]); + TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]); - let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(0)); + let (values, indices) = + tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(0)); let values_expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]); @@ -131,6 +134,5 @@ mod tests { let indices_expected = TensorData::from([[[0, 1], [0, 1]], [[1, 0], [1, 2]]]); indices.into_data().assert_eq(&indices_expected, false); - } } diff --git a/crates/burn-tensor/src/tests/quantization/ops/topk.rs b/crates/burn-tensor/src/tests/quantization/ops/topk.rs index 7913fb79e1..78b4900523 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/topk.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/topk.rs @@ -14,9 +14,19 @@ mod tests { ); let tensor = TestTensor::<1>::from_data(data, &Default::default()); - let values = tensor.topk(3, /*dim*/ 0); + // largest + let values = tensor.clone().topk(3, /*dim*/ 0, /*largest*/ Some(1)); let expected = TensorData::from([5., 4., 3.]); + values + .dequantize() + .into_data() + .assert_approx_eq(&expected, 3); + + // smallest + let values = tensor.clone().topk(3, /*dim*/ 0, /*largest*/ Some(0)); + let expected = TensorData::from([1., 2., 3.]); + values .dequantize() .into_data() @@ -24,7 +34,7 @@ mod tests { } #[test] - fn test_topk() { + fn test_topk_3d() { // Quantized [[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]] let data = TensorData::quantized( vec![-100i8, -15, 70, -71, 14, 42, -43, -128, 127, 99, -71, 70], @@ -33,9 +43,20 @@ mod tests { ); let tensor = TestTensor::<3>::from_data(data, &Default::default()); - let values = tensor.topk(2, /*dim*/ 2); + // largest + let values = tensor.clone().topk(2, /*dim*/ 2, /*largest*/ Some(1)); let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + + // smallest + let values = tensor.clone().topk(2, /*dim*/ 2, /*largest*/ Some(0)); + let expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]); + // Precision 1 to approximate de/quantization errors values .dequantize() @@ -44,8 +65,7 @@ mod tests { } #[test] - fn test_topk_with_indices() { - // 1D + fn test_topk_with_indices_1d() { // Quantized [1.0, 2.0, 3.0, 4.0, 5.0] let data = TensorData::quantized( vec![-77i8, -26, 25, 76, 127], @@ -54,7 +74,11 @@ mod tests { ); let tensor = TestTensor::<1>::from_data(data, &Default::default()); - let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0); + // largest + let (values, indices) = + tensor + .clone() + .topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(1)); let values_expected = TensorData::from([5., 4., 3.]); values @@ -65,7 +89,24 @@ mod tests { let indices_expected = TensorData::from([4, 3, 2]); indices.into_data().assert_eq(&indices_expected, false); - // 3D + // smallest + let (values, indices) = + tensor + .clone() + .topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(0)); + + let values_expected = TensorData::from([1., 2., 3.]); + values + .dequantize() + .into_data() + .assert_eq(&values_expected, false); + + let indices_expected = TensorData::from([0, 1, 2]); + indices.into_data().assert_eq(&indices_expected, false); + } + + #[test] + fn test_topk_with_indices_3d() { // Quantized [[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]] let data = TensorData::quantized( vec![-100i8, -15, 70, -71, 14, 42, -43, -128, 127, 99, -71, 70], @@ -74,7 +115,11 @@ mod tests { ); let tensor = TestTensor::<3>::from_data(data, &Default::default()); - let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2); + // largest + let (values, indices) = + tensor + .clone() + .topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(1)); let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); @@ -87,5 +132,23 @@ mod tests { let indices_expected = TensorData::from([[[2, 1], [2, 1]], [[2, 0], [0, 2]]]); indices.into_data().assert_eq(&indices_expected, false); + + // smallest + let (values, indices) = + tensor + .clone() + .topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(0)); + + let values_expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]); + + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + + let indices_expected = TensorData::from([[[0, 1], [0, 1]], [[1, 0], [1, 2]]]); + + indices.into_data().assert_eq(&indices_expected, false); } }