diff --git a/op_builder/xpu/inference.py b/op_builder/xpu/inference.py index 9114dcc2c315..a9ac4f84c2ca 100644 --- a/op_builder/xpu/inference.py +++ b/op_builder/xpu/inference.py @@ -30,7 +30,10 @@ def cxx_args(self): def load(self): try: - import intel_extension_for_pytorch.deepspeed - return intel_extension_for_pytorch.deepspeed.transformer_inference.transformer_inference + import intel_extension_for_pytorch + if hasattr(intel_extension_for_pytorch, "deepspeed"): + return intel_extension_for_pytorch.deepspeed.transformer_inference.transformer_inference + else: + return intel_extension_for_pytorch.xpu.deepspeed except ImportError: raise ImportError("Please install intel-extension-for-pytorch >= 2.1.30 to include DeepSpeed kernels.")