-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GLM-4 and Later GLM Model (Draft) #31977
Conversation
Support Cache class
Hi @zRzRzRzRzRzRzR ! Thanks for drafting the PR. The workflow has been failing due to the usage of TikToken. Once the converter scripts converts tiktoken configuration to HF tokenizer configuration, you won't need to import tiktoken during inference at tokenization_glm.py |
YEP! Will review today 🤗 |
Fix this issue now~ Tks |
Fix attention mask for right padding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am stopping the review as a LOT of the comments are still not adressed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file should be removed as we can map the GPT2Tokenizer direct and use it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here, we can use GPT2TokenizerFast!
logger = logging.get_logger(__name__) | ||
|
||
|
||
class GLMConfig(PretrainedConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is still the issue with the camel casing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GLM is the name of our model, not Glm. Do we need to stick to camel case in this context as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's the same for LLaMa which we set to Llama!
This is the configuration class to store the configuration of a [`GLMModel`]. It is used to instantiate a Phi-3 | ||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the configuration class to store the configuration of a [`GLMModel`]. It is used to instantiate a Phi-3 | |
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the | |
This is the configuration class to store the configuration of a [`GLMModel`]. It is used to instantiate a GLM | |
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix
if is_flash_attn_2_available(): | ||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | ||
from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
|
||
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again this was refactored
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix now
def _get_unpad_data(attention_mask): | ||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | ||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | ||
max_seqlen_in_batch = seqlens_in_batch.max().item() | ||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | ||
return ( | ||
indices, | ||
cu_seqlens, | ||
max_seqlen_in_batch, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix
class GLMRotaryEmbedding(nn.Module): | ||
def __init__(self, dim, rope_theta=1, original_impl=False, device=None, dtype=None): | ||
super().__init__() | ||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) | ||
self.register_buffer("inv_freq", inv_freq) | ||
self.dim = dim | ||
self.original_impl = original_impl | ||
self.rope_theta = rope_theta | ||
|
||
def forward_impl( | ||
self, | ||
seq_len: int, | ||
n_elem: int, | ||
dtype: torch.dtype, | ||
device: torch.device, | ||
base: int = 10000, | ||
): | ||
"""Enhanced Transformer with Rotary Position Embedding. | ||
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ | ||
transformers/rope/__init__.py. MIT License: | ||
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. | ||
""" | ||
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ | ||
base = base * self.rope_theta | ||
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) | ||
|
||
# Create position indexes `[0, 1, ..., seq_len - 1]` | ||
seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) | ||
|
||
# Calculate the product of position index and $\theta_i$ | ||
idx_theta = torch.outer(seq_idx, theta).float() | ||
|
||
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1).to(dtype=dtype) | ||
return cache | ||
|
||
def forward(self, max_seq_len, offset=0): | ||
return self.forward_impl( | ||
max_seq_len, | ||
self.dim, | ||
dtype=self.inv_freq.dtype, | ||
device=self.inv_freq.device, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again same comment here, this is equivalent to LlamaRotaryEmbedidng
return tensor_list | ||
|
||
|
||
class SelfAttention(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is still waiting!
if self.multi_query_attention: | ||
self.num_multi_query_groups_per_partition = self.multi_query_group_num | ||
self.qkv_hidden_size = ( | ||
self.projection_size + 2 * self.hidden_size_per_attention_head * self.multi_query_group_num | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again same comment about GQA and MQA
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙏🏻
logger = logging.get_logger(__name__) | ||
|
||
|
||
class GLMConfig(PretrainedConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's the same for LLaMa which we set to Llama!
Feel free to ping me again for a review! |
BTW @zRzRzRzRzRzRzR, I took over and am currently adding the model. You can find the new PR here #33823, should be ready pretty soon |
This is a draft and we will continue work
Pull Request section?
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker