Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient Joint Attention op #16557

Open
2 of 3 tasks
cglagovichTT opened this issue Jan 9, 2025 · 1 comment
Open
2 of 3 tasks

Efficient Joint Attention op #16557

cglagovichTT opened this issue Jan 9, 2025 · 1 comment

Comments

@cglagovichTT
Copy link
Contributor

cglagovichTT commented Jan 9, 2025

Diffusion Transformer models tend to process vision and text streams independently until attention, where joint attention between vision and text tokens is performed. SD3.5 and Mochi are two models which have this exact pattern.

The simple implementation of concat -> SDPA -> slice can be inefficient for a few reasons:

  1. the long sequence length of the vision inputs leads to a slow concat and slice
  2. if vision and/or text tokens are padded in the sequence length
    a. either you must do a padding-aware concatenate which is slow / leads to OOM, or
    b. you must do a concat of the padded tensors and create a massive attention mask to mask out padding tokens

This issue aims to resolve all of these problems by creating an efficient joint_attention operation which has the following interface:

def joint_attention(
	q,
	k,
	v,
	is_causal=False, # always False
	joint_q,
	joint_k,
	joint_v,
	joint_strategy="rear", # always rear
	scale=None,
	program_config=None,
	compute_kernel_config=None,
):
	''' 
	A joint_attention operation which takes two sets of q, k, v
	and performs the concatenated non-causal attention of them,
	returning sliced outputs.
	q, k, v must be in DRAM. output is in DRAM.
	
        q, k, v:         [b, nh, N, dh]
	joint_(q, k, v): [b, nh, L, dh]

	Efficiently implements the following reference:
	'''
	assert not is_causal
	assert joint_strategy == "rear"
	
	query = ttnn.cat([q, joint_q], dim=2)
	key = ttnn.cat([k, joint_k], dim=2)
	value = ttnn.cat([v, joint_v], dim=2)
	
	out = scaled_dot_product_attention(
		query,
		key,
		value,
		is_causal=is_causal,
		scale=scale,
		program_config=program_config,
	)
	orig_seq = q.shape[2]
	output = out[..., ,:orig_seq, :]
	joint_output = out[..., orig_seq:, :]
	return output, joint_output

OP requirements

  • must be aware of padding on vision and text tokens to internally create a non-causal attention mask which masks out padded tokens
  • must allow arbitrary q_chunk, k_chunk regardless of the input sequence length. since we're dealing with poorly shaped tensors, it is also a good opportunity to remove the requirement that the input sequence length perfectly divides q_chunk and k_chunk

Future features

  • an op like this is a good basis for future ring attention on joint tensors

Impact

In the Mochi model, padded vision and text tokens force me to create a 44k x 44k attention mask. Pushing this mask to device accounts for 66% of end-to-end latency for Mochi video generation. This op removes the need for an attention mask, so it will lead to a 3x end-to-end speedup.

Plan

@cglagovichTT
Copy link
Contributor Author

FYI @uaydonat

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants