From 3a0c8cb4127a0bfd86ddd965ab4ece40b128ee6c Mon Sep 17 00:00:00 2001 From: zhaohu xing <32668889+920232796@users.noreply.github.com> Date: Wed, 6 Jul 2022 14:19:29 +0800 Subject: [PATCH] Opt 66b (#19) * autoloader for opt * opt-66b inference * Update train.py * Load data from example dir * add readme of multi GPU inference Co-authored-by: Zac Liu --- examples/glm_title_generation/train.py | 12 +- examples/opt/README.md | 98 +++++++++++++- examples/opt/generate_opt_66b.py | 22 ++++ examples/opt/opt_30b_en_mutigpu.py | 3 - examples/opt/opt_66b_en_mutigpu.py | 108 ++++++++++++++++ flagai/auto_model/auto_loader.py | 8 ++ flagai/model/base_model.py | 2 + flagai/model/blocks/gpt2_block.py | 2 + flagai/model/gpt2_model.py | 133 ++++++++++++------- flagai/model/layers/attentions.py | 50 ++++---- flagai/model/opt_model.py | 169 +------------------------ flagai/model/predictor/gpt.py | 55 ++++++++ flagai/model/predictor/predictor.py | 4 +- flagai/mp_tools.py | 12 +- 14 files changed, 430 insertions(+), 248 deletions(-) create mode 100644 examples/opt/generate_opt_66b.py create mode 100644 examples/opt/opt_66b_en_mutigpu.py create mode 100644 flagai/model/predictor/gpt.py diff --git a/examples/glm_title_generation/train.py b/examples/glm_title_generation/train.py index f7dd2654..e06d2c0b 100644 --- a/examples/glm_title_generation/train.py +++ b/examples/glm_title_generation/train.py @@ -27,12 +27,16 @@ num_checkpoints=1, ) -cur_dir = os.path.dirname(os.path.abspath(__file__)) -src_dir = cur_dir + '/data/train.src' -tgt_dir = cur_dir + '/data/train.tgt' +# cur_dir = os.path.dirname(os.path.abspath(__file__)) +# src_dir = cur_dir + '/data/train.src' +# tgt_dir = cur_dir + '/data/train.tgt' + +src_dir = "./data/train.src" +tgt_dir = "./data/train.tgt" + maxlen = 256 -auto_loader = AutoLoader("seq2seq", +auto_loader = AutoLoader("lm", model_name="GLM-large-ch", model_dir="./state_dict/") model = auto_loader.get_model() diff --git a/examples/opt/README.md b/examples/opt/README.md index 4ad5aa4f..ee727932 100644 --- a/examples/opt/README.md +++ b/examples/opt/README.md @@ -52,4 +52,100 @@ out = predictor.predict_generate_randomsample(text, repetition_penalty=3.0) print(f"input is {text} \n out is {out}") -``` \ No newline at end of file +``` + +# Multi-GPU inference +## OPT-30b + +To inference by multi-GPU and model parallel, we use torch-DDP and Megatron-LM library. +### Basic step +1. Set up the parameters of model parallel, such as ```model_parallel_size``` and ```world_size``` +2. Initialize torch-DDP +3. Initialize Megatron-LM, model parallel +4. Set up random seed +5. Initialize the model and tokenizer +6. Prediction +### code +```python +import torch +import os +import argparse +from flagai import mpu +from flagai.auto_model.auto_loader import AutoLoader +import random +import numpy as np +from flagai.model.predictor.predictor import Predictor + +# run script : python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 opt_30b_en_mutigpu.py +os.environ["ENV_TYPE"] = "deepspeed+mpu" +model_parallel_size = 4 +world_size = 4 + +os.environ["MODEL_PARALLEL_SIZE"] = str(model_parallel_size) +os.environ["WORLD_SIZE"] = str(world_size) + +def set_random_seed(seed): + """Set random seed for reproducability.""" + if seed is not None and seed > 0: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + mpu.model_parallel_cuda_manual_seed(seed) + +parser = argparse.ArgumentParser() +parser.add_argument('--local_rank', + type=int, + default=0, + help="local_rank") + +ds_args = parser.parse_args() +local_rank = ds_args.local_rank + +master_addr = os.environ.get('MASTER_ADDR', '127.0.0.1') +master_port = os.environ.get('MASTER_PORT', '17501') + +device = torch.device("cuda", local_rank) + +def initialize_distributed(): + """Initialize torch.distributed.""" + torch.backends.cudnn.enabled = False + # Manually set the device ids. + torch.cuda.set_device(device) + # Call the init process + init_method = 'tcp://' + + init_method += master_addr + ':' + master_port + torch.distributed.init_process_group( + backend='nccl', # gloo + world_size=world_size, + rank=local_rank, + init_method=init_method) + mpu.initialize_model_parallel(model_parallel_size) + +initialize_distributed() + +set_random_seed(123) + +print(f"building model...") +loader = AutoLoader("lm", model_name="opt-30b-en") +model = loader.get_model() +tokenizer = loader.get_tokenizer() +model.half() + +model.parallel_output = False +model.eval() +model.to(device) + +torch.distributed.barrier(group=mpu.get_model_parallel_group()) + +text = """I think The Old Man and the Sea is a very good book, what do you think? I think """ + +predictor = Predictor(model, tokenizer) +out = predictor.predict_generate_randomsample(text) +if mpu.get_model_parallel_rank() == 0: + print(f"pred is {out}") +``` +### Run script is +```commandline +python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 opt_30b_en_mutigpu.py +``` diff --git a/examples/opt/generate_opt_66b.py b/examples/opt/generate_opt_66b.py new file mode 100644 index 00000000..bbee7987 --- /dev/null +++ b/examples/opt/generate_opt_66b.py @@ -0,0 +1,22 @@ +from flagai.model.predictor.predictor import Predictor +from flagai.auto_model.auto_loader import AutoLoader +import torch + +loader = AutoLoader(task_name="lm", + model_name="opt-66b-en") + +model = loader.get_model() +tokenizer = loader.get_tokenizer() +model.eval() + +text = """I think The Old Man and the Sea is a very good book, what do you think? Thank you for your question, I think """ + +predictor = Predictor(model, tokenizer) +out = predictor.predict_generate_randomsample(text, + input_max_length=100, + out_max_length=300, + top_k=50, + top_p=0.9, + repetition_penalty=3.0) + +print(f"input is {text} \n out is {out}") \ No newline at end of file diff --git a/examples/opt/opt_30b_en_mutigpu.py b/examples/opt/opt_30b_en_mutigpu.py index ac2a4a1f..623e769f 100644 --- a/examples/opt/opt_30b_en_mutigpu.py +++ b/examples/opt/opt_30b_en_mutigpu.py @@ -10,8 +10,6 @@ import glob import time -# run script : python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 opt_30b_en_mutigpu.py - os.environ["ENV_TYPE"] = "deepspeed+mpu" model_parallel_size = 4 world_size = 4 @@ -61,7 +59,6 @@ def initialize_distributed(): set_random_seed(123) - print(f"building model...") loader = AutoLoader("lm", model_name="opt-30b-en") model = loader.get_model() diff --git a/examples/opt/opt_66b_en_mutigpu.py b/examples/opt/opt_66b_en_mutigpu.py new file mode 100644 index 00000000..76616980 --- /dev/null +++ b/examples/opt/opt_66b_en_mutigpu.py @@ -0,0 +1,108 @@ +# os.environ["CUDA_VISIBLE_DEVICES"] = "0,2" +import torch +import os +import time +os.environ["ENV_TYPE"] = "deepspeed+mpu" +os.environ["MODEL_PARALLEL_SIZE"] = '8' +os.environ["WORLD_SIZE"] = '8' +import argparse +from flagai import mpu +import random +import numpy as np +from flagai.model.predictor.predictor import Predictor +from flagai.model.opt_model import OPTModel +from flagai.data.tokenizer import OPTTokenizer + +def get_current_rank(): + with open('current_rank','r',encoding='utf8') as infile: + line = infile.readline().strip() + return int(line) +def set_current_rank(rank): + with open('current_rank','w',encoding='utf8') as outfile: + outfile.write(str(rank)) + +def get_current_pool(): + with open('current_pool','r',encoding='utf8') as infile: + line = infile.readline().strip() + return int(line) + +def set_current_pool(rank): + with open('current_pool','w',encoding='utf8') as outfile: + outfile.write(str(rank)) + +# run script : python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 opt_66b_en_mutigpu.py +parser = argparse.ArgumentParser() +parser.add_argument('--local_rank', + type=int, + default=0, + help="local_rank") + +def set_random_seed(seed): + """Set random seed for reproducability.""" + if seed is not None and seed > 0: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + mpu.model_parallel_cuda_manual_seed(seed) + +ds_args = parser.parse_args() +local_rank = ds_args.local_rank + +master_addr = os.environ.get('MASTER_ADDR', '127.0.0.1') +master_port = os.environ.get('MASTER_PORT', '17501') + +device = torch.device("cuda", local_rank) +model_parallel_size = 8 +world_size = 8 + +def initialize_distributed(): + """Initialize torch.distributed.""" + torch.backends.cudnn.enabled = False + # Manually set the device ids. + torch.cuda.set_device(device) + # Call the init process + init_method = 'tcp://' + + init_method += master_addr + ':' + master_port + torch.distributed.init_process_group( + backend='nccl', # gloo + world_size=world_size, + rank=local_rank, + init_method=init_method) + mpu.initialize_model_parallel(model_parallel_size) + +initialize_distributed() + +set_current_pool(4) +set_current_rank(0) +set_random_seed(123) +torch.distributed.barrier(group=mpu.get_model_parallel_group()) +tokenizer = OPTTokenizer() + +while get_current_rank() != local_rank: + time.sleep(10) +while get_current_pool() == 0: + time.sleep(10) +set_current_pool(get_current_pool()-1) +print("loading rank {}".format(local_rank)) +set_current_rank(local_rank + 1) + +model = OPTModel.init_from_json('/mnt/models_xingzhaohu/opt-66b-en/config.json') +checkpoint_path = '/mnt/models_xingzhaohu/opt-66b-en/pytorch_model_{:02d}.bin'.format(local_rank) +model.half() +model.eval() +model.to(device) +model.load_weights(checkpoint_path) + +print("loading rank {} finished".format(local_rank)) +set_current_pool(get_current_pool()+1) +print('current rank setting is {}'.format(get_current_pool())) + +torch.distributed.barrier(group=mpu.get_model_parallel_group()) +text = """I think The Old Man and the Sea is a very good book, what do you think? I think """ + +predictor = Predictor(model, tokenizer) +out = predictor.predict_generate_randomsample(text) +if mpu.get_model_parallel_rank() == 0: + print(f"pred is {out}") + diff --git a/flagai/auto_model/auto_loader.py b/flagai/auto_model/auto_loader.py index bcd59b62..2c79a8fd 100644 --- a/flagai/auto_model/auto_loader.py +++ b/flagai/auto_model/auto_loader.py @@ -72,6 +72,8 @@ def __getattr__(self, name): "opt-6.7b-en": ["flagai.model.opt_model","OPTModel", "opt"], "opt-13b-en": ["flagai.model.opt_model","OPTModel", "opt"], "opt-30b-en": ["flagai.model.opt_model","OPTModel", "opt"], + "opt-66b-en": ["flagai.model.opt_model","OPTModel", "opt"], + } TOKENIZER_DICT = { @@ -96,6 +98,8 @@ def __getattr__(self, name): "opt-6.7b-en": ["flagai.data.tokenizer.opt.opt_en_tokenizer","OPTTokenizer"], "opt-13b-en": ["flagai.data.tokenizer.opt.opt_en_tokenizer","OPTTokenizer"], "opt-30b-en": ["flagai.data.tokenizer.opt.opt_en_tokenizer","OPTTokenizer"], + "opt-66b-en": ["flagai.data.tokenizer.opt.opt_en_tokenizer","OPTTokenizer"], + } @@ -106,6 +110,7 @@ def __init__(self, model_name: str = "RoBERTa-base-ch", model_dir: str = "./checkpoints/", only_download_config: bool = False, + device="cpu", **kwargs): """ Args: @@ -169,6 +174,7 @@ def __init__(self, download_path=model_dir, model_name=model_name_, only_download_config=only_download_config, + device=device, **kwargs) model_id = _get_model_id(model_name) @@ -178,6 +184,8 @@ def __init__(self, vocab_file = os.path.join(download_path,'cog-pretrained.model') if not os.path.exists(vocab_file): vocab_file = _get_vocab_path(download_path, "cog-pretrain.model", model_id) + elif model_name == "glm-large-en": + vocab_file = "GLM-large-en" elif model_name == "cpm-large-ch": # two files to load vocab_file_1 = os.path.join(download_path, "vocab.json") diff --git a/flagai/model/base_model.py b/flagai/model/base_model.py index 2005fc43..5480b73b 100644 --- a/flagai/model/base_model.py +++ b/flagai/model/base_model.py @@ -45,6 +45,7 @@ def from_pretrain(cls, download_path='./checkpoints/', model_name='RoBERTa-base-ch', only_download_config=False, + device="cpu", **kwargs): model_id = None try: @@ -87,6 +88,7 @@ def from_pretrain(cls, model_id) if os.path.exists(config_path): model = cls.init_from_json(config_path, **kwargs) + model.to(device) if os.getenv('ENV_TYPE') != 'deepspeed+mpu': if os.path.exists(checkpoint_path): model.load_weights(checkpoint_path) diff --git a/flagai/model/blocks/gpt2_block.py b/flagai/model/blocks/gpt2_block.py index 27737b89..925419de 100644 --- a/flagai/model/blocks/gpt2_block.py +++ b/flagai/model/blocks/gpt2_block.py @@ -21,6 +21,7 @@ def __init__(self, n_ctx, config, scale=False): def forward( self, hidden_states, + layer_past=None, attention_mask=None, head_mask=None, use_cache=False, @@ -34,6 +35,7 @@ def forward( attn_outputs = self.attn( hidden_states, + layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, diff --git a/flagai/model/gpt2_model.py b/flagai/model/gpt2_model.py index 427c4be6..ec00a943 100644 --- a/flagai/model/gpt2_model.py +++ b/flagai/model/gpt2_model.py @@ -100,6 +100,8 @@ def __init__(self, config): self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.drop = nn.Dropout(config.embd_pdrop) + self.project_in = None + self.project_out = None self.h = nn.ModuleList([ GPT2Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer) @@ -114,12 +116,27 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings + def get_position_embeddings(self, **kwargs): + input_ids = kwargs["input_ids"] + input_shape = input_ids.size() + position_ids = kwargs.get("position_ids", None) + past_length = kwargs["past_length"] + if position_ids is None: + device = input_ids.device + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + position_embeds = self.wpe(position_ids) + return position_embeds + + def forward( self, input_ids, attention_mask=None, + past_key_values=None, position_ids=None, - use_cache=None, + use_cache=False, output_attentions=None, output_hidden_states=None, ): @@ -131,13 +148,22 @@ def forward( if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) - if position_ids is None: - device = input_ids.device - position_ids = torch.arange(0, - input_shape[-1], - dtype=torch.long, - device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + device = input_ids.device + + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + full_ids = input_ids + else: + past_length = past_key_values[0][0].size(-2) + full_ids = torch.ones((input_ids.shape[0], past_length + 1), dtype=torch.long, device=device) + + padding_mask = (full_ids > 0).float() + + position_embeds = self.get_position_embeddings(input_ids=input_ids, past_length=past_length, + position_ids=position_ids, padding_mask=padding_mask, + ) # Attention mask. if attention_mask is not None: @@ -145,17 +171,21 @@ def forward( attention_mask = (1.0 - attention_mask) * -10000.0 inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) + + output_shape = input_shape + (inputs_embeds.size(-1), ) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + # position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1), ) - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - for i, block in enumerate(self.h): + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) @@ -171,6 +201,7 @@ def custom_forward(*inputs): outputs = checkpoint( create_custom_forward(block), hidden_states, + None, attention_mask, None, use_cache, @@ -180,6 +211,7 @@ def custom_forward(*inputs): outputs = block( hidden_states, + layer_past=layer_past, attention_mask=attention_mask, head_mask=None, use_cache=use_cache, @@ -194,14 +226,19 @@ def custom_forward(*inputs): all_self_attentions = all_self_attentions + ( outputs[2 if use_cache else 1], ) - hidden_states = self.ln_f(hidden_states) + # hidden_states = self.ln_f(hidden_states) + if self.ln_f is not None: + hidden_states = self.ln_f(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) hidden_states = hidden_states.view(*output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) - return hidden_states + return hidden_states, presents class GPT2Model(BaseModel): @@ -245,6 +282,7 @@ def forward( **data, ): input_ids = data.get("input_ids", None) + past_key_values = data.get("past_key_values", None) attention_mask = data.get("attention_mask", None) position_ids = data.get("position_ids", None) labels = data.get("labels", None) @@ -252,11 +290,18 @@ def forward( output_attentions = data.get("output_attentions", None) output_hidden_states = data.get("output_hidden_states", None) + device = input_ids.device extend_mask = (input_ids > 0).float() if attention_mask is None: - attention_mask = self._make_causal_mask(input_ids) - extend_mask = extend_mask.unsqueeze(1).unsqueeze( - 1) * attention_mask + + if past_key_values is not None: + past_length = past_key_values[0][0].size(-2) + full_ids = torch.zeros((input_ids.shape[0], past_length + 1), dtype=torch.long, device=device) + extend_mask = self._make_causal_mask(full_ids) + else : + attention_mask = self._make_causal_mask(input_ids) + extend_mask = extend_mask.unsqueeze(1).unsqueeze( + 1) * attention_mask transformer_outputs = self.transformer( input_ids, @@ -265,16 +310,17 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + past_key_values=past_key_values, ) - logits = transformer_outputs + + logits, past_key_values = transformer_outputs + if os.getenv("ENV_TYPE") == 'deepspeed+mpu': logits_parallel = copy_to_model_parallel_region(logits) else: logits_parallel = logits - # if self.output_predict: - # Parallel logits. logits_parallel = F.linear(logits_parallel, self.transformer.wte.weight) @@ -287,36 +333,27 @@ def forward( shift_logits.contiguous().float(), shift_labels).mean() else: loss = F.cross_entropy( - shift_logits.contiguous().float(), shift_labels.long()) - - if self.parallel_output: # Put in different GPUs - return { - 'logits': logits_parallel, - 'loss': loss, - 'hidden_states': None, - } - else: - return { - "logits": - gather_from_model_parallel_region(logits_parallel), - "loss": - loss, - "hidden_states": - None, - } + shift_logits.view(-1, shift_logits.shape[-1]).contiguous().float(), shift_labels.view(-1).contiguous().long()) + + return { + 'logits': logits_parallel, + 'loss': loss, + 'hidden_states': past_key_values, + } + else: - if self.parallel_output: # Put in different GPUs - return { - 'logits': logits_parallel, - 'hidden_states': None, - } + + if os.getenv("ENV_TYPE") == 'deepspeed+mpu': + logits = gather_from_model_parallel_region(logits_parallel) else: - return { - "logits": - gather_from_model_parallel_region(logits_parallel), - "hidden_states": - None, - } + logits = logits_parallel + return { + "logits": + logits, + "hidden_states": + past_key_values, + } + def load_weights(self, checkpoint_path): checkpoint = torch.load(checkpoint_path, diff --git a/flagai/model/layers/attentions.py b/flagai/model/layers/attentions.py index 0cfa5065..3fcf17cf 100644 --- a/flagai/model/layers/attentions.py +++ b/flagai/model/layers/attentions.py @@ -96,12 +96,11 @@ def _attn(self, v, attention_mask=None, head_mask=None, - output_attentions=False): - w = torch.matmul(q, k) + ): + w = torch.matmul(q, k.transpose(-1, -2)) if self.scale: w = w / (float(v.size(-1))**0.5) - nd, ns = w.size(-2), w.size(-1) # if not self.is_cross_attention: # if only "normal" attention layer implements causal mask @@ -119,17 +118,16 @@ def _attn(self, if head_mask is not None: w = w * head_mask w = w.to(v.dtype) # fp16 - outputs = (torch.matmul(w, v), ) - if output_attentions: - outputs += (w, ) - return outputs + outputs = torch.matmul(w, v) + + return outputs, w def merge_heads(self, x): x = x.permute(0, 2, 1, 3).contiguous() new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1), ) return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states - def split_heads(self, x, k=False): + def split_heads(self, x): if os.getenv('ENV_TYPE') == 'deepspeed+mpu': new_x_shape = x.size()[:-1] + ( self.num_attention_heads_per_partition, @@ -138,44 +136,52 @@ def split_heads(self, x, k=False): new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states - if k: - return x.permute(0, 2, 3, - 1) # (batch, head, head_features, seq_length) - else: - return x.permute(0, 2, 1, - 3) # (batch, head, seq_length, head_features) + # if k: + # return x.permute(0, 2, 3, + # 1) # (batch, head, head_features, seq_length) + # else: + return x.permute(0, 2, 1, + 3) # (batch, head, seq_length, head_features) def forward( self, hidden_states, + layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False, ): - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) query = self.split_heads(query) - key = self.split_heads(key, k=True) + key = self.split_heads(key) value = self.split_heads(value) + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + if use_cache is True: - present = (key.transpose(-2, -1), value - ) # transpose to have same shapes + present = (key, value + ) else: present = None attn_outputs = self._attn(query, key, value, attention_mask, head_mask, - output_attentions) + ) a = attn_outputs[0] - + if layer_past is not None: + a = a[:, :, -1:] a = self.merge_heads(a) a = self.c_proj(a) a = self.resid_dropout(a) - - return (a, present) + attn_outputs[1:] # a, present, (attentions) + outputs = (a, present) + if output_attentions: + outputs += (attn_outputs[1]) + return outputs # a, present, (attentions) class T5Attention(nn.Module): diff --git a/flagai/model/opt_model.py b/flagai/model/opt_model.py index ef89695d..874cb2c2 100644 --- a/flagai/model/opt_model.py +++ b/flagai/model/opt_model.py @@ -24,74 +24,6 @@ from flagai.model.gpt2_model import GPT2Model, GPT2Stack, GPT2Config from torch.utils.checkpoint import checkpoint - -# class GPT2Config: -# -# def __init__( -# self, -# vocab_size=50257, -# n_positions=1024, -# n_ctx=1024, -# n_embd=768, -# n_layer=12, -# n_head=12, -# n_inner=None, -# activation_function="gelu_new", -# resid_pdrop=0.1, -# embd_pdrop=0.1, -# attn_pdrop=0.1, -# layer_norm_epsilon=1e-5, -# initializer_range=0.02, -# summary_type="cls_index", -# summary_use_proj=True, -# summary_activation=None, -# summary_proj_to_labels=True, -# summary_first_dropout=0.1, -# scale_attn_weights=True, -# gradient_checkpointing=False, -# use_cache=True, -# bos_token_id=50256, -# eos_token_id=50256, -# checkpoint_activations=False, -# hidden_size=768, -# ): -# self.checkpoint_activations = checkpoint_activations -# self.vocab_size = vocab_size -# # self.n_ctx = n_ctx -# self.n_positions = n_positions -# self.n_ctx = n_positions -# self.n_embd = n_embd -# self.hidden_size = hidden_size -# self.n_layer = n_layer -# self.n_head = n_head -# self.n_inner = n_inner -# self.activation_function = activation_function -# self.resid_pdrop = resid_pdrop -# self.embd_pdrop = embd_pdrop -# self.attn_pdrop = attn_pdrop -# self.layer_norm_epsilon = layer_norm_epsilon -# self.initializer_range = initializer_range -# self.summary_type = summary_type -# self.summary_use_proj = summary_use_proj -# self.summary_activation = summary_activation -# self.summary_first_dropout = summary_first_dropout -# self.summary_proj_to_labels = summary_proj_to_labels -# self.gradient_checkpointing = gradient_checkpointing -# self.scale_attn_weights = scale_attn_weights -# self.use_cache = use_cache -# -# self.bos_token_id = bos_token_id -# self.eos_token_id = eos_token_id - - -_CHECKPOINT_FOR_DOC = "facebook/opt-350m" -_CONFIG_FOR_DOC = "OPTConfig" -_TOKENIZER_FOR_DOC = "GPT2Tokenizer" - -# Base model docstring -_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] - - OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ "facebook/opt-125m", "facebook/opt-350m", @@ -120,22 +52,11 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int # create positions depending on attention_mask positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 - # cut positions if `past_key_values_length` is > 0 positions = positions[:, past_key_values_length:] return super().forward(positions + self.offset) -def make_causal_mask(input_ids): - device = input_ids.device - bsz, tgt_len = input_ids.shape - mask = torch.full((tgt_len, tgt_len), 0.0).to(device) - mask_cond = torch.arange(mask.size(-1)).to(device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), - 1.0) - - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) - class OPTStack(GPT2Stack): def __init__(self, config: GPT2Config): super(OPTStack, self).__init__(config) @@ -154,88 +75,12 @@ def __init__(self, config: GPT2Config): else: self.project_in = None - def forward( - self, - input_ids, - attention_mask=None, - position_ids=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - ): - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - - extend_mask = (input_ids > 0).float() - position_embeds = self.wpe(extend_mask, 0) - - # if attention_mask is None: - attention_mask = make_causal_mask(input_ids) - attention_mask = extend_mask.unsqueeze(1).unsqueeze( - 1) * attention_mask - attention_mask = (1.0 - attention_mask) * -10000.0 - - inputs_embeds = self.wte(input_ids) - if self.project_in is not None: - inputs_embeds = self.project_in(inputs_embeds) - hidden_states = inputs_embeds + position_embeds - - hidden_states = self.drop(hidden_states) - - # output_shape = input_shape + (hidden_states.size(-1), ) - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for i, block in enumerate(self.h): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states, ) - if self.config.checkpoint_activations: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - - outputs = checkpoint( - create_custom_forward(block), - hidden_states, - attention_mask, - None, - use_cache, - output_attentions, - ) - else: - - outputs = block( - hidden_states, - attention_mask=attention_mask, - head_mask=None, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1], ) - - if output_attentions: - all_self_attentions = all_self_attentions + ( - outputs[2 if use_cache else 1], ) - - if self.ln_f is not None: - hidden_states = self.ln_f(hidden_states) - - if self.project_out is not None: - hidden_states = self.project_out(hidden_states) - - # hidden_states = hidden_states.view(*output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states, ) - - return hidden_states + def get_position_embeddings(self, **kwargs): + pass + padding_mask = kwargs["padding_mask"] + past_length = kwargs["past_length"] + position_embeds = self.wpe(padding_mask, past_length) + return position_embeds def trans_opt_to_gpt_config(opt_config_json): trans_config_json = {} @@ -262,7 +107,6 @@ class OPTModel(GPT2Model): def __init__(self, config, **kwargs): config = trans_opt_to_gpt_config(config) super(OPTModel, self).__init__(config, **kwargs) - # self.config = config self.transformer = OPTStack(self.config) def load_weights(self, checkpoint_path): @@ -279,7 +123,6 @@ def load_weights(self, checkpoint_path): else : checkpoint_[k] = v - checkpoint = self.transpose_weight(checkpoint_) self.load_state_dict(checkpoint, strict=False) self.lm_head.weight.data = nn.Parameter(self.transformer.wte.weight.data) diff --git a/flagai/model/predictor/gpt.py b/flagai/model/predictor/gpt.py new file mode 100644 index 00000000..e99d11f4 --- /dev/null +++ b/flagai/model/predictor/gpt.py @@ -0,0 +1,55 @@ +from flagai.model.predictor.utils import RepetitionPenaltyLogitsProcessor, TemperatureLogitsProcessor, TopPLogitsProcessor, TopKLogitsProcessor, ListProcessor +import torch +import torch.nn.functional as F + + +def gpt_random_sample_use_cache(model, tokenizer, text, input_max_length, out_max_length, + top_k, top_p, repetition_penalty, temperature, device): + tokenizer_out = tokenizer.encode_plus(text, max_length=input_max_length) + token_ids = tokenizer_out["input_ids"] + token_end_id = tokenizer.token_end_id + if token_ids[-1] == token_end_id: + token_ids = token_ids[:-1] + + lp = [ + RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty), + TemperatureLogitsProcessor(temperature=temperature), + TopKLogitsProcessor(top_k=top_k), + TopPLogitsProcessor(top_p=top_p), + ] + list_processor = ListProcessor(lp) + + token_ids = torch.tensor(token_ids, device=device, + dtype=torch.long).view(1, -1) + output_ids = [] + sep_id = tokenizer.token_end_id + outputs = model(**{"input_ids": token_ids, "use_cache": True}) + scores = outputs["logits"] + past_key_values = outputs["hidden_states"] + + logit_score = torch.log_softmax(scores[:, -1], dim=-1) + logit_score[:, tokenizer.token_unk_id] = -float('Inf') + + filtered_logits = list_processor(token_ids, logit_score) + next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), + num_samples=1) + token_ids = torch.cat([token_ids, next_token.long()], dim=1) + + with torch.no_grad(): + for step in range(out_max_length - 1): + outputs = model(**{"input_ids": next_token, "use_cache": True, "past_key_values": past_key_values}) + scores = outputs["logits"] + past_key_values = outputs["hidden_states"] + + logit_score = torch.log_softmax(scores[:, -1], dim=-1) + logit_score[:, tokenizer.token_unk_id] = -float('Inf') + + filtered_logits = list_processor(token_ids, logit_score) + next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), + num_samples=1) + if sep_id == next_token.item(): + break + output_ids.append(next_token.item()) + token_ids = torch.cat((token_ids, next_token.long()), dim=1) + + return tokenizer.decode(output_ids) \ No newline at end of file diff --git a/flagai/model/predictor/predictor.py b/flagai/model/predictor/predictor.py index a7a414e3..fa426945 100644 --- a/flagai/model/predictor/predictor.py +++ b/flagai/model/predictor/predictor.py @@ -8,6 +8,8 @@ t5_random_sample, gpt_random_sample, \ t5_beamsearch, gpt_beamsearch, bert_random_sample, glm_beamsearch, glm_random_sample from typing import List, Union, Dict, Tuple, Any +from flagai.model.predictor.gpt import gpt_random_sample_use_cache + class Predictor: @@ -277,7 +279,7 @@ def predict_generate_randomsample(self, device) elif "gpt" in self.class_name.lower() or "opt" in self.class_name.lower(): - return gpt_random_sample(self.model, self.tokenizer, text, + return gpt_random_sample_use_cache(self.model, self.tokenizer, text, input_max_length, out_max_length, top_k, top_p, repetition_penalty, temperature, device) diff --git a/flagai/mp_tools.py b/flagai/mp_tools.py index 5f172541..de5394af 100644 --- a/flagai/mp_tools.py +++ b/flagai/mp_tools.py @@ -48,16 +48,17 @@ def check_pytorch_model_mp_size(checkpoint: str, target_mp: int): """ check the checkpoints contains the weights for mp_size = target_mp """ + assert target_mp > 1 assert os.path.isdir(checkpoint) filenames = os.listdir(checkpoint) filenames = [ filename for filename in filenames - if filename.startswith("pytorch_model") + if filename.startswith("pytorch_model_") ] - if 'pytorch_model.bin' in filenames and target_mp == 1: - return True - else: - filenames.remove('pytorch_model.bin') + # if 'pytorch_model.bin' in filenames and target_mp == 1: + # return True + # else: + # filenames.remove('pytorch_model.bin') print( "check the weight files in {}, the number of mp_size({}) {} num_of_files({})" .format(checkpoint, target_mp, @@ -233,7 +234,6 @@ def change_pytorch_model_mp_from_1_to_n_new(model_name_brief, checkpoint: str, t d_new[k] = None d_new['module'] = {} with torch.no_grad(): - if "module" in d: d = d["module"]