From cc27fff742232e45f8852d42db330bb213f87e31 Mon Sep 17 00:00:00 2001 From: mingMelody <2416013822@qq.com> Date: Thu, 21 Nov 2024 12:18:39 +0000 Subject: [PATCH] fixbug for chatglm_v2's RetaryEmbedding dtype --- paddlenlp/transformers/chatglm_v2/modeling.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index 2d2af39c2b30..d8eb469e119b 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -98,7 +98,7 @@ class RotaryEmbedding(nn.Layer): def __init__(self, dim, original_impl=False): super().__init__() self.default_dtype = paddle.get_default_dtype() - inv_freq = 1.0 / (10000 ** (paddle.arange(0, dim, 2, dtype="float32") / dim)) + inv_freq = 1.0 / (10000 ** (paddle.arange(0, dim, 2, dtype=self.default_dtype) / dim)) self.register_buffer("inv_freq", inv_freq) self.dim = dim self.original_impl = original_impl @@ -113,16 +113,16 @@ def forward_impl(self, seq_len: int, n_elem: int, base: int = 10000): theta = 1.0 / (base ** (paddle.arange(0, n_elem, 2, dtype="float32") / n_elem)) # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = paddle.arange(0, seq_len, dtype=theta.dtype) + seq_idx = paddle.arange(0, seq_len, dtype="float32") # Calculate the product of position index and $\theta_i$ - idx_theta = paddle.outer(seq_idx, theta).astype(self.default_dtype) + idx_theta = paddle.outer(seq_idx, theta).astype("float32") cache = paddle.stack([paddle.cos(idx_theta), paddle.sin(idx_theta)], axis=-1) # this is to mimic the behaviour of complex32, else we will get different results - if self.default_dtype in (paddle.float16, paddle.bfloat16, paddle.int8): - cache = cache.astype(self.default_dtype) + if self.default_dtype in ("float16", "bfloat16", "int8"): + cache = cache.astype("bfloat16") if self.default_dtype == "bfloat16" else cache.astype("float16") # cache = cache.bfloat16() if dtype == paddle.bfloat16 else cache.astype("float16") return cache