Skip to content

Commit

Permalink
fix attention mask for glm4
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Aug 30, 2024
1 parent 9a18ae0 commit afd2b6d
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,9 @@ def _chatglm2_core_attention_forward(self, query_layer, key_layer, value_layer,


def _glm4_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask):
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask.to(torch.float32)
)
causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32)
causal_mask.masked_fill_(attention_mask, float("-inf"))
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, causal_mask)
context_layer = context_layer.transpose(1, 2).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
Expand Down Expand Up @@ -404,9 +403,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)

if (
self.config._attn_implementation == "sdpa"
Expand Down Expand Up @@ -1966,9 +1965,9 @@ def _dbrx_update_causal_mask_legacy(
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)

if (
self.config._attn_implementation == "sdpa"
Expand Down

0 comments on commit afd2b6d

Please sign in to comment.