diff --git a/workflows/chatbot/inference/generate.py b/workflows/chatbot/inference/generate.py index b1653261122..51ed093fcae 100644 --- a/workflows/chatbot/inference/generate.py +++ b/workflows/chatbot/inference/generate.py @@ -817,6 +817,9 @@ def main(): from optimum.habana.utils import set_seed set_seed(args.seed) + else: + from transformers import set_seed + set_seed(args.seed) tokenizer_path = ( args.tokenizer_name if args.tokenizer_name is not None else base_model_path