-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] MatMul broadcasting. #4100
Conversation
hey @annxingyuan, looks like tests don't pass for this on node do you want to take a look at that before we review? |
yes @tafsiri sorry i thought it was flaking at first - removed reviewers and will re-add when ready |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 7 of 8 files at r1, 5 of 5 files at r2.
Reviewable status: 0 of 1 approvals obtained (waiting on @annxingyuan, @pyu10055, and @tafsiri)
tfjs-core/src/ops/mat_mul.ts, line 89 at r2 (raw file):
const outShapeOuterDims = batchDimA > batchDimB ? $a.shape.slice(0, -2) : $b.shape.slice(0, -2);
can these be replaced with variables: outerDimsA and outerDimsB?
tfjs-node/src/run_tests.ts, line 84 at r2 (raw file):
'maxPoolWithArgmax', 'rotate', 'flipLeftRight', 'unique', // libtensorflow does not yet support tf.matmul with broadcast 'broadcast with unequal batch dims', 'broadcast with unequal ranks'
why tests for tfjs-node need to be skipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @pyu10055 and @tafsiri)
tfjs-core/src/ops/mat_mul.ts, line 89 at r2 (raw file):
Previously, pyu10055 (Ping Yu) wrote…
can these be replaced with variables: outerDimsA and outerDimsB?
Done
tfjs-node/src/run_tests.ts, line 84 at r2 (raw file):
Previously, pyu10055 (Ping Yu) wrote…
why tests for tfjs-node need to be skipped?
Broadcasting support was recently added in tensorflow 2.0 - libtensorflow does not yet support it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @pyu10055 and @tafsiri)
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is