From db40cf0c86d846ad07e99a4dd92aaa8015896230 Mon Sep 17 00:00:00 2001 From: Thanasis Date: Mon, 20 May 2019 10:57:42 +0100 Subject: [PATCH 1/2] resolves #272 --- scripts/documents_filter.py | 26 +++++++++----------------- scripts/documents_weights.py | 10 +--------- scripts/pipeline.py | 10 +++++----- scripts/utils/utils.py | 16 ++++++++++++++++ tests/test_documents_filter.py | 12 +++++++----- 5 files changed, 38 insertions(+), 36 deletions(-) diff --git a/scripts/documents_filter.py b/scripts/documents_filter.py index 3f27ef5..63221c9 100644 --- a/scripts/documents_filter.py +++ b/scripts/documents_filter.py @@ -5,9 +5,9 @@ class DocumentsFilter(object): - def __init__(self, df, docs_mask_dict): + def __init__(self, df, docs_mask_dict, cpc_dict): self.__doc_indices = set([]) - + self.__cpc_dict = cpc_dict if docs_mask_dict['columns'] is not None: self.__doc_indices = self.__filter_column(df, docs_mask_dict['columns'], docs_mask_dict['filter_by']) @@ -40,21 +40,13 @@ def doc_filters(self): def doc_indices(self): return self.__doc_indices - @staticmethod - def __filter_cpc(df, cpc): - cpc_index_list = [] - - df = df.reset_index(drop=True) - for index, row in tqdm(df.iterrows(), desc='Sifting documents for ' + cpc, unit='document', - total=df.shape[0]): - cpc_list = row['classifications_cpc'] - if not isinstance(cpc_list, list): - continue - for cpc_item in cpc_list: - if cpc_item.startswith(cpc): - cpc_index_list.append(index) - break - return cpc_index_list + def __filter_cpc(self, df, cpc): + indices_set = set() + for cpc_item in self.__cpc_dict: + if cpc_item.startswith(cpc): + indices_set |= self.__cpc_dict[cpc_item] + + return list(indices_set) @staticmethod def __filter_column(df, filter_columns, filter_by): diff --git a/scripts/documents_weights.py b/scripts/documents_weights.py index af7b833..0918f0a 100644 --- a/scripts/documents_weights.py +++ b/scripts/documents_weights.py @@ -7,7 +7,7 @@ class DocumentsWeights(object): - def __init__(self, df, time, citation_count_dict, date_header, text_lengths=None, norm_rows=False): + def __init__(self, df, time, citation_count_dict, date_header): self.__dataframe = df self.__date_header = date_header self.__weights = [1.0]*len(df) @@ -22,11 +22,6 @@ def __init__(self, df, time, citation_count_dict, date_header, text_lengths=None self.__weights = [a * b for a, b in zip(self.__weights, cite_weights)] processed = True - # normalize rows to text length - if norm_rows: - self.__normalize_rows(text_lengths) - processed = True - if processed: self.__weights = ut.normalize_array(self.__weights, return_list=True) @@ -65,6 +60,3 @@ def __citation_weights(self, citation_count_dict): citation_count_for_doc_id_dict[doc_id] = citation_count_for_doc_id_dict_std return list(citation_count_for_doc_id_dict.values()) - - # for i, v in enumerate(list_of_citation_counts): - # self.__tfidf_matrix.data[self.__tfidf_matrix.indptr[i]:self.__tfidf_matrix.indptr[i + 1]] *= v \ No newline at end of file diff --git a/scripts/pipeline.py b/scripts/pipeline.py index 699767c..3c7492b 100644 --- a/scripts/pipeline.py +++ b/scripts/pipeline.py @@ -61,17 +61,18 @@ def __init__(self, data_filename, docs_mask_dict, pick_method='sum', ngram_range self.__text_lengths = self.__dataframe[text_header].map(len).tolist() self.__dataframe.drop(columns=[text_header], inplace=True) + self.__cpc_dict = utils.cpc_dict(self.__dataframe) tfidf_filename = path.join('outputs', 'tfidf', output_name + f'-tfidf-mdf-{max_df}.pkl.bz2') makedirs(path.dirname(tfidf_filename), exist_ok=True) with bz2.BZ2File(tfidf_filename, 'wb') as pickle_file: pickle.dump( - (self.__tfidf_obj, self.__dataframe, self.__text_lengths), + (self.__tfidf_obj, self.__dataframe, self.__cpc_dict), pickle_file, protocol=4, fix_imports=False) else: print(f'Reading document and TFIDF from pickle {pickled_tf_idf_file_name}') - self.__tfidf_obj, self.__dataframe, self.__text_lengths = read_pickle(pickled_tf_idf_file_name) + self.__tfidf_obj, self.__dataframe, self.__cpc_dict = read_pickle(pickled_tf_idf_file_name) if docs_mask_dict['date_header'] is None: print('Document dates not specified') else: @@ -99,14 +100,13 @@ def __init__(self, data_filename, docs_mask_dict, pick_method='sum', ngram_range # then apply mask to tfidf object and df (i.e. remove rows with false or 0); do this in place # docs weights( column, dates subset + time, citations etc.) - doc_filters = DocumentsFilter(self.__dataframe, docs_mask_dict).doc_filters + doc_filters = DocumentsFilter(self.__dataframe, docs_mask_dict, self.__cpc_dict).doc_filters # todo: build up list of weight functions (left with single remaining arg etc via partialfunc) # combine(list, tfidf) => multiplies weights together, then multiplies across tfidf (if empty, no side effect) doc_weights = DocumentsWeights(self.__dataframe, docs_mask_dict['time'], docs_mask_dict['cite'], - docs_mask_dict['date_header'], self.__text_lengths, - norm_rows=normalize_rows).weights + docs_mask_dict['date_header']).weights doc_weights = [a * b for a, b in zip(doc_filters, doc_weights)] # todo: this is another weight function... diff --git a/scripts/utils/utils.py b/scripts/utils/utils.py index ea8ed39..3b1af9d 100644 --- a/scripts/utils/utils.py +++ b/scripts/utils/utils.py @@ -10,6 +10,22 @@ from pandas.api.types import is_string_dtype +def cpc_dict(df): + cpc_list_2d = df['classifications_cpc'] + cpc_dict={} + for idx, cpc_list in enumerate(cpc_list_2d): + if not isinstance(cpc_list, list): + continue + for cpc_item in cpc_list: + if cpc_item in cpc_dict: + cpc_set = cpc_dict[cpc_item] + cpc_set.add(idx) + cpc_dict[cpc_item] = cpc_set + else: + cpc_dict[cpc_item] = {idx} + return cpc_dict + + def l2normvec(csr_tfidf_mat): l2normvec=np.zeros((csr_tfidf_mat.shape[0],), dtype=np.float64) diff --git a/tests/test_documents_filter.py b/tests/test_documents_filter.py index 7a7799e..60f35cb 100644 --- a/tests/test_documents_filter.py +++ b/tests/test_documents_filter.py @@ -2,6 +2,7 @@ import unittest from scripts import FilePaths from scripts.documents_filter import DocumentsFilter +from scripts.utils import utils class TestDocumentsFilter(unittest.TestCase): @@ -18,6 +19,7 @@ def setUp(self): # [self.args.year_from, year_to, self.args.month_from, month_to, self.args.date_header] df = pd.read_pickle(FilePaths.us_patents_random_100_pickle_name) + self.__cpc_dict = utils.cpc_dict(df) self.__df = df.reset_index() def test_filter_cpc_Y02(self): @@ -25,14 +27,14 @@ def test_filter_cpc_Y02(self): self.__docs_mask_dict['filter_by'] = 'union' self.__docs_mask_dict['cpc'] = 'Y02' - doc_ids = DocumentsFilter(self.__df, self.__docs_mask_dict).doc_indices + doc_ids = DocumentsFilter(self.__df, self.__docs_mask_dict, self.__cpc_dict).doc_indices self.assertListEqual(list(doc_ids), [95]) def test_filter_cpc_A61(self): self.__docs_mask_dict['filter_by'] = 'union' self.__docs_mask_dict['cpc'] = 'A61' - doc_ids = DocumentsFilter(self.__df, self.__docs_mask_dict).doc_indices + doc_ids = DocumentsFilter(self.__df, self.__docs_mask_dict, self.__cpc_dict).doc_indices self.assertListEqual(list(doc_ids), [67, 69, 72, 74, 11, 13, 17, 81, 85, 90, 94, 43, 50, 57, 60, 63]) def test_filter_dates(self): @@ -41,7 +43,7 @@ def test_filter_dates(self): 'to': pd.to_datetime('today') } self.__docs_mask_dict['date_header'] = 'publication_date' - doc_ids = DocumentsFilter(self.__df, self.__docs_mask_dict).doc_indices + doc_ids = DocumentsFilter(self.__df, self.__docs_mask_dict, self.__cpc_dict).doc_indices self.assertListEqual(list(doc_ids), [26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, @@ -56,7 +58,7 @@ def test_filter_cpc_A61_union_dates(self): self.__docs_mask_dict['filter_by'] = 'union' self.__docs_mask_dict['cpc'] = 'A61' - doc_ids = DocumentsFilter(self.__df, self.__docs_mask_dict).doc_indices + doc_ids = DocumentsFilter(self.__df, self.__docs_mask_dict, self.__cpc_dict).doc_indices self.assertListEqual(list(doc_ids), [11, 13, 17, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, @@ -74,7 +76,7 @@ def test_filter_cpc_A61_intersection_dates(self): self.__docs_mask_dict['filter_by'] = 'intersection' self.__docs_mask_dict['cpc'] = 'A61' - doc_ids = DocumentsFilter(self.__df, self.__docs_mask_dict).doc_indices + doc_ids = DocumentsFilter(self.__df, self.__docs_mask_dict, self.__cpc_dict).doc_indices self.assertListEqual(list(doc_ids), [67, 69, 72, 74, 43, 81, 50, 85, 57, 90, 60, 94, 63]) From 882141e19421b37103aa882d0bc825bf8f32fdcd Mon Sep 17 00:00:00 2001 From: Thanasis Date: Mon, 20 May 2019 11:22:00 +0100 Subject: [PATCH 2/2] tests bug fixed --- tests/test_documents_weights.py | 7 ------- tests/test_tfidf_mask.py | 3 ++- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/test_documents_weights.py b/tests/test_documents_weights.py index fd3b2ba..afbdf1f 100644 --- a/tests/test_documents_weights.py +++ b/tests/test_documents_weights.py @@ -124,13 +124,6 @@ def test_citations(self): actual = list(doc_weights) self.assertListEqual(expected, actual) - def test_normalized(self): - expected_weights = [1 / 5, 1 / 2, 1 / 4, 1] - - actual_weights = DocumentsWeights(self.__df, False, None, 'publication_date', [5, 2, 4, 1], - norm_rows=True).weights - - self.assertListEqual(expected_weights, actual_weights) def test_time(self): expected = [0.2, diff --git a/tests/test_tfidf_mask.py b/tests/test_tfidf_mask.py index e4c95f7..862ae58 100644 --- a/tests/test_tfidf_mask.py +++ b/tests/test_tfidf_mask.py @@ -48,8 +48,9 @@ def init_mask(self, cpc, min_n, uni_factor=0.8): self.__tfidf_obj = tfidf_from_text(self.__df['abstract'], ngram_range=(min_n, self.__max_n), max_document_frequency=self.__max_df, tokenizer=StemTokenizer()) + cpc_dict=utils.cpc_dict(self.__df) - doc_filters = DocumentsFilter(self.__df, docs_mask_dict).doc_filters + doc_filters = DocumentsFilter(self.__df, docs_mask_dict, cpc_dict).doc_filters doc_weights = DocumentsWeights(self.__df, docs_mask_dict['time'], docs_mask_dict['cite'], docs_mask_dict['date_header']).weights doc_weights = [a * b for a, b in zip(doc_filters, doc_weights)]