-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
model loading #32
Comments
yes you can use JaxServer For that |
Additionaly, can you let me know how to convet flax model into pytorch_model_bin? |
I got error with below code config = MistralConfig(rotary_type="complex")
model = FlaxMistralForCausalLM(config, _do_init=False)
tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1',model_max_length=4096,padding_side="left",add_eos_token=True)
tokenizer.pad_token = tokenizer.eos_token
server = JAXServer.load(
path='/my/ckpt-path/Mistral-Test',
model=model,
tokenizer=tokenizer,
config_model=config,
add_params_field=True,
config=None,
init_shape=(1, 1)
) Error:
|
use mistral_flax_to_pt in transform functions (I don't know if I said the right name for that func ;\ ) |
By any chance, can you show an example of loading and applying a flax model to use the mistral_flax_to_pt function? I'm having trouble as I'm not familiar with flax... Sorry for the many requests. |
use |
Is this code right to convert ckpt to pytorch_model.bin? _, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
flax_params = flatten_dict(flax_params['params'], sep='.')
pytorch_state_dict = mistral_convert_flax_to_pt(flax_params, MistralConfig())
torch.save(pytorch_state_dict, 'pytorch_model.bin') I'm wondering if it's ok to use MistralConfig as is as a config for mistral_convert_flax_to_pt |
yes the code is correct but use the config of the model you want to convert from EasyDel to torch or hf to cause the number of elements like num_hidden_layers or ... to be taken from the given config |
Hmm... so I can't just use the mistralconfig class? Also, if I have to use a custom config, I wonder if I can use the one I used for train. |
yes |
Is the only finetune method currently supported is full-fintune? |
Also, are there any plans to create a linear scheduler with a warm up step? |
I just made a warmup_linear scheduler function def get_adamw_with_warmup_linear_scheduler(
steps: int,
learning_rate_start: float = 5e-5,
learning_rate_end: float = 1e-5,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
weight_decay: float = 1e-1,
gradient_accumulation_steps: int = 1,
mu_dtype: Optional[chex.ArrayDType] = None,
warmup_steps: int = 500
):
"""
:param gradient_accumulation_steps:
:param steps:
:param learning_rate_start:
:param learning_rate_end:
:param b1:
:param b2:
:param eps:
:param eps_root:
:param weight_decay:
:param mu_dtype:
# New parameter for warmup
@warmup_steps (int): Number of steps for the warmup phase
# return Optimizer and Scheduler with WarmUp feature
"""
scheduler_warmup= optax.linear_schedule(init_value=5e-8, end_value=learning_rate_start, transition_steps=warmup_steps)
scheduler_decay= optax.linear_schedule(init_value=learning_rate_start, end_value=learning_rate_end, transition_steps=steps-warmup_steps)
scheduler_combined= optax.join_schedules(schedules=[scheduler_warmup, scheduler_decay], boundaries=[warmup_steps])
tx = optax.chain(
optax.scale_by_adam(
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
mu_dtype=mu_dtype
),
optax.add_decayed_weights(
weight_decay=weight_decay
),
optax.scale_by_schedule(scheduler_combined),
optax.scale(-1)
)
if gradient_accumulation_steps > 1:
tx = optax.MultiSteps(
tx, gradient_accumulation_steps
)
return tx, scheduler_combined |
RLHF is supported too for finetuning models but only for llama1 and falcon and mpt models right now |
ill create that for you in next update on main branch |
thank! |
I received the file Mistral-1.566301941871643-69 at the end of the model's training, and I was wondering if there is a way to convert this save file to model.bin or load it to tpu to see if it works.
Thank you for the support!
The text was updated successfully, but these errors were encountered: