Skip to content

Commit

Permalink
Merge pull request #8001 from RasaHQ/reminders-lock-fix-2-X
Browse files Browse the repository at this point in the history
Added Locking Mechanism to Reminders Handler
  • Loading branch information
akelad authored Feb 25, 2021
2 parents 3f054c3 + 29f9b58 commit 3fb8f13
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 19 deletions.
1 change: 1 addition & 0 deletions changelog/8001.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed bug where the conversation does not lock before handling a reminder event.
1 change: 1 addition & 0 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ def create_processor(
self.policy_ensemble,
self.domain,
self.tracker_store,
self.lock_store,
self.nlg,
action_endpoint=self.action_endpoint,
message_preprocessor=preprocessor,
Expand Down
37 changes: 21 additions & 16 deletions rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
UTTER_PREFIX,
)
from rasa.core.nlg import NaturalLanguageGenerator
from rasa.core.lock_store import LockStore
from rasa.core.policies.ensemble import PolicyEnsemble
import rasa.core.tracker_store
import rasa.shared.core.trackers
Expand All @@ -63,6 +64,7 @@ def __init__(
policy_ensemble: PolicyEnsemble,
domain: Domain,
tracker_store: rasa.core.tracker_store.TrackerStore,
lock_store: LockStore,
generator: NaturalLanguageGenerator,
action_endpoint: Optional[EndpointConfig] = None,
max_number_of_predictions: int = MAX_NUMBER_OF_PREDICTIONS,
Expand All @@ -74,6 +76,7 @@ def __init__(
self.policy_ensemble = policy_ensemble
self.domain = domain
self.tracker_store = tracker_store
self.lock_store = lock_store
self.max_number_of_predictions = max_number_of_predictions
self.message_preprocessor = message_preprocessor
self.on_circuit_break = on_circuit_break
Expand Down Expand Up @@ -418,24 +421,26 @@ async def handle_reminder(
output_channel: OutputChannel,
) -> None:
"""Handle a reminder that is triggered asynchronously."""

tracker = await self.fetch_tracker_and_update_session(sender_id, output_channel)

if (
reminder_event.kill_on_user_message
and self._has_message_after_reminder(tracker, reminder_event)
or not self._is_reminder_still_valid(tracker, reminder_event)
):
logger.debug(
f"Canceled reminder because it is outdated ({reminder_event})."
)
else:
intent = reminder_event.intent
entities = reminder_event.entities or {}
await self.trigger_external_user_uttered(
intent, entities, tracker, output_channel
async with self.lock_store.lock(sender_id):
tracker = await self.fetch_tracker_and_update_session(
sender_id, output_channel
)

if (
reminder_event.kill_on_user_message
and self._has_message_after_reminder(tracker, reminder_event)
or not self._is_reminder_still_valid(tracker, reminder_event)
):
logger.debug(
f"Canceled reminder because it is outdated ({reminder_event})."
)
else:
intent = reminder_event.intent
entities = reminder_event.entities or {}
await self.trigger_external_user_uttered(
intent, entities, tracker, output_channel
)

async def trigger_external_user_uttered(
self,
intent_name: Text,
Expand Down
2 changes: 2 additions & 0 deletions tests/core/actions/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rasa.core.policies.policy import PolicyPrediction
from rasa.core.processor import MessageProcessor
from rasa.core.tracker_store import InMemoryTrackerStore
from rasa.core.lock_store import InMemoryLockStore
from rasa.core.actions import action
from rasa.core.actions.action import ActionExecutionRejection
from rasa.shared.core.constants import ACTION_LISTEN_NAME, REQUESTED_SLOT
Expand Down Expand Up @@ -144,6 +145,7 @@ async def test_switch_forms_with_same_slot(default_agent: Agent):
default_agent.policy_ensemble,
domain,
InMemoryTrackerStore(domain),
InMemoryLockStore(),
TemplatedNaturalLanguageGenerator(domain.templates),
)

Expand Down
3 changes: 3 additions & 0 deletions tests/core/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from rasa.core.processor import MessageProcessor
from rasa.shared.core.slots import Slot
from rasa.core.tracker_store import InMemoryTrackerStore, MongoTrackerStore
from rasa.core.lock_store import LockStore, InMemoryLockStore
from rasa.shared.core.trackers import DialogueStateTracker

DEFAULT_DOMAIN_PATH_WITH_SLOTS = "data/test_domains/default_with_slots.yml"
Expand Down Expand Up @@ -142,11 +143,13 @@ def default_channel() -> OutputChannel:
@pytest.fixture
async def default_processor(default_agent: Agent) -> MessageProcessor:
tracker_store = InMemoryTrackerStore(default_agent.domain)
lock_store = InMemoryLockStore()
return MessageProcessor(
default_agent.interpreter,
default_agent.policy_ensemble,
default_agent.domain,
tracker_store,
lock_store,
TemplatedNaturalLanguageGenerator(default_agent.domain.templates),
)

Expand Down
45 changes: 42 additions & 3 deletions tests/core/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid
import json
from _pytest.monkeypatch import MonkeyPatch
from _pytest.logging import LogCaptureFixture
from aioresponses import aioresponses
from typing import Optional, Text, List, Callable, Type, Any, Tuple
from unittest.mock import patch, Mock
Expand Down Expand Up @@ -50,6 +51,7 @@
from rasa.core.processor import MessageProcessor
from rasa.shared.core.slots import Slot, AnySlot
from rasa.core.tracker_store import InMemoryTrackerStore
from rasa.core.lock_store import InMemoryLockStore
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.nlu.constants import INTENT_NAME_KEY
from rasa.utils.endpoints import EndpointConfig
Expand Down Expand Up @@ -134,7 +136,9 @@ async def test_http_parsing():

inter = RasaNLUHttpInterpreter(endpoint_config=endpoint)
try:
await MessageProcessor(inter, None, None, None, None).parse_message(message)
await MessageProcessor(inter, None, None, None, None, None).parse_message(
message
)
except KeyError:
pass # logger looks for intent and entities, so we except

Expand Down Expand Up @@ -204,6 +208,29 @@ async def test_reminder_scheduled(
)


async def test_reminder_lock(
default_channel: CollectingOutputChannel,
default_processor: MessageProcessor,
caplog: LogCaptureFixture,
):
caplog.clear()
with caplog.at_level(logging.DEBUG):
sender_id = uuid.uuid4().hex

reminder = ReminderScheduled("remind", datetime.datetime.now())
tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

tracker.update(UserUttered("test"))
tracker.update(ActionExecuted("action_schedule_reminder"))
tracker.update(reminder)

default_processor.tracker_store.save(tracker)

await default_processor.handle_reminder(reminder, sender_id, default_channel)

assert f"Deleted lock for conversation '{sender_id}'." in caplog.text


async def test_trigger_external_latest_input_channel(
default_channel: CollectingOutputChannel, default_processor: MessageProcessor
):
Expand Down Expand Up @@ -853,7 +880,12 @@ def predict_action_probabilities(
domain = Domain.empty()

processor = MessageProcessor(
test_interpreter, ensemble, domain, InMemoryTrackerStore(domain), Mock()
test_interpreter,
ensemble,
domain,
InMemoryTrackerStore(domain),
InMemoryLockStore(),
Mock(),
)

# This should not raise
Expand Down Expand Up @@ -883,7 +915,12 @@ def test_get_next_action_probabilities_pass_policy_predictions_without_interpret
domain = Domain.empty()

processor = MessageProcessor(
interpreter, ensemble, domain, InMemoryTrackerStore(domain), Mock()
interpreter,
ensemble,
domain,
InMemoryTrackerStore(domain),
InMemoryLockStore(),
Mock(),
)

with pytest.warns(DeprecationWarning):
Expand Down Expand Up @@ -1173,11 +1210,13 @@ def probabilities_using_best_policy(
return PolicyPrediction.for_action_name(domain, ACTION_LISTEN_NAME)

tracker_store = InMemoryTrackerStore(domain)
lock_store = InMemoryLockStore()
processor = MessageProcessor(
RegexInterpreter(),
ConstantEnsemble(),
domain,
tracker_store,
lock_store,
NaturalLanguageGenerator.create(None, domain),
)

Expand Down

0 comments on commit 3fb8f13

Please sign in to comment.