From e7f7b3d9413ab24183b5aa84dce89cd785dcd146 Mon Sep 17 00:00:00 2001 From: erfanzar Date: Tue, 21 May 2024 15:44:38 +0330 Subject: [PATCH] Adding Examples for training --- .../causal_language_model_training_example.py | 216 +++++++++++++++++ .../{dpo => }/dpo_training_example.py | 0 examples/training/sft_training_example.py | 226 ++++++++++++++++++ src/python/easydel/etils/__init__.py | 5 +- src/python/easydel/etils/etils.py | 67 +++++- .../trainer/training_configurations.py | 2 - 6 files changed, 512 insertions(+), 4 deletions(-) create mode 100644 examples/training/causal_language_model_training_example.py rename examples/training/{dpo => }/dpo_training_example.py (100%) create mode 100644 examples/training/sft_training_example.py diff --git a/examples/training/causal_language_model_training_example.py b/examples/training/causal_language_model_training_example.py new file mode 100644 index 000000000..6fa23f0aa --- /dev/null +++ b/examples/training/causal_language_model_training_example.py @@ -0,0 +1,216 @@ +import transformers + +from easydel import ( + AutoEasyDeLModelForCausalLM, + TrainArguments, + EasyDeLOptimizers, + EasyDeLSchedulers, + EasyDeLGradientCheckPointers, + EasyDeLState, + EasyDeLXRapTureConfig, + CausalLanguageModelTrainer, + get_modules_by_type, + easystate_to_huggingface_model, +) +from datasets import load_dataset +from flax.core import FrozenDict +from transformers import AutoTokenizer +from jax import numpy as jnp, sharding +import jax +from transformers import AutoConfig +from huggingface_hub import HfApi +from easydel.etils import define_flags_with_default + +PartitionSpec = sharding.PartitionSpec +api = HfApi() + +FLAGS, DEF_FLAGS = define_flags_with_default( + pretrained_model_name_or_path="", + pretrained_model_name_or_path_tokenizer="", + new_repo_id="", + train_dataset="", + model_name="CLM-EasyDeL", + sharding_axis_dims=(1, -1, 1, 1), + max_length=2048, + input_shape=(8, 2048), + use_lora=False, + block_size=512, + attn_mechanism="sharded_vanilla", + weight_decay=0.02, + total_batch_size=24, + use_pjit_attention_force=False, + gradient_accumulation_steps=1, + step_start_point=0, + num_train_epochs=3, + learning_rate=2e-5, + learning_rate_end=9e-6, + warmup_steps=7, + optimizer=EasyDeLOptimizers.ADAMW, + scheduler=EasyDeLSchedulers.WARM_UP_COSINE, + gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE, + lora_dim=64, + fully_fine_tune_parameters=["embed_tokens"], + lora_fine_tune_parameters=["q_proj", "v_proj", "k_proj", "o_proj"], + training_time="90H", + _required_fields=["pretrained_model_name_or_path", "new_repo_id", "train_dataset"] +) + + +def main(): + pretrained_model_name_or_path_tokenizer = ( + FLAGS.pretrained_model_name_or_path_tokenizer if ( + FLAGS.pretrained_model_name_or_path_tokenizer != "" + ) else FLAGS.pretrained_model_name_or_path + ) + + dtype = jnp.bfloat16 + sharding_axis_dims = eval(FLAGS.sharding_axis_dims) + qps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp") + kps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp") + vps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp") + bps = PartitionSpec(("dp", "fsdp"), "sp", None, None) + aps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp") + + attention_partitions = dict( + query_partition_spec=qps, + key_partition_spec=kps, + value_partition_spec=vps, + bias_partition_spec=bps, + attention_partition_spec=aps, + ) + + model, params = AutoEasyDeLModelForCausalLM.from_pretrained( + FLAGS.pretrained_model_name_or_path, + device=jax.devices('cpu')[0], + input_shape=FLAGS.input_shape, + device_map="auto", + sharding_axis_dims=sharding_axis_dims, + config_kwargs=dict( + use_scan_mlp=False, + attn_mechanism=FLAGS.attn_mechanism, + **attention_partitions + ), + **attention_partitions + ) + + config = model.config + + model_use_tie_word_embedding = config.tie_word_embeddings + + model_parameters = FrozenDict({"params": params}) + + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path_tokenizer, + trust_remote_code=True + ) + + config.add_basic_configurations( + attn_mechanism=FLAGS.attn_mechanism, + shard_attention_computation=True, + **attention_partitions + ) + + configs_to_initialize_model_class = { + "config": config, + "dtype": dtype, + "param_dtype": dtype, + "input_shape": FLAGS.input_shape + } + + if tokenizer.pad_token == None: + tokenizer.pad_token = tokenizer.eos_token + + rapture_config = EasyDeLXRapTureConfig( + model_parameters, + lora_dim=FLAGS.lora_dim, + fully_fine_tune_parameters=FLAGS.fully_fine_tune_parameters, + lora_fine_tune_parameters=FLAGS.lora_fine_tune_parameters, + verbose=True + ) if FLAGS.use_lora else None + + train_dataset = load_dataset(FLAGS.train_dataset, split="train") + + train_arguments = TrainArguments( + model_class=get_modules_by_type(config.model_type)[1], + configs_to_initialize_model_class=configs_to_initialize_model_class, + custom_rule=config.get_partition_rules(True), + + num_train_epochs=FLAGS.num_train_epochs, + learning_rate=FLAGS.learning_rate, + learning_rate_end=FLAGS.learning_rate_end, + warmup_steps=FLAGS.warmup_steps, + optimizer=FLAGS.optimizer, + scheduler=FLAGS.scheduler, + weight_decay=FLAGS.weight_decay, + total_batch_size=FLAGS.total_batch_size, + init_input_shape=FLAGS.input_shape, + max_sequence_length=FLAGS.max_length, + model_name=FLAGS.model_name, + training_time=FLAGS.training_time, + gradient_checkpointing=FLAGS.gradient_checkpointing, + sharding_array=sharding_axis_dims, + use_pjit_attention_force=FLAGS.use_pjit_attention_force, + gradient_accumulation_steps=FLAGS.gradient_accumulation_steps, + step_start_point=FLAGS.step_start_point, + + dtype=dtype, + param_dtype=dtype, + + force_batch_and_gradient_accumulation_steps_calculation=False, + rapture_config=rapture_config, + track_memory=True + ) + + trainer = CausalLanguageModelTrainer( + arguments=train_arguments, + dataset_train=train_dataset, + ) + + output = trainer.train( + model_parameters=model_parameters if not FLAGS.use_lora else None, + state=None + ) + + api.create_repo(FLAGS.new_repo_id, exist_ok=True) + + api.upload_file( + path_or_fileobj=output.checkpoint_path, + repo_id=FLAGS.new_repo_id, + path_in_repo=output.last_save_file_name + ) + + with jax.default_device(jax.devices("cpu")[0]): + state = EasyDeLState.load_state( + output.checkpoint_path, + input_shape=FLAGS.input_shape, + ) + + if model_use_tie_word_embedding: + state_new_params = { + "params": state.params["params"] | { + "lm_head": { + "kernel": state.params["params"]["model"]["embed_tokens"]["embedding"].T + } + } + } + + state = state.replace(params=state_new_params) + + config = AutoConfig.from_pretrained(FLAGS.pretrained_model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(FLAGS.pretrained_model_name_or_path) + + with jax.default_device(jax.devices("cpu")[0]): + model = easystate_to_huggingface_model( + state=state, + base_huggingface_module=type(transformers.AutoModelForCausalLM.from_config(config)), + config=config + ) + + half_model = model.half() + + tokenizer.push_to_hub(FLAGS.new_repo_id) + half_model.push_to_hub(FLAGS.new_repo_id) + + +if __name__ == "__main__": + main() diff --git a/examples/training/dpo/dpo_training_example.py b/examples/training/dpo_training_example.py similarity index 100% rename from examples/training/dpo/dpo_training_example.py rename to examples/training/dpo_training_example.py diff --git a/examples/training/sft_training_example.py b/examples/training/sft_training_example.py new file mode 100644 index 000000000..83156c604 --- /dev/null +++ b/examples/training/sft_training_example.py @@ -0,0 +1,226 @@ +import transformers + +from easydel import ( + AutoEasyDeLModelForCausalLM, + TrainArguments, + EasyDeLOptimizers, + EasyDeLSchedulers, + EasyDeLGradientCheckPointers, + EasyDeLState, + EasyDeLXRapTureConfig, + SFTTrainer, + get_modules_by_type, + easystate_to_huggingface_model, + conversations_formatting_function +) +from datasets import load_dataset +from flax.core import FrozenDict +from transformers import AutoTokenizer +from jax import numpy as jnp, sharding +import jax +from transformers import AutoConfig +from huggingface_hub import HfApi +from easydel.etils import define_flags_with_default + +PartitionSpec = sharding.PartitionSpec +api = HfApi() + +FLAGS, DEF_FLAGS = define_flags_with_default( + pretrained_model_name_or_path="", + pretrained_model_name_or_path_tokenizer="", + new_repo_id="", + train_dataset="", + model_name="SFT-EasyDeL", + sharding_axis_dims=(1, -1, 1, 1), + max_length=2048, + input_shape=(8, 2048), + use_lora=False, + block_size=512, + attn_mechanism="sharded_vanilla", + weight_decay=0.02, + total_batch_size=24, + use_pjit_attention_force=False, + gradient_accumulation_steps=1, + step_start_point=0, + num_train_epochs=3, + learning_rate=2e-5, + learning_rate_end=9e-6, + warmup_steps=7, + optimizer=EasyDeLOptimizers.ADAMW, + scheduler=EasyDeLSchedulers.WARM_UP_COSINE, + gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE, + lora_dim=64, + fully_fine_tune_parameters=["embed_tokens"], + lora_fine_tune_parameters=["q_proj", "v_proj", "k_proj", "o_proj"], + packing_sft=True, + messages_field="conversation", + training_time="90H", + _required_fields=["pretrained_model_name_or_path", "new_repo_id", "train_dataset"] +) + + +def main(): + pretrained_model_name_or_path_tokenizer = ( + FLAGS.pretrained_model_name_or_path_tokenizer if ( + FLAGS.pretrained_model_name_or_path_tokenizer != "" + ) else FLAGS.pretrained_model_name_or_path + ) + + dtype = jnp.bfloat16 + sharding_axis_dims = eval(FLAGS.sharding_axis_dims) + qps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp") + kps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp") + vps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp") + bps = PartitionSpec(("dp", "fsdp"), "sp", None, None) + aps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp") + + attention_partitions = dict( + query_partition_spec=qps, + key_partition_spec=kps, + value_partition_spec=vps, + bias_partition_spec=bps, + attention_partition_spec=aps, + ) + + model, params = AutoEasyDeLModelForCausalLM.from_pretrained( + FLAGS.pretrained_model_name_or_path, + device=jax.devices('cpu')[0], + input_shape=FLAGS.input_shape, + device_map="auto", + sharding_axis_dims=sharding_axis_dims, + config_kwargs=dict( + use_scan_mlp=False, + attn_mechanism=FLAGS.attn_mechanism, + **attention_partitions + ), + **attention_partitions + ) + + config = model.config + + model_use_tie_word_embedding = config.tie_word_embeddings + + model_parameters = FrozenDict({"params": params}) + + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path_tokenizer, + trust_remote_code=True + ) + + config.add_basic_configurations( + attn_mechanism=FLAGS.attn_mechanism, + shard_attention_computation=True, + **attention_partitions + ) + + configs_to_initialize_model_class = { + "config": config, + "dtype": dtype, + "param_dtype": dtype, + "input_shape": FLAGS.input_shape + } + + if tokenizer.pad_token == None: + tokenizer.pad_token = tokenizer.eos_token + + rapture_config = EasyDeLXRapTureConfig( + model_parameters, + lora_dim=FLAGS.lora_dim, + fully_fine_tune_parameters=FLAGS.fully_fine_tune_parameters, + lora_fine_tune_parameters=FLAGS.lora_fine_tune_parameters, + verbose=True + ) if FLAGS.use_lora else None + + train_dataset = load_dataset(FLAGS.train_dataset, split="train") + + train_arguments = TrainArguments( + model_class=get_modules_by_type(config.model_type)[1], + configs_to_initialize_model_class=configs_to_initialize_model_class, + custom_rule=config.get_partition_rules(True), + + num_train_epochs=FLAGS.num_train_epochs, + learning_rate=FLAGS.learning_rate, + learning_rate_end=FLAGS.learning_rate_end, + warmup_steps=FLAGS.warmup_steps, + optimizer=FLAGS.optimizer, + scheduler=FLAGS.scheduler, + weight_decay=FLAGS.weight_decay, + total_batch_size=FLAGS.total_batch_size, + init_input_shape=FLAGS.input_shape, + max_sequence_length=FLAGS.max_length, + model_name=FLAGS.model_name, + training_time=FLAGS.training_time, + gradient_checkpointing=FLAGS.gradient_checkpointing, + sharding_array=sharding_axis_dims, + use_pjit_attention_force=FLAGS.use_pjit_attention_force, + gradient_accumulation_steps=FLAGS.gradient_accumulation_steps, + step_start_point=FLAGS.step_start_point, + + dtype=dtype, + param_dtype=dtype, + + force_batch_and_gradient_accumulation_steps_calculation=False, + rapture_config=rapture_config, + track_memory=True + ) + + trainer = SFTTrainer( + arguments=train_arguments, + train_dataset=train_dataset, + eval_dataset=None, + tokenizer=tokenizer, + dataset_text_field=None, + formatting_func=lambda x: [ + conversations_formatting_function(tokenizer, messages_field=FLAGS.messages_field)(x)], + packing=FLAGS.packing_sft, + num_of_sequences=FLAGS.max_length, + ) + + output = trainer.train( + model_parameters=model_parameters if not FLAGS.use_lora else None, + state=None + ) + + api.create_repo(FLAGS.new_repo_id, exist_ok=True) + + api.upload_file( + path_or_fileobj=output.checkpoint_path, + repo_id=FLAGS.new_repo_id, + path_in_repo=output.last_save_file_name + ) + + with jax.default_device(jax.devices("cpu")[0]): + state = EasyDeLState.load_state( + output.checkpoint_path, + input_shape=FLAGS.input_shape, + ) + + if model_use_tie_word_embedding: + state_new_params = { + "params": state.params["params"] | { + "lm_head": { + "kernel": state.params["params"]["model"]["embed_tokens"]["embedding"].T + } + } + } + + state = state.replace(params=state_new_params) + + config = AutoConfig.from_pretrained(FLAGS.pretrained_model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(FLAGS.pretrained_model_name_or_path) + + with jax.default_device(jax.devices("cpu")[0]): + model = easystate_to_huggingface_model( + state=state, + base_huggingface_module=type(transformers.AutoModelForCausalLM.from_config(config)), + config=config + ) + + half_model = model.half() + + tokenizer.push_to_hub(FLAGS.new_repo_id) + half_model.push_to_hub(FLAGS.new_repo_id) + + +if __name__ == "__main__": + main() diff --git a/src/python/easydel/etils/__init__.py b/src/python/easydel/etils/__init__.py index 35ce8f2a2..6ce5ea625 100644 --- a/src/python/easydel/etils/__init__.py +++ b/src/python/easydel/etils/__init__.py @@ -14,7 +14,10 @@ EasyDeLSchedulers, AVAILABLE_OPTIMIZERS, AVAILABLE_SCHEDULERS, - AVAILABLE_GRADIENT_CHECKPOINTS + AVAILABLE_GRADIENT_CHECKPOINTS, + define_flags_with_default, + set_loggers_level, + get_logger ) from .errors import ( diff --git a/src/python/easydel/etils/etils.py b/src/python/easydel/etils/etils.py index 7ce41ed9c..e61d7d191 100644 --- a/src/python/easydel/etils/etils.py +++ b/src/python/easydel/etils/etils.py @@ -1,6 +1,8 @@ import logging from dataclasses import dataclass -from typing import Literal +from typing import Literal, List, Tuple, Dict, Callable, Any + +import argparse @dataclass @@ -99,3 +101,66 @@ def set_loggers_level(level: int = logging.WARNING): logging.root.setLevel(level) for handler in logging.root.handlers: handler.setLevel(level) + + +def define_flags_with_default( + _required_fields: List = None, + **kwargs +) -> Tuple[argparse.Namespace, Dict[str, Any]]: + """ + Defines flags with default values using argparse. + + Args: + _required_fields: A dictionary with required flag names + **kwargs: Keyword arguments representing flag names and default values. + + Returns: + A tuple containing: + - An argparse.Namespace object containing parsed arguments. + - A dictionary mapping flag names to default values. + """ + _required_fields = _required_fields if _required_fields is not None else [] + parser = argparse.ArgumentParser() + + default_values = {} + + for name, value in kwargs.items(): + default_values[name] = value + + # Custom type handling: + if isinstance(value, tuple): + # For tuples, use a custom action to convert the string to a tuple of ints + parser.add_argument( + f"--{name}", + type=str, # Read as string + default=str(value), # Store default as string + help=f"Value for {name} (comma-separated integers)", + action=StoreTupleAction + ) + else: + # For other types, infer type from default value + parser.add_argument( + f"--{name}", + type=type(value), + default=value, + help=f"Value for {name}" + ) + + args = parser.parse_args() + for key in _required_fields: + if getattr(args, key) == "": + raise ValueError(f"Required field {key} for argument parser.") + return args, default_values + + +class StoreTupleAction(argparse.Action): + """Custom action to store a comma-separated string as a tuple of ints.""" + + def __call__(self, parser, namespace, values, option_string=None): + try: + setattr(namespace, self.dest, tuple(int(v) for v in values.split(","))) + except ValueError: + raise argparse.ArgumentTypeError( + f"Invalid value for {option_string}: {values} " + f"(should be comma-separated integers)" + ) diff --git a/src/python/easydel/trainer/training_configurations.py b/src/python/easydel/trainer/training_configurations.py index f516a9930..1d73bde80 100644 --- a/src/python/easydel/trainer/training_configurations.py +++ b/src/python/easydel/trainer/training_configurations.py @@ -240,9 +240,7 @@ def __init__( available_backends = len(jax.devices(backend)) if force_batch_and_gradient_accumulation_steps_calculation: total_batch_size *= gradient_accumulation_steps # Changed and will be handled inside FJFormer - array_devices = jnp.ones((available_backends, 1)).reshape(sharding_array) - JaxDistributedConfig.initialize(jax_distributed_config) self.force_batch_and_gradient_accumulation_steps_calculation = ( force_batch_and_gradient_accumulation_steps_calculation