Skip to content

Commit

Permalink
feat: improve random seed setting. improve summary so it can potentia…
Browse files Browse the repository at this point in the history
…lly know if text isn't relevant to topic
  • Loading branch information
nsantacruz committed Apr 11, 2024
1 parent 0981a29 commit 33a0535
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions app/topic_source_curation/summarize_and_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
from basic_langchain.schema import SystemMessage, HumanMessage
from dataclasses import dataclass, asdict

random.seed(567454)
RANDOM_SEED = 567454
random.seed(RANDOM_SEED)

@dataclass
class SummarizedSource:
Expand Down Expand Up @@ -100,7 +101,13 @@ def summarize_topic_page(curated_topic: CuratedTopic) -> list[SummarizedSource]:
topic_str = f"Title: '{topic.title}'. Description: '{topic.description.get('en', 'N/A')}'."
summaries: list[SummarizedSource] = []
for source in tqdm(curated_topic.sources, desc=f'summarize_topic_page: {topic.title["en"]}', disable=not verbose):
summary = summarize_based_on_uniqueness(source.text.get('en', source.text['he']), topic_str, llm)
source_text = source.text['en'] if len(source.text['en']) > 0 else source.text['he']
if len(source_text) == 0:
continue
summary = summarize_based_on_uniqueness(source_text, topic_str, llm, "English")
if summary is None:
print("No summary: {}".format(source_text))
continue
summaries.append(SummarizedSource(source, summary))
return summaries

Expand All @@ -116,9 +123,9 @@ def guess_optimal_clustering(embeddings):
best_num_clusters = 0
n_clusters = range(2, len(embeddings)//2)
for n_cluster in tqdm(n_clusters, total=len(n_clusters), desc='guess optimal clustering', disable=not verbose):
kmeans = KMeans(n_clusters=n_cluster, n_init='auto', random_state=random.seed).fit(embeddings)
kmeans = KMeans(n_clusters=n_cluster, n_init='auto', random_state=RANDOM_SEED).fit(embeddings)
labels = kmeans.labels_
sil_coeff = silhouette_score(embeddings, labels, metric='cosine', random_state=random.seed)
sil_coeff = silhouette_score(embeddings, labels, metric='cosine', random_state=RANDOM_SEED)
if sil_coeff > best_sil_coeff:
best_sil_coeff = sil_coeff
best_num_clusters = n_cluster
Expand Down Expand Up @@ -207,14 +214,14 @@ def sort_by_interest_to_newcomer(documents: list[str]):

def get_source_clusters(n_clusters: int, topic: Topic, summarized_sources: list[SummarizedSource], min_size: int) -> list[SourceCluster]:
embeddings = [s.embedding for s in summarized_sources]
kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=random.seed).fit(embeddings)
kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=RANDOM_SEED).fit(embeddings)
clusters = []
for label in tqdm(set(kmeans.labels_), desc="summarizing clusters", disable=not verbose):
indices = np.where(kmeans.labels_ == label)[0]
curr_sources = [summarized_sources[j] for j in indices]
if len(curr_sources) < min_size:
continue
cluster_summary = summarize_cluster(topic, curr_sources)
cluster_summary = curr_sources[0].summary if len(curr_sources) == 1 else summarize_cluster(topic, curr_sources)
clusters += [SourceCluster(label, curr_sources, cluster_summary)]
clusters = add_embeddings_to_clusters(clusters)
return clusters
Expand Down Expand Up @@ -255,7 +262,7 @@ def cluster_by_subtopic(curated_topic: CuratedTopic) -> list[SourceCluster]:
summarized_sources = add_embeddings_to_sources(summarized_sources)
n_clusters = guess_optimal_clustering([s.embedding for s in summarized_sources])
print(f"Optimal Clustering: {n_clusters}")
source_clusters = get_source_clusters(n_clusters, curated_topic.topic, summarized_sources, 3)
source_clusters = get_source_clusters(n_clusters, curated_topic.topic, summarized_sources, 1)
print(f"Num source clusters: {len(source_clusters)}")
# sort_by_interest_to_newcomer(cluster_summaries)
return sort_by_highest_avg_pairwise_distance(source_clusters)
Expand All @@ -265,6 +272,7 @@ def get_cluster_diversity_for_dataset(clusters: list[SourceCluster]) -> None:
from topic_source_curation.common import get_datasets
from collections import defaultdict
source_to_label = {s.source.ref: c.label for c in clusters for s in c.summarized_sources}
label_to_cluster = {c.label: c for c in clusters}
bad, good = get_datasets()
for example in good:
if example.topic.title['en'] != "Shabbat":
Expand All @@ -274,21 +282,24 @@ def get_cluster_diversity_for_dataset(clusters: list[SourceCluster]) -> None:
for source in example.sources:
try:
cluster_diversity[source_to_label[source.ref]] += [source.ref]
print(f"Source {source.ref} found")
except KeyError:
print(f"Source {source.ref} not found")
print(len(cluster_diversity)/len(example.sources))
for label, refs in cluster_diversity.items():
print(f"{label}({len(label_to_cluster[label])}): {label_to_cluster[label].cluster_summary}")
for ref in refs:
print('\t', ref)


if __name__ == '__main__':
verbose = True
topic_pages = get_exported_topic_pages()
clusters = cluster_by_subtopic(topic_pages[1])
get_cluster_diversity_for_dataset(clusters)
# for c in clusters:
# print('----')
# print(c.cluster_summary)
# for s in c.summarized_sources[:10]:
# print('\t', s.source.ref)
clusters = cluster_by_subtopic(topic_pages[2])
# get_cluster_diversity_for_dataset(clusters)
for c in clusters:
print('----')
print(c.cluster_summary)
for s in c.summarized_sources[:10]:
print('\t', s.source.ref)


0 comments on commit 33a0535

Please sign in to comment.