Skip to content

Commit

Permalink
[Inference] Update DygraphInferencePredictor (#9491)
Browse files Browse the repository at this point in the history
* update DygraphInferencePredictor

* update batch_size
  • Loading branch information
DrownFish19 authored Dec 2, 2024
1 parent d455181 commit 0b6284e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
4 changes: 4 additions & 0 deletions csrc/gpu/quant_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ __forceinline__ __device__ hip_bfloat16 add_mul<hip_bfloat16>(hip_bfloat16 a, hi
#else
template<>
__forceinline__ __device__ __nv_bfloat16 add_mul<__nv_bfloat16>(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
#if __CUDA_ARCH__ >= 800
return __hmul(__hadd(a, b), c);
#else
return (static_cast<float>(a) + static_cast<float>(b)) * static_cast<float>(c);
#endif
}
#endif

Expand Down
17 changes: 10 additions & 7 deletions csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def strtobool(v):

def get_gencode_flags():
if not strtobool(os.getenv("FLAG_LLM_PDC", "False")):
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
cc = get_sm_version()
return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
else:
# support more cuda archs
Expand All @@ -75,6 +74,7 @@ def get_gencode_flags():
gencode_flags = get_gencode_flags()
library_path = os.environ.get("LD_LIBRARY_PATH", "/usr/local/cuda/lib64")

sm_version = get_sm_version()

sources = [
"./gpu/save_with_output.cc",
Expand Down Expand Up @@ -102,16 +102,11 @@ def get_gencode_flags():
"./gpu/dequant_int8.cu",
"./gpu/flash_attn_bwd.cc",
"./gpu/tune_cublaslt_gemm.cu",
"./gpu/append_attention.cu",
"./gpu/append_attn/get_block_shape_and_split_kv_block.cu",
"./gpu/append_attn/decoder_write_cache_with_rope_kernel.cu",
"./gpu/append_attn/speculate_write_cache_with_rope_kernel.cu",
"./gpu/sample_kernels/top_p_sampling_reject.cu",
"./gpu/update_inputs_v2.cu",
"./gpu/set_preids_token_penalty_multi_scores.cu",
"./gpu/speculate_decoding_kernels/ngram_match.cc",
]
sources += find_end_files("./gpu/append_attn/template_instantiation", ".cu")
sources += find_end_files("./gpu/speculate_decoding_kernels", ".cu")

nvcc_compile_args = gencode_flags
Expand All @@ -138,6 +133,14 @@ def get_gencode_flags():
if cc >= 80:
sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"]

sources += [
"./gpu/append_attention.cu",
"./gpu/append_attn/get_block_shape_and_split_kv_block.cu",
"./gpu/append_attn/decoder_write_cache_with_rope_kernel.cu",
"./gpu/append_attn/speculate_write_cache_with_rope_kernel.cu",
]
sources += find_end_files("./gpu/append_attn/template_instantiation", ".cu")

if cc >= 89 and cuda_version >= 12.4:
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
Expand Down
11 changes: 8 additions & 3 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,10 +733,9 @@ def _infer(self, inputs: dict[str, paddle.Tensor]):
inputs[key] = paddle.to_tensor(inputs[key])

inputs["cache_kvs"] = self.cache_kvs
self.model.generate(
return self.model.generate(
**inputs,
)
return None


class BlockInferencePredictorMixin(BasePredictor):
Expand Down Expand Up @@ -914,6 +913,12 @@ def init_model_inputs(self, config: PredictorArgument):
self.model_inputs["rope_emb"] = paddle.concat([src_mask.reshape([-1]), tgt_mask.reshape([-1])])

def _preprocess(self, input_text: list[str]):
len_input_text = len(input_text)
if len_input_text < self.batch_size:
padding_len = self.batch_size - len_input_text
input_text += [""] * padding_len
assert len(input_text) == self.batch_size

if self.tokenizer.chat_template is not None:
input_text = [input_text] if isinstance(input_text, str) else input_text
input_text = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in input_text]
Expand Down Expand Up @@ -1073,7 +1078,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
if self.tensor_parallel_rank == 0:
outputs = []
output_tokens = []
while len(outputs) < self.batch_size:
while len(outputs) < len(input_texts):
result = result_queue.get(timeout=1)
outputs.append(result[-1])
output_tokens.append(result[-2])
Expand Down

0 comments on commit 0b6284e

Please sign in to comment.