From 33a053524e1652264936a970937f4c68aab2bb14 Mon Sep 17 00:00:00 2001 From: nsantacruz Date: Thu, 11 Apr 2024 14:19:30 +0300 Subject: [PATCH] feat: improve random seed setting. improve summary so it can potentially know if text isn't relevant to topic --- .../summarize_and_embed.py | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/app/topic_source_curation/summarize_and_embed.py b/app/topic_source_curation/summarize_and_embed.py index e84f657..ad87f09 100644 --- a/app/topic_source_curation/summarize_and_embed.py +++ b/app/topic_source_curation/summarize_and_embed.py @@ -44,7 +44,8 @@ from basic_langchain.schema import SystemMessage, HumanMessage from dataclasses import dataclass, asdict -random.seed(567454) +RANDOM_SEED = 567454 +random.seed(RANDOM_SEED) @dataclass class SummarizedSource: @@ -100,7 +101,13 @@ def summarize_topic_page(curated_topic: CuratedTopic) -> list[SummarizedSource]: topic_str = f"Title: '{topic.title}'. Description: '{topic.description.get('en', 'N/A')}'." summaries: list[SummarizedSource] = [] for source in tqdm(curated_topic.sources, desc=f'summarize_topic_page: {topic.title["en"]}', disable=not verbose): - summary = summarize_based_on_uniqueness(source.text.get('en', source.text['he']), topic_str, llm) + source_text = source.text['en'] if len(source.text['en']) > 0 else source.text['he'] + if len(source_text) == 0: + continue + summary = summarize_based_on_uniqueness(source_text, topic_str, llm, "English") + if summary is None: + print("No summary: {}".format(source_text)) + continue summaries.append(SummarizedSource(source, summary)) return summaries @@ -116,9 +123,9 @@ def guess_optimal_clustering(embeddings): best_num_clusters = 0 n_clusters = range(2, len(embeddings)//2) for n_cluster in tqdm(n_clusters, total=len(n_clusters), desc='guess optimal clustering', disable=not verbose): - kmeans = KMeans(n_clusters=n_cluster, n_init='auto', random_state=random.seed).fit(embeddings) + kmeans = KMeans(n_clusters=n_cluster, n_init='auto', random_state=RANDOM_SEED).fit(embeddings) labels = kmeans.labels_ - sil_coeff = silhouette_score(embeddings, labels, metric='cosine', random_state=random.seed) + sil_coeff = silhouette_score(embeddings, labels, metric='cosine', random_state=RANDOM_SEED) if sil_coeff > best_sil_coeff: best_sil_coeff = sil_coeff best_num_clusters = n_cluster @@ -207,14 +214,14 @@ def sort_by_interest_to_newcomer(documents: list[str]): def get_source_clusters(n_clusters: int, topic: Topic, summarized_sources: list[SummarizedSource], min_size: int) -> list[SourceCluster]: embeddings = [s.embedding for s in summarized_sources] - kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=random.seed).fit(embeddings) + kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=RANDOM_SEED).fit(embeddings) clusters = [] for label in tqdm(set(kmeans.labels_), desc="summarizing clusters", disable=not verbose): indices = np.where(kmeans.labels_ == label)[0] curr_sources = [summarized_sources[j] for j in indices] if len(curr_sources) < min_size: continue - cluster_summary = summarize_cluster(topic, curr_sources) + cluster_summary = curr_sources[0].summary if len(curr_sources) == 1 else summarize_cluster(topic, curr_sources) clusters += [SourceCluster(label, curr_sources, cluster_summary)] clusters = add_embeddings_to_clusters(clusters) return clusters @@ -255,7 +262,7 @@ def cluster_by_subtopic(curated_topic: CuratedTopic) -> list[SourceCluster]: summarized_sources = add_embeddings_to_sources(summarized_sources) n_clusters = guess_optimal_clustering([s.embedding for s in summarized_sources]) print(f"Optimal Clustering: {n_clusters}") - source_clusters = get_source_clusters(n_clusters, curated_topic.topic, summarized_sources, 3) + source_clusters = get_source_clusters(n_clusters, curated_topic.topic, summarized_sources, 1) print(f"Num source clusters: {len(source_clusters)}") # sort_by_interest_to_newcomer(cluster_summaries) return sort_by_highest_avg_pairwise_distance(source_clusters) @@ -265,6 +272,7 @@ def get_cluster_diversity_for_dataset(clusters: list[SourceCluster]) -> None: from topic_source_curation.common import get_datasets from collections import defaultdict source_to_label = {s.source.ref: c.label for c in clusters for s in c.summarized_sources} + label_to_cluster = {c.label: c for c in clusters} bad, good = get_datasets() for example in good: if example.topic.title['en'] != "Shabbat": @@ -274,21 +282,24 @@ def get_cluster_diversity_for_dataset(clusters: list[SourceCluster]) -> None: for source in example.sources: try: cluster_diversity[source_to_label[source.ref]] += [source.ref] - print(f"Source {source.ref} found") except KeyError: print(f"Source {source.ref} not found") print(len(cluster_diversity)/len(example.sources)) + for label, refs in cluster_diversity.items(): + print(f"{label}({len(label_to_cluster[label])}): {label_to_cluster[label].cluster_summary}") + for ref in refs: + print('\t', ref) if __name__ == '__main__': verbose = True topic_pages = get_exported_topic_pages() - clusters = cluster_by_subtopic(topic_pages[1]) - get_cluster_diversity_for_dataset(clusters) - # for c in clusters: - # print('----') - # print(c.cluster_summary) - # for s in c.summarized_sources[:10]: - # print('\t', s.source.ref) + clusters = cluster_by_subtopic(topic_pages[2]) + # get_cluster_diversity_for_dataset(clusters) + for c in clusters: + print('----') + print(c.cluster_summary) + for s in c.summarized_sources[:10]: + print('\t', s.source.ref)