From 88409ca22aa3308bf7e954bc323a21e4137ccdd9 Mon Sep 17 00:00:00 2001 From: zhenwei-intel Date: Fri, 10 May 2024 09:33:03 +0000 Subject: [PATCH] script for phi3 Signed-off-by: zhenwei-intel --- .../quantization/run_generation_gpu_woq.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py index 12058ee1d7c..1145a7de632 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py @@ -10,6 +10,7 @@ from intel_extension_for_transformers.transformers import AutoModelForCausalLM, AutoRoundConfig, RtnConfig, GPTQConfig from intel_extension_for_transformers.transformers.llm.quantization.utils import convert_dtype_str2torch from transformers.utils import check_min_version +import contextlib parser = argparse.ArgumentParser() parser.add_argument( @@ -241,8 +242,10 @@ generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=args.num_beams) if args.profile_token_latency: ipex.transformers.optimize.convert_function(user_model, "greedy_search", _greedy_search) + ipex.transformers.optimize.convert_function(user_model, "_greedy_search", _greedy_search) if args.disable_optimize_transformers: ipex.transformers.optimize.convert_function(user_model, "beam_search", _beam_search) + ipex.transformers.optimize.convert_function(user_model, "_beam_search", _beam_search) user_model.config.token_latency = True total_time = 0.0 @@ -253,7 +256,11 @@ dtype=amp_dtype if amp_enabled else None, ): for i in range(num_iter + num_warmup): - with torch.autograd.profiler_legacy.profile(enabled=args.do_profiling, use_xpu=(args.device=="xpu"), record_shapes=False) as prof: + if args.do_profiling: + context = torch.autograd.profiler_legacy.profile(enabled=args.do_profiling, use_xpu=True, record_shapes=True) + else: + context = contextlib.nullcontext() + with context as prof: input_ids = tokenizer( prompt, return_tensors="pt").input_ids.to(args.device) tic = time.time()