Skip to content

Commit

Permalink
implement chi square test (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmillikin authored Sep 22, 2023
1 parent 112161b commit 22d5c8a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
13 changes: 13 additions & 0 deletions src/tests/test_kinderminer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ def test_fisher_exact_test():
sort_ratio = km.get_sort_ratio(table)
assert sort_ratio == pytest.approx(15 / 59)

def test_chisq_pvalue():
table = [[10, 3000], [2000, 10000000]]
pvalue = km.chi_square(table)
assert pvalue == pytest.approx(2.583e-30, abs=1e-30)

table = [[1, 3000], [2000, 10000000]]
pvalue = km.chi_square(table)
assert pvalue == 1

table = [[0, 100], [0, 10000000]]
pvalue = km.chi_square(table)
assert pvalue == 1

def test_text_sanitation():
text = 'Testing123****.'
sanitized_text = index.sanitize_term(text)
Expand Down
20 changes: 17 additions & 3 deletions src/workers/kinderminer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,24 @@ def get_contingency_table(a_term_set: set, b_term_set: set, total_n: int):
def fisher_exact(table) -> float:
return scipy.stats.fisher_exact(table, fet_sided)[1]

def chi_square(table) -> float:
try:
return scipy.stats.chi2_contingency(table, fet_sided)[1]
except ValueError:
# default to a p-value of 1.0
# this happens if the sum of a row or column is 0
return 1.0

def get_sort_ratio(table) -> float:
denom = (table[0][0] + table[1][0])
if denom == 0:
return 0 # TODO?

return table[0][0] / denom

def kinderminer_search(a_term: str, b_term: str, idx: Index, censor_year = math.inf, return_pmids = False, top_n_articles = math.inf) -> dict:
def kinderminer_search(a_term: str, b_term: str, idx: Index, censor_year = math.inf,
return_pmids = False, top_n_articles = math.inf,
scoring = 'fet') -> dict:
""""""
start_time = time.perf_counter()
result = dict()
Expand All @@ -48,8 +58,12 @@ def kinderminer_search(a_term: str, b_term: str, idx: Index, censor_year = math.
n_a_and_b = table[0][0]
n_articles = idx.n_articles(censor_year)

# perform fisher's exact test
pvalue = fisher_exact(table)
# perform statistical test (default fisher's exact test)
if scoring == 'chi-square':
pvalue = chi_square(table)
else: # 'fet'
pvalue = fisher_exact(table)

sort_ratio = get_sort_ratio(table)

run_time = time.perf_counter() - start_time
Expand Down
5 changes: 3 additions & 2 deletions src/workers/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def km_work_all_vs_all(json: dict):

a_terms = json['a_terms']
b_terms = json['b_terms']
scoring = json.get('scoring', 'fet')

if 'c_terms' in json:
# SKiM query
Expand Down Expand Up @@ -107,7 +108,7 @@ def km_work_all_vs_all(json: dict):
b_term = li.the_index.get_highest_priority_term(b_term_set, b_term_token_dict)
b_term_set.remove(b_term)

res = km.kinderminer_search(a_term, b_term, li.the_index, censor_year, return_pmids, top_n_articles)
res = km.kinderminer_search(a_term, b_term, li.the_index, censor_year, return_pmids, top_n_articles, scoring)

if res['pvalue'] <= ab_fet_threshold:
ab_results.append(res)
Expand Down Expand Up @@ -189,7 +190,7 @@ def km_work_all_vs_all(json: dict):
if b_term == c_term:
continue

bc = km.kinderminer_search(b_term, c_term, li.the_index, censor_year, return_pmids, top_n_articles)
bc = km.kinderminer_search(b_term, c_term, li.the_index, censor_year, return_pmids, top_n_articles, scoring)

abc_result['c_term'] = c_term
abc_result['bc_pvalue'] = bc['pvalue']
Expand Down

0 comments on commit 22d5c8a

Please sign in to comment.