Skip to content

Commit

Permalink
#6361: Update ttnn repeat to use correct shapes when formatting output
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed Mar 19, 2024
1 parent e80763b commit 81ebe08
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ttnn/ttnn/operations/data_movement.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def _repeat_validate_input_tensors(operation_name, input_tensor, *args, **kwargs
input_tensor,
ranks=(2, 3, 4),
dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32),
layouts=(ttnn.TILE_LAYOUT,),
layouts=(ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT),
can_be_on_device=True,
can_be_on_cpu=True,
)
Expand Down Expand Up @@ -548,11 +548,11 @@ def repeat(
device = input_tensor.device()
layout = input_tensor.layout
rank = len(input_tensor.shape)
if dtype == ttnn.bfloat16 and rank == 4:
if rank == 4:
output_tensor = ttl.tensor.repeat(input_tensor, shape)
*batch, _, _ = output_tensor.shape
*_, h, w = input_tensor.shape
*_, padded_h, padded_w = input_tensor.shape.with_tile_padding()
*_, h, w = output_tensor.shape
*_, padded_h, padded_w = output_tensor.shape.with_tile_padding()

output_tensor = ttnn.reshape(output_tensor, shape=ttnn.Shape(batch + [h, w], batch + [padded_h, padded_w]))
return output_tensor
Expand Down

0 comments on commit 81ebe08

Please sign in to comment.