Skip to content

Commit

Permalink
WIP - nothing working yet. Disable mix_in by default.
Browse files Browse the repository at this point in the history
  • Loading branch information
arnocandel committed Apr 11, 2023
1 parent 1640f1f commit 864fd5f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 26 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ cd ..
6. Add Flash Attention

```bash
git clone https://github.com/HazyResearch/flash-attention.git
git clone https://github.com/h2oai/flash-attention.git
cd flash-attention
python setup.py install
cd csrc/layer_norm
Expand Down
98 changes: 73 additions & 25 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def train(

# data_mix_in_path: str = "laion/OIG", # way too big, medium quality
data_mix_in_path: str = "0-hero/OIG-small-chip2", # high quality, 50 MB, good enough for now
data_mix_in_factor: float = 1.0, # >1: more mix-in data, <1: more of data_path data
data_mix_in_factor: float = 0.0, # >1: more mix-in data, <1: more of data_path data
data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
data_mix_in_prompt_type: str = "instruct", # just instruction->output, same as instruct

Expand Down Expand Up @@ -162,6 +162,7 @@ def train(
logging_steps: int = 1,
save_steps: int = None, # must be round multiple of eval_steps
add_eos_token: bool = False,
flash_attention: bool = False,
):
prompt_type = str(prompt_type) # migration from integers
assert prompt_type in prompt_types
Expand Down Expand Up @@ -298,28 +299,68 @@ def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
] # could be sped up, probably
return tokenized_full_prompt

model = prepare_model_for_int8_training(model)

if lora_weights:
from peft import PeftModel
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
device_map=device_map,
local_files_only=local_files_only,
resume_download=resume_download,
)
if base_model in [
'EleutherAI/gpt-j-6B',
'EleutherAI/gpt-neox-20b',
'togethercomputer/GPT-NeoXT-Chat-Base-20B',
] and flash_attention:
log("Enabling Flash attention")
# speed up forward prop for attention layer and reduce memory especially for long context lengths
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config
from flash_attn.models.gptj import gptj_config_to_gpt2_config

if "gpt-j" in base_model.lower():
config = gptj_config_to_gpt2_config(model.config)
else:
config = gpt_neox_config_to_gpt2_config(model.config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
lora_target_modules = ['Wqkv']
model = GPTLMHeadModel.from_pretrained(base_model, config, device='cuda', dtype=torch.float16)
# for v in vars(model2.config):
# setattr(model.config, v, getattr(model2.config, v))
# model.transformer.config = model.config
# model.transformer.h = model2.transformer.layers
# model.lm_head = model2.lm_head
### model.transformer.wte = model2.transformer.wte
### model.transformer.embeddings = model2.transformer.embeddings
print(model)
# FIXME - don't disable LoRA
lora_r = 0
# FIXME - enable 8-bit
# model = prepare_model_for_int8_training(model)
else:
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
model = prepare_model_for_int8_training(model)

if lora_r > 0:
if lora_weights:
log("Loading LoRA weights")
from peft import PeftModel
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
device_map=device_map,
local_files_only=local_files_only,
resume_download=resume_download,
)
else:
log("Creating fresh LoRA weights")
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
else:
log("LoRA disabled.")

if resume_from_checkpoint:
# Check the available weights and load them
Expand All @@ -339,7 +380,10 @@ def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
else:
log(f"Checkpoint {checkpoint_name} not found")

model.print_trainable_parameters() # Be more transparent about the % of trainable params.
try:
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
except:
pass

metrics = {}
for name in supported_metrics:
Expand Down Expand Up @@ -371,7 +415,7 @@ def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
train_data_mix_in = None
valid_data_mix_in = None

if data_mix_in_path:
if data_mix_in_path and data_mix_in_factor > 0:
# get mix-in training/validation data - to keep model "sane"
num_rows = data["train"].num_rows
log("Loading mix-in dataset: %s" % data_mix_in_path)
Expand Down Expand Up @@ -413,6 +457,7 @@ def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):

# get our own training/validation data - for fine-tuning
if val_set_size > 0 and not valid_path and not data_mix_in_path:
log("Creating validation data from training data")
# create valid split from train
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
Expand All @@ -422,8 +467,11 @@ def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
else:
train_data = data["train"]
if valid_path:
log("Using given validation data")
# use given valid split, has priority over data_mix_in_path
valid_data = data["valid"]
else:
log("No validation data")
if "prompt_type" not in train_data.column_names:
train_data = train_data.add_column(
"prompt_type",
Expand Down Expand Up @@ -549,7 +597,7 @@ def compute_metrics(eval_preds):
learning_rate=learning_rate,
gradient_checkpointing=gradient_checkpointing,
fp16=fp16,
# cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
# consider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
optim="adamw_torch", # consider "adafactor" to save memory
logging_steps=logging_steps,
logging_strategy="steps",
Expand Down

0 comments on commit 864fd5f

Please sign in to comment.