Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BB: Using forward hooks #904

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Sep 5, 2023

  • some changes: use forward hooks
  • remove handles with try ... finally to be 100% sure
  • create id_to_adapter_dict on the fly
  • had to make changes to PeftModel
  • hook is a partial, not local function
  • had to change PeftModel because LoraModel.forward was never called otherwise

- some changes: use forward hooks
- remove them with try ... finally to be 100% sure
- create id_to_adapter_dict on the fly
- had to make changes to PeftModel
- hook is a partial, not local function
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much @BenjaminBossan for working on a feasible way to support multiple LoRAs in a batch. Left few comments!

@@ -487,7 +487,10 @@ def get_base_model(self):
"""
Returns the base model.
"""
return self.base_model if self.active_peft_config.is_prompt_learning else self.base_model.model
# TODO: why do we not always return self.base_model.model?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it doesn't exist for prompt learning methods.

if self.disable_adapters:
return final_output

id_to_adapter_dict = dict(enumerate(self.lora_A, start=1)) # 0 is base
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

creating on the fly for each call and on each module is redundant. Better to pass this as kwarg

# adapter_names = ["base"] + list(self.peft_config.keys())
# return {adapter_name: index for index, adapter_name in enumerate(adapter_names)}

# def get_adapter_indices(self, adapter_names, return_tensor=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this util is required by end user so that they can get the adapter indices by calling model.get_adapter_indices(adapter_names). an example is shown below:

def get_model_pred(prompts, adapter_names):
    model.eval()
    adapter_indices = model.get_adapter_indices(adapter_names).cuda()
    inputs = tokenizer(prompts, return_tensors="pt", padding=True)
    if "token_type_ids" in inputs:
        del inputs["token_type_ids"]
    inputs = {k:v.cuda() for k,v in inputs.items()}
    outputs = model.generate(**inputs,
                             adapter_indices=adapter_indices,
                                 max_new_tokens=128,
                                 temperature=0.2,
                                 top_k=50,
                                 top_p=0.95,
                                 do_sample=True,
                                 repetition_penalty=1.1,
                                 eos_token_id = tokenizer.eos_token_id)
    return tokenizer.batch_decode(outputs, skip_special_tokens=False)

return output


def _pre_forward(target, args, kwargs, adapter_indices):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@@ -363,3 +363,33 @@ def run_with_disable(config_kwargs, bias):
@parameterized.expand(TEST_CASES)
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)

def test_mixed_adapter_lora(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test!!!

@BenjaminBossan
Copy link
Member Author

See #903 for the combined PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants