-
Notifications
You must be signed in to change notification settings - Fork 663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify LLM Trainer to support BERT and Tiny LLaMA #2031
Changes from 3 commits
3e2e1a7
519c11a
d5bd8b2
a1ec3e6
f3fb861
6717f93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,5 @@ | ||
einops>=0.6.1 | ||
transformers_stream_generator==0.0.4 | ||
peft==0.3.0 | ||
datasets==2.15.0 | ||
transformers==4.37.2 | ||
boto3==1.33.9 | ||
transformers>=4.20.0 | ||
peft>=0.3.0 | ||
huggingface_hub==0.16.4 | ||
datasets>=2.13.2 | ||
|
||
huggingface_hub==0.19.3 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from dataclasses import dataclass, field | ||
import json, os | ||
import boto3 | ||
import json | ||
import os | ||
from urllib.parse import urlparse | ||
from .abstract_dataset_provider import datasetProvider | ||
from .constants import VOLUME_PATH_DATASET | ||
|
@@ -39,6 +39,8 @@ def load_config(self, serialised_args): | |
self.config = S3DatasetParams(**json.loads(serialised_args)) | ||
|
||
def download_dataset(self): | ||
import boto3 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we put this import on the top? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tenzen-y I did this on purpose so Training Operator SDK won't be dependant on boto3 while importing S3 storage init: https://github.com/kubeflow/training-operator/blob/master/sdk/python/kubeflow/training/api/training_client.py#L125 |
||
|
||
# Create an S3 client for Nutanix Object Store/S3 | ||
s3_client = boto3.client( | ||
"s3", | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,70 +1,87 @@ | ||||||
import argparse | ||||||
import logging | ||||||
from urllib.parse import urlparse | ||||||
import json | ||||||
|
||||||
from datasets import load_from_disk | ||||||
from peft import LoraConfig, get_peft_model | ||||||
import transformers | ||||||
from transformers import ( | ||||||
AutoModelForCausalLM, | ||||||
AutoTokenizer, | ||||||
AutoConfig, | ||||||
AutoModelForImageClassification, | ||||||
TrainingArguments, | ||||||
DataCollatorForLanguageModeling, | ||||||
Trainer, | ||||||
) | ||||||
import torch | ||||||
from datasets import load_dataset | ||||||
from peft import LoraConfig, get_peft_model | ||||||
from urllib.parse import urlparse | ||||||
import os | ||||||
import json | ||||||
|
||||||
|
||||||
# Configure logger. | ||||||
log_formatter = logging.Formatter( | ||||||
"%(asctime)s %(levelname)-8s %(message)s", "%Y-%m-%dT%H:%M:%SZ" | ||||||
) | ||||||
logger = logging.getLogger(__file__) | ||||||
console_handler = logging.StreamHandler() | ||||||
console_handler.setFormatter(log_formatter) | ||||||
logger.addHandler(console_handler) | ||||||
logger.setLevel(logging.INFO) | ||||||
|
||||||
|
||||||
def setup_model_and_tokenizer(model_uri, transformer_type, model_dir): | ||||||
# Set up the model and tokenizer | ||||||
parsed_uri = urlparse(model_uri) | ||||||
model_name = parsed_uri.netloc + parsed_uri.path | ||||||
transformer_type_class = getattr(transformers, transformer_type) | ||||||
|
||||||
model = transformer_type_class.from_pretrained( | ||||||
model = transformer_type.from_pretrained( | ||||||
pretrained_model_name_or_path=model_name, | ||||||
cache_dir=model_dir, | ||||||
local_files_only=True, | ||||||
device_map="auto", | ||||||
trust_remote_code=True, | ||||||
) | ||||||
|
||||||
tokenizer = transformers.AutoTokenizer.from_pretrained( | ||||||
tokenizer = AutoTokenizer.from_pretrained( | ||||||
pretrained_model_name_or_path=model_name, | ||||||
cache_dir=model_dir, | ||||||
local_files_only=True, | ||||||
device_map="auto", | ||||||
) | ||||||
|
||||||
tokenizer.pad_token = tokenizer.eos_token | ||||||
tokenizer.add_pad_token = True | ||||||
|
||||||
# Freeze model parameters | ||||||
for param in model.parameters(): | ||||||
param.requires_grad = False | ||||||
|
||||||
return model, tokenizer | ||||||
|
||||||
|
||||||
def load_and_preprocess_data(dataset_name, dataset_dir, transformer_type, tokenizer): | ||||||
def load_and_preprocess_data(dataset_dir, transformer_type, tokenizer): | ||||||
# Load and preprocess the dataset | ||||||
print("loading dataset") | ||||||
transformer_type_class = getattr(transformers, transformer_type) | ||||||
if transformer_type_class != transformers.AutoModelForImageClassification: | ||||||
dataset = load_dataset(dataset_name, cache_dir=dataset_dir).map( | ||||||
lambda x: tokenizer(x["text"]), batched=True | ||||||
logger.info("Load and preprocess dataset") | ||||||
|
||||||
if transformer_type != AutoModelForImageClassification: | ||||||
dataset = load_from_disk(dataset_dir) | ||||||
|
||||||
logger.info(f"Dataset specification: {dataset}") | ||||||
logger.info("-" * 40) | ||||||
|
||||||
logger.info("Tokenize dataset") | ||||||
# TODO (andreyvelich): Discuss how user should set the tokenizer function. | ||||||
dataset = dataset.map( | ||||||
lambda x: tokenizer(x["text"], padding="max_length", truncation=True), | ||||||
batched=True, | ||||||
) | ||||||
else: | ||||||
dataset = load_dataset(dataset_name, cache_dir=dataset_dir) | ||||||
dataset = load_from_disk(dataset_dir) | ||||||
|
||||||
train_data = dataset["train"] | ||||||
# Check if dataset contains `train` key. Otherwise, load full dataset to train_data. | ||||||
if "train" in dataset: | ||||||
train_data = dataset["train"] | ||||||
else: | ||||||
train_data = dataset | ||||||
|
||||||
try: | ||||||
eval_data = dataset["eval"] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it always "dataset["eval"]" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @johnugeorge It depends on the dataset. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SGTM |
||||||
except Exception as err: | ||||||
except Exception: | ||||||
eval_data = None | ||||||
print("Evaluation dataset is not found") | ||||||
logger.info("Evaluation dataset is not found") | ||||||
|
||||||
return train_data, eval_data | ||||||
|
||||||
|
@@ -77,20 +94,27 @@ def setup_peft_model(model, lora_config): | |||||
return model | ||||||
|
||||||
|
||||||
def train_model(model, train_data, eval_data, tokenizer, train_args): | ||||||
# Train the model | ||||||
def train_model(model, transformer_type, train_data, eval_data, tokenizer, train_args): | ||||||
# Setup the Trainer. | ||||||
trainer = Trainer( | ||||||
model=model, | ||||||
train_dataset=train_data, | ||||||
eval_dataset=eval_data, | ||||||
tokenizer=tokenizer, | ||||||
args=train_args, | ||||||
data_collator=DataCollatorForLanguageModeling( | ||||||
tokenizer, pad_to_multiple_of=8, mlm=False | ||||||
), | ||||||
) | ||||||
|
||||||
# TODO (andreyvelich): Currently, data collator is supported only for casual LM Transformer. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is TODO? I guess that you'd like to support data collector other than casual LM Transformer, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think, it calls Data Collator in HuggingFace: https://huggingface.co/docs/transformers/en/main_classes/data_collator There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to investigate if we want to apply Data Collator for other transformers. I will create an issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Thanks. |
||||||
if transformer_type == AutoModelForCausalLM: | ||||||
logger.info("Add data collector for language modeling") | ||||||
logger.info("-" * 40) | ||||||
trainer.data_collator = DataCollatorForLanguageModeling( | ||||||
tokenizer, | ||||||
pad_to_multiple_of=8, | ||||||
mlm=False, | ||||||
) | ||||||
|
||||||
# Train the model. | ||||||
trainer.train() | ||||||
print("training done") | ||||||
|
||||||
|
||||||
def parse_arguments(): | ||||||
|
@@ -101,8 +125,7 @@ def parse_arguments(): | |||||
parser.add_argument("--model_uri", help="model uri") | ||||||
parser.add_argument("--transformer_type", help="model transformer type") | ||||||
parser.add_argument("--model_dir", help="directory containing model") | ||||||
parser.add_argument("--dataset_dir", help="directory contaning dataset") | ||||||
parser.add_argument("--dataset_name", help="dataset name") | ||||||
parser.add_argument("--dataset_dir", help="directory containing dataset") | ||||||
parser.add_argument("--lora_config", help="lora_config") | ||||||
parser.add_argument( | ||||||
"--training_parameters", help="hugging face training parameters" | ||||||
|
@@ -112,13 +135,25 @@ def parse_arguments(): | |||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
logger.info("Starting HuggingFace LLM Trainer") | ||||||
args = parse_arguments() | ||||||
train_args = TrainingArguments(**json.loads(args.training_parameters)) | ||||||
transformer_type = getattr(transformers, args.transformer_type) | ||||||
|
||||||
logger.info("Setup model and tokenizer") | ||||||
model, tokenizer = setup_model_and_tokenizer( | ||||||
args.model_uri, args.transformer_type, args.model_dir | ||||||
args.model_uri, transformer_type, args.model_dir | ||||||
) | ||||||
|
||||||
logger.info("Preprocess dataset") | ||||||
train_data, eval_data = load_and_preprocess_data( | ||||||
args.dataset_name, args.dataset_dir, args.transformer_type, tokenizer | ||||||
args.dataset_dir, transformer_type, tokenizer | ||||||
) | ||||||
|
||||||
logger.info("Setup LoRA config for model") | ||||||
model = setup_peft_model(model, args.lora_config) | ||||||
train_model(model, train_data, eval_data, tokenizer, train_args) | ||||||
|
||||||
logger.info("Start model training") | ||||||
train_model(model, transformer_type, train_data, eval_data, tokenizer, train_args) | ||||||
|
||||||
logger.info("Training is complete") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,3 @@ | ||
peft>=0.3.0 | ||
peft==0.3.0 | ||
datasets==2.15.0 | ||
transformers>=4.20.0 | ||
bitsandbytes>=0.42.0 | ||
einops>=0.6.1 | ||
transformers==4.37.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@andreyvelich why are we downloading the dataset again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's great catch @deepanker13!
We should remove it.