Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma #630

Merged
merged 73 commits into from
Apr 4, 2024
Merged

Gemma #630

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
1928a47
Create Gemma models.
solitude-alive Mar 31, 2024
e7e4c88
Create _model_builders.py for gemma.
solitude-alive Mar 31, 2024
a56ff40
Create _component_builders.py for gemma.
solitude-alive Mar 31, 2024
e764355
Create _model_utils.py for gemma.
solitude-alive Mar 31, 2024
47c693d
Update for activation Option.
solitude-alive Mar 31, 2024
2db033f
Create default 2B_full.yaml.
solitude-alive Mar 31, 2024
f6f29b3
Update to support the output weight is same as the token embedding.
solitude-alive Mar 31, 2024
76a7558
Update for support .safetensors weight file.
solitude-alive Mar 31, 2024
6371378
tie weight for gemma.
solitude-alive Mar 31, 2024
f3912da
Create comments
solitude-alive Mar 31, 2024
b355634
Create comments.
solitude-alive Mar 31, 2024
0f78903
Fix the lm_head.weight is not loaded from checkpoint.
solitude-alive Mar 31, 2024
8986464
Fix the bug for model need to tie weight and move it after FSDP wrapp…
solitude-alive Apr 1, 2024
f9301e3
complete todo to check the pad_id.
solitude-alive Apr 1, 2024
1bd6210
complete todo to check the position embedding.
solitude-alive Apr 1, 2024
eb53e96
complete todo to check the attention.
solitude-alive Apr 1, 2024
71952e9
Update Gemma model with tied weight on load and save model.
solitude-alive Apr 1, 2024
de4bc58
Update for tied model.
solitude-alive Apr 1, 2024
142b06e
Remove tie_weight during setup.
solitude-alive Apr 1, 2024
e71bc04
Update recipes/configs/gemma/2B_full.yaml
solitude-alive Apr 2, 2024
8ce9a1f
Add the model lm_head weight to the weight map manually if it is not …
solitude-alive Apr 2, 2024
a9d0aee
fix the typo.
solitude-alive Apr 2, 2024
683727d
Add safetensors.
solitude-alive Apr 2, 2024
62f8b91
Add blog for gemma.
solitude-alive Apr 2, 2024
aec3851
Merge remote-tracking branch 'origin/gemma' into gemma
solitude-alive Apr 2, 2024
620fdc1
Update 2B_full.yaml.
solitude-alive Apr 2, 2024
98eb166
Remove is_safetensors_available.
solitude-alive Apr 2, 2024
2a9fa63
Update the default activation to match the nn.Module.
solitude-alive Apr 2, 2024
6aa8574
Update docstring.
solitude-alive Apr 2, 2024
1486b1a
Remove the LoRA code.
solitude-alive Apr 2, 2024
640925d
Remove the LoRA code.
solitude-alive Apr 2, 2024
63f63a6
Remove the LoRA code.
solitude-alive Apr 2, 2024
e9eb5fe
Remove the gemma code.
solitude-alive Apr 3, 2024
1728802
Add gemma.
solitude-alive Apr 3, 2024
3621df3
Add GemmaRMSNorm.
solitude-alive Apr 3, 2024
fa602f7
Add GemmaRMSNorm.
solitude-alive Apr 3, 2024
de519fd
Add gemma_full_finetune_distributed.
solitude-alive Apr 3, 2024
82642c0
Create gemma_full_finetune_distributed.
solitude-alive Apr 3, 2024
c99092b
Update Gemma setting.
solitude-alive Apr 3, 2024
3268481
Add norm before calculating attention.
solitude-alive Apr 3, 2024
44a5473
Update docstring.
solitude-alive Apr 3, 2024
95c0faf
Update default max_seq_len.
solitude-alive Apr 3, 2024
1fc0ee5
separate utility function in checkpointer_utils for loading and savin…
solitude-alive Apr 3, 2024
007c2b2
Update safe_torch_load function.
solitude-alive Apr 3, 2024
cfb1175
Update shared weight setting.
solitude-alive Apr 3, 2024
05f9b1a
Remove break after step.
solitude-alive Apr 3, 2024
c297c0d
Add head_dim para for hf_to_tune functino.
solitude-alive Apr 3, 2024
70bec74
To solve conflict.
solitude-alive Apr 3, 2024
269fd24
Add parameters to TransformerDecoder.
solitude-alive Apr 3, 2024
e3e88d0
Merge branch 'main' into gemma
solitude-alive Apr 4, 2024
62f936d
Remove unnecessary comments.
solitude-alive Apr 4, 2024
e17ce75
Move GemmaRMSNorm to torchtune/models/gemma/rms_norm.py
solitude-alive Apr 4, 2024
1940802
Move GemmaRMSNorm to torchtune/models/gemma/rms_norm.py
solitude-alive Apr 4, 2024
1f62925
Update GemmaRMSNorm.
solitude-alive Apr 4, 2024
c3e99a4
Update GemmaRMSNorm.
solitude-alive Apr 4, 2024
32e5e10
Update GemmaRMSNorm.
solitude-alive Apr 4, 2024
9889a59
Update GemmaTransformerDecoder.
solitude-alive Apr 4, 2024
b8b6b6d
Update GemmaTransformerDecoder.
solitude-alive Apr 4, 2024
6385652
Update docstring.
solitude-alive Apr 4, 2024
ddacd6c
Update para name.
solitude-alive Apr 4, 2024
83e3f69
Update Gemma 2B_full.yaml.
solitude-alive Apr 4, 2024
da89d81
Update shared weight setting.
solitude-alive Apr 4, 2024
7370f67
Update transformer.py.
solitude-alive Apr 4, 2024
59cb14c
fix typo.
solitude-alive Apr 4, 2024
59e6a4e
Update default setting.
solitude-alive Apr 4, 2024
390a77c
Update default setting.
solitude-alive Apr 4, 2024
7adf487
Remove _model_utils.py.
solitude-alive Apr 4, 2024
8ba1412
Update GemmaTransformerDecoder logits.
solitude-alive Apr 4, 2024
53a3ddf
Update default parameter.
solitude-alive Apr 4, 2024
23b3a02
Add save json to out_dir.
solitude-alive Apr 5, 2024
e7c99c3
Update gemma.yaml
solitude-alive Apr 5, 2024
29eeba9
Remove gemma.yaml
solitude-alive Apr 5, 2024
321f59e
fix the path not exist.
solitude-alive Apr 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions recipes/configs/gemma/2B_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Config for multi-device full finetuning in full_finetune_distributed.py
# using a gemma 2B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download --repo-id google/gemma-2b \
# --hf-token <HF_TOKEN> \
# --output-dir /tmp/gemma2
#
# To launch on 4 devices, run the following command from root:
# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
# --config gemma/2B_full \
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
# --config gemma/2B_full \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only when the model is being fine-tuned on 2+ GPUs.


# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: /tmp/gemma/tokenizer.model

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.gemma.gemma_2b

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma/
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
output_dir: /tmp/gemma
model_type: GEMMA
share_weights:
share_weights: True
weight_tying_config: {
"output": "tok_embeddings"
}
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1


# Training env
device: cuda

# Distributed
cpu_offload: False

# Memory management
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-gemma-finetune
log_every_n_steps: null
Loading
Loading