-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ __pycache__/ | |
.idea | ||
*.so | ||
env.ipynb | ||
env.ipynb | ||
env.py | ||
pallas_env.py | ||
test_EasyDeLState.py | ||
io.py | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
version: 2 | ||
build: | ||
os: ubuntu-22.04 | ||
tools: | ||
python: "3.12" | ||
sphinx: | ||
configuration: docs/conf.py | ||
python: | ||
install: | ||
- requirements: docs/requirements.txt |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,109 +1,40 @@ | ||
## EasyDeLState | ||
|
||
EasyDeLState is a cool feature in easydel and have a lot of options like | ||
storing `Model Parameters`, _Optimizer State, Model Config, Model Type, Optimizer and Scheduler Configs_ | ||
|
||
Let see and examples of using EasyDeLState | ||
|
||
### Fine-tuning | ||
|
||
Fine-tuning from a previous State or a new state | ||
|
||
```python | ||
from easydel import ( | ||
AutoEasyDeLConfig, | ||
EasyDeLState | ||
) | ||
from transformers import AutoTokenizer | ||
from jax import numpy as jnp, lax | ||
import jax | ||
|
||
huggingface_model_repo_id = "REPO_ID" | ||
checkpoint_name = "CKPT_NAME" | ||
|
||
state = EasyDeLState.from_pretrained( | ||
pretrained_model_name_or_path=huggingface_model_repo_id, | ||
filename=checkpoint_name, | ||
optimizer="adamw", | ||
scheduler="none", | ||
tx_init=None, | ||
device=jax.devices('cpu')[0], # Offload Device | ||
dtype=jnp.bfloat16, | ||
param_dtype=jnp.bfloat16, | ||
precision=lax.Precision("fastest"), | ||
sharding_axis_dims=(1, -1, 1, 1), | ||
sharding_axis_names=("dp", "fsdp", "tp", "sp"), | ||
query_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None), | ||
key_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None), | ||
value_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None), | ||
bias_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), None, None, None), | ||
attention_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None), | ||
shard_attention_computation=False, | ||
input_shape=(1, 1), | ||
backend=None, | ||
init_optimizer_state=False, | ||
free_optimizer_state=True, | ||
verbose=True, | ||
state_shard_fns=None, | ||
) | ||
|
||
config = AutoEasyDeLConfig.from_pretrained( | ||
huggingface_model_repo_id | ||
) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained( | ||
huggingface_model_repo_id, | ||
trust_remote_code=True | ||
) | ||
|
||
max_length = config.max_position_embeddings | ||
|
||
configs_to_initialize_model_class = { | ||
'config': config, | ||
'dtype': jnp.bfloat16, | ||
'param_dtype': jnp.bfloat16, | ||
'input_shape': (8, 8) | ||
} | ||
``` | ||
|
||
`EasyDeLState` also has `.load_state()` and `.save_state()` with some other usable options like `.free_opt_state()` | ||
which | ||
free optimizer state or `.shard_params()` which shard parameters you can read docs in order to find out more about these | ||
options. | ||
|
||
### Converting to Huggingface and Pytorch | ||
|
||
Let see how you can convert a EasyDeLMistral Model to Huggingface Pytorch Mistral Model from a trained State | ||
|
||
```python | ||
|
||
from transformers import MistralForCausalLM | ||
from easydel import ( | ||
AutoEasyDeLConfig, | ||
EasyDeLState, | ||
easystate_to_huggingface_model | ||
) | ||
import jax | ||
|
||
huggingface_model_repo_id = "REPO_ID" | ||
|
||
config = AutoEasyDeLConfig.from_pretrained( | ||
huggingface_model_repo_id | ||
) | ||
with jax.default_device(jax.devices("cpu")[0]): | ||
model = easystate_to_huggingface_model( | ||
state=EasyDeLState.load_state( | ||
"PATH_TO_CKPT", | ||
input_shape=(8, 2048) | ||
), # You can Pass EasyDeLState here | ||
base_huggingface_module=MistralForCausalLM, | ||
config=config, | ||
) | ||
|
||
model = model.half() # it's a huggingface model now | ||
``` | ||
|
||
### Other Use Cases | ||
|
||
`EasyDeLState` have a general use you can use it everywhere in easydel for example for a stand-alone model | ||
, serve, fine-tuning and many other features, it's up to you to test how creative you are 😇. | ||
**EasyDeLState: A Snapshot of Your EasyDeL Model** | ||
|
||
The `EasyDeLState` class acts like a comprehensive container that holds all the essential information about your EasyDeL | ||
model at a given point in time. Think of it as a snapshot of your model. It includes: | ||
|
||
* **Training Progress:** | ||
* `step`: Tracks the current training step. | ||
* **Model Itself:** | ||
* `module`: Holds the actual instance of your EasyDeL model. | ||
* `module_config`: Stores the model's configuration settings. | ||
* `module_config_args`: Keeps track of arguments used to create the configuration (useful for reloading). | ||
* `apply_fn`: References the core function that applies your model to data. | ||
* **Learned Parameters:** | ||
* `params`: Contains the trained weights and biases of your model. | ||
* **Optimizer Information:** | ||
* `tx`: Stores the optimizer you're using to update the model's parameters (e.g., AdamW). | ||
* `opt_state`: Keeps track of the optimizer's internal state (this is important for things like momentum in | ||
optimizers). | ||
* `tx_init`: Remembers the initial settings used to create the optimizer (again, for reloading purposes). | ||
* **Additional Settings:** | ||
* `hyperparameters`: Provides a flexible place to store other hyperparameters related to your model or training | ||
process. | ||
|
||
**Key Capabilities of EasyDeLState:** | ||
|
||
* **Initialization (`create`)**: Lets you create a brand new `EasyDeLState` to start training. | ||
* **Loading (`load`, `load_state`, `from_pretrained`)**: Enables you to reload a saved model from a checkpoint file or | ||
even a pre-trained model from a repository like Hugging Face Hub. | ||
* **Saving (`save_state`)**: Allows you to save your model's current state, including its parameters and optimizer | ||
state. | ||
* **Optimizer Management (`apply_gradients`, `free_opt_state`, `init_opt_state`)**: Provides methods for updating the | ||
model's parameters using gradients, releasing optimizer memory, and re-initializing the optimizer if needed. | ||
* **Sharding (`shard_params`)**: Helps you distribute your model's parameters efficiently across multiple devices ( | ||
important for training large models). | ||
* **PyTorch Conversion (`to_pytorch`)**: Gives you a way to convert your EasyDeL model to its PyTorch equivalent. | ||
|
||
**In Essence:** | ||
|
||
`EasyDeLState` streamlines the process of managing, saving, loading, and even converting your EasyDeL models. It ensures | ||
that you can easily work with your models and maintain consistency throughout your machine learning workflow. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Minimal makefile for Sphinx documentation | ||
# | ||
|
||
# You can set these variables from the command line, and also | ||
# from the environment for the first two. | ||
SPHINXOPTS ?= | ||
SPHINXBUILD ?= sphinx-build | ||
SOURCEDIR = . | ||
BUILDDIR = _build | ||
|
||
# Put it first so that "make" without argument is like "make help". | ||
help: | ||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | ||
|
||
.PHONY: help Makefile | ||
|
||
# Catch-all target: route all unknown targets to Sphinx using the new | ||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). | ||
%: Makefile | ||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import os | ||
import sys | ||
|
||
sys.path.insert(0, os.path.abspath("..")) | ||
|
||
project = "EasyDeL" | ||
copyright = "2023, Erfan Zare Chavoshi - EasyDeL" | ||
author = "Erfan Zare Chavoshi" | ||
|
||
extensions = [ | ||
"sphinx.ext.autodoc", | ||
"sphinx.ext.napoleon", | ||
"sphinx.ext.viewcode", | ||
"sphinx.ext.intersphinx", | ||
"sphinx_autodoc_typehints", | ||
] | ||
|
||
templates_path = ["_templates"] | ||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] | ||
|
||
intersphinx_mapping = { | ||
"jax": ("https://jax.readthedocs.io/en/latest/", None), | ||
"pytorch": ("https://pytorch.org/docs/stable/", None), | ||
"numpy": ("https://numpy.org/doc/stable/", None), | ||
} | ||
|
||
html_theme = "sphinx_book_theme" | ||
html_static_path = ["_static"] | ||
html_css_files = [ | ||
"custom.css", | ||
] | ||
|
||
source_suffix = [".rst", ".md", ".ipynb"] |
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.