Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnx op topk #2305

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

Onnx op topk #2305

wants to merge 14 commits into from

Conversation

oojo12
Copy link

@oojo12 oojo12 commented Sep 25, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
    • some failed I think due to my computer. I.e test_rotary_encoding_forward failed but passed when I ran it via the ui
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#1714

Provide links to relevant issues and dependent PRs.
#1714

Changes

  • added support for onnx topk op. Relies on version 1 of topk node. Other versions take K as an input instead of an attribute, ref
  • made fn tanh_should_not_have_numerical_bugs_on_macos() only run on macos
  • updated IOEntry::Node in the input_names_map. This was important as before it was never incremented with the number of outputs so you would always have _1 as the output name suffix even if there were x >= 2 outputs.

Summarize the problem being addressed and your solution.

Testing

ran the below:
cargo test
./run-checks.sh all

Describe how these changes have been tested.
instructions listed here and here

Copy link

codecov bot commented Sep 26, 2024

Codecov Report

Attention: Patch coverage is 97.08029% with 4 lines in your changes missing coverage. Please review.

Project coverage is 85.43%. Comparing base (aa79e36) to head (4ca56c6).
Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
crates/onnx-ir/src/dim_inference.rs 86.95% 3 Missing ⚠️
crates/burn-import/src/burn/node/top_k.rs 98.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2305      +/-   ##
==========================================
- Coverage   85.79%   85.43%   -0.36%     
==========================================
  Files         754      766      +12     
  Lines       95189    98065    +2876     
==========================================
+ Hits        81671    83786    +2115     
- Misses      13518    14279     +761     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the addition! 🙏

Overall, implementation looks good! I have a few comments below.

crates/burn-import/onnx-tests/tests/test_onnx.rs Outdated Show resolved Hide resolved
crates/burn-import/onnx-tests/tests/top_k/top_k.py Outdated Show resolved Hide resolved
crates/burn-import/src/burn/node/top_k.rs Outdated Show resolved Hide resolved
crates/burn-import/src/onnx/op_configuration.rs Outdated Show resolved Hide resolved
crates/burn-jit/src/tests/unary.rs Outdated Show resolved Hide resolved
crates/onnx-ir/src/from_onnx.rs Outdated Show resolved Hide resolved
crates/onnx-ir/src/from_onnx.rs Show resolved Hide resolved
@oojo12
Copy link
Author

oojo12 commented Sep 28, 2024

Anytime had fun working on it. Also I finished making the changes.

@oojo12 oojo12 requested a review from laggui September 30, 2024 18:22
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay of the follow-up review 😅 been busy the past couple of days.

The changes look good to me! But we can remove the tensor API changes and ONNX op config for topk largest (see my comments)

Comment on lines +829 to +832
let largest = match node.attrs.get("largest") {
Some(val) => val.clone().into_i64(),
_ => 1,
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're not checking for "k" as the second input of the node (for opsets 10, 11) and just adding support for opset 1, then we don't need to check for the "largest" attribute here. It's only present in the later version 11 of the op.

So we can remove this from the config and node.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure thing will remove and resubmit tonight

@@ -726,20 +726,53 @@ where
}

/// Returns the `k` largest elements of the given input tensor along a given dimension.
pub fn topk(self, k: usize, dim: usize) -> Tensor<B, D, K> {
pub fn topk(self, k: usize, dim: usize, largest: Option<usize>) -> Tensor<B, D, K> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're only adding support for opset 1, the "largest" option is not necessary. We can remove the changes to the tensor API.

In any case, it would have been preferable to add it as another method, similar to sort & sort_descending. Something like topk and topk_smallest. But not required for this PR 🙂

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good will remove the tensor api changes as well

@@ -107,6 +107,7 @@ fn main() {
.input("tests/sum/sum_int.onnx")
.input("tests/tanh/tanh.onnx")
.input("tests/tile/tile.onnx")
.input("tests/top_k/top_k.onnx")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks the CI caught something I missed! This file doesn't exist anymore with your changes :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants