Replies: 1 comment 1 reply
-
hello and thanks for the idea |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I noticed when working with a bfloat16 model, converstion between torch and flax would raise errors like:
The following commit tries to implement this for Mistral, both directions.
main...yhavinga:EasyDeL:mistral_bfloat16_conversions
The test
python_test/mistral_torch_bf16_conversions.py
runs these conversions:then compares output of the two torch and the flax model.
In both float32 and bfloat16 mode there are differences, the bf16 differences are larger than the fp32.
Since the fp32 numbers are really close, I'm inclined to believe the model implementations are correct, and differences in output are due to differences in algorithm implementations causing slightly different values for the highest precisions and compounding these.
I am not sure if this commit (converstions of bfloat16 parameters) is worth pursueing. Maybe in the future numpy will support bfloat16 conversions, rendering this code useless. And it's not a general solution for all models. Maybe a more general approach for model conversions would be: always upcast to fp32, convert between torch and flax, then downcast to datatype the model was trained in again.
Beta Was this translation helpful? Give feedback.
All reactions