diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 071310d69e..f7f372f5fa 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -284,7 +284,7 @@ def _autoset_attn_implementation_monkeypatch( # the different processes. To avoid this contention, we first create the model (on meta device) on local rank # zero. This will set up the transformers model cache and avoid the future contention. if dist.get_local_rank() == 0: - if os.path.isdir(pretrained_model_name_or_path): + if pretrained and os.path.isdir(pretrained_model_name_or_path): with init_empty_weights(include_buffers=False): with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning)