Skip to content

Commit

Permalink
Update llama.py
Browse files Browse the repository at this point in the history
  • Loading branch information
arpanetus authored Feb 13, 2024
1 parent f7a4fbf commit d12a36e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions higgsfield/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def __init__(
if not checkpoint_path:
if cpu_init_rank0:
if rank == 0:
model = LlamaForCausalLM.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(model_name, use_cache=False)
else:
llama_config = LlamaConfig.from_pretrained(model_name)
llama_config = LlamaConfig.from_pretrained(model_name, use_cache=False)

with torch.device('meta'):
model = LlamaForCausalLM(llama_config)
else:
model = LlamaForCausalLM.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(model_name, use_cache=False)
else:
if not cpu_init_rank0:
print("Ignoring cpu_init_rank0=False while loading model from checkpoint path")
Expand Down Expand Up @@ -298,4 +298,4 @@ def __init__(
precision,
cpu_init_rank0,
cpu_offload,
)
)

0 comments on commit d12a36e

Please sign in to comment.