Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] Make TE and Apex dependencies optional #9732

Merged
merged 3 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4516,6 +4516,48 @@ jobs:
AFTER_SCRIPT: |
rm -rf examples/multimodal/text_to_image/sd_train_results
L2_NeMo_2_GPT_Pretraining_no_transformer_engine:
needs: [cicd-test-container-setup]
runs-on: self-hosted-azure
timeout-minutes: 10
container:
image: nemoci.azurecr.io/nemo_container_${{ github.run_id }}
options:
--device=/dev/nvidia0
--gpus all
--shm-size=8g
--env TRANSFORMERS_OFFLINE=0
--volume /mnt/datadrive/TestData:/home/TestData
steps:
- name: Checkout repository
uses: actions/checkout@v4
- run: |
pip uninstall -y apex ## TODO: remove when apex is no longer a dependency
pip uninstall -y transformer_engine
python examples/llm/megatron_gpt_pretraining.py \
--devices=2 \
--max-steps=3 \
--experiment-dir=examples/llm/gpt_pretrain_results \
--vocab-path=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
--merges-path=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
--data-path=/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document \
--index-mapping-dir=examples/llm/gpt_index_mappings
python examples/llm/megatron_gpt_pretraining.py \
--devices=2 \
--max-steps=6 \
--experiment-dir=examples/llm/gpt_pretrain_results \
--vocab-path=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
--merges-path=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
--data-path=/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document \
--index-mapping-dir=examples/llm/gpt_index_mappings
rm -rf examples/llm/gpt_pretrain_results
rm -rf examples/llm/gpt_index_mappings
- uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
if: "failure()"

Nemo_CICD_Test:
needs:
- L0_Unit_Tests_GPU
Expand Down Expand Up @@ -4616,6 +4658,7 @@ jobs:
- L2_TTS_Fast_dev_runs_1_Hifigan
- Speech_Checkpoints_tests
- L2_Stable_Diffusion_Training
- L2_NeMo_2_GPT_Pretraining_no_transformer_engine
if: always()
runs-on: ubuntu-latest
steps:
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ WORKDIR /workspace
# Install NeMo requirements
ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea
ARG MODELOPT_VERSION=0.13.0
ARG MCORE_TAG=c0164bcfd4f8213a10a6b1e47ef80721a68b4fb6
ARG MCORE_TAG=c7a1f82d761577e6ca0338d3521eac82f2aa0904
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
RUN \
--mount=type=bind,source=requirements,target=requirements \
Expand Down
109 changes: 109 additions & 0 deletions examples/llm/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
## NOTE: This script is present for github-actions testing only.
## There are no guarantees that this script is up-to-date with latest NeMo.

import argparse

from megatron.core.optimizer import OptimizerConfig
from pytorch_lightning.loggers import TensorBoardLogger

from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.llm.api import train
from nemo.collections.llm.gpt.data import PreTrainingDataModule
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.lightning import NeMoLogger
from nemo.lightning.pytorch.callbacks import ModelCheckpoint
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule


def get_args():
parser = argparse.ArgumentParser(description='Train a small GPT model using NeMo 2.0')
parser.add_argument('--devices', type=int, help="Number of devices to use for training")
parser.add_argument('--max-steps', type=int, help="Number of steps to train for")
parser.add_argument('--experiment-dir', type=str, help="directory to write results and checkpoints to")
parser.add_argument('--data-path', type=str, help="Path to data file")
parser.add_argument('--vocab-path', type=str, help="Path to vocab file")
parser.add_argument('--merges-path', type=str, help="Path to merges file")
parser.add_argument('--index-mapping-dir', type=str, help="directory to write index mappings to")

return parser.parse_args()


if __name__ == '__main__':

args = get_args()

seq_length = 2048

tokenizer = get_nmt_tokenizer(
"megatron",
"GPT2BPETokenizer",
vocab_file=args.vocab_path,
merges_file=args.merges_path,
)
data = PreTrainingDataModule(
paths=args.data_path,
seq_length=2048,
global_batch_size=32,
seed=1234,
tokenizer=tokenizer,
)
gpt_config = llm.GPTConfig(
num_layers=12,
hidden_size=768,
ffn_hidden_size=3072,
num_attention_heads=12,
seq_length=seq_length,
init_method_std=0.023,
hidden_dropout=0.1,
attention_dropout=0.1,
layernorm_epsilon=1e-5,
make_vocab_size_divisible_by=128,
)
model = llm.GPTModel(gpt_config, tokenizer=data.tokenizer)
strategy = nl.MegatronStrategy()
checkpoint_callback = ModelCheckpoint(
every_n_train_steps=5000,
enable_nemo_ckpt_io=False,
async_save=False,
)
callbacks = [checkpoint_callback]

loggers = []
tensorboard_logger = TensorBoardLogger(
save_dir='dummy', ## NOTE: this gets overwritten by default
)
loggers.append(tensorboard_logger)

opt_config = OptimizerConfig(
optimizer='adam',
lr=6e-4,
min_lr=6e-5,
use_distributed_optimizer=False,
bf16=True,
)
opt = MegatronOptimizerModule(config=opt_config)

trainer = nl.Trainer(
devices=args.devices,
max_steps=args.max_steps,
accelerator="gpu",
strategy=strategy,
logger=loggers,
callbacks=callbacks,
log_every_n_steps=1,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed", amp_O2=False),
)

nemo_logger = NeMoLogger(
dir=args.experiment_dir,
)

train(
model=model,
data=data,
trainer=trainer,
log=nemo_logger,
tokenizer='data',
optim=opt,
)
19 changes: 8 additions & 11 deletions examples/nlp/machine_translation/nmt_transformer_infer_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,12 @@

from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.app_state import AppState
from nemo.utils.model_utils import inject_model_parallel_rank

try:
from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
ModelType = ApexGuardDefaults()
HAVE_APEX = False


@hydra_runner(config_path="conf", config_name="nmt_megatron_infer")
def main(cfg) -> None:
Expand Down Expand Up @@ -101,13 +92,19 @@ def main(cfg) -> None:
src_text.append(line.strip())
if len(src_text) == cfg.batch_size:
translations = model.translate(
text=src_text, source_lang=cfg.source_lang, target_lang=cfg.target_lang,
text=src_text,
source_lang=cfg.source_lang,
target_lang=cfg.target_lang,
)
for translation in translations:
tgt_f.write(translation + "\n")
src_text = []
if len(src_text) > 0:
translations = model.translate(text=src_text, source_lang=cfg.source_lang, target_lang=cfg.target_lang,)
translations = model.translate(
text=src_text,
source_lang=cfg.source_lang,
target_lang=cfg.target_lang,
)
for translation in translations:
tgt_f.write(translation + "\n")

Expand Down
6 changes: 2 additions & 4 deletions nemo/collections/llm/gpt/data/pre_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
state_dict: the datamodule state returned by ``state_dict``.
"""
try:
from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR
except ModuleNotFoundError:
from nemo.lightning.apex_utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR
from megatron.core.num_microbatches_calculator import _GLOBAL_NUM_MICROBATCHES_CALCULATOR

consumed_samples = state_dict['consumed_samples']
self.data_sampler.init_consumed_samples = consumed_samples
self.data_sampler.prev_consumed_samples = consumed_samples
Expand Down
15 changes: 14 additions & 1 deletion nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from nemo.lightning.megatron_parallel import MaskedTokenLossReduction
from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule

HAVE_TE = True
try:
import transformer_engine

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'transformer_engine' is not used.
except (ImportError, ModuleNotFoundError):
HAVE_TE = False

if TYPE_CHECKING:
from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel

Expand Down Expand Up @@ -77,6 +83,13 @@
)


def default_layer_spec(config: "GPTConfig") -> ModuleSpec:
if HAVE_TE:
return transformer_engine_layer_spec(config)
else:
return local_layer_spec(config)


@dataclass
class GPTConfig(TransformerConfig, io.IOMixin):
# From megatron.core.models.gpt.gpt_model.GPTModel
Expand All @@ -93,7 +106,7 @@
# TODO: Move this to better places?
get_attention_mask_from_fusion: bool = False

transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = transformer_engine_layer_spec
transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = default_layer_spec
forward_step_fn: Callable = gpt_forward_step
data_step_fn: Callable = gpt_data_step

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,10 @@
from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging

try:
import apex.transformer.pipeline_parallel.utils
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

HAVE_APEX = True

except (ImportError, ModuleNotFoundError):

HAVE_APEX = False

try:
from megatron.core import InferenceParams, dist_checkpointing, parallel_state, tensor_parallel
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint

Expand Down
Loading
Loading