diff --git a/app/util/sefaria_specific.py b/app/util/sefaria_specific.py index 0174c41..c4cade3 100644 --- a/app/util/sefaria_specific.py +++ b/app/util/sefaria_specific.py @@ -58,3 +58,15 @@ def get_ref_text_with_fallback(oref: Ref, lang: str, auto_translate=False) -> st def convert_trefs_to_sources(trefs) -> list[TopicPromptSource]: return [_make_topic_prompt_source(Ref(tref), '', with_commentary=False) for tref in trefs] + + +def remove_refs_from_same_category(refs: list[Ref], max_category_count: int) -> list[Ref]: + from collections import defaultdict + cat_counts = defaultdict(int) + out_refs = [] + for ref in refs: + cat_counts[ref.primary_category] += 1 + if cat_counts[ref.primary_category] > max_category_count: + continue + out_refs.append(ref) + return out_refs diff --git a/app/util/topic.py b/app/util/topic.py index fd6d2da..f9e6f85 100644 --- a/app/util/topic.py +++ b/app/util/topic.py @@ -1,12 +1,13 @@ import django django.setup() from sefaria.model.topic import Topic as SefariaTopic +from sefaria.model.text import Ref from functools import partial from basic_langchain.schema import SystemMessage, HumanMessage from basic_langchain.chat_models import ChatOpenAI, ChatAnthropic from util.general import get_by_xml_tag, run_parallel, summarize_text from util.webpage import get_webpage_text -from util.sefaria_specific import filter_invalid_refs, convert_trefs_to_sources +from util.sefaria_specific import filter_invalid_refs, convert_trefs_to_sources, remove_refs_from_same_category from sefaria_llm_interface.common.topic import Topic from sefaria.helper.topic import get_topic @@ -54,7 +55,8 @@ def get_topic_description_from_webpages(topic: Topic): def get_topic_description_from_top_sources(topic: Topic, verbose=True): - top_trefs = get_top_trefs_from_slug(topic.slug, top_n=5) + top_trefs = get_top_trefs_from_slug(topic.slug, top_n=15) + top_trefs = [r.normal() for r in remove_refs_from_same_category([Ref(tref) for tref in top_trefs], 2)][:6] top_sources = convert_trefs_to_sources(top_trefs) llm = ChatAnthropic(model='claude-3-haiku-20240307', temperature=0) summaries = run_parallel([source.text['en'] for source in top_sources], diff --git a/experiments/topic_source_curation_v2/choose.py b/experiments/topic_source_curation_v2/choose.py index 10b9098..0d3622f 100644 --- a/experiments/topic_source_curation_v2/choose.py +++ b/experiments/topic_source_curation_v2/choose.py @@ -75,6 +75,7 @@ def choose_ideal_clusters(clusters: list[Cluster], max_clusters: int) -> list[Cl def sort_clusters(clusters: list[Cluster], topic:Topic, max_clusters: int) -> list[Cluster]: # sorted_clusters = _sort_by_highest_avg_pairwise_distance(clusters) + print(f"Sorting {len(clusters)} clusters by interestingness...") sorted_clusters = _sort_clusters_by_instruction(clusters, topic) sorted_cluster_items = run_parallel(sorted_clusters, partial(_sort_within_cluster, topic=topic), max_workers=100, desc="Sorting for interestingness within cluster") for cluster, sorted_items in zip(sorted_clusters, sorted_cluster_items): @@ -141,7 +142,7 @@ def gpt_compare(a, b) -> int: def sort_by_instruction(documents, comparison_instruction, key_extraction_func=lambda x:x): from functools import cmp_to_key message_suffix = " The only output should be either '1' or '2' or '0'" - llm = ChatOpenAI(model="gpt-4o", temperature=0) + llm = ChatOpenAI(model='gpt-3.5-turbo-0125', temperature=0) system = SystemMessage( content= comparison_instruction