Skip to content

Commit

Permalink
udp
Browse files Browse the repository at this point in the history
upd

upd

upd

prefetch

prefetch doesn't work

upd

upd

upd

refactor

add test for head_dim 64 & 256

upd

upd

upd

upd

upd

upd

upd

wip

upd

upd
  • Loading branch information
yzh119 committed Dec 16, 2024
1 parent d9d8eb1 commit 04ee9bc
Show file tree
Hide file tree
Showing 52 changed files with 5,730 additions and 132 deletions.
22 changes: 22 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,25 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

-------------------------------------------------------------------------------------------------
Some of the code in this project are adapted from other open-source projects with different
licenses. This product also bundles some third-party components under other open source licenses.
This section summarizes those components and their licenses.
See licenses/ for text of these licenses.

BSD 3-Clause License
--------------------

include/flashinfer/attention/hopper/epilogue.cuh
include/flashinfer/attention/hopper/mainloop.cuh
include/flashinfer/attention/hopper/kernel_traits.cuh
include/flashinfer/attention/hopper/named_barrier.cuh
include/flashinfer/attention/hopper/tile_scheduler.cuh
include/flashinfer/attention/hopper/utils.cuh

BSD 3-Clause "New" License
--------------------------

3rdparty/cutlass
include/flashinfer/attention/hopper/block_sparse_gather.cuh
96 changes: 96 additions & 0 deletions aot_build_utils/generate_batch_paged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import re
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
head_dim,
pos_encoding_mode,
allow_fp16_qk_reduction,
mask_mode,
dtype_q,
dtype_kv,
dtype_out,
idtype,
):
def get_insts(attention_variant):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)
]
)

dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]

content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>
namespace flashinfer {{
using DTypeQ = cutlass_dtype_t<{dtype_q}>;
using DTypeKV = cutlass_dtype_t<{dtype_kv}>;
using DTypeO = cutlass_dtype_t<{dtype_out}>;
using Params = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;
{get_insts("LogitsSoftCap")}
{get_insts("StandardAttention")}
}}"""
return content


if __name__ == "__main__":
pattern = (
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
)
compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)

with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
97 changes: 97 additions & 0 deletions aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import re
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
head_dim,
pos_encoding_mode,
allow_fp16_qk_reduction,
mask_mode,
dtype_q,
dtype_kv,
dtype_out,
idtype,
):

def get_insts(attention_variant):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)
]
)

dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]

content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>
namespace flashinfer {{
using DTypeQ = cutlass_dtype_t<{dtype_q}>;
using DTypeKV = cutlass_dtype_t<{dtype_kv}>;
using DTypeO = cutlass_dtype_t<{dtype_out}>;
using Params = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;
{get_insts("LogitsSoftCap")}
{get_insts("StandardAttention")}
}}
"""
return content


if __name__ == "__main__":
pattern = (
r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
)
compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)
with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
85 changes: 85 additions & 0 deletions aot_build_utils/generate_single_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import re
import sys
from pathlib import Path

from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal


def get_cu_file_str(
head_dim,
pos_encoding_mode,
allow_fp16_qk_reduction,
mask_mode,
dtype_q,
dtype_kv,
dtype_out,
):
content = """#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>
namespace flashinfer {{
using DTypeQ = cutlass_dtype_t<{dtype_q}>;
using DTypeKV = cutlass_dtype_t<{dtype_kv}>;
using DTypeO = cutlass_dtype_t<{dtype_out}>;
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap>(
Params& params,
cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap>(
Params& params,
cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention>(
Params& params,
cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention>(
Params& params,
cudaStream_t stream);
}}
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
dtype_q=dtype_literal[dtype_q],
dtype_kv=dtype_literal[dtype_kv],
dtype_out=dtype_literal[dtype_out],
use_custom_mask="true" if int(mask_mode) == 2 else "false",
)
return content


if __name__ == "__main__":
pattern = (
r"single_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_sm90\.cu"
)

compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)
with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
Loading

0 comments on commit 04ee9bc

Please sign in to comment.