From 10e703216eaa0fed8de38fa0bb519e724722b84f Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 4 Aug 2021 18:19:51 +0800 Subject: [PATCH 1/2] DOC: update documents. --- docs/source/index.rst | 2 +- docs/source/recbole/recbole.config.rst | 8 --- ...ole.data.dataloader.context_dataloader.rst | 4 -- ...cbole.data.dataloader.neg_sample_mixin.rst | 4 -- .../recbole/recbole.data.dataloader.rst | 3 - ....data.dataloader.sequential_dataloader.rst | 4 -- docs/source/recbole/recbole.data.dataset.rst | 1 - .../recbole.evaluator.abstract_evaluator.rst | 4 -- ....rst => recbole.evaluator.base_metric.rst} | 2 +- ...ng.rst => recbole.evaluator.collector.rst} | 2 +- ...rs.rst => recbole.evaluator.evaluator.rst} | 2 +- ...set.rst => recbole.evaluator.register.rst} | 2 +- docs/source/recbole/recbole.evaluator.rst | 7 +- docs/source/user_guide/model_intro.rst | 2 + recbole/evaluator/metrics.py | 68 ++++++++++++++----- 15 files changed, 62 insertions(+), 53 deletions(-) delete mode 100644 docs/source/recbole/recbole.config.rst delete mode 100644 docs/source/recbole/recbole.data.dataloader.context_dataloader.rst delete mode 100644 docs/source/recbole/recbole.data.dataloader.neg_sample_mixin.rst delete mode 100644 docs/source/recbole/recbole.data.dataloader.sequential_dataloader.rst delete mode 100644 docs/source/recbole/recbole.evaluator.abstract_evaluator.rst rename docs/source/recbole/{recbole.evaluator.proxy_evaluator.rst => recbole.evaluator.base_metric.rst} (51%) rename docs/source/recbole/{recbole.config.eval_setting.rst => recbole.evaluator.collector.rst} (55%) rename docs/source/recbole/{recbole.evaluator.evaluators.rst => recbole.evaluator.evaluator.rst} (54%) rename docs/source/recbole/{recbole.data.dataset.social_dataset.rst => recbole.evaluator.register.rst} (50%) diff --git a/docs/source/index.rst b/docs/source/index.rst index 96c20bbe7..4a41dfe09 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,7 +38,7 @@ RecBole v0.2.0 :maxdepth: 1 :caption: API REFERENCE: - recbole/recbole.config + recbole/recbole.config.configurator recbole/recbole.data recbole/recbole.evaluator recbole/recbole.model diff --git a/docs/source/recbole/recbole.config.rst b/docs/source/recbole/recbole.config.rst deleted file mode 100644 index 1b58676fc..000000000 --- a/docs/source/recbole/recbole.config.rst +++ /dev/null @@ -1,8 +0,0 @@ -recbole.config -====================== - -.. toctree:: - :maxdepth: 4 - - recbole.config.configurator - recbole.config.eval_setting diff --git a/docs/source/recbole/recbole.data.dataloader.context_dataloader.rst b/docs/source/recbole/recbole.data.dataloader.context_dataloader.rst deleted file mode 100644 index f46d5ee0c..000000000 --- a/docs/source/recbole/recbole.data.dataloader.context_dataloader.rst +++ /dev/null @@ -1,4 +0,0 @@ -.. automodule:: recbole.data.dataloader.context_dataloader - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/recbole/recbole.data.dataloader.neg_sample_mixin.rst b/docs/source/recbole/recbole.data.dataloader.neg_sample_mixin.rst deleted file mode 100644 index 67fdd0e93..000000000 --- a/docs/source/recbole/recbole.data.dataloader.neg_sample_mixin.rst +++ /dev/null @@ -1,4 +0,0 @@ -.. automodule:: recbole.data.dataloader.neg_sample_mixin - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/recbole/recbole.data.dataloader.rst b/docs/source/recbole/recbole.data.dataloader.rst index 1dc37ef3d..5db4daf65 100644 --- a/docs/source/recbole/recbole.data.dataloader.rst +++ b/docs/source/recbole/recbole.data.dataloader.rst @@ -5,9 +5,6 @@ recbole.data.dataloader :maxdepth: 4 recbole.data.dataloader.abstract_dataloader - recbole.data.dataloader.context_dataloader recbole.data.dataloader.general_dataloader recbole.data.dataloader.knowledge_dataloader - recbole.data.dataloader.neg_sample_mixin - recbole.data.dataloader.sequential_dataloader recbole.data.dataloader.user_dataloader diff --git a/docs/source/recbole/recbole.data.dataloader.sequential_dataloader.rst b/docs/source/recbole/recbole.data.dataloader.sequential_dataloader.rst deleted file mode 100644 index 94ab388e9..000000000 --- a/docs/source/recbole/recbole.data.dataloader.sequential_dataloader.rst +++ /dev/null @@ -1,4 +0,0 @@ -.. automodule:: recbole.data.dataloader.sequential_dataloader - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/recbole/recbole.data.dataset.rst b/docs/source/recbole/recbole.data.dataset.rst index 64d432fae..58a17e800 100644 --- a/docs/source/recbole/recbole.data.dataset.rst +++ b/docs/source/recbole/recbole.data.dataset.rst @@ -9,4 +9,3 @@ recbole.data.dataset recbole.data.dataset.kg_dataset recbole.data.dataset.kg_seq_dataset recbole.data.dataset.sequential_dataset - recbole.data.dataset.social_dataset diff --git a/docs/source/recbole/recbole.evaluator.abstract_evaluator.rst b/docs/source/recbole/recbole.evaluator.abstract_evaluator.rst deleted file mode 100644 index 66de59e49..000000000 --- a/docs/source/recbole/recbole.evaluator.abstract_evaluator.rst +++ /dev/null @@ -1,4 +0,0 @@ -.. automodule:: recbole.evaluator.abstract_evaluator - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/recbole/recbole.evaluator.proxy_evaluator.rst b/docs/source/recbole/recbole.evaluator.base_metric.rst similarity index 51% rename from docs/source/recbole/recbole.evaluator.proxy_evaluator.rst rename to docs/source/recbole/recbole.evaluator.base_metric.rst index c4689d43a..323aacf1b 100644 --- a/docs/source/recbole/recbole.evaluator.proxy_evaluator.rst +++ b/docs/source/recbole/recbole.evaluator.base_metric.rst @@ -1,4 +1,4 @@ -.. automodule:: recbole.evaluator.proxy_evaluator +.. automodule:: recbole.evaluator.base_metric :members: :undoc-members: :show-inheritance: diff --git a/docs/source/recbole/recbole.config.eval_setting.rst b/docs/source/recbole/recbole.evaluator.collector.rst similarity index 55% rename from docs/source/recbole/recbole.config.eval_setting.rst rename to docs/source/recbole/recbole.evaluator.collector.rst index cd78d6dca..b67ab72fc 100644 --- a/docs/source/recbole/recbole.config.eval_setting.rst +++ b/docs/source/recbole/recbole.evaluator.collector.rst @@ -1,4 +1,4 @@ -.. automodule:: recbole.config.eval_setting +.. automodule:: recbole.evaluator.collector :members: :undoc-members: :show-inheritance: diff --git a/docs/source/recbole/recbole.evaluator.evaluators.rst b/docs/source/recbole/recbole.evaluator.evaluator.rst similarity index 54% rename from docs/source/recbole/recbole.evaluator.evaluators.rst rename to docs/source/recbole/recbole.evaluator.evaluator.rst index 8aeca9804..39c00e545 100644 --- a/docs/source/recbole/recbole.evaluator.evaluators.rst +++ b/docs/source/recbole/recbole.evaluator.evaluator.rst @@ -1,4 +1,4 @@ -.. automodule:: recbole.evaluator.evaluators +.. automodule:: recbole.evaluator.evaluator :members: :undoc-members: :show-inheritance: diff --git a/docs/source/recbole/recbole.data.dataset.social_dataset.rst b/docs/source/recbole/recbole.evaluator.register.rst similarity index 50% rename from docs/source/recbole/recbole.data.dataset.social_dataset.rst rename to docs/source/recbole/recbole.evaluator.register.rst index 47db01c07..6c642d099 100644 --- a/docs/source/recbole/recbole.data.dataset.social_dataset.rst +++ b/docs/source/recbole/recbole.evaluator.register.rst @@ -1,4 +1,4 @@ -.. automodule:: recbole.data.dataset.social_dataset +.. automodule:: recbole.evaluator.register :members: :undoc-members: :show-inheritance: diff --git a/docs/source/recbole/recbole.evaluator.rst b/docs/source/recbole/recbole.evaluator.rst index 0f3c99793..49f09b114 100644 --- a/docs/source/recbole/recbole.evaluator.rst +++ b/docs/source/recbole/recbole.evaluator.rst @@ -4,8 +4,9 @@ recbole.evaluator .. toctree:: :maxdepth: 4 - recbole.evaluator.abstract_evaluator - recbole.evaluator.evaluators + recbole.evaluator.base_metric + recbole.evaluator.collector + recbole.evaluator.evaluator recbole.evaluator.metrics - recbole.evaluator.proxy_evaluator + recbole.evaluator.register recbole.evaluator.utils diff --git a/docs/source/user_guide/model_intro.rst b/docs/source/user_guide/model_intro.rst index 8014a7326..33c353061 100644 --- a/docs/source/user_guide/model_intro.rst +++ b/docs/source/user_guide/model_intro.rst @@ -9,6 +9,7 @@ General Recommendation In the class of general recommendation, the interaction of users and items(.inter file) is the only data that can be used by model. Usually, the models are trained on implicit feedback data and evaluated under the task of top-n recommendation. All the collaborative filter(CF) based models are classified in this class. + .. toctree:: :maxdepth: 1 @@ -72,6 +73,7 @@ Sequential Recommendation The task of sequential recommendation(next-item recommendation) is the same as general recommendation which sort a list of items according to preference. While the history interactions are organized in sequences and the model tend to characterize the sequential data. The models of session-based recommendation are also included in this class. + .. toctree:: :maxdepth: 1 diff --git a/recbole/evaluator/metrics.py b/recbole/evaluator/metrics.py index ca1a055aa..986b2cfd1 100644 --- a/recbole/evaluator/metrics.py +++ b/recbole/evaluator/metrics.py @@ -29,9 +29,12 @@ class Hit(TopkMetric): r"""Hit_ (also known as hit ratio at :math:`N`) is a way of calculating how many 'hits' you have in an n-sized list of ranked items. + .. _Hit: https://medium.com/@rishabhbhatia315/recommendation-system-evaluation-metrics-3f6739288870 + .. math:: \mathrm {HR@K} =\frac{Number \space of \space Hits @K}{|GT|} + :math:`HR` is the number of users with a positive sample in the recommendation list. :math:`GT` is the total number of samples in the test set. """ @@ -52,13 +55,16 @@ def metric_info(self, pos_index): class MRR(TopkMetric): r"""The MRR_ (also known as mean reciprocal rank) is a statistic measure for evaluating any process - that produces a list of possible responses to a sample of queries, ordered by probability of correctness. - .. _MRR: https://en.wikipedia.org/wiki/Mean_reciprocal_rank - .. math:: - \mathrm {MRR} = \frac{1}{|{U}|} \sum_{i=1}^{|{U}|} \frac{1}{rank_i} - :math:`U` is the number of users, :math:`rank_i` is the rank of the first item in the recommendation list - in the test set results for user :math:`i`. - """ + that produces a list of possible responses to a sample of queries, ordered by probability of correctness. + + .. _MRR: https://en.wikipedia.org/wiki/Mean_reciprocal_rank + + .. math:: + \mathrm {MRR} = \frac{1}{|{U}|} \sum_{i=1}^{|{U}|} \frac{1}{rank_i} + + :math:`U` is the number of users, :math:`rank_i` is the rank of the first item in the recommendation list + in the test set results for user :math:`i`. + """ def __init__(self, config): super().__init__(config) @@ -82,15 +88,18 @@ def metric_info(self, pos_index): class MAP(TopkMetric): r"""MAP_ (also known as Mean Average Precision) The MAP is meant to calculate Avg. Precision for the relevant items. + Note: In this case the normalization factor used is :math:`\frac{1}{\min (m,N)}`, which prevents your AP score from being unfairly suppressed when your number of recommendations couldn't possibly capture all the correct ones. + .. _MAP: http://sdsawtelle.github.io/blog/output/mean-average-precision-MAP-for-recommender-systems.html#MAP-for-Recommender-Algorithms + .. math:: - \begin{align*} - \mathrm{AP@N} &= \frac{1}{\mathrm{min}(m,N)}\sum_{k=1}^N P(k) \cdot rel(k) \\ - \mathrm{MAP@N}& = \frac{1}{|U|}\sum_{u=1}^{|U|}(\mathrm{AP@N})_u - \end{align*} + \begin{align*} + \mathrm{AP@N} &= \frac{1}{\mathrm{min}(m,N)}\sum_{k=1}^N P(k) \cdot rel(k) \\ + \mathrm{MAP@N}& = \frac{1}{|U|}\sum_{u=1}^{|U|}(\mathrm{AP@N})_u + \end{align*} """ def __init__(self, config): @@ -119,9 +128,12 @@ def metric_info(self, pos_index, pos_len): class Recall(TopkMetric): r"""Recall_ (also known as sensitivity) is the fraction of the total amount of relevant instances that were actually retrieved + .. _recall: https://en.wikipedia.org/wiki/Precision_and_recall#Recall + .. math:: \mathrm {Recall@K} = \frac{|Rel_u\cap Rec_u|}{Rel_u} + :math:`Rel_u` is the set of items relevant to user :math:`U`, :math:`Rec_u` is the top K items recommended to users. We obtain the result by calculating the average :math:`Recall@K` of each user. @@ -143,7 +155,9 @@ def metric_info(self, pos_index, pos_len): class NDCG(TopkMetric): r"""NDCG_ (also known as normalized discounted cumulative gain) is a measure of ranking quality. Through normalizing the score, users and their recommendation list results in the whole test set can be evaluated. + .. _NDCG: https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG + .. math:: \begin{gather} \mathrm {DCG@K}=\sum_{i=1}^{K} \frac{2^{rel_i}-1}{\log_{2}{(i+1)}}\\ @@ -189,9 +203,12 @@ def metric_info(self, pos_index, pos_len): class Precision(TopkMetric): r"""Precision_ (also called positive predictive value) is the fraction of relevant instances among the retrieved instances + .. _precision: https://en.wikipedia.org/wiki/Precision_and_recall#Precision + .. math:: \mathrm {Precision@K} = \frac{|Rel_u \cap Rec_u|}{Rec_u} + :math:`Rel_u` is the set of items relevant to user :math:`U`, :math:`Rec_u` is the top K items recommended to users. We obtain the result by calculating the average :math:`Precision@K` of each user. @@ -215,16 +232,20 @@ def metric_info(self, pos_index): class GAUC(object): r"""GAUC_ (also known as Group Area Under Curve) is used to evaluate the two-class model, referring to the area under the ROC curve grouped by user. + .. _GAUC: https://dl.acm.org/doi/10.1145/3219819.3219823 + Note: It calculates the AUC score of each user, and finally obtains GAUC by weighting the user AUC. It is also not limited to k. Due to our padding for `scores_tensor` in `RankEvaluator` with `-np.inf`, the padding value will influence the ranks of origin items. Therefore, we use descending sort here and make an identity transformation to the formula of `AUC`, which is shown in `auc_` function. For readability, we didn't do simplification in the code. + .. math:: \mathrm {GAUC} = \frac {{{M} \times {(M+N+1)} - \frac{M \times (M+1)}{2}} - \sum\limits_{i=1}^M rank_{i}} {{M} \times {N}} + :math:`M` is the number of positive samples. :math:`N` is the number of negative samples. :math:`rank_i` is the descending rank of the ith positive sample. @@ -278,14 +299,18 @@ def metric_info(self, pos_rank_sum, user_len_list, pos_len_list): class AUC(LossMetric): r"""AUC_ (also known as Area Under Curve) is used to evaluate the two-class model, referring to the area under the ROC curve + .. _AUC: https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve + Note: This metric does not calculate group-based AUC which considers the AUC scores averaged across users. It is also not limited to k. Instead, it calculates the scores on the entire prediction results regardless the users. + .. math:: \mathrm {AUC} = \frac{\sum\limits_{i=1}^M rank_{i} - \frac {{M} \times {(M+1)}}{2}} {{{M} \times {N}}} + :math:`M` is the number of positive samples. :math:`N` is the number of negative samples. :math:`rank_i` is the ascending rank of the ith positive sample. @@ -329,9 +354,12 @@ def metric_info(self, preds, trues): class MAE(LossMetric): r"""`Mean absolute error regression loss`__ + .. __: https://en.wikipedia.org/wiki/Mean_absolute_error + .. math:: \mathrm{MAE}=\frac{1}{|{T}|} \sum_{(u, i) \in {T}}\left|\hat{r}_{u i}-r_{u i}\right| + :math:`T` is the test set, :math:`\hat{r}_{u i}` is the score predicted by the model, and :math:`r_{u i}` the actual score of the test set. """ @@ -348,12 +376,15 @@ def metric_info(self, preds, trues): class RMSE(LossMetric): r"""`Mean std error regression loss`__ - .. __: https://en.wikipedia.org/wiki/Root-mean-square_deviation - .. math:: - \mathrm{RMSE} = \sqrt{\frac{1}{|{T}|} \sum_{(u, i) \in {T}}(\hat{r}_{u i}-r_{u i})^{2}} - :math:`T` is the test set, :math:`\hat{r}_{u i}` is the score predicted by the model, - and :math:`r_{u i}` the actual score of the test set. - """ + + .. __: https://en.wikipedia.org/wiki/Root-mean-square_deviation + + .. math:: + \mathrm{RMSE} = \sqrt{\frac{1}{|{T}|} \sum_{(u, i) \in {T}}(\hat{r}_{u i}-r_{u i})^{2}} + + :math:`T` is the test set, :math:`\hat{r}_{u i}` is the score predicted by the model, + and :math:`r_{u i}` the actual score of the test set. + """ def __init__(self, config): super().__init__(config) @@ -367,9 +398,12 @@ def metric_info(self, preds, trues): class LogLoss(LossMetric): r"""`Log loss`__, aka logistic loss or cross-entropy loss + .. __: http://wiki.fast.ai/index.php/Log_Loss + .. math:: -\log {P(y_t|y_p)} = -(({y_t}\ \log{y_p}) + {(1-y_t)}\ \log{(1 - y_p)}) + For a single sample, :math:`y_t` is true label in :math:`\{0,1\}`. :math:`y_p` is the estimated probability that :math:`y_t = 1`. """ From bbece8fe14ed91609abe997206bd2abf0e9d88a7 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 5 Aug 2021 00:26:29 +0800 Subject: [PATCH 2/2] FORMAT: code format --- tests/model/test_model_auto.py | 63 +++++++++++++++++----------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py index 252fc392f..17360cd6f 100644 --- a/tests/model/test_model_auto.py +++ b/tests/model/test_model_auto.py @@ -149,38 +149,38 @@ def test_ease(self): def test_MultiDAE(self): config_dict = { 'model': 'MultiDAE', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) def test_MultiVAE(self): config_dict = { 'model': 'MultiVAE', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) def test_enmf(self): config_dict = { 'model': 'ENMF', - 'neg_sampling': None , + 'neg_sampling': None, } quick_test(config_dict) - + def test_MacridVAE(self): config_dict = { 'model': 'MacridVAE', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) def test_CDAE(self): config_dict = { 'model': 'CDAE', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) - + def test_NNCF(self): config_dict = { 'model': 'NNCF', @@ -190,7 +190,7 @@ def test_NNCF(self): def test_RecVAE(self): config_dict = { 'model': 'RecVAE', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -378,7 +378,7 @@ def test_fpmc(self): def test_gru4rec(self): config_dict = { 'model': 'GRU4Rec', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -392,7 +392,7 @@ def test_gru4rec_with_BPR_loss(self): def test_narm(self): config_dict = { 'model': 'NARM', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -406,7 +406,7 @@ def test_narm_with_BPR_loss(self): def test_stamp(self): config_dict = { 'model': 'STAMP', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -422,7 +422,7 @@ def test_caser(self): 'model': 'Caser', 'MAX_ITEM_LIST_LENGTH': 10, 'reproducibility': False, - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -439,7 +439,7 @@ def test_nextitnet(self): config_dict = { 'model': 'NextItNet', 'reproducibility': False, - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -460,7 +460,7 @@ def test_transrec(self): def test_sasrec(self): config_dict = { 'model': 'SASRec', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -484,7 +484,7 @@ def test_srgnn(self): config_dict = { 'model': 'SRGNN', 'MAX_ITEM_LIST_LENGTH': 3, - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -500,7 +500,7 @@ def test_gcsan(self): config_dict = { 'model': 'GCSAN', 'MAX_ITEM_LIST_LENGTH': 3, - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -516,7 +516,7 @@ def test_gcsan_with_BPR_loss_and_tanh(self): def test_gru4recf(self): config_dict = { 'model': 'GRU4RecF', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -524,7 +524,7 @@ def test_gru4recf_with_max_pooling(self): config_dict = { 'model': 'GRU4RecF', 'pooling_mode': 'max', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -532,14 +532,14 @@ def test_gru4recf_with_sum_pooling(self): config_dict = { 'model': 'GRU4RecF', 'pooling_mode': 'sum', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) def test_sasrecf(self): config_dict = { 'model': 'SASRecF', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -547,7 +547,7 @@ def test_sasrecf_with_max_pooling(self): config_dict = { 'model': 'SASRecF', 'pooling_mode': 'max', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -555,14 +555,14 @@ def test_sasrecf_with_sum_pooling(self): config_dict = { 'model': 'SASRecF', 'pooling_mode': 'sum', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) def test_hrm(self): config_dict = { 'model': 'HRM', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -576,7 +576,7 @@ def test_hrm_with_BPR_loss(self): def test_npe(self): config_dict = { 'model': 'NPE', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -590,7 +590,7 @@ def test_npe_with_BPR_loss(self): def test_shan(self): config_dict = { 'model': 'SHAN', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -611,14 +611,14 @@ def test_hgn_with_CE_loss(self): config_dict = { 'model': 'HGN', 'loss_type': 'CE', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) def test_fossil(self): config_dict = { 'model': 'FOSSIL', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -631,7 +631,7 @@ def test_repeat_net(self): def test_fdsa(self): config_dict = { 'model': 'FDSA', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -639,7 +639,7 @@ def test_fdsa_with_max_pooling(self): config_dict = { 'model': 'FDSA', 'pooling_mode': 'max', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) @@ -647,14 +647,14 @@ def test_fdsa_with_sum_pooling(self): config_dict = { 'model': 'FDSA', 'pooling_mode': 'sum', - 'neg_sampling': None + 'neg_sampling': None } quick_test(config_dict) def test_bert4rec(self): config_dict = { 'model': 'BERT4Rec', - 'neg_sampling': None + 'neg_sampling': None } objective_function(config_dict=config_dict, config_file_list=config_file_list, saved=False) @@ -808,6 +808,5 @@ def test_kgnnls_with_concat(self): quick_test(config_dict) - if __name__ == '__main__': unittest.main()