Skip to content

Commit

Permalink
Add Entity regex group selection
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknovak committed Jul 16, 2024
1 parent a8d8c59 commit fe68acf
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 9 deletions.
7 changes: 6 additions & 1 deletion anonipy/anonymize/generators/llm_label_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Tuple

import torch
Expand Down Expand Up @@ -130,6 +131,7 @@ def generate(
entity: Entity,
add_entity_attrs: str = "",
temperature: float = 0.0,
use_regex: bool = True,
*args,
**kwargs,
) -> str:
Expand All @@ -145,17 +147,20 @@ def generate(
entity: The entity to generate the label from.
add_entity_attrs: Additional entity attribute description to add to the generation.
temperature: The temperature to use for the generation.
use_regex: Whether to use regex to determine the generated substitute.
Returns:
The generated entity label substitute.
"""

user_prompt = f"What is a random {add_entity_attrs} {entity.label} replacement for {entity.text}? Respond only with the replacement."
# prepare the regex for the entity if needed
regex = None if not use_regex else entity.get_regex_group() or entity.regex
assistant_prompt = gen(
name="replacement",
stop="<|eot_id|>",
regex=entity.regex,
regex=regex,
temperature=temperature,
)
# generate the replacement for the entity
Expand Down
11 changes: 11 additions & 0 deletions anonipy/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def __post_init__(self):
raise ValueError("Custom entities require a regex.")
self.regex = regex_mapping[self.type]

def get_regex_group(self) -> Union[str, None]:
"""Get the regex group.
Returns:
The regex group.
"""

p_match = re.match(r"^.*?\((.*)\).*$", self.regex)
return p_match.group(1) if p_match else self.regex


class Replacement(TypedDict):
"""The class representing the anonipy Replacement object.
Expand Down
28 changes: 20 additions & 8 deletions test/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,22 @@
label="name",
start_index=30,
end_index=38,
score=1.0,
type="string",
),
"name:pattern": Entity(
text="John Doe",
label="name",
start_index=30,
end_index=38,
type="string",
regex="Person: (.*)",
),
"date": [
Entity(
text="20-05-2024",
label="date",
start_index=86,
end_index=96,
score=1.0,
type="date",
)
]
Expand All @@ -135,7 +141,6 @@
label="date",
start_index=86,
end_index=86 + len(str),
score=1.0,
type="date",
)
for str in DATETIME_STRS
Expand All @@ -145,23 +150,20 @@
label="integer",
start_index=121,
end_index=132,
score=1.0,
type="integer",
),
"float": Entity(
text="123,456,789.000",
label="float",
start_index=121,
end_index=132,
score=1.0,
type="float",
),
"custom": Entity(
text="123-45-6789",
label="custom",
start_index=121,
end_index=132,
score=1.0,
type="custom",
regex="\\d{3}-\\d{2}-\\d{4}",
),
Expand All @@ -187,7 +189,8 @@ def test_has_methods(self):
def test_generate_default(self):
entity = test_entities["name"]
generated_text = self.generator.generate(entity)
match = re.match(entity.regex, generated_text)
regex = entity.get_regex_group() or entity.regex
match = re.match(regex, generated_text)
self.assertNotEqual(match, None)
self.assertEqual(match.group(0), generated_text)

Expand All @@ -196,7 +199,16 @@ def test_generate_custom(self):
generated_text = self.generator.generate(
entity, add_entity_attrs="Spanish", temperature=0.5
)
match = re.match(entity.regex, generated_text)
regex = entity.get_regex_group() or entity.regex
match = re.match(regex, generated_text)
self.assertNotEqual(match, None)
self.assertEqual(match.group(0), generated_text)

def test_generate_pattern(self):
entity = test_entities["name:pattern"]
generated_text = self.generator.generate(entity)
regex = entity.get_regex_group() or entity.regex
match = re.match(regex, generated_text)
self.assertNotEqual(match, None)
self.assertEqual(match.group(0), generated_text)

Expand Down

0 comments on commit fe68acf

Please sign in to comment.