Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port mistral transformer checkpoint #1768

Merged
merged 9 commits into from
Aug 21, 2024

Conversation

cosmo3769
Copy link
Contributor

Hi @mattdangerw @ariG23498,

Ported mistral transformers checkpoint in kerasNLP. Please check. Thank you!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall! One comment.

Could you include a small colab showing generation just to verify this is working? Since we don't have numerics validation yet.

"rope_max_wavelength": transformers_config["rope_theta"],
"layer_norm_epsilon": transformers_config["rms_norm_eps"],
"sliding_window": transformers_config["sliding_window"],
"dtype": transformers_config["torch_dtype"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should convert dtype. We don't for other models.

We will create a backbone with the default Keras floating point type, unless the user supplies their own arg. But we don't restore to the saved dtypes policy by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without dtype conversion, I am getting an error: DTypePromotionError: The DTypes <class 'numpy.dtypes.Float16DType'> and <class 'numpy.dtype[bfloat16]'> do not have a common DType. For example they cannot be stored in a single array unless the dtype is object.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I think this is something we will have to solve during weight conversion, and not by sticking this value in the config. I will take a look.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw do you think this would be a better place to keep the check?
https://github.com/keras-team/keras-nlp/blob/f80fbfd0eaeee7a9e63a4c98a81ff8aba5506f3e/keras_nlp/src/utils/transformers/safetensor_utils.py#L97

We can check if the dtypes match here -- if there is a conversion needed, warn the user that there is a type conversion happening at this stage to port the weights and then continue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ariG23498 Makes sense. 💡

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would actually think any type conversion should happen inside of the assign call. I tried removing this dtype line and could not reproduce the error. Is this only on a specific backend?

https://github.com/keras-team/keras/blob/413f859d892394f584fcdd61b41d13e5999242a3/keras/src/backend/common/variables.py#L224

I don't think we need to warn that type conversion is happening. Loading a half precision save at full precision or vice versa is quite common.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can just remove this line right? I'll give that a try, and land if things look good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I checked it now by removing this line. It works.

@cosmo3769
Copy link
Contributor Author

Could you include a small colab showing generation just to verify this is working? Since we don't have numerics validation yet.

Working demo Colab link

class TestTask(TestCase):
@pytest.mark.large
def test_convert_tiny_preset(self):
model = MistralCausalLM.from_preset("hf://mistralai/Mistral-7B-v0.1")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too big to run in our automated testing regularly. @ariG23498 can you detail what you did to make hf://ariG23498/tiny-gemma-test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a detailed code to build a small test model and how to upload that to hub.

@cosmo3769 could you take a look at it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ariG23498 Sure.

@mattdangerw
Copy link
Member

Main thing we need before we merge is a smaller test case. Left a common on the big chain though, still not sure exactly where things are breaking if you remove dtype from the config.

@cosmo3769
Copy link
Contributor Author

Added tiny-mistral test.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm! Will merge after test runs

@cosmo3769
Copy link
Contributor Author

Resolved the merge conflict.

@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Aug 21, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Aug 21, 2024
@mattdangerw
Copy link
Member

Jax failure is from #1783, but this one looks good. Pulling this in!

@mattdangerw mattdangerw merged commit 081e4c8 into keras-team:master Aug 21, 2024
9 of 10 checks passed
pkgoogle pushed a commit to pkgoogle/keras-hub that referenced this pull request Aug 22, 2024
* ported mistral

* update test

* fix config

* fix typo

* switched float32 to float16

* tiny-mistral-test

* removed dtype config
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants