Skip to content

Commit

Permalink
feat: add cluster caching
Browse files Browse the repository at this point in the history
  • Loading branch information
nsantacruz committed Apr 9, 2024
1 parent d38cc0e commit eeef931
Showing 1 changed file with 48 additions and 7 deletions.
55 changes: 48 additions & 7 deletions app/topic_source_curation/summarize_and_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
- For each cluster chosen, choose one source that is the most interesting / fulfills category quota
- Needs thought
"""
import json
import os
import random
from typing import Any
from typing import Any, Union
from functools import wraps
from sklearn.metrics import silhouette_score, pairwise_distances
from sklearn.cluster import KMeans
from tqdm import tqdm
Expand All @@ -39,7 +42,7 @@
from basic_langchain.embeddings import OpenAIEmbeddings
from basic_langchain.chat_models import ChatOpenAI, ChatAnthropic
from basic_langchain.schema import SystemMessage, HumanMessage
from dataclasses import dataclass
from dataclasses import dataclass, asdict

random.seed = 567454

Expand All @@ -49,6 +52,16 @@ class SummarizedSource:
summary: str
embedding: np.ndarray = None

def __init__(self, source: Union[TopicPromptSource, dict], summary: str, embedding: np.ndarray = None):
self.source = source if isinstance(source, TopicPromptSource) else TopicPromptSource(**source)
self.summary = summary
self.embedding = np.array(embedding) if embedding is not None else None

def serialize(self) -> dict:
serial = asdict(self)
serial['embedding'] = self.embedding.tolist()
return serial


@dataclass
class SourceCluster:
Expand All @@ -57,6 +70,12 @@ class SourceCluster:
cluster_summary: str
embedding: np.ndarray = None

def __init__(self, label: int, summarized_sources: list[Union[SummarizedSource, dict]], cluster_summary: str = None, embedding: np.ndarray = None):
self.label = label
self.summarized_sources = [s if isinstance(s, SummarizedSource) else SummarizedSource(**s) for s in summarized_sources]
self.cluster_summary = cluster_summary
self.embedding = np.array(embedding) if embedding is not None else None

def __len__(self):
return len(self.summarized_sources)

Expand All @@ -66,6 +85,14 @@ def __hash__(self):
def __eq__(self, other):
return self.label == other.label

def serialize(self) -> dict:
return {
'label': int(self.label),
'cluster_summary': self.cluster_summary,
'summarized_sources': [s.serialize() for s in self.summarized_sources],
"embedding": self.embedding.tolist(),
}


def summarize_topic_page(curated_topic: CuratedTopic) -> list[SummarizedSource]:
llm = ChatAnthropic(model='claude-3-haiku-20240307', temperature=0)
Expand Down Expand Up @@ -206,7 +233,23 @@ def add_embeddings_to_clusters(clusters: list[SourceCluster]) -> list[SourceClus
cluster.embedding = embedding
return cluster

def summarize_and_embed(curated_topic: CuratedTopic):
def source_cluster_cache(func):
@wraps(func)
def wrapper(curated_topic: CuratedTopic):
cache_filename = f"_cache/source-clusters-{curated_topic.topic.title['en']}.json"
if os.path.exists(cache_filename):
with open(cache_filename, "r") as fin:
raw_clusters = json.load(fin)
return [SourceCluster(**raw_cluster) for raw_cluster in raw_clusters]
clusters = func(curated_topic)
with open(cache_filename, "w") as fout:
json.dump([c.serialize() for c in clusters], fout, ensure_ascii=False)
return clusters
return wrapper


@source_cluster_cache
def cluster_by_subtopic(curated_topic: CuratedTopic) -> list[SourceCluster]:
# make unique
source_by_ref = {s.ref: s for s in curated_topic.sources}
curated_topic.sources = list(source_by_ref.values())
Expand All @@ -217,14 +260,12 @@ def summarize_and_embed(curated_topic: CuratedTopic):
source_clusters = get_source_clusters(n_clusters, curated_topic.topic, summarized_sources, 3)
print(f"Num source clusters: {len(source_clusters)}")
# sort_by_interest_to_newcomer(cluster_summaries)
source_clusters = sort_by_highest_avg_pairwise_distance(source_clusters)
for cluster in source_clusters:
print(cluster.cluster_summary)
return sort_by_highest_avg_pairwise_distance(source_clusters)


if __name__ == '__main__':
verbose = True
topic_pages = get_exported_topic_pages()
summarize_and_embed(topic_pages[1])
cluster_by_subtopic(topic_pages[1])


0 comments on commit eeef931

Please sign in to comment.