Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor func load_model to class ModelLoader #1909

Merged
merged 11 commits into from
Oct 25, 2024

Conversation

MengqingCao
Copy link
Contributor

@MengqingCao MengqingCao commented Sep 12, 2024

Description

part of #1758

This PR refactor the func load_model in src/axolotl/utils/models.py into a class ModelLoader. Different member functions of class ModelLoader are separated according to their features, and all the member vars of ModelLoader are shared in these funcs. Moreover, this refactoring make the pipeline of model loading more clearly.

TODO:

  • add UT for ModelLoader

Mainly changes are listed here:

  • organize comman var into member var of class ModelLoader
  • split operations in load_model into separate member funcs
  • refactor cfg.load_in_Xbit to kwarg

The UML of ModelLoader:

9ef679eaeeff1fe01dbff0c4b67f81c

Motivation and Context

Why is this change required?

As the models loaded in Axolotl support more and more features, the func load_model is huge now. And this results in confusion about variable changes when abstracting part of func load_model (#1758 (review)). Refactoring load_model will optimize the code structure and facilitate stable evolution when introducing more features in the future.

How has this been tested?

  1. part of UTs for ModelLoader is added and tests passed
  2. I tested to funtune, inference (on both terminal and gradio webui) on open_llama_3b_v2 model, and here comes the screenshot of inferencing:

GPU-INFERENCE

However, I don't have access to Ampere or newer GPU, thus I cannot pass the UT on my local machine. It would be nice if all UTs could be tested on CI.

@winglian
Copy link
Collaborator

@MengqingCao this is on our list to tackle this week to get merged in. We'll need to get this rebased.

@MengqingCao
Copy link
Contributor Author

MengqingCao commented Oct 14, 2024 via email

@MengqingCao
Copy link
Contributor Author

@winglian sorry for a little delay. Now the rebase is done, please review it.

BTW, I basically copy the original code to make the review a little easier.
In the future, maybe refactoring more if-else branches could be done step by step, which will require more models and tests under different configuration conditions.

@MengqingCao
Copy link
Contributor Author

MengqingCao commented Oct 17, 2024 via email

  * organize comman var into member var of class ModelLoader
  * split operations in load_model into separate member funcs
  * refactor cfg.load_in_Xbit to kwarg
@MengqingCao
Copy link
Contributor Author

Hi @winglian, the code has updated, plz retrigger the CI, thanks!

@NanoCode012 NanoCode012 self-assigned this Oct 18, 2024
@MengqingCao
Copy link
Contributor Author

MengqingCao commented Oct 18, 2024

I'm confused why test failed on tests/test_prompt_tokenizers.py and tests/test_validation.py, because everything goes well on my machine. Could you give me some advice? @NanoCode012

image

@NanoCode012
Copy link
Collaborator

@MengqingCao , from the tests, it may be erroring due to this below.

ModuleNotFoundError: No module named 'flash_attn'

Let me see what should be done in a bit.

Comment on lines 116 to 140
@pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
@pytest.mark.parametrize(
"dist_dtype", [torch.bfloat16, torch.float16, torch.float32]
)
@pytest.mark.parametrize("before_kbit_train_or_finetune", [True, False])
def test_convert_embedding_modules_dtype(
self, embedding_modules, dist_dtype, before_kbit_train_or_finetune
):
tokenizer = load_tokenizer(self.cfg)
self.model_loader.model, _ = load_model(self.cfg, tokenizer, inference=False)

self.model_loader.convert_embedding_modules_dtype(
embedding_modules, dist_dtype, before_kbit_train_or_finetune
)
for name, module in self.model_loader.model.named_modules():
if (
"norm" in name
or (before_kbit_train_or_finetune and name.endswith(".gate"))
or (
any(m in name for m in embedding_modules)
and hasattr(module, "weight")
)
):
for _, param in module.named_parameters(recurse=False):
assert param.dtype == dist_dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move thjis one to it's own e2e/ test that runs on a GPU instance. I believe it's ooming

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or let's use a config fixture that uses a much smaller model like a 68M parameter model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@winglian @NanoCode012 Thanks for your help. I have moved it to e2e now. Retrigger the CI plz to check if this fix work.

@MengqingCao
Copy link
Contributor Author

@winglian I spent some time fixing the failed UTs and found that load_cfg breaks the caplog, which causes these UTs to fail. The latest UTs just use DictDefault to create cfg to fix it. Could you please retrigger the CI again to verify the current code?

@NanoCode012
Copy link
Collaborator

@winglian I spent some time fixing the failed UTs and found that load_cfg breaks the caplog, which causes these UTs to fail. The latest UTs just use DictDefault to create cfg to fix it. Could you please retrigger the CI again to verify the current code?

That's a nice catch. Been debugging it yesterday and couldn't figure out exactly why it failed when all tests are ran together. I suspected caplog but when I tried using capsys, it failed too..

I re-triggered the CI, and they are passing so far.

@MengqingCao
Copy link
Contributor Author

That's a nice catch. Been debugging it yesterday and couldn't figure out exactly why it failed when all tests are ran together. I suspected caplog but when I tried using capsys, it failed too..

It's too hidden to determine the cause, and it fails from the moment it imports load_cfg. I guess the failure may caused by accessing the resources on the hub when calling check_remote_config, but unfortunately I can't be sure.

@MengqingCao
Copy link
Contributor Author

MengqingCao commented Oct 22, 2024 via email

@MengqingCao
Copy link
Contributor Author

@winglian @NanoCode012 Thanks a lot for your work! All UTs in test_load_model.py pass now.

Since quantized parameters cannot be converted to data types simply via .to(dist_dtype), this has nothing to do with the correctness of convert_embedding_modules_dtype, so turn off load_in_Xbit to make UT pass.

I should have used a small parameter model first so that I could test it locally at first instead of oom...

@winglian winglian merged commit 1d6a5e2 into axolotl-ai-cloud:main Oct 25, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants