Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ttnn.repeat throws runtime memory error #6361

Closed
kpaigwar opened this issue Mar 13, 2024 · 26 comments
Closed

ttnn.repeat throws runtime memory error #6361

kpaigwar opened this issue Mar 13, 2024 · 26 comments
Assignees
Labels

Comments

@kpaigwar
Copy link
Contributor

Describe the bug
When trying to expand a tensor of shape [1, 1, 32, 16] to [1, 1, 32, 16*2048] using ttnn.repeat, the kernel fails with the following error
FATAL | 32KB runtime args targeting kernel reader_concat_interleaved_start_id on (x=0,y=0) are too large. Cannot be written as they will run into memory region reserved for result. Max allowable size is 1 KB.

However, the same output tensor can be successfully created using ttnn.repeat_interleave api

To Reproduce
Run below test code in your tt-metal environment

import ttnn
import torch
from tests.ttnn.utils_for_testing import assert_with_pcc


device = ttnn.open_device(device_id=0)
torch_input_tensor = torch.randn((1, 1, 32, 16), dtype=torch.bfloat16)
repeat_shape = (1, 1, 1, 2048)

torch_result_repeat_interleave = torch_input_tensor.repeat_interleave(2048, dim=3)
torch_result_repeat = torch_input_tensor.repeat(repeat_shape)

repeat_tensor = torch.randn(repeat_shape, dtype=torch.bfloat16)
repeat_tensor = ttnn.from_torch(repeat_tensor, layout=ttnn.TILE_LAYOUT, device=device)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

repeat_interleaved_output = ttnn.repeat_interleave(input_tensor, 2048, dim=3)
repeat_interleaved_output = ttnn.to_torch(repeat_interleaved_output)
print(repeat_interleaved_output.shape)
assert_with_pcc(torch_result_repeat_interleave, repeat_interleaved_output, 0.9999)
print("Repeat Interleave Passed")

repeat_output = ttnn.repeat(input_tensor, repeat_tensor.shape)
repeat_output = ttnn.to_torch(repeat_output)
print(repeat_output.shape)
assert_with_pcc(torch_result_repeat_interleave, repeat_output, 0.9999)

ttnn.close_device(device)

Expected behavior
If ttnn.repeat_interleave passes then it is also expected to work with ttnn.repeat

Screenshots
Screenshot 2024-03-13 at 11 57 20 AM

@tt-aho
Copy link
Contributor

tt-aho commented Mar 13, 2024

Updating this issue from chat:

I think it’s this many because they’re trying repeat the tensor 2048 times, and repeat is implemented under the hood as concat. For concat, it works on a variable number of tensor inputs but for each tensor, it requires 4 args, so repeat is probably a naive wrapper that is calling concat with 2048 args of the same tensor, which would result in 32kb of args

For an actual repeat op that works similarly to concat, it would only need 4 + some other args instead of 4 * num tensors/repeats

@umadevimcw
Copy link
Contributor

@tt-aho @kpaigwar To get clarification (I assume some conversation is going on) Do we need to address this issue? What is the plan for this? Any suggestions/ideas?

@kpaigwar
Copy link
Contributor Author

kpaigwar commented Mar 15, 2024

@umadevimcw, the repeat OP is required for performant mamba model implementation which we are targeting by the end of this month.

Requirement

Equivalent OP for torch.repeat

input_tensor = torch.rand(1, 1, 32, 16)
repeat_shape = (1, 1, 1, 2048)
output = input_tensor.repeat(repeat_shape)
output.shape -> (1, 1, 32, 16*2048)

Non-performant workaround

I have created a transformation matrix to do equivalent repeat OP but its a big matmul (1, 1, 32, 16) @(1, 1, 16, 16*2048). We want an implementation which is performant than this matmul.

Unit test for above workaround

import ttnn
import torch
from tests.ttnn.utils_for_testing import assert_with_pcc

device = ttnn.open_device(device_id=0)
torch_input_tensor = torch.randn((1, 1, 32, 16), dtype=torch.bfloat16)
repeat_shape = (1, 1, 1, 2048)

torch_result_repeat = torch_input_tensor.repeat(repeat_shape)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

repeat_tranform_matrix = torch.eye(16).repeat(1, 2048).unsqueeze(0).unsqueeze(0)
repeat_tranform_matrix = ttnn.from_torch(repeat_tranform_matrix, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat16)

repeat_transform_output = ttnn.matmul(input_tensor, repeat_tranform_matrix, memory_config=ttnn.L1_MEMORY_CONFIG)
repeat_transform_output = ttnn.to_torch(repeat_transform_output)
print(repeat_transform_output.shape)
assert_with_pcc(torch_result_repeat, repeat_transform_output, 0.9999)
print("Repeat Transform Passed")
ttnn.close_device(device)

@KalaivaniMCW
Copy link
Contributor

For ttnn.repeat bug found the below issues :

  • while running unit tests for ttnn.repeat along different dimensions, cases for h and w dimensions failed. There is a shape mismatch in ttnn.repeat (data_movement.py) along h and w dim, currently it reshapes the (h,w) of repeat output back to the input shape. This needs to be changed so that the output tensor is reshaped to correct h and w values.

  • as per composite ops implementation , repeat calls the concat op for each dimension with a vector consisting 'n' number of input tensors as input argument with the dim number. (n is num of repeats for that dimension).
    So in this case (1, 1, 1, 2048) for dim=3, concat is called with vector of 2048 input_tensors.

  • as per concat.cpp, h and w of input tensors should be multiples of 32 if concat is happening on h or w dim , so repeat fails when these conditions aren't met.

  • As for this issue , i tried changing input shape from (1, 1, 32, 16) to (1, 1, 32, 32) and run with a smaller repeat shape say (1, 1, 1, 50) and i was able to run it. Any value higher causes hang issue and Segmentation fault (core dumped) and I am having to restart the terminal.

  • Also in the test code given, repeat_output is being compared with repeat_interleave. is that intentional ?
    assert_with_pcc(torch_result_repeat_interleave, repeat_output, 0.9999)

  • I was able to pass the assert only for such cases where repeats are upto 50 and h,w are in multiples of 32 if h,w need to be repeated. assert_with_pcc(torch_result_repeat, repeat_output, 0.9999)

@jliangTT jliangTT removed their assignment Mar 15, 2024
@kpaigwar
Copy link
Contributor Author

Answering to @tt-aho comment
"If you are targeting performance, are you planning to run this model/op sharded? Or exactly as specified in this issue? The optimizations/support for sharding are different than interleaved so want to make sure we're targeting/optimizing the right thing."
-> Yes, we are planning to run this op in sharded way. Unit test is just an example to show the workaround we have.

@kpaigwar
Copy link
Contributor Author

kpaigwar commented Mar 15, 2024

@KalaivaniMCW
assert_with_pcc(torch_result_repeat_interleave, repeat_output, 0.9999)
The repeat_interleave was just to show that both have same output shape and its possible to compute it without any memory errors.

@kevinwuTT
Copy link

@kpaigwar Hello, just to be sure, did you mean to tag @KalaivaniMCW?

@kpaigwar
Copy link
Contributor Author

Summarizing the requirements

  1. Equivalent Op to torch.repeat
  2. Should support both L1 sharding/interleaved. If sharding implementation could take more time, interleaved implementation is ok.
  3. Should be performant than current workaround of using big transformation matrix.

@tt-aho
Copy link
Contributor

tt-aho commented Mar 15, 2024

The restriction on shapes being tiled sized is because the concat op does not natively support RM layout, and pads/formats to tile layout, runs concat, and unpads, which means it can't concat on dim 2/3 if there is padding on the same dim (no restriction on concat of other dims). Restriction on number of inputs/args is because we have a limit on the number of runtime args (I think 256?) which would result in an assert if over.

However I am not sure about what would be causing the hang/segmentation fault. Would need to look into this.

As for support for this op, there would be 2 steps I see for supporting the interleaved case.

  1. Update concat op to natively support RM to allow for 16 element width. This will be okay since the width required has an aligned size. Most of the current concat logic could be reused for the rm case with some minor adjustment.
  2. Update repeat to be a variant of concat instead of wrapping concat with X number of tensors where you only need to pass the args for 1 tensor to the kernel. For repeating on multiple dims it would just be a loop for now, so repeat on X number of dims would be X number of op calls internally (fine for this case where repeat is on one dim).

This would add basic interleaved support for rm concat, and a more optimal repeat (there would be better optimizations that could be done for repeat since it can have more efficient read/write logic compared to concat).
Alternative potentially lower perf solution but probably easiest to get functionality is to make the composite repeat break up the concats in a smart way, eg powers of 2 decomposition, instead of one big concat.

However, since the target is sharding I don't think it's useful to look into optimizing the interleaved version beyond basic support for debug unless it's needed elsewhere.

@kpaigwar do you have the specs for sharding for this op? Also just want to check if this issue/request is just a workaround for a temporary broadcast (similar to #5769) or if this is an actual needed standalone op?

@kpaigwar
Copy link
Contributor Author

@tt-aho both issues are similar. But they differ on how the broadcasting is happening. In this particular issue the last dim
of tensor(1, 1, 32, 16) is 16 which is getting broadcasted but in issue #5769 weights have batch_dim= 1 is getting broadcasted to input batch_size 32. In other words the repeat is many to many and #5769 is one-to-many
We don't have a shard specs right now as we are operating in L1 interleaved.

@tt-aho
Copy link
Contributor

tt-aho commented Mar 15, 2024

Okay, I think this makes sense. We can try and move forward with my suggestions above on implementing interleaved repeat to unblock functionality/basic perf, and once you have the shard specs we can review the requirements for sharded repeat, otherwise we can look into more perf optimizations for interleaved repeat.

@ntarafdar
Copy link
Contributor

Thanks @jliangTT

tt-aho added a commit that referenced this issue Mar 18, 2024
tt-aho added a commit that referenced this issue Mar 18, 2024
tt-aho added a commit that referenced this issue Mar 18, 2024
tt-aho added a commit that referenced this issue Mar 18, 2024
@kpaigwar
Copy link
Contributor Author

kpaigwar commented Mar 18, 2024

Update: We found that the workaround matmul op for torch.repeat is the bottleneck for our performance taking .39 ms (12% of device time). See row 177 in the profiling sheet attached. We are hoping this issue should be prioritized.
cc @aguhaTT
ops_perf_results_mamba_2024_03_18_16_20_50.xlsx

@tt-aho
Copy link
Contributor

tt-aho commented Mar 18, 2024

You can try branch aho/repeat which implements repeat on device and is in the process of getting merged to main. It is tested for the shape you have in this PR but has not been optimized since you mentioned you will be sharding the model so it currently doesn't seem like priority to optimize the interleaved version.

@kpaigwar
Copy link
Contributor Author

Thanks, we will try that and update on the sharding performance

@ntarafdar
Copy link
Contributor

ntarafdar commented Mar 18, 2024

@kpaigwar if @tt-aho 's PR address the runtime memory error, can we close this issue.
We should use a diff issue to track performance

tt-aho added a commit that referenced this issue Mar 18, 2024
tt-aho added a commit that referenced this issue Mar 18, 2024
tt-aho added a commit that referenced this issue Mar 18, 2024
tt-aho added a commit that referenced this issue Mar 18, 2024
@tt-aho
Copy link
Contributor

tt-aho commented Mar 18, 2024

This is now merged to main. Please confirm if it addresses the issue and follow @tarafdarTT 's suggestion of opening a new pr once you have perf requirements/specifications.

@kpaigwar
Copy link
Contributor Author

@tt-aho I tested this on main, I am getting shape mismatch errors on my example as well as test_repeat.py (modified to reproduce error)

Test

import ttnn
import torch
from tests.ttnn.utils_for_testing import assert_with_pcc


device = ttnn.open_device(device_id=0)
torch_input_tensor = torch.randn((1, 1, 32, 32), dtype=torch.bfloat16)
repeat_shape = torch.randn((1, 1, 1, 2), dtype=torch.bfloat16)

input_tensor1 = ttnn.from_torch(repeat_shape, layout=ttnn.TILE_LAYOUT)
input_tensor1 = ttnn.to_device(input_tensor1, device)
torch_result = torch_input_tensor.repeat(repeat_shape.shape)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output = ttnn.repeat(input_tensor, input_tensor1.shape)
output = ttnn.to_torch(output)

assert_with_pcc(torch_result, output, 0.9999)

ttnn.close_device(device)

Error

File "/proj_sw/user_dev/kpaigwar/tt-metal/ttnn/ttnn/operations/core.py", line 227, in reshape                                                                                               
  return ttnn._ttnn.operations.core.reshape(input_tensor, shape)                                                                                                                            
RuntimeError: TT_ASSERT @ tt_eager/tensor/tensor.cpp:240: this->volume() == tt::tt_metal::compute_volume(new_shape)                                                                           
info:                                                                                                                                                                                         
{} != {}                                                                                                                                                                                      
2048
1024

@tt-aho
Copy link
Contributor

tt-aho commented Mar 19, 2024

I see. I'm not entirely familiar with and need to look at why ttnn does additional reshaping under the hood, but my test with tt_lib version is passing for your shape case so might just be something in ttnn needs updating.

@tt-aho
Copy link
Contributor

tt-aho commented Mar 19, 2024

I have fix for ttnn version with pr #6526. Note that for your specific case with input shape [1, 1, 32, 16] you should create your input tensor in RM and not TILE, or else your output will have padding interleaved inside the tensor.

@kpaigwar
Copy link
Contributor Author

@tt-aho I tested with your fix, it worked for cases when tensor doesn't have any padding. Although it failed on following tensors

torch_input_tensor = torch.randn((1, 1, 32, 16), dtype=torch.bfloat16)
repeat_shape = torch.randn((1, 1, 1, 2048), dtype=torch.bfloat16)

@kpaigwar
Copy link
Contributor Author

I will close this issue and address padding issue on your PR #6526.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

8 participants