diff --git a/MANIFEST.in b/MANIFEST.in index d50b202e4..5c5f176d2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,6 +4,7 @@ include requirements.txt include database.db recursive-include tests * +recursive-include corpus *.json recursive-exclude * *.pyc recursive-exclude * *.py~ diff --git a/chatterbot/adapters/storage/jsondatabase.py b/chatterbot/adapters/storage/jsondatabase.py index 6988fc846..8a8ae0282 100644 --- a/chatterbot/adapters/storage/jsondatabase.py +++ b/chatterbot/adapters/storage/jsondatabase.py @@ -123,5 +123,6 @@ def drop(self): """ import os - os.remove(self.database.path) + if os.path.exists(self.database.path): + os.remove(self.database.path) diff --git a/chatterbot/corpus/__init__.py b/chatterbot/corpus/__init__.py index e69de29bb..2e3bc0d55 100644 --- a/chatterbot/corpus/__init__.py +++ b/chatterbot/corpus/__init__.py @@ -0,0 +1,2 @@ +from .corpus import Corpus + diff --git a/chatterbot/corpus/corpus.py b/chatterbot/corpus/corpus.py new file mode 100644 index 000000000..2b91a209a --- /dev/null +++ b/chatterbot/corpus/corpus.py @@ -0,0 +1,61 @@ +import os, json + + +class Corpus(object): + + def __init__(self): + current_directory = os.path.dirname(__file__) + self.data_directory = os.path.join(current_directory, 'data') + + def get_file_path(self, dotted_path): + """ + Reads a dotted file path and returns the file path. + """ + parts = dotted_path.split(".") + if parts[0] == 'chatterbot': + parts.pop(0) + parts[0] = self.data_directory + + corpus_path = os.path.join(*parts) + + if os.path.exists(corpus_path + ".json"): + corpus_path += ".json" + + return corpus_path + + def read_corpus(self, file_name): + """ + Read and return the data from a corpus json file. + """ + with open(file_name) as data_file: + data = json.load(data_file) + return data + + def load_corpus(self, dotted_path): + """ + Return the data contained within a specified corpus. + """ + + corpus_path = self.get_file_path(dotted_path) + + corpora = [] + + if os.path.isdir(corpus_path): + for dirname, dirnames, filenames in os.walk(corpus_path): + for datafile in filenames: + if datafile.endswith(".json"): + + corpus = self.read_corpus( + os.path.join(dirname, datafile) + ) + + for key in list(corpus.keys()): + corpora.append(corpus[key]) + else: + corpus = self.read_corpus(corpus_path) + + for key in list(corpus.keys()): + corpora.append(corpus[key]) + + return corpora + diff --git a/chatterbot/corpus/english/conversations.json b/chatterbot/corpus/data/english/conversations.json similarity index 100% rename from chatterbot/corpus/english/conversations.json rename to chatterbot/corpus/data/english/conversations.json diff --git a/chatterbot/corpus/english/greetings.json b/chatterbot/corpus/data/english/greetings.json similarity index 100% rename from chatterbot/corpus/english/greetings.json rename to chatterbot/corpus/data/english/greetings.json diff --git a/chatterbot/corpus/english/__init__.py b/chatterbot/corpus/english/__init__.py deleted file mode 100644 index d0d078cff..000000000 --- a/chatterbot/corpus/english/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from chatterbot.corpus.utils import read_corpus -import os, sys - - -current_directory = os.path.dirname(__file__) - -_greetings = read_corpus(current_directory + '/greetings.json') -_conversations = read_corpus(current_directory + '/conversations.json') - -setattr( - sys.modules[__name__], - 'greetings', _greetings['greetings'] -) - -setattr( - sys.modules[__name__], - 'conversations', _conversations['conversations'] -) - -modules = [_greetings, _conversations] - diff --git a/chatterbot/corpus/utils.py b/chatterbot/corpus/utils.py deleted file mode 100644 index c212c2599..000000000 --- a/chatterbot/corpus/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -import json - -def read_corpus(file_name): - """ - Read and return the data from a corpus json file. - """ - with open(file_name) as data_file: - data = json.load(data_file) - return data - -def load_corpus(corpus_path): - """ - Return the data contained within a specified corpus. - """ - from chatterbot.utils.module_loading import import_module - from types import ModuleType - - corpus = import_module(corpus_path) - - if isinstance(corpus, ModuleType): - corpora = [] - for module in corpus.modules: - for key in list(module.keys()): - corpora.append(module[key]) - return corpora - - return [corpus] - diff --git a/chatterbot/training.py b/chatterbot/training.py index 799ca3b02..dadb9051c 100644 --- a/chatterbot/training.py +++ b/chatterbot/training.py @@ -1,11 +1,12 @@ -from .corpus.utils import load_corpus from .conversation import Statement +from .corpus import Corpus class Trainer(object): def __init__(self, chatbot, **kwargs): self.chatbot = chatbot + self.corpus = Corpus() def train_from_list(self, conversation): @@ -30,7 +31,7 @@ def train_from_list(self, conversation): def train_from_corpora(self, corpora): for corpus in corpora: - corpus_data = load_corpus(corpus) + corpus_data = self.corpus.load_corpus(corpus) for data in corpus_data: for pair in data: self.train_from_list(pair) diff --git a/tests/base_case.py b/tests/base_case.py index f2e8a0322..9c65336ca 100644 --- a/tests/base_case.py +++ b/tests/base_case.py @@ -1,10 +1,46 @@ from unittest import TestCase from chatterbot import ChatBot +import os -class ChatBotTestCase(TestCase): +class UntrainedChatBotTestCase(TestCase): + + def setUp(self): + self.test_data_directory = 'test_data' + self.test_database_name = self.random_string() + ".db" + + if not os.path.exists(self.test_data_directory): + os.makedirs(self.test_data_directory) + + database_path = self.test_data_directory + '/' + self.test_database_name + + self.chatbot = ChatBot("Test Bot", database=database_path) + + def random_string(self, start=0, end=9000): + """ + Generate a string based on a random number. + """ + from random import randint + return str(randint(start, end)) + + def remove_test_data(self): + import shutil + + if os.path.exists(self.test_data_directory): + shutil.rmtree(self.test_data_directory) + + def tearDown(self): + """ + Remove the test database. + """ + self.chatbot.storage.drop() + self.remove_test_data() + + +class ChatBotTestCase(UntrainedChatBotTestCase): def setUp(self): + super(ChatBotTestCase, self).setUp() """ Set up a database for testing. """ @@ -27,43 +63,7 @@ def setUp(self): "Blue." ] - self.chatbot = ChatBot("Test Bot", database="test-database.db") - self.chatbot.train(data1) self.chatbot.train(data2) self.chatbot.train(data3) - def tearDown(self): - """ - Remove the test database. - """ - self.chatbot.storage.drop() - - -class UntrainedChatBotTestCase(TestCase): - """ - This is a test case for use when the - chat bot should not start with any - prior training. - """ - - def setUp(self): - """ - Set up a database for testing. - """ - test_db = self.random_string() + ".db" - self.chatbot = ChatBot("Test Bot", database=test_db) - - def random_string(self, start=0, end=9000): - """ - Generate a string based on a random number. - """ - from random import randint - return str(randint(start, end)) - - def tearDown(self): - """ - Remove the test database. - """ - self.chatbot.storage.drop() - diff --git a/tests/corpus_tests/test_corpus.py b/tests/corpus_tests/test_corpus.py index ba8a1b870..fd505ed9a 100644 --- a/tests/corpus_tests/test_corpus.py +++ b/tests/corpus_tests/test_corpus.py @@ -1,23 +1,40 @@ from unittest import TestCase -from chatterbot.corpus.utils import read_corpus, load_corpus +from chatterbot.corpus import Corpus +import os class CorpusUtilsTestCase(TestCase): + def setUp(self): + self.corpus = Corpus() + + def test_get_file_path(self): + """ + Test that a dotted path is properly converted to a file address. + """ + path = self.corpus.get_file_path("chatterbot.corpus.english") + self.assertIn( + os.path.join("chatterbot", "corpus", "data", "english"), + path + ) + def test_read_corpus(self): - #data = read_corpus("chatterbot/corpus/english/greetings/conversations.json") - # TODO - pass + corpus_path = os.path.join( + self.corpus.data_directory, + "english", "conversations.json" + ) + data = self.corpus.read_corpus(corpus_path) + self.assertIn("conversations", data) def test_load_corpus(self): - corpus = load_corpus("chatterbot.corpus.english.greetings") + corpus = self.corpus.load_corpus("chatterbot.corpus.english.greetings") self.assertEqual(len(corpus), 1) self.assertIn(["Hi", "Hello"], corpus[0]) def test_load_corpus_general(self): - corpus = load_corpus("chatterbot.corpus.english") + corpus = self.corpus.load_corpus("chatterbot.corpus.english") self.assertEqual(len(corpus), 2) - self.assertIn(["Hi", "Hello"], corpus[0]) + self.assertIn(["Hi", "Hello"], corpus[1])