From 2588d030a2d1555bb3abb2ba7396e6ed4b53eab8 Mon Sep 17 00:00:00 2001 From: erfanzar Date: Tue, 21 May 2024 18:06:16 +0330 Subject: [PATCH] Updating Docs and adding new examples --- README.md | 2 +- docs/generated-cli-cli.md | 2 + generate_documentations.py | 12 +- mkdocs.yml | 14 +- src/python/easydel/etils/auto_tx.py | 30 +- src/python/easydel/etils/configs.py | 13 +- src/python/easydel/etils/easystate.py | 286 +- src/python/easydel/etils/etils.py | 20 +- .../modules/_attentions/blockwise_attn.py | 3 +- .../modules/arctic/modelling_arctic_flax.py | 176 +- .../easydel/modules/attention_module.py | 8 +- .../easydel/modules/auto_easydel_model.py | 168 +- .../modules/cohere/cohere_configuration.py | 27 +- .../modules/cohere/modelling_cohere_flax.py | 331 +- .../modules/dbrx/modelling_dbrx_flax.py | 131 +- .../deepseek_v2/deepseek_configuration.py | 424 +-- .../deepseek_v2/modeling_deepseek_flax.py | 2762 +++++++-------- .../modules/easydel_modelling_utils.py | 340 +- .../modules/falcon/modelling_falcon_flax.py | 95 +- .../easydel/modules/flax_modelling_utils.py | 129 +- .../modules/gemma/gemma_configuration.py | 27 +- .../modules/gemma/modelling_gemma_flax.py | 56 +- .../modules/gpt_j/modelling_gpt_j_flax.py | 2 +- src/python/easydel/modules/grok_1/__init__.py | 10 +- .../modules/grok_1/grok_1_configuration.py | 299 +- .../modules/grok_1/modelling_grok_1_flax.py | 2543 +++++++------- .../modules/jetmoe/jetmoe_configuration.py | 25 +- .../modules/jetmoe/modelling_jetmoe_flax.py | 8 +- .../modules/llama/llama_configuration.py | 147 +- .../modules/llama/modelling_llama_flax.py | 397 ++- .../llama/vision_llama_configuration.py | 10 +- src/python/easydel/modules/mamba/__init__.py | 16 +- .../modules/mamba/mamba_configuration.py | 144 +- .../modules/mamba/modelling_mamba_flax.py | 2145 ++++++------ .../modules/mistral/mistral_configuration.py | 129 +- .../modules/mistral/modelling_mistral_flax.py | 248 +- .../mistral/modelling_vision_mistral_flax.py | 16 +- .../mistral/vision_mistral_configuration.py | 10 +- .../modules/mixtral/mixtral_configuration.py | 142 +- .../modules/mixtral/modelling_mixtral_flax.py | 179 +- .../modules/mosaic_mpt/modelling_mpt_flax.py | 28 +- .../modules/olmo/olmo_configuration.py | 14 +- .../modules/openelm/modelling_openelm_flax.py | 211 +- .../modules/openelm/openelm_configuration.py | 124 +- .../easydel/modules/opt/modelling_opt_flax.py | 2 +- .../easydel/modules/phi/modelling_phi_flax.py | 42 +- .../modules/phi3/modelling_phi3_flax.py | 42 +- .../modules/phi3/phi3_configuration.py | 4 +- .../modules/qwen1/modelling_qwen1_flax.py | 395 ++- .../modules/qwen1/qwen1_configuration.py | 40 +- .../modules/qwen2/modelling_qwen_flax.py | 397 ++- .../modules/qwen2/qwen_configuration.py | 62 +- .../qwen2_moe/configuration_qwen2_moe.py | 26 +- .../qwen2_moe/modeling_qwen2_moe_flax.py | 361 +- .../easydel/modules/roberta/__init__.py | 36 +- .../modules/roberta/modelling_roberta_flax.py | 2818 +++++++-------- .../modules/roberta/roberta_configuration.py | 186 +- src/python/easydel/modules/rwkv/__init__.py | 22 +- .../stablelm/modelling_stablelm_flax.py | 90 +- .../easydel/modules/t5/modelling_t5_flax.py | 2 +- .../easydel/modules/whisper/__init__.py | 28 +- .../modules/whisper/modelling_whisper_flax.py | 3078 ++++++++--------- .../modules/whisper/whisper_configuration.py | 238 +- .../easydel/partitioning/partitioner.py | 13 +- .../reinforcement_learning/__init__.py | 3 +- .../easydel/reinforcement_learning/core.py | 3 +- .../models/modelling_casual_language_rl.py | 21 +- .../trainer/__init__.py | 2 +- .../trainer/partitioner_config.py | 52 +- .../trainer/ppo_config.py | 116 +- .../trainer/training_configs.py | 30 +- .../reinforcement_learning/trainer/utils.py | 42 +- .../utils/collectors.py | 117 +- .../serve/gradio_user_interface_base.py | 30 +- src/python/easydel/serve/jax_serve.py | 318 +- .../easydel/serve/prompters/__init__.py | 14 +- .../easydel/serve/prompters/base_prompter.py | 20 +- .../easydel/serve/serve_engine/__init__.py | 18 +- .../serve/serve_engine/configuration.py | 194 +- .../easydel/serve/serve_engine/serve.py | 1047 +++--- src/python/easydel/serve/torch_serve.py | 167 +- src/python/easydel/serve/utils.py | 50 +- src/python/easydel/smi/smi.py | 44 +- src/python/easydel/trainer/base_trainer.py | 82 +- .../causal_language_model_trainer/__init__.py | 30 +- .../causal_language_model_trainer.py | 26 +- .../fwd_bwd_functions.py | 363 +- .../modeling_output.py | 30 +- src/python/easydel/trainer/dpo/__init__.py | 38 +- src/python/easydel/trainer/dpo/dpo_trainer.py | 2548 +++++++------- .../easydel/trainer/dpo/fwd_bwd_functions.py | 1530 ++++---- .../easydel/trainer/dpo/modelling_output.py | 30 +- src/python/easydel/trainer/dpo/utils.py | 301 +- src/python/easydel/trainer/orpo/__init__.py | 22 +- .../easydel/trainer/orpo/fwd_bwd_functions.py | 797 ++--- .../easydel/trainer/orpo/modelling_output.py | 30 +- .../easydel/trainer/orpo/orpo_trainer.py | 2438 ++++++------- .../trainer/training_configurations.py | 279 +- src/python/easydel/trainer/utils.py | 16 +- .../__init__.py | 30 +- .../fwd_bwd_functions.py | 321 +- .../modelling_output.py | 30 +- .../vision_causal_language_model_trainer.py | 26 +- .../easydel/transform/easydel_transform.py | 85 +- src/python/easydel/transform/falcon.py | 4 +- src/python/easydel/transform/llama.py | 8 +- src/python/easydel/transform/mistral.py | 8 +- src/python/easydel/transform/mpt.py | 4 +- src/python/easydel/utils/prompters.py | 61 +- src/python/easydel/utils/tensor_utils.py | 12 +- src/python/easydel/utils/utils.py | 116 +- 111 files changed, 16330 insertions(+), 15008 deletions(-) create mode 100644 docs/generated-cli-cli.md diff --git a/README.md b/README.md index 31a304450..d52d939cc 100644 --- a/README.md +++ b/README.md @@ -192,7 +192,7 @@ print(f"Hey ! , here's where your model saved {output.checkpoint_path}") ``` > [!NOTE] -> You Can use Lora too, both for DPO and SFT Trainers. +> You Can use Lora too, for DPO, ORPO and SFT Trainers. ## FineTuning diff --git a/docs/generated-cli-cli.md b/docs/generated-cli-cli.md new file mode 100644 index 000000000..a0eecaab1 --- /dev/null +++ b/docs/generated-cli-cli.md @@ -0,0 +1,2 @@ +# cli.cli +::: src.python.easydel.cli.cli \ No newline at end of file diff --git a/generate_documentations.py b/generate_documentations.py index e8537eac0..4f9fde709 100644 --- a/generate_documentations.py +++ b/generate_documentations.py @@ -121,11 +121,11 @@ def main(): handlers: python: options: - docstring_style: sphinx + docstring_style: google repo_url: https://github.com/erfanzar/EasyDeL site_author: Erfan Zare Chavoshi -site_name: easydel +site_name: EasyDeL copyright: Erfan Zare Chavoshi-easydel theme: @@ -138,9 +138,9 @@ def main(): statics = { ("Home",): "index.md", - ("install",): "Install.md", - ("AvailableModels",): "AvailableModels.md", - ("EasyBIT",): "Bits.md", + ("Install",): "Install.md", + ("Available models",): "AvailableModels.md", + ("Easy Bits",): "Bits.md", ("Examples", "EasyState"): "EasyStateExample.md", ("Examples", "LoRA and Transfer Learning"): "LoRA-TransferLearningExample.md", ("Examples", "Fine Tuning Example"): "FineTuningExample.md", @@ -154,7 +154,7 @@ def main(): ("Examples", "MosaicMPT Models"): "MosaicMPT.md", ("Examples", "Easy Attention"): "AttentionModuleExample.md", ("Examples", "Model Parameter Quantization"): "Parameter-Quantization.md", - ("CONTRIBUTING",): "CONTRIBUTING.md" + ("Contributing",): "CONTRIBUTING.md" } cache = {("APIs",) + k: v for k, v in cache.items()} diff --git a/mkdocs.yml b/mkdocs.yml index 9437a9134..4b2ca775e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,5 +1,7 @@ nav: - APIs: + - Cli: + - Cli: generated-cli-cli.md - Data Preprocessing: - Processor: generated-data_preprocessing-_processor.md - Etils: @@ -188,9 +190,9 @@ nav: - Prompters: generated-utils-prompters.md - Tensor Utils: generated-utils-tensor_utils.md - Utils: generated-utils-utils.md - - AvailableModels: AvailableModels.md - - CONTRIBUTING: CONTRIBUTING.md - - EasyBIT: Bits.md + - Available models: AvailableModels.md + - Contributing: CONTRIBUTING.md + - Easy Bits: Bits.md - Examples: - DataProcessing: DataProcessing.md - Easy Attention: AttentionModuleExample.md @@ -206,7 +208,7 @@ nav: - MosaicMPT Models: MosaicMPT.md - PytorchServer: PyTorchServer.md - Home: index.md - - install: Install.md + - Install: Install.md plugins: - search @@ -214,11 +216,11 @@ plugins: handlers: python: options: - docstring_style: sphinx + docstring_style: google repo_url: https://github.com/erfanzar/EasyDeL site_author: Erfan Zare Chavoshi -site_name: easydel +site_name: EasyDeL copyright: Erfan Zare Chavoshi-easydel theme: diff --git a/src/python/easydel/etils/auto_tx.py b/src/python/easydel/etils/auto_tx.py index 28a099014..8e8a19996 100644 --- a/src/python/easydel/etils/auto_tx.py +++ b/src/python/easydel/etils/auto_tx.py @@ -20,20 +20,26 @@ def get_optimizer_and_scheduler( weight_decay: float = 0.02, warmup_steps: int = 0 ): - """ - The get_optimizer_and_scheduler function is a helper function that returns an optimizer and scheduler + """The get_optimizer_and_scheduler function is a helper function that returns an optimizer and scheduler based on the parameters passed to it. - :param optimizer: AVAILABLE_OPTIMIZERS: Choose the optimizer - :param scheduler: AVAILABLE_SCHEDULERS: Determine the learning rate scheduler - :param steps: int: Specify the number of steps in the training process - :param learning_rate: float: Set the learning rate for the optimizer - :param learning_rate_end: float: Set the final learning rate - :param gradient_accumulation_steps: int: Accumulate the gradients before updating the weights - :param extra_optimizer_kwargs: dict | None: Pass extra arguments to the optimizer - :param weight_decay: float: Set the weight decay for adamw optimizer - :param warmup_steps: int: Specify the number of steps to warm up the learning rate - :return: A tuple of two objects: (Optimizer and scheduler) + Args: + optimizer: AVAILABLE_OPTIMIZERS: Choose the optimizer + scheduler: AVAILABLE_SCHEDULERS: Determine the learning rate + scheduler + steps: int: Specify the number of steps in the training process + learning_rate: float: Set the learning rate for the optimizer + learning_rate_end: float: Set the final learning rate + gradient_accumulation_steps: int: Accumulate the gradients + before updating the weights + extra_optimizer_kwargs: dict | None: Pass extra arguments to the + optimizer + weight_decay: float: Set the weight decay for adamw optimizer + warmup_steps: int: Specify the number of steps to warm up the + learning rate + + Returns: + A tuple of two objects: (Optimizer and scheduler) """ if extra_optimizer_kwargs is None: extra_optimizer_kwargs = {} diff --git a/src/python/easydel/etils/configs.py b/src/python/easydel/etils/configs.py index d4aae2e5b..d489f267b 100644 --- a/src/python/easydel/etils/configs.py +++ b/src/python/easydel/etils/configs.py @@ -396,13 +396,14 @@ def get_config(model_type: str, struct: str): - """ - The get_config function takes in a model_type and struct, and returns the corresponding config. + """The get_config function takes in a model_type and struct, and returns the corresponding config. + + Args: + model_type: str: Determine which model to use + struct: str: Specify the structure of the model - :param model_type: str: Determine which model to use - :param struct: str: Specify the structure of the model - :return: A dictionary of hyperparameters - + Returns: + A dictionary of hyperparameters """ if model_type == "llama": return llama_configs[struct] diff --git a/src/python/easydel/etils/easystate.py b/src/python/easydel/etils/easystate.py index 441e33818..951a9321b 100644 --- a/src/python/easydel/etils/easystate.py +++ b/src/python/easydel/etils/easystate.py @@ -69,16 +69,19 @@ class EasyDeLState(struct.PyTreeNode): def apply_gradients(self, *, grads, **kwargs): - """ - The apply_gradients function is the core of the optimizer. It takes in a dictionary of gradients, + """The apply_gradients function is the core of the optimizer. It takes in a dictionary of gradients, and returns an updated version of itself with new parameters and state. The function also updates the step count. - :param self: Refer to the current instance of the class - :param *: Unpack the grads dictionary into positional arguments - :param grads: Pass in the gradients of the loss function with respect to each parameter - :param kwargs: Pass in additional arguments to the function - :return: A new State with the updated parameters and params + Args: + self: Refer to the current instance of the class + : Unpack the grads dictionary into positional arguments + grads: Pass in the gradients of the loss function with + respect to each parameter + **kwargs: Pass in additional arguments to the function + + Returns: + A new State with the updated parameters and params """ if OVERWRITE_WITH_GRADIENT in grads: grads_with_opt = grads['params'] @@ -120,21 +123,28 @@ def create( **kwargs ): - """ - The create function is used to create a new instance of the class. - - :param cls: Create a new instance of the class - :param *: Pass a list of parameters to the function - :param apply_fn: Callable: Apply the model to a batch of data - :param params: core.FrozenDict[str,Any] | Mapping[str,Any]: Pass in the parameters of the model - :param tx: optax.GradientTransformation: Initialize the optimizer - :param tx_init: Optional[dict]: Initialize the optimizer - :param hyperparameters: Optional[dict]: Pass hyperparameters to the state for init - :param module: Optional[EasyDeLFlaxPretrainedModel]: Pass the module to be used int state - :param module_config: Optional[EasyDeLPretrainedConfig]: Pass in the module config - :param module_config_args: Optional[dict]: Store the config args of the model - :param kwargs: Pass in additional parameters to the - :return: A EasyDeLState object + """The create function is used to create a new instance of the class. + + Args: + cls: Create a new instance of the class + : Pass a list of parameters to the function + apply_fn: Callable: Apply the model to a batch of data + params: core.FrozenDict[str,Any] | Mapping[str,Any]: Pass in + the parameters of the model + tx: optax.GradientTransformation: Initialize the optimizer + tx_init: Optional[dict]: Initialize the optimizer + hyperparameters: Optional[dict]: Pass hyperparameters to the + state for init + module: Optional[EasyDeLFlaxPretrainedModel]: Pass the + module to be used int state + module_config: Optional[EasyDeLPretrainedConfig]: Pass in + the module config + module_config_args: Optional[dict]: Store the config args of + the model + **kwargs: Pass in additional parameters to the + + Returns: + A EasyDeLState object """ if hyperparameters is None: hyperparameters = {} @@ -175,22 +185,32 @@ def load( **kwargs ): - """ - The load function is used to load a saved state of the Model and optimizer or Model Only. - - :param cls: Make the function a class method - :param *: Pass in a variable number of arguments - :param step: int: Keep track of the number of steps that have been taken - :param apply_fn: Callable: Apply the optimizer to the model - :param params: core.FrozenDict[str,Any] | Mapping[str,Any]: Pass in the parameters of the model - :param opt_state: Optional[optax.OptState]: optimizer state - :param tx_init: Optional[dict]: Pass the hyperparameters to the optimizer - :param hyperparameters: Optional[dict]: Load hyperparameters from the state dict - :param module: Optional[EasyDeLFlaxPretrainedModel]: Pass in the module - :param module_config: Optional[EasyDeLPretrainedConfig]: Pass the module config - :param module_config_args: Optional[dict]: Pass the config_args to the model - :param kwargs: Pass in any additional parameters that may be needed for the model - :return: A new instance of the class + """The load function is used to load a saved state of the Model and optimizer or Model Only. + + Args: + cls: Make the function a class method + : Pass in a variable number of arguments + step: int: Keep track of the number of steps that have been + taken + apply_fn: Callable: Apply the optimizer to the model + params: core.FrozenDict[str,Any] | Mapping[str,Any]: Pass in + the parameters of the model + opt_state: Optional[optax.OptState]: optimizer state + tx_init: Optional[dict]: Pass the hyperparameters to the + optimizer + hyperparameters: Optional[dict]: Load hyperparameters from + the state dict + module: Optional[EasyDeLFlaxPretrainedModel]: Pass in the + module + module_config: Optional[EasyDeLPretrainedConfig]: Pass the + module config + module_config_args: Optional[dict]: Pass the config_args to + the model + **kwargs: Pass in any additional parameters that may be + needed for the model + + Returns: + A new instance of the class """ if module_config is not None: module_config = copy.deepcopy(module_config) @@ -267,22 +287,29 @@ def load_state( config_kwargs: Optional[dict] = None ): - """ - The load_state function is a class method that loads the state of an EasyDeLModel from a checkpoint. - - :param cls: Create an instance of the class - :param checkpoint_path: str | os.PathLike: Specify the path to the checkpoint file - :param dtype: jnp.dtype: The dtype of the model - :param param_dtype: jnp.dtype: The dtype of the model parameters - :param precision: Optional[Union[str, jax.lax.Precision]]: precision of the model - :param init_optimizer_state: bool: Initialize the optimizer if it's not Initialized yet (if it Initialized the option + """The load_state function is a class method that loads the state of an EasyDeLModel from a checkpoint. + + Args: + cls: Create an instance of the class + checkpoint_path: str | os.PathLike: Specify the path to the + checkpoint file + dtype: jnp.dtype: The dtype of the model + param_dtype: jnp.dtype: The dtype of the model parameters + precision: Optional[Union[str, jax.lax.Precision]]: + precision of the model + init_optimizer_state: bool: Initialize the optimizer if it's + not Initialized yet (if it Initialized the option + state_shard_fns: Optional[Mapping[str,Callable]]: Specify + the function that will be used + verbose: bool: Print out the progress of loading + input_shape: Tuple: input_shape to init module + config_kwargs: Optional[dict] : config kwargs to be passed + to model config will be ignored ) - :param state_shard_fns: Optional[Mapping[str,Callable]]: Specify the function that will be used to shard the loaded state - :param verbose: bool: Print out the progress of loading - :param input_shape: Tuple: input_shape to init module - :param config_kwargs: Optional[dict] : config kwargs to be passed to model config - :return: A state object + + Returns: + A state object """ from ..modules.auto_easydel_model import get_modules_by_type @@ -349,18 +376,25 @@ def save_state( float_dtype: Union[str, jax.numpy.dtype] = None, ): - """ - The save_state function saves the state of a model to disk. - - :param self: Pass the object itself to the function - :param filename: str | os.PathLike: Specify the name of the file to save - :param save_optimizer: bool: Determine whether to save the optimizer state or not - :param checkpoint_dir: Optional[str | os.PathLike]: Specify the directory where the checkpoint is saved - :param verbose: bool: Print out the path of the saved file - :param gather_fns: dict[Callable]: Specify a dictionary of functions that can be used to gather - :param float_dtype: str | jax.numpy.dtype: Specify the precision of the saved model + """The save_state function saves the state of a model to disk. + + Args: + self: Pass the object itself to the function + filename: str | os.PathLike: Specify the name of the file to + save + save_optimizer: bool: Determine whether to save the + optimizer state or not + checkpoint_dir: Optional[str | os.PathLike]: Specify the + directory where the checkpoint is saved + verbose: bool: Print out the path of the saved file + gather_fns: dict[Callable]: Specify a dictionary of + functions that can be used to gather + float_dtype: str | jax.numpy.dtype: Specify the precision of + the saved model :param : Save the optimizer state - :return: None + + Returns: + None """ state = self if not save_optimizer: @@ -385,13 +419,14 @@ def save_state( def free_opt_state(self) -> "EasyDeLState": - """ - The free_opt_state function is used to free the memory allocated by a previous call to setopt. + """The free_opt_state function is used to free the memory allocated by a previous call to setopt. It should be called after all the options have been set, and before you perform any of the transfers. + Args: + self: Represent the instance of the class - :param self: Represent the instance of the class - :return: A new state with the opt_state field set to none + Returns: + A new state with the opt_state field set to none """ return self.replace( opt_state=None @@ -399,10 +434,14 @@ def free_opt_state(self) -> "EasyDeLState": def init_opt_state(self) -> "EasyDeLState": - """ - The init_opt_state function initializes the optimizer state. - :param self: Make the object callable, and params is used to pass in a dictionary of parameters - :return: A new instance of the class with opt_state initialized + """The init_opt_state function initializes the optimizer state. + + Args: + self: Make the object callable, and params is used to pass + in a dictionary of parameters + + Returns: + A new instance of the class with opt_state initialized """ if self.opt_state is None: params_with_opt = ( @@ -447,40 +486,63 @@ def from_pretrained( **kwargs ) -> "EasyDeLState": - """ - The from_pretrained function is a helper function to quickly load a pretrained model and its associated configuration. + """The from_pretrained function is a helper function to quickly load a pretrained model and its associated configuration. This method takes care of returning the correct model class instance based on the `model_type` property in the config object, or when it's missing, falling back to using pattern matching on the `pretrained_model_name_or_path` string: - :param cls: Refer to the class that is being defined - :param pretrained_model_name_or_path: str: Load the pretrained model - :param filename: Optional[str]: Specify the name of the file to download from huggingface hub - :param optimizer: AVAILABLE_OPTIMIZERS: Specify the optimizer used for training - :param scheduler: AVAILABLE_SCHEDULERS: Specify the name of the scheduler to use - :param tx_init: Optional[dict]: Pass the hyperparameters of the optimizer - :param device: Specify the device on which to run the model - :param dtype: jax.numpy.dtype: Specify the dtype of the model parameters - :param param_dtype: jax.numpy.dtype: Specify the data type of the parameters - :param precision: jax.lax.Precision: Control the precision of the calculation - :param sharding_axis_dims: Sequence[int]: Specify the dimension of each axis - :param sharding_axis_names: Sequence[str]: Specify the names of the axes in each shard - :param query_partition_spec: PartitionSpec: Specify the partitioning of the query matrix - :param generation_query_partition_spec: PartitionSpec: Specify the partitioning of the query tensor in + Args: + cls: Refer to the class that is being defined + pretrained_model_name_or_path: str: Load the pretrained + model + filename: Optional[str]: Specify the name of the file to + download from huggingface hub + optimizer: AVAILABLE_OPTIMIZERS: Specify the optimizer used + for training + scheduler: AVAILABLE_SCHEDULERS: Specify the name of the + scheduler to use + tx_init: Optional[dict]: Pass the hyperparameters of the + optimizer + device: Specify the device on which to run the model + dtype: jax.numpy.dtype: Specify the dtype of the model + parameters + param_dtype: jax.numpy.dtype: Specify the data type of the + parameters + precision: jax.lax.Precision: Control the precision of the + calculation + sharding_axis_dims: Sequence[int]: Specify the dimension of + each axis + sharding_axis_names: Sequence[str]: Specify the names of the + axes in each shard + query_partition_spec: PartitionSpec: Specify the + partitioning of the query matrix + generation_query_partition_spec: PartitionSpec: Specify the + partitioning of the query tensor in + value_partition_spec: PartitionSpec: Specify the + partitioning of the value tensor + bias_partition_spec: PartitionSpec: Specify the partitioning + of the bias + attention_partition_spec: PartitionSpec: Partition the + attention weights + shard_attention_computation: bool: Determine whether to use + shard_map or not + input_shape: Sequence[int]: Specify the shape of the input + to be used for training + backend: Optional[str]: Specify the backend used for the + model + init_optimizer_state: bool: Initialize the optimizer state + free_optimizer_state: bool: Free the optimizer state from + memory + verbose: bool: Print the progress of loading the model + state_shard_fns: Optional[Mapping[str,Callable]]: Specify + the function to use for sharding the state + **kwargs: Pass keyword arguments to the function + config_kwargs: Optional[Mapping[str, Any]]: Config kwargs to + be added to config before creating module generation process:param key_partition_spec: PartitionSpec: Specify the partitioning of the key matrix - :param value_partition_spec: PartitionSpec: Specify the partitioning of the value tensor - :param bias_partition_spec: PartitionSpec: Specify the partitioning of the bias - :param attention_partition_spec: PartitionSpec: Partition the attention weights - :param shard_attention_computation: bool: Determine whether to use shard_map or not - :param input_shape: Sequence[int]: Specify the shape of the input to be used for training - :param backend: Optional[str]: Specify the backend used for the model - :param init_optimizer_state: bool: Initialize the optimizer state - :param free_optimizer_state: bool: Free the optimizer state from memory - :param verbose: bool: Print the progress of loading the model - :param state_shard_fns: Optional[Mapping[str,Callable]]: Specify the function to use for sharding the state - :param kwargs: Pass keyword arguments to the function - :param config_kwargs: Optional[Mapping[str, Any]]: Config kwargs to be added to config before creating module - :return: An `EasyDeLState` object + + Returns: + An `EasyDeLState` object """ if free_optimizer_state and init_optimizer_state: raise EasyDeLRuntimeError( @@ -586,9 +648,7 @@ def shard_params( @staticmethod def create_hyperparameters(model_type: str): - """ - it's the only way we can dump xla compiler - """ + """it's the only way we can dump xla compiler""" return { STRING_REP.format( type="str", @@ -625,13 +685,15 @@ def unsafe_dict(dictionary: dict): def __str__(self): - """ - The __str__ function is called when you call str(object) or print(object). + """The __str__ function is called when you call str(object) or print(object). The __repr__ function is called when you type the object name in the interpreter. If no __str__ method exists, Python will use __repr__ as a fallback. - :param self: Refer to the object itself - :return: string + Args: + self: Refer to the object itself + + Returns: + string """ params_size = sum(getattr(n, "size", 0) for n in jax.tree_util.tree_flatten(self.params)[0]) opt_state_size = sum(getattr(n, "size", 0) for n in jax.tree_util.tree_flatten(self.opt_state)[0]) @@ -695,15 +757,17 @@ def find_key(key, dictionary: dict) -> Union[str, None]: def __repr__(self): - """ - The __repr__ function is the "official" string representation of an object. + """The __repr__ function is the "official" string representation of an object. It's what you get when you type the object name at the Python prompt, or pass it to str(). The goal of __repr__ is to be unambiguous: if eval(repr(x)) == x, then __repr__ should return a string that looks like a valid Python expression that could be used to recreate an object with the same value ( given an appropriate environment). If this is not possible, a string formatted using %s formatting is also acceptable. - :param self: Represent the instance of the class - :return: A string that is a valid python expression + Args: + self: Represent the instance of the class + + Returns: + A string that is a valid python expression """ return self.__str__() diff --git a/src/python/easydel/etils/etils.py b/src/python/easydel/etils/etils.py index e61d7d191..a37f92374 100644 --- a/src/python/easydel/etils/etils.py +++ b/src/python/easydel/etils/etils.py @@ -7,8 +7,7 @@ @dataclass class EasyDeLOptimizers: - """ - The code snippet is defining a data class called `EasyDeLOptimizers` using the `@dataclass` + """The code snippet is defining a data class called `EasyDeLOptimizers` using the `@dataclass` decorator. A data class is a class that is primarily used to store data, and it automatically generates special methods such as `__init__`, `__repr__`, and `__eq__` based on the class attributes. @@ -20,8 +19,7 @@ class EasyDeLOptimizers: @dataclass class EasyDeLSchedulers: - """ - The code snippet is defining a data class called `EasyDeLSchedulers` using the `@dataclass` + """The code snippet is defining a data class called `EasyDeLSchedulers` using the `@dataclass` decorator. A data class is a class that is primarily used to store data, and it automatically generates special methods such as `__init__`, `__repr__`, and `__eq__` based on the class attributes. @@ -35,8 +33,7 @@ class EasyDeLSchedulers: @dataclass class EasyDeLGradientCheckPointers: - """ - The code snippet is defining a data class called `EasyDeLGradientCheckPointers` using the `@dataclass` + """The code snippet is defining a data class called `EasyDeLGradientCheckPointers` using the `@dataclass` decorator. A data class is a class that is primarily used to store data, and it automatically generates special methods such as `__init__`, `__repr__`, and `__eq__` based on the class attributes. @@ -94,9 +91,11 @@ def get_logger(name, level: int = logging.INFO) -> logging.Logger: def set_loggers_level(level: int = logging.WARNING): - """ - Function to set the logging level of all loggers to the specified level. - :param level: int: The logging level to set. Defaults to logging.WARNING. + """Function to set the logging level of all loggers to the specified level. + + Args: + level: int: The logging level to set. Defaults to + logging.WARNING. """ logging.root.setLevel(level) for handler in logging.root.handlers: @@ -107,8 +106,7 @@ def define_flags_with_default( _required_fields: List = None, **kwargs ) -> Tuple[argparse.Namespace, Dict[str, Any]]: - """ - Defines flags with default values using argparse. + """Defines flags with default values using argparse. Args: _required_fields: A dictionary with required flag names diff --git a/src/python/easydel/modules/_attentions/blockwise_attn.py b/src/python/easydel/modules/_attentions/blockwise_attn.py index b32dd7fab..587cbac84 100644 --- a/src/python/easydel/modules/_attentions/blockwise_attn.py +++ b/src/python/easydel/modules/_attentions/blockwise_attn.py @@ -1,5 +1,4 @@ -""" -An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370 +"""An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682 from EasyLM https://github.com/young-geng/EasyLM/blob/main/EasyLM/bpt.py """ diff --git a/src/python/easydel/modules/arctic/modelling_arctic_flax.py b/src/python/easydel/modules/arctic/modelling_arctic_flax.py index e1c5b38eb..edfd8e9e1 100644 --- a/src/python/easydel/modules/arctic/modelling_arctic_flax.py +++ b/src/python/easydel/modules/arctic/modelling_arctic_flax.py @@ -188,23 +188,28 @@ def __call__( init_cache: bool = False, output_attentions: bool = True ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It defines how the module behaves when called as a function, and it's what you'll use to call your model in practice. The __call__ method takes an input tensor (x) and returns an output tensor (y). In this case, we're defining our model to be a simple linear layer with no activation: y = x @ w + b. - :param self: Refer to the object itself - :param hidden_states: chex.Array: Pass in the hidden state of the model - :param freq_cis: Tuple[chex.Array, chex.Array],: Create the apply_rotary variable - :param attention_mask: chex.Array: Mask the attention weights - :param causal_mask: chex.Array: Mask the attention weights - :param position_ids: chex.Array: Specify the position of each token in a sequence - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attention weights - :return: A tuple of (out, attn_output) - + Args: + self: Refer to the object itself + hidden_states: chex.Array: Pass in the hidden state of the + model + freq_cis: Tuple[chex.Array, chex.Array],: Create the + apply_rotary variable + attention_mask: chex.Array: Mask the attention weights + causal_mask: chex.Array: Mask the attention weights + position_ids: chex.Array: Specify the position of each token + in a sequence + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attention weights + + Returns: + A tuple of (out, attn_output) """ batch_size, sequence_length = hidden_states.shape[:2] query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj( @@ -480,8 +485,7 @@ def __call__(self, class FlaxArcticSparseMoeBlock(nn.Module): - """ - This implementation is + """This implementation is strictly equivalent to standard MoE with full capacity (no dropped tokens). It's faster since it formulates MoE operations in terms of block-sparse operations to accomodate imbalanced @@ -631,8 +635,7 @@ def __call__( init_cache: bool = False, output_attentions: bool = True, ): - """ - The __call__ function is the main function of a TransformerEncoderLayer. + """The __call__ function is the main function of a TransformerEncoderLayer. It takes in the following arguments: hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers. @@ -640,16 +643,23 @@ def __call__( used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Represent the input to the encoder layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer - :param attention_mask: chex.Array: Mask out the attention weights for certain positions - :param causal_mask: chex.Array: Mask the future tokens - :param position_ids: chex.Array: Indicate the position of each token in the sequence - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache for the self-attention layer - :return: A tuple of hidden_states and attention_output - + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Represent the input to the + encoder layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency + information to the attention layer + attention_mask: chex.Array: Mask out the attention weights + for certain positions + causal_mask: chex.Array: Mask the future tokens + position_ids: chex.Array: Indicate the position of each + token in the sequence + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache for the self- + attention layer + + Returns: + A tuple of hidden_states and attention_output """ residual_input = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -733,8 +743,7 @@ def __call__( output_hidden_states: Optional[bool] = False, output_attentions: Optional[bool] = False, ): - """ - The __call__ function is the main function of a TransformerEncoderLayer. + """The __call__ function is the main function of a TransformerEncoderLayer. It takes in the following arguments: hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers. @@ -742,17 +751,26 @@ def __call__( , used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Represent the input to the encoder layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer - :param attention_mask: chex.Array: Mask out the attention weights for certain positions - :param causal_mask: chex.Array: Mask the future tokens - :param position_ids: chex.Array: Indicate the position of each token in the sequence - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache for the self-attention layer - :param output_attentions: bool: Determine whether to return the attention weights or not - :return: A tuple of hidden_states, attention_output, all_hidden_states and all_router_losses - + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Represent the input to the + encoder layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency + information to the attention layer + attention_mask: chex.Array: Mask out the attention weights + for certain positions + causal_mask: chex.Array: Mask the future tokens + position_ids: chex.Array: Indicate the position of each + token in the sequence + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache for the self- + attention layer + output_attentions: bool: Determine whether to return the + attention weights or not + + Returns: + A tuple of hidden_states, attention_output, + all_hidden_states and all_router_losses """ all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -825,17 +843,21 @@ def init_weights( input_shape: Tuple, params: FrozenDict = None ) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. + """The init_weights function is used to initialize the weights of a model. It takes in a rng, which is a random number generator key that can be used to generate random numbers. The input_shape parameter specifies the shape of the inputs that will be fed into this model. The params parameter allows you to pass in pre-trained weights for your model, if you have them available. - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids - :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Initialize the input_ids, attention_mask + and position_ids + params: flax.core.FrozenDict: Pass in the parameters of a + pre-trained model + + Returns: + A frozendict of parameters """ self.config.initialization_of_moe = True @@ -908,28 +930,35 @@ def __call__( add_params_field: bool = False, **kwargs ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It takes as input: - The parameters of the model (self.params) - The inputs to the model (input_ids, attention_mask, position_ids) - Whether we are training (train=True/False) and whether we want to return all hidden states and attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False). - :param self: Represent the instance of the class - :param input_ids: Pass the input sequence to the model - :param attention_mask: Mask out the padding tokens - :param position_ids: Specify the position of each token in the sequence - :param params: dict: Pass in the parameters of the model - :param past_key_values: dict: Pass the past key values to the model - :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model - :param train: bool: Determine whether to use dropout or not - :param output_attentions: Optional[bool]: Determine whether to return the attention weights - :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers - :param return_dict: Optional[bool]: Return a dictionary of the outputs - :param add_params_field: bool: Add a params field to the inputs dictionary - :return: A tuple of (last_hidden_state, past_key_values) - + Args: + self: Represent the instance of the class + input_ids: Pass the input sequence to the model + attention_mask: Mask out the padding tokens + position_ids: Specify the position of each token in the + sequence + params: dict: Pass in the parameters of the model + past_key_values: dict: Pass the past key values to the model + dropout_rng: jax.random.PRNGKey: Pass in a random number + generator key to the model + train: bool: Determine whether to use dropout or not + output_attentions: Optional[bool]: Determine whether to + return the attention weights + output_hidden_states: Optional[bool]: Determine whether to + return the hidden states of all layers + return_dict: Optional[bool]: Return a dictionary of the + outputs + add_params_field: bool: Add a params field to the inputs + dictionary + + Returns: + A tuple of (last_hidden_state, past_key_values) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1216,15 +1245,18 @@ def set_output_embeddings(self, new_embeddings): self.module.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None): - """ - The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. - - :param self: Access variables that belong to the class - :param input_ids: Pass in the input tokens - :param max_length: Set the length of the sequence to be generated - :param attention_mask: Optional[chex.Array]: Mask the attention weights - :return: A dictionary of the past_key_values, attention_mask and position ids - + """The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. + + Args: + self: Access variables that belong to the class + input_ids: Pass in the input tokens + max_length: Set the length of the sequence to be generated + attention_mask: Optional[chex.Array]: Mask the attention + weights + + Returns: + A dictionary of the past_key_values, attention_mask and + position ids """ batch_size, seq_length = input_ids.shape diff --git a/src/python/easydel/modules/attention_module.py b/src/python/easydel/modules/attention_module.py index 73cb72674..93c80d4ee 100644 --- a/src/python/easydel/modules/attention_module.py +++ b/src/python/easydel/modules/attention_module.py @@ -110,9 +110,7 @@ def combine_flash_masks(causal_mask, segment_ids): def get_flash_attention() -> Tuple[Callable, bool, bool]: - """ - return: FlashAttention FN, Upcast Needed to float32,do_shard_map - """ + """return: FlashAttention FN, Upcast Needed to float32,do_shard_map""" platform = jax.lib.xla_bridge.get_backend().platform if platform == "gpu": warnings.warn("for GPU backend use `cudnn` or `pallas_flash`") @@ -1165,9 +1163,7 @@ def test_attentions( chunk_size=128, axis_dims=(1, -1, 1, 1) ): - """ - creates a test for attention module to help you find the best attention mechanism you can use. - """ + """creates a test for attention module to help you find the best attention mechanism you can use.""" import flax try: import pandas diff --git a/src/python/easydel/modules/auto_easydel_model.py b/src/python/easydel/modules/auto_easydel_model.py index a849ddb27..595a2d9e5 100644 --- a/src/python/easydel/modules/auto_easydel_model.py +++ b/src/python/easydel/modules/auto_easydel_model.py @@ -90,7 +90,7 @@ def get_modules_by_type(model_type: str) -> Tuple[ embedding_layer_names=["wte"], rnn_based_or_rwkv=False, layer_norm_names=[ - "norm_1", "norm_2","norm_f" + "norm_1", "norm_2", "norm_f" ] ) ) @@ -374,20 +374,52 @@ def get_modules_by_type(model_type: str) -> Tuple[ def is_flatten(pytree: dict): - """ - The is_flatten function checks if the pytree is flattened. + """The is_flatten function checks if the pytree is flattened. If it is, then the first key in the dictionary will be a tuple of (mpl, mpl_id). Otherwise, it will be an integer representing mpl_id. - :param pytree: dict: Pass the pytree to the function - :return: True if the pytree is a flattened tree, and false otherwise - + Args: + pytree: dict: Pass the pytree to the function + + Returns: + True if the pytree is a flattened tree, and false otherwise """ mpl = [k for k in pytree.keys()][0] return True if isinstance(mpl, tuple) else False class AutoEasyDeLModelForCausalLM: + """This class provides a convenient way to load and shard pretrained causal language models from the Hugging Face Hub + and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference + with JAX. + + This class inherits from the `EasyDeLFlaxPretrainedModel` class, providing functionalities for model loading, + parameter sharding, and interaction with the EasyDeL framework. + + Attributes: + None + + Examples: + ```python + import jax + from easydel import AutoEasyDeLModelForCausalLM + + # Load a GPT-2 model on a single CPU + model, params = AutoEasyDeLModelForCausalLM.from_pretrained( + "gpt2", + device=jax.devices("cpu")[0] + ) + + # Load a GPT-2 model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP) + model, params = AutoEasyDeLModelForCausalLM.from_pretrained( + "gpt2", + sharding_axis_dims=(1, 8, 1, 1), + sharding_axis_names=("dp", "fsdp", "tp", "sp"), + device=jax.devices("cpu")[0] # offload to CPU [OPTIONAL] + ) + ``` + """ + @classmethod def from_pretrained( cls, @@ -416,40 +448,49 @@ def from_pretrained( bit_targeted_params: Optional[List[str]] = None, **kwargs ) -> Tuple[EasyDeLFlaxPretrainedModel, dict]: - """ - The from_pretrained function is a helper function that allows you to instantiate a model from the pretrained - model repository. It takes as input the name of the model (e.g., 'bert-base-uncased') and returns an instance of - the class corresponding to your model, with all weights loaded from disk. - - :param cls: Create an instance of the class that called this function - :param pretrained_model_name_or_path: str: Identify the model in the huggingface model hub - :param device: Specify the device on which to run the model - :param dtype: jax.numpy.dtype: Specify the data type of the model - :param param_dtype: jax.numpy.dtype: Specify the dtype of the parameters - :param precision: jax.lax.Precision: Control the precision of the model - :param sharding_axis_dims: typing.Sequence[int]: Specify the dimension of each axis in the sharded model - :param sharding_axis_names: typing.Sequence[str]: Specify the order of sharding - :param query_partition_spec: PartitionSpec: Specify the partitioning of the query tensor - :param generation_query_partition_spec: PartitionSpec: Specify the partitioning of the query tensor in - generation process - :param key_partition_spec: PartitionSpec: Partition the key matrix - :param value_partition_spec: PartitionSpec: Specify the partitioning of the value tensor - :param bias_partition_spec: PartitionSpec: Specify the Attention Bias partition spec - :param generation_bias_partition_spec: PartitionSpec: Specify the Attention Bias partition spec for generation - :param attention_partition_spec: PartitionSpec: Specify the partitioning of the attention weights - :param shard_attention_computation: bool: whenever to use shard_map for attention - :param input_shape: typing.Sequence[int]: Specify the shape of the input to the model - :param shard_fns: Optional[Mapping[tuple, Callable]]: Sharding Function to be used to shard model - :param backend: typing.Optional[str]: backend to use for model - :param config_kwargs: Optional[Mapping[str, Any]]: Config kwargs to be added to config before creating module - :param auto_shard_params: bool: whether to automaticly shard the model parameters - :param partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]]: custom partition rules to create partition - specs required to shard model parameters - :param load_in_8bit: bool: whether to load model parameters and convert them into 8bit - :param bit_targeted_params: Optional[List[str]]: list of targeted parameters to be converted into 8bit - :param kwargs: Pass additional arguments to the model and config classes - :return: A model and parameters - + """Loads and shards a pretrained causal language model from the Hugging Face Hub and converts it into an + EasyDeL compatible model. + + Args: + pretrained_model_name_or_path (str): Path or name of the pretrained model in the Hugging Face Hub. + device (jax.Array, optional): Device to load the model on. Defaults to the first CPU. + dtype (jax.numpy.dtype, optional): Data type of the model. Defaults to jax.numpy.float32. + param_dtype (jax.numpy.dtype, optional): Data type of the model parameters. Defaults to jax.numpy.float32. + precision (jax.lax.Precision, optional): Precision for computations. Defaults to jax.lax.Precision("fastest"). + sharding_axis_dims (Sequence[int], optional): Dimensions of each sharding axis. Defaults to (1, -1, 1, 1). + sharding_axis_names (Sequence[str], optional): Names of the sharding axes. Defaults to ("dp", "fsdp", "tp", "sp"). + query_partition_spec (PartitionSpec, optional): Partitioning specification for the query tensor. Defaults to + PartitionSpec(("dp", "fsdp"), "sp", "tp", None). + generation_query_partition_spec (PartitionSpec, optional): Partitioning specification for the query tensor during + generation. Defaults to PartitionSpec(("dp", "fsdp"), None, "tp", None). + key_partition_spec (PartitionSpec, optional): Partitioning specification for the key tensor. Defaults to + PartitionSpec(("dp", "fsdp"), "sp", "tp", None). + value_partition_spec (PartitionSpec, optional): Partitioning specification for the value tensor. Defaults to + PartitionSpec(("dp", "fsdp"), "sp", "tp", None). + bias_partition_spec (PartitionSpec, optional): Partitioning specification for the attention bias. Defaults to + PartitionSpec(("dp", "fsdp"), None, None, None). + generation_bias_partition_spec (PartitionSpec, optional): Partitioning specification for the attention bias during + generation. Defaults to PartitionSpec(("dp", "fsdp"), None, None, None). + attention_partition_spec (PartitionSpec, optional): Partitioning specification for the attention weights. Defaults to + PartitionSpec(("dp", "fsdp"), "sp", "tp", None). + shard_attention_computation (bool, optional): Whether to shard attention computation. Defaults to True. + input_shape (Sequence[int], optional): Shape of the input to the model. Defaults to (1, 1). + shard_fns (Optional[Mapping[tuple, Callable] | dict], optional): Sharding functions to use for the model. If None, + auto-sharding is used if auto_shard_params is True. Defaults to None. + backend (Optional[str], optional): Backend to use for the model. Defaults to None. + config_kwargs (Optional[Mapping[str, Any]], optional): Configuration keyword arguments to pass to the model config. + Defaults to None. + auto_shard_params (bool, optional): Whether to automatically shard the model parameters. Defaults to False. + partition_rules (Optional[Tuple[Tuple[str, PartitionSpec]]], optional): Custom partition rules for parameter + sharding. If not None, shard_fns should also be provided. Defaults to None. + load_in_8bit (bool, optional): Whether to load the model parameters in 8-bit precision. Defaults to False. + bit_targeted_params (Optional[List[str]], optional): List of parameter names to convert to 8-bit precision. If + None and load_in_8bit is True, all kernels and embeddings are converted to 8-bit. Defaults to None. + **kwargs: Additional keyword arguments to pass to the model and config classes. + + Returns: + Tuple[EasyDeLFlaxPretrainedModel, dict]: A tuple containing the EasyDeL model and the loaded and sharded + model parameters. """ logger.debug(f"Downloading model config from {pretrained_model_name_or_path}") @@ -583,28 +624,41 @@ def from_pretrained( backend: Optional[str] = None, **kwargs ) -> EasyDeLPretrainedConfig: - """ - The from_pretrained function is a helper function that allows you to instantiate a model from the pretrained + """The from_pretrained function is a helper function that allows you to instantiate a model from the pretrained model repository. It takes as input the name of the model (e.g., 'bert-base-uncased') and returns an instance of the class corresponding to your model, with all weights loaded from disk. - :param cls: Create an instance of the class that called this function - :param pretrained_model_name_or_path: str: Identify the model in the huggingface model hub - :param sharding_axis_dims: Sequence[int]: Specify the dimension of each axis in the sharded model - :param sharding_axis_names: Sequence[str]: Specify the order of sharding - :param query_partition_spec: PartitionSpec: Specify the partitioning of the query tensor - :param generation_query_partition_spec: PartitionSpec: Specify the partitioning of the query tensor in + Args: + cls: Create an instance of the class that called this + function + pretrained_model_name_or_path: str: Identify the model in + the huggingface model hub + sharding_axis_dims: Sequence[int]: Specify the dimension of + each axis in the sharded model + sharding_axis_names: Sequence[str]: Specify the order of + sharding + query_partition_spec: PartitionSpec: Specify the + partitioning of the query tensor + generation_query_partition_spec: PartitionSpec: Specify the + partitioning of the query tensor in + key_partition_spec: PartitionSpec: Partition the key matrix + value_partition_spec: PartitionSpec: Specify the + partitioning of the value tensor + bias_partition_spec: PartitionSpec: Specify the Attention + Bias partition spec + generation_bias_partition_spec: PartitionSpec: Specify the + Attention Bias partition spec for generation + attention_partition_spec: PartitionSpec: Specify the + partitioning of the attention weights + shard_attention_computation: bool: whenever to use shard_map + for attention + backend: Optional[str]: backend to use for model + **kwargs: Pass additional arguments to the model and config + classes generation process - :param key_partition_spec: PartitionSpec: Partition the key matrix - :param value_partition_spec: PartitionSpec: Specify the partitioning of the value tensor - :param bias_partition_spec: PartitionSpec: Specify the Attention Bias partition spec - :param generation_bias_partition_spec: PartitionSpec: Specify the Attention Bias partition spec for generation - :param attention_partition_spec: PartitionSpec: Specify the partitioning of the attention weights - :param shard_attention_computation: bool: whenever to use shard_map for attention - :param backend: Optional[str]: backend to use for model - :param kwargs: Pass additional arguments to the model and config classes - :return: A Model Config + Returns: + A Model Config """ config = AutoConfig.from_pretrained(pretrained_model_name_or_path) diff --git a/src/python/easydel/modules/cohere/cohere_configuration.py b/src/python/easydel/modules/cohere/cohere_configuration.py index 467c07211..a7514995b 100644 --- a/src/python/easydel/modules/cohere/cohere_configuration.py +++ b/src/python/easydel/modules/cohere/cohere_configuration.py @@ -65,15 +65,17 @@ def __init__( ) def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices. - :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to + partition the model fully or not + Returns: + A list of tuples """ return ( @@ -124,13 +126,16 @@ def add_jax_args( bits: Optional[int] = None, **kwargs, ): - """ - The add_jax_args function adds the following arguments to the Transformer class: - - :param self: Refer to the current object - :param tie_word_embeddings: bool: Tie the word embeddings to the decoder - :param gradient_checkpointing: str: Control the amount of memory used by jax - :param bits: Optional[int]: Determine the number of bits used in the quantization + """The add_jax_args function adds the following arguments to the Transformer class: + + Args: + self: Refer to the current object + tie_word_embeddings: bool: Tie the word embeddings to the + decoder + gradient_checkpointing: str: Control the amount of memory + used by jax + bits: Optional[int]: Determine the number of bits used in + the quantization """ self.tie_word_embeddings = tie_word_embeddings self.gradient_checkpointing = gradient_checkpointing diff --git a/src/python/easydel/modules/cohere/modelling_cohere_flax.py b/src/python/easydel/modules/cohere/modelling_cohere_flax.py index 31275dbda..f165ac808 100644 --- a/src/python/easydel/modules/cohere/modelling_cohere_flax.py +++ b/src/python/easydel/modules/cohere/modelling_cohere_flax.py @@ -201,33 +201,37 @@ def _merge_heads(self, hidden_states): @staticmethod def _transpose_sequence_head(query, key, value): - """ - The _transpose_sequence_head function transposes the query, key and value matrices. + """The _transpose_sequence_head function transposes the query, key and value matrices. - :param query: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query, key and value matrices + Args: + query: Get the attention weights for each of the heads + key: Determine the number of heads + value: Store the values of the input + Returns: + The transpose of the query, key and value matrices """ return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - :param self: Access variables that belong to the class - :param batch_size: Reshape the query, key and value tensors - :param sequence_length: Reshape the query, key and value tensors - :param query: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param freq_cis: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query, key and value - + Args: + self: Access variables that belong to the class + batch_size: Reshape the query, key and value tensors + sequence_length: Reshape the query, key and value tensors + query: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + freq_cis: Calculate the frequency of each word in the + vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query, key and value """ query = query.reshape( batch_size, @@ -269,25 +273,32 @@ def __call__( output_attentions: bool = False, fcm_mask=None, ): - """ - - The __call__ function is the main function of a JAX module. It defines how the module behaves when called + """The __call__ function is the main function of a JAX module. It defines how the module behaves when called with inputs. The __call__ function can be thought of as a "forward pass" through the model, and it should return all outputs that are needed for training or inference. - :param self: Access variables that belong to the class - :param hidden_states: chex.Array: Pass the hidden states of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param causal_mask: chex.Array: Mask out the future tokens in the decoder - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attention weights or not - :param fcm_mask: Mask out the attention weights between the input and output tokens + Args: + self: Access variables that belong to the class + hidden_states: chex.Array: Pass the hidden states of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency coefficients for each position + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Determine the position of each + token in a sequence + causal_mask: chex.Array: Mask out the future tokens in the + decoder + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attention weights or not + fcm_mask: Mask out the attention weights between the input + and output tokens :param : Determine if the attention is causal or not - :return: A tuple of two arrays + Returns: + A tuple of two arrays """ batch_size, sequence_length = hidden_states.shape[:2] ( @@ -455,16 +466,18 @@ def setup(self) -> None: ) def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - """ - The __call__ function is the main function of a class. + """The __call__ function is the main function of a class. It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments). The __call__ method enables instances of a class to be called like standard Python functions. - :param self: Represent the instance of the class - :param x: jnp.ndarray: Pass in the input to the layer - :param deterministic: bool: Determine whether to use dropout - :return: A tensor that is the result of applying a dropout function to x + Args: + self: Represent the instance of the class + x: jnp.ndarray: Pass in the input to the layer + deterministic: bool: Determine whether to use dropout + Returns: + A tensor that is the result of applying a dropout function + to x """ x = self.down_proj(jax.nn.silu(self.gate_proj(x)) * self.up_proj(x)) return x @@ -527,25 +540,32 @@ def __call__( output_attentions: bool = False, fcm_mask: Optional[jnp.ndarray] = None, ): - """ - The __call__ function is the main function of a TransformerEncoderLayer. + """The __call__ function is the main function of a TransformerEncoderLayer. It takes in hidden states, frequency-domain inputs, and masks as input. It then applies self-attention to the hidden states using those inputs and returns an output tensor with shape (batch_size, sequence_length, model_dim). - :param self: Refer to the class instance itself - :param hidden_states: chex.Array: Pass in the hidden state of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency information - :param attention_mask: chex.Array: Mask out the attention weights for padding tokens - :param position_ids: chex.Array: Determine the position of each token in the sequence - :param causal_mask: chex.Array: Mask the attention weights - :param deterministic: bool: Control whether the dropout is applied or not - :param init_cache: bool: Initialize the cache in the attention layer - :param output_attentions: bool: Return the attention weights - :param fcm_mask: Optional[jnp.ndarray]: Mask the self-attention + Args: + self: Refer to the class instance itself + hidden_states: chex.Array: Pass in the hidden state of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency information + attention_mask: chex.Array: Mask out the attention weights + for padding tokens + position_ids: chex.Array: Determine the position of each + token in the sequence + causal_mask: chex.Array: Mask the attention weights + deterministic: bool: Control whether the dropout is applied + or not + init_cache: bool: Initialize the cache in the attention + layer + output_attentions: bool: Return the attention weights + fcm_mask: Optional[jnp.ndarray]: Mask the self-attention :param : Control the dropout in the self attention layer - :return: A tuple of two items + Returns: + A tuple of two items """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -597,37 +617,42 @@ def __init__( _do_init: bool = True, **kwargs, ): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It sets up the instance of the class, and defines what happens when it's created. The __init__ function can take arguments, but self is always required (it refers to the instance of the object). - - :param self: Refer to the object itself - :param config: CohereConfig: Pass the configuration to the module - :param input_shape: Tuple: Specify the shape of the input to the model - :param seed: int: Set the seed for random number generation - :param dtype: jnp.dtype: Specify the data type of the input - :param _do_init: bool: Control whether the module is initialized or not - :param kwargs: Pass in any additional parameters that the module_class might need + Args: + self: Refer to the object itself + config: CohereConfig: Pass the configuration to the module + input_shape: Tuple: Specify the shape of the input to the + model + seed: int: Set the seed for random number generation + dtype: jnp.dtype: Specify the data type of the input + _do_init: bool: Control whether the module is initialized or + not + **kwargs: Pass in any additional parameters that the + module_class might need :param : Specify the number of layers in the network - :return: The super() of the class + Returns: + The super() of the class """ module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. + """The init_weights function is used to initialize the weights of a model. - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Specify the shape of the input tensor - :param params: FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Specify the shape of the input tensor + params: FrozenDict: Pass in the parameters of a pre-trained + model + Returns: + A frozendict of parameters """ input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -666,17 +691,18 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz return random_params def init_cache(self, batch_size, max_length): - """ - The init_cache function is used to initialize the cache for a given batch size and sequence length. + """The init_cache function is used to initialize the cache for a given batch size and sequence length. The cache is a dictionary that contains all the intermediate states from each layer in the model. This allows us to run inference on multiple batches without having to re-run forward passes through every layer in the model, which would be very slow. - :param self: Access the module - :param batch_size: Define the batch size of the input tensors - :param max_length: Set the length of the input sequence - :return: A dictionary with the following keys: + Args: + self: Access the module + batch_size: Define the batch size of the input tensors + max_length: Set the length of the input sequence + Returns: + A dictionary with the following keys: """ input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) @@ -704,27 +730,36 @@ def __call__( add_params_field: bool = False, **kwargs ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It takes in inputs and returns outputs, but it also has some other important features: - It can take in mutable state (e.g., past_key_values) that will be updated during the call and returned at the end. - It can take in random number generators (rngs) that are used to generate random numbers for dropout or sampling operations. - :param self: Represent the instance of the class - :param input_ids: chex.Array: Pass in the input tokens - :param attention_mask: chex.Array: Mask out certain tokens in the input - :param position_ids: chex.Array: Create the positional embeddings - :param params: dict: Pass in the parameters of the model - :param past_key_values: dict: Pass in the past key values from a previous call to __call__ - :param dropout_rng: jax.random.PRNGKey: Make sure that the dropout is applied in a random way - :param train: bool: Determine whether to use dropout or not - :param output_attentions: Optional[bool]: Determine whether to return the attention weights - :param output_hidden_states: Optional[bool]: Return the hidden states of all layers - :param return_dict: Optional[bool]: Determine whether to return a dictionary or not - :param extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids - :param add_params_field: bool: Add the params field to the inputs dictionary - :return: A tuple of the following: - + Args: + self: Represent the instance of the class + input_ids: chex.Array: Pass in the input tokens + attention_mask: chex.Array: Mask out certain tokens in the + input + position_ids: chex.Array: Create the positional embeddings + params: dict: Pass in the parameters of the model + past_key_values: dict: Pass in the past key values from a + previous call to __call__ + dropout_rng: jax.random.PRNGKey: Make sure that the dropout + is applied in a random way + train: bool: Determine whether to use dropout or not + output_attentions: Optional[bool]: Determine whether to + return the attention weights + output_hidden_states: Optional[bool]: Return the hidden + states of all layers + return_dict: Optional[bool]: Determine whether to return a + dictionary or not + extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in + the embedding for the input_ids + add_params_field: bool: Add the params field to the inputs + dictionary + + Returns: + A tuple of the following: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -822,27 +857,35 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ): - """ - The __call__ function is the main function of a JAX nn.Module. + """The __call__ function is the main function of a JAX nn.Module. It defines how the module behaves when called as a function, and it's what you'll use to call your model in training loops or inference scripts. The __call__ method should take all inputs that are necessary for computing outputs from the module, and return all outputs that are computed by this module. - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Pass the input tensor to the encoder - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency of each token - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Specify the position of each token in a sequence - :param causal_mask: chex.Array: Mask the attention weights - :param deterministic: bool: Determine whether the model is in training or evaluation mode - :param init_cache: bool: Initialize the cache for each layer - :param output_attentions: bool: Determine whether to output the attention weights - :param output_hidden_states: bool: Determine whether to return the hidden states of each layer - :param return_dict: bool: Return a dictionary of the outputs + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Pass the input tensor to the + encoder + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency of each token + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Specify the position of each token + in a sequence + causal_mask: chex.Array: Mask the attention weights + deterministic: bool: Determine whether the model is in + training or evaluation mode + init_cache: bool: Initialize the cache for each layer + output_attentions: bool: Determine whether to output the + attention weights + output_hidden_states: bool: Determine whether to return the + hidden states of each layer + return_dict: bool: Return a dictionary of the outputs :param : Determine whether to use the forgetful causal mask - :return: A tuple of 3 values + Returns: + A tuple of 3 values """ all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -949,26 +992,33 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids + """The __call__ function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids and returns the output of the model. The __call__ function also has optional arguments that can be used to control the behavior of the model (e.g., deterministic=True). These optional arguments are passed as keyword arguments when calling a Flax model. - :param self: Represent the instance of the class - :param input_ids: chex.Array: Pass in the input token ids - :param attention_mask: chex.Array: Mask out the padding tokens - :param position_ids: chex.Array: Indicate the position of each token in a sequence - :param deterministic: bool: Control whether dropout is applied or not - :param inputs_embeds: chex.Array: Pass in the embeddings of the input tokens - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attentions or not - :param output_hidden_states: bool: Determine whether to return hidden states - :param return_dict: bool: Return a dictionary of the output or not - :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of the - :param None]]: Pass in the extra embedding - :return: A tuple of: - + Args: + self: Represent the instance of the class + input_ids: chex.Array: Pass in the input token ids + attention_mask: chex.Array: Mask out the padding tokens + position_ids: chex.Array: Indicate the position of each + token in a sequence + deterministic: bool: Control whether dropout is applied or + not + inputs_embeds: chex.Array: Pass in the embeddings of the + input tokens + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attentions or not + output_hidden_states: bool: Determine whether to return + hidden states + return_dict: bool: Return a dictionary of the output or not + extra_embedding: Optional[Union[jnp.ndarray: Pass in the + embedding of the + None]]: Pass in the extra embedding + + Returns: + A tuple of: """ if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids.astype("i4")) @@ -1058,22 +1108,27 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax module. It takes in inputs and returns outputs. - - :param self: Refer to the object itself - :param input_ids: chex.Array: Pass the input token ids to the model - :param attention_mask: chex.Array: Mask out the padding tokens - :param position_ids: chex.Array: Specify the position of each token in the input sequence - :param deterministic: bool: Control whether the model is trained or not - :param init_cache: bool: Initialize the cache for the decoder - :param output_attentions: bool: Return the attention weights - :param output_hidden_states: bool: Determine whether to return the hidden states - :param return_dict: bool: Return a dictionary of the outputs or not - :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of the word that we want to predict - :param None]]: Pass in the extra embedding - :return: The logits and the hidden states - + """The __call__ function is the main function of a Flax module. It takes in inputs and returns outputs. + + Args: + self: Refer to the object itself + input_ids: chex.Array: Pass the input token ids to the model + attention_mask: chex.Array: Mask out the padding tokens + position_ids: chex.Array: Specify the position of each token + in the input sequence + deterministic: bool: Control whether the model is trained or + not + init_cache: bool: Initialize the cache for the decoder + output_attentions: bool: Return the attention weights + output_hidden_states: bool: Determine whether to return the + hidden states + return_dict: bool: Return a dictionary of the outputs or not + extra_embedding: Optional[Union[jnp.ndarray: Pass in the + embedding of the word that we want to predict + None]]: Pass in the extra embedding + + Returns: + The logits and the hidden states """ batch_size, seq_length = input_ids.shape if attention_mask is None: diff --git a/src/python/easydel/modules/dbrx/modelling_dbrx_flax.py b/src/python/easydel/modules/dbrx/modelling_dbrx_flax.py index a21b1ce36..6566944a0 100644 --- a/src/python/easydel/modules/dbrx/modelling_dbrx_flax.py +++ b/src/python/easydel/modules/dbrx/modelling_dbrx_flax.py @@ -181,33 +181,37 @@ def _merge_heads(self, hidden_states): @staticmethod def _transpose_sequence_head(query, key, value): - """ - The _transpose_sequence_head function transposes the query, key and value matrices. + """The _transpose_sequence_head function transposes the query, key and value matrices. - :param query: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query, key and value matrices + Args: + query: Get the attention weights for each of the heads + key: Determine the number of heads + value: Store the values of the input + Returns: + The transpose of the query, key and value matrices """ return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - :param self: Access variables that belong to the class - :param batch_size: Reshape the query, key and value tensors - :param sequence_length: Reshape the query, key and value tensors - :param query: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param freq_cis: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query, key and value - + Args: + self: Access variables that belong to the class + batch_size: Reshape the query, key and value tensors + sequence_length: Reshape the query, key and value tensors + query: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + freq_cis: Calculate the frequency of each word in the + vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query, key and value """ query = query.reshape( batch_size, @@ -249,24 +253,32 @@ def __call__( output_attentions: bool = False, fcm_mask=None, ): - """ - The __call__ function is the main function of a JAX module. It defines how the module behaves when called + """The __call__ function is the main function of a JAX module. It defines how the module behaves when called with inputs. The __call__ function can be thought of as a "forward pass" through the model, and it should return all outputs that are needed for training or inference. - :param self: Access variables that belong to the class - :param hidden_states: chex.Array: Pass the hidden states of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param causal_mask: chex.Array: Mask out the future tokens in the decoder - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attention weights or not - :param fcm_mask: Mask out the attention weights between the input and output tokens + Args: + self: Access variables that belong to the class + hidden_states: chex.Array: Pass the hidden states of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency coefficients for each position + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Determine the position of each + token in a sequence + causal_mask: chex.Array: Mask out the future tokens in the + decoder + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attention weights or not + fcm_mask: Mask out the attention weights between the input + and output tokens :param : Determine if the attention is causal or not - :return: A tuple of two arrays + Returns: + A tuple of two arrays """ batch_size, sequence_length = hidden_states.shape[:2] qkv_states = self.Wqkv(hidden_states) @@ -800,17 +812,21 @@ def init_weights( input_shape: Tuple, params: FrozenDict = None ) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. + """The init_weights function is used to initialize the weights of a model. It takes in a rng, which is a random number generator key that can be used to generate random numbers. The input_shape parameter specifies the shape of the inputs that will be fed into this model. The params parameter allows you to pass in pre-trained weights for your model, if you have them available. - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids - :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Initialize the input_ids, attention_mask + and position_ids + params: flax.core.FrozenDict: Pass in the parameters of a + pre-trained model + + Returns: + A frozendict of parameters """ self.config.initialization_of_moe = True @@ -884,28 +900,35 @@ def __call__( add_params_field: bool = False, **kwargs ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It takes as input: - The parameters of the model (self.params) - The inputs to the model (input_ids, attention_mask, position_ids) - Whether we are training (train=True/False) and whether we want to return all hidden states and attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False). - :param self: Represent the instance of the class - :param input_ids: Pass the input sequence to the model - :param attention_mask: Mask out the padding tokens - :param position_ids: Specify the position of each token in the sequence - :param params: dict: Pass in the parameters of the model - :param past_key_values: dict: Pass the past key values to the model - :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model - :param train: bool: Determine whether to use dropout or not - :param output_attentions: Optional[bool]: Determine whether to return the attention weights - :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers - :param return_dict: Optional[bool]: Return a dictionary of the outputs - :param add_params_field: bool: Add a params field to the inputs dictionary - :return: A tuple of (last_hidden_state, past_key_values) - + Args: + self: Represent the instance of the class + input_ids: Pass the input sequence to the model + attention_mask: Mask out the padding tokens + position_ids: Specify the position of each token in the + sequence + params: dict: Pass in the parameters of the model + past_key_values: dict: Pass the past key values to the model + dropout_rng: jax.random.PRNGKey: Pass in a random number + generator key to the model + train: bool: Determine whether to use dropout or not + output_attentions: Optional[bool]: Determine whether to + return the attention weights + output_hidden_states: Optional[bool]: Determine whether to + return the hidden states of all layers + return_dict: Optional[bool]: Return a dictionary of the + outputs + add_params_field: bool: Add a params field to the inputs + dictionary + + Returns: + A tuple of (last_hidden_state, past_key_values) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions diff --git a/src/python/easydel/modules/deepseek_v2/deepseek_configuration.py b/src/python/easydel/modules/deepseek_v2/deepseek_configuration.py index b87c2babc..e38246556 100644 --- a/src/python/easydel/modules/deepseek_v2/deepseek_configuration.py +++ b/src/python/easydel/modules/deepseek_v2/deepseek_configuration.py @@ -1,206 +1,218 @@ -import warnings -from typing import Optional, Dict, Union - -from jax.sharding import PartitionSpec - -from ..easydel_modelling_utils import EasyDeLPretrainedConfig - - -class DeepseekV2Config(EasyDeLPretrainedConfig): - model_type: str = "deepseek_v2" - - def __init__( - self, - vocab_size=102400, - hidden_size=4096, - intermediate_size=11008, - moe_intermediate_size=1407, - num_hidden_layers=30, - num_attention_heads=32, - num_key_value_heads=32, - n_shared_experts=None, - n_routed_experts=None, - ep_size=1, - routed_scaling_factor=1.0, - kv_lora_rank=512, - q_lora_rank=1536, - qk_rope_head_dim=64, - v_head_dim=128, - qk_nope_head_dim=128, - topk_method='gready', - n_group=None, - topk_group=None, - num_experts_per_tok=None, - moe_layer_freq=1, - first_k_dense_replace=0, - norm_topk_prob=False, - scoring_func='softmax', - aux_loss_alpha=0.001, - seq_aux=True, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=None, - bos_token_id=100000, - eos_token_id=100001, - pretraining_tp=1, - tie_word_embeddings=False, - rope_theta=10000.0, - attention_bias=False, - attention_dropout=0.0, - gradient_checkpointing: str = "nothing_saveable", - use_scan_mlp: bool = False, - scan_mlp_chunk_size: int = 1024, - bits: Optional[int] = None, - rope_scaling: Dict[str, Union[str, float]] = None, - **kwargs, - ): - warnings.warn( - "`DeepseekV2` is still in beta mode.", - UserWarning - ) - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.moe_intermediate_size = moe_intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.n_shared_experts = n_shared_experts - self.n_routed_experts = n_routed_experts - self.ep_size = ep_size - self.routed_scaling_factor = routed_scaling_factor - self.kv_lora_rank = kv_lora_rank - self.q_lora_rank = q_lora_rank - self.qk_rope_head_dim = qk_rope_head_dim - self.v_head_dim = v_head_dim - self.qk_nope_head_dim = qk_nope_head_dim - self.topk_method = topk_method - self.n_group = n_group - self.topk_group = topk_group - self.num_experts_per_tok = num_experts_per_tok - self.moe_layer_freq = moe_layer_freq - self.first_k_dense_replace = first_k_dense_replace - self.norm_topk_prob = norm_topk_prob - self.scoring_func = scoring_func - self.aux_loss_alpha = aux_loss_alpha - self.seq_aux = seq_aux - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.gradient_checkpointing = gradient_checkpointing - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - use_scan_mlp=use_scan_mlp, - scan_mlp_chunk_size=scan_mlp_chunk_size, - bits=bits, - **kwargs, - ) - - def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. - It returns a list of tuples, where each tuple contains two elements: - 1) A regex string that matches the name of one or more parameters in the model. - 2) A PartitionScheme object that defines how those parameters should be partitioned. - - :param fully_sharded_data_parallel: bool: Determine whether to use the fully_sharded_data_parallel partitioning scheme or not - :return: A list of tuples - - """ - return ( - - ("model/embed_tokens/embedding", PartitionSpec("sp", "fsdp")), - - ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")), - ("self_attn/o_proj/kernel", PartitionSpec("tp", ("sp", "fsdp"))), - - ("w1/kernel", PartitionSpec(("fsdp", "sp"))), - ("w2/kernel", PartitionSpec(("fsdp", "sp"))), - ("w3/kernel", PartitionSpec(("fsdp", "sp"))), - ("gate/kernel", PartitionSpec(("fsdp", "sp"))), - - ("input_layernorm/kernel", PartitionSpec(None)), - ("post_attention_layernorm/kernel", PartitionSpec(None)), - - ("model/norm/kernel", PartitionSpec(None)), - ("lm_head/kernel", PartitionSpec("fsdp", "sp")), - (".*", PartitionSpec(None)), - ) if not fully_sharded_data_parallel else ( - ("model/embed_tokens/embedding", PartitionSpec(("fsdp", "sp"))), - - ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")), - ("self_attn/o_proj/kernel", PartitionSpec("tp", ("sp", "fsdp"))), - - ("w1/kernel", PartitionSpec(("fsdp", "sp"))), - ("w2/kernel", PartitionSpec(("fsdp", "sp"))), - ("w3/kernel", PartitionSpec(("fsdp", "sp"))), - ("gate/kernel", PartitionSpec(("fsdp", "sp"))), - - ("input_layernorm/kernel", PartitionSpec(None)), - ("post_attention_layernorm/kernel", PartitionSpec(None)), - - ("model/norm/kernel", PartitionSpec(None)), - ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))), - (".*", PartitionSpec(("fsdp", "sp"))), - ) - - def add_jax_args( - self, - gradient_checkpointing: str = "nothing_saveable", - use_scan_mlp: bool = False, - scan_mlp_chunk_size: int = 1024, - bits: Optional[int] = None, - rope_scaling: Dict[str, Union[str, float]] = None, - **kwargs, - ): - """ - The add_jax_args function adds the following arguments to the model: - - :param self: Bind the attributes and methods of a class to an instance of that class - :param gradient_checkpointing: str: Determine whether to use gradient checkpointing - :param use_scan_mlp: bool: Determine whether to use the scan_mlp function or not - :param scan_mlp_chunk_size: int: Chunk the input to the mlp - :param number_rep_kv: int: Control the number of times that the key and value vectors are repeated - :param bits: Optional[int]: Specify the number of bits to use for quantization - :param attention_dropout: float: Set the dropout rate for the attention layer - :param attention_bias: bool: when ever to use attention_bias - :param initialization_of_moe: bool: initialization of moe needs to disable some dynamic part's this boolean - variable will turn them off. - :param rope_scaling: Dict[str, Union[str, float]]: rope_scaling for rope - :return: A tuple of the following: - - """ - self.attention_dropout = attention_dropout - self.attention_bias = attention_bias - self.rope_scaling = rope_scaling - self.number_rep_kv = number_rep_kv - self.gradient_checkpointing = gradient_checkpointing - self.use_scan_mlp = use_scan_mlp - self.scan_mlp_chunk_size = scan_mlp_chunk_size - self.bits = bits - self.initialization_of_moe = initialization_of_moe - - @staticmethod - def get_weight_decay_exclusions(): - return tuple() - - @staticmethod - def rng_keys(): - return 'params', 'dropout', 'fcm' +import warnings +from typing import Optional, Dict, Union + +from jax.sharding import PartitionSpec + +from ..easydel_modelling_utils import EasyDeLPretrainedConfig + + +class DeepseekV2Config(EasyDeLPretrainedConfig): + model_type: str = "deepseek_v2" + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=None, + n_routed_experts=None, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method='gready', + n_group=None, + topk_group=None, + num_experts_per_tok=None, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func='softmax', + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + gradient_checkpointing: str = "nothing_saveable", + use_scan_mlp: bool = False, + scan_mlp_chunk_size: int = 1024, + bits: Optional[int] = None, + rope_scaling: Dict[str, Union[str, float]] = None, + **kwargs, + ): + warnings.warn( + "`DeepseekV2` is still in beta mode.", + UserWarning + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.gradient_checkpointing = gradient_checkpointing + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + use_scan_mlp=use_scan_mlp, + scan_mlp_chunk_size=scan_mlp_chunk_size, + bits=bits, + **kwargs, + ) + + def get_partition_rules(self, fully_sharded_data_parallel: bool = True): + """The get_partition_rules function is used to define the partitioning scheme for a model. + It returns a list of tuples, where each tuple contains two elements: + 1) A regex string that matches the name of one or more parameters in the model. + 2) A PartitionScheme object that defines how those parameters should be partitioned. + + Args: + fully_sharded_data_parallel: bool: Determine whether to use + the fully_sharded_data_parallel partitioning scheme or + not + + Returns: + A list of tuples + """ + return ( + + ("model/embed_tokens/embedding", PartitionSpec("sp", "fsdp")), + + ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")), + ("self_attn/o_proj/kernel", PartitionSpec("tp", ("sp", "fsdp"))), + + ("w1/kernel", PartitionSpec(("fsdp", "sp"))), + ("w2/kernel", PartitionSpec(("fsdp", "sp"))), + ("w3/kernel", PartitionSpec(("fsdp", "sp"))), + ("gate/kernel", PartitionSpec(("fsdp", "sp"))), + + ("input_layernorm/kernel", PartitionSpec(None)), + ("post_attention_layernorm/kernel", PartitionSpec(None)), + + ("model/norm/kernel", PartitionSpec(None)), + ("lm_head/kernel", PartitionSpec("fsdp", "sp")), + (".*", PartitionSpec(None)), + ) if not fully_sharded_data_parallel else ( + ("model/embed_tokens/embedding", PartitionSpec(("fsdp", "sp"))), + + ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")), + ("self_attn/o_proj/kernel", PartitionSpec("tp", ("sp", "fsdp"))), + + ("w1/kernel", PartitionSpec(("fsdp", "sp"))), + ("w2/kernel", PartitionSpec(("fsdp", "sp"))), + ("w3/kernel", PartitionSpec(("fsdp", "sp"))), + ("gate/kernel", PartitionSpec(("fsdp", "sp"))), + + ("input_layernorm/kernel", PartitionSpec(None)), + ("post_attention_layernorm/kernel", PartitionSpec(None)), + + ("model/norm/kernel", PartitionSpec(None)), + ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))), + (".*", PartitionSpec(("fsdp", "sp"))), + ) + + def add_jax_args( + self, + gradient_checkpointing: str = "nothing_saveable", + use_scan_mlp: bool = False, + scan_mlp_chunk_size: int = 1024, + bits: Optional[int] = None, + rope_scaling: Dict[str, Union[str, float]] = None, + **kwargs, + ): + """The add_jax_args function adds the following arguments to the model: + + Args: + self: Bind the attributes and methods of a class to an + instance of that class + gradient_checkpointing: str: Determine whether to use + gradient checkpointing + use_scan_mlp: bool: Determine whether to use the scan_mlp + function or not + scan_mlp_chunk_size: int: Chunk the input to the mlp + number_rep_kv: int: Control the number of times that the key + and value vectors are repeated + bits: Optional[int]: Specify the number of bits to use for + quantization + attention_dropout: float: Set the dropout rate for the + attention layer + attention_bias: bool: when ever to use attention_bias + initialization_of_moe: bool: initialization of moe needs to + disable some dynamic part's this boolean variable will + turn them off. + rope_scaling: Dict[str, Union[str, float]]: rope_scaling for + rope + + Returns: + A tuple of the following: + """ + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.rope_scaling = rope_scaling + self.number_rep_kv = number_rep_kv + self.gradient_checkpointing = gradient_checkpointing + self.use_scan_mlp = use_scan_mlp + self.scan_mlp_chunk_size = scan_mlp_chunk_size + self.bits = bits + self.initialization_of_moe = initialization_of_moe + + @staticmethod + def get_weight_decay_exclusions(): + return tuple() + + @staticmethod + def rng_keys(): + return 'params', 'dropout', 'fcm' diff --git a/src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py b/src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py index 490ced31b..b4bec2274 100644 --- a/src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py +++ b/src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py @@ -1,1368 +1,1394 @@ -import functools -import math -import typing - -import fjformer -import flax -from flax.struct import dataclass -from jax import numpy as jnp, lax -import jax -from fjformer import linen as nn -from flax.traverse_util import unflatten_dict, flatten_dict -from flax.core import freeze, unfreeze, FrozenDict -from typing import Union, Optional, Tuple -from flax.linen import partitioning as nn_partitioning, combine_masks -from transformers.modeling_flax_outputs import FlaxMaskedLMOutput, FlaxBaseModelOutput, FlaxCausalLMOutput -from fjformer.func import auxiliary_load_balancing_loss_func -from ..attention_module import AttentionModule -from ..flax_modelling_utils import ( - ACT2FN, - with_sharding_constraint, - repeat_kv_bnsh, - get_dot_general_by_bits, - BaseJAXAttentionModule, - get_gradient_checkpoint_policy, - block_wise_ffn -) -from jax.sharding import PartitionSpec -import chex -from .deepseek_configuration import DeepseekV2Config -from ..easydel_modelling_utils import EasyDeLFlaxPretrainedModel - -re_mat = nn_partitioning.remat - - -@flax.struct.dataclass -class MoeModelOutput: - last_hidden_state: chex.Array = None - hidden_states: Optional[Tuple[chex.Array]] = None - attentions: Optional[Tuple[chex.Array]] = None - router_logits: Optional[Tuple[chex.Array]] = None - - -@flax.struct.dataclass -class MoeCausalLMOutput(FlaxMaskedLMOutput): - aux_loss: Optional[chex.Array] = None - router_logits: Optional[Tuple[chex.Array]] = None - - -class DeepseekV2RMSNorm(nn.Module): - dim: int - eps: float = 1e-6 - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - - def setup(self) -> None: - self.weight = self.param( - 'kernel', - nn.initializers.ones, - (self.dim,), - self.param_dtype, - ) - - def _norm(self, x: jnp.ndarray) -> jnp.ndarray: - return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) - - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - x = x.astype(jnp.promote_types(self.dtype, jnp.float32)) - output = self._norm(x).astype(self.dtype) - weight = fjformer.linen.linen.control_quantization(self.weight, self.dtype) - return output * weight - - -def yarn_find_correction_dim( - num_rotations, dim, base=10000, max_position_embeddings=2048 -): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( - 2 * math.log(base) - ) - - -def yarn_find_correction_range( - low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 -): - low = math.floor( - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) - return max(low, 0), min(high, dim - 1) # Clamp values just in case - - -def yarn_get_mscale(scale=1., mscale=1.): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -def yarn_linear_ramp_mask(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (jnp.arange(dim, dtype=jnp.float32) - min) / (max - min) - return jnp.clip(linear_func, 0, 1) - - -def init_deepseek_rotary_embedding( - dim, - max_position_embeddings=2048, - base=10000, - method: typing.Literal["linear", "yarn", "dynamic", None] = None, - kwargs: typing.Optional[dict] = None, -): - if method is None: - inv_freq = 1.0 / (base ** (jnp.arange(0, dim, 2).astype("float32") / dim)) - t = jnp.arange(max_position_embeddings, dtype=inv_freq.dtype) - freqs = jnp.outer(t, inv_freq) - emb = jnp.concatenate((freqs, freqs), axis=-1) - return jnp.sin(emb), jnp.cos(emb) - elif method == "linear": - assert kwargs is not None - inv_freq = 1.0 / (base ** (jnp.arange(0, dim, 2).astype("float32") / dim)) - t = jnp.arange(max_position_embeddings, dtype=inv_freq.dtype) / kwargs.get("scaling_factor") - freqs = jnp.outer(t, inv_freq) - emb = jnp.concatenate((freqs, freqs), axis=-1) - return jnp.sin(emb), jnp.cos(emb) - elif method == "dynamic": - assert kwargs is not None - targeted_len = kwargs.get("targeted_len", max_position_embeddings) - if targeted_len > max_position_embeddings: - base = base * ( - (kwargs.get("scaling_factor") * targeted_len / max_position_embeddings) - - (kwargs.get("scaling_factor") - 1) - ) ** (dim / (dim - 2)) - inv_freq = 1.0 / ( - base ** (jnp.arange(0, dim, 2).astype("float32") / dim) - ) - - else: - inv_freq = 1.0 / (base ** (jnp.arange(0, dim, 2).astype("float32") / dim)) - t = jnp.arange(max_position_embeddings, dtype=inv_freq.dtype) / kwargs.get("scaling_factor") - - freqs = jnp.outer(t, inv_freq) - emb = jnp.concatenate((freqs, freqs), axis=-1) - return jnp.sin(emb), jnp.cos(emb) - elif method == "yarn": - - scaling_factor = kwargs.get("scaling_factor", 1.0) - original_max_position_embeddings = kwargs.get("original_max_position_embeddings", 4096) - beta_fast = kwargs.get("beta_fast", 32) - beta_slow = kwargs.get("beta_slow", 1) - mscale = kwargs.get("mscale", 1) - mscale_all_dim = kwargs.get("mscale_all_dim", 0) - freq_extra = 1.0 / ( - base - ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim) - ) - freq_inter = 1.0 / ( - scaling_factor - * base - ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim) - ) - - low, high = yarn_find_correction_range( - beta_fast, - beta_slow, - dim, - base, - original_max_position_embeddings, - ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).astype("float32") - inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask - t = jnp.arange(max_position_embeddings, dtype=jnp.float32) - - freqs = jnp.outer(t, inv_freq) - - _mscale = float( - yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(scaling_factor, mscale_all_dim) - ) - - emb = jnp.concatenate((freqs, freqs), axis=-1) - return (jnp.sin(emb) * _mscale).astype("float32"), (jnp.cos(emb) * _mscale).astype("float32") - - -def rotate_half(x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return jnp.concatenate((-x2, x1), axis=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - cos = jnp.expand_dims(cos[position_ids], unsqueeze_dim) - sin = jnp.expand_dims(sin[position_ids], unsqueeze_dim) - - b, h, s, d = q.shape - q = q.view(b, h, s, d // 2, 2).transpose(0, 1, 2, 4, 3).reshape(b, h, s, d) - - b, h, s, d = k.shape - k = k.view(b, h, s, d // 2, 2).transpose(0, 1, 2, 4, 3).reshape(b, h, s, d) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class FlaxDeepseekV2MLP(nn.Module): - config: DeepseekV2Config - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") - hidden_size: Optional[int] = None - intermediate_size: Optional[int] = None - - def setup(self) -> None: - dense = functools.partial( - nn.Linear, - use_bias=False, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - kernel_init=nn.initializers.normal(), - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - self.gate_proj = dense(self.intermediate_size or self.config.intermediate_size) - self.up_proj = dense(self.intermediate_size or self.config.intermediate_size) - self.down_proj = dense(self.hidden_size or self.config.hidden_size) - self.act_fn = ACT2FN[self.config.hidden_act] - - def __call__( - self, - x: chex.Array, - e: bool = False # Ignored - ): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class FlaxMoEGate(nn.Module): - config: DeepseekV2Config - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") - - def setup(self) -> None: - config = self.config - self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts - self.routed_scaling_factor = config.routed_scaling_factor - self.scoring_func = config.scoring_func - self.alpha = config.aux_loss_alpha - self.seq_aux = config.seq_aux - self.topk_method = config.topk_method - self.n_group = config.n_group - self.topk_group = config.topk_group - - self.norm_topk_prob = config.norm_topk_prob - self.gating_dim = config.hidden_size - self.weight = self.param( - "kernel", - nn.initializers.kaiming_uniform(dtype=self.param_dtype), - (self.n_routed_experts, self.gating_dim) - ) - - def __call__(self, hidden_states, deterministic: bool = True): - bsz, seq_len, h = hidden_states.shape - hidden_states = hidden_states.reshape(-1, h) - logits = jax.lax.batch_matmul( - hidden_states.astype(jnp.float32), - self.weight.astype(jnp.float32), - precision=self.precision - ) - if self.scoring_func == "softmax": - scores = jax.nn.softmax(logits.astype(jnp.float32), axis=-1) - else: - raise NotImplementedError( - f"insupportable scoring function for MoE gating: {self.scoring_func}" - ) - - ### select top-k experts - if self.topk_method == "gready": - topk_weight, topk_idx = jax.lax.top_k( - scores, k=self.top_k - ) - elif self.topk_method == "group_limited_greedy": - group_scores = scores.reshape(bsz * seq_len, self.n_group, -1).max(axis=-1) # [n, n_group] - - # Find the indices of the top k scores in each group - top_k_indices = lax.top_k(group_scores, self.topk_group)[1] # [n, topk_group] - - # Initialize a mask with zeros - group_mask = jnp.zeros_like(group_scores) # [n, n_group] - - # Update the mask: this is a bit tricky in JAX as there is no direct scatter function - n_indices = jnp.arange(group_mask.shape[0])[:, None] - group_mask = group_mask.at[n_indices, top_k_indices].set(1) # [n, n_group] - - # Expand and reshape the group_mask - score_mask = jnp.repeat(group_mask[:, :, None], self.n_routed_experts // self.n_group, axis=2) - score_mask = score_mask.reshape(bsz * seq_len, -1) # [n, e] - - # Apply the mask to scores - masked_scores = jnp.where(score_mask, scores, 0.0) # [n, e] - - # Compute the top k scores after masking - topk_weight, topk_idx = lax.top_k(masked_scores, self.top_k) - else: - raise ValueError() - ### norm gate to sum 1 - if self.top_k > 1 and self.norm_topk_prob: - denominator = jnp.sum(topk_weight, axis=-1, keepdims=True) + 1e-20 - topk_weight = topk_weight / denominator - else: - topk_weight = topk_weight * self.routed_scaling_factor - ### expert-level computation auxiliary loss - if not deterministic and self.alpha > 0.0: - scores_for_aux = scores - aux_topk = self.top_k - topk_idx_for_aux_loss = topk_idx.reshape(bsz, -1) - if self.seq_aux: - scores_for_seq_aux = scores_for_aux.reshape(bsz, seq_len, -1) - ce = jnp.zeros(bsz, self.n_routed_experts) - ce = ce.at[1, topk_idx_for_aux_loss].add( - jnp.ones(bsz, seq_len * aux_topk), - ) - ce = jnp.divide(ce, (seq_len * aux_topk / self.n_routed_experts)) - aux_loss = jnp.mean(jnp.sum((ce * jnp.mean(scores_for_seq_aux, axis=-1)), axis=1)) * self.alpha - else: - mask_ce = jax.nn.one_hot( - topk_idx_for_aux_loss.reshape(-1), num_classes=self.n_routed_experts - ) - ce = jnp.mean(mask_ce.astype("float32"), axis=0) - Pi = jnp.mean(scores_for_aux, axis=0) - fi = ce * self.n_routed_experts - aux_loss = jnp.sum(Pi * fi) * self.alpha - else: - aux_loss = None - return topk_idx, topk_weight, aux_loss - - -class FlaxDeepseekV2MLPCollection(nn.Module): - config: DeepseekV2Config - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") - - def setup(self) -> None: - self.experts = [ - FlaxDeepseekV2MLP( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - intermediate_size=self.config.moe_intermediate_size, - name=str(i) - ) - for i in range(self.config.n_routed_experts) - ] - - def __call__(self, hidden_states, flat_topk_idx): - y = jnp.empty_like(hidden_states) - for i, expert in enumerate(self.experts): - y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) - return y - - -class FlaxDeepseekV2MoE(nn.Module): - config: DeepseekV2Config - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") - - def setup(self) -> None: - config = self.config - self.num_experts_per_tok = config.num_experts_per_tok - - self.ep_size = 1 - self.experts_per_rank = config.n_routed_experts - self.ep_rank = 0 - self.experts = FlaxDeepseekV2MLPCollection( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - ) - self.gate = FlaxMoEGate( - config=config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - ) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = FlaxDeepseekV2MoE( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - intermediate_size=intermediate_size, - ) - - def __call__( - self, - hidden_states: chex.Array, - e: bool = False # ignored ! - ): - - identity = hidden_states - orig_shape = hidden_states.shape - topk_idx, topk_weight, aux_loss = self.gate(hidden_states) - hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_topk_idx = topk_idx.reshape(-1) - hidden_states = hidden_states.repeat(self.num_experts_per_tok, axis=0) - y = self.experts(hidden_states=hidden_states, flat_topk_idx=flat_topk_idx) - y = (y.reshape(*topk_weight.shape, -1) * jnp.expand_dims(topk_weight, -1)).sum(axis=1) - y = y.reshape(*orig_shape) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts(identity) - return y - - -class FlaxDeepseekV2Attention(BaseJAXAttentionModule): - config: DeepseekV2Config - layer_idx: int - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") - - def setup(self) -> None: - config = self.config - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.q_lora_rank = config.q_lora_rank - self.qk_rope_head_dim = config.qk_rope_head_dim - self.kv_lora_rank = config.kv_lora_rank - self.v_head_dim = config.v_head_dim - self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim - - self.is_causal = True - - dense_class = functools.partial( - nn.Linear, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - ) - - self.q_a_proj = dense_class(config.q_lora_rank, use_bias=config.attention_bias) - self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank) - self.q_b_proj = dense_class(self.num_heads * self.q_head_dim, use_bias=False) - - self.kv_a_proj_with_mqa = dense_class( - config.kv_lora_rank + config.qk_rope_head_dim, - use_bias=config.attention_bias, - ) - self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank) - self.kv_b_proj = dense_class( - self.num_heads - * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), - use_bias=False, - ) - - self.o_proj = dense_class(self.hidden_size, use_bias=config.attention_bias) - - softmax_scale = self.q_head_dim ** (-0.5) - if self.config.rope_scaling is not None: - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_scaling["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - softmax_scale = self.softmax_scale * mscale * mscale - - self.attention_performer = AttentionModule( - use_sharding_constraint=self.config.use_sharding_constraint, - block_k_major=self.config.block_k_major, - block_b=self.config.block_b, - block_q=self.config.block_q, - block_k=self.config.block_k, - block_q_major_dkv=self.config.block_q_major_dkv, - block_k_major_dkv=self.config.block_k_major_dkv, - block_k_major_dq=self.config.block_k_major_dq, - block_k_dkv=self.config.block_k_dkv, - block_q_dkv=self.config.block_q_dkv, - block_q_dq=self.config.block_q_dq, - block_k_dq=self.config.block_k_dq, - num_attention_heads=self.config.num_attention_heads, - attention_dropout=self.config.attention_dropout, - head_dims=self.q_head_dim, - attention_partition_spec=self.config.attention_partition_spec, - shard_attention_computation=self.config.shard_attention_computation, - precision=self.precision, - force_float32_tpu=True, - attn_mechanism=self.config.attn_mechanism, - dtype=self.dtype, - bias_partition_spec=self.config.bias_partition_spec, - key_partition_spec=self.config.key_partition_spec, - query_partition_spec=self.config.query_partition_spec, - generation_query_partition_spec=self.config.generation_query_partition_spec, - generation_bias_partition_spec=self.config.generation_bias_partition_spec, - generation_attention_partition_spec=self.config.generation_attention_partition_spec, - value_partition_spec=self.config.value_partition_spec, - scan_ring_attention=self.config.scan_ring_attention, - mesh=self.config.jax_mesh(), - sm_scale=softmax_scale, - axis_name=self.config.attention_axis_name - ) - - def __call__( - self, - hidden_states: chex.Array, - freq_cis: Tuple[chex.Array, chex.Array], - attention_mask: chex.Array, - position_ids: chex.Array, - causal_mask: chex.Array, - segment_ids: Optional[chex.Array] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - fcm_mask=None, - ): - bsz, q_len, _ = hidden_states.shape - - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) - q_nope = q[:, :, :, :self.qk_nope_head_dim] - q_pe = q[:, :, :, self.qk_nope_head_dim:] - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - - k_pe = compressed_kv[:, :, :, self.kv_lora_rank:self.kv_lora_rank + self.qk_rope_head_dim] - compressed_kv = compressed_kv[:, :, :, :self.kv_lora_rank] - - k_pe = k_pe.reshape(bsz, q_len, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .reshape(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(0, 2, 1, 3) - ) - - k_nope = kv[:, :, :, :self.qk_nope_head_dim] - value_states = kv[:, :, :, self.qk_nope_head_dim:self.qk_nope_head_dim + self.v_head_dim] - - sin, cos = freq_cis - - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim:] = q_pe - - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim:] = k_pe - - query_states = query_states.transpose(0, 2, 1, 3) - key_states = key_states.transpose(0, 2, 1, 3) - value_states = value_states.transpose(0, 2, 1, 3) - - query_length, key_length = query_states.shape[1], key_states.shape[1] - - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - causal_mask, (0, 0, mask_shift, 0), (1, 1, - query_length, max_decoder_length) - ) - else: - causal_mask = causal_mask[:, :, :query_length, :key_length] - - batch_size = hidden_states.shape[0] - causal_mask = jnp.broadcast_to( - causal_mask, (batch_size,) + causal_mask.shape[1:]) - if attention_mask.ndim == 2: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - attention_mask = jnp.broadcast_to( - attention_mask, causal_mask.shape - ) - attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask) - - dropout_rng = None - - if not deterministic and self.config.attention_dropout > 0.0: - dropout_rng = self.make_rng("dropout") - if self.has_variable("cache", "cached_key") or init_cache: - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, - value_states, - query_states, - attention_mask - ) - # if self.config.use_sharding_constraint: - # query_states = with_sharding_constraint( - # query_states, PartitionSpec(("dp", "fsdp"), "sp" if query_states.shape[1] != 1 else None, "tp", None) - # ) - # key_states = with_sharding_constraint( - # key_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None) - # ) - # value_states = with_sharding_constraint( - # value_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None) - # ) - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo( - self.dtype).min).astype(self.dtype), - ) - - query_length, key_length = query_states.shape[1], key_states.shape[1] - - attentions = self.attention_performer.__call__( - query_states=query_states, - key_states=key_states, - value_states=value_states, - bias=attention_bias, - attention_mask=attention_mask, - causal=True, - dropout_rng=dropout_rng, - deterministic=deterministic, - query_sequence_length=query_length, - key_value_sequence_length=key_length, - uses_cache=self.has_variable("cache", "cached_key") or init_cache, - segment_ids=segment_ids, - causal_mask=causal_mask - ) - - attn_output = self._merge_heads(attentions.attention_outputs) - if self.config.shard_attention_computation: - attn_output = with_sharding_constraint( - attn_output, PartitionSpec( - ("dp", "fsdp"), - "sp" if attn_output.shape[1] != 1 else None, - "tp" - ) - ) - attn_output = self.o_proj(attn_output) - - outputs = ( - attn_output, attentions.attention_weights - ) if output_attentions else ( - attn_output, - ) - return outputs - - -class FlaxDeepseekV2DecoderLayer(nn.Module): - config: DeepseekV2Config - layer_idx: int - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") - - def setup(self): - config = self.config - layer_idx = self.layer_idx - self.hidden_size = config.hidden_size - - attn_block = FlaxDeepseekV2Attention - mlp_block = FlaxDeepseekV2MLP - mlp_moe_block = FlaxDeepseekV2MoE - - if self.config.gradient_checkpointing != "": - # hidden_states: chex.Array, - # freq_cis: Tuple[chex.Array, chex.Array], - # attention_mask: chex.Array, - # position_ids: chex.Array, - # causal_mask: chex.Array, - # segment_ids: Optional[chex.Array] = None, - # deterministic: bool = True, - # init_cache: bool = False, - # output_attentions: bool = False, - # fcm_mask = None, - attn_block = re_mat( - attn_block, - policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing), - static_argnums=(1, 3, 4, 6, 7, 8) - ) - mlp_block = re_mat( - mlp_block, - policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing), - static_argnums=(1,) - ) - - mlp_moe_block = re_mat( - mlp_moe_block, - policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing), - static_argnums=(1,) - ) - - self.self_attn = attn_block( - config=config, - layer_idx=self.layer_idx, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - self.mlp = ( - mlp_moe_block( - config=config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0 - ) - else mlp_block( - config=config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - ) - self.input_layernorm = DeepseekV2RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, - dtype=self.dtype, - param_dtype=self.param_dtype - ) - self.post_attention_layernorm = DeepseekV2RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, - dtype=self.dtype, - param_dtype=self.param_dtype - ) - - def forward( - self, - hidden_states: chex.Array, - freq_cis: Tuple[chex.Array, chex.Array], - attention_mask: chex.Array, - causal_mask: chex.Array, - position_ids: chex.Array, - segment_ids: Optional[chex.Array] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = True - ) -> Tuple[ - chex.Array, Optional[chex.Array] - ]: - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states, - freq_cis, - attention_mask, - position_ids, - causal_mask, - segment_ids, - deterministic, - init_cache, - output_attentions, - None - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - - if self.config.use_scan_mlp: - feed_forward_hidden_states = block_wise_ffn( - self.mlp, - hidden_states, - self.config.scan_mlp_chunk_size, - deterministic, - ) - else: - feed_forward_hidden_states = self.mlp( - hidden_states, - deterministic, - ) - hidden_states = residual + feed_forward_hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - return outputs # type:ignore - - -class FlaxDeepseekV2DecoratorCollection(nn.Module): - config: DeepseekV2Config - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[Union[str, jax.lax.Precision] - ] = jax.lax.Precision("fastest") - - def setup(self) -> None: - self.layers = [ - FlaxDeepseekV2DecoderLayer( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - layer_idx=i, - name=str(i) - ) for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states: chex.Array, - freq_cis: Tuple[chex.Array, chex.Array], - attention_mask: chex.Array, - causal_mask: chex.Array, - position_ids: chex.Array, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - # hidden_states: chex.Array, - # freq_cis: Tuple[chex.Array, chex.Array], - # attention_mask: chex.Array, - # causal_mask: chex.Array, - # position_ids: chex.Array, - # segment_ids: Optional[chex.Array] = None, - # deterministic: bool = True, - # init_cache: bool = False, - # output_attentions: bool = True - - output = layer( - hidden_states, - freq_cis, - attention_mask, - causal_mask, - position_ids, - None, - deterministic, - init_cache, - output_attentions - ) - hidden_states = output[0] - - if output_attentions: - output_attentions += (output[1],) - - return hidden_states, all_hidden_states, all_attentions - - -class FlaxDeepseekV2Module(nn.Module): - config: DeepseekV2Config - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[Union[jax.lax.Precision, str]] = None - - def setup(self): - - self.embed_tokens = nn.Embed( - self.config.vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal( - stddev=self.config.initializer_range), - dtype=self.dtype, - param_dtype=self.param_dtype, - ) - - self.layers = FlaxDeepseekV2DecoratorCollection( - self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.norm = DeepseekV2RMSNorm( - self.config.hidden_size, - eps=self.config.rms_norm_eps, - dtype=self.dtype, - param_dtype=self.param_dtype - ) - - initial_rope_kwargs = {} - method = None - if self.config.rope_scaling is not None: - scaling_type = self.config.rope_scaling["type"] - method = scaling_type - if scaling_type != "yarn": - initial_rope_kwargs = dict(scaling_factor=self.config.rope_scaling["factor"]) - else: - initial_rope_kwargs = { - key: self.config.rope_scaling[key] - for key in [ - "original_max_position_embeddings", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", - ] - if key in self.config.rope_scaling - } - initial_rope_kwargs["scaling_factor"] = self.config.rope_scaling["factor"] - self.freq_cis = init_deepseek_rotary_embedding( - dim=self.config.hidden_size // self.config.num_attention_heads, - max_position_embeddings=( - getattr( - self.config, - "freq_max_position_embeddings", - self.config.max_position_embeddings - ) - ), - base=self.config.rope_theta, - method=method, # type:ignore - kwargs=initial_rope_kwargs - ) - self.causal_mask = flax.linen.make_causal_mask( - jnp.ones( - ( - 1, - getattr( - self.config, - "c_max_position_embeddings", - self.config.max_position_embeddings - ) - ), - dtype="bool" - ), - dtype="bool" - ) - - def __call__( - self, - input_ids: Optional[chex.Array] = None, - attention_mask: Optional[chex.Array] = None, - position_ids: Optional[chex.Array] = None, - deterministic: bool = True, - inputs_embeds: chex.Array = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ) -> typing.Union[Tuple[chex.Array, ...], FlaxBaseModelOutput]: - """ - The __call__ function is the main function of a Flax model. - It takes in input_ids, attention_mask, and position_ids as inputs to the model. - The output is a tuple containing: last hidden state (hidden states), all hidden states (if output_hidden_states=True), attentions (if output attentions=True). - - - :param self: Represent the instance of the class - :param input_ids: chex.Array: Pass in the input ids - :param attention_mask: chex.Array: Mask out the attention weights for certain tokens - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param deterministic: bool: Determine whether to use dropout or not - :param inputs_embeds: chex.Array: Pass in the embedding of the input_ids - :param init_cache: bool: Initialize the cache for the decoder - :param output_attentions: bool: Determine whether to return the attention weights or not - :param output_hidden_states: bool: Return all hidden states or just the last one - :param return_dict: bool: Return a dictionary of the outputs or not - :param : Determine whether the model is in training mode or not - :return: A tuple of the hidden states, all hidden states, and attentions - - """ - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids.astype("i4")) - if attention_mask.ndim == 2: - b, s = attention_mask.shape - attention_mask = attention_mask.reshape(b, 1, 1, s) - - outputs = self.layers( - hidden_states=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - freq_cis=self.freq_cis, - init_cache=init_cache, - output_attentions=output_attentions, - deterministic=deterministic, - causal_mask=self.causal_mask - ) - - hidden_states = outputs[0] - hidden_states = self.norm(hidden_states) - - if output_hidden_states: - all_hidden_states = outputs[1] + (hidden_states,) - outputs = (hidden_states, all_hidden_states) + outputs[2:] - else: - outputs = (hidden_states,) + outputs[1:] - - if not return_dict: - return tuple(value for value in outputs if value is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=outputs[1], - attentions=outputs[-1], - ) - - -class DeepseekV2PreTrainedModel(EasyDeLFlaxPretrainedModel): - config_class: DeepseekV2Config = DeepseekV2Config - module_class: nn.Module = None - base_model_prefix = "model" - - def __init__( - self, - config: DeepseekV2Config, - dtype: jnp.dtype = jnp.bfloat16, - param_dtype: jnp.dtype = jnp.bfloat16, - precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest"), - input_shape: Tuple[int, int] = (1, 1), - seed: int = 0, - _do_init: bool = False, - **kwargs - ): - module = self.module_class( - config=config, - dtype=dtype, - param_dtype=param_dtype, - precision=precision, - **kwargs - ) - - super().__init__( - dtype=dtype, _do_init=_do_init, - module=module, config=config, input_shape=input_shape, - seed=seed, - ) - - def init_weights( - self, - rng: jax.random.PRNGKey, - input_shape: Tuple, - params: FrozenDict = None - ) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. - It takes in a rng, which is a random number generator key that can be used to generate random numbers. - The input_shape parameter specifies the shape of the inputs that will be fed into this model. - The params parameter allows you to pass in pre-trained weights for your model, if you have them available. - - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids - :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters - """ - - self.config.initialization_of_moe = True - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids, dtype="i4") - position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(input_ids).shape[-1], dtype="i4"), - input_shape, - ) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - if self.config.add_cross_attention: - encoder_hidden_states = jnp.zeros( - input_shape + (self.config.hidden_size,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - else: - module_init_outputs = self.module.init( - rngs, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=False - ) - random_params = module_init_outputs["params"] - - self.config.initialization_of_moe = False - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - - input_ids = jnp.ones((batch_size, max_length)) - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange( - jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return init_variables["cache"] - - def __call__( - self, - input_ids: chex.Array, - attention_mask: Optional[chex.Array] = None, - position_ids: Optional[chex.Array] = None, - params: dict = None, - past_key_values: dict = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, - add_params_field: bool = False, - **kwargs - ): - """ - The __call__ function is the main function of a JAX module. - It takes as input: - - The parameters of the model (self.params) - - The inputs to the model (input_ids, attention_mask, position_ids) - - Whether we are training (train=True/False) and whether we want to return all hidden states and - attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False). - - :param self: Represent the instance of the class - :param input_ids: Pass the input sequence to the model - :param attention_mask: Mask out the padding tokens - :param position_ids: Specify the position of each token in the sequence - :param params: dict: Pass in the parameters of the model - :param past_key_values: dict: Pass the past key values to the model - :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model - :param train: bool: Determine whether to use dropout or not - :param output_attentions: Optional[bool]: Determine whether to return the attention weights - :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers - :param return_dict: Optional[bool]: Return a dictionary of the outputs - :param add_params_field: bool: Add a params field to the inputs dictionary - :return: A tuple of (last_hidden_state, past_key_values) - - """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - batch_size, sequence_length = input_ids.shape - - if position_ids is None: - if past_key_values is not None: - raise ValueError( - "Make sure to provide `position_ids` when passing `past_key_values`.") - - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[ - None, :], (batch_size, sequence_length)) - - if attention_mask is None: - attention_mask = jnp.ones((batch_size, sequence_length)) - - rng_s = {} - if dropout_rng is not None: - rng_s["dropout"] = dropout_rng - - inputs = { - "params": params or self.params} if add_params_field else params or self.params - - if self.config.bits is not None: - rng_s['params'] = jax.random.key(0) - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), # input_ids: chex.Array - # attention_mask: Optional[chex.Array] = None - jnp.array(attention_mask, dtype="i4"), - # position_ids: Optional[chex.Array] = None - jnp.array(position_ids, dtype="i4"), - None, # inputs_embeds: Optional[chex.Array] = None - output_attentions, # output_attentions: Optional[bool] = None - # output_hidden_states: Optional[bool] = None - output_hidden_states, - # output_router_logits: Optional[bool] = None - output_router_logits, - False, # init_cache: bool = False - not train, # deterministic: bool = True - return_dict, # return_dict: bool = True - rngs=rng_s, - mutable=mutable, - ) - - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + \ - (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -class FlaxDeepseekV2ForCausalLMModule(nn.Module): - config: DeepseekV2Config - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest") - - def setup(self) -> None: - self.model = FlaxDeepseekV2Module( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.lm_head = nn.Linear( - self.config.vocab_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - use_bias=False, - kernel_init=nn.initializers.normal(self.config.initializer_range), - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - - def __call__( - self, - input_ids: chex.Array, - attention_mask: chex.Array, - position_ids: chex.Array, - deterministic: bool = True, - inputs_embeds: chex.Array = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - """ - The __call__ function is the main function of a Flax module. It defines how the model will be called, - and what it returns. In this case, we are calling our Transformer model with input_ids and attention_mask - as inputs (these are defined in __init__). We also have some optional arguments that can be passed to - the call function: deterministic (whether to use dropout), inputs_embeds (if you want to pass your own embeddings), - output_attentions and output_hidden states which return additional outputs from the transformer layers if set True. Finally, - - :param self: Refer to the object itself - :param input_ids: chex.Array: Pass in the input tokens - :param attention_mask: chex.Array: Mask out the padding tokens - :param position_ids: chex.Array: Specify the position of each token in the sequence - :param deterministic: bool: Determine whether to use dropout in the model - :param inputs_embeds: chex.Array: Pass in the embeddings of the input tokens - :param init_cache: bool: Initialize the cache for the decoder - :param output_attentions: bool: Return the attention weights - :param output_hidden_states: bool: Return the hidden states of all layers - :param return_dict: bool: Return a dictionary of the outputs or just the logits - :param : Determine whether to return the logits or not - :return: A tuple of (lm_logits, hidden_states, attentions) - - """ - batch_size, seq_length = input_ids.shape - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - position_ids = jnp.broadcast_to( - jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), - (batch_size, seq_length) - ) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - deterministic=deterministic, - inputs_embeds=inputs_embeds, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_kernel = self.transformer.variables["params"]["embed_tokens"]["embedding"] - shared_kernel = fjformer.linen.linen.control_quantization(shared_kernel, self.param_dtype).T - lm_logits = self.lm_head.apply( - {"params": {"kernel": shared_kernel}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - # lm_logits = lm_logits.astype(jnp.float32) - - if not return_dict: - return (lm_logits,) + outputs[1:] - - return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) - - -class FlaxDeepseekV2Model(DeepseekV2PreTrainedModel): - module_class = FlaxDeepseekV2Module - - -class FlaxDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): - module_class = FlaxDeepseekV2ForCausalLMModule - - def set_input_embeddings(self, value): - self.module.model.embed_tokens = value - - def get_input_embeddings(self): - return self.module.model.embed_tokens - - def set_decoder(self, decoder): - self.module.model = decoder - - def get_decoder(self): - return self.module.model - - def get_output_embeddings(self): - return self.module.lm_head - - def set_output_embeddings(self, new_embeddings): - self.module.lm_head = new_embeddings - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None): - """ - The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. - - :param self: Access variables that belong to the class - :param input_ids: Pass in the input tokens - :param max_length: Set the length of the sequence to be generated - :param attention_mask: Optional[chex.Array]: Mask the attention weights - :return: A dictionary of the past_key_values, attention_mask and position ids - - """ - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - extended_attention_mask = jnp.ones( - (batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice( - extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[ - None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs +import functools +import math +import typing + +import fjformer +import flax +from flax.struct import dataclass +from jax import numpy as jnp, lax +import jax +from fjformer import linen as nn +from flax.traverse_util import unflatten_dict, flatten_dict +from flax.core import freeze, unfreeze, FrozenDict +from typing import Union, Optional, Tuple +from flax.linen import partitioning as nn_partitioning, combine_masks +from transformers.modeling_flax_outputs import FlaxMaskedLMOutput, FlaxBaseModelOutput, FlaxCausalLMOutput +from fjformer.func import auxiliary_load_balancing_loss_func +from ..attention_module import AttentionModule +from ..flax_modelling_utils import ( + ACT2FN, + with_sharding_constraint, + repeat_kv_bnsh, + get_dot_general_by_bits, + BaseJAXAttentionModule, + get_gradient_checkpoint_policy, + block_wise_ffn +) +from jax.sharding import PartitionSpec +import chex +from .deepseek_configuration import DeepseekV2Config +from ..easydel_modelling_utils import EasyDeLFlaxPretrainedModel + +re_mat = nn_partitioning.remat + + +@flax.struct.dataclass +class MoeModelOutput: + last_hidden_state: chex.Array = None + hidden_states: Optional[Tuple[chex.Array]] = None + attentions: Optional[Tuple[chex.Array]] = None + router_logits: Optional[Tuple[chex.Array]] = None + + +@flax.struct.dataclass +class MoeCausalLMOutput(FlaxMaskedLMOutput): + aux_loss: Optional[chex.Array] = None + router_logits: Optional[Tuple[chex.Array]] = None + + +class DeepseekV2RMSNorm(nn.Module): + dim: int + eps: float = 1e-6 + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + + def setup(self) -> None: + self.weight = self.param( + 'kernel', + nn.initializers.ones, + (self.dim,), + self.param_dtype, + ) + + def _norm(self, x: jnp.ndarray) -> jnp.ndarray: + return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = x.astype(jnp.promote_types(self.dtype, jnp.float32)) + output = self._norm(x).astype(self.dtype) + weight = fjformer.linen.linen.control_quantization(self.weight, self.dtype) + return output * weight + + +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1., mscale=1.): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (jnp.arange(dim, dtype=jnp.float32) - min) / (max - min) + return jnp.clip(linear_func, 0, 1) + + +def init_deepseek_rotary_embedding( + dim, + max_position_embeddings=2048, + base=10000, + method: typing.Literal["linear", "yarn", "dynamic", None] = None, + kwargs: typing.Optional[dict] = None, +): + if method is None: + inv_freq = 1.0 / (base ** (jnp.arange(0, dim, 2).astype("float32") / dim)) + t = jnp.arange(max_position_embeddings, dtype=inv_freq.dtype) + freqs = jnp.outer(t, inv_freq) + emb = jnp.concatenate((freqs, freqs), axis=-1) + return jnp.sin(emb), jnp.cos(emb) + elif method == "linear": + assert kwargs is not None + inv_freq = 1.0 / (base ** (jnp.arange(0, dim, 2).astype("float32") / dim)) + t = jnp.arange(max_position_embeddings, dtype=inv_freq.dtype) / kwargs.get("scaling_factor") + freqs = jnp.outer(t, inv_freq) + emb = jnp.concatenate((freqs, freqs), axis=-1) + return jnp.sin(emb), jnp.cos(emb) + elif method == "dynamic": + assert kwargs is not None + targeted_len = kwargs.get("targeted_len", max_position_embeddings) + if targeted_len > max_position_embeddings: + base = base * ( + (kwargs.get("scaling_factor") * targeted_len / max_position_embeddings) + - (kwargs.get("scaling_factor") - 1) + ) ** (dim / (dim - 2)) + inv_freq = 1.0 / ( + base ** (jnp.arange(0, dim, 2).astype("float32") / dim) + ) + + else: + inv_freq = 1.0 / (base ** (jnp.arange(0, dim, 2).astype("float32") / dim)) + t = jnp.arange(max_position_embeddings, dtype=inv_freq.dtype) / kwargs.get("scaling_factor") + + freqs = jnp.outer(t, inv_freq) + emb = jnp.concatenate((freqs, freqs), axis=-1) + return jnp.sin(emb), jnp.cos(emb) + elif method == "yarn": + + scaling_factor = kwargs.get("scaling_factor", 1.0) + original_max_position_embeddings = kwargs.get("original_max_position_embeddings", 4096) + beta_fast = kwargs.get("beta_fast", 32) + beta_slow = kwargs.get("beta_slow", 1) + mscale = kwargs.get("mscale", 1) + mscale_all_dim = kwargs.get("mscale_all_dim", 0) + freq_extra = 1.0 / ( + base + ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim) + ) + freq_inter = 1.0 / ( + scaling_factor + * base + ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim) + ) + + low, high = yarn_find_correction_range( + beta_fast, + beta_slow, + dim, + base, + original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).astype("float32") + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + t = jnp.arange(max_position_embeddings, dtype=jnp.float32) + + freqs = jnp.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(scaling_factor, mscale_all_dim) + ) + + emb = jnp.concatenate((freqs, freqs), axis=-1) + return (jnp.sin(emb) * _mscale).astype("float32"), (jnp.cos(emb) * _mscale).astype("float32") + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return jnp.concatenate((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + cos = jnp.expand_dims(cos[position_ids], unsqueeze_dim) + sin = jnp.expand_dims(sin[position_ids], unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(0, 1, 2, 4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(0, 1, 2, 4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class FlaxDeepseekV2MLP(nn.Module): + config: DeepseekV2Config + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") + hidden_size: Optional[int] = None + intermediate_size: Optional[int] = None + + def setup(self) -> None: + dense = functools.partial( + nn.Linear, + use_bias=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + kernel_init=nn.initializers.normal(), + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + self.gate_proj = dense(self.intermediate_size or self.config.intermediate_size) + self.up_proj = dense(self.intermediate_size or self.config.intermediate_size) + self.down_proj = dense(self.hidden_size or self.config.hidden_size) + self.act_fn = ACT2FN[self.config.hidden_act] + + def __call__( + self, + x: chex.Array, + e: bool = False # Ignored + ): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class FlaxMoEGate(nn.Module): + config: DeepseekV2Config + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") + + def setup(self) -> None: + config = self.config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.alpha = config.aux_loss_alpha + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = self.param( + "kernel", + nn.initializers.kaiming_uniform(dtype=self.param_dtype), + (self.n_routed_experts, self.gating_dim) + ) + + def __call__(self, hidden_states, deterministic: bool = True): + bsz, seq_len, h = hidden_states.shape + hidden_states = hidden_states.reshape(-1, h) + logits = jax.lax.batch_matmul( + hidden_states.astype(jnp.float32), + self.weight.astype(jnp.float32), + precision=self.precision + ) + if self.scoring_func == "softmax": + scores = jax.nn.softmax(logits.astype(jnp.float32), axis=-1) + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + ### select top-k experts + if self.topk_method == "gready": + topk_weight, topk_idx = jax.lax.top_k( + scores, k=self.top_k + ) + elif self.topk_method == "group_limited_greedy": + group_scores = scores.reshape(bsz * seq_len, self.n_group, -1).max(axis=-1) # [n, n_group] + + # Find the indices of the top k scores in each group + top_k_indices = lax.top_k(group_scores, self.topk_group)[1] # [n, topk_group] + + # Initialize a mask with zeros + group_mask = jnp.zeros_like(group_scores) # [n, n_group] + + # Update the mask: this is a bit tricky in JAX as there is no direct scatter function + n_indices = jnp.arange(group_mask.shape[0])[:, None] + group_mask = group_mask.at[n_indices, top_k_indices].set(1) # [n, n_group] + + # Expand and reshape the group_mask + score_mask = jnp.repeat(group_mask[:, :, None], self.n_routed_experts // self.n_group, axis=2) + score_mask = score_mask.reshape(bsz * seq_len, -1) # [n, e] + + # Apply the mask to scores + masked_scores = jnp.where(score_mask, scores, 0.0) # [n, e] + + # Compute the top k scores after masking + topk_weight, topk_idx = lax.top_k(masked_scores, self.top_k) + else: + raise ValueError() + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = jnp.sum(topk_weight, axis=-1, keepdims=True) + 1e-20 + topk_weight = topk_weight / denominator + else: + topk_weight = topk_weight * self.routed_scaling_factor + ### expert-level computation auxiliary loss + if not deterministic and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + topk_idx_for_aux_loss = topk_idx.reshape(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.reshape(bsz, seq_len, -1) + ce = jnp.zeros(bsz, self.n_routed_experts) + ce = ce.at[1, topk_idx_for_aux_loss].add( + jnp.ones(bsz, seq_len * aux_topk), + ) + ce = jnp.divide(ce, (seq_len * aux_topk / self.n_routed_experts)) + aux_loss = jnp.mean(jnp.sum((ce * jnp.mean(scores_for_seq_aux, axis=-1)), axis=1)) * self.alpha + else: + mask_ce = jax.nn.one_hot( + topk_idx_for_aux_loss.reshape(-1), num_classes=self.n_routed_experts + ) + ce = jnp.mean(mask_ce.astype("float32"), axis=0) + Pi = jnp.mean(scores_for_aux, axis=0) + fi = ce * self.n_routed_experts + aux_loss = jnp.sum(Pi * fi) * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +class FlaxDeepseekV2MLPCollection(nn.Module): + config: DeepseekV2Config + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") + + def setup(self) -> None: + self.experts = [ + FlaxDeepseekV2MLP( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + intermediate_size=self.config.moe_intermediate_size, + name=str(i) + ) + for i in range(self.config.n_routed_experts) + ] + + def __call__(self, hidden_states, flat_topk_idx): + y = jnp.empty_like(hidden_states) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + return y + + +class FlaxDeepseekV2MoE(nn.Module): + config: DeepseekV2Config + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") + + def setup(self) -> None: + config = self.config + self.num_experts_per_tok = config.num_experts_per_tok + + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = FlaxDeepseekV2MLPCollection( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.gate = FlaxMoEGate( + config=config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = FlaxDeepseekV2MoE( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + intermediate_size=intermediate_size, + ) + + def __call__( + self, + hidden_states: chex.Array, + e: bool = False # ignored ! + ): + + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.reshape(-1) + hidden_states = hidden_states.repeat(self.num_experts_per_tok, axis=0) + y = self.experts(hidden_states=hidden_states, flat_topk_idx=flat_topk_idx) + y = (y.reshape(*topk_weight.shape, -1) * jnp.expand_dims(topk_weight, -1)).sum(axis=1) + y = y.reshape(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + +class FlaxDeepseekV2Attention(BaseJAXAttentionModule): + config: DeepseekV2Config + layer_idx: int + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") + + def setup(self) -> None: + config = self.config + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + dense_class = functools.partial( + nn.Linear, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + + self.q_a_proj = dense_class(config.q_lora_rank, use_bias=config.attention_bias) + self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank) + self.q_b_proj = dense_class(self.num_heads * self.q_head_dim, use_bias=False) + + self.kv_a_proj_with_mqa = dense_class( + config.kv_lora_rank + config.qk_rope_head_dim, + use_bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank) + self.kv_b_proj = dense_class( + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + use_bias=False, + ) + + self.o_proj = dense_class(self.hidden_size, use_bias=config.attention_bias) + + softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + softmax_scale = self.softmax_scale * mscale * mscale + + self.attention_performer = AttentionModule( + use_sharding_constraint=self.config.use_sharding_constraint, + block_k_major=self.config.block_k_major, + block_b=self.config.block_b, + block_q=self.config.block_q, + block_k=self.config.block_k, + block_q_major_dkv=self.config.block_q_major_dkv, + block_k_major_dkv=self.config.block_k_major_dkv, + block_k_major_dq=self.config.block_k_major_dq, + block_k_dkv=self.config.block_k_dkv, + block_q_dkv=self.config.block_q_dkv, + block_q_dq=self.config.block_q_dq, + block_k_dq=self.config.block_k_dq, + num_attention_heads=self.config.num_attention_heads, + attention_dropout=self.config.attention_dropout, + head_dims=self.q_head_dim, + attention_partition_spec=self.config.attention_partition_spec, + shard_attention_computation=self.config.shard_attention_computation, + precision=self.precision, + force_float32_tpu=True, + attn_mechanism=self.config.attn_mechanism, + dtype=self.dtype, + bias_partition_spec=self.config.bias_partition_spec, + key_partition_spec=self.config.key_partition_spec, + query_partition_spec=self.config.query_partition_spec, + generation_query_partition_spec=self.config.generation_query_partition_spec, + generation_bias_partition_spec=self.config.generation_bias_partition_spec, + generation_attention_partition_spec=self.config.generation_attention_partition_spec, + value_partition_spec=self.config.value_partition_spec, + scan_ring_attention=self.config.scan_ring_attention, + mesh=self.config.jax_mesh(), + sm_scale=softmax_scale, + axis_name=self.config.attention_axis_name + ) + + def __call__( + self, + hidden_states: chex.Array, + freq_cis: Tuple[chex.Array, chex.Array], + attention_mask: chex.Array, + position_ids: chex.Array, + causal_mask: chex.Array, + segment_ids: Optional[chex.Array] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + fcm_mask=None, + ): + bsz, q_len, _ = hidden_states.shape + + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) + q_nope = q[:, :, :, :self.qk_nope_head_dim] + q_pe = q[:, :, :, self.qk_nope_head_dim:] + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + + k_pe = compressed_kv[:, :, :, self.kv_lora_rank:self.kv_lora_rank + self.qk_rope_head_dim] + compressed_kv = compressed_kv[:, :, :, :self.kv_lora_rank] + + k_pe = k_pe.reshape(bsz, q_len, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .reshape(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(0, 2, 1, 3) + ) + + k_nope = kv[:, :, :, :self.qk_nope_head_dim] + value_states = kv[:, :, :, self.qk_nope_head_dim:self.qk_nope_head_dim + self.v_head_dim] + + sin, cos = freq_cis + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim:] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim:] = k_pe + + query_states = query_states.transpose(0, 2, 1, 3) + key_states = key_states.transpose(0, 2, 1, 3) + value_states = value_states.transpose(0, 2, 1, 3) + + query_length, key_length = query_states.shape[1], key_states.shape[1] + + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + causal_mask, (0, 0, mask_shift, 0), (1, 1, + query_length, max_decoder_length) + ) + else: + causal_mask = causal_mask[:, :, :query_length, :key_length] + + batch_size = hidden_states.shape[0] + causal_mask = jnp.broadcast_to( + causal_mask, (batch_size,) + causal_mask.shape[1:]) + if attention_mask.ndim == 2: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + attention_mask = jnp.broadcast_to( + attention_mask, causal_mask.shape + ) + attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask) + + dropout_rng = None + + if not deterministic and self.config.attention_dropout > 0.0: + dropout_rng = self.make_rng("dropout") + if self.has_variable("cache", "cached_key") or init_cache: + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, + value_states, + query_states, + attention_mask + ) + # if self.config.use_sharding_constraint: + # query_states = with_sharding_constraint( + # query_states, PartitionSpec(("dp", "fsdp"), "sp" if query_states.shape[1] != 1 else None, "tp", None) + # ) + # key_states = with_sharding_constraint( + # key_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None) + # ) + # value_states = with_sharding_constraint( + # value_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None) + # ) + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo( + self.dtype).min).astype(self.dtype), + ) + + query_length, key_length = query_states.shape[1], key_states.shape[1] + + attentions = self.attention_performer.__call__( + query_states=query_states, + key_states=key_states, + value_states=value_states, + bias=attention_bias, + attention_mask=attention_mask, + causal=True, + dropout_rng=dropout_rng, + deterministic=deterministic, + query_sequence_length=query_length, + key_value_sequence_length=key_length, + uses_cache=self.has_variable("cache", "cached_key") or init_cache, + segment_ids=segment_ids, + causal_mask=causal_mask + ) + + attn_output = self._merge_heads(attentions.attention_outputs) + if self.config.shard_attention_computation: + attn_output = with_sharding_constraint( + attn_output, PartitionSpec( + ("dp", "fsdp"), + "sp" if attn_output.shape[1] != 1 else None, + "tp" + ) + ) + attn_output = self.o_proj(attn_output) + + outputs = ( + attn_output, attentions.attention_weights + ) if output_attentions else ( + attn_output, + ) + return outputs + + +class FlaxDeepseekV2DecoderLayer(nn.Module): + config: DeepseekV2Config + layer_idx: int + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") + + def setup(self): + config = self.config + layer_idx = self.layer_idx + self.hidden_size = config.hidden_size + + attn_block = FlaxDeepseekV2Attention + mlp_block = FlaxDeepseekV2MLP + mlp_moe_block = FlaxDeepseekV2MoE + + if self.config.gradient_checkpointing != "": + # hidden_states: chex.Array, + # freq_cis: Tuple[chex.Array, chex.Array], + # attention_mask: chex.Array, + # position_ids: chex.Array, + # causal_mask: chex.Array, + # segment_ids: Optional[chex.Array] = None, + # deterministic: bool = True, + # init_cache: bool = False, + # output_attentions: bool = False, + # fcm_mask = None, + attn_block = re_mat( + attn_block, + policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing), + static_argnums=(1, 3, 4, 6, 7, 8) + ) + mlp_block = re_mat( + mlp_block, + policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing), + static_argnums=(1,) + ) + + mlp_moe_block = re_mat( + mlp_moe_block, + policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing), + static_argnums=(1,) + ) + + self.self_attn = attn_block( + config=config, + layer_idx=self.layer_idx, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + self.mlp = ( + mlp_moe_block( + config=config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else mlp_block( + config=config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + ) + self.input_layernorm = DeepseekV2RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + self.post_attention_layernorm = DeepseekV2RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + + def forward( + self, + hidden_states: chex.Array, + freq_cis: Tuple[chex.Array, chex.Array], + attention_mask: chex.Array, + causal_mask: chex.Array, + position_ids: chex.Array, + segment_ids: Optional[chex.Array] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = True + ) -> Tuple[ + chex.Array, Optional[chex.Array] + ]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states, + freq_cis, + attention_mask, + position_ids, + causal_mask, + segment_ids, + deterministic, + init_cache, + output_attentions, + None + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.config.use_scan_mlp: + feed_forward_hidden_states = block_wise_ffn( + self.mlp, + hidden_states, + self.config.scan_mlp_chunk_size, + deterministic, + ) + else: + feed_forward_hidden_states = self.mlp( + hidden_states, + deterministic, + ) + hidden_states = residual + feed_forward_hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs # type:ignore + + +class FlaxDeepseekV2DecoratorCollection(nn.Module): + config: DeepseekV2Config + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[Union[str, jax.lax.Precision] + ] = jax.lax.Precision("fastest") + + def setup(self) -> None: + self.layers = [ + FlaxDeepseekV2DecoderLayer( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + layer_idx=i, + name=str(i) + ) for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states: chex.Array, + freq_cis: Tuple[chex.Array, chex.Array], + attention_mask: chex.Array, + causal_mask: chex.Array, + position_ids: chex.Array, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # hidden_states: chex.Array, + # freq_cis: Tuple[chex.Array, chex.Array], + # attention_mask: chex.Array, + # causal_mask: chex.Array, + # position_ids: chex.Array, + # segment_ids: Optional[chex.Array] = None, + # deterministic: bool = True, + # init_cache: bool = False, + # output_attentions: bool = True + + output = layer( + hidden_states, + freq_cis, + attention_mask, + causal_mask, + position_ids, + None, + deterministic, + init_cache, + output_attentions + ) + hidden_states = output[0] + + if output_attentions: + output_attentions += (output[1],) + + return hidden_states, all_hidden_states, all_attentions + + +class FlaxDeepseekV2Module(nn.Module): + config: DeepseekV2Config + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + + self.embed_tokens = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal( + stddev=self.config.initializer_range), + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + + self.layers = FlaxDeepseekV2DecoratorCollection( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.norm = DeepseekV2RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + + initial_rope_kwargs = {} + method = None + if self.config.rope_scaling is not None: + scaling_type = self.config.rope_scaling["type"] + method = scaling_type + if scaling_type != "yarn": + initial_rope_kwargs = dict(scaling_factor=self.config.rope_scaling["factor"]) + else: + initial_rope_kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + initial_rope_kwargs["scaling_factor"] = self.config.rope_scaling["factor"] + self.freq_cis = init_deepseek_rotary_embedding( + dim=self.config.hidden_size // self.config.num_attention_heads, + max_position_embeddings=( + getattr( + self.config, + "freq_max_position_embeddings", + self.config.max_position_embeddings + ) + ), + base=self.config.rope_theta, + method=method, # type:ignore + kwargs=initial_rope_kwargs + ) + self.causal_mask = flax.linen.make_causal_mask( + jnp.ones( + ( + 1, + getattr( + self.config, + "c_max_position_embeddings", + self.config.max_position_embeddings + ) + ), + dtype="bool" + ), + dtype="bool" + ) + + def __call__( + self, + input_ids: Optional[chex.Array] = None, + attention_mask: Optional[chex.Array] = None, + position_ids: Optional[chex.Array] = None, + deterministic: bool = True, + inputs_embeds: chex.Array = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> typing.Union[Tuple[chex.Array, ...], FlaxBaseModelOutput]: + """The __call__ function is the main function of a Flax model. + It takes in input_ids, attention_mask, and position_ids as inputs to the model. + The output is a tuple containing: last hidden state (hidden states), all hidden states (if output_hidden_states=True), attentions (if output attentions=True). + + Args: + self: Represent the instance of the class + input_ids: chex.Array: Pass in the input ids + attention_mask: chex.Array: Mask out the attention weights + for certain tokens + position_ids: chex.Array: Determine the position of each + token in a sequence + deterministic: bool: Determine whether to use dropout or not + inputs_embeds: chex.Array: Pass in the embedding of the + input_ids + init_cache: bool: Initialize the cache for the decoder + output_attentions: bool: Determine whether to return the + attention weights or not + output_hidden_states: bool: Return all hidden states or just + the last one + return_dict: bool: Return a dictionary of the outputs or not + :param : Determine whether the model is in training mode or not + + Returns: + A tuple of the hidden states, all hidden states, and + attentions + """ + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids.astype("i4")) + if attention_mask.ndim == 2: + b, s = attention_mask.shape + attention_mask = attention_mask.reshape(b, 1, 1, s) + + outputs = self.layers( + hidden_states=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + freq_cis=self.freq_cis, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + causal_mask=self.causal_mask + ) + + hidden_states = outputs[0] + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = outputs[1] + (hidden_states,) + outputs = (hidden_states, all_hidden_states) + outputs[2:] + else: + outputs = (hidden_states,) + outputs[1:] + + if not return_dict: + return tuple(value for value in outputs if value is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=outputs[1], + attentions=outputs[-1], + ) + + +class DeepseekV2PreTrainedModel(EasyDeLFlaxPretrainedModel): + config_class: DeepseekV2Config = DeepseekV2Config + module_class: nn.Module = None + base_model_prefix = "model" + + def __init__( + self, + config: DeepseekV2Config, + dtype: jnp.dtype = jnp.bfloat16, + param_dtype: jnp.dtype = jnp.bfloat16, + precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest"), + input_shape: Tuple[int, int] = (1, 1), + seed: int = 0, + _do_init: bool = False, + **kwargs + ): + module = self.module_class( + config=config, + dtype=dtype, + param_dtype=param_dtype, + precision=precision, + **kwargs + ) + + super().__init__( + dtype=dtype, _do_init=_do_init, + module=module, config=config, input_shape=input_shape, + seed=seed, + ) + + def init_weights( + self, + rng: jax.random.PRNGKey, + input_shape: Tuple, + params: FrozenDict = None + ) -> FrozenDict: + """The init_weights function is used to initialize the weights of a model. + It takes in a rng, which is a random number generator key that can be used to generate random numbers. + The input_shape parameter specifies the shape of the inputs that will be fed into this model. + The params parameter allows you to pass in pre-trained weights for your model, if you have them available. + + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Initialize the input_ids, attention_mask + and position_ids + params: flax.core.FrozenDict: Pass in the parameters of a + pre-trained model + + Returns: + A frozendict of parameters + """ + + self.config.initialization_of_moe = True + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1], dtype="i4"), + input_shape, + ) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros( + input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=False + ) + random_params = module_init_outputs["params"] + + self.config.initialization_of_moe = False + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange( + jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return init_variables["cache"] + + def __call__( + self, + input_ids: chex.Array, + attention_mask: Optional[chex.Array] = None, + position_ids: Optional[chex.Array] = None, + params: dict = None, + past_key_values: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + add_params_field: bool = False, + **kwargs + ): + """The __call__ function is the main function of a JAX module. + It takes as input: + - The parameters of the model (self.params) + - The inputs to the model (input_ids, attention_mask, position_ids) + - Whether we are training (train=True/False) and whether we want to return all hidden states and + attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False). + + Args: + self: Represent the instance of the class + input_ids: Pass the input sequence to the model + attention_mask: Mask out the padding tokens + position_ids: Specify the position of each token in the + sequence + params: dict: Pass in the parameters of the model + past_key_values: dict: Pass the past key values to the model + dropout_rng: jax.random.PRNGKey: Pass in a random number + generator key to the model + train: bool: Determine whether to use dropout or not + output_attentions: Optional[bool]: Determine whether to + return the attention weights + output_hidden_states: Optional[bool]: Determine whether to + return the hidden states of all layers + return_dict: Optional[bool]: Return a dictionary of the + outputs + add_params_field: bool: Add a params field to the inputs + dictionary + + Returns: + A tuple of (last_hidden_state, past_key_values) + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + batch_size, sequence_length = input_ids.shape + + if position_ids is None: + if past_key_values is not None: + raise ValueError( + "Make sure to provide `position_ids` when passing `past_key_values`.") + + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[ + None, :], (batch_size, sequence_length)) + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + rng_s = {} + if dropout_rng is not None: + rng_s["dropout"] = dropout_rng + + inputs = { + "params": params or self.params} if add_params_field else params or self.params + + if self.config.bits is not None: + rng_s['params'] = jax.random.key(0) + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), # input_ids: chex.Array + # attention_mask: Optional[chex.Array] = None + jnp.array(attention_mask, dtype="i4"), + # position_ids: Optional[chex.Array] = None + jnp.array(position_ids, dtype="i4"), + None, # inputs_embeds: Optional[chex.Array] = None + output_attentions, # output_attentions: Optional[bool] = None + # output_hidden_states: Optional[bool] = None + output_hidden_states, + # output_router_logits: Optional[bool] = None + output_router_logits, + False, # init_cache: bool = False + not train, # deterministic: bool = True + return_dict, # return_dict: bool = True + rngs=rng_s, + mutable=mutable, + ) + + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + \ + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxDeepseekV2ForCausalLMModule(nn.Module): + config: DeepseekV2Config + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest") + + def setup(self) -> None: + self.model = FlaxDeepseekV2Module( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.lm_head = nn.Linear( + self.config.vocab_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + use_bias=False, + kernel_init=nn.initializers.normal(self.config.initializer_range), + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + + def __call__( + self, + input_ids: chex.Array, + attention_mask: chex.Array, + position_ids: chex.Array, + deterministic: bool = True, + inputs_embeds: chex.Array = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + """The __call__ function is the main function of a Flax module. It defines how the model will be called, + and what it returns. In this case, we are calling our Transformer model with input_ids and attention_mask + as inputs (these are defined in __init__). We also have some optional arguments that can be passed to + the call function: deterministic (whether to use dropout), inputs_embeds (if you want to pass your own embeddings), + output_attentions and output_hidden states which return additional outputs from the transformer layers if set True. Finally, + + Args: + self: Refer to the object itself + input_ids: chex.Array: Pass in the input tokens + attention_mask: chex.Array: Mask out the padding tokens + position_ids: chex.Array: Specify the position of each token + in the sequence + deterministic: bool: Determine whether to use dropout in the + model + inputs_embeds: chex.Array: Pass in the embeddings of the + input tokens + init_cache: bool: Initialize the cache for the decoder + output_attentions: bool: Return the attention weights + output_hidden_states: bool: Return the hidden states of all + layers + return_dict: bool: Return a dictionary of the outputs or + just the logits + :param : Determine whether to return the logits or not + + Returns: + A tuple of (lm_logits, hidden_states, attentions) + """ + batch_size, seq_length = input_ids.shape + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + position_ids = jnp.broadcast_to( + jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), + (batch_size, seq_length) + ) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + deterministic=deterministic, + inputs_embeds=inputs_embeds, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_kernel = self.transformer.variables["params"]["embed_tokens"]["embedding"] + shared_kernel = fjformer.linen.linen.control_quantization(shared_kernel, self.param_dtype).T + lm_logits = self.lm_head.apply( + {"params": {"kernel": shared_kernel}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + # lm_logits = lm_logits.astype(jnp.float32) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + +class FlaxDeepseekV2Model(DeepseekV2PreTrainedModel): + module_class = FlaxDeepseekV2Module + + +class FlaxDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): + module_class = FlaxDeepseekV2ForCausalLMModule + + def set_input_embeddings(self, value): + self.module.model.embed_tokens = value + + def get_input_embeddings(self): + return self.module.model.embed_tokens + + def set_decoder(self, decoder): + self.module.model = decoder + + def get_decoder(self): + return self.module.model + + def get_output_embeddings(self): + return self.module.lm_head + + def set_output_embeddings(self, new_embeddings): + self.module.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None): + """The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. + + Args: + self: Access variables that belong to the class + input_ids: Pass in the input tokens + max_length: Set the length of the sequence to be generated + attention_mask: Optional[chex.Array]: Mask the attention + weights + + Returns: + A dictionary of the past_key_values, attention_mask and + position ids + """ + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + extended_attention_mask = jnp.ones( + (batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice( + extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[ + None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs diff --git a/src/python/easydel/modules/easydel_modelling_utils.py b/src/python/easydel/modules/easydel_modelling_utils.py index 9bfc572e8..4df9f85e8 100644 --- a/src/python/easydel/modules/easydel_modelling_utils.py +++ b/src/python/easydel/modules/easydel_modelling_utils.py @@ -43,32 +43,42 @@ class EasyMethod: class EasyDeLPretrainedConfig(PretrainedConfig): - """ - It initializes all the attributes of an object, and it's called when you create a new instance of that class. - :param self: Refer to the instance of the class - :param axis_dims: Sequence[int]: Specify the number of dimensions for each axis - :param axis_names: Sequence[str]: Set the names of the axes - :param attn_mechanism: Literal["vanilla", "flash", "splash", "ring"]: attention mechanism to use - :param block_k: int: block size of key_states - :param block_q: int: block size of query_states - :param block_b: int: block size of bias - :param block_q_major_dkv: int: block size of block_q_major_dkv - :param block_k_major_dkv: int: block size of block_k_major_dkv - :param block_k_dkv: int: block size of block_k_dkv - :param block_q_dkv: int: block size of block_q_dkv - :param block_k_major_dq: int: block size of block_k_major_dq - :param block_k_dq: int: block size of block_k_dq - :param block_q_dq: int: block size of block_q_dq - :param query_partition_spec: PartitionSpec: Specify the partitioning of the query tensor - :param key_partition_spec: PartitionSpec: Partition the key matrix - :param value_partition_spec: PartitionSpec: Specify the partitioning of the value tensor - :param bias_partition_spec: PartitionSpec: Specify the Attention Bias partition spec - :param attention_partition_spec: PartitionSpec: Specify the partitioning of the attention weights - :param shard_attention_computation: bool: whenever to shard qkv b for attention - :param use_sharding_constraint: bool: whether to use sharding constraint for the arrays - :param use_scan_mlp: bool: Determine whether to use scan_mlp or not - :param backend: Optional[None]: Specify the backend to use - :param flash_attention_backward_pass_impl: Literal["triton", "xla"]: Specify the backward pass kernel for flash attention + """It initializes all the attributes of an object, and it's called when you create a new instance of that class. + + Args: + self: Refer to the instance of the class + axis_dims: Sequence[int]: Specify the number of dimensions for + each axis + axis_names: Sequence[str]: Set the names of the axes + attn_mechanism: Literal["vanilla", "flash", "splash", "ring"]: + attention mechanism to use + block_k: int: block size of key_states + block_q: int: block size of query_states + block_b: int: block size of bias + block_q_major_dkv: int: block size of block_q_major_dkv + block_k_major_dkv: int: block size of block_k_major_dkv + block_k_dkv: int: block size of block_k_dkv + block_q_dkv: int: block size of block_q_dkv + block_k_major_dq: int: block size of block_k_major_dq + block_k_dq: int: block size of block_k_dq + block_q_dq: int: block size of block_q_dq + query_partition_spec: PartitionSpec: Specify the partitioning of + the query tensor + key_partition_spec: PartitionSpec: Partition the key matrix + value_partition_spec: PartitionSpec: Specify the partitioning of + the value tensor + bias_partition_spec: PartitionSpec: Specify the Attention Bias + partition spec + attention_partition_spec: PartitionSpec: Specify the + partitioning of the attention weights + shard_attention_computation: bool: whenever to shard qkv b for + attention + use_sharding_constraint: bool: whether to use sharding + constraint for the arrays + use_scan_mlp: bool: Determine whether to use scan_mlp or not + backend: Optional[None]: Specify the backend to use + flash_attention_backward_pass_impl: Literal["triton", "xla"]: + Specify the backward pass kernel for flash attention """ def __init__( @@ -151,14 +161,15 @@ def __init__( def create_mesh( axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"), backend="" ): - """ - The create_mesh function creates a mesh object that can be used to shard arrays. + """The create_mesh function creates a mesh object that can be used to shard arrays. - :param axis_dims: Sequence[int]: Specify the dimensions of the mesh - :param axis_names: Sequence[str]: Name the axes of the mesh - :param backend: Specify the backend to use - :return: A mesh object + Args: + axis_dims: Sequence[int]: Specify the dimensions of the mesh + axis_names: Sequence[str]: Name the axes of the mesh + backend: Specify the backend to use + Returns: + A mesh object """ array_devices = jax.numpy.ones( (len(jax.devices() if backend == "" else jax.devices(backend)), 1)) @@ -183,14 +194,15 @@ def create_mesh( ) def jax_mesh(self) -> Mesh: - """ - The jax_mesh function is a helper function that creates a Mesh object from the + """The jax_mesh function is a helper function that creates a Mesh object from the axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively. The backend attribute is also used if it exists. - :param self: Refer to the object itself - :return: A jaxMesh + Args: + self: Refer to the object itself + Returns: + A jaxMesh """ return self.create_mesh( axis_dims=[v for k, v in self.axis_dims.items()] if isinstance( @@ -207,12 +219,15 @@ def jax_mesh(self) -> Mesh: def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to specify how the parameters of a model are partitioned across devices. + """The get_partition_rules function is used to specify how the parameters of a model are partitioned across devices. - :param self: Access the attributes of the class - :param fully_sharded_data_parallel: bool: Determine whether the model is fully sharded or not - :return: A tuple of tuples + Args: + self: Access the attributes of the class + fully_sharded_data_parallel: bool: Determine whether the + model is fully sharded or not + + Returns: + A tuple of tuples """ if not fully_sharded_data_parallel: raise NotImplementedError() @@ -222,33 +237,36 @@ def get_partition_rules(self, fully_sharded_data_parallel: bool = True): ) def get_axis_dims(self) -> Sequence[int]: - """ - The get_axis_dims function returns a sequence of integers representing the dimensions of each axis. + """The get_axis_dims function returns a sequence of integers representing the dimensions of each axis. - :param self: Represent the instance of the class - :return: The dimensions of the axes + Args: + self: Represent the instance of the class + Returns: + The dimensions of the axes """ return self.axis_dims def get_axis_names(self) -> Sequence[str]: - """ - The get_axis_names function returns a list of the names of the axes. + """The get_axis_names function returns a list of the names of the axes. - :param self: Represent the instance of the class - :return: A list of the names of all axes + Args: + self: Represent the instance of the class + Returns: + A list of the names of all axes """ return self.axis_names def get_backend(self) -> str: - """ - The get_backend function returns the backend that is currently being used. + """The get_backend function returns the backend that is currently being used. If no backend has been set, it will return the default JAX backend. - :param self: Bind the method to an object - :return: The backend platform + Args: + self: Bind the method to an object + Returns: + The backend platform """ return self.backend if not self.backend == "" else jax.lib.xla_bridge.get_backend().platform @@ -290,47 +308,66 @@ def add_basic_configurations( quantize_kv_cache: bool = ..., flash_attention_backward_pass_impl: Literal["triton", "xla"] = ... ): - """ - It initializes all the attributes of an object, and it's called when you create a new instance of that class. - :param self: Refer to the instance of the class - :param axis_dims: Sequence[int]: Specify the number of dimensions for each axis - :param axis_names: Sequence[str]: Set the names of the axes - :param attn_mechanism: Literal["vanilla", "flash", "splash"]: attention mechanism to use - :param block_k: int: block size of key_states - :param block_q: int: block size of query_states - :param block_b: int: block size of bias - :param block_k_major: int: block size if key major - :param block_q_major_dkv: int: block size of block_q_major_dkv - :param block_k_major_dkv: int: block size of block_k_major_dkv - :param block_k_dkv: int: block size of block_k_dkv - :param block_q_dkv: int: block size of block_q_dkv - :param block_k_major_dq: int: block size of block_k_major_dq - :param block_k_dq: int: block size of block_k_dq - :param block_q_dq: int: block size of block_q_dq - :param query_partition_spec: PartitionSpec: Specify the partitioning of the query tensor - :param key_partition_spec: PartitionSpec: Partition the key matrix - :param value_partition_spec: PartitionSpec: Specify the partitioning of the value tensor - :param bias_partition_spec: PartitionSpec: Specify the Attention Bias partition spec - :param attention_partition_spec: PartitionSpec: Specify the partitioning of the attention weights - :param generation_attention_partition_spec: : PartitionSpec: Specify the partitioning of the attention weights + """It initializes all the attributes of an object, and it's called when you create a new instance of that class. + + Args: + self: Refer to the instance of the class + axis_dims: Sequence[int]: Specify the number of dimensions + for each axis + axis_names: Sequence[str]: Set the names of the axes + attn_mechanism: Literal["vanilla", "flash", "splash"]: + attention mechanism to use + block_k: int: block size of key_states + block_q: int: block size of query_states + block_b: int: block size of bias + block_k_major: int: block size if key major + block_q_major_dkv: int: block size of block_q_major_dkv + block_k_major_dkv: int: block size of block_k_major_dkv + block_k_dkv: int: block size of block_k_dkv + block_q_dkv: int: block size of block_q_dkv + block_k_major_dq: int: block size of block_k_major_dq + block_k_dq: int: block size of block_k_dq + block_q_dq: int: block size of block_q_dq + query_partition_spec: PartitionSpec: Specify the + partitioning of the query tensor + key_partition_spec: PartitionSpec: Partition the key matrix + value_partition_spec: PartitionSpec: Specify the + partitioning of the value tensor + bias_partition_spec: PartitionSpec: Specify the Attention + Bias partition spec + attention_partition_spec: PartitionSpec: Specify the + partitioning of the attention weights + generation_attention_partition_spec: : PartitionSpec: + Specify the partitioning of the attention weights + generation_bias_partition_spec: : PartitionSpec: Specify the + partitioning of the Attention Bias partition spec in + generation process + generation_query_partition_spec: : PartitionSpec: Specify + the partitioning of the query tensor + shard_attention_computation: bool: whenever to use shard_map + for attention + use_sharded_kv_caching: bool: whenever to use shard_map and + sharding for key and value + backend: Optional[None]: Specify the backend to use + easy_method: Literal["train", "serve", "convert"]: easydel + Quantization Method to be applied for + bits: Optional[int]: Model bits for quantization + use_sharding_constraint: bool: whether to use sharding + constraint for the arrays + scan_ring_attention: bool: Whether to use can for ring + attention + scan_attention_layers: bool: Whether to use can for + attention layers + use_scan_mlp: bool: Determine whether to use scan_mlp or not + scan_mlp_chunk_size: int: Size of chunks in scan MLP. + attention_axis_name: str: Name of the attention axis name + quantize_kv_cache: bool: Whether to quantize Key/Value in + attention for generation process. + flash_attention_backward_pass_impl: Literal["triton", + "xla"]: Specify the backward pass kernel for flash + attention in generation process - :param generation_bias_partition_spec: : PartitionSpec: Specify the partitioning of the Attention Bias - partition spec in generation process - :param generation_query_partition_spec: : PartitionSpec: Specify the partitioning of the query tensor in generation process - :param shard_attention_computation: bool: whenever to use shard_map for attention - :param use_sharded_kv_caching: bool: whenever to use shard_map and sharding for key and value - :param backend: Optional[None]: Specify the backend to use - :param easy_method: Literal["train", "serve", "convert"]: easydel Quantization Method to be applied for - :param bits: Optional[int]: Model bits for quantization - :param use_sharding_constraint: bool: whether to use sharding constraint for the arrays - :param scan_ring_attention: bool: Whether to use can for ring attention - :param scan_attention_layers: bool: Whether to use can for attention layers - :param use_scan_mlp: bool: Determine whether to use scan_mlp or not - :param scan_mlp_chunk_size: int: Size of chunks in scan MLP. - :param attention_axis_name: str: Name of the attention axis name - :param quantize_kv_cache: bool: Whether to quantize Key/Value in attention for generation process. - :param flash_attention_backward_pass_impl: Literal["triton", "xla"]: Specify the backward pass kernel for flash attention """ set_attrs_smartly(self, "axis_dims", (1, -1, 1, 1), axis_dims) set_attrs_smartly(self, "axis_names", ("dp", "fsdp", "tp", "sp"), axis_names) @@ -417,14 +454,16 @@ def add_basic_configurations( def __repr__(self): - """ - The __repr__ function is used to generate a string representation of an object. + """The __repr__ function is used to generate a string representation of an object. This function should return a string that can be parsed by the Python interpreter to recreate the object. The __repr__ function is called when you use print() on an object, or when you type its name in the REPL. - :param self: Refer to the instance of the class - :return: A string representation of the object + Args: + self: Refer to the instance of the class + + Returns: + A string representation of the object """ string = f"{self.__class__.__name__}(\n" for k, v in self.__dict__.items(): @@ -442,12 +481,14 @@ def add_jax_args(self, **kwargs): def __str__(self): - """ - The __str__ function is called when you use the print function or when str() is used. + """The __str__ function is called when you use the print function or when str() is used. It should return a string representation of the object. - :param self: Refer to the instance of the class - :return: The object's string representation + Args: + self: Refer to the instance of the class + + Returns: + The object's string representation """ return self.__repr__() @@ -474,62 +515,72 @@ def __init__( ) def get_input_embeddings(self): - """ - The get_input_embeddings function returns the embedding layer of the model. + """The get_input_embeddings function returns the embedding layer of the model. + + Args: + self: Refer to the current object - :param self: Refer to the current object - :return: The embedding layer of the model + Returns: + The embedding layer of the model """ raise NotImplementedError() def set_input_embeddings(self, value): - """ - The set_input_embeddings function is used to set the embedding module of the model. + """The set_input_embeddings function is used to set the embedding module of the model. - :param self: Represent the instance of the class - :param value: Set the embeddings of the model + Args: + self: Represent the instance of the class + value: Set the embeddings of the model """ raise NotImplementedError() def get_output_embeddings(self): - """ - The get_output_embeddings function returns the output embeddings of a model. + """The get_output_embeddings function returns the output embeddings of a model. - :param self: Represent the instance of the class - :return: The output embeddings of the model + Args: + self: Represent the instance of the class + + Returns: + The output embeddings of the model """ raise NotImplementedError() def set_output_embeddings(self, new_embeddings): - """ - The set_output_embeddings function is used to set the output embeddings of a model. + """The set_output_embeddings function is used to set the output embeddings of a model. This function can be used to change the output embedding layer of a pretrained model in order to finetune it to some downstream task. Changing this layer has an effect only if the model has already been fine-tuned on some task (e.g., for classification). If you are training your own language models, you should call this function before you start training. - :param self: Represent the instance of the class - :param new_embeddings: Set the embeddings of the output layer - :return: A new embedding layer + Args: + self: Represent the instance of the class + new_embeddings: Set the embeddings of the output layer + + Returns: + A new embedding layer """ raise NotImplementedError() def set_decoder(self, decoder): - """ - The set_decoder function is used to set the decoder for a given encoder. + """The set_decoder function is used to set the decoder for a given encoder. + + Args: + self: Refer to the object itself + decoder: Set the decoder for a given encoder - :param self: Refer to the object itself - :param decoder: Set the decoder for a given encoder - :return: A decoder + Returns: + A decoder """ raise NotImplementedError() def get_decoder(self): - """ - The get_decoder function is used to create a decoder object. + """The get_decoder function is used to create a decoder object. + + Args: + self: Represent the instance of the class - :param self: Represent the instance of the class - :return: A decoder object + Returns: + A decoder object """ raise NotImplementedError() @@ -537,15 +588,18 @@ def init_cache(self, batch_size: int, max_length: int): raise NotImplementedError("init_cache is not Implemented Yet!") def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None): - """ - The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. + """The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. - :param self: Access variables that belong to the class - :param input_ids: Pass in the input tokens - :param max_length: Set the length of the sequence to be generated - :param attention_mask: Optional[chex.Array]: Mask the attention weights - :return: A dictionary of the past_key_values, attention_mask and position ids + Args: + self: Access variables that belong to the class + input_ids: Pass in the input tokens + max_length: Set the length of the sequence to be generated + attention_mask: Optional[chex.Array]: Mask the attention + weights + Returns: + A dictionary of the past_key_values, attention_mask and + position ids """ batch_size, seq_length = input_ids.shape @@ -592,14 +646,16 @@ def __call__( def __repr__(self): - """ - The __repr__ function is used to generate a string representation of an object. + """The __repr__ function is used to generate a string representation of an object. This function should return a string that can be parsed by the Python interpreter to recreate the object. The __repr__ function is called when you use print() on an object, or when you type its name in the REPL. - :param self: Refer to the instance of the class - :return: A string representation of the object + Args: + self: Refer to the instance of the class + + Returns: + A string representation of the object """ string = f"{self.__class__.__name__}(\n" for k, v in self.__dict__.items(): @@ -613,12 +669,14 @@ def __repr__(self): def __str__(self): - """ - The __str__ function is called when you use the print function or when str() is used. + """The __str__ function is called when you use the print function or when str() is used. It should return a string representation of the object. - :param self: Refer to the instance of the class - :return: The object's string representation + Args: + self: Refer to the instance of the class + + Returns: + The object's string representation """ return self.__repr__() diff --git a/src/python/easydel/modules/falcon/modelling_falcon_flax.py b/src/python/easydel/modules/falcon/modelling_falcon_flax.py index a66a97ec2..f4264668a 100644 --- a/src/python/easydel/modules/falcon/modelling_falcon_flax.py +++ b/src/python/easydel/modules/falcon/modelling_falcon_flax.py @@ -26,16 +26,20 @@ def built_bloom_alibi(attention_mask, num_attention_heads): - """ - The built_bloom_alibi function is used to create a bloom alibi for the attention mask. + """The built_bloom_alibi function is used to create a bloom alibi for the attention mask. The bloom alibi is used in the Bloom Attention layer to ensure that each token has a unique attention vector, even if it's masked out. This ensures that all tokens have an equal chance of being selected as the most important token in the sequence, which helps with training stability and performance. - :param attention_mask: Mask out the padding tokens in the input sequence - :param num_attention_heads: Determine the number of attention heads in the model - :return: A tensor of shape (batch_size, num_attention_heads, 1, sequence_length) - + Args: + attention_mask: Mask out the padding tokens in the input + sequence + num_attention_heads: Determine the number of attention heads in + the model + + Returns: + A tensor of shape (batch_size, num_attention_heads, 1, + sequence_length) """ batch_size, sequence_length = attention_mask.shape cp2 = 2 ** math.floor(math.log2(num_attention_heads)) @@ -57,18 +61,20 @@ def built_bloom_alibi(attention_mask, num_attention_heads): def precompute_falcon_freq_cis(max_position_embedding: int, head_dim: int, theta: float = 10000): - """ - The precompute_falcon_freq_cis function is used to precompute the sinusoidal frequencies for the FALCON model. + """The precompute_falcon_freq_cis function is used to precompute the sinusoidal frequencies for the FALCON model. The function takes in three arguments: max_position_embedding, head_dim, and theta. The first two are self-explanatory; the third is a hyperparameter that controls how quickly the frequency increases with position (i.e., how many times higher it will be at position i than at position 0). The default value of 10000 was chosen because it worked well on the tasks we tested. - :param max_position_embedding: int: Set the maximum length of the sequence - :param head_dim: int: Determine the size of the positional embedding - :param theta: float: Adjust the frequency of the sinusoid - :return: A tuple of two arrays - + Args: + max_position_embedding: int: Set the maximum length of the + sequence + head_dim: int: Determine the size of the positional embedding + theta: float: Adjust the frequency of the sinusoid + + Returns: + A tuple of two arrays """ inv_freq_cis = 1.0 / (theta ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)) freq = jnp.einsum("i , j -> i j", jnp.arange(max_position_embedding), inv_freq_cis).astype("float32") @@ -78,46 +84,53 @@ def precompute_falcon_freq_cis(max_position_embedding: int, head_dim: int, theta def _rotate_half(x): - """ - The _rotate_half function takes a 1D array and rotates it by half its length. + """The _rotate_half function takes a 1D array and rotates it by half its length. For example, if the input is [0, 1, 2, 3], then the output will be [-2,-3,-0,-4]. This function is used to rotate the Fourier transform of an image so that its zero-frequency component is in the center of the spectrum. - :param x: Specify the input array - :return: The negative of the second half of x concatenated with the first half - + Args: + x: Specify the input array + + Returns: + The negative of the second half of x concatenated with the first + half """ return jnp.concatenate((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), axis=-1) def apply_rotary_pos_embedding(tensor, sin_, cos_): - """ - The apply_rotary_pos_embedding function applies a rotary positional embedding to the input tensor. + """The apply_rotary_pos_embedding function applies a rotary positional embedding to the input tensor. + + Args: + tensor: Pass in the tensor that we want to apply the positional + embedding to + sin_: Rotate the tensor by half of its length + cos_: Multiply the tensor and cosine of the angle - :param tensor: Pass in the tensor that we want to apply the positional embedding to - :param sin_: Rotate the tensor by half of its length - :param cos_: Multiply the tensor and cosine of the angle - :return: A tensor with the same shape as its input, - + Returns: + A tensor with the same shape as its input, """ return (tensor * cos_) + (_rotate_half(tensor) * sin_) def dropout_add(linen_drop: flax.linen.Dropout, x: chex.Array, residual: chex.Array, deterministic: bool) -> chex.Array: - """ - The dropout_add function is a helper function that adds the residual to the output of + """The dropout_add function is a helper function that adds the residual to the output of the dropout layer. This is necessary because we want to use deterministic=True when we are evaluating our model, but we still need to add in the residual. The reason for this is that during training, we have two paths through our network: one with dropout and one without. The path without dropout (residual) allows us to backpropagate gradients through both paths at once. - :param linen_drop: flax.linen.Dropout: Specify the dropout layer - :param x: chex.Array: Pass in the input to the dropout layer - :param residual: chex.Array: Add the residual to the output of dropout_add - :param deterministic: bool: Determine whether the dropout layer is active or not - :return: A tensor that is the sum of the residual and a dropout layer - + Args: + linen_drop: flax.linen.Dropout: Specify the dropout layer + x: chex.Array: Pass in the input to the dropout layer + residual: chex.Array: Add the residual to the output of + dropout_add + deterministic: bool: Determine whether the dropout layer is + active or not + + Returns: + A tensor that is the sum of the residual and a dropout layer """ out = linen_drop(inputs=x, deterministic=deterministic) out = residual + out @@ -718,15 +731,17 @@ def __init__(self, config, super().__init__(_do_init=_do_init, module=module, config=config, dtype=dtype, input_shape=input_shape) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. + """The init_weights function is used to initialize the weights of a model. - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Specify the shape of the input tensor - :param params: FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Specify the shape of the input tensor + params: FrozenDict: Pass in the parameters of a pre-trained + model + Returns: + A frozendict of parameters """ input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) diff --git a/src/python/easydel/modules/flax_modelling_utils.py b/src/python/easydel/modules/flax_modelling_utils.py index e42befdaf..cdaadbf53 100644 --- a/src/python/easydel/modules/flax_modelling_utils.py +++ b/src/python/easydel/modules/flax_modelling_utils.py @@ -68,17 +68,18 @@ def canonicalize_dtype( def get_names_from_partition_spec(partition_specs): - """ - The get_names_from_partition_spec function takes a partition_specs argument, which is either a dictionary or list. + """The get_names_from_partition_spec function takes a partition_specs argument, which is either a dictionary or list. If it's a dictionary, the function converts it to a list of values. Then for each item in the partition_specs list: If the item is None, continue (do nothing) and move on to next iteration of loop. If the item is an instance of str (i.e., if it's just one string), add that string to names set and move on to next iteration of loop. Otherwise, (if not None or str), call get_names_from_partition_spec recurs - :param partition_specs: Define the partitioning of a table - :return: A list of the names of all partitions + Args: + partition_specs: Define the partitioning of a table + Returns: + A list of the names of all partitions """ names = set() if isinstance(partition_specs, dict): @@ -95,15 +96,17 @@ def get_names_from_partition_spec(partition_specs): def names_in_mesh(*names): - """ - The names_in_mesh function is a decorator that can be used to check whether + """The names_in_mesh function is a decorator that can be used to check whether the names of the axes passed into a function are valid. It will raise an exception if any of the axis names are not in the physical mesh. For example, if you have a function that takes two axes as arguments, and you want to make sure they're both in your mesh: - :param names: Collect all the names passed to the function into a tuple - :return: A boolean indicating whether all the given + Args: + *names: Collect all the names passed to the function into a + tuple + Returns: + A boolean indicating whether all the given """ return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names) @@ -136,17 +139,19 @@ def get_gradient_checkpoint_policy(name): def repeat_kv_bnsh(x: chex.Array, n_rep: int) -> chex.Array: - """ - The repeat_kv_bnsh function is used to repeat the key and value vectors for each head in a multi-head attention + """The repeat_kv_bnsh function is used to repeat the key and value vectors for each head in a multi-head attention module. This function takes as input an array of shape (batch_size, n_heads, sequence_length, head_dim) and returns an array of shape (batch_size, n_heads * nrep, sequence length, head dim). The reason this is necessary is because the attention module expects keys/values/queries to be repeated across heads but not across batches. However we want our keys/values/queries to be repeated both across heads AND batches so that we can use them - :param x: chex.Array: Pass in the input to the function - :param n_rep: int: Repeat the key and value heads - :return: A new array with the same shape as x, except for the second dimension which is n_kv_heads * n_rep + Args: + x: chex.Array: Pass in the input to the function + n_rep: int: Repeat the key and value heads + Returns: + A new array with the same shape as x, except for the second + dimension which is n_kv_heads * n_rep """ bs, n_kv_heads, s, head_dim = x.shape if n_rep == 1: @@ -158,13 +163,15 @@ def repeat_kv_bnsh(x: chex.Array, n_rep: int) -> chex.Array: def repeat_kv_bsnh(x: chex.Array, n_rep: int) -> chex.Array: - """ - The repeat_kv_bsnh function is used to repeat the key and value vectors for each head. + """The repeat_kv_bsnh function is used to repeat the key and value vectors for each head. - :param x: chex.Array: Specify the input array - :param n_rep: int: Repeat the key-value attention heads n_rep times - :return: A new array with the same batch size, sequence length, and head dimension as the input array + Args: + x: chex.Array: Specify the input array + n_rep: int: Repeat the key-value attention heads n_rep times + Returns: + A new array with the same batch size, sequence length, and head + dimension as the input array """ bs, s, n_kv_heads, head_dim = x.shape x = x.transpose(0, 2, 1, 3) @@ -283,15 +290,15 @@ def _calc_su_scaling_factor(scale): def rotate_half(x): - """ - The rotate_half function takes a complex-valued array and rotates the + """The rotate_half function takes a complex-valued array and rotates the phase of its second half by 180 degrees. This is equivalent to multiplying the second half by -i, or equivalently rotating it 90 degrees counterclockwise. + Args: + x: Specify the input array - :param x: Specify the input array - :return: A new array that is the same as the input - + Returns: + A new array that is the same as the input """ x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] @@ -299,31 +306,33 @@ def rotate_half(x): def apply_rotary_pos_emb(tensor, sin_, cos_): - """ - The apply_rotary_pos_emb function applies a rotary positional embedding to the input tensor. + """The apply_rotary_pos_emb function applies a rotary positional embedding to the input tensor. b,h,s,d or pytorch style - :param tensor: Store the tensor that is passed into the function - :param sin_: Rotate the tensor by pi/2 - :param cos_: Apply the cosine function to the tensor - :return: A tensor with the same shape as the input tensor + Args: + tensor: Store the tensor that is passed into the function + sin_: Rotate the tensor by pi/2 + cos_: Apply the cosine function to the tensor + Returns: + A tensor with the same shape as the input tensor """ b, h, s, d = tensor.shape return (tensor * cos_[:, :, :s, :]) + (rotate_half(tensor) * sin_[:, :, :s, :]) def get_ranks_and_size(mesh): - """ - The get_ranks_and_size function is used to determine the number of MPI processes + """The get_ranks_and_size function is used to determine the number of MPI processes (``mp_node_size``) and the number of devices per process (``dp_node_size``). The ``mesh.shape[mp]`` determines how many MPI processes are needed, and then we divide that by the local device count to get ``mp_node_size = max( 1, mp / jax.local )`. This means that if there are more than enough devices for all MPI ranks on a node, each rank will only use one device; otherwise it will use - :param mesh: Get the shape of the mesh - :return: A dictionary with the following keys: + Args: + mesh: Get the shape of the mesh + Returns: + A dictionary with the following keys: """ out = dict(mesh=mesh) total_process_size = mesh.shape["tp"] * mesh.shape["sp"] @@ -342,14 +351,15 @@ def get_ranks_and_size(mesh): def create_mesh( axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"), backend="" ): - """ - The create_mesh function creates a mesh object that can be used to shard arrays. + """The create_mesh function creates a mesh object that can be used to shard arrays. - :param axis_dims: Sequence[int]: Specify the dimensions of the mesh - :param axis_names: Sequence[str]: Name the axes of the mesh - :param backend: Specify the backend to use - :return: A mesh object + Args: + axis_dims: Sequence[int]: Specify the dimensions of the mesh + axis_names: Sequence[str]: Name the axes of the mesh + backend: Specify the backend to use + Returns: + A mesh object """ array_devices = jax.numpy.ones( (len(jax.devices() if backend == "" else jax.devices(backend)), 1)) @@ -361,15 +371,16 @@ def create_mesh( def add_start_docstrings(*docstr): - """ - The add_start_docstrings function is a decorator that adds the docstrings to the beginning of a function. + """The add_start_docstrings function is a decorator that adds the docstrings to the beginning of a function. The add_start_docstrings function takes in an arbitrary number of strings and returns a decorator. The returned decorator takes in one argument, fn, which is assumed to be a function. The docstring for fn is set equal to the concatenation of all the strings passed into add_start_docstrings plus (if it exists) the original docstring for fn. - :param docstr: Pass in a variable number of arguments to the function - :return: A decorator that adds the docstrings to the function + Args: + *docstr: Pass in a variable number of arguments to the function + Returns: + A decorator that adds the docstrings to the function """ def docstring_decorator(fn): @@ -384,14 +395,17 @@ def get_dot_general_by_bits( bits: Optional[int] = None, mode: Literal["train", "serve", "convert"] = EasyMethod.TRAIN ) -> dict: - """ - The get_general_dot function is a helper function that returns a q_flax.QDotGeneral object + """The get_general_dot function is a helper function that returns a q_flax.QDotGeneral object with the specified number of bits for forward and backward passes. If no bits are specified, the function returns None. - :param bits: Optional[int]: Specify the number of bits for quantization - :param mode: EasyMethod: Specify the use of model to init the QDot Method for (e.q TRAIN,SERVE,...) - :return: A dict that contain dot_general_cls + Args: + bits: Optional[int]: Specify the number of bits for quantization + mode: EasyMethod: Specify the use of model to init the QDot + Method for (e.q TRAIN,SERVE,...) + + Returns: + A dict that contain dot_general_cls """ if mode == EasyMethod.TRAIN: rhs_quant_mode = q_flax.QuantMode.TRAIN @@ -420,19 +434,22 @@ class BaseJAXAttentionModule(nn.Module): @nn.compact def _concatenate_to_cache(self, key, value, query_states, attention_mask): - """ - The _concatenate_to_cache function is used to concatenate the key and value vectors + """The _concatenate_to_cache function is used to concatenate the key and value vectors of a query_states with those of previous queries. This allows for the attention mechanism to look at all previous queries when computing its output. The function takes in three arguments: key, value, and query_states. It also uses two variables that are stored in the cache: cached_key and cached_value. - :param self: Access the variables stored in the cache - :param key: Store the keys of the encoder-decoder attention - :param value: Initialize the cached_value variable - :param query_states: Determine the number of cache vectors to update - :param attention_mask: Mask out the padded vectors in the cache - :return: The key, value and attention_mask + Args: + self: Access the variables stored in the cache + key: Store the keys of the encoder-decoder attention + value: Initialize the cached_value variable + query_states: Determine the number of cache vectors to + update + attention_mask: Mask out the padded vectors in the cache + + Returns: + The key, value and attention_mask """ quantize_kv_cache = self.config.quantize_kv_cache is_initialized = self.has_variable("cache", "cached_key") diff --git a/src/python/easydel/modules/gemma/gemma_configuration.py b/src/python/easydel/modules/gemma/gemma_configuration.py index 66bb66147..d6b9c8a59 100644 --- a/src/python/easydel/modules/gemma/gemma_configuration.py +++ b/src/python/easydel/modules/gemma/gemma_configuration.py @@ -35,8 +35,7 @@ def __init__( hidden_activation=None, **kwargs, ): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It sets up the attributes of an object, which are sometimes called fields or properties. The __init__ function can accept arguments, but self must be the first one. """ @@ -70,15 +69,17 @@ def __init__( ) def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices. - :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to + partition the model fully or not + Returns: + A list of tuples """ return ( @@ -122,12 +123,14 @@ def add_jax_args( bits: Optional[int] = None, **kwargs, ): - """ - The add_jax_args function adds the following arguments to the Transformer class: - - :param self: Refer to the current object - :param gradient_checkpointing: str: Control the amount of memory used by jax - :param bits: Optional[int]: Determine the number of bits used in the quantization + """The add_jax_args function adds the following arguments to the Transformer class: + + Args: + self: Refer to the current object + gradient_checkpointing: str: Control the amount of memory + used by jax + bits: Optional[int]: Determine the number of bits used in + the quantization """ self.gradient_checkpointing = gradient_checkpointing self.bits = bits diff --git a/src/python/easydel/modules/gemma/modelling_gemma_flax.py b/src/python/easydel/modules/gemma/modelling_gemma_flax.py index e7840639f..b256d688d 100644 --- a/src/python/easydel/modules/gemma/modelling_gemma_flax.py +++ b/src/python/easydel/modules/gemma/modelling_gemma_flax.py @@ -214,33 +214,37 @@ def _split_heads(self, hidden_states, num_heads): @staticmethod def _transpose_sequence_head(query, key, value): - """ - The _transpose_sequence_head function transposes the query, key and value matrices. + """The _transpose_sequence_head function transposes the query, key and value matrices. - :param query: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query, key and value matrices + Args: + query: Get the attention weights for each of the heads + key: Determine the number of heads + value: Store the values of the input + Returns: + The transpose of the query, key and value matrices """ return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - :param self: Access variables that belong to the class - :param batch_size: Reshape the query, key and value tensors - :param sequence_length: Reshape the query, key and value tensors - :param query: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param freq_cis: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query, key and value - + Args: + self: Access variables that belong to the class + batch_size: Reshape the query, key and value tensors + sequence_length: Reshape the query, key and value tensors + query: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + freq_cis: Calculate the frequency of each word in the + vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query, key and value """ query = query.reshape( batch_size, @@ -513,8 +517,7 @@ def __call__( class FlaxGemmaPreTrainedModel(EasyDeLFlaxPretrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + """An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ @@ -550,17 +553,18 @@ def __init__( ) def init_cache(self, batch_size, max_length): - """ - The init_cache function is used to initialize the cache for a given batch size and sequence length. + """The init_cache function is used to initialize the cache for a given batch size and sequence length. The cache is a dictionary that contains all the intermediate states from each layer in the model. This allows us to run inference on multiple batches without having to re-run forward passes through every layer in the model, which would be very slow. - :param self: Access the module - :param batch_size: Define the batch size of the input tensors - :param max_length: Set the length of the input sequence - :return: A dictionary with the following keys: + Args: + self: Access the module + batch_size: Define the batch size of the input tensors + max_length: Set the length of the input sequence + Returns: + A dictionary with the following keys: """ input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) diff --git a/src/python/easydel/modules/gpt_j/modelling_gpt_j_flax.py b/src/python/easydel/modules/gpt_j/modelling_gpt_j_flax.py index 5b4940292..9072271b0 100644 --- a/src/python/easydel/modules/gpt_j/modelling_gpt_j_flax.py +++ b/src/python/easydel/modules/gpt_j/modelling_gpt_j_flax.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" GPT-J model configuration""" +"""GPT-J model configuration""" import math from functools import partial from typing import Optional, Tuple, Union diff --git a/src/python/easydel/modules/grok_1/__init__.py b/src/python/easydel/modules/grok_1/__init__.py index 9ce380e0f..53c38a848 100644 --- a/src/python/easydel/modules/grok_1/__init__.py +++ b/src/python/easydel/modules/grok_1/__init__.py @@ -1,5 +1,5 @@ -from .grok_1_configuration import Grok1Config as Grok1Config -from .modelling_grok_1_flax import ( - FlaxGrok1ForCausalLM as FlaxGrok1ForCausalLM, - FlaxGrok1Model as FlaxGrok1Model -) +from .grok_1_configuration import Grok1Config as Grok1Config +from .modelling_grok_1_flax import ( + FlaxGrok1ForCausalLM as FlaxGrok1ForCausalLM, + FlaxGrok1Model as FlaxGrok1Model +) diff --git a/src/python/easydel/modules/grok_1/grok_1_configuration.py b/src/python/easydel/modules/grok_1/grok_1_configuration.py index f018a9220..5a2f871a6 100644 --- a/src/python/easydel/modules/grok_1/grok_1_configuration.py +++ b/src/python/easydel/modules/grok_1/grok_1_configuration.py @@ -1,147 +1,152 @@ -from ..easydel_modelling_utils import EasyDeLPretrainedConfig -from typing import Union, Optional -from jax.sharding import PartitionSpec - - -class Grok1Config(EasyDeLPretrainedConfig): - model_type: str = "grok-1" - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=32768, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=32, - attn_output_multiplier=1.0, - max_attn_value=1.0, - max_position_embeddings=4096, - embedding_multiplier_scale: float = 1.0, - output_multiplier_scale: float = 1.0, - rms_norm_eps=1e-5, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=True, - num_experts_per_tok=2, - num_experts=8, - output_router_logits=False, - router_aux_loss_coef=0.001, - gradient_checkpointing: str = "nothing_saveable", - bits: Optional[int] = None, - **kwargs - ): - self.vocab_size = vocab_size - self.attn_output_multiplier = attn_output_multiplier - self.max_attn_value = max_attn_value - self.max_position_embeddings = max_position_embeddings - self.embedding_multiplier_scale = embedding_multiplier_scale - self.output_multiplier_scale = output_multiplier_scale - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - - self.num_experts_per_tok = num_experts_per_tok - self.num_experts = num_experts - self.output_router_logits = output_router_logits - self.router_aux_loss_coef = router_aux_loss_coef - self.gradient_checkpointing = gradient_checkpointing - self.bits = bits - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. - It returns a list of tuples, where each tuple contains two elements: - 1) A regex string that matches the name of one or more parameters in the model. - 2) A PartitionScheme object that defines how those parameters should be partitioned across devices. - - :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not - :return: A list of tuples - - """ - return ( - - ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))), - - ("attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")), - ("attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))), - - ("linear/kernel", PartitionSpec(("fsdp", "sp"), "tp")), - ("linear_1/kernel", PartitionSpec("tp", ("fsdp", "sp"))), - ("linear_v/kernel", PartitionSpec(("fsdp", "sp"), "tp")), - ("gate/kernel", PartitionSpec(("fsdp", "sp"))), - - ("post_attn_norm/kernel", PartitionSpec(None)), - ("pre_attn_norm/kernel", PartitionSpec(None)), - ("pre_moe_norm/kernel", PartitionSpec(None)), - ("post_moe_norm/kernel", PartitionSpec(None)), - - ("model/norm/kernel", PartitionSpec(None)), - ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")), - (".*", PartitionSpec(None)), - ) if not fully_sharded_data_parallel else ( - - ("model/embed_tokens/embedding", PartitionSpec(("fsdp", "sp"))), - - ("attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"))), - ("attn/o_proj/kernel", PartitionSpec(("fsdp", "sp"))), - - ("linear/kernel", PartitionSpec(("fsdp", "sp"))), - ("linear_1/kernel", PartitionSpec(("fsdp", "sp"))), - ("linear_v/kernel", PartitionSpec(("fsdp", "sp"))), - ("gate/kernel", PartitionSpec(("fsdp", "sp"))), - - ("post_attn_norm/kernel", PartitionSpec(("fsdp", "sp"))), - ("pre_attn_norm/kernel", PartitionSpec(("fsdp", "sp"))), - ("pre_moe_norm/kernel", PartitionSpec(("fsdp", "sp"))), - ("post_moe_norm/kernel", PartitionSpec(("fsdp", "sp"))), - - ("model/norm/kernel", PartitionSpec(("fsdp", "sp"))), - ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))), - (".*", PartitionSpec(("fsdp", "sp"))), - ) - - def add_jax_args( - self, - tie_word_embeddings: bool = False, - gradient_checkpointing: str = "nothing_saveable", - bits: Optional[int] = None, - **kwargs, - ): - """ - The add_jax_args function adds the following arguments to the Transformer class: - - :param self: Refer to the current object - :param tie_word_embeddings: bool: Tie the word embeddings to the decoder - :param gradient_checkpointing: str: Control the amount of memory used by jax - :param bits: Optional[int]: Determine the number of bits used in the quantization - """ - self.tie_word_embeddings = tie_word_embeddings - self.gradient_checkpointing = gradient_checkpointing - self.bits = bits - - @staticmethod - def get_weight_decay_exclusions(): - return tuple() - - @staticmethod - def rng_keys(): - return 'params', 'dropout' +from ..easydel_modelling_utils import EasyDeLPretrainedConfig +from typing import Union, Optional +from jax.sharding import PartitionSpec + + +class Grok1Config(EasyDeLPretrainedConfig): + model_type: str = "grok-1" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=32768, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + attn_output_multiplier=1.0, + max_attn_value=1.0, + max_position_embeddings=4096, + embedding_multiplier_scale: float = 1.0, + output_multiplier_scale: float = 1.0, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + num_experts_per_tok=2, + num_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + gradient_checkpointing: str = "nothing_saveable", + bits: Optional[int] = None, + **kwargs + ): + self.vocab_size = vocab_size + self.attn_output_multiplier = attn_output_multiplier + self.max_attn_value = max_attn_value + self.max_position_embeddings = max_position_embeddings + self.embedding_multiplier_scale = embedding_multiplier_scale + self.output_multiplier_scale = output_multiplier_scale + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.gradient_checkpointing = gradient_checkpointing + self.bits = bits + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def get_partition_rules(self, fully_sharded_data_parallel: bool = True): + """The get_partition_rules function is used to define the partitioning scheme for a model. + It returns a list of tuples, where each tuple contains two elements: + 1) A regex string that matches the name of one or more parameters in the model. + 2) A PartitionScheme object that defines how those parameters should be partitioned across devices. + + Args: + fully_sharded_data_parallel: bool: Determine whether to + partition the model fully or not + + Returns: + A list of tuples + """ + return ( + + ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))), + + ("attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")), + ("attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))), + + ("linear/kernel", PartitionSpec(("fsdp", "sp"), "tp")), + ("linear_1/kernel", PartitionSpec("tp", ("fsdp", "sp"))), + ("linear_v/kernel", PartitionSpec(("fsdp", "sp"), "tp")), + ("gate/kernel", PartitionSpec(("fsdp", "sp"))), + + ("post_attn_norm/kernel", PartitionSpec(None)), + ("pre_attn_norm/kernel", PartitionSpec(None)), + ("pre_moe_norm/kernel", PartitionSpec(None)), + ("post_moe_norm/kernel", PartitionSpec(None)), + + ("model/norm/kernel", PartitionSpec(None)), + ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")), + (".*", PartitionSpec(None)), + ) if not fully_sharded_data_parallel else ( + + ("model/embed_tokens/embedding", PartitionSpec(("fsdp", "sp"))), + + ("attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"))), + ("attn/o_proj/kernel", PartitionSpec(("fsdp", "sp"))), + + ("linear/kernel", PartitionSpec(("fsdp", "sp"))), + ("linear_1/kernel", PartitionSpec(("fsdp", "sp"))), + ("linear_v/kernel", PartitionSpec(("fsdp", "sp"))), + ("gate/kernel", PartitionSpec(("fsdp", "sp"))), + + ("post_attn_norm/kernel", PartitionSpec(("fsdp", "sp"))), + ("pre_attn_norm/kernel", PartitionSpec(("fsdp", "sp"))), + ("pre_moe_norm/kernel", PartitionSpec(("fsdp", "sp"))), + ("post_moe_norm/kernel", PartitionSpec(("fsdp", "sp"))), + + ("model/norm/kernel", PartitionSpec(("fsdp", "sp"))), + ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))), + (".*", PartitionSpec(("fsdp", "sp"))), + ) + + def add_jax_args( + self, + tie_word_embeddings: bool = False, + gradient_checkpointing: str = "nothing_saveable", + bits: Optional[int] = None, + **kwargs, + ): + """The add_jax_args function adds the following arguments to the Transformer class: + + Args: + self: Refer to the current object + tie_word_embeddings: bool: Tie the word embeddings to the + decoder + gradient_checkpointing: str: Control the amount of memory + used by jax + bits: Optional[int]: Determine the number of bits used in + the quantization + """ + self.tie_word_embeddings = tie_word_embeddings + self.gradient_checkpointing = gradient_checkpointing + self.bits = bits + + @staticmethod + def get_weight_decay_exclusions(): + return tuple() + + @staticmethod + def rng_keys(): + return 'params', 'dropout' diff --git a/src/python/easydel/modules/grok_1/modelling_grok_1_flax.py b/src/python/easydel/modules/grok_1/modelling_grok_1_flax.py index 8eb4971ef..5e23c8ee1 100644 --- a/src/python/easydel/modules/grok_1/modelling_grok_1_flax.py +++ b/src/python/easydel/modules/grok_1/modelling_grok_1_flax.py @@ -1,1252 +1,1291 @@ -import math -from typing import Optional, Tuple, Union - -import chex -from fjformer import linen as nn -import flax.linen.partitioning -import flax.struct -import jax -import jax.numpy as jnp -from fjformer.func import auxiliary_load_balancing_loss_func -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.sharding import PartitionSpec -from transformers.modeling_flax_outputs import FlaxMaskedLMOutput - -from .grok_1_configuration import Grok1Config -from ..attention_module import AttentionModule -from ..easydel_modelling_utils import EasyDeLFlaxPretrainedModel -# easydel.modules -from ..flax_modelling_utils import ( - with_sharding_constraint, - get_gradient_checkpoint_policy, - repeat_kv_bnsh, - apply_rotary_pos_emb, - precompute_freq_cis, - get_dot_general_by_bits, - BaseJAXAttentionModule, - block_wise_ffn -) -from fjformer.linen import Linear - -re_mat = flax.linen.partitioning.remat - - -@flax.struct.dataclass -class MoeModelOutput: - last_hidden_state: chex.Array = None - hidden_states: Optional[Tuple[chex.Array]] = None - attentions: Optional[Tuple[chex.Array]] = None - router_logits: Optional[Tuple[chex.Array]] = None - - -@flax.struct.dataclass -class MoeCausalLMOutput(FlaxMaskedLMOutput): - aux_loss: Optional[chex.Array] = None - router_logits: Optional[Tuple[chex.Array]] = None - - -class FlaxGrok1Embedding(nn.Module): - dtype: jnp.dtype = jnp.float32 - - def __call__(self, query, key, freq_cis, position_ids): - sin, cos = freq_cis - - sin = sin[position_ids][:, None, :, :] - cos = cos[position_ids][:, None, :, :] - - key = apply_rotary_pos_emb(key, sin, cos) - query = apply_rotary_pos_emb(query, sin, cos) - - return query.astype(self.dtype), key.astype(self.dtype) - - -def repeat_kv(x: chex.Array, n_rep: int) -> chex.Array: - bs, s, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - x = x[:, :, jnp.newaxis, :, :] - x = jnp.repeat(x, n_rep, axis=2) - - return x.reshape(bs, s, - n_kv_heads * n_rep, - head_dim) - - -class FlaxGrok1RMSNorm(nn.Module): - dim: int - eps: float = 1e-6 - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.weight = self.param( - 'kernel', - nn.initializers.ones, - (self.dim,), - self.param_dtype, - ) - - def _norm(self, x: jnp.ndarray) -> jnp.ndarray: - return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) - - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - x = x.astype(jnp.promote_types(self.dtype, jnp.float32)) - output = self._norm(x).astype(self.dtype) - weight = nn.linen.control_quantization(self.weight, self.dtype) - return output * weight - - -class FlaxGrok1Attention(BaseJAXAttentionModule): - config: Grok1Config - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[jax.lax.Precision, str]] = None - - def setup(self): - config = self.config - self.hidden_size = config.hidden_size - self.head_dim = self.config.hidden_size // self.config.num_attention_heads - self.num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads - - if self.num_key_value_groups == 1: - assert self.config.num_attention_heads == self.config.num_key_value_heads - self.q_proj = Linear( - config.num_attention_heads * self.head_dim, - dtype=self.dtype, - param_dtype=self.param_dtype, - use_bias=False, - kernel_init=jax.nn.initializers.normal( - self.config.initializer_range), - precision=self.precision, - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - self.k_proj = Linear( - config.num_key_value_heads * self.head_dim, - dtype=self.dtype, - param_dtype=self.param_dtype, - use_bias=False, - kernel_init=jax.nn.initializers.normal( - self.config.initializer_range), - precision=self.precision, - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - self.v_proj = Linear( - config.num_key_value_heads * self.head_dim, - dtype=self.dtype, - param_dtype=self.param_dtype, - use_bias=False, - kernel_init=jax.nn.initializers.normal( - self.config.initializer_range), - precision=self.precision, - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - self.o_proj = Linear( - config.hidden_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - use_bias=False, - kernel_init=jax.nn.initializers.normal( - self.config.initializer_range), - precision=self.precision, - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - - self.rotary = FlaxGrok1Embedding(self.dtype) - self.attention_performer = AttentionModule( - use_sharding_constraint=self.config.use_sharding_constraint, - block_k_major=self.config.block_k_major, - block_b=self.config.block_b, - block_q=self.config.block_q, - block_k=self.config.block_k, - block_q_major_dkv=self.config.block_q_major_dkv, - block_k_major_dkv=self.config.block_k_major_dkv, - block_k_major_dq=self.config.block_k_major_dq, - block_k_dkv=self.config.block_k_dkv, - block_q_dkv=self.config.block_q_dkv, - block_q_dq=self.config.block_q_dq, - block_k_dq=self.config.block_k_dq, - num_attention_heads=self.config.num_attention_heads, - attention_dropout=self.config.attention_dropout, - head_dims=self.head_dim, - attention_partition_spec=self.config.attention_partition_spec, - shard_attention_computation=self.config.shard_attention_computation, - precision=self.precision, - force_float32_tpu=True, - attn_mechanism=self.config.attn_mechanism, - dtype=self.dtype, - bias_partition_spec=self.config.bias_partition_spec, - key_partition_spec=self.config.key_partition_spec, - query_partition_spec=self.config.query_partition_spec, - generation_query_partition_spec=self.config.generation_query_partition_spec, - generation_bias_partition_spec=self.config.generation_bias_partition_spec, - generation_attention_partition_spec=self.config.generation_attention_partition_spec, - value_partition_spec=self.config.value_partition_spec, - scan_ring_attention=self.config.scan_ring_attention, - mesh=self.config.jax_mesh(), - sm_scale=1 / math.sqrt(self.head_dim), - axis_name=self.config.attention_axis_name - ) - self.resid_dropout = flax.linen.Dropout(rate=config.resid_pdrop) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) - - @staticmethod - def _transpose_sequence_head(query, key, value): - """ - The _transpose_sequence_head function transposes the query, key and value matrices. - - :param query: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query, key and value matrices - - """ - return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) - - def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. - The main difference is that it takes in an additional argument, freq_cis, which are used to calculate - the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - - :param self: Access variables that belong to the class - :param batch_size: Reshape the query, key and value tensors - :param sequence_length: Reshape the query, key and value tensors - :param query: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param freq_cis: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query, key and value - - """ - query = query.reshape( - batch_size, - sequence_length, - self.config.num_attention_heads, - self.head_dim - ) - key = key.reshape( - batch_size, - sequence_length, - self.config.num_key_value_heads, - self.head_dim - ) - value = value.reshape( - batch_size, - sequence_length, - self.config.num_key_value_heads, - self.head_dim - ) - - query, key, value = self._transpose_sequence_head(query, key, value) - query, key = self.rotary( - position_ids=position_ids, query=query, key=key, freq_cis=freq_cis - ) - key = repeat_kv_bnsh(key, self.num_key_value_groups) - value = repeat_kv_bnsh(value, self.num_key_value_groups) - return self._transpose_sequence_head(query, key, value) - - def __call__( - self, - hidden_states: chex.Array, - freq_cis: Tuple[chex.Array, chex.Array], - attention_mask: chex.Array, - position_ids: chex.Array, - causal_mask: chex.Array, - segment_ids: Optional[chex.Array] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - fcm_mask=None, - ): - """ - - The __call__ function is the main function of a JAX module. It defines how the module behaves when called - with inputs. The __call__ function can be thought of as a "forward pass" through the model, - and it should return all outputs that are needed for training or inference. - - :param self: Access variables that belong to the class - :param hidden_states: chex.Array: Pass the hidden states of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param causal_mask: chex.Array: Mask out the future tokens in the decoder - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attention weights or not - :param fcm_mask: Mask out the attention weights between the input and output tokens - :param : Determine if the attention is causal or not - :return: A tuple of two arrays - - """ - batch_size, sequence_length = hidden_states.shape[:2] - query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj( - hidden_states) - - query_states = query_states.reshape( - batch_size, sequence_length, self.config.num_attention_heads, self.head_dim) - key_states = key_states.reshape( - batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim) - value_states = value_states.reshape( - batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim) - - query_states, key_states, value_states = self.apply_rotary( - query=query_states, - key=key_states, - value=value_states, - position_ids=position_ids, - freq_cis=freq_cis, - batch_size=batch_size, - sequence_length=sequence_length - ) - - assert_msg = ( - "num_attention_heads repeat wont work likely\n" - f"INFO :\n\trepeat_kv_bnsh Used with num_key_value_groups = {self.num_key_value_groups}\n\t" - f"NH : {self.config.num_attention_heads} KVH : {self.config.num_attention_heads}" - ) - - assert query_states.shape[-2] == self.config.num_attention_heads, assert_msg - assert key_states.shape[-2] == self.config.num_attention_heads, assert_msg - assert value_states.shape[-2] == self.config.num_attention_heads, assert_msg - - query_length, key_length = query_states.shape[1], key_states.shape[1] - - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - causal_mask, - (0, 0, mask_shift, 0), - (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = causal_mask[:, :, :query_length, :key_length] - - batch_size = hidden_states.shape[0] - causal_mask = jnp.broadcast_to( - causal_mask, (batch_size,) + causal_mask.shape[1:]) - attention_mask = jnp.broadcast_to(jnp.expand_dims( - attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask) - if attention_mask.ndim == 2: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - dropout_rng = None - - if not deterministic and self.config.attention_dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - if self.has_variable("cache", "cached_key") or init_cache: - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, - value_states, - query_states, - attention_mask - ) - - # if self.config.use_sharding_constraint: - # query_states = with_sharding_constraint( - # query_states, PartitionSpec(("dp", "fsdp"), "sp" if query_states.shape[1] != 1 else None, "tp", None) - # ) - # key_states = with_sharding_constraint( - # key_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None) - # ) - # value_states = with_sharding_constraint( - # value_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None) - # ) - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo( - self.dtype).min).astype(self.dtype), - ) - - query_length, key_length = query_states.shape[1], key_states.shape[1] - - attentions = self.attention_performer.__call__( - query_states=query_states, - key_states=key_states, - value_states=value_states, - bias=attention_bias, - attention_mask=attention_mask, - causal=True, - dropout_rng=dropout_rng, - deterministic=deterministic, - query_sequence_length=query_length, - key_value_sequence_length=key_length, - uses_cache=self.has_variable("cache", "cached_key") or init_cache, - segment_ids=segment_ids, - causal_mask=causal_mask - ) - - - attn_output = self._merge_heads(attentions.attention_outputs) - if self.config.shard_attention_computation: - attn_output = with_sharding_constraint( - attn_output, PartitionSpec( - ("dp", "fsdp"), - "sp" if attn_output.shape[1] != 1 else None, - "tp" - ) - ) - attn_output = self.o_proj(attn_output) - - attn_output = self.resid_dropout(attn_output, deterministic=deterministic) - outputs = (attn_output, attentions.attention_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxGrok1BLockSparseMLP(nn.Module): - config: Grok1Config - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[jax.lax.Precision, str]] = None - - def setup(self) -> None: - config = self.config - - self.linear = Linear( - config.intermediate_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - use_bias=False, - kernel_init=jax.nn.initializers.normal( - self.config.initializer_range), - precision=self.precision, - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - self.linear_1 = Linear( - config.hidden_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - use_bias=False, - kernel_init=jax.nn.initializers.normal( - self.config.initializer_range), - precision=self.precision, - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - self.linear_v = Linear( - config.intermediate_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - use_bias=False, - kernel_init=jax.nn.initializers.normal( - self.config.initializer_range), - precision=self.precision, - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - - def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - """ - The __call__ function is the main function of a class. - It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments). - The __call__ method enables instances of a class to be called like standard Python functions. - - :param self: Represent the instance of the class - :param x: jnp.ndarray: Pass in the input to the layer - :param deterministic: bool: Determine whether to use dropout # IGNORED - :return: A tensor that is the result of applying a dropout function to x - - """ - return self.linear_1(nn.gelu(self.linear(x)) * self.linear_v(x)) - - -class FlaxGrok1BlocKSparesTop2MLPCollection(nn.Module): - config: Grok1Config - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest") - - def setup(self) -> None: - self.layers = [ - FlaxGrok1BLockSparseMLP( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - name=str(i) - ) - for i in range(self.config.num_experts) - ] - - def __call__( - self, - selected_experts: chex.Array, - hidden_states: chex.Array, - routing_weights: chex.Array, - batch_size: int, - sequence_length: int, - hidden_dim: int - ) -> chex.Array: - final_hidden_state = jnp.zeros_like(hidden_states) - - for index in range(self.config.num_experts): - expert_layer_output = block_wise_ffn( - self.layers[index], - hidden_states, - self.config.scan_mlp_chunk_size, - False - ) if self.config.use_scan_mlp else self.layers[index](hidden_states) - expert_layer_output_exp = jnp.sum( - jnp.multiply( - selected_experts == index, routing_weights - ), axis=-1 - )[:, :, None] * expert_layer_output - final_hidden_state += expert_layer_output_exp - return final_hidden_state - - -class FlaxGrok1SparseMoeBlock(nn.Module): - """ - This implementation is - strictly equivalent to standard MoE with full capacity (no - dropped tokens). It's faster since it formulates MoE operations - in terms of block-sparse operations to accomodate imbalanced - assignments of tokens to experts, whereas standard MoE either - (1) drop tokens at the cost of reduced performance or (2) set - capacity factor to number of experts and thus waste computation - and memory on padding. - """ - config: Grok1Config - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[ - Union[None, jax.lax.Precision] - ] = jax.lax.Precision("fastest") - - def setup(self) -> None: - self.gate = Linear( - self.config.num_experts, - use_bias=False, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - kernel_init=nn.initializers.normal(), - ) - - self.experts = FlaxGrok1BlocKSparesTop2MLPCollection( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - def __call__( - self, - hidden_states: chex.Array, - e: bool = False # Ignored - ) -> Tuple[chex.Array, chex.Array]: - batch_size, sequence_length, hidden_dim = hidden_states.shape - - router_logits = self.gate(hidden_states).astype( - jnp.promote_types(self.dtype, jnp.float32) - ) - routing_weights, selected_experts = jax.lax.top_k( - router_logits, - k=self.config.num_experts_per_tok - ) - routing_weights = jax.nn.softmax( - routing_weights.astype( - jnp.promote_types(self.dtype, jnp.float32) - ), axis=-1 - ) - - return self.experts( - selected_experts=selected_experts, - batch_size=batch_size, - sequence_length=sequence_length, - hidden_dim=hidden_dim, - hidden_states=hidden_states, - routing_weights=routing_weights - ), router_logits - - -class FlaxGrok1DecoderLayer(nn.Module): - config: Grok1Config - layer_index: int - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") - - def setup(self) -> None: - # hidden_states: chex.Array - # freq_cis: Tuple[chex.Array, chex.Array], - # attention_mask: chex.Array - # causal_mask: chex.Array - # position_ids: chex.Array - # deterministic: bool = True - # init_cache: bool = False - # output_attentions: bool = True - - attn_block = FlaxGrok1Attention - mlp_block = FlaxGrok1SparseMoeBlock - if self.config.gradient_checkpointing != "": - attn_block = re_mat( - attn_block, - policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing), - static_argnums=( - 3, 5, 6, 7, 8 - ) - ) - mlp_block = re_mat( - mlp_block, - policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing), - static_argnums=( - 1, - ) - ) - self.attn = attn_block( - config=self.config, - layer_index=self.layer_index, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.moe_block = mlp_block( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.pre_attn_norm = FlaxGrok1RMSNorm( - dim=self.config.hidden_size, - eps=self.config.rms_norm_eps, - dtype=self.dtype, - param_dtype=self.param_dtype - ) - self.post_attn_norm = FlaxGrok1RMSNorm( - dim=self.config.hidden_size, - eps=self.config.rms_norm_eps, - dtype=self.dtype, - param_dtype=self.param_dtype - ) - self.pre_moe_norm = FlaxGrok1RMSNorm( - dim=self.config.hidden_size, - eps=self.config.rms_norm_eps, - dtype=self.dtype, - param_dtype=self.param_dtype - ) - self.post_moe_norm = FlaxGrok1RMSNorm( - dim=self.config.hidden_size, - eps=self.config.rms_norm_eps, - dtype=self.dtype, - param_dtype=self.param_dtype - ) - - def __call__( - self, - hidden_states: chex.Array, - freq_cis: Tuple[chex.Array, chex.Array], - attention_mask: chex.Array, - causal_mask: chex.Array, - position_ids: chex.Array, - segment_ids: Optional[chex.Array] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = True, - output_router_logits: Optional[bool] = False, - ): - """ - The __call__ function is the main function of a TransformerEncoderLayer. - It takes in the following arguments: - hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers. - freq_cis (chex.Array): A tensor containing frequency-domain representations of each token's context vector, used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. This tensor has shape `(batch_size, num - - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Represent the input to the encoder layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer - :param attention_mask: chex.Array: Mask out the attention weights for certain positions - :param causal_mask: chex.Array: Mask the future tokens - :param position_ids: chex.Array: Indicate the position of each token in the sequence - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache for the self-attention layer - :param output_attentions: bool: Determine whether to return the attention weights or not - :return: A tuple of hidden_states and attention_output - - """ - residual = hidden_states - hidden_states = self.pre_attn_norm(hidden_states) - hidden_states, attention_weights, present_key_value = self.attn( - hidden_states, - freq_cis, - attention_mask, - causal_mask, - position_ids, - segment_ids, - deterministic, - init_cache, - output_attentions - ) - - hidden_states = self.post_attn_norm(hidden_states) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.pre_moe_norm(hidden_states) - hidden_states, router_logits = self.moe_block(hidden_states) - hidden_states = self.post_moe_norm(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (attention_weights,) - if output_router_logits: - outputs += (router_logits,) - return outputs - - -class FlaxGrok1DecoderLayerCollection(nn.Module): - config: Grok1Config - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest") - - def setup(self) -> None: - self.blocks = [ - FlaxGrok1DecoderLayer( - layer_index=layer_index, - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - name=str(layer_index) - ) - - for layer_index in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states: chex.Array, - freq_cis: Tuple[chex.Array, chex.Array], - attention_mask: chex.Array, - causal_mask: chex.Array, - position_ids: chex.Array, - deterministic: bool = True, - init_cache: bool = False, - output_hidden_states: Optional[bool] = False, - output_attentions: Optional[bool] = False, - output_router_logits: Optional[bool] = False, - ): - """ - The __call__ function is the main function of a TransformerEncoderLayer. - It takes in the following arguments: - hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers. - freq_cis (chex.Array): A tensor containing frequency-domain representations of each token's context vector, used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. This tensor has shape `(batch_size, num - - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Represent the input to the encoder layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer - :param attention_mask: chex.Array: Mask out the attention weights for certain positions - :param causal_mask: chex.Array: Mask the future tokens - :param position_ids: chex.Array: Indicate the position of each token in the sequence - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache for the self-attention layer - :param output_attentions: bool: Determine whether to return the attention weights or not - :return: A tuple of hidden_states, attention_output, all_hidden_states and all_router_logits - - """ - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - - for block in self.blocks: - if output_hidden_states: - all_hidden_states += (hidden_states,) - layer_outputs = block( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - init_cache=init_cache, - freq_cis=freq_cis, - causal_mask=causal_mask, - deterministic=deterministic, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if output_router_logits: - all_router_logits += (layer_outputs[-1],) - - outputs = (hidden_states,) - if output_attentions: - outputs += (all_self_attns,) - if output_hidden_states: - outputs += (all_hidden_states,) - if output_router_logits: - outputs += (all_router_logits,) - return outputs - - -class Grok1PreTrainedModel(EasyDeLFlaxPretrainedModel): - config_class: Grok1Config = Grok1Config - module_class: nn.Module = None - base_model_prefix = "model" - - # main_input_name = "input_ids" - - def __init__( - self, - config: Grok1Config, - dtype: jnp.dtype = jnp.bfloat16, - param_dtype: jnp.dtype = jnp.bfloat16, - precision: Optional[jax.lax.Precision] = jax.lax.Precision( - "fastest"), - input_shape: Tuple[int, int] = (1, 1), - seed: int = 0, - _do_init: bool = False, - **kwargs - ): - module = self.module_class( - config=config, - dtype=dtype, - param_dtype=param_dtype, - precision=precision, - **kwargs - ) - - super().__init__( - dtype=dtype, _do_init=_do_init, - module=module, config=config, input_shape=input_shape, - seed=seed, - ) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, - params: Optional[FrozenDict] = None) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. - It takes in a rng, which is a random number generator key that can be used to generate random numbers. - The input_shape parameter specifies the shape of the inputs that will be fed into this model. - The params parameter allows you to pass in pre-trained weights for your model, if you have them available. - - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids - :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters - """ - - self.config.initialization_of_moe = True - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids, dtype="i4") - position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(input_ids).shape[-1], dtype="i4"), - input_shape, - ) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - if self.config.add_cross_attention: - encoder_hidden_states = jnp.zeros( - input_shape + (self.config.hidden_size,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - else: - module_init_outputs = self.module.init( - rngs, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=False - ) - random_params = module_init_outputs["params"] - - self.config.initialization_of_moe = False - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - - input_ids = jnp.ones((batch_size, max_length)) - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange( - jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return init_variables["cache"] - - def __call__( - self, - input_ids: chex.Array, - attention_mask: Optional[chex.Array] = None, - position_ids: Optional[chex.Array] = None, - params: dict = None, - past_key_values: dict = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, - add_params_field: bool = False, - **kwargs - ): - """ - The __call__ function is the main function of a JAX module. - It takes as input: - - The parameters of the model (self.params) - - The inputs to the model (input_ids, attention_mask, position_ids) - - Whether we are training (train=True/False) and whether we want to return all hidden states and - attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False). - - :param self: Represent the instance of the class - :param input_ids: Pass the input sequence to the model - :param attention_mask: Mask out the padding tokens - :param position_ids: Specify the position of each token in the sequence - :param params: dict: Pass in the parameters of the model - :param past_key_values: dict: Pass the past key values to the model - :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model - :param train: bool: Determine whether to use dropout or not - :param output_attentions: Optional[bool]: Determine whether to return the attention weights - :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers - :param return_dict: Optional[bool]: Return a dictionary of the outputs - :param add_params_field: bool: Add a params field to the inputs dictionary - :return: A tuple of (last_hidden_state, past_key_values) - - """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - batch_size, sequence_length = input_ids.shape - - if position_ids is None: - if past_key_values is not None: - raise ValueError( - "Make sure to provide `position_ids` when passing `past_key_values`.") - - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[ - None, :], (batch_size, sequence_length)) - - if attention_mask is None: - attention_mask = jnp.ones((batch_size, sequence_length)) - - rng_s = {} - if dropout_rng is not None: - rng_s["dropout"] = dropout_rng - - inputs = { - "params": params or self.params} if add_params_field else params or self.params - - if self.config.bits is not None: - rng_s['params'] = jax.random.key(0) - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), # input_ids: chex.Array - # attention_mask: Optional[chex.Array] = None - jnp.array(attention_mask, dtype="i4"), - # position_ids: Optional[chex.Array] = None - jnp.array(position_ids, dtype="i4"), - None, # inputs_embeds: Optional[chex.Array] = None - output_attentions, # output_attentions: Optional[bool] = None - # output_hidden_states: Optional[bool] = None - output_hidden_states, - # output_router_logits: Optional[bool] = None - output_router_logits, - False, # init_cache: bool = False - not train, # deterministic: bool = True - return_dict, # return_dict: bool = True - rngs=rng_s, - mutable=mutable, - ) - - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + \ - (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -class FlaxGrok1Module(nn.Module): - config: Grok1Config - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest") - - def setup(self) -> None: - self.embed_tokens = nn.Embed( - self.config.vocab_size, - self.config.hidden_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - ) - - self.layers = FlaxGrok1DecoderLayerCollection( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - self.norm = FlaxGrok1RMSNorm( - self.config.hidden_size, - eps=self.config.rms_norm_eps, - dtype=self.dtype, - param_dtype=self.param_dtype - ) - - initial_rope_kwargs = dict( - rope_type="none" - ) - if self.config.rope_scaling is not None: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - initial_rope_kwargs = dict( - scaling_factor=scaling_factor, - rope_type=scaling_type - ) - self.freq_cis = precompute_freq_cis( - max_position_embeddings=( - getattr(self.config, "freq_max_position_embeddings", self.config.max_position_embeddings) - ), - dim=self.config.hidden_size // self.config.num_attention_heads, - base=self.config.rope_theta, - **initial_rope_kwargs - ) - self.causal_mask = flax.linen.make_causal_mask( - jnp.ones( - (1, getattr(self.config, "c_max_position_embeddings", self.config.max_position_embeddings)), - dtype="bool" - ), dtype="bool" - ) - - def __call__( - self, - input_ids: chex.Array, - attention_mask: chex.Array, - position_ids: chex.Array, - inputs_embeds: Optional[chex.Array] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - init_cache: bool = False, - deterministic: bool = True, - return_dict: bool = True, - ) -> MoeModelOutput | Tuple: - if output_router_logits is None: - output_router_logits = self.config.output_router_logits - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if inputs_embeds is None and input_ids is not None: - inputs_embeds = self.embed_tokens(input_ids.astype("i4")) - inputs_embeds = inputs_embeds * self.config.embedding_multiplier_scale - else: - raise ValueError( - "you should specify inputs_embeds or input_ids one of them") - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - collection_outputs = self.layers( - hidden_states=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - causal_mask=self.causal_mask, - freq_cis=self.freq_cis, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - output_hidden_states=output_hidden_states, - init_cache=init_cache, - deterministic=deterministic, - ) - all_self_attns = None - all_hidden_states = None - all_router_logits = None - hidden_states = collection_outputs[0] - if output_attentions: - all_self_attns = collection_outputs[1] - if output_hidden_states: - all_hidden_states = collection_outputs[2 if output_attentions else 1] - if output_router_logits: - all_router_logits = collection_outputs[-1] - hidden_states = self.norm(hidden_states) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - if not return_dict: - return tuple( - v - for v in [hidden_states, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) - return MoeModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) - - -class FlaxGrok1Model(Grok1PreTrainedModel): - module_class = FlaxGrok1Module - - -class FlaxGrok1ForCausalLMModule(nn.Module): - config: Grok1Config - dtype: jnp.dtype = jnp.bfloat16 - param_dtype: jnp.dtype = jnp.bfloat16 - precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest") - - def setup(self) -> None: - self.model = FlaxGrok1Module( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.lm_head = Linear( - self.config.vocab_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - use_bias=False, - kernel_init=nn.initializers.normal(self.config.initializer_range), - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - - self.output_multiplier_scale = self.config.output_multiplier_scale - - def __call__( - self, - input_ids: chex.Array, - attention_mask: Optional[chex.Array] = None, - position_ids: Optional[chex.Array] = None, - inputs_embeds: Optional[chex.Array] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - init_cache: bool = False, - deterministic: bool = True, - return_dict: bool = True, - ) -> MoeCausalLMOutput | Tuple: - if output_router_logits is None: - output_router_logits = self.config.output_router_logits - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, - init_cache=init_cache, - deterministic=deterministic, - return_dict=True, - ) - logits = self.lm_head(outputs.last_hidden_state) - logits = logits * self.output_multiplier_scale - batch_size, seq_length, hd = logits.shape - aux_loss = None - if output_router_logits and outputs.router_logits is not None: - aux_loss = auxiliary_load_balancing_loss_func( - gate_logits=tuple([logit.reshape(batch_size * seq_length, -1) for logit in outputs.router_logits]), - num_experts=self.num_experts, - top_k=self.num_experts_per_tok, - attention_mask=attention_mask - ) - aux_loss = aux_loss * self.config.router_aux_loss_coef - if not return_dict: - outputs = (logits,) + tuple( - v - for v in [ - aux_loss, - outputs.hidden_states, - outputs.attentions, - outputs.router_logits - ] - if v is not None - ) - return outputs - - return MoeCausalLMOutput( - aux_loss=aux_loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_logits=outputs.router_logits, - ) - - -class FlaxGrok1ForCausalLM(Grok1PreTrainedModel): - module_class = FlaxGrok1ForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None): - """ - The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. - - :param self: Access variables that belong to the class - :param input_ids: Pass in the input tokens - :param max_length: Set the length of the sequence to be generated - :param attention_mask: Optional[chex.Array]: Mask the attention weights - :return: A dictionary of the past_key_values, attention_mask and position ids - - """ - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - extended_attention_mask = jnp.ones( - (batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice( - extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[ - None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs +import math +from typing import Optional, Tuple, Union + +import chex +from fjformer import linen as nn +import flax.linen.partitioning +import flax.struct +import jax +import jax.numpy as jnp +from fjformer.func import auxiliary_load_balancing_loss_func +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.sharding import PartitionSpec +from transformers.modeling_flax_outputs import FlaxMaskedLMOutput + +from .grok_1_configuration import Grok1Config +from ..attention_module import AttentionModule +from ..easydel_modelling_utils import EasyDeLFlaxPretrainedModel +# easydel.modules +from ..flax_modelling_utils import ( + with_sharding_constraint, + get_gradient_checkpoint_policy, + repeat_kv_bnsh, + apply_rotary_pos_emb, + precompute_freq_cis, + get_dot_general_by_bits, + BaseJAXAttentionModule, + block_wise_ffn +) +from fjformer.linen import Linear + +re_mat = flax.linen.partitioning.remat + + +@flax.struct.dataclass +class MoeModelOutput: + last_hidden_state: chex.Array = None + hidden_states: Optional[Tuple[chex.Array]] = None + attentions: Optional[Tuple[chex.Array]] = None + router_logits: Optional[Tuple[chex.Array]] = None + + +@flax.struct.dataclass +class MoeCausalLMOutput(FlaxMaskedLMOutput): + aux_loss: Optional[chex.Array] = None + router_logits: Optional[Tuple[chex.Array]] = None + + +class FlaxGrok1Embedding(nn.Module): + dtype: jnp.dtype = jnp.float32 + + def __call__(self, query, key, freq_cis, position_ids): + sin, cos = freq_cis + + sin = sin[position_ids][:, None, :, :] + cos = cos[position_ids][:, None, :, :] + + key = apply_rotary_pos_emb(key, sin, cos) + query = apply_rotary_pos_emb(query, sin, cos) + + return query.astype(self.dtype), key.astype(self.dtype) + + +def repeat_kv(x: chex.Array, n_rep: int) -> chex.Array: + bs, s, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + x = x[:, :, jnp.newaxis, :, :] + x = jnp.repeat(x, n_rep, axis=2) + + return x.reshape(bs, s, + n_kv_heads * n_rep, + head_dim) + + +class FlaxGrok1RMSNorm(nn.Module): + dim: int + eps: float = 1e-6 + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.weight = self.param( + 'kernel', + nn.initializers.ones, + (self.dim,), + self.param_dtype, + ) + + def _norm(self, x: jnp.ndarray) -> jnp.ndarray: + return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = x.astype(jnp.promote_types(self.dtype, jnp.float32)) + output = self._norm(x).astype(self.dtype) + weight = nn.linen.control_quantization(self.weight, self.dtype) + return output * weight + + +class FlaxGrok1Attention(BaseJAXAttentionModule): + config: Grok1Config + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + config = self.config + self.hidden_size = config.hidden_size + self.head_dim = self.config.hidden_size // self.config.num_attention_heads + self.num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads + + if self.num_key_value_groups == 1: + assert self.config.num_attention_heads == self.config.num_key_value_heads + self.q_proj = Linear( + config.num_attention_heads * self.head_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal( + self.config.initializer_range), + precision=self.precision, + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + self.k_proj = Linear( + config.num_key_value_heads * self.head_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal( + self.config.initializer_range), + precision=self.precision, + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + self.v_proj = Linear( + config.num_key_value_heads * self.head_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal( + self.config.initializer_range), + precision=self.precision, + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + self.o_proj = Linear( + config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal( + self.config.initializer_range), + precision=self.precision, + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + + self.rotary = FlaxGrok1Embedding(self.dtype) + self.attention_performer = AttentionModule( + use_sharding_constraint=self.config.use_sharding_constraint, + block_k_major=self.config.block_k_major, + block_b=self.config.block_b, + block_q=self.config.block_q, + block_k=self.config.block_k, + block_q_major_dkv=self.config.block_q_major_dkv, + block_k_major_dkv=self.config.block_k_major_dkv, + block_k_major_dq=self.config.block_k_major_dq, + block_k_dkv=self.config.block_k_dkv, + block_q_dkv=self.config.block_q_dkv, + block_q_dq=self.config.block_q_dq, + block_k_dq=self.config.block_k_dq, + num_attention_heads=self.config.num_attention_heads, + attention_dropout=self.config.attention_dropout, + head_dims=self.head_dim, + attention_partition_spec=self.config.attention_partition_spec, + shard_attention_computation=self.config.shard_attention_computation, + precision=self.precision, + force_float32_tpu=True, + attn_mechanism=self.config.attn_mechanism, + dtype=self.dtype, + bias_partition_spec=self.config.bias_partition_spec, + key_partition_spec=self.config.key_partition_spec, + query_partition_spec=self.config.query_partition_spec, + generation_query_partition_spec=self.config.generation_query_partition_spec, + generation_bias_partition_spec=self.config.generation_bias_partition_spec, + generation_attention_partition_spec=self.config.generation_attention_partition_spec, + value_partition_spec=self.config.value_partition_spec, + scan_ring_attention=self.config.scan_ring_attention, + mesh=self.config.jax_mesh(), + sm_scale=1 / math.sqrt(self.head_dim), + axis_name=self.config.attention_axis_name + ) + self.resid_dropout = flax.linen.Dropout(rate=config.resid_pdrop) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) + + @staticmethod + def _transpose_sequence_head(query, key, value): + """The _transpose_sequence_head function transposes the query, key and value matrices. + + Args: + query: Get the attention weights for each of the heads + key: Determine the number of heads + value: Store the values of the input + + Returns: + The transpose of the query, key and value matrices + """ + return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) + + def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + The main difference is that it takes in an additional argument, freq_cis, which are used to calculate + the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. + + Args: + self: Access variables that belong to the class + batch_size: Reshape the query, key and value tensors + sequence_length: Reshape the query, key and value tensors + query: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + freq_cis: Calculate the frequency of each word in the + vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query, key and value + """ + query = query.reshape( + batch_size, + sequence_length, + self.config.num_attention_heads, + self.head_dim + ) + key = key.reshape( + batch_size, + sequence_length, + self.config.num_key_value_heads, + self.head_dim + ) + value = value.reshape( + batch_size, + sequence_length, + self.config.num_key_value_heads, + self.head_dim + ) + + query, key, value = self._transpose_sequence_head(query, key, value) + query, key = self.rotary( + position_ids=position_ids, query=query, key=key, freq_cis=freq_cis + ) + key = repeat_kv_bnsh(key, self.num_key_value_groups) + value = repeat_kv_bnsh(value, self.num_key_value_groups) + return self._transpose_sequence_head(query, key, value) + + def __call__( + self, + hidden_states: chex.Array, + freq_cis: Tuple[chex.Array, chex.Array], + attention_mask: chex.Array, + position_ids: chex.Array, + causal_mask: chex.Array, + segment_ids: Optional[chex.Array] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + fcm_mask=None, + ): + """The __call__ function is the main function of a JAX module. It defines how the module behaves when called + with inputs. The __call__ function can be thought of as a "forward pass" through the model, + and it should return all outputs that are needed for training or inference. + + Args: + self: Access variables that belong to the class + hidden_states: chex.Array: Pass the hidden states of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency coefficients for each position + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Determine the position of each + token in a sequence + causal_mask: chex.Array: Mask out the future tokens in the + decoder + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attention weights or not + fcm_mask: Mask out the attention weights between the input + and output tokens + :param : Determine if the attention is causal or not + + Returns: + A tuple of two arrays + """ + batch_size, sequence_length = hidden_states.shape[:2] + query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj( + hidden_states) + + query_states = query_states.reshape( + batch_size, sequence_length, self.config.num_attention_heads, self.head_dim) + key_states = key_states.reshape( + batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.reshape( + batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim) + + query_states, key_states, value_states = self.apply_rotary( + query=query_states, + key=key_states, + value=value_states, + position_ids=position_ids, + freq_cis=freq_cis, + batch_size=batch_size, + sequence_length=sequence_length + ) + + assert_msg = ( + "num_attention_heads repeat wont work likely\n" + f"INFO :\n\trepeat_kv_bnsh Used with num_key_value_groups = {self.num_key_value_groups}\n\t" + f"NH : {self.config.num_attention_heads} KVH : {self.config.num_attention_heads}" + ) + + assert query_states.shape[-2] == self.config.num_attention_heads, assert_msg + assert key_states.shape[-2] == self.config.num_attention_heads, assert_msg + assert value_states.shape[-2] == self.config.num_attention_heads, assert_msg + + query_length, key_length = query_states.shape[1], key_states.shape[1] + + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + causal_mask, + (0, 0, mask_shift, 0), + (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = causal_mask[:, :, :query_length, :key_length] + + batch_size = hidden_states.shape[0] + causal_mask = jnp.broadcast_to( + causal_mask, (batch_size,) + causal_mask.shape[1:]) + attention_mask = jnp.broadcast_to(jnp.expand_dims( + attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask) + if attention_mask.ndim == 2: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + dropout_rng = None + + if not deterministic and self.config.attention_dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + if self.has_variable("cache", "cached_key") or init_cache: + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, + value_states, + query_states, + attention_mask + ) + + # if self.config.use_sharding_constraint: + # query_states = with_sharding_constraint( + # query_states, PartitionSpec(("dp", "fsdp"), "sp" if query_states.shape[1] != 1 else None, "tp", None) + # ) + # key_states = with_sharding_constraint( + # key_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None) + # ) + # value_states = with_sharding_constraint( + # value_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None) + # ) + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo( + self.dtype).min).astype(self.dtype), + ) + + query_length, key_length = query_states.shape[1], key_states.shape[1] + + attentions = self.attention_performer.__call__( + query_states=query_states, + key_states=key_states, + value_states=value_states, + bias=attention_bias, + attention_mask=attention_mask, + causal=True, + dropout_rng=dropout_rng, + deterministic=deterministic, + query_sequence_length=query_length, + key_value_sequence_length=key_length, + uses_cache=self.has_variable("cache", "cached_key") or init_cache, + segment_ids=segment_ids, + causal_mask=causal_mask + ) + + + attn_output = self._merge_heads(attentions.attention_outputs) + if self.config.shard_attention_computation: + attn_output = with_sharding_constraint( + attn_output, PartitionSpec( + ("dp", "fsdp"), + "sp" if attn_output.shape[1] != 1 else None, + "tp" + ) + ) + attn_output = self.o_proj(attn_output) + + attn_output = self.resid_dropout(attn_output, deterministic=deterministic) + outputs = (attn_output, attentions.attention_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxGrok1BLockSparseMLP(nn.Module): + config: Grok1Config + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self) -> None: + config = self.config + + self.linear = Linear( + config.intermediate_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal( + self.config.initializer_range), + precision=self.precision, + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + self.linear_1 = Linear( + config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal( + self.config.initializer_range), + precision=self.precision, + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + self.linear_v = Linear( + config.intermediate_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal( + self.config.initializer_range), + precision=self.precision, + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + """The __call__ function is the main function of a class. + It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments). + The __call__ method enables instances of a class to be called like standard Python functions. + + Args: + self: Represent the instance of the class + x: jnp.ndarray: Pass in the input to the layer + deterministic: bool: Determine whether to use dropout # + IGNORED + + Returns: + A tensor that is the result of applying a dropout function + to x + """ + return self.linear_1(nn.gelu(self.linear(x)) * self.linear_v(x)) + + +class FlaxGrok1BlocKSparesTop2MLPCollection(nn.Module): + config: Grok1Config + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest") + + def setup(self) -> None: + self.layers = [ + FlaxGrok1BLockSparseMLP( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=str(i) + ) + for i in range(self.config.num_experts) + ] + + def __call__( + self, + selected_experts: chex.Array, + hidden_states: chex.Array, + routing_weights: chex.Array, + batch_size: int, + sequence_length: int, + hidden_dim: int + ) -> chex.Array: + final_hidden_state = jnp.zeros_like(hidden_states) + + for index in range(self.config.num_experts): + expert_layer_output = block_wise_ffn( + self.layers[index], + hidden_states, + self.config.scan_mlp_chunk_size, + False + ) if self.config.use_scan_mlp else self.layers[index](hidden_states) + expert_layer_output_exp = jnp.sum( + jnp.multiply( + selected_experts == index, routing_weights + ), axis=-1 + )[:, :, None] * expert_layer_output + final_hidden_state += expert_layer_output_exp + return final_hidden_state + + +class FlaxGrok1SparseMoeBlock(nn.Module): + """This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + config: Grok1Config + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[ + Union[None, jax.lax.Precision] + ] = jax.lax.Precision("fastest") + + def setup(self) -> None: + self.gate = Linear( + self.config.num_experts, + use_bias=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + kernel_init=nn.initializers.normal(), + ) + + self.experts = FlaxGrok1BlocKSparesTop2MLPCollection( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + def __call__( + self, + hidden_states: chex.Array, + e: bool = False # Ignored + ) -> Tuple[chex.Array, chex.Array]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + + router_logits = self.gate(hidden_states).astype( + jnp.promote_types(self.dtype, jnp.float32) + ) + routing_weights, selected_experts = jax.lax.top_k( + router_logits, + k=self.config.num_experts_per_tok + ) + routing_weights = jax.nn.softmax( + routing_weights.astype( + jnp.promote_types(self.dtype, jnp.float32) + ), axis=-1 + ) + + return self.experts( + selected_experts=selected_experts, + batch_size=batch_size, + sequence_length=sequence_length, + hidden_dim=hidden_dim, + hidden_states=hidden_states, + routing_weights=routing_weights + ), router_logits + + +class FlaxGrok1DecoderLayer(nn.Module): + config: Grok1Config + layer_index: int + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest") + + def setup(self) -> None: + # hidden_states: chex.Array + # freq_cis: Tuple[chex.Array, chex.Array], + # attention_mask: chex.Array + # causal_mask: chex.Array + # position_ids: chex.Array + # deterministic: bool = True + # init_cache: bool = False + # output_attentions: bool = True + + attn_block = FlaxGrok1Attention + mlp_block = FlaxGrok1SparseMoeBlock + if self.config.gradient_checkpointing != "": + attn_block = re_mat( + attn_block, + policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing), + static_argnums=( + 3, 5, 6, 7, 8 + ) + ) + mlp_block = re_mat( + mlp_block, + policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing), + static_argnums=( + 1, + ) + ) + self.attn = attn_block( + config=self.config, + layer_index=self.layer_index, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.moe_block = mlp_block( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.pre_attn_norm = FlaxGrok1RMSNorm( + dim=self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + self.post_attn_norm = FlaxGrok1RMSNorm( + dim=self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + self.pre_moe_norm = FlaxGrok1RMSNorm( + dim=self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + self.post_moe_norm = FlaxGrok1RMSNorm( + dim=self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + + def __call__( + self, + hidden_states: chex.Array, + freq_cis: Tuple[chex.Array, chex.Array], + attention_mask: chex.Array, + causal_mask: chex.Array, + position_ids: chex.Array, + segment_ids: Optional[chex.Array] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = True, + output_router_logits: Optional[bool] = False, + ): + """The __call__ function is the main function of a TransformerEncoderLayer. + It takes in the following arguments: + hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers. + freq_cis (chex.Array): A tensor containing frequency-domain representations of each token's context vector, used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. This tensor has shape `(batch_size, num + + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Represent the input to the + encoder layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency + information to the attention layer + attention_mask: chex.Array: Mask out the attention weights + for certain positions + causal_mask: chex.Array: Mask the future tokens + position_ids: chex.Array: Indicate the position of each + token in the sequence + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache for the self- + attention layer + output_attentions: bool: Determine whether to return the + attention weights or not + + Returns: + A tuple of hidden_states and attention_output + """ + residual = hidden_states + hidden_states = self.pre_attn_norm(hidden_states) + hidden_states, attention_weights, present_key_value = self.attn( + hidden_states, + freq_cis, + attention_mask, + causal_mask, + position_ids, + segment_ids, + deterministic, + init_cache, + output_attentions + ) + + hidden_states = self.post_attn_norm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_moe_norm(hidden_states) + hidden_states, router_logits = self.moe_block(hidden_states) + hidden_states = self.post_moe_norm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (attention_weights,) + if output_router_logits: + outputs += (router_logits,) + return outputs + + +class FlaxGrok1DecoderLayerCollection(nn.Module): + config: Grok1Config + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest") + + def setup(self) -> None: + self.blocks = [ + FlaxGrok1DecoderLayer( + layer_index=layer_index, + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=str(layer_index) + ) + + for layer_index in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states: chex.Array, + freq_cis: Tuple[chex.Array, chex.Array], + attention_mask: chex.Array, + causal_mask: chex.Array, + position_ids: chex.Array, + deterministic: bool = True, + init_cache: bool = False, + output_hidden_states: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + ): + """The __call__ function is the main function of a TransformerEncoderLayer. + It takes in the following arguments: + hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers. + freq_cis (chex.Array): A tensor containing frequency-domain representations of each token's context vector, used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. This tensor has shape `(batch_size, num + + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Represent the input to the + encoder layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency + information to the attention layer + attention_mask: chex.Array: Mask out the attention weights + for certain positions + causal_mask: chex.Array: Mask the future tokens + position_ids: chex.Array: Indicate the position of each + token in the sequence + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache for the self- + attention layer + output_attentions: bool: Determine whether to return the + attention weights or not + + Returns: + A tuple of hidden_states, attention_output, + all_hidden_states and all_router_logits + """ + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for block in self.blocks: + if output_hidden_states: + all_hidden_states += (hidden_states,) + layer_outputs = block( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + init_cache=init_cache, + freq_cis=freq_cis, + causal_mask=causal_mask, + deterministic=deterministic, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + outputs = (hidden_states,) + if output_attentions: + outputs += (all_self_attns,) + if output_hidden_states: + outputs += (all_hidden_states,) + if output_router_logits: + outputs += (all_router_logits,) + return outputs + + +class Grok1PreTrainedModel(EasyDeLFlaxPretrainedModel): + config_class: Grok1Config = Grok1Config + module_class: nn.Module = None + base_model_prefix = "model" + + # main_input_name = "input_ids" + + def __init__( + self, + config: Grok1Config, + dtype: jnp.dtype = jnp.bfloat16, + param_dtype: jnp.dtype = jnp.bfloat16, + precision: Optional[jax.lax.Precision] = jax.lax.Precision( + "fastest"), + input_shape: Tuple[int, int] = (1, 1), + seed: int = 0, + _do_init: bool = False, + **kwargs + ): + module = self.module_class( + config=config, + dtype=dtype, + param_dtype=param_dtype, + precision=precision, + **kwargs + ) + + super().__init__( + dtype=dtype, _do_init=_do_init, + module=module, config=config, input_shape=input_shape, + seed=seed, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, + params: Optional[FrozenDict] = None) -> FrozenDict: + """The init_weights function is used to initialize the weights of a model. + It takes in a rng, which is a random number generator key that can be used to generate random numbers. + The input_shape parameter specifies the shape of the inputs that will be fed into this model. + The params parameter allows you to pass in pre-trained weights for your model, if you have them available. + + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Initialize the input_ids, attention_mask + and position_ids + params: flax.core.FrozenDict: Pass in the parameters of a + pre-trained model + + Returns: + A frozendict of parameters + """ + + self.config.initialization_of_moe = True + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1], dtype="i4"), + input_shape, + ) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros( + input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=False + ) + random_params = module_init_outputs["params"] + + self.config.initialization_of_moe = False + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange( + jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return init_variables["cache"] + + def __call__( + self, + input_ids: chex.Array, + attention_mask: Optional[chex.Array] = None, + position_ids: Optional[chex.Array] = None, + params: dict = None, + past_key_values: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + add_params_field: bool = False, + **kwargs + ): + """The __call__ function is the main function of a JAX module. + It takes as input: + - The parameters of the model (self.params) + - The inputs to the model (input_ids, attention_mask, position_ids) + - Whether we are training (train=True/False) and whether we want to return all hidden states and + attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False). + + Args: + self: Represent the instance of the class + input_ids: Pass the input sequence to the model + attention_mask: Mask out the padding tokens + position_ids: Specify the position of each token in the + sequence + params: dict: Pass in the parameters of the model + past_key_values: dict: Pass the past key values to the model + dropout_rng: jax.random.PRNGKey: Pass in a random number + generator key to the model + train: bool: Determine whether to use dropout or not + output_attentions: Optional[bool]: Determine whether to + return the attention weights + output_hidden_states: Optional[bool]: Determine whether to + return the hidden states of all layers + return_dict: Optional[bool]: Return a dictionary of the + outputs + add_params_field: bool: Add a params field to the inputs + dictionary + + Returns: + A tuple of (last_hidden_state, past_key_values) + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + batch_size, sequence_length = input_ids.shape + + if position_ids is None: + if past_key_values is not None: + raise ValueError( + "Make sure to provide `position_ids` when passing `past_key_values`.") + + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[ + None, :], (batch_size, sequence_length)) + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + rng_s = {} + if dropout_rng is not None: + rng_s["dropout"] = dropout_rng + + inputs = { + "params": params or self.params} if add_params_field else params or self.params + + if self.config.bits is not None: + rng_s['params'] = jax.random.key(0) + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), # input_ids: chex.Array + # attention_mask: Optional[chex.Array] = None + jnp.array(attention_mask, dtype="i4"), + # position_ids: Optional[chex.Array] = None + jnp.array(position_ids, dtype="i4"), + None, # inputs_embeds: Optional[chex.Array] = None + output_attentions, # output_attentions: Optional[bool] = None + # output_hidden_states: Optional[bool] = None + output_hidden_states, + # output_router_logits: Optional[bool] = None + output_router_logits, + False, # init_cache: bool = False + not train, # deterministic: bool = True + return_dict, # return_dict: bool = True + rngs=rng_s, + mutable=mutable, + ) + + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + \ + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxGrok1Module(nn.Module): + config: Grok1Config + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest") + + def setup(self) -> None: + self.embed_tokens = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + + self.layers = FlaxGrok1DecoderLayerCollection( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + self.norm = FlaxGrok1RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + + initial_rope_kwargs = dict( + rope_type="none" + ) + if self.config.rope_scaling is not None: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + initial_rope_kwargs = dict( + scaling_factor=scaling_factor, + rope_type=scaling_type + ) + self.freq_cis = precompute_freq_cis( + max_position_embeddings=( + getattr(self.config, "freq_max_position_embeddings", self.config.max_position_embeddings) + ), + dim=self.config.hidden_size // self.config.num_attention_heads, + base=self.config.rope_theta, + **initial_rope_kwargs + ) + self.causal_mask = flax.linen.make_causal_mask( + jnp.ones( + (1, getattr(self.config, "c_max_position_embeddings", self.config.max_position_embeddings)), + dtype="bool" + ), dtype="bool" + ) + + def __call__( + self, + input_ids: chex.Array, + attention_mask: chex.Array, + position_ids: chex.Array, + inputs_embeds: Optional[chex.Array] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + init_cache: bool = False, + deterministic: bool = True, + return_dict: bool = True, + ) -> MoeModelOutput | Tuple: + if output_router_logits is None: + output_router_logits = self.config.output_router_logits + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + + if inputs_embeds is None and input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids.astype("i4")) + inputs_embeds = inputs_embeds * self.config.embedding_multiplier_scale + else: + raise ValueError( + "you should specify inputs_embeds or input_ids one of them") + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + collection_outputs = self.layers( + hidden_states=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + causal_mask=self.causal_mask, + freq_cis=self.freq_cis, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + output_hidden_states=output_hidden_states, + init_cache=init_cache, + deterministic=deterministic, + ) + all_self_attns = None + all_hidden_states = None + all_router_logits = None + hidden_states = collection_outputs[0] + if output_attentions: + all_self_attns = collection_outputs[1] + if output_hidden_states: + all_hidden_states = collection_outputs[2 if output_attentions else 1] + if output_router_logits: + all_router_logits = collection_outputs[-1] + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class FlaxGrok1Model(Grok1PreTrainedModel): + module_class = FlaxGrok1Module + + +class FlaxGrok1ForCausalLMModule(nn.Module): + config: Grok1Config + dtype: jnp.dtype = jnp.bfloat16 + param_dtype: jnp.dtype = jnp.bfloat16 + precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest") + + def setup(self) -> None: + self.model = FlaxGrok1Module( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.lm_head = Linear( + self.config.vocab_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + use_bias=False, + kernel_init=nn.initializers.normal(self.config.initializer_range), + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + + self.output_multiplier_scale = self.config.output_multiplier_scale + + def __call__( + self, + input_ids: chex.Array, + attention_mask: Optional[chex.Array] = None, + position_ids: Optional[chex.Array] = None, + inputs_embeds: Optional[chex.Array] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + init_cache: bool = False, + deterministic: bool = True, + return_dict: bool = True, + ) -> MoeCausalLMOutput | Tuple: + if output_router_logits is None: + output_router_logits = self.config.output_router_logits + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + init_cache=init_cache, + deterministic=deterministic, + return_dict=True, + ) + logits = self.lm_head(outputs.last_hidden_state) + logits = logits * self.output_multiplier_scale + batch_size, seq_length, hd = logits.shape + aux_loss = None + if output_router_logits and outputs.router_logits is not None: + aux_loss = auxiliary_load_balancing_loss_func( + gate_logits=tuple([logit.reshape(batch_size * seq_length, -1) for logit in outputs.router_logits]), + num_experts=self.num_experts, + top_k=self.num_experts_per_tok, + attention_mask=attention_mask + ) + aux_loss = aux_loss * self.config.router_aux_loss_coef + if not return_dict: + outputs = (logits,) + tuple( + v + for v in [ + aux_loss, + outputs.hidden_states, + outputs.attentions, + outputs.router_logits + ] + if v is not None + ) + return outputs + + return MoeCausalLMOutput( + aux_loss=aux_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class FlaxGrok1ForCausalLM(Grok1PreTrainedModel): + module_class = FlaxGrok1ForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None): + """ + The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. + + :param self: Access variables that belong to the class + :param input_ids: Pass in the input tokens + :param max_length: Set the length of the sequence to be generated + :param attention_mask: Optional[chex.Array]: Mask the attention weights + :return: A dictionary of the past_key_values, attention_mask and position ids + + """ + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + extended_attention_mask = jnp.ones( + (batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice( + extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[ + None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs diff --git a/src/python/easydel/modules/jetmoe/jetmoe_configuration.py b/src/python/easydel/modules/jetmoe/jetmoe_configuration.py index 62c56391c..52d32feae 100644 --- a/src/python/easydel/modules/jetmoe/jetmoe_configuration.py +++ b/src/python/easydel/modules/jetmoe/jetmoe_configuration.py @@ -63,15 +63,17 @@ def __init__( ) def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices. - :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to + partition the model fully or not + Returns: + A list of tuples """ return ( (".*", PartitionSpec(("fsdp", "sp"))), @@ -84,13 +86,16 @@ def add_jax_args( bits: Optional[int] = None, **kwargs, ): - """ - The add_jax_args function adds the following arguments to the Transformer class: + """The add_jax_args function adds the following arguments to the Transformer class: - :param self: Refer to the current object - :param tie_word_embeddings: bool: Tie the word embeddings to the decoder - :param gradient_checkpointing: str: Control the amount of memory used by jax - :param bits: Optional[int]: Determine the number of bits used in the quantization + Args: + self: Refer to the current object + tie_word_embeddings: bool: Tie the word embeddings to the + decoder + gradient_checkpointing: str: Control the amount of memory + used by jax + bits: Optional[int]: Determine the number of bits used in + the quantization """ self.tie_word_embeddings = tie_word_embeddings self.gradient_checkpointing = gradient_checkpointing diff --git a/src/python/easydel/modules/jetmoe/modelling_jetmoe_flax.py b/src/python/easydel/modules/jetmoe/modelling_jetmoe_flax.py index f68d5a7b0..52cea7b7f 100644 --- a/src/python/easydel/modules/jetmoe/modelling_jetmoe_flax.py +++ b/src/python/easydel/modules/jetmoe/modelling_jetmoe_flax.py @@ -69,9 +69,7 @@ def _make_sliding_window_causal_mask( past_key_values_length: int = 0, sliding_window: int = 4096, ): - """ - Make causal mask used for sliding window attention - """ + """Make causal mask used for sliding window attention""" bsz, tgt_len = input_ids_shape tensor = jnp.full( @@ -91,9 +89,7 @@ def _make_sliding_window_causal_mask( def compute_gating(k: int, num_experts: int, top_k_gates: jnp.ndarray, top_k_indices: jnp.ndarray) -> Tuple[ chex.Array, chex.Array, chex.Array, chex.Array ]: - """ - Compute gating values for the mixture of experts based on probabilities and top-k indices. - """ + """Compute gating values for the mixture of experts based on probabilities and top-k indices.""" zeros = jnp.zeros([top_k_gates.shape[0], num_experts], dtype=top_k_gates.dtype) gates = zeros.at[jnp.arange(zeros.shape[0])[:, None], top_k_indices].set(1) expert_size = gates.astype(jnp.int32).sum(axis=0) diff --git a/src/python/easydel/modules/llama/llama_configuration.py b/src/python/easydel/modules/llama/llama_configuration.py index e2aff48ab..5b8681002 100644 --- a/src/python/easydel/modules/llama/llama_configuration.py +++ b/src/python/easydel/modules/llama/llama_configuration.py @@ -40,47 +40,68 @@ def __init__( scan_layers: bool = False, **kwargs, ): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It sets up the attributes of an object, which are sometimes called fields or properties. The __init__ function can accept arguments, but self must be the first one. - :param self: Refer to the object itself - :param vocab_size: int: Set the size of the vocabulary - :param hidden_size: int: Set the size of the hidden layers in each transformer block - :param intermediate_size: int: Set the size of the intermediate layer - :param num_hidden_layers: int: Determine the number of layers in the transformer - :param num_attention_heads: int: Determine the number of attention heads - :param number_rep_kv: int: Set the number of times to repeat the key and value vectors - :param num_key_value_heads: Optional[int]: Define the number of key-value heads - :param max_position_embeddings: int: Set the maximum length of a sequence - :param rms_norm_eps: float: Prevent division by zero in the rms normalization - :param initializer_range: float: Initialize the weights of the model - :param use_cache: bool: Determine whether the attention layer should use a cache for faster computation - :param bos_token_id: int: Set the beginning of sequence token - :param eos_token_id: int: Specify the end of sentence token - :param resid_pdrop: float: Set the dropout rate for residual connections - :param embd_pdrop: float: Dropout the embedding layer - :param attention_dropout: float: Dropout the attention weights - :param tie_word_embeddings: bool: Tie the word embeddings and output layer weights - :param gradient_checkpointing: str: Specify how to checkpoint the gradients - :param fcm_min_ratio: float: Set the minimum ratio of the number of elements in a tensor to be processed by flash - :param fcm_max_ratio: float: Determine the maximum ratio of - :param rope_scaling: Dict[str: Define the scaling of the rope - :param Union[str: Specify the type of the parameter - :param float]]: Specify the type of the parameter - :param shard_attention_computation: bool: when ever to use shard_map for attention - :param bits: Optional[int]: Specify the number of bits used to quantize the weights - :param rope_theta: float : rope_theta for compute rope - :param attention_bias: bool : whenever to use attention bias or no - :param hidden_act: str : hidden_act for mlp - :param axis_dims: Sequence[int]: Specify the dimensions of each axis - :param axis_names: Sequence[str]: Specify the names of the axes in a tensor - :param scan_layers: bool: Determine whether to use the scan_layers or not - :param kwargs: Pass a variable number of keyword arguments to a function + Args: + self: Refer to the object itself + vocab_size: int: Set the size of the vocabulary + hidden_size: int: Set the size of the hidden layers in each + transformer block + intermediate_size: int: Set the size of the intermediate + layer + num_hidden_layers: int: Determine the number of layers in + the transformer + num_attention_heads: int: Determine the number of attention + heads + number_rep_kv: int: Set the number of times to repeat the + key and value vectors + num_key_value_heads: Optional[int]: Define the number of + key-value heads + max_position_embeddings: int: Set the maximum length of a + sequence + rms_norm_eps: float: Prevent division by zero in the rms + normalization + initializer_range: float: Initialize the weights of the + model + use_cache: bool: Determine whether the attention layer + should use a cache for faster computation + bos_token_id: int: Set the beginning of sequence token + eos_token_id: int: Specify the end of sentence token + resid_pdrop: float: Set the dropout rate for residual + connections + embd_pdrop: float: Dropout the embedding layer + attention_dropout: float: Dropout the attention weights + tie_word_embeddings: bool: Tie the word embeddings and + output layer weights + gradient_checkpointing: str: Specify how to checkpoint the + gradients + fcm_min_ratio: float: Set the minimum ratio of the number of + elements in a tensor to be processed by flash + fcm_max_ratio: float: Determine the maximum ratio of + rope_scaling: Dict[str: Define the scaling of the rope + Union[str: Specify the type of the parameter + float]]: Specify the type of the parameter + shard_attention_computation: bool: when ever to use + shard_map for attention + bits: Optional[int]: Specify the number of bits used to + quantize the weights + rope_theta: float : rope_theta for compute rope + attention_bias: bool : whenever to use attention bias or no + hidden_act: str : hidden_act for mlp + axis_dims: Sequence[int]: Specify the dimensions of each + axis + axis_names: Sequence[str]: Specify the names of the axes in + a tensor + scan_layers: bool: Determine whether to use the scan_layers + or not + **kwargs: Pass a variable number of keyword arguments to a + function :param : Define the number of layers in the model - :return: Nothing + Returns: + Nothing """ num_key_value_heads = num_key_value_heads or number_rep_kv * num_attention_heads self.num_key_value_heads = num_key_value_heads @@ -118,15 +139,17 @@ def __init__( ) def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices. - :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to + partition the model fully or not + Returns: + A list of tuples """ return ( @@ -181,23 +204,33 @@ def add_jax_args( scan_layers: bool = True, **kwargs, ): - """ - The add_jax_args function adds the following arguments to the Transformer class: - - :param self: Refer to the current object - :param resid_pdrop: float: Set the dropout rate for residual connections - :param embd_pdrop: float: Set the probability of dropping an embedding - :param attention_dropout: float: Set the probability of dropping out the attention layer - :param tie_word_embeddings: bool: Tie the word embeddings to the decoder - :param gradient_checkpointing: str: Control the amount of memory used by jax - :param fcm_min_ratio: float: Control the minimum ratio of the number of chunks to be used in flash-based computation - :param fcm_max_ratio: float: Set the maximum ratio of the number of input tokens to output tokens - :param number_rep_kv: int: Determine how many times the key and value vectors are repeated - :param bits: Optional[int]: Determine the number of bits used in the quantization - :param rope_theta: float : rope_theta for compute rope - :param attention_bias: bool : whenever to use attention bias or no - :param hidden_act: str : hidden_act for mlp - :param scan_layers: bool: Determine whether to use scan layers or not + """The add_jax_args function adds the following arguments to the Transformer class: + + Args: + self: Refer to the current object + resid_pdrop: float: Set the dropout rate for residual + connections + embd_pdrop: float: Set the probability of dropping an + embedding + attention_dropout: float: Set the probability of dropping + out the attention layer + tie_word_embeddings: bool: Tie the word embeddings to the + decoder + gradient_checkpointing: str: Control the amount of memory + used by jax + fcm_min_ratio: float: Control the minimum ratio of the + number of chunks to be used in flash-based computation + fcm_max_ratio: float: Set the maximum ratio of the number of + input tokens to output tokens + number_rep_kv: int: Determine how many times the key and + value vectors are repeated + bits: Optional[int]: Determine the number of bits used in + the quantization + rope_theta: float : rope_theta for compute rope + attention_bias: bool : whenever to use attention bias or no + hidden_act: str : hidden_act for mlp + scan_layers: bool: Determine whether to use scan layers or + not """ self.scan_layers = scan_layers self.embd_pdrop = embd_pdrop diff --git a/src/python/easydel/modules/llama/modelling_llama_flax.py b/src/python/easydel/modules/llama/modelling_llama_flax.py index 656108daf..e33596253 100644 --- a/src/python/easydel/modules/llama/modelling_llama_flax.py +++ b/src/python/easydel/modules/llama/modelling_llama_flax.py @@ -178,33 +178,37 @@ def _merge_heads(self, hidden_states): @staticmethod def _transpose_sequence_head(query, key, value): - """ - The _transpose_sequence_head function transposes the query, key and value matrices. + """The _transpose_sequence_head function transposes the query, key and value matrices. - :param query: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query, key and value matrices + Args: + query: Get the attention weights for each of the heads + key: Determine the number of heads + value: Store the values of the input + Returns: + The transpose of the query, key and value matrices """ return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - :param self: Access variables that belong to the class - :param batch_size: Reshape the query, key and value tensors - :param sequence_length: Reshape the query, key and value tensors - :param query: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param freq_cis: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query, key and value - + Args: + self: Access variables that belong to the class + batch_size: Reshape the query, key and value tensors + sequence_length: Reshape the query, key and value tensors + query: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + freq_cis: Calculate the frequency of each word in the + vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query, key and value """ query = query.reshape( batch_size, @@ -246,25 +250,32 @@ def __call__( output_attentions: bool = False, fcm_mask=None, ): - """ - - The __call__ function is the main function of a JAX module. It defines how the module behaves when called + """The __call__ function is the main function of a JAX module. It defines how the module behaves when called with inputs. The __call__ function can be thought of as a "forward pass" through the model, and it should return all outputs that are needed for training or inference. - :param self: Access variables that belong to the class - :param hidden_states: chex.Array: Pass the hidden states of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param causal_mask: chex.Array: Mask out the future tokens in the decoder - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attention weights or not - :param fcm_mask: Mask out the attention weights between the input and output tokens + Args: + self: Access variables that belong to the class + hidden_states: chex.Array: Pass the hidden states of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency coefficients for each position + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Determine the position of each + token in a sequence + causal_mask: chex.Array: Mask out the future tokens in the + decoder + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attention weights or not + fcm_mask: Mask out the attention weights between the input + and output tokens :param : Determine if the attention is causal or not - :return: A tuple of two arrays + Returns: + A tuple of two arrays """ batch_size, sequence_length = hidden_states.shape[:2] ( @@ -428,16 +439,18 @@ def setup(self) -> None: self.dropout = flax.linen.Dropout(rate=self.config.resid_pdrop) def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - """ - The __call__ function is the main function of a class. + """The __call__ function is the main function of a class. It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments). The __call__ method enables instances of a class to be called like standard Python functions. - :param self: Represent the instance of the class - :param x: jnp.ndarray: Pass in the input to the layer - :param deterministic: bool: Determine whether to use dropout - :return: A tensor that is the result of applying a dropout function to x + Args: + self: Represent the instance of the class + x: jnp.ndarray: Pass in the input to the layer + deterministic: bool: Determine whether to use dropout + Returns: + A tensor that is the result of applying a dropout function + to x """ x = self.down_proj(jax.nn.silu(self.gate_proj(x)) * self.up_proj(x)) x = self.dropout(x, deterministic=deterministic) @@ -508,25 +521,32 @@ def __call__( output_attentions: bool = False, fcm_mask: Optional[jnp.ndarray] = None, ): - """ - The __call__ function is the main function of a TransformerEncoderLayer. + """The __call__ function is the main function of a TransformerEncoderLayer. It takes in hidden states, frequency-domain inputs, and masks as input. It then applies self-attention to the hidden states using those inputs and returns an output tensor with shape (batch_size, sequence_length, model_dim). - :param self: Refer to the class instance itself - :param hidden_states: chex.Array: Pass in the hidden state of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency information - :param attention_mask: chex.Array: Mask out the attention weights for padding tokens - :param position_ids: chex.Array: Determine the position of each token in the sequence - :param causal_mask: chex.Array: Mask the attention weights - :param deterministic: bool: Control whether the dropout is applied or not - :param init_cache: bool: Initialize the cache in the attention layer - :param output_attentions: bool: Return the attention weights - :param fcm_mask: Optional[jnp.ndarray]: Mask the self-attention + Args: + self: Refer to the class instance itself + hidden_states: chex.Array: Pass in the hidden state of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency information + attention_mask: chex.Array: Mask out the attention weights + for padding tokens + position_ids: chex.Array: Determine the position of each + token in the sequence + causal_mask: chex.Array: Mask the attention weights + deterministic: bool: Control whether the dropout is applied + or not + init_cache: bool: Initialize the cache in the attention + layer + output_attentions: bool: Return the attention weights + fcm_mask: Optional[jnp.ndarray]: Mask the self-attention :param : Control the dropout in the self attention layer - :return: A tuple of two items + Returns: + A tuple of two items """ attn_outputs = self.self_attn( self.input_layernorm(hidden_states), @@ -577,37 +597,42 @@ def __init__( _do_init: bool = True, **kwargs, ): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It sets up the instance of the class, and defines what happens when it's created. The __init__ function can take arguments, but self is always required (it refers to the instance of the object). - - :param self: Refer to the object itself - :param config: LlamaConfig: Pass the configuration to the module - :param input_shape: Tuple: Specify the shape of the input to the model - :param seed: int: Set the seed for random number generation - :param dtype: jnp.dtype: Specify the data type of the input - :param _do_init: bool: Control whether the module is initialized or not - :param kwargs: Pass in any additional parameters that the module_class might need + Args: + self: Refer to the object itself + config: LlamaConfig: Pass the configuration to the module + input_shape: Tuple: Specify the shape of the input to the + model + seed: int: Set the seed for random number generation + dtype: jnp.dtype: Specify the data type of the input + _do_init: bool: Control whether the module is initialized or + not + **kwargs: Pass in any additional parameters that the + module_class might need :param : Specify the number of layers in the network - :return: The super() of the class + Returns: + The super() of the class """ module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. + """The init_weights function is used to initialize the weights of a model. - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Specify the shape of the input tensor - :param params: FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Specify the shape of the input tensor + params: FrozenDict: Pass in the parameters of a pre-trained + model + Returns: + A frozendict of parameters """ input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -646,17 +671,18 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz return random_params def init_cache(self, batch_size, max_length): - """ - The init_cache function is used to initialize the cache for a given batch size and sequence length. + """The init_cache function is used to initialize the cache for a given batch size and sequence length. The cache is a dictionary that contains all the intermediate states from each layer in the model. This allows us to run inference on multiple batches without having to re-run forward passes through every layer in the model, which would be very slow. - :param self: Access the module - :param batch_size: Define the batch size of the input tensors - :param max_length: Set the length of the input sequence - :return: A dictionary with the following keys: + Args: + self: Access the module + batch_size: Define the batch size of the input tensors + max_length: Set the length of the input sequence + Returns: + A dictionary with the following keys: """ input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) @@ -684,27 +710,36 @@ def __call__( add_params_field: bool = False, **kwargs ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It takes in inputs and returns outputs, but it also has some other important features: - It can take in mutable state (e.g., past_key_values) that will be updated during the call and returned at the end. - It can take in random number generators (rngs) that are used to generate random numbers for dropout or sampling operations. - :param self: Represent the instance of the class - :param input_ids: chex.Array: Pass in the input tokens - :param attention_mask: chex.Array: Mask out certain tokens in the input - :param position_ids: chex.Array: Create the positional embeddings - :param params: dict: Pass in the parameters of the model - :param past_key_values: dict: Pass in the past key values from a previous call to __call__ - :param dropout_rng: jax.random.PRNGKey: Make sure that the dropout is applied in a random way - :param train: bool: Determine whether to use dropout or not - :param output_attentions: Optional[bool]: Determine whether to return the attention weights - :param output_hidden_states: Optional[bool]: Return the hidden states of all layers - :param return_dict: Optional[bool]: Determine whether to return a dictionary or not - :param extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids - :param add_params_field: bool: Add the params field to the inputs dictionary - :return: A tuple of the following: - + Args: + self: Represent the instance of the class + input_ids: chex.Array: Pass in the input tokens + attention_mask: chex.Array: Mask out certain tokens in the + input + position_ids: chex.Array: Create the positional embeddings + params: dict: Pass in the parameters of the model + past_key_values: dict: Pass in the past key values from a + previous call to __call__ + dropout_rng: jax.random.PRNGKey: Make sure that the dropout + is applied in a random way + train: bool: Determine whether to use dropout or not + output_attentions: Optional[bool]: Determine whether to + return the attention weights + output_hidden_states: Optional[bool]: Return the hidden + states of all layers + return_dict: Optional[bool]: Determine whether to return a + dictionary or not + extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in + the embedding for the input_ids + add_params_field: bool: Add the params field to the inputs + dictionary + + Returns: + A tuple of the following: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -802,27 +837,35 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ): - """ - The __call__ function is the main function of a JAX nn.Module. + """The __call__ function is the main function of a JAX nn.Module. It defines how the module behaves when called as a function, and it's what you'll use to call your model in training loops or inference scripts. The __call__ method should take all inputs that are necessary for computing outputs from the module, and return all outputs that are computed by this module. - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Pass the input tensor to the encoder - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency of each token - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Specify the position of each token in a sequence - :param causal_mask: chex.Array: Mask the attention weights - :param deterministic: bool: Determine whether the model is in training or evaluation mode - :param init_cache: bool: Initialize the cache for each layer - :param output_attentions: bool: Determine whether to output the attention weights - :param output_hidden_states: bool: Determine whether to return the hidden states of each layer - :param return_dict: bool: Return a dictionary of the outputs + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Pass the input tensor to the + encoder + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency of each token + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Specify the position of each token + in a sequence + causal_mask: chex.Array: Mask the attention weights + deterministic: bool: Determine whether the model is in + training or evaluation mode + init_cache: bool: Initialize the cache for each layer + output_attentions: bool: Determine whether to output the + attention weights + output_hidden_states: bool: Determine whether to return the + hidden states of each layer + return_dict: bool: Return a dictionary of the outputs :param : Determine whether to use the forgetful causal mask - :return: A tuple of 3 values + Returns: + A tuple of 3 values """ all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -930,26 +973,33 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids + """The __call__ function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids and returns the output of the model. The __call__ function also has optional arguments that can be used to control the behavior of the model (e.g., deterministic=True). These optional arguments are passed as keyword arguments when calling a Flax model. - :param self: Represent the instance of the class - :param input_ids: chex.Array: Pass in the input token ids - :param attention_mask: chex.Array: Mask out the padding tokens - :param position_ids: chex.Array: Indicate the position of each token in a sequence - :param deterministic: bool: Control whether dropout is applied or not - :param inputs_embeds: chex.Array: Pass in the embeddings of the input tokens - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attentions or not - :param output_hidden_states: bool: Determine whether to return hidden states - :param return_dict: bool: Return a dictionary of the output or not - :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of the - :param None]]: Pass in the extra embedding - :return: A tuple of: - + Args: + self: Represent the instance of the class + input_ids: chex.Array: Pass in the input token ids + attention_mask: chex.Array: Mask out the padding tokens + position_ids: chex.Array: Indicate the position of each + token in a sequence + deterministic: bool: Control whether dropout is applied or + not + inputs_embeds: chex.Array: Pass in the embeddings of the + input tokens + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attentions or not + output_hidden_states: bool: Determine whether to return + hidden states + return_dict: bool: Return a dictionary of the output or not + extra_embedding: Optional[Union[jnp.ndarray: Pass in the + embedding of the + None]]: Pass in the extra embedding + + Returns: + A tuple of: """ if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids.astype("i4")) @@ -1041,22 +1091,27 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax module. It takes in inputs and returns outputs. - - :param self: Refer to the object itself - :param input_ids: chex.Array: Pass the input token ids to the model - :param attention_mask: chex.Array: Mask out the padding tokens - :param position_ids: chex.Array: Specify the position of each token in the input sequence - :param deterministic: bool: Control whether the model is trained or not - :param init_cache: bool: Initialize the cache for the decoder - :param output_attentions: bool: Return the attention weights - :param output_hidden_states: bool: Determine whether to return the hidden states - :param return_dict: bool: Return a dictionary of the outputs or not - :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of the word that we want to predict - :param None]]: Pass in the extra embedding - :return: The logits and the hidden states - + """The __call__ function is the main function of a Flax module. It takes in inputs and returns outputs. + + Args: + self: Refer to the object itself + input_ids: chex.Array: Pass the input token ids to the model + attention_mask: chex.Array: Mask out the padding tokens + position_ids: chex.Array: Specify the position of each token + in the input sequence + deterministic: bool: Control whether the model is trained or + not + init_cache: bool: Initialize the cache for the decoder + output_attentions: bool: Return the attention weights + output_hidden_states: bool: Determine whether to return the + hidden states + return_dict: bool: Return a dictionary of the outputs or not + extra_embedding: Optional[Union[jnp.ndarray: Pass in the + embedding of the word that we want to predict + None]]: Pass in the extra embedding + + Returns: + The logits and the hidden states """ batch_size, seq_length = input_ids.shape if attention_mask is None: @@ -1118,15 +1173,18 @@ def set_output_embeddings(self, new_embeddings): self.module.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None): - """ - The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. - - :param self: Access variables that belong to the class - :param input_ids: Pass in the input tokens - :param max_length: Set the length of the sequence to be generated - :param attention_mask: Optional[chex.Array]: Mask the attention weights - :return: A dictionary of the past_key_values, attention_mask and position ids - + """The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. + + Args: + self: Access variables that belong to the class + input_ids: Pass in the input tokens + max_length: Set the length of the sequence to be generated + attention_mask: Optional[chex.Array]: Mask the attention + weights + + Returns: + A dictionary of the past_key_values, attention_mask and + position ids """ batch_size, seq_length = input_ids.shape @@ -1161,12 +1219,14 @@ class FlaxLlamaForSequenceClassificationModule(nn.Module): precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self): - """ - The setup function is called once at the beginning of training. + """The setup function is called once at the beginning of training. It initializes the model and optimizer, and sets up any other state that needs to be initialized. - :param self: Access variables that belong to the class - :return: A tuple of the model and the classifier + Args: + self: Access variables that belong to the class + + Returns: + A tuple of the model and the classifier """ self.model = FlaxLlamaModule(self.config, dtype=self.dtype) self.classifier = Linear( @@ -1191,26 +1251,31 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax module. + """The __call__ function is the main function of a Flax module. It takes in all the inputs to the model and returns all outputs from it. The __call__ function can be called directly on an instance of a class, or by using parentheses after an instance: >>> my_model = MyModel() # instantiate your model class >>> output = my_model(input) # call your model with input data as arguments to __call__ - :param self: Refer to the class instance - :param input_ids: chex.Array: Pass the input to the model - :param attention_mask: chex.Array: Specify which tokens are masked - :param position_ids: chex.Array: Specify the position of each token in the sequence - :param deterministic: bool: Control whether the model is run in deterministic or stochastic mode - :param init_cache: bool: Initialize the cache for the transformer - :param output_attentions: bool: Return the attention weights - :param output_hidden_states: bool: Return the hidden states of all layers - :param return_dict: bool: Return a dictionary of outputs - :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of a new word - :param None]]: Pass the extra embedding to the model - :return: A tuple of logits and hidden_states - + Args: + self: Refer to the class instance + input_ids: chex.Array: Pass the input to the model + attention_mask: chex.Array: Specify which tokens are masked + position_ids: chex.Array: Specify the position of each token + in the sequence + deterministic: bool: Control whether the model is run in + deterministic or stochastic mode + init_cache: bool: Initialize the cache for the transformer + output_attentions: bool: Return the attention weights + output_hidden_states: bool: Return the hidden states of all + layers + return_dict: bool: Return a dictionary of outputs + extra_embedding: Optional[Union[jnp.ndarray: Pass in the + embedding of a new word + None]]: Pass the extra embedding to the model + + Returns: + A tuple of logits and hidden_states """ batch_size, seq_length = input_ids.shape if attention_mask is None: diff --git a/src/python/easydel/modules/llama/vision_llama_configuration.py b/src/python/easydel/modules/llama/vision_llama_configuration.py index e508896b6..0fa1383be 100644 --- a/src/python/easydel/modules/llama/vision_llama_configuration.py +++ b/src/python/easydel/modules/llama/vision_llama_configuration.py @@ -16,15 +16,17 @@ def __init__( self.sample_mode = sample_mode def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices. - :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to + partition the model fully or not + Returns: + A list of tuples """ return ( diff --git a/src/python/easydel/modules/mamba/__init__.py b/src/python/easydel/modules/mamba/__init__.py index e7fc61c12..36b179eaa 100644 --- a/src/python/easydel/modules/mamba/__init__.py +++ b/src/python/easydel/modules/mamba/__init__.py @@ -1,8 +1,8 @@ -from .mamba_configuration import MambaConfig -from .modelling_mamba_flax import ( - FlaxMambaModule, - FlaxMambaCache, - FlaxMambaForCausalLMModule, - FlaxMambaForCausalLM, - FlaxMambaModel -) +from .mamba_configuration import MambaConfig +from .modelling_mamba_flax import ( + FlaxMambaModule, + FlaxMambaCache, + FlaxMambaForCausalLMModule, + FlaxMambaForCausalLM, + FlaxMambaModel +) diff --git a/src/python/easydel/modules/mamba/mamba_configuration.py b/src/python/easydel/modules/mamba/mamba_configuration.py index 076bd9d23..c6ad3dedb 100644 --- a/src/python/easydel/modules/mamba/mamba_configuration.py +++ b/src/python/easydel/modules/mamba/mamba_configuration.py @@ -1,72 +1,72 @@ -import math - -from ..easydel_modelling_utils import EasyDeLPretrainedConfig -from typing import Optional - - -class MambaConfig(EasyDeLPretrainedConfig): - model_type: str = "mamba" - - def __init__( - self, - vocab_size=50280, - hidden_size=768, - state_size=16, - num_hidden_layers=32, - layer_norm_epsilon=1e-5, - pad_token_id=0, - bos_token_id=0, - eos_token_id=0, - expand=2, - conv_kernel=4, - use_bias=False, - use_conv_bias=True, - hidden_act="silu", - initializer_range=0.1, - residual_in_fp32=True, - time_step_rank="auto", - time_step_scale=1.0, - time_step_min=0.001, - time_step_max=0.1, - time_step_init_scheme="random", - time_step_floor=1e-4, - rescale_prenorm_residual=False, - use_cache=True, - gradient_checkpointing: str = "nothing_saveable", - **kwargs - ): - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.state_size = state_size - self.num_hidden_layers = num_hidden_layers - self.layer_norm_epsilon = layer_norm_epsilon - self.conv_kernel = conv_kernel - self.expand = expand - self.intermediate_size = int(expand * self.hidden_size) - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.use_bias = use_bias - self.use_conv_bias = use_conv_bias - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank - self.time_step_scale = time_step_scale - self.time_step_min = time_step_min - self.time_step_max = time_step_max - self.time_step_init_scheme = time_step_init_scheme - self.time_step_floor = time_step_floor - self.rescale_prenorm_residual = rescale_prenorm_residual - self.residual_in_fp32 = residual_in_fp32 - self.use_cache = use_cache - self.gradient_checkpointing = gradient_checkpointing - super().__init__(**kwargs) - - def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - return super().get_partition_rules(fully_sharded_data_parallel=fully_sharded_data_parallel) - - def add_jax_args( - self, - gradient_checkpointing: str = "nothing_saveable" - ): - self.gradient_checkpointing = gradient_checkpointing +import math + +from ..easydel_modelling_utils import EasyDeLPretrainedConfig +from typing import Optional + + +class MambaConfig(EasyDeLPretrainedConfig): + model_type: str = "mamba" + + def __init__( + self, + vocab_size=50280, + hidden_size=768, + state_size=16, + num_hidden_layers=32, + layer_norm_epsilon=1e-5, + pad_token_id=0, + bos_token_id=0, + eos_token_id=0, + expand=2, + conv_kernel=4, + use_bias=False, + use_conv_bias=True, + hidden_act="silu", + initializer_range=0.1, + residual_in_fp32=True, + time_step_rank="auto", + time_step_scale=1.0, + time_step_min=0.001, + time_step_max=0.1, + time_step_init_scheme="random", + time_step_floor=1e-4, + rescale_prenorm_residual=False, + use_cache=True, + gradient_checkpointing: str = "nothing_saveable", + **kwargs + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.state_size = state_size + self.num_hidden_layers = num_hidden_layers + self.layer_norm_epsilon = layer_norm_epsilon + self.conv_kernel = conv_kernel + self.expand = expand + self.intermediate_size = int(expand * self.hidden_size) + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank + self.time_step_scale = time_step_scale + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_init_scheme = time_step_init_scheme + self.time_step_floor = time_step_floor + self.rescale_prenorm_residual = rescale_prenorm_residual + self.residual_in_fp32 = residual_in_fp32 + self.use_cache = use_cache + self.gradient_checkpointing = gradient_checkpointing + super().__init__(**kwargs) + + def get_partition_rules(self, fully_sharded_data_parallel: bool = True): + return super().get_partition_rules(fully_sharded_data_parallel=fully_sharded_data_parallel) + + def add_jax_args( + self, + gradient_checkpointing: str = "nothing_saveable" + ): + self.gradient_checkpointing = gradient_checkpointing diff --git a/src/python/easydel/modules/mamba/modelling_mamba_flax.py b/src/python/easydel/modules/mamba/modelling_mamba_flax.py index da59cc2ed..ad6f915ea 100644 --- a/src/python/easydel/modules/mamba/modelling_mamba_flax.py +++ b/src/python/easydel/modules/mamba/modelling_mamba_flax.py @@ -1,1065 +1,1080 @@ -import functools -import itertools -import math -from typing import Optional, Tuple, Union, List, Dict, Any, Callable, Sequence, TypeVar - -import fjformer -from jax.core import ShapedArray -import chex -from fjformer import linen as nn -import jax -import jax.numpy as jnp -from fjformer.linen import Linear -import numpy as np -from chex import PRNGKey, Shape, Array -from einops import einsum -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import partitioning as nn_partitioning -from flax.linen.dtypes import promote_dtype -from flax.linen.linear import ( - default_kernel_init, - ConvGeneralDilatedT, - PrecisionLike, - Dtype, - PaddingLike, - canonicalize_padding, - _conv_dimension_numbers -) -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax, eval_shape -import flax.struct -from transformers.modeling_flax_outputs import FlaxBaseModelOutput - -from .mamba_configuration import MambaConfig -from ..easydel_modelling_utils import EasyDeLFlaxPretrainedModel -from ..flax_modelling_utils import ( - get_gradient_checkpoint_policy, - get_dot_general_by_bits, - ACT2FN -) - - -def init_to_value(x, dtype): - return lambda _: x.astype(dtype) - - -@flax.struct.dataclass -class MambaOutput(FlaxBaseModelOutput): - last_hidden_state: chex.Array = None - cache_params: Optional[List[chex.Array]] = None - hidden_states: Optional[Tuple[chex.Array]] = None - - -@flax.struct.dataclass -class MambaCausalLMOutput(FlaxBaseModelOutput): - logits: chex.Array = None - cache_params: Optional[List[chex.Array]] = None - hidden_states: Optional[Tuple[chex.Array]] = None - - -class FlaxMambaCache: - def __init__( - self, - config: MambaConfig, - batch_size: int, - dtype=jnp.float16, - ): - self.seqlen_offset = 0 - self.dtype = dtype - intermediate_size = config.intermediate_size - ssm_state_size = config.state_size - conv_kernel_size = config.conv_kernel - - self.conv_states = { - i: jnp.zeros((batch_size, intermediate_size, conv_kernel_size), dtype=dtype) - for i in range(config.num_hidden_layers) - } - self.ssm_states = { - i: jnp.zeros((batch_size, intermediate_size, ssm_state_size), dtype=dtype) - for i in range(config.num_hidden_layers) - } - - -_T = TypeVar("_T") - - -def create_tuple_parser(n: int) -> Callable[[Union[_T, Sequence[_T]]], tuple[_T, ...]]: - def parse(x: Union[_T, Sequence[_T]]) -> tuple[_T, ...]: - if isinstance(x, Sequence): - if len(x) == n: - return tuple(x) - else: - raise ValueError( - f"x!=n ({x}!=({n}))" - ) - else: - return tuple(itertools.repeat(x, n)) - - return parse - - -class Conv1D(nn.Module): - features: int - kernel_size: int = 1 - stride: int = 1 - padding: int = 0 - dilation: int = 1 - groups: int = 1 - use_bias: bool = True - num_spatial_dims: int = 1 - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - @nn.compact - def __call__(self, x): - - kernel = self.param( - "kernel", - nn.initializers.lecun_normal(dtype=self.param_dtype), - (self.features, 1, self.kernel_size), - self.param_dtype - ) - unbatched_rank = self.num_spatial_dims + 2 - if x.ndim != unbatched_rank: - raise ValueError( - f"Input to `Conv` needs to have rank {unbatched_rank}," - f" but input has shape {x.shape}.", - ) - - # def rava_run(input_rava): - # input_rava = jnp.expand_dims(input_rava, 0) - # input_rava = lax.conv_general_dilated( - # lhs=input_rava, - # rhs=jnp.asarray(kernel, dtype=self.dtype), - # window_strides=stride, - # padding=padding, - # rhs_dilation=dilation, - # feature_group_count=self.groups, - # ) - # input_rava = jnp.squeeze(input_rava, axis=0) - # if self.use_bias: - # bias = self.param( - # "bias", - # nn.initializers.zeros_init(), - # (self.features,) + (1,) * self.num_spatial_dims, - # self.param_dtype - # ) - # input_rava = input_rava + jnp.asarray(bias, dtype=self.dtype) - # return input_rava - - # return nn.vmap( - # rava_run, - # in_axes=0, - # out_axes=0, - # variable_axes={"params": 0}, - # split_rngs={"params": False} - # )(x) - - # x = jnp.expand_dims(x, 0) - x = lax.conv_general_dilated( - lhs=x, - rhs=jnp.asarray(kernel, dtype=self.dtype), - window_strides=(self.stride,), - padding=((self.padding, self.padding),), - rhs_dilation=(self.dilation,), - feature_group_count=self.groups, - ) - if self.use_bias: - bias = self.param( - "bias", - nn.initializers.zeros_init(), - (self.features,), - self.param_dtype - ) - x = x + jnp.asarray(bias.reshape(1, -1, 1), dtype=self.dtype) - return x - - -def mamba_ssm( - u: jax.Array, - delta: jax.Array, - A: jax.Array, - B: jax.Array, - C: jax.Array, - D: Optional[jax.Array] = None, - delta_bias: Optional[jax.Array] = None, - delta_softplus: bool = False, - associative_scan: bool = True, -) -> jax.Array: - if delta_bias is not None: - raise NotImplementedError("delta_bias not implemented yet.") - - l, d_in = u.shape - n = A.shape[1] - - delta = jnp.asarray(delta, dtype=jnp.float32) - - if delta_softplus: - delta = jax.nn.softplus(delta) - - delta_A = jnp.exp(einsum(delta, A, "l d_in, d_in n -> l d_in n")) - delta_B_u = einsum(delta, B, u, "l d_in, l n, l d_in -> l d_in n") - - x = jnp.zeros((d_in, n)) - - def _scan_fn(x, params): - d_A, d_Bu, C = params - - x = d_A * x + d_Bu - return x, einsum(x, C, "d_in n, n -> d_in") - - def _associative_scan_fn(s, c): - return tuple((c[0] * s[0], c[0] * s[1] + c[1])) - - if associative_scan: - _, y = jax.lax.associative_scan(_associative_scan_fn, (delta_A, delta_B_u)) - y = einsum(y, C, "L d_in n, L n -> L d_in") - else: - _, y = jax.lax.scan(_scan_fn, init=x, xs=[delta_A, delta_B_u, C]) - - y = y + u * D - return y - - -class MambaRMSNorm(nn.Module): - dim: int - eps: float = 1e-6 - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.weight = self.param( - 'kernel', - nn.initializers.ones, - (self.dim,), - self.param_dtype, - ) - - def _norm(self, x: jnp.ndarray) -> jnp.ndarray: - return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) - - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - x = x.astype(jnp.promote_types(self.dtype, jnp.float32)) - output = self._norm(x).astype(self.dtype) - weight = jnp.asarray(fjformer.linen.linen.control_quantization(self.weight, self.dtype)) - return output * weight - - -class Conv(nn.Module): - features: int - kernel_size: Sequence[int] - strides: Union[None, int, Sequence[int]] = 1 - padding: PaddingLike = "SAME" - input_dilation: Union[None, int, Sequence[int]] = 1 - kernel_dilation: Union[None, int, Sequence[int]] = 1 - feature_group_count: int = 1 - use_bias: bool = True - mask: Optional[Array] = None - dtype: Optional[Dtype] = None - param_dtype: Dtype = jnp.float32 - precision: PrecisionLike = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros_init() - conv_general_dilated: Optional[ConvGeneralDilatedT] = None - conv_general_dilated_cls: Any = None - - @property - def shared_weights(self) -> bool: - return True - - @nn.compact - def __call__(self, inputs: Array) -> Array: - - if isinstance(self.kernel_size, int): - raise TypeError( - 'Expected Conv kernel_size to be a' - ' tuple/list of integers (eg.: [3, 3]) but got' - f' {self.kernel_size}.' - ) - else: - kernel_size = tuple(self.kernel_size) - - def maybe_broadcast( - x: Optional[Union[int, Sequence[int]]] - ) -> Tuple[int, ...]: - if x is None: - # backward compatibility with using None as sentinel for - # broadcast 1 - x = 1 - if isinstance(x, int): - return (x,) * len(kernel_size) - return tuple(x) - - # Combine all input batch dimensions into a single leading batch axis. - num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) - if num_batch_dimensions != 1: - input_batch_shape = inputs.shape[:num_batch_dimensions] - total_batch_size = int(np.prod(input_batch_shape)) - flat_input_shape = (total_batch_size,) + inputs.shape[ - num_batch_dimensions: - ] - inputs = jnp.reshape(inputs, flat_input_shape) - - # self.strides or (1,) * (inputs.ndim - 2) - strides = maybe_broadcast(self.strides) - input_dilation = maybe_broadcast(self.input_dilation) - kernel_dilation = maybe_broadcast(self.kernel_dilation) - - padding_lax = canonicalize_padding(self.padding, len(kernel_size)) - if padding_lax == 'CIRCULAR': - kernel_size_dilated = [ - (k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation) - ] - zero_pad: List[Tuple[int, int]] = [(0, 0)] - pads = ( - zero_pad - + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] - + [(0, 0)] - ) - inputs = jnp.pad(inputs, pads, mode='wrap') - padding_lax = 'VALID' - elif padding_lax == 'CAUSAL': - if len(kernel_size) != 1: - raise ValueError( - 'Causal padding is only implemented for 1D convolutions.' - ) - left_pad = kernel_dilation[0] * (kernel_size[0] - 1) - pads = [(0, 0), (left_pad, 0), (0, 0)] - inputs = jnp.pad(inputs, pads) - padding_lax = 'VALID' - - dimension_numbers = _conv_dimension_numbers(inputs.shape) - in_features = jnp.shape(inputs)[-1] - - if self.shared_weights: - # One shared convolutional kernel for all pixels in the output. - - inf_f = in_features // self.feature_group_count - # inf_f = 1 - kernel_shape = (self.features, inf_f,) + kernel_size - - else: - if self.feature_group_count != 1: - raise NotImplementedError( - '`lax.conv_general_dilated_local` does not support ' - f'`feature_group_count != 1`, got `{self.feature_group_count}`.' - ) - - # Need to know the spatial output shape of a standard convolution to - # create the unshared convolution kernel. - if self.conv_general_dilated_cls is not None: - conv_general_dilated = self.conv_general_dilated_cls() - elif self.conv_general_dilated is not None: - conv_general_dilated = self.conv_general_dilated - else: - conv_general_dilated = lax.conv_general_dilated - conv_output_shape = eval_shape( - lambda lhs, rhs: conv_general_dilated( # pylint: disable=g-long-lambda - lhs=lhs, - rhs=rhs, - window_strides=strides, - padding=padding_lax, - dimension_numbers=dimension_numbers, - lhs_dilation=input_dilation, - rhs_dilation=kernel_dilation, - ), - inputs, - ShapedArray(kernel_size + (in_features, self.features), inputs.dtype), - ).shape - - # One (unshared) convolutional kernel per each pixel in the output. - kernel_shape = conv_output_shape[1:-1] + ( - np.prod(kernel_size) * in_features, - self.features, - ) - - if self.mask is not None and self.mask.shape != kernel_shape: - raise ValueError( - 'Mask needs to have the same shape as weights. ' - f'Shapes are: {self.mask.shape}, {kernel_shape}' - ) - - kernel = self.param( - 'kernel', self.kernel_init, kernel_shape, self.param_dtype - ) - kernel = jnp.asarray(kernel.transpose(2, 1, 0), self.dtype) - if self.mask is not None: - kernel *= self.mask - - if self.use_bias: - if self.shared_weights: - bias_shape = (self.features,) - else: - bias_shape = conv_output_shape[1:] - - bias = self.param('bias', self.bias_init, bias_shape, self.param_dtype) - else: - bias = None - - inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) - if self.shared_weights: - if self.conv_general_dilated_cls is not None: - conv_general_dilated = self.conv_general_dilated_cls() - elif self.conv_general_dilated is not None: - conv_general_dilated = self.conv_general_dilated - else: - conv_general_dilated = lax.conv_general_dilated - y = conv_general_dilated( - lhs=inputs, - rhs=kernel, - window_strides=strides, - padding=padding_lax, - lhs_dilation=input_dilation, - rhs_dilation=kernel_dilation, - dimension_numbers=dimension_numbers, - feature_group_count=self.feature_group_count, - precision=self.precision, - ) - else: - y = lax.conv_general_dilated_local( - lhs=inputs, - rhs=kernel, - window_strides=strides, - padding=padding_lax, - filter_shape=kernel_size, - lhs_dilation=input_dilation, - rhs_dilation=kernel_dilation, - dimension_numbers=dimension_numbers, - precision=self.precision, - ) - - if self.use_bias: - bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) - y += bias - - if num_batch_dimensions != 1: - output_shape = input_batch_shape + y.shape[1:] - y = jnp.reshape(y, output_shape) - return y - - -class FlaxMambaMixer(nn.Module): - config: MambaConfig - layer_idx: int - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self) -> None: - config = self.config - hidden_size = config.hidden_size - ssm_state_size = config.state_size - intermediate_size = config.intermediate_size - time_step_rank = config.time_step_rank - - self.conv1d = Conv1D( - # features=config.intermediate_size, - # kernel_size=(config.conv_kernel,), - # feature_group_count=config.intermediate_size, - # padding="SAME", - # strides=(1,), - # dtype=self.dtype, - # param_dtype=self.param_dtype, - # precision=self.precision, - # use_bias=config.use_conv_bias, - # # ---- # # - features=config.intermediate_size, - kernel_size=config.conv_kernel, - groups=config.intermediate_size, - stride=1, - padding=config.conv_kernel - 1, - use_bias=config.use_conv_bias, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - ) - - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] - - dt_init_std = self.config.time_step_rank ** -0.5 * self.config.time_step_scale - if self.config.time_step_init_scheme == "constant": - init_kernel_dt = nn.initializers.constant(dt_init_std, dtype=self.param_dtype) - elif self.config.time_step_init_scheme == "random": - # def init_kernel_dt(): - def init_kernel_dt(key, _shape, _dtype): - return jax.nn.initializers.uniform( - scale=dt_init_std * 2, dtype=self.param_dtype - )(key, _shape, _dtype) - dt_init_std - - # return init_r - else: - init_kernel_dt = nn.initializers.normal(self.config.initializer_range, self.param_dtype) - - dt = jax.lax.clamp( - self.config.time_step_floor, - jnp.exp( - jax.random.normal( - key=self.make_rng("params"), - shape=(self.config.intermediate_size,), - dtype=self.param_dtype - ) - * (jnp.log(self.config.time_step_max) - jnp.log(self.config.time_step_min)) - + jnp.log(self.config.time_step_min) - ), - self.config.time_step_max - ) - inv_dt = dt + jnp.log(-jnp.expm1(-dt)) - - dense_class = functools.partial( - Linear, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - **get_dot_general_by_bits( - self.config.bits, - self.config.easy_method - ) - ) - self.in_proj = dense_class( - intermediate_size * 2, - use_bias=config.use_bias - ) - self.x_proj = dense_class( - time_step_rank + ssm_state_size * 2, - use_bias=False - ) - self.dt_proj = dense_class( - intermediate_size, - use_bias=True, - kernel_init=init_kernel_dt, - bias_init=lambda s1, s2, s3: inv_dt - ) - self.out_proj = dense_class( - hidden_size, - use_bias=config.use_bias - ) - - self.A_log = self.param( - "A_log", - init_to_value( - jnp.log( - jnp.broadcast_to( - jnp.arange(1, ssm_state_size + 1, dtype=jnp.float32)[None, :], - (intermediate_size, ssm_state_size) - ) - ), - self.dtype - ) - ) - self.D = self.param( - "D", init_to_value( - jnp.ones( - intermediate_size - ), - self.dtype - ) - ) - self.ssm_state_size = ssm_state_size - self.intermediate_size = intermediate_size - self.conv_kernel_size = self.config.conv_kernel - self.time_step_rank = self.config.time_step_rank - - def __call__(self, input_states, cache_params=None): - batch_size, seq_len, _ = input_states.shape - dtype = input_states.dtype - projected_states = self.in_proj(input_states).transpose(0, 2, 1) - hidden_states, gate = jnp.split(projected_states, 2, axis=1) - - # 2. Convolution sequence transformation - if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx] - if cache_params.seqlen_offset > 0: - conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] - conv_state = jnp.roll(conv_state, shift=-1, axis=-1) - conv_state = conv_state.at[:, :, -1].set(hidden_states[:, :, 0]) - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = jnp.sum(conv_state * self.conv1d.variables["kernel"][:, 0, :], axis=-1) - if self.use_conv_bias: - hidden_states += self.conv1d.variables["bias"] - hidden_states = jnp.expand_dims(self.act(hidden_states).astype(dtype), -1) - # [batch, intermediate_size, 1] : decoding - else: - padding_amount = self.conv_kernel_size - hidden_states.shape[-1] - conv_state = jnp.pad(hidden_states, ((0, 0), (0, padding_amount)), mode='constant') - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) - # [batch, intermediate_size, seq_len] - else: - ssm_state = jnp.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), dtype=dtype - ) - hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) - - ssm_parameters = self.x_proj(hidden_states.transpose(0, 2, 1)) - time_step, B, C = jnp.split( - ssm_parameters, - indices_or_sections=[self.time_step_rank, self.time_step_rank + self.ssm_state_size], - axis=-1 - ) - discrete_time_step = self.dt_proj(time_step) - # [batch, seq_len, intermediate_size] - discrete_time_step = jax.nn.softplus(discrete_time_step).transpose(0, 2, 1) - # [batch, intermediate_size, seq_len] - A = -jnp.exp(self.A_log.astype(jnp.float32)) - # [intermediate_size, ssm_state_size] - modified_a = jnp.expand_dims(jnp.expand_dims(A, axis=0), axis=2) - modified_time_step = jnp.expand_dims(discrete_time_step, axis=-1) - discrete_A = jnp.exp(modified_a * modified_time_step) - # [batch, intermediate_size, seq_len, ssm_state_size] - - discrete_B = modified_time_step * B[:, jnp.newaxis, :, :].astype(jnp.float32) - # [batch, intermediate_size, seq_len, ssm_state_size] - - deltaB_u = discrete_B * hidden_states[:, :, :, jnp.newaxis].astype(jnp.float32) - - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - scan_outputs = [] - for i in range(seq_len): - ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] - # [batch, intermediate_size, ssm_state] - - scan_output = jax.lax.batch_matmul(ssm_state.astype(dtype), jnp.expand_dims(C[:, i, :], -1)) - # [batch, intermediate_size, 1] - - scan_outputs.append(scan_output[:, :, 0]) - - scan_output = jnp.stack(scan_outputs, axis=-1) - # [batch, seq_len, intermediate_size] - scan_output = scan_output + (hidden_states * self.D[jnp.newaxis, :, jnp.newaxis]) - scan_output = (scan_output * self.act(gate)) - - if cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) - - # 4. Final linear projection - contextualized_states = self.out_proj(scan_output.transpose(0, 2, 1)) - # [batch, seq_len, hidden_size] - return contextualized_states - - -class FlaxMambaBlock(nn.Module): - config: MambaConfig - layer_idx: int - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self): - config = self.config - self.residual_in_fp32 = config.residual_in_fp32 - self.norm = MambaRMSNorm( - config.hidden_size, - eps=config.layer_norm_epsilon, - dtype=self.dtype, - param_dtype=self.param_dtype - ) - block = FlaxMambaMixer - if self.config.gradient_checkpointing != "": - block = nn_partitioning.remat( - block, - static_argnums=(1,), - policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing) - ) - self.mixer = block( - config, - self.layer_idx, - self.dtype, - self.param_dtype, - self.precision - ) - - def __call__( - self, - hidden_states: chex.Array, - cache_params: Optional[FlaxMambaCache] = None - ) -> chex.Array: - residual = hidden_states - hidden_states = self.norm( - hidden_states - ) - if self.residual_in_fp32: - residual = residual.astype(jnp.float32) - hidden_states = self.mixer( - hidden_states, - cache_params - ) - hidden_states = residual + hidden_states - return hidden_states - - -class FlaxMambaLayerCollection(nn.Module): - config: MambaConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self) -> None: - self.blocks = [ - FlaxMambaBlock( - config=self.config, - layer_idx=layer_idx, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - name=str(layer_idx) - ) - for layer_idx in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states: chex.Array, - cache_params: Optional[FlaxMambaCache] = None, - output_hidden_states: bool = False - ) -> Tuple[chex.Array, Union[None, Tuple[chex.Array, ...]]]: - all_hidden_states = () if output_hidden_states else None - for block in self.blocks: - hidden_states = block( - hidden_states, - cache_params - ) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - return hidden_states, all_hidden_states - - -class FlaxMambaModule(nn.Module): - config: MambaConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self) -> None: - config = self.config - self.embeddings = nn.Embed( - config.vocab_size, - config.hidden_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - ) - self.layers = FlaxMambaLayerCollection( - config=config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - self.norm_f = MambaRMSNorm( - config.hidden_size, - eps=config.layer_norm_epsilon, - dtype=self.dtype, - param_dtype=self.param_dtype, - ) - - def __call__( - self, - input_ids: Optional[chex.Array] = None, - inputs_embeds: Optional[chex.Array] = None, - cache_params: Optional[chex.Array] = None, - deterministic: bool = True, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, MambaOutput]: - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else (self.config.use_cache if not deterministic else False) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.embeddings(input_ids) - - if deterministic and use_cache: - use_cache = False - - if cache_params is None and use_cache: - cache_params = FlaxMambaCache( - self.config, inputs_embeds.shape[0], dtype=inputs_embeds.dtype - ) - - hidden_states = inputs_embeds - hidden_states, all_hidden_states = self.layers(hidden_states, cache_params=cache_params) - - if use_cache: - cache_params.seqlen_offset += inputs_embeds.shape[1] - - hidden_states = self.norm_f(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) - - return MambaOutput( - last_hidden_state=hidden_states, - cache_params=cache_params if use_cache else None, - hidden_states=all_hidden_states, - ) - - -class FlaxMambaForCausalLMModule(nn.Module): - config: MambaConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self) -> None: - self.backbone = FlaxMambaModule( - self.config, - self.dtype, - self.param_dtype, - self.precision - ) - self.lm_head = Linear( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - def __call__( - self, - input_ids: Optional[chex.Array] = None, - inputs_embeds: Optional[chex.Array] = None, - cache_params: Optional[chex.Array] = None, - deterministic: bool = True, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, MambaCausalLMOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # input_ids: Optional[chex.Array] = None, - # inputs_embeds: Optional[chex.Array] = None, - # deterministic: bool = True, - # cache_params: Optional[List[chex.Array]] = None, - # use_cache: Optional[bool] = None, - # output_hidden_states: Optional[bool] = None, - # return_dict: Optional[bool] = None, - - mamba_outputs = self.backbone( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - deterministic=deterministic, - cache_params=cache_params, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = mamba_outputs[0] - - logits = self.lm_head(hidden_states).astype(jnp.float32) - - if not return_dict: - return (logits,) + mamba_outputs[1:] - - return MambaCausalLMOutput( - logits=logits, - cache_params=mamba_outputs.cache_params, - hidden_states=mamba_outputs.hidden_states, - ) - - -class FlaxMambaPretrainedModel(EasyDeLFlaxPretrainedModel): - config_class = MambaConfig - base_model_prefix = "backbone" - module_class: nn.Module = None - - def __init__( - self, - config: MambaConfig, - input_shape: Tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - param_dtype: jnp.dtype = jnp.float32, - precision: Optional[Union[str, lax.Precision]] = None, - _do_init: bool = True, - **kwargs, - ): - """ - The __init__ function is called when the class is instantiated. - It sets up the instance of the class, and defines what happens when it's created. - The __init__ function can take arguments, but self is always required (it refers to the instance of the object). - - - :param self: Refer to the object itself - :param config: MambaConfig: Pass the configuration to the module - :param input_shape: Tuple: Specify the shape of the input to the model - :param seed: int: Set the seed for random number generation - :param dtype: jnp.dtype: Specify the data type of the model ra - :param param_dtype: jnp.dtype: Specify the data type of the param_dtype - :param precision: Optional[Union[str, lax.Precision]]: precision for model operations - :param _do_init: bool: Control whether the module is initialized or not - :param kwargs: Pass in any additional parameters that the module_class might need - :param : Specify the number of layers in the network - :return: The super() of the class - - """ - module = self.module_class( - config=config, - dtype=dtype, - param_dtype=param_dtype, - precision=precision, - **kwargs - ) - super().__init__( - config, - module, - input_shape=(input_shape[0], 1), - seed=seed, - dtype=dtype, - _do_init=_do_init - ) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. - - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Specify the shape of the input tensor - :param params: FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters - - """ - input_ids = jnp.zeros(input_shape, dtype="i4") - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - module_init_outputs = self.module.init( - rngs, - input_ids, - return_dict=False - ) - - random_params = module_init_outputs["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - return None - - def __call__( - self, - input_ids: Optional[chex.Array] = None, - inputs_embeds: Optional[chex.Array] = None, - cache_params: dict = None, - deterministic: bool = True, - params: dict = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - extra_embedding: Optional[Union[jnp.ndarray, None]] = None, - add_params_field: bool = False, - attention_mask: Optional[chex.Array] = None, # Ignored(we are using an SSM model not attention) - use_cache: bool = False, - **kwargs - ): - """ - The __call__ function is the main function of a JAX module. - - :param self: Represent the instance of the class - :param input_ids: Optional[chex.Array]: Pass in the input tokens - :param inputs_embeds: Optional[chex.Array]: Pass in the embedded tokens - :param cache_params: dict: Pass in the past cache_params from a previous call to __call__ - :param params: dict: Pass in the parameters of the model - :param dropout_rng: jax.random.PRNGKey: Make sure that the dropout is applied in a random way - :param train: bool: Determine whether to use dropout or not - :param output_hidden_states: Optional[bool]: Return the hidden states of all layers - :param return_dict: Optional[bool]: Determine whether to return a dictionary or not - :param extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids - :param add_params_field: bool: Add the params field to the inputs dictionary - :return: A tuple of the following: - - """ - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - batch_size, sequence_length = input_ids.shape - - assert sequence_length <= self.config.max_position_embeddings, "Maximum Position Embedding Reached !" - if cache_params is not None: - assert isinstance(cache_params, FlaxMambaCache), f"Wrong cache input_type of {type(cache_params)}" - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - rngs["params"] = jax.random.key(0) - - inputs = { - "params": params or self.params - } if add_params_field else params or self.params - - # input_ids: Optional[chex.Array] = None, - # inputs_embeds: Optional[chex.Array] = None, - # cache_params: Optional[chex.Array] = None, - # deterministic: bool = True, - # use_cache: Optional[bool] = None, - # output_hidden_states: Optional[bool] = None, - # return_dict: Optional[bool] = None, - - return self.module.apply( - inputs, - input_ids, - inputs_embeds, - cache_params, - train, - use_cache, - output_hidden_states, - return_dict, - rngs=rngs, - mutable=False, - ) - - -class FlaxMambaModel(FlaxMambaPretrainedModel): - module_class = FlaxMambaModule - - -class FlaxMambaForCausalLM(FlaxMambaPretrainedModel): - module_class = FlaxMambaForCausalLMModule - - def update_inputs_for_generation( - self, - outputs: MambaOutput, - model_kwargs: Dict[str, Any], - **kwargs - ) -> Dict[str, Any]: - model_kwargs["cache_params"] = outputs.get("cache_params", None) - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids, - max_length, - **kwargs - ): - return { - "cache_params": kwargs.get("cache_params", None) - } +import functools +import itertools +import math +from typing import Optional, Tuple, Union, List, Dict, Any, Callable, Sequence, TypeVar + +import fjformer +from jax.core import ShapedArray +import chex +from fjformer import linen as nn +import jax +import jax.numpy as jnp +from fjformer.linen import Linear +import numpy as np +from chex import PRNGKey, Shape, Array +from einops import einsum +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import partitioning as nn_partitioning +from flax.linen.dtypes import promote_dtype +from flax.linen.linear import ( + default_kernel_init, + ConvGeneralDilatedT, + PrecisionLike, + Dtype, + PaddingLike, + canonicalize_padding, + _conv_dimension_numbers +) +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax, eval_shape +import flax.struct +from transformers.modeling_flax_outputs import FlaxBaseModelOutput + +from .mamba_configuration import MambaConfig +from ..easydel_modelling_utils import EasyDeLFlaxPretrainedModel +from ..flax_modelling_utils import ( + get_gradient_checkpoint_policy, + get_dot_general_by_bits, + ACT2FN +) + + +def init_to_value(x, dtype): + return lambda _: x.astype(dtype) + + +@flax.struct.dataclass +class MambaOutput(FlaxBaseModelOutput): + last_hidden_state: chex.Array = None + cache_params: Optional[List[chex.Array]] = None + hidden_states: Optional[Tuple[chex.Array]] = None + + +@flax.struct.dataclass +class MambaCausalLMOutput(FlaxBaseModelOutput): + logits: chex.Array = None + cache_params: Optional[List[chex.Array]] = None + hidden_states: Optional[Tuple[chex.Array]] = None + + +class FlaxMambaCache: + def __init__( + self, + config: MambaConfig, + batch_size: int, + dtype=jnp.float16, + ): + self.seqlen_offset = 0 + self.dtype = dtype + intermediate_size = config.intermediate_size + ssm_state_size = config.state_size + conv_kernel_size = config.conv_kernel + + self.conv_states = { + i: jnp.zeros((batch_size, intermediate_size, conv_kernel_size), dtype=dtype) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: jnp.zeros((batch_size, intermediate_size, ssm_state_size), dtype=dtype) + for i in range(config.num_hidden_layers) + } + + +_T = TypeVar("_T") + + +def create_tuple_parser(n: int) -> Callable[[Union[_T, Sequence[_T]]], tuple[_T, ...]]: + def parse(x: Union[_T, Sequence[_T]]) -> tuple[_T, ...]: + if isinstance(x, Sequence): + if len(x) == n: + return tuple(x) + else: + raise ValueError( + f"x!=n ({x}!=({n}))" + ) + else: + return tuple(itertools.repeat(x, n)) + + return parse + + +class Conv1D(nn.Module): + features: int + kernel_size: int = 1 + stride: int = 1 + padding: int = 0 + dilation: int = 1 + groups: int = 1 + use_bias: bool = True + num_spatial_dims: int = 1 + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + @nn.compact + def __call__(self, x): + + kernel = self.param( + "kernel", + nn.initializers.lecun_normal(dtype=self.param_dtype), + (self.features, 1, self.kernel_size), + self.param_dtype + ) + unbatched_rank = self.num_spatial_dims + 2 + if x.ndim != unbatched_rank: + raise ValueError( + f"Input to `Conv` needs to have rank {unbatched_rank}," + f" but input has shape {x.shape}.", + ) + + # def rava_run(input_rava): + # input_rava = jnp.expand_dims(input_rava, 0) + # input_rava = lax.conv_general_dilated( + # lhs=input_rava, + # rhs=jnp.asarray(kernel, dtype=self.dtype), + # window_strides=stride, + # padding=padding, + # rhs_dilation=dilation, + # feature_group_count=self.groups, + # ) + # input_rava = jnp.squeeze(input_rava, axis=0) + # if self.use_bias: + # bias = self.param( + # "bias", + # nn.initializers.zeros_init(), + # (self.features,) + (1,) * self.num_spatial_dims, + # self.param_dtype + # ) + # input_rava = input_rava + jnp.asarray(bias, dtype=self.dtype) + # return input_rava + + # return nn.vmap( + # rava_run, + # in_axes=0, + # out_axes=0, + # variable_axes={"params": 0}, + # split_rngs={"params": False} + # )(x) + + # x = jnp.expand_dims(x, 0) + x = lax.conv_general_dilated( + lhs=x, + rhs=jnp.asarray(kernel, dtype=self.dtype), + window_strides=(self.stride,), + padding=((self.padding, self.padding),), + rhs_dilation=(self.dilation,), + feature_group_count=self.groups, + ) + if self.use_bias: + bias = self.param( + "bias", + nn.initializers.zeros_init(), + (self.features,), + self.param_dtype + ) + x = x + jnp.asarray(bias.reshape(1, -1, 1), dtype=self.dtype) + return x + + +def mamba_ssm( + u: jax.Array, + delta: jax.Array, + A: jax.Array, + B: jax.Array, + C: jax.Array, + D: Optional[jax.Array] = None, + delta_bias: Optional[jax.Array] = None, + delta_softplus: bool = False, + associative_scan: bool = True, +) -> jax.Array: + if delta_bias is not None: + raise NotImplementedError("delta_bias not implemented yet.") + + l, d_in = u.shape + n = A.shape[1] + + delta = jnp.asarray(delta, dtype=jnp.float32) + + if delta_softplus: + delta = jax.nn.softplus(delta) + + delta_A = jnp.exp(einsum(delta, A, "l d_in, d_in n -> l d_in n")) + delta_B_u = einsum(delta, B, u, "l d_in, l n, l d_in -> l d_in n") + + x = jnp.zeros((d_in, n)) + + def _scan_fn(x, params): + d_A, d_Bu, C = params + + x = d_A * x + d_Bu + return x, einsum(x, C, "d_in n, n -> d_in") + + def _associative_scan_fn(s, c): + return tuple((c[0] * s[0], c[0] * s[1] + c[1])) + + if associative_scan: + _, y = jax.lax.associative_scan(_associative_scan_fn, (delta_A, delta_B_u)) + y = einsum(y, C, "L d_in n, L n -> L d_in") + else: + _, y = jax.lax.scan(_scan_fn, init=x, xs=[delta_A, delta_B_u, C]) + + y = y + u * D + return y + + +class MambaRMSNorm(nn.Module): + dim: int + eps: float = 1e-6 + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.weight = self.param( + 'kernel', + nn.initializers.ones, + (self.dim,), + self.param_dtype, + ) + + def _norm(self, x: jnp.ndarray) -> jnp.ndarray: + return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = x.astype(jnp.promote_types(self.dtype, jnp.float32)) + output = self._norm(x).astype(self.dtype) + weight = jnp.asarray(fjformer.linen.linen.control_quantization(self.weight, self.dtype)) + return output * weight + + +class Conv(nn.Module): + features: int + kernel_size: Sequence[int] + strides: Union[None, int, Sequence[int]] = 1 + padding: PaddingLike = "SAME" + input_dilation: Union[None, int, Sequence[int]] = 1 + kernel_dilation: Union[None, int, Sequence[int]] = 1 + feature_group_count: int = 1 + use_bias: bool = True + mask: Optional[Array] = None + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 + precision: PrecisionLike = None + kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros_init() + conv_general_dilated: Optional[ConvGeneralDilatedT] = None + conv_general_dilated_cls: Any = None + + @property + def shared_weights(self) -> bool: + return True + + @nn.compact + def __call__(self, inputs: Array) -> Array: + + if isinstance(self.kernel_size, int): + raise TypeError( + 'Expected Conv kernel_size to be a' + ' tuple/list of integers (eg.: [3, 3]) but got' + f' {self.kernel_size}.' + ) + else: + kernel_size = tuple(self.kernel_size) + + def maybe_broadcast( + x: Optional[Union[int, Sequence[int]]] + ) -> Tuple[int, ...]: + if x is None: + # backward compatibility with using None as sentinel for + # broadcast 1 + x = 1 + if isinstance(x, int): + return (x,) * len(kernel_size) + return tuple(x) + + # Combine all input batch dimensions into a single leading batch axis. + num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) + if num_batch_dimensions != 1: + input_batch_shape = inputs.shape[:num_batch_dimensions] + total_batch_size = int(np.prod(input_batch_shape)) + flat_input_shape = (total_batch_size,) + inputs.shape[ + num_batch_dimensions: + ] + inputs = jnp.reshape(inputs, flat_input_shape) + + # self.strides or (1,) * (inputs.ndim - 2) + strides = maybe_broadcast(self.strides) + input_dilation = maybe_broadcast(self.input_dilation) + kernel_dilation = maybe_broadcast(self.kernel_dilation) + + padding_lax = canonicalize_padding(self.padding, len(kernel_size)) + if padding_lax == 'CIRCULAR': + kernel_size_dilated = [ + (k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation) + ] + zero_pad: List[Tuple[int, int]] = [(0, 0)] + pads = ( + zero_pad + + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + + [(0, 0)] + ) + inputs = jnp.pad(inputs, pads, mode='wrap') + padding_lax = 'VALID' + elif padding_lax == 'CAUSAL': + if len(kernel_size) != 1: + raise ValueError( + 'Causal padding is only implemented for 1D convolutions.' + ) + left_pad = kernel_dilation[0] * (kernel_size[0] - 1) + pads = [(0, 0), (left_pad, 0), (0, 0)] + inputs = jnp.pad(inputs, pads) + padding_lax = 'VALID' + + dimension_numbers = _conv_dimension_numbers(inputs.shape) + in_features = jnp.shape(inputs)[-1] + + if self.shared_weights: + # One shared convolutional kernel for all pixels in the output. + + inf_f = in_features // self.feature_group_count + # inf_f = 1 + kernel_shape = (self.features, inf_f,) + kernel_size + + else: + if self.feature_group_count != 1: + raise NotImplementedError( + '`lax.conv_general_dilated_local` does not support ' + f'`feature_group_count != 1`, got `{self.feature_group_count}`.' + ) + + # Need to know the spatial output shape of a standard convolution to + # create the unshared convolution kernel. + if self.conv_general_dilated_cls is not None: + conv_general_dilated = self.conv_general_dilated_cls() + elif self.conv_general_dilated is not None: + conv_general_dilated = self.conv_general_dilated + else: + conv_general_dilated = lax.conv_general_dilated + conv_output_shape = eval_shape( + lambda lhs, rhs: conv_general_dilated( # pylint: disable=g-long-lambda + lhs=lhs, + rhs=rhs, + window_strides=strides, + padding=padding_lax, + dimension_numbers=dimension_numbers, + lhs_dilation=input_dilation, + rhs_dilation=kernel_dilation, + ), + inputs, + ShapedArray(kernel_size + (in_features, self.features), inputs.dtype), + ).shape + + # One (unshared) convolutional kernel per each pixel in the output. + kernel_shape = conv_output_shape[1:-1] + ( + np.prod(kernel_size) * in_features, + self.features, + ) + + if self.mask is not None and self.mask.shape != kernel_shape: + raise ValueError( + 'Mask needs to have the same shape as weights. ' + f'Shapes are: {self.mask.shape}, {kernel_shape}' + ) + + kernel = self.param( + 'kernel', self.kernel_init, kernel_shape, self.param_dtype + ) + kernel = jnp.asarray(kernel.transpose(2, 1, 0), self.dtype) + if self.mask is not None: + kernel *= self.mask + + if self.use_bias: + if self.shared_weights: + bias_shape = (self.features,) + else: + bias_shape = conv_output_shape[1:] + + bias = self.param('bias', self.bias_init, bias_shape, self.param_dtype) + else: + bias = None + + inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) + if self.shared_weights: + if self.conv_general_dilated_cls is not None: + conv_general_dilated = self.conv_general_dilated_cls() + elif self.conv_general_dilated is not None: + conv_general_dilated = self.conv_general_dilated + else: + conv_general_dilated = lax.conv_general_dilated + y = conv_general_dilated( + lhs=inputs, + rhs=kernel, + window_strides=strides, + padding=padding_lax, + lhs_dilation=input_dilation, + rhs_dilation=kernel_dilation, + dimension_numbers=dimension_numbers, + feature_group_count=self.feature_group_count, + precision=self.precision, + ) + else: + y = lax.conv_general_dilated_local( + lhs=inputs, + rhs=kernel, + window_strides=strides, + padding=padding_lax, + filter_shape=kernel_size, + lhs_dilation=input_dilation, + rhs_dilation=kernel_dilation, + dimension_numbers=dimension_numbers, + precision=self.precision, + ) + + if self.use_bias: + bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) + y += bias + + if num_batch_dimensions != 1: + output_shape = input_batch_shape + y.shape[1:] + y = jnp.reshape(y, output_shape) + return y + + +class FlaxMambaMixer(nn.Module): + config: MambaConfig + layer_idx: int + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self) -> None: + config = self.config + hidden_size = config.hidden_size + ssm_state_size = config.state_size + intermediate_size = config.intermediate_size + time_step_rank = config.time_step_rank + + self.conv1d = Conv1D( + # features=config.intermediate_size, + # kernel_size=(config.conv_kernel,), + # feature_group_count=config.intermediate_size, + # padding="SAME", + # strides=(1,), + # dtype=self.dtype, + # param_dtype=self.param_dtype, + # precision=self.precision, + # use_bias=config.use_conv_bias, + # # ---- # # + features=config.intermediate_size, + kernel_size=config.conv_kernel, + groups=config.intermediate_size, + stride=1, + padding=config.conv_kernel - 1, + use_bias=config.use_conv_bias, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + + dt_init_std = self.config.time_step_rank ** -0.5 * self.config.time_step_scale + if self.config.time_step_init_scheme == "constant": + init_kernel_dt = nn.initializers.constant(dt_init_std, dtype=self.param_dtype) + elif self.config.time_step_init_scheme == "random": + # def init_kernel_dt(): + def init_kernel_dt(key, _shape, _dtype): + return jax.nn.initializers.uniform( + scale=dt_init_std * 2, dtype=self.param_dtype + )(key, _shape, _dtype) - dt_init_std + + # return init_r + else: + init_kernel_dt = nn.initializers.normal(self.config.initializer_range, self.param_dtype) + + dt = jax.lax.clamp( + self.config.time_step_floor, + jnp.exp( + jax.random.normal( + key=self.make_rng("params"), + shape=(self.config.intermediate_size,), + dtype=self.param_dtype + ) + * (jnp.log(self.config.time_step_max) - jnp.log(self.config.time_step_min)) + + jnp.log(self.config.time_step_min) + ), + self.config.time_step_max + ) + inv_dt = dt + jnp.log(-jnp.expm1(-dt)) + + dense_class = functools.partial( + Linear, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + **get_dot_general_by_bits( + self.config.bits, + self.config.easy_method + ) + ) + self.in_proj = dense_class( + intermediate_size * 2, + use_bias=config.use_bias + ) + self.x_proj = dense_class( + time_step_rank + ssm_state_size * 2, + use_bias=False + ) + self.dt_proj = dense_class( + intermediate_size, + use_bias=True, + kernel_init=init_kernel_dt, + bias_init=lambda s1, s2, s3: inv_dt + ) + self.out_proj = dense_class( + hidden_size, + use_bias=config.use_bias + ) + + self.A_log = self.param( + "A_log", + init_to_value( + jnp.log( + jnp.broadcast_to( + jnp.arange(1, ssm_state_size + 1, dtype=jnp.float32)[None, :], + (intermediate_size, ssm_state_size) + ) + ), + self.dtype + ) + ) + self.D = self.param( + "D", init_to_value( + jnp.ones( + intermediate_size + ), + self.dtype + ) + ) + self.ssm_state_size = ssm_state_size + self.intermediate_size = intermediate_size + self.conv_kernel_size = self.config.conv_kernel + self.time_step_rank = self.config.time_step_rank + + def __call__(self, input_states, cache_params=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + projected_states = self.in_proj(input_states).transpose(0, 2, 1) + hidden_states, gate = jnp.split(projected_states, 2, axis=1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx] + if cache_params.seqlen_offset > 0: + conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] + conv_state = jnp.roll(conv_state, shift=-1, axis=-1) + conv_state = conv_state.at[:, :, -1].set(hidden_states[:, :, 0]) + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = jnp.sum(conv_state * self.conv1d.variables["kernel"][:, 0, :], axis=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.variables["bias"] + hidden_states = jnp.expand_dims(self.act(hidden_states).astype(dtype), -1) + # [batch, intermediate_size, 1] : decoding + else: + padding_amount = self.conv_kernel_size - hidden_states.shape[-1] + conv_state = jnp.pad(hidden_states, ((0, 0), (0, padding_amount)), mode='constant') + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + # [batch, intermediate_size, seq_len] + else: + ssm_state = jnp.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + + ssm_parameters = self.x_proj(hidden_states.transpose(0, 2, 1)) + time_step, B, C = jnp.split( + ssm_parameters, + indices_or_sections=[self.time_step_rank, self.time_step_rank + self.ssm_state_size], + axis=-1 + ) + discrete_time_step = self.dt_proj(time_step) + # [batch, seq_len, intermediate_size] + discrete_time_step = jax.nn.softplus(discrete_time_step).transpose(0, 2, 1) + # [batch, intermediate_size, seq_len] + A = -jnp.exp(self.A_log.astype(jnp.float32)) + # [intermediate_size, ssm_state_size] + modified_a = jnp.expand_dims(jnp.expand_dims(A, axis=0), axis=2) + modified_time_step = jnp.expand_dims(discrete_time_step, axis=-1) + discrete_A = jnp.exp(modified_a * modified_time_step) + # [batch, intermediate_size, seq_len, ssm_state_size] + + discrete_B = modified_time_step * B[:, jnp.newaxis, :, :].astype(jnp.float32) + # [batch, intermediate_size, seq_len, ssm_state_size] + + deltaB_u = discrete_B * hidden_states[:, :, :, jnp.newaxis].astype(jnp.float32) + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + scan_outputs = [] + for i in range(seq_len): + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] + # [batch, intermediate_size, ssm_state] + + scan_output = jax.lax.batch_matmul(ssm_state.astype(dtype), jnp.expand_dims(C[:, i, :], -1)) + # [batch, intermediate_size, 1] + + scan_outputs.append(scan_output[:, :, 0]) + + scan_output = jnp.stack(scan_outputs, axis=-1) + # [batch, seq_len, intermediate_size] + scan_output = scan_output + (hidden_states * self.D[jnp.newaxis, :, jnp.newaxis]) + scan_output = (scan_output * self.act(gate)) + + if cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.transpose(0, 2, 1)) + # [batch, seq_len, hidden_size] + return contextualized_states + + +class FlaxMambaBlock(nn.Module): + config: MambaConfig + layer_idx: int + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self): + config = self.config + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = MambaRMSNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + block = FlaxMambaMixer + if self.config.gradient_checkpointing != "": + block = nn_partitioning.remat( + block, + static_argnums=(1,), + policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing) + ) + self.mixer = block( + config, + self.layer_idx, + self.dtype, + self.param_dtype, + self.precision + ) + + def __call__( + self, + hidden_states: chex.Array, + cache_params: Optional[FlaxMambaCache] = None + ) -> chex.Array: + residual = hidden_states + hidden_states = self.norm( + hidden_states + ) + if self.residual_in_fp32: + residual = residual.astype(jnp.float32) + hidden_states = self.mixer( + hidden_states, + cache_params + ) + hidden_states = residual + hidden_states + return hidden_states + + +class FlaxMambaLayerCollection(nn.Module): + config: MambaConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self) -> None: + self.blocks = [ + FlaxMambaBlock( + config=self.config, + layer_idx=layer_idx, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=str(layer_idx) + ) + for layer_idx in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states: chex.Array, + cache_params: Optional[FlaxMambaCache] = None, + output_hidden_states: bool = False + ) -> Tuple[chex.Array, Union[None, Tuple[chex.Array, ...]]]: + all_hidden_states = () if output_hidden_states else None + for block in self.blocks: + hidden_states = block( + hidden_states, + cache_params + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return hidden_states, all_hidden_states + + +class FlaxMambaModule(nn.Module): + config: MambaConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self) -> None: + config = self.config + self.embeddings = nn.Embed( + config.vocab_size, + config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + self.layers = FlaxMambaLayerCollection( + config=config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + self.norm_f = MambaRMSNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + + def __call__( + self, + input_ids: Optional[chex.Array] = None, + inputs_embeds: Optional[chex.Array] = None, + cache_params: Optional[chex.Array] = None, + deterministic: bool = True, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, MambaOutput]: + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not deterministic else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if deterministic and use_cache: + use_cache = False + + if cache_params is None and use_cache: + cache_params = FlaxMambaCache( + self.config, inputs_embeds.shape[0], dtype=inputs_embeds.dtype + ) + + hidden_states = inputs_embeds + hidden_states, all_hidden_states = self.layers(hidden_states, cache_params=cache_params) + + if use_cache: + cache_params.seqlen_offset += inputs_embeds.shape[1] + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return MambaOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class FlaxMambaForCausalLMModule(nn.Module): + config: MambaConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self) -> None: + self.backbone = FlaxMambaModule( + self.config, + self.dtype, + self.param_dtype, + self.precision + ) + self.lm_head = Linear( + self.config.vocab_size, + use_bias=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + def __call__( + self, + input_ids: Optional[chex.Array] = None, + inputs_embeds: Optional[chex.Array] = None, + cache_params: Optional[chex.Array] = None, + deterministic: bool = True, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, MambaCausalLMOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # input_ids: Optional[chex.Array] = None, + # inputs_embeds: Optional[chex.Array] = None, + # deterministic: bool = True, + # cache_params: Optional[List[chex.Array]] = None, + # use_cache: Optional[bool] = None, + # output_hidden_states: Optional[bool] = None, + # return_dict: Optional[bool] = None, + + mamba_outputs = self.backbone( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + deterministic=deterministic, + cache_params=cache_params, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = mamba_outputs[0] + + logits = self.lm_head(hidden_states).astype(jnp.float32) + + if not return_dict: + return (logits,) + mamba_outputs[1:] + + return MambaCausalLMOutput( + logits=logits, + cache_params=mamba_outputs.cache_params, + hidden_states=mamba_outputs.hidden_states, + ) + + +class FlaxMambaPretrainedModel(EasyDeLFlaxPretrainedModel): + config_class = MambaConfig + base_model_prefix = "backbone" + module_class: nn.Module = None + + def __init__( + self, + config: MambaConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + param_dtype: jnp.dtype = jnp.float32, + precision: Optional[Union[str, lax.Precision]] = None, + _do_init: bool = True, + **kwargs, + ): + """The __init__ function is called when the class is instantiated. + It sets up the instance of the class, and defines what happens when it's created. + The __init__ function can take arguments, but self is always required (it refers to the instance of the object). + + Args: + self: Refer to the object itself + config: MambaConfig: Pass the configuration to the module + input_shape: Tuple: Specify the shape of the input to the + model + seed: int: Set the seed for random number generation + dtype: jnp.dtype: Specify the data type of the model ra + param_dtype: jnp.dtype: Specify the data type of the + param_dtype + precision: Optional[Union[str, lax.Precision]]: precision + for model operations + _do_init: bool: Control whether the module is initialized or + not + **kwargs: Pass in any additional parameters that the + module_class might need + :param : Specify the number of layers in the network + + Returns: + The super() of the class + """ + module = self.module_class( + config=config, + dtype=dtype, + param_dtype=param_dtype, + precision=precision, + **kwargs + ) + super().__init__( + config, + module, + input_shape=(input_shape[0], 1), + seed=seed, + dtype=dtype, + _do_init=_do_init + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + """The init_weights function is used to initialize the weights of a model. + + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Specify the shape of the input tensor + params: FrozenDict: Pass in the parameters of a pre-trained + model + + Returns: + A frozendict of parameters + """ + input_ids = jnp.zeros(input_shape, dtype="i4") + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + module_init_outputs = self.module.init( + rngs, + input_ids, + return_dict=False + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + return None + + def __call__( + self, + input_ids: Optional[chex.Array] = None, + inputs_embeds: Optional[chex.Array] = None, + cache_params: dict = None, + deterministic: bool = True, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + extra_embedding: Optional[Union[jnp.ndarray, None]] = None, + add_params_field: bool = False, + attention_mask: Optional[chex.Array] = None, # Ignored(we are using an SSM model not attention) + use_cache: bool = False, + **kwargs + ): + """The __call__ function is the main function of a JAX module. + + Args: + self: Represent the instance of the class + input_ids: Optional[chex.Array]: Pass in the input tokens + inputs_embeds: Optional[chex.Array]: Pass in the embedded + tokens + cache_params: dict: Pass in the past cache_params from a + previous call to __call__ + params: dict: Pass in the parameters of the model + dropout_rng: jax.random.PRNGKey: Make sure that the dropout + is applied in a random way + train: bool: Determine whether to use dropout or not + output_hidden_states: Optional[bool]: Return the hidden + states of all layers + return_dict: Optional[bool]: Determine whether to return a + dictionary or not + extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in + the embedding for the input_ids + add_params_field: bool: Add the params field to the inputs + dictionary + + Returns: + A tuple of the following: + """ + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + batch_size, sequence_length = input_ids.shape + + assert sequence_length <= self.config.max_position_embeddings, "Maximum Position Embedding Reached !" + if cache_params is not None: + assert isinstance(cache_params, FlaxMambaCache), f"Wrong cache input_type of {type(cache_params)}" + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + rngs["params"] = jax.random.key(0) + + inputs = { + "params": params or self.params + } if add_params_field else params or self.params + + # input_ids: Optional[chex.Array] = None, + # inputs_embeds: Optional[chex.Array] = None, + # cache_params: Optional[chex.Array] = None, + # deterministic: bool = True, + # use_cache: Optional[bool] = None, + # output_hidden_states: Optional[bool] = None, + # return_dict: Optional[bool] = None, + + return self.module.apply( + inputs, + input_ids, + inputs_embeds, + cache_params, + train, + use_cache, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=False, + ) + + +class FlaxMambaModel(FlaxMambaPretrainedModel): + module_class = FlaxMambaModule + + +class FlaxMambaForCausalLM(FlaxMambaPretrainedModel): + module_class = FlaxMambaForCausalLMModule + + def update_inputs_for_generation( + self, + outputs: MambaOutput, + model_kwargs: Dict[str, Any], + **kwargs + ) -> Dict[str, Any]: + model_kwargs["cache_params"] = outputs.get("cache_params", None) + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids, + max_length, + **kwargs + ): + return { + "cache_params": kwargs.get("cache_params", None) + } diff --git a/src/python/easydel/modules/mistral/mistral_configuration.py b/src/python/easydel/modules/mistral/mistral_configuration.py index 2e0a62662..1d7a6311a 100644 --- a/src/python/easydel/modules/mistral/mistral_configuration.py +++ b/src/python/easydel/modules/mistral/mistral_configuration.py @@ -37,43 +37,61 @@ def __init__( attention_bias: bool = False, **kwargs, ): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It allows the class to initialize the attributes of a class. The self parameter is a reference to the current instance of the class, and is used to access variables that belong to the class. - :param self: Represent the instance of the class - :param vocab_size: Define the size of the vocabulary - :param hidden_size: Determine the size of the embedding layers - :param intermediate_size: Define the size of the intermediate layer in each transformer block - :param num_hidden_layers: Determine the number of layers in the encoder and decoder - :param num_attention_heads: Determine the number of attention heads in each layer - :param num_key_value_heads: Specify the number of heads for key and value - :param hidden_act: Specify the activation function used in the hidden layers - :param max_position_embeddings: Set the maximum length of the sequence - :param initializer_range: Initialize the weights of the model - :param rms_norm_eps: Avoid division by zero in the rms normalization - :param use_cache: Determine whether to use the cache in the decoder - :param pad_token_id: Specify the token id of the padding token - :param bos_token_id: Specify the beginning of sentence token id - :param eos_token_id: Specify the end of sentence token - :param tie_word_embeddings: Tie the word embeddings and the output layer - :param rope_theta: Control the number of tokens in a rope - :param sliding_window: Control the number of tokens that are processed in parallel - :param gradient_checkpointing: str: Specify whether to use gradient checkpointing - :param use_scan_mlp: bool: Determine whether or not to use the scan_mlp function - :param scan_mlp_chunk_size: int: Specify the chunk size of the scan mlp - :param number_rep_kv: int: Specify the number of times to repeat the key and value vectors - :param attention_dropout: float: Set the dropout rate for the attention layer - :param bits: Optional[int]: Specify the number of bits used for quantization - :param axis_dims: Sequence[int]: Specify the dimension of each axis - :param axis_names: Sequence[str]: Specify the names of each axis in the tensor - :param "mp"): Define the maximum position embeddings - :param attention_bias: bool: when ever to use attention_bias - :param kwargs: Pass a variable number of keyword arguments to a function + Args: + self: Represent the instance of the class + vocab_size: Define the size of the vocabulary + hidden_size: Determine the size of the embedding layers + intermediate_size: Define the size of the intermediate layer + in each transformer block + num_hidden_layers: Determine the number of layers in the + encoder and decoder + num_attention_heads: Determine the number of attention heads + in each layer + num_key_value_heads: Specify the number of heads for key and + value + hidden_act: Specify the activation function used in the + hidden layers + max_position_embeddings: Set the maximum length of the + sequence + initializer_range: Initialize the weights of the model + rms_norm_eps: Avoid division by zero in the rms + normalization + use_cache: Determine whether to use the cache in the decoder + pad_token_id: Specify the token id of the padding token + bos_token_id: Specify the beginning of sentence token id + eos_token_id: Specify the end of sentence token + tie_word_embeddings: Tie the word embeddings and the output + layer + rope_theta: Control the number of tokens in a rope + sliding_window: Control the number of tokens that are + processed in parallel + gradient_checkpointing: str: Specify whether to use gradient + checkpointing + use_scan_mlp: bool: Determine whether or not to use the + scan_mlp function + scan_mlp_chunk_size: int: Specify the chunk size of the scan + mlp + number_rep_kv: int: Specify the number of times to repeat + the key and value vectors + attention_dropout: float: Set the dropout rate for the + attention layer + bits: Optional[int]: Specify the number of bits used for + quantization + axis_dims: Sequence[int]: Specify the dimension of each axis + axis_names: Sequence[str]: Specify the names of each axis in + the tensor + "mp"): Define the maximum position embeddings + attention_bias: bool: when ever to use attention_bias + **kwargs: Pass a variable number of keyword arguments to a + function :param : Define the number of layers in the model - :return: An instance of the class + Returns: + An instance of the class """ self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -114,15 +132,18 @@ def __init__( @staticmethod def get_partition_rules(fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned. - :param fully_sharded_data_parallel: bool: Determine whether to use the fully_sharded_data_parallel partitioning scheme or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to use + the fully_sharded_data_parallel partitioning scheme or + not + Returns: + A list of tuples """ return ( @@ -171,20 +192,28 @@ def add_jax_args( attention_bias: bool = False, **kwargs, ): - """ - The add_jax_args function adds the following arguments to the model: - - :param self: Bind the attributes and methods of a class to an instance of that class - :param gradient_checkpointing: str: Determine whether to use gradient checkpointing - :param use_scan_mlp: bool: Determine whether to use the scan_mlp function or notn - :param scan_mlp_chunk_size: int: Chunk the input to the mlp - :param number_rep_kv: int: Control the number of times that the key and value vectors are repeated - :param bits: Optional[int]: Specify the number of bits to use for quantization - :param attention_dropout: float: Set the dropout rate for the attention layer - :param attention_bias: bool: when ever to use attention_bias - :param rope_scaling: Dict[str, Union[str, float]]: rope_scaling for rope - :return: A tuple of the following: - + """The add_jax_args function adds the following arguments to the model: + + Args: + self: Bind the attributes and methods of a class to an + instance of that class + gradient_checkpointing: str: Determine whether to use + gradient checkpointing + use_scan_mlp: bool: Determine whether to use the scan_mlp + function or notn + scan_mlp_chunk_size: int: Chunk the input to the mlp + number_rep_kv: int: Control the number of times that the key + and value vectors are repeated + bits: Optional[int]: Specify the number of bits to use for + quantization + attention_dropout: float: Set the dropout rate for the + attention layer + attention_bias: bool: when ever to use attention_bias + rope_scaling: Dict[str, Union[str, float]]: rope_scaling for + rope + + Returns: + A tuple of the following: """ self.attention_bias = attention_bias diff --git a/src/python/easydel/modules/mistral/modelling_mistral_flax.py b/src/python/easydel/modules/mistral/modelling_mistral_flax.py index 0d21ed8e6..1a95fb9a1 100644 --- a/src/python/easydel/modules/mistral/modelling_mistral_flax.py +++ b/src/python/easydel/modules/mistral/modelling_mistral_flax.py @@ -51,9 +51,7 @@ def _make_sliding_window_causal_mask( past_key_values_length: int = 0, sliding_window: int = 4096, ): - """ - Make causal mask used for sliding window attention - """ + """Make causal mask used for sliding window attention""" bsz, tgt_len = input_ids_shape tensor = jnp.full( @@ -214,33 +212,37 @@ def _merge_heads(self, hidden_states): @staticmethod def _transpose_sequence_head(query, key, value): - """ - The _transpose_sequence_head function transposes the query, key and value matrices. + """The _transpose_sequence_head function transposes the query, key and value matrices. - :param query: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query, key and value matrices + Args: + query: Get the attention weights for each of the heads + key: Determine the number of heads + value: Store the values of the input + Returns: + The transpose of the query, key and value matrices """ return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - :param self: Access variables that belong to the class - :param batch_size: Reshape the query, key and value tensors - :param sequence_length: Reshape the query, key and value tensors - :param query: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param freq_cis: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query, key and value - + Args: + self: Access variables that belong to the class + batch_size: Reshape the query, key and value tensors + sequence_length: Reshape the query, key and value tensors + query: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + freq_cis: Calculate the frequency of each word in the + vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query, key and value """ query = query.reshape( batch_size, @@ -282,25 +284,32 @@ def __call__( output_attentions: bool = False, fcm_mask=None, ): - """ - - The __call__ function is the main function of a JAX module. It defines how the module behaves when called + """The __call__ function is the main function of a JAX module. It defines how the module behaves when called with inputs. The __call__ function can be thought of as a "forward pass" through the model, and it should return all outputs that are needed for training or inference. - :param self: Access variables that belong to the class - :param hidden_states: chex.Array: Pass the hidden states of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param causal_mask: chex.Array: Mask out the future tokens in the decoder - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attention weights or not - :param fcm_mask: Mask out the attention weights between the input and output tokens + Args: + self: Access variables that belong to the class + hidden_states: chex.Array: Pass the hidden states of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency coefficients for each position + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Determine the position of each + token in a sequence + causal_mask: chex.Array: Mask out the future tokens in the + decoder + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attention weights or not + fcm_mask: Mask out the attention weights between the input + and output tokens :param : Determine if the attention is causal or not - :return: A tuple of two arrays + Returns: + A tuple of two arrays """ batch_size, sequence_length = hidden_states.shape[:2] query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj( @@ -488,8 +497,7 @@ def __call__( init_cache: bool = False, output_attentions: bool = True ): - """ - The __call__ function is the main function of a TransformerEncoderLayer. + """The __call__ function is the main function of a TransformerEncoderLayer. It takes in the following arguments: hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers. @@ -497,17 +505,25 @@ def __call__( used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Represent the input to the encoder layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer - :param attention_mask: chex.Array: Mask out the attention weights for certain positions - :param causal_mask: chex.Array: Mask the future tokens - :param position_ids: chex.Array: Indicate the position of each token in the sequence - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache for the self-attention layer - :param output_attentions: bool: Determine whether to return the attention weights or not - :return: A tuple of hidden_states and attention_output - + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Represent the input to the + encoder layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency + information to the attention layer + attention_mask: chex.Array: Mask out the attention weights + for certain positions + causal_mask: chex.Array: Mask the future tokens + position_ids: chex.Array: Indicate the position of each + token in the sequence + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache for the self- + attention layer + output_attentions: bool: Determine whether to return the + attention weights or not + + Returns: + A tuple of hidden_states and attention_output """ # hidden_states: chex.Array, @@ -579,18 +595,21 @@ def init_weights( input_shape: Tuple, params: flax.core.FrozenDict = None ) -> flax.core.FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. + """The init_weights function is used to initialize the weights of a model. It takes in an rng, which is a random number generator key that can be used to generate random numbers. The input_shape parameter specifies the shape of the inputs that will be fed into this model. The params parameter allows you to pass in pre-trained weights for your model, if you have them available. - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids - :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Initialize the input_ids, attention_mask + and position_ids + params: flax.core.FrozenDict: Pass in the parameters of a + pre-trained model + Returns: + A frozendict of parameters """ input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -661,28 +680,35 @@ def __call__( add_params_field: bool = False, **kwargs ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It takes as input: - The parameters of the model (self.params) - The inputs to the model (input_ids, attention_mask, position_ids) - Whether we are training (train=True/False) and whether we want to return all hidden states and attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False). - :param self: Represent the instance of the class - :param input_ids: Pass the input sequence to the model - :param attention_mask: Mask out the padding tokens - :param position_ids: Specify the position of each token in the sequence - :param params: dict: Pass in the parameters of the model - :param past_key_values: dict: Pass the past key values to the model - :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model - :param train: bool: Determine whether to use dropout or not - :param output_attentions: Optional[bool]: Determine whether to return the attention weights - :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers - :param return_dict: Optional[bool]: Return a dictionary of the outputs - :param add_params_field: bool: Add a params field to the inputs dictionary - :return: A tuple of (last_hidden_state, past_key_values) - + Args: + self: Represent the instance of the class + input_ids: Pass the input sequence to the model + attention_mask: Mask out the padding tokens + position_ids: Specify the position of each token in the + sequence + params: dict: Pass in the parameters of the model + past_key_values: dict: Pass the past key values to the model + dropout_rng: jax.random.PRNGKey: Pass in a random number + generator key to the model + train: bool: Determine whether to use dropout or not + output_attentions: Optional[bool]: Determine whether to + return the attention weights + output_hidden_states: Optional[bool]: Determine whether to + return the hidden states of all layers + return_dict: Optional[bool]: Return a dictionary of the + outputs + add_params_field: bool: Add a params field to the inputs + dictionary + + Returns: + A tuple of (last_hidden_state, past_key_values) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -876,25 +902,31 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ) -> typing.Union[Tuple[Array, ...], FlaxBaseModelOutput]: - """ - The __call__ function is the main function of a Flax model. + """The __call__ function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids as inputs to the model. The output is a tuple containing: last hidden state (hidden states), all hidden states (if output_hidden_states=True), attentions (if output attentions=True). - - :param self: Represent the instance of the class - :param input_ids: chex.Array: Pass in the input ids - :param attention_mask: chex.Array: Mask out the attention weights for certain tokens - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param deterministic: bool: Determine whether to use dropout or not - :param inputs_embeds: chex.Array: Pass in the embedding of the input_ids - :param init_cache: bool: Initialize the cache for the decoder - :param output_attentions: bool: Determine whether to return the attention weights or not - :param output_hidden_states: bool: Return all hidden states or just the last one - :param return_dict: bool: Return a dictionary of the outputs or not + Args: + self: Represent the instance of the class + input_ids: chex.Array: Pass in the input ids + attention_mask: chex.Array: Mask out the attention weights + for certain tokens + position_ids: chex.Array: Determine the position of each + token in a sequence + deterministic: bool: Determine whether to use dropout or not + inputs_embeds: chex.Array: Pass in the embedding of the + input_ids + init_cache: bool: Initialize the cache for the decoder + output_attentions: bool: Determine whether to return the + attention weights or not + output_hidden_states: bool: Return all hidden states or just + the last one + return_dict: bool: Return a dictionary of the outputs or not :param : Determine whether the model is in training mode or not - :return: A tuple of the hidden states, all hidden states, and attentions + Returns: + A tuple of the hidden states, all hidden states, and + attentions """ if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids.astype("i4")) @@ -979,26 +1011,32 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ): - """ - The __call__ function is the main function of a Flax module. It defines how the model will be called, - and what it returns. In this case, we are calling our Transformer model with input_ids and attention_mask - as inputs (these are defined in __init__). We also have some optional arguments that can be passed to - the call function: deterministic (whether to use dropout), inputs_embeds (if you want to pass your own embeddings), - output_attentions and output_hidden states which return additional outputs from the transformer layers if set True. Finally, - - :param self: Refer to the object itself - :param input_ids: chex.Array: Pass in the input tokens - :param attention_mask: chex.Array: Mask out the padding tokens - :param position_ids: chex.Array: Specify the position of each token in the sequence - :param deterministic: bool: Determine whether to use dropout in the model - :param inputs_embeds: chex.Array: Pass in the embeddings of the input tokens - :param init_cache: bool: Initialize the cache for the decoder - :param output_attentions: bool: Return the attention weights - :param output_hidden_states: bool: Return the hidden states of all layers - :param return_dict: bool: Return a dictionary of the outputs or just the logits - :param : Determine whether to return the logits or not - :return: A tuple of (lm_logits, hidden_states, attentions) - + """The __call__ function is the main function of a Flax module. It defines how the model will be called, + and what it returns. In this case, we are calling our Transformer model with input_ids and attention_mask + as inputs (these are defined in __init__). We also have some optional arguments that can be passed to + the call function: deterministic (whether to use dropout), inputs_embeds (if you want to pass your own embeddings), + output_attentions and output_hidden states which return additional outputs from the transformer layers if set True. Finally, + + Args: + self: Refer to the object itself + input_ids: chex.Array: Pass in the input tokens + attention_mask: chex.Array: Mask out the padding tokens + position_ids: chex.Array: Specify the position of each token + in the sequence + deterministic: bool: Determine whether to use dropout in the + model + inputs_embeds: chex.Array: Pass in the embeddings of the + input tokens + init_cache: bool: Initialize the cache for the decoder + output_attentions: bool: Return the attention weights + output_hidden_states: bool: Return the hidden states of all + layers + return_dict: bool: Return a dictionary of the outputs or + just the logits + :param : Determine whether to return the logits or not + + Returns: + A tuple of (lm_logits, hidden_states, attentions) """ batch_size, seq_length = input_ids.shape diff --git a/src/python/easydel/modules/mistral/modelling_vision_mistral_flax.py b/src/python/easydel/modules/mistral/modelling_vision_mistral_flax.py index 9e30277b4..fa451a902 100644 --- a/src/python/easydel/modules/mistral/modelling_vision_mistral_flax.py +++ b/src/python/easydel/modules/mistral/modelling_vision_mistral_flax.py @@ -52,15 +52,17 @@ def init_cache(self, batch_size, max_length): return init_variables["cache"] def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. + """The init_weights function is used to initialize the weights of a model. - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Specify the shape of the input tensor - :param params: FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Specify the shape of the input tensor + params: FrozenDict: Pass in the parameters of a pre-trained + model + Returns: + A frozendict of parameters """ input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) diff --git a/src/python/easydel/modules/mistral/vision_mistral_configuration.py b/src/python/easydel/modules/mistral/vision_mistral_configuration.py index 2861b3c1d..5b7ab7764 100644 --- a/src/python/easydel/modules/mistral/vision_mistral_configuration.py +++ b/src/python/easydel/modules/mistral/vision_mistral_configuration.py @@ -16,15 +16,17 @@ def __init__( self.sample_mode = sample_mode def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices. - :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to + partition the model fully or not + Returns: + A list of tuples """ return ( diff --git a/src/python/easydel/modules/mixtral/mixtral_configuration.py b/src/python/easydel/modules/mixtral/mixtral_configuration.py index f064403de..d42a54eae 100644 --- a/src/python/easydel/modules/mixtral/mixtral_configuration.py +++ b/src/python/easydel/modules/mixtral/mixtral_configuration.py @@ -43,46 +43,66 @@ def __init__( router_jitter_noise=0.0, **kwargs, ): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It allows the class to initialize the attributes of a class. The self parameter is a reference to the current instance of the class, and is used to access variables that belong to the class. - :param self: Represent the instance of the class - :param vocab_size: Define the size of the vocabulary - :param hidden_size: Determine the size of the embedding layers - :param intermediate_size: Define the size of the intermediate layer in each transformer block - :param num_hidden_layers: Determine the number of layers in the encoder and decoder - :param num_attention_heads: Determine the number of attention heads in each layer - :param num_key_value_heads: Specify the number of heads for key and value - :param hidden_act: Specify the activation function used in the hidden layers - :param max_position_embeddings: Set the maximum length of the sequence - :param initializer_range: Initialize the weights of the model - :param rms_norm_eps: Avoid division by zero in the rms normalization - :param use_cache: Determine whether to use the cache in the decoder - :param pad_token_id: Specify the token id of the padding token - :param bos_token_id: Specify the beginning of sentence token id - :param eos_token_id: Specify the end of sentence token - :param tie_word_embeddings: Tie the word embeddings and the output layer - :param rope_theta: Control the number of tokens in a rope - :param sliding_window: Control the number of tokens that are processed in parallel - :param gradient_checkpointing: str: Specify whether to use gradient checkpointing - :param use_scan_mlp: bool: Determine whether or not to use the scan_mlp function - :param scan_mlp_chunk_size: int: Specify the chunk size of the scan mlp - :param number_rep_kv: int: Specify the number of times to repeat the key and value vectors - :param bits: Optional[int]: Specify the number of bits used for quantization - :param axis_dims: Sequence[int]: Specify the dimension of each axis - :param axis_names: Sequence[str]: Specify the names of each axis in the tensor - :param "mp"): Define the maximum position embeddings - :param kwargs: Pass a variable number of keyword arguments to a function - :param rope_scaling: Dict[str, Union[str, float]]: rope scaling information - :param attention_dropout: float: Set the dropout rate for the attention layer - :param initialization_of_moe: bool: initialization of moe needs to disable some dynamic part's this boolean - variable will turn them off. - :param attention_bias: bool: when ever to use attention_bias + Args: + self: Represent the instance of the class + vocab_size: Define the size of the vocabulary + hidden_size: Determine the size of the embedding layers + intermediate_size: Define the size of the intermediate layer + in each transformer block + num_hidden_layers: Determine the number of layers in the + encoder and decoder + num_attention_heads: Determine the number of attention heads + in each layer + num_key_value_heads: Specify the number of heads for key and + value + hidden_act: Specify the activation function used in the + hidden layers + max_position_embeddings: Set the maximum length of the + sequence + initializer_range: Initialize the weights of the model + rms_norm_eps: Avoid division by zero in the rms + normalization + use_cache: Determine whether to use the cache in the decoder + pad_token_id: Specify the token id of the padding token + bos_token_id: Specify the beginning of sentence token id + eos_token_id: Specify the end of sentence token + tie_word_embeddings: Tie the word embeddings and the output + layer + rope_theta: Control the number of tokens in a rope + sliding_window: Control the number of tokens that are + processed in parallel + gradient_checkpointing: str: Specify whether to use gradient + checkpointing + use_scan_mlp: bool: Determine whether or not to use the + scan_mlp function + scan_mlp_chunk_size: int: Specify the chunk size of the scan + mlp + number_rep_kv: int: Specify the number of times to repeat + the key and value vectors + bits: Optional[int]: Specify the number of bits used for + quantization + axis_dims: Sequence[int]: Specify the dimension of each axis + axis_names: Sequence[str]: Specify the names of each axis in + the tensor + "mp"): Define the maximum position embeddings + **kwargs: Pass a variable number of keyword arguments to a + function + rope_scaling: Dict[str, Union[str, float]]: rope scaling + information + attention_dropout: float: Set the dropout rate for the + attention layer + initialization_of_moe: bool: initialization of moe needs to + disable some dynamic part's this boolean variable will + turn them off. + attention_bias: bool: when ever to use attention_bias :param : Define the number of layers in the model - :return: An instance of the class + Returns: + An instance of the class """ self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -127,15 +147,18 @@ def __init__( ) def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned. - :param fully_sharded_data_parallel: bool: Determine whether to use the fully_sharded_data_parallel partitioning scheme or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to use + the fully_sharded_data_parallel partitioning scheme or + not + Returns: + A list of tuples """ return ( @@ -187,22 +210,31 @@ def add_jax_args( initialization_of_moe: bool = False, **kwargs, ): - """ - The add_jax_args function adds the following arguments to the model: - - :param self: Bind the attributes and methods of a class to an instance of that class - :param gradient_checkpointing: str: Determine whether to use gradient checkpointing - :param use_scan_mlp: bool: Determine whether to use the scan_mlp function or not - :param scan_mlp_chunk_size: int: Chunk the input to the mlp - :param number_rep_kv: int: Control the number of times that the key and value vectors are repeated - :param bits: Optional[int]: Specify the number of bits to use for quantization - :param attention_dropout: float: Set the dropout rate for the attention layer - :param attention_bias: bool: when ever to use attention_bias - :param initialization_of_moe: bool: initialization of moe needs to disable some dynamic part's this boolean - variable will turn them off. - :param rope_scaling: Dict[str, Union[str, float]]: rope_scaling for rope - :return: A tuple of the following: - + """The add_jax_args function adds the following arguments to the model: + + Args: + self: Bind the attributes and methods of a class to an + instance of that class + gradient_checkpointing: str: Determine whether to use + gradient checkpointing + use_scan_mlp: bool: Determine whether to use the scan_mlp + function or not + scan_mlp_chunk_size: int: Chunk the input to the mlp + number_rep_kv: int: Control the number of times that the key + and value vectors are repeated + bits: Optional[int]: Specify the number of bits to use for + quantization + attention_dropout: float: Set the dropout rate for the + attention layer + attention_bias: bool: when ever to use attention_bias + initialization_of_moe: bool: initialization of moe needs to + disable some dynamic part's this boolean variable will + turn them off. + rope_scaling: Dict[str, Union[str, float]]: rope_scaling for + rope + + Returns: + A tuple of the following: """ self.attention_dropout = attention_dropout self.attention_bias = attention_bias diff --git a/src/python/easydel/modules/mixtral/modelling_mixtral_flax.py b/src/python/easydel/modules/mixtral/modelling_mixtral_flax.py index af02433ef..02a4b3e85 100644 --- a/src/python/easydel/modules/mixtral/modelling_mixtral_flax.py +++ b/src/python/easydel/modules/mixtral/modelling_mixtral_flax.py @@ -187,23 +187,28 @@ def __call__( init_cache: bool = False, output_attentions: bool = True ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It defines how the module behaves when called as a function, and it's what you'll use to call your model in practice. The __call__ method takes an input tensor (x) and returns an output tensor (y). In this case, we're defining our model to be a simple linear layer with no activation: y = x @ w + b. - :param self: Refer to the object itself - :param hidden_states: chex.Array: Pass in the hidden state of the model - :param freq_cis: Tuple[chex.Array, chex.Array],: Create the apply_rotary variable - :param attention_mask: chex.Array: Mask the attention weights - :param causal_mask: chex.Array: Mask the attention weights - :param position_ids: chex.Array: Specify the position of each token in a sequence - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attention weights - :return: A tuple of (out, attn_output) - + Args: + self: Refer to the object itself + hidden_states: chex.Array: Pass in the hidden state of the + model + freq_cis: Tuple[chex.Array, chex.Array],: Create the + apply_rotary variable + attention_mask: chex.Array: Mask the attention weights + causal_mask: chex.Array: Mask the attention weights + position_ids: chex.Array: Specify the position of each token + in a sequence + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attention weights + + Returns: + A tuple of (out, attn_output) """ batch_size, sequence_length = hidden_states.shape[:2] query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj( @@ -392,8 +397,7 @@ def __call__( class FlaxMixtralSparseMoeBlock(nn.Module): - """ - This implementation is + """This implementation is strictly equivalent to standard MoE with full capacity (no dropped tokens). It's faster since it formulates MoE operations in terms of block-sparse operations to accomodate imbalanced @@ -529,23 +533,30 @@ def __call__( output_attentions: bool = True, output_router_logits: Optional[bool] = None, ): - """ - The __call__ function is the main function of a TransformerEncoderLayer. + """The __call__ function is the main function of a TransformerEncoderLayer. It takes in the following arguments: hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers. freq_cis (chex.Array): A tensor containing frequency-domain representations of each token's context vector, used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. This tensor has shape `(batch_size, num - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Represent the input to the encoder layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer - :param attention_mask: chex.Array: Mask out the attention weights for certain positions - :param causal_mask: chex.Array: Mask the future tokens - :param position_ids: chex.Array: Indicate the position of each token in the sequence - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache for the self-attention layer - :param output_attentions: bool: Determine whether to return the attention weights or not - :return: A tuple of hidden_states and attention_output - + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Represent the input to the + encoder layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency + information to the attention layer + attention_mask: chex.Array: Mask out the attention weights + for certain positions + causal_mask: chex.Array: Mask the future tokens + position_ids: chex.Array: Indicate the position of each + token in the sequence + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache for the self- + attention layer + output_attentions: bool: Determine whether to return the + attention weights or not + + Returns: + A tuple of hidden_states and attention_output """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -620,23 +631,31 @@ def __call__( output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = None, ): - """ - The __call__ function is the main function of a TransformerEncoderLayer. + """The __call__ function is the main function of a TransformerEncoderLayer. It takes in the following arguments: hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers. freq_cis (chex.Array): A tensor containing frequency-domain representations of each token's context vector, used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. This tensor has shape `(batch_size, num - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Represent the input to the encoder layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer - :param attention_mask: chex.Array: Mask out the attention weights for certain positions - :param causal_mask: chex.Array: Mask the future tokens - :param position_ids: chex.Array: Indicate the position of each token in the sequence - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache for the self-attention layer - :param output_attentions: bool: Determine whether to return the attention weights or not - :return: A tuple of hidden_states, attention_output, all_hidden_states and all_router_logits - + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Represent the input to the + encoder layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency + information to the attention layer + attention_mask: chex.Array: Mask out the attention weights + for certain positions + causal_mask: chex.Array: Mask the future tokens + position_ids: chex.Array: Indicate the position of each + token in the sequence + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache for the self- + attention layer + output_attentions: bool: Determine whether to return the + attention weights or not + + Returns: + A tuple of hidden_states, attention_output, + all_hidden_states and all_router_logits """ all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -714,17 +733,21 @@ def init_weights( input_shape: Tuple, params: FrozenDict = None ) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. + """The init_weights function is used to initialize the weights of a model. It takes in a rng, which is a random number generator key that can be used to generate random numbers. The input_shape parameter specifies the shape of the inputs that will be fed into this model. The params parameter allows you to pass in pre-trained weights for your model, if you have them available. - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids - :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Initialize the input_ids, attention_mask + and position_ids + params: flax.core.FrozenDict: Pass in the parameters of a + pre-trained model + + Returns: + A frozendict of parameters """ self.config.initialization_of_moe = True @@ -798,28 +821,35 @@ def __call__( add_params_field: bool = False, **kwargs ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It takes as input: - The parameters of the model (self.params) - The inputs to the model (input_ids, attention_mask, position_ids) - Whether we are training (train=True/False) and whether we want to return all hidden states and attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False). - :param self: Represent the instance of the class - :param input_ids: Pass the input sequence to the model - :param attention_mask: Mask out the padding tokens - :param position_ids: Specify the position of each token in the sequence - :param params: dict: Pass in the parameters of the model - :param past_key_values: dict: Pass the past key values to the model - :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model - :param train: bool: Determine whether to use dropout or not - :param output_attentions: Optional[bool]: Determine whether to return the attention weights - :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers - :param return_dict: Optional[bool]: Return a dictionary of the outputs - :param add_params_field: bool: Add a params field to the inputs dictionary - :return: A tuple of (last_hidden_state, past_key_values) - + Args: + self: Represent the instance of the class + input_ids: Pass the input sequence to the model + attention_mask: Mask out the padding tokens + position_ids: Specify the position of each token in the + sequence + params: dict: Pass in the parameters of the model + past_key_values: dict: Pass the past key values to the model + dropout_rng: jax.random.PRNGKey: Pass in a random number + generator key to the model + train: bool: Determine whether to use dropout or not + output_attentions: Optional[bool]: Determine whether to + return the attention weights + output_hidden_states: Optional[bool]: Determine whether to + return the hidden states of all layers + return_dict: Optional[bool]: Return a dictionary of the + outputs + add_params_field: bool: Add a params field to the inputs + dictionary + + Returns: + A tuple of (last_hidden_state, past_key_values) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1130,15 +1160,18 @@ def set_output_embeddings(self, new_embeddings): self.module.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None): - """ - The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. - - :param self: Access variables that belong to the class - :param input_ids: Pass in the input tokens - :param max_length: Set the length of the sequence to be generated - :param attention_mask: Optional[chex.Array]: Mask the attention weights - :return: A dictionary of the past_key_values, attention_mask and position ids - + """The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. + + Args: + self: Access variables that belong to the class + input_ids: Pass in the input tokens + max_length: Set the length of the sequence to be generated + attention_mask: Optional[chex.Array]: Mask the attention + weights + + Returns: + A dictionary of the past_key_values, attention_mask and + position ids """ batch_size, seq_length = input_ids.shape diff --git a/src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py b/src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py index 4bf8d52c2..1b142959f 100644 --- a/src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py +++ b/src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py @@ -143,21 +143,27 @@ def __call__( deterministic: bool = False ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It takes in inputs and returns outputs, just like any other Python function. The difference is that __call__ can also take in state (e.g., parameters) from the module itself, and it can update that state as part of its computation. - :param self: Access variables that belong to the class - :param hidden_states: chex.Array: Pass the input to the attention layer - :param attention_mask: chex.Array: Mask out certain positions in the sequence - :param position_bias: chex.Array: Add a bias to the attention scores - :param causal_mask: chex.Array: Mask out certain positions in the sequence - :param init_cache: bool: Initialize the cache - :param deterministic: bool: deterministic to activate dropouts and detect training process - :return: The output of the attention layer - + Args: + self: Access variables that belong to the class + hidden_states: chex.Array: Pass the input to the attention + layer + attention_mask: chex.Array: Mask out certain positions in + the sequence + position_bias: chex.Array: Add a bias to the attention + scores + causal_mask: chex.Array: Mask out certain positions in the + sequence + init_cache: bool: Initialize the cache + deterministic: bool: deterministic to activate dropouts and + detect training process + + Returns: + The output of the attention layer """ inp_shape = hidden_states.shape mixed_qkv = self.Wqkv(hidden_states) diff --git a/src/python/easydel/modules/olmo/olmo_configuration.py b/src/python/easydel/modules/olmo/olmo_configuration.py index 2e06f308e..2591f8265 100644 --- a/src/python/easydel/modules/olmo/olmo_configuration.py +++ b/src/python/easydel/modules/olmo/olmo_configuration.py @@ -24,9 +24,7 @@ class BlockType(StrEnum): class OLMoConfig(EasyDeLPretrainedConfig): - """ - OLMo (model) configuration. - """ + """OLMo (model) configuration.""" def __init__( self, @@ -113,15 +111,17 @@ def add_jax_args( self.gradient_checkpointing = gradient_checkpointing def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices. - :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to + partition the model fully or not + Returns: + A list of tuples """ return ( diff --git a/src/python/easydel/modules/openelm/modelling_openelm_flax.py b/src/python/easydel/modules/openelm/modelling_openelm_flax.py index 5fddc7691..b41af3bf9 100644 --- a/src/python/easydel/modules/openelm/modelling_openelm_flax.py +++ b/src/python/easydel/modules/openelm/modelling_openelm_flax.py @@ -191,33 +191,37 @@ def _merge_heads(self, hidden_states): @staticmethod def _transpose_sequence_head(query, key, value): - """ - The _transpose_sequence_head function transposes the query, key and value matrices. + """The _transpose_sequence_head function transposes the query, key and value matrices. - :param query: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query, key and value matrices + Args: + query: Get the attention weights for each of the heads + key: Determine the number of heads + value: Store the values of the input + Returns: + The transpose of the query, key and value matrices """ return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - :param self: Access variables that belong to the class - :param batch_size: Reshape the query, key and value tensors - :param sequence_length: Reshape the query, key and value tensors - :param query: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param freq_cis: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query, key and value - + Args: + self: Access variables that belong to the class + batch_size: Reshape the query, key and value tensors + sequence_length: Reshape the query, key and value tensors + query: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + freq_cis: Calculate the frequency of each word in the + vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query, key and value """ query = query.reshape( batch_size, @@ -257,25 +261,32 @@ def __call__( output_attentions: bool = False, fcm_mask=None, ): - """ - - The __call__ function is the main function of a JAX module. It defines how the module behaves when called + """The __call__ function is the main function of a JAX module. It defines how the module behaves when called with inputs. The __call__ function can be thought of as a "forward pass" through the model, and it should return all outputs that are needed for training or inference. - :param self: Access variables that belong to the class - :param hidden_states: chex.Array: Pass the hidden states of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param causal_mask: chex.Array: Mask out the future tokens in the decoder - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attention weights or not - :param fcm_mask: Mask out the attention weights between the input and output tokens + Args: + self: Access variables that belong to the class + hidden_states: chex.Array: Pass the hidden states of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency coefficients for each position + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Determine the position of each + token in a sequence + causal_mask: chex.Array: Mask out the future tokens in the + decoder + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attention weights or not + fcm_mask: Mask out the attention weights between the input + and output tokens :param : Determine if the attention is causal or not - :return: A tuple of two arrays + Returns: + A tuple of two arrays """ batch_size, sequence_length = hidden_states.shape[:2] output_attentions = False @@ -728,25 +739,31 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ) -> typing.Union[Tuple[Array, ...], FlaxBaseModelOutput]: - """ - The __call__ function is the main function of a Flax model. + """The __call__ function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids as inputs to the model. The output is a tuple containing: last hidden state (hidden states), all hidden states (if output_hidden_states=True), attentions (if output attentions=True). - - :param self: Represent the instance of the class - :param input_ids: chex.Array: Pass in the input ids - :param attention_mask: chex.Array: Mask out the attention weights for certain tokens - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param deterministic: bool: Determine whether to use dropout or not - :param inputs_embeds: chex.Array: Pass in the embedding of the input_ids - :param init_cache: bool: Initialize the cache for the decoder - :param output_attentions: bool: Determine whether to return the attention weights or not - :param output_hidden_states: bool: Return all hidden states or just the last one - :param return_dict: bool: Return a dictionary of the outputs or not + Args: + self: Represent the instance of the class + input_ids: chex.Array: Pass in the input ids + attention_mask: chex.Array: Mask out the attention weights + for certain tokens + position_ids: chex.Array: Determine the position of each + token in a sequence + deterministic: bool: Determine whether to use dropout or not + inputs_embeds: chex.Array: Pass in the embedding of the + input_ids + init_cache: bool: Initialize the cache for the decoder + output_attentions: bool: Determine whether to return the + attention weights or not + output_hidden_states: bool: Return all hidden states or just + the last one + return_dict: bool: Return a dictionary of the outputs or not :param : Determine whether the model is in training mode or not - :return: A tuple of the hidden states, all hidden states, and attentions + Returns: + A tuple of the hidden states, all hidden states, and + attentions """ if inputs_embeds is None: inputs_embeds = self.token_embeddings(input_ids.astype("i4")) @@ -819,18 +836,21 @@ def init_weights( input_shape: Tuple, params: flax.core.FrozenDict = None ) -> flax.core.FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. + """The init_weights function is used to initialize the weights of a model. It takes in a rng, which is a random number generator key that can be used to generate random numbers. The input_shape parameter specifies the shape of the inputs that will be fed into this model. The params parameter allows you to pass in pre-trained weights for your model, if you have them available. - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids - :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Initialize the input_ids, attention_mask + and position_ids + params: flax.core.FrozenDict: Pass in the parameters of a + pre-trained model + Returns: + A frozendict of parameters """ input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -903,28 +923,35 @@ def __call__( add_params_field: bool = False, **kwargs ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It takes as input: - The parameters of the model (self.params) - The inputs to the model (input_ids, attention_mask, position_ids) - Whether we are training (train=True/False) and whether we want to return all hidden states and attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False). - :param self: Represent the instance of the class - :param input_ids: Pass the input sequence to the model - :param attention_mask: Mask out the padding tokens - :param position_ids: Specify the position of each token in the sequence - :param params: dict: Pass in the parameters of the model - :param past_key_values: dict: Pass the past key values to the model - :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model - :param train: bool: Determine whether to use dropout or not - :param output_attentions: Optional[bool]: Determine whether to return the attention weights - :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers - :param return_dict: Optional[bool]: Return a dictionary of the outputs - :param add_params_field: bool: Add a params field to the inputs dictionary - :return: A tuple of (last_hidden_state, past_key_values) - + Args: + self: Represent the instance of the class + input_ids: Pass the input sequence to the model + attention_mask: Mask out the padding tokens + position_ids: Specify the position of each token in the + sequence + params: dict: Pass in the parameters of the model + past_key_values: dict: Pass the past key values to the model + dropout_rng: jax.random.PRNGKey: Pass in a random number + generator key to the model + train: bool: Determine whether to use dropout or not + output_attentions: Optional[bool]: Determine whether to + return the attention weights + output_hidden_states: Optional[bool]: Determine whether to + return the hidden states of all layers + return_dict: Optional[bool]: Return a dictionary of the + outputs + add_params_field: bool: Add a params field to the inputs + dictionary + + Returns: + A tuple of (last_hidden_state, past_key_values) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1026,26 +1053,32 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ): - """ - The __call__ function is the main function of a Flax module. It defines how the model will be called, - and what it returns. In this case, we are calling our Transformer model with input_ids and attention_mask - as inputs (these are defined in __init__). We also have some optional arguments that can be passed to - the call function: deterministic (whether to use dropout), inputs_embeds (if you want to pass your own embeddings), - output_attentions and output_hidden states which return additional outputs from the transformer layers if set True. Finally, - - :param self: Refer to the object itself - :param input_ids: chex.Array: Pass in the input tokens - :param attention_mask: chex.Array: Mask out the padding tokens - :param position_ids: chex.Array: Specify the position of each token in the sequence - :param deterministic: bool: Determine whether to use dropout in the model - :param inputs_embeds: chex.Array: Pass in the embeddings of the input tokens - :param init_cache: bool: Initialize the cache for the decoder - :param output_attentions: bool: Return the attention weights - :param output_hidden_states: bool: Return the hidden states of all layers - :param return_dict: bool: Return a dictionary of the outputs or just the logits - :param : Determine whether to return the logits or not - :return: A tuple of (lm_logits, hidden_states, attentions) - + """The __call__ function is the main function of a Flax module. It defines how the model will be called, + and what it returns. In this case, we are calling our Transformer model with input_ids and attention_mask + as inputs (these are defined in __init__). We also have some optional arguments that can be passed to + the call function: deterministic (whether to use dropout), inputs_embeds (if you want to pass your own embeddings), + output_attentions and output_hidden states which return additional outputs from the transformer layers if set True. Finally, + + Args: + self: Refer to the object itself + input_ids: chex.Array: Pass in the input tokens + attention_mask: chex.Array: Mask out the padding tokens + position_ids: chex.Array: Specify the position of each token + in the sequence + deterministic: bool: Determine whether to use dropout in the + model + inputs_embeds: chex.Array: Pass in the embeddings of the + input tokens + init_cache: bool: Initialize the cache for the decoder + output_attentions: bool: Return the attention weights + output_hidden_states: bool: Return the hidden states of all + layers + return_dict: bool: Return a dictionary of the outputs or + just the logits + :param : Determine whether to return the logits or not + + Returns: + A tuple of (lm_logits, hidden_states, attentions) """ batch_size, seq_length = input_ids.shape diff --git a/src/python/easydel/modules/openelm/openelm_configuration.py b/src/python/easydel/modules/openelm/openelm_configuration.py index e3fe6c857..675ac099d 100644 --- a/src/python/easydel/modules/openelm/openelm_configuration.py +++ b/src/python/easydel/modules/openelm/openelm_configuration.py @@ -10,8 +10,7 @@ def make_divisible( divisor: Optional[int] = 8, min_value: Optional[Union[float, int]] = None, ) -> Union[float, int]: - """ - This function is taken from the original tf repo. + """This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by the divisor It can be seen at: https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62 @@ -82,43 +81,61 @@ def __init__( bits: Optional[int] = None, **kwargs, ): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It allows the class to initialize the attributes of a class. The self parameter is a reference to the current instance of the class, and is used to access variables that belong to the class. - :param self: Represent the instance of the class - :param vocab_size: Define the size of the vocabulary - :param hidden_size: Determine the size of the embedding layers - :param intermediate_size: Define the size of the intermediate layer in each transformer block - :param num_hidden_layers: Determine the number of layers in the encoder and decoder - :param num_attention_heads: Determine the number of attention heads in each layer - :param num_key_value_heads: Specify the number of heads for key and value - :param hidden_act: Specify the activation function used in the hidden layers - :param max_position_embeddings: Set the maximum length of the sequence - :param initializer_range: Initialize the weights of the model - :param rms_norm_eps: Avoid division by zero in the rms normalization - :param use_cache: Determine whether to use the cache in the decoder - :param pad_token_id: Specify the token id of the padding token - :param bos_token_id: Specify the beginning of sentence token id - :param eos_token_id: Specify the end of sentence token - :param tie_word_embeddings: Tie the word embeddings and the output layer - :param rope_theta: Control the number of tokens in a rope - :param sliding_window: Control the number of tokens that are processed in parallel - :param gradient_checkpointing: str: Specify whether to use gradient checkpointing - :param use_scan_mlp: bool: Determine whether or not to use the scan_mlp function - :param scan_mlp_chunk_size: int: Specify the chunk size of the scan mlp - :param number_rep_kv: int: Specify the number of times to repeat the key and value vectors - :param attention_dropout: float: Set the dropout rate for the attention layer - :param bits: Optional[int]: Specify the number of bits used for quantization - :param axis_dims: Sequence[int]: Specify the dimension of each axis - :param axis_names: Sequence[str]: Specify the names of each axis in the tensor - :param "mp"): Define the maximum position embeddings - :param attention_bias: bool: when ever to use attention_bias - :param kwargs: Pass a variable number of keyword arguments to a function + Args: + self: Represent the instance of the class + vocab_size: Define the size of the vocabulary + hidden_size: Determine the size of the embedding layers + intermediate_size: Define the size of the intermediate layer + in each transformer block + num_hidden_layers: Determine the number of layers in the + encoder and decoder + num_attention_heads: Determine the number of attention heads + in each layer + num_key_value_heads: Specify the number of heads for key and + value + hidden_act: Specify the activation function used in the + hidden layers + max_position_embeddings: Set the maximum length of the + sequence + initializer_range: Initialize the weights of the model + rms_norm_eps: Avoid division by zero in the rms + normalization + use_cache: Determine whether to use the cache in the decoder + pad_token_id: Specify the token id of the padding token + bos_token_id: Specify the beginning of sentence token id + eos_token_id: Specify the end of sentence token + tie_word_embeddings: Tie the word embeddings and the output + layer + rope_theta: Control the number of tokens in a rope + sliding_window: Control the number of tokens that are + processed in parallel + gradient_checkpointing: str: Specify whether to use gradient + checkpointing + use_scan_mlp: bool: Determine whether or not to use the + scan_mlp function + scan_mlp_chunk_size: int: Specify the chunk size of the scan + mlp + number_rep_kv: int: Specify the number of times to repeat + the key and value vectors + attention_dropout: float: Set the dropout rate for the + attention layer + bits: Optional[int]: Specify the number of bits used for + quantization + axis_dims: Sequence[int]: Specify the dimension of each axis + axis_names: Sequence[str]: Specify the names of each axis in + the tensor + "mp"): Define the maximum position embeddings + attention_bias: bool: when ever to use attention_bias + **kwargs: Pass a variable number of keyword arguments to a + function :param : Define the number of layers in the model - :return: An instance of the class + Returns: + An instance of the class """ self.vocab_size = vocab_size self.max_context_length = max_context_length @@ -161,15 +178,18 @@ def __init__( ) def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned. - :param fully_sharded_data_parallel: bool: Determine whether to use the fully_sharded_data_parallel partitioning scheme or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to use + the fully_sharded_data_parallel partitioning scheme or + not + Returns: + A list of tuples """ return ( @@ -215,17 +235,23 @@ def add_jax_args( rope_scaling: Dict[str, Union[str, float]] = None, **kwargs, ): - """ - The add_jax_args function adds the following arguments to the model: - - :param self: Bind the attributes and methods of a class to an instance of that class - :param gradient_checkpointing: str: Determine whether to use gradient checkpointing - :param use_scan_mlp: bool: Determine whether to use the scan_mlp function or notn - :param scan_mlp_chunk_size: int: Chunk the input to the mlp - :param bits: Optional[int]: Specify the number of bits to use for quantization - :param rope_scaling: Dict[str, Union[str, float]]: rope_scaling for rope - :return: A tuple of the following: - + """The add_jax_args function adds the following arguments to the model: + + Args: + self: Bind the attributes and methods of a class to an + instance of that class + gradient_checkpointing: str: Determine whether to use + gradient checkpointing + use_scan_mlp: bool: Determine whether to use the scan_mlp + function or notn + scan_mlp_chunk_size: int: Chunk the input to the mlp + bits: Optional[int]: Specify the number of bits to use for + quantization + rope_scaling: Dict[str, Union[str, float]]: rope_scaling for + rope + + Returns: + A tuple of the following: """ self.rope_scaling = rope_scaling diff --git a/src/python/easydel/modules/opt/modelling_opt_flax.py b/src/python/easydel/modules/opt/modelling_opt_flax.py index 38145c257..05c4f80d7 100644 --- a/src/python/easydel/modules/opt/modelling_opt_flax.py +++ b/src/python/easydel/modules/opt/modelling_opt_flax.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # THIS SCRIPT IS EDITED FROM ORIGINAL IMPLEMENTATION OF TRANSFORMERS OPT -""" Flax OPT model.""" +"""Flax OPT model.""" from functools import partial diff --git a/src/python/easydel/modules/phi/modelling_phi_flax.py b/src/python/easydel/modules/phi/modelling_phi_flax.py index cc62f6225..fc68161af 100644 --- a/src/python/easydel/modules/phi/modelling_phi_flax.py +++ b/src/python/easydel/modules/phi/modelling_phi_flax.py @@ -180,34 +180,40 @@ def _merge_heads(self, hidden_states): @staticmethod def _transpose_sequence_head(query_states, key, value): - """ - The _transpose_sequence_head function transposes the query_states, key and value matrices. + """The _transpose_sequence_head function transposes the query_states, key and value matrices. - :param query_states: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query_states, key and value matrices + Args: + query_states: Get the attention weights for each of the + heads + key: Determine the number of heads + value: Store the values of the input + Returns: + The transpose of the query_states, key and value matrices """ return jnp.transpose(query_states, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - :param self: Access variables that belong to the class - :param batch_size: Reshape the query_states, key and value tensors - :param sequence_length: Reshape the query_states, key and value tensors - :param query: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param freq_cis: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query_states, key and value - + Args: + self: Access variables that belong to the class + batch_size: Reshape the query_states, key and value tensors + sequence_length: Reshape the query_states, key and value + tensors + query: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + freq_cis: Calculate the frequency of each word in the + vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query_states, key and value """ query = query.reshape( batch_size, diff --git a/src/python/easydel/modules/phi3/modelling_phi3_flax.py b/src/python/easydel/modules/phi3/modelling_phi3_flax.py index 8b4fbed02..e1f3b7bfa 100644 --- a/src/python/easydel/modules/phi3/modelling_phi3_flax.py +++ b/src/python/easydel/modules/phi3/modelling_phi3_flax.py @@ -196,34 +196,40 @@ def _merge_heads(self, hidden_states): @staticmethod def _transpose_sequence_head(query_states, key, value): - """ - The _transpose_sequence_head function transposes the query_states, key and value matrices. + """The _transpose_sequence_head function transposes the query_states, key and value matrices. - :param query_states: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query_states, key and value matrices + Args: + query_states: Get the attention weights for each of the + heads + key: Determine the number of heads + value: Store the values of the input + Returns: + The transpose of the query_states, key and value matrices """ return jnp.transpose(query_states, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - :param self: Access variables that belong to the class - :param batch_size: Reshape the query_states, key and value tensors - :param sequence_length: Reshape the query_states, key and value tensors - :param query: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param freq_cis: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query_states, key and value - + Args: + self: Access variables that belong to the class + batch_size: Reshape the query_states, key and value tensors + sequence_length: Reshape the query_states, key and value + tensors + query: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + freq_cis: Calculate the frequency of each word in the + vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query_states, key and value """ query = query.reshape( batch_size, diff --git a/src/python/easydel/modules/phi3/phi3_configuration.py b/src/python/easydel/modules/phi3/phi3_configuration.py index 61bce45e8..87f4d1447 100644 --- a/src/python/easydel/modules/phi3/phi3_configuration.py +++ b/src/python/easydel/modules/phi3/phi3_configuration.py @@ -118,9 +118,7 @@ def get_partition_rules(self, fully_sharded_data_parallel: bool = True): ) def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ + """Validate the `rope_scaling` configuration.""" if self.rope_scaling is None: return diff --git a/src/python/easydel/modules/qwen1/modelling_qwen1_flax.py b/src/python/easydel/modules/qwen1/modelling_qwen1_flax.py index cfc8bcb95..34a679a67 100644 --- a/src/python/easydel/modules/qwen1/modelling_qwen1_flax.py +++ b/src/python/easydel/modules/qwen1/modelling_qwen1_flax.py @@ -177,16 +177,18 @@ def setup(self) -> None: ) def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - """ - The __call__ function is the main function of a class. + """The __call__ function is the main function of a class. It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments). The __call__ method enables instances of a class to be called like standard Python functions. - :param self: Represent the instance of the class - :param x: jnp.ndarray: Pass in the input to the layer - :param deterministic: bool: Determine whether to use dropout - :return: A tensor that is the result of applying a dropout function to x + Args: + self: Represent the instance of the class + x: jnp.ndarray: Pass in the input to the layer + deterministic: bool: Determine whether to use dropout + Returns: + A tensor that is the result of applying a dropout function + to x """ x = self.c_proj(jax.nn.silu(self.w2(x)) * self.w1(x)) return x @@ -278,33 +280,37 @@ def _merge_heads(self, hidden_states): @staticmethod def _transpose_sequence_head(query, key, value): - """ - The _transpose_sequence_head function transposes the query, key and value matrices. + """The _transpose_sequence_head function transposes the query, key and value matrices. - :param query: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query, key and value matrices + Args: + query: Get the attention weights for each of the heads + key: Determine the number of heads + value: Store the values of the input + Returns: + The transpose of the query, key and value matrices """ return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) def apply_rotary(self, batch_size, sequence_length, query, key, value, rotary_pos_emb_list, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, rotary_pos_emb_list, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - :param self: Access variables that belong to the class - :param batch_size: Reshape the query, key and value tensors - :param sequence_length: Reshape the query, key and value tensors - :param query_states: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param rotary_pos_emb_list: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query_states, key and value - + Args: + self: Access variables that belong to the class + batch_size: Reshape the query, key and value tensors + sequence_length: Reshape the query, key and value tensors + query_states: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + rotary_pos_emb_list: Calculate the frequency of each word in + the vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query_states, key and value """ query_states, key = self.rotary( position_ids=position_ids, query_states=query_states, key=key, rotary_pos_emb_list=rotary_pos_emb_list @@ -326,25 +332,32 @@ def __call__( encoder_attention_mask: Optional[chex.Array] = None, fcm_mask=None, ): - """ - - The __call__ function is the main function of a JAX module. It defines how the module behaves when called + """The __call__ function is the main function of a JAX module. It defines how the module behaves when called with inputs. The __call__ function can be thought of as a "forward pass" through the model, and it should return all outputs that are needed for training or inference. - :param self: Access variables that belong to the class - :param hidden_states: chex.Array: Pass the hidden states of the previous layer - :param rotary_pos_emb_list: list[chex.Array]: Pass in the frequency coefficients for each position - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param causal_mask: chex.Array: Mask out the future tokens in the decoder - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attention weights or not - :param fcm_mask: Mask out the attention weights between the input and output tokens + Args: + self: Access variables that belong to the class + hidden_states: chex.Array: Pass the hidden states of the + previous layer + rotary_pos_emb_list: list[chex.Array]: Pass in the frequency + coefficients for each position + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Determine the position of each + token in a sequence + causal_mask: chex.Array: Mask out the future tokens in the + decoder + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attention weights or not + fcm_mask: Mask out the attention weights between the input + and output tokens :param : Determine if the attention is causal or not - :return: A tuple of two arrays + Returns: + A tuple of two arrays """ batch_size, sequence_length = hidden_states.shape[:2] mixed_x_layer: chex.Array = self.c_attn(hidden_states) @@ -517,25 +530,32 @@ def __call__( encoder_attention_mask: Optional[chex.Array] = None, fcm_mask: Optional[jnp.ndarray] = None, ): - """ - The __call__ function is the main function of a TransformerEncoderLayer. + """The __call__ function is the main function of a TransformerEncoderLayer. It takes in hidden states, frequency-domain inputs, and masks as input. It then applies self-attention to the hidden states using those inputs and returns an output tensor with shape (batch_size, sequence_length, model_dim). - :param self: Refer to the class instance itself - :param hidden_states: chex.Array: Pass in the hidden state of the previous layer - :param rotary_pos_emb_list: list[chex.Array]: Pass in the frequency information - :param attention_mask: chex.Array: Mask out the attention weights for padding tokens - :param position_ids: chex.Array: Determine the position of each token in the sequence - :param causal_mask: chex.Array: Mask the attention weights - :param deterministic: bool: Control whether the dropout is applied or not - :param init_cache: bool: Initialize the cache in the attention layer - :param output_attentions: bool: Return the attention weights - :param fcm_mask: Optional[jnp.ndarray]: Mask the self-attention + Args: + self: Refer to the class instance itself + hidden_states: chex.Array: Pass in the hidden state of the + previous layer + rotary_pos_emb_list: list[chex.Array]: Pass in the frequency + information + attention_mask: chex.Array: Mask out the attention weights + for padding tokens + position_ids: chex.Array: Determine the position of each + token in the sequence + causal_mask: chex.Array: Mask the attention weights + deterministic: bool: Control whether the dropout is applied + or not + init_cache: bool: Initialize the cache in the attention + layer + output_attentions: bool: Return the attention weights + fcm_mask: Optional[jnp.ndarray]: Mask the self-attention :param : Control the dropout in the self attention layer - :return: A tuple of two items + Returns: + A tuple of two items """ # hidden_states: chex.Array # rotary_pos_emb_list: list[chex.Array] @@ -616,36 +636,41 @@ def __init__( _do_init: bool = True, **kwargs, ): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It sets up the instance of the class, and defines what happens when it's created. The __init__ function can take arguments, but self is always required (it refers to the instance of the object). - - :param self: Refer to the object itself - :param config: Qwen1Config: Pass the configuration to the module - :param input_shape: Tuple: Specify the shape of the input to the model - :param seed: int: Set the seed for random number generation - :param dtype: jnp.dtype: Specify the data type of the input - :param _do_init: bool: Control whether the module is initialized or not - :param kwargs: Pass in any additional parameters that the module_class might need + Args: + self: Refer to the object itself + config: Qwen1Config: Pass the configuration to the module + input_shape: Tuple: Specify the shape of the input to the + model + seed: int: Set the seed for random number generation + dtype: jnp.dtype: Specify the data type of the input + _do_init: bool: Control whether the module is initialized or + not + **kwargs: Pass in any additional parameters that the + module_class might need :param : Specify the number of h in the network - :return: The super() of the class + Returns: + The super() of the class """ module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. + """The init_weights function is used to initialize the weights of a model. - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Specify the shape of the input tensor - :param params: FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Specify the shape of the input tensor + params: FrozenDict: Pass in the parameters of a pre-trained + model + Returns: + A frozendict of parameters """ input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -683,17 +708,18 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz return random_params def init_cache(self, batch_size, max_length): - """ - The init_cache function is used to initialize the cache for a given batch size and sequence length. + """The init_cache function is used to initialize the cache for a given batch size and sequence length. The cache is a dictionary that contains all the intermediate states from each layer in the model. This allows us to run inference on multiple batches without having to re-run forward passes through every layer in the model, which would be very slow. - :param self: Access the module - :param batch_size: Define the batch size of the input tensors - :param max_length: Set the length of the input sequence - :return: A dictionary with the following keys: + Args: + self: Access the module + batch_size: Define the batch size of the input tensors + max_length: Set the length of the input sequence + Returns: + A dictionary with the following keys: """ input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) @@ -741,27 +767,36 @@ def __call__( add_params_field: bool = False, **kwargs ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It takes in inputs and returns outputs, but it also has some other important features: - It can take in mutable state (e.g., past_key_values) that will be updated during the call and returned at the end. - It can take in random number generators (rngs) that are used to generate random numbers for dropout or sampling operations. - :param self: Represent the instance of the class - :param input_ids: chex.Array: Pass in the input tokens - :param attention_mask: chex.Array: Mask out certain tokens in the input - :param position_ids: chex.Array: Create the positional embeddings - :param params: dict: Pass in the parameters of the model - :param past_key_values: dict: Pass in the past key values from a previous call to __call__ - :param dropout_rng: jax.random.PRNGKey: Make sure that the dropout is applied in a random way - :param train: bool: Determine whether to use dropout or not - :param output_attentions: Optional[bool]: Determine whether to return the attention weights - :param output_hidden_states: Optional[bool]: Return the hidden states of all h - :param return_dict: Optional[bool]: Determine whether to return a dictionary or not - :param extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids - :param add_params_field: bool: Add the params field to the inputs dictionary - :return: A tuple of the following: - + Args: + self: Represent the instance of the class + input_ids: chex.Array: Pass in the input tokens + attention_mask: chex.Array: Mask out certain tokens in the + input + position_ids: chex.Array: Create the positional embeddings + params: dict: Pass in the parameters of the model + past_key_values: dict: Pass in the past key values from a + previous call to __call__ + dropout_rng: jax.random.PRNGKey: Make sure that the dropout + is applied in a random way + train: bool: Determine whether to use dropout or not + output_attentions: Optional[bool]: Determine whether to + return the attention weights + output_hidden_states: Optional[bool]: Return the hidden + states of all h + return_dict: Optional[bool]: Determine whether to return a + dictionary or not + extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in + the embedding for the input_ids + add_params_field: bool: Add the params field to the inputs + dictionary + + Returns: + A tuple of the following: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -836,15 +871,18 @@ def __call__( return outputs def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None): - """ - The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. - - :param self: Access variables that belong to the class - :param input_ids: Pass in the input tokens - :param max_length: Set the length of the sequence to be generated - :param attention_mask: Optional[chex.Array]: Mask the attention weights - :return: A dictionary of the past_key_values, attention_mask and position ids - + """The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. + + Args: + self: Access variables that belong to the class + input_ids: Pass in the input tokens + max_length: Set the length of the sequence to be generated + attention_mask: Optional[chex.Array]: Mask the attention + weights + + Returns: + A dictionary of the past_key_values, attention_mask and + position ids """ batch_size, seq_length = input_ids.shape extended_attention_mask = jnp.ones( @@ -904,27 +942,35 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ): - """ - The __call__ function is the main function of a JAX nn.Module. + """The __call__ function is the main function of a JAX nn.Module. It defines how the module behaves when called as a function, and it's what you'll use to call your model in training loops or inference scripts. The __call__ method should take all inputs that are necessary for computing outputs from the module, and return all outputs that are computed by this module. - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Pass the input tensor to the encoder - :param rotary_pos_emb_list: chex.Array: Pass in the frequency of each token - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Specify the position of each token in a sequence - :param causal_mask: chex.Array: Mask the attention weights - :param deterministic: bool: Determine whether the model is in training or evaluation mode - :param init_cache: bool: Initialize the cache for each layer - :param output_attentions: bool: Determine whether to output the attention weights - :param output_hidden_states: bool: Determine whether to return the hidden states of each layer - :param return_dict: bool: Return a dictionary of the outputs + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Pass the input tensor to the + encoder + rotary_pos_emb_list: chex.Array: Pass in the frequency of + each token + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Specify the position of each token + in a sequence + causal_mask: chex.Array: Mask the attention weights + deterministic: bool: Determine whether the model is in + training or evaluation mode + init_cache: bool: Initialize the cache for each layer + output_attentions: bool: Determine whether to output the + attention weights + output_hidden_states: bool: Determine whether to return the + hidden states of each layer + return_dict: bool: Return a dictionary of the outputs :param : Determine whether to use the forgetful causal mask - :return: A tuple of 3 values + Returns: + A tuple of 3 values """ all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -1034,25 +1080,32 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids + """The __call__ function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids and returns the output of the model. The __call__ function also has optional arguments that can be used to control the behavior of the model (e.g., deterministic=True). These optional arguments are passed as keyword arguments when calling a Flax model. - :param self: Represent the instance of the class - :param input_ids: chex.Array: Pass in the input token ids - :param attention_mask: chex.Array: Mask out the padding tokens - :param position_ids: chex.Array: Indicate the position of each token in a sequence - :param deterministic: bool: Control whether dropout is applied or not - :param inputs_embeds: chex.Array: Pass in the embeddings of the input tokens - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attentions or not - :param output_hidden_states: bool: Determine whether to return hidden states - :param return_dict: bool: Return a dictionary of the output or not - :param extra_embedding: Optional[Union[jnp.ndarray, None]]: Pass in the embedding of the - :return: A tuple of: - + Args: + self: Represent the instance of the class + input_ids: chex.Array: Pass in the input token ids + attention_mask: chex.Array: Mask out the padding tokens + position_ids: chex.Array: Indicate the position of each + token in a sequence + deterministic: bool: Control whether dropout is applied or + not + inputs_embeds: chex.Array: Pass in the embeddings of the + input tokens + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attentions or not + output_hidden_states: bool: Determine whether to return + hidden states + return_dict: bool: Return a dictionary of the output or not + extra_embedding: Optional[Union[jnp.ndarray, None]]: Pass in + the embedding of the + + Returns: + A tuple of: """ if inputs_embeds is None: inputs_embeds = self.wte(input_ids.astype("i4")) @@ -1167,22 +1220,27 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax module. It takes in inputs and returns outputs. - - :param self: Refer to the object itself - :param input_ids: chex.Array: Pass the input token ids to the model - :param attention_mask: chex.Array: Mask out the padding tokens - :param position_ids: chex.Array: Specify the position of each token in the input sequence - :param deterministic: bool: Control whether the model is trained or not - :param init_cache: bool: Initialize the cache for the decoder - :param output_attentions: bool: Return the attention weights - :param output_hidden_states: bool: Determine whether to return the hidden states - :param return_dict: bool: Return a dictionary of the outputs or not - :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of the word that we want to predict - :param None]]: Pass in the extra embedding - :return: The logits and the hidden states - + """The __call__ function is the main function of a Flax module. It takes in inputs and returns outputs. + + Args: + self: Refer to the object itself + input_ids: chex.Array: Pass the input token ids to the model + attention_mask: chex.Array: Mask out the padding tokens + position_ids: chex.Array: Specify the position of each token + in the input sequence + deterministic: bool: Control whether the model is trained or + not + init_cache: bool: Initialize the cache for the decoder + output_attentions: bool: Return the attention weights + output_hidden_states: bool: Determine whether to return the + hidden states + return_dict: bool: Return a dictionary of the outputs or not + extra_embedding: Optional[Union[jnp.ndarray: Pass in the + embedding of the word that we want to predict + None]]: Pass in the extra embedding + + Returns: + The logits and the hidden states """ batch_size, seq_length = input_ids.shape if attention_mask is None: @@ -1252,12 +1310,14 @@ class FlaxQwen1ForSequenceClassificationModule(nn.Module): precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self): - """ - The setup function is called once at the beginning of training. + """The setup function is called once at the beginning of training. It initializes the model and optimizer, and sets up any other state that needs to be initialized. - :param self: Access variables that belong to the class - :return: A tuple of the model and the classifier + Args: + self: Access variables that belong to the class + + Returns: + A tuple of the model and the classifier """ self.model = FlaxQwen1Module(self.config, dtype=self.dtype) self.classifier = Linear( @@ -1282,26 +1342,31 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax module. + """The __call__ function is the main function of a Flax module. It takes in all the inputs to the model and returns all outputs from it. The __call__ function can be called directly on an instance of a class, or by using parentheses after an instance: >>> my_model = MyModel() # instantiate your model class >>> output = my_model(input) # call your model with input data as arguments to __call__ - :param self: Refer to the class instance - :param input_ids: chex.Array: Pass the input to the model - :param attention_mask: chex.Array: Specify which tokens are masked - :param position_ids: chex.Array: Specify the position of each token in the sequence - :param deterministic: bool: Control whether the model is run in deterministic or stochastic mode - :param init_cache: bool: Initialize the cache for the transformer - :param output_attentions: bool: Return the attention weights - :param output_hidden_states: bool: Return the hidden states of all h - :param return_dict: bool: Return a dictionary of outputs - :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of a new word - :param None]]: Pass the extra embedding to the model - :return: A tuple of logits and hidden_states - + Args: + self: Refer to the class instance + input_ids: chex.Array: Pass the input to the model + attention_mask: chex.Array: Specify which tokens are masked + position_ids: chex.Array: Specify the position of each token + in the sequence + deterministic: bool: Control whether the model is run in + deterministic or stochastic mode + init_cache: bool: Initialize the cache for the transformer + output_attentions: bool: Return the attention weights + output_hidden_states: bool: Return the hidden states of all + h + return_dict: bool: Return a dictionary of outputs + extra_embedding: Optional[Union[jnp.ndarray: Pass in the + embedding of a new word + None]]: Pass the extra embedding to the model + + Returns: + A tuple of logits and hidden_states """ batch_size, seq_length = input_ids.shape if attention_mask is None: diff --git a/src/python/easydel/modules/qwen1/qwen1_configuration.py b/src/python/easydel/modules/qwen1/qwen1_configuration.py index d27174764..1a206bdda 100644 --- a/src/python/easydel/modules/qwen1/qwen1_configuration.py +++ b/src/python/easydel/modules/qwen1/qwen1_configuration.py @@ -72,15 +72,17 @@ def __init__( ) def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices. - :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to + partition the model fully or not + Returns: + A list of tuples """ return ( @@ -129,18 +131,24 @@ def add_jax_args( init_rope_cache_auto: bool = False, **kwargs, ): - """ - The add_jax_args function adds the following arguments to the Transformer class: - - :param self: Refer to the current object - :param gradient_checkpointing: str: Control the amount of memory used by jax - :param use_scan_mlp: bool: Determine whether to use the scan_mlp function or not - :param scan_mlp_chunk_size: int: Set the chunk size for scan_mlp - :param init_rope_cache_auto: bool: Whether to use the rope_cache_auto in model - :param bits: Optional[int]: Determine the number of bits used in the quantization - :param scan_layers: bool: Determine whether to use scan layers or not - :return: The following: - + """The add_jax_args function adds the following arguments to the Transformer class: + + Args: + self: Refer to the current object + gradient_checkpointing: str: Control the amount of memory + used by jax + use_scan_mlp: bool: Determine whether to use the scan_mlp + function or not + scan_mlp_chunk_size: int: Set the chunk size for scan_mlp + init_rope_cache_auto: bool: Whether to use the + rope_cache_auto in model + bits: Optional[int]: Determine the number of bits used in + the quantization + scan_layers: bool: Determine whether to use scan layers or + not + + Returns: + The following: """ self.scan_layers = scan_layers self.gradient_checkpointing = gradient_checkpointing diff --git a/src/python/easydel/modules/qwen2/modelling_qwen_flax.py b/src/python/easydel/modules/qwen2/modelling_qwen_flax.py index 1ed4cf24c..60171fe9d 100644 --- a/src/python/easydel/modules/qwen2/modelling_qwen_flax.py +++ b/src/python/easydel/modules/qwen2/modelling_qwen_flax.py @@ -126,16 +126,18 @@ def setup(self) -> None: self.dropout = flax.linen.Dropout(rate=self.config.resid_pdrop) def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - """ - The __call__ function is the main function of a class. + """The __call__ function is the main function of a class. It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments). The __call__ method enables instances of a class to be called like standard Python functions. - :param self: Represent the instance of the class - :param x: jnp.ndarray: Pass in the input to the layer - :param deterministic: bool: Determine whether to use dropout - :return: A tensor that is the result of applying a dropout function to x + Args: + self: Represent the instance of the class + x: jnp.ndarray: Pass in the input to the layer + deterministic: bool: Determine whether to use dropout + Returns: + A tensor that is the result of applying a dropout function + to x """ x = self.down_proj(jax.nn.silu(self.gate_proj(x)) * self.up_proj(x)) x = self.dropout(x, deterministic=deterministic) @@ -244,33 +246,37 @@ def _merge_heads(self, hidden_states): @staticmethod def _transpose_sequence_head(query, key, value): - """ - The _transpose_sequence_head function transposes the query, key and value matrices. + """The _transpose_sequence_head function transposes the query, key and value matrices. - :param query: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query, key and value matrices + Args: + query: Get the attention weights for each of the heads + key: Determine the number of heads + value: Store the values of the input + Returns: + The transpose of the query, key and value matrices """ return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - :param self: Access variables that belong to the class - :param batch_size: Reshape the query, key and value tensors - :param sequence_length: Reshape the query, key and value tensors - :param query: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param freq_cis: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query, key and value - + Args: + self: Access variables that belong to the class + batch_size: Reshape the query, key and value tensors + sequence_length: Reshape the query, key and value tensors + query: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + freq_cis: Calculate the frequency of each word in the + vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query, key and value """ query = query.reshape(batch_size, sequence_length, self.config.num_attention_heads, self.head_dim) key = key.reshape(batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim) @@ -297,25 +303,32 @@ def __call__( output_attentions: bool = False, fcm_mask=None, ): - """ - - The __call__ function is the main function of a JAX module. It defines how the module behaves when called + """The __call__ function is the main function of a JAX module. It defines how the module behaves when called with inputs. The __call__ function can be thought of as a "forward pass" through the model, and it should return all outputs that are needed for training or inference. - :param self: Access variables that belong to the class - :param hidden_states: chex.Array: Pass the hidden states of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param causal_mask: chex.Array: Mask out the future tokens in the decoder - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attention weights or not - :param fcm_mask: Mask out the attention weights between the input and output tokens + Args: + self: Access variables that belong to the class + hidden_states: chex.Array: Pass the hidden states of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency coefficients for each position + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Determine the position of each + token in a sequence + causal_mask: chex.Array: Mask out the future tokens in the + decoder + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attention weights or not + fcm_mask: Mask out the attention weights between the input + and output tokens :param : Determine if the attention is causal or not - :return: A tuple of two arrays + Returns: + A tuple of two arrays """ batch_size, sequence_length = hidden_states.shape[:2] query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj( @@ -499,25 +512,32 @@ def __call__( output_attentions: bool = False, fcm_mask: Optional[jnp.ndarray] = None, ): - """ - The __call__ function is the main function of a TransformerEncoderLayer. + """The __call__ function is the main function of a TransformerEncoderLayer. It takes in hidden states, frequency-domain inputs, and masks as input. It then applies self-attention to the hidden states using those inputs and returns an output tensor with shape (batch_size, sequence_length, model_dim). - :param self: Refer to the class instance itself - :param hidden_states: chex.Array: Pass in the hidden state of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency information - :param attention_mask: chex.Array: Mask out the attention weights for padding tokens - :param position_ids: chex.Array: Determine the position of each token in the sequence - :param causal_mask: chex.Array: Mask the attention weights - :param deterministic: bool: Control whether the dropout is applied or not - :param init_cache: bool: Initialize the cache in the attention layer - :param output_attentions: bool: Return the attention weights - :param fcm_mask: Optional[jnp.ndarray]: Mask the self-attention + Args: + self: Refer to the class instance itself + hidden_states: chex.Array: Pass in the hidden state of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency information + attention_mask: chex.Array: Mask out the attention weights + for padding tokens + position_ids: chex.Array: Determine the position of each + token in the sequence + causal_mask: chex.Array: Mask the attention weights + deterministic: bool: Control whether the dropout is applied + or not + init_cache: bool: Initialize the cache in the attention + layer + output_attentions: bool: Return the attention weights + fcm_mask: Optional[jnp.ndarray]: Mask the self-attention :param : Control the dropout in the self attention layer - :return: A tuple of two items + Returns: + A tuple of two items """ attn_outputs = self.self_attn( self.input_layernorm(hidden_states), @@ -568,36 +588,41 @@ def __init__( _do_init: bool = True, **kwargs, ): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It sets up the instance of the class, and defines what happens when it's created. The __init__ function can take arguments, but self is always required (it refers to the instance of the object). - - :param self: Refer to the object itself - :param config: Qwen2Config: Pass the configuration to the module - :param input_shape: Tuple: Specify the shape of the input to the model - :param seed: int: Set the seed for random number generation - :param dtype: jnp.dtype: Specify the data type of the input - :param _do_init: bool: Control whether the module is initialized or not - :param kwargs: Pass in any additional parameters that the module_class might need + Args: + self: Refer to the object itself + config: Qwen2Config: Pass the configuration to the module + input_shape: Tuple: Specify the shape of the input to the + model + seed: int: Set the seed for random number generation + dtype: jnp.dtype: Specify the data type of the input + _do_init: bool: Control whether the module is initialized or + not + **kwargs: Pass in any additional parameters that the + module_class might need :param : Specify the number of layers in the network - :return: The super() of the class + Returns: + The super() of the class """ module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. + """The init_weights function is used to initialize the weights of a model. - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Specify the shape of the input tensor - :param params: FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Specify the shape of the input tensor + params: FrozenDict: Pass in the parameters of a pre-trained + model + Returns: + A frozendict of parameters """ input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -635,17 +660,18 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz return random_params def init_cache(self, batch_size, max_length): - """ - The init_cache function is used to initialize the cache for a given batch size and sequence length. + """The init_cache function is used to initialize the cache for a given batch size and sequence length. The cache is a dictionary that contains all the intermediate states from each layer in the model. This allows us to run inference on multiple batches without having to re-run forward passes through every layer in the model, which would be very slow. - :param self: Access the module - :param batch_size: Define the batch size of the input tensors - :param max_length: Set the length of the input sequence - :return: A dictionary with the following keys: + Args: + self: Access the module + batch_size: Define the batch size of the input tensors + max_length: Set the length of the input sequence + Returns: + A dictionary with the following keys: """ input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) @@ -673,27 +699,36 @@ def __call__( add_params_field: bool = False, **kwargs ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It takes in inputs and returns outputs, but it also has some other important features: - It can take in mutable state (e.g., past_key_values) that will be updated during the call and returned at the end. - It can take in random number generators (rngs) that are used to generate random numbers for dropout or sampling operations. - :param self: Represent the instance of the class - :param input_ids: chex.Array: Pass in the input tokens - :param attention_mask: chex.Array: Mask out certain tokens in the input - :param position_ids: chex.Array: Create the positional embeddings - :param params: dict: Pass in the parameters of the model - :param past_key_values: dict: Pass in the past key values from a previous call to __call__ - :param dropout_rng: jax.random.PRNGKey: Make sure that the dropout is applied in a random way - :param train: bool: Determine whether to use dropout or not - :param output_attentions: Optional[bool]: Determine whether to return the attention weights - :param output_hidden_states: Optional[bool]: Return the hidden states of all layers - :param return_dict: Optional[bool]: Determine whether to return a dictionary or not - :param extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids - :param add_params_field: bool: Add the params field to the inputs dictionary - :return: A tuple of the following: - + Args: + self: Represent the instance of the class + input_ids: chex.Array: Pass in the input tokens + attention_mask: chex.Array: Mask out certain tokens in the + input + position_ids: chex.Array: Create the positional embeddings + params: dict: Pass in the parameters of the model + past_key_values: dict: Pass in the past key values from a + previous call to __call__ + dropout_rng: jax.random.PRNGKey: Make sure that the dropout + is applied in a random way + train: bool: Determine whether to use dropout or not + output_attentions: Optional[bool]: Determine whether to + return the attention weights + output_hidden_states: Optional[bool]: Return the hidden + states of all layers + return_dict: Optional[bool]: Determine whether to return a + dictionary or not + extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in + the embedding for the input_ids + add_params_field: bool: Add the params field to the inputs + dictionary + + Returns: + A tuple of the following: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -793,27 +828,35 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ): - """ - The __call__ function is the main function of a JAX nn.Module. + """The __call__ function is the main function of a JAX nn.Module. It defines how the module behaves when called as a function, and it's what you'll use to call your model in training loops or inference scripts. The __call__ method should take all inputs that are necessary for computing outputs from the module, and return all outputs that are computed by this module. - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Pass the input tensor to the encoder - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency of each token - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Specify the position of each token in a sequence - :param causal_mask: chex.Array: Mask the attention weights - :param deterministic: bool: Determine whether the model is in training or evaluation mode - :param init_cache: bool: Initialize the cache for each layer - :param output_attentions: bool: Determine whether to output the attention weights - :param output_hidden_states: bool: Determine whether to return the hidden states of each layer - :param return_dict: bool: Return a dictionary of the outputs + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Pass the input tensor to the + encoder + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency of each token + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Specify the position of each token + in a sequence + causal_mask: chex.Array: Mask the attention weights + deterministic: bool: Determine whether the model is in + training or evaluation mode + init_cache: bool: Initialize the cache for each layer + output_attentions: bool: Determine whether to output the + attention weights + output_hidden_states: bool: Determine whether to return the + hidden states of each layer + return_dict: bool: Return a dictionary of the outputs :param : Determine whether to use the forgetful causal mask - :return: A tuple of 3 values + Returns: + A tuple of 3 values """ all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -929,26 +972,33 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids + """The __call__ function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids and returns the output of the model. The __call__ function also has optional arguments that can be used to control the behavior of the model (e.g., deterministic=True). These optional arguments are passed as keyword arguments when calling a Flax model. - :param self: Represent the instance of the class - :param input_ids: chex.Array: Pass in the input token ids - :param attention_mask: chex.Array: Mask out the padding tokens - :param position_ids: chex.Array: Indicate the position of each token in a sequence - :param deterministic: bool: Control whether dropout is applied or not - :param inputs_embeds: chex.Array: Pass in the embeddings of the input tokens - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attentions or not - :param output_hidden_states: bool: Determine whether to return hidden states - :param return_dict: bool: Return a dictionary of the output or not - :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of the - :param None]]: Pass in the extra embedding - :return: A tuple of: - + Args: + self: Represent the instance of the class + input_ids: chex.Array: Pass in the input token ids + attention_mask: chex.Array: Mask out the padding tokens + position_ids: chex.Array: Indicate the position of each + token in a sequence + deterministic: bool: Control whether dropout is applied or + not + inputs_embeds: chex.Array: Pass in the embeddings of the + input tokens + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attentions or not + output_hidden_states: bool: Determine whether to return + hidden states + return_dict: bool: Return a dictionary of the output or not + extra_embedding: Optional[Union[jnp.ndarray: Pass in the + embedding of the + None]]: Pass in the extra embedding + + Returns: + A tuple of: """ if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids.astype("i4")) @@ -1038,22 +1088,27 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax module. It takes in inputs and returns outputs. - - :param self: Refer to the object itself - :param input_ids: chex.Array: Pass the input token ids to the model - :param attention_mask: chex.Array: Mask out the padding tokens - :param position_ids: chex.Array: Specify the position of each token in the input sequence - :param deterministic: bool: Control whether the model is trained or not - :param init_cache: bool: Initialize the cache for the decoder - :param output_attentions: bool: Return the attention weights - :param output_hidden_states: bool: Determine whether to return the hidden states - :param return_dict: bool: Return a dictionary of the outputs or not - :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of the word that we want to predict - :param None]]: Pass in the extra embedding - :return: The logits and the hidden states - + """The __call__ function is the main function of a Flax module. It takes in inputs and returns outputs. + + Args: + self: Refer to the object itself + input_ids: chex.Array: Pass the input token ids to the model + attention_mask: chex.Array: Mask out the padding tokens + position_ids: chex.Array: Specify the position of each token + in the input sequence + deterministic: bool: Control whether the model is trained or + not + init_cache: bool: Initialize the cache for the decoder + output_attentions: bool: Return the attention weights + output_hidden_states: bool: Determine whether to return the + hidden states + return_dict: bool: Return a dictionary of the outputs or not + extra_embedding: Optional[Union[jnp.ndarray: Pass in the + embedding of the word that we want to predict + None]]: Pass in the extra embedding + + Returns: + The logits and the hidden states """ batch_size, seq_length = input_ids.shape if attention_mask is None: @@ -1115,15 +1170,18 @@ def set_output_embeddings(self, new_embeddings): self.module.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None): - """ - The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. - - :param self: Access variables that belong to the class - :param input_ids: Pass in the input tokens - :param max_length: Set the length of the sequence to be generated - :param attention_mask: Optional[chex.Array]: Mask the attention weights - :return: A dictionary of the past_key_values, attention_mask and position ids - + """The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. + + Args: + self: Access variables that belong to the class + input_ids: Pass in the input tokens + max_length: Set the length of the sequence to be generated + attention_mask: Optional[chex.Array]: Mask the attention + weights + + Returns: + A dictionary of the past_key_values, attention_mask and + position ids """ batch_size, seq_length = input_ids.shape @@ -1158,12 +1216,14 @@ class FlaxQwen2ForSequenceClassificationModule(nn.Module): precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self): - """ - The setup function is called once at the beginning of training. + """The setup function is called once at the beginning of training. It initializes the model and optimizer, and sets up any other state that needs to be initialized. - :param self: Access variables that belong to the class - :return: A tuple of the model and the classifier + Args: + self: Access variables that belong to the class + + Returns: + A tuple of the model and the classifier """ self.model = FlaxQwen2Module(self.config, dtype=self.dtype) self.classifier = Linear( @@ -1188,26 +1248,31 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax module. + """The __call__ function is the main function of a Flax module. It takes in all the inputs to the model and returns all outputs from it. The __call__ function can be called directly on an instance of a class, or by using parentheses after an instance: >>> my_model = MyModel() # instantiate your model class >>> output = my_model(input) # call your model with input data as arguments to __call__ - :param self: Refer to the class instance - :param input_ids: chex.Array: Pass the input to the model - :param attention_mask: chex.Array: Specify which tokens are masked - :param position_ids: chex.Array: Specify the position of each token in the sequence - :param deterministic: bool: Control whether the model is run in deterministic or stochastic mode - :param init_cache: bool: Initialize the cache for the transformer - :param output_attentions: bool: Return the attention weights - :param output_hidden_states: bool: Return the hidden states of all layers - :param return_dict: bool: Return a dictionary of outputs - :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of a new word - :param None]]: Pass the extra embedding to the model - :return: A tuple of logits and hidden_states - + Args: + self: Refer to the class instance + input_ids: chex.Array: Pass the input to the model + attention_mask: chex.Array: Specify which tokens are masked + position_ids: chex.Array: Specify the position of each token + in the sequence + deterministic: bool: Control whether the model is run in + deterministic or stochastic mode + init_cache: bool: Initialize the cache for the transformer + output_attentions: bool: Return the attention weights + output_hidden_states: bool: Return the hidden states of all + layers + return_dict: bool: Return a dictionary of outputs + extra_embedding: Optional[Union[jnp.ndarray: Pass in the + embedding of a new word + None]]: Pass the extra embedding to the model + + Returns: + A tuple of logits and hidden_states """ batch_size, seq_length = input_ids.shape if attention_mask is None: diff --git a/src/python/easydel/modules/qwen2/qwen_configuration.py b/src/python/easydel/modules/qwen2/qwen_configuration.py index f510f2ced..8cc44b95d 100644 --- a/src/python/easydel/modules/qwen2/qwen_configuration.py +++ b/src/python/easydel/modules/qwen2/qwen_configuration.py @@ -81,15 +81,17 @@ def __init__( ) def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices. - :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to + partition the model fully or not + Returns: + A list of tuples """ return ( @@ -146,26 +148,38 @@ def add_jax_args( rope_scaling: Optional[Mapping[str, str | float]] = None, **kwargs, ): - """ - The add_jax_args function adds the following arguments to the Transformer class: - - :param self: Refer to the current object - :param resid_pdrop: float: Set the dropout rate for residual connections - :param embd_pdrop: float: Set the probability of dropping an embedding - :param attention_dropout: float: Set the probability of dropping out the attention layer - :param tie_word_embeddings: bool: Tie the word embeddings to the decoder - :param gradient_checkpointing: str: Control the amount of memory used by jax - :param fcm_min_ratio: float: Control the minimum ratio of the number of chunks to be used in flash-based computation - :param fcm_max_ratio: float: Set the maximum ratio of the number of input tokens to output tokens - :param use_scan_mlp: bool: Determine whether to use the scan_mlp function or not - :param scan_mlp_chunk_size: int: Set the chunk size for scan_mlp - :param number_rep_kv: int: Determine how many times the key and value vectors are repeated - :param bits: Optional[int]: Determine the number of bits used in the quantization - :param rope_theta: float : rope_theta for compute rope - :param hidden_act: str : hidden_act for mlp - :param scan_layers: bool: Determine whether to use scan layers or not - :return: The following: - + """The add_jax_args function adds the following arguments to the Transformer class: + + Args: + self: Refer to the current object + resid_pdrop: float: Set the dropout rate for residual + connections + embd_pdrop: float: Set the probability of dropping an + embedding + attention_dropout: float: Set the probability of dropping + out the attention layer + tie_word_embeddings: bool: Tie the word embeddings to the + decoder + gradient_checkpointing: str: Control the amount of memory + used by jax + fcm_min_ratio: float: Control the minimum ratio of the + number of chunks to be used in flash-based computation + fcm_max_ratio: float: Set the maximum ratio of the number of + input tokens to output tokens + use_scan_mlp: bool: Determine whether to use the scan_mlp + function or not + scan_mlp_chunk_size: int: Set the chunk size for scan_mlp + number_rep_kv: int: Determine how many times the key and + value vectors are repeated + bits: Optional[int]: Determine the number of bits used in + the quantization + rope_theta: float : rope_theta for compute rope + hidden_act: str : hidden_act for mlp + scan_layers: bool: Determine whether to use scan layers or + not + + Returns: + The following: """ self.scan_layers = scan_layers self.embd_pdrop = embd_pdrop diff --git a/src/python/easydel/modules/qwen2_moe/configuration_qwen2_moe.py b/src/python/easydel/modules/qwen2_moe/configuration_qwen2_moe.py index 6d40e9c4b..bee58b5ac 100644 --- a/src/python/easydel/modules/qwen2_moe/configuration_qwen2_moe.py +++ b/src/python/easydel/modules/qwen2_moe/configuration_qwen2_moe.py @@ -72,15 +72,17 @@ def __init__( ) def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - """ - The get_partition_rules function is used to define the partitioning scheme for a model. + """The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices. - :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not - :return: A list of tuples + Args: + fully_sharded_data_parallel: bool: Determine whether to + partition the model fully or not + Returns: + A list of tuples """ return ( @@ -128,15 +130,17 @@ def add_jax_args( bits: Optional[int] = None, **kwargs, ): - """ - The add_jax_args function adds the following arguments to the Transformer class: - - :param self: Refer to the current object + """The add_jax_args function adds the following arguments to the Transformer class: - :param gradient_checkpointing: str: Control the amount of memory used by jax - :param bits: Optional[int]: Determine the number of bits used in the quantization - :return: The following: + Args: + self: Refer to the current object + gradient_checkpointing: str: Control the amount of memory + used by jax + bits: Optional[int]: Determine the number of bits used in + the quantization + Returns: + The following: """ self.gradient_checkpointing = gradient_checkpointing self.bits = bits diff --git a/src/python/easydel/modules/qwen2_moe/modeling_qwen2_moe_flax.py b/src/python/easydel/modules/qwen2_moe/modeling_qwen2_moe_flax.py index 7aa2922ea..8fd36d128 100644 --- a/src/python/easydel/modules/qwen2_moe/modeling_qwen2_moe_flax.py +++ b/src/python/easydel/modules/qwen2_moe/modeling_qwen2_moe_flax.py @@ -145,16 +145,18 @@ def setup(self) -> None: ) def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - """ - The __call__ function is the main function of a class. + """The __call__ function is the main function of a class. It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments). The __call__ method enables instances of a class to be called like standard Python functions. - :param self: Represent the instance of the class - :param x: jnp.ndarray: Pass in the input to the layer - :param deterministic: bool: Determine whether to use dropout - :return: A tensor that is the result of applying a dropout function to x + Args: + self: Represent the instance of the class + x: jnp.ndarray: Pass in the input to the layer + deterministic: bool: Determine whether to use dropout + Returns: + A tensor that is the result of applying a dropout function + to x """ x = self.down_proj(jax.nn.silu(self.gate_proj(x)) * self.up_proj(x)) return x @@ -262,33 +264,37 @@ def _merge_heads(self, hidden_states): @staticmethod def _transpose_sequence_head(query, key, value): - """ - The _transpose_sequence_head function transposes the query, key and value matrices. + """The _transpose_sequence_head function transposes the query, key and value matrices. - :param query: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query, key and value matrices + Args: + query: Get the attention weights for each of the heads + key: Determine the number of heads + value: Store the values of the input + Returns: + The transpose of the query, key and value matrices """ return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - :param self: Access variables that belong to the class - :param batch_size: Reshape the query, key and value tensors - :param sequence_length: Reshape the query, key and value tensors - :param query: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param freq_cis: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query, key and value - + Args: + self: Access variables that belong to the class + batch_size: Reshape the query, key and value tensors + sequence_length: Reshape the query, key and value tensors + query: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + freq_cis: Calculate the frequency of each word in the + vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query, key and value """ query = query.reshape(batch_size, sequence_length, self.config.num_attention_heads, self.head_dim) key = key.reshape(batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim) @@ -315,25 +321,32 @@ def __call__( output_attentions: bool = False, fcm_mask=None, ): - """ - - The __call__ function is the main function of a JAX module. It defines how the module behaves when called + """The __call__ function is the main function of a JAX module. It defines how the module behaves when called with inputs. The __call__ function can be thought of as a "forward pass" through the model, and it should return all outputs that are needed for training or inference. - :param self: Access variables that belong to the class - :param hidden_states: chex.Array: Pass the hidden states of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param causal_mask: chex.Array: Mask out the future tokens in the decoder - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attention weights or not - :param fcm_mask: Mask out the attention weights between the input and output tokens + Args: + self: Access variables that belong to the class + hidden_states: chex.Array: Pass the hidden states of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency coefficients for each position + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Determine the position of each + token in a sequence + causal_mask: chex.Array: Mask out the future tokens in the + decoder + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attention weights or not + fcm_mask: Mask out the attention weights between the input + and output tokens :param : Determine if the attention is causal or not - :return: A tuple of two arrays + Returns: + A tuple of two arrays """ batch_size, sequence_length = hidden_states.shape[:2] query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj( @@ -500,8 +513,7 @@ def __call__( class FlaxQwen2MoeSparseMoeBlock(nn.Module): - """ - This implementation is + """This implementation is strictly equivalent to standard MoE with full capacity (no dropped tokens). It's faster since it formulates MoE operations in terms of block-sparse operations to accomodate imbalanced @@ -661,25 +673,32 @@ def __call__( fcm_mask: Optional[jnp.ndarray] = None, ): - """ - The __call__ function is the main function of a TransformerEncoderLayer. + """The __call__ function is the main function of a TransformerEncoderLayer. It takes in hidden states, frequency-domain inputs, and masks as input. It then applies self-attention to the hidden states using those inputs and returns an output tensor with shape (batch_size, sequence_length, model_dim). - :param self: Refer to the class instance itself - :param hidden_states: chex.Array: Pass in the hidden state of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency information - :param attention_mask: chex.Array: Mask out the attention weights for padding tokens - :param position_ids: chex.Array: Determine the position of each token in the sequence - :param causal_mask: chex.Array: Mask the attention weights - :param deterministic: bool: Control whether the dropout is applied or not - :param init_cache: bool: Initialize the cache in the attention layer - :param output_attentions: bool: Return the attention weights - :param fcm_mask: Optional[jnp.ndarray]: Mask the self-attention + Args: + self: Refer to the class instance itself + hidden_states: chex.Array: Pass in the hidden state of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency information + attention_mask: chex.Array: Mask out the attention weights + for padding tokens + position_ids: chex.Array: Determine the position of each + token in the sequence + causal_mask: chex.Array: Mask the attention weights + deterministic: bool: Control whether the dropout is applied + or not + init_cache: bool: Initialize the cache in the attention + layer + output_attentions: bool: Return the attention weights + fcm_mask: Optional[jnp.ndarray]: Mask the self-attention :param : Control the dropout in the self attention layer - :return: A tuple of two items + Returns: + A tuple of two items """ attn_outputs = self.self_attn( self.input_layernorm(hidden_states), @@ -728,36 +747,41 @@ def __init__( _do_init: bool = True, **kwargs, ): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It sets up the instance of the class, and defines what happens when it's created. The __init__ function can take arguments, but self is always required (it refers to the instance of the object). - - :param self: Refer to the object itself - :param config: Qwen2MoeConfig: Pass the configuration to the module - :param input_shape: Tuple: Specify the shape of the input to the model - :param seed: int: Set the seed for random number generation - :param dtype: jnp.dtype: Specify the data type of the input - :param _do_init: bool: Control whether the module is initialized or not - :param kwargs: Pass in any additional parameters that the module_class might need + Args: + self: Refer to the object itself + config: Qwen2MoeConfig: Pass the configuration to the module + input_shape: Tuple: Specify the shape of the input to the + model + seed: int: Set the seed for random number generation + dtype: jnp.dtype: Specify the data type of the input + _do_init: bool: Control whether the module is initialized or + not + **kwargs: Pass in any additional parameters that the + module_class might need :param : Specify the number of layers in the network - :return: The super() of the class + Returns: + The super() of the class """ module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - """ - The init_weights function is used to initialize the weights of a model. + """The init_weights function is used to initialize the weights of a model. - :param self: Access variables that belong to the class - :param rng: jax.random.PRNGKey: Initialize the weights of the model - :param input_shape: Tuple: Specify the shape of the input tensor - :param params: FrozenDict: Pass in the parameters of a pre-trained model - :return: A frozendict of parameters + Args: + self: Access variables that belong to the class + rng: jax.random.PRNGKey: Initialize the weights of the model + input_shape: Tuple: Specify the shape of the input tensor + params: FrozenDict: Pass in the parameters of a pre-trained + model + Returns: + A frozendict of parameters """ input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -795,17 +819,18 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz return random_params def init_cache(self, batch_size, max_length): - """ - The init_cache function is used to initialize the cache for a given batch size and sequence length. + """The init_cache function is used to initialize the cache for a given batch size and sequence length. The cache is a dictionary that contains all the intermediate states from each layer in the model. This allows us to run inference on multiple batches without having to re-run forward passes through every layer in the model, which would be very slow. - :param self: Access the module - :param batch_size: Define the batch size of the input tensors - :param max_length: Set the length of the input sequence - :return: A dictionary with the following keys: + Args: + self: Access the module + batch_size: Define the batch size of the input tensors + max_length: Set the length of the input sequence + Returns: + A dictionary with the following keys: """ input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) @@ -834,27 +859,36 @@ def __call__( add_params_field: bool = False, **kwargs ): - """ - The __call__ function is the main function of a JAX module. + """The __call__ function is the main function of a JAX module. It takes in inputs and returns outputs, but it also has some other important features: - It can take in mutable state (e.g., past_key_values) that will be updated during the call and returned at the end. - It can take in random number generators (rngs) that are used to generate random numbers for dropout or sampling operations. - :param self: Represent the instance of the class - :param input_ids: chex.Array: Pass in the input tokens - :param attention_mask: chex.Array: Mask out certain tokens in the input - :param position_ids: chex.Array: Create the positional embeddings - :param params: dict: Pass in the parameters of the model - :param past_key_values: dict: Pass in the past key values from a previous call to __call__ - :param dropout_rng: jax.random.PRNGKey: Make sure that the dropout is applied in a random way - :param train: bool: Determine whether to use dropout or not - :param output_attentions: Optional[bool]: Determine whether to return the attention weights - :param output_hidden_states: Optional[bool]: Return the hidden states of all layers - :param return_dict: Optional[bool]: Determine whether to return a dictionary or not - :param extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids - :param add_params_field: bool: Add the params field to the inputs dictionary - :return: A tuple of the following: - + Args: + self: Represent the instance of the class + input_ids: chex.Array: Pass in the input tokens + attention_mask: chex.Array: Mask out certain tokens in the + input + position_ids: chex.Array: Create the positional embeddings + params: dict: Pass in the parameters of the model + past_key_values: dict: Pass in the past key values from a + previous call to __call__ + dropout_rng: jax.random.PRNGKey: Make sure that the dropout + is applied in a random way + train: bool: Determine whether to use dropout or not + output_attentions: Optional[bool]: Determine whether to + return the attention weights + output_hidden_states: Optional[bool]: Return the hidden + states of all layers + return_dict: Optional[bool]: Determine whether to return a + dictionary or not + extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in + the embedding for the input_ids + add_params_field: bool: Add the params field to the inputs + dictionary + + Returns: + A tuple of the following: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -958,27 +992,35 @@ def __call__( output_router_logits: Optional[bool] = None, return_dict: bool = True, ): - """ - The __call__ function is the main function of a JAX nn.Module. + """The __call__ function is the main function of a JAX nn.Module. It defines how the module behaves when called as a function, and it's what you'll use to call your model in training loops or inference scripts. The __call__ method should take all inputs that are necessary for computing outputs from the module, and return all outputs that are computed by this module. - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Pass the input tensor to the encoder - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency of each token - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Specify the position of each token in a sequence - :param causal_mask: chex.Array: Mask the attention weights - :param deterministic: bool: Determine whether the model is in training or evaluation mode - :param init_cache: bool: Initialize the cache for each layer - :param output_attentions: bool: Determine whether to output the attention weights - :param output_hidden_states: bool: Determine whether to return the hidden states of each layer - :param return_dict: bool: Return a dictionary of the outputs + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Pass the input tensor to the + encoder + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency of each token + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Specify the position of each token + in a sequence + causal_mask: chex.Array: Mask the attention weights + deterministic: bool: Determine whether the model is in + training or evaluation mode + init_cache: bool: Initialize the cache for each layer + output_attentions: bool: Determine whether to output the + attention weights + output_hidden_states: bool: Determine whether to return the + hidden states of each layer + return_dict: bool: Return a dictionary of the outputs :param : Determine whether to use the forgetful causal mask - :return: A tuple of 3 values + Returns: + A tuple of 3 values """ all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -1205,22 +1247,27 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax module. It takes in inputs and returns outputs. - - :param self: Refer to the object itself - :param input_ids: chex.Array: Pass the input token ids to the model - :param attention_mask: chex.Array: Mask out the padding tokens - :param position_ids: chex.Array: Specify the position of each token in the input sequence - :param deterministic: bool: Control whether the model is trained or not - :param init_cache: bool: Initialize the cache for the decoder - :param output_attentions: bool: Return the attention weights - :param output_hidden_states: bool: Determine whether to return the hidden states - :param return_dict: bool: Return a dictionary of the outputs or not - :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of the word that we want to predict - :param None]]: Pass in the extra embedding - :return: The logits and the hidden states - + """The __call__ function is the main function of a Flax module. It takes in inputs and returns outputs. + + Args: + self: Refer to the object itself + input_ids: chex.Array: Pass the input token ids to the model + attention_mask: chex.Array: Mask out the padding tokens + position_ids: chex.Array: Specify the position of each token + in the input sequence + deterministic: bool: Control whether the model is trained or + not + init_cache: bool: Initialize the cache for the decoder + output_attentions: bool: Return the attention weights + output_hidden_states: bool: Determine whether to return the + hidden states + return_dict: bool: Return a dictionary of the outputs or not + extra_embedding: Optional[Union[jnp.ndarray: Pass in the + embedding of the word that we want to predict + None]]: Pass in the extra embedding + + Returns: + The logits and the hidden states """ if output_router_logits is None: output_router_logits = self.config.output_router_logits @@ -1303,15 +1350,18 @@ def set_output_embeddings(self, new_embeddings): self.module.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None): - """ - The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. - - :param self: Access variables that belong to the class - :param input_ids: Pass in the input tokens - :param max_length: Set the length of the sequence to be generated - :param attention_mask: Optional[chex.Array]: Mask the attention weights - :return: A dictionary of the past_key_values, attention_mask and position ids - + """The prepare_inputs_for_generation function is used to prepare the inputs for a generation task. + + Args: + self: Access variables that belong to the class + input_ids: Pass in the input tokens + max_length: Set the length of the sequence to be generated + attention_mask: Optional[chex.Array]: Mask the attention + weights + + Returns: + A dictionary of the past_key_values, attention_mask and + position ids """ batch_size, seq_length = input_ids.shape @@ -1346,12 +1396,14 @@ class FlaxQwen2MoeForSequenceClassificationModule(nn.Module): precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self): - """ - The setup function is called once at the beginning of training. + """The setup function is called once at the beginning of training. It initializes the model and optimizer, and sets up any other state that needs to be initialized. - :param self: Access variables that belong to the class - :return: A tuple of the model and the classifier + Args: + self: Access variables that belong to the class + + Returns: + A tuple of the model and the classifier """ self.model = FlaxQwen2MoeModule(self.config, dtype=self.dtype) self.classifier = Linear( @@ -1376,26 +1428,31 @@ def __call__( return_dict: bool = True, extra_embedding: Optional[Union[jnp.ndarray, None]] = None ): - """ - The __call__ function is the main function of a Flax module. + """The __call__ function is the main function of a Flax module. It takes in all the inputs to the model and returns all outputs from it. The __call__ function can be called directly on an instance of a class, or by using parentheses after an instance: >>> my_model = MyModel() # instantiate your model class >>> output = my_model(input) # call your model with input data as arguments to __call__ - :param self: Refer to the class instance - :param input_ids: chex.Array: Pass the input to the model - :param attention_mask: chex.Array: Specify which tokens are masked - :param position_ids: chex.Array: Specify the position of each token in the sequence - :param deterministic: bool: Control whether the model is run in deterministic or stochastic mode - :param init_cache: bool: Initialize the cache for the transformer - :param output_attentions: bool: Return the attention weights - :param output_hidden_states: bool: Return the hidden states of all layers - :param return_dict: bool: Return a dictionary of outputs - :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of a new word - :param None]]: Pass the extra embedding to the model - :return: A tuple of logits and hidden_states - + Args: + self: Refer to the class instance + input_ids: chex.Array: Pass the input to the model + attention_mask: chex.Array: Specify which tokens are masked + position_ids: chex.Array: Specify the position of each token + in the sequence + deterministic: bool: Control whether the model is run in + deterministic or stochastic mode + init_cache: bool: Initialize the cache for the transformer + output_attentions: bool: Return the attention weights + output_hidden_states: bool: Return the hidden states of all + layers + return_dict: bool: Return a dictionary of outputs + extra_embedding: Optional[Union[jnp.ndarray: Pass in the + embedding of a new word + None]]: Pass the extra embedding to the model + + Returns: + A tuple of logits and hidden_states """ batch_size, seq_length = input_ids.shape if attention_mask is None: diff --git a/src/python/easydel/modules/roberta/__init__.py b/src/python/easydel/modules/roberta/__init__.py index 686224015..4bd60bb9d 100644 --- a/src/python/easydel/modules/roberta/__init__.py +++ b/src/python/easydel/modules/roberta/__init__.py @@ -1,18 +1,18 @@ -from .roberta_configuration import RobertaConfig -from .modelling_roberta_flax import ( - FlaxRobertaForCausalLM, - FlaxRobertaForMultipleChoice, - FlaxRobertaForMaskedLMModule, - FlaxRobertaForQuestionAnswering, - FlaxRobertaForSequenceClassification, - FlaxRobertaForTokenClassification, -) - -__all__ = ( - "FlaxRobertaForSequenceClassification", - "FlaxRobertaForQuestionAnswering", - "FlaxRobertaForTokenClassification", - "FlaxRobertaForMultipleChoice", - "FlaxRobertaForCausalLM", - "RobertaConfig" -) +from .roberta_configuration import RobertaConfig +from .modelling_roberta_flax import ( + FlaxRobertaForCausalLM, + FlaxRobertaForMultipleChoice, + FlaxRobertaForMaskedLMModule, + FlaxRobertaForQuestionAnswering, + FlaxRobertaForSequenceClassification, + FlaxRobertaForTokenClassification, +) + +__all__ = ( + "FlaxRobertaForSequenceClassification", + "FlaxRobertaForQuestionAnswering", + "FlaxRobertaForTokenClassification", + "FlaxRobertaForMultipleChoice", + "FlaxRobertaForCausalLM", + "RobertaConfig" +) diff --git a/src/python/easydel/modules/roberta/modelling_roberta_flax.py b/src/python/easydel/modules/roberta/modelling_roberta_flax.py index bdaede4ff..49a70d4b6 100644 --- a/src/python/easydel/modules/roberta/modelling_roberta_flax.py +++ b/src/python/easydel/modules/roberta/modelling_roberta_flax.py @@ -1,1409 +1,1409 @@ -# Model is modified version from EasyLM -# Supports 8,6,4 BIT and flash attention -import math -from gc import unfreeze -from typing import Optional, Tuple - -import chex -import fjformer -from flax import linen as nn -from flax.core import FrozenDict, freeze -from flax.linen.attention import make_attention_mask, make_causal_mask, combine_masks, dot_product_attention_weights -from flax.linen.partitioning import remat -from flax.traverse_util import flatten_dict, unflatten_dict -import flax.linen -from transformers.modeling_flax_outputs import FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxMaskedLMOutput, \ - FlaxSequenceClassifierOutput, FlaxMultipleChoiceModelOutput, FlaxTokenClassifierOutput, \ - FlaxQuestionAnsweringModelOutput, FlaxCausalLMOutputWithCrossAttentions, \ - FlaxBaseModelOutputWithPoolingAndCrossAttentions - -from .roberta_configuration import RobertaConfig -import jax -from jax.sharding import PartitionSpec -from jax import lax, numpy as jnp -from ..easydel_modelling_utils import EasyDeLFlaxPretrainedModel -from ..attention_module import AttentionModule -from ..flax_modelling_utils import get_gradient_checkpoint_policy, ACT2FN, get_dot_general_by_bits, \ - BaseJAXAttentionModule - -from fjformer.linen import Linear - - -class FlaxRobertaEmbeddings(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.word_embeddings = nn.Embed( - self.config.vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.position_embeddings = nn.Embed( - self.config.max_position_embeddings, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.token_type_embeddings = nn.Embed( - self.config.type_vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = flax.linen.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) - position_embeds = self.position_embeddings(position_ids.astype("i4")) - token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) - - hidden_states = inputs_embeds + token_type_embeddings + position_embeds - - hidden_states = self.LayerNorm(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -class FlaxRobertaSelfAttention(BaseJAXAttentionModule): - config: RobertaConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.head_dim = self.config.hidden_size // self.config.num_attention_heads - if self.config.hidden_size % self.config.num_attention_heads != 0: - raise ValueError( - "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " - " : {self.config.num_attention_heads}" - ) - self.attention_performer = AttentionModule( - use_sharding_constraint=self.config.use_sharding_constraint, - block_k_major=self.config.block_k_major, - block_b=self.config.block_b, - block_q=self.config.block_q, - block_k=self.config.block_k, - block_q_major_dkv=self.config.block_q_major_dkv, - block_k_major_dkv=self.config.block_k_major_dkv, - block_k_major_dq=self.config.block_k_major_dq, - block_k_dkv=self.config.block_k_dkv, - block_q_dkv=self.config.block_q_dkv, - block_q_dq=self.config.block_q_dq, - block_k_dq=self.config.block_k_dq, - num_attention_heads=self.config.num_attention_heads, - attention_dropout=0.0, - head_dims=self.head_dim, - attention_partition_spec=self.config.attention_partition_spec, - shard_attention_computation=self.config.shard_attention_computation, - precision=self.precision, - force_float32_tpu=True, - attn_mechanism=self.config.attn_mechanism, - dtype=self.dtype, - bias_partition_spec=self.config.bias_partition_spec, - key_partition_spec=self.config.key_partition_spec, - query_partition_spec=self.config.query_partition_spec, - generation_query_partition_spec=self.config.generation_query_partition_spec, - generation_bias_partition_spec=self.config.generation_bias_partition_spec, - generation_attention_partition_spec=self.config.generation_attention_partition_spec, - value_partition_spec=self.config.value_partition_spec, - scan_ring_attention=self.config.scan_ring_attention, - mesh=self.config.jax_mesh(), - sm_scale=1 / math.sqrt(self.head_dim), - axis_name=self.config.attention_axis_name, - backward_pass_impl=self.config.flash_attention_backward_pass_impl - ) - self.query = Linear( - self.config.hidden_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - self.key = Linear( - self.config.hidden_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - self.value = Linear( - self.config.hidden_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, getattr(self.config, "c_max_position_embeddings", self.config.max_position_embeddings)), - dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - segment_ids: Optional[chex.Array] = None, - key_value_states: Optional[jnp.array] = None, - init_cache: bool = False, - deterministic=True, - output_attentions: bool = False, - ): - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - query_states = self.query(hidden_states) - if is_cross_attention: - key_states = self.key(key_value_states) - value_states = self.value(key_value_states) - else: - key_states = self.key(hidden_states) - value_states = self.value(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - if attention_mask is not None: - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.config.attention_probs_dropout_prob > 0.0: - dropout_rng = self.make_rng("dropout") - if layer_head_mask is None: - out = self.attention_performer.__call__( - query_states=query_states, - key_states=key_states, - value_states=value_states, - dropout_rng=dropout_rng, - deterministic=deterministic, - causal=True, - bias=attention_bias, - attention_mask=attention_mask, - uses_cache=False, - query_sequence_length=query_states.shape[1], - key_value_sequence_length=key_states.shape[1], - segment_ids=segment_ids, - causal_mask=causal_mask - ) - attn_weights = out.attention_weights - attn_output = out.attention_outputs - else: - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_probs_dropout_prob, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - - attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxRobertaSelfOutput(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.dense = Linear( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = flax.linen.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, input_tensor, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class FlaxRobertaAttention(nn.Module): - config: RobertaConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.self = FlaxRobertaSelfAttention( - self.config, - causal=self.causal, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states=None, - init_cache=False, - deterministic=True, - output_attentions: bool = False, - ): - attn_outputs = self.self( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=key_value_states, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_outputs[1],) - - return outputs - - -class FlaxRobertaIntermediate(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.dense = Linear( - self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - self.activation = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -class FlaxRobertaOutput(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.dense = Linear( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - precision=self.precision, - param_dtype=self.param_dtype, - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - self.dropout = flax.linen.Dropout(rate=self.config.hidden_dropout_prob) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states, attention_output, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + attention_output) - return hidden_states - - -class FlaxRobertaLayer(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.attention = FlaxRobertaAttention( - self.config, - causal=self.config.is_decoder, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype) - self.output = FlaxRobertaOutput(self.config, dtype=self.dtype) - if self.config.add_cross_attention: - self.crossattention = FlaxRobertaAttention( - self.config, - causal=True, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - ): - # Self Attention - attention_outputs = self.attention( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = attention_outputs[0] - - # Cross-Attention Block - if encoder_hidden_states is not None: - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=encoder_hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = cross_attention_outputs[0] - - hidden_states = self.intermediate(attention_output) - hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attention_outputs[1],) - if encoder_hidden_states is not None: - outputs += (cross_attention_outputs[1],) - return outputs - - -class FlaxRobertaLayerCollection(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - block = FlaxRobertaLayer - if self.config.gradient_checkpointing != "": - block = remat( - block, - static_argnums=(5, 6, 7), - policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing) - ) - - self.layers = [ - block( - self.config, - name=str(i), - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - # Check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.shape[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for " - f" {head_mask.shape[0]}." - ) - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = layer( - hidden_states, - attention_mask, - head_mask[i] if head_mask is not None else None, - encoder_hidden_states, - encoder_attention_mask, - init_cache, - deterministic, - output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -class FlaxRobertaEncoder(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.layer = FlaxRobertaLayerCollection( - self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - ) - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.layer( - hidden_states, - attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -class FlaxRobertaPooler(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.dense = Linear( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - - def __call__(self, hidden_states): - cls_hidden_state = hidden_states[:, 0] - cls_hidden_state = self.dense(cls_hidden_state) - return nn.tanh(cls_hidden_state) - - -class FlaxRobertaLMHead(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.dense = Linear( - self.config.hidden_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.decoder = Linear( - self.config.vocab_size, - dtype=self.dtype, - use_bias=False, - param_dtype=self.param_dtype, - precision=self.precision, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - self.bias = self.param( - "bias", - jax.nn.initializers.zeros, - ( - self.config.vocab_size, - ) - ) - - def __call__(self, hidden_states, shared_embedding=None): - hidden_states = self.dense(hidden_states) - hidden_states = ACT2FN["gelu"](hidden_states) - hidden_states = self.layer_norm(hidden_states) - - if shared_embedding is not None: - hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - hidden_states = self.decoder(hidden_states) - - bias = fjformer.linen.linen.control_quantization(self.bias, self.dtype) - hidden_states += bias - return hidden_states - - -class FlaxRobertaClassificationHead(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.dense = Linear( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - param_dtype=self.param_dtype, - precision=self.precision, - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = flax.linen.Dropout(rate=classifier_dropout) - self.out_proj = Linear( - self.config.num_labels, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - - def __call__(self, hidden_states, deterministic=True): - hidden_states = hidden_states[:, 0, :] - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.dense(hidden_states) - hidden_states = nn.tanh(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - -class FlaxRobertaPreTrainedModel(EasyDeLFlaxPretrainedModel): - config_class = RobertaConfig - base_model_prefix = "roberta" - - module_class: nn.Module = None - - def __init__( - self, - config: RobertaConfig, - input_shape: Tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - param_dtype: jnp.dtype = jnp.float32, - precision: Optional[lax.Precision] = None, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class( - config=config, - dtype=dtype, - param_dtype=param_dtype, - precision=precision, - **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - input_ids = jnp.zeros(input_shape, dtype="i4") - token_type_ids = jnp.ones_like(input_ids) - mask = (input_ids != self.config.pad_token_id).astype("i4") - - if mask.ndim > 2: - mask = mask.reshape((-1, mask.shape[-1])) - incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask - incremental_indices = incremental_indices.reshape(input_ids.shape) - else: - incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask - - position_ids = incremental_indices.astype("i4") + self.config.pad_token_id - - attention_mask = jnp.ones_like(input_ids) - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - if self.config.add_cross_attention: - encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - else: - module_init_outputs = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False - ) - - random_params = module_init_outputs["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length), dtype="i4") - attention_mask = jnp.ones_like(input_ids, dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - params: dict = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - past_key_values: dict = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # init input tensors if not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - if position_ids is None: - mask = (input_ids != self.config.pad_token_id).astype("i4") - - if mask.ndim > 2: - mask = mask.reshape((-1, mask.shape[-1])) - incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask - incremental_indices = incremental_indices.reshape(input_ids.shape) - else: - incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask - - position_ids = incremental_indices.astype("i4") + self.config.pad_token_id - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - if head_mask is None: - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - if self.config.add_cross_attention: - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - else: - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - ) - - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta -class FlaxRobertaModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - add_pooling_layer: bool = True - - def setup(self): - self.embeddings = FlaxRobertaEmbeddings( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.encoder = FlaxRobertaEncoder( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.pooler = FlaxRobertaPooler( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - head_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # make sure `token_type_ids` is correctly initialized when not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - # make sure `position_ids` is correctly initialized when not passed - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - hidden_states = self.embeddings( - input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic - ) - outputs = self.encoder( - hidden_states, - attention_mask, - head_mask=head_mask, - deterministic=deterministic, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - pooled = self.pooler(hidden_states) if self.add_pooling_layer else None - - if not return_dict: - # if pooled is None, don't return it - if pooled is None: - return (hidden_states,) + outputs[1:] - return (hidden_states, pooled) + outputs[1:] - - return FlaxBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=hidden_states, - pooler_output=pooled, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -class FlaxRobertaModel(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaModule - - -class FlaxRobertaForMaskedLMModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.roberta = FlaxRobertaModule( - config=self.config, - add_pooling_layer=False, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.lm_head = FlaxRobertaLMHead( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - shared_embedding = fjformer.linen.linen.control_quantization(shared_embedding, self.param_dtype) - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxMaskedLMOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxRobertaForSequenceClassificationModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.roberta = FlaxRobertaModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.classifier = FlaxRobertaClassificationHead( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - logits = self.classifier(sequence_output, deterministic=deterministic) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaForSequenceClassificationModule - - -class FlaxRobertaForMultipleChoiceModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.roberta = FlaxRobertaModule( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.dropout = flax.linen.Dropout(rate=self.config.hidden_dropout_prob) - self.classifier = Linear( - 1, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - num_choices = input_ids.shape[1] - input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None - attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None - token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None - position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None - - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.classifier(pooled_output) - - reshaped_logits = logits.reshape(-1, num_choices) - - if not return_dict: - return (reshaped_logits,) + outputs[2:] - - return FlaxMultipleChoiceModelOutput( - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaForMultipleChoiceModule - - -class FlaxRobertaForTokenClassificationModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.roberta = FlaxRobertaModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - param_dtype=self.param_dtype, - precision=self.precision - ) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = flax.linen.Dropout(rate=classifier_dropout) - self.classifier = Linear( - self.config.num_labels, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - - ) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - logits = self.classifier(hidden_states) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxTokenClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaForTokenClassificationModule - - -class FlaxRobertaForQuestionAnsweringModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.roberta = FlaxRobertaModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.qa_outputs = Linear( - self.config.num_labels, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) - ) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - logits = self.qa_outputs(hidden_states) - start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - return (start_logits, end_logits) + outputs[1:] - - return FlaxQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaForQuestionAnsweringModule - - -class FlaxRobertaForCausalLMModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[lax.Precision] = None - - def setup(self): - self.roberta = FlaxRobertaModule( - config=self.config, - add_pooling_layer=False, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.lm_head = FlaxRobertaLMHead( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - token_type_ids: Optional[jnp.ndarray] = None, - head_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - shared_embedding = fjformer.linen.linen.control_quantization(shared_embedding, self.param_dtype) - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxCausalLMOutputWithCrossAttentions( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs +# Model is modified version from EasyLM +# Supports 8,6,4 BIT and flash attention +import math +from gc import unfreeze +from typing import Optional, Tuple + +import chex +import fjformer +from flax import linen as nn +from flax.core import FrozenDict, freeze +from flax.linen.attention import make_attention_mask, make_causal_mask, combine_masks, dot_product_attention_weights +from flax.linen.partitioning import remat +from flax.traverse_util import flatten_dict, unflatten_dict +import flax.linen +from transformers.modeling_flax_outputs import FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxMaskedLMOutput, \ + FlaxSequenceClassifierOutput, FlaxMultipleChoiceModelOutput, FlaxTokenClassifierOutput, \ + FlaxQuestionAnsweringModelOutput, FlaxCausalLMOutputWithCrossAttentions, \ + FlaxBaseModelOutputWithPoolingAndCrossAttentions + +from .roberta_configuration import RobertaConfig +import jax +from jax.sharding import PartitionSpec +from jax import lax, numpy as jnp +from ..easydel_modelling_utils import EasyDeLFlaxPretrainedModel +from ..attention_module import AttentionModule +from ..flax_modelling_utils import get_gradient_checkpoint_policy, ACT2FN, get_dot_general_by_bits, \ + BaseJAXAttentionModule + +from fjformer.linen import Linear + + +class FlaxRobertaEmbeddings(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.position_embeddings = nn.Embed( + self.config.max_position_embeddings, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = flax.linen.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxRobertaSelfAttention(BaseJAXAttentionModule): + config: RobertaConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " + " : {self.config.num_attention_heads}" + ) + self.attention_performer = AttentionModule( + use_sharding_constraint=self.config.use_sharding_constraint, + block_k_major=self.config.block_k_major, + block_b=self.config.block_b, + block_q=self.config.block_q, + block_k=self.config.block_k, + block_q_major_dkv=self.config.block_q_major_dkv, + block_k_major_dkv=self.config.block_k_major_dkv, + block_k_major_dq=self.config.block_k_major_dq, + block_k_dkv=self.config.block_k_dkv, + block_q_dkv=self.config.block_q_dkv, + block_q_dq=self.config.block_q_dq, + block_k_dq=self.config.block_k_dq, + num_attention_heads=self.config.num_attention_heads, + attention_dropout=0.0, + head_dims=self.head_dim, + attention_partition_spec=self.config.attention_partition_spec, + shard_attention_computation=self.config.shard_attention_computation, + precision=self.precision, + force_float32_tpu=True, + attn_mechanism=self.config.attn_mechanism, + dtype=self.dtype, + bias_partition_spec=self.config.bias_partition_spec, + key_partition_spec=self.config.key_partition_spec, + query_partition_spec=self.config.query_partition_spec, + generation_query_partition_spec=self.config.generation_query_partition_spec, + generation_bias_partition_spec=self.config.generation_bias_partition_spec, + generation_attention_partition_spec=self.config.generation_attention_partition_spec, + value_partition_spec=self.config.value_partition_spec, + scan_ring_attention=self.config.scan_ring_attention, + mesh=self.config.jax_mesh(), + sm_scale=1 / math.sqrt(self.head_dim), + axis_name=self.config.attention_axis_name, + backward_pass_impl=self.config.flash_attention_backward_pass_impl + ) + self.query = Linear( + self.config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + self.key = Linear( + self.config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + self.value = Linear( + self.config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, getattr(self.config, "c_max_position_embeddings", self.config.max_position_embeddings)), + dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + segment_ids: Optional[chex.Array] = None, + key_value_states: Optional[jnp.array] = None, + init_cache: bool = False, + deterministic=True, + output_attentions: bool = False, + ): + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + query_states = self.query(hidden_states) + if is_cross_attention: + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + if attention_mask is not None: + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + if layer_head_mask is None: + out = self.attention_performer.__call__( + query_states=query_states, + key_states=key_states, + value_states=value_states, + dropout_rng=dropout_rng, + deterministic=deterministic, + causal=True, + bias=attention_bias, + attention_mask=attention_mask, + uses_cache=False, + query_sequence_length=query_states.shape[1], + key_value_sequence_length=key_states.shape[1], + segment_ids=segment_ids, + causal_mask=causal_mask + ) + attn_weights = out.attention_weights + attn_output = out.attention_outputs + else: + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxRobertaSelfOutput(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.dense = Linear( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = flax.linen.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class FlaxRobertaAttention(nn.Module): + config: RobertaConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.self = FlaxRobertaSelfAttention( + self.config, + causal=self.causal, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states=None, + init_cache=False, + deterministic=True, + output_attentions: bool = False, + ): + attn_outputs = self.self( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +class FlaxRobertaIntermediate(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.dense = Linear( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class FlaxRobertaOutput(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.dense = Linear( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + precision=self.precision, + param_dtype=self.param_dtype, + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + self.dropout = flax.linen.Dropout(rate=self.config.hidden_dropout_prob) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + attention_output) + return hidden_states + + +class FlaxRobertaLayer(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.attention = FlaxRobertaAttention( + self.config, + causal=self.config.is_decoder, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype) + self.output = FlaxRobertaOutput(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = FlaxRobertaAttention( + self.config, + causal=True, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + ): + # Self Attention + attention_outputs = self.attention( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = attention_outputs[0] + + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) + return outputs + + +class FlaxRobertaLayerCollection(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + block = FlaxRobertaLayer + if self.config.gradient_checkpointing != "": + block = remat( + block, + static_argnums=(5, 6, 7), + policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing) + ) + + self.layers = [ + block( + self.config, + name=str(i), + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # Check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.shape[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for " + f" {head_mask.shape[0]}." + ) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class FlaxRobertaEncoder(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.layer = FlaxRobertaLayerCollection( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class FlaxRobertaPooler(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.dense = Linear( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return nn.tanh(cls_hidden_state) + + +class FlaxRobertaLMHead(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.dense = Linear( + self.config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.decoder = Linear( + self.config.vocab_size, + dtype=self.dtype, + use_bias=False, + param_dtype=self.param_dtype, + precision=self.precision, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + self.bias = self.param( + "bias", + jax.nn.initializers.zeros, + ( + self.config.vocab_size, + ) + ) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.dense(hidden_states) + hidden_states = ACT2FN["gelu"](hidden_states) + hidden_states = self.layer_norm(hidden_states) + + if shared_embedding is not None: + hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + hidden_states = self.decoder(hidden_states) + + bias = fjformer.linen.linen.control_quantization(self.bias, self.dtype) + hidden_states += bias + return hidden_states + + +class FlaxRobertaClassificationHead(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.dense = Linear( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + param_dtype=self.param_dtype, + precision=self.precision, + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = flax.linen.Dropout(rate=classifier_dropout) + self.out_proj = Linear( + self.config.num_labels, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = hidden_states[:, 0, :] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.dense(hidden_states) + hidden_states = nn.tanh(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class FlaxRobertaPreTrainedModel(EasyDeLFlaxPretrainedModel): + config_class = RobertaConfig + base_model_prefix = "roberta" + + module_class: nn.Module = None + + def __init__( + self, + config: RobertaConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + param_dtype: jnp.dtype = jnp.float32, + precision: Optional[lax.Precision] = None, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class( + config=config, + dtype=dtype, + param_dtype=param_dtype, + precision=precision, + **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.ones_like(input_ids) + mask = (input_ids != self.config.pad_token_id).astype("i4") + + if mask.ndim > 2: + mask = mask.reshape((-1, mask.shape[-1])) + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + incremental_indices = incremental_indices.reshape(input_ids.shape) + else: + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + + position_ids = incremental_indices.astype("i4") + self.config.pad_token_id + + attention_mask = jnp.ones_like(input_ids) + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + past_key_values: dict = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if position_ids is None: + mask = (input_ids != self.config.pad_token_id).astype("i4") + + if mask.ndim > 2: + mask = mask.reshape((-1, mask.shape[-1])) + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + incremental_indices = incremental_indices.reshape(input_ids.shape) + else: + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + + position_ids = incremental_indices.astype("i4") + self.config.pad_token_id + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if head_mask is None: + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta +class FlaxRobertaModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + add_pooling_layer: bool = True + + def setup(self): + self.embeddings = FlaxRobertaEmbeddings( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.encoder = FlaxRobertaEncoder( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.pooler = FlaxRobertaPooler( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # make sure `token_type_ids` is correctly initialized when not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + # make sure `position_ids` is correctly initialized when not passed + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + hidden_states = self.embeddings( + input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic + ) + outputs = self.encoder( + hidden_states, + attention_mask, + head_mask=head_mask, + deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +class FlaxRobertaModel(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaModule + + +class FlaxRobertaForMaskedLMModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.lm_head = FlaxRobertaLMHead( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + shared_embedding = fjformer.linen.linen.control_quantization(shared_embedding, self.param_dtype) + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxRobertaForSequenceClassificationModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.classifier = FlaxRobertaClassificationHead( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output, deterministic=deterministic) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForSequenceClassificationModule + + +class FlaxRobertaForMultipleChoiceModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.dropout = flax.linen.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = Linear( + 1, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None + + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForMultipleChoiceModule + + +class FlaxRobertaForTokenClassificationModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + param_dtype=self.param_dtype, + precision=self.precision + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = flax.linen.Dropout(rate=classifier_dropout) + self.classifier = Linear( + self.config.num_labels, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + + ) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForTokenClassificationModule + + +class FlaxRobertaForQuestionAnsweringModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.qa_outputs = Linear( + self.config.num_labels, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + **get_dot_general_by_bits(bits=self.config.bits, mode=self.config.easy_method) + ) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForQuestionAnsweringModule + + +class FlaxRobertaForCausalLMModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[lax.Precision] = None + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.lm_head = FlaxRobertaLMHead( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + token_type_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + shared_embedding = fjformer.linen.linen.control_quantization(shared_embedding, self.param_dtype) + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs diff --git a/src/python/easydel/modules/roberta/roberta_configuration.py b/src/python/easydel/modules/roberta/roberta_configuration.py index 007d9da3f..eb47ec5d1 100644 --- a/src/python/easydel/modules/roberta/roberta_configuration.py +++ b/src/python/easydel/modules/roberta/roberta_configuration.py @@ -1,93 +1,93 @@ -from jax.sharding import PartitionSpec -from ..easydel_modelling_utils import EasyDeLPretrainedConfig - - -class RobertaConfig(EasyDeLPretrainedConfig): - model_type: str = "roberta" - - def __init__( - self, - vocab_size=50265, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=514, - type_vocab_size=1, - initializer_range=0.02, - layer_norm_eps=1e-5, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - position_embedding_type="absolute", - use_cache=True, - classifier_dropout=None, - gradient_checkpointing="nothing_saveable", - **kwargs - ): - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type - self.use_cache = use_cache - self.classifier_dropout = classifier_dropout - self.gradient_checkpointing = gradient_checkpointing - super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) - - def get_partition_rules(self, fully_sharded_data_parallel: bool = True): - return ( - ("embeddings/(position_embeddings|token_type_embeddings)/embedding", PartitionSpec()), - ("embeddings/word_embeddings/embedding", PartitionSpec()), - ("attention/self/(key|query|value)/kernel", PartitionSpec("fsdp", "tp")), - ("attention/self/(key|query|value)/bias", PartitionSpec()), - ("attention/output/dense/kernel", PartitionSpec("tp", "fsdp")), - ("attention/output/dense/bias", PartitionSpec()), - ("(LayerNorm|layer_norm)/(bias|scale)", PartitionSpec()), - ("intermediate/dense/kernel", PartitionSpec("fsdp", "tp")), - ("intermediate/dense/bias", PartitionSpec("tp")), - ("output/dense/kernel", PartitionSpec("tp", "fsdp")), - ("output/dense/bias", PartitionSpec()), - ("lm_head/dense/kernel", PartitionSpec()), - ("lm_head/dense/bias", PartitionSpec()), - ("lm_head/decoder/kernel", PartitionSpec("fsdp", "tp")), - ("lm_head/decoder/bias", PartitionSpec("tp")), - (".*", PartitionSpec()), - ) if not fully_sharded_data_parallel else ( - ("embeddings/(position_embeddings|token_type_embeddings)/embedding", PartitionSpec()), - ("embeddings/word_embeddings/embedding", PartitionSpec()), - ("attention/self/(key|query|value)/kernel", PartitionSpec(("fsdp", "sp"))), - ("attention/self/(key|query|value)/bias", PartitionSpec()), - ("attention/output/dense/kernel", PartitionSpec(("fsdp", "sp"))), - ("attention/output/dense/bias", PartitionSpec()), - ("(LayerNorm|layer_norm)/(bias|scale)", PartitionSpec()), - ("intermediate/dense/kernel", PartitionSpec(("fsdp", "sp"))), - ("intermediate/dense/bias", PartitionSpec("sp")), - ("output/dense/kernel", PartitionSpec(("fsdp", "sp"))), - ("output/dense/bias", PartitionSpec()), - ("lm_head/dense/kernel", PartitionSpec()), - ("lm_head/dense/bias", PartitionSpec()), - ("lm_head/decoder/kernel", PartitionSpec(("fsdp", "sp"))), - ("lm_head/decoder/bias", PartitionSpec("sp")), - (".*", PartitionSpec()), - ) - - def add_jax_args( - self, - gradient_checkpointing="nothing_saveable", - **kwargs - ): - self.gradient_checkpointing = gradient_checkpointing - for k, v in kwargs.items(): - setattr(self, k, v) +from jax.sharding import PartitionSpec +from ..easydel_modelling_utils import EasyDeLPretrainedConfig + + +class RobertaConfig(EasyDeLPretrainedConfig): + model_type: str = "roberta" + + def __init__( + self, + vocab_size=50265, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=514, + type_vocab_size=1, + initializer_range=0.02, + layer_norm_eps=1e-5, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + gradient_checkpointing="nothing_saveable", + **kwargs + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + self.gradient_checkpointing = gradient_checkpointing + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + def get_partition_rules(self, fully_sharded_data_parallel: bool = True): + return ( + ("embeddings/(position_embeddings|token_type_embeddings)/embedding", PartitionSpec()), + ("embeddings/word_embeddings/embedding", PartitionSpec()), + ("attention/self/(key|query|value)/kernel", PartitionSpec("fsdp", "tp")), + ("attention/self/(key|query|value)/bias", PartitionSpec()), + ("attention/output/dense/kernel", PartitionSpec("tp", "fsdp")), + ("attention/output/dense/bias", PartitionSpec()), + ("(LayerNorm|layer_norm)/(bias|scale)", PartitionSpec()), + ("intermediate/dense/kernel", PartitionSpec("fsdp", "tp")), + ("intermediate/dense/bias", PartitionSpec("tp")), + ("output/dense/kernel", PartitionSpec("tp", "fsdp")), + ("output/dense/bias", PartitionSpec()), + ("lm_head/dense/kernel", PartitionSpec()), + ("lm_head/dense/bias", PartitionSpec()), + ("lm_head/decoder/kernel", PartitionSpec("fsdp", "tp")), + ("lm_head/decoder/bias", PartitionSpec("tp")), + (".*", PartitionSpec()), + ) if not fully_sharded_data_parallel else ( + ("embeddings/(position_embeddings|token_type_embeddings)/embedding", PartitionSpec()), + ("embeddings/word_embeddings/embedding", PartitionSpec()), + ("attention/self/(key|query|value)/kernel", PartitionSpec(("fsdp", "sp"))), + ("attention/self/(key|query|value)/bias", PartitionSpec()), + ("attention/output/dense/kernel", PartitionSpec(("fsdp", "sp"))), + ("attention/output/dense/bias", PartitionSpec()), + ("(LayerNorm|layer_norm)/(bias|scale)", PartitionSpec()), + ("intermediate/dense/kernel", PartitionSpec(("fsdp", "sp"))), + ("intermediate/dense/bias", PartitionSpec("sp")), + ("output/dense/kernel", PartitionSpec(("fsdp", "sp"))), + ("output/dense/bias", PartitionSpec()), + ("lm_head/dense/kernel", PartitionSpec()), + ("lm_head/dense/bias", PartitionSpec()), + ("lm_head/decoder/kernel", PartitionSpec(("fsdp", "sp"))), + ("lm_head/decoder/bias", PartitionSpec("sp")), + (".*", PartitionSpec()), + ) + + def add_jax_args( + self, + gradient_checkpointing="nothing_saveable", + **kwargs + ): + self.gradient_checkpointing = gradient_checkpointing + for k, v in kwargs.items(): + setattr(self, k, v) diff --git a/src/python/easydel/modules/rwkv/__init__.py b/src/python/easydel/modules/rwkv/__init__.py index e9bb463d8..642f232dd 100644 --- a/src/python/easydel/modules/rwkv/__init__.py +++ b/src/python/easydel/modules/rwkv/__init__.py @@ -1,11 +1,11 @@ -from .modelling_rwkv_flax import ( - FlaxRwkvForCausalLM as FlaxRwkvForCausalLM, - FlaxRwkvModel as FlaxRwkvModel -) -from .rwkv_configuration import RwkvConfig as RwkvConfig - -__all__ = ( - "FlaxRwkvForCausalLM", - "FlaxRwkvModel", - "RwkvConfig" -) +from .modelling_rwkv_flax import ( + FlaxRwkvForCausalLM as FlaxRwkvForCausalLM, + FlaxRwkvModel as FlaxRwkvModel +) +from .rwkv_configuration import RwkvConfig as RwkvConfig + +__all__ = ( + "FlaxRwkvForCausalLM", + "FlaxRwkvModel", + "RwkvConfig" +) diff --git a/src/python/easydel/modules/stablelm/modelling_stablelm_flax.py b/src/python/easydel/modules/stablelm/modelling_stablelm_flax.py index 901ed1a89..f247df645 100644 --- a/src/python/easydel/modules/stablelm/modelling_stablelm_flax.py +++ b/src/python/easydel/modules/stablelm/modelling_stablelm_flax.py @@ -109,16 +109,18 @@ def setup(self) -> None: self.act_fn = ACT2FN[config.hidden_act] def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - """ - The __call__ function is the main function of a class. + """The __call__ function is the main function of a class. It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments). The __call__ method enables instances of a class to be called like standard Python functions. - :param self: Represent the instance of the class - :param x: jnp.ndarray: Pass in the input to the layer - :param deterministic: bool: Determine whether to use dropout # Ignored - :return: A tensor that is the result of function to x + Args: + self: Represent the instance of the class + x: jnp.ndarray: Pass in the input to the layer + deterministic: bool: Determine whether to use dropout # + Ignored + Returns: + A tensor that is the result of function to x """ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) @@ -225,33 +227,38 @@ def _merge_heads(self, hidden_states): @staticmethod def _transpose_sequence_head(query, key, value): - """ - The _transpose_sequence_head function transposes the query, key and value matrices. + """The _transpose_sequence_head function transposes the query, key and value matrices. - :param query: Get the attention weights for each of the heads - :param key: Determine the number of heads - :param value: Store the values of the input - :return: The transpose of the query, key and value matrices + Args: + query: Get the attention weights for each of the heads + key: Determine the number of heads + value: Store the values of the input + Returns: + The transpose of the query, key and value matrices """ return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): - """ - The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + """The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. - :param self: Access variables that belong to the class - :param batch_size: Reshape the query_states, key and value tensors - :param sequence_length: Reshape the query_states, key and value tensors - :param query: Calculate the attention weights - :param key: Calculate the attention - :param value: Compute the attention weights - :param freq_cis: Calculate the frequency of each word in the vocabulary - :param position_ids: Identify the position of each token in the sequence - :return: A tuple of 3 tensors: query_states, key and value - + Args: + self: Access variables that belong to the class + batch_size: Reshape the query_states, key and value tensors + sequence_length: Reshape the query_states, key and value + tensors + query: Calculate the attention weights + key: Calculate the attention + value: Compute the attention weights + freq_cis: Calculate the frequency of each word in the + vocabulary + position_ids: Identify the position of each token in the + sequence + + Returns: + A tuple of 3 tensors: query_states, key and value """ query = query.reshape( batch_size, @@ -311,25 +318,32 @@ def __call__( output_attentions: bool = False, fcm_mask=None, ): - """ - - The __call__ function is the main function of a JAX module. It defines how the module behaves when called + """The __call__ function is the main function of a JAX module. It defines how the module behaves when called with inputs. The __call__ function can be thought of as a "forward pass" through the model, and it should return all outputs that are needed for training or inference. - :param self: Access variables that belong to the class - :param hidden_states: chex.Array: Pass the hidden states of the previous layer - :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position - :param attention_mask: chex.Array: Mask out certain tokens in the input sequence - :param position_ids: chex.Array: Determine the position of each token in a sequence - :param causal_mask: chex.Array: Mask out the future tokens in the decoder - :param deterministic: bool: Determine whether to use dropout or not - :param init_cache: bool: Initialize the cache - :param output_attentions: bool: Determine whether to return the attention weights or not - :param fcm_mask: Mask out the attention weights between the input and output tokens + Args: + self: Access variables that belong to the class + hidden_states: chex.Array: Pass the hidden states of the + previous layer + freq_cis: Tuple[chex.Array, chex.Array],: Pass in the + frequency coefficients for each position + attention_mask: chex.Array: Mask out certain tokens in the + input sequence + position_ids: chex.Array: Determine the position of each + token in a sequence + causal_mask: chex.Array: Mask out the future tokens in the + decoder + deterministic: bool: Determine whether to use dropout or not + init_cache: bool: Initialize the cache + output_attentions: bool: Determine whether to return the + attention weights or not + fcm_mask: Mask out the attention weights between the input + and output tokens :param : Determine if the attention is causal or not - :return: A tuple of two arrays + Returns: + A tuple of two arrays """ batch_size, sequence_length = hidden_states.shape[:2] query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj( diff --git a/src/python/easydel/modules/t5/modelling_t5_flax.py b/src/python/easydel/modules/t5/modelling_t5_flax.py index 5e898f149..69d842ac1 100644 --- a/src/python/easydel/modules/t5/modelling_t5_flax.py +++ b/src/python/easydel/modules/t5/modelling_t5_flax.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # This model is copied from the Transformers and this script will apply pjit on them -""" Flax T5 model.""" +"""Flax T5 model.""" import copy from typing import Callable, Optional, Tuple diff --git a/src/python/easydel/modules/whisper/__init__.py b/src/python/easydel/modules/whisper/__init__.py index a861bb5d4..50bbd134a 100644 --- a/src/python/easydel/modules/whisper/__init__.py +++ b/src/python/easydel/modules/whisper/__init__.py @@ -1,14 +1,14 @@ -from .modelling_whisper_flax import ( - FlaxWhisperForConditionalGeneration as FlaxWhisperForConditionalGeneration, - FlaxWhisperForAudioClassification as FlaxWhisperForAudioClassification, - FlaxWhisperTimeStampLogitsProcessor as FlaxWhisperTimeStampLogitsProcessor -) - -from .whisper_configuration import WhisperConfig as WhisperConfig - -__all__ = ( - "WhisperConfig", - "FlaxWhisperTimeStampLogitsProcessor", - "FlaxWhisperForAudioClassification", - "FlaxWhisperForConditionalGeneration" -) +from .modelling_whisper_flax import ( + FlaxWhisperForConditionalGeneration as FlaxWhisperForConditionalGeneration, + FlaxWhisperForAudioClassification as FlaxWhisperForAudioClassification, + FlaxWhisperTimeStampLogitsProcessor as FlaxWhisperTimeStampLogitsProcessor +) + +from .whisper_configuration import WhisperConfig as WhisperConfig + +__all__ = ( + "WhisperConfig", + "FlaxWhisperTimeStampLogitsProcessor", + "FlaxWhisperForAudioClassification", + "FlaxWhisperForConditionalGeneration" +) diff --git a/src/python/easydel/modules/whisper/modelling_whisper_flax.py b/src/python/easydel/modules/whisper/modelling_whisper_flax.py index 56bc8e3e4..36071f475 100644 --- a/src/python/easydel/modules/whisper/modelling_whisper_flax.py +++ b/src/python/easydel/modules/whisper/modelling_whisper_flax.py @@ -1,1539 +1,1539 @@ -import random -from functools import partial - -import fjformer -from flax.linen import make_causal_mask -from jax.random import PRNGKey -import math -from typing import Optional, Tuple, Union, Any -import flax.linen -from fjformer import linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks -from flax.linen import partitioning as nn_partitioning -from flax.traverse_util import flatten_dict, unflatten_dict -from fjformer.linen import Linear -from jax import lax -from jax.sharding import PartitionSpec -from transformers import FlaxWhisperTimeStampLogitsProcessor -from transformers.modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxSequenceClassifierOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxSeq2SeqModelOutput, - FlaxSeq2SeqLMOutput, - FlaxCausalLMOutputWithCrossAttentions -) - -from .whisper_configuration import WhisperConfig -from ..attention_module import AttentionModule -from ..easydel_modelling_utils import EasyDeLFlaxPretrainedModel -# easydel.modules -from ..flax_modelling_utils import ( - with_sharding_constraint, - get_gradient_checkpoint_policy, - get_dot_general_by_bits, - BaseJAXAttentionModule, - ACT2FN -) - -remat = nn_partitioning.remat - - -def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array: - length, channels = shape - if channels % 2 != 0: - raise ValueError( - f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." - ) - log_timescale_increment = math.log(10000) / (channels // 2 - 1) - inv_timescales = jnp.exp(-log_timescale_increment * jnp.arange(channels // 2)) - scaled_time = jnp.arange(length).reshape(-1, 1) * inv_timescales.reshape(1, -1) - return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype) - - -class FlaxWhisperAttention(BaseJAXAttentionModule): - config: WhisperConfig - embed_dim: int - num_heads: int - dropout: float = 0.0 - causal: bool = False - bias: bool = True - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self) -> None: - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {self.num_heads})." - ) - - dense = partial( - Linear, - self.embed_dim, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - - self.q_proj = dense(use_bias=self.bias) - self.k_proj = dense(use_bias=False) - self.v_proj = dense(use_bias=self.bias) - self.out_proj = dense(use_bias=self.bias) - - self.attention_performer = AttentionModule( - use_sharding_constraint=self.config.use_sharding_constraint, - block_k_major=self.config.block_k_major, - block_b=self.config.block_b, - block_q=self.config.block_q, - block_k=self.config.block_k, - block_q_major_dkv=self.config.block_q_major_dkv, - block_k_major_dkv=self.config.block_k_major_dkv, - block_k_major_dq=self.config.block_k_major_dq, - block_k_dkv=self.config.block_k_dkv, - block_q_dkv=self.config.block_q_dkv, - block_q_dq=self.config.block_q_dq, - block_k_dq=self.config.block_k_dq, - num_attention_heads=self.config.num_attention_heads, - attention_dropout=self.config.attention_dropout, - head_dims=self.head_dim, - attention_partition_spec=self.config.attention_partition_spec, - shard_attention_computation=self.config.shard_attention_computation, - precision=self.precision, - force_float32_tpu=True, - attn_mechanism=self.config.attn_mechanism, - dtype=self.dtype, - bias_partition_spec=self.config.bias_partition_spec, - key_partition_spec=self.config.key_partition_spec, - query_partition_spec=self.config.query_partition_spec, - generation_query_partition_spec=self.config.generation_query_partition_spec, - generation_bias_partition_spec=self.config.generation_bias_partition_spec, - generation_attention_partition_spec=self.config.generation_attention_partition_spec, - value_partition_spec=self.config.value_partition_spec, - scan_ring_attention=self.config.scan_ring_attention, - mesh=self.config.jax_mesh(), - sm_scale=1 / math.sqrt(self.head_dim), - axis_name=self.config.attention_axis_name, - backward_pass_impl=self.config.flash_attention_backward_pass_impl - ) - - def __call__( - self, - hidden_states: jnp.ndarray, - key_value_states: Optional[jnp.ndarray] = None, - attention_mask: Optional[jnp.ndarray] = None, - causal_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - ) -> tuple[Any, Any]: - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - query_states = self.q_proj(hidden_states) - - if is_cross_attention: - key_states = self.k_proj(key_value_states) - value_states = self.v_proj(key_value_states) - else: - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - if self.causal: - assert causal_mask is not None, "seems like you forgot to pass causal_mask" - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - causal_mask, - (0, 0, mask_shift, 0), - (1, 1, query_length, max_decoder_length), - ) - else: - causal_mask = causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - query_length, key_length = query_states.shape[1], key_states.shape[1] - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - attentions = self.attention_performer.__call__( - query_states=query_states, - key_states=key_states, - value_states=value_states, - bias=attention_bias, - attention_mask=attention_mask, - causal=False, - dropout_rng=dropout_rng, - deterministic=deterministic, - query_sequence_length=query_length, - key_value_sequence_length=key_length, - uses_cache=self.has_variable("cache", "cached_key") or init_cache, - segment_ids=None, - causal_mask=causal_mask - ) - - attn_output = self._merge_heads(attentions.attention_outputs) - if self.config.shard_attention_computation: - attn_output = with_sharding_constraint( - attn_output, PartitionSpec( - ("dp", "fsdp"), - "sp" if attn_output.shape[1] != 1 else None, - "tp" - ) - ) - attn_output = self.out_proj(attn_output) - - return attn_output, attentions.attention_outputs - - def _split_heads(self, hidden_state) -> jnp.ndarray: - return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_state) -> jnp.ndarray: - return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,)) - - -class FlaxWhisperEncoderLayer(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxWhisperAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.encoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.dropout_layer = flax.linen.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = flax.linen.Dropout(rate=self.config.activation_dropout) - self.fc1 = Linear( - self.config.encoder_ffn_dim, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - self.fc2 = Linear( - self.embed_dim, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - causal_mask: Optional[jnp.ndarray] = None, - output_attentions: bool = True, - deterministic: bool = True, - ) -> Tuple[jnp.ndarray]: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - causal_mask=causal_mask - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -class FlaxWhisperEncoderLayerCollection(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self): - block = FlaxWhisperEncoderLayer - if self.config.gradient_checkpointing != "": - block = remat( - block, - static_argnums=(2, 3, 4), - policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing) - ) - self.layers = [ - block(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.encoder_layers) - ] - self.layerdrop = self.config.encoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - causal_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for encoder_layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): # skip the layer - layer_outputs = (None, None) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_mask, - output_attentions, - deterministic, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -class FlaxWhisperDecoderLayer(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxWhisperAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - causal=True, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.dropout_layer = flax.linen.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = flax.linen.Dropout(rate=self.config.activation_dropout) - - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.encoder_attn = FlaxWhisperAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.fc1 = Linear( - self.config.decoder_ffn_dim, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - self.fc2 = Linear( - self.embed_dim, - param_dtype=self.param_dtype, - precision=self.precision, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - causal_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = True, - deterministic: bool = True, - ) -> Tuple[jnp.ndarray]: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - causal_mask=causal_mask, - init_cache=init_cache - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights = self.encoder_attn( - hidden_states=hidden_states, - causal_mask=causal_mask, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs - - -class FlaxWhisperDecoderLayerCollection(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self): - - block = FlaxWhisperDecoderLayer - if self.config.gradient_checkpointing != "": - block = remat( - block, - static_argnums=(4, 5, 6, 7), - policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing) - ) - self.layers = [ - block(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.decoder_layers) - ] - - self.layerdrop = self.config.decoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - causal_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): - layer_outputs = (None, None, None) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - init_cache, - output_attentions, - deterministic, - ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class FlaxWhisperEncoder(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self) -> None: - self.conv1 = nn.Conv( - self.config.d_model, - kernel_size=(3,), - padding=1, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.conv2 = nn.Conv( - self.config.d_model, - kernel_size=(3,), - strides=2, - padding=1, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - self.dropout_layer = flax.linen.Dropout(rate=self.config.dropout) - - self.layers = FlaxWhisperEncoderLayerCollection( - self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - self.embed_positions = nn.Embed( - self.config.max_source_positions, - self.config.d_model, - dtype=self.dtype, - embedding_init=sinusoidal_embedding_init, - param_dtype=self.param_dtype, - ) - - self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_features: jnp.ndarray, - causal_mask: Optional[jnp.ndarray] = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ) -> tuple[Any | None, ...] | FlaxBaseModelOutput: - if input_features.shape[1:] != (self.config.num_mel_bins, self.config.max_source_positions * 2): - raise ValueError( - "input_features.shape[1:], must be equal to (self.config.num_mel_bins," - f" self.config.max_source_positions * 2) (got {input_features.shape[1:]}, but should be" - f" ({self.config.num_mel_bins}, {self.config.max_source_positions * 2}))" - ) - - input_features = input_features.transpose(0, 2, 1) - hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False) - hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) - - embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions)) - # freeze the sinusoidal embeddings by stopping the back-prop - embed_positions = jax.lax.stop_gradient(embed_positions) - hidden_states = hidden_states + embed_positions - - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - causal_mask=causal_mask, - attention_mask=None, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_states = outputs[0] - last_hidden_states = self.layer_norm(last_hidden_states) - - # update the last element in `hidden_states` after applying `layernorm` above - hidden_states = None - if output_hidden_states: - hidden_states = outputs[1] - hidden_states = hidden_states[:-1] + (last_hidden_states,) - - if not return_dict: - outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=last_hidden_states, - hidden_states=hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxWhisperDecoder(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self) -> None: - self.embed_tokens = nn.Embed( - self.config.vocab_size, - self.config.d_model, - dtype=self.dtype, - param_dtype=self.param_dtype - ) - self.embed_positions = nn.Embed( - self.config.max_target_positions, - self.config.d_model, - dtype=self.dtype, - param_dtype=self.param_dtype - ) - - self.layers = FlaxWhisperDecoderLayerCollection( - self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - self.dropout_layer = flax.linen.Dropout(rate=self.config.dropout) - - self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-5) - - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: jnp.ndarray, - position_ids: jnp.ndarray, - causal_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ) -> tuple[Any, ...] | FlaxBaseModelOutputWithPastAndCrossAttentions: - input_embeds = self.embed_tokens(input_ids) - position_embeds = self.embed_positions(position_ids) - - hidden_states = input_embeds + position_embeds - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask=attention_mask, - causal_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_states = outputs[0] - last_hidden_states = self.layer_norm(last_hidden_states) - - hidden_states = None - if output_hidden_states: - hidden_states = outputs[1] - hidden_states = hidden_states[:-1] + (last_hidden_states,) - - if not return_dict: - outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=last_hidden_states, - hidden_states=hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -class FlaxWhisperModule(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self) -> None: - self.encoder = FlaxWhisperEncoder( - self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.decoder = FlaxWhisperDecoder( - self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - - self.causal_mask = make_causal_mask( - jnp.ones((1, max(self.config.max_source_positions, self.config.target_positions)), dtype="bool"), - dtype="bool" - ) - - def __call__( - self, - input_features: jnp.ndarray, - decoder_input_ids: jnp.ndarray, - decoder_attention_mask: jnp.ndarray, - decoder_position_ids: jnp.ndarray, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - encoder_outputs = self.encoder( - input_features, - causal_mask=self.causal_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - causal_mask=self.causal_mask, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - -class FlaxWhisperPreTrainedModel(EasyDeLFlaxPretrainedModel): - config_class = WhisperConfig - base_model_prefix: str = "model" - main_input_name = "input_features" - module_class: nn.Module = None - - def __init__( - self, - config: WhisperConfig, - input_shape: Tuple[int] = None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - param_dtype: jnp.dtype = jnp.float32, - precision: Optional[Union[str, lax.Precision]] = None, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class( - config=config, - dtype=dtype, - param_dtype=param_dtype, - precision=precision, - **kwargs - ) - if input_shape is None: - input_shape = (1, config.num_mel_bins, 2 * config.max_source_positions) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def enable_gradient_checkpointing(self): - self._module = self.module_class( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=True, - ) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_features = jnp.zeros(input_shape, dtype="f4") - input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) - - decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_features=input_features, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length, encoder_outputs): - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape - ) - - def _decoder_forward( - module, - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs - ): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, - ) - return unfreeze(init_variables["cache"]) - - def encode( - self, - input_features: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - add_params_field: bool = False, - **kwargs, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, input_features, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_features, **kwargs) - - return self.module.apply( - {"params": params or self.params} if add_params_field else params or self.params, - input_features=jnp.array(input_features, dtype="f4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: dict = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - add_params_field: bool = False, - ): - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - if decoder_attention_mask is not None: - decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 - else: - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} if add_params_field else params or self.params - - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def __call__( - self, - input_features: jnp.ndarray, - decoder_input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - extra_embedding: Optional[Union[jnp.ndarray, None]] = None, - add_params_field: bool = False, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare decoder inputs - if decoder_position_ids is None: - if decoder_attention_mask is not None: - decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 - else: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params} if add_params_field else params or self.params, - input_features=jnp.array(input_features, dtype="f4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - -class FlaxWhisperModel(FlaxWhisperPreTrainedModel): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - module_class = FlaxWhisperModule - - -class FlaxWhisperForConditionalGenerationModule(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self) -> None: - self.model = FlaxWhisperModule( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.lm_head = Linear( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_features, - decoder_input_ids, - decoder_attention_mask: jnp.ndarray = None, - decoder_position_ids: jnp.ndarray = None, - position_ids: jnp.ndarray = None, - attention_mask: jnp.ndarray = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_features=input_features, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"] - - shared_embedding = fjformer.linen.linen.control_quantization(shared_embedding, self.param_dtype).T - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return output - - return FlaxSeq2SeqLMOutput( - logits=lm_logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): - module_class = FlaxWhisperForConditionalGenerationModule - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: dict = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - add_params_field: Optional[bool] = False - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - if decoder_attention_mask is not None: - decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 - else: - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4") - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} if add_params_field else params or self.params - - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - outputs = decoder_module( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - **kwargs, - ) - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"] - - shared_embedding = fjformer.linen.linen.control_quantization(shared_embedding, self.param_dtype).T - lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding}}, hidden_states) - else: - lm_logits = module.lm_head(hidden_states) - - return lm_logits, outputs - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - if past_key_values is None: - lm_logits, decoder_outputs = outputs - else: - (lm_logits, decoder_outputs), past = outputs - - if return_dict: - outputs = FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - else: - outputs = (lm_logits,) + decoder_outputs[1:] - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def generate( - self, - input_features, - generation_config=None, - logits_processor=None, - return_timestamps=None, - task=None, - language=None, - is_multilingual=None, - **kwargs, - ): - if generation_config is None: - generation_config = self.generation_config - - if return_timestamps is not None: - generation_config.return_timestamps = return_timestamps - - if task is not None: - generation_config.task = task - - if is_multilingual is not None: - generation_config.is_multilingual = is_multilingual - - if language is not None: - generation_config.language = language - - if kwargs is not None and "decoder_input_ids" in kwargs: - decoder_input_length = len(kwargs["decoder_input_ids"]) - else: - decoder_input_length = 1 - - forced_decoder_ids = [] - - if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual: - if hasattr(generation_config, "language"): - forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language])) - else: - forced_decoder_ids.append((1, None)) - - if hasattr(generation_config, "task"): - forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) - else: - forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) - - if ( - hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps - ) or return_timestamps: - logits_processor = [ - FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length) - ] - else: - if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id: - idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 - forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) - - if len(forced_decoder_ids) > 0: - generation_config.forced_decoder_ids = forced_decoder_ids - - return super().generate( - input_features, - generation_config, - logits_processor=logits_processor, - **kwargs, - ) - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jax.Array] = None, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - position_ids = decoder_attention_mask.cumsum(-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - "decoder_position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 - return model_kwargs - - -class FlaxWhisperForAudioClassificationModule(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def setup(self) -> None: - self.encoder = FlaxWhisperEncoder( - config=self.config, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision - ) - self.config.is_encoder_decoder = False - num_layers = self.config.num_hidden_layers + 1 - if self.config.use_weighted_layer_sum: - self.layer_weights = jnp.repeat(1 / num_layers, num_layers) - self.projector = Linear( - self.config.classifier_proj_size, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - self.classifier = Linear( - self.config.num_labels, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - **get_dot_general_by_bits(self.config.bits, self.config.easy_method) - ) - - def __call__( - self, - input_features, - encoder_outputs=None, - output_attentions=None, - output_hidden_states: bool = True, - return_dict: bool = True, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_features, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if self.config.use_weighted_layer_sum: - hidden_states = jnp.stack(encoder_outputs, axis=1) - norm_weights = jax.nn.softmax(self.layer_weights, axis=-1) - hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1) - else: - hidden_states = encoder_outputs[0] - - hidden_states = self.projector(hidden_states) - pooled_output = jnp.mean(hidden_states, axis=1) - - logits = self.classifier(pooled_output) - - if not return_dict: - return (logits,) + encoder_outputs[1:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel): - module_class = FlaxWhisperForAudioClassificationModule - dtype: jnp.dtype = jnp.float32 - param_dtype: jnp.dtype = jnp.float32 - precision: Optional[Union[str, lax.Precision]] = None - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_features = jnp.zeros(input_shape, dtype="f4") - input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_features=input_features, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def __call__( - self, - input_features: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - add_params_field: Optional[bool] = False, - **kwargs, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - return self.module.apply( - {"params": params or self.params} if add_params_field else params or self.params, - input_features=jnp.array(input_features, dtype="f4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - ) +import random +from functools import partial + +import fjformer +from flax.linen import make_causal_mask +from jax.random import PRNGKey +import math +from typing import Optional, Tuple, Union, Any +import flax.linen +from fjformer import linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks +from flax.linen import partitioning as nn_partitioning +from flax.traverse_util import flatten_dict, unflatten_dict +from fjformer.linen import Linear +from jax import lax +from jax.sharding import PartitionSpec +from transformers import FlaxWhisperTimeStampLogitsProcessor +from transformers.modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxSequenceClassifierOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxSeq2SeqModelOutput, + FlaxSeq2SeqLMOutput, + FlaxCausalLMOutputWithCrossAttentions +) + +from .whisper_configuration import WhisperConfig +from ..attention_module import AttentionModule +from ..easydel_modelling_utils import EasyDeLFlaxPretrainedModel +# easydel.modules +from ..flax_modelling_utils import ( + with_sharding_constraint, + get_gradient_checkpoint_policy, + get_dot_general_by_bits, + BaseJAXAttentionModule, + ACT2FN +) + +remat = nn_partitioning.remat + + +def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array: + length, channels = shape + if channels % 2 != 0: + raise ValueError( + f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + ) + log_timescale_increment = math.log(10000) / (channels // 2 - 1) + inv_timescales = jnp.exp(-log_timescale_increment * jnp.arange(channels // 2)) + scaled_time = jnp.arange(length).reshape(-1, 1) * inv_timescales.reshape(1, -1) + return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype) + + +class FlaxWhisperAttention(BaseJAXAttentionModule): + config: WhisperConfig + embed_dim: int + num_heads: int + dropout: float = 0.0 + causal: bool = False + bias: bool = True + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + + dense = partial( + Linear, + self.embed_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + + self.q_proj = dense(use_bias=self.bias) + self.k_proj = dense(use_bias=False) + self.v_proj = dense(use_bias=self.bias) + self.out_proj = dense(use_bias=self.bias) + + self.attention_performer = AttentionModule( + use_sharding_constraint=self.config.use_sharding_constraint, + block_k_major=self.config.block_k_major, + block_b=self.config.block_b, + block_q=self.config.block_q, + block_k=self.config.block_k, + block_q_major_dkv=self.config.block_q_major_dkv, + block_k_major_dkv=self.config.block_k_major_dkv, + block_k_major_dq=self.config.block_k_major_dq, + block_k_dkv=self.config.block_k_dkv, + block_q_dkv=self.config.block_q_dkv, + block_q_dq=self.config.block_q_dq, + block_k_dq=self.config.block_k_dq, + num_attention_heads=self.config.num_attention_heads, + attention_dropout=self.config.attention_dropout, + head_dims=self.head_dim, + attention_partition_spec=self.config.attention_partition_spec, + shard_attention_computation=self.config.shard_attention_computation, + precision=self.precision, + force_float32_tpu=True, + attn_mechanism=self.config.attn_mechanism, + dtype=self.dtype, + bias_partition_spec=self.config.bias_partition_spec, + key_partition_spec=self.config.key_partition_spec, + query_partition_spec=self.config.query_partition_spec, + generation_query_partition_spec=self.config.generation_query_partition_spec, + generation_bias_partition_spec=self.config.generation_bias_partition_spec, + generation_attention_partition_spec=self.config.generation_attention_partition_spec, + value_partition_spec=self.config.value_partition_spec, + scan_ring_attention=self.config.scan_ring_attention, + mesh=self.config.jax_mesh(), + sm_scale=1 / math.sqrt(self.head_dim), + axis_name=self.config.attention_axis_name, + backward_pass_impl=self.config.flash_attention_backward_pass_impl + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + causal_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + ) -> tuple[Any, Any]: + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + query_states = self.q_proj(hidden_states) + + if is_cross_attention: + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + if self.causal: + assert causal_mask is not None, "seems like you forgot to pass causal_mask" + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + causal_mask, + (0, 0, mask_shift, 0), + (1, 1, query_length, max_decoder_length), + ) + else: + causal_mask = causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + query_length, key_length = query_states.shape[1], key_states.shape[1] + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attentions = self.attention_performer.__call__( + query_states=query_states, + key_states=key_states, + value_states=value_states, + bias=attention_bias, + attention_mask=attention_mask, + causal=False, + dropout_rng=dropout_rng, + deterministic=deterministic, + query_sequence_length=query_length, + key_value_sequence_length=key_length, + uses_cache=self.has_variable("cache", "cached_key") or init_cache, + segment_ids=None, + causal_mask=causal_mask + ) + + attn_output = self._merge_heads(attentions.attention_outputs) + if self.config.shard_attention_computation: + attn_output = with_sharding_constraint( + attn_output, PartitionSpec( + ("dp", "fsdp"), + "sp" if attn_output.shape[1] != 1 else None, + "tp" + ) + ) + attn_output = self.out_proj(attn_output) + + return attn_output, attentions.attention_outputs + + def _split_heads(self, hidden_state) -> jnp.ndarray: + return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_state) -> jnp.ndarray: + return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,)) + + +class FlaxWhisperEncoderLayer(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxWhisperAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.encoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.dropout_layer = flax.linen.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = flax.linen.Dropout(rate=self.config.activation_dropout) + self.fc1 = Linear( + self.config.encoder_ffn_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + self.fc2 = Linear( + self.embed_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + causal_mask: Optional[jnp.ndarray] = None, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_mask=causal_mask + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class FlaxWhisperEncoderLayerCollection(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self): + block = FlaxWhisperEncoderLayer + if self.config.gradient_checkpointing != "": + block = remat( + block, + static_argnums=(2, 3, 4), + policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing) + ) + self.layers = [ + block(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.encoder_layers) + ] + self.layerdrop = self.config.encoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + causal_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_mask, + output_attentions, + deterministic, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxWhisperDecoderLayer(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxWhisperAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + causal=True, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.dropout_layer = flax.linen.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = flax.linen.Dropout(rate=self.config.activation_dropout) + + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.encoder_attn = FlaxWhisperAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.fc1 = Linear( + self.config.decoder_ffn_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + self.fc2 = Linear( + self.embed_dim, + param_dtype=self.param_dtype, + precision=self.precision, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + causal_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_mask=causal_mask, + init_cache=init_cache + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + causal_mask=causal_mask, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class FlaxWhisperDecoderLayerCollection(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self): + + block = FlaxWhisperDecoderLayer + if self.config.gradient_checkpointing != "": + block = remat( + block, + static_argnums=(4, 5, 6, 7), + policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing) + ) + self.layers = [ + block(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.decoder_layers) + ] + + self.layerdrop = self.config.decoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + causal_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): + layer_outputs = (None, None, None) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + causal_mask, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + output_attentions, + deterministic, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class FlaxWhisperEncoder(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self) -> None: + self.conv1 = nn.Conv( + self.config.d_model, + kernel_size=(3,), + padding=1, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.conv2 = nn.Conv( + self.config.d_model, + kernel_size=(3,), + strides=2, + padding=1, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + self.dropout_layer = flax.linen.Dropout(rate=self.config.dropout) + + self.layers = FlaxWhisperEncoderLayerCollection( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + self.embed_positions = nn.Embed( + self.config.max_source_positions, + self.config.d_model, + dtype=self.dtype, + embedding_init=sinusoidal_embedding_init, + param_dtype=self.param_dtype, + ) + + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + input_features: jnp.ndarray, + causal_mask: Optional[jnp.ndarray] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ) -> tuple[Any | None, ...] | FlaxBaseModelOutput: + if input_features.shape[1:] != (self.config.num_mel_bins, self.config.max_source_positions * 2): + raise ValueError( + "input_features.shape[1:], must be equal to (self.config.num_mel_bins," + f" self.config.max_source_positions * 2) (got {input_features.shape[1:]}, but should be" + f" ({self.config.num_mel_bins}, {self.config.max_source_positions * 2}))" + ) + + input_features = input_features.transpose(0, 2, 1) + hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False) + hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) + + embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions)) + # freeze the sinusoidal embeddings by stopping the back-prop + embed_positions = jax.lax.stop_gradient(embed_positions) + hidden_states = hidden_states + embed_positions + + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + causal_mask=causal_mask, + attention_mask=None, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_states = outputs[0] + last_hidden_states = self.layer_norm(last_hidden_states) + + # update the last element in `hidden_states` after applying `layernorm` above + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_states,) + + if not return_dict: + outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=last_hidden_states, + hidden_states=hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxWhisperDecoder(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self) -> None: + self.embed_tokens = nn.Embed( + self.config.vocab_size, + self.config.d_model, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + self.embed_positions = nn.Embed( + self.config.max_target_positions, + self.config.d_model, + dtype=self.dtype, + param_dtype=self.param_dtype + ) + + self.layers = FlaxWhisperDecoderLayerCollection( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + self.dropout_layer = flax.linen.Dropout(rate=self.config.dropout) + + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-5) + + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: jnp.ndarray, + position_ids: jnp.ndarray, + causal_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ) -> tuple[Any, ...] | FlaxBaseModelOutputWithPastAndCrossAttentions: + input_embeds = self.embed_tokens(input_ids) + position_embeds = self.embed_positions(position_ids) + + hidden_states = input_embeds + position_embeds + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask=attention_mask, + causal_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_states = outputs[0] + last_hidden_states = self.layer_norm(last_hidden_states) + + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_states,) + + if not return_dict: + outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=last_hidden_states, + hidden_states=hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +class FlaxWhisperModule(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self) -> None: + self.encoder = FlaxWhisperEncoder( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.decoder = FlaxWhisperDecoder( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + + self.causal_mask = make_causal_mask( + jnp.ones((1, max(self.config.max_source_positions, self.config.target_positions)), dtype="bool"), + dtype="bool" + ) + + def __call__( + self, + input_features: jnp.ndarray, + decoder_input_ids: jnp.ndarray, + decoder_attention_mask: jnp.ndarray, + decoder_position_ids: jnp.ndarray, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + encoder_outputs = self.encoder( + input_features, + causal_mask=self.causal_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + causal_mask=self.causal_mask, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + +class FlaxWhisperPreTrainedModel(EasyDeLFlaxPretrainedModel): + config_class = WhisperConfig + base_model_prefix: str = "model" + main_input_name = "input_features" + module_class: nn.Module = None + + def __init__( + self, + config: WhisperConfig, + input_shape: Tuple[int] = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + param_dtype: jnp.dtype = jnp.float32, + precision: Optional[Union[str, lax.Precision]] = None, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class( + config=config, + dtype=dtype, + param_dtype=param_dtype, + precision=precision, + **kwargs + ) + if input_shape is None: + input_shape = (1, config.num_mel_bins, 2 * config.max_source_positions) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_features = jnp.zeros(input_shape, dtype="f4") + input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) + + decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_features=input_features, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length, encoder_outputs): + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward( + module, + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs + ): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, + ) + return unfreeze(init_variables["cache"]) + + def encode( + self, + input_features: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + add_params_field: bool = False, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_features, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_features, **kwargs) + + return self.module.apply( + {"params": params or self.params} if add_params_field else params or self.params, + input_features=jnp.array(input_features, dtype="f4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + add_params_field: bool = False, + ): + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + if decoder_attention_mask is not None: + decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 + else: + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} if add_params_field else params or self.params + + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def __call__( + self, + input_features: jnp.ndarray, + decoder_input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + extra_embedding: Optional[Union[jnp.ndarray, None]] = None, + add_params_field: bool = False, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare decoder inputs + if decoder_position_ids is None: + if decoder_attention_mask is not None: + decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 + else: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params} if add_params_field else params or self.params, + input_features=jnp.array(input_features, dtype="f4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +class FlaxWhisperModel(FlaxWhisperPreTrainedModel): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + module_class = FlaxWhisperModule + + +class FlaxWhisperForConditionalGenerationModule(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self) -> None: + self.model = FlaxWhisperModule( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.lm_head = Linear( + self.config.vocab_size, + use_bias=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_features, + decoder_input_ids, + decoder_attention_mask: jnp.ndarray = None, + decoder_position_ids: jnp.ndarray = None, + position_ids: jnp.ndarray = None, + attention_mask: jnp.ndarray = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_features=input_features, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"] + + shared_embedding = fjformer.linen.linen.control_quantization(shared_embedding, self.param_dtype).T + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return output + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): + module_class = FlaxWhisperForConditionalGenerationModule + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + add_params_field: Optional[bool] = False + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + if decoder_attention_mask is not None: + decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 + else: + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4") + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} if add_params_field else params or self.params + + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + outputs = decoder_module( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + **kwargs, + ) + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"] + + shared_embedding = fjformer.linen.linen.control_quantization(shared_embedding, self.param_dtype).T + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding}}, hidden_states) + else: + lm_logits = module.lm_head(hidden_states) + + return lm_logits, outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def generate( + self, + input_features, + generation_config=None, + logits_processor=None, + return_timestamps=None, + task=None, + language=None, + is_multilingual=None, + **kwargs, + ): + if generation_config is None: + generation_config = self.generation_config + + if return_timestamps is not None: + generation_config.return_timestamps = return_timestamps + + if task is not None: + generation_config.task = task + + if is_multilingual is not None: + generation_config.is_multilingual = is_multilingual + + if language is not None: + generation_config.language = language + + if kwargs is not None and "decoder_input_ids" in kwargs: + decoder_input_length = len(kwargs["decoder_input_ids"]) + else: + decoder_input_length = 1 + + forced_decoder_ids = [] + + if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual: + if hasattr(generation_config, "language"): + forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language])) + else: + forced_decoder_ids.append((1, None)) + + if hasattr(generation_config, "task"): + forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) + else: + forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) + + if ( + hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps + ) or return_timestamps: + logits_processor = [ + FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length) + ] + else: + if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id: + idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 + forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) + + if len(forced_decoder_ids) > 0: + generation_config.forced_decoder_ids = forced_decoder_ids + + return super().generate( + input_features, + generation_config, + logits_processor=logits_processor, + **kwargs, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + position_ids = decoder_attention_mask.cumsum(-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + +class FlaxWhisperForAudioClassificationModule(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def setup(self) -> None: + self.encoder = FlaxWhisperEncoder( + config=self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision + ) + self.config.is_encoder_decoder = False + num_layers = self.config.num_hidden_layers + 1 + if self.config.use_weighted_layer_sum: + self.layer_weights = jnp.repeat(1 / num_layers, num_layers) + self.projector = Linear( + self.config.classifier_proj_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + self.classifier = Linear( + self.config.num_labels, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + **get_dot_general_by_bits(self.config.bits, self.config.easy_method) + ) + + def __call__( + self, + input_features, + encoder_outputs=None, + output_attentions=None, + output_hidden_states: bool = True, + return_dict: bool = True, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_features, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = jnp.stack(encoder_outputs, axis=1) + norm_weights = jax.nn.softmax(self.layer_weights, axis=-1) + hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1) + else: + hidden_states = encoder_outputs[0] + + hidden_states = self.projector(hidden_states) + pooled_output = jnp.mean(hidden_states, axis=1) + + logits = self.classifier(pooled_output) + + if not return_dict: + return (logits,) + encoder_outputs[1:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel): + module_class = FlaxWhisperForAudioClassificationModule + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[str, lax.Precision]] = None + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_features = jnp.zeros(input_shape, dtype="f4") + input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_features=input_features, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def __call__( + self, + input_features: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + add_params_field: Optional[bool] = False, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params} if add_params_field else params or self.params, + input_features=jnp.array(input_features, dtype="f4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) diff --git a/src/python/easydel/modules/whisper/whisper_configuration.py b/src/python/easydel/modules/whisper/whisper_configuration.py index 87e7aacb4..cb688b830 100644 --- a/src/python/easydel/modules/whisper/whisper_configuration.py +++ b/src/python/easydel/modules/whisper/whisper_configuration.py @@ -1,119 +1,119 @@ -import math -from typing import Sequence, Optional - -from jax.sharding import PartitionSpec - -from ..easydel_modelling_utils import EasyDeLPretrainedConfig - - -class WhisperConfig(EasyDeLPretrainedConfig): - model_type: str = "whisper" - attribute_map = { - "num_attention_heads": "encoder_attention_heads", - "hidden_size": "d_model" - } - - def __init__( - self, - vocab_size=51865, - num_mel_bins=80, - encoder_layers=4, - encoder_attention_heads=6, - decoder_layers=4, - decoder_attention_heads=6, - decoder_ffn_dim=1536, - encoder_ffn_dim=1536, - encoder_layerdrop=0.0, - decoder_layerdrop=0.0, - decoder_start_token_id=50257, - use_cache=True, - is_encoder_decoder=True, - activation_function="gelu", - d_model=384, - dropout=0.0, - attention_dropout=0.0, - activation_dropout=0.0, - init_std=0.02, - scale_embedding=False, - max_source_positions=1500, - max_target_positions=448, - pad_token_id=50256, - bos_token_id=50256, - eos_token_id=50256, - suppress_tokens=None, - begin_suppress_tokens=[220, 50256], - use_weighted_layer_sum=False, - classifier_proj_size=256, - apply_spec_augment=False, - mask_time_prob=0.05, - mask_time_length=10, - mask_time_min_masks=2, - mask_feature_prob=0.0, - mask_feature_length=10, - mask_feature_min_masks=0, - median_filter_width=7, - bits: Optional[int] = None, - gradient_checkpointing: str = "nothing_saveable", - **kwargs, - ): - self.vocab_size = vocab_size - self.num_mel_bins = num_mel_bins - self.d_model = d_model - self.encoder_layers = encoder_layers - self.encoder_attention_heads = encoder_attention_heads - self.decoder_layers = decoder_layers - self.decoder_attention_heads = decoder_attention_heads - self.decoder_ffn_dim = decoder_ffn_dim - self.encoder_ffn_dim = encoder_ffn_dim - self.dropout = dropout - self.attention_dropout = attention_dropout - self.activation_dropout = activation_dropout - self.activation_function = activation_function - self.init_std = init_std - self.encoder_layerdrop = encoder_layerdrop - self.decoder_layerdrop = decoder_layerdrop - self.use_cache = use_cache - self.num_hidden_layers = encoder_layers - self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True - self.max_source_positions = max_source_positions - self.max_target_positions = max_target_positions - - # Audio Classification-specific parameters. Feel free to ignore for other classes. - self.classifier_proj_size = classifier_proj_size - self.use_weighted_layer_sum = use_weighted_layer_sum - - # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 - self.apply_spec_augment = apply_spec_augment - self.mask_time_prob = mask_time_prob - self.mask_time_length = mask_time_length - self.mask_time_min_masks = mask_time_min_masks - self.mask_feature_prob = mask_feature_prob - self.mask_feature_length = mask_feature_length - self.mask_feature_min_masks = mask_feature_min_masks - - self.median_filter_width = median_filter_width - self.bits = bits - self.gradient_checkpointing = gradient_checkpointing - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - is_encoder_decoder=is_encoder_decoder, - decoder_start_token_id=decoder_start_token_id, - suppress_tokens=suppress_tokens, - begin_suppress_tokens=begin_suppress_tokens, - **kwargs, - ) - - def add_jax_args( - self, - bits: Optional[int] = None, - gradient_checkpointing: str = "nothing_saveable", - **kwargs - ): - self.bits = bits - self.gradient_checkpointing = gradient_checkpointing - for k, v in kwargs.items(): - if not hasattr(self, k): - setattr(self, k, v) +import math +from typing import Sequence, Optional + +from jax.sharding import PartitionSpec + +from ..easydel_modelling_utils import EasyDeLPretrainedConfig + + +class WhisperConfig(EasyDeLPretrainedConfig): + model_type: str = "whisper" + attribute_map = { + "num_attention_heads": "encoder_attention_heads", + "hidden_size": "d_model" + } + + def __init__( + self, + vocab_size=51865, + num_mel_bins=80, + encoder_layers=4, + encoder_attention_heads=6, + decoder_layers=4, + decoder_attention_heads=6, + decoder_ffn_dim=1536, + encoder_ffn_dim=1536, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + decoder_start_token_id=50257, + use_cache=True, + is_encoder_decoder=True, + activation_function="gelu", + d_model=384, + dropout=0.0, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + scale_embedding=False, + max_source_positions=1500, + max_target_positions=448, + pad_token_id=50256, + bos_token_id=50256, + eos_token_id=50256, + suppress_tokens=None, + begin_suppress_tokens=[220, 50256], + use_weighted_layer_sum=False, + classifier_proj_size=256, + apply_spec_augment=False, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, + median_filter_width=7, + bits: Optional[int] = None, + gradient_checkpointing: str = "nothing_saveable", + **kwargs, + ): + self.vocab_size = vocab_size + self.num_mel_bins = num_mel_bins + self.d_model = d_model + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.max_source_positions = max_source_positions + self.max_target_positions = max_target_positions + + # Audio Classification-specific parameters. Feel free to ignore for other classes. + self.classifier_proj_size = classifier_proj_size + self.use_weighted_layer_sum = use_weighted_layer_sum + + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks + + self.median_filter_width = median_filter_width + self.bits = bits + self.gradient_checkpointing = gradient_checkpointing + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + suppress_tokens=suppress_tokens, + begin_suppress_tokens=begin_suppress_tokens, + **kwargs, + ) + + def add_jax_args( + self, + bits: Optional[int] = None, + gradient_checkpointing: str = "nothing_saveable", + **kwargs + ): + self.bits = bits + self.gradient_checkpointing = gradient_checkpointing + for k, v in kwargs.items(): + if not hasattr(self, k): + setattr(self, k, v) diff --git a/src/python/easydel/partitioning/partitioner.py b/src/python/easydel/partitioning/partitioner.py index c93181930..9e199c23e 100644 --- a/src/python/easydel/partitioning/partitioner.py +++ b/src/python/easydel/partitioning/partitioner.py @@ -17,17 +17,20 @@ def get_partitions( jax_attn_format: bool = True, fsdp_on_batch: bool = True ) -> EasyDeLPartitions: - """ - The get_partitions function is a helper function that returns an EasyDeLPartitions object. + """The get_partitions function is a helper function that returns an EasyDeLPartitions object. The EasyDeLPartitions object contains the PartitionSpec objects for each of the five tensors in the attention computation: query, key, value, bias and attention. The PartitionSpec objects are used to specify how each tensor should be partitioned across devices (i.e., which dimensions of each tensor should be split across devices). For example, if we want to split the batch dimension of all five tensors across two devices then we would set ``query_partition_spec=key_partition_spec=value_partition_spec= - :param jax_attn_format: bool: Specify whether the attention - :param fsdp_on_batch: bool: Determine whether the batch dimension is partitioned - :return: A easydelpartitions object + Args: + jax_attn_format: bool: Specify whether the attention + fsdp_on_batch: bool: Determine whether the batch dimension is + partitioned + + Returns: + A easydelpartitions object """ if jax_attn_format: if fsdp_on_batch: diff --git a/src/python/easydel/reinforcement_learning/__init__.py b/src/python/easydel/reinforcement_learning/__init__.py index d675e74da..b0b3f53ba 100644 --- a/src/python/easydel/reinforcement_learning/__init__.py +++ b/src/python/easydel/reinforcement_learning/__init__.py @@ -1,5 +1,4 @@ -"""Using This Feature is not recommended since it's not fully completed -""" +"""Using This Feature is not recommended since it's not fully completed""" from .models import AutoRLModelForCasualLMWithValueHead __all__ = ( diff --git a/src/python/easydel/reinforcement_learning/core.py b/src/python/easydel/reinforcement_learning/core.py index fdf017449..f8ee5839e 100644 --- a/src/python/easydel/reinforcement_learning/core.py +++ b/src/python/easydel/reinforcement_learning/core.py @@ -189,8 +189,7 @@ def build_bert_batch_from_txt(text_list, tokenizer, device): def multinomial(logits, num_samples: int, replacement: bool = False): - """ - Implements the `torch.multinomial` function in JAX. + """Implements the `torch.multinomial` function in JAX. Args: logits (jnp.array): The unnormalized log probabilities of the events. diff --git a/src/python/easydel/reinforcement_learning/models/modelling_casual_language_rl.py b/src/python/easydel/reinforcement_learning/models/modelling_casual_language_rl.py index 20a61e2b1..be81d92a6 100644 --- a/src/python/easydel/reinforcement_learning/models/modelling_casual_language_rl.py +++ b/src/python/easydel/reinforcement_learning/models/modelling_casual_language_rl.py @@ -39,12 +39,13 @@ class ValueHead(nn.Module): kernel_init: Callable = nn.initializers.orthogonal() def setup(self): - """ - The setup function is called by the model's constructor. + """The setup function is called by the model's constructor. It initializes all the layers in your model, and assigns them to member variables. The setup function should be used for any initialization that needs to happen before running forward(). This includes things like loading weights from a file, or setting up an optimizer. - :param self: Represent the instance of the class + + Args: + self: Represent the instance of the class """ self.dropout = flax.linen.Dropout(self.summary_dropout_prob) @@ -58,16 +59,18 @@ def setup(self): ) def __call__(self, hidden_states: chex.Array, deterministic: bool = True): - """ - The __call__ function is the main function of a class. + """The __call__ function is the main function of a class. It is called when an instance of the class (an object) is invoked as a function, e.g., x(arg). The __call__ method enables instances of a class to be called like standard Python functions. - :param self: Represent the instance of the class - :param hidden_states: chex.Array: Pass the hidden states of the previous layer - :param deterministic: bool: Determine whether to use dropout - :return: A tensor of shape (batch_size, num_classes) + Args: + self: Represent the instance of the class + hidden_states: chex.Array: Pass the hidden states of the + previous layer + deterministic: bool: Determine whether to use dropout + Returns: + A tensor of shape (batch_size, num_classes) """ return self.summary(self.dropout(hidden_states, deterministic=deterministic)) diff --git a/src/python/easydel/reinforcement_learning/trainer/__init__.py b/src/python/easydel/reinforcement_learning/trainer/__init__.py index d3f5a12fa..8b1378917 100644 --- a/src/python/easydel/reinforcement_learning/trainer/__init__.py +++ b/src/python/easydel/reinforcement_learning/trainer/__init__.py @@ -1 +1 @@ - + diff --git a/src/python/easydel/reinforcement_learning/trainer/partitioner_config.py b/src/python/easydel/reinforcement_learning/trainer/partitioner_config.py index 4b3b74802..68f04247b 100644 --- a/src/python/easydel/reinforcement_learning/trainer/partitioner_config.py +++ b/src/python/easydel/reinforcement_learning/trainer/partitioner_config.py @@ -1,26 +1,26 @@ -from typing import Optional, Sequence - -import jax -from jax.sharding import Sharding, Mesh, PartitionSpec - - -class PartitionerConfig: - def __init__( - self, - axis_dims: Sequence[int] = (1, -1, 1, 1), - axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"), - backend: Optional[None] = jax.default_backend(), - input_ids_partition_spec: PartitionSpec = PartitionSpec("dp", "fsdp"), - attention_mask_partition_spec: PartitionSpec = PartitionSpec("dp", "fsdp"), - ): - self.axis_dims = axis_dims - self.axis_names = axis_names - self.backend = backend - self.input_ids_partition_spec = input_ids_partition_spec - self.attention_mask_partition_spec = attention_mask_partition_spec - - def __str__(self): - return self.__repr__() - - def __repr__(self): - return self.__class__.__name__ + "(" + "".join("\n\t" + k for k, v in self.__dict__.items()) + "\n)" +from typing import Optional, Sequence + +import jax +from jax.sharding import Sharding, Mesh, PartitionSpec + + +class PartitionerConfig: + def __init__( + self, + axis_dims: Sequence[int] = (1, -1, 1, 1), + axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"), + backend: Optional[None] = jax.default_backend(), + input_ids_partition_spec: PartitionSpec = PartitionSpec("dp", "fsdp"), + attention_mask_partition_spec: PartitionSpec = PartitionSpec("dp", "fsdp"), + ): + self.axis_dims = axis_dims + self.axis_names = axis_names + self.backend = backend + self.input_ids_partition_spec = input_ids_partition_spec + self.attention_mask_partition_spec = attention_mask_partition_spec + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return self.__class__.__name__ + "(" + "".join("\n\t" + k for k, v in self.__dict__.items()) + "\n)" diff --git a/src/python/easydel/reinforcement_learning/trainer/ppo_config.py b/src/python/easydel/reinforcement_learning/trainer/ppo_config.py index b3e3d546a..ccb199263 100644 --- a/src/python/easydel/reinforcement_learning/trainer/ppo_config.py +++ b/src/python/easydel/reinforcement_learning/trainer/ppo_config.py @@ -50,50 +50,78 @@ def __init__( extra_optimizer_kwargs: dict | None = None, weight_decay: Optional[float] = 0.01, ): - """ - Configuration class for PPOTrainer - :param exp_name: str : the name of this experiment (by default is the file name without the extension name) - :param seed: int :Seed value for random generations - :param task_name: Optional[str] : Name of task to use - used only for tracking purposes - :param model_name: Optional[str] :Name of model to use - used only for tracking purposes - :param query_dataset: Optional[str] :Name of dataset to query - used only for tracking purposes - :param reward_model: Optional[str] :The reward model to use - used only for tracking purposes - :param remove_unused_columns: bool : Remove unused columns from the dataset if `datasets.Dataset` is used - :param tracker_kwargs: Optional[dict] : Keyword arguments for the tracker - :param accelerator_kwargs: Optional[dict] :Keyword arguments for the accelerator - :param project_kwargs: Optional[dict] : Keyword arguments for the accelerator project config (e.g. `logging_dir`) - :param tracker_project_name: str :Name of project to use for tracking - :param push_to_hub_if_best_kwargs: Optional[dict] :Keyword arguments for pushing model to the hub during training - (e.g. pretrained_model_name_or_path). - :param steps: int : Number of training steps - :param learning_rate: float :Adam learning rate - :param adap_kl_ctrl: bool :Use adaptive KL control, otherwise linear - :param init_kl_coef: Optional[float] : Initial KL penalty coefficient (used for adaptive and linear control) - :param kl_penalty: Literal["kl", "abs", "mse", "full"] : kl penalty options: 'kl': model_logp - ref_logp, - 'abs': abs(kl), 'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution - :param target: Optional[float] :Target KL value for adaptive KL control - :param horizon: Optional[float] :Horizon for adaptive KL control - :param gamma: float :Gamma parameter for advantage calculation - :param lam: float : Lambda parameter for advantage calculation - :param cliprange: float : Range for clipping in PPO policy gradient loss - :param cliprange_value: float : Range for clipping values in loss calculation - :param vf_coef: float : Scaling factor for value loss - :param batch_size: int :Number of samples per optimisation step - :param gradient_accumulation_steps: int :The number of gradient accumulation steps - :param ppo_epochs: int : Number of optimisation epochs per batch of samples - :param max_grad_norm: Optional[float] :Maximum gradient norm for gradient clipping - :param target_kl: float :Stop early if we exceed this value by over 50% - :param compare_steps: int : Number of steps between comparison of the current reward with the best seen so far - :param ratio_threshold : float :Skip mini-batches with high PPO ratios that can cause loss spikes - :param use_score_scaling: bool : Use score scaling - :param use_score_norm: bool : Use score normalization. Only applicable if use_score_scaling is True - :param score_clip: Optional[float] :Score clipping - :param whiten_rewards: bool :Whiten the rewards before compute advantages - :param is_encoder_decoder: Optional[bool] :TO BE FILLED In RUNTIME: Whether the model is an encoder-decoder model - :param warmup_steps: Optional[int]: - :param learning_rate_end: float : - :param extra_optimizer_kwargs: dict | None : - :param weight_decay: Optional[float] : Weight decay is Optimizer Weight decay :\ + """Configuration class for PPOTrainer + + Args: + exp_name: str : the name of this experiment (by default is + the file name without the extension name) + seed: int :Seed value for random generations + task_name: Optional[str] : Name of task to use - used only + for tracking purposes + model_name: Optional[str] :Name of model to use - used only + for tracking purposes + query_dataset: Optional[str] :Name of dataset to query - + used only for tracking purposes + reward_model: Optional[str] :The reward model to use - used + only for tracking purposes + remove_unused_columns: bool : Remove unused columns from the + dataset if `datasets.Dataset` is used + tracker_kwargs: Optional[dict] : Keyword arguments for the + tracker + accelerator_kwargs: Optional[dict] :Keyword arguments for + the accelerator + project_kwargs: Optional[dict] : Keyword arguments for the + accelerator project config (e.g. `logging_dir`) + tracker_project_name: str :Name of project to use for + tracking + push_to_hub_if_best_kwargs: Optional[dict] :Keyword + arguments for pushing model to the hub during training + steps: int : Number of training steps + learning_rate: float :Adam learning rate + adap_kl_ctrl: bool :Use adaptive KL control, otherwise + linear + init_kl_coef: Optional[float] : Initial KL penalty + coefficient (used for adaptive and linear control) + kl_penalty: Literal["kl", "abs", "mse", "full"] : kl penalty + options: 'kl': model_logp - ref_logp, + target: Optional[float] :Target KL value for adaptive KL + control + horizon: Optional[float] :Horizon for adaptive KL control + gamma: float :Gamma parameter for advantage calculation + lam: float : Lambda parameter for advantage calculation + cliprange: float : Range for clipping in PPO policy gradient + loss + cliprange_value: float : Range for clipping values in loss + calculation + vf_coef: float : Scaling factor for value loss + batch_size: int :Number of samples per optimisation step + gradient_accumulation_steps: int :The number of gradient + accumulation steps + ppo_epochs: int : Number of optimisation epochs per batch of + samples + max_grad_norm: Optional[float] :Maximum gradient norm for + gradient clipping + target_kl: float :Stop early if we exceed this value by over + 50% + compare_steps: int : Number of steps between comparison of + the current reward with the best seen so far + ratio_threshold: float :Skip mini-batches with high PPO + ratios that can cause loss spikes + use_score_scaling: bool : Use score scaling + use_score_norm: bool : Use score normalization. Only + applicable if use_score_scaling is True + score_clip: Optional[float] :Score clipping + whiten_rewards: bool :Whiten the rewards before compute + advantages + is_encoder_decoder: Optional[bool] :TO BE FILLED In RUNTIME: + Whether the model is an encoder-decoder model + warmup_steps: Optional[int]: + learning_rate_end: float : + extra_optimizer_kwargs: dict | None : + weight_decay: Optional[float] : Weight decay is Optimizer + Weight decay :\ + (e.g. pretrained_model_name_or_path). + 'abs': abs(kl), 'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution """ tracker_kwargs = tracker_kwargs if tracker_kwargs is not None else {} diff --git a/src/python/easydel/reinforcement_learning/trainer/training_configs.py b/src/python/easydel/reinforcement_learning/trainer/training_configs.py index 0d82d4a4d..0ac9465fb 100644 --- a/src/python/easydel/reinforcement_learning/trainer/training_configs.py +++ b/src/python/easydel/reinforcement_learning/trainer/training_configs.py @@ -1,15 +1,15 @@ -from dataclasses import dataclass -from typing import Optional - - -@dataclass -class RewardConfig: - max_length: Optional[int] = None - """ - The maximum length of the sequences in the batch. This argument is - required if you want to use the default data collator. - """ - gradient_checkpointing: Optional[bool] = True - """If True, use gradient checkpointing to save memory at the expense of slower backward pass.""" - gradient_checkpointing_kwargs: Optional[dict] = None - """Keyword arguments to pass to the gradient checkpointing function.""" +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class RewardConfig: + max_length: Optional[int] = None + """ + The maximum length of the sequences in the batch. This argument is + required if you want to use the default data collator. + """ + gradient_checkpointing: Optional[bool] = True + """If True, use gradient checkpointing to save memory at the expense of slower backward pass.""" + gradient_checkpointing_kwargs: Optional[dict] = None + """Keyword arguments to pass to the gradient checkpointing function.""" diff --git a/src/python/easydel/reinforcement_learning/trainer/utils.py b/src/python/easydel/reinforcement_learning/trainer/utils.py index 2bf68935d..6fdc7848c 100644 --- a/src/python/easydel/reinforcement_learning/trainer/utils.py +++ b/src/python/easydel/reinforcement_learning/trainer/utils.py @@ -1,21 +1,21 @@ -from typing import Union - -import chex -import jax.numpy - - -def pad_to_length(tensor: chex.Array, length: int, pad_value: Union[int, float], axis: int = -1) -> chex.Array: - if tensor.shape[axis] >= length: - if tensor.ndim == 2: - tensor = tensor[:, :length] - return tensor - else: - pad_size = list(tensor.shape) - pad_size[axis] = length - tensor.shape[axis] - return jax.numpy.concatenate( - [ - tensor, - pad_value * jax.numpy.ones(pad_size, dtype=tensor.dtype), - ], - axis=axis, - ) +from typing import Union + +import chex +import jax.numpy + + +def pad_to_length(tensor: chex.Array, length: int, pad_value: Union[int, float], axis: int = -1) -> chex.Array: + if tensor.shape[axis] >= length: + if tensor.ndim == 2: + tensor = tensor[:, :length] + return tensor + else: + pad_size = list(tensor.shape) + pad_size[axis] = length - tensor.shape[axis] + return jax.numpy.concatenate( + [ + tensor, + pad_value * jax.numpy.ones(pad_size, dtype=tensor.dtype), + ], + axis=axis, + ) diff --git a/src/python/easydel/reinforcement_learning/utils/collectors.py b/src/python/easydel/reinforcement_learning/utils/collectors.py index 08f2712fb..be98a009a 100644 --- a/src/python/easydel/reinforcement_learning/utils/collectors.py +++ b/src/python/easydel/reinforcement_learning/utils/collectors.py @@ -1,58 +1,59 @@ -from dataclasses import dataclass -from typing import Optional, List, Dict, Any - -from ..core import pad_sequence -from jax import numpy as jnp - - -@dataclass -class DPODataCollatorWithPadding: - r""" - DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch. - - :param pad_token_id: int: The tokenizers pad_token_id. - :param label_pad_token_id: int: The label used for masking. - :param is_encoder_decoder: Optional[bool]: Whether you model has an encoder_decoder architecture - """ - - pad_token_id: int = 0 - label_pad_token_id: int = -100 - is_encoder_decoder: Optional[bool] = False - - def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: - padded_batch = {} - for k in features[0].keys(): - if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): - if self.is_encoder_decoder: - to_pad = [jnp.array(ex[k], dtype="i4") for ex in features] - - if (k.startswith("prompt")) and (k.endswith("input_ids")): - padding_value = self.pad_token_id - elif k.endswith("_attention_mask"): - padding_value = 0 - elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k): - padding_value = self.label_pad_token_id - else: - raise ValueError(f"Unexpected key in batch '{k}'") - padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value).astype("i4") - else: - if "prompt" in k: - to_pad = [jnp.array(ex[k][::-1], dtype="i4") for ex in features] - else: - to_pad = [jnp.array(ex[k], dtype="i4") for ex in features] - if k.endswith("_input_ids"): - padding_value = self.pad_token_id - elif k.endswith("_labels"): - padding_value = self.label_pad_token_id - elif k.endswith("_attention_mask"): - padding_value = 0 - else: - raise ValueError(f"Unexpected key in batch '{k}'") - padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value).astype("i4") - if "prompt" in k: - padded_batch[k] = jnp.flip(padded_batch[k], axis=[1]) - elif k.endswith("_logps"): - padded_batch[k] = jnp.array([ex[k] for ex in features]) - else: - padded_batch[k] = [ex[k] for ex in features] - return padded_batch +from dataclasses import dataclass +from typing import Optional, List, Dict, Any + +from ..core import pad_sequence +from jax import numpy as jnp + + +@dataclass +class DPODataCollatorWithPadding: + r"""DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch. + + Args: + pad_token_id: int: The tokenizers pad_token_id. + label_pad_token_id: int: The label used for masking. + is_encoder_decoder: Optional[bool]: Whether you model has an + encoder_decoder architecture + """ + + pad_token_id: int = 0 + label_pad_token_id: int = -100 + is_encoder_decoder: Optional[bool] = False + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + padded_batch = {} + for k in features[0].keys(): + if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): + if self.is_encoder_decoder: + to_pad = [jnp.array(ex[k], dtype="i4") for ex in features] + + if (k.startswith("prompt")) and (k.endswith("input_ids")): + padding_value = self.pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k): + padding_value = self.label_pad_token_id + else: + raise ValueError(f"Unexpected key in batch '{k}'") + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value).astype("i4") + else: + if "prompt" in k: + to_pad = [jnp.array(ex[k][::-1], dtype="i4") for ex in features] + else: + to_pad = [jnp.array(ex[k], dtype="i4") for ex in features] + if k.endswith("_input_ids"): + padding_value = self.pad_token_id + elif k.endswith("_labels"): + padding_value = self.label_pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + else: + raise ValueError(f"Unexpected key in batch '{k}'") + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value).astype("i4") + if "prompt" in k: + padded_batch[k] = jnp.flip(padded_batch[k], axis=[1]) + elif k.endswith("_logps"): + padded_batch[k] = jnp.array([ex[k] for ex in features]) + else: + padded_batch[k] = [ex[k] for ex in features] + return padded_batch diff --git a/src/python/easydel/serve/gradio_user_interface_base.py b/src/python/easydel/serve/gradio_user_interface_base.py index c98c46ed9..774f82358 100644 --- a/src/python/easydel/serve/gradio_user_interface_base.py +++ b/src/python/easydel/serve/gradio_user_interface_base.py @@ -13,8 +13,7 @@ def chat_interface_components( max_new_tokens: int, max_compile_tokens: int ): - """ - The function `chat_interface_components` creates the components for a chat interface, including + """The function `chat_interface_components` creates the components for a chat interface, including a chat history, message box, buttons for submitting, stopping, and clearing the conversation, and sliders for advanced options. """ @@ -173,10 +172,11 @@ def build_inference( max_new_tokens: int, max_compile_tokens: int ) -> gr.Blocks: - """ - The function "build_inference" returns a gr.Blocks object that model + """The function "build_inference" returns a gr.Blocks object that model interface components. - :return: a gr.Blocks object. + + Returns: + a gr.Blocks object. """ with gr.Blocks( theme=seafoam @@ -191,14 +191,16 @@ def build_inference( def __repr__(self): - """ - The __repr__ function is used to generate a string representation of an object. + """The __repr__ function is used to generate a string representation of an object. This function should return a string that can be parsed by the Python interpreter to recreate the object. The __repr__ function is called when you use print() on an object, or when you type its name in the REPL. - :param self: Refer to the instance of the class - :return: A string representation of the object + Args: + self: Refer to the instance of the class + + Returns: + A string representation of the object """ string = f"{self.__class__.__name__}(\n" for k, v in self.__dict__.items(): @@ -214,11 +216,13 @@ def __repr__(self): def __str__(self): - """ - The __str__ function is called when you use the print function or when str() is used. + """The __str__ function is called when you use the print function or when str() is used. It should return a string representation of the object. - :param self: Refer to the instance of the class - :return: The object's string representation + Args: + self: Refer to the instance of the class + + Returns: + The object's string representation """ return self.__repr__() diff --git a/src/python/easydel/serve/jax_serve.py b/src/python/easydel/serve/jax_serve.py index b4cbb548d..13e65572f 100644 --- a/src/python/easydel/serve/jax_serve.py +++ b/src/python/easydel/serve/jax_serve.py @@ -131,15 +131,15 @@ class JAXServer(GradioUserInference): def __init__(self, server_config=None): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It sets up all the attributes that will be used by other methods in the class. + Args: + self: Refer to the current instance of a class + server_config: Pass the JAXServerConfig object - :param self: Refer to the current instance of a class - :param server_config: Pass the JAXServerConfig object - :return: A fastapi object - + Returns: + A fastapi object """ ( self.process_uvicorn, @@ -172,8 +172,7 @@ def __init__(self, server_config=None): self.app = gr.mount_gradio_app(self.app, self.gradio_inference(), "/gradio_chat") def status(self): - """ - The status function returns a dictionary with the following keys: + """The status function returns a dictionary with the following keys: server_config: A dictionary containing all the configuration parameters for this server. devices: A string describing which devices are available to JAX. number_of_backends: The number of backends available to JAX. This is usually equal to the number of GPUs @@ -181,9 +180,11 @@ def status(self): system BIOS settings (e.g., because they are defective). It can also be more than one if you have multiple machines connected via MPI and running under Horov - :param self: Represent the instance of the class - :return: A dictionary with the following keys: - + Args: + self: Represent the instance of the class + + Returns: + A dictionary with the following keys: """ return { "server_config": {k: v for k, v in self.server_config.__dict__.items()}, @@ -196,25 +197,25 @@ def status(self): @staticmethod def get_memory(): - """ - The get_memory function returns the total memory of the system in bytes. - + """The get_memory function returns the total memory of the system in bytes. - :return: The amount of memory used by the program - + Returns: + The amount of memory used by the program """ return get_mem() def configure_generate_functions(self, model, tokenizer): - """ - The configure_generate_functions function is used to configure the generation functions for a given model. + """The configure_generate_functions function is used to configure the generation functions for a given model. + + Args: + self: Access variables within the class + model: Generate the model + tokenizer: Get the eos_token_id, pad_token_id and bos token + id - :param self: Access variables within the class - :param model: Generate the model - :param tokenizer: Get the eos_token_id, pad_token_id and bos token id - :return: A function that takes in three parameters: - + Returns: + A function that takes in three parameters: """ assert self.partition_specs is not None, "you should first shard params with using ``shard_params`` method" @@ -301,19 +302,22 @@ def generate(parameters, input_ids, attention_mask): self._funcs_generated = True def auto_configure(self, model, params, tokenizer, partition_rules): - """ - The auto_configure function is a helper function that will automatically configure the model for distributed training. + """The auto_configure function is a helper function that will automatically configure the model for distributed training. It does this by: 1) sharding the parameters of the model based on partition_rules, and then 2) configuring generate functions to be used in distributed training. - :param self: Represent the instance of the class - :param model: Configure the model - :param params: Store the parameters that are used to configure the model - :param tokenizer: Tokenize the input text - :param partition_rules: Specify how the parameters should be partitioned - :return: A dictionary with the following keys: - + Args: + self: Represent the instance of the class + model: Configure the model + params: Store the parameters that are used to configure the + model + tokenizer: Tokenize the input text + partition_rules: Specify how the parameters should be + partitioned + + Returns: + A dictionary with the following keys: """ self.shard_params(params=params, partition_rules=partition_rules) self.configure_generate_functions(model, tokenizer) @@ -324,15 +328,17 @@ def generate( input_ids: chex.Array, attention_mask: chex.Array, ): - """ - The generate function is used to generate a sequence of tokens from the model. - - :param self: Access variables that belong to the class - :param params: Union[flax.core.FrozenDict, dict]: Pass the parameters of the model to be used in generating text - :param input_ids: chex.Array: Pass the input to the model - :param attention_mask: chex.Array: Mask the padding tokens - :return: The logits of the model - + """The generate function is used to generate a sequence of tokens from the model. + + Args: + self: Access variables that belong to the class + params: Union[flax.core.FrozenDict, dict]: Pass the + parameters of the model to be used in generating text + input_ids: chex.Array: Pass the input to the model + attention_mask: chex.Array: Mask the padding tokens + + Returns: + The logits of the model """ if not self._funcs_generated: raise NotImplementedError( @@ -357,21 +363,27 @@ def load( do_memory_log: bool = False, verbose: bool = True ) -> "JAXServer": - """ - The load function is used to load a pretrained model from disk. - - :param cls: Refer to the class itself - :param model: transformers.FlaxPreTrainedModel: Initialize the server - :param config_model: transformers.PretrainedConfig: Get the partition rules - :param tokenizer: transformers.PreTrainedTokenizer: Load the tokenizer from the model - :param path: Union[str, os.PathLike]: Specify the path to the checkpoint file - :param server_config: Configure the server - :param add_params_field: bool: Add a params field to the server - :param init_shape: tuple: Specify the shape of the input to be used for generating shard_fns - :param do_memory_log: bool: Log the memory usage of the server - :param verbose: bool: Print the compilation process - :return: A server - + """The load function is used to load a pretrained model from disk. + + Args: + cls: Refer to the class itself + model: transformers.FlaxPreTrainedModel: Initialize the + server + config_model: transformers.PretrainedConfig: Get the + partition rules + tokenizer: transformers.PreTrainedTokenizer: Load the + tokenizer from the model + path: Union[str, os.PathLike]: Specify the path to the + checkpoint file + server_config: Configure the server + add_params_field: bool: Add a params field to the server + init_shape: tuple: Specify the shape of the input to be used + for generating shard_fns + do_memory_log: bool: Log the memory usage of the server + verbose: bool: Print the compilation process + + Returns: + A server """ assert hasattr(model, "init_weights"), "model must contain init_weights func in order to init params for shard_fns" @@ -500,8 +512,7 @@ def from_parameters( shard_parameters: bool = False, verbose: bool = True ) -> "JAXServer": - """ - The from_parameters function is used to load a model from the parameters of a pretrained model. + """The from_parameters function is used to load a model from the parameters of a pretrained model. It takes in the following arguments: - cls: The class of the server you are loading, this should be Server or TPU_Server depending on what backend you want to use. @@ -510,18 +521,22 @@ def from_parameters( where *model* is replaced with whatever transformer you are using (e.g., bert). You can also create your own custom - :param cls: Create a new instance of the class - :param model: transformers.FlaxPreTrainedModel: Load the model - :param config_model: transformers.PretrainedConfig: Get the partition rules - :param tokenizer: transformers.PreTrainedTokenizer: Tokenize the input text - :param params: Dict: Pass in the parameters of the model - :param server_config: Pass in the server_config file for the server - :param add_params_field: bool: Add a params field to the server - :param do_memory_log: bool: Log the memory usage of the server - :param shard_parameters:bool: whenever a shard model parameters. - :param verbose: bool: Print out the status of the compilation - :return: A server object - + Args: + cls: Create a new instance of the class + model: transformers.FlaxPreTrainedModel: Load the model + config_model: transformers.PretrainedConfig: Get the + partition rules + tokenizer: transformers.PreTrainedTokenizer: Tokenize the + input text + params: Dict: Pass in the parameters of the model + server_config: Pass in the server_config file for the server + add_params_field: bool: Add a params field to the server + do_memory_log: bool: Log the memory usage of the server + shard_parameters: bool: whenever a shard model parameters. + verbose: bool: Print out the status of the compilation + + Returns: + A server object """ assert hasattr(model, "init_weights"), ( "model must contain init_weights func in order to init params for shard_fns" @@ -564,16 +579,18 @@ def from_parameters( return server def compile(self, verbose: bool = True) -> bool: - """ - The compile function is used to compile the model for use in inference. + """The compile function is used to compile the model for use in inference. It does this by running through all possible combinations of rules and actions, and compiling them into functions that can be called later on during inference. This allows us to avoid having to recompile the model every time we want to run it, which would be very slow. - :param self: Represent the instance of the class - :param verbose: bool: Print out the compiling process - :return: True, but what does it do? + Args: + self: Represent the instance of the class + verbose: bool: Print out the compiling process + + Returns: + True, but what does it do? """ assert self._funcs_generated, "funcs are not generated yet" assert self.partition_specs is not None, "rules should not be None" @@ -609,18 +626,19 @@ def greedy_generate(self, input_ids: chex.Array, attention_mask: chex.Array, ): - """ - The greedy_generate function is a helper function that takes in the model parameters, input_ids and attention_mask + """The greedy_generate function is a helper function that takes in the model parameters, input_ids and attention_mask and returns the generated tokens. It uses greedy search to generate tokens one at a time. - - :param self: Refer to the object itself - :param params: Union[flax.core.FrozenDict, dict]: Pass the parameters to the model - :param input_ids: chex.Array: Pass in the input sequence - :param attention_mask: chex.Array: Mask the input tokens + Args: + self: Refer to the object itself + params: Union[flax.core.FrozenDict, dict]: Pass the + parameters to the model + input_ids: chex.Array: Pass in the input sequence + attention_mask: chex.Array: Mask the input tokens :param : Specify the parameters of the model - :return: generated_ids - + + Returns: + generated_ids """ if not self._funcs_generated: raise NotImplementedError( @@ -634,18 +652,20 @@ def greedy_generate(self, def shard_params(self, params, partition_rules): - """ - The shard_params function takes in a set of parameters and a partition rule. + """The shard_params function takes in a set of parameters and a partition rule. The partition rule is used to determine how the parameters should be sharded across devices. For example, if we have two devices, one with 4GB of memory and another with 8GB of memory, we may want to shard our model such that the device with more memory has more parameters on it. This function returns an updated version of params where each parameter is now stored on its own device. - :param self: Bind the instance of the class to a method - :param params: Pass the parameters of the model to be sharded - :param partition_rules: Specify how the parameters should be partitioned - :return: The sharded parameters - + Args: + self: Bind the instance of the class to a method + params: Pass the parameters of the model to be sharded + partition_rules: Specify how the parameters should be + partitioned + + Returns: + The sharded parameters """ logging.log( logging.INFO, @@ -663,8 +683,7 @@ def shard_params(self, params, partition_rules): def forward_chat(self, data: ChatRequest): - """ - The forward_chat function is the main function of this class. + """The forward_chat function is the main function of this class. It takes in a ChatRequest object, which contains a prompt and history. The prompt is the user"s input to be processed by the chatbot, while history is an array of previous inputs and outputs from both sides (user and bot). @@ -672,10 +691,12 @@ def forward_chat(self, data: ChatRequest): This formatted string is then passed through our sample() method, which returns an output response as well as how many tokens were used to generate it. - :param self: Access the attributes and methods of the class - :param data: ChatRequest: Pass in the data from the request - :return: A dictionary with the following keys: - + Args: + self: Access the attributes and methods of the class + data: ChatRequest: Pass in the data from the request + + Returns: + A dictionary with the following keys: """ if not self._funcs_generated: return { @@ -703,9 +724,7 @@ def forward_chat(self, data: ChatRequest): } def format_instruct(self, system: str, instruction: str) -> str: - """ - Here you will get the system and instruction from user, and you can apply your prompting style - """ + """Here you will get the system and instruction from user, and you can apply your prompting style""" conversation = [] if system is not None and system != "": conversation.append({ @@ -721,9 +740,7 @@ def format_instruct(self, system: str, instruction: str) -> str: ) def format_chat(self, history: List[List[str]], prompt: str, system: Union[str, None]) -> str: - """ - Here you will get the system, prompt and history from user, and you can apply your prompting style - """ + """Here you will get the system, prompt and history from user, and you can apply your prompting style""" conversation = [] if system is not None and system != "": conversation.append({ @@ -753,17 +770,19 @@ def format_chat(self, history: List[List[str]], prompt: str, system: Union[str, ) def forward_instruct(self, data: InstructRequest): - """ - The forward_instruct function is the main function of this class. + """The forward_instruct function is the main function of this class. It takes in a InstructRequest object, which contains the system and instruction to be processed. The function then formats the input string using format_instruct, and passes it into sample(). sample() returns a tuple containing (response, used_tokens). The response is returned as part of the response dictionary. If no valid responses are found by sample(), None will be returned instead. - :param self: Bind the method to the object - :param data: InstructRequest: Pass the system and instruction to the function - :return: A dictionary with three keys: - + Args: + self: Bind the method to the object + data: InstructRequest: Pass the system and instruction to + the function + + Returns: + A dictionary with three keys: """ if not self._funcs_generated: return { @@ -789,18 +808,19 @@ def forward_instruct(self, data: InstructRequest): } def forward_instruct_non_api(self, prompt, system, greedy): - """ - The forward_instruct_non_api function is a wrapper for the forward_instruct function. + """The forward_instruct_non_api function is a wrapper for the forward_instruct function. It takes in a prompt, system, and greedy flag as arguments and returns the response from the forward_instruct function. The purpose of this wrapper is to allow users to call forward_instruct without having to create an InstructRequest object. - :param self: Represent the instance of the class - :param prompt: Pass the instruction to the system - :param system: Specify which system to use for the instruction - :param greedy: Determine whether the system should return - :return: The response from the forward_instruct function - + Args: + self: Represent the instance of the class + prompt: Pass the instruction to the system + system: Specify which system to use for the instruction + greedy: Determine whether the system should return + + Returns: + The response from the forward_instruct function """ data = InstructRequest( prompt=prompt, @@ -810,18 +830,20 @@ def forward_instruct_non_api(self, prompt, system, greedy): return self.forward_instruct(data) def forward_chat_non_api(self, prompt, history, greedy): - """ - The forward_chat_non_api function is a wrapper for the forward_chat function. + """The forward_chat_non_api function is a wrapper for the forward_chat function. It takes in a prompt, history, and greedy parameter and returns the response from the forward_chat function. The purpose of this wrapper is to allow users to use the chatbot without having to create ChatRequest objects. - :param self: Represent the instance of the class - :param prompt: Pass the user's input to the model - :param history: Pass the history of the conversation to the model - :param greedy: Determine whether the model should use a greedy search - :return: A chat-response object - + Args: + self: Represent the instance of the class + prompt: Pass the user's input to the model + history: Pass the history of the conversation to the model + greedy: Determine whether the model should use a greedy + search + + Returns: + A chat-response object """ data = ChatRequest( prompt=prompt, @@ -875,19 +897,23 @@ def sample(self, max_new_tokens: int = None, **kwargs ): - """ - The sample function is the main function of a model. It takes in an input string and returns a list of strings + """The sample function is the main function of a model. It takes in an input string and returns a list of strings that are generated from that input string. The sample function can be called multiple times with different inputs, and each time it will return a new set of outputs based on those inputs. - :param self: Access the class attributes - :param string: str: Pass the string that we want to generate - :param *: Pass a variable number of arguments to a function - :param greedy: bool: Determine whether to use the greedy or non-greedy version of the generate function - :param max_new_tokens: int: Set the number of tokens to generate - :param kwargs: Pass any additional parameters to the sample function - :return: A generator that yields the predicted text and the number of tokens generated - + Args: + self: Access the class attributes + string: str: Pass the string that we want to generate + : Pass a variable number of arguments to a function + greedy: bool: Determine whether to use the greedy or non- + greedy version of the generate function + max_new_tokens: int: Set the number of tokens to generate + **kwargs: Pass any additional parameters to the sample + function + + Returns: + A generator that yields the predicted text and the number of + tokens generated """ fixed_pad = self.server_config.max_sequence_length - self.server_config.max_compile_tokens @@ -950,16 +976,17 @@ def sample(self, break def fire(self): - """ - The fire function is a wrapper around the uvicorn.run function that allows you + """The fire function is a wrapper around the uvicorn.run function that allows you to run your model in a separate process from the main one. This is useful for running models on GPUs, as it prevents any other processes from using them while the model is being served. - :param self: Refer to the instance of the class - :return: A process, which is a child of the main process - + Args: + self: Refer to the instance of the class + + Returns: + A process, which is a child of the main process """ assert self._funcs_generated, "you have to first add your model and parameters into server before using fire " \ "with using ``configure_generate_functions``" @@ -971,13 +998,14 @@ def run(): self.process_uvicorn.start() def end(self): - """ - The end function is used to stop the server. + """The end function is used to stop the server. It will wait for the process to end before returning. - :param self: Represent the instance of the class - :return: The process_uvicorn - + Args: + self: Represent the instance of the class + + Returns: + The process_uvicorn """ if self.process_uvicorn is not None: self.process_uvicorn.join() diff --git a/src/python/easydel/serve/prompters/__init__.py b/src/python/easydel/serve/prompters/__init__.py index 3545ac442..126c36b98 100644 --- a/src/python/easydel/serve/prompters/__init__.py +++ b/src/python/easydel/serve/prompters/__init__.py @@ -1,7 +1,7 @@ -from .cargo_prompter import CargoPrompter -from .guanaco_prompter import GuanacoPrompter -from .llama2_prompter import Llama2Prompter -from .openchat_prompter import OpenChatPrompter -from .chatml_prompter import ChatMLPrompter -from .gemma_prompter import GemmaPrompter -from .zephyr_prompter import ZephyrPrompter +from .cargo_prompter import CargoPrompter +from .guanaco_prompter import GuanacoPrompter +from .llama2_prompter import Llama2Prompter +from .openchat_prompter import OpenChatPrompter +from .chatml_prompter import ChatMLPrompter +from .gemma_prompter import GemmaPrompter +from .zephyr_prompter import ZephyrPrompter diff --git a/src/python/easydel/serve/prompters/base_prompter.py b/src/python/easydel/serve/prompters/base_prompter.py index 9ba8d5449..d3b5d0a23 100644 --- a/src/python/easydel/serve/prompters/base_prompter.py +++ b/src/python/easydel/serve/prompters/base_prompter.py @@ -79,14 +79,16 @@ def retrival_qa_template( ) + self.assistant_message_token def __repr__(self): - """ - The __repr__ function is used to generate a string representation of an object. + """The __repr__ function is used to generate a string representation of an object. This function should return a string that can be parsed by the Python interpreter to recreate the object. The __repr__ function is called when you use print() on an object, or when you type its name in the REPL. - :param self: Refer to the instance of the class - :return: A string representation of the object + Args: + self: Refer to the instance of the class + + Returns: + A string representation of the object """ string = f"{self.__class__.__name__}(\n" for k, v in self.__dict__.items(): @@ -101,11 +103,13 @@ def __repr__(self): return string + ")" def __str__(self): - """ - The __str__ function is called when you use the print function or when str() is used. + """The __str__ function is called when you use the print function or when str() is used. It should return a string representation of the object. - :param self: Refer to the instance of the class - :return: The object's string representation + Args: + self: Refer to the instance of the class + + Returns: + The object's string representation """ return self.__repr__() diff --git a/src/python/easydel/serve/serve_engine/__init__.py b/src/python/easydel/serve/serve_engine/__init__.py index d204990f3..49423e743 100644 --- a/src/python/easydel/serve/serve_engine/__init__.py +++ b/src/python/easydel/serve/serve_engine/__init__.py @@ -1,9 +1,9 @@ -from .serve import EasyServe, EasyServeConfig, LLMBaseReq -from .client import EasyClient - -__all__ = ( - "EasyServe", - "EasyServeConfig", - "LLMBaseReq", - "EasyClient" -) +from .serve import EasyServe, EasyServeConfig, LLMBaseReq +from .client import EasyClient + +__all__ = ( + "EasyServe", + "EasyServeConfig", + "LLMBaseReq", + "EasyClient" +) diff --git a/src/python/easydel/serve/serve_engine/configuration.py b/src/python/easydel/serve/serve_engine/configuration.py index 2fb437386..17a6d5d23 100644 --- a/src/python/easydel/serve/serve_engine/configuration.py +++ b/src/python/easydel/serve/serve_engine/configuration.py @@ -1,91 +1,103 @@ -from typing import Sequence, Optional -from jax.sharding import PartitionSpec -from dataclasses import dataclass - - -@dataclass -class EasyServeConfig: - """ - :param host: str: Set the host address of the server - :param port: int: Specify the port number that the server will run on - :param batch_size: int: Set the batch size of the model - :param max_sequence_length: int: Set the maximum length of the text that can be generated - :param max_new_tokens: int: Determine how many tokens can be added to the vocabulary - :param max_compile_tokens: int: Set the maximum number of tokens that can be streamed at a time - :param generation_ps: jax.sharding.PartitionSpec : PartitionSpec to use for sharding data - :param temperature: float: Control the randomness of the output - :param top_p: float: Control the diversity of the text generated - :param top_k: int: Limit the number of tokens that can be generated - :param logging: bool: Print out the progress of the server - :param mesh_axes_names: Sequence[str]: Specify the names of the axes in the mesh tensor - :param mesh_axes_shape: Sequence[int]: Specify the shape of the mesh - :param dtype: str: Specify the data type of the model - :param use_prefix_tokenizer: bool: Determine if the tokenizer should be used to generate tokens - :param pre_compile: bool: Pre-compile the model - :return: Nothing - - """ - host: str = "0.0.0.0" - port: int = 2059 - - batch_size: int = 1 - max_sequence_length: int = 4096 - max_new_tokens: int = 4096 - max_compile_tokens: int = 64 - temperature: float = 0.1 - top_p: float = 0.95 - top_k: int = 50 - repetition_penalty: float = 1.2 - greedy: bool = False - - logging: bool = True - - mesh_axes_names: Sequence[str] = ("dp", "fsdp", "tp", "sp") - mesh_axes_shape: Sequence[int] = (1, -1, 1, 1) - generation_ps: PartitionSpec = PartitionSpec("dp", "fsdp") - dtype: str = "fp16" - - eos_token_id: Optional[int] = None - pad_token_id: Optional[int] = None - bos_token_id: Optional[int] = None - - use_prefix_tokenizer: bool = True - pre_compile: bool = True - - verbose: bool = True - - use_mxn_break_point: bool = True - - def __repr__(self): - - """ - The __repr__ function is used to generate a string representation of an object. - This function should return a string that can be parsed by the Python interpreter - to recreate the object. The __repr__ function is called when you use print() on an - object, or when you type its name in the REPL. - - :param self: Refer to the instance of the class - :return: A string representation of the object - """ - string = f"{self.__class__.__name__}(\n" - for k, v in self.__dict__.items(): - if not k.startswith("_"): - - try: - repr_src = f"\t{k} : " + v.__str__().replace("\n", "\n\t") + "\n" - string += repr_src if len(repr_src) < 500 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n" - except TypeError: - ... - - return string + ")" - - def __str__(self): - - """ - The __str__ function is called when you use the print function or when str() is used. - It should return a string representation of the object. - - :param self: Refer to the instance of the class - :return: The object's string representation - """ - return self.__repr__() +from typing import Sequence, Optional +from jax.sharding import PartitionSpec +from dataclasses import dataclass + + +@dataclass +class EasyServeConfig: + """ + Args: + host: str: Set the host address of the server + port: int: Specify the port number that the server will run on + batch_size: int: Set the batch size of the model + max_sequence_length: int: Set the maximum length of the text + that can be generated + max_new_tokens: int: Determine how many tokens can be added to + the vocabulary + max_compile_tokens: int: Set the maximum number of tokens that + can be streamed at a time + generation_ps: jax.sharding.PartitionSpec : PartitionSpec to use + for sharding data + temperature: float: Control the randomness of the output + top_p: float: Control the diversity of the text generated + top_k: int: Limit the number of tokens that can be generated + logging: bool: Print out the progress of the server + mesh_axes_names: Sequence[str]: Specify the names of the axes in + the mesh tensor + mesh_axes_shape: Sequence[int]: Specify the shape of the mesh + dtype: str: Specify the data type of the model + use_prefix_tokenizer: bool: Determine if the tokenizer should be + used to generate tokens + pre_compile: bool: Pre-compile the model + + Returns: + Nothing + """ + host: str = "0.0.0.0" + port: int = 2059 + + batch_size: int = 1 + max_sequence_length: int = 4096 + max_new_tokens: int = 4096 + max_compile_tokens: int = 64 + temperature: float = 0.1 + top_p: float = 0.95 + top_k: int = 50 + repetition_penalty: float = 1.2 + greedy: bool = False + + logging: bool = True + + mesh_axes_names: Sequence[str] = ("dp", "fsdp", "tp", "sp") + mesh_axes_shape: Sequence[int] = (1, -1, 1, 1) + generation_ps: PartitionSpec = PartitionSpec("dp", "fsdp") + dtype: str = "fp16" + + eos_token_id: Optional[int] = None + pad_token_id: Optional[int] = None + bos_token_id: Optional[int] = None + + use_prefix_tokenizer: bool = True + pre_compile: bool = True + + verbose: bool = True + + use_mxn_break_point: bool = True + + def __repr__(self): + + """The __repr__ function is used to generate a string representation of an object. + This function should return a string that can be parsed by the Python interpreter + to recreate the object. The __repr__ function is called when you use print() on an + object, or when you type its name in the REPL. + + Args: + self: Refer to the instance of the class + + Returns: + A string representation of the object + """ + string = f"{self.__class__.__name__}(\n" + for k, v in self.__dict__.items(): + if not k.startswith("_"): + + try: + repr_src = f"\t{k} : " + v.__str__().replace("\n", "\n\t") + "\n" + string += repr_src if len(repr_src) < 500 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n" + except TypeError: + ... + + return string + ")" + + def __str__(self): + + """The __str__ function is called when you use the print function or when str() is used. + It should return a string representation of the object. + + Args: + self: Refer to the instance of the class + + Returns: + The object's string representation + """ + return self.__repr__() diff --git a/src/python/easydel/serve/serve_engine/serve.py b/src/python/easydel/serve/serve_engine/serve.py index d6336f0a4..74a053ef6 100644 --- a/src/python/easydel/serve/serve_engine/serve.py +++ b/src/python/easydel/serve/serve_engine/serve.py @@ -1,506 +1,541 @@ -import asyncio -import copy -import functools -import json -import logging -import time -import warnings - -import jax -import websocket -import websockets -from fjformer import with_sharding_constraint, match_partition_rules, make_shard_and_gather_fns, get_dtype -from jax import numpy as jnp - -from ...etils.etils import get_logger -from ...modules.easydel_modelling_utils import EasyDeLFlaxPretrainedModel -from flax.core import FrozenDict -from transformers import PreTrainedTokenizerBase, GenerationConfig -from typing import Callable, Tuple, List, Optional, Union, Dict -from .configuration import EasyServeConfig -from jax.sharding import PartitionSpec, Mesh -from jax.experimental.pjit import pjit -from dataclasses import dataclass - -logger = get_logger(__name__) - - -@dataclass -class LLMBaseReq: - greedy_generate_function: Callable - non_greedy_generate_function: Callable - tokenizer: PreTrainedTokenizerBase - prefix_tokenizer: PreTrainedTokenizerBase - - -class EasyServe: - def __init__( - self, - llm: EasyDeLFlaxPretrainedModel, - params: Union[FrozenDict, dict], - tokenizer: PreTrainedTokenizerBase, - prefix_tokenizer: PreTrainedTokenizerBase, - greedy_generate_function: Callable, - non_greedy_generate_function: Callable, - serve_config: EasyServeConfig, - ): - self.llm = llm - self.params = params - self.tokenizer = tokenizer - self.prefix_tokenizer = prefix_tokenizer - self.greedy_generate_function = greedy_generate_function - self.non_greedy_generate_function = non_greedy_generate_function - self.serve_config = serve_config - if serve_config.pre_compile: - self.compile(verbose=serve_config.verbose) - - def get_generation_function(self, greedy: bool): - return self.greedy_generate_function if greedy else self.non_greedy_generate_function - - def conversation_template(self, conversation: List[Dict]) -> str: - """ - The conversation_template function takes a list of ConversationItem objects and returns a string. - where system message, user message, and assistant message are the content fields of the ConversationItem objects. - If there is no system message in the conversation, then it will be omitted from the template. - - :param self: Refer to the current instance of a class - :param conversation: List[ConversationItem]: Pass in the conversation items - :return: A string that is a concatenation of the messages in the conversation - - """ - return self.tokenizer.apply_chat_template( - conversation=conversation, - add_generation_prompt=True, - tokenize=False - ) - - async def generate(self, socket): - data = json.loads(await socket.recv()) - prompt = self.conversation_template(data["conversation"]) - max_new_tokens = data.get("max_new_tokens", None) or self.serve_config.max_new_tokens - greedy = data.get("greedy", None) or self.serve_config.greedy - start = time.time() - send_data = {} - prl_res = 0 - for response, num_token_generated in self.sample( - string=prompt, - max_new_tokens=max_new_tokens, - greedy=greedy, - - ): - generation_duration = time.time() - start - tokens_pre_second = num_token_generated / generation_duration - - send_data = { - "response": response[prl_res:], - "num_token_generated": num_token_generated, - "greedy": greedy, - "model_prompt": prompt, - "generation_duration": generation_duration, - "tokens_pre_second": tokens_pre_second, - "done": False - } - prl_res += len(response) - await socket.send(json.dumps(send_data)) - - send_data["done"] = True - send_data["response"] = "" - await socket.send(json.dumps(send_data)) - - async def handle_client(self, socket: websocket.WebSocket, path: str): - try: - logger.info("connection open") - if path == "/stream/v1/conversation": - await self.generate(socket) - elif path == "/": - await socket.send(json.dumps({"status": "AgentX server is Running..."})) - else: - await socket.send(json.dumps({"error": f"invalid path {path}"})) - except websockets.ConnectionClosed: - logger.info("connection closed") - except Exception as e: - logger.warning(f"Error: {e}") - - @staticmethod - def create_shard_and_gather_functions( - parameters: dict, - partition_rules: Tuple[Tuple[str, PartitionSpec]], - dtype: Union[jax.numpy.dtype, str] = "fp16" - ): - - """ - The create_shard_and_gather_functions function takes in a dictionary of parameters, - a tuple of partition rules, and an optional dtype. It then matches the partition rules to the - parameters and creates shard functions for each parameter. The shard functions are used to - split up a parameter into shards (or partitions) that can be stored on different devices. - The gather function is used to combine all the shards back together again. - - :param parameters: dict: Specify the parameters of the model - :param partition_rules: Tuple[Tuple[str, PartitionSpec]]: Specify which parameters to partition - :param dtype: jax.numpy.dtype | str: Specify the data type of the parameters - :return: A tuple of three elements: - """ - partition_specs = match_partition_rules(partition_rules, parameters) - shard_fns, gather_fns = make_shard_and_gather_fns( - partition_specs=partition_specs, - dtype_specs=get_dtype(dtype) - ) - return shard_fns, gather_fns, partition_specs - - @staticmethod - def shard_parameters( - mesh: Mesh, - params: Union[FrozenDict, dict], - partition_rules: Tuple[Tuple[str, PartitionSpec]], - serve_config: EasyServeConfig, - ): - - """ - The shard_parameters function takes a set of parameters and partitions them according to the partition_rules. - - :param mesh: Mesh: Create a mesh object that is used to shard the parameters - :param params: FrozenDict | dict: Pass in the parameters of the model - :param partition_rules: Tuple[Tuple[str, PartitionSpec]]: Specify the partitioning rules for each parameter - :param serve_config: EasyServeConfig: Specify the dtype of the parameters - :param : Create a mesh of devices - :return: sharded parameters - """ - - partition_specs = match_partition_rules(params=params, rules=partition_rules) - shard_fns, _ = make_shard_and_gather_fns(partition_specs, get_dtype(serve_config.dtype)) - - with mesh: - params = jax.tree_map( - lambda func, param: func(param), shard_fns, params - ) - - return params - - @staticmethod - def create_generation_functions_and_tokenizers( - model: EasyDeLFlaxPretrainedModel, - tokenizer: PreTrainedTokenizerBase, - serve_config: EasyServeConfig, - partition_specs: dict[str, PartitionSpec] - ) -> LLMBaseReq: - """ - The create_generation_functions_and_tokenizers function is used to create the functions that will be used for - generation. It also creates a tokenizer object that can be used to encode and decode text. The function takes in - a model, a tokenizer, an EasyServeConfig object (which contains all the parameters needed for generation), and - partition_specs which are specifications about how data should be partitioned across devices. - - :param model: EasyDeLFlaxPretrainedModel: Create the model and tokenizer - :param tokenizer: PreTrainedTokenizerBase: Create a tokenizer object - :param serve_config: EasyServeConfig: Create the generation function - :param partition_specs: dict[str, PartitionSpec]: Specify the sharding of the model parameters - :return: An LLMBaseReq object - """ - if tokenizer.pad_token is None: - logging.info( - "Tokenizer does not contain padding token setting padding token to eos token for open end generation") - tokenizer.pad_token = tokenizer.eos_token - - try: - tokenizer.padding_side = "left" - tokenizer.truncation_side = "left" - prefix_tokenizer = copy.deepcopy(tokenizer) - tokenizer.padding_side = "right" - tokenizer.truncation_side = "right" - tokenizer = copy.deepcopy(tokenizer) - - except: - warnings.warn( - f"The class Model of Tokenizer {type(tokenizer)} do not support deepcopy option " - ) - if serve_config.use_prefix_tokenizer: - tokenizer.padding_side = "left" - tokenizer.truncation_side = "left" - else: - tokenizer.padding_side = "right" - tokenizer.truncation_side = "right" - prefix_tokenizer = tokenizer - - @functools.partial( - pjit, - in_shardings=(partition_specs, PartitionSpec(), PartitionSpec()), - out_shardings=(PartitionSpec()) - ) - def greedy_generate_function( - parameters, - input_ids, - attention_mask - ): - input_ids = with_sharding_constraint(input_ids, serve_config.generation_ps) - attention_mask = with_sharding_constraint(attention_mask, serve_config.generation_ps) - predict = model.generate( - input_ids, - attention_mask=attention_mask, - params=parameters, - generation_config=GenerationConfig( - max_new_tokens=serve_config.max_compile_tokens, - - eos_token_id=serve_config.eos_token_id or tokenizer.eos_token_id, - pad_token_id=serve_config.pad_token_id or tokenizer.pad_token_id, - bos_token_id=serve_config.bos_token_id or tokenizer.bos_token_id, - - do_sample=False, - num_beams=1, - ) - ).sequences[:, input_ids.shape[1]:] - return predict - - @functools.partial( - pjit, - in_shardings=(partition_specs, PartitionSpec(), PartitionSpec()), - out_shardings=(PartitionSpec()) - ) - def non_greedy_generate_function( - parameters, - input_ids, - attention_mask - ): - input_ids = with_sharding_constraint(input_ids, serve_config.generation_ps) - attention_mask = with_sharding_constraint(attention_mask, serve_config.generation_ps) - predict = model.generate( - input_ids, - attention_mask=attention_mask, - params=parameters, - generation_config=GenerationConfig( - max_new_tokens=serve_config.max_compile_tokens, - - eos_token_id=serve_config.eos_token_id or tokenizer.eos_token_id, - pad_token_id=serve_config.pad_token_id or tokenizer.pad_token_id, - bos_token_id=serve_config.bos_token_id or tokenizer.bos_token_id, - - temperature=serve_config.temperature, - repetition_penalty=serve_config.repetition_penalty, - do_sample=True, - num_beams=1, - top_p=serve_config.top_p, - top_k=serve_config.top_k, - ) - ).sequences[:, input_ids.shape[1]:] - return predict - - return LLMBaseReq( - greedy_generate_function=greedy_generate_function, - non_greedy_generate_function=non_greedy_generate_function, - tokenizer=tokenizer, - prefix_tokenizer=prefix_tokenizer - ) - - @classmethod - def from_parameters( - cls, - llm: EasyDeLFlaxPretrainedModel, - params: dict, - tokenizer: PreTrainedTokenizerBase, - serve_config: EasyServeConfig, - partition_rules: Tuple[Tuple[str, PartitionSpec]], - shard_parameters: bool = True, - ): - - """ - The from_parameters function is the main entry point for creating a model that can be served. - It takes in a pretrained model, parameters, tokenizer and serve_config as input and returns an object of type - EasyServe. - - :param cls: Create a new instance of the class - :param llm: EasyDeLFlaxPretrainedModel: Pass the model to the class - :param params: dict: Pass the parameters of the model - :param tokenizer: PreTrainedTokenizerBase: Create the tokenizer and prefix_tokenizer - :param serve_config: EasyServeConfig: Configure the model for serving - :param partition_rules: Tuple[Tuple[str, PartitionSpec]]: Partition the parameters of the model - :param shard_parameters: bool: Specify whether the parameters should be sharded or not - :param : Shard the parameters of the model - :return: A EasyServe object - """ - shard_fns, gather_fns, partition_specs = cls.create_shard_and_gather_functions( - parameters=params, - partition_rules=partition_rules, - dtype=serve_config.dtype - ) - llm_base_req = cls.create_generation_functions_and_tokenizers( - model=llm, - tokenizer=tokenizer, - partition_specs=partition_specs, - serve_config=serve_config - ) - - if shard_parameters: - params = cls.shard_parameters( - params=params, - partition_rules=partition_rules, - serve_config=serve_config, - mesh=llm.config.jax_mesh() - ) - - return cls( - llm=llm, - serve_config=serve_config, - tokenizer=llm_base_req.tokenizer, - prefix_tokenizer=llm_base_req.prefix_tokenizer, - params=params, - greedy_generate_function=llm_base_req.greedy_generate_function, - non_greedy_generate_function=llm_base_req.non_greedy_generate_function, - ) - - def sample( - self, - string: str, - *, - greedy: bool = False, - max_new_tokens: int = None, - **kwargs - ): - """ - The process function is the main function of a model. It takes in an input string and returns a list of strings - that are generated from that input string. The process function can be called multiple times with different inputs, - and each time it will return a new set of outputs based on those inputs. - - :param self: Access the class attributes - :param string: str: Pass the string that we want to generate - :param greedy: bool: Determine whether to use the greedy or non-greedy version of the generate function - :param max_new_tokens: int: Set the number of tokens to generate - :param kwargs: Pass any additional parameters to the process function - :return: A generator that yields the predicted text and the number of tokens generated - - """ - with self.llm.config.jax_mesh(): - fixed_pad = self.serve_config.max_sequence_length - self.serve_config.max_compile_tokens - tokens = self.prefix_tokenizer( - string, - max_length=fixed_pad, - padding="max_length", - return_tensors="jax" - ) if self.serve_config.use_prefix_tokenizer else self.tokenizer( - string, - return_tensors="jax" - ) - - input_ids = tokens.input_ids - attention_mask = tokens.attention_mask - num_generated_tokens = 0 - - for _ in range( - (max_new_tokens or self.serve_config.max_new_tokens) // self.serve_config.max_compile_tokens): - - predicted_token = self.get_generation_function(greedy=greedy)( - self.params, - input_ids, - attention_mask - ) - - num_generated_tokens += predicted_token.shape[-1] - plus_attn_mask = jnp.ones( - (len(attention_mask), self.serve_config.max_compile_tokens), - dtype="i4" - ) - - input_ids = jnp.concatenate( - (input_ids, predicted_token), dtype="i4", - axis=-1 - )[:, -fixed_pad:] - - attention_mask = jnp.concatenate( - (attention_mask, plus_attn_mask), dtype="i4", - axis=-1 - )[:, -fixed_pad:] - - returns = ( - self.tokenizer.decode( - input_ids[0][-num_generated_tokens:], # type:ignore - skip_special_tokens=True - ), - num_generated_tokens - ) - - yield returns - - if self.serve_config.use_mxn_break_point: - if self.serve_config.max_compile_tokens != predicted_token.shape[-1]: - break - if ( - predicted_token[0][-1] == (self.serve_config.eos_token_id or self.tokenizer.eos_token_id) - or - predicted_token[0][-1] == (self.serve_config.eos_token_id or self.prefix_tokenizer.eos_token_id) - ): - break - - def compile(self, verbose: bool = True) -> bool: - """ - The compile function is used to compile the model for use in inference. - It does this by running through all possible combinations of rules and actions, - and compiling them into functions that can be called later on during inference. - This allows us to avoid having to recompile the model every time we want to run it, - which would be very slow. - - :param self: Represent the instance of the class - :param verbose: bool: Print out the compiling process - :return: True, but what does it do? - """ - if self.serve_config.use_prefix_tokenizer: - if verbose: - logger.info("Compiling greedy generate function") - response, tokens = [None] * 2 - for response, tokens in self.sample( - string="", - max_new_tokens=self.serve_config.max_compile_tokens, - greedy=True - ): - ... - if verbose: - logger.info("Compiling non-greedy generate function") - for response, tokens in self.sample( - string="", - max_new_tokens=self.serve_config.max_compile_tokens, - greedy=False - ): - ... - - else: - warnings.warn( - "Skip Compiling the compiling process is useless " - "when you are not using prefix tokenizer", - ) - return True - - def __repr__(self): - - """ - The __repr__ function is used to generate a string representation of an object. - This function should return a string that can be parsed by the Python interpreter - to recreate the object. The __repr__ function is called when you use print() on an - object, or when you type its name in the REPL. - - :param self: Refer to the instance of the class - :return: A string representation of the object - """ - string = f"{self.__class__.__name__}(\n" - for k, v in self.__dict__.items(): - if not k.startswith("_"): - try: - repr_src = f"\t{k} : " + v.__str__().replace("\n", "\n\t") + "\n" - string += repr_src if len(repr_src) < 500 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n" - except TypeError: - ... - return string + ")" - - def __str__(self): - - """ - The __str__ function is called when you use the print function or when str() is used. - It should return a string representation of the object. - - :param self: Refer to the instance of the class - :return: The object's string representation - """ - return self.__repr__() - - def fire(self): - async def run_engine(): - async with websockets.serve(self.handle_client, self.serve_config.host, self.serve_config.port) as ws: - logger.info(f"Starting EasyDeL websocket server on {self.serve_config.host}:{self.serve_config.port}") - await ws.wait_closed() - - asyncio.run(run_engine()) +import asyncio +import copy +import functools +import json +import logging +import time +import warnings + +import jax +import websocket +import websockets +from fjformer import with_sharding_constraint, match_partition_rules, make_shard_and_gather_fns, get_dtype +from jax import numpy as jnp + +from ...etils.etils import get_logger +from ...modules.easydel_modelling_utils import EasyDeLFlaxPretrainedModel +from flax.core import FrozenDict +from transformers import PreTrainedTokenizerBase, GenerationConfig +from typing import Callable, Tuple, List, Optional, Union, Dict +from .configuration import EasyServeConfig +from jax.sharding import PartitionSpec, Mesh +from jax.experimental.pjit import pjit +from dataclasses import dataclass + +logger = get_logger(__name__) + + +@dataclass +class LLMBaseReq: + greedy_generate_function: Callable + non_greedy_generate_function: Callable + tokenizer: PreTrainedTokenizerBase + prefix_tokenizer: PreTrainedTokenizerBase + + +class EasyServe: + def __init__( + self, + llm: EasyDeLFlaxPretrainedModel, + params: Union[FrozenDict, dict], + tokenizer: PreTrainedTokenizerBase, + prefix_tokenizer: PreTrainedTokenizerBase, + greedy_generate_function: Callable, + non_greedy_generate_function: Callable, + serve_config: EasyServeConfig, + ): + self.llm = llm + self.params = params + self.tokenizer = tokenizer + self.prefix_tokenizer = prefix_tokenizer + self.greedy_generate_function = greedy_generate_function + self.non_greedy_generate_function = non_greedy_generate_function + self.serve_config = serve_config + if serve_config.pre_compile: + self.compile(verbose=serve_config.verbose) + + def get_generation_function(self, greedy: bool): + return self.greedy_generate_function if greedy else self.non_greedy_generate_function + + def conversation_template(self, conversation: List[Dict]) -> str: + """The conversation_template function takes a list of ConversationItem objects and returns a string. + where system message, user message, and assistant message are the content fields of the ConversationItem objects. + If there is no system message in the conversation, then it will be omitted from the template. + + Args: + self: Refer to the current instance of a class + conversation: List[ConversationItem]: Pass in the + conversation items + + Returns: + A string that is a concatenation of the messages in the + conversation + """ + return self.tokenizer.apply_chat_template( + conversation=conversation, + add_generation_prompt=True, + tokenize=False + ) + + async def generate(self, socket): + data = json.loads(await socket.recv()) + prompt = self.conversation_template(data["conversation"]) + max_new_tokens = data.get("max_new_tokens", None) or self.serve_config.max_new_tokens + greedy = data.get("greedy", None) or self.serve_config.greedy + start = time.time() + send_data = {} + prl_res = 0 + for response, num_token_generated in self.sample( + string=prompt, + max_new_tokens=max_new_tokens, + greedy=greedy, + + ): + generation_duration = time.time() - start + tokens_pre_second = num_token_generated / generation_duration + + send_data = { + "response": response[prl_res:], + "num_token_generated": num_token_generated, + "greedy": greedy, + "model_prompt": prompt, + "generation_duration": generation_duration, + "tokens_pre_second": tokens_pre_second, + "done": False + } + prl_res += len(response) + await socket.send(json.dumps(send_data)) + + send_data["done"] = True + send_data["response"] = "" + await socket.send(json.dumps(send_data)) + + async def handle_client(self, socket: websocket.WebSocket, path: str): + try: + logger.info("connection open") + if path == "/stream/v1/conversation": + await self.generate(socket) + elif path == "/": + await socket.send(json.dumps({"status": "AgentX server is Running..."})) + else: + await socket.send(json.dumps({"error": f"invalid path {path}"})) + except websockets.ConnectionClosed: + logger.info("connection closed") + except Exception as e: + logger.warning(f"Error: {e}") + + @staticmethod + def create_shard_and_gather_functions( + parameters: dict, + partition_rules: Tuple[Tuple[str, PartitionSpec]], + dtype: Union[jax.numpy.dtype, str] = "fp16" + ): + + """The create_shard_and_gather_functions function takes in a dictionary of parameters, + a tuple of partition rules, and an optional dtype. It then matches the partition rules to the + parameters and creates shard functions for each parameter. The shard functions are used to + split up a parameter into shards (or partitions) that can be stored on different devices. + The gather function is used to combine all the shards back together again. + + Args: + parameters: dict: Specify the parameters of the model + partition_rules: Tuple[Tuple[str, PartitionSpec]]: Specify + which parameters to partition + dtype: jax.numpy.dtype | str: Specify the data type of the + parameters + + Returns: + A tuple of three elements: + """ + partition_specs = match_partition_rules(partition_rules, parameters) + shard_fns, gather_fns = make_shard_and_gather_fns( + partition_specs=partition_specs, + dtype_specs=get_dtype(dtype) + ) + return shard_fns, gather_fns, partition_specs + + @staticmethod + def shard_parameters( + mesh: Mesh, + params: Union[FrozenDict, dict], + partition_rules: Tuple[Tuple[str, PartitionSpec]], + serve_config: EasyServeConfig, + ): + + """The shard_parameters function takes a set of parameters and partitions them according to the partition_rules. + + Args: + mesh: Mesh: Create a mesh object that is used to shard the + parameters + params: FrozenDict | dict: Pass in the parameters of the + model + partition_rules: Tuple[Tuple[str, PartitionSpec]]: Specify + the partitioning rules for each parameter + serve_config: EasyServeConfig: Specify the dtype of the + parameters + :param : Create a mesh of devices + + Returns: + sharded parameters + """ + + partition_specs = match_partition_rules(params=params, rules=partition_rules) + shard_fns, _ = make_shard_and_gather_fns(partition_specs, get_dtype(serve_config.dtype)) + + with mesh: + params = jax.tree_map( + lambda func, param: func(param), shard_fns, params + ) + + return params + + @staticmethod + def create_generation_functions_and_tokenizers( + model: EasyDeLFlaxPretrainedModel, + tokenizer: PreTrainedTokenizerBase, + serve_config: EasyServeConfig, + partition_specs: dict[str, PartitionSpec] + ) -> LLMBaseReq: + """The create_generation_functions_and_tokenizers function is used to create the functions that will be used for + generation. It also creates a tokenizer object that can be used to encode and decode text. The function takes in + a model, a tokenizer, an EasyServeConfig object (which contains all the parameters needed for generation), and + partition_specs which are specifications about how data should be partitioned across devices. + + Args: + model: EasyDeLFlaxPretrainedModel: Create the model and + tokenizer + tokenizer: PreTrainedTokenizerBase: Create a tokenizer + object + serve_config: EasyServeConfig: Create the generation + function + partition_specs: dict[str, PartitionSpec]: Specify the + sharding of the model parameters + + Returns: + An LLMBaseReq object + """ + if tokenizer.pad_token is None: + logging.info( + "Tokenizer does not contain padding token setting padding token to eos token for open end generation") + tokenizer.pad_token = tokenizer.eos_token + + try: + tokenizer.padding_side = "left" + tokenizer.truncation_side = "left" + prefix_tokenizer = copy.deepcopy(tokenizer) + tokenizer.padding_side = "right" + tokenizer.truncation_side = "right" + tokenizer = copy.deepcopy(tokenizer) + + except: + warnings.warn( + f"The class Model of Tokenizer {type(tokenizer)} do not support deepcopy option " + ) + if serve_config.use_prefix_tokenizer: + tokenizer.padding_side = "left" + tokenizer.truncation_side = "left" + else: + tokenizer.padding_side = "right" + tokenizer.truncation_side = "right" + prefix_tokenizer = tokenizer + + @functools.partial( + pjit, + in_shardings=(partition_specs, PartitionSpec(), PartitionSpec()), + out_shardings=(PartitionSpec()) + ) + def greedy_generate_function( + parameters, + input_ids, + attention_mask + ): + input_ids = with_sharding_constraint(input_ids, serve_config.generation_ps) + attention_mask = with_sharding_constraint(attention_mask, serve_config.generation_ps) + predict = model.generate( + input_ids, + attention_mask=attention_mask, + params=parameters, + generation_config=GenerationConfig( + max_new_tokens=serve_config.max_compile_tokens, + + eos_token_id=serve_config.eos_token_id or tokenizer.eos_token_id, + pad_token_id=serve_config.pad_token_id or tokenizer.pad_token_id, + bos_token_id=serve_config.bos_token_id or tokenizer.bos_token_id, + + do_sample=False, + num_beams=1, + ) + ).sequences[:, input_ids.shape[1]:] + return predict + + @functools.partial( + pjit, + in_shardings=(partition_specs, PartitionSpec(), PartitionSpec()), + out_shardings=(PartitionSpec()) + ) + def non_greedy_generate_function( + parameters, + input_ids, + attention_mask + ): + input_ids = with_sharding_constraint(input_ids, serve_config.generation_ps) + attention_mask = with_sharding_constraint(attention_mask, serve_config.generation_ps) + predict = model.generate( + input_ids, + attention_mask=attention_mask, + params=parameters, + generation_config=GenerationConfig( + max_new_tokens=serve_config.max_compile_tokens, + + eos_token_id=serve_config.eos_token_id or tokenizer.eos_token_id, + pad_token_id=serve_config.pad_token_id or tokenizer.pad_token_id, + bos_token_id=serve_config.bos_token_id or tokenizer.bos_token_id, + + temperature=serve_config.temperature, + repetition_penalty=serve_config.repetition_penalty, + do_sample=True, + num_beams=1, + top_p=serve_config.top_p, + top_k=serve_config.top_k, + ) + ).sequences[:, input_ids.shape[1]:] + return predict + + return LLMBaseReq( + greedy_generate_function=greedy_generate_function, + non_greedy_generate_function=non_greedy_generate_function, + tokenizer=tokenizer, + prefix_tokenizer=prefix_tokenizer + ) + + @classmethod + def from_parameters( + cls, + llm: EasyDeLFlaxPretrainedModel, + params: dict, + tokenizer: PreTrainedTokenizerBase, + serve_config: EasyServeConfig, + partition_rules: Tuple[Tuple[str, PartitionSpec]], + shard_parameters: bool = True, + ): + + """The from_parameters function is the main entry point for creating a model that can be served. + It takes in a pretrained model, parameters, tokenizer and serve_config as input and returns an object of type + EasyServe. + + Args: + cls: Create a new instance of the class + llm: EasyDeLFlaxPretrainedModel: Pass the model to the class + params: dict: Pass the parameters of the model + tokenizer: PreTrainedTokenizerBase: Create the tokenizer and + prefix_tokenizer + serve_config: EasyServeConfig: Configure the model for + serving + partition_rules: Tuple[Tuple[str, PartitionSpec]]: Partition + the parameters of the model + shard_parameters: bool: Specify whether the parameters + should be sharded or not + :param : Shard the parameters of the model + + Returns: + A EasyServe object + """ + shard_fns, gather_fns, partition_specs = cls.create_shard_and_gather_functions( + parameters=params, + partition_rules=partition_rules, + dtype=serve_config.dtype + ) + llm_base_req = cls.create_generation_functions_and_tokenizers( + model=llm, + tokenizer=tokenizer, + partition_specs=partition_specs, + serve_config=serve_config + ) + + if shard_parameters: + params = cls.shard_parameters( + params=params, + partition_rules=partition_rules, + serve_config=serve_config, + mesh=llm.config.jax_mesh() + ) + + return cls( + llm=llm, + serve_config=serve_config, + tokenizer=llm_base_req.tokenizer, + prefix_tokenizer=llm_base_req.prefix_tokenizer, + params=params, + greedy_generate_function=llm_base_req.greedy_generate_function, + non_greedy_generate_function=llm_base_req.non_greedy_generate_function, + ) + + def sample( + self, + string: str, + *, + greedy: bool = False, + max_new_tokens: int = None, + **kwargs + ): + """The process function is the main function of a model. It takes in an input string and returns a list of strings + that are generated from that input string. The process function can be called multiple times with different inputs, + and each time it will return a new set of outputs based on those inputs. + + Args: + self: Access the class attributes + string: str: Pass the string that we want to generate + greedy: bool: Determine whether to use the greedy or non- + greedy version of the generate function + max_new_tokens: int: Set the number of tokens to generate + **kwargs: Pass any additional parameters to the process + function + + Returns: + A generator that yields the predicted text and the number of + tokens generated + """ + with self.llm.config.jax_mesh(): + fixed_pad = self.serve_config.max_sequence_length - self.serve_config.max_compile_tokens + tokens = self.prefix_tokenizer( + string, + max_length=fixed_pad, + padding="max_length", + return_tensors="jax" + ) if self.serve_config.use_prefix_tokenizer else self.tokenizer( + string, + return_tensors="jax" + ) + + input_ids = tokens.input_ids + attention_mask = tokens.attention_mask + num_generated_tokens = 0 + + for _ in range( + (max_new_tokens or self.serve_config.max_new_tokens) // self.serve_config.max_compile_tokens): + + predicted_token = self.get_generation_function(greedy=greedy)( + self.params, + input_ids, + attention_mask + ) + + num_generated_tokens += predicted_token.shape[-1] + plus_attn_mask = jnp.ones( + (len(attention_mask), self.serve_config.max_compile_tokens), + dtype="i4" + ) + + input_ids = jnp.concatenate( + (input_ids, predicted_token), dtype="i4", + axis=-1 + )[:, -fixed_pad:] + + attention_mask = jnp.concatenate( + (attention_mask, plus_attn_mask), dtype="i4", + axis=-1 + )[:, -fixed_pad:] + + returns = ( + self.tokenizer.decode( + input_ids[0][-num_generated_tokens:], # type:ignore + skip_special_tokens=True + ), + num_generated_tokens + ) + + yield returns + + if self.serve_config.use_mxn_break_point: + if self.serve_config.max_compile_tokens != predicted_token.shape[-1]: + break + if ( + predicted_token[0][-1] == (self.serve_config.eos_token_id or self.tokenizer.eos_token_id) + or + predicted_token[0][-1] == (self.serve_config.eos_token_id or self.prefix_tokenizer.eos_token_id) + ): + break + + def compile(self, verbose: bool = True) -> bool: + """The compile function is used to compile the model for use in inference. + It does this by running through all possible combinations of rules and actions, + and compiling them into functions that can be called later on during inference. + This allows us to avoid having to recompile the model every time we want to run it, + which would be very slow. + + Args: + self: Represent the instance of the class + verbose: bool: Print out the compiling process + + Returns: + True, but what does it do? + """ + if self.serve_config.use_prefix_tokenizer: + if verbose: + logger.info("Compiling greedy generate function") + response, tokens = [None] * 2 + for response, tokens in self.sample( + string="", + max_new_tokens=self.serve_config.max_compile_tokens, + greedy=True + ): + ... + if verbose: + logger.info("Compiling non-greedy generate function") + for response, tokens in self.sample( + string="", + max_new_tokens=self.serve_config.max_compile_tokens, + greedy=False + ): + ... + + else: + warnings.warn( + "Skip Compiling the compiling process is useless " + "when you are not using prefix tokenizer", + ) + return True + + def __repr__(self): + + """The __repr__ function is used to generate a string representation of an object. + This function should return a string that can be parsed by the Python interpreter + to recreate the object. The __repr__ function is called when you use print() on an + object, or when you type its name in the REPL. + + Args: + self: Refer to the instance of the class + + Returns: + A string representation of the object + """ + string = f"{self.__class__.__name__}(\n" + for k, v in self.__dict__.items(): + if not k.startswith("_"): + try: + repr_src = f"\t{k} : " + v.__str__().replace("\n", "\n\t") + "\n" + string += repr_src if len(repr_src) < 500 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n" + except TypeError: + ... + return string + ")" + + def __str__(self): + + """The __str__ function is called when you use the print function or when str() is used. + It should return a string representation of the object. + + Args: + self: Refer to the instance of the class + + Returns: + The object's string representation + """ + return self.__repr__() + + def fire(self): + async def run_engine(): + async with websockets.serve(self.handle_client, self.serve_config.host, self.serve_config.port) as ws: + logger.info(f"Starting EasyDeL websocket server on {self.serve_config.host}:{self.serve_config.port}") + await ws.wait_closed() + + asyncio.run(run_engine()) diff --git a/src/python/easydel/serve/torch_serve.py b/src/python/easydel/serve/torch_serve.py index b19faee28..417b1cada 100644 --- a/src/python/easydel/serve/torch_serve.py +++ b/src/python/easydel/serve/torch_serve.py @@ -108,15 +108,17 @@ def __str__(self): class PyTorchServer(GradioUserInference): def __init__(self, server_config: PyTorchServerConfig): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It sets up the instance of the class, and defines all its attributes. The __init__ function can accept arguments, which are passed at instantiation. - :param self: Represent the instance of the class - :param server_config: PyTorchServerConfig: Pass the configuration parameters to the class - :return: The app, which is a fastapi object - + Args: + self: Represent the instance of the class + server_config: PyTorchServerConfig: Pass the configuration + parameters to the class + + Returns: + The app, which is a fastapi object """ self.model, self.tokenizer = [None] * 2 @@ -141,12 +143,13 @@ def __init__(self, server_config: PyTorchServerConfig): @staticmethod def get_gpu_memory(num_gpus_req=None): - """ - The get_gpu_memory function returns the amount of available GPU memory in GB. + """The get_gpu_memory function returns the amount of available GPU memory in GB. - :param num_gpus_req: Specify the number of gpus to be used - :return: The amount of free memory on each gpu - + Args: + num_gpus_req: Specify the number of gpus to be used + + Returns: + The amount of free memory on each gpu """ gpu_m = [] dc = torch.cuda.device_count() @@ -160,12 +163,13 @@ def get_gpu_memory(num_gpus_req=None): return gpu_m def get_model_load_kwargs(self): - """ - The get_model_load_kwargs function is used to set the torch_dtype, device_map and max_memory parameters for loading a model. + """The get_model_load_kwargs function is used to set the torch_dtype, device_map and max_memory parameters for loading a model. + + Args: + self: Bind the method to an object - :param self: Bind the method to an object - :return: A dictionary with the following keys: - + Returns: + A dictionary with the following keys: """ if self.server_config.dtype == "fp16": dtype = torch.float16 @@ -184,8 +188,7 @@ def get_model_load_kwargs(self): def status(self): - """ - The status function returns a dictionary with the following keys: + """The status function returns a dictionary with the following keys: server_config: A dictionary of configuration parameters. devices: The number of GPUs available to the server. device_sharding: Whether device sharding is enabled. If True, then each request will be served by @@ -194,9 +197,11 @@ def status(self): initialization function via torch-serve"s DeviceShardingStrategy class. See https://pytorch-lightning.readthedoc - :param self: Represent the instance of the class - :return: A dictionary with the following keys: - + Args: + self: Represent the instance of the class + + Returns: + A dictionary with the following keys: """ return { "server_config": {k: v for k, v in self.server_config.__dict__.items()}, @@ -208,17 +213,19 @@ def status(self): } def forward_instruct_fast_api(self, data: InstructRequest): - """ - The forward_instruct_fast_api function is a ReST API endpoint that takes in an InstructRequest object and returns + """The forward_instruct_fast_api function is a ReST API endpoint that takes in an InstructRequest object and returns a response. The InstructRequest object contains the following fields: - system (str): A string representing the name of the system to be instructed. This should match one of the systems defined in your server_config file, or else it will default to "default". If you want to instruct multiple systems at once, use forward_instruct_fast instead. - :param self: Refer to the object itself - :param data: InstructRequest: Pass in the data that is used to generate the response - :return: A dictionary with a single key, response - + Args: + self: Refer to the object itself + data: InstructRequest: Pass in the data that is used to + generate the response + + Returns: + A dictionary with a single key, response """ string = self.format_instruct( system=data.system, @@ -238,14 +245,16 @@ def forward_instruct_fast_api(self, data: InstructRequest): } def forward_chat_fast_api(self, data: ChatRequest): - """ - The forward_chat_fast_api function is a ReST API endpoint that takes in a ChatRequest object and returns the + """The forward_chat_fast_api function is a ReST API endpoint that takes in a ChatRequest object and returns the response from the model. - :param self: Refer to the object itself - :param data: ChatRequest: Pass the data from the serve_engine to the function - :return: A dictionary with a single key, response - + Args: + self: Refer to the object itself + data: ChatRequest: Pass the data from the serve_engine to + the function + + Returns: + A dictionary with a single key, response """ string = self.format_chat( system=data.system, @@ -266,9 +275,7 @@ def forward_chat_fast_api(self, data: ChatRequest): } def format_instruct(self, system: str, instruction: str) -> str: - """ - Here you will get the system and instruction from user, and you can apply your prompting style - """ + """Here you will get the system and instruction from user, and you can apply your prompting style""" conversation = [] if system is not None and system != "": conversation.append({ @@ -284,9 +291,7 @@ def format_instruct(self, system: str, instruction: str) -> str: ) def format_chat(self, history: List[List[str]], prompt: str, system: typing.Union[str, None]) -> str: - """ - Here you will get the system, prompt and history from user, and you can apply your prompting style - """ + """Here you will get the system, prompt and history from user, and you can apply your prompting style""" conversation = [] if system is not None and system != "": conversation.append({ @@ -327,21 +332,29 @@ def sample( stream: bool = True, sample: bool = True ): - """ - The sample function is the main function of this class. It takes a string as input and returns a generator that yields strings. - - :param self: Represent the instance of the class - :param string: str: Pass the string to be generated - :param max_new_tokens: Optional[int]: Limit the number of new tokens that can be generated - :param max_sequence_length: Optional[int]: Set the maximum length of the generated text - :param temperature: Optional[float]: Control the randomness of the text generation - :param top_k:Optional[int]: Filter out the top k tokens with the highest probability - :param top_p:Optional[int]: Control the probability of sampling from the top n tokens - :param repetition_penalty: optional[float]: repetition penalty for generation - :param stream: bool: Determine whether to stream the output or not - :param sample: optional[bool]: Indicate whether to sample from the distribution or take the argmax - :return: A generator - + """The sample function is the main function of this class. It takes a string as input and returns a generator that yields strings. + + Args: + self: Represent the instance of the class + string: str: Pass the string to be generated + max_new_tokens: Optional[int]: Limit the number of new + tokens that can be generated + max_sequence_length: Optional[int]: Set the maximum length + of the generated text + temperature: Optional[float]: Control the randomness of the + text generation + top_k: Optional[int]: Filter out the top k tokens with the + highest probability + top_p: Optional[int]: Control the probability of sampling + from the top n tokens + repetition_penalty: optional[float]: repetition penalty for + generation + stream: bool: Determine whether to stream the output or not + sample: optional[bool]: Indicate whether to sample from the + distribution or take the argmax + + Returns: + A generator """ assert self.model is not None, "you should first load model with ``load`` method" tokens = self.tokenizer( @@ -405,16 +418,20 @@ def sample( return pred def load(self, pretrained_model_name_or_path: str, tokenizer_repo: str = None, auto_config: bool = True, **kwargs): - """ - The load function is used to load a model from the HuggingFace Model Hub. - - :param self: Represent the instance of the class - :param pretrained_model_name_or_path: str: Specify the name of the model to be loaded - :param tokenizer_repo: str: Specify the repo id of the tokenizer - :param auto_config: bool: Determine whether the model should be loaded with a server_config file or not - :param kwargs: Pass a variable number of keyword arguments to the function - :return: A tuple of model and tokenizer - + """The load function is used to load a model from the HuggingFace Model Hub. + + Args: + self: Represent the instance of the class + pretrained_model_name_or_path: str: Specify the name of the + model to be loaded + tokenizer_repo: str: Specify the repo id of the tokenizer + auto_config: bool: Determine whether the model should be + loaded with a server_config file or not + **kwargs: Pass a variable number of keyword arguments to the + function + + Returns: + A tuple of model and tokenizer """ load_kwargs = kwargs if not auto_config else self.get_model_load_kwargs() load_kwargs = load_kwargs | kwargs @@ -486,12 +503,13 @@ def gradio_inference(self): ) def fire(self): - """ - The fire function starts the uvicorn server in a separate process. + """The fire function starts the uvicorn server in a separate process. + + Args: + self: Represent the instance of the class - :param self: Represent the instance of the class - :return: A process that runs the uvicorn server - + Returns: + A process that runs the uvicorn server """ def run(): @@ -501,13 +519,14 @@ def run(): self.process_uvicorn.start() def end(self): - """ - The end function is used to stop the server. + """The end function is used to stop the server. It will wait for the process to end before returning. - :param self: Represent the instance of the class - :return: A boolean value - + Args: + self: Represent the instance of the class + + Returns: + A boolean value """ if self.process_uvicorn is not None: self.process_uvicorn.join() diff --git a/src/python/easydel/serve/utils.py b/src/python/easydel/serve/utils.py index e36879ef0..be53be110 100644 --- a/src/python/easydel/serve/utils.py +++ b/src/python/easydel/serve/utils.py @@ -58,22 +58,27 @@ def __init__( "monospace", ), ): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It sets up the object with all of its instance variables and other things it needs to function properly. - - :param self: Represent the instance of the object - :param *: Unpack the list of parameters into a tuple - :param primary_hue: Union[colors.Color,str]: Set the primary color of the theme - :param secondary_hue: Union[colors.Color,str]: Set the secondary color of the theme - :param neutral_hue: Union[colors.Color,str]: Set the neutral color of the theme - :param spacing_size: Union[sizes.Size,str]: Set the spacing size of the theme - :param radius_size: Union[sizes.Size,str]: Set the radius of the buttons and other elements - :param text_size: Union[sizes.Size,str]: Set the size of the text in the app - - :return: The class object - + Args: + self: Represent the instance of the object + : Unpack the list of parameters into a tuple + primary_hue: Union[colors.Color,str]: Set the primary color + of the theme + secondary_hue: Union[colors.Color,str]: Set the secondary + color of the theme + neutral_hue: Union[colors.Color,str]: Set the neutral color + of the theme + spacing_size: Union[sizes.Size,str]: Set the spacing size of + the theme + radius_size: Union[sizes.Size,str]: Set the radius of the + buttons and other elements + text_size: Union[sizes.Size,str]: Set the size of the text + in the app + + Returns: + The class object """ super().__init__( @@ -179,8 +184,9 @@ def create_generate_function( :param logits_processor :LogitsProcessor: Processor for model logits. Defaults to None. :param return_prediction_only :bool: Whether to return only the generated sequences. Defaults to True. - :return :Callable[[Any, chex.Array, chex.Array], chex.Array]: Sharded function for text generation. - + Returns: + Callable[[Any, chex.Array, chex.Array], chex.Array]: Sharded + function for text generation. """ def generate_fn( @@ -190,10 +196,14 @@ def generate_fn( ) -> chex.Array: """Generate text sequences using the provided model and parameters. - :param parameters:Union[dict, jax.tree_util.PyTreeDef]: Model parameters. - :param input_ids: chex.Array: Input token IDs. - :param attention_mask:chex.Array: Attention mask. - :return: Generated array sequences. + Args: + parameters: Union[dict, jax.tree_util.PyTreeDef]: Model + parameters. + input_ids: chex.Array: Input token IDs. + attention_mask: chex.Array: Attention mask. + + Returns: + Generated array sequences. """ input_ids = with_sharding_constraint( input_ids, diff --git a/src/python/easydel/smi/smi.py b/src/python/easydel/smi/smi.py index 71d3ed4cd..ae3011d40 100644 --- a/src/python/easydel/smi/smi.py +++ b/src/python/easydel/smi/smi.py @@ -14,18 +14,20 @@ # Edited version of Jax-SMI from https://github.com/ayaka14732/jax-smi/ def run(note_book=None, interval: float = 1, dir_prefix: str = '/dev/shm', dpr=True): - """ - The run function is a simple wrapper around the go tool pprof command. + """The run function is a simple wrapper around the go tool pprof command. It runs the command every interval seconds and prints out its output to stdout. If you are running this in a notebook, it will print to IPython's display instead of stdout. + Args: + note_book: Determine whether the program is running in a + notebook or not + interval: float: Specify the time interval between each refresh + dir_prefix: str: Specify the directory where the memory + dpr: Control whether the output is displayed in a notebook or + not - :param note_book: Determine whether the program is running in a notebook or not - :param interval: float: Specify the time interval between each refresh - :param dir_prefix: str: Specify the directory where the memory - :param dpr: Control whether the output is displayed in a notebook or not - :return: The output of the pprof command - + Returns: + The output of the pprof command """ if note_book is None: import os @@ -62,15 +64,16 @@ def is_notebook(): def get_mem(dir_prefix: str = "/dev/shm" if sys.platform != "win32" else "."): - """ - The get_mem function is a wrapper around the go tool pprof command. + """The get_mem function is a wrapper around the go tool pprof command. It takes in an optional argument, dir_prefix, which defaults to /dev/shm. The function then runs the go tool pprof command with arguments -tags and dir_prefix/memory.prof, and returns its stdout as a string. - :param dir_prefix: str: Specify the directory where - :return: A string of the memory profile - + Args: + dir_prefix: str: Specify the directory where + + Returns: + A string of the memory profile """ return subprocess.run( args=['go', 'tool', 'pprof', '-tags', f'{dir_prefix}/memory.prof'], @@ -81,13 +84,16 @@ def get_mem(dir_prefix: str = "/dev/shm" if sys.platform != "win32" else "."): def initialise_tracking(interval: float = 0.5, dir_prefix: str = "/dev/shm" if sys.platform != "win32" else ".") -> None: - """ - The initialise_tracking function starts a daemon thread that periodically saves the current memory profile to disk. + """The initialise_tracking function starts a daemon thread that periodically saves the current memory profile to disk. + + Args: + interval: float: Specify the time interval between each memory + profile + dir_prefix: str: Specify the directory where the memory profile + will be saved - :param interval: float: Specify the time interval between each memory profile - :param dir_prefix: str: Specify the directory where the memory profile will be saved - :return: Nothing, but it starts a thread that - + Returns: + Nothing, but it starts a thread that """ def inner(): diff --git a/src/python/easydel/trainer/base_trainer.py b/src/python/easydel/trainer/base_trainer.py index 988971bed..59e276bb1 100644 --- a/src/python/easydel/trainer/base_trainer.py +++ b/src/python/easydel/trainer/base_trainer.py @@ -67,8 +67,7 @@ def __init__( checkpoint_path: Union[str, os.PathLike] = None, _do_init_fns: bool = True ): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It sets up all the variables that are needed for training, including: - The timer to keep track of how long each epoch takes. - The dataloaders for both training and evaluation (if provided). @@ -78,15 +77,19 @@ def __init__( or loaded from a checkpoint file (see below). This means that you can pass in either - :param self: Represent the instance of the class - :param arguments: TrainArguments: Pass the arguments to the trainer - :param dataset_train: Dataset: Pass the training dataset to the trainer - :param dataset_eval: Dataset: Pass the validation dataset - :param finetune: bool: Load the model from a checkpoint - :param checkpoint_path: Union[str,os.PathLike] : Load the checkpoint path - :param _do_init_fns: bool: Initialize the functions - :return: Nothing, it just initializes the class - + Args: + self: Represent the instance of the class + arguments: TrainArguments: Pass the arguments to the trainer + dataset_train: Dataset: Pass the training dataset to the + trainer + dataset_eval: Dataset: Pass the validation dataset + finetune: bool: Load the model from a checkpoint + checkpoint_path: Union[str,os.PathLike] : Load the + checkpoint path + _do_init_fns: bool: Initialize the functions + + Returns: + Nothing, it just initializes the class """ # Loggers self.timer = getattr(self, "timer", None) @@ -177,12 +180,11 @@ def __repr__(self): @staticmethod def finish(): - """ - The finish function is called when the experiment ends. + """The finish function is called when the experiment ends. It can be used to save data, upload files, or do any other cleanup tasks. - :return: A dictionary of the run's metadata - + Returns: + A dictionary of the run's metadata """ wandb.finish() @@ -203,16 +205,17 @@ def _start(): return threading.Thread(target=_start) def initialize_trainer_utils(self): - """ - The initialize_trainer_utils function is responsible for initializing the following: + """The initialize_trainer_utils function is responsible for initializing the following: - wandb_runtime (if you use_wandb is True) - timer object (for logging time taken by various functions) - dataloader objects for training and evaluation data, along with max steps per epoch. The configure_dataloader function accomplishes this task. - :param self: Represent the instance of the class - :return: A tuple of functions + Args: + self: Represent the instance of the class + Returns: + A tuple of functions """ self.wandb_runtime = None if self.arguments.use_wandb: @@ -279,24 +282,27 @@ def create_collate_function( @abc.abstractmethod def configure_functions(self) -> TrainerConfigureFunctionFuncOutput: - """ - The configure_functions function is responsible for configuring the functions that will be used in training. + """The configure_functions function is responsible for configuring the functions that will be used in training. It does this by first defining a function called function_configurations, which initializes the model parameters and returns them as a EasyDeLState object. The EasyDeLState object contains all the information needed to train or evaluate on a batch of data, including: - :param self: Access the class attributes - :return: A TrainerConfigureFunctionFuncOutput object + Args: + self: Access the class attributes + + Returns: + A TrainerConfigureFunctionFuncOutput object """ raise NotImplementedError def configure_dataloader(self) -> TrainerConfigureDataloaderFuncOutput: - """ - The configure_dataloader function is used to configure the dataloader for training and evaluation. + """The configure_dataloader function is used to configure the dataloader for training and evaluation. - :param self: Refer to the class instance itself - :return: A TrainerConfigureDataloaderFuncOutput object + Args: + self: Refer to the class instance itself + Returns: + A TrainerConfigureDataloaderFuncOutput object """ def create_tf_dataset(dataset: Dataset, is_train: bool) -> Iterator[ndarray[Any, Any]]: @@ -332,9 +338,7 @@ def create_tf_dataset_from_iterable(dataset: IterableDataset, is_train: bool) -> ) def calculate_steps(dataset: Union[Dataset, IterableDataset], is_train: bool): - """ - Return total number of steps to train or evaluate on. - """ + """Return total number of steps to train or evaluate on.""" if hasattr(dataset, "__len__"): num_steps = len(dataset) * (self.arguments.num_train_epochs if is_train else 1) max_steps = self.arguments.max_training_steps if is_train else self.arguments.max_evaluation_steps @@ -369,12 +373,14 @@ def to_tf_dataloader(dataset: Union[Dataset, IterableDataset], is_train: bool): ) def configure_model(self) -> TrainerConfigureModelFuncOutput: - """ - The configure_model function is responsible for creating the model, optimizer and scheduler. + """The configure_model function is responsible for creating the model, optimizer and scheduler. - :param self: Represent the instance of the class - :return: A model, optimizer, scheduler and config in TrainerConfigureModelFuncOutput Object + Args: + self: Represent the instance of the class + Returns: + A model, optimizer, scheduler and config in + TrainerConfigureModelFuncOutput Object """ extra_configs = {} if self.arguments.extra_configs is None else self.arguments.extra_configs if self.arguments.model_class is not None: @@ -452,12 +458,8 @@ def _save_state( @abc.abstractmethod def train(self): - """ - abstract of Train Function to train model - """ + """abstract of Train Function to train model""" @abc.abstractmethod def eval(self, state): - """ - abstract of Eval Function to evaluate model - """ + """abstract of Eval Function to evaluate model""" diff --git a/src/python/easydel/trainer/causal_language_model_trainer/__init__.py b/src/python/easydel/trainer/causal_language_model_trainer/__init__.py index 550d09952..9fa030717 100644 --- a/src/python/easydel/trainer/causal_language_model_trainer/__init__.py +++ b/src/python/easydel/trainer/causal_language_model_trainer/__init__.py @@ -1,15 +1,15 @@ -from .causal_language_model_trainer import ( - CausalLanguageModelTrainer as CausalLanguageModelTrainer, - CausalLMTrainerOutput as CausalLMTrainerOutput -) -from .fwd_bwd_functions import ( - create_casual_language_model_train_step as create_casual_language_model_train_step, - create_casual_language_model_evaluation_step as create_casual_language_model_evaluation_step -) - -__all__ = ( - "create_casual_language_model_train_step", - "create_casual_language_model_evaluation_step", - "CausalLanguageModelTrainer", - "CausalLMTrainerOutput" -) +from .causal_language_model_trainer import ( + CausalLanguageModelTrainer as CausalLanguageModelTrainer, + CausalLMTrainerOutput as CausalLMTrainerOutput +) +from .fwd_bwd_functions import ( + create_casual_language_model_train_step as create_casual_language_model_train_step, + create_casual_language_model_evaluation_step as create_casual_language_model_evaluation_step +) + +__all__ = ( + "create_casual_language_model_train_step", + "create_casual_language_model_evaluation_step", + "CausalLanguageModelTrainer", + "CausalLMTrainerOutput" +) diff --git a/src/python/easydel/trainer/causal_language_model_trainer/causal_language_model_trainer.py b/src/python/easydel/trainer/causal_language_model_trainer/causal_language_model_trainer.py index 2d8c711c0..d67ed2b2d 100644 --- a/src/python/easydel/trainer/causal_language_model_trainer/causal_language_model_trainer.py +++ b/src/python/easydel/trainer/causal_language_model_trainer/causal_language_model_trainer.py @@ -57,14 +57,16 @@ def collate_fn(batch): return collate_fn def configure_functions(self) -> TrainerConfigureFunctionFuncOutput: - """ - The configure_functions function is responsible for configuring the functions that will be used in training. + """The configure_functions function is responsible for configuring the functions that will be used in training. It does this by first defining a function called function_configurations, which initializes the model parameters and returns them as a EasyDeLState object. The EasyDeLState object contains all the information needed to train or evaluate on a batch of data, including: - :param self: Access the class attributes - :return: A TrainerConfigureFunctionFuncOutput object + Args: + self: Access the class attributes + + Returns: + A TrainerConfigureFunctionFuncOutput object """ def initialize_state_function(): @@ -355,18 +357,20 @@ def train( model_parameters: Optional[flax.core.FrozenDict] = None, state: Optional[EasyDeLState] = None ) -> CausalLMTrainerOutput: - """ - The train function is the main function of this module. + """The train function is the main function of this module. It takes a model_parameters argument which can be used to load a pretrained model and finetune it. The train function returns an CausalLMTrainerOutput object that contains the last saved file name, predict func, train state, mesh and checkpoint streamer. + Args: + self: Make the class methods aware of other methods and + attributes within the class + model_parameters: flax.core.FrozenDict: Load a pre-trained + model + state: Optional[EasyDeLState]: Ready to Use State - :param self: Make the class methods aware of other methods and attributes within the class - :param model_parameters: flax.core.FrozenDict: Load a pre-trained model - :param state: Optional[EasyDeLState]: Ready to Use State - :return: An object of type "CausalLMTrainerOutput" - + Returns: + An object of type "CausalLMTrainerOutput" """ def get_layer_names(frozen_dict, prefix=""): diff --git a/src/python/easydel/trainer/causal_language_model_trainer/fwd_bwd_functions.py b/src/python/easydel/trainer/causal_language_model_trainer/fwd_bwd_functions.py index bfbece03b..a9cc69475 100644 --- a/src/python/easydel/trainer/causal_language_model_trainer/fwd_bwd_functions.py +++ b/src/python/easydel/trainer/causal_language_model_trainer/fwd_bwd_functions.py @@ -1,175 +1,188 @@ -from fjformer.func.loss_func import ( - cross_entropy_loss_and_accuracy, - SpecialLossNormalizingFactor, - get_loss_normalizing_factor_and_weights, - compute_weighted_cross_entropy_and_accuracy, -) - -import jax -from jax.sharding import PartitionSpec -from jax import numpy as jnp -from fjformer import with_sharding_constraint - - -def create_casual_language_model_train_step( - partition_spec=PartitionSpec(("dp", "fsdp"), "sp"), - label_smoothing_factor=0.0, - z_loss=0.0, - gradient_accumulation_steps: int = 1, -): - """ - The create_casual_language_model_train_step function is a training step function that takes in the current state - of the model,and a batch of data. It then calculates the loss and accuracy for this batch, and returns - an updated state with new parameters based on these gradients. - - :param partition_spec: Specify which devices the model will be split across - :param label_smoothing_factor: A float in [0, 1] specifying the amount of label smoothing to apply, - where 0 means no smoothing. - :param z_loss: A regularization term that adds a penalty for large weights, where 0 means no regularization. - :param gradient_accumulation_steps: int : gradient accumulation step size from arguments - :return: A casual_language_model_train_step function that takes in the current state of the model, - """ - assert gradient_accumulation_steps > 0, "gradient_accumulation_steps must be greater than 0" # Ignore - - def casual_language_model_train_step(state, batch): - """ - The casual_language_model_train_step function is a training step function that takes in the current state - of the model and a batch of data. It then calculates the loss and accuracy for this batch, - and returns an updated state with new parameters based on these gradients. - - :param state: Store the model parameters - :param batch: Pass the data to the model, dict with - input_ids(bs, seq_len), labels(bs, seq_len-1), attention_mask(bs, seq_len) - :return: A tuple of (state, loss, accuracy) - """ - batch = with_sharding_constraint(batch, partition_spec) - - def calculate_loss(params): - labels = batch.get("labels", None) - if labels is None: - labels = batch["input_ids"][..., 1:] - else: - labels = labels[..., 1:] - model_outputs = state.apply_fn(params=params, **batch, return_dict=True) - logits = model_outputs.logits - aux_loss = getattr(model_outputs, "aux_loss", None) - loss_normalizing_factor = ( - SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS - ) - # loss_weights is 1 unless the label is <= 0 or the attention mask is 0 - loss_weights = jnp.where( - (batch["attention_mask"][:, 1:] != 0) & (labels > 0), 1, 0 - ) - lnf, weights = get_loss_normalizing_factor_and_weights( - loss_normalizing_factor, - { - "decoder_target_tokens": labels, - "decoder_loss_weights": loss_weights, - }, - ) - ( - loss, - z_loss_computed, - weight_sum, - accuracy, - ) = compute_weighted_cross_entropy_and_accuracy( - logits=logits[:, :-1, :], - targets=labels, - weights=weights, - label_smoothing=label_smoothing_factor, - z_loss=z_loss, - loss_normalizing_factor=lnf, - ) - if aux_loss is not None: - loss += aux_loss - return loss, (accuracy, z_loss_computed, aux_loss) - - grad_fn = jax.value_and_grad(calculate_loss, has_aux=True) - (loss__, (accuracy__, z_loss_computed__, aux_loss__)), grad = grad_fn(state.params) - state = state.apply_gradients(grads=grad) - - grad_norms = jax.tree_map(jnp.linalg.norm, grad) - max_grad_norm = jax.tree_util.tree_reduce(jnp.maximum, grad_norms) - mean_grad_norm = jax.tree_util.tree_reduce( - jnp.add, jax.tree_map(jnp.sum, grad_norms) - ) / jax.tree_util.tree_reduce(jnp.add, jax.tree_map(jnp.size, grad_norms)) - metrics = { - "accuracy": accuracy__, - "regularization_z_loss": z_loss_computed__, - "max_grad_norm": max_grad_norm, - "mean_grad_norm": mean_grad_norm, - "grad_norms": grad_norms, - } - if aux_loss__ is not None: - metrics.update({"aux_loss": aux_loss__}) - return state, loss__, metrics - - return casual_language_model_train_step - - -def create_casual_language_model_evaluation_step( - partition_spec=PartitionSpec(("dp", "fsdp"), "sp") -): - """ - The create_casual_language_model_evaluation_step function is used to create a function that calculates the loss - and accuracy of a model. It takes in a set of parameters, which are then passed into the state.apply_fn function - to generate logits for each token in the batch. The cross entropy loss and accuracy are then calculated from these - logits. - - :param partition_spec: Specify the partitioning of the model parameters - :return: A function that can be used to calculate the loss and accuracy of a model - - """ - - def casual_language_model_evaluation_step(state, batch_eval): - """ - The casual_language_model_evaluation_step function is used to calculate the loss and accuracy of a model. - It takes in a set of parameters, which are then passed into the state.apply_fn function - to generate logits for each token in the batch. The cross entropy loss and accuracy are then calculated from - these logits. - - :param state: Store the model parameters and other information about the training process - :param batch_eval: Pass the batch of data to the function - :return: The loss and accuracy of the model - - """ - batch_eval = with_sharding_constraint(batch_eval, partition_spec) - - def calculate_loss(params): - """ - The calculate_loss function is used to calculate the loss and accuracy of a model. - It takes in a set of parameters, which are then passed into the state.apply_fn function - to generate logits for each token in the batch. The cross entropy loss and accuracy are then calculated - from these logits. - - :param params: Pass the model parameters to the function - :return: The loss and the accuracy - - """ - labels = batch_eval.get("labels", None) - if labels is None: - labels = batch_eval["input_ids"][..., 1:] - else: - labels = labels[..., 1:] - model_outputs = state.apply_fn(params=params, **batch_eval, return_dict=True) - logits = model_outputs.logits - aux_loss = getattr(model_outputs, "aux_loss", None) - valid = jnp.where( - (batch_eval["attention_mask"][:, 1:].astype(jnp.float32) != 0) - & (labels > 0), - 1.0, - 0.0, - ) - loss, accuracy = cross_entropy_loss_and_accuracy( - logits[:, :-1, :], - labels, - valid, - ) - if aux_loss is not None: - loss += aux_loss - return loss, (accuracy, aux_loss) - - loss__, (accuracy__, aux_loss__) = calculate_loss(state.params) - return loss__, accuracy__, aux_loss__ - - return casual_language_model_evaluation_step +from fjformer.func.loss_func import ( + cross_entropy_loss_and_accuracy, + SpecialLossNormalizingFactor, + get_loss_normalizing_factor_and_weights, + compute_weighted_cross_entropy_and_accuracy, +) + +import jax +from jax.sharding import PartitionSpec +from jax import numpy as jnp +from fjformer import with_sharding_constraint + + +def create_casual_language_model_train_step( + partition_spec=PartitionSpec(("dp", "fsdp"), "sp"), + label_smoothing_factor=0.0, + z_loss=0.0, + gradient_accumulation_steps: int = 1, +): + """The create_casual_language_model_train_step function is a training step function that takes in the current state + of the model,and a batch of data. It then calculates the loss and accuracy for this batch, and returns + an updated state with new parameters based on these gradients. + + Args: + partition_spec: Specify which devices the model will be split + across + label_smoothing_factor: A float in [0, 1] specifying the amount + of label smoothing to apply, where 0 means no smoothing. + z_loss: A regularization term that adds a penalty for large + weights, where 0 means no regularization. + gradient_accumulation_steps: int : gradient accumulation step + size from arguments + + Returns: + A casual_language_model_train_step function that takes in the + current state of the model, + """ + assert gradient_accumulation_steps > 0, "gradient_accumulation_steps must be greater than 0" # Ignore + + def casual_language_model_train_step(state, batch): + """The casual_language_model_train_step function is a training step function that takes in the current state + of the model and a batch of data. It then calculates the loss and accuracy for this batch, + and returns an updated state with new parameters based on these gradients. + + Args: + state: Store the model parameters + batch: Pass the data to the model, dict with input_ids(bs, + seq_len), labels(bs, seq_len-1), attention_mask(bs, + seq_len) + + Returns: + A tuple of (state, loss, accuracy) + """ + batch = with_sharding_constraint(batch, partition_spec) + + def calculate_loss(params): + labels = batch.get("labels", None) + if labels is None: + labels = batch["input_ids"][..., 1:] + else: + labels = labels[..., 1:] + model_outputs = state.apply_fn(params=params, **batch, return_dict=True) + logits = model_outputs.logits + aux_loss = getattr(model_outputs, "aux_loss", None) + loss_normalizing_factor = ( + SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS + ) + # loss_weights is 1 unless the label is <= 0 or the attention mask is 0 + loss_weights = jnp.where( + (batch["attention_mask"][:, 1:] != 0) & (labels > 0), 1, 0 + ) + lnf, weights = get_loss_normalizing_factor_and_weights( + loss_normalizing_factor, + { + "decoder_target_tokens": labels, + "decoder_loss_weights": loss_weights, + }, + ) + ( + loss, + z_loss_computed, + weight_sum, + accuracy, + ) = compute_weighted_cross_entropy_and_accuracy( + logits=logits[:, :-1, :], + targets=labels, + weights=weights, + label_smoothing=label_smoothing_factor, + z_loss=z_loss, + loss_normalizing_factor=lnf, + ) + if aux_loss is not None: + loss += aux_loss + return loss, (accuracy, z_loss_computed, aux_loss) + + grad_fn = jax.value_and_grad(calculate_loss, has_aux=True) + (loss__, (accuracy__, z_loss_computed__, aux_loss__)), grad = grad_fn(state.params) + state = state.apply_gradients(grads=grad) + + grad_norms = jax.tree_map(jnp.linalg.norm, grad) + max_grad_norm = jax.tree_util.tree_reduce(jnp.maximum, grad_norms) + mean_grad_norm = jax.tree_util.tree_reduce( + jnp.add, jax.tree_map(jnp.sum, grad_norms) + ) / jax.tree_util.tree_reduce(jnp.add, jax.tree_map(jnp.size, grad_norms)) + metrics = { + "accuracy": accuracy__, + "regularization_z_loss": z_loss_computed__, + "max_grad_norm": max_grad_norm, + "mean_grad_norm": mean_grad_norm, + "grad_norms": grad_norms, + } + if aux_loss__ is not None: + metrics.update({"aux_loss": aux_loss__}) + return state, loss__, metrics + + return casual_language_model_train_step + + +def create_casual_language_model_evaluation_step( + partition_spec=PartitionSpec(("dp", "fsdp"), "sp") +): + """The create_casual_language_model_evaluation_step function is used to create a function that calculates the loss + and accuracy of a model. It takes in a set of parameters, which are then passed into the state.apply_fn function + to generate logits for each token in the batch. The cross entropy loss and accuracy are then calculated from these + logits. + + Args: + partition_spec: Specify the partitioning of the model parameters + + Returns: + A function that can be used to calculate the loss and accuracy + of a model + """ + + def casual_language_model_evaluation_step(state, batch_eval): + """The casual_language_model_evaluation_step function is used to calculate the loss and accuracy of a model. + It takes in a set of parameters, which are then passed into the state.apply_fn function + to generate logits for each token in the batch. The cross entropy loss and accuracy are then calculated from + these logits. + + Args: + state: Store the model parameters and other information + about the training process + batch_eval: Pass the batch of data to the function + + Returns: + The loss and accuracy of the model + """ + batch_eval = with_sharding_constraint(batch_eval, partition_spec) + + def calculate_loss(params): + """ + The calculate_loss function is used to calculate the loss and accuracy of a model. + It takes in a set of parameters, which are then passed into the state.apply_fn function + to generate logits for each token in the batch. The cross entropy loss and accuracy are then calculated + from these logits. + + :param params: Pass the model parameters to the function + :return: The loss and the accuracy + + """ + labels = batch_eval.get("labels", None) + if labels is None: + labels = batch_eval["input_ids"][..., 1:] + else: + labels = labels[..., 1:] + model_outputs = state.apply_fn(params=params, **batch_eval, return_dict=True) + logits = model_outputs.logits + aux_loss = getattr(model_outputs, "aux_loss", None) + valid = jnp.where( + (batch_eval["attention_mask"][:, 1:].astype(jnp.float32) != 0) + & (labels > 0), + 1.0, + 0.0, + ) + loss, accuracy = cross_entropy_loss_and_accuracy( + logits[:, :-1, :], + labels, + valid, + ) + if aux_loss is not None: + loss += aux_loss + return loss, (accuracy, aux_loss) + + loss__, (accuracy__, aux_loss__) = calculate_loss(state.params) + return loss__, accuracy__, aux_loss__ + + return casual_language_model_evaluation_step diff --git a/src/python/easydel/trainer/causal_language_model_trainer/modeling_output.py b/src/python/easydel/trainer/causal_language_model_trainer/modeling_output.py index 858c697eb..599d01b24 100644 --- a/src/python/easydel/trainer/causal_language_model_trainer/modeling_output.py +++ b/src/python/easydel/trainer/causal_language_model_trainer/modeling_output.py @@ -1,15 +1,15 @@ -from dataclasses import dataclass -import jax -from typing import Any, Optional, Callable, Mapping -from ...etils.easystate import EasyDeLState - - -@dataclass -class CausalLMTrainerOutput: - state: EasyDeLState - mesh: Optional[jax.sharding.Mesh] - checkpoint_manager: Any - gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None - shard_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None - last_save_file_name: Optional[str] = None - checkpoint_path: Optional[str] = None +from dataclasses import dataclass +import jax +from typing import Any, Optional, Callable, Mapping +from ...etils.easystate import EasyDeLState + + +@dataclass +class CausalLMTrainerOutput: + state: EasyDeLState + mesh: Optional[jax.sharding.Mesh] + checkpoint_manager: Any + gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None + shard_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None + last_save_file_name: Optional[str] = None + checkpoint_path: Optional[str] = None diff --git a/src/python/easydel/trainer/dpo/__init__.py b/src/python/easydel/trainer/dpo/__init__.py index c7667a999..30d8f00d0 100644 --- a/src/python/easydel/trainer/dpo/__init__.py +++ b/src/python/easydel/trainer/dpo/__init__.py @@ -1,19 +1,19 @@ -from .modelling_output import DPOTrainerOutput as DPOTrainerOutput -from .fwd_bwd_functions import ( - create_dpo_train_function as create_dpo_train_function, - create_dpo_eval_function as create_dpo_eval_function, - create_concatenated_forward as create_concatenated_forward, - get_batch_log_probs as get_batch_log_probs, - concatenated_inputs as concatenated_inputs -) -from .dpo_trainer import DPOTrainer as DPOTrainer - -__all__ = ( - "DPOTrainer", - "create_dpo_train_function", - "create_dpo_eval_function", - "create_concatenated_forward", - "get_batch_log_probs", - "concatenated_inputs", - "DPOTrainerOutput" -) +from .modelling_output import DPOTrainerOutput as DPOTrainerOutput +from .fwd_bwd_functions import ( + create_dpo_train_function as create_dpo_train_function, + create_dpo_eval_function as create_dpo_eval_function, + create_concatenated_forward as create_concatenated_forward, + get_batch_log_probs as get_batch_log_probs, + concatenated_inputs as concatenated_inputs +) +from .dpo_trainer import DPOTrainer as DPOTrainer + +__all__ = ( + "DPOTrainer", + "create_dpo_train_function", + "create_dpo_eval_function", + "create_concatenated_forward", + "get_batch_log_probs", + "concatenated_inputs", + "DPOTrainerOutput" +) diff --git a/src/python/easydel/trainer/dpo/dpo_trainer.py b/src/python/easydel/trainer/dpo/dpo_trainer.py index c406347fb..823c04a8e 100644 --- a/src/python/easydel/trainer/dpo/dpo_trainer.py +++ b/src/python/easydel/trainer/dpo/dpo_trainer.py @@ -1,1274 +1,1274 @@ -import copy -import os -import sys -import time -import typing -import warnings -from abc import ABC -from collections import defaultdict -import flax.core -import jax -import tensorflow.data -import tensorflow_datasets -import termcolor -import wandb -from fjformer import match_partition_rules, make_shard_and_gather_fns -from tqdm import tqdm - -from typing import Optional, Literal, Dict, Union, Any, Callable, Mapping - -from jax.experimental.pjit import pjit -from datasets import Dataset -from jax import numpy as jnp - -from ...etils.etils import get_logger -from ..training_configurations import TrainArguments -from ..base_trainer import ( - BaseTrainer, - TrainerConfigureFunctionFuncOutput, - TrainerConfigureDataloaderFuncOutput, - TrainerConfigureModelFuncOutput -) -from ...etils import EasyDeLState, EasyDeLTimerError -from transformers import PreTrainedTokenizerBase -from jax.sharding import PartitionSpec - -from ...utils import Timers -from .utils import ( - pad_to_length, - DPODataCollatorWithPadding, - leave_alone_context_manager -) -from .fwd_bwd_functions import ( - create_dpo_train_function, - create_dpo_eval_function, - create_concatenated_forward, -) -from .modelling_output import DPOTrainerOutput - -logger = get_logger(__name__) - - -class DPOTrainer(BaseTrainer, ABC): - """ - easydel DPO Trainer Class - """ - - def __init__( - self, - arguments: TrainArguments, - model_state: EasyDeLState | str, - ref_model_state: Optional[EasyDeLState | str] = None, - beta: float = 0.1, - label_smoothing: float = .0, - loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid", - label_pad_token_id: int = -100, - padding_value: int = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, - data_collator: Optional[Callable] = None, - max_length: Optional[int] = None, - max_prompt_length: Optional[int] = None, - max_target_length: Optional[int] = None, - precompute_ref_log_probs: bool = False, - model_init_kwargs: Optional[Dict] = None, - ref_model_init_kwargs: Optional[Dict] = None, - reference_free: bool = False, - auto_shard_model_state: bool = True, - auto_shard_ref_model_state: bool = True, - is_encoder_decoder: Optional[bool] = False, - dataset_map_arguments: Optional[dict] = None, - low_mem_usage: bool = True, - auto_fix_data: bool = True, - _do_init_fns: bool = True, - ): - - """ - The __init__ function is called when the class is instantiated. - It sets up the attributes of an object. - - - :param self: Refer to the object itself - :param model_state: EasyDeLState | str: Pass the model state to the trainer - :param ref_model_state: Optional[EasyDeLState | str]: Pass the reference model state - :param beta: float: Control the strength of the regularization term - :param label_smoothing: float: Smooth the labels - :param loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] : Determine the loss function used - :param arguments: TrainArguments: Pass the arguments to the trainer - :param label_pad_token_id: int: Pad the labels - :param padding_value: int: Specify the value that is used for padding - :param train_dataset: Optional[Dataset]: Load the training dataset - :param eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] : Pass the evaluation dataset to the trainer - :param tokenizer: Optional[PreTrainedTokenizerBase]: Pass the tokenizer to the trainer - :param max_length: Optional[int]: Set the maximum length of the input sequence - :param max_prompt_length: Optional[int]: Set the maximum length of the prompt - :param max_target_length: Optional[int]: Truncate the target sequence - :param data_collator: Optional[Callable]: Function to be used for creating datasets. - :param precompute_ref_log_probs: bool: Precompute the log probabilities of the reference model - :param model_init_kwargs: Optional[Dict]: Pass in the model_kwargs to model for init process - :param ref_model_init_kwargs: Optional[Dict]: Pass the ref_model_init_kwargs to ref_model for init process - :param auto_shard_model_state: bool: whenever to automatically shard `model_state` - :param auto_shard_ref_model_state: bool: whenever to automatically shard `ref_model_state` - :param dataset_map_arguments: Optional[dict]: arguments to be passed to train and eval datasets for - tokenizing process with `dataset.map`. - :param _do_init_fns: bool : preferred to set ture to trainer will automatically configure - model with provided training Arguments - :param : Set the padding value for the model - """ - assert arguments is not None, ( - "You Have to pass arguments that will be used for training but you have passed" - "`arguments=None`" - ) - assert isinstance(arguments, TrainArguments), ( - f"arguments type must be `TrainArguments` but got {type(arguments)}" - ) - if model_init_kwargs is None: - model_init_kwargs = {} - elif not isinstance(model_state, str): - raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.") - - if ref_model_init_kwargs is None: - ref_model_init_kwargs = {} - elif not isinstance(ref_model_state, str): - raise ValueError( - "You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated." - ) - - if isinstance(model_state, str): - warnings.warn( - "You passed a model_id to the DPOTrainer. This will automatically create an " - "`AutoEasyDeLModelForCausalLM` for you." - ) - model_state = EasyDeLState.from_pretrained( - model_state, - **model_init_kwargs - ) - if isinstance(ref_model_state, str): - warnings.warn( - "You passed a ref model_id to the DPOTrainer. This will automatically create an " - "`AutoEasyDeLModelForCausalLM`" - ) - ref_model_state = EasyDeLState.from_pretrained( - ref_model_state, - **ref_model_init_kwargs - ) - - if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0: - warnings.warn( - "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter." - ) - self.auto_fix_data = auto_fix_data - - if tokenizer is None: - raise ValueError("tokenizer must be specified to tokenize a DPO dataset.") - if max_length is None: - warnings.warn( - "`max_length` is not set in the DPOTrainer's init" - " it will default to `512` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_length = 512 - if max_prompt_length is None: - warnings.warn( - "`max_prompt_length` is not set in the DPOTrainer's init" - " it will default to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_prompt_length = 128 - - if max_target_length is None and is_encoder_decoder: - warnings.warn( - "When using an encoder decoder architecture, you should set `max_target_length` in the " - "DPOTrainer's init it will default to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_target_length = 128 - - padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id - self.max_length = max_length - self.label_pad_token_id = label_pad_token_id - self.padding_value = padding_value - self.max_prompt_length = max_prompt_length - self.truncation_mode = arguments.truncation_mode - - self.max_target_length = max_target_length - self.tokenizer = tokenizer - self.precompute_ref_log_probs = precompute_ref_log_probs - self.reference_free = reference_free - self.is_encoder_decoder = False - self._precomputed_train_ref_log_probs = False - self._precomputed_eval_ref_log_probs = False - self.beta = beta - self.label_smoothing = label_smoothing - self.loss_type = loss_type - self.low_mem_usage = low_mem_usage - data_collator = DPODataCollatorWithPadding( - max_prompt_length=self.max_prompt_length, - max_target_length=self.max_target_length, - pad_token_id=tokenizer.pad_token_id, - label_pad_token_id=label_pad_token_id, - is_encoder_decoder=False, - ) if data_collator is None else data_collator - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - if dataset_map_arguments is None: - dataset_map_arguments = {} - train_dataset = train_dataset.map( - self.tokenize_row, - **dataset_map_arguments - ) - if eval_dataset is not None: - eval_dataset = eval_dataset.map( - self.tokenize_row, - **dataset_map_arguments - ) - - self.arguments = arguments - self.hp_name = None - self.deepspeed = None - self.is_in_train = False - - self.data_collator = data_collator - self.train_dataset = train_dataset - self.eval_dataset = eval_dataset - self.tokenizer = tokenizer - self.ref_model_state = ref_model_state - self.model_state = model_state - self._loggers_initialized = False - self.mesh = self.arguments.get_mesh() - assert padding_value is not None, "`padding_value` can not be set as `None` it must be an integer." - - self.concatenated_forward = create_concatenated_forward( - is_encoder_decoder=self.is_encoder_decoder, - padding_value=padding_value, - label_pad_token_id=label_pad_token_id, - ) - self.auto_shard_ref_model_state = auto_shard_ref_model_state - self.auto_shard_model_state = auto_shard_model_state - - self._cached_p_l_s = None - self._cached_c_l_s = None - self._cached_r_l_s = None - super().__init__( - arguments=arguments, - dataset_train=train_dataset, - dataset_eval=eval_dataset, - finetune=True, - checkpoint_path=None, - _do_init_fns=_do_init_fns - ) - - def initialize_trainer_utils(self): - """ - The initialize_trainer_utils function is responsible for initializing the following: - - wandb_runtime (if you use_wandb is True) - - timer object (for logging time taken by various functions) - - dataloader objects for training and evaluation data, along with max steps per epoch. - The configure_dataloader function accomplishes this task. - - :param self: Represent the instance of the class - :return: A tuple of functions - - """ - self.wandb_runtime = self.arguments.get_wandb_init() if self.arguments.use_wandb else None - self.timer = Timers( - use_wandb=False, - tensorboard_writer=self.arguments.get_board() - ) - - self.timer("configure dataloaders").start() - dataset_configurations = self.configure_dataloader() - self.dataloader_train = dataset_configurations.dataloader_train - self.max_training_steps = dataset_configurations.max_training_steps - self.dataloader_eval = dataset_configurations.dataloader_eval - self.max_evaluation_steps = dataset_configurations.max_evaluation_steps - - self.timer("configure dataloaders").stop() - - self.timer.log(["configure dataloaders"]) - - self.timer("configure Model, Optimizer, Scheduler and Config").start() - model_configurations = self.configure_model() - model = model_configurations.model - tx = model_configurations.tx - scheduler = model_configurations.scheduler - config = model_configurations.config - self.model = model - self.tx = tx - self.scheduler = scheduler - self.config = config - if self.rapture is not None: - lora_modules = self.rapture.apply_lora( - module=model, - parameters=self.arguments.rapture_config.parameters, - tx=tx, - ) - self.lora_parameters = lora_modules.lora_parameters - self.lora_apply_fn = lora_modules.lora_module.__call__ - self.lora_opt_state = lora_modules.lora_opt_state - self.lora_model = lora_modules.lora_module - self.lora_tx = lora_modules.lora_tx - - self.timer("configure Model, Optimizer, Scheduler and Config").stop() - self.timer.log(["configure Model, Optimizer, Scheduler and Config"]) - - self.timer("configure functions and sharding them").start() - - if self.auto_shard_model_state: - self.timer("Sharding Model State").start() - self.model_state: EasyDeLState = self.shard_states( - self.model_state, - self.model_state.module.config.get_partition_rules(self.arguments.fully_sharded_data_parallel) - ) - - termcolor.cprint("initializing TX and Schedulers for `model_state`", force_color=True, color="cyan") - - params_with_opt = ( - self.model_state.params[ - 'params' - ] if '_overwrite_with_gradient' in self.model_state.params else self.model_state.params - ) - opt_state = self.tx.init(params_with_opt) - - self.model_state = self.model_state.replace( - opt_state=opt_state, - tx=self.tx - ) - - self.timer("Sharding Model State").stop() - self.timer.log(["Sharding Model State"]) - if self.auto_shard_ref_model_state and self.ref_model_state is not None: - self.timer("Sharding Ref Model State").start() - self.ref_model_state = self.shard_states( - self.ref_model_state, - self.ref_model_state.module.config.get_partition_rules(self.arguments.fully_sharded_data_parallel) - ) - self.timer("Sharding Ref Model State").stop() - self.timer.log(["Sharding Ref Model State"]) - - function_configurations = self.configure_functions() - self.create_sharded_state_from_params_function = ( - function_configurations.create_sharded_state_from_params_function - ) - self.sharded_train_step_function = function_configurations.sharded_train_step_function - self.sharded_eval_step_function = function_configurations.sharded_eval_step_function - self.mesh = function_configurations.mesh - self.checkpoint_manager = function_configurations.checkpoint_manager - self.initialize_state_function = function_configurations.initialize_state_function - self.timer("configure functions and sharding them").stop() - self.timer.log(["configure functions and sharding them"]) - - def create_collate_function( - self, - max_sequence_length: int, - truncation_mode: typing.Literal["keep_end", "keep_start"] = "keep_end", - ) -> Callable: - return self.data_collator - - def shard_states(self, state, rules): - with self.arguments.get_mesh(): - partition_spec = match_partition_rules(rules=rules, params=jax.eval_shape(lambda: state)) - - def _shard(x): - return x - - shard = pjit( - _shard, - in_shardings=(PartitionSpec(),), - out_shardings=partition_spec - ) - return shard(state) - - def configure_dataloader(self) -> TrainerConfigureDataloaderFuncOutput: - dataloader_train = self.get_train_dataloader() - max_evaluation_steps = None - dataloader_eval = None - - max_training_steps = self.arguments.num_train_epochs * len( - dataloader_train - ) if self.arguments.max_training_steps is None else self.arguments.max_training_steps - if self.eval_dataset is not None: - dataloader_eval = self.get_eval_dataloader(self.eval_dataset) - max_evaluation_steps = len(dataloader_eval) - return TrainerConfigureDataloaderFuncOutput( - dataloader_train=dataloader_train, # type:ignore - max_training_steps=max_training_steps, - dataloader_eval=dataloader_eval, - max_evaluation_steps=max_evaluation_steps - ) - - def configure_functions(self) -> TrainerConfigureFunctionFuncOutput: - def initialize_state_function(): - initialized_parameters = self.model.init_weights( - jax.random.PRNGKey(0), - self.arguments.init_input_shape - ) - - if self.arguments.dtype == jnp.bfloat16: - initialized_parameters = self.model.to_bf16(initialized_parameters) - elif self.arguments.dtype == jnp.float16: - initialized_parameters = self.model.to_fp16(initialized_parameters) - - tx = self.tx - parameters = flax.core.freeze({"params": initialized_parameters}) - tx_init = copy.deepcopy(self.arguments.optimizer_kwargs) - - if self.rapture is not None: - lora_parameters = self.lora_parameters - if self.arguments.dtype == jnp.bfloat16: - lora_parameters = self.model.to_bf16(lora_parameters) - elif self.arguments.dtype == jnp.float16: - lora_parameters = self.model.to_fp16(lora_parameters) - - return EasyDeLState( - step=0, - apply_fn=self.lora_apply_fn, - params=lora_parameters, - tx=self.lora_tx, - opt_state=self.lora_opt_state, - tx_init=EasyDeLState.safe_dict(tx_init), - hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), - module=self.lora_model, - module_config=self.model_state.module.config, - module_config_args=None, - ) - else: - return EasyDeLState.create( - tx=tx, - params=parameters, - apply_fn=self.model.__call__, - module_config=copy.deepcopy(self.model_state.module.config), - tx_init=tx_init, - hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), - module=self.model, - module_config_args=None - ) - - def create_state_from_params_function(parameters): - if self.rapture is None: - return EasyDeLState.create( - tx=self.tx, - params=parameters, - apply_fn=self.model.__call__, - module_config=copy.deepcopy(self.model_state.module.config), - tx_init=copy.deepcopy(self.arguments.optimizer_kwargs), - hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), - module=self.model, - module_config_args=None - ) - else: - return EasyDeLState( - step=0, - apply_fn=self.lora_apply_fn, - params=parameters, - tx=self.lora_tx, - opt_state=self.lora_opt_state, - tx_init=EasyDeLState.safe_dict(copy.deepcopy(self.arguments.optimizer_kwargs)), - hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), - module=self.lora_model, - module_config=self.model_state.module.config, - module_config_args=None, - ) - - state_shape = jax.eval_shape(lambda: self.model_state) - - state_partition_spec = match_partition_rules( - self.config.get_partition_rules( - fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel - ) if self.arguments.custom_rule is None else self.arguments.custom_rule, - state_shape - ) - create_sharded_state_from_params_function = pjit( - create_state_from_params_function, - in_shardings=(state_partition_spec.params,), - out_shardings=state_partition_spec, - donate_argnums=(0,) - ) - train_function = create_dpo_train_function( - concatenated_forward=self.concatenated_forward, - ref_state=self.ref_model_state, - loss_type=self.loss_type, - reference_free=self.reference_free, - label_smoothing=self.label_smoothing, - beta=self.beta - ) - sharded_train_step_function = pjit( - train_function, - in_shardings=(state_partition_spec, self.arguments.step_partition_spec), - out_shardings=(state_partition_spec, PartitionSpec()), - ) - - eval_function = create_dpo_eval_function( - concatenated_forward=self.concatenated_forward, - ref_state=self.ref_model_state, - loss_type=self.loss_type, - reference_free=self.reference_free, - label_smoothing=self.label_smoothing, - beta=self.beta - ) - - sharded_eval_step_function = pjit( - eval_function, - in_shardings=(state_partition_spec, self.arguments.step_partition_spec), - out_shardings=(state_partition_spec, PartitionSpec()), - ) - - self.arguments.ckpt_path_exists() - self.state_partition_spec = state_partition_spec - self.state_shape = state_shape - checkpoint_manager = self.arguments.get_streaming_checkpointer() - mesh = self.arguments.get_mesh() - return TrainerConfigureFunctionFuncOutput( - initialize_state_function=initialize_state_function, - sharded_train_step_function=sharded_train_step_function, - create_sharded_state_from_params_function=create_sharded_state_from_params_function, - checkpoint_manager=checkpoint_manager, - mesh=mesh, - sharded_eval_step_function=sharded_eval_step_function - ) - - def configure_model(self) -> TrainerConfigureModelFuncOutput: - config = self.model_state.module.config - tx, scheduler = self.arguments.get_optimizer_and_scheduler(self.max_training_steps) - return TrainerConfigureModelFuncOutput( - model=self.model_state.module, - config=config, # type: ignore - scheduler=scheduler, - tx=tx - ) - - def _get_train_dataloader(self) -> tensorflow.data.Dataset: - - """ - The _get_train_dataloader function is used to create a tensorflow.data.Dataset object for the training dataset. - - :param self: Represent the instance of the class - :return: A dataloader object - """ - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - train_dataset = self.train_dataset - data_collator = self.data_collator - - return tensorflow_datasets.as_numpy( - train_dataset.to_tf_dataset( - batch_size=self.arguments.total_batch_size, - collate_fn=data_collator, - num_workers=self.arguments.dataloader_num_workers, - shuffle=True, - drop_remainder=True - ) - ) - - def _get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> tensorflow.data.Dataset: - """ - Returns the evaluation [`~tensorflow.data.Dataset`]. - - Subclass and override this method if you want to inject some custom behavior. - - Args: - eval_dataset (`torch.utils.data.Dataset`, *optional*): - If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted - by the `model.forward()` method are automatically removed. It must implement `__len__`. - """ - if eval_dataset is None and self.eval_dataset is None: - raise ValueError("Trainer: evaluation requires an eval_dataset.") - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - - return tensorflow_datasets.as_numpy( - eval_dataset.to_tf_dataset( - batch_size=self.arguments.total_batch_size, - collate_fn=self.data_collator, - num_workers=self.arguments.dataloader_num_workers, - shuffle=False, - drop_remainder=True - ) - ) - - def get_train_dataloader( - self, - ) -> tensorflow.data.Dataset: - """ - Returns the training [`~tensorflow.data.Dataset`]. - """ - - if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: - - data_loader = tensorflow_datasets.as_numpy( - self.train_dataset.to_tf_dataset( - batch_size=self.arguments.total_batch_size, - collate_fn=self.data_collator, - num_workers=self.arguments.dataloader_num_workers, - shuffle=False, - drop_remainder=True - ) - ) - reference_chosen_log_probs = [] - reference_rejected_log_probs = [] - for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): - reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs( - self.model_state, - padded_batch, - ) - reference_chosen_log_probs.append(reference_chosen_logp) - reference_rejected_log_probs.append(reference_rejected_logp) - - all_reference_chosen_log_probs = jnp.concatenate(reference_chosen_log_probs) - all_reference_rejected_log_probs = jnp.concatenate(reference_rejected_log_probs) - self.train_dataset = self.train_dataset.add_column( - name="reference_chosen_log_probs", column=all_reference_chosen_log_probs - ) - self.train_dataset = self.train_dataset.add_column( - name="reference_rejected_log_probs", column=all_reference_rejected_log_probs - ) - - self._precomputed_train_ref_log_probs = True - return self._get_train_dataloader() - - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> tensorflow.data.Dataset: - """ - Returns the evaluation [`~tensorflow.data.Dataset`]. - """ - if eval_dataset is None and self.eval_dataset is None: - raise ValueError("Trainer: evaluation requires an eval_dataset.") - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - - if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: - - # prepare dataloader - data_loader = tensorflow_datasets.as_numpy( - eval_dataset.to_tf_dataset( - batch_size=self.arguments.total_batch_size, - collate_fn=self.data_collator, - num_workers=self.arguments.dataloader_num_workers, - shuffle=False, - drop_remainder=True - ) - ) - - reference_chosen_log_probs = [] - reference_rejected_log_probs = [] - for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): - reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs( - self.model_state, - padded_batch - ) - reference_chosen_log_probs.append(reference_chosen_logp.cpu()) - reference_rejected_log_probs.append(reference_rejected_logp.cpu()) - - all_reference_chosen_log_probs = jnp.concatenate(reference_chosen_log_probs) - all_reference_rejected_log_probs = jnp.concatenate(reference_rejected_log_probs) - - eval_dataset = eval_dataset.add_column(name="reference_chosen_log_probs", - column=all_reference_chosen_log_probs) - eval_dataset = eval_dataset.add_column( - name="reference_rejected_log_probs", column=all_reference_rejected_log_probs - ) - - if self.eval_dataset is not None: - self.eval_dataset = eval_dataset - self._precomputed_eval_ref_log_probs = True - - return self._get_eval_dataloader(eval_dataset=eval_dataset) - - def build_tokenized_answer(self, prompt, answer): - """ - Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. - It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`. - """ - - full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) - prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] - - answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids):] - answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids):] - prompt_input_ids = jnp.asarray(prompt_input_ids, dtype="i4") - answer_input_ids = jnp.asarray(answer_input_ids, dtype="i4") - full_concat_input_ids = jnp.concatenate( - ( - prompt_input_ids, - answer_input_ids - ) - ) - - # Prepare input tokens for token by token comparison - full_input_ids = jnp.array(full_tokenized["input_ids"]) - - if len(full_input_ids) != len(full_concat_input_ids): - raise ValueError("Prompt input ids and answer input ids should have the same length.") - - response_token_ids_start_idx = len(prompt_input_ids) - if prompt_input_ids.tolist() != full_tokenized["input_ids"][:response_token_ids_start_idx]: - response_token_ids_start_idx -= 1 - - prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] - prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] - - if len(prompt_input_ids) != len(prompt_attention_mask): - raise ValueError("Prompt input ids and attention mask should have the same length.") - - answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] - answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] - - return dict( - prompt_input_ids=jnp.array(prompt_input_ids, dtype="i4"), - prompt_attention_mask=jnp.array(prompt_attention_mask, dtype="i4"), - input_ids=jnp.array(answer_input_ids, dtype="i4"), - attention_mask=jnp.array(answer_attention_mask, dtype="i4"), - ) - - def tokenize_row(self, feature, state: EasyDeLState = None) -> Dict: - - """ - The tokenize_row function is responsible for taking a single row of data and converting it into the format that - the model expects. This includes: - - Tokenizing the text (using HuggingFace's tokenizer) - - Padding/truncating sequences to a fixed length (if necessary) - - Creating attention masks, which tell the model which tokens are padding and which aren't. - - :param self: Represent the instance of the class - :param feature: Pass in the data from the dataset - :param state: EasyDeLState: Keep track of the state of the tokenizer - :return: A dictionary of the following keys - """ - batch = {} - prompt = feature["prompt"] - chosen = feature["chosen"] - rejected = feature["rejected"] - - if not isinstance(prompt, str): - raise ValueError(f"prompt should be an str but got {type(prompt)} , {prompt}") - prompt_tokens = self.tokenizer( - prompt, - add_special_tokens=False, - return_tensors="np", - ) - prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} - - if not isinstance(chosen, str): - raise ValueError(f"chosen should be an str but got {type(chosen)} , {chosen}") - chosen_tokens = self.build_tokenized_answer(prompt, chosen) - - if not isinstance(rejected, str): - raise ValueError(f"rejected should be an str but got {type(rejected)}") - rejected_tokens = self.build_tokenized_answer(prompt, rejected) - v2d = lambda ar: ar.reshape(1, -1) if ar.ndim == 1 else ar - - def add_tkn(n, ar): - return jnp.concatenate( - ( - jnp.array(n).reshape(1, 1), - v2d(ar) - ), axis=-1 - ) - - def add_post_tkn(n, ar): - return jnp.concatenate( - ( - v2d(ar), - jnp.array(n).reshape(1, 1) - ), axis=-1 - ) - - prompt_tokens["prompt_input_ids"] = add_tkn( - self.tokenizer.bos_token_id, - prompt_tokens["prompt_input_ids"] - ) - chosen_tokens["prompt_input_ids"] = add_tkn( - self.tokenizer.bos_token_id, - chosen_tokens["prompt_input_ids"] - ) - rejected_tokens["prompt_input_ids"] = add_tkn( - self.tokenizer.bos_token_id, - rejected_tokens["prompt_input_ids"] - ) - - prompt_tokens["prompt_attention_mask"] = add_tkn( - 1, prompt_tokens["prompt_attention_mask"] - ) - chosen_tokens["prompt_attention_mask"] = add_tkn( - 1, chosen_tokens["prompt_attention_mask"] - ) - rejected_tokens["prompt_attention_mask"] = add_tkn( - 1, rejected_tokens["prompt_attention_mask"] - ) - - # add EOS token to end of answer - chosen_tokens["input_ids"] = add_post_tkn(self.tokenizer.eos_token_id, chosen_tokens["input_ids"]) - chosen_tokens["attention_mask"] = add_post_tkn(1, chosen_tokens["attention_mask"]) - - rejected_tokens["input_ids"] = add_post_tkn(self.tokenizer.eos_token_id, rejected_tokens["input_ids"]) - rejected_tokens["attention_mask"] = add_post_tkn(1, rejected_tokens["attention_mask"]) - - longer_response_length = max(chosen_tokens["input_ids"].shape[-1], rejected_tokens["input_ids"].shape[-1]) - - # if combined sequence is too long, truncate the prompt - for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: - length_rn = answer_tokens["prompt_input_ids"].shape[-1] + longer_response_length - if length_rn > self.max_length: - - if self.truncation_mode == "keep_start": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][:, : self.max_prompt_length] - elif self.truncation_mode == "keep_end": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][:, -self.max_prompt_length:] - else: - raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") - # if that's still too long, truncate the response - for answer_tokens in [chosen_tokens, rejected_tokens]: - if answer_tokens["prompt_input_ids"].shape[-1] + longer_response_length > self.max_length: - for k in ["input_ids", "attention_mask"]: - answer_tokens[k] = answer_tokens[k][:, : self.max_length - self.max_prompt_length] - - chosen_sequence_tokens = { - k: jnp.concatenate( - (v2d(chosen_tokens[f"prompt_{k}"]), v2d(chosen_tokens[k])), - axis=-1 - ) for k in ["input_ids", "attention_mask"] - } - rejected_sequence_tokens = { - k: jnp.concatenate( - (v2d(rejected_tokens[f"prompt_{k}"]), v2d(rejected_tokens[k])), - axis=-1 - ) for k in ["input_ids", "attention_mask"] - } - chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] - chosen_sequence_tokens["labels"] = chosen_sequence_tokens["labels"].at[ - : len(chosen_tokens["prompt_input_ids"]) - ].set([self.label_pad_token_id] * len(chosen_tokens["prompt_input_ids"])) - rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] - rejected_sequence_tokens["labels"] = rejected_sequence_tokens["labels"].at[ - : len(rejected_tokens["prompt_input_ids"]) - ].set( - ([self.label_pad_token_id] * len(rejected_tokens["prompt_input_ids"])) - ) - - for k, tokens_ in { - "chosen_": chosen_sequence_tokens, - "rejected_": rejected_sequence_tokens, - "": prompt_tokens, - }.items(): - for type_key, tokens in tokens_.items(): - if type_key == "token_type_ids": - continue - - b, s = tokens.shape - - if self.max_prompt_length > s: - if k == "chosen_": - if type_key == "input_ids": - tokens = pad_to_length( - tokens, - self.max_target_length, - pad_value=self.padding_value, - axis=-1 - ) - elif type_key == "attention_mask": - tokens = pad_to_length( - tokens, - self.max_target_length, - pad_value=0, - axis=-1 - ) - elif type_key == "labels": - tokens = pad_to_length( - tokens, - self.max_target_length, - pad_value=self.padding_value, - axis=-1 - ) - - tokens = tokens[..., :self.max_target_length] - - if tokens.shape[-1] != self.max_target_length: - raise ValueError( - f"there was an error in padding token with `type_key` of {type_key}" - f". it must have sequence_length of {self.max_target_length} but we got {tokens.shape[-1]}" - f" From {k}{type_key}" - ) - tokens = tokens[..., :self.max_target_length] - elif k == "rejected_": - if type_key == "input_ids": - tokens = pad_to_length( - tokens, - self.max_target_length, - pad_value=self.padding_value, - axis=-1 - ) - elif type_key == "attention_mask": - tokens = pad_to_length( - tokens, - self.max_target_length, - pad_value=0, - axis=-1 - ) - elif type_key == "labels": - tokens = pad_to_length( - tokens, - self.max_target_length, - pad_value=self.padding_value, - axis=-1 - ) - tokens = tokens[..., :self.max_target_length] - if tokens.shape[-1] != self.max_target_length: - raise ValueError( - f"there was an error in padding token with `type_key` of {type_key}" - f". it must have sequence_length of {self.max_target_length} but we got {tokens.shape[-1]}" - f" From {k}{type_key}" - ) - elif k == "": - if type_key == "prompt_input_ids": - tokens = pad_to_length( - tokens, - self.max_prompt_length, - pad_value=self.padding_value, - axis=-1 - ) - elif type_key == "prompt_attention_mask": - tokens = pad_to_length( - tokens, - self.max_prompt_length, - pad_value=0, - axis=-1 - ) - elif type_key == "prompt_labels": - tokens = pad_to_length( - tokens, - self.max_prompt_length, - pad_value=self.padding_value, - axis=-1 - ) - tokens = tokens[..., :self.max_prompt_length] - if tokens.shape[-1] != self.max_prompt_length: - raise ValueError( - f"there was an error in padding token with `type_key` of {type_key}" - f". it must have sequence_length of {self.max_prompt_length} but we got {tokens.shape[-1]}" - f" From {k}{type_key}" - ) - batch[f"{k}{type_key}"] = tokens - return batch - - def compute_reference_log_probs( - self, - state: EasyDeLState, - padded_batch: Dict, - ) -> tuple[Any, Any]: - """ - Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset. - """ - - if self.ref_model_state is None: - ( - reference_chosen_log_probs, - reference_rejected_log_probs, - _, - _, - ) = self.concatenated_forward( - apply_fn=state.apply_fn, - params=state.params, - batch=padded_batch, - ) - else: - ( - reference_chosen_log_probs, - reference_rejected_log_probs, - _, - _, - ) = self.concatenated_forward( - apply_fn=self.ref_model_state.apply_fn, - params=self.ref_model_state.params, - batch=padded_batch, - ) - - return reference_chosen_log_probs, reference_rejected_log_probs - - def _save_state( - self, - state: EasyDeLState, - gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]], - milestone: bool = False - ) -> str: - step = int( - jax.device_get( - state.step - ) - ) + self.arguments.step_start_point if self.arguments.step_start_point is not None else int( - jax.device_get( - state.step - ) - ) - checkpoint_name = f"{self.arguments.model_name}-S{step}" - filename = f"{checkpoint_name}_{step}" if milestone else f"{checkpoint_name}" - filename += ".easy" - termcolor.cprint(f"Saving Model {filename}.", color="cyan", force_color=True) - state.save_state( - filename=filename, - checkpoint_dir=os.path.join(self.arguments.save_dir, self.arguments.model_name), - gather_fns=gather_fns, - float_dtype=self.dtype, - verbose=self.arguments.verbose, - save_optimizer=self.arguments.save_optimizer_state, - ) - return filename - - def train(self) -> DPOTrainerOutput: - assert self.model_state is not None, "model_state can not be None for training purpose" - with self.mesh: - with jax.default_device(jax.devices("cpu")[0]) if self.low_mem_usage else leave_alone_context_manager: - dir_prefix: str = "/dev/shm" if sys.platform != "win32" else "." - checkpoint_path = "SAVING_SKIPPED" - - pbar = tqdm(total=self.max_training_steps) - pbar.set_description("Training") - current_step = self.model_state.step.tolist() if isinstance( - self.model_state.step, - jax.Array - ) else self.model_state.step - - loss_sum = None - chosen_rewards_sum = None - rejected_rewards_sum = None - - try: - for epoch_index in range(self.arguments.num_train_epochs): - for batch in self.dataloader_train: - current_step += 1 - if self.arguments.step_start_point > current_step: - ... - elif current_step < self.max_training_steps: - time_start = time.time() - - self.model_state, metrics = self.sharded_train_step_function( - self.model_state, - batch - ) - total_time = time.time() - time_start - ( - loss, chosen_rewards, rejected_rewards - ) = metrics.loss, metrics.chosen_rewards[0], metrics.rejected_rewards[0] - - loss_sum = loss.tolist() if loss_sum is None else loss_sum + loss - - rejected_rewards_sum = ( - rejected_rewards.tolist() if ( - rejected_rewards_sum is None - ) else rejected_rewards_sum + rejected_rewards - ) - chosen_rewards_sum = ( - chosen_rewards.tolist() if ( - chosen_rewards_sum is None - ) else chosen_rewards_sum + chosen_rewards - ) - train_metrics = { - "train/loss": loss.tolist(), - "train/mean_loss": loss_sum / (current_step - self.arguments.step_start_point), - "train/mean_rejected_rewards": rejected_rewards_sum / ( - current_step - self.arguments.step_start_point - ), - "train/mean_chosen_rewards": chosen_rewards_sum / ( - current_step - self.arguments.step_start_point - ), - "train/learning_rate": self.scheduler( - jax.device_get(self.model_state.step) - ).tolist(), - "train/step": current_step, - "train/step_time": total_time, - "train/perplexity": jnp.exp(loss).tolist(), - "train/epoch": epoch_index - } - log_metrics = copy.deepcopy(train_metrics) - train_metrics.update(self.arguments.captured_memory) - if self.arguments.use_wandb: - with jax.spmd_mode("allow_all"): - self.wandb_runtime.log( - train_metrics - ) - pbar.update(1) - pbar.set_postfix(**{k.replace("train/", ""): v for k, v in log_metrics.items()}) - else: - break - except KeyboardInterrupt: - termcolor.cprint( - "KeyboardInterrupt At training model Will return Current State of the Model with Parameters.", - color="cyan", - force_color=True - ) - - except EasyDeLTimerError: - termcolor.cprint( - "Training reached out maximum training Time Killing training Process " - "and Will return Current State of the Model with Parameters.", - color="cyan", - force_color=True - ) - - if self.arguments.merge_lora_rapture_parameters and self.rapture is not None: - print( - termcolor.colored( - "Info : ", color="red", force_color=True - ), - termcolor.colored( - "Merging LoRA Parameters.", color="white", force_color=True - ) - ) - self.model_state = self.model_state.replace( - params=self.rapture.merge_parameters(self.model_state.params) - ) - - shard_fns, gather_fns = make_shard_and_gather_fns( - partition_specs=match_partition_rules( - rules=self.model_state.module.config.get_partition_rules( - self.arguments.fully_sharded_data_parallel - ), - params=jax.eval_shape(lambda: self.model_state) - ), - dtype_specs=self.arguments.dtype - ) - output = DPOTrainerOutput( - state=self.model_state, - mesh=self.mesh, - shard_fns=shard_fns, - gather_fns=gather_fns, - checkpoint_manager=self.checkpoint_manager, - ) - if self.arguments.save_steps is None and self.arguments.do_last_save: - shard_fns, gather_fns = make_shard_and_gather_fns( - match_partition_rules( - self.config.get_partition_rules( - fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel - ) if self.arguments.custom_rule is None else self.arguments.custom_rule, - jax.eval_shape(lambda: self.model_state) - ), - dtype_specs=self.dtype - ) # You have to re-init the new shard and gather functions in order to be able to skip LoRA weight - # crashing errors and saving errors - filename = self._save_state( - state=self.model_state, - gather_fns=gather_fns - ) - checkpoint_path = f"{str(self.arguments.get_path())}/{filename}" - - if self.arguments.do_eval: - for _ in self.eval( - self.model_state - ): - ... - - output.checkpoint_path = checkpoint_path - output.last_save_file_name = filename - wandb.finish() - - return output - - def eval(self, model_state: EasyDeLState) -> typing.Iterator[dict]: - """Evaluate the Given Model State and yield the eval metrics""" - assert self.eval_dataset is not None, "`dataloader_eval` is required by evaluator function." - with self.mesh: - pbar = tqdm(total=self.max_evaluation_steps) - pbar.set_description("Evaluating") - current_step = 0 - loss_sum = None - chosen_rewards_sum = None - rejected_rewards_sum = None - - try: - for batch in self.dataloader_eval: - current_step += 1 - time_start = time.time() - for key in self.arguments.ids_to_pop_from_dataset: - _ = batch.pop(key, None) - for key in list(batch.keys()): - if not ( - key.endswith("_input_ids") - or key.endswith("_attention_mask") - or key.endswith("_labels") - ): - _ = batch.pop(key, None) - - metrics = self.sharded_eval_step_function( - model_state, - batch - ) - total_time = time.time() - time_start - ( - loss, chosen_rewards, rejected_rewards - ) = metrics.loss, metrics.chosen_rewards[0], metrics.rejected_rewards[0] - - loss_sum = loss.tolist() if loss_sum is None else loss_sum + loss - rejected_rewards_sum = ( - rejected_rewards.tolist() if ( - rejected_rewards_sum is None - ) else rejected_rewards_sum + rejected_rewards - ) - chosen_rewards_sum = ( - chosen_rewards.tolist() if ( - chosen_rewards_sum is None - ) else chosen_rewards_sum + chosen_rewards - ) - - eval_metrics = { - "eval/loss": loss.tolist(), - "eval/mean_loss": loss_sum / (current_step - self.arguments.step_start_point), - "eval/mean_rejected_rewards": rejected_rewards_sum / ( - current_step - self.arguments.step_start_point - ), - "eval/mean_chosen_rewards": chosen_rewards_sum / ( - current_step - self.arguments.step_start_point - ), - "eval/step": current_step, - "eval/step_time": total_time, - "eval/perplexity": jnp.exp(loss).tolist(), - } - log_metrics = copy.deepcopy(eval_metrics) - eval_metrics.update(self.arguments.captured_memory) - if self.arguments.use_wandb: - with jax.spmd_mode("allow_all"): - self.wandb_runtime.log( - eval_metrics - ) - - pbar.update(1) - pbar.set_postfix(**{k.replace("eval/", ""): v for k, v in log_metrics.items()}) - yield eval_metrics - except KeyboardInterrupt: - termcolor.cprint( - "KeyboardInterrupt At Evaluation model Will return Nothing and just pass.", - color="cyan", - force_color=True - ) - - def __repr__(self): - - """ - The __repr__ function is used to generate a string representation of an object. - This function should return a string that can be parsed by the Python interpreter - to recreate the object. The __repr__ function is called when you use print() on an - object, or when you type its name in the REPL. - - :param self: Refer to the instance of the class - :return: A string representation of the object - """ - string = f"{self.__class__.__name__}(\n" - for k, v in self.__dict__.items(): - if not k.startswith("_"): - try: - repr_src = f"\t{k} : " + v.__str__().replace("\n", "\n\t") + "\n" - string += repr_src if len(repr_src) < 350 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n" - except TypeError: - repr_src = f"\t{k} : " + "EasyDeLReadingError" + "\n" - string += repr_src if len(repr_src) < 350 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n" - - return string + ")" - - def __str__(self): - - """ - The __str__ function is called when you use the print function or when str() is used. - It should return a string representation of the object. - - :param self: Refer to the instance of the class - :return: The object's string representation - """ - return self.__repr__() +import copy +import os +import sys +import time +import typing +import warnings +from abc import ABC +from collections import defaultdict +import flax.core +import jax +import tensorflow.data +import tensorflow_datasets +import termcolor +import wandb +from fjformer import match_partition_rules, make_shard_and_gather_fns +from tqdm import tqdm + +from typing import Optional, Literal, Dict, Union, Any, Callable, Mapping + +from jax.experimental.pjit import pjit +from datasets import Dataset +from jax import numpy as jnp + +from ...etils.etils import get_logger +from ..training_configurations import TrainArguments +from ..base_trainer import ( + BaseTrainer, + TrainerConfigureFunctionFuncOutput, + TrainerConfigureDataloaderFuncOutput, + TrainerConfigureModelFuncOutput +) +from ...etils import EasyDeLState, EasyDeLTimerError +from transformers import PreTrainedTokenizerBase +from jax.sharding import PartitionSpec + +from ...utils import Timers +from .utils import ( + pad_to_length, + DPODataCollatorWithPadding, + leave_alone_context_manager +) +from .fwd_bwd_functions import ( + create_dpo_train_function, + create_dpo_eval_function, + create_concatenated_forward, +) +from .modelling_output import DPOTrainerOutput + +logger = get_logger(__name__) + + +class DPOTrainer(BaseTrainer, ABC): + """ + easydel DPO Trainer Class + """ + + def __init__( + self, + arguments: TrainArguments, + model_state: EasyDeLState | str, + ref_model_state: Optional[EasyDeLState | str] = None, + beta: float = 0.1, + label_smoothing: float = .0, + loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid", + label_pad_token_id: int = -100, + padding_value: int = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + data_collator: Optional[Callable] = None, + max_length: Optional[int] = None, + max_prompt_length: Optional[int] = None, + max_target_length: Optional[int] = None, + precompute_ref_log_probs: bool = False, + model_init_kwargs: Optional[Dict] = None, + ref_model_init_kwargs: Optional[Dict] = None, + reference_free: bool = False, + auto_shard_model_state: bool = True, + auto_shard_ref_model_state: bool = True, + is_encoder_decoder: Optional[bool] = False, + dataset_map_arguments: Optional[dict] = None, + low_mem_usage: bool = True, + auto_fix_data: bool = True, + _do_init_fns: bool = True, + ): + + """ + The __init__ function is called when the class is instantiated. + It sets up the attributes of an object. + + + :param self: Refer to the object itself + :param model_state: EasyDeLState | str: Pass the model state to the trainer + :param ref_model_state: Optional[EasyDeLState | str]: Pass the reference model state + :param beta: float: Control the strength of the regularization term + :param label_smoothing: float: Smooth the labels + :param loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] : Determine the loss function used + :param arguments: TrainArguments: Pass the arguments to the trainer + :param label_pad_token_id: int: Pad the labels + :param padding_value: int: Specify the value that is used for padding + :param train_dataset: Optional[Dataset]: Load the training dataset + :param eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] : Pass the evaluation dataset to the trainer + :param tokenizer: Optional[PreTrainedTokenizerBase]: Pass the tokenizer to the trainer + :param max_length: Optional[int]: Set the maximum length of the input sequence + :param max_prompt_length: Optional[int]: Set the maximum length of the prompt + :param max_target_length: Optional[int]: Truncate the target sequence + :param data_collator: Optional[Callable]: Function to be used for creating datasets. + :param precompute_ref_log_probs: bool: Precompute the log probabilities of the reference model + :param model_init_kwargs: Optional[Dict]: Pass in the model_kwargs to model for init process + :param ref_model_init_kwargs: Optional[Dict]: Pass the ref_model_init_kwargs to ref_model for init process + :param auto_shard_model_state: bool: whenever to automatically shard `model_state` + :param auto_shard_ref_model_state: bool: whenever to automatically shard `ref_model_state` + :param dataset_map_arguments: Optional[dict]: arguments to be passed to train and eval datasets for + tokenizing process with `dataset.map`. + :param _do_init_fns: bool : preferred to set ture to trainer will automatically configure + model with provided training Arguments + :param : Set the padding value for the model + """ + assert arguments is not None, ( + "You Have to pass arguments that will be used for training but you have passed" + "`arguments=None`" + ) + assert isinstance(arguments, TrainArguments), ( + f"arguments type must be `TrainArguments` but got {type(arguments)}" + ) + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model_state, str): + raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.") + + if ref_model_init_kwargs is None: + ref_model_init_kwargs = {} + elif not isinstance(ref_model_state, str): + raise ValueError( + "You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated." + ) + + if isinstance(model_state, str): + warnings.warn( + "You passed a model_id to the DPOTrainer. This will automatically create an " + "`AutoEasyDeLModelForCausalLM` for you." + ) + model_state = EasyDeLState.from_pretrained( + model_state, + **model_init_kwargs + ) + if isinstance(ref_model_state, str): + warnings.warn( + "You passed a ref model_id to the DPOTrainer. This will automatically create an " + "`AutoEasyDeLModelForCausalLM`" + ) + ref_model_state = EasyDeLState.from_pretrained( + ref_model_state, + **ref_model_init_kwargs + ) + + if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0: + warnings.warn( + "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter." + ) + self.auto_fix_data = auto_fix_data + + if tokenizer is None: + raise ValueError("tokenizer must be specified to tokenize a DPO dataset.") + if max_length is None: + warnings.warn( + "`max_length` is not set in the DPOTrainer's init" + " it will default to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_prompt_length is None: + warnings.warn( + "`max_prompt_length` is not set in the DPOTrainer's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + + if max_target_length is None and is_encoder_decoder: + warnings.warn( + "When using an encoder decoder architecture, you should set `max_target_length` in the " + "DPOTrainer's init it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_target_length = 128 + + padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id + self.max_length = max_length + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value + self.max_prompt_length = max_prompt_length + self.truncation_mode = arguments.truncation_mode + + self.max_target_length = max_target_length + self.tokenizer = tokenizer + self.precompute_ref_log_probs = precompute_ref_log_probs + self.reference_free = reference_free + self.is_encoder_decoder = False + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + self.beta = beta + self.label_smoothing = label_smoothing + self.loss_type = loss_type + self.low_mem_usage = low_mem_usage + data_collator = DPODataCollatorWithPadding( + max_prompt_length=self.max_prompt_length, + max_target_length=self.max_target_length, + pad_token_id=tokenizer.pad_token_id, + label_pad_token_id=label_pad_token_id, + is_encoder_decoder=False, + ) if data_collator is None else data_collator + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + if dataset_map_arguments is None: + dataset_map_arguments = {} + train_dataset = train_dataset.map( + self.tokenize_row, + **dataset_map_arguments + ) + if eval_dataset is not None: + eval_dataset = eval_dataset.map( + self.tokenize_row, + **dataset_map_arguments + ) + + self.arguments = arguments + self.hp_name = None + self.deepspeed = None + self.is_in_train = False + + self.data_collator = data_collator + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.tokenizer = tokenizer + self.ref_model_state = ref_model_state + self.model_state = model_state + self._loggers_initialized = False + self.mesh = self.arguments.get_mesh() + assert padding_value is not None, "`padding_value` can not be set as `None` it must be an integer." + + self.concatenated_forward = create_concatenated_forward( + is_encoder_decoder=self.is_encoder_decoder, + padding_value=padding_value, + label_pad_token_id=label_pad_token_id, + ) + self.auto_shard_ref_model_state = auto_shard_ref_model_state + self.auto_shard_model_state = auto_shard_model_state + + self._cached_p_l_s = None + self._cached_c_l_s = None + self._cached_r_l_s = None + super().__init__( + arguments=arguments, + dataset_train=train_dataset, + dataset_eval=eval_dataset, + finetune=True, + checkpoint_path=None, + _do_init_fns=_do_init_fns + ) + + def initialize_trainer_utils(self): + """ + The initialize_trainer_utils function is responsible for initializing the following: + - wandb_runtime (if you use_wandb is True) + - timer object (for logging time taken by various functions) + - dataloader objects for training and evaluation data, along with max steps per epoch. + The configure_dataloader function accomplishes this task. + + :param self: Represent the instance of the class + :return: A tuple of functions + + """ + self.wandb_runtime = self.arguments.get_wandb_init() if self.arguments.use_wandb else None + self.timer = Timers( + use_wandb=False, + tensorboard_writer=self.arguments.get_board() + ) + + self.timer("configure dataloaders").start() + dataset_configurations = self.configure_dataloader() + self.dataloader_train = dataset_configurations.dataloader_train + self.max_training_steps = dataset_configurations.max_training_steps + self.dataloader_eval = dataset_configurations.dataloader_eval + self.max_evaluation_steps = dataset_configurations.max_evaluation_steps + + self.timer("configure dataloaders").stop() + + self.timer.log(["configure dataloaders"]) + + self.timer("configure Model, Optimizer, Scheduler and Config").start() + model_configurations = self.configure_model() + model = model_configurations.model + tx = model_configurations.tx + scheduler = model_configurations.scheduler + config = model_configurations.config + self.model = model + self.tx = tx + self.scheduler = scheduler + self.config = config + if self.rapture is not None: + lora_modules = self.rapture.apply_lora( + module=model, + parameters=self.arguments.rapture_config.parameters, + tx=tx, + ) + self.lora_parameters = lora_modules.lora_parameters + self.lora_apply_fn = lora_modules.lora_module.__call__ + self.lora_opt_state = lora_modules.lora_opt_state + self.lora_model = lora_modules.lora_module + self.lora_tx = lora_modules.lora_tx + + self.timer("configure Model, Optimizer, Scheduler and Config").stop() + self.timer.log(["configure Model, Optimizer, Scheduler and Config"]) + + self.timer("configure functions and sharding them").start() + + if self.auto_shard_model_state: + self.timer("Sharding Model State").start() + self.model_state: EasyDeLState = self.shard_states( + self.model_state, + self.model_state.module.config.get_partition_rules(self.arguments.fully_sharded_data_parallel) + ) + + termcolor.cprint("initializing TX and Schedulers for `model_state`", force_color=True, color="cyan") + + params_with_opt = ( + self.model_state.params[ + 'params' + ] if '_overwrite_with_gradient' in self.model_state.params else self.model_state.params + ) + opt_state = self.tx.init(params_with_opt) + + self.model_state = self.model_state.replace( + opt_state=opt_state, + tx=self.tx + ) + + self.timer("Sharding Model State").stop() + self.timer.log(["Sharding Model State"]) + if self.auto_shard_ref_model_state and self.ref_model_state is not None: + self.timer("Sharding Ref Model State").start() + self.ref_model_state = self.shard_states( + self.ref_model_state, + self.ref_model_state.module.config.get_partition_rules(self.arguments.fully_sharded_data_parallel) + ) + self.timer("Sharding Ref Model State").stop() + self.timer.log(["Sharding Ref Model State"]) + + function_configurations = self.configure_functions() + self.create_sharded_state_from_params_function = ( + function_configurations.create_sharded_state_from_params_function + ) + self.sharded_train_step_function = function_configurations.sharded_train_step_function + self.sharded_eval_step_function = function_configurations.sharded_eval_step_function + self.mesh = function_configurations.mesh + self.checkpoint_manager = function_configurations.checkpoint_manager + self.initialize_state_function = function_configurations.initialize_state_function + self.timer("configure functions and sharding them").stop() + self.timer.log(["configure functions and sharding them"]) + + def create_collate_function( + self, + max_sequence_length: int, + truncation_mode: typing.Literal["keep_end", "keep_start"] = "keep_end", + ) -> Callable: + return self.data_collator + + def shard_states(self, state, rules): + with self.arguments.get_mesh(): + partition_spec = match_partition_rules(rules=rules, params=jax.eval_shape(lambda: state)) + + def _shard(x): + return x + + shard = pjit( + _shard, + in_shardings=(PartitionSpec(),), + out_shardings=partition_spec + ) + return shard(state) + + def configure_dataloader(self) -> TrainerConfigureDataloaderFuncOutput: + dataloader_train = self.get_train_dataloader() + max_evaluation_steps = None + dataloader_eval = None + + max_training_steps = self.arguments.num_train_epochs * len( + dataloader_train + ) if self.arguments.max_training_steps is None else self.arguments.max_training_steps + if self.eval_dataset is not None: + dataloader_eval = self.get_eval_dataloader(self.eval_dataset) + max_evaluation_steps = len(dataloader_eval) + return TrainerConfigureDataloaderFuncOutput( + dataloader_train=dataloader_train, # type:ignore + max_training_steps=max_training_steps, + dataloader_eval=dataloader_eval, + max_evaluation_steps=max_evaluation_steps + ) + + def configure_functions(self) -> TrainerConfigureFunctionFuncOutput: + def initialize_state_function(): + initialized_parameters = self.model.init_weights( + jax.random.PRNGKey(0), + self.arguments.init_input_shape + ) + + if self.arguments.dtype == jnp.bfloat16: + initialized_parameters = self.model.to_bf16(initialized_parameters) + elif self.arguments.dtype == jnp.float16: + initialized_parameters = self.model.to_fp16(initialized_parameters) + + tx = self.tx + parameters = flax.core.freeze({"params": initialized_parameters}) + tx_init = copy.deepcopy(self.arguments.optimizer_kwargs) + + if self.rapture is not None: + lora_parameters = self.lora_parameters + if self.arguments.dtype == jnp.bfloat16: + lora_parameters = self.model.to_bf16(lora_parameters) + elif self.arguments.dtype == jnp.float16: + lora_parameters = self.model.to_fp16(lora_parameters) + + return EasyDeLState( + step=0, + apply_fn=self.lora_apply_fn, + params=lora_parameters, + tx=self.lora_tx, + opt_state=self.lora_opt_state, + tx_init=EasyDeLState.safe_dict(tx_init), + hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), + module=self.lora_model, + module_config=self.model_state.module.config, + module_config_args=None, + ) + else: + return EasyDeLState.create( + tx=tx, + params=parameters, + apply_fn=self.model.__call__, + module_config=copy.deepcopy(self.model_state.module.config), + tx_init=tx_init, + hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), + module=self.model, + module_config_args=None + ) + + def create_state_from_params_function(parameters): + if self.rapture is None: + return EasyDeLState.create( + tx=self.tx, + params=parameters, + apply_fn=self.model.__call__, + module_config=copy.deepcopy(self.model_state.module.config), + tx_init=copy.deepcopy(self.arguments.optimizer_kwargs), + hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), + module=self.model, + module_config_args=None + ) + else: + return EasyDeLState( + step=0, + apply_fn=self.lora_apply_fn, + params=parameters, + tx=self.lora_tx, + opt_state=self.lora_opt_state, + tx_init=EasyDeLState.safe_dict(copy.deepcopy(self.arguments.optimizer_kwargs)), + hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), + module=self.lora_model, + module_config=self.model_state.module.config, + module_config_args=None, + ) + + state_shape = jax.eval_shape(lambda: self.model_state) + + state_partition_spec = match_partition_rules( + self.config.get_partition_rules( + fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel + ) if self.arguments.custom_rule is None else self.arguments.custom_rule, + state_shape + ) + create_sharded_state_from_params_function = pjit( + create_state_from_params_function, + in_shardings=(state_partition_spec.params,), + out_shardings=state_partition_spec, + donate_argnums=(0,) + ) + train_function = create_dpo_train_function( + concatenated_forward=self.concatenated_forward, + ref_state=self.ref_model_state, + loss_type=self.loss_type, + reference_free=self.reference_free, + label_smoothing=self.label_smoothing, + beta=self.beta + ) + sharded_train_step_function = pjit( + train_function, + in_shardings=(state_partition_spec, self.arguments.step_partition_spec), + out_shardings=(state_partition_spec, PartitionSpec()), + ) + + eval_function = create_dpo_eval_function( + concatenated_forward=self.concatenated_forward, + ref_state=self.ref_model_state, + loss_type=self.loss_type, + reference_free=self.reference_free, + label_smoothing=self.label_smoothing, + beta=self.beta + ) + + sharded_eval_step_function = pjit( + eval_function, + in_shardings=(state_partition_spec, self.arguments.step_partition_spec), + out_shardings=(state_partition_spec, PartitionSpec()), + ) + + self.arguments.ckpt_path_exists() + self.state_partition_spec = state_partition_spec + self.state_shape = state_shape + checkpoint_manager = self.arguments.get_streaming_checkpointer() + mesh = self.arguments.get_mesh() + return TrainerConfigureFunctionFuncOutput( + initialize_state_function=initialize_state_function, + sharded_train_step_function=sharded_train_step_function, + create_sharded_state_from_params_function=create_sharded_state_from_params_function, + checkpoint_manager=checkpoint_manager, + mesh=mesh, + sharded_eval_step_function=sharded_eval_step_function + ) + + def configure_model(self) -> TrainerConfigureModelFuncOutput: + config = self.model_state.module.config + tx, scheduler = self.arguments.get_optimizer_and_scheduler(self.max_training_steps) + return TrainerConfigureModelFuncOutput( + model=self.model_state.module, + config=config, # type: ignore + scheduler=scheduler, + tx=tx + ) + + def _get_train_dataloader(self) -> tensorflow.data.Dataset: + + """ + The _get_train_dataloader function is used to create a tensorflow.data.Dataset object for the training dataset. + + :param self: Represent the instance of the class + :return: A dataloader object + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + + return tensorflow_datasets.as_numpy( + train_dataset.to_tf_dataset( + batch_size=self.arguments.total_batch_size, + collate_fn=data_collator, + num_workers=self.arguments.dataloader_num_workers, + shuffle=True, + drop_remainder=True + ) + ) + + def _get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> tensorflow.data.Dataset: + """ + Returns the evaluation [`~tensorflow.data.Dataset`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + return tensorflow_datasets.as_numpy( + eval_dataset.to_tf_dataset( + batch_size=self.arguments.total_batch_size, + collate_fn=self.data_collator, + num_workers=self.arguments.dataloader_num_workers, + shuffle=False, + drop_remainder=True + ) + ) + + def get_train_dataloader( + self, + ) -> tensorflow.data.Dataset: + """ + Returns the training [`~tensorflow.data.Dataset`]. + """ + + if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: + + data_loader = tensorflow_datasets.as_numpy( + self.train_dataset.to_tf_dataset( + batch_size=self.arguments.total_batch_size, + collate_fn=self.data_collator, + num_workers=self.arguments.dataloader_num_workers, + shuffle=False, + drop_remainder=True + ) + ) + reference_chosen_log_probs = [] + reference_rejected_log_probs = [] + for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): + reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs( + self.model_state, + padded_batch, + ) + reference_chosen_log_probs.append(reference_chosen_logp) + reference_rejected_log_probs.append(reference_rejected_logp) + + all_reference_chosen_log_probs = jnp.concatenate(reference_chosen_log_probs) + all_reference_rejected_log_probs = jnp.concatenate(reference_rejected_log_probs) + self.train_dataset = self.train_dataset.add_column( + name="reference_chosen_log_probs", column=all_reference_chosen_log_probs + ) + self.train_dataset = self.train_dataset.add_column( + name="reference_rejected_log_probs", column=all_reference_rejected_log_probs + ) + + self._precomputed_train_ref_log_probs = True + return self._get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> tensorflow.data.Dataset: + """ + Returns the evaluation [`~tensorflow.data.Dataset`]. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: + + # prepare dataloader + data_loader = tensorflow_datasets.as_numpy( + eval_dataset.to_tf_dataset( + batch_size=self.arguments.total_batch_size, + collate_fn=self.data_collator, + num_workers=self.arguments.dataloader_num_workers, + shuffle=False, + drop_remainder=True + ) + ) + + reference_chosen_log_probs = [] + reference_rejected_log_probs = [] + for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): + reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs( + self.model_state, + padded_batch + ) + reference_chosen_log_probs.append(reference_chosen_logp.cpu()) + reference_rejected_log_probs.append(reference_rejected_logp.cpu()) + + all_reference_chosen_log_probs = jnp.concatenate(reference_chosen_log_probs) + all_reference_rejected_log_probs = jnp.concatenate(reference_rejected_log_probs) + + eval_dataset = eval_dataset.add_column(name="reference_chosen_log_probs", + column=all_reference_chosen_log_probs) + eval_dataset = eval_dataset.add_column( + name="reference_rejected_log_probs", column=all_reference_rejected_log_probs + ) + + if self.eval_dataset is not None: + self.eval_dataset = eval_dataset + self._precomputed_eval_ref_log_probs = True + + return self._get_eval_dataloader(eval_dataset=eval_dataset) + + def build_tokenized_answer(self, prompt, answer): + """ + Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. + It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`. + """ + + full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] + + answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids):] + answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids):] + prompt_input_ids = jnp.asarray(prompt_input_ids, dtype="i4") + answer_input_ids = jnp.asarray(answer_input_ids, dtype="i4") + full_concat_input_ids = jnp.concatenate( + ( + prompt_input_ids, + answer_input_ids + ) + ) + + # Prepare input tokens for token by token comparison + full_input_ids = jnp.array(full_tokenized["input_ids"]) + + if len(full_input_ids) != len(full_concat_input_ids): + raise ValueError("Prompt input ids and answer input ids should have the same length.") + + response_token_ids_start_idx = len(prompt_input_ids) + if prompt_input_ids.tolist() != full_tokenized["input_ids"][:response_token_ids_start_idx]: + response_token_ids_start_idx -= 1 + + prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] + prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] + + if len(prompt_input_ids) != len(prompt_attention_mask): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] + answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + + return dict( + prompt_input_ids=jnp.array(prompt_input_ids, dtype="i4"), + prompt_attention_mask=jnp.array(prompt_attention_mask, dtype="i4"), + input_ids=jnp.array(answer_input_ids, dtype="i4"), + attention_mask=jnp.array(answer_attention_mask, dtype="i4"), + ) + + def tokenize_row(self, feature, state: EasyDeLState = None) -> Dict: + + """ + The tokenize_row function is responsible for taking a single row of data and converting it into the format that + the model expects. This includes: + - Tokenizing the text (using HuggingFace's tokenizer) + - Padding/truncating sequences to a fixed length (if necessary) + - Creating attention masks, which tell the model which tokens are padding and which aren't. + + :param self: Represent the instance of the class + :param feature: Pass in the data from the dataset + :param state: EasyDeLState: Keep track of the state of the tokenizer + :return: A dictionary of the following keys + """ + batch = {} + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)} , {prompt}") + prompt_tokens = self.tokenizer( + prompt, + add_special_tokens=False, + return_tensors="np", + ) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} + + if not isinstance(chosen, str): + raise ValueError(f"chosen should be an str but got {type(chosen)} , {chosen}") + chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + if not isinstance(rejected, str): + raise ValueError(f"rejected should be an str but got {type(rejected)}") + rejected_tokens = self.build_tokenized_answer(prompt, rejected) + v2d = lambda ar: ar.reshape(1, -1) if ar.ndim == 1 else ar + + def add_tkn(n, ar): + return jnp.concatenate( + ( + jnp.array(n).reshape(1, 1), + v2d(ar) + ), axis=-1 + ) + + def add_post_tkn(n, ar): + return jnp.concatenate( + ( + v2d(ar), + jnp.array(n).reshape(1, 1) + ), axis=-1 + ) + + prompt_tokens["prompt_input_ids"] = add_tkn( + self.tokenizer.bos_token_id, + prompt_tokens["prompt_input_ids"] + ) + chosen_tokens["prompt_input_ids"] = add_tkn( + self.tokenizer.bos_token_id, + chosen_tokens["prompt_input_ids"] + ) + rejected_tokens["prompt_input_ids"] = add_tkn( + self.tokenizer.bos_token_id, + rejected_tokens["prompt_input_ids"] + ) + + prompt_tokens["prompt_attention_mask"] = add_tkn( + 1, prompt_tokens["prompt_attention_mask"] + ) + chosen_tokens["prompt_attention_mask"] = add_tkn( + 1, chosen_tokens["prompt_attention_mask"] + ) + rejected_tokens["prompt_attention_mask"] = add_tkn( + 1, rejected_tokens["prompt_attention_mask"] + ) + + # add EOS token to end of answer + chosen_tokens["input_ids"] = add_post_tkn(self.tokenizer.eos_token_id, chosen_tokens["input_ids"]) + chosen_tokens["attention_mask"] = add_post_tkn(1, chosen_tokens["attention_mask"]) + + rejected_tokens["input_ids"] = add_post_tkn(self.tokenizer.eos_token_id, rejected_tokens["input_ids"]) + rejected_tokens["attention_mask"] = add_post_tkn(1, rejected_tokens["attention_mask"]) + + longer_response_length = max(chosen_tokens["input_ids"].shape[-1], rejected_tokens["input_ids"].shape[-1]) + + # if combined sequence is too long, truncate the prompt + for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: + length_rn = answer_tokens["prompt_input_ids"].shape[-1] + longer_response_length + if length_rn > self.max_length: + + if self.truncation_mode == "keep_start": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][:, : self.max_prompt_length] + elif self.truncation_mode == "keep_end": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][:, -self.max_prompt_length:] + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + # if that's still too long, truncate the response + for answer_tokens in [chosen_tokens, rejected_tokens]: + if answer_tokens["prompt_input_ids"].shape[-1] + longer_response_length > self.max_length: + for k in ["input_ids", "attention_mask"]: + answer_tokens[k] = answer_tokens[k][:, : self.max_length - self.max_prompt_length] + + chosen_sequence_tokens = { + k: jnp.concatenate( + (v2d(chosen_tokens[f"prompt_{k}"]), v2d(chosen_tokens[k])), + axis=-1 + ) for k in ["input_ids", "attention_mask"] + } + rejected_sequence_tokens = { + k: jnp.concatenate( + (v2d(rejected_tokens[f"prompt_{k}"]), v2d(rejected_tokens[k])), + axis=-1 + ) for k in ["input_ids", "attention_mask"] + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["labels"].at[ + : len(chosen_tokens["prompt_input_ids"]) + ].set([self.label_pad_token_id] * len(chosen_tokens["prompt_input_ids"])) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["labels"].at[ + : len(rejected_tokens["prompt_input_ids"]) + ].set( + ([self.label_pad_token_id] * len(rejected_tokens["prompt_input_ids"])) + ) + + for k, tokens_ in { + "chosen_": chosen_sequence_tokens, + "rejected_": rejected_sequence_tokens, + "": prompt_tokens, + }.items(): + for type_key, tokens in tokens_.items(): + if type_key == "token_type_ids": + continue + + b, s = tokens.shape + + if self.max_prompt_length > s: + if k == "chosen_": + if type_key == "input_ids": + tokens = pad_to_length( + tokens, + self.max_target_length, + pad_value=self.padding_value, + axis=-1 + ) + elif type_key == "attention_mask": + tokens = pad_to_length( + tokens, + self.max_target_length, + pad_value=0, + axis=-1 + ) + elif type_key == "labels": + tokens = pad_to_length( + tokens, + self.max_target_length, + pad_value=self.padding_value, + axis=-1 + ) + + tokens = tokens[..., :self.max_target_length] + + if tokens.shape[-1] != self.max_target_length: + raise ValueError( + f"there was an error in padding token with `type_key` of {type_key}" + f". it must have sequence_length of {self.max_target_length} but we got {tokens.shape[-1]}" + f" From {k}{type_key}" + ) + tokens = tokens[..., :self.max_target_length] + elif k == "rejected_": + if type_key == "input_ids": + tokens = pad_to_length( + tokens, + self.max_target_length, + pad_value=self.padding_value, + axis=-1 + ) + elif type_key == "attention_mask": + tokens = pad_to_length( + tokens, + self.max_target_length, + pad_value=0, + axis=-1 + ) + elif type_key == "labels": + tokens = pad_to_length( + tokens, + self.max_target_length, + pad_value=self.padding_value, + axis=-1 + ) + tokens = tokens[..., :self.max_target_length] + if tokens.shape[-1] != self.max_target_length: + raise ValueError( + f"there was an error in padding token with `type_key` of {type_key}" + f". it must have sequence_length of {self.max_target_length} but we got {tokens.shape[-1]}" + f" From {k}{type_key}" + ) + elif k == "": + if type_key == "prompt_input_ids": + tokens = pad_to_length( + tokens, + self.max_prompt_length, + pad_value=self.padding_value, + axis=-1 + ) + elif type_key == "prompt_attention_mask": + tokens = pad_to_length( + tokens, + self.max_prompt_length, + pad_value=0, + axis=-1 + ) + elif type_key == "prompt_labels": + tokens = pad_to_length( + tokens, + self.max_prompt_length, + pad_value=self.padding_value, + axis=-1 + ) + tokens = tokens[..., :self.max_prompt_length] + if tokens.shape[-1] != self.max_prompt_length: + raise ValueError( + f"there was an error in padding token with `type_key` of {type_key}" + f". it must have sequence_length of {self.max_prompt_length} but we got {tokens.shape[-1]}" + f" From {k}{type_key}" + ) + batch[f"{k}{type_key}"] = tokens + return batch + + def compute_reference_log_probs( + self, + state: EasyDeLState, + padded_batch: Dict, + ) -> tuple[Any, Any]: + """ + Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset. + """ + + if self.ref_model_state is None: + ( + reference_chosen_log_probs, + reference_rejected_log_probs, + _, + _, + ) = self.concatenated_forward( + apply_fn=state.apply_fn, + params=state.params, + batch=padded_batch, + ) + else: + ( + reference_chosen_log_probs, + reference_rejected_log_probs, + _, + _, + ) = self.concatenated_forward( + apply_fn=self.ref_model_state.apply_fn, + params=self.ref_model_state.params, + batch=padded_batch, + ) + + return reference_chosen_log_probs, reference_rejected_log_probs + + def _save_state( + self, + state: EasyDeLState, + gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]], + milestone: bool = False + ) -> str: + step = int( + jax.device_get( + state.step + ) + ) + self.arguments.step_start_point if self.arguments.step_start_point is not None else int( + jax.device_get( + state.step + ) + ) + checkpoint_name = f"{self.arguments.model_name}-S{step}" + filename = f"{checkpoint_name}_{step}" if milestone else f"{checkpoint_name}" + filename += ".easy" + termcolor.cprint(f"Saving Model {filename}.", color="cyan", force_color=True) + state.save_state( + filename=filename, + checkpoint_dir=os.path.join(self.arguments.save_dir, self.arguments.model_name), + gather_fns=gather_fns, + float_dtype=self.dtype, + verbose=self.arguments.verbose, + save_optimizer=self.arguments.save_optimizer_state, + ) + return filename + + def train(self) -> DPOTrainerOutput: + assert self.model_state is not None, "model_state can not be None for training purpose" + with self.mesh: + with jax.default_device(jax.devices("cpu")[0]) if self.low_mem_usage else leave_alone_context_manager: + dir_prefix: str = "/dev/shm" if sys.platform != "win32" else "." + checkpoint_path = "SAVING_SKIPPED" + + pbar = tqdm(total=self.max_training_steps) + pbar.set_description("Training") + current_step = self.model_state.step.tolist() if isinstance( + self.model_state.step, + jax.Array + ) else self.model_state.step + + loss_sum = None + chosen_rewards_sum = None + rejected_rewards_sum = None + + try: + for epoch_index in range(self.arguments.num_train_epochs): + for batch in self.dataloader_train: + current_step += 1 + if self.arguments.step_start_point > current_step: + ... + elif current_step < self.max_training_steps: + time_start = time.time() + + self.model_state, metrics = self.sharded_train_step_function( + self.model_state, + batch + ) + total_time = time.time() - time_start + ( + loss, chosen_rewards, rejected_rewards + ) = metrics.loss, metrics.chosen_rewards[0], metrics.rejected_rewards[0] + + loss_sum = loss.tolist() if loss_sum is None else loss_sum + loss + + rejected_rewards_sum = ( + rejected_rewards.tolist() if ( + rejected_rewards_sum is None + ) else rejected_rewards_sum + rejected_rewards + ) + chosen_rewards_sum = ( + chosen_rewards.tolist() if ( + chosen_rewards_sum is None + ) else chosen_rewards_sum + chosen_rewards + ) + train_metrics = { + "train/loss": loss.tolist(), + "train/mean_loss": loss_sum / (current_step - self.arguments.step_start_point), + "train/mean_rejected_rewards": rejected_rewards_sum / ( + current_step - self.arguments.step_start_point + ), + "train/mean_chosen_rewards": chosen_rewards_sum / ( + current_step - self.arguments.step_start_point + ), + "train/learning_rate": self.scheduler( + jax.device_get(self.model_state.step) + ).tolist(), + "train/step": current_step, + "train/step_time": total_time, + "train/perplexity": jnp.exp(loss).tolist(), + "train/epoch": epoch_index + } + log_metrics = copy.deepcopy(train_metrics) + train_metrics.update(self.arguments.captured_memory) + if self.arguments.use_wandb: + with jax.spmd_mode("allow_all"): + self.wandb_runtime.log( + train_metrics + ) + pbar.update(1) + pbar.set_postfix(**{k.replace("train/", ""): v for k, v in log_metrics.items()}) + else: + break + except KeyboardInterrupt: + termcolor.cprint( + "KeyboardInterrupt At training model Will return Current State of the Model with Parameters.", + color="cyan", + force_color=True + ) + + except EasyDeLTimerError: + termcolor.cprint( + "Training reached out maximum training Time Killing training Process " + "and Will return Current State of the Model with Parameters.", + color="cyan", + force_color=True + ) + + if self.arguments.merge_lora_rapture_parameters and self.rapture is not None: + print( + termcolor.colored( + "Info : ", color="red", force_color=True + ), + termcolor.colored( + "Merging LoRA Parameters.", color="white", force_color=True + ) + ) + self.model_state = self.model_state.replace( + params=self.rapture.merge_parameters(self.model_state.params) + ) + + shard_fns, gather_fns = make_shard_and_gather_fns( + partition_specs=match_partition_rules( + rules=self.model_state.module.config.get_partition_rules( + self.arguments.fully_sharded_data_parallel + ), + params=jax.eval_shape(lambda: self.model_state) + ), + dtype_specs=self.arguments.dtype + ) + output = DPOTrainerOutput( + state=self.model_state, + mesh=self.mesh, + shard_fns=shard_fns, + gather_fns=gather_fns, + checkpoint_manager=self.checkpoint_manager, + ) + if self.arguments.save_steps is None and self.arguments.do_last_save: + shard_fns, gather_fns = make_shard_and_gather_fns( + match_partition_rules( + self.config.get_partition_rules( + fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel + ) if self.arguments.custom_rule is None else self.arguments.custom_rule, + jax.eval_shape(lambda: self.model_state) + ), + dtype_specs=self.dtype + ) # You have to re-init the new shard and gather functions in order to be able to skip LoRA weight + # crashing errors and saving errors + filename = self._save_state( + state=self.model_state, + gather_fns=gather_fns + ) + checkpoint_path = f"{str(self.arguments.get_path())}/{filename}" + + if self.arguments.do_eval: + for _ in self.eval( + self.model_state + ): + ... + + output.checkpoint_path = checkpoint_path + output.last_save_file_name = filename + wandb.finish() + + return output + + def eval(self, model_state: EasyDeLState) -> typing.Iterator[dict]: + """Evaluate the Given Model State and yield the eval metrics""" + assert self.eval_dataset is not None, "`dataloader_eval` is required by evaluator function." + with self.mesh: + pbar = tqdm(total=self.max_evaluation_steps) + pbar.set_description("Evaluating") + current_step = 0 + loss_sum = None + chosen_rewards_sum = None + rejected_rewards_sum = None + + try: + for batch in self.dataloader_eval: + current_step += 1 + time_start = time.time() + for key in self.arguments.ids_to_pop_from_dataset: + _ = batch.pop(key, None) + for key in list(batch.keys()): + if not ( + key.endswith("_input_ids") + or key.endswith("_attention_mask") + or key.endswith("_labels") + ): + _ = batch.pop(key, None) + + metrics = self.sharded_eval_step_function( + model_state, + batch + ) + total_time = time.time() - time_start + ( + loss, chosen_rewards, rejected_rewards + ) = metrics.loss, metrics.chosen_rewards[0], metrics.rejected_rewards[0] + + loss_sum = loss.tolist() if loss_sum is None else loss_sum + loss + rejected_rewards_sum = ( + rejected_rewards.tolist() if ( + rejected_rewards_sum is None + ) else rejected_rewards_sum + rejected_rewards + ) + chosen_rewards_sum = ( + chosen_rewards.tolist() if ( + chosen_rewards_sum is None + ) else chosen_rewards_sum + chosen_rewards + ) + + eval_metrics = { + "eval/loss": loss.tolist(), + "eval/mean_loss": loss_sum / (current_step - self.arguments.step_start_point), + "eval/mean_rejected_rewards": rejected_rewards_sum / ( + current_step - self.arguments.step_start_point + ), + "eval/mean_chosen_rewards": chosen_rewards_sum / ( + current_step - self.arguments.step_start_point + ), + "eval/step": current_step, + "eval/step_time": total_time, + "eval/perplexity": jnp.exp(loss).tolist(), + } + log_metrics = copy.deepcopy(eval_metrics) + eval_metrics.update(self.arguments.captured_memory) + if self.arguments.use_wandb: + with jax.spmd_mode("allow_all"): + self.wandb_runtime.log( + eval_metrics + ) + + pbar.update(1) + pbar.set_postfix(**{k.replace("eval/", ""): v for k, v in log_metrics.items()}) + yield eval_metrics + except KeyboardInterrupt: + termcolor.cprint( + "KeyboardInterrupt At Evaluation model Will return Nothing and just pass.", + color="cyan", + force_color=True + ) + + def __repr__(self): + + """ + The __repr__ function is used to generate a string representation of an object. + This function should return a string that can be parsed by the Python interpreter + to recreate the object. The __repr__ function is called when you use print() on an + object, or when you type its name in the REPL. + + :param self: Refer to the instance of the class + :return: A string representation of the object + """ + string = f"{self.__class__.__name__}(\n" + for k, v in self.__dict__.items(): + if not k.startswith("_"): + try: + repr_src = f"\t{k} : " + v.__str__().replace("\n", "\n\t") + "\n" + string += repr_src if len(repr_src) < 350 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n" + except TypeError: + repr_src = f"\t{k} : " + "EasyDeLReadingError" + "\n" + string += repr_src if len(repr_src) < 350 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n" + + return string + ")" + + def __str__(self): + + """ + The __str__ function is called when you use the print function or when str() is used. + It should return a string representation of the object. + + :param self: Refer to the instance of the class + :return: The object's string representation + """ + return self.__repr__() diff --git a/src/python/easydel/trainer/dpo/fwd_bwd_functions.py b/src/python/easydel/trainer/dpo/fwd_bwd_functions.py index 41880686a..06f76b2c2 100644 --- a/src/python/easydel/trainer/dpo/fwd_bwd_functions.py +++ b/src/python/easydel/trainer/dpo/fwd_bwd_functions.py @@ -1,731 +1,799 @@ -import typing -import chex -import flax.core -import jax - -from typing import Literal, Dict, Union, Tuple, List, Callable - -from jax import numpy as jnp -from ...etils import EasyDeLState -from flax.struct import dataclass -from .utils import pad_to_length - - -@dataclass -class DPOStepOut: - loss: chex.Array - chosen_rewards: chex.Array - rejected_rewards: chex.Array - - -def create_concatenated_forward( - is_encoder_decoder, - label_pad_token_id, - padding_value, - truncation_mode: typing.Literal["keep_end", "keep_start"] = "keep_end", - fixed_max_length: int | None = None -): - """ - The create_concatenated_forward function is a helper function that creates a forward pass function for the - model. The forward pass function takes in an apply_fn, which is the model's apply_fn, and runs it on concatenated - inputs. It returns chosen log probs, rejected log probs, chosen logits and rejected logits. - - :param is_encoder_decoder: Determine whether the model is an encoder-decoder model or not - :param label_pad_token_id: Pad the labels to the same length - :param padding_value: Pad the inputs to the same length - :param truncation_mode: typing.Literal["keep_end","keep_start"]: where to pad and where to keep. - :param fixed_max_length : int|None: by providing fixed_max_length the func will always return a fixed sequence length - and won't use dynamic methods. - :return: A function that takes in a apply_fn, params and a batch of inputs, - """ - - def concatenated_forward( - apply_fn: Callable, - params: dict | flax.core.FrozenDict, - batch: Dict[str, Union[List, chex.Array]] - - ) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: - """ - The concatenated_forward function is used to compute the log-probabilities of both chosen and rejected labels. - - :param apply_fn: Callable: Pass in the model function - :param params: dict | flax.core.FrozenDict: Pass the model parameters to the function - :param batch: Dict[str, Union[List, chex.Array]] : Pass the batch of data to the concatenated_forward function - :return: The log_probs of the chosen and rejected labels, as well as their corresponding logits - """ - assert padding_value is not None, "`padding_value` can not be set as `None` it must be an integer." - concatenated_batch = concatenated_inputs( - batch, - is_encoder_decoder=is_encoder_decoder, - label_pad_token_id=label_pad_token_id, - padding_value=padding_value, - truncation_mode=truncation_mode, - fixed_max_length=fixed_max_length - ) - len_chosen = batch["chosen_labels"].shape[0] - concatenated_batch["concatenated_input_ids"] = concatenated_batch["concatenated_input_ids"].reshape( - concatenated_batch["concatenated_input_ids"].shape[0], -1 - ) - concatenated_batch["concatenated_labels"] = concatenated_batch["concatenated_labels"].reshape( - concatenated_batch["concatenated_labels"].shape[0], -1 - ) - concatenated_batch["concatenated_attention_mask"] = concatenated_batch["concatenated_attention_mask"].reshape( - concatenated_batch["concatenated_attention_mask"].shape[0], -1 - ) - model_kwargs = ( - { - "labels": concatenated_batch["concatenated_labels"], - "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), - } - if is_encoder_decoder - else {} - ) - all_logits = apply_fn( - concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], - params=params, - **model_kwargs, - ).logits - - all_log_probs = get_batch_log_probs( - all_logits, - concatenated_batch["concatenated_labels"], - average_log_prob=False, - is_encoder_decoder=is_encoder_decoder, - label_pad_token_id=label_pad_token_id, - ) - - chosen_log_probs = all_log_probs[:len_chosen] - rejected_log_probs = all_log_probs[len_chosen:] - - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] - - return chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits - - return concatenated_forward - - -def get_batch_log_probs( - logits: chex.Array, - labels: chex.Array, - average_log_prob: bool = False, - label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, -) -> chex.Array: - """ - The get_batch_log_probs function computes the log probability of a batch of sequences. - - :param logits: chex.Array: Compute the log_softmax of the input - :param labels: chex.Array: Mask the logits - :param average_log_prob: bool: Determine whether to average the log prob over the sequence length - :param label_pad_token_id: int: Mask out the padding tokens in the labels - :param is_encoder_decoder: bool: Indicate whether the model is an encoder-decoder model - :param : Determine whether to average the log probability over all tokens or not - :return: The log probability of the labels given the logits - """ - - # sudo code - # (per_token_log_probs * loss_mask).sum(-1) - # or - # (per_token_log_probs * loss_mask).sum(-1) / loss_mask.sum(-1) - - if logits.shape[:-1] != labels.shape: - raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") - - if not is_encoder_decoder: - labels = labels[:, 1:] - logits = logits[:, :-1, :] - - batch, seq_len, dim = logits.shape - loss_mask = labels != label_pad_token_id - labels = jax.lax.select( - labels == label_pad_token_id, - jnp.zeros(labels.shape, dtype=labels.dtype), - labels - ) - logits_log_s = jax.nn.log_softmax( - logits, -1 - ) - per_token_log_probs = jnp.take_along_axis( - logits_log_s, - axis=2, - indices=labels[:, :, None] - ).reshape(batch, seq_len) - - if average_log_prob: - log_prob = jnp.sum((per_token_log_probs * loss_mask), axis=-1) / jnp.sum(loss_mask, axis=-1) - else: - log_prob = jnp.sum((per_token_log_probs * loss_mask), axis=-1) - - return log_prob - - -def concatenated_inputs( - batch: Dict[str, Union[List, chex.Array]], - is_encoder_decoder: bool = False, - label_pad_token_id: int = -100, - padding_value: int = 0, - truncation_mode: typing.Literal["keep_end", "keep_start"] = "keep_end", - fixed_max_length: int | None = None -) -> Dict[str, chex.Array]: - """ - The concatenated_inputs function takes a batch of chosen and rejected examples, - and concatenates them together. This is useful for training the model to predict whether an example was chosen - by the human annotator. The function also pads all inputs to - the same length as the longest input in that batch. - - :param batch: Dict[str,Union[List,chex.Array]]: Pass the batch of data into the function, - Allow for the batch to be a list of arrays or just an array, - Specify the type of data that is being passed in - - :param is_encoder_decoder: bool: Determine whether the model is an encoder-decoder model - :param label_pad_token_id: int: Pad the labels with a value of -100 - :param padding_value: int: Pad the input_ids and attention_mask arrays to the same length - :param truncation_mode: typing.Literal["keep_end", "keep_start"]: is left padded or not should it keep start of the - array or the end of the array?. - - :param fixed_max_length : int|None: by providing fixed_max_length the func will always return a fixed sequence - length and won't use dynamic methods. - - :return: A dictionary of the concatenated inputs - """ - concatenated_batch = {} - if fixed_max_length is None: - if is_encoder_decoder: - max_length = max(batch["chosen_labels"].shape[-1], batch["rejected_labels"].shape[-1]) - else: - max_length = max(batch["chosen_input_ids"].shape[-1], batch["rejected_input_ids"].shape[-1]) - else: - max_length = fixed_max_length - for k in batch: - if k.startswith("chosen") and isinstance(batch[k], jax.Array): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - else: - raise KeyError("couldn't find pad_value [Dataset Issue]") - concatenated_key = k.replace("chosen", "concatenated") - concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) - for k in batch: - if k.startswith("rejected") and isinstance(batch[k], jax.Array): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - assert padding_value is not None, "`padding_value` can not be set as `None`" - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - else: - raise KeyError("couldn't find pad_value [Dataset Issue]") - concatenated_key = k.replace("rejected", "concatenated") - v2d = lambda ar: ar.reshape(ar.shape[0], -1) - concatenated_batch[concatenated_key] = jnp.concatenate( - ( - v2d(concatenated_batch[concatenated_key]), - pad_to_length(v2d(batch[k]), max_length, pad_value=pad_value), - ), - axis=0, - ) - for k in list(concatenated_batch.keys()): - val = concatenated_batch[k] - if val.ndim == 3: - # making 3d array 2d - concatenated_batch[k] = val.reshape(val.shape[0], -1) - if is_encoder_decoder: - concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1) - concatenated_batch["concatenated_attention_mask"] = ( - batch["prompt_attention_mask"].repeat(2, 1) - ) - - return concatenated_batch - - -def create_dpo_train_function( - concatenated_forward: Callable, - ref_state: EasyDeLState = None, - beta: float = 0.1, - label_smoothing: float = 0, - loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid", - reference_free: bool = False, -): - """ - The create_dpo_train_function function is a helper function that creates the DPO training step. - - :param concatenated_forward: Callable: Define the forward pass of the model - :param ref_state: EasyDeLState: Specify the reference policy - :param beta: float: Scale the logits - :param label_smoothing: float: Smooth the labels - :param loss_type: Literal["sigmoid", "hinge", "ipo", "kto"]: Determine the loss function - :param reference_free: bool: Indicate whether the reference policy is used or not - :return: A function that takes in a state and a batch - """ - - def _sigmoid_dpo_loss( - logits: chex.Array, - policy_chosen_log_probs: chex.Array = None, # IGNORED - reference_chosen_log_probs: chex.Array = None, # IGNORED - policy_rejected_log_probs: chex.Array = None, # IGNORED - reference_rejected_log_probs: chex.Array = None # IGNORED - ): - - """ - The _sigmoid_dpo_loss function is a helper function for the sigmoid_dpo_loss - function. It computes the loss of each example in a batch, given its logits - and (optionally) its chosen/rejected log probabilities under both policies. - - :param logits: chex.Array: Compute the loss - :param policy_chosen_log_probs: chex.Array: Calculate the policy loss - :param reference_chosen_log_probs: chex.Array: Compute the loss for the reference policy # IGNORED - :param policy_rejected_log_probs: chex.Array: Calculate the loss for the rejected samples # IGNORED - :param reference_rejected_log_probs: chex.Array: Calculate the loss of rejected samples # IGNORED - :return: an array represent loss - """ - losses = ( - -jax.nn.log_sigmoid(beta * logits) * (1 - label_smoothing) - - jax.nn.log_sigmoid(-beta * logits) * label_smoothing - ) - return losses - - def _hinge_dpo_loss( - logits: chex.Array, - policy_chosen_log_probs: chex.Array, # IGNORED - reference_chosen_log_probs: chex.Array, # IGNORED - policy_rejected_log_probs: chex.Array, # IGNORED - reference_rejected_log_probs: chex.Array # IGNORED - ): - - """ - The _hinge_dpo_loss function is a helper function that computes the loss for DPO. - - :param logits: chex.Array: Calculate the hinge loss - :param policy_chosen_log_probs: chex.Array: Compute the policy loss - :param reference_chosen_log_probs: chex.Array: Compute the loss for the reference policy # IGNORED - :param policy_rejected_log_probs: chex.Array: Calculate the loss for the rejected samples # IGNORED - :param reference_rejected_log_probs: chex.Array: Calculate the loss of rejected samples # IGNORED - :return: an array represent The hinge loss - """ - return jax.relu(1 - beta * logits) - - def _ipo_dpo_loss( - logits: chex.Array, - policy_chosen_log_probs: chex.Array, # IGNORED - reference_chosen_log_probs: chex.Array, # IGNORED - policy_rejected_log_probs: chex.Array, # IGNORED - reference_rejected_log_probs: chex.Array # IGNORED - ): - """ - The _ipo_dpo_loss function is a helper function that calculates the loss for - the IPO-DPO algorithm. It takes in the logits, policy_chosen_log_probs, - reference_chosen_log_probs, policy rejected log probs and reference rejected - log probs as inputs. The output of this function is used to calculate the loss - for each batch of data. - - :param logits: chex.Array: Calculate the loss - :param policy_chosen_log_probs: chex.Array: Compute the - :param reference_chosen_log_probs: chex.Array: Compute the loss for the reference policy # IGNORED - :param policy_rejected_log_probs: chex.Array: Calculate the loss for the rejected samples # IGNORED - :param reference_rejected_log_probs: chex.Array: Calculate the loss of rejected samples # IGNORED - :return: an array represent loss - """ - return (logits - 1 / (2 * beta)) ** 2 - - def _kto_pair_dpo_loss( - logits: chex.Array, # IGNORED - policy_chosen_log_probs: chex.Array, - reference_chosen_log_probs: chex.Array, - policy_rejected_log_probs: chex.Array, - reference_rejected_log_probs: chex.Array - ): - - """ - The _kto_pair_dpo_loss function is a helper function that computes the loss for - a single pair of trajectories. It takes in two sets of log probabilities, one from - the policy and one from the reference distribution. The first set are the log - probabilities for actions taken by each agent in a trajectory, while the second set - are those for actions not taken by each agent (i.e., rejected). The function then - computes KL divergences between these two sets of distributions and uses them to compute losses. - - :param logits: chex.Array: Calculate the log_probs - :param policy_chosen_log_probs: chex.Array: Calculate the chosen_kl # IGNORED - :param reference_chosen_log_probs: chex.Array: Calculate the chosen_kl - :param policy_rejected_log_probs: chex.Array: Calculate the rejected_kl variable - :param reference_rejected_log_probs: chex.Array: Calculate the rejected_kl variable - :return: an array represent loss - """ - chosen_kl = jax.lax.clamp( - min=0, - x=jnp.mean(policy_chosen_log_probs - reference_chosen_log_probs), - max=1e9 - ) - rejected_kl = jax.lax.clamp( - min=0, - x=jnp.mean(policy_rejected_log_probs - reference_rejected_log_probs), - max=1e9 - ) - - chosen_log_ratios = policy_chosen_log_probs - reference_chosen_log_probs - rejected_log_ratios = policy_rejected_log_probs - reference_rejected_log_probs - losses = jnp.concatenate( - ( - 1 - jax.nn.sigmoid(beta * (chosen_log_ratios - rejected_kl)), - 1 - jax.nn.sigmoid(beta * (chosen_kl - rejected_log_ratios)), - ), - 0, - ) - - return losses - - if loss_type == "sigmoid": - _loss_func = _sigmoid_dpo_loss - elif loss_type == "hinge": - _loss_func = _hinge_dpo_loss - elif loss_type == "ipo": - _loss_func = _ipo_dpo_loss - elif loss_type == "kto_pair": - _loss_func = _kto_pair_dpo_loss - else: - raise ValueError(f"UnKnown loss_type {loss_type}") - - def dpo_step( - state: EasyDeLState, - batch: dict - ) -> tuple[EasyDeLState, DPOStepOut]: - - """ - The dpo_step function is the core of DPO. It takes a state and a batch, - and returns an updated state. The update is done by calculating the loss - for each example in the batch, then taking its gradient with respect to - the parameters of the policy network (which are stored in `state`). This - gradient is then used to update `state`. - - :param state: EasyDeLState: Store the parameters of the model - :param batch: dict: Pass the data to the model - :return: A new state, which is a collection of the parameters and apply_fn - """ - - def calculate_loss(params: dict | flax.core.FrozenDict): - ( - policy_chosen_log_probs, - policy_rejected_log_probs, - policy_chosen_logits, - policy_rejected_logits, - ) = concatenated_forward( - state.apply_fn, - params, - batch - ) - - if "reference_chosen_log_probs" in batch and "reference_rejected_log_probs" in batch: - reference_chosen_log_probs = batch["reference_chosen_log_probs"] - reference_rejected_log_probs = batch["reference_rejected_log_probs"] - else: - if ref_state is None: - ( - reference_chosen_log_probs, - reference_rejected_log_probs, - _, - _, - ) = concatenated_forward( - state.apply_fn, - state.params, - batch - ) - else: - ( - reference_chosen_log_probs, - reference_rejected_log_probs, - _, - _, - ) = concatenated_forward( - ref_state.apply_fn, - ref_state.params, - batch - ) - - pi_log_ratios = policy_chosen_log_probs - policy_rejected_log_probs - - if reference_free: - ref_log_ratios = 0 - else: - ref_log_ratios = reference_chosen_log_probs - reference_rejected_log_probs - - logits = pi_log_ratios - ref_log_ratios - losses = _loss_func( - logits, - policy_chosen_log_probs, - reference_chosen_log_probs, - policy_rejected_log_probs, - reference_rejected_log_probs - ) - chosen_rewards = ( - beta - * ( - policy_chosen_log_probs - reference_chosen_log_probs - ) - ) - rejected_rewards = ( - beta - * ( - policy_rejected_log_probs - - reference_rejected_log_probs - ) - ) - return losses[0], (chosen_rewards, rejected_rewards) - - grad_fn = jax.value_and_grad(calculate_loss, has_aux=True) - (__loss, (__chosen_rewards, __rejected_rewards)), grads = grad_fn(state.params) - new_state = state.apply_gradients(grads=grads) - return new_state, DPOStepOut( - loss=__loss, - rejected_rewards=__rejected_rewards, - chosen_rewards=__chosen_rewards - ) - - return dpo_step - - -def create_dpo_eval_function( - concatenated_forward: Callable, - ref_state: EasyDeLState = None, - beta: float = 0.1, - label_smoothing: float = 0, - loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid", - reference_free: bool = False, -): - """ - The create_dpo_eval_function function is a helper function that creates the DPO evaluating step. - - :param concatenated_forward: Callable: Define the forward pass of the model - :param ref_state: EasyDeLState: Specify the reference policy - :param beta: float: Scale the logits - :param label_smoothing: float: Smooth the labels - :param loss_type: Literal["sigmoid", "hinge", "ipo", "kto"]: Determine the loss function - :param reference_free: bool: Indicate whether the reference policy is used or not - :return: A function that takes in a state and a batch - """ - - def _sigmoid_dpo_loss( - logits: chex.Array, - policy_chosen_log_probs: chex.Array = None, # IGNORED - reference_chosen_log_probs: chex.Array = None, # IGNORED - policy_rejected_log_probs: chex.Array = None, # IGNORED - reference_rejected_log_probs: chex.Array = None # IGNORED - ): - - """ - The _sigmoid_dpo_loss function is a helper function for the sigmoid_dpo_loss - function. It computes the loss of each example in a batch, given its logits - and (optionally) its chosen/rejected log probabilities under both policies. - - :param logits: chex.Array: Compute the loss - :param policy_chosen_log_probs: chex.Array: Calculate the policy loss - :param reference_chosen_log_probs: chex.Array: Compute the loss for the reference policy # IGNORED - :param policy_rejected_log_probs: chex.Array: Calculate the loss for the rejected samples # IGNORED - :param reference_rejected_log_probs: chex.Array: Calculate the loss of rejected samples # IGNORED - :return: an array represent loss - """ - losses = ( - -jax.nn.log_sigmoid(beta * logits) * (1 - label_smoothing) - - jax.nn.log_sigmoid(-beta * logits) * label_smoothing - ) - return losses - - def _hinge_dpo_loss( - logits: chex.Array, - policy_chosen_log_probs: chex.Array, # IGNORED - reference_chosen_log_probs: chex.Array, # IGNORED - policy_rejected_log_probs: chex.Array, # IGNORED - reference_rejected_log_probs: chex.Array # IGNORED - ): - - """ - The _hinge_dpo_loss function is a helper function that computes the loss for DPO. - - :param logits: chex.Array: Calculate the hinge loss - :param policy_chosen_log_probs: chex.Array: Compute the policy loss - :param reference_chosen_log_probs: chex.Array: Compute the loss for the reference policy # IGNORED - :param policy_rejected_log_probs: chex.Array: Calculate the loss for the rejected samples # IGNORED - :param reference_rejected_log_probs: chex.Array: Calculate the loss of rejected samples # IGNORED - :return: an array represent The hinge loss - """ - return jax.relu(1 - beta * logits) - - def _ipo_dpo_loss( - logits: chex.Array, - policy_chosen_log_probs: chex.Array, # IGNORED - reference_chosen_log_probs: chex.Array, # IGNORED - policy_rejected_log_probs: chex.Array, # IGNORED - reference_rejected_log_probs: chex.Array # IGNORED - ): - """ - The _ipo_dpo_loss function is a helper function that calculates the loss for - the IPO-DPO algorithm. It takes in the logits, policy_chosen_log_probs, - reference_chosen_log_probs, policy rejected log probs and reference rejected - log probs as inputs. The output of this function is used to calculate the loss - for each batch of data. - - :param logits: chex.Array: Calculate the loss - :param policy_chosen_log_probs: chex.Array: Compute the - :param reference_chosen_log_probs: chex.Array: Compute the loss for the reference policy # IGNORED - :param policy_rejected_log_probs: chex.Array: Calculate the loss for the rejected samples # IGNORED - :param reference_rejected_log_probs: chex.Array: Calculate the loss of rejected samples # IGNORED - :return: an array represent loss - """ - return (logits - 1 / (2 * beta)) ** 2 - - def _kto_pair_dpo_loss( - logits: chex.Array, # IGNORED - policy_chosen_log_probs: chex.Array, - reference_chosen_log_probs: chex.Array, - policy_rejected_log_probs: chex.Array, - reference_rejected_log_probs: chex.Array - ): - - """ - The _kto_pair_dpo_loss function is a helper function that computes the loss for - a single pair of trajectories. It takes in two sets of log probabilities, one from - the policy and one from the reference distribution. The first set are the log - probabilities for actions taken by each agent in a trajectory, while the second set - are those for actions not taken by each agent (i.e., rejected). The function then - computes KL divergences between these two sets of distributions and uses them to compute losses. - - :param logits: chex.Array: Calculate the log_probs - :param policy_chosen_log_probs: chex.Array: Calculate the chosen_kl # IGNORED - :param reference_chosen_log_probs: chex.Array: Calculate the chosen_kl - :param policy_rejected_log_probs: chex.Array: Calculate the rejected_kl variable - :param reference_rejected_log_probs: chex.Array: Calculate the rejected_kl variable - :return: an array represent loss - """ - chosen_kl = jax.lax.clamp( - min=0, - x=jnp.mean(policy_chosen_log_probs - reference_chosen_log_probs), - max=1e9 - ) - rejected_kl = jax.lax.clamp( - min=0, - x=jnp.mean(policy_rejected_log_probs - reference_rejected_log_probs), - max=1e9 - ) - - chosen_log_ratios = policy_chosen_log_probs - reference_chosen_log_probs - rejected_log_ratios = policy_rejected_log_probs - reference_rejected_log_probs - losses = jnp.concatenate( - ( - 1 - jax.nn.sigmoid(beta * (chosen_log_ratios - rejected_kl)), - 1 - jax.nn.sigmoid(beta * (chosen_kl - rejected_log_ratios)), - ), - 0, - ) - - return losses - - if loss_type == "sigmoid": - _loss_func = _sigmoid_dpo_loss - elif loss_type == "hinge": - _loss_func = _hinge_dpo_loss - elif loss_type == "ipo": - _loss_func = _ipo_dpo_loss - elif loss_type == "kto_pair": - _loss_func = _kto_pair_dpo_loss - else: - raise ValueError(f"UnKnown loss_type {loss_type}") - - def dpo_step( - state: EasyDeLState, - batch: dict - ) -> DPOStepOut: - - """ - The dpo_step function is the core of DPO. It takes a state and a batch, - and returns an updated state. The update is done by calculating the loss - for each example in the batch, then taking its gradient with respect to - the parameters of the policy network (which are stored in `state`). This - gradient is then used to update `state`. - - :param state: EasyDeLState: Store the parameters of the model - :param batch: dict: Pass the data to the model - :return: A `DPOStepOut` class - """ - - def calculate_loss(params: dict | flax.core.FrozenDict): - ( - policy_chosen_log_probs, - policy_rejected_log_probs, - policy_chosen_logits, - policy_rejected_logits, - ) = concatenated_forward( - state.apply_fn, - params, - batch - ) - - if "reference_chosen_log_probs" in batch and "reference_rejected_log_probs" in batch: - reference_chosen_log_probs = batch["reference_chosen_log_probs"] - reference_rejected_log_probs = batch["reference_rejected_log_probs"] - else: - if ref_state is None: - ( - reference_chosen_log_probs, - reference_rejected_log_probs, - _, - _, - ) = concatenated_forward( - state.apply_fn, - state.params, - batch - ) - else: - ( - reference_chosen_log_probs, - reference_rejected_log_probs, - _, - _, - ) = concatenated_forward( - ref_state.apply_fn, - ref_state.params, - batch - ) - - pi_log_ratios = policy_chosen_log_probs - policy_rejected_log_probs - - if reference_free: - ref_log_ratios = 0 - else: - ref_log_ratios = reference_chosen_log_probs - reference_rejected_log_probs - - logits = pi_log_ratios - ref_log_ratios - losses = _loss_func( - logits, - policy_chosen_log_probs, - reference_chosen_log_probs, - policy_rejected_log_probs, - reference_rejected_log_probs - ) - chosen_rewards = ( - beta - * ( - policy_chosen_log_probs - reference_chosen_log_probs - ) - ) - rejected_rewards = ( - beta - * ( - policy_rejected_log_probs - - reference_rejected_log_probs - ) - ) - return losses[0], (chosen_rewards, rejected_rewards) - - __loss, (__chosen_rewards, __rejected_rewards) = calculate_loss(state.params) - - return DPOStepOut( - loss=__loss, - rejected_rewards=__rejected_rewards, - chosen_rewards=__chosen_rewards - ) - - return dpo_step +import typing +import chex +import flax.core +import jax + +from typing import Literal, Dict, Union, Tuple, List, Callable + +from jax import numpy as jnp +from ...etils import EasyDeLState +from flax.struct import dataclass +from .utils import pad_to_length + + +@dataclass +class DPOStepOut: + loss: chex.Array + chosen_rewards: chex.Array + rejected_rewards: chex.Array + + +def create_concatenated_forward( + is_encoder_decoder, + label_pad_token_id, + padding_value, + truncation_mode: typing.Literal["keep_end", "keep_start"] = "keep_end", + fixed_max_length: int | None = None +): + """The create_concatenated_forward function is a helper function that creates a forward pass function for the + model. The forward pass function takes in an apply_fn, which is the model's apply_fn, and runs it on concatenated + inputs. It returns chosen log probs, rejected log probs, chosen logits and rejected logits. + + Args: + is_encoder_decoder: Determine whether the model is an encoder- + decoder model or not + label_pad_token_id: Pad the labels to the same length + padding_value: Pad the inputs to the same length + truncation_mode: typing.Literal["keep_end","keep_start"]: where + to pad and where to keep. + fixed_max_length: int|None: by providing fixed_max_length the + func will always return a fixed sequence length + and won't use dynamic methods. + + Returns: + A function that takes in a apply_fn, params and a batch of + inputs, + """ + + def concatenated_forward( + apply_fn: Callable, + params: dict | flax.core.FrozenDict, + batch: Dict[str, Union[List, chex.Array]] + + ) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: + """The concatenated_forward function is used to compute the log-probabilities of both chosen and rejected labels. + + Args: + apply_fn: Callable: Pass in the model function + params: dict | flax.core.FrozenDict: Pass the model + parameters to the function + batch: Dict[str, Union[List, chex.Array]] : Pass the batch + of data to the concatenated_forward function + + Returns: + The log_probs of the chosen and rejected labels, as well as + their corresponding logits + """ + assert padding_value is not None, "`padding_value` can not be set as `None` it must be an integer." + concatenated_batch = concatenated_inputs( + batch, + is_encoder_decoder=is_encoder_decoder, + label_pad_token_id=label_pad_token_id, + padding_value=padding_value, + truncation_mode=truncation_mode, + fixed_max_length=fixed_max_length + ) + len_chosen = batch["chosen_labels"].shape[0] + concatenated_batch["concatenated_input_ids"] = concatenated_batch["concatenated_input_ids"].reshape( + concatenated_batch["concatenated_input_ids"].shape[0], -1 + ) + concatenated_batch["concatenated_labels"] = concatenated_batch["concatenated_labels"].reshape( + concatenated_batch["concatenated_labels"].shape[0], -1 + ) + concatenated_batch["concatenated_attention_mask"] = concatenated_batch["concatenated_attention_mask"].reshape( + concatenated_batch["concatenated_attention_mask"].shape[0], -1 + ) + model_kwargs = ( + { + "labels": concatenated_batch["concatenated_labels"], + "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), + } + if is_encoder_decoder + else {} + ) + all_logits = apply_fn( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + params=params, + **model_kwargs, + ).logits + + all_log_probs = get_batch_log_probs( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=False, + is_encoder_decoder=is_encoder_decoder, + label_pad_token_id=label_pad_token_id, + ) + + chosen_log_probs = all_log_probs[:len_chosen] + rejected_log_probs = all_log_probs[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits + + return concatenated_forward + + +def get_batch_log_probs( + logits: chex.Array, + labels: chex.Array, + average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, +) -> chex.Array: + """The get_batch_log_probs function computes the log probability of a batch of sequences. + + Args: + logits: chex.Array: Compute the log_softmax of the input + labels: chex.Array: Mask the logits + average_log_prob: bool: Determine whether to average the log + prob over the sequence length + label_pad_token_id: int: Mask out the padding tokens in the + labels + is_encoder_decoder: bool: Indicate whether the model is an + encoder-decoder model + :param : Determine whether to average the log probability over all tokens or not + + Returns: + The log probability of the labels given the logits + """ + + # sudo code + # (per_token_log_probs * loss_mask).sum(-1) + # or + # (per_token_log_probs * loss_mask).sum(-1) / loss_mask.sum(-1) + + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:] + logits = logits[:, :-1, :] + + batch, seq_len, dim = logits.shape + loss_mask = labels != label_pad_token_id + labels = jax.lax.select( + labels == label_pad_token_id, + jnp.zeros(labels.shape, dtype=labels.dtype), + labels + ) + logits_log_s = jax.nn.log_softmax( + logits, -1 + ) + per_token_log_probs = jnp.take_along_axis( + logits_log_s, + axis=2, + indices=labels[:, :, None] + ).reshape(batch, seq_len) + + if average_log_prob: + log_prob = jnp.sum((per_token_log_probs * loss_mask), axis=-1) / jnp.sum(loss_mask, axis=-1) + else: + log_prob = jnp.sum((per_token_log_probs * loss_mask), axis=-1) + + return log_prob + + +def concatenated_inputs( + batch: Dict[str, Union[List, chex.Array]], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + truncation_mode: typing.Literal["keep_end", "keep_start"] = "keep_end", + fixed_max_length: int | None = None +) -> Dict[str, chex.Array]: + """The concatenated_inputs function takes a batch of chosen and rejected examples, + and concatenates them together. This is useful for training the model to predict whether an example was chosen + by the human annotator. The function also pads all inputs to + the same length as the longest input in that batch. + + Args: + batch: Dict[str,Union[List,chex.Array]]: Pass the batch of data + into the function, + is_encoder_decoder: bool: Determine whether the model is an + encoder-decoder model + label_pad_token_id: int: Pad the labels with a value of -100 + padding_value: int: Pad the input_ids and attention_mask arrays + to the same length + truncation_mode: typing.Literal["keep_end", "keep_start"]: is + left padded or not should it keep start of the + fixed_max_length: int|None: by providing fixed_max_length the + func will always return a fixed sequence length and won't + use dynamic methods. + Allow for the batch to be a list of arrays or just an array, + Specify the type of data that is being passed in + + array or the end of the array?. + + Returns: + A dictionary of the concatenated inputs + """ + concatenated_batch = {} + if fixed_max_length is None: + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[-1], batch["rejected_labels"].shape[-1]) + else: + max_length = max(batch["chosen_input_ids"].shape[-1], batch["rejected_input_ids"].shape[-1]) + else: + max_length = fixed_max_length + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], jax.Array): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + else: + raise KeyError("couldn't find pad_value [Dataset Issue]") + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], jax.Array): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + assert padding_value is not None, "`padding_value` can not be set as `None`" + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + else: + raise KeyError("couldn't find pad_value [Dataset Issue]") + concatenated_key = k.replace("rejected", "concatenated") + v2d = lambda ar: ar.reshape(ar.shape[0], -1) + concatenated_batch[concatenated_key] = jnp.concatenate( + ( + v2d(concatenated_batch[concatenated_key]), + pad_to_length(v2d(batch[k]), max_length, pad_value=pad_value), + ), + axis=0, + ) + for k in list(concatenated_batch.keys()): + val = concatenated_batch[k] + if val.ndim == 3: + # making 3d array 2d + concatenated_batch[k] = val.reshape(val.shape[0], -1) + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1) + ) + + return concatenated_batch + + +def create_dpo_train_function( + concatenated_forward: Callable, + ref_state: EasyDeLState = None, + beta: float = 0.1, + label_smoothing: float = 0, + loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid", + reference_free: bool = False, +): + """The create_dpo_train_function function is a helper function that creates the DPO training step. + + Args: + concatenated_forward: Callable: Define the forward pass of the + model + ref_state: EasyDeLState: Specify the reference policy + beta: float: Scale the logits + label_smoothing: float: Smooth the labels + loss_type: Literal["sigmoid", "hinge", "ipo", "kto"]: Determine + the loss function + reference_free: bool: Indicate whether the reference policy is + used or not + + Returns: + A function that takes in a state and a batch + """ + + def _sigmoid_dpo_loss( + logits: chex.Array, + policy_chosen_log_probs: chex.Array = None, # IGNORED + reference_chosen_log_probs: chex.Array = None, # IGNORED + policy_rejected_log_probs: chex.Array = None, # IGNORED + reference_rejected_log_probs: chex.Array = None # IGNORED + ): + + """The _sigmoid_dpo_loss function is a helper function for the sigmoid_dpo_loss + function. It computes the loss of each example in a batch, given its logits + and (optionally) its chosen/rejected log probabilities under both policies. + + Args: + logits: chex.Array: Compute the loss + policy_chosen_log_probs: chex.Array: Calculate the policy + loss + reference_chosen_log_probs: chex.Array: Compute the loss for + the reference policy # IGNORED + policy_rejected_log_probs: chex.Array: Calculate the loss + for the rejected samples # IGNORED + reference_rejected_log_probs: chex.Array: Calculate the loss + of rejected samples # IGNORED + + Returns: + an array represent loss + """ + losses = ( + -jax.nn.log_sigmoid(beta * logits) * (1 - label_smoothing) + - jax.nn.log_sigmoid(-beta * logits) * label_smoothing + ) + return losses + + def _hinge_dpo_loss( + logits: chex.Array, + policy_chosen_log_probs: chex.Array, # IGNORED + reference_chosen_log_probs: chex.Array, # IGNORED + policy_rejected_log_probs: chex.Array, # IGNORED + reference_rejected_log_probs: chex.Array # IGNORED + ): + + """The _hinge_dpo_loss function is a helper function that computes the loss for DPO. + + Args: + logits: chex.Array: Calculate the hinge loss + policy_chosen_log_probs: chex.Array: Compute the policy loss + reference_chosen_log_probs: chex.Array: Compute the loss for + the reference policy # IGNORED + policy_rejected_log_probs: chex.Array: Calculate the loss + for the rejected samples # IGNORED + reference_rejected_log_probs: chex.Array: Calculate the loss + of rejected samples # IGNORED + + Returns: + an array represent The hinge loss + """ + return jax.relu(1 - beta * logits) + + def _ipo_dpo_loss( + logits: chex.Array, + policy_chosen_log_probs: chex.Array, # IGNORED + reference_chosen_log_probs: chex.Array, # IGNORED + policy_rejected_log_probs: chex.Array, # IGNORED + reference_rejected_log_probs: chex.Array # IGNORED + ): + """The _ipo_dpo_loss function is a helper function that calculates the loss for + the IPO-DPO algorithm. It takes in the logits, policy_chosen_log_probs, + reference_chosen_log_probs, policy rejected log probs and reference rejected + log probs as inputs. The output of this function is used to calculate the loss + for each batch of data. + + :param logits: chex.Array: Calculate the loss + :param policy_chosen_log_probs: chex.Array: Compute the + :param reference_chosen_log_probs: chex.Array: Compute the loss for the reference policy # IGNORED + :param policy_rejected_log_probs: chex.Array: Calculate the loss for the rejected samples # IGNORED + :param reference_rejected_log_probs: chex.Array: Calculate the loss of rejected samples # IGNORED + :return: an array represent loss + """ + return (logits - 1 / (2 * beta)) ** 2 + + def _kto_pair_dpo_loss( + logits: chex.Array, # IGNORED + policy_chosen_log_probs: chex.Array, + reference_chosen_log_probs: chex.Array, + policy_rejected_log_probs: chex.Array, + reference_rejected_log_probs: chex.Array + ): + + """The _kto_pair_dpo_loss function is a helper function that computes the loss for + a single pair of trajectories. It takes in two sets of log probabilities, one from + the policy and one from the reference distribution. The first set are the log + probabilities for actions taken by each agent in a trajectory, while the second set + are those for actions not taken by each agent (i.e., rejected). The function then + computes KL divergences between these two sets of distributions and uses them to compute losses. + + Args: + logits: chex.Array: Calculate the log_probs + policy_chosen_log_probs: chex.Array: Calculate the chosen_kl + # IGNORED + reference_chosen_log_probs: chex.Array: Calculate the + chosen_kl + policy_rejected_log_probs: chex.Array: Calculate the + rejected_kl variable + reference_rejected_log_probs: chex.Array: Calculate the + rejected_kl variable + + Returns: + an array represent loss + """ + chosen_kl = jax.lax.clamp( + min=0, + x=jnp.mean(policy_chosen_log_probs - reference_chosen_log_probs), + max=1e9 + ) + rejected_kl = jax.lax.clamp( + min=0, + x=jnp.mean(policy_rejected_log_probs - reference_rejected_log_probs), + max=1e9 + ) + + chosen_log_ratios = policy_chosen_log_probs - reference_chosen_log_probs + rejected_log_ratios = policy_rejected_log_probs - reference_rejected_log_probs + losses = jnp.concatenate( + ( + 1 - jax.nn.sigmoid(beta * (chosen_log_ratios - rejected_kl)), + 1 - jax.nn.sigmoid(beta * (chosen_kl - rejected_log_ratios)), + ), + 0, + ) + + return losses + + if loss_type == "sigmoid": + _loss_func = _sigmoid_dpo_loss + elif loss_type == "hinge": + _loss_func = _hinge_dpo_loss + elif loss_type == "ipo": + _loss_func = _ipo_dpo_loss + elif loss_type == "kto_pair": + _loss_func = _kto_pair_dpo_loss + else: + raise ValueError(f"UnKnown loss_type {loss_type}") + + def dpo_step( + state: EasyDeLState, + batch: dict + ) -> tuple[EasyDeLState, DPOStepOut]: + + """The dpo_step function is the core of DPO. It takes a state and a batch, + and returns an updated state. The update is done by calculating the loss + for each example in the batch, then taking its gradient with respect to + the parameters of the policy network (which are stored in `state`). This + gradient is then used to update `state`. + + Args: + state: EasyDeLState: Store the parameters of the model + batch: dict: Pass the data to the model + + Returns: + A new state, which is a collection of the parameters and + apply_fn + """ + + def calculate_loss(params: dict | flax.core.FrozenDict): + ( + policy_chosen_log_probs, + policy_rejected_log_probs, + policy_chosen_logits, + policy_rejected_logits, + ) = concatenated_forward( + state.apply_fn, + params, + batch + ) + + if "reference_chosen_log_probs" in batch and "reference_rejected_log_probs" in batch: + reference_chosen_log_probs = batch["reference_chosen_log_probs"] + reference_rejected_log_probs = batch["reference_rejected_log_probs"] + else: + if ref_state is None: + ( + reference_chosen_log_probs, + reference_rejected_log_probs, + _, + _, + ) = concatenated_forward( + state.apply_fn, + state.params, + batch + ) + else: + ( + reference_chosen_log_probs, + reference_rejected_log_probs, + _, + _, + ) = concatenated_forward( + ref_state.apply_fn, + ref_state.params, + batch + ) + + pi_log_ratios = policy_chosen_log_probs - policy_rejected_log_probs + + if reference_free: + ref_log_ratios = 0 + else: + ref_log_ratios = reference_chosen_log_probs - reference_rejected_log_probs + + logits = pi_log_ratios - ref_log_ratios + losses = _loss_func( + logits, + policy_chosen_log_probs, + reference_chosen_log_probs, + policy_rejected_log_probs, + reference_rejected_log_probs + ) + chosen_rewards = ( + beta + * ( + policy_chosen_log_probs - reference_chosen_log_probs + ) + ) + rejected_rewards = ( + beta + * ( + policy_rejected_log_probs + - reference_rejected_log_probs + ) + ) + return losses[0], (chosen_rewards, rejected_rewards) + + grad_fn = jax.value_and_grad(calculate_loss, has_aux=True) + (__loss, (__chosen_rewards, __rejected_rewards)), grads = grad_fn(state.params) + new_state = state.apply_gradients(grads=grads) + return new_state, DPOStepOut( + loss=__loss, + rejected_rewards=__rejected_rewards, + chosen_rewards=__chosen_rewards + ) + + return dpo_step + + +def create_dpo_eval_function( + concatenated_forward: Callable, + ref_state: EasyDeLState = None, + beta: float = 0.1, + label_smoothing: float = 0, + loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid", + reference_free: bool = False, +): + """The create_dpo_eval_function function is a helper function that creates the DPO evaluating step. + + Args: + concatenated_forward: Callable: Define the forward pass of the + model + ref_state: EasyDeLState: Specify the reference policy + beta: float: Scale the logits + label_smoothing: float: Smooth the labels + loss_type: Literal["sigmoid", "hinge", "ipo", "kto"]: Determine + the loss function + reference_free: bool: Indicate whether the reference policy is + used or not + + Returns: + A function that takes in a state and a batch + """ + + def _sigmoid_dpo_loss( + logits: chex.Array, + policy_chosen_log_probs: chex.Array = None, # IGNORED + reference_chosen_log_probs: chex.Array = None, # IGNORED + policy_rejected_log_probs: chex.Array = None, # IGNORED + reference_rejected_log_probs: chex.Array = None # IGNORED + ): + + """The _sigmoid_dpo_loss function is a helper function for the sigmoid_dpo_loss + function. It computes the loss of each example in a batch, given its logits + and (optionally) its chosen/rejected log probabilities under both policies. + + Args: + logits: chex.Array: Compute the loss + policy_chosen_log_probs: chex.Array: Calculate the policy + loss + reference_chosen_log_probs: chex.Array: Compute the loss for + the reference policy # IGNORED + policy_rejected_log_probs: chex.Array: Calculate the loss + for the rejected samples # IGNORED + reference_rejected_log_probs: chex.Array: Calculate the loss + of rejected samples # IGNORED + + Returns: + an array represent loss + """ + losses = ( + -jax.nn.log_sigmoid(beta * logits) * (1 - label_smoothing) + - jax.nn.log_sigmoid(-beta * logits) * label_smoothing + ) + return losses + + def _hinge_dpo_loss( + logits: chex.Array, + policy_chosen_log_probs: chex.Array, # IGNORED + reference_chosen_log_probs: chex.Array, # IGNORED + policy_rejected_log_probs: chex.Array, # IGNORED + reference_rejected_log_probs: chex.Array # IGNORED + ): + + """The _hinge_dpo_loss function is a helper function that computes the loss for DPO. + + Args: + logits: chex.Array: Calculate the hinge loss + policy_chosen_log_probs: chex.Array: Compute the policy loss + reference_chosen_log_probs: chex.Array: Compute the loss for + the reference policy # IGNORED + policy_rejected_log_probs: chex.Array: Calculate the loss + for the rejected samples # IGNORED + reference_rejected_log_probs: chex.Array: Calculate the loss + of rejected samples # IGNORED + + Returns: + an array represent The hinge loss + """ + return jax.relu(1 - beta * logits) + + def _ipo_dpo_loss( + logits: chex.Array, + policy_chosen_log_probs: chex.Array, # IGNORED + reference_chosen_log_probs: chex.Array, # IGNORED + policy_rejected_log_probs: chex.Array, # IGNORED + reference_rejected_log_probs: chex.Array # IGNORED + ): + """The _ipo_dpo_loss function is a helper function that calculates the loss for + the IPO-DPO algorithm. It takes in the logits, policy_chosen_log_probs, + reference_chosen_log_probs, policy rejected log probs and reference rejected + log probs as inputs. The output of this function is used to calculate the loss + for each batch of data. + + :param logits: chex.Array: Calculate the loss + :param policy_chosen_log_probs: chex.Array: Compute the + :param reference_chosen_log_probs: chex.Array: Compute the loss for the reference policy # IGNORED + :param policy_rejected_log_probs: chex.Array: Calculate the loss for the rejected samples # IGNORED + :param reference_rejected_log_probs: chex.Array: Calculate the loss of rejected samples # IGNORED + :return: an array represent loss + """ + return (logits - 1 / (2 * beta)) ** 2 + + def _kto_pair_dpo_loss( + logits: chex.Array, # IGNORED + policy_chosen_log_probs: chex.Array, + reference_chosen_log_probs: chex.Array, + policy_rejected_log_probs: chex.Array, + reference_rejected_log_probs: chex.Array + ): + + """The _kto_pair_dpo_loss function is a helper function that computes the loss for + a single pair of trajectories. It takes in two sets of log probabilities, one from + the policy and one from the reference distribution. The first set are the log + probabilities for actions taken by each agent in a trajectory, while the second set + are those for actions not taken by each agent (i.e., rejected). The function then + computes KL divergences between these two sets of distributions and uses them to compute losses. + + Args: + logits: chex.Array: Calculate the log_probs + policy_chosen_log_probs: chex.Array: Calculate the chosen_kl + # IGNORED + reference_chosen_log_probs: chex.Array: Calculate the + chosen_kl + policy_rejected_log_probs: chex.Array: Calculate the + rejected_kl variable + reference_rejected_log_probs: chex.Array: Calculate the + rejected_kl variable + + Returns: + an array represent loss + """ + chosen_kl = jax.lax.clamp( + min=0, + x=jnp.mean(policy_chosen_log_probs - reference_chosen_log_probs), + max=1e9 + ) + rejected_kl = jax.lax.clamp( + min=0, + x=jnp.mean(policy_rejected_log_probs - reference_rejected_log_probs), + max=1e9 + ) + + chosen_log_ratios = policy_chosen_log_probs - reference_chosen_log_probs + rejected_log_ratios = policy_rejected_log_probs - reference_rejected_log_probs + losses = jnp.concatenate( + ( + 1 - jax.nn.sigmoid(beta * (chosen_log_ratios - rejected_kl)), + 1 - jax.nn.sigmoid(beta * (chosen_kl - rejected_log_ratios)), + ), + 0, + ) + + return losses + + if loss_type == "sigmoid": + _loss_func = _sigmoid_dpo_loss + elif loss_type == "hinge": + _loss_func = _hinge_dpo_loss + elif loss_type == "ipo": + _loss_func = _ipo_dpo_loss + elif loss_type == "kto_pair": + _loss_func = _kto_pair_dpo_loss + else: + raise ValueError(f"UnKnown loss_type {loss_type}") + + def dpo_step( + state: EasyDeLState, + batch: dict + ) -> DPOStepOut: + + """The dpo_step function is the core of DPO. It takes a state and a batch, + and returns an updated state. The update is done by calculating the loss + for each example in the batch, then taking its gradient with respect to + the parameters of the policy network (which are stored in `state`). This + gradient is then used to update `state`. + + Args: + state: EasyDeLState: Store the parameters of the model + batch: dict: Pass the data to the model + + Returns: + A `DPOStepOut` class + """ + + def calculate_loss(params: dict | flax.core.FrozenDict): + ( + policy_chosen_log_probs, + policy_rejected_log_probs, + policy_chosen_logits, + policy_rejected_logits, + ) = concatenated_forward( + state.apply_fn, + params, + batch + ) + + if "reference_chosen_log_probs" in batch and "reference_rejected_log_probs" in batch: + reference_chosen_log_probs = batch["reference_chosen_log_probs"] + reference_rejected_log_probs = batch["reference_rejected_log_probs"] + else: + if ref_state is None: + ( + reference_chosen_log_probs, + reference_rejected_log_probs, + _, + _, + ) = concatenated_forward( + state.apply_fn, + state.params, + batch + ) + else: + ( + reference_chosen_log_probs, + reference_rejected_log_probs, + _, + _, + ) = concatenated_forward( + ref_state.apply_fn, + ref_state.params, + batch + ) + + pi_log_ratios = policy_chosen_log_probs - policy_rejected_log_probs + + if reference_free: + ref_log_ratios = 0 + else: + ref_log_ratios = reference_chosen_log_probs - reference_rejected_log_probs + + logits = pi_log_ratios - ref_log_ratios + losses = _loss_func( + logits, + policy_chosen_log_probs, + reference_chosen_log_probs, + policy_rejected_log_probs, + reference_rejected_log_probs + ) + chosen_rewards = ( + beta + * ( + policy_chosen_log_probs - reference_chosen_log_probs + ) + ) + rejected_rewards = ( + beta + * ( + policy_rejected_log_probs + - reference_rejected_log_probs + ) + ) + return losses[0], (chosen_rewards, rejected_rewards) + + __loss, (__chosen_rewards, __rejected_rewards) = calculate_loss(state.params) + + return DPOStepOut( + loss=__loss, + rejected_rewards=__rejected_rewards, + chosen_rewards=__chosen_rewards + ) + + return dpo_step diff --git a/src/python/easydel/trainer/dpo/modelling_output.py b/src/python/easydel/trainer/dpo/modelling_output.py index 5e412678f..e23ddefc7 100644 --- a/src/python/easydel/trainer/dpo/modelling_output.py +++ b/src/python/easydel/trainer/dpo/modelling_output.py @@ -1,15 +1,15 @@ -from dataclasses import dataclass -import jax -from typing import Any, Optional, Callable, Mapping -from ...etils.easystate import EasyDeLState - - -@dataclass -class DPOTrainerOutput: - state: EasyDeLState - mesh: Optional[jax.sharding.Mesh] - checkpoint_manager: Any - gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None - shard_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None - last_save_file_name: Optional[str] = None - checkpoint_path: Optional[str] = None +from dataclasses import dataclass +import jax +from typing import Any, Optional, Callable, Mapping +from ...etils.easystate import EasyDeLState + + +@dataclass +class DPOTrainerOutput: + state: EasyDeLState + mesh: Optional[jax.sharding.Mesh] + checkpoint_manager: Any + gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None + shard_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None + last_save_file_name: Optional[str] = None + checkpoint_path: Optional[str] = None diff --git a/src/python/easydel/trainer/dpo/utils.py b/src/python/easydel/trainer/dpo/utils.py index 796b88ad4..caef86adf 100644 --- a/src/python/easydel/trainer/dpo/utils.py +++ b/src/python/easydel/trainer/dpo/utils.py @@ -1,150 +1,151 @@ -import chex -import jax - -from typing import Optional, Dict, Union, Any, List -from jax import numpy as jnp -from dataclasses import dataclass -from contextlib import contextmanager - - -def pad_to_length(tensor: chex.Array, length: int, pad_value: Union[int, float], axis: int = -1) -> chex.Array: - if tensor.shape[axis] >= length: - if tensor.ndim == 2: - tensor = tensor[:, :length] - return tensor - else: - pad_size = list(tensor.shape) - pad_size[axis] = length - tensor.shape[axis] - return jax.numpy.concatenate( - [ - tensor, - pad_value * jax.numpy.ones(pad_size, dtype=tensor.dtype), - ], - axis=axis, - ) - - -@dataclass -class DPODataCollatorWithPadding: - r""" - DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch. - - :param pad_token_id: int: The tokenizers pad_token_id. - :param label_pad_token_id: int: The label used for masking. - :param is_encoder_decoder: Optional[bool]: Whether you model has an encoder_decoder architecture - """ - max_prompt_length: int - max_target_length: int - pad_token_id: int = 0 - label_pad_token_id: int = -100 - is_encoder_decoder: Optional[bool] = False - ids_to_pop_from_dataset: Optional[dict] = None - auto_fix_data: bool = True - - def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: - padded_batch = {} - for k in features[0].keys(): - if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): - if self.is_encoder_decoder: - to_pad = [jnp.array(ex[k], dtype="i4") for ex in features] - - if (k.startswith("prompt")) and (k.endswith("input_ids")): - padding_value = self.pad_token_id - elif k.endswith("_attention_mask"): - padding_value = 0 - elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k): - padding_value = self.label_pad_token_id - else: - raise ValueError(f"Unexpected key in batch '{k}'") - padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value).astype("i4") - else: - if "prompt" in k: - to_pad = [jnp.array(ex[k][::-1], dtype="i4") for ex in features] - else: - to_pad = [jnp.array(ex[k], dtype="i4") for ex in features] - if k.endswith("_input_ids"): - padding_value = self.pad_token_id - elif k.endswith("_labels"): - padding_value = self.label_pad_token_id - elif k.endswith("_attention_mask"): - padding_value = 0 - else: - raise ValueError(f"Unexpected key in batch '{k}'") - padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value).astype("i4") - if "prompt" in k: - padded_batch[k] = jnp.flip(padded_batch[k], axis=[1]) - elif k.endswith("_logps"): - padded_batch[k] = jnp.array([ex[k] for ex in features]) - else: - padded_batch[k] = [ex[k] for ex in features] - if self.ids_to_pop_from_dataset: - for key in self.ids_to_pop_from_dataset: - _ = padded_batch.pop(key, None) - for key in list(padded_batch.keys()): - if not ( - key.endswith("_input_ids") - or key.endswith("_attention_mask") - or key.endswith("_labels") - or key.endswith("_log_probs") - ): - _ = padded_batch.pop(key, None) - for k in list(padded_batch.keys()): - v = padded_batch[k] - padded_batch[k] = v.reshape(v.shape[0], -1) - if self.auto_fix_data: - padded_batch["rejected_input_ids"] = padded_batch["rejected_input_ids"][..., :self.max_target_length] - padded_batch[ - "rejected_attention_mask" - ] = padded_batch["rejected_attention_mask"][..., :self.max_target_length] - padded_batch["rejected_labels"] = padded_batch["rejected_labels"][..., :self.max_target_length] - - padded_batch["chosen_input_ids"] = padded_batch["chosen_input_ids"][..., :self.max_target_length] - padded_batch["chosen_attention_mask"] = padded_batch["chosen_attention_mask"][..., :self.max_target_length] - padded_batch["chosen_labels"] = padded_batch["chosen_labels"][..., :self.max_target_length] - - padded_batch["prompt_input_ids"] = padded_batch["prompt_input_ids"][..., :self.max_prompt_length] - padded_batch[ - "prompt_attention_mask" - ] = padded_batch["prompt_attention_mask"][..., :self.max_prompt_length] - - return padded_batch - - -def pad_sequence( - sequences, - batch_first=False, - padding_value=0, - max_len: int | None = None -): - max_len = max(seq.shape[-1] for seq in sequences) if max_len is None else max_len - padding_value = jnp.array(padding_value).reshape(1) - if batch_first: - padded_seqs = [ - jnp.concatenate( - [ - seq.reshape(1, -1), - jnp.ones((1, max_len - seq.shape[-1])) * padding_value - ], - axis=1 - ) if seq.shape[-1] < max_len else seq.reshape(1, -1) - for seq in sequences - ] - else: - padded_seqs = [ - jnp.concatenate( - [ - jnp.ones((1, max_len - seq.shape[-1])) * padding_value, - seq.reshape(1, -1) - ], - axis=1 - ) if seq.shape[-1] < max_len else seq.reshape(1, -1) - for seq in sequences - ] - - return jnp.array(padded_seqs) - - -@contextmanager -def leave_alone_context_manager(): - # Perform setup actions (none in this case) - yield +import chex +import jax + +from typing import Optional, Dict, Union, Any, List +from jax import numpy as jnp +from dataclasses import dataclass +from contextlib import contextmanager + + +def pad_to_length(tensor: chex.Array, length: int, pad_value: Union[int, float], axis: int = -1) -> chex.Array: + if tensor.shape[axis] >= length: + if tensor.ndim == 2: + tensor = tensor[:, :length] + return tensor + else: + pad_size = list(tensor.shape) + pad_size[axis] = length - tensor.shape[axis] + return jax.numpy.concatenate( + [ + tensor, + pad_value * jax.numpy.ones(pad_size, dtype=tensor.dtype), + ], + axis=axis, + ) + + +@dataclass +class DPODataCollatorWithPadding: + r"""DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch. + + Args: + pad_token_id: int: The tokenizers pad_token_id. + label_pad_token_id: int: The label used for masking. + is_encoder_decoder: Optional[bool]: Whether you model has an + encoder_decoder architecture + """ + max_prompt_length: int + max_target_length: int + pad_token_id: int = 0 + label_pad_token_id: int = -100 + is_encoder_decoder: Optional[bool] = False + ids_to_pop_from_dataset: Optional[dict] = None + auto_fix_data: bool = True + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + padded_batch = {} + for k in features[0].keys(): + if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): + if self.is_encoder_decoder: + to_pad = [jnp.array(ex[k], dtype="i4") for ex in features] + + if (k.startswith("prompt")) and (k.endswith("input_ids")): + padding_value = self.pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k): + padding_value = self.label_pad_token_id + else: + raise ValueError(f"Unexpected key in batch '{k}'") + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value).astype("i4") + else: + if "prompt" in k: + to_pad = [jnp.array(ex[k][::-1], dtype="i4") for ex in features] + else: + to_pad = [jnp.array(ex[k], dtype="i4") for ex in features] + if k.endswith("_input_ids"): + padding_value = self.pad_token_id + elif k.endswith("_labels"): + padding_value = self.label_pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + else: + raise ValueError(f"Unexpected key in batch '{k}'") + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value).astype("i4") + if "prompt" in k: + padded_batch[k] = jnp.flip(padded_batch[k], axis=[1]) + elif k.endswith("_logps"): + padded_batch[k] = jnp.array([ex[k] for ex in features]) + else: + padded_batch[k] = [ex[k] for ex in features] + if self.ids_to_pop_from_dataset: + for key in self.ids_to_pop_from_dataset: + _ = padded_batch.pop(key, None) + for key in list(padded_batch.keys()): + if not ( + key.endswith("_input_ids") + or key.endswith("_attention_mask") + or key.endswith("_labels") + or key.endswith("_log_probs") + ): + _ = padded_batch.pop(key, None) + for k in list(padded_batch.keys()): + v = padded_batch[k] + padded_batch[k] = v.reshape(v.shape[0], -1) + if self.auto_fix_data: + padded_batch["rejected_input_ids"] = padded_batch["rejected_input_ids"][..., :self.max_target_length] + padded_batch[ + "rejected_attention_mask" + ] = padded_batch["rejected_attention_mask"][..., :self.max_target_length] + padded_batch["rejected_labels"] = padded_batch["rejected_labels"][..., :self.max_target_length] + + padded_batch["chosen_input_ids"] = padded_batch["chosen_input_ids"][..., :self.max_target_length] + padded_batch["chosen_attention_mask"] = padded_batch["chosen_attention_mask"][..., :self.max_target_length] + padded_batch["chosen_labels"] = padded_batch["chosen_labels"][..., :self.max_target_length] + + padded_batch["prompt_input_ids"] = padded_batch["prompt_input_ids"][..., :self.max_prompt_length] + padded_batch[ + "prompt_attention_mask" + ] = padded_batch["prompt_attention_mask"][..., :self.max_prompt_length] + + return padded_batch + + +def pad_sequence( + sequences, + batch_first=False, + padding_value=0, + max_len: int | None = None +): + max_len = max(seq.shape[-1] for seq in sequences) if max_len is None else max_len + padding_value = jnp.array(padding_value).reshape(1) + if batch_first: + padded_seqs = [ + jnp.concatenate( + [ + seq.reshape(1, -1), + jnp.ones((1, max_len - seq.shape[-1])) * padding_value + ], + axis=1 + ) if seq.shape[-1] < max_len else seq.reshape(1, -1) + for seq in sequences + ] + else: + padded_seqs = [ + jnp.concatenate( + [ + jnp.ones((1, max_len - seq.shape[-1])) * padding_value, + seq.reshape(1, -1) + ], + axis=1 + ) if seq.shape[-1] < max_len else seq.reshape(1, -1) + for seq in sequences + ] + + return jnp.array(padded_seqs) + + +@contextmanager +def leave_alone_context_manager(): + # Perform setup actions (none in this case) + yield diff --git a/src/python/easydel/trainer/orpo/__init__.py b/src/python/easydel/trainer/orpo/__init__.py index fee8387e6..2698d46fd 100644 --- a/src/python/easydel/trainer/orpo/__init__.py +++ b/src/python/easydel/trainer/orpo/__init__.py @@ -1,11 +1,11 @@ -from .fwd_bwd_functions import create_orpo_step_function, create_concatenated_forward, odds_ratio_loss -from .modelling_output import ORPOTrainerOutput -from .orpo_trainer import ORPOTrainer - -__all__ = ( - "create_orpo_step_function", - "create_concatenated_forward", - "odds_ratio_loss", - "ORPOTrainerOutput", - "ORPOTrainer" -) +from .fwd_bwd_functions import create_orpo_step_function, create_concatenated_forward, odds_ratio_loss +from .modelling_output import ORPOTrainerOutput +from .orpo_trainer import ORPOTrainer + +__all__ = ( + "create_orpo_step_function", + "create_concatenated_forward", + "odds_ratio_loss", + "ORPOTrainerOutput", + "ORPOTrainer" +) diff --git a/src/python/easydel/trainer/orpo/fwd_bwd_functions.py b/src/python/easydel/trainer/orpo/fwd_bwd_functions.py index 5d8a9b940..49751fce5 100644 --- a/src/python/easydel/trainer/orpo/fwd_bwd_functions.py +++ b/src/python/easydel/trainer/orpo/fwd_bwd_functions.py @@ -1,385 +1,412 @@ -import typing -import warnings - -import chex -import fjformer -import flax.core -import jax - -from typing import Literal, Dict, Union, Tuple, List, Callable - -from jax import numpy as jnp -from ...etils import EasyDeLState -from flax.struct import dataclass -from jax.sharding import PartitionSpec - - -def pad_to_length(tensor: chex.Array, length: int, pad_value: Union[int, float], axis: int = -1) -> chex.Array: - if tensor.shape[axis] >= length: - if tensor.ndim == 2: - tensor = tensor[:, :length] - return tensor - else: - pad_size = list(tensor.shape) - pad_size[axis] = length - tensor.shape[axis] - return jax.numpy.concatenate( - [ - tensor, - pad_value * jax.numpy.ones(pad_size, dtype=tensor.dtype), - ], - axis=axis, - ) - - -@dataclass -class ORPOStepOut: - loss: chex.Array - metrics: Dict[str, Union[chex.Array, str]] - - -def create_concatenated_forward( - is_encoder_decoder, - label_pad_token_id, - padding_value, - truncation_mode: typing.Literal["keep_end", "keep_start"] = "keep_end", - fixed_max_length: int | None = None -): - """ - The create_concatenated_forward function is a helper function that creates a forward pass function for the - model. The forward pass function takes in an apply_fn, which is the model's apply_fn, and runs it on concatenated - inputs. It returns chosen log probs, rejected log probs, chosen logits and rejected logits. - - :param is_encoder_decoder: Determine whether the model is an encoder-decoder model or not - :param label_pad_token_id: Pad the labels to the same length - :param padding_value: Pad the inputs to the same length - :param truncation_mode: typing.Literal["keep_end","keep_start"]: where to pad and where to keep. - :param fixed_max_length : int|None: by providing fixed_max_length the func will always return a fixed sequence length - and won't use dynamic methods. - :return: A function that takes in a apply_fn, params and a batch of inputs, - """ - - def concatenated_forward( - apply_fn: Callable, - params: dict | flax.core.FrozenDict, - batch: Dict[str, Union[List, chex.Array]] - - ) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]: - """ - The concatenated_forward function is used to compute the log-probabilities of both chosen and rejected labels. - - :param apply_fn: Callable: Pass in the model function - :param params: dict | flax.core.FrozenDict: Pass the model parameters to the function - :param batch: Dict[str, Union[List, chex.Array]] : Pass the batch of data to the concatenated_forward function - :return: The log_probs of the chosen and rejected labels, as well as their corresponding logits - """ - assert padding_value is not None, "`padding_value` can not be set as `None` it must be an integer." - concatenated_batch = concatenated_inputs( - batch, - is_encoder_decoder=is_encoder_decoder, - label_pad_token_id=label_pad_token_id, - padding_value=padding_value, - truncation_mode=truncation_mode, - fixed_max_length=fixed_max_length - ) - len_chosen = batch["chosen_labels"].shape[0] - concatenated_batch["concatenated_input_ids"] = concatenated_batch["concatenated_input_ids"].reshape( - concatenated_batch["concatenated_input_ids"].shape[0], -1 - ) - concatenated_batch["concatenated_labels"] = concatenated_batch["concatenated_labels"].reshape( - concatenated_batch["concatenated_labels"].shape[0], -1 - ) - concatenated_batch["concatenated_attention_mask"] = concatenated_batch["concatenated_attention_mask"].reshape( - concatenated_batch["concatenated_attention_mask"].shape[0], -1 - ) - model_kwargs = ( - { - "labels": concatenated_batch["concatenated_labels"], - "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), - } - if is_encoder_decoder - else {} - ) - all_logits = apply_fn( - concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], - params=params, - **model_kwargs, - ).logits - - def cross_entropy_loss(logits, labels, mask): - if not is_encoder_decoder: - logits = logits[..., :-1, :] - labels = labels[..., 1:] - mask = mask[..., 1:] - loss = fjformer.cross_entropy_loss_and_accuracy(logits, labels, mask)[0] - return loss - - if is_encoder_decoder: - labels = concatenated_batch["concatenated_labels"] - else: - labels = concatenated_batch["concatenated_input_ids"] - - chosen_nll_loss = cross_entropy_loss( - all_logits[:len_chosen], - labels[:len_chosen], - concatenated_batch["concatenated_attention_mask"][:len_chosen] - ) - all_log_probs = get_batch_log_probs( - all_logits, - concatenated_batch["concatenated_labels"], - average_log_prob=False, - is_encoder_decoder=is_encoder_decoder, - label_pad_token_id=label_pad_token_id, - ) - - chosen_log_probs = all_log_probs[:len_chosen] - rejected_log_probs = all_log_probs[len_chosen:] - - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] - return chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits, chosen_nll_loss - - return concatenated_forward - - -def get_batch_log_probs( - logits: chex.Array, - labels: chex.Array, - average_log_prob: bool = False, - label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, -) -> chex.Array: - """ - The get_batch_log_probs function computes the log probability of a batch of sequences. - - :param logits: chex.Array: Compute the log_softmax of the input - :param labels: chex.Array: Mask the logits - :param average_log_prob: bool: Determine whether to average the log prob over the sequence length - :param label_pad_token_id: int: Mask out the padding tokens in the labels - :param is_encoder_decoder: bool: Indicate whether the model is an encoder-decoder model - :param : Determine whether to average the log probability over all tokens or not - :return: The log probability of the labels given the logits - """ - - # sudo code - # (per_token_log_probs * loss_mask).sum(-1) - # or - # (per_token_log_probs * loss_mask).sum(-1) / loss_mask.sum(-1) - - if logits.shape[:-1] != labels.shape: - raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") - - if not is_encoder_decoder: - labels = labels[:, 1:] - logits = logits[:, :-1, :] - - batch, seq_len, dim = logits.shape - loss_mask = labels != label_pad_token_id - - labels = jnp.where(labels == label_pad_token_id, 0, labels) - - per_token_logps = jnp.take_along_axis( - jax.nn.log_softmax(logits, axis=-1), axis=2, indices=labels[:, :, None] - ).reshape(batch, seq_len) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - -def concatenated_inputs( - batch: Dict[str, Union[List, chex.Array]], - is_encoder_decoder: bool = False, - label_pad_token_id: int = -100, - padding_value: int = 0, - truncation_mode: typing.Literal["keep_end", "keep_start"] = "keep_end", - fixed_max_length: int | None = None -) -> Dict[str, chex.Array]: - """ - The concatenated_inputs function takes a batch of chosen and rejected examples, - and concatenates them together. This is useful for training the model to predict whether an example was chosen - by the human annotator. The function also pads all inputs to - the same length as the longest input in that batch. - - :param batch: Dict[str,Union[List,chex.Array]]: Pass the batch of data into the function, - Allow for the batch to be a list of arrays or just an array, - Specify the type of data that is being passed in - - :param is_encoder_decoder: bool: Determine whether the model is an encoder-decoder model - :param label_pad_token_id: int: Pad the labels with a value of -100 - :param padding_value: int: Pad the input_ids and attention_mask arrays to the same length - :param truncation_mode: typing.Literal["keep_end", "keep_start"]: is left padded or not should it keep start of the - array or the end of the array?. - - :param fixed_max_length : int|None: by providing fixed_max_length the func will always return a fixed sequence - length and won't use dynamic methods. - - :return: A dictionary of the concatenated inputs - """ - concatenated_batch = {} - if fixed_max_length is None: - if is_encoder_decoder: - max_length = max(batch["chosen_labels"].shape[-1], batch["rejected_labels"].shape[-1]) - else: - max_length = max(batch["chosen_input_ids"].shape[-1], batch["rejected_input_ids"].shape[-1]) - else: - max_length = fixed_max_length - for k in batch: - if k.startswith("chosen") and isinstance(batch[k], jax.Array): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - else: - raise KeyError("couldn't find pad_value [Dataset Issue]") - concatenated_key = k.replace("chosen", "concatenated") - concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) - for k in batch: - if k.startswith("rejected") and isinstance(batch[k], jax.Array): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - assert padding_value is not None, "`padding_value` can not be set as `None`" - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - else: - raise KeyError("couldn't find pad_value [Dataset Issue]") - concatenated_key = k.replace("rejected", "concatenated") - v2d = lambda ar: ar.reshape(ar.shape[0], -1) - concatenated_batch[concatenated_key] = jnp.concatenate( - ( - v2d(concatenated_batch[concatenated_key]), - pad_to_length(v2d(batch[k]), max_length, pad_value=pad_value), - ), - axis=0, - ) - for k in list(concatenated_batch.keys()): - val = concatenated_batch[k] - if val.ndim == 3: - # making 3d array 2d - concatenated_batch[k] = val.reshape(val.shape[0], -1) - if is_encoder_decoder: - warnings.warn("`concatenated_input_ids` will be repeated (encoder decoder model detected)") - concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1) - concatenated_batch["concatenated_attention_mask"] = ( - batch["prompt_attention_mask"].repeat(2, 1) - ) - - return concatenated_batch - - -def odds_ratio_loss( - beta: float, - policy_chosen_logps: chex.Array, - policy_rejected_logps: chex.Array, -) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]: - log_odds = (policy_chosen_logps - policy_rejected_logps) - ( - jnp.log1p(-jnp.exp(policy_chosen_logps)) - jnp.log1p(-jnp.exp(policy_rejected_logps)) - ) - sig_ratio = jax.nn.sigmoid(log_odds) - ratio = jnp.log(sig_ratio) - losses = beta * ratio - # jax.debug.print("policy_chosen_logps : {x}", x=policy_chosen_logps) - # jax.debug.print("policy_rejected_logps : {x}", x=policy_rejected_logps) - # jax.debug.print("sig_ratio : {x}", x=sig_ratio) - # jax.debug.print("ratio : {x}", x=ratio) - # jax.debug.print("log_odds : {x}", x=log_odds) - # jax.debug.print("losses : {x}", x=losses) - - chosen_rewards = beta * jax.lax.stop_gradient(policy_chosen_logps) - rejected_rewards = beta * jax.lax.stop_gradient(policy_rejected_logps) - - return ( - losses, - chosen_rewards, - rejected_rewards, - jnp.mean(ratio), - jnp.mean(log_odds) - ) - - -def create_orpo_step_function( - concatenated_forward: Callable, - beta: float = 0.1, - mode: Literal["train", "eval"] = "train", - batch_partition_spec: PartitionSpec = PartitionSpec(("fsdp", "dp"), "sp") -): - """ - The create_orpo_step_function function is a helper function that creates the ORPO training step. - - :param concatenated_forward: Callable: Define the forward pass of the model - :param beta: float: Scale the logits - :param mode: Literal["train", "eval"] : "train", "eval" function modes - :param batch_partition_spec: PartitionSpec: Batch PartitionSpec - :return: A function that takes in a state and a batch - """ - - def orpo_step( - state: EasyDeLState, - batch: dict - ) -> tuple[EasyDeLState, ORPOStepOut]: - """ - The orpo_step function is the core of ORPO. It takes a state and a batch, - and returns an updated state. The update is done by calculating the loss - for each example in the batch, then taking its gradient with respect to - the parameters of the policy network (which are stored in `state`). This - gradient is then used to update `state`. - - :param state: EasyDeLState: Store the parameters of the model - :param batch: dict: Pass the data to the model - :return: A new state, which is a collection of the parameters and apply_fn - """ - batch = fjformer.with_sharding_constraint(batch, partition_specs=batch_partition_spec) - - def calculate_loss(params: dict | flax.core.FrozenDict): - ( - policy_chosen_log_probs, - policy_rejected_log_probs, - policy_chosen_logits, - policy_rejected_logits, - policy_nll_loss - ) = concatenated_forward( - state.apply_fn, - params, - batch - ) - - losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = odds_ratio_loss( - beta, policy_chosen_log_probs, policy_rejected_log_probs - ) - - loss = policy_nll_loss - losses.mean() - - reward_accuracies = (chosen_rewards > rejected_rewards).astype("float32") - metrics = {} - prefix = "eval_" if mode == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean() - metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean() - metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean() - metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean() - metrics[f"{prefix}logps/rejected"] = policy_rejected_log_probs.mean() - metrics[f"{prefix}logps/chosen"] = policy_chosen_log_probs.mean() - metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.mean() - metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.mean() - metrics[f"{prefix}nll_loss"] = policy_nll_loss.mean() - metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio - metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen - return loss, metrics - - if mode == "train": - grad_fn = jax.value_and_grad(calculate_loss, has_aux=True) - (__loss, (__metrics)), grads = grad_fn(state.params) - new_state = state.apply_gradients(grads=grads) - else: - __loss, __metrics = calculate_loss(state.params) - new_state = state - return new_state, ORPOStepOut( - loss=__loss, - metrics=__metrics - ) - - return orpo_step +import typing +import warnings + +import chex +import fjformer +import flax.core +import jax + +from typing import Literal, Dict, Union, Tuple, List, Callable + +from jax import numpy as jnp +from ...etils import EasyDeLState +from flax.struct import dataclass +from jax.sharding import PartitionSpec + + +def pad_to_length(tensor: chex.Array, length: int, pad_value: Union[int, float], axis: int = -1) -> chex.Array: + if tensor.shape[axis] >= length: + if tensor.ndim == 2: + tensor = tensor[:, :length] + return tensor + else: + pad_size = list(tensor.shape) + pad_size[axis] = length - tensor.shape[axis] + return jax.numpy.concatenate( + [ + tensor, + pad_value * jax.numpy.ones(pad_size, dtype=tensor.dtype), + ], + axis=axis, + ) + + +@dataclass +class ORPOStepOut: + loss: chex.Array + metrics: Dict[str, Union[chex.Array, str]] + + +def create_concatenated_forward( + is_encoder_decoder, + label_pad_token_id, + padding_value, + truncation_mode: typing.Literal["keep_end", "keep_start"] = "keep_end", + fixed_max_length: int | None = None +): + """The create_concatenated_forward function is a helper function that creates a forward pass function for the + model. The forward pass function takes in an apply_fn, which is the model's apply_fn, and runs it on concatenated + inputs. It returns chosen log probs, rejected log probs, chosen logits and rejected logits. + + Args: + is_encoder_decoder: Determine whether the model is an encoder- + decoder model or not + label_pad_token_id: Pad the labels to the same length + padding_value: Pad the inputs to the same length + truncation_mode: typing.Literal["keep_end","keep_start"]: where + to pad and where to keep. + fixed_max_length: int|None: by providing fixed_max_length the + func will always return a fixed sequence length + and won't use dynamic methods. + + Returns: + A function that takes in a apply_fn, params and a batch of + inputs, + """ + + def concatenated_forward( + apply_fn: Callable, + params: dict | flax.core.FrozenDict, + batch: Dict[str, Union[List, chex.Array]] + + ) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]: + """The concatenated_forward function is used to compute the log-probabilities of both chosen and rejected labels. + + Args: + apply_fn: Callable: Pass in the model function + params: dict | flax.core.FrozenDict: Pass the model + parameters to the function + batch: Dict[str, Union[List, chex.Array]] : Pass the batch + of data to the concatenated_forward function + + Returns: + The log_probs of the chosen and rejected labels, as well as + their corresponding logits + """ + assert padding_value is not None, "`padding_value` can not be set as `None` it must be an integer." + concatenated_batch = concatenated_inputs( + batch, + is_encoder_decoder=is_encoder_decoder, + label_pad_token_id=label_pad_token_id, + padding_value=padding_value, + truncation_mode=truncation_mode, + fixed_max_length=fixed_max_length + ) + len_chosen = batch["chosen_labels"].shape[0] + concatenated_batch["concatenated_input_ids"] = concatenated_batch["concatenated_input_ids"].reshape( + concatenated_batch["concatenated_input_ids"].shape[0], -1 + ) + concatenated_batch["concatenated_labels"] = concatenated_batch["concatenated_labels"].reshape( + concatenated_batch["concatenated_labels"].shape[0], -1 + ) + concatenated_batch["concatenated_attention_mask"] = concatenated_batch["concatenated_attention_mask"].reshape( + concatenated_batch["concatenated_attention_mask"].shape[0], -1 + ) + model_kwargs = ( + { + "labels": concatenated_batch["concatenated_labels"], + "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), + } + if is_encoder_decoder + else {} + ) + all_logits = apply_fn( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + params=params, + **model_kwargs, + ).logits + + def cross_entropy_loss(logits, labels, mask): + if not is_encoder_decoder: + logits = logits[..., :-1, :] + labels = labels[..., 1:] + mask = mask[..., 1:] + loss = fjformer.cross_entropy_loss_and_accuracy(logits, labels, mask)[0] + return loss + + if is_encoder_decoder: + labels = concatenated_batch["concatenated_labels"] + else: + labels = concatenated_batch["concatenated_input_ids"] + + chosen_nll_loss = cross_entropy_loss( + all_logits[:len_chosen], + labels[:len_chosen], + concatenated_batch["concatenated_attention_mask"][:len_chosen] + ) + all_log_probs = get_batch_log_probs( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=False, + is_encoder_decoder=is_encoder_decoder, + label_pad_token_id=label_pad_token_id, + ) + + chosen_log_probs = all_log_probs[:len_chosen] + rejected_log_probs = all_log_probs[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + return chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits, chosen_nll_loss + + return concatenated_forward + + +def get_batch_log_probs( + logits: chex.Array, + labels: chex.Array, + average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, +) -> chex.Array: + """The get_batch_log_probs function computes the log probability of a batch of sequences. + + Args: + logits: chex.Array: Compute the log_softmax of the input + labels: chex.Array: Mask the logits + average_log_prob: bool: Determine whether to average the log + prob over the sequence length + label_pad_token_id: int: Mask out the padding tokens in the + labels + is_encoder_decoder: bool: Indicate whether the model is an + encoder-decoder model + :param : Determine whether to average the log probability over all tokens or not + + Returns: + The log probability of the labels given the logits + """ + + # sudo code + # (per_token_log_probs * loss_mask).sum(-1) + # or + # (per_token_log_probs * loss_mask).sum(-1) / loss_mask.sum(-1) + + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:] + logits = logits[:, :-1, :] + + batch, seq_len, dim = logits.shape + loss_mask = labels != label_pad_token_id + + labels = jnp.where(labels == label_pad_token_id, 0, labels) + + per_token_logps = jnp.take_along_axis( + jax.nn.log_softmax(logits, axis=-1), axis=2, indices=labels[:, :, None] + ).reshape(batch, seq_len) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + +def concatenated_inputs( + batch: Dict[str, Union[List, chex.Array]], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + truncation_mode: typing.Literal["keep_end", "keep_start"] = "keep_end", + fixed_max_length: int | None = None +) -> Dict[str, chex.Array]: + """The concatenated_inputs function takes a batch of chosen and rejected examples, + and concatenates them together. This is useful for training the model to predict whether an example was chosen + by the human annotator. The function also pads all inputs to + the same length as the longest input in that batch. + + Args: + batch: Dict[str,Union[List,chex.Array]]: Pass the batch of data + into the function, + is_encoder_decoder: bool: Determine whether the model is an + encoder-decoder model + label_pad_token_id: int: Pad the labels with a value of -100 + padding_value: int: Pad the input_ids and attention_mask arrays + to the same length + truncation_mode: typing.Literal["keep_end", "keep_start"]: is + left padded or not should it keep start of the + fixed_max_length: int|None: by providing fixed_max_length the + func will always return a fixed sequence length and won't + use dynamic methods. + Allow for the batch to be a list of arrays or just an array, + Specify the type of data that is being passed in + + array or the end of the array?. + + Returns: + A dictionary of the concatenated inputs + """ + concatenated_batch = {} + if fixed_max_length is None: + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[-1], batch["rejected_labels"].shape[-1]) + else: + max_length = max(batch["chosen_input_ids"].shape[-1], batch["rejected_input_ids"].shape[-1]) + else: + max_length = fixed_max_length + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], jax.Array): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + else: + raise KeyError("couldn't find pad_value [Dataset Issue]") + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], jax.Array): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + assert padding_value is not None, "`padding_value` can not be set as `None`" + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + else: + raise KeyError("couldn't find pad_value [Dataset Issue]") + concatenated_key = k.replace("rejected", "concatenated") + v2d = lambda ar: ar.reshape(ar.shape[0], -1) + concatenated_batch[concatenated_key] = jnp.concatenate( + ( + v2d(concatenated_batch[concatenated_key]), + pad_to_length(v2d(batch[k]), max_length, pad_value=pad_value), + ), + axis=0, + ) + for k in list(concatenated_batch.keys()): + val = concatenated_batch[k] + if val.ndim == 3: + # making 3d array 2d + concatenated_batch[k] = val.reshape(val.shape[0], -1) + if is_encoder_decoder: + warnings.warn("`concatenated_input_ids` will be repeated (encoder decoder model detected)") + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1) + ) + + return concatenated_batch + + +def odds_ratio_loss( + beta: float, + policy_chosen_logps: chex.Array, + policy_rejected_logps: chex.Array, +) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]: + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + jnp.log1p(-jnp.exp(policy_chosen_logps)) - jnp.log1p(-jnp.exp(policy_rejected_logps)) + ) + sig_ratio = jax.nn.sigmoid(log_odds) + ratio = jnp.log(sig_ratio) + losses = beta * ratio + # jax.debug.print("policy_chosen_logps : {x}", x=policy_chosen_logps) + # jax.debug.print("policy_rejected_logps : {x}", x=policy_rejected_logps) + # jax.debug.print("sig_ratio : {x}", x=sig_ratio) + # jax.debug.print("ratio : {x}", x=ratio) + # jax.debug.print("log_odds : {x}", x=log_odds) + # jax.debug.print("losses : {x}", x=losses) + + chosen_rewards = beta * jax.lax.stop_gradient(policy_chosen_logps) + rejected_rewards = beta * jax.lax.stop_gradient(policy_rejected_logps) + + return ( + losses, + chosen_rewards, + rejected_rewards, + jnp.mean(ratio), + jnp.mean(log_odds) + ) + + +def create_orpo_step_function( + concatenated_forward: Callable, + beta: float = 0.1, + mode: Literal["train", "eval"] = "train", + batch_partition_spec: PartitionSpec = PartitionSpec(("fsdp", "dp"), "sp") +): + """The create_orpo_step_function function is a helper function that creates the ORPO training step. + + Args: + concatenated_forward: Callable: Define the forward pass of the + model + beta: float: Scale the logits + mode: Literal["train", "eval"] : "train", "eval" function modes + batch_partition_spec: PartitionSpec: Batch PartitionSpec + + Returns: + A function that takes in a state and a batch + """ + + def orpo_step( + state: EasyDeLState, + batch: dict + ) -> tuple[EasyDeLState, ORPOStepOut]: + """The orpo_step function is the core of ORPO. It takes a state and a batch, + and returns an updated state. The update is done by calculating the loss + for each example in the batch, then taking its gradient with respect to + the parameters of the policy network (which are stored in `state`). This + gradient is then used to update `state`. + + Args: + state: EasyDeLState: Store the parameters of the model + batch: dict: Pass the data to the model + + Returns: + A new state, which is a collection of the parameters and + apply_fn + """ + batch = fjformer.with_sharding_constraint(batch, partition_specs=batch_partition_spec) + + def calculate_loss(params: dict | flax.core.FrozenDict): + ( + policy_chosen_log_probs, + policy_rejected_log_probs, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss + ) = concatenated_forward( + state.apply_fn, + params, + batch + ) + + losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = odds_ratio_loss( + beta, policy_chosen_log_probs, policy_rejected_log_probs + ) + + loss = policy_nll_loss - losses.mean() + + reward_accuracies = (chosen_rewards > rejected_rewards).astype("float32") + metrics = {} + prefix = "eval_" if mode == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean() + metrics[f"{prefix}logps/rejected"] = policy_rejected_log_probs.mean() + metrics[f"{prefix}logps/chosen"] = policy_chosen_log_probs.mean() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.mean() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.mean() + metrics[f"{prefix}nll_loss"] = policy_nll_loss.mean() + metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio + metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen + return loss, metrics + + if mode == "train": + grad_fn = jax.value_and_grad(calculate_loss, has_aux=True) + (__loss, (__metrics)), grads = grad_fn(state.params) + new_state = state.apply_gradients(grads=grads) + else: + __loss, __metrics = calculate_loss(state.params) + new_state = state + return new_state, ORPOStepOut( + loss=__loss, + metrics=__metrics + ) + + return orpo_step diff --git a/src/python/easydel/trainer/orpo/modelling_output.py b/src/python/easydel/trainer/orpo/modelling_output.py index b74a2e563..703921f1c 100644 --- a/src/python/easydel/trainer/orpo/modelling_output.py +++ b/src/python/easydel/trainer/orpo/modelling_output.py @@ -1,15 +1,15 @@ -from dataclasses import dataclass -import jax -from typing import Any, Optional, Callable, Mapping -from ...etils.easystate import EasyDeLState - - -@dataclass -class ORPOTrainerOutput: - state: EasyDeLState - mesh: Optional[jax.sharding.Mesh] - checkpoint_manager: Any - gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None - shard_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None - last_save_file_name: Optional[str] = None - checkpoint_path: Optional[str] = None +from dataclasses import dataclass +import jax +from typing import Any, Optional, Callable, Mapping +from ...etils.easystate import EasyDeLState + + +@dataclass +class ORPOTrainerOutput: + state: EasyDeLState + mesh: Optional[jax.sharding.Mesh] + checkpoint_manager: Any + gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None + shard_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None + last_save_file_name: Optional[str] = None + checkpoint_path: Optional[str] = None diff --git a/src/python/easydel/trainer/orpo/orpo_trainer.py b/src/python/easydel/trainer/orpo/orpo_trainer.py index 5f950ef53..9370b739f 100644 --- a/src/python/easydel/trainer/orpo/orpo_trainer.py +++ b/src/python/easydel/trainer/orpo/orpo_trainer.py @@ -1,1219 +1,1219 @@ -import copy -import os -import sys -import time -import typing -import warnings -from abc import ABC -from collections import defaultdict -from glob import glob - -import flax.core -import jax -import tensorflow.data -import tensorflow_datasets -import termcolor -import wandb -from fjformer import match_partition_rules, make_shard_and_gather_fns -from flax.core import FrozenDict -from tqdm import tqdm - -from typing import ( - Optional, - Literal, - Dict, - Union, - Any, - Callable, - Mapping, - Tuple -) - -from jax.experimental.pjit import pjit -from datasets import Dataset -from jax import numpy as jnp - -from ...etils.etils import get_logger -from ..training_configurations import TrainArguments -from ..base_trainer import ( - BaseTrainer, - TrainerConfigureFunctionFuncOutput, - TrainerConfigureDataloaderFuncOutput, - TrainerConfigureModelFuncOutput -) -from ...etils import EasyDeLState, EasyDeLTimerError -from transformers import PreTrainedTokenizerBase -from jax.sharding import PartitionSpec - -from ...utils import Timers, prefix_print -from ..dpo.utils import ( - pad_to_length, - DPODataCollatorWithPadding, - leave_alone_context_manager -) -from .fwd_bwd_functions import ( - create_orpo_step_function, - create_concatenated_forward, -) -from .modelling_output import ORPOTrainerOutput - -logger = get_logger(__name__) - - -class ORPOTrainer(BaseTrainer, ABC): - """ - easydel ORPO Trainer Class - """ - - def __init__( - self, - arguments: TrainArguments, - max_length: Optional[int] = None, - max_prompt_length: Optional[int] = None, - max_completion_length: Optional[int] = None, - beta: float = 0.1, - disable_dropout: bool = True, - label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, - padding_value: int = None, - data_collator: Optional[DPODataCollatorWithPadding] = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, - _do_init_fns: bool = True, - dataset_map_arguments: Optional[Dict[str, Any]] = None, - low_mem_usage: bool = False, - ): - - """ - The __init__ function is called when the class is instantiated. - It sets up the attributes of an object. - - - :param self: Refer to the object itself - :param beta: float: Control the strength of the regularization term - :param arguments: TrainArguments: Pass the arguments to the trainer - :param label_pad_token_id: int: Pad the labels - :param padding_value: int: Specify the value that is used for padding - :param train_dataset: Optional[Dataset]: Load the training dataset - :param eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] : Pass the evaluation dataset to the trainer - :param tokenizer: Optional[PreTrainedTokenizerBase]: Pass the tokenizer to the trainer - :param max_length: Optional[int]: Set the maximum length of the input sequence - :param max_prompt_length: Optional[int]: Set the maximum length of the prompt - :param max_completion_length: Optional[int]: Truncate the target sequence - :param data_collator: Optional[Callable]: Function to be used for creating datasets. - tokenizing process with `dataset.map`. - :param _do_init_fns: bool : preferred to set ture to trainer will automatically configure - model with provided training Arguments - :param : Set the padding value for the model - """ - - assert arguments is not None, ( - "You Have to pass arguments that will be used for training but you have passed" - "`arguments=None`" - ) - assert isinstance(arguments, TrainArguments), ( - f"arguments type must be `TrainArguments` but got {type(arguments)}" - ) - - if tokenizer is None: - raise ValueError("tokenizer must be specified to tokenize a ORPO dataset.") - if max_length is None: - warnings.warn( - "`max_length` is not set in the ORPOTrainer's init" - " it will default to `512` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_length = 512 - if max_prompt_length is None: - warnings.warn( - "`max_prompt_length` is not set in the ORPOTrainer's init" - " it will default to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_prompt_length = 128 - - if max_completion_length is None: - warnings.warn( - "When using an encoder decoder architecture, you should set `max_completion_length` in the " - "ORPOTrainer's init it will default to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_completion_length = 128 - - padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id - self.max_length = max_length - self.label_pad_token_id = label_pad_token_id - self.padding_value = padding_value - self.max_prompt_length = max_prompt_length - self.truncation_mode = arguments.truncation_mode - self.disable_dropout = disable_dropout - self.max_completion_length = max_completion_length - self.tokenizer = tokenizer - self.is_encoder_decoder = is_encoder_decoder - self.low_mem_usage = low_mem_usage - self.beta = beta - data_collator = DPODataCollatorWithPadding( - max_prompt_length=self.max_prompt_length, - max_target_length=self.max_completion_length, - pad_token_id=tokenizer.pad_token_id, - label_pad_token_id=label_pad_token_id, - is_encoder_decoder=False, - ) if data_collator is None else data_collator - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - if dataset_map_arguments is None: - dataset_map_arguments = {} - train_dataset = train_dataset.map( - self.tokenize_row, - **dataset_map_arguments - ) - if eval_dataset is not None: - eval_dataset = eval_dataset.map( - self.tokenize_row, - **dataset_map_arguments - ) - - self.arguments = arguments - self.hp_name = None - self.deepspeed = None - self.is_in_train = False - - self.data_collator = data_collator - self.train_dataset = train_dataset - self.eval_dataset = eval_dataset - self.tokenizer = tokenizer - self._loggers_initialized = False - self.mesh = self.arguments.get_mesh() - assert padding_value is not None, "`padding_value` can not be set as `None` it must be an integer." - - self.concatenated_forward = create_concatenated_forward( - is_encoder_decoder=self.is_encoder_decoder, - padding_value=padding_value, - label_pad_token_id=label_pad_token_id, - ) - - self._cached_p_l_s = None - self._cached_c_l_s = None - self._cached_r_l_s = None - - super().__init__( - arguments=arguments, - dataset_train=train_dataset, - dataset_eval=eval_dataset, - finetune=True, - checkpoint_path=None, - _do_init_fns=_do_init_fns - ) - - def build_tokenized_answer(self, prompt, answer): - """ - Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. - It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`. - """ - - full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) - prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] - - answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids):] - answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids):] - prompt_input_ids = jnp.asarray(prompt_input_ids, dtype="i4") - answer_input_ids = jnp.asarray(answer_input_ids, dtype="i4") - full_concat_input_ids = jnp.concatenate( - ( - prompt_input_ids, - answer_input_ids - ) - ) - - # Prepare input tokens for token by token comparison - full_input_ids = jnp.array(full_tokenized["input_ids"]) - - if len(full_input_ids) != len(full_concat_input_ids): - raise ValueError("Prompt input ids and answer input ids should have the same length.") - - response_token_ids_start_idx = len(prompt_input_ids) - if prompt_input_ids.tolist() != full_tokenized["input_ids"][:response_token_ids_start_idx]: - response_token_ids_start_idx -= 1 - - prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] - prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] - - if len(prompt_input_ids) != len(prompt_attention_mask): - raise ValueError("Prompt input ids and attention mask should have the same length.") - - answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] - answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] - - return dict( - prompt_input_ids=jnp.array(prompt_input_ids, dtype="i4"), - prompt_attention_mask=jnp.array(prompt_attention_mask, dtype="i4"), - input_ids=jnp.array(answer_input_ids, dtype="i4"), - attention_mask=jnp.array(answer_attention_mask, dtype="i4"), - ) - - def tokenize_row(self, feature, state: EasyDeLState = None) -> Dict: - - """ - The tokenize_row function is responsible for taking a single row of data and converting it into the format that - the model expects. This includes: - - Tokenizing the text (using HuggingFace's tokenizer) - - Padding/truncating sequences to a fixed length (if necessary) - - Creating attention masks, which tell the model which tokens are padding and which aren't. - - :param self: Represent the instance of the class - :param feature: Pass in the data from the dataset - :param state: EasyDeLState: Keep track of the state of the tokenizer - :return: A dictionary of the following keys - """ - batch = {} - prompt = feature["prompt"] - chosen = feature["chosen"] - rejected = feature["rejected"] - - if not isinstance(prompt, str): - raise ValueError(f"prompt should be an str but got {type(prompt)} , {prompt}") - prompt_tokens = self.tokenizer( - prompt, - add_special_tokens=False, - return_tensors="np", - ) - prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} - - if not isinstance(chosen, str): - raise ValueError(f"chosen should be an str but got {type(chosen)} , {chosen}") - chosen_tokens = self.build_tokenized_answer(prompt, chosen) - - if not isinstance(rejected, str): - raise ValueError(f"rejected should be an str but got {type(rejected)}") - rejected_tokens = self.build_tokenized_answer(prompt, rejected) - v2d = lambda ar: ar.reshape(1, -1) if ar.ndim == 1 else ar - - def add_tkn(n, ar): - return jnp.concatenate( - ( - jnp.array(n).reshape(1, 1), - v2d(ar) - ), axis=-1 - ) - - def add_post_tkn(n, ar): - return jnp.concatenate( - ( - v2d(ar), - jnp.array(n).reshape(1, 1) - ), axis=-1 - ) - - prompt_tokens["prompt_input_ids"] = add_tkn( - self.tokenizer.bos_token_id, - prompt_tokens["prompt_input_ids"] - ) - chosen_tokens["prompt_input_ids"] = add_tkn( - self.tokenizer.bos_token_id, - chosen_tokens["prompt_input_ids"] - ) - rejected_tokens["prompt_input_ids"] = add_tkn( - self.tokenizer.bos_token_id, - rejected_tokens["prompt_input_ids"] - ) - - prompt_tokens["prompt_attention_mask"] = add_tkn( - 1, prompt_tokens["prompt_attention_mask"] - ) - chosen_tokens["prompt_attention_mask"] = add_tkn( - 1, chosen_tokens["prompt_attention_mask"] - ) - rejected_tokens["prompt_attention_mask"] = add_tkn( - 1, rejected_tokens["prompt_attention_mask"] - ) - - # add EOS token to end of answer - chosen_tokens["input_ids"] = add_post_tkn(self.tokenizer.eos_token_id, chosen_tokens["input_ids"]) - chosen_tokens["attention_mask"] = add_post_tkn(1, chosen_tokens["attention_mask"]) - - rejected_tokens["input_ids"] = add_post_tkn(self.tokenizer.eos_token_id, rejected_tokens["input_ids"]) - rejected_tokens["attention_mask"] = add_post_tkn(1, rejected_tokens["attention_mask"]) - - longer_response_length = max(chosen_tokens["input_ids"].shape[-1], rejected_tokens["input_ids"].shape[-1]) - - # if combined sequence is too long, truncate the prompt - for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: - length_rn = answer_tokens["prompt_input_ids"].shape[-1] + longer_response_length - if length_rn > self.max_length: - - if self.truncation_mode == "keep_start": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][:, : self.max_prompt_length] - elif self.truncation_mode == "keep_end": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][:, -self.max_prompt_length:] - else: - raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") - # if that's still too long, truncate the response - for answer_tokens in [chosen_tokens, rejected_tokens]: - if answer_tokens["prompt_input_ids"].shape[-1] + longer_response_length > self.max_length: - for k in ["input_ids", "attention_mask"]: - answer_tokens[k] = answer_tokens[k][:, : self.max_length - self.max_prompt_length] - - chosen_sequence_tokens = { - k: jnp.concatenate( - (v2d(chosen_tokens[f"prompt_{k}"]), v2d(chosen_tokens[k])), - axis=-1 - ) for k in ["input_ids", "attention_mask"] - } - rejected_sequence_tokens = { - k: jnp.concatenate( - (v2d(rejected_tokens[f"prompt_{k}"]), v2d(rejected_tokens[k])), - axis=-1 - ) for k in ["input_ids", "attention_mask"] - } - chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] - chosen_sequence_tokens["labels"] = chosen_sequence_tokens["labels"].at[ - : len(chosen_tokens["prompt_input_ids"]) - ].set([self.label_pad_token_id] * len(chosen_tokens["prompt_input_ids"])) - rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] - rejected_sequence_tokens["labels"] = rejected_sequence_tokens["labels"].at[ - : len(rejected_tokens["prompt_input_ids"]) - ].set( - ([self.label_pad_token_id] * len(rejected_tokens["prompt_input_ids"])) - ) - - for k, tokens_ in { - "chosen_": chosen_sequence_tokens, - "rejected_": rejected_sequence_tokens, - "": prompt_tokens, - }.items(): - for type_key, tokens in tokens_.items(): - if type_key == "token_type_ids": - continue - - b, s = tokens.shape - - if self.max_prompt_length > s: - if k == "chosen_": - if type_key == "input_ids": - tokens = pad_to_length( - tokens, - self.max_completion_length, - pad_value=self.padding_value, - axis=-1 - ) - elif type_key == "attention_mask": - tokens = pad_to_length( - tokens, - self.max_completion_length, - pad_value=0, - axis=-1 - ) - elif type_key == "labels": - tokens = pad_to_length( - tokens, - self.max_completion_length, - pad_value=self.padding_value, - axis=-1 - ) - - tokens = tokens[..., :self.max_completion_length] - - if tokens.shape[-1] != self.max_completion_length: - raise ValueError( - f"there was an error in padding token with `type_key` of {type_key}" - f". it must have sequence_length of {self.max_completion_length} but we got {tokens.shape[-1]}" - f" From {k}{type_key}" - ) - tokens = tokens[..., :self.max_completion_length] - elif k == "rejected_": - if type_key == "input_ids": - tokens = pad_to_length( - tokens, - self.max_completion_length, - pad_value=self.padding_value, - axis=-1 - ) - elif type_key == "attention_mask": - tokens = pad_to_length( - tokens, - self.max_completion_length, - pad_value=0, - axis=-1 - ) - elif type_key == "labels": - tokens = pad_to_length( - tokens, - self.max_completion_length, - pad_value=self.padding_value, - axis=-1 - ) - tokens = tokens[..., :self.max_completion_length] - if tokens.shape[-1] != self.max_completion_length: - raise ValueError( - f"there was an error in padding token with `type_key` of {type_key}" - f". it must have sequence_length of {self.max_completion_length} but we got {tokens.shape[-1]}" - f" From {k}{type_key}" - ) - elif k == "": - if type_key == "prompt_input_ids": - tokens = pad_to_length( - tokens, - self.max_prompt_length, - pad_value=self.padding_value, - axis=-1 - ) - elif type_key == "prompt_attention_mask": - tokens = pad_to_length( - tokens, - self.max_prompt_length, - pad_value=0, - axis=-1 - ) - elif type_key == "prompt_labels": - tokens = pad_to_length( - tokens, - self.max_prompt_length, - pad_value=self.padding_value, - axis=-1 - ) - tokens = tokens[..., :self.max_prompt_length] - if tokens.shape[-1] != self.max_prompt_length: - raise ValueError( - f"there was an error in padding token with `type_key` of {type_key}" - f". it must have sequence_length of {self.max_prompt_length} but we got {tokens.shape[-1]}" - f" From {k}{type_key}" - ) - batch[f"{k}{type_key}"] = tokens - return batch - - def configure_functions(self) -> TrainerConfigureFunctionFuncOutput: - """ - The configure_functions function is responsible for configuring the functions that will be used in training. - It does this by first defining a function called function_configurations, which initializes the model parameters - and returns - them as a EasyDeLState object. The EasyDeLState object contains all the information needed to train or evaluate - on a batch of data, including: - :param self: Access the class attributes - :return: A TrainerConfigureFunctionFuncOutput object - - """ - - def initialize_state_function(): - initialized_parameters = self.model.init_weights( - jax.random.PRNGKey(0), - self.arguments.init_input_shape - ) - - if self.arguments.dtype == jnp.bfloat16: - initialized_parameters = self.model.to_bf16(initialized_parameters) - elif self.arguments.dtype == jnp.float16: - initialized_parameters = self.model.to_fp16(initialized_parameters) - - tx = self.tx - parameters = flax.core.freeze({"params": initialized_parameters}) - tx_init = copy.deepcopy(self.arguments.optimizer_kwargs) - - if self.rapture is not None: - lora_parameters = self.lora_parameters - if self.arguments.dtype == jnp.bfloat16: - lora_parameters = self.model.to_bf16(lora_parameters) - elif self.arguments.dtype == jnp.float16: - lora_parameters = self.model.to_fp16(lora_parameters) - - return EasyDeLState( - step=0, - apply_fn=self.lora_apply_fn, - params=lora_parameters, - tx=self.lora_tx, - opt_state=self.lora_opt_state, - tx_init=EasyDeLState.safe_dict(tx_init), - hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), - module=self.lora_model, - module_config=self.model.config, - module_config_args=None, - ) - else: - return EasyDeLState.create( - tx=tx, - params=parameters, - apply_fn=self.model.__call__, - module_config=copy.deepcopy(self.model.config), - tx_init=tx_init, - hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), - module=self.model, - module_config_args=None - ) - - def create_state_from_params_function(parameters): - if self.rapture is None: - return EasyDeLState.create( - tx=self.tx, - params=parameters, - apply_fn=self.model.__call__, - module_config=copy.deepcopy(self.model.config), - tx_init=copy.deepcopy(self.arguments.optimizer_kwargs), - hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), - module=self.model, - module_config_args=None - ) - else: - return EasyDeLState( - step=0, - apply_fn=self.lora_apply_fn, - params=parameters, - tx=self.lora_tx, - opt_state=self.lora_opt_state, - tx_init=EasyDeLState.safe_dict(copy.deepcopy(self.arguments.optimizer_kwargs)), - hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), - module=self.lora_model, - module_config=self.model.config, - module_config_args=None, - ) - - state_shape = jax.eval_shape(initialize_state_function) - state_partition_spec = match_partition_rules( - self.config.get_partition_rules( - fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel - ) if self.arguments.custom_rule is None else self.arguments.custom_rule, - state_shape - ) - create_sharded_state_from_params_function = pjit( - create_state_from_params_function, - in_shardings=(state_partition_spec.params,), - out_shardings=state_partition_spec, - donate_argnums=(0,) - ) - sharded_train_step_function = pjit( - create_orpo_step_function( - mode="train", - beta=self.beta, - concatenated_forward=self.concatenated_forward, - batch_partition_spec=self.arguments.step_partition_spec - ), - in_shardings=(state_partition_spec, PartitionSpec()), - out_shardings=(state_partition_spec, PartitionSpec(),), - - ) - - sharded_eval_step_function = pjit( - create_orpo_step_function( - mode="eval", - beta=self.beta, - concatenated_forward=self.concatenated_forward, - batch_partition_spec=self.arguments.step_partition_spec - ), - in_shardings=(state_partition_spec, PartitionSpec()), - out_shardings=(state_partition_spec, PartitionSpec(),), - - ) - - mesh = self.arguments.get_mesh() - self.arguments.ckpt_path_exists() - checkpoint_manager = self.arguments.get_streaming_checkpointer() - self.state_partition_spec = state_partition_spec - self.state_shape = state_shape - - return TrainerConfigureFunctionFuncOutput( - create_sharded_state_from_params_function=create_sharded_state_from_params_function, - sharded_train_step_function=sharded_train_step_function, - sharded_eval_step_function=sharded_eval_step_function, - mesh=mesh, - checkpoint_manager=checkpoint_manager, - initialize_state_function=initialize_state_function - ) - - def initialize_state( - self, - model_parameters: Optional[flax.core.FrozenDict] = None, - state: Optional[EasyDeLState] = None, - ) -> Tuple[EasyDeLState, Mapping[str, Callable], Mapping[str, Callable]]: - if model_parameters is None and state is None and self.rapture is None and self.checkpoint_path is None: - raise RuntimeError( - "You are passing `model_parameters=None`, `state=None`, and `checkpoint_path=None` and also you are not" - " using LoRA, if you are " - "Using LoRA make sure to pass parameters and Rapture Config correctly otherwise pass the " - "model_parameters or state." - ) - if model_parameters is None and state is None: - model_parameters = self.lora_parameters - with self.mesh: - shard_fns, gather_fns = make_shard_and_gather_fns( - self.state_partition_spec, - dtype_specs=self.dtype - ) - if state is not None: - sharded_state = state - params = sharded_state.params if not self.arguments.do_shard_fns else jax.tree_util.tree_map( - lambda f, x: f(x), - shard_fns.params, - sharded_state.params - ) - sharded_state.params = params - if sharded_state.opt_state is None: - prefix_print( - "Action", "Optimizer State is not Found!, initializing one." - ) - with jax.default_device(self.arguments.offload_device): - sharded_state = sharded_state.init_opt_state() - opt_state = sharded_state.opt_state if not self.arguments.do_shard_fns else jax.tree_util.tree_map( - lambda f, x: f(x), - shard_fns.opt_state, - sharded_state.opt_state - ) - sharded_state = sharded_state.replace( - opt_state=opt_state - ) - elif self.finetune: - - if model_parameters is None and self.checkpoint_path is not None: - prefix_print( - "Action", f"Loading Model From {self.checkpoint_path}" - ) - with jax.default_device(self.arguments.offload_device): - sharded_state = EasyDeLState.load_state( - verbose=self.arguments.verbose, - state_shard_fns=shard_fns, - init_optimizer_state=True, - checkpoint_path=self.checkpoint_path, - input_shape=self.arguments.init_input_shape, - config_kwargs=self.arguments.loaded_model_config_kwargs - ) - state_shape = jax.eval_shape(lambda: sharded_state) - state_partition_spec = match_partition_rules( - self.config.get_partition_rules( - fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel - ) if self.arguments.custom_rule is None else self.arguments.custom_rule, - state_shape - ) - sharded_train_step_function = pjit( - create_orpo_step_function( - mode="train", - beta=self.beta, - concatenated_forward=self.concatenated_forward, - batch_partition_spec=self.arguments.step_partition_spec - ), - in_shardings=(state_partition_spec, PartitionSpec()), - out_shardings=(state_partition_spec, PartitionSpec(),), - - ) - - sharded_eval_step_function = pjit( - create_orpo_step_function( - mode="eval", - beta=self.beta, - concatenated_forward=self.concatenated_forward, - batch_partition_spec=self.arguments.step_partition_spec - ), - in_shardings=(state_partition_spec, PartitionSpec()), - out_shardings=(state_partition_spec, PartitionSpec(),), - ) - - self.state_partition_spec = state_partition_spec - self.state_shape = state_shape - self.sharded_train_step_function = sharded_train_step_function - self.sharded_eval_step_function = sharded_eval_step_function - - if self.arguments.remove_ckpt_after_load: - os.remove(self.checkpoint_path) - elif model_parameters is not None and self.checkpoint_path is None: - prefix_print( - "Action", f"Sharding Passed Parameters" - ) - from flax.core import unfreeze - if not isinstance(model_parameters, flax.core.FrozenDict): - prefix_print( - "Warning", - "Model Parameters should be like FrozenDict({'params': params}) make sure to " - "pass as type FrozenDict in case of not getting UnExcepted Errors " - ) - - model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map( - lambda f, x: f(x), - shard_fns.params, - model_parameters, - ) - sharded_state = self.create_sharded_state_from_params_function(model_parameters) - elif model_parameters is not None and self.checkpoint_path is not None: - raise EasyDeLTimerError( - "You can't pass `model_parameters` and `checkpoint_path` at same time" - ) - else: - raise EasyDeLTimerError( - "You should pass `model_parameters` or `checkpoint_path` to trainer in order to load model" - ) - else: - sharded_state = self.initialize_state_function() - params = sharded_state.params if not self.arguments.do_shard_fns else jax.tree_util.tree_map( - lambda f, x: f(x), - shard_fns.params, - sharded_state.params - ) - sharded_state.params = params - - self.sharded_state = sharded_state - return sharded_state, shard_fns, gather_fns - - def _save_state( - self, - state: EasyDeLState, - gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]], - milestone: bool = False - ) -> str: - step = int( - jax.device_get( - state.step - ) - ) + self.arguments.step_start_point if self.arguments.step_start_point is not None else int( - jax.device_get( - state.step - ) - ) - - checkpoint_dir = os.path.join(self.arguments.save_dir, self.arguments.model_name) - filename_extension = ".easy" - if self.arguments.save_total_limit: - checkpoint_files = glob(os.path.join(checkpoint_dir, f"*{filename_extension}")) - checkpoint_files.sort(key=os.path.getmtime) - for old_checkpoint in checkpoint_files[:-self.arguments.save_total_limit]: - os.remove(old_checkpoint) - termcolor.cprint(f"Removed old checkpoint: {old_checkpoint}", color="red", force_color=True) - - checkpoint_name = f"{self.arguments.model_name}-S{step}" - filename = f"{checkpoint_name}_{step}" if milestone else f"{checkpoint_name}" - filename += ".easy" - termcolor.cprint(f"Saving Model {filename}.", color="cyan", force_color=True) - state.save_state( - filename=filename, - checkpoint_dir=checkpoint_dir, - gather_fns=gather_fns, - float_dtype=self.dtype, - verbose=self.arguments.verbose, - save_optimizer=self.arguments.save_optimizer_state, - ) - return filename - - def initialize_trainer_utils(self): - """ - The initialize_trainer_utils function is responsible for initializing the following: - - wandb_runtime (if you use_wandb is True) - - timer object (for logging time taken by various functions) - - dataloader objects for training and evaluation data, along with max steps per epoch. - The configure_dataloader function accomplishes this task. - - :param self: Represent the instance of the class - :return: A tuple of functions - - """ - self.wandb_runtime = self.arguments.get_wandb_init() if self.arguments.use_wandb else None - self.timer = Timers( - use_wandb=False, - tensorboard_writer=self.arguments.get_board() - ) - - self.timer("configure dataloaders").start() - dataset_configurations = self.configure_dataloader() - self.dataloader_train = dataset_configurations.dataloader_train - self.max_training_steps = dataset_configurations.max_training_steps - self.dataloader_eval = dataset_configurations.dataloader_eval - self.max_evaluation_steps = dataset_configurations.max_evaluation_steps - - self.timer("configure dataloaders").stop() - - self.timer.log(["configure dataloaders"]) - - self.timer("configure Model, Optimizer, Scheduler and Config").start() - model_configurations = self.configure_model() - model = model_configurations.model - tx = model_configurations.tx - scheduler = model_configurations.scheduler - config = model_configurations.config - self.model = model - self.tx = tx - self.scheduler = scheduler - self.config = config - if self.rapture is not None: - lora_modules = self.rapture.apply_lora( - module=model, - parameters=self.arguments.rapture_config.parameters, - tx=tx, - ) - self.lora_parameters = lora_modules.lora_parameters - self.lora_apply_fn = lora_modules.lora_module.__call__ - self.lora_opt_state = lora_modules.lora_opt_state - self.lora_model = lora_modules.lora_module - self.lora_tx = lora_modules.lora_tx - - self.timer("configure Model, Optimizer, Scheduler and Config").stop() - self.timer.log(["configure Model, Optimizer, Scheduler and Config"]) - - self.timer("configure functions and sharding them").start() - - function_configurations = self.configure_functions() - self.create_sharded_state_from_params_function = ( - function_configurations.create_sharded_state_from_params_function - ) - self.sharded_train_step_function = function_configurations.sharded_train_step_function - self.sharded_eval_step_function = function_configurations.sharded_eval_step_function - self.mesh = function_configurations.mesh - self.checkpoint_manager = function_configurations.checkpoint_manager - self.initialize_state_function = function_configurations.initialize_state_function - self.timer("configure functions and sharding them").stop() - self.timer.log(["configure functions and sharding them"]) - - def create_collate_function( - self, - max_sequence_length: int, - truncation_mode: typing.Literal["keep_end", "keep_start"] = "keep_end", - ) -> Callable: - return self.data_collator - - def shard_states(self, state, rules): - with self.arguments.get_mesh(): - partition_spec = match_partition_rules(rules=rules, params=jax.eval_shape(lambda: state)) - - def _shard(x): - return x - - shard = pjit( - _shard, - in_shardings=(PartitionSpec(),), - out_shardings=partition_spec - ) - return shard(state) - - def configure_dataloader(self) -> TrainerConfigureDataloaderFuncOutput: - dataloader_train = self.get_train_dataloader() - max_evaluation_steps = None - dataloader_eval = None - - max_training_steps = self.arguments.num_train_epochs * len( - dataloader_train - ) if self.arguments.max_training_steps is None else self.arguments.max_training_steps - if self.eval_dataset is not None: - dataloader_eval = self.get_eval_dataloader(self.eval_dataset) - max_evaluation_steps = len(dataloader_eval) - return TrainerConfigureDataloaderFuncOutput( - dataloader_train=dataloader_train, # type:ignore - max_training_steps=max_training_steps, - dataloader_eval=dataloader_eval, - max_evaluation_steps=max_evaluation_steps - ) - - def _get_train_dataloader(self) -> tensorflow.data.Dataset: - - """ - The _get_train_dataloader function is used to create a tensorflow.data.Dataset object for the training dataset. - - :param self: Represent the instance of the class - :return: A dataloader object - """ - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - train_dataset = self.train_dataset - data_collator = self.data_collator - - return tensorflow_datasets.as_numpy( - train_dataset.to_tf_dataset( - batch_size=self.arguments.total_batch_size, - collate_fn=data_collator, - num_workers=self.arguments.dataloader_num_workers, - shuffle=True, - drop_remainder=True - ) - ) - - def _get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> tensorflow.data.Dataset: - """ - Returns the evaluation [`~tensorflow.data.Dataset`]. - - Subclass and override this method if you want to inject some custom behavior. - - Args: - eval_dataset (`torch.utils.data.Dataset`, *optional*): - If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted - by the `model.forward()` method are automatically removed. It must implement `__len__`. - """ - if eval_dataset is None and self.eval_dataset is None: - raise ValueError("Trainer: evaluation requires an eval_dataset.") - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - - return tensorflow_datasets.as_numpy( - eval_dataset.to_tf_dataset( - batch_size=self.arguments.total_batch_size, - collate_fn=self.data_collator, - num_workers=self.arguments.dataloader_num_workers, - shuffle=False, - drop_remainder=True - ) - ) - - def get_train_dataloader( - self, - ) -> tensorflow.data.Dataset: - """ - Returns the training [`~tensorflow.data.Dataset`]. - """ - return self._get_train_dataloader() - - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> tensorflow.data.Dataset: - """ - Returns the evaluation [`~tensorflow.data.Dataset`]. - """ - if eval_dataset is None and self.eval_dataset is None: - raise ValueError("Trainer: evaluation requires an eval_dataset.") - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - return self._get_eval_dataloader(eval_dataset=eval_dataset) - - def train( - self, - model_parameters: Optional[flax.core.FrozenDict] = None, - state: Optional[EasyDeLState] = None - ) -> ORPOTrainerOutput: - def get_layer_names(frozen_dict, prefix=""): - layer_names = {} - for key, value in frozen_dict.items(): - if isinstance(value, FrozenDict): - layer_names.update(get_layer_names(value, prefix=f"{prefix}_{key}")) - else: - layer_name = f"{prefix}_{key}".lstrip("/") - layer_names[layer_name] = value - return layer_names - - def count_model_parameters(_p): - termcolor.cprint( - f"Model Contain {sum(n.size for n in jax.tree_util.tree_flatten(flax.core.unfreeze(_p))[0]) / 1e9} " - f"Billion Parameters", - color="red", force_color=True - ) - - checkpoint_path = "SAVING_SKIPPED" - if self.arguments.performance_mode: - termcolor.cprint( - "Performance Mode is ON, we will ignore the Memory Tracking, WANDB Logging, and extra information " - "Process.", - color="red", - force_color=True - ) - sharded_state, shard_fns, gather_fns = self.initialize_state( - model_parameters=model_parameters, - state=state - ) - self.model_state = sharded_state - count_model_parameters(sharded_state.params) - with self.mesh: - with jax.default_device(jax.devices("cpu")[0]) if self.low_mem_usage else leave_alone_context_manager(): - dir_prefix: str = "/dev/shm" if sys.platform != "win32" else "." - checkpoint_path = "SAVING_SKIPPED" - - pbar = tqdm(total=self.max_training_steps) - pbar.set_description("Training") - current_step = self.model_state.step.tolist() if isinstance( - self.model_state.step, - jax.Array - ) else self.model_state.step - - loss_sum = None - - try: - for epoch_index in range(self.arguments.num_train_epochs): - for batch in self.dataloader_train: - current_step += 1 - if self.arguments.step_start_point > current_step: - ... - elif current_step < self.max_training_steps: - time_start = time.time() - - self.model_state, outputs = self.sharded_train_step_function( - self.model_state, - batch - ) - total_time = time.time() - time_start - (loss, metrics) = outputs.loss, outputs.metrics - - loss_sum = loss.tolist() if loss_sum is None else loss_sum + loss - - train_metrics = { - "train/loss": loss.tolist(), - "train/mean_loss": loss_sum / (current_step - self.arguments.step_start_point), - "train/learning_rate": self.scheduler( - jax.device_get(self.model_state.step)).tolist(), - "train/step": current_step, - "train/step_time": total_time, - "train/perplexity": jnp.exp(loss).tolist(), - "train/epoch": epoch_index - } - train_metrics.update(metrics) - log_metrics = copy.deepcopy(train_metrics) - train_metrics.update(self.arguments.captured_memory) - if self.arguments.use_wandb: - with jax.spmd_mode("allow_all"): - self.wandb_runtime.log( - train_metrics - ) - pbar.update(1) - pbar.set_postfix(**{k.replace("train/", ""): v for k, v in log_metrics.items()}) - else: - break - except KeyboardInterrupt: - termcolor.cprint( - "KeyboardInterrupt At training model Will return Current State of the Model with Parameters.", - color="cyan", - force_color=True - ) - - except EasyDeLTimerError: - termcolor.cprint( - "Training reached out maximum training Time Killing training Process " - "and Will return Current State of the Model with Parameters.", - color="cyan", - force_color=True - ) - - if self.arguments.merge_lora_rapture_parameters and self.rapture is not None: - print( - termcolor.colored( - "Info : ", color="red", force_color=True - ), - termcolor.colored( - "Merging LoRA Parameters.", color="white", force_color=True - ) - ) - self.model_state = self.model_state.replace( - params=self.rapture.merge_parameters(self.model_state.params) - ) - - shard_fns, gather_fns = make_shard_and_gather_fns( - partition_specs=match_partition_rules( - rules=self.model_state.module.config.get_partition_rules( - self.arguments.fully_sharded_data_parallel - ), - params=jax.eval_shape(lambda: self.model_state) - ), - dtype_specs=self.arguments.dtype - ) - output = ORPOTrainerOutput( - state=self.model_state, - mesh=self.mesh, - shard_fns=shard_fns, - gather_fns=gather_fns, - checkpoint_manager=self.checkpoint_manager, - ) - if self.arguments.save_steps is None and self.arguments.do_last_save: - shard_fns, gather_fns = make_shard_and_gather_fns( - match_partition_rules( - self.config.get_partition_rules( - fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel - ) if self.arguments.custom_rule is None else self.arguments.custom_rule, - jax.eval_shape(lambda: self.model_state) - ), - dtype_specs=self.dtype - ) # You have to re-init the new shard and gather functions in order to be able to skip LoRA weight - # crashing errors and saving errors - filename = self._save_state( - state=self.model_state, - gather_fns=gather_fns - ) - checkpoint_path = f"{str(self.arguments.get_path())}/{filename}" - - if self.arguments.do_eval: - for _ in self.eval( - self.model_state - ): - ... - - output.checkpoint_path = checkpoint_path - output.last_save_file_name = filename - wandb.finish() - - return output - - def eval(self, model_state: EasyDeLState) -> typing.Iterator[dict]: - """Evaluate the Given Model State and yield the eval metrics""" - assert self.eval_dataset is not None, "`dataloader_eval` is required by evaluator function." - with self.mesh: - pbar = tqdm(total=self.max_evaluation_steps) - pbar.set_description("Evaluating") - current_step = 0 - loss_sum = None - try: - for batch in self.dataloader_eval: - current_step += 1 - time_start = time.time() - for key in self.arguments.ids_to_pop_from_dataset: - _ = batch.pop(key, None) - for key in list(batch.keys()): - if not ( - key.endswith("_input_ids") - or key.endswith("_attention_mask") - or key.endswith("_labels") - ): - _ = batch.pop(key, None) - - _, outputs = self.sharded_eval_step_function( - model_state, - batch - ) - total_time = time.time() - time_start - ( - loss, metrics - ) = outputs.loss, outputs.metrics - - loss_sum = loss.tolist() if loss_sum is None else loss_sum + loss - - eval_metrics = { - "eval/loss": loss.tolist(), - "eval/mean_loss": loss_sum / (current_step - self.arguments.step_start_point), - "eval/step": current_step, - "eval/step_time": total_time, - "eval/perplexity": jnp.exp(loss).tolist(), - } - eval_metrics.update(metrics) - log_metrics = copy.deepcopy(eval_metrics) - eval_metrics.update(self.arguments.captured_memory) - if self.arguments.use_wandb: - with jax.spmd_mode("allow_all"): - self.wandb_runtime.log( - eval_metrics - ) - - pbar.update(1) - pbar.set_postfix(**{k.replace("eval/", ""): v for k, v in log_metrics.items()}) - yield eval_metrics - except KeyboardInterrupt: - termcolor.cprint( - "KeyboardInterrupt At Evaluation model Will return Nothing and just pass.", - color="cyan", - force_color=True - ) - - def __repr__(self): - - """ - The __repr__ function is used to generate a string representation of an object. - This function should return a string that can be parsed by the Python interpreter - to recreate the object. The __repr__ function is called when you use print() on an - object, or when you type its name in the REPL. - - :param self: Refer to the instance of the class - :return: A string representation of the object - """ - string = f"{self.__class__.__name__}(\n" - for k, v in self.__dict__.items(): - if not k.startswith("_"): - try: - repr_src = f"\t{k} : " + v.__str__().replace("\n", "\n\t") + "\n" - string += repr_src if len(repr_src) < 350 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n" - except TypeError: - repr_src = f"\t{k} : " + "EasyDeLReadingError" + "\n" - string += repr_src if len(repr_src) < 350 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n" - - return string + ")" - - def __str__(self): - - """ - The __str__ function is called when you use the print function or when str() is used. - It should return a string representation of the object. - - :param self: Refer to the instance of the class - :return: The object's string representation - """ - return self.__repr__() +import copy +import os +import sys +import time +import typing +import warnings +from abc import ABC +from collections import defaultdict +from glob import glob + +import flax.core +import jax +import tensorflow.data +import tensorflow_datasets +import termcolor +import wandb +from fjformer import match_partition_rules, make_shard_and_gather_fns +from flax.core import FrozenDict +from tqdm import tqdm + +from typing import ( + Optional, + Literal, + Dict, + Union, + Any, + Callable, + Mapping, + Tuple +) + +from jax.experimental.pjit import pjit +from datasets import Dataset +from jax import numpy as jnp + +from ...etils.etils import get_logger +from ..training_configurations import TrainArguments +from ..base_trainer import ( + BaseTrainer, + TrainerConfigureFunctionFuncOutput, + TrainerConfigureDataloaderFuncOutput, + TrainerConfigureModelFuncOutput +) +from ...etils import EasyDeLState, EasyDeLTimerError +from transformers import PreTrainedTokenizerBase +from jax.sharding import PartitionSpec + +from ...utils import Timers, prefix_print +from ..dpo.utils import ( + pad_to_length, + DPODataCollatorWithPadding, + leave_alone_context_manager +) +from .fwd_bwd_functions import ( + create_orpo_step_function, + create_concatenated_forward, +) +from .modelling_output import ORPOTrainerOutput + +logger = get_logger(__name__) + + +class ORPOTrainer(BaseTrainer, ABC): + """ + easydel ORPO Trainer Class + """ + + def __init__( + self, + arguments: TrainArguments, + max_length: Optional[int] = None, + max_prompt_length: Optional[int] = None, + max_completion_length: Optional[int] = None, + beta: float = 0.1, + disable_dropout: bool = True, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + padding_value: int = None, + data_collator: Optional[DPODataCollatorWithPadding] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + _do_init_fns: bool = True, + dataset_map_arguments: Optional[Dict[str, Any]] = None, + low_mem_usage: bool = False, + ): + + """ + The __init__ function is called when the class is instantiated. + It sets up the attributes of an object. + + + :param self: Refer to the object itself + :param beta: float: Control the strength of the regularization term + :param arguments: TrainArguments: Pass the arguments to the trainer + :param label_pad_token_id: int: Pad the labels + :param padding_value: int: Specify the value that is used for padding + :param train_dataset: Optional[Dataset]: Load the training dataset + :param eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] : Pass the evaluation dataset to the trainer + :param tokenizer: Optional[PreTrainedTokenizerBase]: Pass the tokenizer to the trainer + :param max_length: Optional[int]: Set the maximum length of the input sequence + :param max_prompt_length: Optional[int]: Set the maximum length of the prompt + :param max_completion_length: Optional[int]: Truncate the target sequence + :param data_collator: Optional[Callable]: Function to be used for creating datasets. + tokenizing process with `dataset.map`. + :param _do_init_fns: bool : preferred to set ture to trainer will automatically configure + model with provided training Arguments + :param : Set the padding value for the model + """ + + assert arguments is not None, ( + "You Have to pass arguments that will be used for training but you have passed" + "`arguments=None`" + ) + assert isinstance(arguments, TrainArguments), ( + f"arguments type must be `TrainArguments` but got {type(arguments)}" + ) + + if tokenizer is None: + raise ValueError("tokenizer must be specified to tokenize a ORPO dataset.") + if max_length is None: + warnings.warn( + "`max_length` is not set in the ORPOTrainer's init" + " it will default to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_prompt_length is None: + warnings.warn( + "`max_prompt_length` is not set in the ORPOTrainer's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + + if max_completion_length is None: + warnings.warn( + "When using an encoder decoder architecture, you should set `max_completion_length` in the " + "ORPOTrainer's init it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_completion_length = 128 + + padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id + self.max_length = max_length + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value + self.max_prompt_length = max_prompt_length + self.truncation_mode = arguments.truncation_mode + self.disable_dropout = disable_dropout + self.max_completion_length = max_completion_length + self.tokenizer = tokenizer + self.is_encoder_decoder = is_encoder_decoder + self.low_mem_usage = low_mem_usage + self.beta = beta + data_collator = DPODataCollatorWithPadding( + max_prompt_length=self.max_prompt_length, + max_target_length=self.max_completion_length, + pad_token_id=tokenizer.pad_token_id, + label_pad_token_id=label_pad_token_id, + is_encoder_decoder=False, + ) if data_collator is None else data_collator + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + if dataset_map_arguments is None: + dataset_map_arguments = {} + train_dataset = train_dataset.map( + self.tokenize_row, + **dataset_map_arguments + ) + if eval_dataset is not None: + eval_dataset = eval_dataset.map( + self.tokenize_row, + **dataset_map_arguments + ) + + self.arguments = arguments + self.hp_name = None + self.deepspeed = None + self.is_in_train = False + + self.data_collator = data_collator + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.tokenizer = tokenizer + self._loggers_initialized = False + self.mesh = self.arguments.get_mesh() + assert padding_value is not None, "`padding_value` can not be set as `None` it must be an integer." + + self.concatenated_forward = create_concatenated_forward( + is_encoder_decoder=self.is_encoder_decoder, + padding_value=padding_value, + label_pad_token_id=label_pad_token_id, + ) + + self._cached_p_l_s = None + self._cached_c_l_s = None + self._cached_r_l_s = None + + super().__init__( + arguments=arguments, + dataset_train=train_dataset, + dataset_eval=eval_dataset, + finetune=True, + checkpoint_path=None, + _do_init_fns=_do_init_fns + ) + + def build_tokenized_answer(self, prompt, answer): + """ + Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. + It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`. + """ + + full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] + + answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids):] + answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids):] + prompt_input_ids = jnp.asarray(prompt_input_ids, dtype="i4") + answer_input_ids = jnp.asarray(answer_input_ids, dtype="i4") + full_concat_input_ids = jnp.concatenate( + ( + prompt_input_ids, + answer_input_ids + ) + ) + + # Prepare input tokens for token by token comparison + full_input_ids = jnp.array(full_tokenized["input_ids"]) + + if len(full_input_ids) != len(full_concat_input_ids): + raise ValueError("Prompt input ids and answer input ids should have the same length.") + + response_token_ids_start_idx = len(prompt_input_ids) + if prompt_input_ids.tolist() != full_tokenized["input_ids"][:response_token_ids_start_idx]: + response_token_ids_start_idx -= 1 + + prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] + prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] + + if len(prompt_input_ids) != len(prompt_attention_mask): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] + answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + + return dict( + prompt_input_ids=jnp.array(prompt_input_ids, dtype="i4"), + prompt_attention_mask=jnp.array(prompt_attention_mask, dtype="i4"), + input_ids=jnp.array(answer_input_ids, dtype="i4"), + attention_mask=jnp.array(answer_attention_mask, dtype="i4"), + ) + + def tokenize_row(self, feature, state: EasyDeLState = None) -> Dict: + + """ + The tokenize_row function is responsible for taking a single row of data and converting it into the format that + the model expects. This includes: + - Tokenizing the text (using HuggingFace's tokenizer) + - Padding/truncating sequences to a fixed length (if necessary) + - Creating attention masks, which tell the model which tokens are padding and which aren't. + + :param self: Represent the instance of the class + :param feature: Pass in the data from the dataset + :param state: EasyDeLState: Keep track of the state of the tokenizer + :return: A dictionary of the following keys + """ + batch = {} + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)} , {prompt}") + prompt_tokens = self.tokenizer( + prompt, + add_special_tokens=False, + return_tensors="np", + ) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} + + if not isinstance(chosen, str): + raise ValueError(f"chosen should be an str but got {type(chosen)} , {chosen}") + chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + if not isinstance(rejected, str): + raise ValueError(f"rejected should be an str but got {type(rejected)}") + rejected_tokens = self.build_tokenized_answer(prompt, rejected) + v2d = lambda ar: ar.reshape(1, -1) if ar.ndim == 1 else ar + + def add_tkn(n, ar): + return jnp.concatenate( + ( + jnp.array(n).reshape(1, 1), + v2d(ar) + ), axis=-1 + ) + + def add_post_tkn(n, ar): + return jnp.concatenate( + ( + v2d(ar), + jnp.array(n).reshape(1, 1) + ), axis=-1 + ) + + prompt_tokens["prompt_input_ids"] = add_tkn( + self.tokenizer.bos_token_id, + prompt_tokens["prompt_input_ids"] + ) + chosen_tokens["prompt_input_ids"] = add_tkn( + self.tokenizer.bos_token_id, + chosen_tokens["prompt_input_ids"] + ) + rejected_tokens["prompt_input_ids"] = add_tkn( + self.tokenizer.bos_token_id, + rejected_tokens["prompt_input_ids"] + ) + + prompt_tokens["prompt_attention_mask"] = add_tkn( + 1, prompt_tokens["prompt_attention_mask"] + ) + chosen_tokens["prompt_attention_mask"] = add_tkn( + 1, chosen_tokens["prompt_attention_mask"] + ) + rejected_tokens["prompt_attention_mask"] = add_tkn( + 1, rejected_tokens["prompt_attention_mask"] + ) + + # add EOS token to end of answer + chosen_tokens["input_ids"] = add_post_tkn(self.tokenizer.eos_token_id, chosen_tokens["input_ids"]) + chosen_tokens["attention_mask"] = add_post_tkn(1, chosen_tokens["attention_mask"]) + + rejected_tokens["input_ids"] = add_post_tkn(self.tokenizer.eos_token_id, rejected_tokens["input_ids"]) + rejected_tokens["attention_mask"] = add_post_tkn(1, rejected_tokens["attention_mask"]) + + longer_response_length = max(chosen_tokens["input_ids"].shape[-1], rejected_tokens["input_ids"].shape[-1]) + + # if combined sequence is too long, truncate the prompt + for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: + length_rn = answer_tokens["prompt_input_ids"].shape[-1] + longer_response_length + if length_rn > self.max_length: + + if self.truncation_mode == "keep_start": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][:, : self.max_prompt_length] + elif self.truncation_mode == "keep_end": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][:, -self.max_prompt_length:] + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + # if that's still too long, truncate the response + for answer_tokens in [chosen_tokens, rejected_tokens]: + if answer_tokens["prompt_input_ids"].shape[-1] + longer_response_length > self.max_length: + for k in ["input_ids", "attention_mask"]: + answer_tokens[k] = answer_tokens[k][:, : self.max_length - self.max_prompt_length] + + chosen_sequence_tokens = { + k: jnp.concatenate( + (v2d(chosen_tokens[f"prompt_{k}"]), v2d(chosen_tokens[k])), + axis=-1 + ) for k in ["input_ids", "attention_mask"] + } + rejected_sequence_tokens = { + k: jnp.concatenate( + (v2d(rejected_tokens[f"prompt_{k}"]), v2d(rejected_tokens[k])), + axis=-1 + ) for k in ["input_ids", "attention_mask"] + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["labels"].at[ + : len(chosen_tokens["prompt_input_ids"]) + ].set([self.label_pad_token_id] * len(chosen_tokens["prompt_input_ids"])) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["labels"].at[ + : len(rejected_tokens["prompt_input_ids"]) + ].set( + ([self.label_pad_token_id] * len(rejected_tokens["prompt_input_ids"])) + ) + + for k, tokens_ in { + "chosen_": chosen_sequence_tokens, + "rejected_": rejected_sequence_tokens, + "": prompt_tokens, + }.items(): + for type_key, tokens in tokens_.items(): + if type_key == "token_type_ids": + continue + + b, s = tokens.shape + + if self.max_prompt_length > s: + if k == "chosen_": + if type_key == "input_ids": + tokens = pad_to_length( + tokens, + self.max_completion_length, + pad_value=self.padding_value, + axis=-1 + ) + elif type_key == "attention_mask": + tokens = pad_to_length( + tokens, + self.max_completion_length, + pad_value=0, + axis=-1 + ) + elif type_key == "labels": + tokens = pad_to_length( + tokens, + self.max_completion_length, + pad_value=self.padding_value, + axis=-1 + ) + + tokens = tokens[..., :self.max_completion_length] + + if tokens.shape[-1] != self.max_completion_length: + raise ValueError( + f"there was an error in padding token with `type_key` of {type_key}" + f". it must have sequence_length of {self.max_completion_length} but we got {tokens.shape[-1]}" + f" From {k}{type_key}" + ) + tokens = tokens[..., :self.max_completion_length] + elif k == "rejected_": + if type_key == "input_ids": + tokens = pad_to_length( + tokens, + self.max_completion_length, + pad_value=self.padding_value, + axis=-1 + ) + elif type_key == "attention_mask": + tokens = pad_to_length( + tokens, + self.max_completion_length, + pad_value=0, + axis=-1 + ) + elif type_key == "labels": + tokens = pad_to_length( + tokens, + self.max_completion_length, + pad_value=self.padding_value, + axis=-1 + ) + tokens = tokens[..., :self.max_completion_length] + if tokens.shape[-1] != self.max_completion_length: + raise ValueError( + f"there was an error in padding token with `type_key` of {type_key}" + f". it must have sequence_length of {self.max_completion_length} but we got {tokens.shape[-1]}" + f" From {k}{type_key}" + ) + elif k == "": + if type_key == "prompt_input_ids": + tokens = pad_to_length( + tokens, + self.max_prompt_length, + pad_value=self.padding_value, + axis=-1 + ) + elif type_key == "prompt_attention_mask": + tokens = pad_to_length( + tokens, + self.max_prompt_length, + pad_value=0, + axis=-1 + ) + elif type_key == "prompt_labels": + tokens = pad_to_length( + tokens, + self.max_prompt_length, + pad_value=self.padding_value, + axis=-1 + ) + tokens = tokens[..., :self.max_prompt_length] + if tokens.shape[-1] != self.max_prompt_length: + raise ValueError( + f"there was an error in padding token with `type_key` of {type_key}" + f". it must have sequence_length of {self.max_prompt_length} but we got {tokens.shape[-1]}" + f" From {k}{type_key}" + ) + batch[f"{k}{type_key}"] = tokens + return batch + + def configure_functions(self) -> TrainerConfigureFunctionFuncOutput: + """ + The configure_functions function is responsible for configuring the functions that will be used in training. + It does this by first defining a function called function_configurations, which initializes the model parameters + and returns + them as a EasyDeLState object. The EasyDeLState object contains all the information needed to train or evaluate + on a batch of data, including: + :param self: Access the class attributes + :return: A TrainerConfigureFunctionFuncOutput object + + """ + + def initialize_state_function(): + initialized_parameters = self.model.init_weights( + jax.random.PRNGKey(0), + self.arguments.init_input_shape + ) + + if self.arguments.dtype == jnp.bfloat16: + initialized_parameters = self.model.to_bf16(initialized_parameters) + elif self.arguments.dtype == jnp.float16: + initialized_parameters = self.model.to_fp16(initialized_parameters) + + tx = self.tx + parameters = flax.core.freeze({"params": initialized_parameters}) + tx_init = copy.deepcopy(self.arguments.optimizer_kwargs) + + if self.rapture is not None: + lora_parameters = self.lora_parameters + if self.arguments.dtype == jnp.bfloat16: + lora_parameters = self.model.to_bf16(lora_parameters) + elif self.arguments.dtype == jnp.float16: + lora_parameters = self.model.to_fp16(lora_parameters) + + return EasyDeLState( + step=0, + apply_fn=self.lora_apply_fn, + params=lora_parameters, + tx=self.lora_tx, + opt_state=self.lora_opt_state, + tx_init=EasyDeLState.safe_dict(tx_init), + hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), + module=self.lora_model, + module_config=self.model.config, + module_config_args=None, + ) + else: + return EasyDeLState.create( + tx=tx, + params=parameters, + apply_fn=self.model.__call__, + module_config=copy.deepcopy(self.model.config), + tx_init=tx_init, + hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), + module=self.model, + module_config_args=None + ) + + def create_state_from_params_function(parameters): + if self.rapture is None: + return EasyDeLState.create( + tx=self.tx, + params=parameters, + apply_fn=self.model.__call__, + module_config=copy.deepcopy(self.model.config), + tx_init=copy.deepcopy(self.arguments.optimizer_kwargs), + hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), + module=self.model, + module_config_args=None + ) + else: + return EasyDeLState( + step=0, + apply_fn=self.lora_apply_fn, + params=parameters, + tx=self.lora_tx, + opt_state=self.lora_opt_state, + tx_init=EasyDeLState.safe_dict(copy.deepcopy(self.arguments.optimizer_kwargs)), + hyperparameters=EasyDeLState.create_hyperparameters(self.model.config.model_type), + module=self.lora_model, + module_config=self.model.config, + module_config_args=None, + ) + + state_shape = jax.eval_shape(initialize_state_function) + state_partition_spec = match_partition_rules( + self.config.get_partition_rules( + fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel + ) if self.arguments.custom_rule is None else self.arguments.custom_rule, + state_shape + ) + create_sharded_state_from_params_function = pjit( + create_state_from_params_function, + in_shardings=(state_partition_spec.params,), + out_shardings=state_partition_spec, + donate_argnums=(0,) + ) + sharded_train_step_function = pjit( + create_orpo_step_function( + mode="train", + beta=self.beta, + concatenated_forward=self.concatenated_forward, + batch_partition_spec=self.arguments.step_partition_spec + ), + in_shardings=(state_partition_spec, PartitionSpec()), + out_shardings=(state_partition_spec, PartitionSpec(),), + + ) + + sharded_eval_step_function = pjit( + create_orpo_step_function( + mode="eval", + beta=self.beta, + concatenated_forward=self.concatenated_forward, + batch_partition_spec=self.arguments.step_partition_spec + ), + in_shardings=(state_partition_spec, PartitionSpec()), + out_shardings=(state_partition_spec, PartitionSpec(),), + + ) + + mesh = self.arguments.get_mesh() + self.arguments.ckpt_path_exists() + checkpoint_manager = self.arguments.get_streaming_checkpointer() + self.state_partition_spec = state_partition_spec + self.state_shape = state_shape + + return TrainerConfigureFunctionFuncOutput( + create_sharded_state_from_params_function=create_sharded_state_from_params_function, + sharded_train_step_function=sharded_train_step_function, + sharded_eval_step_function=sharded_eval_step_function, + mesh=mesh, + checkpoint_manager=checkpoint_manager, + initialize_state_function=initialize_state_function + ) + + def initialize_state( + self, + model_parameters: Optional[flax.core.FrozenDict] = None, + state: Optional[EasyDeLState] = None, + ) -> Tuple[EasyDeLState, Mapping[str, Callable], Mapping[str, Callable]]: + if model_parameters is None and state is None and self.rapture is None and self.checkpoint_path is None: + raise RuntimeError( + "You are passing `model_parameters=None`, `state=None`, and `checkpoint_path=None` and also you are not" + " using LoRA, if you are " + "Using LoRA make sure to pass parameters and Rapture Config correctly otherwise pass the " + "model_parameters or state." + ) + if model_parameters is None and state is None: + model_parameters = self.lora_parameters + with self.mesh: + shard_fns, gather_fns = make_shard_and_gather_fns( + self.state_partition_spec, + dtype_specs=self.dtype + ) + if state is not None: + sharded_state = state + params = sharded_state.params if not self.arguments.do_shard_fns else jax.tree_util.tree_map( + lambda f, x: f(x), + shard_fns.params, + sharded_state.params + ) + sharded_state.params = params + if sharded_state.opt_state is None: + prefix_print( + "Action", "Optimizer State is not Found!, initializing one." + ) + with jax.default_device(self.arguments.offload_device): + sharded_state = sharded_state.init_opt_state() + opt_state = sharded_state.opt_state if not self.arguments.do_shard_fns else jax.tree_util.tree_map( + lambda f, x: f(x), + shard_fns.opt_state, + sharded_state.opt_state + ) + sharded_state = sharded_state.replace( + opt_state=opt_state + ) + elif self.finetune: + + if model_parameters is None and self.checkpoint_path is not None: + prefix_print( + "Action", f"Loading Model From {self.checkpoint_path}" + ) + with jax.default_device(self.arguments.offload_device): + sharded_state = EasyDeLState.load_state( + verbose=self.arguments.verbose, + state_shard_fns=shard_fns, + init_optimizer_state=True, + checkpoint_path=self.checkpoint_path, + input_shape=self.arguments.init_input_shape, + config_kwargs=self.arguments.loaded_model_config_kwargs + ) + state_shape = jax.eval_shape(lambda: sharded_state) + state_partition_spec = match_partition_rules( + self.config.get_partition_rules( + fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel + ) if self.arguments.custom_rule is None else self.arguments.custom_rule, + state_shape + ) + sharded_train_step_function = pjit( + create_orpo_step_function( + mode="train", + beta=self.beta, + concatenated_forward=self.concatenated_forward, + batch_partition_spec=self.arguments.step_partition_spec + ), + in_shardings=(state_partition_spec, PartitionSpec()), + out_shardings=(state_partition_spec, PartitionSpec(),), + + ) + + sharded_eval_step_function = pjit( + create_orpo_step_function( + mode="eval", + beta=self.beta, + concatenated_forward=self.concatenated_forward, + batch_partition_spec=self.arguments.step_partition_spec + ), + in_shardings=(state_partition_spec, PartitionSpec()), + out_shardings=(state_partition_spec, PartitionSpec(),), + ) + + self.state_partition_spec = state_partition_spec + self.state_shape = state_shape + self.sharded_train_step_function = sharded_train_step_function + self.sharded_eval_step_function = sharded_eval_step_function + + if self.arguments.remove_ckpt_after_load: + os.remove(self.checkpoint_path) + elif model_parameters is not None and self.checkpoint_path is None: + prefix_print( + "Action", f"Sharding Passed Parameters" + ) + from flax.core import unfreeze + if not isinstance(model_parameters, flax.core.FrozenDict): + prefix_print( + "Warning", + "Model Parameters should be like FrozenDict({'params': params}) make sure to " + "pass as type FrozenDict in case of not getting UnExcepted Errors " + ) + + model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map( + lambda f, x: f(x), + shard_fns.params, + model_parameters, + ) + sharded_state = self.create_sharded_state_from_params_function(model_parameters) + elif model_parameters is not None and self.checkpoint_path is not None: + raise EasyDeLTimerError( + "You can't pass `model_parameters` and `checkpoint_path` at same time" + ) + else: + raise EasyDeLTimerError( + "You should pass `model_parameters` or `checkpoint_path` to trainer in order to load model" + ) + else: + sharded_state = self.initialize_state_function() + params = sharded_state.params if not self.arguments.do_shard_fns else jax.tree_util.tree_map( + lambda f, x: f(x), + shard_fns.params, + sharded_state.params + ) + sharded_state.params = params + + self.sharded_state = sharded_state + return sharded_state, shard_fns, gather_fns + + def _save_state( + self, + state: EasyDeLState, + gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]], + milestone: bool = False + ) -> str: + step = int( + jax.device_get( + state.step + ) + ) + self.arguments.step_start_point if self.arguments.step_start_point is not None else int( + jax.device_get( + state.step + ) + ) + + checkpoint_dir = os.path.join(self.arguments.save_dir, self.arguments.model_name) + filename_extension = ".easy" + if self.arguments.save_total_limit: + checkpoint_files = glob(os.path.join(checkpoint_dir, f"*{filename_extension}")) + checkpoint_files.sort(key=os.path.getmtime) + for old_checkpoint in checkpoint_files[:-self.arguments.save_total_limit]: + os.remove(old_checkpoint) + termcolor.cprint(f"Removed old checkpoint: {old_checkpoint}", color="red", force_color=True) + + checkpoint_name = f"{self.arguments.model_name}-S{step}" + filename = f"{checkpoint_name}_{step}" if milestone else f"{checkpoint_name}" + filename += ".easy" + termcolor.cprint(f"Saving Model {filename}.", color="cyan", force_color=True) + state.save_state( + filename=filename, + checkpoint_dir=checkpoint_dir, + gather_fns=gather_fns, + float_dtype=self.dtype, + verbose=self.arguments.verbose, + save_optimizer=self.arguments.save_optimizer_state, + ) + return filename + + def initialize_trainer_utils(self): + """ + The initialize_trainer_utils function is responsible for initializing the following: + - wandb_runtime (if you use_wandb is True) + - timer object (for logging time taken by various functions) + - dataloader objects for training and evaluation data, along with max steps per epoch. + The configure_dataloader function accomplishes this task. + + :param self: Represent the instance of the class + :return: A tuple of functions + + """ + self.wandb_runtime = self.arguments.get_wandb_init() if self.arguments.use_wandb else None + self.timer = Timers( + use_wandb=False, + tensorboard_writer=self.arguments.get_board() + ) + + self.timer("configure dataloaders").start() + dataset_configurations = self.configure_dataloader() + self.dataloader_train = dataset_configurations.dataloader_train + self.max_training_steps = dataset_configurations.max_training_steps + self.dataloader_eval = dataset_configurations.dataloader_eval + self.max_evaluation_steps = dataset_configurations.max_evaluation_steps + + self.timer("configure dataloaders").stop() + + self.timer.log(["configure dataloaders"]) + + self.timer("configure Model, Optimizer, Scheduler and Config").start() + model_configurations = self.configure_model() + model = model_configurations.model + tx = model_configurations.tx + scheduler = model_configurations.scheduler + config = model_configurations.config + self.model = model + self.tx = tx + self.scheduler = scheduler + self.config = config + if self.rapture is not None: + lora_modules = self.rapture.apply_lora( + module=model, + parameters=self.arguments.rapture_config.parameters, + tx=tx, + ) + self.lora_parameters = lora_modules.lora_parameters + self.lora_apply_fn = lora_modules.lora_module.__call__ + self.lora_opt_state = lora_modules.lora_opt_state + self.lora_model = lora_modules.lora_module + self.lora_tx = lora_modules.lora_tx + + self.timer("configure Model, Optimizer, Scheduler and Config").stop() + self.timer.log(["configure Model, Optimizer, Scheduler and Config"]) + + self.timer("configure functions and sharding them").start() + + function_configurations = self.configure_functions() + self.create_sharded_state_from_params_function = ( + function_configurations.create_sharded_state_from_params_function + ) + self.sharded_train_step_function = function_configurations.sharded_train_step_function + self.sharded_eval_step_function = function_configurations.sharded_eval_step_function + self.mesh = function_configurations.mesh + self.checkpoint_manager = function_configurations.checkpoint_manager + self.initialize_state_function = function_configurations.initialize_state_function + self.timer("configure functions and sharding them").stop() + self.timer.log(["configure functions and sharding them"]) + + def create_collate_function( + self, + max_sequence_length: int, + truncation_mode: typing.Literal["keep_end", "keep_start"] = "keep_end", + ) -> Callable: + return self.data_collator + + def shard_states(self, state, rules): + with self.arguments.get_mesh(): + partition_spec = match_partition_rules(rules=rules, params=jax.eval_shape(lambda: state)) + + def _shard(x): + return x + + shard = pjit( + _shard, + in_shardings=(PartitionSpec(),), + out_shardings=partition_spec + ) + return shard(state) + + def configure_dataloader(self) -> TrainerConfigureDataloaderFuncOutput: + dataloader_train = self.get_train_dataloader() + max_evaluation_steps = None + dataloader_eval = None + + max_training_steps = self.arguments.num_train_epochs * len( + dataloader_train + ) if self.arguments.max_training_steps is None else self.arguments.max_training_steps + if self.eval_dataset is not None: + dataloader_eval = self.get_eval_dataloader(self.eval_dataset) + max_evaluation_steps = len(dataloader_eval) + return TrainerConfigureDataloaderFuncOutput( + dataloader_train=dataloader_train, # type:ignore + max_training_steps=max_training_steps, + dataloader_eval=dataloader_eval, + max_evaluation_steps=max_evaluation_steps + ) + + def _get_train_dataloader(self) -> tensorflow.data.Dataset: + + """ + The _get_train_dataloader function is used to create a tensorflow.data.Dataset object for the training dataset. + + :param self: Represent the instance of the class + :return: A dataloader object + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + + return tensorflow_datasets.as_numpy( + train_dataset.to_tf_dataset( + batch_size=self.arguments.total_batch_size, + collate_fn=data_collator, + num_workers=self.arguments.dataloader_num_workers, + shuffle=True, + drop_remainder=True + ) + ) + + def _get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> tensorflow.data.Dataset: + """ + Returns the evaluation [`~tensorflow.data.Dataset`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + return tensorflow_datasets.as_numpy( + eval_dataset.to_tf_dataset( + batch_size=self.arguments.total_batch_size, + collate_fn=self.data_collator, + num_workers=self.arguments.dataloader_num_workers, + shuffle=False, + drop_remainder=True + ) + ) + + def get_train_dataloader( + self, + ) -> tensorflow.data.Dataset: + """ + Returns the training [`~tensorflow.data.Dataset`]. + """ + return self._get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> tensorflow.data.Dataset: + """ + Returns the evaluation [`~tensorflow.data.Dataset`]. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + return self._get_eval_dataloader(eval_dataset=eval_dataset) + + def train( + self, + model_parameters: Optional[flax.core.FrozenDict] = None, + state: Optional[EasyDeLState] = None + ) -> ORPOTrainerOutput: + def get_layer_names(frozen_dict, prefix=""): + layer_names = {} + for key, value in frozen_dict.items(): + if isinstance(value, FrozenDict): + layer_names.update(get_layer_names(value, prefix=f"{prefix}_{key}")) + else: + layer_name = f"{prefix}_{key}".lstrip("/") + layer_names[layer_name] = value + return layer_names + + def count_model_parameters(_p): + termcolor.cprint( + f"Model Contain {sum(n.size for n in jax.tree_util.tree_flatten(flax.core.unfreeze(_p))[0]) / 1e9} " + f"Billion Parameters", + color="red", force_color=True + ) + + checkpoint_path = "SAVING_SKIPPED" + if self.arguments.performance_mode: + termcolor.cprint( + "Performance Mode is ON, we will ignore the Memory Tracking, WANDB Logging, and extra information " + "Process.", + color="red", + force_color=True + ) + sharded_state, shard_fns, gather_fns = self.initialize_state( + model_parameters=model_parameters, + state=state + ) + self.model_state = sharded_state + count_model_parameters(sharded_state.params) + with self.mesh: + with jax.default_device(jax.devices("cpu")[0]) if self.low_mem_usage else leave_alone_context_manager(): + dir_prefix: str = "/dev/shm" if sys.platform != "win32" else "." + checkpoint_path = "SAVING_SKIPPED" + + pbar = tqdm(total=self.max_training_steps) + pbar.set_description("Training") + current_step = self.model_state.step.tolist() if isinstance( + self.model_state.step, + jax.Array + ) else self.model_state.step + + loss_sum = None + + try: + for epoch_index in range(self.arguments.num_train_epochs): + for batch in self.dataloader_train: + current_step += 1 + if self.arguments.step_start_point > current_step: + ... + elif current_step < self.max_training_steps: + time_start = time.time() + + self.model_state, outputs = self.sharded_train_step_function( + self.model_state, + batch + ) + total_time = time.time() - time_start + (loss, metrics) = outputs.loss, outputs.metrics + + loss_sum = loss.tolist() if loss_sum is None else loss_sum + loss + + train_metrics = { + "train/loss": loss.tolist(), + "train/mean_loss": loss_sum / (current_step - self.arguments.step_start_point), + "train/learning_rate": self.scheduler( + jax.device_get(self.model_state.step)).tolist(), + "train/step": current_step, + "train/step_time": total_time, + "train/perplexity": jnp.exp(loss).tolist(), + "train/epoch": epoch_index + } + train_metrics.update(metrics) + log_metrics = copy.deepcopy(train_metrics) + train_metrics.update(self.arguments.captured_memory) + if self.arguments.use_wandb: + with jax.spmd_mode("allow_all"): + self.wandb_runtime.log( + train_metrics + ) + pbar.update(1) + pbar.set_postfix(**{k.replace("train/", ""): v for k, v in log_metrics.items()}) + else: + break + except KeyboardInterrupt: + termcolor.cprint( + "KeyboardInterrupt At training model Will return Current State of the Model with Parameters.", + color="cyan", + force_color=True + ) + + except EasyDeLTimerError: + termcolor.cprint( + "Training reached out maximum training Time Killing training Process " + "and Will return Current State of the Model with Parameters.", + color="cyan", + force_color=True + ) + + if self.arguments.merge_lora_rapture_parameters and self.rapture is not None: + print( + termcolor.colored( + "Info : ", color="red", force_color=True + ), + termcolor.colored( + "Merging LoRA Parameters.", color="white", force_color=True + ) + ) + self.model_state = self.model_state.replace( + params=self.rapture.merge_parameters(self.model_state.params) + ) + + shard_fns, gather_fns = make_shard_and_gather_fns( + partition_specs=match_partition_rules( + rules=self.model_state.module.config.get_partition_rules( + self.arguments.fully_sharded_data_parallel + ), + params=jax.eval_shape(lambda: self.model_state) + ), + dtype_specs=self.arguments.dtype + ) + output = ORPOTrainerOutput( + state=self.model_state, + mesh=self.mesh, + shard_fns=shard_fns, + gather_fns=gather_fns, + checkpoint_manager=self.checkpoint_manager, + ) + if self.arguments.save_steps is None and self.arguments.do_last_save: + shard_fns, gather_fns = make_shard_and_gather_fns( + match_partition_rules( + self.config.get_partition_rules( + fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel + ) if self.arguments.custom_rule is None else self.arguments.custom_rule, + jax.eval_shape(lambda: self.model_state) + ), + dtype_specs=self.dtype + ) # You have to re-init the new shard and gather functions in order to be able to skip LoRA weight + # crashing errors and saving errors + filename = self._save_state( + state=self.model_state, + gather_fns=gather_fns + ) + checkpoint_path = f"{str(self.arguments.get_path())}/{filename}" + + if self.arguments.do_eval: + for _ in self.eval( + self.model_state + ): + ... + + output.checkpoint_path = checkpoint_path + output.last_save_file_name = filename + wandb.finish() + + return output + + def eval(self, model_state: EasyDeLState) -> typing.Iterator[dict]: + """Evaluate the Given Model State and yield the eval metrics""" + assert self.eval_dataset is not None, "`dataloader_eval` is required by evaluator function." + with self.mesh: + pbar = tqdm(total=self.max_evaluation_steps) + pbar.set_description("Evaluating") + current_step = 0 + loss_sum = None + try: + for batch in self.dataloader_eval: + current_step += 1 + time_start = time.time() + for key in self.arguments.ids_to_pop_from_dataset: + _ = batch.pop(key, None) + for key in list(batch.keys()): + if not ( + key.endswith("_input_ids") + or key.endswith("_attention_mask") + or key.endswith("_labels") + ): + _ = batch.pop(key, None) + + _, outputs = self.sharded_eval_step_function( + model_state, + batch + ) + total_time = time.time() - time_start + ( + loss, metrics + ) = outputs.loss, outputs.metrics + + loss_sum = loss.tolist() if loss_sum is None else loss_sum + loss + + eval_metrics = { + "eval/loss": loss.tolist(), + "eval/mean_loss": loss_sum / (current_step - self.arguments.step_start_point), + "eval/step": current_step, + "eval/step_time": total_time, + "eval/perplexity": jnp.exp(loss).tolist(), + } + eval_metrics.update(metrics) + log_metrics = copy.deepcopy(eval_metrics) + eval_metrics.update(self.arguments.captured_memory) + if self.arguments.use_wandb: + with jax.spmd_mode("allow_all"): + self.wandb_runtime.log( + eval_metrics + ) + + pbar.update(1) + pbar.set_postfix(**{k.replace("eval/", ""): v for k, v in log_metrics.items()}) + yield eval_metrics + except KeyboardInterrupt: + termcolor.cprint( + "KeyboardInterrupt At Evaluation model Will return Nothing and just pass.", + color="cyan", + force_color=True + ) + + def __repr__(self): + + """ + The __repr__ function is used to generate a string representation of an object. + This function should return a string that can be parsed by the Python interpreter + to recreate the object. The __repr__ function is called when you use print() on an + object, or when you type its name in the REPL. + + :param self: Refer to the instance of the class + :return: A string representation of the object + """ + string = f"{self.__class__.__name__}(\n" + for k, v in self.__dict__.items(): + if not k.startswith("_"): + try: + repr_src = f"\t{k} : " + v.__str__().replace("\n", "\n\t") + "\n" + string += repr_src if len(repr_src) < 350 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n" + except TypeError: + repr_src = f"\t{k} : " + "EasyDeLReadingError" + "\n" + string += repr_src if len(repr_src) < 350 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n" + + return string + ")" + + def __str__(self): + + """ + The __str__ function is called when you use the print function or when str() is used. + It should return a string representation of the object. + + :param self: Refer to the instance of the class + :return: The object's string representation + """ + return self.__repr__() diff --git a/src/python/easydel/trainer/training_configurations.py b/src/python/easydel/trainer/training_configurations.py index 1d73bde80..4725ad3a6 100644 --- a/src/python/easydel/trainer/training_configurations.py +++ b/src/python/easydel/trainer/training_configurations.py @@ -122,85 +122,143 @@ def __init__( loaded_model_config_kwargs: Optional[dict] = None, **kwargs ): - """ -The __init__ function is called when the class is instantiated. -It sets up the attributes of an object, which are sometimes called fields or properties. -The __init__ function can accept arguments, just like a normal function. - -:param self: Represent the instance of the class -:param model_name: str: Specify the model name -:param num_train_epochs: int: Set the number of epochs for training -:param model_huggingface_repo_id: Optional[str]: Load a pretrained model from the huggingface model hub -:param model_class: Optional[EasyDeLFlaxPretrainedModel]: Pass a model class to the trainer -:param total_batch_size: int: Set the batch size of the model -:param max_training_steps: Optional[int]: Set the maximum total number of training steps across all epochs -:param max_evaluation_steps: Optional[int]: Set the maximum number of steps to evaluate for -:param optimizer: AVAILABLE_OPTIMIZERS: Specify the optimizer used to train the model -:param scheduler: AVAILABLE_SCHEDULERS: Set the learning rate scheduler -:param learning_rate: Union[int, float] : Set the learning rate for the optimizer -:param learning_rate_end: Optional[float]: Set the learning rate at the end of training -:param gradient_accumulation_steps: int: Accumulate gradients over multiple batches -:param weight_decay: float: Specify the weight decay to be used by the optimizer -:param label_smoothing_factor: float: Set the label smoothing factor to be used by the loss function -:param z_loss: float: Set the z loss factor to be used by the loss function -:param gradient_checkpointing: AVAILABLE_GRADIENT_CHECKPOINTS: Determine how to use gradient checkpointing -:param max_sequence_length: Optional[int]: Set the maximum length of the input sequence -:param sharding_array: Union[tuple,int]: Specify the mesh of devices to use for training -:param is_fine_tuning: bool: Tell the model whether or not to initialize the weights of -:param do_train: bool: Indicate whether to train the model or not -:param do_eval: bool: Determine whether to run evaluation on the validation set after training -:param do_test: Optional[bool]: Determine if the model should be tested -:param train_on_inputs: bool: Use input_ids instead of labels, overrides ignored (-100) tokens in the labels -:param backend: Optional[str]: Specify the backend of jax -:param extra_optimizer_kwargs: dict: Pass extra arguments to the optimizer -:param save_steps: Optional[int]: Save the model after every n steps -:param save_dir: str: Define the directory where the checkpoints will be saved -:param save_total_limit: int: Set the maximum number of checkpoints to keep, older checkpoints will be deleted -:param dtype: jnp.dtype: Set the dtype of the model parameters -:param param_dtype: jnp.dtype: Specify the data type of the model parameters -:param fully_sharded_data_parallel: bool: Determine if the model should be fully fsdp or not -:param use_wandb: bool: Enable or disable the wandb logging -:param custom_rule: Mapping[str, PartitionSpec]: Specify the partitioning rules of the model -:param extra_configs: Optional[dict]: Pass extra configurations to the model class -:param ids_to_pop_from_dataset: Optional[list]: Remove some of the ids from the dataset -:param remove_ckpt_after_load: bool: Remove the checkpoint after loading it -:param configs_to_initialize_model_class: Optional[dict]: Pass extra configurations to the model class -:param do_last_save: bool: Save the model after training is complete -:param model_parameters: Optional[dict]: Pass the model parameters to the model class -:param do_shard_fns: bool: Shard the model functions across devices -:param track_memory: bool: Track the memory usage of the model -:param loss_re_mat: str: Specify the regular expression to match the loss function name -:param loss_chunk: int: Chunk the loss to avoid memory overflow -:param truncation_mode: typing.Literal["keep_end", "keep_start"]: Determine if the input is left padded or not and - which side of the array should remain in case of using maximum padding. -:param warmup_steps: int: Specify the number of steps to warm up the learning rate -:param init_input_shape: Tuple[int, int]: Initialize the model with a shape that is not (batch_size, length) -:param step_partition_spec: PartitionSpec: Partition the model for training -:param training_time: Optional[str]: Set a time limit for the training process -:param dataloader_num_workers: Optional[int]: Set the number of workers used by pytorch's -:param dataloader_pin_memory: Optional[bool]: Pin the memory of the dataloader -:param jax_distributed_config: Optional[dict]: Configure the jax distributed backend -:param log_all_workers: bool: Log all workers in wandb, -:param wandb_entity: Optional[str]: Specify the entity to use when logging to weights & biases -:param save_optimizer_state : bool: when ever to save optimizer state and other args in checkpoint -:param step_start_point: Optional[int]: start training from given step for example instead of starting training from - step 0 it will start from 20000 and leave the data behind -:param verbose: bool: when ever to turn verbose mode of or on -:param offload_device: jax.Device: device to be used to offload parameters on -:param rapture_config: Optional[EasyDeLXRaptureConfig]: LoRA Config for models -:param merge_lora_rapture_parameters: bool: whenever to merge lora parameters with original parameters before saving -:param state_apply_fn_kwarguments_to_model: Optional[dict]: state_apply_fn_kwarguments_to_model is a dictionary that - be used to apply the parameters and extra things that you want to deliver to model. -:param remove_unused_columns: bool: when ever to remove the unused data columns from dataset -:param force_batch_and_gradient_accumulation_steps_calculation: bool: whether to force batch and gradient to be - applied as total batch_size (e.g total_batch_size = total_batch_size * gradient_accumulation_steps be applied) -:param performance_mode: bool: whether to optimize the whole training process this will cut off some logging options - and optimize training process. -:param neftune_noise_alpha: Optional[float]: If not `None`, this will activate NEFTune noise embeddings. This has been - proven to drastically improve model performances for instruction fine-tuning. -:param loaded_model_config_kwargs: Optional[dict]: config key arguments to be passed to the model while being loaded -from checkpoint -:param **kwargs: Pass keyword, variable-length argument list + """The __init__ function is called when the class is instantiated. + It sets up the attributes of an object, which are sometimes called fields or properties. + The __init__ function can accept arguments, just like a normal function. + + Args: + self: Represent the instance of the class + model_name: str: Specify the model name + num_train_epochs: int: Set the number of epochs for training + model_huggingface_repo_id: Optional[str]: Load a pretrained + model from the huggingface model hub + model_class: Optional[EasyDeLFlaxPretrainedModel]: Pass a + model class to the trainer + total_batch_size: int: Set the batch size of the model + max_training_steps: Optional[int]: Set the maximum total + number of training steps across all epochs + max_evaluation_steps: Optional[int]: Set the maximum number + of steps to evaluate for + optimizer: AVAILABLE_OPTIMIZERS: Specify the optimizer used + to train the model + scheduler: AVAILABLE_SCHEDULERS: Set the learning rate + scheduler + learning_rate: Union[int, float] : Set the learning rate for + the optimizer + learning_rate_end: Optional[float]: Set the learning rate at + the end of training + gradient_accumulation_steps: int: Accumulate gradients over + multiple batches + weight_decay: float: Specify the weight decay to be used by + the optimizer + label_smoothing_factor: float: Set the label smoothing + factor to be used by the loss function + z_loss: float: Set the z loss factor to be used by the loss + function + gradient_checkpointing: AVAILABLE_GRADIENT_CHECKPOINTS: + Determine how to use gradient checkpointing + max_sequence_length: Optional[int]: Set the maximum length + of the input sequence + sharding_array: Union[tuple,int]: Specify the mesh of + devices to use for training + is_fine_tuning: bool: Tell the model whether or not to + initialize the weights of + do_train: bool: Indicate whether to train the model or not + do_eval: bool: Determine whether to run evaluation on the + validation set after training + do_test: Optional[bool]: Determine if the model should be + tested + train_on_inputs: bool: Use input_ids instead of labels, + overrides ignored (-100) tokens in the labels + backend: Optional[str]: Specify the backend of jax + extra_optimizer_kwargs: dict: Pass extra arguments to the + optimizer + save_steps: Optional[int]: Save the model after every n + steps + save_dir: str: Define the directory where the checkpoints + will be saved + save_total_limit: int: Set the maximum number of checkpoints + to keep, older checkpoints will be deleted + dtype: jnp.dtype: Set the dtype of the model parameters + param_dtype: jnp.dtype: Specify the data type of the model + parameters + fully_sharded_data_parallel: bool: Determine if the model + should be fully fsdp or not + use_wandb: bool: Enable or disable the wandb logging + custom_rule: Mapping[str, PartitionSpec]: Specify the + partitioning rules of the model + extra_configs: Optional[dict]: Pass extra configurations to + the model class + ids_to_pop_from_dataset: Optional[list]: Remove some of the + ids from the dataset + remove_ckpt_after_load: bool: Remove the checkpoint after + loading it + configs_to_initialize_model_class: Optional[dict]: Pass + extra configurations to the model class + do_last_save: bool: Save the model after training is + complete + model_parameters: Optional[dict]: Pass the model parameters + to the model class + do_shard_fns: bool: Shard the model functions across devices + track_memory: bool: Track the memory usage of the model + loss_re_mat: str: Specify the regular expression to match + the loss function name + loss_chunk: int: Chunk the loss to avoid memory overflow + truncation_mode: typing.Literal["keep_end", "keep_start"]: + Determine if the input is left padded or not and which + side of the array should remain in case of using maximum + padding. + warmup_steps: int: Specify the number of steps to warm up + the learning rate + init_input_shape: Tuple[int, int]: Initialize the model with + a shape that is not (batch_size, length) + step_partition_spec: PartitionSpec: Partition the model for + training + training_time: Optional[str]: Set a time limit for the + training process + dataloader_num_workers: Optional[int]: Set the number of + workers used by pytorch's + dataloader_pin_memory: Optional[bool]: Pin the memory of the + dataloader + jax_distributed_config: Optional[dict]: Configure the jax + distributed backend + log_all_workers: bool: Log all workers in wandb, + wandb_entity: Optional[str]: Specify the entity to use when + logging to weights & biases + save_optimizer_state: bool: when ever to save optimizer + state and other args in checkpoint + step_start_point: Optional[int]: start training from given + step for example instead of starting training from step + 0 it will start from 20000 and leave the data behind + verbose: bool: when ever to turn verbose mode of or on + offload_device: jax.Device: device to be used to offload + parameters on + rapture_config: Optional[EasyDeLXRaptureConfig]: LoRA Config + for models + merge_lora_rapture_parameters: bool: whenever to merge lora + parameters with original parameters before saving + state_apply_fn_kwarguments_to_model: Optional[dict]: + state_apply_fn_kwarguments_to_model is a dictionary that + be used to apply the parameters and extra things that + you want to deliver to model. + remove_unused_columns: bool: when ever to remove the unused + data columns from dataset + force_batch_and_gradient_accumulation_steps_calculation: + bool: whether to force batch and gradient to be applied + as total batch_size (e.g total_batch_size = + total_batch_size * gradient_accumulation_steps be + applied) + performance_mode: bool: whether to optimize the whole + training process this will cut off some logging options + and optimize training process. + neftune_noise_alpha: Optional[float]: If not `None`, this + will activate NEFTune noise embeddings. This has been + proven to drastically improve model performances for + instruction fine-tuning. + loaded_model_config_kwargs: Optional[dict]: config key + arguments to be passed to the model while being loaded + **kwargs: Pass keyword, variable-length argument list + from checkpoint """ super().__init__() @@ -368,15 +426,16 @@ def __call__(self): return {k: v for k, v in self.__dict__.items()} def get_meter_dict(self): - """ - The get_meter_dict function is used to return a dictionary of the hyperparameters. + """The get_meter_dict function is used to return a dictionary of the hyperparameters. The function iterates through all the attributes in the class and returns a dictionary with the key as "hyperparameters/{k}" and value as v for each attribute k,v in self.__dict__ if it is an instance of int, float, str, bool or torch.Tensor. - :param self: Represent the instance of the class - :return: A dictionary of hyperparameters + Args: + self: Represent the instance of the class + Returns: + A dictionary of hyperparameters """ return { f"hyperparameters/{k}": v for k, v in self.__dict__.items() if @@ -384,13 +443,14 @@ def get_meter_dict(self): } def get_wandb_init(self) -> Run | RunDisabled | None: - """ - The get_wandb_init function is a helper function that returns the wandb.init() call with + """The get_wandb_init function is a helper function that returns the wandb.init() call with the project name, config object, and tags set to appropriate values for this model. - :param self: Pass the class instance to the function - :return: A wandb or None + Args: + self: Pass the class instance to the function + Returns: + A wandb or None """ return wandb.init( project=f"EasyDeL-{self.model_name}", @@ -423,8 +483,7 @@ def string_func(it_self): return string def get_path(self): - """ - The get_path function returns a pathlib.Path object, which is a class that + """The get_path function returns a pathlib.Path object, which is a class that represents file paths and provides methods for interacting with the files at those paths. The get_path function takes no arguments and returns an instance of the Path class initialized with two arguments: self.save_dir (a string) and @@ -432,38 +491,42 @@ def get_path(self): store our model checkpoints, while the model name will be used to create unique filenames for each checkpoint. - :param self: Represent the instance of the class - :return: A pathlib + Args: + self: Represent the instance of the class + Returns: + A pathlib """ return pathlib.Path( self.save_dir, self.model_name ) def ckpt_path_exists(self): - """ - The ckpt_path_exists function checks to see if the path exists. If it does not, then it creates a new directory. + """The ckpt_path_exists function checks to see if the path exists. If it does not, then it creates a new directory. - :param self: Represent the instance of the class - :return: A path + Args: + self: Represent the instance of the class + Returns: + A path """ path = self.get_path() if not path.exists(): path.mkdir(parents=True) def get_mesh(self): - """ - The get_mesh function is used to create a mesh object that can be used + """The get_mesh function is used to create a mesh object that can be used to define the geometry of the device. The mesh object contains two arrays: a list of vertices and a list of faces. Each face is defined by three indices, which correspond to three vertices in the vertex array. The get_mesh function is called when creating an instance of DeviceGeometry, which is then passed into an instance of DeviceSimulation. - :param self: Refer to the object itself - :return: A mesh object with the device array shape and the mesh names + Args: + self: Refer to the object itself + Returns: + A mesh object with the device array shape and the mesh names """ return Mesh( create_device_mesh( @@ -489,15 +552,16 @@ def get_optimizer_and_scheduler( ) def get_streaming_checkpointer(self): - """ - The get_streaming_checkpointer function is used to save the model's weights. + """The get_streaming_checkpointer function is used to save the model's weights. The streaming checkpointer saves the model's weights in a file called "checkpoint" and then saves a copy of that file with an incrementing number appended to it (e.g., checkpoint_001, checkpoint_002, etc.). This allows you to keep multiple versions of your trained models. - :param self: Represent the instance of the class - :return: A CheckpointManager object + Args: + self: Represent the instance of the class + Returns: + A CheckpointManager object """ return CheckpointManager( os.path.join(self.save_dir, self.model_name), @@ -506,15 +570,16 @@ def get_streaming_checkpointer(self): ) def get_board(self): - """ - The get_board function is a helper function that returns a TensorBoard object. + """The get_board function is a helper function that returns a TensorBoard object. The TensorBoard object is used to log the training and validation loss, as well as the accuracy of the model during training. The get_board function takes no arguments, and returns an instance of torch.utils.tensorboard SummaryWriter class. - :param self: Represent the instance of the class - :return: A summary-writer object + Args: + self: Represent the instance of the class + Returns: + A summary-writer object """ return torch.utils.tensorboard.SummaryWriter( log_dir=str(self.get_path()), diff --git a/src/python/easydel/trainer/utils.py b/src/python/easydel/trainer/utils.py index 5bc82170d..77d71982c 100644 --- a/src/python/easydel/trainer/utils.py +++ b/src/python/easydel/trainer/utils.py @@ -170,8 +170,7 @@ def _collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None def tolist(x): - """ - from HF + """from HF Args: x: @@ -186,8 +185,7 @@ def tolist(x): class DataCollatorForCompletionOnlyLM: - """ - Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index' + """Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index' when they do not come from the assistant. This ensures that the loss is only calculated on the completion made by the assistant. """ @@ -270,9 +268,7 @@ def _whole_word_mask(self, input_tokens: List[str], max_predictions=512): return mask_labels def jax_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]: - """ - Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. - """ + """Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.""" labels = np.copy(inputs) probability_matrix = np.full(labels.shape, 0.15) if special_tokens_mask is None: @@ -434,8 +430,7 @@ def format_dataset(examples): def instructions_formatting_function(tokenizer: AutoTokenizer): - r""" - from TRL + r"""from TRL return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer apply chat template to the dataset """ @@ -463,8 +458,7 @@ def format_dataset(examples): def get_formatting_func_from_dataset( dataset: Union[Dataset, "ConstantLengthDataset"], tokenizer: AutoTokenizer # type: ignore ) -> Optional[Callable]: - r""" - from TRL + r"""from TRL Finds the correct formatting function based on the dataset structure. Currently supported datasets are: - `ChatML` with [{"role": str, "content": str}] - `instruction` with [{"prompt": str, "completion": str}] diff --git a/src/python/easydel/trainer/vision_causal_language_model_trainer/__init__.py b/src/python/easydel/trainer/vision_causal_language_model_trainer/__init__.py index c7fbe5c66..1bd9d261a 100644 --- a/src/python/easydel/trainer/vision_causal_language_model_trainer/__init__.py +++ b/src/python/easydel/trainer/vision_causal_language_model_trainer/__init__.py @@ -1,15 +1,15 @@ -from .modelling_output import VisionCausalLMTrainerOutput as VisionCausalLMTrainerOutput -from .fwd_bwd_functions import ( - create_vision_casual_language_model_train_step as create_vision_casual_language_model_train_step, - create_vision_casual_language_model_evaluation_step as create_vision_casual_language_model_evaluation_step, - VisionCausalLanguageModelStepOutput as VisionCausalLanguageModelStepOutput -) -from .vision_causal_language_model_trainer import VisionCausalLanguageModelTrainer as VisionCausalLanguageModelTrainer - -__all__ = ( - "create_vision_casual_language_model_train_step", - "create_vision_casual_language_model_evaluation_step", - "VisionCausalLanguageModelStepOutput", - "VisionCausalLanguageModelTrainer", - "VisionCausalLMTrainerOutput" -) +from .modelling_output import VisionCausalLMTrainerOutput as VisionCausalLMTrainerOutput +from .fwd_bwd_functions import ( + create_vision_casual_language_model_train_step as create_vision_casual_language_model_train_step, + create_vision_casual_language_model_evaluation_step as create_vision_casual_language_model_evaluation_step, + VisionCausalLanguageModelStepOutput as VisionCausalLanguageModelStepOutput +) +from .vision_causal_language_model_trainer import VisionCausalLanguageModelTrainer as VisionCausalLanguageModelTrainer + +__all__ = ( + "create_vision_casual_language_model_train_step", + "create_vision_casual_language_model_evaluation_step", + "VisionCausalLanguageModelStepOutput", + "VisionCausalLanguageModelTrainer", + "VisionCausalLMTrainerOutput" +) diff --git a/src/python/easydel/trainer/vision_causal_language_model_trainer/fwd_bwd_functions.py b/src/python/easydel/trainer/vision_causal_language_model_trainer/fwd_bwd_functions.py index f0263e66d..d2b08c945 100644 --- a/src/python/easydel/trainer/vision_causal_language_model_trainer/fwd_bwd_functions.py +++ b/src/python/easydel/trainer/vision_causal_language_model_trainer/fwd_bwd_functions.py @@ -1,156 +1,165 @@ -from fjformer.func.loss_func import cross_entropy_loss_and_accuracy - -import jax -from jax.sharding import PartitionSpec -from jax import numpy as jnp -from fjformer import ( - with_sharding_constraint -) -import chex -from ...etils.easystate import EasyDeLState -from flax.struct import dataclass - - -@dataclass -class VisionCausalLanguageModelStepOutput: - loss: chex.Array - text_loss: chex.Array - text_accuracy: chex.Array - vision_loss: chex.Array - vision_accuracy: chex.Array - - -def create_vision_casual_language_model_train_step(partition_spec=PartitionSpec(("dp", "fsdp"), "sp")): - """ - The create_vision_casual_language_model_train_step function is a training step function that takes in the current - state of the model,and a batch of data. It then calculates the loss and accuracy for this batch, and returns - an updated state with new parameters based on these gradients. - - :param partition_spec: Specify which devices the model will be split across - :return: A casual_language_model_train_step function that takes in the current state of the model, - - """ - - def vision_casual_language_model_train_step(state, batch) -> [ - EasyDeLState, - chex.Array, - VisionCausalLanguageModelStepOutput - ]: - """ - The vision_casual_language_model_train_step function is a training step function that takes in the current state - of the model and a batch of data. It then calculates the loss and accuracy for this batch, - and returns an updated state with new parameters based on these gradients. - - :param state: Store the model parameters - :param batch: Pass the data to the model - :return: A tuple of (state, loss, VisionCausalLanguageModelStepOutput) - - """ - batch = with_sharding_constraint(batch, partition_spec) - - def calculate_loss(params): - labels = batch.get("labels", None) - if labels is None: - labels = batch["input_ids"][..., 1:] - else: - labels = labels[..., 1:] - label_vision_mask = batch.pop("label_vision_mask") - - model_outputs = state.apply_fn(params=params, **batch, return_dict=True) - logits = model_outputs.logits - aux_loss = getattr(model_outputs, "aux_loss", None) - - vision_loss, vision_accuracy = cross_entropy_loss_and_accuracy( - logits[:, :-1, :], - jnp.where(label_vision_mask, labels, 0), - batch["attention_mask"].astype(jnp.float32)[:, 1:] * label_vision_mask - ) - text_loss, text_accuracy = cross_entropy_loss_and_accuracy( - logits[:, :-1, :], - jnp.where(label_vision_mask, 0, labels), - batch["attention_mask"].astype(jnp.float32)[:, 1:] * (1.0 - label_vision_mask) - ) - - loss = 0.5 * (vision_loss + text_loss + (aux_loss if aux_loss is not None else 0.)) - - return loss, VisionCausalLanguageModelStepOutput( - loss=loss, - text_accuracy=text_accuracy, - vision_accuracy=vision_accuracy, - text_loss=text_loss, - vision_loss=vision_loss - ) - - grad_fn = jax.value_and_grad(calculate_loss, has_aux=True) - (loss__, metrics), grad = grad_fn(state.params) - state = state.apply_gradients(grads=grad) - return state, loss__, metrics - - return vision_casual_language_model_train_step - - -def create_vision_casual_language_model_evaluation_step(partition_spec=PartitionSpec(("dp", "fsdp"), "sp")): - """ - The create_vision_casual_language_model_evaluation_step function is used to create a function that calculates the - loss and accuracy of a model. It takes in a set of parameters, which are then passed into the state.apply_fn function - to generate logits for each token in the batch. The cross entropy loss and accuracy are then calculated from these - logits. - - :param partition_spec: Specify the partitioning of the model parameters - :return: A function that can be used to calculate the loss and accuracy of a model - - """ - - def vision_casual_language_model_evaluation_step(state, batch) -> [ - EasyDeLState, - chex.Array, - VisionCausalLanguageModelStepOutput - ]: - """ - The vision_casual_language_model_train_step function is a training step function that takes in the current state - of the model and a batch of data. It then calculates the loss and accuracy for this batch, - and returns an updated state with new parameters based on these gradients. - - :param state: Store the model parameters - :param batch: Pass the data to the model - :return: A tuple of (state, loss, VisionCausalLanguageModelStepOutput) - - """ - batch = with_sharding_constraint(batch, partition_spec) - - def calculate_loss(params): - labels = batch.get("labels", None) - if labels is None: - labels = batch["input_ids"][..., 1:] - else: - labels = labels[..., 1:] - label_vision_mask = batch.pop("label_vision_mask") - model_outputs = state.apply_fn(params=params, **batch, return_dict=True) - logits = model_outputs.logits - aux_loss = getattr(model_outputs, "aux_loss", None) - - vision_loss, vision_accuracy = cross_entropy_loss_and_accuracy( - logits[:, :-1, :], - jnp.where(label_vision_mask, labels, 0), - batch["attention_mask"].astype(jnp.float32)[:, 1:] * label_vision_mask - ) - text_loss, text_accuracy = cross_entropy_loss_and_accuracy( - logits[:, :-1, :], - jnp.where(label_vision_mask, 0, labels), - batch["attention_mask"].astype(jnp.float32)[:, 1:] * (1.0 - label_vision_mask) - ) - - loss = 0.5 * (vision_loss + text_loss + (aux_loss if aux_loss is not None else 0.)) - - return loss, VisionCausalLanguageModelStepOutput( - loss=loss, - text_accuracy=text_accuracy, - vision_accuracy=vision_accuracy, - text_loss=text_loss, - vision_loss=vision_loss - ) - - loss__, metrics = calculate_loss(state.params) - return loss__, metrics - - return vision_casual_language_model_evaluation_step +from fjformer.func.loss_func import cross_entropy_loss_and_accuracy + +import jax +from jax.sharding import PartitionSpec +from jax import numpy as jnp +from fjformer import ( + with_sharding_constraint +) +import chex +from ...etils.easystate import EasyDeLState +from flax.struct import dataclass + + +@dataclass +class VisionCausalLanguageModelStepOutput: + loss: chex.Array + text_loss: chex.Array + text_accuracy: chex.Array + vision_loss: chex.Array + vision_accuracy: chex.Array + + +def create_vision_casual_language_model_train_step(partition_spec=PartitionSpec(("dp", "fsdp"), "sp")): + """The create_vision_casual_language_model_train_step function is a training step function that takes in the current + state of the model,and a batch of data. It then calculates the loss and accuracy for this batch, and returns + an updated state with new parameters based on these gradients. + + Args: + partition_spec: Specify which devices the model will be split + across + + Returns: + A casual_language_model_train_step function that takes in the + current state of the model, + """ + + def vision_casual_language_model_train_step(state, batch) -> [ + EasyDeLState, + chex.Array, + VisionCausalLanguageModelStepOutput + ]: + """The vision_casual_language_model_train_step function is a training step function that takes in the current state + of the model and a batch of data. It then calculates the loss and accuracy for this batch, + and returns an updated state with new parameters based on these gradients. + + Args: + state: Store the model parameters + batch: Pass the data to the model + + Returns: + A tuple of (state, loss, + VisionCausalLanguageModelStepOutput) + """ + batch = with_sharding_constraint(batch, partition_spec) + + def calculate_loss(params): + labels = batch.get("labels", None) + if labels is None: + labels = batch["input_ids"][..., 1:] + else: + labels = labels[..., 1:] + label_vision_mask = batch.pop("label_vision_mask") + + model_outputs = state.apply_fn(params=params, **batch, return_dict=True) + logits = model_outputs.logits + aux_loss = getattr(model_outputs, "aux_loss", None) + + vision_loss, vision_accuracy = cross_entropy_loss_and_accuracy( + logits[:, :-1, :], + jnp.where(label_vision_mask, labels, 0), + batch["attention_mask"].astype(jnp.float32)[:, 1:] * label_vision_mask + ) + text_loss, text_accuracy = cross_entropy_loss_and_accuracy( + logits[:, :-1, :], + jnp.where(label_vision_mask, 0, labels), + batch["attention_mask"].astype(jnp.float32)[:, 1:] * (1.0 - label_vision_mask) + ) + + loss = 0.5 * (vision_loss + text_loss + (aux_loss if aux_loss is not None else 0.)) + + return loss, VisionCausalLanguageModelStepOutput( + loss=loss, + text_accuracy=text_accuracy, + vision_accuracy=vision_accuracy, + text_loss=text_loss, + vision_loss=vision_loss + ) + + grad_fn = jax.value_and_grad(calculate_loss, has_aux=True) + (loss__, metrics), grad = grad_fn(state.params) + state = state.apply_gradients(grads=grad) + return state, loss__, metrics + + return vision_casual_language_model_train_step + + +def create_vision_casual_language_model_evaluation_step(partition_spec=PartitionSpec(("dp", "fsdp"), "sp")): + """The create_vision_casual_language_model_evaluation_step function is used to create a function that calculates the + loss and accuracy of a model. It takes in a set of parameters, which are then passed into the state.apply_fn function + to generate logits for each token in the batch. The cross entropy loss and accuracy are then calculated from these + logits. + + Args: + partition_spec: Specify the partitioning of the model parameters + + Returns: + A function that can be used to calculate the loss and accuracy + of a model + """ + + def vision_casual_language_model_evaluation_step(state, batch) -> [ + EasyDeLState, + chex.Array, + VisionCausalLanguageModelStepOutput + ]: + """The vision_casual_language_model_train_step function is a training step function that takes in the current state + of the model and a batch of data. It then calculates the loss and accuracy for this batch, + and returns an updated state with new parameters based on these gradients. + + Args: + state: Store the model parameters + batch: Pass the data to the model + + Returns: + A tuple of (state, loss, + VisionCausalLanguageModelStepOutput) + """ + batch = with_sharding_constraint(batch, partition_spec) + + def calculate_loss(params): + labels = batch.get("labels", None) + if labels is None: + labels = batch["input_ids"][..., 1:] + else: + labels = labels[..., 1:] + label_vision_mask = batch.pop("label_vision_mask") + model_outputs = state.apply_fn(params=params, **batch, return_dict=True) + logits = model_outputs.logits + aux_loss = getattr(model_outputs, "aux_loss", None) + + vision_loss, vision_accuracy = cross_entropy_loss_and_accuracy( + logits[:, :-1, :], + jnp.where(label_vision_mask, labels, 0), + batch["attention_mask"].astype(jnp.float32)[:, 1:] * label_vision_mask + ) + text_loss, text_accuracy = cross_entropy_loss_and_accuracy( + logits[:, :-1, :], + jnp.where(label_vision_mask, 0, labels), + batch["attention_mask"].astype(jnp.float32)[:, 1:] * (1.0 - label_vision_mask) + ) + + loss = 0.5 * (vision_loss + text_loss + (aux_loss if aux_loss is not None else 0.)) + + return loss, VisionCausalLanguageModelStepOutput( + loss=loss, + text_accuracy=text_accuracy, + vision_accuracy=vision_accuracy, + text_loss=text_loss, + vision_loss=vision_loss + ) + + loss__, metrics = calculate_loss(state.params) + return loss__, metrics + + return vision_casual_language_model_evaluation_step diff --git a/src/python/easydel/trainer/vision_causal_language_model_trainer/modelling_output.py b/src/python/easydel/trainer/vision_causal_language_model_trainer/modelling_output.py index c2fbfa82e..3edb47aa1 100644 --- a/src/python/easydel/trainer/vision_causal_language_model_trainer/modelling_output.py +++ b/src/python/easydel/trainer/vision_causal_language_model_trainer/modelling_output.py @@ -1,15 +1,15 @@ -from dataclasses import dataclass -import jax -from typing import Any, Optional, Callable, Mapping -from ...etils.easystate import EasyDeLState - - -@dataclass -class VisionCausalLMTrainerOutput: - state: EasyDeLState - mesh: Optional[jax.sharding.Mesh] - checkpoint_manager: Any - gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None - shard_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None - last_save_file_name: Optional[str] = None - checkpoint_path: Optional[str] = None +from dataclasses import dataclass +import jax +from typing import Any, Optional, Callable, Mapping +from ...etils.easystate import EasyDeLState + + +@dataclass +class VisionCausalLMTrainerOutput: + state: EasyDeLState + mesh: Optional[jax.sharding.Mesh] + checkpoint_manager: Any + gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None + shard_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]] = None + last_save_file_name: Optional[str] = None + checkpoint_path: Optional[str] = None diff --git a/src/python/easydel/trainer/vision_causal_language_model_trainer/vision_causal_language_model_trainer.py b/src/python/easydel/trainer/vision_causal_language_model_trainer/vision_causal_language_model_trainer.py index 2196e1646..4ed265323 100644 --- a/src/python/easydel/trainer/vision_causal_language_model_trainer/vision_causal_language_model_trainer.py +++ b/src/python/easydel/trainer/vision_causal_language_model_trainer/vision_causal_language_model_trainer.py @@ -60,14 +60,16 @@ def collate_fn(batch): return collate_fn def configure_functions(self) -> TrainerConfigureFunctionFuncOutput: - """ - The configure_functions function is responsible for configuring the functions that will be used in training. + """The configure_functions function is responsible for configuring the functions that will be used in training. It does this by first defining a function called function_configurations, which initializes the model parameters and returns them as a EasyDeLState object. The EasyDeLState object contains all the information needed to train or evaluate on a batch of data, including: - :param self: Access the class attributes - :return: A TrainerConfigureFunctionFuncOutput object + Args: + self: Access the class attributes + + Returns: + A TrainerConfigureFunctionFuncOutput object """ def initialize_state_function(): @@ -315,18 +317,20 @@ def train( model_parameters: Optional[flax.core.FrozenDict] = None, state: Optional[EasyDeLState] = None ) -> VisionCausalLMTrainerOutput: - """ - The train function is the main function of this module. + """The train function is the main function of this module. It takes a model_parameters argument which can be used to load a pretrained model and finetune it. The train function returns an TrainerOutput object that contains the last saved file name, predict func, train state, mesh and checkpoint streamer. + Args: + self: Make the class methods aware of other methods and + attributes within the class + model_parameters: flax.core.FrozenDict: Load a pre-trained + model + state: Optional[EasyDeLState]: Ready to Use State - :param self: Make the class methods aware of other methods and attributes within the class - :param model_parameters: flax.core.FrozenDict: Load a pre-trained model - :param state: Optional[EasyDeLState]: Ready to Use State - :return: An object of type "TrainerOutput" - + Returns: + An object of type "TrainerOutput" """ def count_model_parameters(_p): diff --git a/src/python/easydel/transform/easydel_transform.py b/src/python/easydel/transform/easydel_transform.py index 106b80c20..653b877ff 100644 --- a/src/python/easydel/transform/easydel_transform.py +++ b/src/python/easydel/transform/easydel_transform.py @@ -40,16 +40,18 @@ def float_tensor_to_dtype(tensor, dtype): def match_keywords(string, ts, ns): - """ - The match_keywords function takes a string, and two lists of strings. + """The match_keywords function takes a string, and two lists of strings. The first list is the "must-have" keywords, and the second list is the "not-allowed" keywords. It returns True if all the must-have keywords are in string, but none of not allowed are in it. - :param string: Pass in the text that is being searched - :param ts: Specify the required keywords and ns is used to specify the non-required keywords - :param ns: Specify a list of negative keywords - :return: True if all the keywords in ts are present and none of the - + Args: + string: Pass in the text that is being searched + ts: Specify the required keywords and ns is used to specify the + non-required keywords + ns: Specify a list of negative keywords + + Returns: + True if all the keywords in ts are present and none of the """ for t in ts: if t not in string: @@ -75,27 +77,34 @@ def huggingface_to_easydel( remove_state_dict: bool = False, **kwargs ): - """ - The huggingface_to_easydel function takes a huggingface model's state_dict and converts it to an easydel + """The huggingface_to_easydel function takes a huggingface model's state_dict and converts it to an easydel model's flax_dict. The function is designed to be used in conjunction with the load_huggingface function, which loads a huggingface model from disk. The embedding layer name must be specified as well as the device on which the conversion will take place. - :param state_dict: Load the weights from a huggingface model - :param embedding_layer_names: List[str]: Identify the embedding layer in the huggingface model - :param device: Determine which device the model will be loaded on - :param layer_norm_names: Replaces weight or kernel with (scale) - :param shard_fns: Optional[Mapping[tuple, Callable]]: Sharding Function to be used to shard model - :param convert_to_8bit : bool: whenever to convert the into 8bit format - :param params_pattern_selection : Optional[re.Pattern]: patter to use to find the parameters of the model which will + Args: + state_dict: Load the weights from a huggingface model + embedding_layer_names: List[str]: Identify the embedding layer + in the huggingface model + device: Determine which device the model will be loaded on + layer_norm_names: Replaces weight or kernel with (scale) + shard_fns: Optional[Mapping[tuple, Callable]]: Sharding Function + to be used to shard model + convert_to_8bit: bool: whenever to convert the into 8bit format + params_pattern_selection: Optional[re.Pattern]: patter to use to + find the parameters of the model which will + dtype: jax.numpy.dtype: Specify the data type of the tensors + rnn_based_or_rwkv: bool: rnn_based_or_rwkv is a conditioner + which decide whenever it finds a value in tree + verbose: bool: whenever to log sharding or converting process + remove_state_dict: bool : whether to remove state dict during + the transforming process be converted to 8bit format. - :param dtype: jax.numpy.dtype: Specify the data type of the tensors - :param rnn_based_or_rwkv: bool: rnn_based_or_rwkv is a conditioner which decide whenever it finds a value in tree that start with time_mix_ it will automatically reshape that for easydel use case - :param verbose:bool: whenever to log sharding or converting process - :param remove_state_dict:bool : whether to remove state dict during the transforming process - :return: A dictionary of the weights and biases in a format that can be used by flax (it's an UnFlattenDict) + Returns: + A dictionary of the weights and biases in a format that can be + used by flax (it's an UnFlattenDict) """ embedding_layer_names = set(embedding_layer_names or []) layer_norm_names = set(layer_norm_names or []) @@ -151,14 +160,16 @@ def huggingface_to_easydel( def read_ckpt(path: [str, os.PathLike], shard_fns=None, add_extra_past_fix: list = None): - """ - The read_ckpt function reads a checkpoint file and returns the tensors in it. + """The read_ckpt function reads a checkpoint file and returns the tensors in it. - :param path: [str, os.PathLike]: Specify the path to the checkpoint file - :param shard_fns: Shard the tensors - :param add_extra_past_fix: list: Add an extra past to the key - :return: A dictionary of tensors - + Args: + path: [str, os.PathLike]: Specify the path to the checkpoint + file + shard_fns: Shard the tensors + add_extra_past_fix: list: Add an extra past to the key + + Returns: + A dictionary of tensors """ tensors = {} with open(path, "rb") as stream: @@ -175,15 +186,17 @@ def read_ckpt(path: [str, os.PathLike], shard_fns=None, add_extra_past_fix: list def save_ckpt(train_state, path, gather_fns=None, float_dtype=None): - """ - The save_ckpt function saves the state of a training run to disk. + """The save_ckpt function saves the state of a training run to disk. - :param train_state: Store the current state of the training process - :param path: Specify the location of the checkpoint file - :param gather_fns: Specify a function that will be used to convert the tensor to bytes - :param float_dtype: Convert the tensor to a specific dtype - :return: Nothing - + Args: + train_state: Store the current state of the training process + path: Specify the location of the checkpoint file + gather_fns: Specify a function that will be used to convert the + tensor to bytes + float_dtype: Convert the tensor to a specific dtype + + Returns: + Nothing """ train_state = to_state_dict(train_state) diff --git a/src/python/easydel/transform/falcon.py b/src/python/easydel/transform/falcon.py index 7b0d9ec2d..915f3dfb0 100644 --- a/src/python/easydel/transform/falcon.py +++ b/src/python/easydel/transform/falcon.py @@ -20,9 +20,7 @@ def match_keywords(string, ts, ns): def falcon_from_pretrained(model_id, device): - """ - return: Weight or Params for easydel Model , Config - """ + """return: Weight or Params for easydel Model , Config""" # Requested By vwxyzjn at https://github.com/erfanzar/EasyDeL/issues/15#issue-1881044170 config = FalconConfig.from_pretrained(model_id) model = FalconForCausalLM.from_pretrained(model_id) diff --git a/src/python/easydel/transform/llama.py b/src/python/easydel/transform/llama.py index 1a5e0bba4..9b4896b43 100644 --- a/src/python/easydel/transform/llama.py +++ b/src/python/easydel/transform/llama.py @@ -148,9 +148,7 @@ def llama_convert_flax_to_pt(flax_params, config: LlamaConfig, dtype=jnp.float16 def llama_easydel_to_hf(path, config: LlamaConfig): - """ - Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface) - """ + """Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface)""" torch_params = load_and_convert_checkpoint_to_torch(path) edited_params = {} for k, v in torch_params.items(): @@ -161,9 +159,7 @@ def llama_easydel_to_hf(path, config: LlamaConfig): def llama_from_pretrained(model_id, device): - """ - return: Weight or Params for easydel Model , Config - """ + """return: Weight or Params for easydel Model , Config""" config = LlamaConfig.from_pretrained(model_id) model = LlamaForCausalLM.from_pretrained(model_id) easydel_wights = llama_convert_hf_to_flax( diff --git a/src/python/easydel/transform/mistral.py b/src/python/easydel/transform/mistral.py index c00d3df6d..84915100b 100644 --- a/src/python/easydel/transform/mistral.py +++ b/src/python/easydel/transform/mistral.py @@ -248,9 +248,7 @@ def mistral_convert_flax_to_pt(flax_params, config: MistralConfig, dtype=jnp.flo def mistral_easydel_to_hf(path, config: MistralConfig): - """ - Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface) - """ + """Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface)""" torch_params = load_and_convert_checkpoint_to_torch(path) edited_params = {} for k, v in torch_params.items(): @@ -261,9 +259,7 @@ def mistral_easydel_to_hf(path, config: MistralConfig): def mistral_from_pretrained(model_id, device): - """ - return: Weight or Params for easydel Model , Config - """ + """return: Weight or Params for easydel Model , Config""" config = MistralConfig.from_pretrained(model_id) model = MistralForCausalLM.from_pretrained(model_id) easydel_wights = mistral_convert_hf_to_flax( diff --git a/src/python/easydel/transform/mpt.py b/src/python/easydel/transform/mpt.py index 1038b6270..dc083acbe 100644 --- a/src/python/easydel/transform/mpt.py +++ b/src/python/easydel/transform/mpt.py @@ -144,9 +144,7 @@ def mpt_convert_flax_to_pt_1b(state_dict_flax, n_layers: int, device=torch.devic def mpt_from_pretrained(model_id, device, **kwargs): - """ - return: Weight or Params for easydel Model , Config - """ + """return: Weight or Params for easydel Model , Config""" config = MptConfig.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, **kwargs) diff --git a/src/python/easydel/utils/prompters.py b/src/python/easydel/utils/prompters.py index 652c5a6d6..a696f0e3c 100644 --- a/src/python/easydel/utils/prompters.py +++ b/src/python/easydel/utils/prompters.py @@ -6,17 +6,20 @@ def antitoxin_prompter( prompt: str, system: typing.Optional[str] = None, ): - """ - The antitoxin_prompter function takes in a history of user-assistant interactions, + """The antitoxin_prompter function takes in a history of user-assistant interactions, a prompt from the user, and optionally a system response. It returns an input string that can be fed into the antitoxin model to generate an assistant response. - :param history: typing.List[str]: Pass in the history of the conversation - :param prompt: str: Pass the user's input to the assistant - :param system: typing.Optional[str]: Pass the system's response to the prompt + Args: + history: typing.List[str]: Pass in the history of the + conversation + prompt: str: Pass the user's input to the assistant + system: typing.Optional[str]: Pass the system's response to the + prompt :param : Store the history of user and assistant interaction - :return: A string that contains the user's prompt, - + + Returns: + A string that contains the user's prompt, """ sys_str = f"<|im_start|>system\n{system}<|im_end|>\n" if system is not None else "" histories = "" @@ -30,16 +33,18 @@ def antitoxin_prompter_chat_format( history: typing.List[str], system: typing.Optional[str] = None, ): - """ - The antitoxin_prompter_chat_format function takes a list of strings and returns a string. + """The antitoxin_prompter_chat_format function takes a list of strings and returns a string. The input is the history of the chat, which is a list of tuples where each tuple contains two strings: the user's message and the assistant's response. The output is formatted as follows: - :param history: typing.List[str]: Pass in the history of user and assistant messages - :param system: typing.Optional[str]: Pass in the system message + Args: + history: typing.List[str]: Pass in the history of user and + assistant messages + system: typing.Optional[str]: Pass in the system message :param : Store the history of the conversation - :return: A string that contains the system message and - + + Returns: + A string that contains the system message and """ sys_str = f"<|im_start|>system\n{system}<|im_end|>\n" if system is not None else "" histories = "" @@ -54,17 +59,20 @@ def llama2_prompter( system: typing.Optional[str] = None, ): - """ - The llama2_prompter function takes a history of user-system interactions, + """The llama2_prompter function takes a history of user-system interactions, a prompt for the next system response, and optionally a system response. It returns an LLAMA2 formatted string that can be used as input to the LLAMA2 model. - :param history: typing.List[str]: Store the history of user input and system response - :param prompt: str: Specify the prompt to be displayed - :param system: typing.Optional[str]: Indicate that the system is optional + Args: + history: typing.List[str]: Store the history of user input and + system response + prompt: str: Specify the prompt to be displayed + system: typing.Optional[str]: Indicate that the system is + optional :param : Specify the system's response - :return: A string that is a concatenation of the - + + Returns: + A string that is a concatenation of the """ do_strip = False if system is not None: @@ -84,18 +92,19 @@ def llama2_prompter_chat_format( system: str, messages: typing.List[str], ): - """ - The llama2_prompter_chat_format function takes a system message and a list of messages, + """The llama2_prompter_chat_format function takes a system message and a list of messages, and returns the formatted string that can be used to create an LLAMA2 chat file. The system message is optional, and if it is not provided then the function will return only the user messages. The user messages are expected to be in pairs: one for each speaker (system or human). The first element of each pair should be the name of that speaker. - :param system: str: Store the system message - :param messages: typing.List[str]: Pass in a list of strings + Args: + system: str: Store the system message + messages: typing.List[str]: Pass in a list of strings :param : Add the system message to the beginning of the chat - :return: A string that is the - + + Returns: + A string that is the """ if system is not None: string = [f'[INST] <>\n{system}\n<>\n\n'] diff --git a/src/python/easydel/utils/tensor_utils.py b/src/python/easydel/utils/tensor_utils.py index fe69a5487..f6e660343 100644 --- a/src/python/easydel/utils/tensor_utils.py +++ b/src/python/easydel/utils/tensor_utils.py @@ -5,21 +5,15 @@ def pt2np(array: torch.Tensor) -> np.array: - """ - Convert Pytorch Array to Numpy Array - """ + """Convert Pytorch Array to Numpy Array""" return array.detach().cpu().numpy() def np2jax(array: np.array) -> chex.Array: - """ - Convert Numpy Array to JAX Array - """ + """Convert Numpy Array to JAX Array""" return jnp.asarray(array) def pt2jax(array: torch.Tensor) -> chex.Array: - """ - Convert Pytorch Array to JAX Array - """ + """Convert Pytorch Array to JAX Array""" return np2jax(pt2np(array)) diff --git a/src/python/easydel/utils/utils.py b/src/python/easydel/utils/utils.py index b04ba7734..1de7175bf 100644 --- a/src/python/easydel/utils/utils.py +++ b/src/python/easydel/utils/utils.py @@ -14,14 +14,15 @@ class Timer: def __init__(self, name): - """ - The __init__ function is called when the class is instantiated. + """The __init__ function is called when the class is instantiated. It sets up the object with a name and initializes other variables. - :param self: Represent the instance of the class - :param name: Give the timer a name - :return: An instance of the class - + Args: + self: Represent the instance of the class + name: Give the timer a name + + Returns: + An instance of the class """ self.name_ = name self.elapsed_ = 0.0 @@ -29,53 +30,57 @@ def __init__(self, name): self.start_time = time.time() def start(self): - """ - The start function starts the timer. + """The start function starts the timer. Args: None - :param self: Access the attributes and methods of the class in python - :return: Nothing - + Args: + self: Access the attributes and methods of the class in + python + + Returns: + Nothing """ assert not self.started_, "timer has already been started" self.start_time = time.time() self.started_ = True def stop(self): - """ - The stop function stops the timer and adds the time elapsed since start was called to the total elapsed time. + """The stop function stops the timer and adds the time elapsed since start was called to the total elapsed time. + Args: + self: Represent the instance of the class - :param self: Represent the instance of the class - :return: The time elapsed since the start function was called - + Returns: + The time elapsed since the start function was called """ assert self.started_, "timer is not started" self.elapsed_ += time.time() - self.start_time self.started_ = False def reset(self): - """ - The reset function sets the elapsed time to 0.0 and the started flag to False. + """The reset function sets the elapsed time to 0.0 and the started flag to False. + + Args: + self: Represent the instance of the class - :param self: Represent the instance of the class - :return: True if the timer was running, false otherwise - + Returns: + True if the timer was running, false otherwise """ self.elapsed_ = 0.0 self.started_ = False def elapsed(self, reset=True): - """ - The elapsed function returns the elapsed time in seconds since the timer was started. + """The elapsed function returns the elapsed time in seconds since the timer was started. If reset is True, then it also resets the timer to zero and restarts it. If reset is False, then it leaves the timer running. - :param self: Represent the instance of the class - :param reset: Reset the timer - :return: The elapsed time in seconds - + Args: + self: Represent the instance of the class + reset: Reset the timer + + Returns: + The elapsed time in seconds """ started_ = self.started_ if self.started_: @@ -138,16 +143,18 @@ def __call__(self, name): def write(self, names, iteration, normalizer=1.0, reset=False): - """ - The write function is used to write the elapsed time of a timer to Tensorboard and/or Weights & Biases. - - :param self: Make the function a method of the class - :param names: Specify which timer(s) to write - :param iteration: Keep track of the number of iterations - :param normalizer: Normalize the time elapsed by a certain value - :param reset: Reset the timer after it has been written to tensorboard - :return: Nothing - + """The write function is used to write the elapsed time of a timer to Tensorboard and/or Weights & Biases. + + Args: + self: Make the function a method of the class + names: Specify which timer(s) to write + iteration: Keep track of the number of iterations + normalizer: Normalize the time elapsed by a certain value + reset: Reset the timer after it has been written to + tensorboard + + Returns: + Nothing """ assert normalizer > 0.0 for name in names: @@ -160,15 +167,16 @@ def write(self, names, iteration, normalizer=1.0, reset=False): wandb.log({f"timers/{name}": value}, step=iteration) def log(self, names, normalizer=1.0, reset=True): - """ - The log function is used to print the time elapsed for a given function. - - :param self: Represent the instance of the class - :param names: Specify the name of the timer that we want to log - :param normalizer: Normalize the time taken to run a function - :param reset: Reset the timer after logging - :return: The time taken for the given name - + """The log function is used to print the time elapsed for a given function. + + Args: + self: Represent the instance of the class + names: Specify the name of the timer that we want to log + normalizer: Normalize the time taken to run a function + reset: Reset the timer after logging + + Returns: + The time taken for the given name """ assert normalizer > 0.0 @@ -207,13 +215,15 @@ def get_mesh( shape: typing.Sequence[int] = (1, -1, 1, 1), axis_names: typing.Sequence[str] = ("dp", "fsdp", "tp", "sp") ): - """ - The get_mesh function is a helper function that creates a JAX Mesh object. - - :param shape: typing.Sequence[int]: Specify the shape of the array that is used to create the mesh - :param axis_names: typing.Sequence[int]: Specify the Axis Names in mesh - :return: A mesh object - + """The get_mesh function is a helper function that creates a JAX Mesh object. + + Args: + shape: typing.Sequence[int]: Specify the shape of the array that + is used to create the mesh + axis_names: typing.Sequence[int]: Specify the Axis Names in mesh + + Returns: + A mesh object """ from jax.sharding import Mesh from jax.experimental import mesh_utils