Skip to content

Commit

Permalink
Merge pull request #103 from SeanNaren/feat/deepspeed
Browse files Browse the repository at this point in the history
Initial Lightning DeepSpeed Integration
  • Loading branch information
minimaxir authored Mar 29, 2021
2 parents 12a647c + 967d33c commit fd2cfca
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 58 deletions.
22 changes: 10 additions & 12 deletions aitextgen/aitextgen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pytorch_lightning.plugins import DeepSpeedPlugin
from transformers import (
GPT2LMHeadModel,
GPT2TokenizerFast,
Expand Down Expand Up @@ -651,24 +652,21 @@ def train(
if not is_gpu_used:
n_gpu = 0

# use the deepseed plugin if installed and specified
# use the DeepSpeed plugin if installed and specified
deepspeed_plugin = None
# if is_gpu_used and use_deepspeed:
# deepspeed_config = gen_deepspeed_config(
# self.get_device(), learning_rate, weight_decay
# )
# deepspeed_plugin = DeepSpeedPlugin(deepseed_config)
# logger.info("Using DeepSpeed training.")
# logger.warning(
# "deepspeed was attempted to be used, but was not installed. "
# + "Using normal training behavior."
# )
if is_gpu_used and use_deepspeed:
deepspeed_plugin = DeepSpeedPlugin()
logger.info("Using DeepSpeed training.")
if not fp16:
logger.info("Setting FP16 to True for DeepSpeed ZeRO Training.")
fp16 = True


train_params = dict(
accumulate_grad_batches=gradient_accumulation_steps,
gpus=n_gpu,
max_steps=num_steps,
gradient_clip_val=max_grad_norm if not fp16 else 0,
gradient_clip_val=max_grad_norm,
checkpoint_callback=False,
logger=loggers if loggers else False,
weights_summary=None,
Expand Down
44 changes: 0 additions & 44 deletions aitextgen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,47 +172,3 @@ def skip_special_tokens(tensor, device, special_token_ids):
~tensor.unsqueeze(1).eq(special_token_id_tensor.unsqueeze(1)).any(1)
].tolist()


def gen_deepspeed_config(device, lr, weight_decay):
"""Deepspeed OneBitAdam config.
Adapted from https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html#deepspeed
Args:
device ([type]): Device for training
lr ([type]): Learning rate
weight_decay ([type]): Weight decay
"""

deepspeed_config = {
"zero_allow_untested_optimizer": True,
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": lr,
"betas": [0.998, 0.999],
"eps": 1e-5,
"weight_decay": weight_decay,
"cuda_aware": "cuda" in device,
},
},
"scheduler": {
"type": "WarmupLR",
"params": {
"last_batch_iteration": -1,
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 100,
},
},
"zero_optimization": {
"stage": 2, # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning)
"cpu_offload": True, # Enable Offloading optimizer state/calculation to the host CPU
"contiguous_gradients": True, # Reduce gradient fragmentation.
"overlap_comm": True, # Overlap reduce/backward operation of gradients for speed.
"allgather_bucket_size": 2e8, # Number of elements to all gather at once.
"reduce_bucket_size": 2e8, # Number of elements we reduce/allreduce at once.
},
}

return deepspeed_config
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers>=4.3.0
fire>=0.3.0
pytorch-lightning>=1.2.0
pytorch-lightning>=1.2.3
torch>=1.6.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
install_requires=[
"transformers>=4.3.0",
"fire>=0.3.0",
"pytorch-lightning>=1.2.0",
"pytorch-lightning>=1.2.3",
"torch>=1.6.0",
],
)

0 comments on commit fd2cfca

Please sign in to comment.