Code for paper "CR-Walker: Conversational Recommender System with Tree-structured Graph Reasoning and Dialog Acts" EMNLP 2021.
you can find our paper at arxiv.
Cite this paper:
@inproceedings{ma2021crwalker,
title={CR-Walker: Tree-Structured Graph Reasoning and Dialog Acts for Conversational Recommendation},
author={Ma, Wenchang and Takanobu, Ryuichi and Huang, Minlie},
booktitle={Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing},
pages={1839--1851},
year={2021},
organization={ACL}
}
-
google link to raw data and our model checkpoints. Table of content:
CR-Walker ├─data │ ├─gorecdial │ │ └─raw │ ├─gorecdial_gpt │ ├─redial │ │ └─raw │ └─redial_gpt └─saved
-
download to [your home directory]/CR-Walker/.
-
For GoRecdial:
python train_gorecdial.py --option train --model_name <your_model_name> --pretrain
-
For Redial:
python train_redial.py --option train --model_name <your_model_name> --pretrain
We implemented an MIM pretraining stage similar to KGSF to accelerate training. Also, we provided option of adding wordnet features by adding "--word_net" as command line option.
-
For GoRecdial
python train_gorecdial.py --option test --model_name gorecdial_reason_128
-
For Redial:
python train_redial.py --option test --model_name redial_reason_128
You can directly evaluate the best model checkpoints for the two datasets that we provided. The results may slightly differ from the paper since we re-trained the model. Note that the reasoning width ('sample' argument in conf.py) has been set to 1 for speed during training. You can tune it larger along with the selection threshold ('threshold' argument in conf.py) to yield better performance.
-
For GoRecdial
python train_gorecdial.py --option test_gen --model_name gorecdial_reason_128
-
For Redial:
python train_redial.py --option test_gen --model_name redial_reason_128
Similarly, you can tune the selection threshold, reasoning width and max number of leaf nodes ('max_leaf' argument in conf.py) to control generation.
python==3.6.10
pytorch==1.4.0
torch_geometric==1.6.0