Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Generating with mixed adapter batches and with beam search enabled #2287

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,34 @@ def _enable_peft_forward_hooks(self, *args, **kwargs):
if unexpected_adapters:
raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}")

# deal with beam search
num_beams = kwargs.get("num_beams", None)
uses_beam_search = isinstance(num_beams, int) and (num_beams > 1)
original_adapter_names = adapter_names[:]
if uses_beam_search:
if not isinstance(adapter_names, (list, tuple)):
raise TypeError(f"Got adapter names of type {type(adapter_names)}, expected a list of str.")
# When there is beam search, the inputs are repeated n times, thus we repeat each adapter name n times and
# then flatten the nested list. For encoder-decoder models, this extended list should not be applied to the
# encoder part.
adapter_names = sum(([n] * kwargs["num_beams"] for n in adapter_names), [])

hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)

if uses_beam_search and hasattr(self.model, "get_encoder"):
# For encooder-decoder models, even when applying beam search, the encoder part of the model should not use
# the extended adapter_names.
for module in self.model.get_encoder().modules():
if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=original_adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)

yield

for handle in hook_handles:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,18 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs):
def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs):
self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs)

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(
{
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"task_type": "CAUSAL_LM",
},
)
)
def test_generate_with_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate_with_mixed_adapter_batches_and_beam_search(model_id, config_cls, config_kwargs)

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_oft_or_hra_and_gpt2)
)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,18 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs):
self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs)

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(
{
"model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"task_type": "SEQ_2_SEQ_LM",
},
)
)
def test_generate_with_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate_with_mixed_adapter_batches_and_beam_search(model_id, config_cls, config_kwargs)

# skip non lora models - generate does not work for prefix tuning, prompt tuning
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_generate(self, test_name, model_id, config_cls, config_kwargs):
Expand Down
80 changes: 80 additions & 0 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,86 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs):
assert torch.allclose(logits_adapter0[1::3], logits_mixed[1::3], atol=atol, rtol=rtol)
assert torch.allclose(logits_adapter1[2::3], logits_mixed[2::3], atol=atol, rtol=rtol)

def _test_generate_with_mixed_adapter_batches_and_beam_search(self, model_id, config_cls, config_kwargs):
# Test generating with beam search and with mixing different adapters in a single batch by passing the
# adapter_names argument. See #2283.
if config_cls not in (LoraConfig,):
return pytest.skip(f"Mixed adapter batches not supported for {config_cls}")

config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)

torch.manual_seed(0)
model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, adapter_name="adapter0").eval()
model.add_adapter("adapter1", config)

# In contrast to forward, for generate, it can sometimes happen that we get the same results as the base model
# even with LoRA applied because the impact of LoRA is not big enough. Therefore, use this "trick" to make LoRA
# stronger.
for name, param in model.named_parameters():
if model.base_model.prefix in name:
param.data.mul_(10.0)

model = model.to(self.torch_device).eval()

dummy_input = self.prepare_inputs_for_testing()
# ensure that we have at least 3 samples for this test
dummy_input = {k: torch.cat([v for _ in range(3)]) for k, v in dummy_input.items()}

gen_kwargs = {**dummy_input, "max_length": 20, "num_beams": 10, "early_stopping": True}
with torch.inference_mode():
with model.disable_adapter():
gen_base = model.generate(**gen_kwargs)

model.set_adapter("adapter0")
with torch.inference_mode():
gen_adapter0 = model.generate(**gen_kwargs)

model.set_adapter("adapter1")
with torch.inference_mode():
gen_adapter1 = model.generate(**gen_kwargs)

def remove_padding(seq, pad_value):
lst = list(seq)
while lst and (lst[-1] == pad_value):
lst.pop()
return lst

def gens_are_same(gen0, gen1):
# Special function to compare generations. We cannot use torch.allclose it will raise an error when sequence
# lengths differ. Morevoer, we need to remove the padding from the sequences. This is because, even though
# normally identical sequences should have the same length, when we do mixed adapter batches, each sample
# will be padded to the longest sequence in that mixed batch, which can be different from the longest
# sequence without mixed adapter batches.
pad_value = model.config.eos_token_id
for sample0, sample1 in zip(gen0, gen1):
sample0 = remove_padding(sample0, pad_value)
sample1 = remove_padding(sample1, pad_value)
if (len(sample0) != len(sample1)) or (sample0 != sample1):
# at least one sample differs, the generations are not identical
return False
return True

# sanity check that there are enough outputs and that they are different
assert len(gen_base) == len(gen_adapter0) == len(gen_adapter1) >= 3
assert not gens_are_same(gen_base, gen_adapter0)
assert not gens_are_same(gen_base, gen_adapter1)
assert not gens_are_same(gen_adapter0, gen_adapter1)

# alternate between base model, adapter0, and adapter1
adapters = ["__base__", "adapter0", "adapter1"]
gen_kwargs["adapter_names"] = [adapters[i % 3] for i in (range(len(dummy_input["input_ids"])))]

with torch.inference_mode():
gen_mixed = model.generate(**gen_kwargs)

assert gens_are_same(gen_base[::3], gen_mixed[::3])
assert gens_are_same(gen_adapter0[1::3], gen_mixed[1::3])
assert gens_are_same(gen_adapter1[2::3], gen_mixed[2::3])

def _test_generate(self, model_id, config_cls, config_kwargs):
model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
Expand Down
Loading