Skip to content

Commit

Permalink
Merge pull request #170 from PanQiWei/temporarily_replace_prepare_inp…
Browse files Browse the repository at this point in the history
…uts_for_generation

Replace base_model's function temporarily
  • Loading branch information
pacman100 authored Mar 14, 2023
2 parents 321cbd6 + 644d68e commit 3b3fc47
Showing 1 changed file with 67 additions and 43 deletions.
110 changes: 67 additions & 43 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ class PeftModelForCausalLM(PeftModel):
def __init__(self, model, peft_config: PeftConfig):
super().__init__(model, peft_config)
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation

def forward(
self,
Expand Down Expand Up @@ -576,28 +575,38 @@ def forward(
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

def generate(self, **kwargs):
if not isinstance(self.peft_config, PromptLearningConfig):
return self.base_model.generate(**kwargs)
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
try:
if not isinstance(self.peft_config, PromptLearningConfig):
outputs = self.base_model.generate(**kwargs)
else:
if "input_ids" not in kwargs:
raise ValueError("input_ids must be provided for Peft model generation")
if kwargs.get("attention_mask", None) is not None:
# concat prompt attention mask
prefix_attention_mask = torch.ones(
kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens
).to(kwargs["input_ids"].device)
kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)

if kwargs.get("position_ids", None) is not None:
warnings.warn(
"Position ids are not supported for parameter efficient tuning. Ignoring position ids."
)
kwargs["position_ids"] = None
if kwargs.get("token_type_ids", None) is not None:
warnings.warn(
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
)
kwargs["token_type_ids"] = None

outputs = self.base_model.generate(**kwargs)
except:
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
raise
else:
if "input_ids" not in kwargs:
raise ValueError("input_ids must be provided for Peft model generation")
if kwargs.get("attention_mask", None) is not None:
# concat prompt attention mask
prefix_attention_mask = torch.ones(
kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens
).to(kwargs["input_ids"].device)
kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)

if kwargs.get("position_ids", None) is not None:
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
kwargs["position_ids"] = None
if kwargs.get("token_type_ids", None) is not None:
warnings.warn(
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
)
kwargs["token_type_ids"] = None

return self.base_model.generate(**kwargs)
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
return outputs

def prepare_inputs_for_generation(self, *args, **kwargs):
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
Expand Down Expand Up @@ -641,13 +650,9 @@ class PeftModelForSeq2SeqLM(PeftModel):
def __init__(self, model, peft_config: PeftConfig):
super().__init__(model, peft_config)
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
self.base_model._prepare_encoder_decoder_kwargs_for_generation
)
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
self._prepare_encoder_decoder_kwargs_for_generation
)

def forward(
self,
Expand Down Expand Up @@ -740,24 +745,43 @@ def forward(
)

def generate(self, **kwargs):
if not isinstance(self.peft_config, PromptLearningConfig):
return self.base_model.generate(**kwargs)
else:
if "input_ids" not in kwargs:
raise ValueError("input_ids must be provided for Peft model generation")
if kwargs.get("position_ids", None) is not None:
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
kwargs["position_ids"] = None
if kwargs.get("token_type_ids", None) is not None:
warnings.warn(
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
)
kwargs["token_type_ids"] = None

if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
return self.base_model.generate(**kwargs)
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
self._prepare_encoder_decoder_kwargs_for_generation
)
try:
if not isinstance(self.peft_config, PromptLearningConfig):
outputs = self.base_model.generate(**kwargs)
else:
raise NotImplementedError
if "input_ids" not in kwargs:
raise ValueError("input_ids must be provided for Peft model generation")
if kwargs.get("position_ids", None) is not None:
warnings.warn(
"Position ids are not supported for parameter efficient tuning. Ignoring position ids."
)
kwargs["position_ids"] = None
if kwargs.get("token_type_ids", None) is not None:
warnings.warn(
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
)
kwargs["token_type_ids"] = None

if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
outputs = self.base_model.generate(**kwargs)
else:
raise NotImplementedError
except:
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
self.base_model_prepare_encoder_decoder_kwargs_for_generation
)
raise
else:
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
self.base_model_prepare_encoder_decoder_kwargs_for_generation
)
return outputs

def prepare_inputs_for_generation(self, *args, **kwargs):
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
Expand Down

0 comments on commit 3b3fc47

Please sign in to comment.