diff --git a/chatterbot/adapters/logic/base_match.py b/chatterbot/adapters/logic/base_match.py index 724845328..e86ae5fa9 100644 --- a/chatterbot/adapters/logic/base_match.py +++ b/chatterbot/adapters/logic/base_match.py @@ -61,7 +61,7 @@ def process(self, input_statement): len(response_list) ) ) - response = self.break_tie(response_list, self.tie_breaking_method) + response = self.break_tie(input_statement, response_list, self.tie_breaking_method) self.logger.info(u'Tie broken. Using "{}"'.format(response.text)) else: response = self.context.storage.get_random() diff --git a/chatterbot/adapters/logic/closest_match.py b/chatterbot/adapters/logic/closest_match.py index d67ab11c0..80be46b25 100644 --- a/chatterbot/adapters/logic/closest_match.py +++ b/chatterbot/adapters/logic/closest_match.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- -from .base_match import BaseMatchAdapter from fuzzywuzzy import fuzz +from .base_match import BaseMatchAdapter + class ClosestMatchAdapter(BaseMatchAdapter): """ @@ -36,7 +37,7 @@ def get(self, input_statement): # Find the closest matching known statement for statement in statement_list: - ratio = fuzz.ratio(input_statement.text, statement.text) + ratio = fuzz.ratio(input_statement.text.lower(), statement.text.lower()) if ratio > confidence: confidence = ratio @@ -46,4 +47,3 @@ def get(self, input_statement): confidence /= 100.0 return confidence, closest_match - diff --git a/chatterbot/adapters/logic/mixins.py b/chatterbot/adapters/logic/mixins.py index 3ebd00192..29176c6de 100644 --- a/chatterbot/adapters/logic/mixins.py +++ b/chatterbot/adapters/logic/mixins.py @@ -7,7 +7,7 @@ class TieBreaking(object): that multiple responses are generated within a logic adapter. """ - def break_tie(self, statement_list, method): + def break_tie(self, input_statement, statement_list, method): METHODS = { "first_response": self.get_first_response, @@ -16,10 +16,10 @@ def break_tie(self, statement_list, method): } if method in METHODS: - return METHODS[method](statement_list) + return METHODS[method](input_statement, statement_list) # Default to the first method if an invalid method is passed in - return METHODS["first_response"](statement_list) + return METHODS["first_response"](input_statement, statement_list) def get_most_frequent_response(self, input_statement, response_list): """ @@ -42,7 +42,7 @@ def get_most_frequent_response(self, input_statement, response_list): # Choose the most commonly occuring matching response return matching_response - def get_first_response(self, response_list): + def get_first_response(self, input_statement, response_list): """ Return the first statement in the response list. """ @@ -52,7 +52,7 @@ def get_first_response(self, response_list): )) return response_list[0] - def get_random_response(self, response_list): + def get_random_response(self, input_statement, response_list): """ Choose a random response from the selection. """ diff --git a/tests/logic_adapter_tests/test_mixins.py b/tests/logic_adapter_tests/test_mixins.py index 3d197564c..cad9d4869 100644 --- a/tests/logic_adapter_tests/test_mixins.py +++ b/tests/logic_adapter_tests/test_mixins.py @@ -1,11 +1,10 @@ from unittest import TestCase -from ..base_case import ChatBotTestCase + from chatterbot.adapters.logic.mixins import TieBreaking from chatterbot.conversation import Statement, Response class TieBreakingTests(TestCase): - def setUp(self): self.mixin = TieBreaking() @@ -31,7 +30,7 @@ def test_get_first_response(self): Statement("A quest.") ] - output = self.mixin.get_first_response(statement_list) + output = self.mixin.get_first_response(Statement("Hello"), statement_list) self.assertEqual("What... is your quest?", output) @@ -42,7 +41,6 @@ def test_get_random_response(self): Statement("A phone.") ] - output = self.mixin.get_random_response(statement_list) + output = self.mixin.get_random_response(Statement("Hello"), statement_list) self.assertTrue(output) -