Skip to content

Commit

Permalink
Merge pull request #9 from harmonydata/search-instruments
Browse files Browse the repository at this point in the history
Search instruments
  • Loading branch information
woodthom2 authored Sep 6, 2024
2 parents 2f1a8f7 + 23c8086 commit e33911a
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 58 deletions.
147 changes: 101 additions & 46 deletions harmony_api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_example_instruments() -> List[Instrument]:
example_instruments = []

with open(
str(os.getcwd()) + "/example_questionnaires.json", "r", encoding="utf-8"
str(os.getcwd()) + "/example_questionnaires.json", "r", encoding="utf-8"
) as file:
for line in file:
instrument = Instrument.model_validate_json(line)
Expand Down Expand Up @@ -108,24 +108,24 @@ def get_mhc_embeddings(model_name: str) -> tuple:
data_path = os.path.join(dir_path, "../mhc_embeddings") # submodule

with open(
os.path.join(data_path, "mhc_questions.txt"), "r", encoding="utf-8"
os.path.join(data_path, "mhc_questions.txt"), "r", encoding="utf-8"
) as file:
for line in file:
mhc_question = Question(question_text=line)
mhc_questions.append(mhc_question)

with open(
os.path.join(data_path, "mhc_all_metadatas.json"), "r", encoding="utf-8"
os.path.join(data_path, "mhc_all_metadatas.json"), "r", encoding="utf-8"
) as file:
for line in file:
mhc_meta = json.loads(line)
mhc_all_metadata.append(mhc_meta)

with open(
os.path.join(
data_path, f"mhc_embeddings_{model_name.replace('/', '-')}.npy"
),
"rb",
os.path.join(
data_path, f"mhc_embeddings_{model_name.replace('/', '-')}.npy"
),
"rb",
) as file:
mhc_embeddings = np.load(file, allow_pickle=True)
except (Exception,) as e:
Expand Down Expand Up @@ -153,8 +153,8 @@ def get_catalogue_data_default() -> dict:
else:
if settings.AZURE_STORAGE_URL:
with requests.get(
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{all_questions_ever_seen_json}",
stream=True,
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{all_questions_ever_seen_json}",
stream=True,
) as response:
if response.ok:
buffer = BytesIO()
Expand All @@ -171,8 +171,8 @@ def get_catalogue_data_default() -> dict:
else:
if settings.AZURE_STORAGE_URL:
with requests.get(
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{instrument_idx_to_question_idxs_json}",
stream=True,
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{instrument_idx_to_question_idxs_json}",
stream=True,
) as response:
if response.ok:
buffer = BytesIO()
Expand All @@ -193,8 +193,8 @@ def get_catalogue_data_default() -> dict:
else:
if settings.AZURE_STORAGE_URL:
with requests.get(
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{all_instruments_preprocessed_json}",
stream=True,
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{all_instruments_preprocessed_json}",
stream=True,
) as response:
if response.ok:
buffer = BytesIO()
Expand Down Expand Up @@ -237,8 +237,8 @@ def get_catalogue_data_model_embeddings(model: dict) -> np.ndarray:
decompressor_results = []
decompressor = bz2.BZ2Decompressor()
with requests.get(
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{embeddings_filename}",
stream=True,
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{embeddings_filename}",
stream=True,
) as response:
if response.ok:
for chunk in response.iter_content(chunk_size=1024):
Expand All @@ -252,12 +252,22 @@ def get_catalogue_data_model_embeddings(model: dict) -> np.ndarray:
return all_embeddings_concatenated


def filter_catalogue_data(catalogue_data: dict, sources: List[str]) -> dict:
def filter_catalogue_data(
catalogue_data: dict,
sources: List[str] | None = None,
topics: List[str] | None = None,
instrument_length_min: int | None = None,
instrument_length_max: int | None = None,
) -> dict:
"""
Filter catalogue data to only keep instruments with the sources.
:param catalogue_data: Catalogue data.
:param sources: Only keep instruments from sources.
:param topics: Only keep instruments with these topics. Topics can be found in the metadata of each instrument.
:param instrument_length_min: Only keep instruments with min number of questions.
:param instrument_length_max: Only keep instruments with max number of questions.
:return: The filtered catalogue data.
"""

def normalize_text(text: str):
Expand All @@ -266,17 +276,30 @@ def normalize_text(text: str):

return text

# Lowercase sources
sources_set = {x.strip().lower() for x in sources if x.strip()}
if not sources:
sources = []
if not topics:
topics = []

# If the value for any of these is less than 1, set it to 1
if instrument_length_min and (instrument_length_min < 1):
instrument_length_min = 1
if instrument_length_max and (instrument_length_max < 1):
instrument_length_max = 1

# Nothing to filter
if not sources_set:
return catalogue_data
# If min length is bigger than max length, set the min length to equal the max length
if instrument_length_min and instrument_length_max:
if instrument_length_min > instrument_length_max:
instrument_length_min = instrument_length_max

# Lowercase sources and topics
sources_set = {x.strip().lower() for x in sources if x.strip()}
topics_set = {x.strip().lower() for x in topics if x.strip()}

# Create a dictionary with questions and their vectors
question_normalized_to_vector: dict[str, List[float]] = {}
for question, vector in zip(
catalogue_data["all_questions"], catalogue_data["all_embeddings_concatenated"]
catalogue_data["all_questions"], catalogue_data["all_embeddings_concatenated"]
):
question_normalized = normalize_text(question)
if question_normalized not in question_normalized_to_vector:
Expand All @@ -285,13 +308,45 @@ def normalize_text(text: str):
# Find instrument indexes to remove
idxs_instruments_to_remove: List[int] = []
for instrument_idx, catalogue_instrument in enumerate(
catalogue_data["all_instruments"]
catalogue_data["all_instruments"]
):
if (
catalogue_instrument["metadata"]["source"].strip().lower()
not in sources_set
):
idxs_instruments_to_remove.append(instrument_idx)
questions_len = len(catalogue_instrument["questions"])

# By min instrument questions length
if instrument_length_min:
if questions_len < instrument_length_min:
idxs_instruments_to_remove.append(instrument_idx)
continue

# By max instrument questions length
if instrument_length_max:
if questions_len > instrument_length_max:
idxs_instruments_to_remove.append(instrument_idx)
continue

# By sources
if sources_set:
if (
catalogue_instrument["metadata"]["source"].strip().lower()
not in sources_set
):
idxs_instruments_to_remove.append(instrument_idx)
continue

# By topics
if topics_set:
not_found_topics_len = 0
catalogue_instrument_topics: list[str] = catalogue_instrument[
"metadata"
].get("topics", [])
for topic in topics_set:
if topic not in [
x.strip().lower() for x in catalogue_instrument_topics if x.strip()
]:
not_found_topics_len += 1
if not_found_topics_len == len(topics_set):
idxs_instruments_to_remove.append(instrument_idx)
continue

# Remove instruments
for idx_instrument_to_remove in sorted(idxs_instruments_to_remove, reverse=True):
Expand Down Expand Up @@ -375,7 +430,7 @@ def check_model_availability(model: dict) -> bool:


def get_cached_text_vectors(
instruments: List[Instrument], model: dict, query: str | None = None
instruments: List[Instrument], model: dict, query: str | None = None
) -> dict[str, List[float]]:
"""
Get cached text vectors.
Expand Down Expand Up @@ -432,63 +487,63 @@ def get_vectorisation_function_for_model(model: dict) -> Callable | None:
vectorisation_function: Callable | None = None

if (
model["framework"] == HUGGINGFACE_MINILM_L12_V2["framework"]
and model["model"] == HUGGINGFACE_MINILM_L12_V2["model"]
model["framework"] == HUGGINGFACE_MINILM_L12_V2["framework"]
and model["model"] == HUGGINGFACE_MINILM_L12_V2["model"]
):
vectorisation_function = (
hugging_face_embeddings.get_hugging_face_embeddings_minilm_l12_v2
)

elif (
model["framework"] == HUGGINGFACE_MPNET_BASE_V2["framework"]
and model["model"] == HUGGINGFACE_MPNET_BASE_V2["model"]
model["framework"] == HUGGINGFACE_MPNET_BASE_V2["framework"]
and model["model"] == HUGGINGFACE_MPNET_BASE_V2["model"]
):
vectorisation_function = (
hugging_face_embeddings.get_hugging_face_embeddings_mpnet_base_v2
)

elif (
model["framework"] == OPENAI_ADA_02["framework"]
and model["model"] == OPENAI_ADA_02["model"]
model["framework"] == OPENAI_ADA_02["framework"]
and model["model"] == OPENAI_ADA_02["model"]
):
vectorisation_function = openai_embeddings.get_openai_embeddings_ada_02
elif (
model["framework"] == OPENAI_3_LARGE["framework"]
and model["model"] == OPENAI_3_LARGE["model"]
model["framework"] == OPENAI_3_LARGE["framework"]
and model["model"] == OPENAI_3_LARGE["model"]
):
vectorisation_function = openai_embeddings.get_openai_embeddings_3_large
elif (
model["framework"] == AZURE_OPENAI_3_LARGE["framework"]
and model["model"] == AZURE_OPENAI_3_LARGE["model"]
model["framework"] == AZURE_OPENAI_3_LARGE["framework"]
and model["model"] == AZURE_OPENAI_3_LARGE["model"]
):
vectorisation_function = (
azure_openai_embeddings.get_azure_openai_embeddings_3_large
)
elif (
model["framework"] == AZURE_OPENAI_ADA_02["framework"]
and model["model"] == AZURE_OPENAI_ADA_02["model"]
model["framework"] == AZURE_OPENAI_ADA_02["framework"]
and model["model"] == AZURE_OPENAI_ADA_02["model"]
):
vectorisation_function = (
azure_openai_embeddings.get_azure_openai_embeddings_ada_02
)
elif (
model["framework"] == GOOGLE_GECKO_MULTILINGUAL["framework"]
and model["model"] == GOOGLE_GECKO_MULTILINGUAL["model"]
model["framework"] == GOOGLE_GECKO_MULTILINGUAL["framework"]
and model["model"] == GOOGLE_GECKO_MULTILINGUAL["model"]
):
vectorisation_function = (
google_embeddings.get_google_embeddings_gecko_multilingual
)
elif (
model["framework"] == GOOGLE_GECKO_003["framework"]
and model["model"] == GOOGLE_GECKO_003["model"]
model["framework"] == GOOGLE_GECKO_003["framework"]
and model["model"] == GOOGLE_GECKO_003["model"]
):
vectorisation_function = google_embeddings.get_google_embeddings_gecko_003

return vectorisation_function


def assign_missing_ids_to_instruments(
instruments: List[Instrument],
instruments: List[Instrument],
) -> List[Instrument]:
"""
Assign missing IDs to instruments.
Expand Down
Loading

0 comments on commit e33911a

Please sign in to comment.