From eeef93196a44dd17a561c9d0592193ff73ba37b8 Mon Sep 17 00:00:00 2001 From: nsantacruz Date: Tue, 9 Apr 2024 14:28:35 +0300 Subject: [PATCH] feat: add cluster caching --- .../summarize_and_embed.py | 55 ++++++++++++++++--- 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/app/topic_source_curation/summarize_and_embed.py b/app/topic_source_curation/summarize_and_embed.py index 06cbc75..fb189ad 100644 --- a/app/topic_source_curation/summarize_and_embed.py +++ b/app/topic_source_curation/summarize_and_embed.py @@ -24,8 +24,11 @@ - For each cluster chosen, choose one source that is the most interesting / fulfills category quota - Needs thought """ +import json +import os import random -from typing import Any +from typing import Any, Union +from functools import wraps from sklearn.metrics import silhouette_score, pairwise_distances from sklearn.cluster import KMeans from tqdm import tqdm @@ -39,7 +42,7 @@ from basic_langchain.embeddings import OpenAIEmbeddings from basic_langchain.chat_models import ChatOpenAI, ChatAnthropic from basic_langchain.schema import SystemMessage, HumanMessage -from dataclasses import dataclass +from dataclasses import dataclass, asdict random.seed = 567454 @@ -49,6 +52,16 @@ class SummarizedSource: summary: str embedding: np.ndarray = None + def __init__(self, source: Union[TopicPromptSource, dict], summary: str, embedding: np.ndarray = None): + self.source = source if isinstance(source, TopicPromptSource) else TopicPromptSource(**source) + self.summary = summary + self.embedding = np.array(embedding) if embedding is not None else None + + def serialize(self) -> dict: + serial = asdict(self) + serial['embedding'] = self.embedding.tolist() + return serial + @dataclass class SourceCluster: @@ -57,6 +70,12 @@ class SourceCluster: cluster_summary: str embedding: np.ndarray = None + def __init__(self, label: int, summarized_sources: list[Union[SummarizedSource, dict]], cluster_summary: str = None, embedding: np.ndarray = None): + self.label = label + self.summarized_sources = [s if isinstance(s, SummarizedSource) else SummarizedSource(**s) for s in summarized_sources] + self.cluster_summary = cluster_summary + self.embedding = np.array(embedding) if embedding is not None else None + def __len__(self): return len(self.summarized_sources) @@ -66,6 +85,14 @@ def __hash__(self): def __eq__(self, other): return self.label == other.label + def serialize(self) -> dict: + return { + 'label': int(self.label), + 'cluster_summary': self.cluster_summary, + 'summarized_sources': [s.serialize() for s in self.summarized_sources], + "embedding": self.embedding.tolist(), + } + def summarize_topic_page(curated_topic: CuratedTopic) -> list[SummarizedSource]: llm = ChatAnthropic(model='claude-3-haiku-20240307', temperature=0) @@ -206,7 +233,23 @@ def add_embeddings_to_clusters(clusters: list[SourceCluster]) -> list[SourceClus cluster.embedding = embedding return cluster -def summarize_and_embed(curated_topic: CuratedTopic): +def source_cluster_cache(func): + @wraps(func) + def wrapper(curated_topic: CuratedTopic): + cache_filename = f"_cache/source-clusters-{curated_topic.topic.title['en']}.json" + if os.path.exists(cache_filename): + with open(cache_filename, "r") as fin: + raw_clusters = json.load(fin) + return [SourceCluster(**raw_cluster) for raw_cluster in raw_clusters] + clusters = func(curated_topic) + with open(cache_filename, "w") as fout: + json.dump([c.serialize() for c in clusters], fout, ensure_ascii=False) + return clusters + return wrapper + + +@source_cluster_cache +def cluster_by_subtopic(curated_topic: CuratedTopic) -> list[SourceCluster]: # make unique source_by_ref = {s.ref: s for s in curated_topic.sources} curated_topic.sources = list(source_by_ref.values()) @@ -217,14 +260,12 @@ def summarize_and_embed(curated_topic: CuratedTopic): source_clusters = get_source_clusters(n_clusters, curated_topic.topic, summarized_sources, 3) print(f"Num source clusters: {len(source_clusters)}") # sort_by_interest_to_newcomer(cluster_summaries) - source_clusters = sort_by_highest_avg_pairwise_distance(source_clusters) - for cluster in source_clusters: - print(cluster.cluster_summary) + return sort_by_highest_avg_pairwise_distance(source_clusters) if __name__ == '__main__': verbose = True topic_pages = get_exported_topic_pages() - summarize_and_embed(topic_pages[1]) + cluster_by_subtopic(topic_pages[1])