Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove tyro #1176

Merged
merged 42 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
c88e5f6
refactor
vwxyzjn Dec 8, 2023
0ebeb38
Remove tyro in `ppo.py`
vwxyzjn Dec 11, 2023
527d96f
quick update
vwxyzjn Dec 11, 2023
789108a
update default args
vwxyzjn Dec 11, 2023
e7f9580
Merge branch 'main' into refactor3
vwxyzjn Jan 4, 2024
d7e66c3
quick push
vwxyzjn Jan 4, 2024
fa2a4c8
precommit
vwxyzjn Jan 4, 2024
67a8f53
refactor
vwxyzjn Jan 4, 2024
fe3f766
quick change
vwxyzjn Jan 9, 2024
34ea7d7
remove tyro
vwxyzjn Jan 9, 2024
af724f4
quick change
vwxyzjn Jan 9, 2024
800b9fd
precommit
vwxyzjn Jan 9, 2024
c257d85
quick change
vwxyzjn Jan 9, 2024
7960804
fix hello_world
vwxyzjn Jan 9, 2024
782fdd3
remove docstring diffences
vwxyzjn Jan 9, 2024
3c1f3a0
add `module load cuda/12.1`
vwxyzjn Jan 9, 2024
da745c8
push changes
vwxyzjn Jan 10, 2024
61506ff
precommit
vwxyzjn Jan 10, 2024
04f5a7b
make dpo runnable
vwxyzjn Jan 10, 2024
83ad452
fix circular import
vwxyzjn Jan 10, 2024
56bea91
quick fix
vwxyzjn Jan 10, 2024
5936fa9
refactor
vwxyzjn Jan 11, 2024
c636cfd
quick update
vwxyzjn Jan 11, 2024
0022af6
path change
vwxyzjn Jan 11, 2024
cca6940
update plots
vwxyzjn Jan 11, 2024
7b5e7b2
Merge branch 'main' into refactor3
vwxyzjn Jan 11, 2024
bfde066
fix docs
vwxyzjn Jan 11, 2024
a6addfd
quick change
vwxyzjn Jan 11, 2024
040e32d
Update trl/trainer/model_config.py
vwxyzjn Jan 12, 2024
348faf0
Update trl/trainer/model_config.py
vwxyzjn Jan 12, 2024
e646352
Update trl/trainer/utils.py
vwxyzjn Jan 12, 2024
02e3f3c
Update examples/scripts/dpo.py
vwxyzjn Jan 12, 2024
f4d0840
Merge branch 'main' into refactor3
vwxyzjn Jan 24, 2024
d355e38
address comments. use attn_implementation
vwxyzjn Jan 24, 2024
87fe921
precommit
vwxyzjn Jan 24, 2024
bc1a7aa
remove duplicate code
vwxyzjn Jan 24, 2024
161a102
update peft.py
vwxyzjn Jan 24, 2024
884aac1
fix test no op dep
vwxyzjn Jan 24, 2024
5a496f7
Update trl/trainer/utils.py
vwxyzjn Jan 26, 2024
bd5ab8c
Apply suggestions from code review
vwxyzjn Jan 26, 2024
455157e
precommit
vwxyzjn Jan 26, 2024
74834d2
add docs
vwxyzjn Jan 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions benchmark/benchmark_and_report.sh
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
#### Step 1: create a work directory:
# this is necessary because another github action job will remove
# the entire directory, which slurm depends on.
# https://stackoverflow.com/questions/4632028/how-to-create-a-temporary-directory
MY_SLURM_TMP_DIR=/fsx/costa/slurm_tmpdir
mkdir -p $MY_SLURM_TMP_DIR
WORK_DIR=`mktemp -d -p "$MY_SLURM_TMP_DIR"`
cp -r "$PWD" "$WORK_DIR"
cd "$WORK_DIR/$(basename "$PWD")"
echo WORK_DIR: $WORK_DIR

#### Step 2: actual work starts:
echo PATH is $PATH
echo PYTHONPATH is $PYTHONPATH
echo whcih python is $(which python)

export WANDB_ENTITY=huggingface
bash $BENCHMARK_SCRIPT > output.txt

Expand Down
2 changes: 1 addition & 1 deletion benchmark/benchmark_level1.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# hello world experiment
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
Expand Down
4 changes: 2 additions & 2 deletions benchmark/benchmark_level2.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# compound experiments: gpt2xl + grad_accu
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_gpt2xl_grad_accu --model_name gpt2-xl --mini_batch_size 16 --gradient_accumulation_steps 8 --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
Expand All @@ -12,7 +12,7 @@ python benchmark/benchmark.py \

# compound experiments: Cerebras-GPT-6.7B + deepspeed zero2 + grad_accu
python benchmark/benchmark.py \
--command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --ppo_config.exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --ppo_config.batch_size 32 --ppo_config.mini_batch_size 32 --ppo_config.log_with wandb --ppo_config.model_name cerebras/Cerebras-GPT-6.7B --ppo_config.reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \
--command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --batch_size 32 --mini_batch_size 32 --log_with wandb --model_name cerebras/Cerebras-GPT-6.7B --reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
Expand Down
8 changes: 4 additions & 4 deletions benchmark/benchmark_level3.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## w/ and w/o gradient accumulation
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_step_grad_accu --mini_batch_size 1 --gradient_accumulation_steps 128 --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
Expand All @@ -12,7 +12,7 @@ python benchmark/benchmark.py \

## w/ different models (gpt2, gpt2-xl, falcon, llama2)
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_gpt2 --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
Expand All @@ -22,7 +22,7 @@ python benchmark/benchmark.py \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_falcon_rw_1b --model_name tiiuae/falcon-rw-1b --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
Expand All @@ -35,7 +35,7 @@ python benchmark/benchmark.py \

## w/ and w/o PEFT
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_peft --use_peft --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_peft --use_peft --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
Expand Down
2 changes: 1 addition & 1 deletion benchmark/post_github_comment.sbatch
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
#SBATCH --job-name=trl
#SBATCH --partition=production-cluster
#SBATCH --partition=hopper-cpu
#SBATCH --ntasks=1
#SBATCH --output=slurm/logs/%x_%j.out

Expand Down
3 changes: 3 additions & 0 deletions benchmark/regression_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
BENCHMARK_SCRIPT="benchmark/benchmark_level1.sh" \
BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level1_plot.sh" \
bash benchmark/benchmark_and_report.sh
6 changes: 3 additions & 3 deletions benchmark/trl.slurm_template
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#!/bin/bash
#SBATCH --job-name=trl
#SBATCH --partition=production-cluster
#SBATCH --partition=hopper-prod
#SBATCH --gpus-per-task={{gpus_per_task}}
#SBATCH --cpus-per-gpu={{cpus_per_gpu}}
#SBATCH --ntasks={{ntasks}}
#SBATCH --output=slurm/logs/%x_%j.out
#SBATCH --array={{array}}
#SBATCH --exclude=ip-26-0-156-239,ip-26-0-148-151,ip-26-0-146-212,ip-26-0-145-137,ip-26-0-146-249,ip-26-0-146-149,ip-26-0-147-233,ip-26-0-145-154,ip-26-0-144-35,ip-26-0-144-189,ip-26-0-146-183,ip-26-0-147-120,ip-26-0-144-95,ip-26-0-145-193
##SBATCH --exclude=ip-26-0-149-199
{{nodes}}

seeds={{seeds}}
seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]}

echo "Running task $SLURM_ARRAY_TASK_ID with seed: $seed"
srun {{command}} --ppo_config.seed $seed
srun {{command}} --seed $seed
14 changes: 7 additions & 7 deletions docs/source/sentiment_tuning.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ accelerate launch examples/scripts/ppo.py # launches training
# 3. get help text and documentation
python examples/scripts/ppo.py --help
# 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16
python examples/scripts/ppo.py --ppo_config.log_with wandb --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 16
python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_accumulation_steps 16
```

Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
Expand All @@ -42,7 +42,7 @@ Below are some benchmark results for `examples/scripts/ppo.py`. To reproduce loc

```bash
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
Expand All @@ -61,7 +61,7 @@ python benchmark/benchmark.py \

```bash
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_step_grad_accu --mini_batch_size 1 --gradient_accumulation_steps 128 --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
Expand All @@ -79,7 +79,7 @@ python benchmark/benchmark.py \

```bash
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_gpt2 --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
Expand All @@ -89,7 +89,7 @@ python benchmark/benchmark.py \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_gpt2xl_grad_accu --model_name gpt2-xl --mini_batch_size 16 --gradient_accumulation_steps 8 --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
Expand All @@ -99,7 +99,7 @@ python benchmark/benchmark.py \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_falcon_rw_1b --model_name tiiuae/falcon-rw-1b --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
Expand All @@ -116,7 +116,7 @@ python benchmark/benchmark.py \

```
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_peft --use_peft --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_peft --use_peft --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
Expand Down
79 changes: 39 additions & 40 deletions examples/scripts/ddpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,51 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
python examples/scripts/ddpo.py \
--num_epochs=200 \
--train_gradient_accumulation_steps=1 \
--sample_num_steps=50 \
--sample_batch_size=6 \
--train_batch_size=3 \
--sample_num_batches_per_epoch=4 \
--per_prompt_stat_tracking=True \
--per_prompt_stat_tracking_buffer_size=32 \
--tracker_project_name="stable_diffusion_training" \
--log_with="wandb"
"""
import os
from dataclasses import dataclass, field

import numpy as np
import torch
import torch.nn as nn
import tyro
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from transformers import CLIPModel, CLIPProcessor
from transformers import CLIPModel, CLIPProcessor, HfArgumentParser

from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
from trl.import_utils import is_npu_available, is_xpu_available


@dataclass
class ScriptArguments:
hf_user_access_token: str
pretrained_model: str = "runwayml/stable-diffusion-v1-5"
"""the pretrained model to use"""
pretrained_revision: str = "main"
"""the pretrained model revision to use"""
hf_hub_model_id: str = "ddpo-finetuned-stable-diffusion"
"""HuggingFace repo to save model weights to"""
hf_hub_aesthetic_model_id: str = "trl-lib/ddpo-aesthetic-predictor"
"""HuggingFace model ID for aesthetic scorer model weights"""
hf_hub_aesthetic_model_filename: str = "aesthetic-model.pth"
"""HuggingFace model filename for aesthetic scorer model weights"""
use_lora: bool = True
"""Whether to use LoRA."""

ddpo_config: DDPOConfig = field(
default_factory=lambda: DDPOConfig(
num_epochs=200,
train_gradient_accumulation_steps=1,
sample_num_steps=50,
sample_batch_size=6,
train_batch_size=3,
sample_num_batches_per_epoch=4,
per_prompt_stat_tracking=True,
per_prompt_stat_tracking_buffer_size=32,
tracker_project_name="stable_diffusion_training",
log_with="wandb",
project_kwargs={
"logging_dir": "./logs",
"automatic_checkpoint_naming": True,
"total_limit": 5,
"project_dir": "./save",
},
)
pretrained_model: str = field(
default="runwayml/stable-diffusion-v1-5", metadata={"help": "the pretrained model to use"}
)
pretrained_revision: str = field(default="main", metadata={"help": "the pretrained model revision to use"})
hf_hub_model_id: str = field(
default="ddpo-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"}
)
hf_hub_aesthetic_model_id: str = field(
default="trl-lib/ddpo-aesthetic-predictor",
metadata={"help": "HuggingFace model ID for aesthetic scorer model weights"},
)
hf_hub_aesthetic_model_filename: str = field(
default="aesthetic-model.pth",
metadata={"help": "HuggingFace model filename for aesthetic scorer model weights"},
)
use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."})


class MLP(nn.Module):
Expand Down Expand Up @@ -192,14 +184,21 @@ def image_outputs_logger(image_data, global_step, accelerate_logger):


if __name__ == "__main__":
args = tyro.cli(ScriptArguments)
parser = HfArgumentParser((ScriptArguments, DDPOConfig))
args, ddpo_config = parser.parse_args_into_dataclasses()
ddpo_config.project_kwargs = {
"logging_dir": "./logs",
"automatic_checkpoint_naming": True,
"total_limit": 5,
"project_dir": "./save",
}
Comment on lines +189 to +194
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

required. Otherwise errors out.


pipeline = DefaultDDPOStableDiffusionPipeline(
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora
)

trainer = DDPOTrainer(
args.ddpo_config,
ddpo_config,
aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename),
prompt_fn,
pipeline,
Expand All @@ -208,4 +207,4 @@ def image_outputs_logger(image_data, global_step, accelerate_logger):

trainer.train()

trainer.push_to_hub(args.hf_hub_model_id, token=args.hf_user_access_token)
trainer.push_to_hub(args.hf_hub_model_id)
Loading
Loading