From fe68acfdce4b8de82836dc5193df23cf609549ab Mon Sep 17 00:00:00 2001 From: eriknovak Date: Tue, 16 Jul 2024 18:18:09 +0200 Subject: [PATCH] Add Entity regex group selection --- .../generators/llm_label_generator.py | 7 ++++- anonipy/definitions.py | 11 ++++++++ test/test_generators.py | 28 +++++++++++++------ 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/anonipy/anonymize/generators/llm_label_generator.py b/anonipy/anonymize/generators/llm_label_generator.py index 2b34a03..4c31f59 100644 --- a/anonipy/anonymize/generators/llm_label_generator.py +++ b/anonipy/anonymize/generators/llm_label_generator.py @@ -1,3 +1,4 @@ +import re from typing import Tuple import torch @@ -130,6 +131,7 @@ def generate( entity: Entity, add_entity_attrs: str = "", temperature: float = 0.0, + use_regex: bool = True, *args, **kwargs, ) -> str: @@ -145,6 +147,7 @@ def generate( entity: The entity to generate the label from. add_entity_attrs: Additional entity attribute description to add to the generation. temperature: The temperature to use for the generation. + use_regex: Whether to use regex to determine the generated substitute. Returns: The generated entity label substitute. @@ -152,10 +155,12 @@ def generate( """ user_prompt = f"What is a random {add_entity_attrs} {entity.label} replacement for {entity.text}? Respond only with the replacement." + # prepare the regex for the entity if needed + regex = None if not use_regex else entity.get_regex_group() or entity.regex assistant_prompt = gen( name="replacement", stop="<|eot_id|>", - regex=entity.regex, + regex=regex, temperature=temperature, ) # generate the replacement for the entity diff --git a/anonipy/definitions.py b/anonipy/definitions.py index 5faeae0..40cab26 100644 --- a/anonipy/definitions.py +++ b/anonipy/definitions.py @@ -50,6 +50,17 @@ def __post_init__(self): raise ValueError("Custom entities require a regex.") self.regex = regex_mapping[self.type] + def get_regex_group(self) -> Union[str, None]: + """Get the regex group. + + Returns: + The regex group. + + """ + + p_match = re.match(r"^.*?\((.*)\).*$", self.regex) + return p_match.group(1) if p_match else self.regex + class Replacement(TypedDict): """The class representing the anonipy Replacement object. diff --git a/test/test_generators.py b/test/test_generators.py index 0aa54ea..97afea4 100644 --- a/test/test_generators.py +++ b/test/test_generators.py @@ -116,16 +116,22 @@ label="name", start_index=30, end_index=38, - score=1.0, type="string", ), + "name:pattern": Entity( + text="John Doe", + label="name", + start_index=30, + end_index=38, + type="string", + regex="Person: (.*)", + ), "date": [ Entity( text="20-05-2024", label="date", start_index=86, end_index=96, - score=1.0, type="date", ) ] @@ -135,7 +141,6 @@ label="date", start_index=86, end_index=86 + len(str), - score=1.0, type="date", ) for str in DATETIME_STRS @@ -145,7 +150,6 @@ label="integer", start_index=121, end_index=132, - score=1.0, type="integer", ), "float": Entity( @@ -153,7 +157,6 @@ label="float", start_index=121, end_index=132, - score=1.0, type="float", ), "custom": Entity( @@ -161,7 +164,6 @@ label="custom", start_index=121, end_index=132, - score=1.0, type="custom", regex="\\d{3}-\\d{2}-\\d{4}", ), @@ -187,7 +189,8 @@ def test_has_methods(self): def test_generate_default(self): entity = test_entities["name"] generated_text = self.generator.generate(entity) - match = re.match(entity.regex, generated_text) + regex = entity.get_regex_group() or entity.regex + match = re.match(regex, generated_text) self.assertNotEqual(match, None) self.assertEqual(match.group(0), generated_text) @@ -196,7 +199,16 @@ def test_generate_custom(self): generated_text = self.generator.generate( entity, add_entity_attrs="Spanish", temperature=0.5 ) - match = re.match(entity.regex, generated_text) + regex = entity.get_regex_group() or entity.regex + match = re.match(regex, generated_text) + self.assertNotEqual(match, None) + self.assertEqual(match.group(0), generated_text) + + def test_generate_pattern(self): + entity = test_entities["name:pattern"] + generated_text = self.generator.generate(entity) + regex = entity.get_regex_group() or entity.regex + match = re.match(regex, generated_text) self.assertNotEqual(match, None) self.assertEqual(match.group(0), generated_text)