Skip to content

Commit

Permalink
Fix attention mask for glm4 (#884)
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova authored Aug 30, 2024
1 parent af8c28d commit d6e6e1f
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,9 @@ def _chatglm2_core_attention_forward(self, query_layer, key_layer, value_layer,


def _glm4_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask):
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask.to(torch.float32)
)
causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32)
causal_mask.masked_fill_(attention_mask, float("-inf"))
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, causal_mask)
context_layer = context_layer.transpose(1, 2).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
Expand Down

0 comments on commit d6e6e1f

Please sign in to comment.