diff --git a/.env b/.env index 9b63b9d..0bf600b 100644 --- a/.env +++ b/.env @@ -1,5 +1,6 @@ # change this to a directory on your local machine to store pubmed articles PUBMED_DIR=/path/to/pubmed/folder +NEO4J_DIR=/path/to/neo4j/folder # password hash (password is 'password' by default; to change it, you need # to generate a hash yourself using bcrypt and put it here) diff --git a/.gitignore b/.gitignore index 51e40c4..a2c9692 100755 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ project/tests/htmlcov ./.env *Icon* **/Index/** -**venv** \ No newline at end of file +**venv** +src/tests/test_data/indexer/neo4j** \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 32fd1d3..caf6402 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -37,5 +37,18 @@ services: networks: - fast_km-network + neo4j: + image: neo4j + environment: + NEO4J_AUTH: neo4j/mypass + volumes: + - ${NEO4J_DIR}/data:/data + - ${NEO4J_DIR}/logs:/logs + ports: + - "7474:7474" + - "7687:7687" + networks: + - fast_km-network + networks: fast_km-network: diff --git a/requirements.txt b/requirements.txt index 49d51df..95f6e2b 100755 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,6 @@ rq-dashboard==0.6.1 scipy==1.7.1 flask-bcrypt==0.7.1 pymongo==4.0.1 -Werkzeug==2.0.2 \ No newline at end of file +Werkzeug==2.0.2 +neo4j==4.4.5 +py2neo==2021.2.3 \ No newline at end of file diff --git a/src/indexing/index.py b/src/indexing/index.py index 30345bf..855d54c 100644 --- a/src/indexing/index.py +++ b/src/indexing/index.py @@ -10,6 +10,8 @@ from indexing.abstract_catalog import AbstractCatalog delim = '\t' +logical_or = '/' # supports '/' to mean 'or' +logical_and = '&' # supports '&' to mean 'and' mongo_cache = None class Index(): @@ -17,6 +19,7 @@ def __init__(self, pubmed_abstract_dir: str): # caches self._query_cache = dict() self._token_cache = dict() + self._date_censored_query_cache = dict() self._n_articles_by_pub_year = dict() _connect_to_mongo() @@ -26,6 +29,7 @@ def __init__(self, pubmed_abstract_dir: str): self._abstract_catalog = util.get_abstract_catalog(pubmed_abstract_dir) self._byte_offsets = dict() self._publication_years = dict() + self._date_censored_pmids = dict() self._init_byte_info() self._open_connection() @@ -33,41 +37,51 @@ def close_connection(self) -> None: self.connection.close() self.file_obj.close() - def query_index(self, query: str) -> 'set[int]': - query = query.lower().strip() - - if query in self._query_cache: - return self._query_cache[query] + def construct_abstract_set(self, term: str) -> set: + # TODO: support parenthesis for allowing OR and AND at the same time? + # e.g., "(cancer/carcinoma) & BRCA1" + + term = sanitize_term(term) + is_cached, pmid_set = self.check_caches_for_term(term) + + if is_cached: + return pmid_set + + if logical_or in term: + terms = term.split(logical_or) + pmid_set = set() + for synonym in terms: + pmid_set.update(self._query_index(synonym)) + elif logical_and in term: + terms = term.split(logical_and) + pmid_set = self._query_index(terms[0]) + for t in terms[1:]: + pmid_set.intersection_update(self._query_index(t)) else: - result = _check_mongo_for_query(query) - if not isinstance(result, type(None)): - self._query_cache[query] = result - return result + pmid_set = self._query_index(term) - tokens = util.get_tokens(query) + if len(pmid_set) < 10000: + _place_in_mongo(term, pmid_set) + self._query_cache[term] = pmid_set - if len(tokens) > 100: - raise ValueError("Query must have <=100 words") - if not tokens: - return set() + return pmid_set - result = self._query_disk(tokens) + def censor_by_year(self, pmids: 'set[int]', censor_year: int, term: str) -> 'set[int]': + if censor_year not in self._date_censored_pmids: + censored_set = set() - if len(result) < 10000 or len(tokens) > 1: - _place_in_mongo(query, result) + for pmid, year in self._publication_years.items(): + if year <= censor_year: + censored_set.add(pmid) + self._date_censored_pmids[censor_year] = censored_set - self._query_cache[query] = result + if (term, censor_year) in self._date_censored_query_cache: + return self._date_censored_query_cache[(term, censor_year)] + + date_censored_pmid_set = self._date_censored_pmids[censor_year] & pmids + self._date_censored_query_cache[(term, censor_year)] = date_censored_pmid_set - return result - - def censor_by_year(self, pmids: 'set[int]', censor_year: int) -> 'set[int]': - censored_set = set() - - for pmid in pmids: - if self._publication_years[pmid] <= censor_year: - censored_set.add(pmid) - - return censored_set + return date_censored_pmid_set def n_articles(self, censor_year = math.inf) -> int: """Returns the number of indexed abstracts, given an optional @@ -97,6 +111,44 @@ def decache_token(self, token: str): if ltoken in self._query_cache: del self._query_cache[ltoken] + def check_caches_for_term(self, term: str): + if term in self._query_cache: + # check RAM cache + return (True, self._query_cache[term]) + else: + # check mongoDB cache + result = _check_mongo_for_query(term) + if not isinstance(result, type(None)): + self._query_cache[term] = result + return (True, result) + + return (False, None) + + def _query_index(self, query: str) -> 'set[int]': + query = util.sanitize_text(query) + + is_cached, result = self.check_caches_for_term(query) + if is_cached: + return result + + tokens = util.get_tokens(query) + + if len(tokens) > 100: + print("Query failed, must have <=100 words; query was " + query) + return set() + # raise ValueError("Query must have <=100 words") + if not tokens: + return set() + + result = self._query_disk(tokens) + + if len(result) < 10000 or len(tokens) > 1: + _place_in_mongo(query, result) + + self._query_cache[query] = result + + return result + def _open_connection(self) -> None: if not os.path.exists(self._bin_path): print('warning: index does not exist and needs to be built') @@ -217,7 +269,26 @@ def _check_if_mongo_should_be_refreshed(self, terms_to_check: 'list[str]' = ['fe return False -def _intersect_dict_keys(dicts: 'list[dict]'): +def sanitize_term(term: str) -> str: + if logical_or in term or logical_and in term: + sanitized_subterms = [] + + if logical_or in term: + string_joiner = logical_or + elif logical_and in term: + string_joiner = logical_and + + for subterm in term.split(string_joiner): + sanitized_subterms.append(util.sanitize_text(subterm)) + + sanitized_subterms.sort() + sanitized_term = str.join(string_joiner, sanitized_subterms) + else: + sanitized_term = util.sanitize_text(term) + + return sanitized_term + +def _intersect_dict_keys(dicts: 'list[dict]') -> None: lowest_n_keys = sorted(dicts, key=lambda x: len(x))[0] key_intersect = set(lowest_n_keys.keys()) @@ -232,7 +303,7 @@ def _intersect_dict_keys(dicts: 'list[dict]'): return key_intersect -def _connect_to_mongo(): +def _connect_to_mongo() -> None: # TODO: set expiration time for cached items (72h, etc.?) # mongo_cache.create_index('query', unique=True) #expireafterseconds=72 * 60 * 60, global mongo_cache @@ -246,7 +317,7 @@ def _connect_to_mongo(): print('warning: could not find a MongoDB instance to use as a query cache') mongo_cache = None -def _check_mongo_for_query(query: str): +def _check_mongo_for_query(query: str) -> bool: if not isinstance(mongo_cache, type(None)): try: result = mongo_cache.find_one({'query': query}) @@ -261,7 +332,7 @@ def _check_mongo_for_query(query: str): else: return None -def _place_in_mongo(query, result): +def _place_in_mongo(query: str, result: 'set[int]') -> None: if not isinstance(mongo_cache, type(None)): try: mongo_cache.insert_one({'query': query, 'result': list(result)}) @@ -277,7 +348,7 @@ def _place_in_mongo(query, result): else: pass -def _empty_mongo(): +def _empty_mongo() -> None: if not isinstance(mongo_cache, type(None)): x = mongo_cache.delete_many({}) print('mongodb cache cleared, ' + str(x.deleted_count) + ' items were deleted') \ No newline at end of file diff --git a/src/indexing/km_util.py b/src/indexing/km_util.py index 45da689..c7f8f68 100755 --- a/src/indexing/km_util.py +++ b/src/indexing/km_util.py @@ -50,6 +50,9 @@ def get_tokens(text: str) -> 'list[str]': tokens = tokenizer.tokenize(l_text) return tokens +def sanitize_text(text: str) -> str: + return str.join(' ', get_tokens(text)) + def get_index_dir(abstracts_dir: str) -> str: return os.path.join(abstracts_dir, 'Index') diff --git a/src/knowledge_graph/knowledge_graph.py b/src/knowledge_graph/knowledge_graph.py new file mode 100644 index 0000000..6afb1c5 --- /dev/null +++ b/src/knowledge_graph/knowledge_graph.py @@ -0,0 +1,174 @@ +from py2neo import Graph, Node, Relationship, NodeMatcher +from py2neo.bulk import create_nodes, create_relationships, merge_relationships +from itertools import islice +import indexing.km_util as util +import indexing.index as index + +uri="bolt://neo4j:7687" +user = "neo4j" +password = "mypass" + +rel_pvalue_cutoff = 1e-5 +min_pmids_for_rel = 3 + +class KnowledgeGraph: + def __init__(self): + self.query_cache = dict() + + try: + self.graph = Graph(uri, auth=(user, password)) + except: + self.graph = None + print('Could not find a neo4j knowledge graph database') + + def query(self, a_term: str, b_term: str): + if not self.graph: + return [{'a_term': a_term, 'a_type': '', 'relationship': 'neo4j connection error', 'b_term': b_term, 'b_type': '', 'pmids': []}] + + if index.logical_and in a_term or index.logical_and in b_term: + return [self._null_rel_response(a_term, b_term)] + + a_term_stripped = _sanitize_txt(a_term) + b_term_stripped = _sanitize_txt(b_term) + + if (a_term_stripped, b_term_stripped) in self.query_cache: + return self.query_cache[(a_term_stripped, b_term_stripped)] + + # get nodes from the neo4j database + a_matches = self.graph.nodes.match(name=a_term_stripped).all() + b_matches = self.graph.nodes.match(name=b_term_stripped).all() + + # get relationship(s) between a and b nodes + relation_matches = [] + + for a_node in a_matches: + for b_node in b_matches: + ab_rels = self.graph.match(nodes=(a_node, b_node)).all() + relation_matches.extend(ab_rels) + + ba_rels = self.graph.match(nodes=(b_node, a_node)).all() + relation_matches.extend(ba_rels) + + result = [] + + for relation in relation_matches: + # TODO: this is pretty hacky. + # need to find a better way to retrieve node/relation types as strings. + node1_name = relation.nodes[0]['name'] + node1_type = str(relation.nodes[0].labels).strip(':') + + node2_name = relation.nodes[1]['name'] + node2_type = str(relation.nodes[1].labels).strip(':') + + relationship = str(type(relation)).replace("'", "").replace(">", "").split('.')[2] + + relation_json = {'a_term': node1_name, 'a_type': node1_type, 'relationship': relationship, 'b_term': node2_name, 'b_type': node2_type, 'pmids':relation['pmids'][:100]} + result.append(relation_json) + + if not result: + result.append(self._null_rel_response(a_term, b_term)) + + self.query_cache[(a_term_stripped, b_term_stripped)] = result + return result + + def populate(self, path_to_tsv_file: str): + self.graph.delete_all() + + node_types = ['CHEMICAL', 'CONDITION', 'DRUG', 'GGP', 'BIO_PROCESS'] + for node_type in node_types: + try: + self.graph.run("CREATE INDEX ON :" + node_type + "(name)") + except: + pass + + # add nodes + nodes = {} + + with open(path_to_tsv_file, 'r') as f: + for i, line in enumerate(f): + if i == 0: + continue + + spl = line.strip().split('\t') + + node1_name = _sanitize_txt(spl[0]) + node1_type = spl[1] + rel_txt = spl[2] + node2_name = _sanitize_txt(spl[3]) + node2_type = spl[4] + + pmids = spl[len(spl) - 1].strip('}').strip('{') + pmids = [int(x.strip()) for x in pmids.split(',')] + + if len(pmids) < min_pmids_for_rel: + continue + + if node1_type not in nodes: + nodes[node1_type] = set() + if node2_type not in nodes: + nodes[node2_type] = set() + + nodes[node1_type].add(node1_name) + nodes[node2_type].add(node2_name) + + for node_type, nodes_list in nodes.items(): + create_nodes(self.graph.auto(), [[x] for x in nodes_list], labels={node_type}, keys=["name"]) + nodes.clear() + + # add relations + rels = {} + with open(path_to_tsv_file, 'r') as f: + for n_rel, line in enumerate(f): + if n_rel == 0: + continue + + spl = line.strip().split('\t') + + node1_name = _sanitize_txt(spl[0]) + node1_type = spl[1] + rel_txt = spl[2] + node2_name = _sanitize_txt(spl[3]) + node2_type = spl[4] + + pmids = spl[len(spl) - 1].strip('}').strip('{') + pmids = [int(x.strip()) for x in pmids.split(',')] + + if len(pmids) < min_pmids_for_rel: + continue + + category_txt = node1_type + ',' + rel_txt + ',' + node2_type + + if category_txt not in rels: + rels[category_txt] = [] + + rels[category_txt].append(((node1_name), {"pmids": pmids}, (node2_name))) + + if (n_rel + 1) % 20000 == 0: + self._post_rels(rels) + rels.clear() + + self._post_rels(rels) + rels.clear() + + def _post_rels(self, rels: dict): + for rel, rel_nodes in rels.items(): + n1_type = rel.split(',')[0] + r_type = rel.split(',')[1] + n2_type = rel.split(',')[2] + + batch_size = 5000 + + for batch in _group_elements(rel_nodes, batch_size): + merge_relationships(self.graph.auto(), batch, r_type, start_node_key=(n1_type, "name"), end_node_key=(n2_type, "name")) + + def _null_rel_response(self, a_term, b_term): + {'a_term': a_term, 'a_type': '', 'relationship': '', 'b_term': b_term, 'b_type': '', 'pmids': []} + +def _sanitize_txt(term: str): + term = term.split(index.logical_or)[0] + return str.join(' ', util.get_tokens(term.lower().strip())) + +# batches "lst" into "chunk_size" sized elements +def _group_elements(lst, chunk_size): + lst = iter(lst) + return iter(lambda: tuple(islice(lst, chunk_size)), ()) \ No newline at end of file diff --git a/src/server/app.py b/src/server/app.py index 3c984ff..60e288a 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -3,8 +3,11 @@ import rq_dashboard from redis import Redis from rq import Queue +from rq.job import Job +from rq.command import send_stop_job_command +from rq.exceptions import InvalidJobOperation from flask_restful import Api -from workers.work import km_work, km_work_all_vs_all, triple_miner_work, update_index_work, clear_mongo_cache +from workers.work import km_work, km_work_all_vs_all, update_index_work, clear_mongo_cache import logging from flask_bcrypt import Bcrypt @@ -114,15 +117,6 @@ def _post_skim_job(): def _get_skim_job(): return _get_generic(request) -## ******** TripleMiner Post/Get ******** -@_app.route('/tripleminer/api/jobs/', methods=['POST']) -def _post_tripleminer_job(): - return _post_generic(triple_miner_work, request) - -@_app.route('/tripleminer/api/jobs/', methods=['GET']) -def _get_tripleminer_job(): - return _get_generic(request) - ## ******** Update Index Post/Get ******** @_app.route('/update_index/api/jobs/', methods=['POST']) def _post_update_index_job(): @@ -135,4 +129,34 @@ def _get_update_index_job(): ## ******** Clear MongoDB Cache Post ******** @_app.route('/clear_cache/api/jobs/', methods=['POST']) def _post_clear_cache_job(): - return _post_generic(clear_mongo_cache, request) \ No newline at end of file + return _post_generic(clear_mongo_cache, request) + +## ******** Cancel Job Post ******** +@_app.route('/cancel_job/api/jobs/', methods=['POST']) +def _post_cancel_job(): + json_data = request.get_json(request.data) + job_id = json_data['id'] + response = jsonify(dict()) + + job = Job.fetch(job_id, connection=_r) + + if job: + job.cancel() + try: + send_stop_job_command(connection=_r, job_id=job_id) + + # TODO: can't seem to get an error message to display in + # rq-dashboard. probably should come back to this at some point. + + # job.exc_info = 'Job was canceled by request' + # job.save() + except InvalidJobOperation: + # probably tried to cancel a job that wasn't in progress. + # just ignore the error message. + pass + status_code = 200 + else: + status_code = 404 + + response.status_code = status_code + return response \ No newline at end of file diff --git a/src/tests/test_container_integration.py b/src/tests/test_container_integration.py index b5372cf..6d1d416 100644 --- a/src/tests/test_container_integration.py +++ b/src/tests/test_container_integration.py @@ -61,10 +61,23 @@ def test_container_integration(data_dir, monkeypatch): # run query skim_url = api_url + skim_append - query = {'a_terms': ['cancer'], 'b_terms': ['coffee'], 'c_terms': ['water'], 'ab_fet_threshold': 1, 'top_n': 50} - result = _post_job(skim_url, query)['result'] + query = {'a_terms': ['cancer'], 'b_terms': ['coffee'], 'c_terms': ['water'], 'ab_fet_threshold': 1, 'top_n': 50, 'query_knowledge_graph': 'True'} + job_info = _post_job(skim_url, query) + + if job_info['status'] == 'failed': + if 'message' in job_info: + raise RuntimeError('the job failed because: ' + job_info['message']) + raise RuntimeError('the job failed without an annotated reason') + + result = job_info['result'] assert result[0]['total_count'] == 0 + # TODO: this just tests that the neo4j database can be connected to. + # it does not test for actual querying of the knowledge graph. need to + # write that into a test. + #assert 'ab_relationship' in result[0] + #assert 'connection error' not in result[0]['ab_relationship'] + # build the index _post_job(api_url + update_index_append, {'n_files': 0, 'clear_cache': False}) @@ -72,6 +85,8 @@ def test_container_integration(data_dir, monkeypatch): # cache to auto-clear. result = _post_job(skim_url, query)['result'] assert result[0]['total_count'] > 4000 + #assert 'ab_relationship' in result[0] + #assert 'connection error' not in result[0]['ab_relationship'] except Exception as e: assert False, str(e) diff --git a/src/tests/test_index.py b/src/tests/test_index.py index f193e0d..70c422e 100644 --- a/src/tests/test_index.py +++ b/src/tests/test_index.py @@ -37,25 +37,25 @@ def test_index_abstract(tmp_path): the_index = Index(tmp_path) - query = the_index.query_index("the") + query = the_index._query_index("the") assert query == set([abs1.pmid, abs2.pmid]) - query = the_index.query_index("are are are") + query = the_index._query_index("are are are") assert query == set([abs2.pmid]) - query = the_index.query_index("are are are some") + query = the_index._query_index("are are are some") assert len(query) == 0 - query = the_index.query_index("are are are quick") + query = the_index._query_index("are are are quick") assert len(query) == 0 - query = the_index.query_index("brown") + query = the_index._query_index("brown") assert query == set([abs1.pmid]) assert the_index._publication_years[abs1.pmid] == abs1.pub_year assert the_index._publication_years[abs2.pmid] == abs2.pub_year - query = the_index.query_index("test_test") + query = the_index._query_index("test_test") assert len(query) == 0 assert the_index.n_articles() == 2 \ No newline at end of file diff --git a/src/tests/test_index_building.py b/src/tests/test_index_building.py index 8a96a9d..e6e5a78 100755 --- a/src/tests/test_index_building.py +++ b/src/tests/test_index_building.py @@ -76,11 +76,11 @@ def test_indexer(data_dir): # query the index index = Index(data_dir) - query = index.query_index("polysaccharide") - query = query | index.query_index("polysaccharides") - query = query | index.query_index("lipopolysaccharides") - query = query | index.query_index("lipopolysaccharide") - query = query | index.query_index("exopolysaccharide") + query = index._query_index("polysaccharide") + query = query | index._query_index("polysaccharides") + query = query | index._query_index("lipopolysaccharides") + query = query | index._query_index("lipopolysaccharide") + query = query | index._query_index("exopolysaccharide") assert len(query) == 37 diff --git a/src/tests/test_kinderminer.py b/src/tests/test_kinderminer.py index 424b894..604ed96 100755 --- a/src/tests/test_kinderminer.py +++ b/src/tests/test_kinderminer.py @@ -5,6 +5,7 @@ from indexing.index_builder import IndexBuilder from workers import kinderminer as km from indexing import km_util as util +import indexing.index as index from .test_index_building import data_dir def test_fisher_exact_test(): @@ -23,6 +24,19 @@ def test_fisher_exact_test(): sort_ratio = km.get_sort_ratio(table) assert sort_ratio == pytest.approx(15 / 59) +def test_text_sanitation(): + text = 'Testing123****.' + sanitized_text = index.sanitize_term(text) + assert sanitized_text == 'testing123' + + text = 'The quick brown fox / jumped over the lazy dog.' + sanitized_text = index.sanitize_term(text) + assert sanitized_text == 'jumped over the lazy dog/the quick brown fox' + + text = 'This&is&a&test.' + sanitized_text = index.sanitize_term(text) + assert sanitized_text == 'a&is&test&this' + def test_kinderminer(data_dir): index_dir = util.get_index_dir(data_dir) @@ -37,8 +51,8 @@ def test_kinderminer(data_dir): idx = Index(data_dir) # test index querying - lung_pmids = idx.query_index('lung') - tissue_pmids = idx.query_index('tissue') + lung_pmids = idx._query_index('lung') + tissue_pmids = idx._query_index('tissue') assert len(lung_pmids) == 109 assert len(tissue_pmids) == 234 @@ -49,6 +63,15 @@ def test_kinderminer(data_dir): [93, 3812] ] + # test censor year + # also tests that the article title is queried properly + km_result = km.kinderminer_search('patients undergoing pancreaticoduodenectomy', 'somatostatin', idx, censor_year=2020, return_pmids=True) + assert km_result['len(a_term_set)'] == 1 + assert km_result['n_articles'] == 6 + km_result = km.kinderminer_search('patients undergoing pancreaticoduodenectomy', 'somatostatin', idx, censor_year=2020, return_pmids=True) + assert km_result['len(a_term_set)'] == 1 + assert km_result['n_articles'] == 6 + # test KM query results km_result = km.kinderminer_search('significant', 'cancer', idx, return_pmids=True) assert km_result['pvalue'] == pytest.approx(0.486007, abs=1e-6) diff --git a/src/tests/test_work.py b/src/tests/test_work.py index 3ab2ed0..49f56e2 100644 --- a/src/tests/test_work.py +++ b/src/tests/test_work.py @@ -32,8 +32,10 @@ def test_skim_work(data_dir): assert result[0]['ab_count'] > 0 # test SKiM with A-B-C terms - result = work.km_work_all_vs_all({'a_terms': ['cancer'], 'b_terms': ['test'], 'c_terms': ['coffee'], 'top_n': 50, 'ab_fet_threshold': 0.8}) + result = work.km_work_all_vs_all({'a_terms': ['cancer'], 'b_terms': ['test'], 'c_terms': ['coffee'], 'top_n': 50, 'ab_fet_threshold': 0.8, 'query_knowledge_graph': True}) assert len(result) == 1 assert result[0]['c_term'] == 'coffee' assert result[0]['ab_count'] > 0 - assert result[0]['bc_count'] > 0 \ No newline at end of file + assert result[0]['bc_count'] > 0 + #assert result[0]['ab_relationship'] == 'neo4j connection error' + #assert result[0]['bc_relationship'] == 'neo4j connection error' \ No newline at end of file diff --git a/src/workers/kinderminer.py b/src/workers/kinderminer.py index 26e4d24..92a029b 100755 --- a/src/workers/kinderminer.py +++ b/src/workers/kinderminer.py @@ -3,8 +3,6 @@ import math from indexing.index import Index -logical_or = '/' # supports '/' to mean 'or' -logical_and = '&' # supports '&' to mean 'and' fet_sided = 'greater' def get_contingency_table(a_term_set: set, b_term_set: set, total_n: int): @@ -35,13 +33,13 @@ def kinderminer_search(a_term: str, b_term: str, idx: Index, censor_year = math. result = dict() # query the index (handling synonyms if appropriate) - a_term_set = _construct_abstract_set(a_term, idx) - b_term_set = _construct_abstract_set(b_term, idx) + a_term_set = idx.construct_abstract_set(a_term) + b_term_set = idx.construct_abstract_set(b_term) # censor by year if applicable if censor_year is not math.inf: - a_term_set = idx.censor_by_year(a_term_set, censor_year) - b_term_set = idx.censor_by_year(b_term_set, censor_year) + a_term_set = idx.censor_by_year(a_term_set, censor_year, a_term) + b_term_set = idx.censor_by_year(b_term_set, censor_year, b_term) # create contingency table table = get_contingency_table(a_term_set, b_term_set, @@ -69,25 +67,10 @@ def kinderminer_search(a_term: str, b_term: str, idx: Index, censor_year = math. if return_pmids: result['pmid_intersection'] = a_term_set & b_term_set - return result - -def _construct_abstract_set(term: str, idx: Index) -> set: - # TODO: support parenthesis for allowing OR and AND at the same time? - # e.g., "(cancer/carcinoma) & BRCA1" - if logical_or in term: - terms = term.split(logical_or) - pmid_set = set() - for synonym in terms: - pmid_set.update(idx.query_index(synonym)) - elif logical_and in term: - terms = term.split(logical_and) - pmid_set = idx.query_index(terms[0]) - for t in terms[1:]: - pmid_set.intersection_update(idx.query_index(t)) - else: - pmid_set = idx.query_index(term) + if len(result['pmid_intersection']) > 1000: + result['pmid_intersection'] = set(list(result['pmid_intersection'])[:1000]) - return pmid_set + return result def get_prediction_score(fet, ratio): max_score = 323.0 diff --git a/src/workers/work.py b/src/workers/work.py index 873bf9c..ae7544f 100644 --- a/src/workers/work.py +++ b/src/workers/work.py @@ -9,12 +9,14 @@ import workers.kinderminer as km from indexing.index_builder import IndexBuilder import indexing.download_abstracts as downloader +from knowledge_graph.knowledge_graph import KnowledgeGraph, rel_pvalue_cutoff _r = Redis(host='redis', port=6379) _q = Queue(connection=_r) def km_work(json: list): _initialize_mongo_caching() + knowledge_graph = connect_to_neo4j() return_val = [] @@ -25,15 +27,7 @@ def km_work(json: list): a_term = item['a_term'] b_term = item['b_term'] - if 'censor_year' in item: - censor_year = int(item['censor_year']) - else: - censor_year = math.inf - - if censor_year is None or censor_year > 2100: - censor_year = math.inf - if censor_year < 0: - censor_year = 0 + censor_year = _get_censor_year(item) if a_term is None or b_term is None: raise TypeError('Must supply a_term and b_term') @@ -47,12 +41,22 @@ def km_work(json: list): if 'pmid_intersection' in res: res['pmid_intersection'] = str(res['pmid_intersection']) + # query knowledge graph + query_kg = False + if 'query_knowledge_graph' in item: + query_kg = bool(item['query_knowledge_graph']) + + if query_kg and res['pvalue'] < rel_pvalue_cutoff: + rel = knowledge_graph.query(a_term, b_term) + res['relationship'] = rel + return_val.append(res) return return_val def km_work_all_vs_all(json: dict): _initialize_mongo_caching() + knowledge_graph = connect_to_neo4j() return_val = [] km_only = False @@ -81,32 +85,40 @@ def km_work_all_vs_all(json: dict): else: ab_fet_threshold = math.inf - if 'censor_year' in json: - censor_year = json['censor_year'] - else: - censor_year = math.inf + censor_year = _get_censor_year(json) return_pmids = False if 'return_pmids' in json: return_pmids = bool(json['return_pmids']) + query_kg = False + if 'query_knowledge_graph' in json: + query_kg = bool(json['query_knowledge_graph']) + if type(top_n) is str: top_n = int(top_n) - if type(censor_year) is str: - censor_year = int(censor_year) + + _update_job_status('progress', 0) for a_term_n, a_term in enumerate(a_terms): ab_results = [] - for b_term in b_terms: + for b_term_n, b_term in enumerate(b_terms): res = km.kinderminer_search(a_term, b_term, li.the_index, censor_year, return_pmids) if res['pvalue'] <= ab_fet_threshold: ab_results.append(res) + # report KM progress + if km_only: + numerator = a_term_n * len(b_terms) + b_term_n + 1 + denom = len(a_terms) * len(b_terms) + progress = round((numerator / denom), 4) + _update_job_status('progress', min(progress, 0.9999)) + # sort by prediction score, descending - ab_results.sort(key=lambda res: - km.get_prediction_score(res['pvalue'], res['sort_ratio']), + ab_results.sort(key=lambda res: + km.get_prediction_score(res['pvalue'], res['sort_ratio']), reverse=True) ab_results = ab_results[:top_n] @@ -121,7 +133,7 @@ def km_work_all_vs_all(json: dict): 'ab_pvalue': ab['pvalue'], 'ab_sort_ratio': ab['sort_ratio'], 'ab_pred_score': km.get_prediction_score(ab['pvalue'], ab['sort_ratio']), - + 'a_count': ab['len(a_term_set)'], 'b_count': ab['len(b_term_set)'], 'ab_count': ab['len(a_b_intersect)'], @@ -131,6 +143,10 @@ def km_work_all_vs_all(json: dict): if return_pmids: abc_result['ab_pmid_intersection'] = str(ab['pmid_intersection']) + if query_kg and abc_result['ab_pvalue'] < rel_pvalue_cutoff: + rel = knowledge_graph.query(abc_result['a_term'], abc_result['b_term']) + abc_result['ab_relationship'] = rel + # add c-terms and b-c term KM info (SKiM) if not km_only: b_term = ab['b_term'] @@ -142,46 +158,24 @@ def km_work_all_vs_all(json: dict): abc_result['bc_pred_score'] = km.get_prediction_score(bc['pvalue'], bc['sort_ratio']) abc_result['c_count'] = bc['len(b_term_set)'] abc_result['bc_count'] = bc['len(a_b_intersect)'] - + if return_pmids: abc_result['bc_pmid_intersection'] = str(bc['pmid_intersection']) + if query_kg and abc_result['bc_pvalue'] < rel_pvalue_cutoff: + rel = knowledge_graph.query(abc_result['b_term'], abc_result['c_term']) + abc_result['bc_relationship'] = rel + return_val.append(abc_result) - # report percentage of C-terms complete if not km_only: + # report SKiM progress - percentage of C-terms complete progress = round(((c_term_n + 1) / len(c_terms)), 4) - else: - # report percentage of A-B pairs complete - progress = round(((a_term_n + 1) / len(a_terms)), 4) - - # report progress but never report 100% progress until the job is actually done - _update_job_status('progress', min(progress, 0.9999)) + _update_job_status('progress', min(progress, 0.9999)) _update_job_status('progress', 1.0000) return return_val -def triple_miner_work(json: list): - _initialize_mongo_caching() - - km_set = [] - - for query in json: - a_term = query['a_term'] - b_term = query['b_term'] - c_term = query['c_term'] - - km_query = dict() - km_query['a_term'] = a_term + '&&' + b_term - km_query['b_term'] = c_term - - if 'censor_year' in query: - km_query['censor_year'] = query['censor_year'] - - km_set.append(km_query) - - return km_work(km_set) - def update_index_work(json: dict): indexing.index._connect_to_mongo() if 'n_files' in json: @@ -245,6 +239,9 @@ def _initialize_mongo_caching(): # such as 'fever' to save the current state of the index li.the_index._check_if_mongo_should_be_refreshed() +def connect_to_neo4j(): + return KnowledgeGraph() + def _restart_workers(requeue_interrupted_jobs = True): print('restarting workers...') workers = Worker.all(_r) @@ -282,6 +279,19 @@ def _update_job_status(key, value): if job is None: print('error: tried to update job status, but could not find job') return - + job.meta[key] = value - job.save_meta() \ No newline at end of file + job.save_meta() + +def _get_censor_year(item): + if 'censor_year' in item: + censor_year = int(item['censor_year']) + else: + censor_year = math.inf + + if censor_year is None or censor_year > 2100: + censor_year = math.inf + if censor_year < 0: + censor_year = 0 + + return censor_year \ No newline at end of file