diff --git a/changelog/6571.improvement.md b/changelog/6571.improvement.md new file mode 100644 index 000000000000..1f98fc652472 --- /dev/null +++ b/changelog/6571.improvement.md @@ -0,0 +1 @@ +Log the model's relative path when using CLI commands. \ No newline at end of file diff --git a/changelog/6760.bugfix.md b/changelog/6760.bugfix.md new file mode 100644 index 000000000000..8b44f62b6c45 --- /dev/null +++ b/changelog/6760.bugfix.md @@ -0,0 +1,3 @@ +Update Pika event broker to be a separate process and make it use a +`multiprocessing.Queue` to send and process messages. This change should help +avoid situations when events stop being sent after a while. \ No newline at end of file diff --git a/changelog/7001.bugfix.md b/changelog/7001.bugfix.md new file mode 100644 index 000000000000..9994c3276af4 --- /dev/null +++ b/changelog/7001.bugfix.md @@ -0,0 +1 @@ +Update Rasa Playground "Download" button to work correctly depending on the current chat state. \ No newline at end of file diff --git a/docs/themes/theme-custom/theme/Prototyper/download-button.jsx b/docs/themes/theme-custom/theme/Prototyper/download-button.jsx index 236590da1311..5f9e5bfafae2 100644 --- a/docs/themes/theme-custom/theme/Prototyper/download-button.jsx +++ b/docs/themes/theme-custom/theme/Prototyper/download-button.jsx @@ -9,7 +9,7 @@ const DownloadButton = (props) => { return ( Download project diff --git a/rasa/core/brokers/pika.py b/rasa/core/brokers/pika.py index 229baeb2e753..5f5e5e19116f 100644 --- a/rasa/core/brokers/pika.py +++ b/rasa/core/brokers/pika.py @@ -1,14 +1,12 @@ import json import logging import os +import sys import time -import typing -from collections import deque +import multiprocessing from contextlib import contextmanager -from threading import Thread from typing import ( Callable, - Deque, Dict, Optional, Text, @@ -17,6 +15,7 @@ List, Tuple, Generator, + TYPE_CHECKING, ) from rasa.constants import DEFAULT_LOG_LEVEL_LIBRARIES, ENV_LOG_LEVEL_LIBRARIES @@ -26,11 +25,11 @@ from rasa.utils.endpoints import EndpointConfig from rasa.shared.utils.io import DEFAULT_ENCODING -if typing.TYPE_CHECKING: +if TYPE_CHECKING: + import pika from pika.adapters.blocking_connection import BlockingChannel from pika import SelectConnection, BlockingConnection, BasicProperties from pika.channel import Channel - import pika from pika.connection import Parameters, Connection logger = logging.getLogger(__name__) @@ -92,6 +91,49 @@ def _pika_log_level(temporary_log_level: int) -> Generator[None, None, None]: pika_logger.setLevel(old_log_level) +def create_rabbitmq_ssl_options( + rabbitmq_host: Optional[Text] = None, +) -> Optional["pika.SSLOptions"]: + """Create RabbitMQ SSL options. + + Requires the following environment variables to be set: + + RABBITMQ_SSL_CLIENT_CERTIFICATE - path to the SSL client certificate (required) + RABBITMQ_SSL_CLIENT_KEY - path to the SSL client key (required) + RABBITMQ_SSL_CA_FILE - path to the SSL CA file for verification (optional) + RABBITMQ_SSL_KEY_PASSWORD - SSL private key password (optional) + + Details on how to enable RabbitMQ TLS support can be found here: + https://www.rabbitmq.com/ssl.html#enabling-tls + + Args: + rabbitmq_host: RabbitMQ hostname + + Returns: + Pika SSL context of type `pika.SSLOptions` if + the RABBITMQ_SSL_CLIENT_CERTIFICATE and RABBITMQ_SSL_CLIENT_KEY + environment variables are valid paths, else `None`. + """ + client_certificate_path = os.environ.get("RABBITMQ_SSL_CLIENT_CERTIFICATE") + client_key_path = os.environ.get("RABBITMQ_SSL_CLIENT_KEY") + + if client_certificate_path and client_key_path: + import pika + import rasa.server + + logger.debug(f"Configuring SSL context for RabbitMQ host '{rabbitmq_host}'.") + + ca_file_path = os.environ.get("RABBITMQ_SSL_CA_FILE") + key_password = os.environ.get("RABBITMQ_SSL_KEY_PASSWORD") + + ssl_context = rasa.server.create_ssl_context( + client_certificate_path, client_key_path, ca_file_path, key_password + ) + return pika.SSLOptions(ssl_context, rabbitmq_host) + else: + return None + + def _get_pika_parameters( host: Text, username: Text, @@ -249,71 +291,50 @@ def close_pika_connection(connection: "Connection") -> None: logger.exception("Failed to close Pika connection with host.") -class PikaEventBroker(EventBroker): - """Pika-based event broker for publishing messages to RabbitMQ.""" +MessageHeaders = Optional[Dict[Text, Text]] +Message = Tuple[Text, MessageHeaders] + + +class PikaMessageProcessor: + """A class that holds all the Pika connection details and processes Pika messages.""" def __init__( self, - host: Text, - username: Text, - password: Text, - port: Union[int, Text] = 5672, - queues: Union[List[Text], Tuple[Text], Text, None] = None, - should_keep_unpublished_messages: bool = True, - raise_on_failure: bool = False, - log_level: Union[Text, int] = os.environ.get( - ENV_LOG_LEVEL_LIBRARIES, DEFAULT_LOG_LEVEL_LIBRARIES - ), - **kwargs: Any, - ): - """Initialise RabbitMQ event broker. + parameters: "Parameters", + get_message: Callable[[], Message], + queues: Union[List[Text], Tuple[Text], Text, None], + ) -> None: + """Initialise Pika connector. Args: - host: Pika host. - username: Username for authentication with Pika host. - password: Password for authentication with Pika host. - port: port of the Pika host. - queues: Pika queues to declare and publish to. - should_keep_unpublished_messages: Whether or not the event broker should - maintain a queue of unpublished messages to be published later in - case of errors. - raise_on_failure: Whether to raise an exception if publishing fails. If - `False`, keep retrying. - log_level: Logging level. + parameters: Pika connection parameters + queues: Pika queues to declare and publish to """ - logging.getLogger("pika").setLevel(log_level) - - self.host = host - self.username = username - self.password = password - self.port = port - self.channel: Optional["Channel"] = None - self.queues = self._get_queues_from_args(queues) - self.should_keep_unpublished_messages = should_keep_unpublished_messages - self.raise_on_failure = raise_on_failure + self.parameters: "Parameters" = parameters + self.queues: List[Text] = self._get_queues_from_args(queues) + self.get_message: Callable[[], Message] = get_message - # List to store unpublished messages which hopefully will be published later - self._unpublished_messages: Deque[Text] = deque() - self._run_pika() + self._connection: Optional["SelectConnection"] = None + self._channel: Optional["Channel"] = None + self._closing = False def __del__(self) -> None: - if self.channel: - close_pika_channel(self.channel) - close_pika_connection(self.channel.connection) + if self._channel: + logger.warning("Closing connection...") + close_pika_channel(self._channel) + close_pika_connection(self._channel.connection) + + if self.is_connected: + self._connection.close() def close(self) -> None: - """Close the pika channel and connection.""" + """Close the Pika connection.""" self.__del__() - @property - def rasa_environment(self) -> Optional[Text]: - """Get value of the `RASA_ENVIRONMENT` environment variable.""" - return os.environ.get("RASA_ENVIRONMENT") - @staticmethod def _get_queues_from_args( queues_arg: Union[List[Text], Tuple[Text], Text, None] - ) -> Union[List[Text], Tuple[Text]]: + ) -> List[Text]: """Get queues for this event broker. The preferred argument defining the RabbitMQ queues the `PikaEventBroker` should @@ -331,7 +352,7 @@ def _get_queues_from_args( `ValueError` if no valid `queues` argument was found. """ if queues_arg and isinstance(queues_arg, (list, tuple)): - return queues_arg + return list(queues_arg) if queues_arg and isinstance(queues_arg, str): logger.debug( @@ -350,76 +371,295 @@ def _get_queues_from_args( return [DEFAULT_QUEUE_NAME] - @classmethod - def from_endpoint_config( - cls, broker_config: Optional["EndpointConfig"] - ) -> Optional["PikaEventBroker"]: - """Initialise `PikaEventBroker` from `EndpointConfig`. + @staticmethod + def _get_message_properties(headers: MessageHeaders = None) -> "BasicProperties": + """Create RabbitMQ message `BasicProperties`. + + The `app_id` property is set to the value of `RASA_ENVIRONMENT` env variable + if present, and the message delivery mode is set to 2 (persistent). + In addition, the `headers` property is set if supplied. Args: - broker_config: `EndpointConfig` to read. + headers: Message headers to add to the message properties of the + published message (key-value dictionary). The headers can be retrieved + in the consumer from the `headers` attribute of the message's + `BasicProperties`. Returns: - `PikaEventBroker` if `broker_config` was supplied, else `None`. + `pika.spec.BasicProperties` with the `RASA_ENVIRONMENT` environment variable + as the properties' `app_id` value, `delivery_mode=2` and `headers` as the + properties' headers. """ - if broker_config is None: - return None + from pika.spec import BasicProperties - return cls(broker_config.url, **broker_config.kwargs) + # make message persistent + kwargs = {"delivery_mode": 2} - def _run_pika(self) -> None: - parameters = _get_pika_parameters( - self.host, self.username, self.password, self.port - ) - self._pika_connection = initialise_pika_select_connection( - parameters, self._on_open_connection, self._on_open_connection_error + env = os.environ.get("RASA_ENVIRONMENT") + if env: + kwargs["app_id"] = env + + if headers: + kwargs["headers"] = headers + + return BasicProperties(**kwargs) + + @property + def is_connected(self) -> bool: + """Indicates if Pika is connected and the channel is initialized. + + Returns: + A boolean value indicating if the connection is established. + """ + return self._connection and self._channel + + def is_ready( + self, attempts: int = 1000, wait_time_between_attempts_in_seconds: float = 0.01 + ) -> bool: + """Spin until the connector is ready to process messages. + + It typically takes 50 ms or so for the pika channel to open. We'll wait up + to 10 seconds just in case. + + Args: + attempts: Number of retries. + wait_time_between_attempts_in_seconds: Wait time between retries. + + Returns: + `True` if the channel is available, `False` otherwise. + """ + while attempts: + if self.is_connected: + return True + time.sleep(wait_time_between_attempts_in_seconds) + attempts -= 1 + + return False + + def _connect(self) -> "SelectConnection": + """Establish a connection to Pika.""" + return initialise_pika_select_connection( + self.parameters, self._on_open_connection, self._on_open_connection_error ) - # Run Pika io loop in extra thread so it's not blocking - self._run_pika_io_loop_in_thread() def _on_open_connection(self, connection: "SelectConnection") -> None: - logger.debug(f"RabbitMQ connection to '{self.host}' was established.") + logger.debug( + f"RabbitMQ connection to '{self.parameters.host}' was established." + ) + connection.add_on_close_callback(self._on_connection_closed) connection.channel(on_open_callback=self._on_channel_open) def _on_open_connection_error(self, _, error: Text) -> None: logger.warning( - f"Connecting to '{self.host}' failed with error '{error}'. Trying again." + f"Connecting to '{self.parameters.host}' failed with error '{error}'. Trying again." ) + def _on_connection_closed(self, _, reason: Any): + self._channel = None + if self._closing: + # noinspection PyUnresolvedReferences + self._connection.ioloop.stop() + else: + logger.warning(f"Connection closed, reopening in 5 seconds: {reason}") + # noinspection PyUnresolvedReferences + self._connection.ioloop.call_later(5, self._reconnect) + + def _reconnect(self): + # noinspection PyUnresolvedReferences + self._connection.ioloop.stop() + + if not self._closing: + self._connection = self._connect() + # noinspection PyUnresolvedReferences + self._connection.ioloop.start() + def _on_channel_open(self, channel: "Channel") -> None: logger.debug("RabbitMQ channel was opened. Declaring fanout exchange.") + self._channel = channel + self._channel.add_on_close_callback(self._on_channel_closed) + # declare exchange of type 'fanout' in order to publish to multiple queues # (https://www.rabbitmq.com/tutorials/amqp-concepts.html#exchange-fanout) - channel.exchange_declare(RABBITMQ_EXCHANGE, exchange_type="fanout") + self._channel.exchange_declare(RABBITMQ_EXCHANGE, exchange_type="fanout") for queue in self.queues: - channel.queue_declare(queue=queue, durable=True) - channel.queue_bind(exchange=RABBITMQ_EXCHANGE, queue=queue) + self._channel.queue_declare(queue=queue, durable=True) + self._channel.queue_bind(exchange=RABBITMQ_EXCHANGE, queue=queue) + + self.process_messages() + + def _on_channel_closed(self, channel: "Channel", reason: Any): + logger.warning(f"Channel {channel} was closed: {reason}") + self._connection.close() + + def _publish(self, message: Message) -> None: + body, headers = message + + self._channel.basic_publish( + exchange=RABBITMQ_EXCHANGE, + routing_key="", + body=body.encode(DEFAULT_ENCODING), + properties=self._get_message_properties(headers), + ) - self.channel = channel + def process_messages(self) -> None: + """Start to process messages.""" - while self._unpublished_messages: - # Send unpublished messages - message = self._unpublished_messages.popleft() - self._publish(message) + try: + while True: + message = self.get_message() + self._publish(message) + logger.debug( + f"Published Pika events to exchange '{RABBITMQ_EXCHANGE}' on host " + f"'{self.parameters.host}':\n{message[0]}" + ) + except EOFError: + # Will most likely happen when shutting down Rasa X. logger.debug( - f"Published message from queue of unpublished messages. " - f"Remaining unpublished messages: {len(self._unpublished_messages)}." + "Pika message queue of worker was closed. Stopping to listen for more " + "messages on this worker." ) - def _run_pika_io_loop_in_thread(self) -> None: - thread = Thread(target=self._run_pika_io_loop, daemon=True) - thread.start() + def run(self): + """Run the message processor by connecting to RabbitMQ and then + starting the IOLoop to block and allow the SelectConnection to operate. + + This function is blocking and indefinite thus it + should be started in a separate process. + """ + self._connection = self._connect() - def _run_pika_io_loop(self) -> None: # noinspection PyUnresolvedReferences - self._pika_connection.ioloop.start() + self._connection.ioloop.start() + + +class PikaEventBroker(EventBroker): + """Pika-based event broker for publishing messages to RabbitMQ.""" + + NUMBER_OF_MP_WORKERS = 1 + MP_CONTEXT = None + + if sys.platform == "darwin" and sys.version_info < (3, 8): + # On macOS, Python 3.8 has switched the default start method to "spawn". To + # quote the documentation: "The fork start method should be considered + # unsafe as it can lead to crashes of the subprocess". Apply this fix when + # running on macOS on Python <= 3.7.x as well. + + # See: + # https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods + MP_CONTEXT = "spawn" + + def __init__( + self, + host: Text, + username: Text, + password: Text, + port: Union[int, Text] = 5672, + queues: Union[List[Text], Tuple[Text], Text, None] = None, + should_keep_unpublished_messages: bool = True, + raise_on_failure: bool = False, + log_level: Union[Text, int] = os.environ.get( + ENV_LOG_LEVEL_LIBRARIES, DEFAULT_LOG_LEVEL_LIBRARIES + ), + **kwargs: Any, + ) -> None: + """Initialise RabbitMQ event broker. + + Args: + host: Pika host. + username: Username for authentication with Pika host. + password: Password for authentication with Pika host. + port: port of the Pika host. + queues: Pika queues to declare and publish to. + should_keep_unpublished_messages: Whether or not the event broker should + maintain a queue of unpublished messages to be published later in + case of errors. + raise_on_failure: Whether to raise an exception if publishing fails. If + `False`, keep retrying. + log_level: Logging level. + """ + logging.getLogger("pika").setLevel(log_level) + + self.host = host + self.username = username + self.password = password + self.port = port + self.queues = queues + self.process: Optional[multiprocessing.Process] = None + self.should_keep_unpublished_messages = should_keep_unpublished_messages + self.raise_on_failure = raise_on_failure + self.pika_message_processor: Optional[PikaMessageProcessor] = None + + self._connect() + + def __del__(self) -> None: + if self.pika_message_processor: + self.pika_message_processor.close() + + if self.process and self.process.is_alive(): + self.process.terminate() + + def close(self) -> None: + """Close the Pika connector.""" + self.__del__() + + @classmethod + def from_endpoint_config( + cls, broker_config: Optional["EndpointConfig"] + ) -> Optional["PikaEventBroker"]: + """Initialise `PikaEventBroker` from `EndpointConfig`. + + Args: + broker_config: `EndpointConfig` to read. + + Returns: + `PikaEventBroker` if `broker_config` was supplied, else `None`. + """ + if broker_config is None: + return None + + return cls(broker_config.url, **broker_config.kwargs) + + def _connect(self) -> None: + parameters = _get_pika_parameters( + self.host, self.username, self.password, self.port + ) + + self.process_queue = self._get_mp_context().Queue() + self.pika_message_processor = PikaMessageProcessor( + parameters, + queues=self.queues, + get_message=lambda: self.process_queue.get(), + ) + self.process = self._start_pika_process() + + def _get_mp_context(self) -> multiprocessing.context.BaseContext: + return multiprocessing.get_context(self.MP_CONTEXT) + + def _start_pika_process(self) -> Optional[multiprocessing.Process]: + if self.pika_message_processor: + process = multiprocessing.Process( + target=self.pika_message_processor.run, daemon=True + ) + process.start() + return process + + return None + + def _publish(self, body: Text, headers: MessageHeaders = None) -> None: + if not self.pika_message_processor: + self._connect() + + if ( + self.process and self.process.is_alive() + ) or self.should_keep_unpublished_messages: + self.process_queue.put((body, headers)) def is_ready( self, attempts: int = 1000, wait_time_between_attempts_in_seconds: float = 0.01 ) -> bool: - """Spin until the pika channel is open. + """Spin until Pika is ready to process messages. It typically takes 50 ms or so for the pika channel to open. We'll wait up to 10 seconds just in case. @@ -431,13 +671,9 @@ def is_ready( Returns: `True` if the channel is available, `False` otherwise. """ - while attempts: - if self.channel: - return True - time.sleep(wait_time_between_attempts_in_seconds) - attempts -= 1 - - return False + return self.pika_message_processor and self.pika_message_processor.is_ready( + attempts, wait_time_between_attempts_in_seconds + ) def publish( self, @@ -467,7 +703,7 @@ def publish( f"Could not open Pika channel at host '{self.host}'. " f"Failed with error: {e}" ) - self.channel = None + self.close() if self.raise_on_failure: raise e @@ -475,110 +711,3 @@ def publish( time.sleep(retry_delay_in_seconds) logger.error(f"Failed to publish Pika event on host '{self.host}':\n{body}") - - def _get_message_properties( - self, headers: Optional[Dict[Text, Text]] = None - ) -> "BasicProperties": - """Create RabbitMQ message `BasicProperties`. - - The `app_id` property is set to the value of `self.rasa_environment` if - present, and the message delivery mode is set to 2 (persistent). In - addition, the `headers` property is set if supplied. - - Args: - headers: Message headers to add to the message properties of the - published message (key-value dictionary). The headers can be retrieved in - the consumer from the `headers` attribute of the message's - `BasicProperties`. - - Returns: - `pika.spec.BasicProperties` with the `RASA_ENVIRONMENT` environment variable - as the properties' `app_id` value, `delivery_mode`=2 and `headers` as the - properties' headers. - """ - from pika.spec import BasicProperties - - # make message persistent - kwargs = {"delivery_mode": 2} - - if self.rasa_environment: - kwargs["app_id"] = self.rasa_environment - - if headers: - kwargs["headers"] = headers - - return BasicProperties(**kwargs) - - def _basic_publish( - self, body: Text, headers: Optional[Dict[Text, Text]] = None - ) -> None: - self.channel.basic_publish( - exchange=RABBITMQ_EXCHANGE, - routing_key="", - body=body.encode(DEFAULT_ENCODING), - properties=self._get_message_properties(headers), - ) - - logger.debug( - f"Published Pika events to exchange '{RABBITMQ_EXCHANGE}' on host " - f"'{self.host}':\n{body}" - ) - - def _publish(self, body: Text, headers: Optional[Dict[Text, Text]] = None) -> None: - if self._pika_connection.is_closed: - # Try to reset connection - self._run_pika() - self._basic_publish(body, headers) - elif not self.channel and self.should_keep_unpublished_messages: - logger.warning( - f"RabbitMQ channel has not been assigned. Adding message to " - f"list of unpublished messages and trying to publish them " - f"later. Current number of unpublished messages is " - f"{len(self._unpublished_messages)}." - ) - self._unpublished_messages.append(body) - else: - self._basic_publish(body, headers) - - -def create_rabbitmq_ssl_options( - rabbitmq_host: Optional[Text] = None, -) -> Optional["pika.SSLOptions"]: - """Create RabbitMQ SSL options. - - Requires the following environment variables to be set: - - RABBITMQ_SSL_CLIENT_CERTIFICATE - path to the SSL client certificate (required) - RABBITMQ_SSL_CLIENT_KEY - path to the SSL client key (required) - RABBITMQ_SSL_CA_FILE - path to the SSL CA file for verification (optional) - RABBITMQ_SSL_KEY_PASSWORD - SSL private key password (optional) - - Details on how to enable RabbitMQ TLS support can be found here: - https://www.rabbitmq.com/ssl.html#enabling-tls - - Args: - rabbitmq_host: RabbitMQ hostname - - Returns: - Pika SSL context of type `pika.SSLOptions` if - the RABBITMQ_SSL_CLIENT_CERTIFICATE and RABBITMQ_SSL_CLIENT_KEY - environment variables are valid paths, else `None`. - """ - client_certificate_path = os.environ.get("RABBITMQ_SSL_CLIENT_CERTIFICATE") - client_key_path = os.environ.get("RABBITMQ_SSL_CLIENT_KEY") - - if client_certificate_path and client_key_path: - import pika - import rasa.server - - logger.debug(f"Configuring SSL context for RabbitMQ host '{rabbitmq_host}'.") - - ca_file_path = os.environ.get("RABBITMQ_SSL_CA_FILE") - key_password = os.environ.get("RABBITMQ_SSL_KEY_PASSWORD") - - ssl_context = rasa.server.create_ssl_context( - client_certificate_path, client_key_path, ca_file_path, key_password - ) - return pika.SSLOptions(ssl_context, rabbitmq_host) - else: - return None diff --git a/rasa/core/policies/mapping_policy.py b/rasa/core/policies/mapping_policy.py index ac3fa5686a99..bca8cb9b9d2c 100644 --- a/rasa/core/policies/mapping_policy.py +++ b/rasa/core/policies/mapping_policy.py @@ -75,7 +75,7 @@ def validate_against_domain( "You have defined triggers in your domain, but haven't " "added the MappingPolicy to your policy ensemble. " "Either remove the triggers from your domain or " - "exclude the MappingPolicy from your policy configuration." + "include the MappingPolicy in your policy configuration." ) def train( diff --git a/rasa/model.py b/rasa/model.py index 51a570e3bc00..76b67719a298 100644 --- a/rasa/model.py +++ b/rasa/model.py @@ -150,6 +150,12 @@ def get_model(model_path: Text = DEFAULT_MODELS_PATH) -> TempDirectoryPath: elif not model_path.endswith(".tar.gz"): raise ModelNotFound(f"Path '{model_path}' does not point to a Rasa model file.") + try: + model_relative_path = os.path.relpath(model_path) + except ValueError: + model_relative_path = model_path + logger.info(f"Loading model {model_relative_path}...") + return unpack_model(model_path) diff --git a/tests/core/test_broker.py b/tests/core/test_broker.py index 91f722d2087c..f3e493d171db 100644 --- a/tests/core/test_broker.py +++ b/tests/core/test_broker.py @@ -15,19 +15,32 @@ from rasa.core.brokers.broker import EventBroker from rasa.core.brokers.file import FileEventBroker from rasa.core.brokers.kafka import KafkaEventBroker -from rasa.core.brokers.pika import PikaEventBroker, DEFAULT_QUEUE_NAME +from rasa.core.brokers.pika import ( + PikaEventBroker, + PikaMessageProcessor, + DEFAULT_QUEUE_NAME, +) from rasa.core.brokers.sql import SQLEventBroker from rasa.shared.core.events import Event, Restarted, SlotSet, UserUttered from rasa.utils.endpoints import EndpointConfig, read_endpoint_config +import pika.connection + TEST_EVENTS = [ UserUttered("/greet", {"name": "greet", "confidence": 1.0}, []), SlotSet("name", "rasa"), Restarted(), ] +TEST_CONNECTION_PARAMETERS = pika.connection.ConnectionParameters( + "amqp://username:password@host:port" +) + + +def test_pika_broker_from_config(monkeypatch: MonkeyPatch): + # patch PikaEventBroker so it doesn't try to connect to RabbitMQ on init + monkeypatch.setattr(PikaEventBroker, "_connect", lambda _: None) -def test_pika_broker_from_config(): cfg = read_endpoint_config( "data/test_endpoints/event_brokers/pika_endpoint.yml", "event_broker" ) @@ -41,18 +54,18 @@ def test_pika_broker_from_config(): # noinspection PyProtectedMember def test_pika_message_property_app_id(monkeypatch: MonkeyPatch): - # patch PikaEventBroker so it doesn't try to connect to RabbitMQ on init - monkeypatch.setattr(PikaEventBroker, "_run_pika", lambda _: None) - pika_producer = PikaEventBroker("", "", "") + pika_processor = PikaMessageProcessor( + TEST_CONNECTION_PARAMETERS, queues=None, get_message=lambda: ("", None) + ) # unset RASA_ENVIRONMENT env var results in empty App ID monkeypatch.delenv("RASA_ENVIRONMENT", raising=False) - assert not pika_producer._get_message_properties().app_id + assert not pika_processor._get_message_properties().app_id # setting it to some value results in that value as the App ID rasa_environment = "some-test-environment" monkeypatch.setenv("RASA_ENVIRONMENT", rasa_environment) - assert pika_producer._get_message_properties().app_id == rasa_environment + assert pika_processor._get_message_properties().app_id == rasa_environment @pytest.mark.parametrize( @@ -70,15 +83,15 @@ def test_pika_queues_from_args( queues_arg: Union[Text, List[Text], None], expected: List[Text], warning: Optional[Type[Warning]], - monkeypatch: MonkeyPatch, ): - # patch PikaEventBroker so it doesn't try to connect to RabbitMQ on init - monkeypatch.setattr(PikaEventBroker, "_run_pika", lambda _: None) - with pytest.warns(warning): - pika_producer = PikaEventBroker("", "", "", queues=queues_arg) + pika_processor = PikaMessageProcessor( + TEST_CONNECTION_PARAMETERS, + queues=queues_arg, + get_message=lambda: ("", None), + ) - assert pika_producer.queues == expected + assert pika_processor.queues == expected def test_no_broker_in_config():