Skip to content

Commit

Permalink
Exllamav2 lora support (#4229)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
  • Loading branch information
Ph0rk0z and oobabooga authored Oct 14, 2023
1 parent 1f5a2c5 commit 8cce1f1
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 12 deletions.
32 changes: 31 additions & 1 deletion modules/LoRA.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def add_lora_to_model(lora_names):
add_lora_autogptq(lora_names)
elif shared.model.__class__.__name__ in ['ExllamaModel', 'ExllamaHF'] or shared.args.loader == 'ExLlama':
add_lora_exllama(lora_names)
elif shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav2HF'] or shared.args.loader == ['ExLlamav2', 'ExLlamav2_HF']:
add_lora_exllamav2(lora_names)
else:
add_lora_transformers(lora_names)

Expand Down Expand Up @@ -64,8 +66,36 @@ def add_lora_exllama(lora_names):
return


# Adapted from https://github.com/Ph0rk0z/text-generation-webui-testing
def add_lora_exllamav2(lora_names):

from exllamav2 import ExLlamaV2Lora

if isinstance(shared.model.loras, list):
for lora in shared.model.loras:
lora.unload()

if len(lora_names) > 0:
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
shared.model.loras = []
for lora_name in lora_names:
lora_path = get_lora_path(lora_name)
if shared.model.__class__.__name__ == 'Exllamav2Model':
lora = ExLlamaV2Lora.from_directory(shared.model.model, str(lora_path))
else:
lora = ExLlamaV2Lora.from_directory(shared.model.ex_model, str(lora_path))

shared.model.loras.append(lora)

shared.lora_names = lora_names
else:
shared.lora_names = []
shared.model.loras = None


def add_lora_autogptq(lora_names):
'''
Adapted from https://github.com/Ph0rk0z/text-generation-webui-testing
'''

try:
from auto_gptq import get_gptq_peft_model
Expand Down
4 changes: 3 additions & 1 deletion modules/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def decode(self, ids, **kwargs):

def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
if token_ids.shape[-1] > 1:
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)

return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()

def generate_with_streaming(self, prompt, state):
Expand Down
13 changes: 8 additions & 5 deletions modules/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def from_pretrained(self, path_to_model):
result.cache = cache
result.tokenizer = tokenizer
result.generator = generator
result.loras = None
return result, result

def encode(self, string, **kwargs):
Expand All @@ -75,8 +76,10 @@ def decode(self, ids, **kwargs):

def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
if token_ids.shape[-1] > 1:
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)

return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu()

def generate_with_streaming(self, prompt, state):
settings = ExLlamaV2Sampler.Settings()
Expand Down Expand Up @@ -105,12 +108,12 @@ def generate_with_streaming(self, prompt, state):

# _gen_begin_base
self.cache.current_seq_len = 0
self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)

has_leading_space = False
for i in range(max_new_tokens):
logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None).float().cpu()
token, _, _= ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer)
logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None, loras=self.loras).float().cpu()
token, _, _ = ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer)
ids = torch.cat([ids, token], dim=1)

if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
Expand Down
10 changes: 5 additions & 5 deletions modules/exllamav2_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self, config: ExLlamaV2Config):
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]

self.ex_model.load(split)

self.generation_config = GenerationConfig()
self.loras = None

self.ex_cache = ExLlamaV2Cache(self.ex_model)
self.past_seq = None
Expand Down Expand Up @@ -97,7 +97,7 @@ def __call__(self, *args, **kwargs):
reset = False
ex_cache.current_seq_len = longest_prefix
if len(seq_tensor) - longest_prefix > 1:
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True)
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
elif len(seq_tensor) == longest_prefix:
# Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
# because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
Expand All @@ -106,12 +106,12 @@ def __call__(self, *args, **kwargs):
if reset:
ex_cache.current_seq_len = 0
if len(seq_tensor) > 1:
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True)
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)

logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device)
else:
ex_cache.current_seq_len = 0
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False)
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras)

if is_negative:
self.past_seq_negative = seq_tensor
Expand Down

0 comments on commit 8cce1f1

Please sign in to comment.