Skip to content

Commit

Permalink
Add pattern restriction to MaskLabelGenerator + generalize model use
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknovak committed May 23, 2024
1 parent ece4c90 commit 25fbd75
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions anonipy/anonymize/generators/mask_label_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,18 @@ def __init__(
)
use_gpu = False

# prepare the fill-mask pipeline and store the mask token
model, tokenizer = self._prepare_model_and_tokenizer(model_name, use_gpu)
self.mask_token = tokenizer.mask_token
self.pipeline = pipeline(
"fill-mask", model=model, tokenizer=tokenizer, top_k=10
"fill-mask", model=model, tokenizer=tokenizer, top_k=40
)

def generate(self, entity: Entity, text: str, *args, **kwargs):
masks = self._create_masks(entity)
input_texts = self._prepare_generate_inputs(masks, text)
suggestions = self.pipeline(input_texts)
return self._create_substitute(masks, suggestions)
return self._create_substitute(entity, masks, suggestions)

# =================================
# Private methods
Expand All @@ -67,7 +69,7 @@ def _create_masks(self, entity: Entity):
{
"true_text": chunks[idx],
"mask_text": " ".join(
chunks[0:idx] + ["<mask>"] + chunks[idx + 1 :]
chunks[0:idx] + [self.mask_token] + chunks[idx + 1 :]
),
"start_index": entity.start_index,
"end_index": entity.end_index,
Expand All @@ -90,14 +92,15 @@ def _prepare_generate_inputs(self, masks, text):
for m in masks
]

def _create_substitute(self, masks, suggestions):
def _create_substitute(self, entity: Entity, masks, suggestions):
substitute_chunks = []
for mask, suggestion in zip(masks, suggestions):
suggestion = suggestion if type(suggestion) == list else [suggestion]
viable_suggestions = list(
filter(
lambda x: x["token_str"] not in STOPWORDS
and x["token_str"] != mask["true_text"],
and x["token_str"] != mask["true_text"]
and re.match(entity.regex, x["token_str"]),
suggestion,
)
)
Expand Down

0 comments on commit 25fbd75

Please sign in to comment.