Skip to content

Commit

Permalink
Merge pull request #19 from BAAI-Open/fix_bugs_for_10b_models_xzh
Browse files Browse the repository at this point in the history
fix bugs
  • Loading branch information
marscrazy authored May 28, 2022
2 parents 8c64639 + c7997e7 commit 598af04
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 40 deletions.
141 changes: 141 additions & 0 deletions examples/t5_flagai_10b/train_title_with_flagai_t5_11b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import sys
sys.path.append('/mnt/liuguang/FlagAI')
from flagai.trainer import Trainer
from flagai.model.t5_model import T5ForConditionalGeneration
from transformers import T5Tokenizer
from flagai.model.predictor.predictor import Predictor
from torch.utils.data import Dataset
import os
import torch
cur_dir = os.path.dirname(os.path.abspath(__file__))

# train_path = cur_dir + "/data/news.tsv"
train_path = "/mnt/datasets/pens_dataset/train.tsv"

class MyTrainer(Trainer):

def forward_step(self, data, model, mems):

model_outputs = model(**data)
output = {}
output['loss'] = model_outputs.loss
output['logits'] = model_outputs.logits
output['hidden_states'] = model_outputs.decoder_hidden_states
return output


trainer = MyTrainer(
env_type='deepspeed',
epochs=1,
batch_size=1,
eval_interval=100000,
log_interval=1,
experiment_name='t5-11b',
load_dir=None,
lr=1e-4,
fp16=True,
master_ip='127.0.0.1',
master_port=17755,
num_nodes=1,
num_gpus=1,
hostfile='./hostfile',
model_parallel_size=1,
deepspeed_config='./deepspeed.json',
training_script=__file__)

def read_file():
src = []
tgt = []

index = 0
with open(train_path, 'r', encoding='utf-8') as f:
for line in f:
index += 1
if index == 1:
continue
line = line.strip('\n').split('\t')
src_list = line[4].split(" ")
if len(src_list) > 510:
continue

src.append(line[4])
tgt.append(line[3])
if index == 100000:
break

return src, tgt

model_name = '/mnt/t5-11b'
tokenizer = T5Tokenizer.from_pretrained('t5-11b')

model = T5ForConditionalGeneration.from_pretrain(download_path = '/mnt', model_name='t5-11b')

print("loading model & tokenizer is done!")

maxlen = 1024

predictor = Predictor(model, tokenizer)


class T5Seq2seqDataset(Dataset):

def __init__(self, sents_src, sents_tgt, tokenizer, maxlen=512):
super(T5Seq2seqDataset, self).__init__()
self.sents_src = sents_src
self.sents_tgt = sents_tgt
self.tokenizer = tokenizer
self.maxlen = maxlen

def __getitem__(self, i):
src = self.sents_src[i]
tgt = self.sents_tgt[i]
inputs = tokenizer(src)
with tokenizer.as_target_tokenizer():
labels = tokenizer(tgt)
output = {}
output['input_ids'] = inputs.input_ids
output['target_ids'] = labels.input_ids
return output

def __len__(self):
return len(self.sents_src)
#
def t5_seq2seq_collate_fn(batch):

def padding(indice, max_length, pad_idx=0):

pad_indice = [
item + [pad_idx] * max(0, max_length - len(item))
for item in indice
]
return torch.tensor(pad_indice)

token_ids_src = [data["input_ids"] for data in batch]
max_length_src = max([len(t) for t in token_ids_src])
token_ids_tgt = [data["target_ids"] for data in batch]
max_length_tgt = max([len(t) for t in token_ids_tgt])

token_ids_padded = padding(token_ids_src, max_length_src)
target_ids_padded = padding(token_ids_tgt, max_length_tgt)
labels_ids = target_ids_padded.clone()
labels_ids[labels_ids == 0] = -100
target_ids_padded = target_ids_padded[:, :-1].contiguous()
labels_ids = labels_ids[:, 1:].contiguous()

return {
"input_ids": token_ids_padded,
"decoder_input_ids": target_ids_padded,
"labels": labels_ids
}


train_src, train_tgt = read_file()

train_dataset = T5Seq2seqDataset(train_src,
train_tgt,
tokenizer=tokenizer,
maxlen=maxlen)

trainer.train(model,
train_dataset=train_dataset,
collate_fn=t5_seq2seq_collate_fn)
6 changes: 6 additions & 0 deletions flagai/model/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def __init__(self,
"""
self.tokenizer = tokenizer
if getattr(self.tokenizer, "token_end_id", None) is None:
setattr(self.tokenizer, "token_end_id", 1)

if getattr(self.tokenizer, "token_start_id", None) is None:
setattr(self.tokenizer, "token_start_id", 0)

self.model = model
self.model.eval()
self.class_name = type(model).__name__
Expand Down
4 changes: 2 additions & 2 deletions flagai/model/predictor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,9 +705,9 @@ def t5_predict_generate(model,

with torch.no_grad():
device = next(model.parameters()).device
decoder_input_ids = torch.tensor(decoder_input_ids, device=device)
decoder_input_ids = torch.tensor(decoder_input_ids, device=device, dtype=torch.long)
if input_ids is not None:
input_ids = torch.tensor(input_ids, device=device)
input_ids = torch.tensor(input_ids, device=device, dtype=torch.long)
if input_ids.ndim == 1:
input_ids = input_ids.view(1, -1)

Expand Down
76 changes: 38 additions & 38 deletions flagai/model/t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,18 +756,18 @@ def __init__(self, config, **kwargs):
self.device_map = None

def get_input_embeddings(self):
return self.shared
return self.shared.weight.data

def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
self.lm_head.weight.data = new_embeddings

def get_output_embeddings(self):
return self.lm_head
return self.lm_head.weight

def get_encoder(self):
return self.encoder
Expand All @@ -776,23 +776,23 @@ def get_decoder(self):
return self.decoder

def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
mems=None,
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_dict=True,
mems=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Expand All @@ -812,8 +812,6 @@ def forward(
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model.generate(input_ids)
"""
use_cache = use_cache if use_cache is not None else self.config['use_cache']
return_dict = return_dict if return_dict is not None else self.config['use_return_dict']

# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
Expand All @@ -836,11 +834,11 @@ def forward(
elif return_dict and not isinstance(encoder_outputs, dict):
encoder_outputs = {
"last_hidden_state":
encoder_outputs[0],
encoder_outputs[0],
"hidden_states":
encoder_outputs[1] if len(encoder_outputs) > 1 else None,
encoder_outputs[1] if len(encoder_outputs) > 1 else None,
"attentions":
encoder_outputs[2] if len(encoder_outputs) > 2 else None,
encoder_outputs[2] if len(encoder_outputs) > 2 else None,
}

hidden_states = encoder_outputs['last_hidden_state']
Expand Down Expand Up @@ -898,10 +896,6 @@ def forward(
self.lm_head = self.lm_head.to(self.encoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)

if self.config['tie_word_embeddings']:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self['model_dim']**-0.5)
#
lm_logits = self.lm_head(sequence_output)

Expand All @@ -915,17 +909,11 @@ def forward(
if not return_dict:
output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs
return ((loss, ) + output) if loss is not None else output
if self.config['return_logist_only']:
if labels is not None:
return {"loss": loss, "logits": lm_logits}

else:
return {"logits": lm_logits}

return {
"loss": loss,
"logits": lm_logits,
"past_key_values": decoder_outputs.past_key_values,
"past_key_values": decoder_outputs["past_key_values"],
"decoder_hidden_states": decoder_outputs["hidden_states"],
"decoder_attentions": decoder_outputs["attentions"],
"cross_attentions": decoder_outputs["cross_attentions"],
Expand Down Expand Up @@ -983,14 +971,26 @@ def _reorder_cache(self, past, beam_idx):
reordered_decoder_past = reordered_decoder_past + (
reordered_layer_past_states, )
return reordered_decoder_past

def load_weights(self, checkpoint_path):

_keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"lm_head\.weight",
]
_keys_to_ignore_on_load_unexpected = [
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]

checkpoint = torch.load(checkpoint_path,
map_location=torch.device("cpu"))
if "module" in checkpoint:
# ddp
checkpoint = checkpoint["module"]
self.load_state_dict(checkpoint, strict=True)
self.load_state_dict(checkpoint, strict=False)
self.set_output_embeddings(nn.Parameter(self.get_input_embeddings()))

return checkpoint

class T5EncoderModel(T5PreTrainedModel):
Expand Down

0 comments on commit 598af04

Please sign in to comment.