From 004cd2b84beceece6a2a1a555a58bfe50407936d Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Tue, 14 Mar 2023 13:15:56 -0700 Subject: [PATCH] [AIR] Add a sanity checking release test for Alpa and ray nightly. (#32995) Signed-off-by: Jun Gong --- release/alpa_tests/2_g4dn_12xlarge.yaml | 22 + release/alpa_tests/app_config.yaml | 34 ++ release/alpa_tests/run.sh | 26 + release/alpa_tests/train_opt_2_7b_minimum.py | 511 +++++++++++++++++++ release/release_tests.yaml | 24 + 5 files changed, 617 insertions(+) create mode 100644 release/alpa_tests/2_g4dn_12xlarge.yaml create mode 100644 release/alpa_tests/app_config.yaml create mode 100644 release/alpa_tests/run.sh create mode 100644 release/alpa_tests/train_opt_2_7b_minimum.py diff --git a/release/alpa_tests/2_g4dn_12xlarge.yaml b/release/alpa_tests/2_g4dn_12xlarge.yaml new file mode 100644 index 0000000000000..8778954d12bca --- /dev/null +++ b/release/alpa_tests/2_g4dn_12xlarge.yaml @@ -0,0 +1,22 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +max_workers: 2 + +head_node_type: + name: head_node + instance_type: g4dn.12xlarge + +worker_node_types: + - name: worker_node + instance_type: g4dn.12xlarge + min_workers: 1 + max_workers: 1 + use_spot: false + +aws: + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: + DeleteOnTermination: true + VolumeSize: 500 diff --git a/release/alpa_tests/app_config.yaml b/release/alpa_tests/app_config.yaml new file mode 100644 index 0000000000000..f5f3c632dce00 --- /dev/null +++ b/release/alpa_tests/app_config.yaml @@ -0,0 +1,34 @@ +base_image: {{ env["RAY_IMAGE_ML_NIGHTLY_GPU"] | default("anyscale/ray-ml:nightly-py37-gpu") }} +env_vars: {} +debian_packages: + - curl + +python: + pip_packages: + - pytest + - awscli + - cupy-cuda113 + - numpy==1.21.0 + - protobuf==3.20.0 + conda_packages: [] + +post_build_cmds: + # Install nightly wheel. + - pip3 install --upgrade pip + # Install Alpa from source for now. + # TODO(jungong) : pip install alpa after next release. + - git clone https://github.com/alpa-projects/alpa.git + - pip3 install -e alpa + # Install custom built jaxlib. + - pip install jaxlib==0.3.22+cuda113.cudnn820 -f https://alpa-projects.github.io/wheels.html + # Install nvidia dependencies. + - pip3 install --no-cache-dir nvidia-pyindex + - pip3 install --no-cache-dir nvidia-tensorrt==7.2.3.4 + # Huggingface transformers. + - pip3 install -U transformers + # Install testing wheel after Alpa dependencies, since Alpa's setup.py requires + # Ray 2.1.0 right now, and would have overridden the installed version if this + # order is reversed. + - pip3 uninstall ray -y || true && pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + # Sanity check. + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/alpa_tests/run.sh b/release/alpa_tests/run.sh new file mode 100644 index 0000000000000..3ea8734ffb814 --- /dev/null +++ b/release/alpa_tests/run.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Integration test for Alpa and Ray. + +# Exit if any of the test commands fail. +set -x -e pipeline + +TRAIN_FILE=https://air-example-data-2.s3.us-west-2.amazonaws.com/alpa/alllines.txt +S3_MODEL_DIR=s3://air-example-data-2/alpa/opt/models/models--facebook--opt-2.7b/ +LOCAL_MODEL_DIR=/tmp/opt-2.7b/ +OUTPUT_DIR=/tmp/alpa_outputs/ + +mkdir -p $LOCAL_MODEL_DIR +mkdir -p $OUTPUT_DIR + +# Download weights and tokenizer. +aws s3 sync $S3_MODEL_DIR $LOCAL_MODEL_DIR + +# Run training. +python train_opt_2_7b_minimum.py \ + --operator_parallel 1 \ + --pipeline_parallel 4 \ + --model_name_or_path $LOCAL_MODEL_DIR \ + --output_dir $OUTPUT_DIR \ + --train_file $TRAIN_FILE \ + --max_train_samples 100 diff --git a/release/alpa_tests/train_opt_2_7b_minimum.py b/release/alpa_tests/train_opt_2_7b_minimum.py new file mode 100644 index 0000000000000..d665d18d047a7 --- /dev/null +++ b/release/alpa_tests/train_opt_2_7b_minimum.py @@ -0,0 +1,511 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Pre-training/Fine-tuning the library models for causal language modeling +(GPT, GPT-2, CTRL, ...) on a text file or a dataset. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" + +from dataclasses import dataclass, field +import functools +from itertools import chain +import json +import logging +import os +from statistics import mean +import time +from typing import Optional + +import datasets +from datasets import Dataset, load_dataset +import numpy as np +from tqdm import tqdm + +import alpa +from alpa.global_env import global_config +from alpa.model.model_util import DynamicScale, TrainState +import jax +import jax.numpy as jnp +import optax +import transformers +import tensorflow as tf +from transformers import ( + FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, + AutoConfig, + AutoTokenizer, + FlaxAutoModelForCausalLM, + HfArgumentParser, +) + + +logger = logging.getLogger(__name__) + +MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def setup_logging(): + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO) + + datasets.utils.logging.set_verbosity_warning() + # Set the verbosity to info of the Transformers logger (on main process only): + transformers.utils.logging.set_verbosity_info() + + +@dataclass +class TrainingArguments: + output_dir: str = field( + metadata={ + "help": "The output directory where the model and checkpoints are saved." + }, + ) + per_device_train_batch_size: int = field( + default=1, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} + ) + num_micro_batches: int = field( + default=1, + metadata={"help": "The number of micro batches for gradient accumulation."}, + ) + operator_parallel: int = field( + default=1, metadata={"help": "The degree of operator model parallelism."} + ) + pipeline_parallel: int = field( + default=1, metadata={"help": "The degree of pipeline model parallelism."} + ) + learning_rate: float = field( + default=5e-5, metadata={"help": "The initial learning rate for AdamW."} + ) + num_train_epochs: int = field( + default=1, metadata={"help": "Total number of training epochs to perform."} + ) + logging_steps: int = field( + default=10, metadata={"help": "Log every X updates steps."} + ) + save_steps: int = field( + default=100, metadata={"help": "Save checkpoint every X updates steps."} + ) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, + or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + metadata={"help": "The model checkpoint for weights initialization."}, + ) + model_type: Optional[str] = field( + default=None, + metadata={ + "help": "If training from scratch, pass a model type from the list: " + + ", ".join(MODEL_TYPES) + }, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, + ) + + +@dataclass +class DataTrainingArguments: + """Arguments pertaining to what data we are going to use for training and eval.""" + + train_file: Optional[str] = field( + metadata={"help": "The input training data file (a text file)."} + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of " + "training examples to this value if set." + ) + }, + ) + block_size: Optional[int] = field( + default=1024, + metadata={ + "help": ( + "Optional input sequence length after tokenization. " + "The training dataset will be truncated in block of this size. " + "Default to the model max input length for single sentence inputs " + "(take into account special tokens)." + ) + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=1, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + + +def data_loader(dataset: Dataset, batch_size: int, shuffle: bool = False): + """Returns batches of size `batch_size` from truncated `dataset`, + sharded over all local devices. Shuffle batches if `shuffle` is `True`. + """ + data_collator = transformers.DefaultDataCollator("np") + tf_dataset = dataset.to_tf_dataset( + batch_size=batch_size, + columns=dataset.column_names, + collate_fn=data_collator, + shuffle=shuffle, + drop_remainder=True, + ) + + for batch in tf_dataset: + batch = {k: v._numpy() for k, v in batch.items()} + yield batch + + +# Main data processing function that will concatenate all texts from +# our dataset and generate chunks of block_size. +def group_texts(block_size, examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*v)) for k, v in examples.items()} + # Length of first concatenated example. + total_length = len(next(iter(concatenated_examples.values()))) + # We drop the small remainder, we could add padding if the model supported + # it instead of this drop, you can customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + +# Define gradient update step fn +def train_step(state, batch): + """Main training step function.""" + + def loss_fn(logits, labels): + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + loss = optax.softmax_cross_entropy( + shift_logits, jax.nn.one_hot(shift_labels, logits.shape[-1]) + ) + return loss.mean() + + def compute_loss(params): + labels = batch.pop("labels") + logits = state.apply_fn(**batch, params=params, deterministic=True)[0] + loss = loss_fn(logits, labels) + return loss + + dynamic_scale = state.dynamic_scale + grad_fn = dynamic_scale.value_and_grad(compute_loss) + dynamic_scale, is_fin, loss, grads = grad_fn(state.params) + + new_state = state.apply_gradients(grads=grads) + new_state = new_state.replace( + opt_state=jax.tree_map( + functools.partial(jnp.where, is_fin), + new_state.opt_state, + state.opt_state, + ), + params=jax.tree_map( + functools.partial(jnp.where, is_fin), new_state.params, state.params + ), + master_copy=jax.tree_map( + functools.partial(jnp.where, is_fin), + new_state.master_copy, + state.master_copy, + ), + dynamic_scale=dynamic_scale, + ) + + metrics = {"loss": loss} + + return new_state, metrics + + +def save_checkpoint(state, model, tokenizer, training_args): + """Util to checkpoint model in output_dir.""" + alpa.prefetch(state.params) + params = alpa.util.map_to_nparray(state.params) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + + +def build_dataset(data_args, tokenizer): + # TODO(jungong) : replace huggingface dataset with Ray dataset. + dataset = load_dataset( + "text", + data_files={ + "train": data_args.train_file, + }, + keep_linebreaks=False, + ) + + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = dataset["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + return tokenizer(examples[text_column_name]) + + logger.info("***** Tokenize dataset *****") + tokenized_datasets = dataset.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=False, + ) + + block_size = min(data_args.block_size, tokenizer.model_max_length) + + # Note that with `batched=True`, this map processes 1,000 texts together, + # so group_texts throws away a remainder for each of those groups of 1,000 texts. + # You can adjust that batch_size here but a higher value if preprocess is slow. + + logger.info("***** Build dataset *****") + lm_datasets = tokenized_datasets.map( + functools.partial(group_texts, block_size), + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=False, + ) + train_dataset = lm_datasets["train"] + if data_args.max_train_samples is not None: + max_train_samples = min(len(train_dataset), data_args.max_train_samples) + train_dataset = train_dataset.select(range(max_train_samples)) + + return train_dataset + + +def save_json_metrics(metrics): + # Skip the first couple of data points for a more accurate throughput. + to_report = { + "throughput_tokens": ( + mean(metrics["tokens"][2:]) if len(metrics["tokens"]) > 2 else 0.0 + ), + "throughput_tflops": ( + mean(metrics["tflops"][2:]) if len(metrics["tflops"]) > 2 else 0.0 + ), + } + test_output_json = os.environ.get( + "TEST_OUTPUT_JSON", "/tmp/alpa_opt_2_7b_sanity_check.json" + ) + + print("Writing metrics: ", to_report, f" to {test_output_json}") + + with open(test_output_json, "wt") as f: + json.dump(to_report, f) + + +def main(): + # Global initialization. + alpa.init(cluster="ray") + + tf.config.experimental.set_visible_devices([], "GPU") + + # "cupy" doesn't really work, use "xla_extension" instead. + global_config.nccl_mode = "xla_extension" + + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if os.path.exists(training_args.output_dir) and os.listdir( + training_args.output_dir + ): + raise ValueError( + f"Directory ({training_args.output_dir}) already exists and is not empty." + ) + + logger.info(f"Training/evaluation parameters {training_args}") + + setup_logging() + + # Load pretrained model and tokenizer + + # Distributed training: + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=True, + ) + + assert model_args.model_name_or_path, "model_name_or_path is required" + model = FlaxAutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=getattr(jnp, "float16"), + use_auth_token=None, + ) + + # Training dataset. + train_dataset = build_dataset(data_args, tokenizer) + + # Adjust batch size and num_micro_batches for small datasets + num_devices = alpa.get_global_num_devices() + # Store some constant + num_epochs = training_args.num_train_epochs + data_parallel = num_devices // ( + training_args.operator_parallel * training_args.pipeline_parallel + ) + train_batch_size = training_args.per_device_train_batch_size * data_parallel + steps_per_epoch = len(train_dataset) // train_batch_size + total_train_steps = steps_per_epoch * num_epochs + + # create adam optimizer + optimizer = optax.chain( + optax.clip_by_global_norm(1.0), + optax.adamw(learning_rate=training_args.learning_rate), + ) + + # Fix a bug in huggingface's transformers implementation + # (https://github.com/huggingface/transformers/pull/18462) + alpa.global_config.flax_always_use_fp16_embedding = True + + # Setup train state + state = TrainState.create( + apply_fn=model.__call__, + params=model.params, + tx=optimizer, + dynamic_scale=DynamicScale(), + use_master_copy=True, + ) + + # Create parallel version of the train and eval step + method = alpa.get_3d_parallel_method( + num_micro_batches=training_args.num_micro_batches, + data_parallel=-1, + operator_parallel=training_args.operator_parallel, + pipeline_parallel=training_args.pipeline_parallel, + ) + + p_train_step = alpa.parallelize(train_step, method=method, donate_argnums=(0,)) + + logger.info("***** Training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info( + " Batch size per device (w. accumulation) = " + f"{training_args.per_device_train_batch_size}" + ) + logger.info( + f" Global train batch size (w. parallel & distributed) = {train_batch_size}" + ) + logger.info(f" Total optimization steps = {total_train_steps}") + logger.info(f" NCCL mode = {global_config.nccl_mode}") + + train_time = 0 + train_metrics = [] + epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) + + step_ct = 0 + last_time = time.time() + + epochs.write("Initial compilation. This might take some minutes...") + + # Track and report throughput per iteration. These are the metrics we + # care about over time. + metrics_to_report = { + "tokens": [], + "tflops": [], + } + for epoch in epochs: + train_start = time.time() + + # Generate an epoch by shuffling sampling indices from the train dataset + train_loader = data_loader(train_dataset, train_batch_size, shuffle=True) + steps_per_epoch = len(train_dataset) // train_batch_size + + for step in tqdm( + range(steps_per_epoch), desc="Training...", position=1, leave=False + ): + batch = next(train_loader) + batch["position_ids"] = ( + batch["attention_mask"].cumsum(axis=1) * batch["attention_mask"] + ) - 1 + state, train_metric = p_train_step(state, batch) + train_metrics.append(train_metric) + + cur_step = epoch * (len(train_dataset) // train_batch_size) + step + + step_ct += 1 + if cur_step % training_args.logging_steps == 0 and cur_step > 0: + latency = (time.time() - last_time) / step_ct + throughput_tokens = np.prod(batch["input_ids"].shape) / latency + throughput_tflops = alpa.util.compute_gpt_tflops( + batch_size=batch["input_ids"].shape[0], + seq_len=batch["input_ids"].shape[1], + num_layers=config.num_hidden_layers, + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, + num_gpus=alpa.get_global_num_devices(), + latency=latency, + ) + step_ct = 0 + + # Save metrics + train_time += time.time() - train_start + train_metric = jax.tree_map(np.mean, train_metric) + + # Metrics we report from the release test. + metrics_to_report["tokens"].append(throughput_tokens) + metrics_to_report["tflops"].append(throughput_tflops) + + epochs.write( + f"Step... {cur_step} | " + f"Loss: {train_metric['loss'].mean():.4f}, " + f"Throughput: {throughput_tokens:.2f} token/s, " + f"{throughput_tflops:.2f} TFLOP/s" + ) + + train_metrics = [] + last_time = time.time() + + if cur_step % training_args.save_steps == 0 and cur_step > 0: + # save checkpoint after each epoch + epochs.write("\nSave checkpoint...") + save_checkpoint(state, model, tokenizer, training_args) + + # Save the final model + epochs.write("\nSave the final model...") + save_checkpoint(state, model, tokenizer, training_args) + + # Save JSON metrics + save_json_metrics(metrics_to_report) + + +if __name__ == "__main__": + main() diff --git a/release/release_tests.yaml b/release/release_tests.yaml index eb19c3cf0166b..e1e5a92039e17 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -2287,6 +2287,30 @@ alert: default +######################## +# Alpa tests +######################## + +- name: alpa_opt_2_7b_sanity_check + group: Alpa tests + working_dir: alpa_tests + + frequency: nightly + team: ml + + cluster: + cluster_env: app_config.yaml + cluster_compute: 2_g4dn_12xlarge.yaml + + run: + timeout: 3600 + script: bash run.sh + + wait_for_nodes: + num_nodes: 2 + + alert: default + ######################## # RLlib tests ########################