diff --git a/services/chem-ner/v1/chem/tagger.py b/services/chem-ner/v1/chem/tagger.py index 3cb64a2f..e615bc38 100755 --- a/services/chem-ner/v1/chem/tagger.py +++ b/services/chem-ner/v1/chem/tagger.py @@ -43,25 +43,34 @@ def predict_formula_ml(input_text): predictions = torch.argmax(output.logits, dim=-1) - # Get token that contains "CHEMICAL" - tokens = tokenizer.convert_ids_to_tokens(tokens['input_ids'][0]) - chemical_tokens_list = [] - i=0 - - while i < len(predictions[0]): - # prediction [0][i] depends of i : {0 : "B-CHEMICAL" , 1 : "I-CHEMICAL" , 2: "NOT a chemical NE"} - k=0 - if predictions[0][i] < 2: - chemical_tokens_toappend = [] - while predictions[0][i+k] < 2: - chemical_tokens_toappend.append(tokens[i+k]) - k+=1 - chemical_tokens_list.append(chemical_tokens_toappend) - i+=k+1 - value = [] - for chemical_tokens in chemical_tokens_list: - value.append(tokenizer.decode(tokenizer.convert_tokens_to_ids(chemical_tokens))) - return value + #convert the predictions to labels + predicted_labels = [model.config.id2label[pred.item()] for pred in predictions[0]] + + chemical_entities = [] + current_entity = [] + + # Iterate over both tokens and entity directly + for token, label in zip(tokenizer.convert_ids_to_tokens(tokens['input_ids'][0]), predicted_labels): + if label.startswith("B-"): # Beginning of an entity + if current_entity: + chemical_entities.append(current_entity) + current_entity = [] + current_entity.append(token) + elif label.startswith("I-") and current_entity: # Continuation of an entity + current_entity.append(token) + else: + if current_entity: + chemical_entities.append(current_entity) + current_entity = [] + + # If there's an entity left at the end (here was a bug with last version) + if current_entity: + chemical_entities.append(current_entity) + + # Convert tokens back to string format + chemical_entities = [tokenizer.convert_tokens_to_string(entity_tokens) for entity_tokens in chemical_entities] + + return chemical_entities # if text too long def split_text(text):