Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime stitching / runtime weights #671

Closed
AleksKnezevic opened this issue Sep 12, 2024 · 2 comments · Fixed by #1301
Closed

Runtime stitching / runtime weights #671

AleksKnezevic opened this issue Sep 12, 2024 · 2 comments · Fixed by #1301
Assignees

Comments

@AleksKnezevic
Copy link
Contributor

AleksKnezevic commented Sep 12, 2024

Proposal for Front-End (FE) Interaction with the tt-mlir Runtime for On-Device Tensors

In this document, I will refer to the tt-mlir runtime as the “runtime” and the forge and pjrt runtimes as “third-party runtimes” (TPRT).

Key Use Cases:

  • Multi-batch instance: Avoid rewriting weights/constants for every iteration.
  • Runtime stitching: Output of one model feeds as input to another (including past cache), without feeding IOs back through the host.
  • Training loops: Updatable weights on the device during training iterations.

The following, I believe, addresses all these use cases:

  • TPRT should push parameters once, and they should remain live on the device.
  • TPRT pushes activations for each iteration.
  • TPRT can create two input tensors for double-buffered graph execution if needed.
  • TPRT is responsible for deallocating all tensors.
  • The runtime leaves outputs on the device until explicitly copied to the host by TPRT.
  • Outputs can be stored in either L1 or DRAM.
  • The compiler may not always know the layout of an on-device tensor (e.g., the FE may not recompile the same graph when executed twice, and activations could be in DRAM on the first execution and L1 cache on the second).
  • If there is a layout mismatch, the runtime converts the tensor to the required layout.
  • Tensors in L1 that are not needed by the current program should be moved to DRAM.

To accomplish this, I propose the following API:


Tensor toDevice(Tensor, Device, Layout)

  • Copies a tensor to the device with the specified layout.
  • Returns a handle to the on-device tensor.

Tensor toDevice(Tensor, Device)

  • Copies a tensor to the device, interleaved into DRAM.
  • Returns a handle to the on-device tensor.

Tensor toHost(Tensor)

  • Waits for tensor operations to complete and copies a tensor to the host.
  • The tensor remains allocated in device memory.

void wait(Tensor)

  • Barriers on tensor operations being finished

Layout getLayout(Binary, ProgramIndex, InputIndex)
-Returns layout of input at index as defined in binary


std::vector submit(Device, Binary, programIndex, inputTensors)

  • Execute binary on device
  • Asserts all input tensors are on device
  • Non-blocking, immediately returns, caller can barrier if desired
  • Calls toLayout on input tensors that are not in correct layout
  • Moves any tensors from L1 to DRAM that are not used by this program
  • Returns list of on-device output tensors

Tensor toLayout(Device, Tensor, Layout)

  • Returns tensor of required layout

void Deallocate(Tensor, Device)

  • Deallocates tensor on device
@jnie-TT
Copy link
Contributor

jnie-TT commented Sep 12, 2024

Hey @AleksKnezevic this looks great! Had a couple of minor comments

void toLayout(Device, Tensor, Layout)

Converts tensor layout

I don't think we can convert layouts of Tensors in place. TTNN ops always allocate/return a new tensor. Maybe we could update this to

Tensor toLayout(Tensor, Layout)

That will solely convert the layout of the tensor. If we ever want to convert layout and move tensors to device or across devices, we can call the ToDevice API.


void Deallocate(Tensor, Device)

Deallocates tensor on device

We should probably add a force flag that signals ttnn whether or not to force deallocate a tensor. By default, it will deallocate only when the reference count is 0.


Layout getLayout(Device, Binary, ProgramIndex, InputIndex)
-Returns layout of input at index as defined in binary

We probably don't need the Device for this API.

@AleksKnezevic
Copy link
Contributor Author

Thanks @jnie-TT, I modified the API above. As for deallocate, if the user (through TTRT) is deallocating a tensor, then it should be fine to force-deallocate in TTNN

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants