Skip to content

Commit

Permalink
Merge pull request #7895 from RasaHQ/reminders-lock-fix-1-10-x
Browse files Browse the repository at this point in the history
Added Locking Mechanism to Reminders Handler
  • Loading branch information
b-quachtran authored Feb 19, 2021
2 parents 8291705 + fa472a5 commit 54cc0f9
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 21 deletions.
1 change: 1 addition & 0 deletions changelog/7895.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed bug where the conversation does not lock before handling a reminder event.
1 change: 1 addition & 0 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ def create_processor(
self.policy_ensemble,
self.domain,
self.tracker_store,
self.lock_store,
self.nlg,
action_endpoint=self.action_endpoint,
message_preprocessor=preprocessor,
Expand Down
40 changes: 20 additions & 20 deletions rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
RegexInterpreter,
)
from rasa.core.nlg import NaturalLanguageGenerator
from rasa.core.lock_store import LockStore
from rasa.core.policies.ensemble import PolicyEnsemble
from rasa.core.tracker_store import TrackerStore
from rasa.core.trackers import DialogueStateTracker, EventVerbosity
Expand All @@ -69,6 +70,7 @@ def __init__(
policy_ensemble: PolicyEnsemble,
domain: Domain,
tracker_store: TrackerStore,
lock_store: LockStore,
generator: NaturalLanguageGenerator,
action_endpoint: Optional[EndpointConfig] = None,
max_number_of_predictions: int = MAX_NUMBER_OF_PREDICTIONS,
Expand All @@ -80,6 +82,7 @@ def __init__(
self.policy_ensemble = policy_ensemble
self.domain = domain
self.tracker_store = tracker_store
self.lock_store = lock_store
self.max_number_of_predictions = max_number_of_predictions
self.message_preprocessor = message_preprocessor
self.on_circuit_break = on_circuit_break
Expand Down Expand Up @@ -348,28 +351,25 @@ async def handle_reminder(
) -> None:
"""Handle a reminder that is triggered asynchronously."""

tracker = await self.get_tracker_with_session_start(sender_id, output_channel)

if not tracker:
logger.warning(
f"Failed to retrieve tracker for conversation ID '{sender_id}'."
async with self.lock_store.lock(sender_id):
tracker = await self.get_tracker_with_session_start(
sender_id, output_channel
)
return None

if (
reminder_event.kill_on_user_message
and self._has_message_after_reminder(tracker, reminder_event)
or not self._is_reminder_still_valid(tracker, reminder_event)
):
logger.debug(
f"Canceled reminder because it is outdated ({reminder_event})."
)
else:
intent = reminder_event.intent
entities = reminder_event.entities or {}
await self.trigger_external_user_uttered(
intent, entities, tracker, output_channel
)
if (
reminder_event.kill_on_user_message
and self._has_message_after_reminder(tracker, reminder_event)
or not self._is_reminder_still_valid(tracker, reminder_event)
):
logger.debug(
f"Canceled reminder because it is outdated ({reminder_event})."
)
else:
intent = reminder_event.intent
entities = reminder_event.entities or {}
await self.trigger_external_user_uttered(
intent, entities, tracker, output_channel
)

async def trigger_external_user_uttered(
self,
Expand Down
3 changes: 3 additions & 0 deletions tests/core/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from rasa.core.processor import MessageProcessor
from rasa.core.slots import Slot
from rasa.core.tracker_store import InMemoryTrackerStore, MongoTrackerStore
from rasa.core.lock_store import LockStore, InMemoryLockStore
from rasa.core.trackers import DialogueStateTracker


Expand Down Expand Up @@ -148,11 +149,13 @@ def default_channel() -> OutputChannel:
@pytest.fixture
async def default_processor(default_agent: Agent) -> MessageProcessor:
tracker_store = InMemoryTrackerStore(default_agent.domain)
lock_store = InMemoryLockStore()
return MessageProcessor(
default_agent.interpreter,
default_agent.policy_ensemble,
default_agent.domain,
tracker_store,
lock_store,
TemplatedNaturalLanguageGenerator(default_agent.domain.templates),
)

Expand Down
28 changes: 27 additions & 1 deletion tests/core/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid
import json
from _pytest.monkeypatch import MonkeyPatch
from _pytest.logging import LogCaptureFixture
from aioresponses import aioresponses
from typing import Optional, Text, List
from unittest.mock import patch
Expand Down Expand Up @@ -109,7 +110,7 @@ async def test_http_parsing():

inter = RasaNLUHttpInterpreter(endpoint_config=endpoint)
try:
await MessageProcessor(inter, None, None, None, None)._parse_message(
await MessageProcessor(inter, None, None, None, None, None)._parse_message(
message
)
except KeyError:
Expand Down Expand Up @@ -181,6 +182,31 @@ async def test_reminder_scheduled(
assert t.events[-1] == ActionExecuted("action_listen")


async def test_reminder_lock(
default_channel: CollectingOutputChannel,
default_processor: MessageProcessor,
caplog: LogCaptureFixture,
):
caplog.clear()
with caplog.at_level(logging.DEBUG):
sender_id = uuid.uuid4().hex

reminder = ReminderScheduled("remind", datetime.datetime.now())
tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

tracker.update(UserUttered("test"))
tracker.update(ActionExecuted("action_schedule_reminder"))
tracker.update(reminder)

default_processor.tracker_store.save(tracker)

await default_processor.handle_reminder(
reminder, sender_id, default_channel, default_processor.nlg
)

assert f"Deleted lock for conversation '{sender_id}'." in caplog.text


async def test_reminder_aborted(
default_channel: CollectingOutputChannel, default_processor: MessageProcessor
):
Expand Down

0 comments on commit 54cc0f9

Please sign in to comment.