Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model loading #32

Closed
JinSeoungwoo opened this issue Oct 7, 2023 · 17 comments
Closed

model loading #32

JinSeoungwoo opened this issue Oct 7, 2023 · 17 comments

Comments

@JinSeoungwoo
Copy link

JinSeoungwoo commented Oct 7, 2023

I received the file Mistral-1.566301941871643-69 at the end of the model's training, and I was wondering if there is a way to convert this save file to model.bin or load it to tpu to see if it works.

Thank you for the support!

@JinSeoungwoo JinSeoungwoo changed the title learning rate decay does not decrease as stated gradient checkpointing and model loading Oct 8, 2023
@JinSeoungwoo JinSeoungwoo changed the title gradient checkpointing and model loading model loading Oct 8, 2023
@erfanzar
Copy link
Owner

erfanzar commented Oct 8, 2023

yes you can use JaxServer For that

@erfanzar
Copy link
Owner

erfanzar commented Oct 8, 2023

link to JaxServer Docs

@JinSeoungwoo
Copy link
Author

JinSeoungwoo commented Oct 8, 2023

Additionaly, can you let me know how to convet flax model into pytorch_model_bin?

@JinSeoungwoo
Copy link
Author

JinSeoungwoo commented Oct 9, 2023

yes you can use JaxServer For that

I got error with below code

config = MistralConfig(rotary_type="complex")
model = FlaxMistralForCausalLM(config, _do_init=False)

tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1',model_max_length=4096,padding_side="left",add_eos_token=True)
tokenizer.pad_token = tokenizer.eos_token

server = JAXServer.load(
    path='/my/ckpt-path/Mistral-Test',
    model=model,
    tokenizer=tokenizer,
    config_model=config,
    add_params_field=True,
    config=None,
    init_shape=(1, 1)
)

Error:

Traceback (most recent call last):
  File "test.py", line 12, in <module>
    server = JAXServer.load(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 414, in load
    server.compile(verbose=verbose)
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 475, in compile
    for r, a in self.process(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 615, in process
    predicted_token = self.greedy_generate(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 506, in greedy_generate
    return self._greedy_generate(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 286, in greedy_generate
    predict = model.generate(
  File "/usr/local/lib/python3.8/dist-packages/transformers/generation/flax_utils.py", line 417, in generate
    return self._greedy_search(
  File "/usr/local/lib/python3.8/dist-packages/transformers/generation/flax_utils.py", line 636, in _greedy_search
    state = greedy_search_body_fn(state)
  File "/usr/local/lib/python3.8/dist-packages/transformers/generation/flax_utils.py", line 612, in greedy_search_body_fn
    model_outputs = model(state.running_token, params=params, **state.model_kwargs)
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 561, in __call__
    outputs = self.module.apply(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 784, in __call__
    outputs = self.model(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 708, in __call__
    outputs = self.layers(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 638, in __call__
    output = layer(
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/partitioning.py", line 553, in inner
    return rematted(variable_groups, rng_groups, *dyn_args)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/partitioning.py", line 550, in rematted
    y = fn(scope, *args)
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 432, in __call__
    attention_output = self.self_attn(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 347, in __call__
    q, k, v, attention_mask = self.concatenate_to_cache_(q, k, v, attention_mask)
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 316, in concatenate_to_cache_
    attention_mask = nn.combine_masks(pad_mask, attention_mask)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/attention.py", line 506, in combine_masks
    assert all(
AssertionError: masks must have same rank: (4, 2)

@erfanzar
Copy link
Owner

erfanzar commented Oct 9, 2023

Additionally, can you let me know how to convert the flax model into pytorch_model_bin?

use mistral_flax_to_pt in transform functions (I don't know if I said the right name for that func ;\ )
right now mistral models have a computing problem that I'm trying to fix them as soon as I can

erfanzar added a commit that referenced this issue Oct 9, 2023
@JinSeoungwoo
Copy link
Author

Additionally, can you let me know how to convert the flax model into pytorch_model_bin?

use mistral_flax_to_pt in transform functions (I don't know if I said the right name for that func ;\ ) right now mistral models have a computing problem that I'm trying to fix them as soon as I can

By any chance, can you show an example of loading and applying a flax model to use the mistral_flax_to_pt function? I'm having trouble as I'm not familiar with flax... Sorry for the many requests.

@erfanzar
Copy link
Owner

erfanzar commented Oct 9, 2023

Additionally, can you let me know how to convert the flax model into pytorch_model_bin?

use mistral_flax_to_pt in transform functions (I don't know if I said the right name for that func ;\ ) right now mistral models have a computing problem that I'm trying to fix them as soon as I can

By any chance, can you show an example of loading and applying a flax model to use the mistral_flax_to_pt function? I'm having trouble as I'm not familiar with flax... Sorry for the many requests.

use mistral_convert_flax_to_pt
and that's fine you can ask any question that you want I'm here to help <3

@JinSeoungwoo
Copy link
Author

JinSeoungwoo commented Oct 10, 2023

Is this code right to convert ckpt to pytorch_model.bin?

 _, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
flax_params = flatten_dict(flax_params['params'], sep='.')

pytorch_state_dict = mistral_convert_flax_to_pt(flax_params, MistralConfig())

torch.save(pytorch_state_dict, 'pytorch_model.bin')

I'm wondering if it's ok to use MistralConfig as is as a config for mistral_convert_flax_to_pt

@erfanzar
Copy link
Owner

yes the code is correct but use the config of the model you want to convert from EasyDel to torch or hf to cause the number of elements like num_hidden_layers or ... to be taken from the given config

@JinSeoungwoo
Copy link
Author

yes the code is correct but use the config of the model you want to convert from EasyDel to torch or hf to cause the number of elements like num_hidden_layers or ... to be taken from the given config

Hmm... so I can't just use the mistralconfig class? Also, if I have to use a custom config, I wonder if I can use the one I used for train.

@erfanzar
Copy link
Owner

yes
you should use the config that used to train mode (saved in W&B project if you don't remember that)

@JinSeoungwoo
Copy link
Author

Is the only finetune method currently supported is full-fintune?

@JinSeoungwoo
Copy link
Author

Also, are there any plans to create a linear scheduler with a warm up step?

@JinSeoungwoo
Copy link
Author

JinSeoungwoo commented Oct 11, 2023

Also, are there any plans to create a linear scheduler with a warm up step?

I just made a warmup_linear scheduler function

def get_adamw_with_warmup_linear_scheduler(
        steps: int,
        learning_rate_start: float = 5e-5,
        learning_rate_end: float = 1e-5,
        b1: float = 0.9,
        b2: float = 0.999,
        eps: float = 1e-8,
        eps_root: float = 0.0,
        weight_decay: float = 1e-1,
        gradient_accumulation_steps: int = 1,
        mu_dtype: Optional[chex.ArrayDType] = None,
        
        warmup_steps: int = 500

):
    """

    :param gradient_accumulation_steps:
    :param steps:
    :param learning_rate_start:
    :param learning_rate_end:
    :param b1:
    :param b2:
    :param eps:
    :param eps_root:
    :param weight_decay:
    :param mu_dtype:

     # New parameter for warmup
     @warmup_steps (int): Number of steps for the warmup phase

     # return Optimizer and Scheduler with WarmUp feature
   """
   
    scheduler_warmup= optax.linear_schedule(init_value=5e-8, end_value=learning_rate_start, transition_steps=warmup_steps)
    scheduler_decay= optax.linear_schedule(init_value=learning_rate_start, end_value=learning_rate_end, transition_steps=steps-warmup_steps)

    scheduler_combined= optax.join_schedules(schedules=[scheduler_warmup, scheduler_decay], boundaries=[warmup_steps])

    tx = optax.chain(
        optax.scale_by_adam(
            b1=b1,
            b2=b2,
            eps=eps,
            eps_root=eps_root,
            mu_dtype=mu_dtype
        ),
        optax.add_decayed_weights(
            weight_decay=weight_decay
        ),
        optax.scale_by_schedule(scheduler_combined),
        optax.scale(-1)
    )
    if gradient_accumulation_steps > 1:
        tx = optax.MultiSteps(
            tx, gradient_accumulation_steps
        )
    return tx, scheduler_combined

@erfanzar
Copy link
Owner

Is the only finetune method currently supported is full-fintune?

RLHF is supported too for finetuning models but only for llama1 and falcon and mpt models right now

@erfanzar
Copy link
Owner

Also, are there any plans to create a linear scheduler with a warm up step?

ill create that for you in next update on main branch

@erfanzar
Copy link
Owner

Also, are there any plans to create a linear scheduler with a warm up step?

I just made a warmup_linear scheduler function

def get_adamw_with_warmup_linear_scheduler(
        steps: int,
        learning_rate_start: float = 5e-5,
        learning_rate_end: float = 1e-5,
        b1: float = 0.9,
        b2: float = 0.999,
        eps: float = 1e-8,
        eps_root: float = 0.0,
        weight_decay: float = 1e-1,
        gradient_accumulation_steps: int = 1,
        mu_dtype: Optional[chex.ArrayDType] = None,
        
        warmup_steps: int = 500

):
    """

    :param gradient_accumulation_steps:
    :param steps:
    :param learning_rate_start:
    :param learning_rate_end:
    :param b1:
    :param b2:
    :param eps:
    :param eps_root:
    :param weight_decay:
    :param mu_dtype:

     # New parameter for warmup
     @warmup_steps (int): Number of steps for the warmup phase

     # return Optimizer and Scheduler with WarmUp feature
   """
   
    scheduler_warmup= optax.linear_schedule(init_value=5e-8, end_value=learning_rate_start, transition_steps=warmup_steps)
    scheduler_decay= optax.linear_schedule(init_value=learning_rate_start, end_value=learning_rate_end, transition_steps=steps-warmup_steps)

    scheduler_combined= optax.join_schedules(schedules=[scheduler_warmup, scheduler_decay], boundaries=[warmup_steps])

    tx = optax.chain(
        optax.scale_by_adam(
            b1=b1,
            b2=b2,
            eps=eps,
            eps_root=eps_root,
            mu_dtype=mu_dtype
        ),
        optax.add_decayed_weights(
            weight_decay=weight_decay
        ),
        optax.scale_by_schedule(scheduler_combined),
        optax.scale(-1)
    )
    if gradient_accumulation_steps > 1:
        tx = optax.MultiSteps(
            tx, gradient_accumulation_steps
        )
    return tx, scheduler_combined

thank!
update Fjtuils to 0.0.20 and reinstall EasyDel and you then you can use warm_up_linear

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants