-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change const to used dtype if it is passed in #7285
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to see one change to the unit test, otherwise looks great, thank you so much for figuring this out.
tests/python/relay/test_const.py
Outdated
|
||
def test_const_dtype(): | ||
strides = (1, 1) | ||
np_array = np.array(strides) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe explicitly cast this to int32 before making the const to ensure we'd hit the error in linux with the old behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was just hoping that windows would make int32 but your suggestion is much better.
* Add fix and unit test for const autoconvert dtype. * formatting * Address review comment, casting input value to int32 * Fix failing test * Augment unit test
* Add fix and unit test for const autoconvert dtype. * formatting * Address review comment, casting input value to int32 * Fix failing test * Augment unit test
* Add fix and unit test for const autoconvert dtype. * formatting * Address review comment, casting input value to int32 * Fix failing test * Augment unit test
* Add fix and unit test for const autoconvert dtype. * formatting * Address review comment, casting input value to int32 * Fix failing test * Augment unit test
If dtype is passed then the value is autoconverted to dtype. Windows defaults python integers to int32 while linux defaults to int64. This causes issues in the onnx frontend where constants are created with dtype=int64 but on Windows they would result in int32 constants and break the importer.
https://discuss.tvm.apache.org/t/onnx-frontend-giving-an-error-when-working-a-simple-model-with-tvm/8742