From 3f4af70d33e7d03b91f895fc85e43d9a28b33d8c Mon Sep 17 00:00:00 2001 From: Fabrice TIERCELIN Date: Wed, 22 May 2024 20:40:42 +0200 Subject: [PATCH] Implement functions that no longer exists --- .../mpt/hf_prefixlm_converter.py | 64 +++++++++++++++++-- 1 file changed, 59 insertions(+), 5 deletions(-) diff --git a/llava/model/language_model/mpt/hf_prefixlm_converter.py b/llava/model/language_model/mpt/hf_prefixlm_converter.py index 8c1a648..74707bb 100644 --- a/llava/model/language_model/mpt/hf_prefixlm_converter.py +++ b/llava/model/language_model/mpt/hf_prefixlm_converter.py @@ -12,20 +12,74 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss -from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom -from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom from transformers.models.bloom.modeling_bloom import logging from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM from transformers.models.gptj.modeling_gptj import GPTJForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM -from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt -from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt logger = logging.get_logger(__name__) _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM) CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM] +def _make_causal_mask_bloom( + input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int +) -> torch.BoolTensor: + """ + Make causal mask used for self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device) + # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround + seq_ids = torch.arange(target_length, device=device) + mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :] + + if past_key_values_length > 0: + mask[:, :past_key_values_length] = False + + expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) + return expanded_mask + +def _expand_mask_bloom(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape + tgt_length = tgt_length if tgt_length is not None else src_length + + expanded_mask = ~(mask[:, None, None, :].to(torch.bool)) + return expanded_mask.expand(batch_size, 1, tgt_length, src_length) + +def _make_causal_mask_opt( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +def _expand_mask_opt(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES: """Converts a GPT-style Causal LM to a Prefix LM. @@ -412,4 +466,4 @@ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]): elif 'labels' in batch and 'attention_mask' in batch: batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask']) else: - raise KeyError('No bidirectional_mask in batch and not sure how to construct one.') \ No newline at end of file + raise KeyError('No bidirectional_mask in batch and not sure how to construct one.')