-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BB: Using forward hooks #904
BB: Using forward hooks #904
Conversation
- some changes: use forward hooks - remove them with try ... finally to be 100% sure - create id_to_adapter_dict on the fly - had to make changes to PeftModel - hook is a partial, not local function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much @BenjaminBossan for working on a feasible way to support multiple LoRAs in a batch. Left few comments!
@@ -487,7 +487,10 @@ def get_base_model(self): | |||
""" | |||
Returns the base model. | |||
""" | |||
return self.base_model if self.active_peft_config.is_prompt_learning else self.base_model.model | |||
# TODO: why do we not always return self.base_model.model? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it doesn't exist for prompt learning
methods.
if self.disable_adapters: | ||
return final_output | ||
|
||
id_to_adapter_dict = dict(enumerate(self.lora_A, start=1)) # 0 is base |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
creating on the fly for each call and on each module is redundant. Better to pass this as kwarg
# adapter_names = ["base"] + list(self.peft_config.keys()) | ||
# return {adapter_name: index for index, adapter_name in enumerate(adapter_names)} | ||
|
||
# def get_adapter_indices(self, adapter_names, return_tensor=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this util is required by end user so that they can get the adapter indices by calling model.get_adapter_indices(adapter_names)
. an example is shown below:
def get_model_pred(prompts, adapter_names):
model.eval()
adapter_indices = model.get_adapter_indices(adapter_names).cuda()
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
inputs = {k:v.cuda() for k,v in inputs.items()}
outputs = model.generate(**inputs,
adapter_indices=adapter_indices,
max_new_tokens=128,
temperature=0.2,
top_k=50,
top_p=0.95,
do_sample=True,
repetition_penalty=1.1,
eos_token_id = tokenizer.eos_token_id)
return tokenizer.batch_decode(outputs, skip_special_tokens=False)
return output | ||
|
||
|
||
def _pre_forward(target, args, kwargs, adapter_indices): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
@@ -363,3 +363,33 @@ def run_with_disable(config_kwargs, bias): | |||
@parameterized.expand(TEST_CASES) | |||
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): | |||
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) | |||
|
|||
def test_mixed_adapter_lora(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice test!!!
See #903 for the combined PR. |
forward
hookstry ... finally
to be 100% sureid_to_adapter_dict
on the flyPeftModel
partial
, not local functionPeftModel
becauseLoraModel.forward
was never called otherwise