Skip to content

Commit

Permalink
Feature: use hf chat support (#1047)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmartin-tech committed Dec 19, 2024
2 parents 7b033be + 86de116 commit d338d95
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 12 deletions.
71 changes: 59 additions & 12 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def _load_client(self):
set_seed(_config.run.seed)

pipeline_kwargs = self._gather_hf_params(hf_constructor=pipeline)
pipeline_kwargs["truncation"] = (
True # this is forced to maintain existing pipeline expectations
)
self.generator = pipeline("text-generation", **pipeline_kwargs)
if self.generator.tokenizer is None:
# account for possible model without a stored tokenizer
Expand All @@ -87,6 +90,11 @@ def _load_client(self):
self.generator.tokenizer = AutoTokenizer.from_pretrained(
pipeline_kwargs["model"]
)
if not hasattr(self, "use_chat"):
self.use_chat = (
hasattr(self.generator.tokenizer, "chat_template")
and self.generator.tokenizer.chat_template is not None
)
if not hasattr(self, "deprefix_prompt"):
self.deprefix_prompt = self.name in models_to_deprefix
if _config.loaded:
Expand All @@ -98,6 +106,9 @@ def _load_client(self):
def _clear_client(self):
self.generator = None

def _format_chat_prompt(self, prompt: str) -> List[dict]:
return [{"role": "user", "content": prompt}]

def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
Expand All @@ -106,13 +117,16 @@ def _call_model(
warnings.simplefilter("ignore", category=UserWarning)
try:
with torch.no_grad():
# workaround for pipeline to truncate the input
encoded_prompt = self.generator.tokenizer(prompt, truncation=True)
truncated_prompt = self.generator.tokenizer.decode(
encoded_prompt["input_ids"], skip_special_tokens=True
)
# according to docs https://huggingface.co/docs/transformers/main/en/chat_templating
# chat template should be automatically utilized if the pipeline tokenizer has support
# and a properly formatted list[dict] is supplied
if self.use_chat:
formatted_prompt = self._format_chat_prompt(prompt)
else:
formatted_prompt = prompt

raw_output = self.generator(
truncated_prompt,
formatted_prompt,
pad_token_id=self.generator.tokenizer.eos_token_id,
max_new_tokens=self.max_tokens,
num_return_sequences=generations_this_call,
Expand All @@ -127,10 +141,15 @@ def _call_model(
i["generated_text"] for i in raw_output
] # generator returns 10 outputs by default in __init__

if self.use_chat:
text_outputs = [_o[-1]["content"].strip() for _o in outputs]
else:
text_outputs = outputs

if not self.deprefix_prompt:
return outputs
return text_outputs
else:
return [re.sub("^" + re.escape(prompt), "", _o) for _o in outputs]
return [re.sub("^" + re.escape(prompt), "", _o) for _o in text_outputs]


class OptimumPipeline(Pipeline, HFCompatible):
Expand Down Expand Up @@ -468,6 +487,13 @@ def _load_client(self):
self.name, padding_side="left"
)

if not hasattr(self, "use_chat"):
# test tokenizer for `apply_chat_template` support
self.use_chat = (
hasattr(self.tokenizer, "chat_template")
and self.tokenizer.chat_template is not None
)

self.generation_config = transformers.GenerationConfig.from_pretrained(
self.name
)
Expand All @@ -492,14 +518,27 @@ def _call_model(
if self.top_k is not None:
self.generation_config.top_k = self.top_k

text_output = []
raw_text_output = []
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
with torch.no_grad():
if self.use_chat:
formatted_prompt = self.tokenizer.apply_chat_template(
self._format_chat_prompt(prompt),
tokenize=False,
add_generation_prompt=True,
)
else:
formatted_prompt = prompt

inputs = self.tokenizer(
prompt, truncation=True, return_tensors="pt"
formatted_prompt, truncation=True, return_tensors="pt"
).to(self.device)

prefix_prompt = self.tokenizer.decode(
inputs["input_ids"][0], skip_special_tokens=True
)

try:
outputs = self.model.generate(
**inputs, generation_config=self.generation_config
Expand All @@ -512,14 +551,22 @@ def _call_model(
return returnval
else:
raise e
text_output = self.tokenizer.batch_decode(
raw_text_output = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True, device=self.device
)

if self.use_chat:
text_output = [
re.sub("^" + re.escape(prefix_prompt), "", i).strip()
for i in raw_text_output
]
else:
text_output = raw_text_output

if not self.deprefix_prompt:
return text_output
else:
return [re.sub("^" + re.escape(prompt), "", i) for i in text_output]
return [re.sub("^" + re.escape(prefix_prompt), "", i) for i in text_output]


class LLaVA(Generator, HFCompatible):
Expand Down
30 changes: 30 additions & 0 deletions tests/generators/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ def test_pipeline(hf_generator_config):
assert isinstance(item, str)


def test_pipeline_chat(mocker, hf_generator_config):
# uses a ~350M model with chat support
g = garak.generators.huggingface.Pipeline(
"microsoft/DialoGPT-small", config_root=hf_generator_config
)
mock_format = mocker.patch.object(
g, "_format_chat_prompt", wraps=g._format_chat_prompt
)
output = g.generate("Hello world!")
mock_format.assert_called_once()
assert len(output) == 1
for item in output:
assert isinstance(item, str)


def test_inference(mocker, hf_mock_response, hf_generator_config):
model_name = "gpt2"
mock_request = mocker.patch.object(
Expand Down Expand Up @@ -121,6 +136,21 @@ def test_model(hf_generator_config):
assert item is None # gpt2 is known raise exception returning `None`


def test_model_chat(mocker, hf_generator_config):
# uses a ~350M model with chat support
g = garak.generators.huggingface.Model(
"microsoft/DialoGPT-small", config_root=hf_generator_config
)
mock_format = mocker.patch.object(
g, "_format_chat_prompt", wraps=g._format_chat_prompt
)
output = g.generate("Hello world!")
mock_format.assert_called_once()
assert len(output) == 1
for item in output:
assert isinstance(item, str)


def test_select_hf_device():
from garak.generators.huggingface import HFCompatible
import torch
Expand Down

0 comments on commit d338d95

Please sign in to comment.