Skip to content

Commit

Permalink
fix: allow load_adapter to use different device
Browse files Browse the repository at this point in the history
  • Loading branch information
yhZhai committed Apr 8, 2024
1 parent e07095a commit f1fe0f0
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def _update_offload(self, offload_index: dict[dict[str:str]], adapters_weights:
os.makedirs(base_name)
safe_save_file(safe_dict, new_fname, metadata=metadata)

def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = False, **kwargs: Any):
def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = False, torch_device: Optional[str] = None, **kwargs: Any):
"""
Load a trained adapter into the model.
Expand All @@ -811,13 +811,16 @@ def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = Fa
is_trainable (`bool`, *optional*, defaults to `False`):
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be
used for inference.
torch_device (`str`, *optional*, defaults to None):
The device to load the adapter on. If `None`, the device will be inferred.
kwargs: (`optional`):
Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub.
"""
from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING

hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs)
torch_device = infer_device()
if torch_device is None:
torch_device = infer_device()

if adapter_name not in self.peft_config:
# load the config
Expand Down

0 comments on commit f1fe0f0

Please sign in to comment.