Skip to content

Commit

Permalink
Adding fix for llama CI failure caused by ttnn.experimental.tensor.ty…
Browse files Browse the repository at this point in the history
…pecast (#11765)

#0: Adding fix for CI failure on ttnn.experimental.typecast.
  • Loading branch information
avoraTT authored Aug 22, 2024
1 parent c9a6748 commit 9022fbb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,13 +392,13 @@ def prefill_attn_mqa(
keys = self.layer_past[0]
# Fill cache expects batch in dim0
keys_reshaped = ttnn.reshape(keys, [self.max_batch_size, self.n_local_kv_heads, -1, self.head_dim])
ttnn.fill_cache(keys_reshaped, ttnn.experimental.tensor.typecast(key_layer, ttnn.bfloat8_b), user_id)
ttnn.fill_cache(keys_reshaped, ttnn.experimental.typecast(key_layer, ttnn.bfloat8_b), user_id)

# FILL V CACHE
values = self.layer_past[1]
# Fill cache expects batch in dim0
values_reshaped = ttnn.reshape(values, [self.max_batch_size, self.n_local_kv_heads, -1, self.head_dim])
ttnn.fill_cache(values_reshaped, ttnn.experimental.tensor.typecast(value_layer, ttnn.bfloat8_b), user_id)
ttnn.fill_cache(values_reshaped, ttnn.experimental.typecast(value_layer, ttnn.bfloat8_b), user_id)

# SDPA
attn_output = ttnn.transformer.scaled_dot_product_attention(
Expand Down

0 comments on commit 9022fbb

Please sign in to comment.