diff --git a/LICENSE b/LICENSE index 261eeb9e..7c8f7e14 100644 --- a/LICENSE +++ b/LICENSE @@ -199,3 +199,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +------------------------------------------------------------------------------------------------- +Some of the code in this project are adapted from other open-source projects with different +licenses. This product also bundles some third-party components under other open source licenses. +This section summarizes those components and their licenses. +See licenses/ for text of these licenses. + +BSD 3-Clause License +-------------------- + +include/flashinfer/attention/hopper/epilogue.cuh +include/flashinfer/attention/hopper/mainloop.cuh +include/flashinfer/attention/hopper/kernel_traits.cuh +include/flashinfer/attention/hopper/named_barrier.cuh +include/flashinfer/attention/hopper/tile_scheduler.cuh +include/flashinfer/attention/hopper/utils.cuh + +BSD 3-Clause "New" License +-------------------------- + +3rdparty/cutlass +include/flashinfer/attention/hopper/block_sparse_gather.cuh diff --git a/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py b/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py new file mode 100644 index 00000000..80310a74 --- /dev/null +++ b/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py @@ -0,0 +1,96 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import re +import sys +from pathlib import Path + +from .literal_map import ( + dtype_literal, + idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) + + +def get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, + dtype_kv, + dtype_out, + idtype, +): + def get_insts(attention_variant): + return "\n".join( + [ + """template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>( + Params& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>( + Params& params, + cudaStream_t stream); + """.format( + head_dim=head_dim, + pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], + allow_fp16_qk_reduction=allow_fp16_qk_reduction, + mask_mode=mask_mode_literal[int(mask_mode)], + attention_variant=attention_variant, + ) + ] + ) + + dtype_q = dtype_literal[dtype_q] + dtype_kv = dtype_literal[dtype_kv] + dtype_out = dtype_literal[dtype_out] + idtype = idtype_literal[idtype] + + content = f"""#include +#include +#include + + +namespace flashinfer {{ + +using DTypeQ = cutlass_dtype_t<{dtype_q}>; +using DTypeKV = cutlass_dtype_t<{dtype_kv}>; +using DTypeO = cutlass_dtype_t<{dtype_out}>; + +using Params = BatchPrefillPagedParams; + +{get_insts("LogitsSoftCap")} + +{get_insts("StandardAttention")} + +}}""" + return content + + +if __name__ == "__main__": + pattern = ( + r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu" + ) + compiled_pattern = re.compile(pattern) + path = Path(sys.argv[1]) + fname = path.name + match = compiled_pattern.match(fname) + + with open(path, "w") as f: + f.write(get_cu_file_str(*match.groups())) diff --git a/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py b/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py new file mode 100644 index 00000000..e26a7389 --- /dev/null +++ b/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py @@ -0,0 +1,97 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import re +import sys +from pathlib import Path + +from .literal_map import ( + dtype_literal, + idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) + + +def get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, + dtype_kv, + dtype_out, + idtype, +): + + def get_insts(attention_variant): + return "\n".join( + [ + """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>( + Params& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>( + Params& params, + cudaStream_t stream); + """.format( + head_dim=head_dim, + pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], + allow_fp16_qk_reduction=allow_fp16_qk_reduction, + mask_mode=mask_mode_literal[int(mask_mode)], + attention_variant=attention_variant, + ) + ] + ) + + dtype_q = dtype_literal[dtype_q] + dtype_kv = dtype_literal[dtype_kv] + dtype_out = dtype_literal[dtype_out] + idtype = idtype_literal[idtype] + + content = f"""#include +#include +#include + + +namespace flashinfer {{ + +using DTypeQ = cutlass_dtype_t<{dtype_q}>; +using DTypeKV = cutlass_dtype_t<{dtype_kv}>; +using DTypeO = cutlass_dtype_t<{dtype_out}>; + +using Params = BatchPrefillRaggedParams; + +{get_insts("LogitsSoftCap")} + +{get_insts("StandardAttention")} + +}} + """ + return content + + +if __name__ == "__main__": + pattern = ( + r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu" + ) + compiled_pattern = re.compile(pattern) + path = Path(sys.argv[1]) + fname = path.name + match = compiled_pattern.match(fname) + with open(path, "w") as f: + f.write(get_cu_file_str(*match.groups())) diff --git a/aot_build_utils/generate_single_prefill_sm90_inst.py b/aot_build_utils/generate_single_prefill_sm90_inst.py new file mode 100644 index 00000000..13e57999 --- /dev/null +++ b/aot_build_utils/generate_single_prefill_sm90_inst.py @@ -0,0 +1,85 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import re +import sys +from pathlib import Path + +from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal + + +def get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, + dtype_kv, + dtype_out, +): + content = """#include +#include +#include + +namespace flashinfer {{ + +using DTypeQ = cutlass_dtype_t<{dtype_q}>; +using DTypeKV = cutlass_dtype_t<{dtype_kv}>; +using DTypeO = cutlass_dtype_t<{dtype_out}>; + +using Params = SinglePrefillParams; + +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap>( + Params& params, + cudaStream_t stream); + +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap>( + Params& params, + cudaStream_t stream); + +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention>( + Params& params, + cudaStream_t stream); + +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention>( + Params& params, + cudaStream_t stream); +}} + """.format( + head_dim=head_dim, + pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], + allow_fp16_qk_reduction=allow_fp16_qk_reduction, + mask_mode=mask_mode_literal[int(mask_mode)], + dtype_q=dtype_literal[dtype_q], + dtype_kv=dtype_literal[dtype_kv], + dtype_out=dtype_literal[dtype_out], + use_custom_mask="true" if int(mask_mode) == 2 else "false", + ) + return content + + +if __name__ == "__main__": + pattern = ( + r"single_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_sm90\.cu" + ) + + compiled_pattern = re.compile(pattern) + path = Path(sys.argv[1]) + fname = path.name + match = compiled_pattern.match(fname) + with open(path, "w") as f: + f.write(get_cu_file_str(*match.groups())) diff --git a/aot_build_utils/generate_sm90.py b/aot_build_utils/generate_sm90.py new file mode 100644 index 00000000..f87f34e5 --- /dev/null +++ b/aot_build_utils/generate_sm90.py @@ -0,0 +1,200 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import argparse +from itertools import product +from pathlib import Path +from typing import List + +from . import ( + generate_batch_paged_prefill_sm90_inst, + generate_batch_ragged_prefill_sm90_inst, + generate_single_prefill_sm90_inst, +) + + +def get_sm90_instantiation_cu(args: argparse.Namespace) -> List[str]: + def write_if_different(path: Path, content: str) -> None: + if path.exists() and path.read_text() == content: + return + path.write_text(content) + + path: Path = args.path + head_dims: List[int] = args.head_dims + pos_encoding_modes: List[int] = args.pos_encoding_modes + allow_fp16_qk_reductions: List[int] = args.allow_fp16_qk_reductions + mask_modes: List[int] = args.mask_modes + enable_bf16: bool = args.enable_bf16 + + path.mkdir(parents=True, exist_ok=True) + + idtypes = ["i32"] + prefill_dtypes = ["f16"] + decode_dtypes = ["f16"] + fp16_dtypes = ["f16"] + if enable_bf16: + prefill_dtypes.append("bf16") + decode_dtypes.append("bf16") + fp16_dtypes.append("bf16") + + # single prefill files + single_prefill_sm90_uris = [] + for ( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + ) in product( + head_dims, + pos_encoding_modes, + allow_fp16_qk_reductions, + mask_modes, + ): + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): + fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_sm90.cu" + content = generate_single_prefill_sm90_inst.get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + ) + for use_sliding_window in [True, False]: + for use_logits_soft_cap in [True, False]: + if ( + mask_mode == 0 + ): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris + single_prefill_sm90_uris.append( + f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}_" + f"f16qk_{bool(allow_fp16_qk_reduction)}_sm90" + ) + write_if_different(path / fname, content) + + # batch prefill files + batch_prefill_sm90_uris = [] + for ( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + idtype, + ) in product( + head_dims, + pos_encoding_modes, + allow_fp16_qk_reductions, + mask_modes, + idtypes, + ): + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): + fname = f"batch_paged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu" + content = generate_batch_paged_prefill_sm90_inst.get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + idtype, + ) + write_if_different(path / fname, content) + + fname = f"batch_ragged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu" + content = generate_batch_ragged_prefill_sm90_inst.get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + idtype, + ) + write_if_different(path / fname, content) + + for sliding_window in [True, False]: + for logits_soft_cap in [True, False]: + if ( + mask_mode == 0 + ): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris + batch_prefill_sm90_uris.append( + f"batch_prefill_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"dtype_idx_{idtype}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{sliding_window}_" + f"use_logits_cap_{logits_soft_cap}_" + f"f16qk_{bool(allow_fp16_qk_reduction)}_sm90" + ) + + return single_prefill_sm90_uris + batch_prefill_sm90_uris + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Generate cuda files") + parser.add_argument( + "--path", type=Path, required=True, help="Path to the dispatch inc file" + ) + parser.add_argument( + "--head_dims", type=int, required=True, nargs="+", help="Head dimensions" + ) + parser.add_argument( + "--pos_encoding_modes", + type=int, + required=True, + nargs="+", + help="Position encoding modes", + ) + parser.add_argument( + "--allow_fp16_qk_reductions", + type=lambda x: x if isinstance(x, int) else int(x.lower() == "true"), + required=True, + nargs="+", + help="Allow fp16 qk reductions", + ) + parser.add_argument( + "--mask_modes", + type=int, + required=True, + nargs="+", + help="Mask modes", + ) + parser.add_argument( + "--enable_bf16", + type=lambda x: x if isinstance(x, int) else x.lower() == "true", + required=True, + nargs="+", + help="Enable bf16", + ) + parser.add_argument( + "--enable_fp8", + type=lambda x: x if isinstance(x, int) else x.lower() == "true", + default=True, + nargs="+", + help="Enable fp8", + ) + args = parser.parse_args() + get_sm90_instantiation_cu(args) diff --git a/benchmarks/bench_hopper_attention.py b/benchmarks/bench_hopper_attention.py new file mode 100644 index 00000000..f5bcc19e --- /dev/null +++ b/benchmarks/bench_hopper_attention.py @@ -0,0 +1,201 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton + +import flashinfer + + +def bench_single_prefill(seq_len, num_heads, causal, head_dim): + num_qo_heads = num_kv_heads = num_heads + q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") + k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + + sm80_ms, sm90_ms = ( + triton.testing.do_bench( + lambda: flashinfer.single_prefill_with_kv_cache_return_lse( + q, k, v, causal=causal, backend=backend + ), + warmup=100, + rep=1000, + ) + for backend in ["fa2", "fa3"] + ) + + def flops(ms): + if causal: + return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 + else: + return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + + print( + f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" + ) + + +def bench_batch_ragged_prefill(batch_size, num_heads, seq_len, causal, head_dim): + num_qo_heads = num_kv_heads = num_heads + q = torch.randn( + batch_size * seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + k = torch.randn( + batch_size * seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + v = torch.randn( + batch_size * seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + + sm80_wrapper, sm90_wrapper = ( + flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0"), + kv_layout="NHD", + backend=backend, + ) + for backend in ["fa2", "fa3"] + ) + + qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() + kv_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() + + for wrapper in [sm80_wrapper, sm90_wrapper]: + wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + causal=causal, + ) + + sm80_ms, sm90_ms = ( + triton.testing.do_bench( + lambda: wrapper.run(q, k, v), + warmup=100, + rep=1000, + ) + for wrapper in [sm80_wrapper, sm90_wrapper] + ) + + def flops(ms): + if causal: + return ( + batch_size * seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 + ) + else: + return ( + batch_size * seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + ) + + print( + f"bench_batch_ragged_prefill (batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" + ) + + +def bench_batch_paged_prefill( + page_size, batch_size, num_heads, seq_len, causal, head_dim +): + num_qo_heads = num_kv_heads = num_heads + q = torch.randn( + batch_size * seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + k = torch.randn( + batch_size * seq_len // page_size, + page_size, + num_kv_heads, + head_dim, + dtype=torch.half, + device="cuda", + ) + v = torch.randn( + batch_size * seq_len // page_size, + page_size, + num_kv_heads, + head_dim, + dtype=torch.half, + device="cuda", + ) + + sm80_wrapper, sm90_wrapper = ( + flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0"), + kv_layout="NHD", + backend=backend, + ) + for backend in ["fa2", "fa3"] + ) + + qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() + kv_indptr = torch.arange( + 0, batch_size * (seq_len // page_size) + 1, (seq_len // page_size) + ).int() + kv_indices = torch.arange(0, batch_size * (seq_len // page_size)).int() + last_page_len = torch.ones(batch_size, dtype=torch.int32) * page_size + + for wrapper in [sm80_wrapper, sm90_wrapper]: + wrapper.plan( + qo_indptr, + kv_indptr, + kv_indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, # page_size + causal=causal, + ) + + sm80_ms, sm90_ms = ( + triton.testing.do_bench( + lambda: wrapper.run(q, (k, v)), + warmup=100, + rep=1000, + ) + for wrapper in [sm80_wrapper, sm90_wrapper] + ) + + def flops(ms): + if causal: + return ( + batch_size * seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 + ) + else: + return ( + batch_size * seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + ) + + print( + f"bench_batch_paged_prefill (page_size={page_size} batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" + ) + + +if __name__ == "__main__": + bench_batch_paged_prefill(1, 128, 32, 1024, True, 128) + bench_batch_paged_prefill(1, 64, 32, 2048, True, 128) + bench_batch_paged_prefill(1, 32, 32, 4096, True, 128) + bench_batch_paged_prefill(1, 16, 32, 8192, True, 128) + bench_batch_paged_prefill(1, 1, 32, 32768, True, 128) + bench_batch_paged_prefill(16, 128, 32, 1024, True, 128) + bench_batch_paged_prefill(16, 64, 32, 2048, True, 128) + bench_batch_paged_prefill(16, 32, 32, 4096, True, 128) + bench_batch_paged_prefill(16, 16, 32, 8192, True, 128) + bench_batch_paged_prefill(16, 1, 32, 32768, True, 128) + bench_batch_ragged_prefill(128, 32, 1024, True, 128) + bench_batch_ragged_prefill(64, 32, 2048, True, 128) + bench_batch_ragged_prefill(32, 32, 4096, True, 128) + bench_batch_ragged_prefill(16, 32, 8192, True, 128) + bench_batch_ragged_prefill(1, 32, 32768, True, 128) diff --git a/csrc/aot_extension_utils.h b/csrc/aot_extension_utils.h index 76db0168..b701c289 100644 --- a/csrc/aot_extension_utils.h +++ b/csrc/aot_extension_utils.h @@ -30,15 +30,15 @@ #define DISPATCH_mask_mode(expr, const_expr, ...) \ _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) -#define DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, ...) \ - [&]() -> bool { \ - if (use_logits_soft_cap) { \ - constexpr bool USE_LOGITS_SOFT_CAP = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool USE_LOGITS_SOFT_CAP = false; \ - return __VA_ARGS__(); \ - } \ +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ }() #define DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_dtype, kv_dtype, c_type_q, c_type_kv, ...) \ diff --git a/csrc/batch_decode.cu b/csrc/batch_decode.cu index 19cea2c8..95b79b22 100644 --- a/csrc/batch_decode.cu +++ b/csrc/batch_decode.cu @@ -56,7 +56,7 @@ std::vector BatchDecodeWithPagedKVCachePlan( using DTypeKV = kv_type; using DTypeO = DTypeQ; return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { + return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { using ParamsT = BatchDecodeParams; using AttentionVariant = ComposedAttention 0, USE_LOGITS_SOFT_CAP, [&] { + return DISPATCH_BOOL(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] { using ParamsT = BatchDecodeParams; using AttentionVariant = ComposedAttention; using RaggedAttentionVariant = ComposedAttention paged_kv( num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout, static_cast(paged_k_cache.data_ptr()), diff --git a/csrc/batch_prefill_sm90.cu b/csrc/batch_prefill_sm90.cu new file mode 100644 index 00000000..0d358e90 --- /dev/null +++ b/csrc/batch_prefill_sm90.cu @@ -0,0 +1,281 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "aot_extension_utils.h" + +namespace flashinfer { + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched( + BatchPrefillRaggedParams& params, cudaStream_t stream); + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched( + BatchPrefillPagedParams& params, cudaStream_t stream); + +} // namespace flashinfer + +using namespace flashinfer; + +std::vector BatchPrefillWithKVCacheSM90Plan( + unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, + bool enable_cuda_graph, int64_t cuda_stream) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + PrefillPlanSM90Info plan_info; + + using IdType = int32_t; + + cudaStream_t stream = reinterpret_cast(cuda_stream); + + cudaError_t status = PrefillSM90Plan( + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), + kv_indptr.data_ptr(), kv_len_arr.data_ptr(), batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size, causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + + TORCH_CHECK(status == cudaSuccess, + "PrefillSM90Plan failed with error: ", cudaGetErrorString(status)); + + return plan_info.ToVector(); +} + +void BatchPrefillWithRaggedKVCacheSM90Run( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, at::Tensor kv_indptr, std::optional maybe_qk_indptr, + at::Tensor o, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream) { + PrefillPlanSM90Info plan_info; + plan_info.FromVector(plan_info_vec); + + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + } + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + unsigned int head_dim = q.size(2); + + auto q_scalar_type = q.scalar_type(); + + QKVLayout kv_layout = static_cast(layout); + cudaStream_t stream = reinterpret_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + bool use_logits_soft_cap = logits_soft_cap > 0.f; + bool use_swa = window_left != -1; + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, qkv_type, [&] { + using DTypeQ = cutlass_dtype_t; + using DTypeKV = DTypeQ; + using DTypeO = DTypeQ; + using IdType = int32_t; + + BatchPrefillRaggedParams params; + + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); + } else { + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); + } + params.nnz_qo = q.size(0); + params.nnz_kv = k.size(0); + params.head_dim = head_dim; + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); + params.group_size = params.num_qo_heads / params.num_kv_heads; + params.window_left = window_left; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale_log2 = sm_scale * math::log2e; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { + return DISPATCH_BOOL(use_swa, USE_SWA, [&] { + using AttentionVariant = + std::conditional_t; + cudaError_t status = + BatchPrefillWithRaggedKVCacheDispatched(params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + }); + }); +} + +void BatchPrefillWithPagedKVCacheSM90Run( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, + std::optional maybe_qk_indptr, at::Tensor o, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse, int64_t cuda_stream) { + PrefillPlanSM90Info plan_info; + plan_info.FromVector(plan_info_vec); + + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + } + QKVLayout kv_layout = static_cast(layout); + unsigned int num_kv_heads, page_size; + unsigned int head_dim = q.size(2); + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); + } else { + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); + } + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + auto q_scalar_type = q.scalar_type(); + + cudaStream_t stream = reinterpret_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + bool use_logits_soft_cap = logits_soft_cap > 0.f; + bool use_swa = window_left != -1; + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, qkv_type, [&] { + using DTypeQ = cutlass_dtype_t; + using DTypeKV = DTypeQ; + using DTypeO = DTypeQ; + using IdType = int32_t; + + BatchPrefillPagedParams params; + + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(paged_k_cache.data_ptr()); + params.v_ptr = static_cast(paged_v_cache.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + // (num_pages, page_size, num_heads, head_dim) + params.k_stride_n = paged_k_cache.stride(1); + params.k_stride_h = paged_k_cache.stride(2); + params.v_stride_n = paged_v_cache.stride(1); + params.v_stride_h = paged_v_cache.stride(2); + } else { + // (num_pages, num_heads, page_size, head_dim) + params.k_stride_h = paged_k_cache.stride(1); + params.k_stride_n = paged_k_cache.stride(2); + params.v_stride_h = paged_v_cache.stride(1); + params.v_stride_n = paged_v_cache.stride(2); + } + params.nnz_qo = q.size(0); + params.head_dim = head_dim; + params.num_qo_heads = q.size(1); + params.num_kv_heads = num_kv_heads; + params.group_size = params.num_qo_heads / num_kv_heads; + params.page_size = page_size; + params.window_left = window_left; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale_log2 = sm_scale * math::log2e; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.kv_indices = static_cast(paged_kv_indices.data_ptr()); + + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { + return DISPATCH_BOOL(use_swa, USE_SWA, [&] { + using AttentionVariant = + std::conditional_t; + cudaError_t status = + BatchPrefillWithPagedKVCacheDispatched(params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + }); + }); +} diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index b34253e2..7885cd72 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -59,13 +59,13 @@ using namespace flashinfer; #define DISPATCH_mask_mode(expr, const_expr, ...) \ _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) -#define DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, ...) \ - [&]() -> bool { \ - if (use_logits_soft_cap) { \ - constexpr bool USE_LOGITS_SOFT_CAP = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool USE_LOGITS_SOFT_CAP = false; \ - return __VA_ARGS__(); \ - } \ +#define DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, ...) \ + [&]() -> bool { \ + if (use_logits_soft_cap) { \ + constexpr bool USE_LOGITS_SOFT_CAP = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool USE_LOGITS_SOFT_CAP = false; \ + return __VA_ARGS__(); \ + } \ }() diff --git a/csrc/flashinfer_gemm_sm90_ops.cu b/csrc/flashinfer_gemm_sm90_ops.cu deleted file mode 100644 index b6802e42..00000000 --- a/csrc/flashinfer_gemm_sm90_ops.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pytorch_extension_utils.h" - -void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr, - at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride, - at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major, - int64_t cuda_stream); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, - "Cutlass Segment GEMM operator for SM90"); -} diff --git a/csrc/flashinfer_ops.cu b/csrc/flashinfer_ops.cu index e6676e3a..3d84bfc4 100644 --- a/csrc/flashinfer_ops.cu +++ b/csrc/flashinfer_ops.cu @@ -86,6 +86,14 @@ void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::T at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, unsigned int layout, int64_t cuda_stream); +void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices, + at::Tensor block_sparse_indptr, + at::Tensor vector_sparse_offsets, + at::Tensor vector_sparse_indptr, + at::Tensor kv_len_arr, unsigned int stride_block, + unsigned int stride_n, unsigned int batch_size, + unsigned int block_size, int64_t cuda_stream); + //========== prefill ========== void single_prefill_with_kv_cache(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, @@ -226,6 +234,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // page m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + m.def("block_sparse_indices_to_vector_sparse_offsets", + &block_sparse_indices_to_vector_sparse_offsets, "Precompute block sparse offsets"); // prefill m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, diff --git a/csrc/flashinfer_ops_sm90.cu b/csrc/flashinfer_ops_sm90.cu new file mode 100644 index 00000000..cc3ac869 --- /dev/null +++ b/csrc/flashinfer_ops_sm90.cu @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "aot_extension_utils.h" + +void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr, + at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride, + at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major, + int64_t cuda_stream); + +void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, + at::Tensor v, + std::optional maybe_packed_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor o, + unsigned int layout, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta, std::optional maybe_lse, + int64_t cuda_stream); + +std::vector BatchPrefillWithKVCacheSM90Plan( + unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, + bool enable_cuda_graph, int64_t cuda_stream); + +void BatchPrefillWithRaggedKVCacheSM90Run( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, at::Tensor kv_indptr, std::optional maybe_qk_indptr, + at::Tensor o, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream); + +void BatchPrefillWithPagedKVCacheSM90Run( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, + std::optional maybe_qk_indptr, at::Tensor o, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse, int64_t cuda_stream); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, + "Cutlass Segment GEMM operator for SM90"); + m.def("single_prefill_with_kv_cache_sm90", &single_prefill_with_kv_cache_sm90); + m.def("batch_prefill_with_kv_cache_sm90_plan", &BatchPrefillWithKVCacheSM90Plan); + m.def("batch_prefill_with_ragged_kv_cache_sm90_run", &BatchPrefillWithRaggedKVCacheSM90Run); + m.def("batch_prefill_with_paged_kv_cache_sm90_run", &BatchPrefillWithPagedKVCacheSM90Run); +} diff --git a/csrc/flashinfer_page_ops.cu b/csrc/flashinfer_page_ops.cu index d78d4ac0..e365eb62 100644 --- a/csrc/flashinfer_page_ops.cu +++ b/csrc/flashinfer_page_ops.cu @@ -20,6 +20,16 @@ void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::T at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, unsigned int layout, int64_t cuda_stream); +void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices, + at::Tensor block_sparse_indptr, + at::Tensor vector_sparse_offsets, + at::Tensor vector_sparse_indptr, + at::Tensor kv_len_arr, unsigned int stride_block, + unsigned int stride_n, unsigned int batch_size, + unsigned int block_size, int64_t cuda_stream); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + m.def("block_sparse_indices_to_vector_sparse_offsets", + &block_sparse_indices_to_vector_sparse_offsets, "Precompute block sparse offsets"); } diff --git a/csrc/group_gemm.cu b/csrc/group_gemm.cu index 78779fe5..8fae8b9b 100644 --- a/csrc/group_gemm.cu +++ b/csrc/group_gemm.cu @@ -28,7 +28,7 @@ void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_x_data.scalar_type(), c_type, [&] { - using cutlass_t = typename cutlass_dtype::value; + using cutlass_t = cutlass_dtype_t; auto status = CutlassSegmentGEMMRun( workspace_buffer.data_ptr(), workspace_buffer.element_size() * workspace_buffer.size(0), all_problems.data_ptr(), batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(), diff --git a/csrc/group_gemm_sm90.cu b/csrc/group_gemm_sm90.cu index 3710cf2f..5341a204 100644 --- a/csrc/group_gemm_sm90.cu +++ b/csrc/group_gemm_sm90.cu @@ -30,7 +30,7 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_x_data.scalar_type(), c_type, [&] { - using cutlass_t = typename cutlass_dtype::value; + using cutlass_t = cutlass_dtype_t; auto status = CutlassSegmentGEMMSM90Run( float_workspace_buffer.data_ptr(), float_workspace_buffer.element_size() * float_workspace_buffer.size(0), diff --git a/csrc/page.cu b/csrc/page.cu index 644a7dc6..db684194 100644 --- a/csrc/page.cu +++ b/csrc/page.cu @@ -110,3 +110,30 @@ void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::T TORCH_CHECK(success, "AppendPagedKVCache failed to dispatch with dtype ", kv_scalar_dtype); } + +void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices, + at::Tensor block_sparse_indptr, + at::Tensor vector_sparse_offsets, + at::Tensor vector_sparse_indptr, + at::Tensor kv_len_arr, unsigned int stride_block, + unsigned int stride_n, unsigned int batch_size, + unsigned int block_size, int64_t cuda_stream) { + CHECK_INPUT(block_sparse_indices); + CHECK_INPUT(block_sparse_indptr); + CHECK_INPUT(vector_sparse_offsets); + CHECK_INPUT(vector_sparse_indptr); + CHECK_INPUT(kv_len_arr); + + cudaStream_t stream = reinterpret_cast(cuda_stream); + + cudaError_t status = BlockSparseIndicesToVectorSparseOffset( + static_cast(block_sparse_indices.data_ptr()), + static_cast(block_sparse_indptr.data_ptr()), + static_cast(vector_sparse_offsets.data_ptr()), + static_cast(vector_sparse_indptr.data_ptr()), + static_cast(kv_len_arr.data_ptr()), stride_block, stride_n, batch_size, block_size, + stream); + + TORCH_CHECK(status == cudaSuccess, "BlockSparseIndicesToVectorSparseOffset failed with error: ", + cudaGetErrorString(status)); +} diff --git a/csrc/rope.cu b/csrc/rope.cu index a1018dbc..3f10357b 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -139,7 +139,6 @@ void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_r size_t k_rope_stride_h = k_rope.stride(1); cudaStream_t stream = reinterpret_cast(cuda_stream); - cudaStream_t torch_current_stream(nullptr); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache( static_cast(q.data_ptr()), static_cast(k.data_ptr()), @@ -231,7 +230,6 @@ void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, a size_t k_rope_stride_h = k_rope.stride(1); cudaStream_t stream = reinterpret_cast(cuda_stream); - cudaStream_t torch_current_stream(nullptr); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { cudaError_t status = BatchQKApplyLlama31RotaryPosIds( static_cast(q.data_ptr()), static_cast(k.data_ptr()), diff --git a/csrc/single_decode.cu b/csrc/single_decode.cu index 60f9114b..60a2bd76 100644 --- a/csrc/single_decode.cu +++ b/csrc/single_decode.cu @@ -76,7 +76,7 @@ void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::T using DTypeKV = kv_type; using DTypeO = DTypeQ; return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_LOGITS_SOFT_CAP(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] { + return DISPATCH_BOOL(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] { using ParamsT = SingleDecodeParams; using AttentionVariant = ComposedAttention; using AttentionVariant = ComposedAttention + +#include +#include +#include +#include +#include +#include +#include + +#include "aot_extension_utils.h" + +namespace flashinfer { + +template +cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams& params, + cudaStream_t stream); + +} // namespace flashinfer + +using namespace flashinfer; + +void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, + at::Tensor v, + std::optional maybe_packed_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor o, + unsigned int layout, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta, std::optional maybe_lse, + int64_t cuda_stream) { + unsigned int head_dim = q.size(2); + unsigned int num_qo_heads = q.size(1); + unsigned int qo_len = q.size(0); + + auto q_scalar_type = q.scalar_type(); + + QKVLayout kv_layout = static_cast(layout); + cudaStream_t stream = reinterpret_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + bool use_logits_soft_cap = logits_soft_cap > 0.0f; + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, q_type, [&] { + using DTypeQ = cutlass_dtype_t; + using DTypeKV = DTypeQ; + using DTypeO = DTypeQ; + SinglePrefillParams params; + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? (static_cast(maybe_lse->data_ptr())) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); + } else { + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); + } + params.qo_len = q.size(0); + params.kv_len = k.size(0); + params.head_dim = head_dim; + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); + params.causal = mask_mode == MaskMode::kCausal; + params.group_size = params.num_qo_heads / params.num_kv_heads; + params.window_left = window_left; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale_log2 = sm_scale * math::log2e; + bool use_swa = window_left != -1; + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { + return DISPATCH_BOOL(use_swa, USE_SWA, [&] { + using AttentionVariant = + std::conditional_t; + cudaError_t status = + SinglePrefillWithKVCacheDispatched( + params, stream); + TORCH_CHECK(status == cudaSuccess, + "single_prefill_with_kv_cache_sm90 failed with error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); + }); + }); + }); +} diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 49265c3b..6c418366 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -17,9 +17,11 @@ # Re-export from .activation import gen_act_and_mul_module as gen_act_and_mul_module from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str +from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module from .attention import gen_batch_decode_module as gen_batch_decode_module from .attention import gen_batch_prefill_module as gen_batch_prefill_module +from .attention import gen_batch_prefill_sm90_module as gen_batch_prefill_sm90_module from .attention import ( gen_customize_single_decode_module as gen_customize_single_decode_module, ) @@ -28,20 +30,21 @@ ) from .attention import gen_single_decode_module as gen_single_decode_module from .attention import gen_single_prefill_module as gen_single_prefill_module +from .attention import gen_single_prefill_sm90_module as gen_single_prefill_sm90_module from .attention import get_batch_decode_mla_uri as get_batch_decode_mla_uri from .attention import get_batch_decode_uri as get_batch_decode_uri +from .attention import get_batch_prefill_sm90_uri as get_batch_prefill_sm90_uri from .attention import get_batch_prefill_uri as get_batch_prefill_uri from .attention import get_single_decode_uri as get_single_decode_uri +from .attention import get_single_prefill_sm90_uri as get_single_prefill_sm90_uri from .attention import get_single_prefill_uri as get_single_prefill_uri from .core import clear_cache_dir, load_cuda_ops from .env import * from .utils import parallel_load_modules as parallel_load_modules -from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri - try: - from .. import _kernels - from .. import _kernels_sm90 + from .. import _kernels, _kernels_sm90 + has_prebuilt_ops = True except ImportError: has_prebuilt_ops = False diff --git a/flashinfer/jit/attention.py b/flashinfer/jit/attention.py index 72c2fca2..ba24f7c1 100644 --- a/flashinfer/jit/attention.py +++ b/flashinfer/jit/attention.py @@ -23,6 +23,10 @@ from .batch_decode_mla_templ import batch_decode_mla_suffix, batch_decode_mla_templ from .batch_decode_templ import batch_decode_suffix, batch_decode_templ +from .batch_prefill_sm90_templ import ( + batch_prefill_sm90_suffix, + batch_prefill_sm90_templ, +) from .batch_prefill_templ import batch_prefill_suffix, batch_prefill_templ from .core import load_cuda_ops from .env import FLASHINFER_GEN_SRC_DIR @@ -31,6 +35,10 @@ single_decode_suffix, single_decode_templ, ) +from .single_prefill_sm90_templ import ( + single_prefill_sm90_suffix, + single_prefill_sm90_templ, +) from .single_prefill_templ import ( customizable_single_prefill_templ, single_prefill_suffix, @@ -247,6 +255,35 @@ def get_single_prefill_sources( ) +def get_single_prefill_sm90_sources( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + head_dim: int, + pos_encoding_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, + use_fp16_qk_reduction: bool, +) -> List[str]: + assert not use_fp16_qk_reduction, "fp16 qk reduction is not supported on sm90" + assert ( + pos_encoding_mode == 0 + ), "Currently we only support pos_encoding_mode=0 on sm90" + return render_templates( + single_prefill_sm90_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "head_dim": head_dim, + "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], + "use_sliding_window": "true" if use_sliding_window else "false", + "use_logits_soft_cap": "true" if use_logits_soft_cap else "false", + "use_fp16_qk_reduction": "true" if use_fp16_qk_reduction else "false", + }, + ) + + def get_single_prefill_uri( dtype_q: torch.dtype, dtype_kv: torch.dtype, @@ -269,6 +306,10 @@ def get_single_prefill_uri( ) +def get_single_prefill_sm90_uri(*args): + return get_single_prefill_uri(*args) + "_sm90" + + def gen_single_prefill_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR uri = get_single_prefill_uri(*args) @@ -282,6 +323,19 @@ def gen_single_prefill_module(*args): return load_cuda_ops(uri, source_paths) +def gen_single_prefill_sm90_module(*args): + gen_directory = FLASHINFER_GEN_SRC_DIR + uri = get_single_prefill_sm90_uri(*args) + sources = get_single_prefill_sm90_sources(*args) + source_paths = [] + for suffix, source in zip(single_prefill_sm90_suffix, sources): + path = gen_directory / f"{uri}{suffix}" + source_paths.append(path) + write_if_different(path, source) + + return load_cuda_ops(uri, source_paths) + + def get_batch_prefill_sources( dtype_q: torch.dtype, dtype_kv: torch.dtype, @@ -309,6 +363,37 @@ def get_batch_prefill_sources( ) +def get_batch_prefill_sm90_sources( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim: int, + pos_encoding_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, + use_fp16_qk_reduction: bool, +) -> List[str]: + assert not use_fp16_qk_reduction, "fp16 qk reduction is not supported on sm90" + assert ( + pos_encoding_mode == 0 + ), "Currently we only support pos_encoding_mode=0 on sm90" + return render_templates( + batch_prefill_sm90_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "dtype_idx": dtype_map[dtype_idx], + "head_dim": head_dim, + "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], + "use_sliding_window": "true" if use_sliding_window else "false", + "use_logits_soft_cap": "true" if use_logits_soft_cap else "false", + "use_fp16_qk_reduction": "true" if use_fp16_qk_reduction else "false", + }, + ) + + def get_batch_prefill_uri( dtype_q: torch.dtype, dtype_kv: torch.dtype, @@ -333,6 +418,10 @@ def get_batch_prefill_uri( ) +def get_batch_prefill_sm90_uri(*args): + return get_batch_prefill_uri(*args) + "_sm90" + + def gen_batch_prefill_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR uri = get_batch_prefill_uri(*args) @@ -346,6 +435,19 @@ def gen_batch_prefill_module(*args): return load_cuda_ops(uri, source_paths) +def gen_batch_prefill_sm90_module(*args): + gen_directory = FLASHINFER_GEN_SRC_DIR + uri = get_batch_prefill_sm90_uri(*args) + sources = get_batch_prefill_sm90_sources(*args) + source_paths = [] + for suffix, source in zip(batch_prefill_sm90_suffix, sources): + path = gen_directory / f"{uri}{suffix}" + source_paths.append(path) + write_if_different(path, source) + + return load_cuda_ops(uri, source_paths) + + def get_customize_single_decode_sources( dtype_q: torch.dtype, dtype_kv: torch.dtype, diff --git a/flashinfer/jit/batch_prefill_sm90_templ.py b/flashinfer/jit/batch_prefill_sm90_templ.py new file mode 100644 index 00000000..c06c4aac --- /dev/null +++ b/flashinfer/jit/batch_prefill_sm90_templ.py @@ -0,0 +1,19 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +batch_prefill_sm90_suffix = [".cu", "_pybind.cc"] + +batch_prefill_sm90_templ = [r"""""", r""""""] diff --git a/flashinfer/jit/single_prefill_sm90_templ.py b/flashinfer/jit/single_prefill_sm90_templ.py new file mode 100644 index 00000000..917cf368 --- /dev/null +++ b/flashinfer/jit/single_prefill_sm90_templ.py @@ -0,0 +1,19 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +single_prefill_sm90_suffix = [".cu", "_pybind.cc"] + +single_prefill_sm90_templ = [r"""""", r""""""] diff --git a/flashinfer/page.py b/flashinfer/page.py index a008fb08..b0f80903 100644 --- a/flashinfer/page.py +++ b/flashinfer/page.py @@ -51,6 +51,44 @@ def get_page_module(): return _page_module +def block_sparse_indices_to_vector_sparse_offsets( + block_sparse_indices: torch.Tensor, + block_sparse_indptr: torch.Tensor, + vector_sparse_offsets: torch.Tensor, + vector_sparse_indptr: torch.Tensor, + kv_lens: torch.Tensor, + stride_block: int, + stride_n: int, + block_size: int, +) -> torch.Tensor: + if block_size == 1: + if stride_block == 1: + return block_sparse_indices + else: + return block_sparse_indices * stride_block + + with block_sparse_indices.device as device: + assert block_sparse_indices.dtype == torch.int32 + assert block_sparse_indptr.dtype == torch.int32 + assert vector_sparse_offsets.dtype == torch.int32 + assert vector_sparse_indptr.dtype == torch.int32 + assert kv_lens.dtype == torch.int32 + batch_size = block_sparse_indptr.size(0) - 1 + get_page_module().block_sparse_indices_to_vector_sparse_offsets( + block_sparse_indices, + block_sparse_indptr, + vector_sparse_offsets, + vector_sparse_indptr, + kv_lens, + stride_block, + stride_n, + batch_size, + block_size, + get_cuda_stream(device), + ) + return vector_sparse_offsets + + @register_custom_op( "flashinfer::append_paged_kv_cache", mutates_args=("paged_k_cache", "paged_v_cache"), diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 82944761..d5e8f5a3 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -24,13 +24,18 @@ from .jit import ( gen_batch_prefill_module, + gen_batch_prefill_sm90_module, gen_single_prefill_module, + gen_single_prefill_sm90_module, + get_batch_prefill_sm90_uri, get_batch_prefill_uri, + get_single_prefill_sm90_uri, get_single_prefill_uri, has_prebuilt_ops, load_cuda_ops, prebuilt_ops_uri, ) +from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens from .quantization import packbits, segment_packbits from .utils import ( MaskMode, @@ -43,14 +48,95 @@ _get_cache_buf, _unpack_paged_kv_cache, canonicalize_torch_dtype, + determine_attention_backend, get_cuda_stream, is_float8, + log2e, register_custom_op, register_fake_op, ) _single_prefill_modules = {} +_single_prefill_sm90_modules = {} _batch_prefill_modules = {} +_batch_prefill_sm90_modules = {} + + +def get_single_prefill_sm90_module(*args): + global _single_prefill_sm90_modules + if args not in _single_prefill_sm90_modules: + uri = get_single_prefill_sm90_uri(*args) + # if has_prebuilt_ops and uri in prebuilt_ops_uri: + from . import _kernels_sm90 + + run_func = _kernels_sm90.single_prefill_with_kv_cache_sm90 + # else: + # run_func = gen_single_prefill_sm90_module(*args).run + + @register_custom_op(f"flashinfer::{uri}_run", mutates_args=("tmp", "maybe_lse")) + def run_single_prefill_sm90( + mask_mode: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_packed_custom_mask: Optional[torch.Tensor], + tmp: torch.Tensor, + maybe_alibi_slopes: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + with q.device as device: # device guard + o = torch.empty_like(q) + run_func( + mask_mode, + q, + k, + v, + maybe_packed_custom_mask, + # tmp, + maybe_alibi_slopes, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + get_cuda_stream(device), + ) + return o + + @register_fake_op(f"flashinfer::{uri}_run") + def _fake_run_single_prefill_sm90( + mask_mode: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_packed_custom_mask: Optional[torch.Tensor], + tmp: torch.Tensor, + maybe_alibi_slopes: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return torch.empty_like(q) + + # Register the module + _single_prefill_sm90_modules[args] = SimpleNamespace( + run=run_single_prefill_sm90 + ) + + return _single_prefill_sm90_modules[args] def get_single_prefill_module(*args): @@ -130,6 +216,207 @@ def _fake_run_single_prefill( return _single_prefill_modules[args] +def get_batch_prefill_sm90_module(*args): + global _batch_prefill_sm90_modules + if args not in _batch_prefill_sm90_modules: + uri = get_batch_prefill_sm90_uri(*args) + + from . import _kernels_sm90 + + head_dim = args[4] + plan_func = ( + lambda *plan_args: _kernels_sm90.batch_prefill_with_kv_cache_sm90_plan( + head_dim, + *plan_args, + ) + ) + ragged_run_func = _kernels_sm90.batch_prefill_with_ragged_kv_cache_sm90_run + paged_run_func = _kernels_sm90.batch_prefill_with_paged_kv_cache_sm90_run + + # torch library for ragged_run + + @register_custom_op( + f"flashinfer::{uri}_ragged_run", + mutates_args=( + "float_workspace_buffer", + "int_workspace_buffer", + "maybe_lse", + ), + ) + def ragged_run( + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + with q.device as device: # device guard + o = torch.empty_like(q) + ragged_run_func( + mask_mode, + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + k, + v, + maybe_custom_mask, + maybe_alibi_slopes, + qo_indptr, + kv_indptr, + maybe_qk_indptr, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + get_cuda_stream(device), + ) + return o + + @register_fake_op(f"flashinfer::{uri}_ragged_run") + def _fake_ragged_run( + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return torch.empty_like(q) + + # torch library for paged_run + + @register_custom_op( + f"flashinfer::{uri}_paged_run", + mutates_args=( + "float_workspace_buffer", + "int_workspace_buffer", + "paged_k_cache", + "paged_v_cache", + "maybe_lse", + ), + ) + def paged_run( + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: torch.Tensor, + paged_v_cache: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + with q.device as device: # device guard + o = torch.empty_like(q) + paged_run_func( + mask_mode, + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + paged_k_cache, + paged_v_cache, + maybe_custom_mask, + maybe_alibi_slopes, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + maybe_qk_indptr, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + get_cuda_stream(device), + ) + return o + + @register_fake_op(f"flashinfer::{uri}_paged_run") + def _fake_paged_run( + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: torch.Tensor, + paged_v_cache: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return torch.empty_like(q) + + # Register the module. + # + # Note that plan is not part of model logic. It should not be included in + # Cuda Graph or torch.compile. So, we don't provide a torch library for plan. + _batch_prefill_sm90_modules[args] = SimpleNamespace( + plan=plan_func, + ragged_run=ragged_run, + paged_run=paged_run, + ) + return _batch_prefill_sm90_modules[args] + + def get_batch_prefill_module(*args): global _batch_prefill_modules if args not in _batch_prefill_modules: @@ -236,7 +523,7 @@ def _fake_ragged_run( # torch library for paged_run @register_custom_op( - f"flashinfer::{get_batch_prefill_uri(*args)}_paged_run", + f"flashinfer::{uri}_paged_run", mutates_args=( "float_workspace_buffer", "int_workspace_buffer", @@ -428,6 +715,7 @@ def single_prefill_with_kv_cache( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, return_lse: bool = False, + backend: str = "auto", ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Prefill/Append attention with KV cache for single request, return the attention output. @@ -485,6 +773,10 @@ def single_prefill_with_kv_cache( The theta used in RoPE, if not provided, will be set to 1e4. return_lse : bool Whether to return the log sum exp value of the attention logits. + backend : str + The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. + If set to ``auto``, the function will automatically choose the backend based on the + device architecture and kernel availability. Returns ------- @@ -563,7 +855,21 @@ def single_prefill_with_kv_cache( if return_lse: lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device) - out = get_single_prefill_module( + if backend == "auto": + backend = determine_attention_backend( + q.device, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + packed_custom_mask is not None, # use_custom_mask + q.dtype, + k.dtype, + ) + if backend == "fa2": + module_getter = get_single_prefill_module + elif backend == "fa3": + module_getter = get_single_prefill_sm90_module + + out = module_getter( q.dtype, k.dtype, q.dtype, @@ -733,6 +1039,7 @@ def __init__( paged_kv_last_page_len_buf: Optional[torch.Tensor] = None, custom_mask_buf: Optional[torch.Tensor] = None, qk_indptr_buf: Optional[torch.Tensor] = None, + backend: str = "auto", ) -> None: r"""Constructor of :class:`BatchPrefillWithPagedKVCacheWrapper`. @@ -783,14 +1090,35 @@ def __init__( should be ``[batch_size + 1]``. This argument is only effective when ``use_cuda_graph`` is ``True`` and the custom mask will be used in attention computation. + + backend : str + The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. + If set to ``auto``, the function will automatically choose the backend based on the + device architecture and kernel availability. """ _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) + if backend in ["fa3", "auto"]: + self._int_workspace_buffer = torch.empty( + (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) + # NOTE(Zihao): assume maximum accumulate kv length is 16M + self._vector_sparse_indices_buffer = torch.empty( + (16 * 1024 * 1024,), dtype=torch.int32, device=self.device + ) + # NOTE(Zihao): assume maximum batch size is 32768 + self._vector_sparse_indptr_buffer = torch.empty( + (32768,), dtype=torch.int32, device=self.device + ) + self._kv_lens_buffer = torch.empty( + (32768,), dtype=torch.int32, device=self.device + ) + else: + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, @@ -834,6 +1162,7 @@ def __init__( self._custom_mask_buf = custom_mask_buf self._qk_indptr_buf = qk_indptr_buf self._max_total_num_rows = None + self._backend = backend @property def is_cuda_graph_enabled(self) -> bool: @@ -1068,7 +1397,18 @@ def plan( self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type - self._cached_module = get_batch_prefill_module( + + if self._backend == "auto": + self._backend = determine_attention_backend( + self.device, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + self._custom_mask_buf is not None, # use_custom_mask + q_data_type, + kv_data_type, + ) + + get_module_args = ( q_data_type, kv_data_type, q_data_type, @@ -1079,21 +1419,63 @@ def plan( logits_soft_cap > 0, # use_logits_soft_cap allow_fp16_qk_reduction, ) - with self.device as device: - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - paged_kv_indptr_host, - self._max_total_num_rows or total_num_rows, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - get_cuda_stream(device), + if self._backend == "fa2": + self._cached_module = get_batch_prefill_module(*get_module_args) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + paged_kv_indptr_host, + self._max_total_num_rows or total_num_rows, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) + else: + self._cached_module = get_batch_prefill_sm90_module(*get_module_args) + paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu") + kv_lens_arr_host = get_seq_lens( + paged_kv_indptr_host, paged_kv_last_page_len_host, page_size + ) + self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( + kv_lens_arr_host, non_blocking=non_blocking ) + if page_size != 1: + vector_sparse_indptr_host = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), + ], + dim=0, + ) + self._vector_sparse_indptr_buffer[ + : len(vector_sparse_indptr_host) + ].copy_(vector_sparse_indptr_host, non_blocking=non_blocking) + else: + vector_sparse_indptr_host = paged_kv_indptr_host + + with self.device as device: + self._plan_info = self._cached_module.plan( + causal, + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + vector_sparse_indptr_host, + kv_lens_arr_host, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) + self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._allow_fp16_qk_reduction = allow_fp16_qk_reduction @@ -1199,6 +1581,13 @@ def run( _check_cached_qkv_data_type( q, k_cache, self._cached_q_data_type, self._cached_kv_data_type ) + stride_block = k_cache.stride(0) + if self._kv_layout == "NHD": + page_size = k_cache.shape[1] + stride_n = k_cache.stride(1) + else: + page_size = k_cache.shape[2] + stride_n = k_cache.stride(2) window_left = self._window_left logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale @@ -1228,6 +1617,22 @@ def run( else: mask_mode = MaskMode.NON_CAUSAL.value + if self._backend == "fa3": + # NOTE(Zihao): we divide both stride_block and stride_n by stride_n + # because we will multiply stride_n back in the kernel + sparse_indices = block_sparse_indices_to_vector_sparse_offsets( + self._paged_kv_indices_buf, + self._paged_kv_indptr_buf, + self._vector_sparse_indices_buffer, # output + self._vector_sparse_indptr_buffer, + self._kv_lens_buffer, + stride_block // stride_n, + 1, # stride_n // stride_n + page_size, + ) + else: + sparse_indices = self._paged_kv_indices_buf + out = self._cached_module.paged_run( mask_mode, self._float_workspace_buffer, @@ -1240,7 +1645,7 @@ def run( _get_cache_alibi_slopes_buf(q.shape[1], q.device), self._qo_indptr_buf, self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, + sparse_indices, # self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, self._qk_indptr_buf, TensorLayout[self._kv_layout].value, @@ -1404,6 +1809,7 @@ def __init__( kv_indptr_buf: Optional[torch.Tensor] = None, custom_mask_buf: Optional[torch.Tensor] = None, qk_indptr_buf: Optional[torch.Tensor] = None, + backend: str = "auto", ) -> None: r"""Constructor of :class:`BatchPrefillWithRaggedKVCacheWrapper`. @@ -1442,16 +1848,26 @@ def __init__( should be ``[batch_size]``. This argument is only effective when ``use_cuda_graph`` is ``True`` and custom mask will be used in attention computation. + + backend : str + The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. + If set to ``auto``, the function will automatically choose the backend based on the + device architecture and kernel availability. """ _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) + if backend in ["fa3", "auto"]: + self._int_workspace_buffer = torch.empty( + (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) + else: + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) self._pin_memory_int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, pin_memory=True + self._int_workspace_buffer.shape, dtype=torch.uint8, pin_memory=True ) self._use_cuda_graph = use_cuda_graph if use_cuda_graph: @@ -1478,6 +1894,7 @@ def __init__( self._custom_mask_buf = custom_mask_buf self._qk_indptr_buf = qk_indptr_buf self._max_total_num_rows = None + self._backend = backend @property def is_cuda_graph_enabled(self) -> bool: @@ -1671,7 +2088,18 @@ def plan( self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type - self._cached_module = get_batch_prefill_module( + + if self._backend == "auto": + self._backend = determine_attention_backend( + self.device, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + self._custom_mask_buf is not None, # use_custom_mask + q_data_type, + kv_data_type, + ) + + get_module_args = ( q_data_type, kv_data_type, q_data_type, @@ -1682,21 +2110,45 @@ def plan( logits_soft_cap > 0, # use_logits_soft_cap allow_fp16_qk_reduction, ) - with self.device as device: - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - kv_indptr_host, - self._max_total_num_rows or total_num_rows, - batch_size, - num_qo_heads, - num_kv_heads, - 1, # page_size - self.is_cuda_graph_enabled, - get_cuda_stream(device), - ) + if self._backend == "fa2": + self._cached_module = get_batch_prefill_module(*get_module_args) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + self._max_total_num_rows or total_num_rows, + batch_size, + num_qo_heads, + num_kv_heads, + 1, # page_size + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) + else: + self._cached_module = get_batch_prefill_sm90_module(*get_module_args) + kv_len_arr = kv_indptr_host[1:] - kv_indptr_host[:-1] + with self.device as device: + # NOTE(Zihao): there are some interface differences between fa2 and fa3 + # we should align the interface in the future + self._plan_info = self._cached_module.plan( + causal, + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + kv_len_arr, + batch_size, + num_qo_heads, + num_kv_heads, + 1, # page_size + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) + self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._allow_fp16_qk_reduction = allow_fp16_qk_reduction diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index 4bc92d0d..7732ce05 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -21,7 +21,12 @@ import torch from .decode import get_batch_decode_module -from .prefill import _compute_page_qk_indptr, get_batch_prefill_module +from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens +from .prefill import ( + _compute_page_qk_indptr, + get_batch_prefill_module, + get_batch_prefill_sm90_module, +) from .quantization import segment_packbits from .utils import ( MaskMode, @@ -30,6 +35,7 @@ _check_pos_encoding_mode, _get_cache_alibi_slopes_buf, canonicalize_torch_dtype, + determine_attention_backend, get_cuda_stream, ) @@ -107,6 +113,7 @@ class BlockSparseAttentionWrapper: def __init__( self, float_workspace_buffer: torch.Tensor, + backend: str = "auto", ) -> None: r"""Constructs of :class:`BlockSparseAttentionWrapper`. @@ -116,12 +123,34 @@ def __init__( The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors. + backend : str + The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. + If set to ``auto``, the function will automatically choose the backend based on the + device architecture and kernel availability. """ self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device - ) + if backend in ["fa3", "auto"]: + self._int_workspace_buffer = torch.empty( + (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) + # NOTE(Zihao): assume maximum accumulate kv length is 16M + self._vector_sparse_indices_buffer = torch.empty( + (16 * 1024 * 1024,), dtype=torch.int32, device=self.device + ) + # NOTE(Zihao): assume maximum batch size is 32768 + self._vector_sparse_indptr_buffer = torch.empty( + (32768,), dtype=torch.int32, device=self.device + ) + self._kv_lens_buffer = torch.empty( + (32768,), dtype=torch.int32, device=self.device + ) + else: + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), + dtype=torch.uint8, + device=float_workspace_buffer.device, + ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, @@ -139,6 +168,7 @@ def __init__( self.C: Optional[int] = None self.M: Optional[int] = None self.N: Optional[int] = None + self._backend = backend def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor @@ -176,6 +206,7 @@ def plan( head_dim: int, mask: Optional[torch.Tensor] = None, packed_mask: Optional[torch.Tensor] = None, + causal: bool = False, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, logits_soft_cap: Optional[float] = None, @@ -217,6 +248,10 @@ def plan( packed_mask : torch.Tensor, optional The 1D packed mask tensor, if provided, the :attr:`custom_mask` will be ignored. The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. + causal : bool + Whether to apply causal mask to the attention matrix. + This is only effective when :attr:`custom_mask` is not provided in + :meth:`plan`. pos_encoding_mode : str, optional The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. @@ -303,7 +338,7 @@ def plan( else: self._packed_mask_buf = None self._qk_indptr_buf = None - mask_mode = MaskMode.NON_CAUSAL.value + mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value self._mask_mode = mask_mode self.M = M @@ -318,7 +353,7 @@ def plan( # at this moment, when mask is provided, we use the tensor-core implementation if ( R * (num_qo_heads // num_kv_heads) < 4 - and mask_mode == MaskMode.NON_CAUSAL.value + and mask_mode != MaskMode.CUSTOM.value ): # If the operation is not compute-bound, we use the cuda-core implementation self._use_tensor_cores = False @@ -349,7 +384,18 @@ def plan( else: # if the operation is compute-bound, we use the tensor-core implementation self._use_tensor_cores = True - self._cached_module = get_batch_prefill_module( + + if self._backend == "auto": + self._backend = determine_attention_backend( + self.device, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + mask_mode == MaskMode.CUSTOM.value, # use_custom_mask + q_data_type, + kv_data_type, + ) + + get_module_args = ( q_data_type, kv_data_type, q_data_type, @@ -360,21 +406,60 @@ def plan( logits_soft_cap > 0, # use_logits_soft_cap allow_fp16_qk_reduction, ) - - with self.device as device: - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - kv_indptr_host, - num_blocks_row, - num_qo_heads, - num_kv_heads, - C, - False, # is_cuda_graph_enabled - get_cuda_stream(device), + if self._backend == "fa2": + self._cached_module = get_batch_prefill_module(*get_module_args) + + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + M, # total_num_rows + num_blocks_row, + num_qo_heads, + num_kv_heads, + C, + False, # is_cuda_graph_enabled + get_cuda_stream(device), + ) + else: + self._cached_module = get_batch_prefill_sm90_module(*get_module_args) + kv_lens_arr_host = (kv_indptr_host[1:] - kv_indptr_host[:-1]) * self.C + self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( + kv_lens_arr_host, non_blocking=non_blocking ) + if self.C != 1: + vector_sparse_indptr_host = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), + ], + dim=0, + ) + self._vector_sparse_indptr_buffer[ + : len(vector_sparse_indptr_host) + ].copy_(vector_sparse_indptr_host, non_blocking=non_blocking) + else: + vector_sparse_indptr_host = kv_indptr_host + + with self.device as device: + self._plan_info = self._cached_module.plan( + causal, + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + vector_sparse_indptr_host, + kv_lens_arr_host, + num_blocks_row, # batch_size + num_qo_heads, + num_kv_heads, + self.C, # page_size + False, # is_cuda_graph_enabled, + get_cuda_stream(device), + ) self._pos_encoding_mode = pos_encoding_mode self._allow_fp16_qk_reduction = allow_fp16_qk_reduction @@ -426,7 +511,6 @@ def run( return_lse : bool Whether to return the logsumexp of attention output - Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] @@ -454,6 +538,10 @@ def run( k = k.reshape(-1, self.C, *k.shape[-2:]) v = v.reshape(-1, self.C, *v.shape[-2:]) + stride_block = k.stride(0) + stride_n = k.stride(1) + print(k.shape, stride_block, stride_n) + lse = None if return_lse: lse = torch.empty( @@ -461,6 +549,21 @@ def run( ) if self._use_tensor_cores: + if self._backend == "fa3": + sparse_indices = block_sparse_indices_to_vector_sparse_offsets( + self._paged_kv_indices_buf, + self._paged_kv_indptr_buf, + self._vector_sparse_indices_buffer, # output + self._vector_sparse_indptr_buffer, + self._kv_lens_buffer, + stride_block // stride_n, + 1, # stride_n // stride_n + self.C, # block_size + ) + print(self.C, sparse_indices, self._vector_sparse_indices_buffer) + else: + sparse_indices = self._paged_kv_indices_buf + out = self._cached_module.paged_run( self._mask_mode, self._float_workspace_buffer, @@ -473,7 +576,7 @@ def run( _get_cache_alibi_slopes_buf(q.shape[1], self.device), self._qo_indptr, self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, + sparse_indices, # self._paged_kv_indices_buf, self._paged_kv_last_page_len, self._qk_indptr_buf, TensorLayout[self._kv_layout].value, diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 4abce374..d38af827 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -260,3 +260,95 @@ def determine_gemm_backend(device: torch.device) -> str: return "sm90" else: return "sm80" + + +def is_fa3_backend_supported( + pos_encoding_mode: int, + allow_fp16_qk_reductions: bool, + use_custom_mask: bool, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, +) -> bool: + """ + Check if the FA3 backend is supported based on the given parameters. + NOTE(Zihao): this function is a workaround for the lack of support for certain features in + our FA3 backend, and will be removed once the backend is fully supported. + + Parameters + ---------- + pos_encoding_mode : int + The positional encoding mode. + allow_fp16_qk_reductions : bool + Whether FP16 QK reductions are allowed. + use_custom_mask : bool + Whether a custom mask is used. + dtype_q : torch.dtype + The data type of the query tensor. + dtype_kv : torch.dtype + The data type of the key-value tensor. + + Returns + ------- + bool + True if the FA3 backend is supported, False otherwise. + """ + if use_custom_mask: + return False + if pos_encoding_mode != PosEncodingMode.NONE.value: + return False + if allow_fp16_qk_reductions: + return False + # NOTE: currently fp8 is not supported in our FA3 backend + # will add support soon + if dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2]: + return False + if dtype_kv in [torch.float8_e4m3fn, torch.float8_e5m2]: + return False + return True + + +def determine_attention_backend( + device: torch.device, + pos_encoding_mode: int, + allow_fp16_qk_reductions: bool, + use_custom_mask: bool, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, +) -> str: + """ + Determine the appropriate attention backend based on the device and parameters. + + Parameters + ---------- + device : torch.device + The device to be used. + mask_mode : int + The mask mode. + pos_encoding_mode : int + The positional encoding mode. + allow_fp16_qk_reductions : bool + Whether FP16 QK reductions are allowed. + use_custom_mask : bool + Whether a custom mask is used. + dtype_q : torch.dtype + The data type of the query tensor. + dtype_kv : torch.dtype + The data type of the key-value tensor. + + Returns + ------- + str + The name of the attention backend to be used. + """ + major, _ = get_compute_capability(device) + + if major >= 9 and is_fa3_backend_supported( + pos_encoding_mode, + allow_fp16_qk_reductions, + use_custom_mask, + dtype_q, + dtype_kv, + ): + return "fa3" + else: + return "fa2" diff --git a/include/flashinfer/attention/heap.h b/include/flashinfer/attention/heap.h new file mode 100644 index 00000000..e742c9b6 --- /dev/null +++ b/include/flashinfer/attention/heap.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_HEAP_H +#define FLASHINFER_ATTENTION_HEAP_H + +#include +#include +#include +#include + +namespace flashinfer { + +/*! + * \brief Heap data structure for (index, value) pairs + * \note minimal element on top + */ +class CTACostHeap { + public: + // first: index, second: cost + using Element = std::pair; + + CTACostHeap(int capacity) : heap_(capacity) { + for (int i = 0; i < capacity; ++i) { + heap_[i] = std::make_pair(i, 0.f); + } + } + + void insert(const Element& element) { + heap_.push_back(element); + std::push_heap(heap_.begin(), heap_.end(), compare); + } + + Element pop() { + std::pop_heap(heap_.begin(), heap_.end(), compare); + Element minElement = heap_.back(); + heap_.pop_back(); + return minElement; + } + + private: + // Custom comparator for the min-heap: compare based on 'val' in the pair + static bool compare(const Element& a, const Element& b) { + return a.second > b.second; // create a min-heap based on val + } + + std::vector heap_; +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HEAP_H diff --git a/include/flashinfer/attention/hopper/attention_updater.cuh b/include/flashinfer/attention/hopper/attention_updater.cuh new file mode 100644 index 00000000..f9fc1abb --- /dev/null +++ b/include/flashinfer/attention/hopper/attention_updater.cuh @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_ATTENTION_UPDATER_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_ATTENTION_UPDATER_CUH_ + +#include +#include + +namespace flashinfer { + +using namespace cute; + +template +struct MaxOp { + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); } +}; + +template +struct SumOp { + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template <> +struct Allreduce<2> { + template + static __device__ __forceinline__ T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +template +__device__ __forceinline__ void thread_reduce_(Tensor const& tensor, + Tensor& summary, Operator& op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor& dst, + Tensor& src, Operator& op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, + Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, + Tensor& max) { + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, + Tensor& sum) { + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); + if constexpr (warp_reduce) { + quad_allreduce_(sum, sum, sum_op); + } +} + +template +__forceinline__ __device__ void apply_exp2(Tensor& tensor, + Tensor const& max) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + auto row_max = max(mi); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + tensor(mi, ni) = exp2f(tensor(mi, ni) - row_max); + } + } +} + +template +__forceinline__ __device__ void scale_apply_exp2(Tensor& tensor, + Tensor const& max, + const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + auto row_max = max(mi); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // row_max * scale is a constant for each row, so we can use fma here + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - row_max * scale); + } + } +} + +template +struct DefaultUpdater { + using TensorT = decltype(make_tensor(Shape>{})); + CUTLASS_DEVICE DefaultUpdater(float scale_ = 1.f) {}; + + __forceinline__ __device__ TensorT get_lse() { return TensorT(); } + + template + __forceinline__ __device__ void update(Tensor0& acc_s) { + // NOTE(Zihao): nothing to do here + }; + + template + __forceinline__ __device__ void finalize(Tensor1& acc_s) { + // NOTE(Zihao): nothing to do here + }; + + template + __forceinline__ __device__ void rescale_o(Tensor1& acc_o) { + // NOTE(Zihao): nothing to do here + }; +}; + +template +struct OnlineSoftmax { + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum, scores_scale; + const float sm_scale_log2; + + CUTLASS_DEVICE OnlineSoftmax(float scale_ = 1.f) : sm_scale_log2(scale_) { clear(scores_scale); }; + + __forceinline__ __device__ TensorT get_lse() const { return row_sum; } + + template + __forceinline__ __device__ void update(Tensor0& acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); + + static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD); + if constexpr (init) { + reduce_max(scores, row_max); + if constexpr (WITH_SCALE) { + scale_apply_exp2(scores, row_max, sm_scale_log2); + } else { + apply_exp2(scores, row_max); + } + reduce_sum(scores, row_sum); + } else { + // update row_max + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + reduce_max(scores, row_max); + // update scores_scale and scale row_sum +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = row_max(mi); + if constexpr (WITH_SCALE) { + scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2); + } else { + scores_scale(mi) = exp2f(scores_max_prev(mi) - scores_max_cur); + } + row_sum(mi) *= scores_scale(mi); + } + // perform exp2 on scores + if constexpr (WITH_SCALE) { + scale_apply_exp2(scores, row_max, sm_scale_log2); + } else { + apply_exp2(scores, row_max); + } + // update row_sum + reduce_sum(scores, row_sum); + } + }; + + template + __forceinline__ __device__ void finalize(Tensor0& acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD); + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float sum = row_sum(mi); + float inv_sum = 1.f / sum; + scores_scale(mi) = inv_sum; + if constexpr (WITH_SCALE) { + row_sum(mi) = row_max(mi) * sm_scale_log2 + math::ptx_log2(sum); + } else { + row_sum(mi) = row_max(mi) + math::ptx_log2(sum); + } + } + }; + + template + __forceinline__ __device__ void rescale_o(Tensor1& acc_o) { + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale(mi); + } + } + }; +}; + +template +using OnlineSoftmaxWithScale = OnlineSoftmax; + +template +using OnlineSoftmaxWithoutScale = OnlineSoftmax; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_ATTENTION_UPDATER_CUH_ diff --git a/include/flashinfer/attention/hopper/block_sparse_gather.cuh b/include/flashinfer/attention/hopper/block_sparse_gather.cuh new file mode 100644 index 00000000..29988a5b --- /dev/null +++ b/include/flashinfer/attention/hopper/block_sparse_gather.cuh @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_BLOCK_SPARSE_GATHER_CUH +#define FLASHINFER_ATTENTION_HOPPER_BLOCK_SPARSE_GATHER_CUH + +#include + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/util/print.hpp" +#include "cutlass/fast_math.h" + +namespace flashinfer { + +using namespace cute; + +template +struct BlockSparseIndexedGather { + CUTE_HOST_DEVICE constexpr BlockSparseIndexedGather(IdType const* indices) : indices_(indices) {} + + template + CUTE_HOST_DEVICE constexpr IdType operator()(I i) const { + // NOTE(Zihao): there is a risk of out-of-bound access, adding boundary check here + // would degrade performance significantly. It is the user's responsibility to ensure + // that (indptr[-2] + TILE_KV) is less than the size of the indices tensor. + return indices_[i]; + } + + CUTE_HOST_DEVICE friend void print(BlockSparseIndexedGather const& s) { + cute::print("BlockSparseIndexedGather"); + } + + IdType const* indices_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride { + CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, int stride_n) + : func_(func), stride_n_(stride_n) {} + + template + CUTE_HOST_DEVICE friend auto operator*(I i, CustomStride const& s) { + // uint64_t ret; + // #if defined(__CUDA_ARCH__) + // asm("{\n\t" + // "mul.wide.u32 %0, %1, %2;\n\t" + // "}" : "=l"(ret) : "r"(s.func_(i)), "r"(s.stride_n_)); + // #else + // ret = uint64_t(s.func_(i)) * uint64_t(s.stride_n_); + // #endif + // return ret; + + // NOTE(Zihao): if the tensor is larger than 64GB ((2 ** 32) * 16byte), we use + // 64-bit multiplication to avoid overflow. Otherwise, 32-bit multiplication is + // sufficient. + // There is a 20+ TFLOPs/s gap between 32-bit and 64-bit multiplication on H100. + return uint32_t(s.func_(i)) * s.stride_n_; + } + + template + CUTE_HOST_DEVICE friend auto operator*(CustomStride const& s, I i) { + // uint64_t ret; + // #if defined(__CUDA_ARCH__) + // asm("{\n\t" + // "mul.wide.u32 %0, %1, %2;\n\t" + // "}" : "=l"(ret) : "r"(s.func_(i)), "r"(s.stride_n_)); + // #else + // ret = uint64_t(s.func_(i)) * uint64_t(s.stride_n_); + // #endif + // return ret; + + // NOTE(Zihao): if the tensor is larger than 64GB = (2 ** 32) * 16byte (16byte is the + // element size after upcasting), we use 64-bit multiplication to avoid overflow. Otherwise, + // 32-bit multiplication is sufficient. + // There is a 20+ TFLOPs/s gap between 32-bit and 64-bit multiplication on H100. + return uint32_t(s.func_(i)) * s.stride_n_; + } + + CUTE_HOST_DEVICE friend void print(CustomStride const& s) { + cute::print("BlockSparseStride{"); + print(s.func_); + cute::print(","); + print(s.stride_n_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) { + return CustomStride(s.func_, safe_div(s.stride_n_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, + CustomStride const& stride) { + return Layout(shape, stride); + } + + Func func_; + uint32_t stride_n_; +}; + +template +CUTLASS_HOST_DEVICE auto make_custom_stride_layout(int stride_n, Func&& func) { + return make_layout(make_shape(_1{}, _1{}), + make_stride(CustomStride(static_cast(func), stride_n), _1{})); +} + +/// Helper function to optionally create a block sparse gather tensor +template +CUTLASS_HOST_DEVICE auto make_block_sparse_tensor(Iterator iter, Shape const& shape, int stride_n, + Func&& func) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride_n, static_cast(func)); + + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); +} + +} // namespace flashinfer + +namespace cute { + +template +CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride) { + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, + [](auto const& s, auto const& d) { return upcast(s, d); }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); + } else { + return make_layout(shape, stride); + } + } else { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr auto upcast( + ComposedLayout, Offset, Layout> const& layout) { + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and + // offset + auto idx = + find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = + as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace cute + +#endif // FLASHINFER_ATTENTION_HOPPER_BLOCK_SPARSE_GATHER_CUH diff --git a/include/flashinfer/attention/hopper/epilogue.cuh b/include/flashinfer/attention/hopper/epilogue.cuh new file mode 100644 index 00000000..7f8b5a32 --- /dev/null +++ b/include/flashinfer/attention/hopper/epilogue.cuh @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_EPILOGUE_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_EPILOGUE_CUH_ + +#include + +#include "../../math.cuh" +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "named_barrier.cuh" +#include "utils.cuh" + +namespace flashinfer { + +using namespace cute; + +template +__forceinline__ __device__ void write_tiled(DTypeO* O, const TiledCopyO& tiled_copy_O, + const LayoutO& layout_O, const TileShapeO& tile_shape_O, + const SMemO& sO, int thread_idx, int qo_tile_idx, + int qo_head_idx, int qo_indptr, int64_t qo_len) { + Tensor mO = make_tensor(make_gmem_ptr(O + qo_indptr * stride<0>(layout_O)), layout_O); + Tensor gO = + get_local_tile_tensor(mO, tile_shape_O, qo_head_idx, 0, qo_len)(_, _, qo_tile_idx); // (O, D) + Tensor cO = cute::make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx) + + ThrCopy thr_copy_O = tiled_copy_O.get_slice(thread_idx); + Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY, CPY_O, CPY_D) + Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY, CPY_O, CPY_D) + Tensor tOcO = thr_copy_O.partition_D(cO); // (CPY, CPY_O, CPY_D) + Tensor tOsOGroup = flatten_1(tOsO); // (CPY, (CPY_O, CPY_D)) + Tensor tOgOGroup = flatten_1(tOgO); // (CPY, (CPY_O, CPY_D)) + Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D)) + + const int qo_tile_size = get<0>(tile_shape_O); + int valid_qo_tile_size = std::min(qo_len - qo_tile_idx * qo_tile_size, qo_tile_size); + if (valid_qo_tile_size == qo_tile_size) { + copy(tiled_copy_O, tOsOGroup, tOgOGroup); + } else { + // copy if not out of bound + auto predicate_fn = [&](auto coords) { + auto s_coords = tOcOGroup(_0{}, coords); + return elem_less(get<0>(s_coords), valid_qo_tile_size); + }; + copy_if(tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup); + } +} + +template +__forceinline__ __device__ void write_O(ElemO* O, const TiledCopyO& tiled_copy_O, + const LayoutO& layout_O, const TileShapeO& tile_shape_O, + const SMemO& sO, int thread_idx, int qo_tile_idx, + int qo_head_idx, int qo_indptr, int qo_len, + int write_warp_idx) { + write_tiled(O, tiled_copy_O, layout_O, tile_shape_O, sO, thread_idx, + qo_tile_idx, qo_head_idx, qo_indptr, qo_len); +} + +template +struct CollectiveEpilogue { + using DTypeO = typename Ktraits::DTypeO; + static constexpr int CTA_Q = Ktraits::CTA_Q; + static constexpr int CTA_KV = Ktraits::CTA_KV; + static constexpr int HEAD_DIM = Ktraits::HEAD_DIM; + using TileShape_QKD = Shape, Int, Int>; + + static constexpr int NUM_WARPS = Ktraits::NUM_WARPS; + static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp; + + static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; + static constexpr int NUM_MMA_THREADS = NUM_THREADS - NUM_COPY_THREADS; + + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_QKD{})), + decltype(cute::get<2>(TileShape_QKD{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_QKD{}))); + + using SmemCopyAtomO = Copy_Atom; + using SharedStorage = cute::array_aligned>; + + using ShapeT = cute::Shape; + using StrideT = cute::Shape; + using LayoutT = cute::Layout; + + using ShapeLseT = cute::Shape; + using StrideLseT = cute::Shape<_1, int64_t>; + using LayoutLseT = cute::Layout; + + using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; + using TMA_O = decltype(make_tma_copy( + GmemTiledCopyOTMA{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeT{}, StrideT{}), SmemLayoutO{}, + select<0, 2>(TileShape_QKD{}), _1{})); // no mcast for O + + static constexpr int VEC_SIZE = cute::ceil_div(128, sizeof_bits_v); + static_assert(HEAD_DIM % VEC_SIZE == 0); + static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM / VEC_SIZE; + static_assert(NUM_MMA_THREADS % NUM_THREADS_PER_ROW == 0); + static constexpr int NUM_ROWS = NUM_MMA_THREADS / NUM_THREADS_PER_ROW; + using TiledCopyOAtom = cute::Copy_Atom, DTypeO>; + using TiledCopyOThrLayout = decltype(cute::make_layout( + cute::make_shape(Int{}, Int{}), LayoutRight{})); + using TiledCopyOValLayout = + decltype(cute::make_layout(cute::make_shape(_1{}, Int{}), LayoutRight{})); + using TiledCopyO = + decltype(make_tiled_copy(TiledCopyOAtom{}, TiledCopyOThrLayout{}, // Thr layout + TiledCopyOValLayout{} // Val layout + )); + + // used for rmem -> smem O copy in fp8 kernel to undo column permutation + using ThreadLayoutrO = Layout, _4, _1>, Stride<_4, _32, _1, _0>>; + using ValueLayoutrO = + Layout, Int>, Stride<_0, _2, Stride<_4, _1>, _8>>; + using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom, DTypeO>{}, + ThreadLayoutrO{}, ValueLayoutrO{})); + using TiledCopyShaperO = Shape<_8, Int, _16, Int>; + using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout{})); + + // Host side kernel arguments + struct Arguments { + DTypeO* O_ptr; + LayoutT const layout_O; + float* lse_ptr; + LayoutLseT const layout_LSE; + }; + + // Device side kernel params + struct Params { + DTypeO* O_ptr; + LayoutT const layout_O; + float* lse_ptr; + LayoutLseT const layout_LSE; + }; + + static Params to_underlying_arguments(Arguments const& args) { + Tensor mO = make_tensor(make_gmem_ptr(args.O_ptr), args.layout_O); + return {args.O_ptr, args.layout_O, args.lse_ptr, args.layout_LSE}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& epilogue_params) {} + + template + CUTLASS_DEVICE void store(Params const& epilogue_params, FrgTensorO const& tOrO, + FrgTensorLSE const& lse, SharedStorage& shared_storage, + TiledMma tiled_mma, int thread_idx, BlockCoord const& block_coord) { + auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + block_coord; + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + + Tensor tOrO_out = convert_type(tOrO); + Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Make sure all WGs have finished reading V + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS, + /*id=*/static_cast(NamedBarriers::kValueEmpty)); + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + cutlass::arch::NamedBarrier::arrive(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.lse_ptr), epilogue_params.layout_LSE); + Tensor gLSE = get_lse_local_tile_tensor(mLSE, Shape>{}, qo_head_idx, qo_indptr, + qo_len)(_, qo_tile_idx); + Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_QKD{})); + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0, 0>(taccOcO))::value == 2); + static_assert(decltype(size<0, 1>(taccOcO))::value == 2); + // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices. + Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{}); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (epilogue_params.lse_ptr) { // don't write to LSE if it's nullptr + if (get<1>(taccOcO_row(_0{})) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < qo_len - qo_tile_idx * CTA_Q) { + gLSE(row) = lse(mi); + } + } + } + } + + int write_warp_idx = NUM_WARPS - 1; + if (cutlass::canonical_warp_idx_sync() == write_warp_idx) { + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + TiledCopyO gmem_tiled_copy_O; + write_O(epilogue_params.O_ptr, gmem_tiled_copy_O, epilogue_params.layout_O, + select<0, 2>(TileShape_QKD{}), sO, thread_idx, qo_tile_idx, + qo_head_idx, qo_indptr, qo_len, write_warp_idx); + } + + CUTLASS_DEVICE void store_tail() { + // tma_store_wait<0>(); + } + + // Write 0 to output and -inf to LSE + template + CUTLASS_DEVICE void store_zero(Params const& epilogue_params, SharedStorage& shared_storage, + int thread_idx, BlockCoord const& block_coord) { + auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + block_coord; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.O_ptr), epilogue_params.layout_O); + Tensor gO = get_local_tile_tensor(mO, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, + qo_len)(_, _, qo_tile_idx); // (O, D) + Tensor cO = cute::make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx) + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.lse_ptr), epilogue_params.layout_LSE); + Tensor gLSE = get_lse_local_tile_tensor(mLSE, Shape>{}, qo_head_idx, qo_indptr, + qo_len)(_, qo_tile_idx); + + TiledCopyO tiled_copy_O; + auto thr_copy_O = tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY, CPY_O, CPY_D) + Tensor tOrO = make_fragment_like(tOgO); // (CPY, CPY_O, CPY_D) + clear(tOrO); + Tensor tOcO = thr_copy_O.partition_D(cO); // (CPY, CPY_O, CPY_D) + Tensor tOgOGroup = flatten_1(tOgO); // (CPY, (CPY_O, CPY_D)) + Tensor tOrOGroup = flatten_1(tOrO); // (CPY, (CPY_O, CPY_D)) + Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D)) + + const int qo_tile_size = get<0>(TileShape_QKD{}); + int valid_qo_tile_size = std::min(qo_len - qo_tile_idx * qo_tile_size, qo_tile_size); + if (valid_qo_tile_size == qo_tile_size) { + copy(tiled_copy_O, tOrOGroup, tOgOGroup); + } else { + auto predicate_fn = [&](auto coords) { + auto s_coords = tOcOGroup(_0{}, coords); + return elem_less(get<0>(s_coords), valid_qo_tile_size); + }; + copy_if(tiled_copy_O, predicate_fn, tOrOGroup, tOgOGroup); + } + + static_assert(CTA_Q <= NUM_MMA_THREADS); + if (epilogue_params.lse_ptr) { // don't write to LSE if it's nullptr + if (thread_idx < qo_len - qo_tile_idx * CTA_Q) { + gLSE(thread_idx) = -math::inf; + } + } + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_EPILOGUE_CUH_ diff --git a/include/flashinfer/attention/hopper/kernel_traits.cuh b/include/flashinfer/attention/hopper/kernel_traits.cuh new file mode 100644 index 00000000..a144b708 --- /dev/null +++ b/include/flashinfer/attention/hopper/kernel_traits.cuh @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_KERNEL_TRAITS_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_KERNEL_TRAITS_CUH_ + +#include + +#include "../../cutlass_utils.cuh" +#include "cute/algorithm/copy.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/layout/layout.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" + +namespace flashinfer { + +using namespace cute; + +template +struct SharedStorageQKVO { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + union { + cute::array_aligned> smem_v; + cute::array_aligned> smem_o; + }; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterBarrier barrier_O; + typename MainloopPipeline::SharedStorage pipeline_k; + typename MainloopPipeline::SharedStorage pipeline_v; + }; +}; + +template +struct AttentionKernelTraits { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + using DTypeQKAccum = float; + + static constexpr int CTA_Q = CTA_Q_; + static_assert(CTA_Q % 64 == 0); + static constexpr int CTA_KV = CTA_KV_; + static constexpr int HEAD_DIM = HEAD_DIM_; + static_assert(HEAD_DIM % 32 == 0); + + static constexpr int NUM_WARPS = ((CTA_Q / 64) + 1) * 4; + static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp; + // NOTE(Zihao): the following constant should only be used when TMA is enabled, + // where only one warp inside a warp group is used for TMA. + static constexpr int NUM_PRODUCER_THREADS = cutlass::NumThreadsPerWarp; + + using AttentionVariant = AttentionVariant_; + using TileShape_QKD = Shape, Int, Int>; + + static constexpr int NUM_STAGES = NUM_STAGES_; + + using AtomLayoutQKD = Layout, _1, _1>>; + using TiledMmaQK = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), AtomLayoutQKD{})); + using TiledMmaPV = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(TileShape_QKD{})), GMMA::Major::K, + GMMA::Major::MN>(), + AtomLayoutQKD{})); + + static constexpr int NUM_MMA_THREADS = size(TiledMmaQK{}); + + using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})), + decltype(cute::get<2>(TileShape_QKD{}))>()); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_QKD{}))); + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})), + decltype(cute::get<2>(TileShape_QKD{}))>()); + using SmemLayoutK = decltype(tile_to_shape( + SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int{}))); + + using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})), + decltype(cute::get<2>(TileShape_QKD{}))>()); + using SmemLayoutV = decltype(tile_to_shape( + SmemLayoutAtomV{}, + make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{}), Int{}))); + + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutVt = decltype(composition( + SmemLayoutV{}, make_ordered_layout(make_shape(get<2>(TileShape_QKD{}), + get<1>(TileShape_QKD{}), Int{}), + Step<_2, _1, _3>{}))); + + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_QKD{})), + decltype(cute::get<2>(TileShape_QKD{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_QKD{}))); + using MainloopPipeline = + std::conditional_t, + typename cutlass::PipelineAsync>; + using PipelineState = typename cutlass::PipelineState; + + using SharedStorage = SharedStorageQKVO; +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_KERNEL_TRAITS_CUH_ diff --git a/include/flashinfer/attention/hopper/mainloop.cuh b/include/flashinfer/attention/hopper/mainloop.cuh new file mode 100644 index 00000000..a6b561e5 --- /dev/null +++ b/include/flashinfer/attention/hopper/mainloop.cuh @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_MAINLOOP_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_MAINLOOP_CUH_ + +#include +#include +#include +#include + +#include "../../math.cuh" +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "mainloop_mma.cuh" +#include "named_barrier.cuh" +#include "utils.cuh" + +namespace flashinfer { + +using namespace cute; + +template +struct CollectiveMainloop { + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + static constexpr int CTA_Q = get<0>(TileShape_QKD{}); + static constexpr int CTA_KV = get<1>(TileShape_QKD{}); + + static constexpr int NUM_STAGES = Ktraits::NUM_STAGES; + static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; + static constexpr int HEAD_DIM = Ktraits::HEAD_DIM; + + using GmemTiledCopyQ = cute::SM90_TMA_LOAD; + using GmemTiledCopyKV = cute::SM90_TMA_LOAD; + + using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutK = typename Ktraits::SmemLayoutK; + using SmemLayoutV = typename Ktraits::SmemLayoutV; + using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + + using ShapeT = cute::Shape; + using StrideT = cute::Shape; // (N, D, H) + using LayoutT = cute::Layout; + + using ShapeLseT = cute::Shape; + using StrideLseT = cute::Shape<_1, int64_t>; + using LayoutLseT = cute::Layout; + + using TMA_Q = decltype(make_tma_copy( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideT{}, int32_t(0)), StrideT{}), + SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{})); // no mcast for Q + + using TMA_K = decltype(make_tma_copy( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideT{}, int32_t(0)), StrideT{}), + take<0, 2>(SmemLayoutK{}), select<1, 2>(TileShape_QKD{}), _1{})); // no mcast + + using TMA_V = decltype(make_tma_copy( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideT{}, int32_t(0)), StrideT{}), + take<0, 2>(SmemLayoutV{}), select<1, 2>(TileShape_QKD{}), _1{})); // no mcast + + static constexpr bool USE_TMA_LOAD_KV = true; + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesQ = + static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesK = + static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); + + // Whether use scheduler barrier or hardware warp scheduler, using heuristic based on data type + // and head dim + static constexpr bool UseSchedulerBarrier = + cutlass::sizeof_bits_v == 8 ? HEAD_DIM >= 128 : HEAD_DIM <= 128; + using WarpScheduler = WarpScheduler; + + // Host side kernel arguments + struct Arguments { + DTypeQ const* Q_ptr; + LayoutT layout_Q; + DTypeKV const* K_ptr; + LayoutT layout_K; + DTypeKV const* V_ptr; + LayoutT layout_V; + int window_left; + float const logits_soft_cap; + float const sm_scale_log2; + }; + + // Device side kernel params + struct Params { + LayoutT layout_Q; + LayoutT layout_K; + LayoutT layout_V; + TMA_Q tma_load_Q; + TMA_K tma_load_K; + TMA_V tma_load_V; + int window_left; + float const logits_soft_cap; + float const sm_scale_log2; + }; + + static Params to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.Q_ptr), args.layout_Q); + TMA_Q tma_load_Q = make_tma_copy(GmemTiledCopyQ{}, mQ, SmemLayoutQ{}, + select<0, 2>(TileShape_QKD{}), _1{}); // no mcast for Q + Tensor mK = make_tensor(make_gmem_ptr(args.K_ptr), args.layout_K); + TMA_K tma_load_K = make_tma_copy(GmemTiledCopyKV{}, mK, SmemLayoutK{}(_, _, _0{}), + select<1, 2>(TileShape_QKD{}), _1{}); // no mcast + Tensor mV = make_tensor(make_gmem_ptr(args.V_ptr), args.layout_V); + TMA_V tma_load_V = make_tma_copy(GmemTiledCopyKV{}, mV, SmemLayoutV{}(_, _, _0{}), + select<1, 2>(TileShape_QKD{}), _1{}); // no mcast + return {args.layout_Q, args.layout_K, args.layout_V, tma_load_Q, tma_load_K, + tma_load_V, args.window_left, args.logits_soft_cap, args.sm_scale_log2}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor()); + } + + CUTLASS_DEVICE + int get_num_kv_tiles(Params const& mainloop_params, int q_tile_idx, const int qo_len, + const int kv_len) { + static constexpr int CTA_Q = get<0>(TileShape_QKD{}); + static constexpr int CTA_KV = get<1>(TileShape_QKD{}); + int num_kv_tiles = cute::ceil_div(kv_len, CTA_KV); + if constexpr (CAUSAL) { + num_kv_tiles = std::min(num_kv_tiles, + cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV)); + } + + return num_kv_tiles; + } + + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v, SharedStorage& shared_storage, + Scheduler& scheduler, typename Scheduler::Params const& scheduler_params, + typename Scheduler::WorkTileInfo& work_tile_info, + BlockCoord const& block_coord, int work_idx) { + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); + + Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); + Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape()); + Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape()); + + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; + + // Prepare the TMA loads + Tensor gQ = get_local_tile_tensor(mQ, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, + qo_len)(_, _, q_tile_idx); // (Q, D) + Tensor gK = get_local_tile_tensor(mK, select<1, 2>(TileShape_QKD{}), kv_head_idx, kv_indptr, + kv_len); // (K, D, _) + Tensor gV = get_local_tile_tensor(mV, select<1, 2>(TileShape_QKD{}), kv_head_idx, kv_indptr, + kv_len); // (K, D, _) + + Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); + Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); + auto [tQgQ, tQsQ] = + tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{}, group_modes<0, 2>(sQ_x), + group_modes<0, 2>(gQ_x)); // (TMA), (TMA) + auto [tKgK, tKsK] = + tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{}, group_modes<0, 2>(sK), + group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE) + auto [tVgV, tVsV] = + tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{}, group_modes<0, 2>(sV), + group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE) + + int num_kv_tiles = get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + int kv_tile_idx = num_kv_tiles - 1; + int swa_begin_kv_tile_idx = 0; + if constexpr (LEFT_SLIDING_WINDOW) { + swa_begin_kv_tile_idx = get_swa_begin_kv_tile_idx(mainloop_params.window_left, + q_tile_idx, qo_len, kv_len); + } + + int lane_predicate = cute::elect_one_sync(); + if (lane_predicate) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), + /*mcast_mask=*/0), + tKgK(_, kv_tile_idx), tKsK(_, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + } + + // Wait for the MMA warpgroups to say that smem_q is ready + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, + static_cast(NamedBarriers::kQueryEmpty)); + + if (lane_predicate) { + shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(mainloop_params.tma_load_Q.with( + reinterpret_cast( + shared_storage.barrier_Q), + /*mcast_mask=*/0), + tQgQ, tQsQ); + } + + // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem + // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the + // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on + // O. + shared_storage.barrier_O.wait((work_idx + 1) % 2); + + if (lane_predicate) { +#pragma unroll 2 + for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), + /*mcast_mask=*/0), + tKgK(_, kv_tile_idx - 1), tKsK(_, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), + /*mcast_mask=*/0), + tVgV(_, kv_tile_idx), tVsV(_, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + } + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + if (lane_predicate) { + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), + /*mcast_mask=*/0), + tVgV(_, kv_tile_idx), tVsV(_, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + scheduler.broadcast_next_work(work_tile_info); + } + + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, + PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v) { + int lane_predicate = cute::elect_one_sync(); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + if (warp_idx_in_warpgroup == 0 && lane_predicate) { + pipeline_k.producer_tail(smem_pipe_write_k); + pipeline_v.producer_tail(smem_pipe_write_v); + } + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_MAINLOOP_CUH_ diff --git a/include/flashinfer/attention/hopper/mainloop_mma.cuh b/include/flashinfer/attention/hopper/mainloop_mma.cuh new file mode 100644 index 00000000..b98df9e0 --- /dev/null +++ b/include/flashinfer/attention/hopper/mainloop_mma.cuh @@ -0,0 +1,265 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_MAINLOOP_MMA_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_MAINLOOP_MMA_CUH_ + +#include +#include +#include +#include + +namespace flashinfer { + +template +CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, AttentionVariant& variant, + MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, + PipelineState& smem_pipe_read_k, PipelineState& smem_pipe_read_v, + FrgTensorO& tOrO, AttentionUpdater& attention_updater, + int kv_tile_idx_count, int swa_begin_kv_tile_idx, + int swa_end_kv_tile_idx, int thread_idx, int work_idx, int q_tile_idx, + SharedStorage& shared_storage, const int32_t qo_len, + const int32_t kv_len, const int32_t qo_head_idx, + const int32_t kv_head_idx) { + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using IdType = typename Ktraits::IdType; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; + using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutK = typename Ktraits::SmemLayoutK; + using SmemLayoutV = typename Ktraits::SmemLayoutV; + using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + static_assert(is_rmem::value, "O tensor must be rmem resident."); + + static constexpr int CTA_Q = get<0>(TileShape_QKD{}); + static constexpr int CTA_KV = get<1>(TileShape_QKD{}); + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{}); + + typename Ktraits::TiledMmaQK tiled_mma_qk; + typename Ktraits::TiledMmaPV tiled_mma_pv; + auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx); + auto threadMmaPV = tiled_mma_pv.get_thread_slice(thread_idx); + + Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ); + Tensor tSrK = threadMmaQK.partition_fragment_B(sK); + Tensor tOrV = threadMmaPV.partition_fragment_B(sVt); + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + int kv_tile_idx = kv_tile_idx_count - 1; + + cutlass::ConsumerToken barrier_token = + static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); + if (barrier_token == cutlass::BarrierStatus::WaitAgain) { + shared_storage.barrier_Q.wait(work_idx % 2); + } + + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + + WarpScheduler::barrier_sync(); + gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), + tSrS); + WarpScheduler::barrier_arrive(); + + if (work_idx != 0) { + int lane_predicate = cute::elect_one_sync(); + if (cutlass::canonical_warp_idx_sync() == Ktraits::NUM_WARPS - 1 && lane_predicate) { +#pragma unroll + for (uint32_t cta_id = 0; cta_id < 1; ++cta_id) { + shared_storage.barrier_O.arrive(cta_id, lane_predicate); + } + } + } + warpgroup_wait<0>(); + pipeline_k.consumer_release(smem_pipe_read_k); + ++smem_pipe_read_k; + + auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; }; + auto col_limit_left = [&](int qo_idx) { + return qo_idx + kv_len - qo_len - mainloop_params.window_left; + }; + { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); + Tensor tScS = threadMmaQK.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; + int kv_idx = get<1>(tScS(i)) + kv_tile_idx * CTA_KV; + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, + qo_head_idx, kv_head_idx); + if constexpr (!CAUSAL) { // Just masking based on col + if (kv_idx >= kv_len) { + tSrS(i) = -math::inf; + } + } else { + if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) { + tSrS(i) = -math::inf; + } + } + if constexpr (LEFT_SLIDING_WINDOW) { + if (kv_idx < col_limit_left(qo_idx)) { + tSrS(i) = -math::inf; + } + } + } + } + + attention_updater.update(tSrS); + Tensor tOrP = make_tensor(convert_type(tSrS).data(), + convert_layout_acc_Aregs(tSrS.layout())); + + constexpr int n_masking_steps = CAUSAL ? cute::ceil_div(CTA_Q, CTA_KV) : 0; + // masking loops +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps && kv_tile_idx > swa_begin_kv_tile_idx; + ++masking_step, --kv_tile_idx) { + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + WarpScheduler::barrier_sync(); + gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), + tSrS); + if (masking_step > 0) { + attention_updater.rescale_o(tOrO); + } + consumer_wait(pipeline_v, smem_pipe_read_v); + gemm(tiled_mma_pv, tOrP, + tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + WarpScheduler::barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read_k); // release K + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); + Tensor tScS = threadMmaQK.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; + int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, + qo_head_idx, kv_head_idx); + if (kv_idx >= col_limit_right(qo_idx)) { + tSrS(i) = -math::inf; + } + if constexpr (LEFT_SLIDING_WINDOW) { + if (kv_idx < col_limit_left(qo_idx)) { + tSrS(i) = -math::inf; + } + } + } + attention_updater.update(tSrS); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + ++smem_pipe_read_k; + ++smem_pipe_read_v; + cute::copy(make_tensor(convert_type(tSrS).data(), + convert_layout_acc_Aregs(tSrS.layout())), + tOrP); + } + +#pragma unroll 1 + for (; kv_tile_idx > swa_end_kv_tile_idx + 1; --kv_tile_idx) { + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + WarpScheduler::barrier_sync(); + gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), + tSrS); + attention_updater.rescale_o(tOrO); + consumer_wait(pipeline_v, smem_pipe_read_v); + gemm(tiled_mma_pv, tOrP, + tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + WarpScheduler::barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read_k); // release K + // #pragma unroll + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); + Tensor tScS = threadMmaQK.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; + int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, + qo_head_idx, kv_head_idx); + } + attention_updater.update(tSrS); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + ++smem_pipe_read_k; + ++smem_pipe_read_v; + cute::copy(make_tensor(convert_type(tSrS).data(), + convert_layout_acc_Aregs(tSrS.layout())), + tOrP); + } + + if constexpr (LEFT_SLIDING_WINDOW) { + constexpr int n_swa_masking_steps = cute::ceil_div(CTA_Q, CTA_KV) + 1; +#pragma unroll + for (int masking_step = 0; + masking_step < n_swa_masking_steps && kv_tile_idx > swa_begin_kv_tile_idx; + ++masking_step, --kv_tile_idx) { + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + WarpScheduler::barrier_sync(); + gemm(tiled_mma_qk, tSrQ, + tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); + attention_updater.rescale_o(tOrO); + consumer_wait(pipeline_v, smem_pipe_read_v); + gemm(tiled_mma_pv, tOrP, + tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + WarpScheduler::barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read_k); // release K + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); + Tensor tScS = threadMmaQK.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; + int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, + qo_head_idx, kv_head_idx); + if (kv_idx < col_limit_left(qo_idx)) { + tSrS(i) = -math::inf; + } + } + attention_updater.update(tSrS); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + ++smem_pipe_read_k; + ++smem_pipe_read_v; + cute::copy(make_tensor(convert_type(tSrS).data(), + convert_layout_acc_Aregs(tSrS.layout())), + tOrP); + } + } + + // Tell warp 0 that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, + /*id=*/static_cast(NamedBarriers::kQueryEmpty)); + attention_updater.rescale_o(tOrO); + consumer_wait(pipeline_v, smem_pipe_read_v); + gemm(tiled_mma_pv, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), + tOrO); + attention_updater.finalize(tSrS); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang + ++smem_pipe_read_v; + + attention_updater.rescale_o(tOrO); + return; +} + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_MAINLOOP_MMA_CUH_ diff --git a/include/flashinfer/attention/hopper/named_barrier.cuh b/include/flashinfer/attention/hopper/named_barrier.cuh new file mode 100644 index 00000000..8ba3b3a0 --- /dev/null +++ b/include/flashinfer/attention/hopper/named_barrier.cuh @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_NAMED_BARRIERS_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_NAMED_BARRIERS_CUH_ + +#include + +#include "cutlass/arch/barrier.h" +#include "cutlass/cutlass.h" + +namespace flashinfer { + +// Enumerates the reserved named barriers to avoid potential conflicts + +enum class NamedBarriers { + kQueryEmpty = 0, + kValueEmpty = 1, + kWarpSchedulerWG1 = 2, + kWarpSchedulerWG2 = 3, + kWarpSchedulerWG3 = 4, + kPrefetchIndices = 5, +}; + +__device__ __forceinline__ int get_warp_group_barrier_idx(int warp_group_idx) { + return static_cast(NamedBarriers::kWarpSchedulerWG1) + warp_group_idx - 1; +} + +template +__device__ __forceinline__ int get_next_consumer_warp_group_idx() { + static_assert(num_consumer_warp_groups == 2 || num_consumer_warp_groups == 3); + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if constexpr (num_consumer_warp_groups == 2) { + // 1 -> 2, 2 -> 1 + return 3 - warp_group_idx; + } else { + // num_consumer_warp_groups == 3 + // 1 -> 2, 2 -> 3, 3 -> 1 + return (warp_group_idx % 3) + 1; + } +} + +template +__device__ __forceinline__ int get_prev_consumer_warp_group_idx() { + static_assert(num_consumer_warp_groups == 2 || num_consumer_warp_groups == 3); + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if constexpr (num_consumer_warp_groups == 2) { + // 1 -> 2, 2 -> 1 + return 3 - warp_group_idx; + } else { + // num_consumer_warp_groups == 3 + // 1 -> 3, 2 -> 1, 3 -> 2 + return ((warp_group_idx + 1) % 3) + 1; + } +} + +template +struct WarpScheduler { + constexpr static int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; + static CUTLASS_DEVICE void barrier_sync() { + if constexpr (UseSchedulerBarrier) { + cutlass::arch::NamedBarrier::sync( + NUM_MMA_THREADS, get_warp_group_barrier_idx(cutlass::canonical_warp_group_idx())); + } + } + + static CUTLASS_DEVICE void barrier_arrive() { + if constexpr (!UseSchedulerBarrier) { + return; + } + static_assert(NUM_MMA_THREADS == 2 * cutlass::NumThreadsPerWarpGroup || + NUM_MMA_THREADS == 3 * cutlass::NumThreadsPerWarpGroup); + if constexpr (NUM_MMA_THREADS == 2 * cutlass::NumThreadsPerWarpGroup) { + cutlass::arch::NamedBarrier::arrive( + NUM_MMA_THREADS, get_warp_group_barrier_idx(get_next_consumer_warp_group_idx<2>())); + } else { + cutlass::arch::NamedBarrier::arrive( + NUM_MMA_THREADS, get_warp_group_barrier_idx(get_next_consumer_warp_group_idx<3>())); + cutlass::arch::NamedBarrier::arrive( + NUM_MMA_THREADS, get_warp_group_barrier_idx(get_prev_consumer_warp_group_idx<3>())); + } + } + + static CUTLASS_DEVICE void mma_init() { + // Tell producer (warp 0) that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, + /*id=*/static_cast(NamedBarriers::kQueryEmpty)); + if constexpr (!UseSchedulerBarrier) { + return; + } + static_assert(NUM_MMA_THREADS == 2 * cutlass::NumThreadsPerWarpGroup || + NUM_MMA_THREADS == 3 * cutlass::NumThreadsPerWarpGroup); + if (cutlass::canonical_warp_group_idx() > 1) { + cutlass::arch::NamedBarrier::arrive( + NUM_MMA_THREADS, /*id=*/static_cast(NamedBarriers::kWarpSchedulerWG1)); + } + if constexpr (NUM_MMA_THREADS == 3 * cutlass::NumThreadsPerWarpGroup) { + if (cutlass::canonical_warp_group_idx() > 2) { + cutlass::arch::NamedBarrier::arrive( + NUM_MMA_THREADS, /*id=*/static_cast(NamedBarriers::kWarpSchedulerWG2)); + } + } + } + +}; // struct WarpScheduler + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_NAMED_BARRIERS_CUH_ diff --git a/include/flashinfer/attention/hopper/params.cuh b/include/flashinfer/attention/hopper/params.cuh new file mode 100644 index 00000000..fcd80a95 --- /dev/null +++ b/include/flashinfer/attention/hopper/params.cuh @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_PARAMS_CUH +#define FLASHINFER_ATTENTION_HOPPER_PARAMS_CUH + +#include + +#include + +namespace flashinfer { + +template +struct SinglePrefillParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + // The QKV matrices. + DTypeQ* q_ptr; + DTypeKV* k_ptr; + DTypeKV* v_ptr; + DTypeO* o_ptr; + float* lse_ptr; + + int64_t q_stride_n; + int64_t k_stride_n; + int64_t v_stride_n; + int64_t o_stride_n; + int64_t q_stride_h; + int64_t k_stride_h; + int64_t v_stride_h; + int64_t o_stride_h; + + int qo_len; + int kv_len; + int head_dim; + int num_qo_heads; + int num_kv_heads; + int group_size; + int window_left; + + float logits_soft_cap; + float sm_scale_log2; + bool causal; + + struct AdditionalParams {}; +}; + +template +struct BatchPrefillRaggedParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + // The QKV matrices. + DTypeQ* q_ptr; + DTypeKV* k_ptr; + DTypeKV* v_ptr; + DTypeO* o_ptr; + float* lse_ptr; + + IdType* qo_tile_indices; + IdType* qo_indptr; + IdType* kv_indptr; + IdType* qo_lens; + IdType* kv_lens; + IdType* head_indices; + IdType* work_indptr; + + int64_t q_stride_n; + int64_t k_stride_n; + int64_t v_stride_n; + int64_t o_stride_n; + int64_t q_stride_h; + int64_t k_stride_h; + int64_t v_stride_h; + int64_t o_stride_h; + int64_t nnz_qo; + int64_t nnz_kv; + + int head_dim; + int num_qo_heads; + int num_kv_heads; + int group_size; + int window_left; + + float logits_soft_cap; + float sm_scale_log2; + bool causal; + + struct AdditionalParams {}; +}; + +template +struct BatchPrefillPagedParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + // The QKV matrices. + DTypeQ* q_ptr; + DTypeKV* k_ptr; + DTypeKV* v_ptr; + DTypeO* o_ptr; + float* lse_ptr; + + IdType* qo_tile_indices; + IdType* qo_indptr; + IdType* kv_indptr; + IdType* kv_indices; + IdType* qo_lens; + IdType* kv_lens; + IdType* head_indices; + IdType* work_indptr; + + int64_t q_stride_n; + int64_t k_stride_n; + int64_t v_stride_n; + int64_t o_stride_n; + int64_t q_stride_h; + int64_t k_stride_h; + int64_t v_stride_h; + int64_t o_stride_h; + int64_t nnz_qo; + + int head_dim; + int num_qo_heads; + int num_kv_heads; + int group_size; + int page_size; + int window_left; + + float logits_soft_cap; + float sm_scale_log2; + bool causal; + + struct AdditionalParams {}; +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_PARAMS_CUH diff --git a/include/flashinfer/attention/hopper/prefill_sm90.cuh b/include/flashinfer/attention/hopper/prefill_sm90.cuh new file mode 100644 index 00000000..708f80f3 --- /dev/null +++ b/include/flashinfer/attention/hopper/prefill_sm90.cuh @@ -0,0 +1,524 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_PREFILL_SM90_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_PREFILL_SM90_CUH_ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../../cutlass_utils.cuh" +#include "../../exception.h" +#include "../mask.cuh" +#include "cute/tensor.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "epilogue.cuh" +#include "kernel_traits.cuh" +#include "mainloop.cuh" +#include "mainloop_mma.cuh" +#include "params.cuh" +#include "sparse_mainloop.cuh" +#include "tile_scheduler.cuh" +#include "utils.cuh" + +namespace flashinfer { + +using namespace cute; + +template +__global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp, 1) + PrefillWithKVCacheKernel(CUTE_GRID_CONSTANT + typename CollectiveMainloop::Params const mainloop_params, + CUTE_GRID_CONSTANT + typename CollectiveEpilogue::Params const epilogue_params, + CUTE_GRID_CONSTANT + typename TileScheduler::Params const scheduler_params) { + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using DTypeO = typename Ktraits::DTypeO; + using DTypeQKAccum = typename Ktraits::DTypeQKAccum; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + using AttentionVariant = typename Ktraits::AttentionVariant; + AttentionVariant variant(mainloop_params); + + static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; + static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; + static constexpr int CTA_Q = Ktraits::CTA_Q; + static constexpr int CTA_KV = Ktraits::CTA_KV; + + static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV; + + using AttentionUpdater = + typename AttentionVariant::template Updater<2 * (2 * CTA_Q / NUM_MMA_THREADS)>; + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + extern __shared__ char shared_memory[]; + auto& shared_storage = *reinterpret_cast(shared_memory); + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); + CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params); + } + + // Obtain warp index + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; + if constexpr (use_tma_load_kv) { + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NUM_MMA_THREADS; + } else { + pipeline_params.producer_arv_count = NUM_COPY_THREADS; + pipeline_params.consumer_arv_count = NUM_MMA_THREADS; + } + + if (warp_idx == 0 && lane_predicate) { + shared_storage.barrier_Q.init(/*num_threads=*/1); + shared_storage.barrier_O.init(/*num_threads=*/1); + } + // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); + MainloopPipeline pipeline_k = [&] { + if constexpr (use_tma_load_kv) { + return MainloopPipeline(shared_storage.pipeline_k, pipeline_params, + /*cluster_shape=*/Shape<_1, _1, _1>{}); + } else { + return MainloopPipeline(shared_storage.pipeline_k, pipeline_params); + } + }(); + + MainloopPipeline pipeline_v = [&] { + if constexpr (use_tma_load_kv) { + return MainloopPipeline(shared_storage.pipeline_v, pipeline_params, + /*cluster_shape=*/Shape<_1, _1, _1>{}); + } else { + return MainloopPipeline(shared_storage.pipeline_v, pipeline_params); + } + }(); + + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue; + + // We need this to guarantee that the Pipeline init is visible to all producers and consumer + // blocks in the Cluster + __syncthreads(); + + if (warp_group_idx == 0) { // Producer + if constexpr (use_tma_load_kv) { + cutlass::arch::warpgroup_reg_dealloc(); + } else { + cutlass::arch::warpgroup_reg_dealloc<72>(); + } + + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + if (!use_tma_load_kv || warp_idx_in_warpgroup == 0) { // Load Q, K, V + PipelineState smem_pipe_write_k = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_v = cutlass::make_producer_start_state(); + + int work_idx = 0; + + TileScheduler scheduler; + for (auto work_tile_info = scheduler.get_initial_work(scheduler_params); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work( + scheduler_params, work_tile_info)) { + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + block_coord; + + if (q_tile_idx * CTA_Q >= qo_len) { + continue; + } + int num_kv_tiles = + collective_mainloop.get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + if (num_kv_tiles <= 0) { + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + scheduler.broadcast_next_work(work_tile_info); + continue; + } + collective_mainloop.load( + mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, + shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx); + ++work_idx; + } + collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v); + } + } else { // Consumer + if constexpr (use_tma_load_kv) { + cutlass::arch::warpgroup_reg_alloc(); + } else { + cutlass::arch::warpgroup_reg_alloc(); + } + + TileScheduler scheduler; + // Initialize matmul objects. + typename Ktraits::TiledMmaPV tiled_mma_pv; + + PipelineState smem_pipe_read_k, smem_pipe_read_v; + // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v + // (like in Cutlass's gemm) because the read and release pipeline states are always the same. + + CollectiveMainloop::WarpScheduler::mma_init(); + scheduler.init_consumer(); + + int work_idx = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = scheduler.get_initial_work(scheduler_params); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work(scheduler_params, + work_tile_info)) { + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 2>(TileShape_QKD{})); + AttentionUpdater attention_updater(mainloop_params.sm_scale_log2); + + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + block_coord; + + if (q_tile_idx * CTA_Q >= qo_len) { + continue; + } + int num_kv_tiles = + collective_mainloop.get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + if (num_kv_tiles <= 0) { // We exit early and write 0 to gO and -inf to gLSE. + collective_epilogue.store_zero(epilogue_params, shared_storage, + threadIdx.x - NUM_COPY_THREADS, block_coord); + continue; + } + + int swa_begin_kv_tile_idx = 0; + int swa_end_kv_tile_idx = -1; + if constexpr (LEFT_SLIDING_WINDOW) { + swa_begin_kv_tile_idx = get_swa_begin_kv_tile_idx( + mainloop_params.window_left, q_tile_idx, qo_len, kv_len); + swa_end_kv_tile_idx = get_swa_end_kv_tile_idx(mainloop_params.window_left, + q_tile_idx, qo_len, kv_len); + } + + mma_f16( + mainloop_params, variant, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, + tOrO, attention_updater, num_kv_tiles, swa_begin_kv_tile_idx, swa_end_kv_tile_idx, + threadIdx.x - NUM_COPY_THREADS, work_idx, q_tile_idx, shared_storage, qo_len, kv_len, + qo_head_idx, kv_head_idx); + collective_epilogue.store(epilogue_params, tOrO, attention_updater.get_lse(), shared_storage, + tiled_mma_pv, threadIdx.x - NUM_COPY_THREADS, block_coord); + + ++work_idx; + } + collective_epilogue.store_tail(); + } +} + +template +cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched( + SinglePrefillParams& params, + cudaStream_t stream) { + using DTypeQ = typename KernelTraits::DTypeQ; + using DTypeKV = typename KernelTraits::DTypeKV; + using DTypeO = typename KernelTraits::DTypeO; + using TileShape_QKD = typename KernelTraits::TileShape_QKD; + + using CollectiveMainloop = CollectiveMainloop; + using CollectiveEpilogue = CollectiveEpilogue; + using Scheduler = SingleTileScheduler; + typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( + {params.q_ptr, + get_gmem_layout(params.qo_len, params.num_qo_heads, params.head_dim, params.q_stride_n, + params.q_stride_h), // layout_Q + params.k_ptr, + get_gmem_layout(params.kv_len, params.num_kv_heads, params.head_dim, params.k_stride_n, + params.k_stride_h), // layout_K + params.v_ptr, + get_gmem_layout(params.kv_len, params.num_kv_heads, params.head_dim, params.v_stride_n, + params.v_stride_h), // layout_V + params.window_left, params.logits_soft_cap, params.sm_scale_log2}); + typename CollectiveEpilogue::Params epilogue_params = + CollectiveEpilogue::to_underlying_arguments({ + static_cast(params.o_ptr), + get_gmem_layout(params.qo_len, params.num_qo_heads, params.head_dim, params.o_stride_n, + params.o_stride_h), // layout_O + static_cast(params.lse_ptr), + get_lse_gmem_layout(params.qo_len, params.num_qo_heads), // layout_LSE + }); + + int num_tiles_q = cutlass::ceil_div(params.qo_len, KernelTraits::CTA_Q); + // TODO(Zihao): also support kv-head major + typename Scheduler::Arguments scheduler_args = { + num_tiles_q, params.num_qo_heads, params.qo_len, params.kv_len, + cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads)}; + typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); + + auto kernel = + (void*)PrefillWithKVCacheKernel; + int smem_size = sizeof(typename KernelTraits::SharedStorage); + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int device; + cudaGetDevice(&device); + int multiprocessor_count; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); + static constexpr int num_ctas = KernelTraits::NUM_WARPS * 32; + dim3 block_dims(num_ctas); + void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream)); + + return cudaSuccess; +} + +template +cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched( + BatchPrefillPagedParams& params, + cudaStream_t stream) { + using DTypeQ = typename KernelTraits::DTypeQ; + using DTypeKV = typename KernelTraits::DTypeKV; + using DTypeO = typename KernelTraits::DTypeO; + using TileShape_QKD = typename KernelTraits::TileShape_QKD; + + using CollectiveMainloop = SparseCollectiveMainloop; + using CollectiveEpilogue = CollectiveEpilogue; + using Scheduler = BatchPrefillTileScheduler; + + typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( + {params.q_ptr, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.q_stride_n, + params.q_stride_h), // layout_Q + params.k_ptr, + // NOTE(Zihao): nnz was useless here, we can just pass 0 + get_gmem_layout(/*nnz=*/0, params.num_kv_heads, params.head_dim, params.k_stride_n, + params.k_stride_h), // layout_K + params.v_ptr, + get_gmem_layout(/*nnz=*/0, params.num_kv_heads, params.head_dim, params.v_stride_n, + params.v_stride_h), // layout_V + params.kv_indices, params.window_left, params.logits_soft_cap, params.sm_scale_log2}); + typename CollectiveEpilogue::Params epilogue_params = + CollectiveEpilogue::to_underlying_arguments({ + params.o_ptr, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.o_stride_n, + params.o_stride_h), // layout_O + params.lse_ptr, get_lse_gmem_layout(params.nnz_qo, params.num_qo_heads), // layout_LSE + }); + + typename Scheduler::Arguments scheduler_args = { + params.work_indptr, params.head_indices, + params.qo_tile_indices, params.qo_indptr, + params.kv_indptr, params.qo_lens, + params.kv_lens, cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads)}; + typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); + + // Get the ptr to kernel function. + auto kernel = + (void*)PrefillWithKVCacheKernel; + int smem_size = sizeof(typename KernelTraits::SharedStorage); + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int device; + cudaGetDevice(&device); + int multiprocessor_count; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); + static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; + dim3 block_dims(ctaSize); + void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream)); + + return cudaSuccess; +} + +template +cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched( + BatchPrefillRaggedParams& params, + cudaStream_t stream) { + using DTypeQ = typename KernelTraits::DTypeQ; + using DTypeKV = typename KernelTraits::DTypeKV; + using DTypeO = typename KernelTraits::DTypeO; + using TileShape_QKD = typename KernelTraits::TileShape_QKD; + + using CollectiveMainloop = CollectiveMainloop; + using CollectiveEpilogue = CollectiveEpilogue; + using Scheduler = BatchPrefillTileScheduler; + typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( + {params.q_ptr, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.q_stride_n, + params.q_stride_h), // layout_Q + params.k_ptr, + // NOTE(Zihao): nnz was useless here, we can just pass 0 + get_gmem_layout(params.nnz_kv, params.num_kv_heads, params.head_dim, params.k_stride_n, + params.k_stride_h), // layout_K + params.v_ptr, + get_gmem_layout(params.nnz_kv, params.num_kv_heads, params.head_dim, params.v_stride_n, + params.v_stride_h), // layout_V + params.window_left, params.logits_soft_cap, params.sm_scale_log2}); + typename CollectiveEpilogue::Params epilogue_params = + CollectiveEpilogue::to_underlying_arguments({ + params.o_ptr, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.o_stride_n, + params.o_stride_h), // layout_O + params.lse_ptr, get_lse_gmem_layout(params.nnz_qo, params.num_qo_heads), // layout_LSE + }); + + // NOTE(Zihao): add support for kv head-major later + typename Scheduler::Arguments scheduler_args = { + params.work_indptr, params.head_indices, + params.qo_tile_indices, params.qo_indptr, + params.kv_indptr, params.qo_lens, + params.kv_lens, cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads)}; + typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); + + // Get the ptr to kernel function. + auto kernel = + (void*)PrefillWithKVCacheKernel; + int smem_size = sizeof(typename KernelTraits::SharedStorage); + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int device; + cudaGetDevice(&device); + int multiprocessor_count; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); + static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; + dim3 block_dims(ctaSize); + void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream)); + + return cudaSuccess; +} + +template +cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams& params, + cudaStream_t stream) { + static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); + if (MASK_MODE == MaskMode::kCustom) { + return cudaErrorNotSupported; // Not supported yet. + } + constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; + if constexpr (HEAD_DIM == 64) { + SinglePrefillWithKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } else if constexpr (HEAD_DIM == 128) { + SinglePrefillWithKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } else { + // HEAD_DIM == 256; + SinglePrefillWithKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } + cudaError_t status = cudaGetLastError(); + return status; +} + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched( + BatchPrefillRaggedParams& params, cudaStream_t stream) { + static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); + if (MASK_MODE == MaskMode::kCustom) { + return cudaErrorNotSupported; // Not supported yet. + } + constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; + if constexpr (HEAD_DIM == 64) { + BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } else if constexpr (HEAD_DIM == 128) { + BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } else { + // HEAD_DIM == 256; + BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } + cudaError_t status = cudaGetLastError(); + return status; +} + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched( + BatchPrefillPagedParams& params, cudaStream_t stream) { + static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); + if (MASK_MODE == MaskMode::kCustom) { + return cudaErrorNotSupported; // Not supported yet. + } + constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; + if constexpr (HEAD_DIM == 64) { + // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 64, need to optimize later + BatchPrefillWithPagedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } else if constexpr (HEAD_DIM == 128) { + BatchPrefillWithPagedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } else { + // HEAD_DIM == 256; + // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 256, need to optimize later + BatchPrefillWithPagedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } + cudaError_t status = cudaGetLastError(); + return status; +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_PREFILL_SM90_CUH_ diff --git a/include/flashinfer/attention/hopper/sparse_mainloop.cuh b/include/flashinfer/attention/hopper/sparse_mainloop.cuh new file mode 100644 index 00000000..263a8c74 --- /dev/null +++ b/include/flashinfer/attention/hopper/sparse_mainloop.cuh @@ -0,0 +1,327 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_ + +#include +#include +#include +#include + +#include "../../math.cuh" +#include "block_sparse_gather.cuh" +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "named_barrier.cuh" +#include "utils.cuh" + +namespace flashinfer { + +using namespace cute; + +template +struct SparseCollectiveMainloop { + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using IdType = typename Ktraits::IdType; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + static constexpr int CTA_Q = get<0>(TileShape_QKD{}); + static constexpr int CTA_KV = get<1>(TileShape_QKD{}); + + static constexpr int NUM_STAGES = Ktraits::NUM_STAGES; + static constexpr int HEAD_DIM = Ktraits::HEAD_DIM; + static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; + + using GmemTiledCopyQ = cute::SM90_TMA_LOAD; + static constexpr auto AlignmentKV = 128 / cutlass::sizeof_bits::value; + using AlignmentTypeKV = cute::uint_byte_t(sizeof(DTypeKV)) * AlignmentKV>; + // NOTE(Zihao): use SM80_CP_ASYNC for sparse loading of KV-cache + using GmemCopyAtomKV = cute::Copy_Atom, DTypeKV>; + using GmemTiledCopyKV = + decltype(cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy< + GmemCopyAtomKV, NUM_COPY_THREADS, AlignmentKV, + cutlass::detail::TagToStrideB_t, + decltype(cute::get<1>(TileShape_QKD{})), decltype(cute::get<2>(TileShape_QKD{}))>()); + + using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutK = typename Ktraits::SmemLayoutK; + using SmemLayoutV = typename Ktraits::SmemLayoutV; + using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + + using ShapeT = cute::Shape; + using StrideT = cute::Shape; // (N, D, H) + using LayoutT = cute::Layout; + + using ShapeLseT = cute::Shape; + using StrideLseT = cute::Shape<_1, int64_t>; + using LayoutLseT = cute::Layout; + + using TMA_Q = decltype(make_tma_copy( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideT{}, int32_t(0)), StrideT{}), + SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{})); // no mcast for Q + + static constexpr bool USE_TMA_LOAD_KV = false; + static constexpr int NUM_MMA_THREADS = size(typename Ktraits::TiledMmaQK{}); + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + static constexpr uint32_t TmaTransactionBytesQ = + static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); + + static constexpr bool UseSchedulerBarrier = + cutlass::sizeof_bits_v == 8 ? HEAD_DIM >= 128 : HEAD_DIM <= 128; + using WarpScheduler = WarpScheduler; + + // Host side kernel arguments + struct Arguments { + DTypeQ const* Q_ptr; + LayoutT layout_Q; + DTypeKV const* K_ptr; + LayoutT layout_K; + DTypeKV const* V_ptr; + LayoutT layout_V; + IdType const* kv_indices; + int window_left; + float const logits_soft_cap; + float const sm_scale_log2; + }; + + // Device side kernel params + struct Params { + LayoutT layout_Q; + LayoutT layout_K; + LayoutT layout_V; + TMA_Q tma_load_Q; + DTypeKV* K_ptr; + DTypeKV* V_ptr; + IdType* kv_indices; + int window_left; + float const logits_soft_cap; + float const sm_scale_log2; + }; + + static Params to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.Q_ptr), args.layout_Q); + TMA_Q tma_load_Q = + make_tma_copy(GmemTiledCopyQ{}, mQ, SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{}); + return {args.layout_Q, + args.layout_K, + args.layout_V, + tma_load_Q, + const_cast(args.K_ptr), + const_cast(args.V_ptr), + const_cast(args.kv_indices), + args.window_left, + args.logits_soft_cap, + args.sm_scale_log2}; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor()); + } + + CUTLASS_DEVICE + int get_num_kv_tiles(Params const& mainloop_params, int q_tile_idx, const int qo_len, + const int kv_len) { + static constexpr int CTA_Q = get<0>(TileShape_QKD{}); + static constexpr int CTA_KV = get<1>(TileShape_QKD{}); + int num_kv_tiles = cute::ceil_div(kv_len, CTA_KV); + if constexpr (CAUSAL) { + num_kv_tiles = std::min(num_kv_tiles, + cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV)); + } + + return num_kv_tiles; + } + + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v, SharedStorage& shared_storage, + Scheduler& scheduler, typename Scheduler::Params const& scheduler_params, + typename Scheduler::WorkTileInfo& work_tile_info, + BlockCoord const& block_coord, int work_idx) { + int thread_idx = threadIdx.x; + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0); + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); + + Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); + + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; + + // Prepare the TMA loads + Tensor gQ = get_local_tile_tensor(mQ, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, + qo_len)(_, _, q_tile_idx); // (Q, D) + + Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); + Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); + auto [tQgQ, tQsQ] = + tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{}, group_modes<0, 2>(sQ_x), + group_modes<0, 2>(gQ_x)); // (TMA), (TMA) + + int num_kv_tiles = get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + int kv_tile_idx = num_kv_tiles - 1; + int swa_begin_kv_tile_idx = 0; + if constexpr (LEFT_SLIDING_WINDOW) { + swa_begin_kv_tile_idx = get_swa_begin_kv_tile_idx(mainloop_params.window_left, + q_tile_idx, qo_len, kv_len); + } + + constexpr int HEAD_DIM = get<2>(TileShape_QKD{}); + constexpr int CTA_KV = get<1>(TileShape_QKD{}); + auto indexed_gather = BlockSparseIndexedGather(mainloop_params.kv_indices + kv_indptr); + + Tensor mK = make_block_sparse_tensor( // (kv_len, D) + make_gmem_ptr(mainloop_params.K_ptr + kv_head_idx * stride<2>(mainloop_params.layout_K)), + make_shape(kv_len, HEAD_DIM), stride<0>(mainloop_params.layout_K), indexed_gather); + Tensor mV = make_block_sparse_tensor( // (kv_len, D) + make_gmem_ptr(mainloop_params.V_ptr + kv_head_idx * stride<2>(mainloop_params.layout_V)), + make_shape(kv_len, HEAD_DIM), stride<0>(mainloop_params.layout_V), indexed_gather); + + Tensor gK = local_tile(mK, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D, kv) + Tensor gV = local_tile(mV, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D, kv) + Tensor cKV = cute::make_identity_tensor(gK.shape()); + + GmemTiledCopyKV gmem_tiled_copy_kv; + auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx); + + Tensor tKgK = gmem_thr_copy_kv.partition_S(gK); // (CPY, CPY_KV, CPY_D, kv) + Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE) + Tensor tVgV = gmem_thr_copy_kv.partition_S(gV); // (CPY, CPY_KV, CPY_D, kv) + Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE) + Tensor tKVcKV = gmem_thr_copy_kv.partition_D(cKV); // (CPY, CPY_KV, CPY_D) + Tensor tKVcKVGroup = flatten_1(tKVcKV); // (CPY, (CPY_KV, CPY_D)) + + int valid_last_kv_tile_size = std::min(kv_len - kv_tile_idx * CTA_KV, CTA_KV); + auto predicate_fn = [&](auto coords) { + auto s_coords = tKVcKVGroup(_0{}, coords); + return elem_less(get<0>(s_coords), valid_last_kv_tile_size); + }; + + // load last k-tile + { + pipeline_k.producer_acquire(smem_pipe_write_k); + Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) + Tensor tKsKiGroup = + flatten_1(tKsK(_, _, _, smem_pipe_write_k.index())); // (CPY, (CPY_KV, CPY_D)) + copy_if(gmem_tiled_copy_kv, predicate_fn, tKgKiGroup, tKsKiGroup); + + pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_k; + } + + // load Q tile + if (warp_idx_in_warpgroup == 0) { + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + cutlass::NumThreadsPerWarp, + static_cast(NamedBarriers::kQueryEmpty)); + + int lane_predicate = cute::elect_one_sync(); + if (lane_predicate) { + shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(mainloop_params.tma_load_Q.with( + reinterpret_cast( + shared_storage.barrier_Q), + /*mcast_mask=*/0), + tQgQ, tQsQ); + } + } + + shared_storage.barrier_O.wait((work_idx + 1) % 2); + + if (kv_tile_idx == swa_begin_kv_tile_idx) { + pipeline_v.producer_acquire(smem_pipe_write_v); + Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) + Tensor tVsViGroup = + flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) + copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup); + + pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_v; + } else { + // load second last k-tile and last v-tile + pipeline_k.producer_acquire(smem_pipe_write_k); + Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) + Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) + copy(gmem_tiled_copy_kv, tKgKi, tKsKi); + + pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_k; + + pipeline_v.producer_acquire(smem_pipe_write_v); + Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) + Tensor tVsViGroup = + flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) + copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup); + + pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); + --kv_tile_idx; + ++smem_pipe_write_v; + + // load remaining k/v tiles +#pragma unroll 2 + for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { + pipeline_k.producer_acquire(smem_pipe_write_k); + + Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) + Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) + copy(gmem_tiled_copy_kv, tKgKi, tKsKi); + + pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_k; + + pipeline_v.producer_acquire(smem_pipe_write_v); + Tensor tVgVi = tVgV(_, _, _, kv_tile_idx); // (CPY, CPY_KV, CPY_D) + Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, CPY_KV, CPY_D) + copy(gmem_tiled_copy_kv, tVgVi, tVsVi); + + pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_v; + } + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + + // load first v tile + { + pipeline_v.producer_acquire(smem_pipe_write_v); + Tensor tVgVi = tVgV(_, _, _, 0); // (CPY, (CPY_KV, CPY_D)) + Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, (CPY_KV, CPY_D)) + copy(gmem_tiled_copy_kv, tVgVi, tVsVi); + pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_v; + } + } + + scheduler.broadcast_next_work(work_tile_info); + } + + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, + PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v) { + pipeline_k.producer_tail(smem_pipe_write_k); + pipeline_v.producer_tail(smem_pipe_write_v); + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_ diff --git a/include/flashinfer/attention/hopper/tile_scheduler.cuh b/include/flashinfer/attention/hopper/tile_scheduler.cuh new file mode 100644 index 00000000..39610271 --- /dev/null +++ b/include/flashinfer/attention/hopper/tile_scheduler.cuh @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_TILE_SCHEDULER_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_TILE_SCHEDULER_CUH_ + +#include "cutlass/arch/barrier.h" +#include "cutlass/fast_math.h" +#include "named_barrier.cuh" + +namespace flashinfer { + +struct SingleTileScheduler { + public: + // Host side kernel arguments + struct Arguments { + int const num_qo_tiles, num_qo_heads, qo_len, kv_len; + cutlass::FastDivmod group_size_fastdiv; + }; + + // Device side kernel params + struct Params { + int const qo_len, kv_len; + cutlass::FastDivmod group_size_fastdiv; + }; + + static Params to_underlying_arguments(Arguments const& args) { + return {args.qo_len, args.kv_len, args.group_size_fastdiv}; + } + + static dim3 get_grid_dim(Arguments const& args, int num_sm) { + return {uint32_t(args.num_qo_tiles), uint32_t(args.num_qo_heads)}; + } + + struct WorkTileInfo { + int q_tile_idx = 0; + int qo_head_idx = 0; + int kv_head_idx = 0; + bool is_valid_tile = false; + + CUTLASS_DEVICE + bool is_valid(Params const& params) const { return is_valid_tile; } + + CUTLASS_DEVICE + auto get_block_coord(Params const& params) const { + return cute::tuple{q_tile_idx, qo_head_idx, kv_head_idx, /*qo_indptr=*/0, + /*kv_indptr=*/0, params.qo_len, params.kv_len}; + } + }; + + CUTLASS_DEVICE + SingleTileScheduler() {} + + CUTLASS_DEVICE + WorkTileInfo get_initial_work(Params const& params) const { + int qo_head_idx = blockIdx.y; + int kv_head_idx = params.group_size_fastdiv.divide(qo_head_idx); + return {/*q_tile_idx=*/int(blockIdx.x), qo_head_idx, kv_head_idx, /*is_valid_tile*/ true}; + } + + CUTLASS_DEVICE + void init_consumer() const {} + + CUTLASS_DEVICE + void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + CUTLASS_DEVICE + void broadcast_next_work(WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, + WorkTileInfo const& current_work) const { + return {-1, -1, false}; + } +}; + +template +struct BatchPrefillTileScheduler { + public: + // Host side kernel arguments + struct Arguments { + IdType *work_indptr, *head_indices, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, + *kv_lens; + cutlass::FastDivmod group_size_fastdiv; + }; + + // Device side kernel params + struct Params { + IdType *work_indptr, *head_indices, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, + *kv_lens; + cutlass::FastDivmod group_size_fastdiv; + }; + + static Params to_underlying_arguments(Arguments const& args) { + return {args.work_indptr, args.head_indices, args.qo_tile_indices, args.qo_indptr, + args.kv_indptr, args.qo_lens, args.kv_lens, args.group_size_fastdiv}; + } + + static dim3 get_grid_dim(Arguments const& args, int num_sm) { + return {132U}; // 132 + } + + struct WorkTileInfo { + int q_tile_idx = 0; + int qo_head_idx = 0; + int kv_head_idx = 0; + int qo_indptr = 0; + int kv_indptr = 0; + int qo_len = 0; + int kv_len = 0; + int counter = 0; + int ptr_begin = 0; + int ptr_end = 0; + + CUTLASS_DEVICE + bool is_valid(Params const& params) const { return counter + ptr_begin < ptr_end; } + + CUTLASS_DEVICE + auto get_block_coord(Params const& params) const { + return cute::tuple{q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, + kv_indptr, qo_len, kv_len}; + } + }; + + CUTLASS_DEVICE + BatchPrefillTileScheduler() {} + + CUTLASS_DEVICE + WorkTileInfo get_initial_work(Params const& params) const { + int ptr_begin = params.work_indptr[blockIdx.x]; + int ptr_end = params.work_indptr[blockIdx.x + 1]; + if (ptr_begin < ptr_end) { + int work_idx = ptr_begin; + int qo_head_idx = params.head_indices[work_idx]; + int kv_head_idx = params.group_size_fastdiv.divide(qo_head_idx); + return {params.qo_tile_indices[work_idx], + qo_head_idx, + kv_head_idx, + params.qo_indptr[work_idx], + params.kv_indptr[work_idx], + params.qo_lens[work_idx], + params.kv_lens[work_idx], + 0, + ptr_begin, + ptr_end}; + } else { + return {-1, -1, -1, -1, -1, -1, 0, ptr_begin, ptr_end}; + } + } + + CUTLASS_DEVICE + void init_consumer() const {} + + CUTLASS_DEVICE + void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + CUTLASS_DEVICE + void broadcast_next_work(WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, + WorkTileInfo const& current_work) const { + int work_idx = current_work.ptr_begin + current_work.counter + 1; + if (work_idx < current_work.ptr_end) { + int qo_head_idx = params.head_indices[work_idx]; + int kv_head_idx = params.group_size_fastdiv.divide(qo_head_idx); + return {params.qo_tile_indices[work_idx], + qo_head_idx, + kv_head_idx, + params.qo_indptr[work_idx], + params.kv_indptr[work_idx], + params.qo_lens[work_idx], + params.kv_lens[work_idx], + current_work.counter + 1, + current_work.ptr_begin, + current_work.ptr_end}; + } else { + return {-1, + -1, + -1, + -1, + -1, + -1, + current_work.counter + 1, + current_work.ptr_begin, + current_work.ptr_end}; + } + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_TILE_SCHEDULER_CUH_ diff --git a/include/flashinfer/attention/hopper/utils.cuh b/include/flashinfer/attention/hopper/utils.cuh new file mode 100644 index 00000000..0441cbd1 --- /dev/null +++ b/include/flashinfer/attention/hopper/utils.cuh @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_UTILS_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_UTILS_CUH_ + +#include +#include +#include +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../math.cuh" +#include "../../utils.cuh" +#include "cutlass/fast_math.h" + +namespace flashinfer { + +using namespace cute; + +template +CUTLASS_DEVICE int get_swa_begin_kv_tile_idx(int window_left, int q_tile_idx, const int qo_len, + const int kv_len) { + return std::max((q_tile_idx * CTA_Q + kv_len - qo_len - window_left) / CTA_KV - 1, 0); +} + +template +CUTLASS_DEVICE int get_swa_end_kv_tile_idx(int window_left, int q_tile_idx, const int qo_len, + const int kv_len) { + return std::max(((q_tile_idx + 1) * CTA_Q + kv_len - qo_len - window_left) / CTA_KV, -1); +} + +template +CUTLASS_HOST_DEVICE auto flatten_1(TensorT tensor) { + Tensor tensor_flatten = cute::flatten(tensor); + return cute::group_modes<1, rank(tensor_flatten)>(tensor_flatten); +} + +CUTLASS_HOST_DEVICE auto get_gmem_layout(int nnz, int num_heads, int head_dim, int64_t n_stride, + int64_t h_stride) { + return make_layout(make_shape(nnz, head_dim, num_heads), + make_stride(n_stride, cute::_1{}, h_stride)); +} + +CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(int nnz, int num_heads) { + return make_layout(make_shape(num_heads, nnz), make_stride(cute::_1{}, int64_t(num_heads))); +} + +template +CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape, + int head_idx, int offset, int seq_len) { + auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(1, get<1>(tile_shape)), + make_coord(offset, _0{})); + auto g_sequence = + make_tensor(g_offset.data(), + make_layout(cute::make_shape(seq_len, get<1>(tile_shape)), g_offset.stride())); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{})); + return g_tensor; +} + +template +CUTLASS_DEVICE auto get_lse_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape, + int head_idx, int offset, int seq_len) { + auto g_offset = local_tile(m_tensor(head_idx, _), cute::make_shape(_1{}), make_coord(offset)); + + auto g_sequence = make_tensor(g_offset.data(), make_layout(cute::make_shape(seq_len), + cute::make_shape(shape<0>(m_tensor)))); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_)); + return g_tensor; +} + +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, +// MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), + make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); +}; + +// For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, +// MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + auto l = logical_divide(get<0>(acc_layout), Shape{}); // (2, 2, (2, N / 16))) + return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), + make_layout(get<2, 1>(l), get<2>(acc_layout))); +}; + +template +__forceinline__ __device__ auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +template +__forceinline__ __device__ void gemm(TiledMma& tiled_mma, TensorA const& tCrA, TensorB const& tCrB, + TensorC& tCrC) { + constexpr bool Is_RS = + !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { + warpgroup_fence_operand(const_cast(tCrA)); + } + warpgroup_fence_operand(tCrC); + warpgroup_arrive(); + if constexpr (init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { + warpgroup_wait(); + } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { + warpgroup_fence_operand(const_cast(tCrA)); + } +} + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_UTILS_CUH_ diff --git a/include/flashinfer/attention/hopper/variants.cuh b/include/flashinfer/attention/hopper/variants.cuh new file mode 100644 index 00000000..75d7c7bc --- /dev/null +++ b/include/flashinfer/attention/hopper/variants.cuh @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// NOTE(Zihao): we should merge this with include/flashinfer/attention/variants.cuh in the future +#ifndef FLASHINFER_ATTENTION_HOPPER_VARIANTS_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_VARIANTS_CUH_ +#include + +#include "../../math.cuh" +#include "attention_updater.cuh" + +namespace flashinfer { + +struct StandardAttention { + template + using Updater = OnlineSoftmaxWithScale; + + template + __device__ StandardAttention(const ParamsT& params) {} + + template + __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, + uint32_t qo_head_idx, uint32_t kv_head_idx) { + return logits; + } +}; + +struct LogitsSoftCap { + float pre_tanh_scale; + float post_tanh_scale; + template + using Updater = OnlineSoftmaxWithoutScale; + + template + __device__ LogitsSoftCap(const ParamsT& params) { + pre_tanh_scale = (params.sm_scale_log2 * math::loge2) * math::ptx_rcp(params.logits_soft_cap); + post_tanh_scale = math::log2e * params.logits_soft_cap; + } + + template + __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, + uint32_t qo_head_idx, uint32_t kv_head_idx) { + return math::tanh(logits * pre_tanh_scale) * post_tanh_scale; + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_VARIANTS_CUH_ diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 64cf106f..f8023171 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -22,13 +22,13 @@ #include #include #include -#include #include #include "../allocator.h" #include "../exception.h" #include "../pos_enc.cuh" #include "../utils.cuh" +#include "heap.h" namespace flashinfer { @@ -720,5 +720,196 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i return cudaSuccess; } +inline float cost_function(int qo_len, int kv_len, int group_size) { + return 2 * float(qo_len) * float(group_size) + kv_len; +} + +template +std::vector flatten(const std::vector>& vec, int size_after_flatten) { + std::vector result; + result.reserve(size_after_flatten); + for (const auto& inner_vec : vec) { + result.insert(result.end(), inner_vec.begin(), inner_vec.end()); + } + return std::move(result); +} + +struct PrefillPlanSM90Info { + int64_t qo_tile_indices_offset; + int64_t qo_indptr_offset; + int64_t kv_indptr_offset; + int64_t qo_len_offset; + int64_t kv_len_offset; + int64_t head_indices_offset; + int64_t work_indptr_offset; + + PrefillPlanSM90Info() + : qo_tile_indices_offset(0), + qo_indptr_offset(0), + kv_indptr_offset(0), + qo_len_offset(0), + kv_len_offset(0), + head_indices_offset(0), + work_indptr_offset(0) {} + + // convert PrefillPlanSM90Info to std::vector + std::vector ToVector() const { + return {qo_tile_indices_offset, qo_indptr_offset, kv_indptr_offset, qo_len_offset, + kv_len_offset, head_indices_offset, work_indptr_offset}; + } + + // From std::vector to PrefillPlanSM90Info + void FromVector(const std::vector& vec) { + if (vec.size() != 7) { + std::ostringstream err_msg; + err_msg << "PrefillPlanSM90Info::FromVector: vec.size() should be 8, but got " << vec.size(); + FLASHINFER_ERROR(err_msg.str()); + } + qo_tile_indices_offset = vec[0]; + qo_indptr_offset = vec[1]; + kv_indptr_offset = vec[2]; + qo_len_offset = vec[3]; + kv_len_offset = vec[4]; + head_indices_offset = vec[5]; + work_indptr_offset = vec[6]; + } +}; + +template +cudaError_t PrefillSM90Plan(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, void* page_locked_int_buffer, + size_t int_workspace_size_in_bytes, PrefillPlanSM90Info& plan_info, + IdType* qo_indptr_h, IdType* kv_indptr_h, IdType* kv_len_arr_h, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, uint32_t page_size, bool causal, + bool enable_cuda_graph, uint32_t sizeof_dtype_o, cudaStream_t stream) { + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " + << num_kv_heads; + FLASHINFER_ERROR(err_msg.str()); + } + + std::vector> idx_qo_kv_len_vec; + for (uint32_t i = 0; i < batch_size; ++i) { + int qo_len = qo_indptr_h[i + 1] - qo_indptr_h[i]; + int kv_len = kv_len_arr_h[i]; + if (kv_len < 0) { + std::ostringstream err_msg; + err_msg << "kv_len[" << i << "]" << kv_len << " should be non-negative"; + FLASHINFER_ERROR(err_msg.str()); + } + if (qo_len < 0) { + std::ostringstream err_msg; + err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]" + << qo_indptr_h[i] << " should be non-negative"; + FLASHINFER_ERROR(err_msg.str()); + } + idx_qo_kv_len_vec.push_back({i, qo_len, kv_len}); + } + + std::sort(idx_qo_kv_len_vec.begin(), idx_qo_kv_len_vec.end(), + [](const auto& a, const auto& b) { return std::get<2>(a) > std::get<2>(b); }); + int cta_tile_q = 128; + if (head_dim == 64) { + cta_tile_q = 192; + } + + const int num_sm90_ctas = 132; // for sm90, the num_ctas is fixed + + CTACostHeap cta_cost_heap(num_sm90_ctas); + std::vector> cta_qo_tile_indices(num_sm90_ctas, std::vector()), + cta_qo_indptr(num_sm90_ctas, std::vector()), + cta_kv_indptr(num_sm90_ctas, std::vector()), + cta_qo_len(num_sm90_ctas, std::vector()), + cta_kv_len(num_sm90_ctas, std::vector()), + cta_head_indices(num_sm90_ctas, std::vector()); + + for (int qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + for (auto& [i, qo_len, kv_len] : idx_qo_kv_len_vec) { + int num_qo_tiles = ceil_div(qo_len, cta_tile_q); + for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) { + auto [cta_idx, accum_cost] = cta_cost_heap.pop(); + // NOTE(Zihao): our current FA3 implementation do not fuse query and group heads + // so the group_size in cost_function is always 1 + cta_cost_heap.insert( + {cta_idx, + accum_cost + cost_function(cta_tile_q, + causal + ? kv_len - (num_qo_tiles - qo_tile_idx - 1) * cta_tile_q + : kv_len, + /*group_size=*/1)}); + cta_qo_tile_indices[cta_idx].push_back(qo_tile_idx); + cta_qo_indptr[cta_idx].push_back(qo_indptr_h[i]); + cta_qo_len[cta_idx].push_back(qo_len); + cta_kv_indptr[cta_idx].push_back(kv_indptr_h[i]); + cta_kv_len[cta_idx].push_back(kv_len); + cta_head_indices[cta_idx].push_back(qo_head_idx); + } + } + } + + std::vector work_indptr_vec(num_sm90_ctas + 1, 0); + for (uint32_t i = 0; i < num_sm90_ctas; ++i) { + work_indptr_vec[i + 1] = work_indptr_vec[i] + cta_qo_tile_indices[i].size(); + } + IdType total_num_works = work_indptr_vec[num_sm90_ctas]; + auto qo_tile_indices_vec = flatten(cta_qo_tile_indices, total_num_works); + auto qo_indptr_vec = flatten(cta_qo_indptr, total_num_works); + auto kv_indptr_vec = flatten(cta_kv_indptr, total_num_works); + auto qo_len_vec = flatten(cta_qo_len, total_num_works); + auto kv_len_vec = flatten(cta_kv_len, total_num_works); + auto head_indices_vec = flatten(cta_head_indices, total_num_works); + + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + const int max_total_num_works = 1048576; + if (total_num_works > max_total_num_works) { + std::ostringstream err_msg; + err_msg << "total_num_works " << total_num_works << " should be less than " + << max_total_num_works; + FLASHINFER_ERROR(err_msg.str()); + } + plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_qo_tile_indices"); + plan_info.qo_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_qo_offset"); + plan_info.kv_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_kv_offset"); + plan_info.qo_len_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, + 16, "batch_prefill_sm90_qo_len"); + plan_info.kv_len_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, + 16, "batch_prefill_sm90_kv_len"); + plan_info.head_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_head_indices"); + plan_info.work_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * (num_sm90_ctas + 1), 16, "batch_prefill_sm90_work_indptr"); + + IdType* qo_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_tile_indices_offset); + IdType* qo_offset_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_indptr_offset); + IdType* kv_offset_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_indptr_offset); + IdType* qo_len_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_len_offset); + IdType* kv_len_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_len_offset); + IdType* head_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.head_indices_offset); + IdType* work_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.work_indptr_offset); + + std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), qo_tile_indices_h); + std::copy(qo_indptr_vec.begin(), qo_indptr_vec.end(), qo_offset_h); + std::copy(kv_indptr_vec.begin(), kv_indptr_vec.end(), kv_offset_h); + std::copy(qo_len_vec.begin(), qo_len_vec.end(), qo_len_h); + std::copy(kv_len_vec.begin(), kv_len_vec.end(), kv_len_h); + std::copy(head_indices_vec.begin(), head_indices_vec.end(), head_indices_h); + std::copy(work_indptr_vec.begin(), work_indptr_vec.end(), work_indptr_h); + + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + cudaMemcpyHostToDevice, stream)); + return cudaSuccess; +} + } // namespace flashinfer #endif // FLASHINFER_ATTENTION_SCHEDULER_CUH_ diff --git a/include/flashinfer/cutlass_utils.cuh b/include/flashinfer/cutlass_utils.cuh index f6d3ef03..5102756a 100644 --- a/include/flashinfer/cutlass_utils.cuh +++ b/include/flashinfer/cutlass_utils.cuh @@ -44,29 +44,37 @@ namespace flashinfer { template struct cutlass_dtype { - using value = T; + using type = T; }; template <> struct cutlass_dtype { - using value = cutlass::half_t; + using type = cutlass::half_t; }; template <> struct cutlass_dtype { - using value = cutlass::bfloat16_t; + using type = cutlass::bfloat16_t; }; template <> struct cutlass_dtype<__nv_fp8_e4m3> { - using value = cutlass::float_e4m3_t; + using type = cutlass::float_e4m3_t; }; template <> struct cutlass_dtype<__nv_fp8_e5m2> { - using value = cutlass::float_e5m2_t; + using type = cutlass::float_e5m2_t; }; +template +using cutlass_dtype_t = typename cutlass_dtype::type; + +template +void compileTimeDebug(T&&) { + static_assert(sizeof(T) == 0, "Compile time debug"); +} + } // namespace flashinfer #endif // FLASHINFER_CUTLASS_UTILS_CUH_ diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index fa256be3..04034f19 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -16,6 +16,8 @@ #ifndef FLASHINFER_PAGE_CUH_ #define FLASHINFER_PAGE_CUH_ +#include + #include #include "fastdiv.cuh" @@ -280,6 +282,55 @@ __global__ void AppendPagedKVCacheKernel(paged_kv_t paged_kv, } } +template +__global__ void BlockSparseIndicesToVectorSparseOffsetsKernel( + IdType* __restrict__ block_sparse_indices, IdType* __restrict__ block_sparse_indptr, + IdType* __restrict__ vector_sparse_offsets, IdType* __restrict__ vector_sparse_indptr, + IdType* __restrict__ kv_lens, const uint32_t stride_block, const uint32_t stride_n, + const uint32_t batch_size, const uint_fastdiv block_size) { +#pragma unroll 1 + for (int b = blockIdx.x; b < batch_size; ++b) { +#pragma unroll 2 + for (int pos = threadIdx.x; pos < kv_lens[b]; pos += blockDim.x) { + uint32_t q, r; + block_size.divmod(pos, q, r); + vector_sparse_offsets[vector_sparse_indptr[b] + pos] = + block_sparse_indices[block_sparse_indptr[b] + q] * stride_block + r * stride_n; + } + } +} + +template +cudaError_t BlockSparseIndicesToVectorSparseOffset( + IdType* block_sparse_indices, IdType* block_sparse_indptr, IdType* vector_sparse_offsets, + IdType* vector_sparse_indptr, IdType* kv_lens, const int64_t stride_block, + const int64_t stride_n, const int64_t batch_size, const uint32_t block_size, + cudaStream_t stream = nullptr) { + int dev_id = 0; + int num_sms = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); + + uint32_t num_threads = 512; + + uint_fastdiv block_size_fastdiv(block_size); + + auto kernel = BlockSparseIndicesToVectorSparseOffsetsKernel; + void* args[] = {(void*)&block_sparse_indices, + (void*)&block_sparse_indptr, + (void*)&vector_sparse_offsets, + (void*)&vector_sparse_indptr, + (void*)&kv_lens, + (void*)&stride_block, + (void*)&stride_n, + (void*)&batch_size, + (void*)&block_size_fastdiv}; + + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, num_sms, num_threads, args, 0, stream)); + + return cudaSuccess; +} + /*! * \brief Append new keys/values to the paged key-value cache in the decode phase * \tparam DType The data type of the key-value cache diff --git a/licenses/LICENSE.cutlass.txt b/licenses/LICENSE.cutlass.txt new file mode 100644 index 00000000..52550084 --- /dev/null +++ b/licenses/LICENSE.cutlass.txt @@ -0,0 +1,27 @@ +Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/LICENSE.flashattention3.txt b/licenses/LICENSE.flashattention3.txt new file mode 100644 index 00000000..5860e4b3 --- /dev/null +++ b/licenses/LICENSE.flashattention3.txt @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/setup.py b/setup.py index 9bbf9895..7cb292f2 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,9 @@ head_dims = list(map(int, head_dims)) pos_encoding_modes = list(map(int, pos_encoding_modes)) +pos_encoding_modes_sm90 = [mode for mode in pos_encoding_modes if mode != 2] allow_fp16_qk_reductions = list(map(int, allow_fp16_qk_reductions)) +allow_fp16_qk_reductions_sm90 = [mode for mode in allow_fp16_qk_reductions if mode != 1] mask_modes = list(map(int, mask_modes)) enable_aot = os.environ.get("FLASHINFER_ENABLE_AOT", "0") == "1" @@ -66,6 +68,7 @@ def generate_cuda() -> None: try: # no aot_build_utils in sdist sys.path.append(str(root)) from aot_build_utils.generate import get_instantiation_cu + from aot_build_utils.generate_sm90 import get_sm90_instantiation_cu except ImportError: return @@ -79,6 +82,15 @@ def generate_cuda() -> None: enable_bf16=enable_bf16, enable_fp8=enable_fp8, ) + ) + get_sm90_instantiation_cu( + argparse.Namespace( + path=gen_dir, + head_dims=head_dims, + pos_encoding_modes=pos_encoding_modes_sm90, + allow_fp16_qk_reductions=allow_fp16_qk_reductions_sm90, + mask_modes=mask_modes, + enable_bf16=enable_bf16, + ) ) aot_config_str = f"""prebuilt_ops_uri = set({aot_kernel_uris})""" (root / "flashinfer" / "jit" / "aot_config.py").write_text(aot_config_str) @@ -185,10 +197,15 @@ def __init__(self, *args, **kwargs) -> None: ] kernel_sm90_sources = [ "csrc/group_gemm_sm90.cu", - "csrc/flashinfer_gemm_sm90_ops.cu", + "csrc/single_prefill_sm90.cu", + "csrc/batch_prefill_sm90.cu", + "csrc/flashinfer_ops_sm90.cu", ] decode_sources = list(gen_dir.glob("*decode_head*.cu")) - prefill_sources = list(gen_dir.glob("*prefill_head*.cu")) + prefill_sources = [ + f for f in gen_dir.glob("*prefill_head*.cu") if "_sm90" not in f.name + ] + prefill_sm90_sources = list(gen_dir.glob("*prefill_head*_sm90.cu")) ext_modules = [ torch_cpp_ext.CUDAExtension( name="flashinfer._kernels", @@ -202,7 +219,7 @@ def __init__(self, *args, **kwargs) -> None: ), torch_cpp_ext.CUDAExtension( name="flashinfer._kernels_sm90", - sources=kernel_sm90_sources, + sources=kernel_sm90_sources + prefill_sm90_sources, include_dirs=include_dirs, extra_compile_args={ "cxx": cxx_flags, diff --git a/tests/test_block_sparse_indices_to_vector_sparse_offsets.py b/tests/test_block_sparse_indices_to_vector_sparse_offsets.py new file mode 100644 index 00000000..cf2ef003 --- /dev/null +++ b/tests/test_block_sparse_indices_to_vector_sparse_offsets.py @@ -0,0 +1,84 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch + +import flashinfer.page + + +@pytest.mark.parametrize("batch_size", [1, 7, 19, 128, 517]) +@pytest.mark.parametrize("kv_len", [97, 199, 2049, 31791]) +@pytest.mark.parametrize("block_size", [1, 3, 7, 16, 64, 79, 128]) +@pytest.mark.parametrize("stride_block", [128]) +@pytest.mark.parametrize("stride_n", [1]) +def test_block_sparse_indices_to_vector_sparse_offsets( + batch_size, kv_len, block_size, stride_block, stride_n +): + if batch_size * kv_len > 1048576: + pytest.skip("skip large test") + num_blocks_per_row = (kv_len + block_size - 1) // block_size + + block_sparse_indices = torch.arange( + batch_size * num_blocks_per_row, device="cuda", dtype=torch.int32 + ) + block_sparse_indptr = torch.arange( + 0, + batch_size * num_blocks_per_row + 1, + num_blocks_per_row, + device="cuda", + dtype=torch.int32, + ) + vector_sparse_offsets_buf = torch.zeros( + batch_size * kv_len, device="cuda", dtype=torch.int32 + ) + vector_sparse_indptr = torch.arange( + 0, batch_size * kv_len + 1, kv_len, device="cuda", dtype=torch.int32 + ) + kv_lens = torch.full((batch_size,), kv_len, device="cuda", dtype=torch.int32) + + vector_sparse_offsets = ( + flashinfer.page.block_sparse_indices_to_vector_sparse_offsets( + block_sparse_indices, + block_sparse_indptr, + vector_sparse_offsets_buf, + vector_sparse_indptr, + kv_lens, + stride_block, + stride_n, + block_size, + ) + ) + + # Check that the output is correct + for i in range(batch_size): + indices_i = block_sparse_indices[ + i * num_blocks_per_row : (i + 1) * num_blocks_per_row + ].cpu() + output_i = vector_sparse_offsets[ + vector_sparse_indptr[i] : vector_sparse_indptr[i + 1] + ].cpu() + + output_ref_i = ( + indices_i[torch.arange(0, kv_len, dtype=torch.int32) // block_size] + * stride_block + + (torch.arange(0, kv_len, dtype=torch.int32) % block_size) * stride_n + ) + torch.testing.assert_close(output_i, output_ref_i) + + +if __name__ == "__main__": + pass diff --git a/tests/test_hopper.py b/tests/test_hopper.py new file mode 100644 index 00000000..1fbad5ff --- /dev/null +++ b/tests/test_hopper.py @@ -0,0 +1,218 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch + +import flashinfer + + +@pytest.mark.parametrize("seq_len", [11, 99, 1763, 9999, 32767]) +@pytest.mark.parametrize("num_qo_heads", [1, 4, 8]) +@pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) +def test_single_prefill( + seq_len, num_qo_heads, num_kv_heads, causal, head_dim, logits_soft_cap +): + if num_qo_heads % num_kv_heads != 0: + pytest.skip("num_qo_heads must be divisible by num_kv_heads") + torch.random.manual_seed(123) + q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") + k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + + o_sm80, lse_sm80 = flashinfer.single_prefill_with_kv_cache_return_lse( + q, + k, + v, + causal=causal, + logits_soft_cap=logits_soft_cap, + backend="fa2", + ) + + o_sm90, lse_sm90 = flashinfer.single_prefill_with_kv_cache_return_lse( + q, k, v, causal=causal, logits_soft_cap=logits_soft_cap, backend="fa3" + ) + torch.testing.assert_close(lse_sm80, lse_sm90, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o_sm80, o_sm90, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [11, 99, 1763, 9999, 32767]) +@pytest.mark.parametrize("num_qo_heads", [1, 4, 8]) +@pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("head_dim", [128]) # [64, 128, 256]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) +def test_batch_ragged_prefill( + batch_size, seq_len, num_qo_heads, num_kv_heads, causal, head_dim, logits_soft_cap +): + if num_qo_heads % num_kv_heads != 0: + pytest.skip("num_qo_heads must be divisible by num_kv_heads") + torch.random.manual_seed(42) + q = torch.randn( + batch_size * seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + k = torch.randn( + batch_size * seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + v = torch.randn( + batch_size * seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + + workspace_buffer = torch.empty( + 256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0" + ) + + wrapper_sm80 = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, backend="fa2" + ) + + wrapper_sm90 = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, backend="fa3" + ) + + qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() + kv_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() + + wrapper_sm80.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + causal=causal, + logits_soft_cap=logits_soft_cap, + ) + o_sm80, lse_sm80 = wrapper_sm80.run_return_lse(q, k, v) + + wrapper_sm90.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + causal=causal, + logits_soft_cap=logits_soft_cap, + ) + o_sm90, lse_sm90 = wrapper_sm90.run_return_lse(q, k, v) + + torch.testing.assert_close(lse_sm80, lse_sm90, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o_sm80, o_sm90, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [11, 12, 99, 1763, 9999, 32767]) +@pytest.mark.parametrize("page_size", [1]) # [1, 16]) +@pytest.mark.parametrize("num_qo_heads", [1, 4, 8]) +@pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) +def test_batch_paged_prefill( + batch_size, + seq_len, + page_size, + num_qo_heads, + num_kv_heads, + causal, + head_dim, + logits_soft_cap, +): + if num_qo_heads % num_kv_heads != 0: + pytest.skip("num_qo_heads must be divisible by num_kv_heads") + torch.random.manual_seed(42) + q = torch.randn( + batch_size * seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + num_pages_per_request = (seq_len + page_size - 1) // page_size + k = torch.randn( + batch_size * num_pages_per_request, + page_size, + num_kv_heads, + head_dim, + dtype=torch.half, + device="cuda", + ) + v = torch.randn( + batch_size * num_pages_per_request, + page_size, + num_kv_heads, + head_dim, + dtype=torch.half, + device="cuda", + ) + + workspace_buffer = torch.empty( + 256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0" + ) + + wrapper_sm80 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, backend="fa2" + ) + + wrapper_sm90 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, backend="fa3" + ) + + last_page_len = seq_len - (num_pages_per_request - 1) * page_size + qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() + kv_indptr = torch.arange( + 0, batch_size * num_pages_per_request + 1, num_pages_per_request + ).int() + kv_indices = torch.arange(0, batch_size * num_pages_per_request).int() + last_page_len = torch.full((batch_size,), last_page_len, dtype=torch.int32) + + wrapper_sm80.plan( + qo_indptr, + kv_indptr, + kv_indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + logits_soft_cap=logits_soft_cap, + ) + o_sm80, lse_sm80 = wrapper_sm80.run_return_lse(q, (k, v)) + + wrapper_sm90.plan( + qo_indptr, + kv_indptr, + kv_indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + logits_soft_cap=logits_soft_cap, + ) + o_sm90, lse_sm90 = wrapper_sm90.run_return_lse(q, (k, v)) + + torch.testing.assert_close(lse_sm80, lse_sm90, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o_sm80, o_sm90, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + # test_batch_prefill(14, 64, 32, 32, False, 128) + # test_batch_prefill(1, 32767, 8, 8, True, 128) + # test_single_prefill(64, 1, 1, False, 256) + # test_batch_paged_prefill(2, 32768, 1, 1, 1, False, 128) + test_batch_paged_prefill(16, 32767, 1, 8, 8, True, 128)