Skip to content

Commit

Permalink
Merge pull request #270 from WentseChen/1102_ds
Browse files Browse the repository at this point in the history
DeepSpeed support
  • Loading branch information
huangshiyu13 authored Nov 9, 2023
2 parents 6470b87 + b6e78a2 commit b44efad
Show file tree
Hide file tree
Showing 16 changed files with 168 additions and 55 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ api_docs
*.json
opponent_pool
!/examples/selfplay/opponent_templates/tictactoe_opponent/info.json
!/examples/nlp/ds_config.json
!/examples/nlp/eval_ds_config.json
wandb_run
examples/dmc/new.gif
/examples/snake/submissions/rl/actor_2000.pth
Expand Down
8 changes: 8 additions & 0 deletions examples/nlp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ Users can train the dialog task via:
python train_ppo.py --config nlp_ppo.yaml
```

Users can train the dialog task with deepspeed via:

```shell
deepspeed train_ppo.py --config nlp_ppo_ds.yaml


```

After the training, users can chat with the agent via:

```shell
Expand Down
11 changes: 11 additions & 0 deletions examples/nlp/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"train_batch_size": 32,
"train_micro_batch_size_per_gpu": 16,
"steps_per_print": 10,
"zero_optimization": {
"stage": 2,
"reduce_bucket_size": 5e7,
"allgather_bucket_size": 5e7
},
"fp16": {"enabled": false, "loss_scale_window": 100}
}
10 changes: 10 additions & 0 deletions examples/nlp/eval_ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"train_batch_size": 32,
"train_micro_batch_size_per_gpu": 16,
"steps_per_print": 10,
"zero_optimization": {
"stage": 0,
"offload_param": {"device": "cpu"}
},
"fp16": {"enabled": false}
}
14 changes: 6 additions & 8 deletions examples/nlp/nlp_ppo.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
seed: 0
lr: 2e-7
critic_lr: 2e-7
lr: 1e-7
critic_lr: 1e-7
run_dir: ./run_results/
log_interval: 1
use_valuenorm: true
use_adv_normalize: true
wandb_entity: "openrl-lab"
ppo_epoch: 5
episode_length: 112
episode_length: 128
num_mini_batch: 20
use_share_model: true

hidden_size: 1
use_deepspeed: true
use_fp16: true


model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog
env:
Expand All @@ -25,9 +24,8 @@ vec_info_class:
id: "NLPVecInfo"
reward_class:
id: "NLPReward"
args: {
"intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier",
args: {
"ref_model": "rajkumarrrk/gpt2-fine-tuned-on-daily-dialog",
"use_deepspeed": true,
"intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier",
}

38 changes: 38 additions & 0 deletions examples/nlp/nlp_ppo_ds.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
seed: 0
lr: 1e-7
critic_lr: 1e-7
run_dir: ./run_results/
log_interval: 1
use_valuenorm: true
use_adv_normalize: true
wandb_entity: "openrl-lab"
ppo_epoch: 5
episode_length: 128
num_mini_batch: 20
use_share_model: true

hidden_size: 1

use_deepspeed: true
use_fp16: false
use_offload: false
deepspeed_config: ds_config.json

model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog
env:
args: {
'tokenizer_path': 'gpt2',
'data_path': 'daily_dialog',
}
vec_info_class:
id: "NLPVecInfo"
reward_class:
id: "NLPReward"
args: {
"use_deepspeed": true,
"ref_ds_config": "eval_ds_config.json",
"ref_model": "rajkumarrrk/gpt2-fine-tuned-on-daily-dialog",
"intent_ds_config": "eval_ds_config.json",
"intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier",
}

6 changes: 6 additions & 0 deletions examples/nlp/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
def train():
# create environment
cfg_parser = create_config_parser()
try:
import deepspeed

cfg_parser = deepspeed.add_config_arguments(cfg_parser)
except:
print("choose not to use deepspeed in the nlp task")
cfg = cfg_parser.parse_args()

env_num = 5
Expand Down
8 changes: 7 additions & 1 deletion openrl/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,11 +1246,17 @@ def create_config_parser():
type=int,
help="local_rank",
)
parser.add_argument(
"--use_offload",
default=False,
type=bool,
help="whether to use offload (deepspeed)",
)
parser.add_argument(
"--use_fp16",
default=False,
type=bool,
help="whether to use fp16",
help="whether to use fp16 (deepspeed)",
)

return parser
25 changes: 18 additions & 7 deletions openrl/envs/nlp/rewards/intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,28 @@
from openrl.supports.opengpu.manager import LocalGPUManager


def get_eval_ds_config(offload, stage=0):
def get_default_ds_config(offload=True, stage=0, fp16=True):
device = "cpu" if offload else "none"
zero_opt_dict = {
"stage": stage,
"offload_param": {"device": device},
}
return {
"train_batch_size": 28,
"train_micro_batch_size_per_gpu": 7,
"train_batch_size": 16,
"train_micro_batch_size_per_gpu": 16,
"steps_per_print": 10,
"zero_optimization": zero_opt_dict,
"fp16": {"enabled": True},
"fp16": {"enabled": fp16},
}


class Intent:
def __init__(
self, intent_model: str, intent_coeff: float = 1.0, use_deepspeed: bool = True
self,
intent_model: str,
intent_coeff: float = 1.0,
use_deepspeed: bool = True,
ds_config: str = "default",
) -> None:
super().__init__()

Expand Down Expand Up @@ -65,10 +69,17 @@ def __init__(self, input_ids, attention_mask):
if self.use_deepspeed:
import deepspeed

if ds_config == "default":
ds_config = get_default_ds_config()
else:
import json

with open(ds_config) as file:
ds_config = json.load(file)

self._device = "cuda"
self._model = self._model.to("cuda")
ds_config = get_eval_ds_config(offload=True, stage=0)
self._model, *_ = deepspeed.initialize(model=self._model, config=ds_config)
self._device = "cuda"
else:
if torch.cuda.is_available():
manager = LocalGPUManager()
Expand Down
35 changes: 24 additions & 11 deletions openrl/envs/nlp/rewards/kl_penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@
from openrl.envs.nlp.utils.distribution import CategoricalDistribution


def get_eval_ds_config(offload, stage=0):
def get_default_ds_config(offload=True, stage=0, fp16=True):
device = "cpu" if offload else "none"
zero_opt_dict = {
"stage": stage,
"offload_param": {"device": device},
}
return {
"train_batch_size": 28, #
"train_micro_batch_size_per_gpu": 7,
"train_batch_size": 16,
"train_micro_batch_size_per_gpu": 16,
"steps_per_print": 10,
"zero_optimization": zero_opt_dict,
"fp16": {"enabled": True},
"fp16": {"enabled": fp16},
}


Expand All @@ -32,10 +32,10 @@ def __init__(
ref_model: str,
apply_model_parallel: bool = True,
use_deepspeed: bool = True,
ds_config: str = "default",
):
super().__init__()
self.use_deepspeed = use_deepspeed
self.use_fp16 = True

# reference model
self._apply_model_parallel = apply_model_parallel
Expand All @@ -50,7 +50,19 @@ def __init__(
if self.use_deepspeed:
import deepspeed

ds_config = get_eval_ds_config(offload=True, stage=0)
if ds_config == "default":
self.use_fp16 = True
ds_config = get_default_ds_config()
else:
import json

with open(ds_config) as file:
ds_config = json.load(file)
if "fp16" in ds_config:
self.use_fp16 = ds_config["fp16"]["enabled"]
else:
self.use_fp16 = False

self._ref_engine, *_ = deepspeed.initialize(model=self, config=ds_config)
elif torch.cuda.is_available():
if self._apply_model_parallel and self._ref_net.is_parallelizable:
Expand Down Expand Up @@ -94,11 +106,12 @@ def __call__(
self._ref_net, input_ids, past_model_kwargs
)

if self.use_fp16:
for key in ["input_ids", "position_ids"]:
model_inputs[key] = model_inputs[key].half().int()
for key in ["attention_mask"]:
model_inputs[key] = model_inputs[key].half()
if self.use_deepspeed:
if self.use_fp16:
for key in ["input_ids", "position_ids"]:
model_inputs[key] = model_inputs[key].half().int()
for key in ["attention_mask"]:
model_inputs[key] = model_inputs[key].half()

with torch.no_grad():
output = self._ref_net(output_hidden_states=True, **model_inputs)
Expand Down
4 changes: 3 additions & 1 deletion openrl/modules/networks/policy_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ def __init__(
self._influence_layer_N = cfg.influence_layer_N
self._use_policy_vhead = cfg.use_policy_vhead
self._recurrent_N = cfg.recurrent_N
self._use_fp16 = cfg.use_fp16 and cfg.use_deepspeed
self.use_half = use_half
self.tpdv = dict(dtype=torch.float32, device=device)

self._use_fp16 = cfg.use_fp16
assert not (cfg.use_fp16 and not cfg.use_deepspeed)

policy_obs_shape = get_policy_obs_space(input_space)

if "Dict" in policy_obs_shape.__class__.__name__:
Expand Down
5 changes: 4 additions & 1 deletion openrl/modules/networks/policy_value_network_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@ def __init__(
self.disable_drop_out = disable_drop_out
self._use_valuenorm = cfg.use_valuenorm
super(CausalLMActorCriticPolicy, self).__init__(
cfg,
input_space,
action_space,
model_name=cfg.model_path,
device=device,
)
self.use_half = use_half
self._use_fp16 = cfg.use_fp16 and cfg.use_deepspeed
self.tpdv = dict(dtype=torch.float32, device=device)

self._use_fp16 = cfg.use_fp16
assert not (cfg.use_fp16 and not cfg.use_deepspeed)

def get_actor_para(self):
return self._policy_model.parameters()

Expand Down
9 changes: 5 additions & 4 deletions openrl/modules/networks/utils/nlp/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,14 @@ class GenerationOutputs:
class LMActorCriticPolicy(nn.Module):
def __init__(
self,
cfg: Any,
observation_space: DictSpace,
action_space: Discrete,
model_name: str,
optimizer_kwargs: Dict[str, Any] = {},
weight_decay: float = 1e-6,
use_sde: bool = None,
apply_model_parallel: bool = False, # TODO
# apply_model_parallel: bool = True,
optimizer_class: torch.optim.Optimizer = torch.optim.AdamW,
generation_kwargs: Dict[str, Any] = {},
prompt_truncation_side: str = "left",
Expand All @@ -146,15 +147,15 @@ def __init__(
optimizer_kwargs (Dict[str, Any], optional): optimizer kwargs. Defaults to {}.
weight_decay (float, optional): weight decay. Defaults to 1e-6.
use_sde (bool, optional): Use state-dependent exploration. Defaults to None.
apply_model_parallel (bool, optional): whether to apply model parallel. Defaults to True.
apply_model_parallel (bool, optional): default to use model parallel when not using deepspeed.
optimizer_class (torch.optim.Optimizer, optional): Optimizer class. Defaults to torch.optim.AdamW.
generation_kwargs (Dict[str, Any], optional): generation parameters for rollout. Defaults to {}.
prompt_truncation_side (str, optional): truncation side for prompt text. Defaults to "left".
"""
super().__init__()
self._use_deepspeed = True # TODO
self._use_deepspeed = cfg.use_deepspeed
self._action_space = action_space
self._apply_model_parallel = apply_model_parallel
self._apply_model_parallel = not cfg.use_deepspeed # TODO
self._build_model_heads(model_name, config, device)
self._action_dist = CategoricalDistribution(self._action_space.n)
self._generation_kwargs = generation_kwargs
Expand Down
2 changes: 2 additions & 0 deletions openrl/modules/networks/utils/nlp/causal_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
class CausalLMActorCriticPolicy(LMActorCriticPolicy):
def __init__(
self,
cfg: Any,
observation_space: DictSpace,
action_space: Discrete,
model_name: str,
Expand All @@ -36,6 +37,7 @@ def __init__(
device: str = "cpu",
):
super().__init__(
cfg,
observation_space,
action_space,
model_name,
Expand Down
Loading

0 comments on commit b44efad

Please sign in to comment.