diff --git a/garak/generators/huggingface.py b/garak/generators/huggingface.py index abfddc9c..341c3f0e 100644 --- a/garak/generators/huggingface.py +++ b/garak/generators/huggingface.py @@ -79,6 +79,9 @@ def _load_client(self): set_seed(_config.run.seed) pipeline_kwargs = self._gather_hf_params(hf_constructor=pipeline) + pipeline_kwargs["truncation"] = ( + True # this is forced to maintain existing pipeline expectations + ) self.generator = pipeline("text-generation", **pipeline_kwargs) if self.generator.tokenizer is None: # account for possible model without a stored tokenizer @@ -87,6 +90,11 @@ def _load_client(self): self.generator.tokenizer = AutoTokenizer.from_pretrained( pipeline_kwargs["model"] ) + if not hasattr(self, "use_chat"): + self.use_chat = ( + hasattr(self.generator.tokenizer, "chat_template") + and self.generator.tokenizer.chat_template is not None + ) if not hasattr(self, "deprefix_prompt"): self.deprefix_prompt = self.name in models_to_deprefix if _config.loaded: @@ -98,6 +106,9 @@ def _load_client(self): def _clear_client(self): self.generator = None + def _format_chat_prompt(self, prompt: str) -> List[dict]: + return [{"role": "user", "content": prompt}] + def _call_model( self, prompt: str, generations_this_call: int = 1 ) -> List[Union[str, None]]: @@ -106,13 +117,16 @@ def _call_model( warnings.simplefilter("ignore", category=UserWarning) try: with torch.no_grad(): - # workaround for pipeline to truncate the input - encoded_prompt = self.generator.tokenizer(prompt, truncation=True) - truncated_prompt = self.generator.tokenizer.decode( - encoded_prompt["input_ids"], skip_special_tokens=True - ) + # according to docs https://huggingface.co/docs/transformers/main/en/chat_templating + # chat template should be automatically utilized if the pipeline tokenizer has support + # and a properly formatted list[dict] is supplied + if self.use_chat: + formatted_prompt = self._format_chat_prompt(prompt) + else: + formatted_prompt = prompt + raw_output = self.generator( - truncated_prompt, + formatted_prompt, pad_token_id=self.generator.tokenizer.eos_token_id, max_new_tokens=self.max_tokens, num_return_sequences=generations_this_call, @@ -127,10 +141,15 @@ def _call_model( i["generated_text"] for i in raw_output ] # generator returns 10 outputs by default in __init__ + if self.use_chat: + text_outputs = [_o[-1]["content"].strip() for _o in outputs] + else: + text_outputs = outputs + if not self.deprefix_prompt: - return outputs + return text_outputs else: - return [re.sub("^" + re.escape(prompt), "", _o) for _o in outputs] + return [re.sub("^" + re.escape(prompt), "", _o) for _o in text_outputs] class OptimumPipeline(Pipeline, HFCompatible): @@ -468,6 +487,13 @@ def _load_client(self): self.name, padding_side="left" ) + if not hasattr(self, "use_chat"): + # test tokenizer for `apply_chat_template` support + self.use_chat = ( + hasattr(self.tokenizer, "chat_template") + and self.tokenizer.chat_template is not None + ) + self.generation_config = transformers.GenerationConfig.from_pretrained( self.name ) @@ -492,14 +518,27 @@ def _call_model( if self.top_k is not None: self.generation_config.top_k = self.top_k - text_output = [] + raw_text_output = [] with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) with torch.no_grad(): + if self.use_chat: + formatted_prompt = self.tokenizer.apply_chat_template( + self._format_chat_prompt(prompt), + tokenize=False, + add_generation_prompt=True, + ) + else: + formatted_prompt = prompt + inputs = self.tokenizer( - prompt, truncation=True, return_tensors="pt" + formatted_prompt, truncation=True, return_tensors="pt" ).to(self.device) + prefix_prompt = self.tokenizer.decode( + inputs["input_ids"][0], skip_special_tokens=True + ) + try: outputs = self.model.generate( **inputs, generation_config=self.generation_config @@ -512,14 +551,22 @@ def _call_model( return returnval else: raise e - text_output = self.tokenizer.batch_decode( + raw_text_output = self.tokenizer.batch_decode( outputs, skip_special_tokens=True, device=self.device ) + if self.use_chat: + text_output = [ + re.sub("^" + re.escape(prefix_prompt), "", i).strip() + for i in raw_text_output + ] + else: + text_output = raw_text_output + if not self.deprefix_prompt: return text_output else: - return [re.sub("^" + re.escape(prompt), "", i) for i in text_output] + return [re.sub("^" + re.escape(prefix_prompt), "", i) for i in text_output] class LLaVA(Generator, HFCompatible): diff --git a/tests/generators/test_huggingface.py b/tests/generators/test_huggingface.py index f784d95d..fd830027 100644 --- a/tests/generators/test_huggingface.py +++ b/tests/generators/test_huggingface.py @@ -50,6 +50,21 @@ def test_pipeline(hf_generator_config): assert isinstance(item, str) +def test_pipeline_chat(mocker, hf_generator_config): + # uses a ~350M model with chat support + g = garak.generators.huggingface.Pipeline( + "microsoft/DialoGPT-small", config_root=hf_generator_config + ) + mock_format = mocker.patch.object( + g, "_format_chat_prompt", wraps=g._format_chat_prompt + ) + output = g.generate("Hello world!") + mock_format.assert_called_once() + assert len(output) == 1 + for item in output: + assert isinstance(item, str) + + def test_inference(mocker, hf_mock_response, hf_generator_config): model_name = "gpt2" mock_request = mocker.patch.object( @@ -121,6 +136,21 @@ def test_model(hf_generator_config): assert item is None # gpt2 is known raise exception returning `None` +def test_model_chat(mocker, hf_generator_config): + # uses a ~350M model with chat support + g = garak.generators.huggingface.Model( + "microsoft/DialoGPT-small", config_root=hf_generator_config + ) + mock_format = mocker.patch.object( + g, "_format_chat_prompt", wraps=g._format_chat_prompt + ) + output = g.generate("Hello world!") + mock_format.assert_called_once() + assert len(output) == 1 + for item in output: + assert isinstance(item, str) + + def test_select_hf_device(): from garak.generators.huggingface import HFCompatible import torch