You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Diffusion Transformer models tend to process vision and text streams independently until attention, where joint attention between vision and text tokens is performed. SD3.5 and Mochi are two models which have this exact pattern.
The simple implementation of concat -> SDPA -> slice can be inefficient for a few reasons:
the long sequence length of the vision inputs leads to a slow concat and slice
if vision and/or text tokens are padded in the sequence length
a. either you must do a padding-aware concatenate which is slow / leads to OOM, or
b. you must do a concat of the padded tensors and create a massive attention mask to mask out padding tokens
This issue aims to resolve all of these problems by creating an efficient joint_attention operation which has the following interface:
defjoint_attention(
q,
k,
v,
is_causal=False, # always Falsejoint_q,
joint_k,
joint_v,
joint_strategy="rear", # always rearscale=None,
program_config=None,
compute_kernel_config=None,
):
''' A joint_attention operation which takes two sets of q, k, v and performs the concatenated non-causal attention of them, returning sliced outputs. q, k, v must be in DRAM. output is in DRAM. q, k, v: [b, nh, N, dh] joint_(q, k, v): [b, nh, L, dh] Efficiently implements the following reference: '''assertnotis_causalassertjoint_strategy=="rear"query=ttnn.cat([q, joint_q], dim=2)
key=ttnn.cat([k, joint_k], dim=2)
value=ttnn.cat([v, joint_v], dim=2)
out=scaled_dot_product_attention(
query,
key,
value,
is_causal=is_causal,
scale=scale,
program_config=program_config,
)
orig_seq=q.shape[2]
output=out[..., ,:orig_seq, :]
joint_output=out[..., orig_seq:, :]
returnoutput, joint_output
OP requirements
must be aware of padding on vision and text tokens to internally create a non-causal attention mask which masks out padded tokens
must allow arbitrary q_chunk, k_chunk regardless of the input sequence length. since we're dealing with poorly shaped tensors, it is also a good opportunity to remove the requirement that the input sequence length perfectly divides q_chunk and k_chunk
Future features
an op like this is a good basis for future ring attention on joint tensors
Impact
In the Mochi model, padded vision and text tokens force me to create a 44k x 44k attention mask. Pushing this mask to device accounts for 66% of end-to-end latency for Mochi video generation. This op removes the need for an attention mask, so it will lead to a 3x end-to-end speedup.
Diffusion Transformer models tend to process vision and text streams independently until attention, where joint attention between vision and text tokens is performed. SD3.5 and Mochi are two models which have this exact pattern.
The simple implementation of concat -> SDPA -> slice can be inefficient for a few reasons:
a. either you must do a padding-aware concatenate which is slow / leads to OOM, or
b. you must do a concat of the padded tensors and create a massive attention mask to mask out padding tokens
This issue aims to resolve all of these problems by creating an efficient
joint_attention
operation which has the following interface:OP requirements
q_chunk, k_chunk
regardless of the input sequence length. since we're dealing with poorly shaped tensors, it is also a good opportunity to remove the requirement that the input sequence length perfectly dividesq_chunk
andk_chunk
Future features
Impact
In the Mochi model, padded vision and text tokens force me to create a 44k x 44k attention mask. Pushing this mask to device accounts for 66% of end-to-end latency for Mochi video generation. This op removes the need for an attention mask, so it will lead to a 3x end-to-end speedup.
Plan
The text was updated successfully, but these errors were encountered: