Skip to content

Commit

Permalink
fix(chem-ner): fix how to get entities
Browse files Browse the repository at this point in the history
  • Loading branch information
leogail committed Nov 22, 2024
1 parent 8e34d97 commit bd7f83b
Showing 1 changed file with 28 additions and 19 deletions.
47 changes: 28 additions & 19 deletions services/chem-ner/v1/chem/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,34 @@ def predict_formula_ml(input_text):

predictions = torch.argmax(output.logits, dim=-1)

# Get token that contains "CHEMICAL"
tokens = tokenizer.convert_ids_to_tokens(tokens['input_ids'][0])
chemical_tokens_list = []
i=0

while i < len(predictions[0]):
# prediction [0][i] depends of i : {0 : "B-CHEMICAL" , 1 : "I-CHEMICAL" , 2: "NOT a chemical NE"}
k=0
if predictions[0][i] < 2:
chemical_tokens_toappend = []
while predictions[0][i+k] < 2:
chemical_tokens_toappend.append(tokens[i+k])
k+=1
chemical_tokens_list.append(chemical_tokens_toappend)
i+=k+1
value = []
for chemical_tokens in chemical_tokens_list:
value.append(tokenizer.decode(tokenizer.convert_tokens_to_ids(chemical_tokens)))
return value
#convert the predictions to labels
predicted_labels = [model.config.id2label[pred.item()] for pred in predictions[0]]

chemical_entities = []
current_entity = []

# Iterate over both tokens and entity directly
for token, label in zip(tokenizer.convert_ids_to_tokens(tokens['input_ids'][0]), predicted_labels):
if label.startswith("B-"): # Beginning of an entity
if current_entity:
chemical_entities.append(current_entity)
current_entity = []
current_entity.append(token)
elif label.startswith("I-") and current_entity: # Continuation of an entity
current_entity.append(token)
else:
if current_entity:
chemical_entities.append(current_entity)
current_entity = []

# If there's an entity left at the end (here was a bug with last version)
if current_entity:
chemical_entities.append(current_entity)

# Convert tokens back to string format
chemical_entities = [tokenizer.convert_tokens_to_string(entity_tokens) for entity_tokens in chemical_entities]

return chemical_entities

# if text too long
def split_text(text):
Expand Down

0 comments on commit bd7f83b

Please sign in to comment.