Skip to content

Commit

Permalink
akoumparouli/add_check_param_hashes_across_dp_replicas (NVIDIA#9811)
Browse files Browse the repository at this point in the history
* Riva and k2 ASR WFST decoding (2) (NVIDIA#9391)

* upload

Signed-off-by: Aleksandr Laptev <alaptev@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add comments and use case

Signed-off-by: Aleksandr Laptev <alaptev@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: GNroy <GNroy@users.noreply.github.com>

* add initial doc

Signed-off-by: Aleksandr Laptev <alaptev@nvidia.com>

* fix doc and k2+cuda eval

Signed-off-by: Aleksandr Laptev <alaptev@nvidia.com>

* isolate decoder components installation and fix suggestions

Signed-off-by: Aleksandr Laptev <alaptev@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: GNroy <GNroy@users.noreply.github.com>

* fix trailing newline

Signed-off-by: Aleksandr Laptev <alaptev@nvidia.com>

---------

Signed-off-by: Aleksandr Laptev <alaptev@nvidia.com>
Signed-off-by: GNroy <GNroy@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: GNroy <GNroy@users.noreply.github.com>
Co-authored-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Add DdpParamParityChecker Callback

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Improve messaging

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Rename to DdpParityChecker

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Add ddp test

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* rename to ddp_parity_checker

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* remove red. imports

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* test fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* missign import

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* ignore test

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* add missing import

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* another missing import

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* make limit_val_batches int

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* remove dup file

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* AG groups decisions on DDP parity

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* fix test

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>

* Exclude from pytest

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Add L2_NeMo_2_GPT_DDP_Param_Parity_check to NeMo_CICD_Test.needs

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>

---------

Signed-off-by: Aleksandr Laptev <alaptev@nvidia.com>
Signed-off-by: GNroy <GNroy@users.noreply.github.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
Co-authored-by: Aleksandr Laptev <alaptev@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: GNroy <GNroy@users.noreply.github.com>
Co-authored-by: Vladimir Bataev <vbataev@nvidia.com>
Co-authored-by: akoumpa <akoumpa@users.noreply.github.com>
  • Loading branch information
6 people authored and WoodieDudy committed Aug 26, 2024
1 parent 4cb744e commit 58b1d04
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 1 deletion.
17 changes: 17 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4753,6 +4753,22 @@ jobs:
rm -rf examples/llm/gpt_pretrain_results
rm -rf examples/llm/gpt_index_mappings
L2_NeMo_2_GPT_DDP_Param_Parity_check:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/lightning/test_ddp_parity_checker.py \
--vocab-path=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
--merges-path=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
--data-path=/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document
AFTER_SCRIPT: |
rm -rf examples/llm/gpt_pretrain_results
rm -rf examples/llm/gpt_index_mappings
Nemo_CICD_Test:
needs:
- gpu-test
Expand Down Expand Up @@ -4859,6 +4875,7 @@ jobs:
- Speech_Checkpoints_tests
#- OPTIONAL_L2_Stable_Diffusion_Training
- L2_NeMo_2_GPT_Pretraining_no_transformer_engine
- L2_NeMo_2_GPT_DDP_Param_Parity_check
if: always()
runs-on: ubuntu-latest
steps:
Expand Down
3 changes: 2 additions & 1 deletion nemo/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from nemo.lightning.pytorch.callbacks.ddp_parity_checker import DdpParityChecker
from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
from nemo.lightning.pytorch.callbacks.nsys import NsysCallback
Expand All @@ -6,7 +7,6 @@
from nemo.lightning.pytorch.callbacks.progress_bar import MegatronProgressBar
from nemo.lightning.pytorch.callbacks.progress_printer import ProgressPrinter


__all__ = [
"ModelCheckpoint",
"ModelTransform",
Expand All @@ -15,4 +15,5 @@
"MegatronProgressBar",
"ProgressPrinter",
"PreemptionCallback",
"DdpParityChecker",
]
74 changes: 74 additions & 0 deletions nemo/lightning/pytorch/callbacks/ddp_parity_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from functools import cache

import torch
from megatron.core.utils import check_param_hashes_across_dp_replicas
from pytorch_lightning.callbacks.callback import Callback

from nemo.lightning import io
from nemo.utils import logging


@cache
def pl_has_dist_opt_with_ovelap(trainer):
optim_config = getattr(getattr(trainer.strategy.model, 'optim', None), 'config', None)
if not getattr(optim_config, 'use_distributed_optimizer', False):
return False
if not getattr(optim_config, 'overlap_param_gather', False):
return False
return True


def pl_check_param_hashes_across_dp_replicas(trainer):
if pl_has_dist_opt_with_ovelap(trainer):
for opt in self.optimizers:
opt.disable_pre_hook()
import megatron.core.parallel_state as mp

res = check_param_hashes_across_dp_replicas([trainer.strategy.model])
torch.distributed.barrier()

all_res = [False for _ in range(mp.get_data_parallel_world_size())]

torch.distributed.all_gather_object(all_res, res, group=mp.get_data_parallel_group_gloo())

if pl_has_dist_opt_with_ovelap(trainer):
for opt in self.optimizers:
opt.enable_pre_hook()
return all(all_res)


class DdpParityChecker(Callback, io.IOMixin):
"""
This callback enables weight parity checkping across DDP replicas with Mcore models.
User can specify their desired interval for weights to be checked via the `interval` parameter.
Args:
dir (Optional[str]): Directory to store the memory profile dump
Example:
>>> callback = DdpParityChecker(interval=10)
>>> trainer = Trainer(callbacks=[callback])
"""

def __init__(self, interval: int = 0):
"""
interval (int): How frequently to check DDP weights for errors. Default to 0 (off).
"""
assert interval > 0, "Expected interval to be > 0. A zero interval makes DdpParityChecker a no-op."
self.interval = interval
self.step = 0

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, unused=0) -> None:
if self.step == self.interval - 1:
if pl_check_param_hashes_across_dp_replicas(trainer):
logging.info(f"DDP Param parity check passed for batch-id= {batch_idx}")
else:
trainer.should_stop = True
trainer.limit_val_batches = 0
logging.info(f"DDP Param parity check FAILED for batch-id= {batch_idx}")
self.step = (self.step + 1) % self.interval

def on_train_end(self, trainer, pl_module) -> None:
pl_check_param_hashes_across_dp_replicas(trainer)
logging.info("DDP Param parity check passed at end of training.")
129 changes: 129 additions & 0 deletions tests/lightning/test_ddp_parity_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import argparse
import os

import pytest
import torch
from megatron.core.optimizer import OptimizerConfig

from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.llm.gpt.data import PreTrainingDataModule
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.lightning.pytorch.callbacks import DdpParityChecker


def make_parser():
parser = argparse.ArgumentParser(description='Train a small GPT model using NeMo 2.0')
parser.add_argument('--data-path', type=str, help="Path to data file")
parser.add_argument('--vocab-path', type=str, help="Path to vocab file")
parser.add_argument('--merges-path', type=str, help="Path to merges file")

return parser


def wrap_config(config, trainer):
class ConfigWrapper(type(config)):
def configure_model(self, tokenizer) -> "MCoreGPTModel":
return make_byzantine_model_wrapper(super().configure_model(tokenizer), trainer)

config.__class__ = ConfigWrapper
return config


def make_byzantine_model_wrapper(model, trainer):
class ByzantineModel(type(model)):
def forward(self, *ans, **kwargs):
ans = super().forward(*ans, **kwargs)
with torch.no_grad():
import random

rank = int(os.environ['LOCAL_RANK'])
if rank != 1:
return ans
for opt in trainer.strategy.model.optim._optimizers:
for g in opt.param_groups:
for param in g['params']:
param.fill_(random.uniform(0, 1))
return ans

model.__class__ = ByzantineModel
return model


@pytest.mark.skip(reason="tested with GH")
def test_failing(trainer, ddp_parity, optim, data, tokenizer):
config = llm.Llama2Config7B(num_layers=2)
config = wrap_config(config, trainer)
model = llm.LlamaModel(config, tokenizer=tokenizer, optim=optim)
trainer.fit(model, data)


@pytest.mark.skip(reason="tested with GH")
def test_working(trainer, ddp_parity, optim, data, tokenizer):
config = llm.Llama2Config7B(num_layers=2)
model = llm.LlamaModel(config, tokenizer=tokenizer, optim=optim)
trainer.fit(model, data)


def make_trainer_optim(args):
ddp_parity = DdpParityChecker(1)
trainer = nl.Trainer(
devices=2,
max_steps=4,
accelerator="gpu",
strategy=nl.MegatronStrategy(
ckpt_include_optimizer=False,
),
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
limit_val_batches=1,
num_sanity_val_steps=0,
log_every_n_steps=1,
logger=None,
callbacks=[ddp_parity],
)

optim = nl.MegatronOptimizerModule(
config=OptimizerConfig(
optimizer="adam",
lr=1e-5,
use_distributed_optimizer=False,
fp16=False,
bf16=True,
params_dtype=torch.float32,
),
)

tokenizer = get_nmt_tokenizer(
"megatron",
"GPT2BPETokenizer",
vocab_file=args.vocab_path,
merges_file=args.merges_path,
)
data = PreTrainingDataModule(
paths=args.data_path,
seq_length=2048,
global_batch_size=32,
seed=1234,
tokenizer=tokenizer,
)

return trainer, ddp_parity, optim, data, tokenizer


@pytest.mark.skip(reason="tested with GH")
def main():
args = make_parser().parse_args()
trainer, ddp_parity, optim, data, tokenizer = make_trainer_optim(args)
test_failing(trainer, ddp_parity, optim, data, tokenizer)
if trainer.should_stop != True:
raise ValueError("DDP parity checking failed.")

try:
test_working(*make_trainer_optim(args))
print("DDP parity checking worked as expected")
except:
raise


if __name__ == "__main__":
main()

0 comments on commit 58b1d04

Please sign in to comment.