Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Ray train integration #312

Merged
merged 67 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
cd1d062
WIP
Yard1 Jan 6, 2023
ea38369
WIP
Yard1 Jan 13, 2023
ca6d7d1
Cleanup
Yard1 Jan 14, 2023
698985a
Nit
Yard1 Jan 14, 2023
993525b
Merge branch 'CarperAI:main' into ray_train_integration_2
Yard1 Jan 15, 2023
649c01e
Cleanup
Yard1 Jan 17, 2023
0793c95
Fixes
Yard1 Jan 17, 2023
f3bd70b
Merge branch 'main' into ray_train_integration_2
Yard1 Jan 17, 2023
9bdcbfe
Make sure master_port, master_addr are set
Yard1 Jan 17, 2023
496c891
Make private
Yard1 Jan 17, 2023
55022f2
Tweak
Yard1 Jan 18, 2023
df0530e
Restore wanddb
Yard1 Jan 18, 2023
7be5448
Add ray.init() back, remove unnecesary code
Yard1 Jan 18, 2023
fe0bf79
Merge branch 'main' into ray_train_integration_2
Yard1 Jan 27, 2023
b885912
Set ACCELERATE_TORCH_DEVICE
Yard1 Jan 27, 2023
bd9a769
Merge branch 'main' into ray_train_integration_2
Yard1 Feb 7, 2023
864d757
refactor(ray_tune): collapse files into `sweep.py` & fix w&b reports
maxreciprocate Feb 16, 2023
2cda942
feat(configs): add & revert back flat updating of the config
maxreciprocate Feb 16, 2023
bbe2ecb
fix(base_trainer): reenable w&b logging through ray-tune
maxreciprocate Feb 16, 2023
0092f60
revert(base_trainer): remove trlx's verbosity limit when under ray
maxreciprocate Feb 16, 2023
88e8d86
refactor(sweep): flatten code & remove debug prints
maxreciprocate Feb 16, 2023
caf13be
chore(configs/ppo_sweep): update variable names to nested structure
maxreciprocate Feb 16, 2023
fc13a3d
Merge branch 'main' into ray-train-integration
maxreciprocate Feb 16, 2023
b5b59ad
style(ray_trainer): satisfy black
maxreciprocate Feb 16, 2023
deb667a
style(sweep): satisfy isort
maxreciprocate Feb 16, 2023
55260fb
Merge branch 'main' into ray-train-integration
maxreciprocate Feb 22, 2023
9082cce
chore(configs/sweeps): update variable names to the nested structure
maxreciprocate Feb 27, 2023
02c0976
chore(ilql_sentiments): restructure config loading for sweeps
maxreciprocate Feb 27, 2023
9c39c1c
fix(sweep): rework `best_config` block
maxreciprocate Feb 27, 2023
f500740
chore(configs): disable `scheduler` in default configs
maxreciprocate Feb 27, 2023
2cf94cf
fix(accelerate_trainer): device mismatch from `get_device`
maxreciprocate Feb 27, 2023
ebe6308
chore(base_trainer): lower verbosity when under sweep
maxreciprocate Feb 27, 2023
0a1303e
chore(ppo_trainer): reenable w&b logging in `make_experience`
maxreciprocate Feb 27, 2023
1ac7b99
feat(scripts): add an example of setting up ray cluster on slurm
maxreciprocate Feb 27, 2023
b3664b6
style(sweep): satisfy isort
maxreciprocate Feb 27, 2023
614af8c
revert(base_trainer): remove logging verbosity changes under ray
maxreciprocate Feb 28, 2023
697217a
Merge branch 'main' into ray-train-integration
maxreciprocate Mar 1, 2023
77a71ca
Revert "Merge branch 'main' into ray-train-integration"
maxreciprocate Mar 2, 2023
840c9b2
Fix device mismatch, update to accelerate>=0.17.0 (#360)
Yard1 Mar 10, 2023
7c8f1aa
Merge branch 'main' into ray-train-integration
maxreciprocate Mar 13, 2023
812569b
merge(trainers): unmerge old merges, merge anew recent changes
maxreciprocate Mar 13, 2023
786709b
merge(models): merge renaming of the directory
maxreciprocate Mar 13, 2023
3377293
feat(base_trainer): enable w&b logging under ray
maxreciprocate Mar 14, 2023
0cf350e
feat(sweep): remove `default_config` from argparse
maxreciprocate Mar 14, 2023
722be0d
fix(default_configs): disable schedulers
maxreciprocate Mar 14, 2023
4cd3e61
feat(setup.cfg): pin `ray` wheel, update `accelerate` `deepspeed`
maxreciprocate Mar 14, 2023
a7e7bb4
merge: revert to upstream changes
maxreciprocate Mar 14, 2023
6fd5aef
fix(scripts/sweep): remove `default_config`
maxreciprocate Mar 15, 2023
ab5a860
merge(configs): remove yml files
maxreciprocate Mar 15, 2023
a189067
Merge branch 'main' into ray-train-integration
maxreciprocate Mar 15, 2023
9b262a3
merge(examples): upstream config usage
maxreciprocate Mar 15, 2023
c7ac679
fix(setup.cfg): condition ray's pinned wheel
maxreciprocate Mar 15, 2023
f9875f0
Use `AccelerateTrainer` from Ray (#386)
Yard1 Mar 23, 2023
8027bd3
Merge branch 'main' into ray-train-integration
maxreciprocate Mar 29, 2023
ee63dd9
chore(sweep): explicitly pin a GPU per worker
maxreciprocate Mar 30, 2023
a677e25
fix(base_trainer): remove checkpointing while under ray
maxreciprocate Mar 30, 2023
23d2c69
chore(README): update sweep instructions
maxreciprocate Mar 30, 2023
c69166a
feat(configs/sweeps): update with more values
maxreciprocate Mar 30, 2023
e14488b
style: satisfy flake
maxreciprocate Mar 30, 2023
81fd0e7
style: satisfy flake
maxreciprocate Mar 30, 2023
55a5aa9
revert(examples): remove redundant device selection
maxreciprocate Mar 30, 2023
962dfa3
style: satisfy black
maxreciprocate Mar 30, 2023
fd6f9c1
fix(base_trainer): report evaluation stats at the end
maxreciprocate Mar 31, 2023
69b0bb8
fix(base_trainer): log final stats for w&b
maxreciprocate Mar 31, 2023
45feae6
fix(examples/sentiments): improve hyperparameters
maxreciprocate Mar 31, 2023
ec638c3
style: satisfy black
maxreciprocate Mar 31, 2023
6c7e6f9
fix(README): add ray cluster manual creation instruction
maxreciprocate Mar 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ For more usage see the [NeMo README](./trlx/models)
#### Use Ray Tune to launch hyperparameter sweep

```bash
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py
ray start --head --port=6379
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/ddp.yaml --num_gpus 4 examples/ppo_sentiments.py
```

#### Benchmark your trlX fork against trlX's `main` branch
Expand Down
2 changes: 1 addition & 1 deletion configs/accelerate/ddp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: MULTI_GPU
downcast_bf16: no
dynamo_backend: 'NO'
dynamo_config: {}
fsdp_config: {}
gpu_ids: all
machine_rank: 0
Expand Down
2 changes: 1 addition & 1 deletion configs/accelerate/zero2-bf16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ deepspeed_config:
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: no
dynamo_backend: 'NO'
dynamo_config: {}
fsdp_config: {}
machine_rank: 0
main_training_function: main
Expand Down
2 changes: 1 addition & 1 deletion configs/accelerate/zero2-fp16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ deepspeed_config:
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: no
dynamo_backend: 'NO'
dynamo_config: {}
fsdp_config: {}
machine_rank: 0
main_training_function: main
Expand Down
2 changes: 1 addition & 1 deletion configs/accelerate/zero3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ deepspeed_config:
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: no
dynamo_backend: 'NO'
dynamo_config: {}
fsdp_config: {}
machine_rank: 0
main_training_function: main
Expand Down
21 changes: 15 additions & 6 deletions configs/sweeps/ilql_sweep.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@ tune_config:
metric: "metrics/sentiments"
search_alg: "random"
scheduler: "fifo"
num_samples: 32
num_samples: 64

lr:
# https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs
optimizer.kwargs.lr:
strategy: "loguniform"
values: [0.00001, 0.01]
tau:
values: [0.000001, 0.001]
method.tau:
strategy: "uniform"
values: [0.6, 0.9]
steps_for_target_q_sync:
method.steps_for_target_q_sync:
strategy: "choice"
values: [1, 5, 10]
alpha:
method.alpha:
strategy: "loguniform"
values: [0.001, 1.0]

# disable checkpointing for storage sake
train.checkpoint_interval:
strategy: "choice"
values: [10000000]
train.save_best:
strategy: "choice"
values: [false]
30 changes: 22 additions & 8 deletions configs/sweeps/ppo_sweep.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,26 @@ tune_config:
num_samples: 32

# https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs
lr:
optimizer.kwargs.lr:
strategy: "loguniform"
values: [0.00001, 0.01]
init_kl_coef:
strategy: "uniform"
values: [0, 0.2]
vf_coef:
strategy: "uniform"
values: [0.5, 2]
values: [0.000001, 0.001]
method.init_kl_coef:
strategy: "loguniform"
values: [0.0001, 0.2]
model.num_layers_unfrozen:
strategy: "choice"
values: [-1, 2, 6]
method.num_rollouts:
strategy: "choice"
values: [32, 128, 512]
method.target:
strategy: "choice"
values: [null, 1]

# disable checkpointing for storage sake
train.checkpoint_interval:
strategy: "choice"
values: [10000000]
train.save_best:
strategy: "choice"
values: [false]
2 changes: 1 addition & 1 deletion examples/ilql_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]:
trlx.train(
samples=imdb["text"],
rewards=imdb["label"],
eval_prompts=["I don't know much about Hungarian underground"] * 64,
eval_prompts=["I don't know much about Hungarian underground"] * 256,
metric_fn=metric_fn,
config=config,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/ppo_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def reward_fn(samples: List[str], **kwargs) -> List[float]:
trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about Hungarian underground"] * 64,
eval_prompts=["I don't know much about Hungarian underground"] * 256,
config=config,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ deepspeed_config:
zero3_init_flag: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_backend: 'NO'
dynamo_config: {}
fsdp_config: {}
gpu_ids: null
machine_rank: 0
Expand Down
40 changes: 40 additions & 0 deletions scripts/sweep-cw.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash
#SBATCH --job-name=trlx-sweep
#SBATCH --account=trlx
#SBATCH --partition=a100-cu117
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --mem=0
#SBATCH --output=%j
#SBATCH --exclusive

export NCCL_DEBUG=WARN
export NCCL_PROTO=simple
export FI_EFA_FORK_SAFE=1
export FI_LOG_LEVEL=1
export FI_EFA_USE_DEVICE_RDMA=1
export FI_EFA_ENABLE_SHM_TRANSFER=0
export FI_PROVIDER=efa
export FI_EFA_TX_MIN_CREDITS=64
# export CUDA_LAUNCH_BLOCKING=1

export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)

cd $TRLX
source $TRLX/venv-with-pinned-ray/bin/activate

ray start --head --port=6379 &

export HOSTNAMES=($HOSTNAMES)
for node in ${HOSTNAMES[@]:1}; do
echo "Starting ray worker @ $node"
srun --nodes=1 --ntasks=1 -w "$node" ray start --address $MASTER_ADDR:6379 --block &
done

sleep 10
ray status

NUM_GPUS=16
python -m trlx.sweep -y --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/zero2-bf16.yaml --num_gpus $NUM_GPUS examples/ppo_sentiments.py
# python -m trlx.sweep -y --config configs/sweeps/ilql_sweep.yml --default_config configs/ilql_config.yml --accelerate_config configs/accelerate/zero2-bf16.yaml --num_gpus $NUM_GPUS examples/ilql_sentiments.py
8 changes: 5 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,24 @@ license = MIT
[options]
packages = find:
install_requires =
accelerate>=0.16.0
accelerate>=0.17.1
attrs>=22.1.0
cattrs>=22.2.0
datasets
deepspeed>=0.7.3
deepspeed>=0.8.1
einops>=0.4.1
numpy>=1.23.2
torchtyping
transformers>=4.21.2
tqdm
rich
wandb>=0.13.5
ray>=2.0.1
tabulate>=0.9.0
networkx
tritonclient
ray@https://ray-ci-artifact-branch-public.s3.amazonaws.com/42bb0357a6fb13e4994789c824f3623f32869ad8/tmp/artifacts/.whl/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl ; python_version=="3.8"
ray@https://ray-ci-artifact-branch-public.s3.amazonaws.com/42bb0357a6fb13e4994789c824f3623f32869ad8/tmp/artifacts/.whl/ray-3.0.0.dev0-cp39-cp39-manylinux2014_x86_64.whl ; python_version=="3.9"
ray@https://ray-ci-artifact-branch-public.s3.amazonaws.com/42bb0357a6fb13e4994789c824f3623f32869ad8/tmp/artifacts/.whl/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl ; python_version=="3.10"

[options.extras_require]
bnb = bitsandbytes
Expand Down
18 changes: 16 additions & 2 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,27 @@ def from_dict(cls, config: Dict):

@classmethod
def update(cls, baseconfig: Dict, config: Dict):
update = {}
# unflatten a string variable name into a nested dictionary
# key1.key2.key3: value -> {key1: {key2: {key3: value}}}
for name, value in config.items():
if isinstance(value, dict):
update[name] = value
else:
*layers, var = name.split(".")
if layers:
d = update.setdefault(layers[0], {})
for layer in layers[1:]:
d = d.setdefault(layer, {})
d[var] = value

if not isinstance(baseconfig, Dict):
baseconfig = baseconfig.to_dict()

updates = set()
merged = merge(baseconfig, config, updates)
merged = merge(baseconfig, update, updates)

for param in config:
for param in update:
if param not in updates:
raise ValueError(f"parameter {param} is not present in the config (typo or a wrong config)")

Expand Down
16 changes: 8 additions & 8 deletions trlx/data/default_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def default_ppo_config():
model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=2),
tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"),
optimizer=OptimizerConfig(
name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
name="adamw", kwargs=dict(lr=3e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4)),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=3e-5)),
method=PPOConfig(
name="PPOConfig",
num_rollouts=128,
chunk_size=128,
ppo_epochs=4,
init_kl_coef=0.05,
target=6,
init_kl_coef=0.001,
target=None,
horizon=10000,
gamma=1,
lam=0.95,
Expand All @@ -61,7 +61,7 @@ def default_ilql_config():
return TRLConfig(
train=TrainConfig(
seq_length=64,
batch_size=32,
batch_size=128,
epochs=100,
total_steps=1000,
checkpoint_interval=1000,
Expand All @@ -75,7 +75,7 @@ def default_ilql_config():
name="adamw", kwargs=dict(lr=5.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
),
scheduler=SchedulerConfig(
name="cosine_annealing", kwargs=dict(T_max=1000, eta_min=5.0e-5) # train.total_steps
name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=5.0e-5) # train.total_steps
),
method=ILQLConfig(
name="ilqlconfig",
Expand All @@ -87,7 +87,7 @@ def default_ilql_config():
beta=0,
steps_for_target_q_sync=5,
two_qs=True,
gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=4, temperature=1.0),
gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=1, temperature=1.0),
),
)

Expand All @@ -110,7 +110,7 @@ def default_sft_config():
name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
),
scheduler=SchedulerConfig(
name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4) # train.total_steps
name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4) # train.total_steps
),
method=SFTConfig(
name="sftconfig",
Expand Down
Loading