-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
upd upd upd prefetch prefetch doesn't work upd upd upd refactor add test for head_dim 64 & 256 upd upd upd upd upd upd upd wip upd upd
- Loading branch information
Showing
52 changed files
with
5,730 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Copyright (c) 2024 by FlashInfer team. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import re | ||
import sys | ||
from pathlib import Path | ||
|
||
from .literal_map import ( | ||
dtype_literal, | ||
idtype_literal, | ||
mask_mode_literal, | ||
pos_encoding_mode_literal, | ||
) | ||
|
||
|
||
def get_cu_file_str( | ||
head_dim, | ||
pos_encoding_mode, | ||
allow_fp16_qk_reduction, | ||
mask_mode, | ||
dtype_q, | ||
dtype_kv, | ||
dtype_out, | ||
idtype, | ||
): | ||
def get_insts(attention_variant): | ||
return "\n".join( | ||
[ | ||
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>( | ||
Params& params, | ||
cudaStream_t stream); | ||
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>( | ||
Params& params, | ||
cudaStream_t stream); | ||
""".format( | ||
head_dim=head_dim, | ||
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], | ||
allow_fp16_qk_reduction=allow_fp16_qk_reduction, | ||
mask_mode=mask_mode_literal[int(mask_mode)], | ||
attention_variant=attention_variant, | ||
) | ||
] | ||
) | ||
|
||
dtype_q = dtype_literal[dtype_q] | ||
dtype_kv = dtype_literal[dtype_kv] | ||
dtype_out = dtype_literal[dtype_out] | ||
idtype = idtype_literal[idtype] | ||
|
||
content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh> | ||
#include <flashinfer/attention/hopper/variants.cuh> | ||
#include <flashinfer/cutlass_utils.cuh> | ||
namespace flashinfer {{ | ||
using DTypeQ = cutlass_dtype_t<{dtype_q}>; | ||
using DTypeKV = cutlass_dtype_t<{dtype_kv}>; | ||
using DTypeO = cutlass_dtype_t<{dtype_out}>; | ||
using Params = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>; | ||
{get_insts("LogitsSoftCap")} | ||
{get_insts("StandardAttention")} | ||
}}""" | ||
return content | ||
|
||
|
||
if __name__ == "__main__": | ||
pattern = ( | ||
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_" | ||
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu" | ||
) | ||
compiled_pattern = re.compile(pattern) | ||
path = Path(sys.argv[1]) | ||
fname = path.name | ||
match = compiled_pattern.match(fname) | ||
|
||
with open(path, "w") as f: | ||
f.write(get_cu_file_str(*match.groups())) |
97 changes: 97 additions & 0 deletions
97
aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
""" | ||
Copyright (c) 2024 by FlashInfer team. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import re | ||
import sys | ||
from pathlib import Path | ||
|
||
from .literal_map import ( | ||
dtype_literal, | ||
idtype_literal, | ||
mask_mode_literal, | ||
pos_encoding_mode_literal, | ||
) | ||
|
||
|
||
def get_cu_file_str( | ||
head_dim, | ||
pos_encoding_mode, | ||
allow_fp16_qk_reduction, | ||
mask_mode, | ||
dtype_q, | ||
dtype_kv, | ||
dtype_out, | ||
idtype, | ||
): | ||
|
||
def get_insts(attention_variant): | ||
return "\n".join( | ||
[ | ||
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>( | ||
Params& params, | ||
cudaStream_t stream); | ||
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>( | ||
Params& params, | ||
cudaStream_t stream); | ||
""".format( | ||
head_dim=head_dim, | ||
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], | ||
allow_fp16_qk_reduction=allow_fp16_qk_reduction, | ||
mask_mode=mask_mode_literal[int(mask_mode)], | ||
attention_variant=attention_variant, | ||
) | ||
] | ||
) | ||
|
||
dtype_q = dtype_literal[dtype_q] | ||
dtype_kv = dtype_literal[dtype_kv] | ||
dtype_out = dtype_literal[dtype_out] | ||
idtype = idtype_literal[idtype] | ||
|
||
content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh> | ||
#include <flashinfer/attention/hopper/variants.cuh> | ||
#include <flashinfer/cutlass_utils.cuh> | ||
namespace flashinfer {{ | ||
using DTypeQ = cutlass_dtype_t<{dtype_q}>; | ||
using DTypeKV = cutlass_dtype_t<{dtype_kv}>; | ||
using DTypeO = cutlass_dtype_t<{dtype_out}>; | ||
using Params = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>; | ||
{get_insts("LogitsSoftCap")} | ||
{get_insts("StandardAttention")} | ||
}} | ||
""" | ||
return content | ||
|
||
|
||
if __name__ == "__main__": | ||
pattern = ( | ||
r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_" | ||
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu" | ||
) | ||
compiled_pattern = re.compile(pattern) | ||
path = Path(sys.argv[1]) | ||
fname = path.name | ||
match = compiled_pattern.match(fname) | ||
with open(path, "w") as f: | ||
f.write(get_cu_file_str(*match.groups())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
Copyright (c) 2024 by FlashInfer team. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import re | ||
import sys | ||
from pathlib import Path | ||
|
||
from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal | ||
|
||
|
||
def get_cu_file_str( | ||
head_dim, | ||
pos_encoding_mode, | ||
allow_fp16_qk_reduction, | ||
mask_mode, | ||
dtype_q, | ||
dtype_kv, | ||
dtype_out, | ||
): | ||
content = """#include <flashinfer/attention/hopper/prefill_sm90.cuh> | ||
#include <flashinfer/attention/hopper/variants.cuh> | ||
#include <flashinfer/cutlass_utils.cuh> | ||
namespace flashinfer {{ | ||
using DTypeQ = cutlass_dtype_t<{dtype_q}>; | ||
using DTypeKV = cutlass_dtype_t<{dtype_kv}>; | ||
using DTypeO = cutlass_dtype_t<{dtype_out}>; | ||
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>; | ||
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap>( | ||
Params& params, | ||
cudaStream_t stream); | ||
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap>( | ||
Params& params, | ||
cudaStream_t stream); | ||
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention>( | ||
Params& params, | ||
cudaStream_t stream); | ||
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention>( | ||
Params& params, | ||
cudaStream_t stream); | ||
}} | ||
""".format( | ||
head_dim=head_dim, | ||
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], | ||
allow_fp16_qk_reduction=allow_fp16_qk_reduction, | ||
mask_mode=mask_mode_literal[int(mask_mode)], | ||
dtype_q=dtype_literal[dtype_q], | ||
dtype_kv=dtype_literal[dtype_kv], | ||
dtype_out=dtype_literal[dtype_out], | ||
use_custom_mask="true" if int(mask_mode) == 2 else "false", | ||
) | ||
return content | ||
|
||
|
||
if __name__ == "__main__": | ||
pattern = ( | ||
r"single_prefill_head_([0-9]+)_posenc_([0-9]+)_" | ||
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_sm90\.cu" | ||
) | ||
|
||
compiled_pattern = re.compile(pattern) | ||
path = Path(sys.argv[1]) | ||
fname = path.name | ||
match = compiled_pattern.match(fname) | ||
with open(path, "w") as f: | ||
f.write(get_cu_file_str(*match.groups())) |
Oops, something went wrong.