You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Out of curiosity, I wanted to test the new omni-staging on the code I have developed. And I noticed a new error, not existing before:
TypeError: No constant handler for type: <enum 'CCADimensions'>
I am using this IntEnum as a named vector indexing in some jit methods, and I guess the error is related to the different way omni-staging is handling constants.
I could solve the problem using some internals Jax methods:
from functools import partial
from jax.lib.xla_bridge import _python_scalar_handler, register_constant_handler
register_constant_handler(CCADimensions, partial(_python_scalar_handler, np.dtype(np.int32)))
Is there a plan to fix the support of IntEnum constants? Or add a helper method to allow users to register custom IntEnum as constants?
The text was updated successfully, but these errors were encountered:
Sorry for not following up on this before! It came up again in #6129, and I think #6130 should fix it. (If it doesn't, please share a minimal repro where it fails, though I'm pretty sure it fixes the issue!)
Out of curiosity, I wanted to test the new omni-staging on the code I have developed. And I noticed a new error, not existing before:
I am using this
IntEnum
as a named vector indexing in some jit methods, and I guess the error is related to the different way omni-staging is handling constants.I could solve the problem using some internals Jax methods:
Is there a plan to fix the support of
IntEnum
constants? Or add a helper method to allow users to register customIntEnum
as constants?The text was updated successfully, but these errors were encountered: