Skip to content

Commit

Permalink
update with spark management
Browse files Browse the repository at this point in the history
  • Loading branch information
ofilangi committed Oct 31, 2024
1 parent 19cf542 commit d5647eb
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 242 deletions.
3 changes: 1 addition & 2 deletions config/mesh-demo.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
"constraints" : {
"meshv:active" : "true",
"rdf:type" : "meshv:Concept"
}

}
}
}
},
Expand Down
5 changes: 4 additions & 1 deletion config/planteome-demo-only-TO-0000394.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"encodeur" : "sentence-transformers/all-MiniLM-L6-v2",
"threshold_similarity_tag_chunk" : 0.70,
"threshold_similarity_tag_chunk" : 0.10,
"threshold_similarity_tag" : 0.80,
"batch_size" : 32,

Expand Down Expand Up @@ -36,6 +36,9 @@
"selected_term" : [
"Crops%2C+Agricultural%2Fmetabolism%5BMeSH%5D"
]
},
"from_file" : {
"json_file" : "data/msd/export-pubmed-20241014-4-planetome-tagging-sub-test/part-00016-6787be90-eb7f-4950-8ef0-98d9dbbbcd38-c000.json"
}

}
Expand Down
2 changes: 1 addition & 1 deletion config/pubmed-all.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"populate_abstract_embeddings" : {
"abstracts_per_file" : 500,
"from_file" : {
"json_dir" : "/scratch/ofilangi/export-pubmed-20241014-4-planetome-tagging"
"json_file" : "data/msd/export-pubmed-20241014-4-planetome-tagging-sub-test/part-00016-6787be90-eb7f-4950-8ef0-98d9dbbbcd38-c000.json"
}
}
}
2 changes: 1 addition & 1 deletion config/transformon-demo.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"transformon": {
"url": " https://entrepot.recherche.data.gouv.fr/api/access/datafile/:persistentId?persistentId=doi:10.57745/X2ZFLG",
"prefix": "http://opendata.inrae.fr/PO2/Ontology/TransformON/Component/",
"format": "xml",
"format": "turtle",
"label" : "skos:prefLabel",
"properties": ["skos:scopeNote"]
}
Expand Down
2 changes: 1 addition & 1 deletion llm_semantic_annotator/abstract/abstract_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _set_embedding_abstract_file(self):
print(f"{pth_filename} already exists !")
continue
results = self._get_data_abstracts_file(json_f)
self.mem.save_pth(self.mem.encode_abstracts(results,genname),genname)
self.mem.save_pth(self.mem.encode_abstracts(results),genname)

def manage_abstracts(self):

Expand Down
17 changes: 6 additions & 11 deletions llm_semantic_annotator/similarity/model_embedding_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,14 @@

# https://huggingface.co/spaces/mteb/leaderboard

class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]


class ModelEmbeddingManager(metaclass=Singleton):
class ModelEmbeddingManager():
def __init__(self, config):
self.config = config
self.retention_dir = config['retention_dir']
if 'retention_dir' not in config:
self.retention_dir = "/tmp"
else:
self.retention_dir = config['retention_dir']
self.encoder=config['encodeur']
self.model_suffix=self.encoder.split('/')[-1]
self.model_name = config.get('encodeur', self.encoder)
Expand Down Expand Up @@ -134,7 +130,7 @@ def encode_tags(self,tags):

return tags_embedding

def encode_abstracts(self,abstracts,genname) :
def encode_abstracts(self,abstracts) :
"""
abstract : {
'doi',
Expand All @@ -147,7 +143,6 @@ def encode_abstracts(self,abstracts,genname) :
chunks_doi_ref = []
lcount = 0

print("Flat abstracts to build batch.....",genname)
for item in tqdm(abstracts):
if 'abstract' in item and item['abstract'].strip() != '':
if 'title' in item and item['title'].strip() != '':
Expand Down
55 changes: 39 additions & 16 deletions llm_semantic_annotator/tag/owl_tag_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ def __init__(self,config,model_embedding_manager):
else:
self.debug_nb_terms_by_ontology = -1

self.retention_dir = config['retention_dir']
self.ontologies_by_link = config['ontologies']
if 'retention_dir' not in config:
self.retention_dir = "/tmp"
else:
self.retention_dir = config['retention_dir']

if 'ontologies' not in config:
self.ontologies_by_link = {}
else:
self.ontologies_by_link = config['ontologies']

self.prefixes = {}

Expand All @@ -28,8 +35,8 @@ def __init__(self,config,model_embedding_manager):

if 'force' not in config:
config['force'] = False
else:
self.force = config['force']

self.force = config['force']

self.mem = model_embedding_manager
self.tags_owl_path_filename = f"tags_owl_"
Expand Down Expand Up @@ -81,25 +88,50 @@ def remove_prefix_tags(self,text):

v = re.sub(pattern, '', text)
return re.sub(r'\(\)', '', v)

def build_corpus(
self,
ontology,
ontology_group_name,
ontology_config,
debug_nb_terms_by_ontology):
debug_nb_terms_by_ontology,
owl_content=None):

tags_owl_path_filename = self.tags_owl_path_filename+ontology
tag_embeddings = self.mem.load_pth(tags_owl_path_filename)

if (len(tag_embeddings)>0):
return tag_embeddings

tags = self.build_tags_from_owl(ontology,ontology_group_name,ontology_config,debug_nb_terms_by_ontology,owl_content)

df = pd.DataFrame({
'ontology' : [ ele['ontology'] for ele in tags ],
'term' : [ ele['term'] for ele in tags ],
'rdfs:label': [ ele['rdfs_label'] for ele in tags ],
'description': [ ele['description'] for ele in tags ],
})

df.to_csv(self.retention_dir+f"/tags_owl_{ontology}.csv", index=False)
self.mem.save_pth(self.mem.encode_tags(tags),tags_owl_path_filename)
return tags

def build_tags_from_owl(
self,
ontology,
ontology_group_name,
ontology_config,
debug_nb_terms_by_ontology,
owl_content=None):

# Charger le fichier OWL local

g = Graph()
print("loading ontology: ",ontology)
g.parse(ontology_config['filepath'], format=ontology_config['format'])
if owl_content:
g.parse(data=owl_content, format=ontology_config['format'])
else:
g.parse(ontology_config['filepath'], format=ontology_config['format'])

# Namespace pour rdfs
RDFS = Namespace("http://www.w3.org/2000/01/rdf-schema#")
Expand Down Expand Up @@ -180,15 +212,6 @@ def build_corpus(
break
nb_record+=1

df = pd.DataFrame({
'ontology' : [ ele['ontology'] for ele in tags ],
'term' : [ ele['term'] for ele in tags ],
'rdfs:label': [ ele['rdfs_label'] for ele in tags ],
'description': [ ele['description'] for ele in tags ],
})

df.to_csv(self.retention_dir+f"/tags_owl_{ontology}.csv", index=False)
self.mem.save_pth(self.mem.encode_tags(tags),tags_owl_path_filename)
return tags

def manage_tags(self):
Expand Down
198 changes: 198 additions & 0 deletions main_msd_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""
Exemple d'exécution :
python main_msd_spark.py config.json
spark-submit \
--py-files <chemin/vers/vos/dependances.zip> \
--files <chemin/vers/votre/fichier/de/configuration.json> \
main_msd_spark.py <chemin/vers/votre/fichier/de/configuration.json>
"""

import os
import json
import sys
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, PandasUDFType, col, udf
from pyspark.sql.types import ArrayType, FloatType, StringType, StructType, StructField

import argparse
from llm_semantic_annotator import ModelEmbeddingManager, OwlTagManager

# Définition des schémas
schema_abstracts = StructType([
StructField("doi", StringType()),
StructField("embedding", ArrayType(FloatType()))
])

schema_tags = StructType([
StructField("term", StringType()),
StructField("ontology", StringType()),
StructField("label", StringType()),
StructField("group", StringType()),
StructField("embedding", ArrayType(FloatType()))
])

def create_encode_abstracts_pandas(config_dict):
@pandas_udf(schema_abstracts, PandasUDFType.GROUPED_MAP)
def encode_abstracts_pandas(key, pdf):
mem = ModelEmbeddingManager(config_dict)
abstracts = [{"doi": row.doi, "title": row.title, "abstract": row.abstract} for _, row in pdf.iterrows()]
embeddings = mem.encode_abstracts(abstracts)
result = [{"doi": doi, "embedding": emb.tolist()} for doi, emb_list in embeddings.items() for emb in emb_list]
return pd.DataFrame(result)
return encode_abstracts_pandas

def create_encode_tags_pandas(config_dict):
@pandas_udf(schema_tags, PandasUDFType.GROUPED_MAP)
def encode_tags_pandas(key, pdf):
mem = ModelEmbeddingManager(config_dict)
tags = [{
"ontology": row.ontology,
"term": row.term,
"rdfs_label": row.rdfs_label,
"description": row.description,
"group": row.group
} for _, row in pdf.iterrows()]
tags_embedding = mem.encode_tags(tags)
result = [{
"term": term,
"ontology": data['ontology'],
"label": data['label'],
"group": data['group'],
"embedding": data['emb'].tolist()
} for term, data in tags_embedding.items()]
return pd.DataFrame(result)
return encode_tags_pandas

def cosine_similarity(vec1, vec2):
if vec1 is None or vec2 is None:
return None
a, b = np.array(vec1), np.array(vec2)
if a.size == 0 or b.size == 0 or a.shape[0] != b.shape[0]:
return None
cosine_sim = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
return float(cosine_sim) if np.isfinite(cosine_sim) else None

def check_config(config):
if 'populate_abstract_embeddings' not in config:
print("Error: 'populate_abstract_embeddings' parameter is missing in the configuration file.")
sys.exit(1)

if 'from_file' not in config['populate_abstract_embeddings']:
print("Error: 'from_file' parameter is missing in the configuration file.")
sys.exit(1)

if 'populate_owl_tag_embeddings' not in config:
print("Error: 'populate_owl_tag_embeddings' parameter is missing in the configuration file.")
sys.exit(1)

if 'ontologies' not in config['populate_owl_tag_embeddings']:
print("Error: 'ontologies' parameter is missing in the configuration file.")
sys.exit(1)

def get_abstracts_from_config(config):
abstracts = []
from_file = config['populate_abstract_embeddings']['from_file']

if 'json_dir' in from_file:
json_dirs = from_file['json_dir']
abstracts.extend([json_dirs] if isinstance(json_dirs, str) else json_dirs)

if 'json_file' in from_file:
json_files = from_file['json_file']
abstracts.extend([json_files] if isinstance(json_files, str) else json_files)

if not abstracts:
print("Warning: No JSON directories or files specified for abstracts.")

return abstracts

def create_spark_session():
return SparkSession.builder \
.appName("MetabolomicsSemanticsDL_Annotation") \
.getOrCreate()

def main(config_file):
with open(config_file, 'r') as f:
config = json.load(f)

check_config(config)

spark = create_spark_session()

root_workdir = config_file.split("/").pop().split(".json")[0] + "_workdir/spark"
print("root:", root_workdir)
parquet_abstracts_path = root_workdir + "/abstracts_embeddings"
parquet_tags_path = root_workdir + "/tags_embeddings"
results = root_workdir + "/results"

if os.path.exists(parquet_abstracts_path):
print("Chargement des embeddings d'abstracts à partir du fichier Parquet existant.")
result_df_doi = spark.read.parquet(parquet_abstracts_path)
else:
abstracts = get_abstracts_from_config(config)

df = spark.read.json(abstracts)
encode_abstracts_pandas_udf = create_encode_abstracts_pandas(config)
result_df_doi = df.groupBy("doi").apply(encode_abstracts_pandas_udf)
result_df_doi.write.mode("overwrite").parquet(parquet_abstracts_path)

if os.path.exists(parquet_tags_path):
print("Chargement des embeddings de tags à partir du fichier Parquet existant.")
spark_df_tags = spark.read.parquet(parquet_tags_path)
else:
encode_tags_pandas_udf = create_encode_tags_pandas(config)
mem = ModelEmbeddingManager(config)
tag_manager = OwlTagManager(config['populate_owl_tag_embeddings'], mem)

tags_list = []

for ontology_group_name,ontologies in config['populate_owl_tag_embeddings']['ontologies'].items():
for ontology in tag_manager.get_ontologies(ontologies):

filepath = tag_manager._get_local_filepath_ontology(ontology,ontologies[ontology]['format'])
# permet de lire le contenu du fichier owl qui peut se trouver sur le cluster hadoop
owl_content = spark.sparkContext.wholeTextFiles(filepath).values().collect()[0]

tags_list.extend(
tag_manager.build_tags_from_owl(
ontology,
ontology_group_name,
ontologies[ontology],
-1,owl_content=owl_content)
)

spark_df_tags = spark.createDataFrame(tags_list)
result_df_tags = spark_df_tags.groupBy("term").apply(encode_tags_pandas_udf)
result_df_tags = result_df_tags.withColumnRenamed('term', 'tag')
spark_df_tags = result_df_tags
spark_df_tags.write.mode("overwrite").parquet(parquet_tags_path)

result_df_doi = result_df_doi.withColumnRenamed("embedding", "abstract_embedding")
spark_df_tags = spark_df_tags.withColumnRenamed("embedding", "tag_embedding")

print(f"Nombre d'abstracts: {result_df_doi.count()}")
print(f"Nombre de tags: {spark_df_tags.count()}")

cosine_similarity_udf = udf(cosine_similarity, FloatType())
try:
result_df = result_df_doi.crossJoin(spark_df_tags) \
.withColumn("similarity", cosine_similarity_udf(col("abstract_embedding"), col("tag_embedding"))) \
.select("doi", "tag", "similarity") \
.filter(col("similarity") >= config["threshold_similarity_tag_chunk"])

result_df.show(truncate=False)
result_df.write.mode("overwrite").parquet(results)

except Exception as e:
print(f"Une erreur s'est produite lors du calcul des similarités : {str(e)}")

spark.stop()

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Metabolomics Semantics DL Annotation")
parser.add_argument("config_file", help="Path to the configuration file")
args = parser.parse_args()
main(args.config_file)
Loading

0 comments on commit d5647eb

Please sign in to comment.