From d12a36e66024a93d33ec61826a77d5a346c16869 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Omar=20=C3=84nwar?= Date: Wed, 14 Feb 2024 00:45:37 +0600 Subject: [PATCH] Update llama.py --- higgsfield/llama/llama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/higgsfield/llama/llama.py b/higgsfield/llama/llama.py index 90e6068..63cbbdf 100644 --- a/higgsfield/llama/llama.py +++ b/higgsfield/llama/llama.py @@ -59,14 +59,14 @@ def __init__( if not checkpoint_path: if cpu_init_rank0: if rank == 0: - model = LlamaForCausalLM.from_pretrained(model_name) + model = LlamaForCausalLM.from_pretrained(model_name, use_cache=False) else: - llama_config = LlamaConfig.from_pretrained(model_name) + llama_config = LlamaConfig.from_pretrained(model_name, use_cache=False) with torch.device('meta'): model = LlamaForCausalLM(llama_config) else: - model = LlamaForCausalLM.from_pretrained(model_name) + model = LlamaForCausalLM.from_pretrained(model_name, use_cache=False) else: if not cpu_init_rank0: print("Ignoring cpu_init_rank0=False while loading model from checkpoint path") @@ -298,4 +298,4 @@ def __init__( precision, cpu_init_rank0, cpu_offload, - ) \ No newline at end of file + )