diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4b31f69..d457b32 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,9 +14,14 @@ Change Log Unreleased ********** +[3.10.0] - 2023-05-05 +********************* +Changed +======= * Switch from ``edx-sphinx-theme`` to ``sphinx-book-theme`` since the former is deprecated * Refactored consumer to manually deserialize messages instead of using DeserializingConsumer +* Make signal argument optional in consumer command (take signal from message headers) [3.9.6] - 2023-02-24 ******************** diff --git a/edx_event_bus_kafka/__init__.py b/edx_event_bus_kafka/__init__.py index a6f8c79..4f8cdb8 100644 --- a/edx_event_bus_kafka/__init__.py +++ b/edx_event_bus_kafka/__init__.py @@ -9,4 +9,4 @@ from edx_event_bus_kafka.internal.consumer import KafkaEventConsumer from edx_event_bus_kafka.internal.producer import KafkaEventProducer, create_producer -__version__ = '3.9.6' +__version__ = '3.10.0' diff --git a/edx_event_bus_kafka/internal/consumer.py b/edx_event_bus_kafka/internal/consumer.py index 49e9721..3c214e8 100644 --- a/edx_event_bus_kafka/internal/consumer.py +++ b/edx_event_bus_kafka/internal/consumer.py @@ -103,21 +103,20 @@ def _reconnect_to_db_if_needed(): class KafkaEventConsumer: """ - Construct consumer for the given topic, group, and signal. The consumer can then - emit events from the event bus using the configured signal. + Construct consumer for the given topic and group. The consumer can then + emit events from the event bus using the signal from the message headers. Note that the topic should be specified here *without* the optional environment prefix. Can also consume messages indefinitely off the queue. """ - def __init__(self, topic, group_id, signal): + def __init__(self, topic, group_id): if confluent_kafka is None: # pragma: no cover raise Exception('Library confluent-kafka not available. Cannot create event consumer.') self.topic = topic self.group_id = group_id - self.signal = signal self.consumer = self._create_consumer() self._shut_down_loop = False self.schema_registry_client = get_schema_registry_client() @@ -240,7 +239,6 @@ def _consume_indefinitely(self): run_context = { 'full_topic': full_topic, 'consumer_group': self.group_id, - 'expected_signal': self.signal, } self.consumer.subscribe([full_topic]) logger.info(f"Running consumer for {run_context!r}") @@ -268,8 +266,9 @@ def _consume_indefinitely(self): with function_trace('_consume_indefinitely_consume_single_message'): # Before processing, make sure our db connection is still active _reconnect_to_db_if_needed() - msg.set_value(self._deserialize_message_value(msg)) - self.emit_signals_from_message(msg) + signal = self.determine_signal(msg) + msg.set_value(self._deserialize_message_value(msg, signal)) + self.emit_signals_from_message(msg, signal) consecutive_errors = 0 self._add_message_monitoring(run_context=run_context, message=msg) @@ -317,54 +316,42 @@ def consume_indefinitely(self, offset_timestamp=None): self.reset_offsets_and_sleep_indefinitely(offset_timestamp) @function_trace('emit_signals_from_message') - def emit_signals_from_message(self, msg): + def emit_signals_from_message(self, msg, signal): """ - Determine the correct signal and send the event from the message. + Send the event from the message via the given signal. + + Assumes the message has been deserialized and the signal matches the event_type of the message header. Arguments: - msg (Message): Consumed message. + msg (Message): Deserialized message. + signal (OpenEdxPublicSignal): Signal - must match the event_type of the message header. """ self._log_message_received(msg) - # DeserializingConsumer.poll() always returns either a valid message - # or None, and raises an exception in all other cases. This means - # we don't need to check msg.error() ourselves. But... check it here - # anyway for robustness against code changes. if msg.error() is not None: raise UnusableMessageError( - f"Polled message had error object (shouldn't happen): {msg.error()!r}" + f"Polled message had error object: {msg.error()!r}" ) - headers = msg.headers() or [] # treat None as [] - - event_types = get_message_header_values(headers, HEADER_EVENT_TYPE) - if len(event_types) == 0: - raise UnusableMessageError( - "Missing ce_type header on message, cannot determine signal" - ) - if len(event_types) > 1: - raise UnusableMessageError( - "Multiple ce_type headers found on message, cannot determine signal" - ) - event_type = event_types[0] + # This should also never happen since the signal should be determined from the message + # but it's here to prevent misuse of the method + msg_event_type = self._get_event_type_from_message(msg) + if signal.event_type != msg_event_type: + raise Exception(f"Error emitting event from Kafka: (UNEXPECTED) message event type {msg_event_type}" + f" does not match signal {signal.event_type}") - if event_type != self.signal.event_type: - raise UnusableMessageError( - f"Signal types do not match. Expected {self.signal.event_type}. " - f"Received message of type {event_type}." - ) try: - event_metadata = _get_metadata_from_headers(headers) + event_metadata = _get_metadata_from_headers(msg.headers()) except Exception as e: raise UnusableMessageError(f"Error determining metadata from message headers: {e}") from e with function_trace('emit_signals_from_message_send_event_with_custom_metadata'): - send_results = self.signal.send_event_with_custom_metadata(event_metadata, **msg.value()) + send_results = signal.send_event_with_custom_metadata(event_metadata, **msg.value()) # Raise an exception if any receivers errored out. This allows logging of the receivers # along with partition, offset, etc. in record_event_consuming_error. Hopefully the # receiver code is idempotent and we can just replay any messages that were involved. - self._check_receiver_results(send_results) + self._check_receiver_results(send_results, signal) # At the very end, log that a message was processed successfully. # Since we're single-threaded, no other information is needed; @@ -373,21 +360,69 @@ def emit_signals_from_message(self, msg): if AUDIT_LOGGING_ENABLED.is_enabled(): logger.info('Message from Kafka processed successfully') - def _deserialize_message_value(self, msg): + def determine_signal(self, msg) -> OpenEdxPublicSignal: + """ + Determine which OpenEdxPublicSignal should be used to emit the event data in a message + + Arguments: + msg (Message): Consumed message + + Returns: + The OpenEdxPublicSignal instance corresponding to the ce_type header on the message + """ + event_type = self._get_event_type_from_message(msg) + try: + return OpenEdxPublicSignal.get_signal_by_type(event_type) + except KeyError as ke: + raise UnusableMessageError( + f"Unrecognized type {event_type} found on message, cannot determine signal" + ) from ke + + def _get_event_type_from_message(self, msg): + """ + Return the event type from the ce_type header + + Arguments: + msg (Message): the consumed message + + Returns + The associated event type as a string + """ + headers = msg.headers() or [] # treat None as [] + event_types = get_message_header_values(headers, HEADER_EVENT_TYPE) + if len(event_types) == 0: + raise UnusableMessageError( + "Missing ce_type header on message, cannot determine signal" + ) + if len(event_types) > 1: + raise UnusableMessageError( + "Multiple ce_type headers found on message, cannot determine signal" + ) + return event_types[0] + + def _deserialize_message_value(self, msg, signal: OpenEdxPublicSignal): """ Deserialize an Avro message value + The signal is expected to match the ce_type header on the message + Arguments: msg (Message): the raw message from the consumer + signal (OpenEdxPublicSignal): The instance of OpenEdxPublicSignal corresponding to the ce_type header on msg Returns: The deserialized message value """ - signal_deserializer = get_deserializer(self.signal, self.schema_registry_client) + msg_event_type = self._get_event_type_from_message(msg) + if signal.event_type != msg_event_type: + # This should never happen but it's here to prevent misuse of the method + raise Exception(f"Error deserializing event from Kafka: (UNEXPECTED) message event type {msg_event_type}" + f" does not match signal {signal.event_type}") + signal_deserializer = get_deserializer(signal, self.schema_registry_client) ctx = SerializationContext(msg.topic(), MessageField.VALUE, msg.headers()) return signal_deserializer(msg.value(), ctx) - def _check_receiver_results(self, send_results: list): + def _check_receiver_results(self, send_results: list, signal: OpenEdxPublicSignal): """ Raises exception if any of the receivers produced an exception. @@ -415,7 +450,7 @@ def _check_receiver_results(self, send_results: list): raise ReceiverError( f"{len(error_descriptions)} receiver(s) out of {len(send_results)} " "produced errors (stack trace elsewhere in logs) " - f"when handling signal {self.signal}: {', '.join(error_descriptions)}", + f"when handling signal {signal}: {', '.join(error_descriptions)}", errors ) @@ -582,12 +617,11 @@ class ConsumeEventsCommand(BaseCommand): Management command for Kafka consumer workers in the event bus. """ help = """ - Consume messages of specified signal type from a Kafka topic and send their data to that signal. + Consume messages from a Kafka topic and send their data to the correct signal. Example:: - python3 manage.py cms consume_events -t user-login -g user-activity-service \ - -s org.openedx.learning.auth.session.login.completed.v1 + python3 manage.py cms consume_events -t user-login -g user-activity-service """ def add_arguments(self, parser): @@ -605,12 +639,16 @@ def add_arguments(self, parser): required=True, help='Consumer group id' ) + + # TODO: remove this once callers have been updated. Left optional to avoid the need for lockstep changes parser.add_argument( '-s', '--signal', nargs=1, - required=True, - help='Type of signal to emit from consumed messages.' + required=False, + default=None, + help='Deprecated argument. Correct signal will be determined from event' ) + parser.add_argument( '-o', '--offset_time', nargs=1, @@ -634,7 +672,6 @@ def handle(self, *args, **options): try: load_all_signals() - signal = OpenEdxPublicSignal.get_signal_by_type(options['signal'][0]) if options['offset_time'] and options['offset_time'][0] is not None: try: offset_timestamp = datetime.fromisoformat(options['offset_time'][0]) @@ -647,7 +684,6 @@ def handle(self, *args, **options): event_consumer = KafkaEventConsumer( topic=options['topic'][0], group_id=options['group_id'][0], - signal=signal, ) if offset_timestamp is None: event_consumer.consume_indefinitely() diff --git a/edx_event_bus_kafka/internal/tests/test_consumer.py b/edx_event_bus_kafka/internal/tests/test_consumer.py index 0b5ba41..db7d037 100644 --- a/edx_event_bus_kafka/internal/tests/test_consumer.py +++ b/edx_event_bus_kafka/internal/tests/test_consumer.py @@ -14,6 +14,7 @@ from django.test.utils import override_settings from openedx_events.learning.data import UserData, UserPersonalData from openedx_events.learning.signals import SESSION_LOGIN_COMPLETED +from openedx_events.tooling import OpenEdxPublicSignal from edx_event_bus_kafka.internal.consumer import ( KafkaEventConsumer, @@ -98,7 +99,7 @@ def setUp(self): self.signal.connect(fake_receiver_returns_quietly) self.signal.connect(fake_receiver_raises_error) self.signal.connect(self.mock_receiver) - self.event_consumer = KafkaEventConsumer('some-topic', 'test_group_id', self.signal) + self.event_consumer = KafkaEventConsumer('some-topic', 'test_group_id') def tearDown(self): self.signal.disconnect(fake_receiver_returns_quietly) @@ -195,7 +196,7 @@ def raise_exception(): # Check that each of the mocked out methods got called as expected. mock_consumer.subscribe.assert_called_once_with(['local-some-topic']) # Check that emit was called the expected number of times - assert mock_emit.call_args_list == [call(self.normal_message)] * len(mock_emit_side_effects) + assert mock_emit.call_args_list == [call(self.normal_message, self.signal)] * len(mock_emit_side_effects) # Check that there was one error log message and that it contained all the right parts, # in some order. @@ -204,8 +205,6 @@ def raise_exception(): assert "Error consuming event from Kafka: Exception('something broke') in context" in exc_log_msg assert "full_topic='local-some-topic'" in exc_log_msg assert "consumer_group='test_group_id'" in exc_log_msg - assert ("expected_signal=") in exc_log_msg assert "-- event details: " in exc_log_msg assert "'partition': 2" in exc_log_msg assert "'offset': 12345" in exc_log_msg @@ -254,7 +253,7 @@ def raise_exception(): with pytest.raises(Exception) as exc_info: self.event_consumer.consume_indefinitely() - assert mock_emit.call_args_list == [call(self.normal_message)] * exception_count + assert mock_emit.call_args_list == [call(self.normal_message, self.signal)] * exception_count assert exc_info.value.args == ("Too many consecutive errors, exiting (4 in a row)",) @override_settings( @@ -332,7 +331,7 @@ def raise_exception(): self.event_consumer.consumer = mock_consumer self.event_consumer.consume_indefinitely() # exits normally - assert mock_emit.call_args_list == [call(self.normal_message)] * len(mock_emit_side_effects) + assert mock_emit.call_args_list == [call(self.normal_message, self.signal)] * len(mock_emit_side_effects) TEST_FAILED_MESSAGE = FakeMessage( partition=7, @@ -396,8 +395,6 @@ def poll_side_effect(*args, **kwargs): assert f"Error consuming event from Kafka: {repr(exception)} in context" in exc_log_msg assert "full_topic='local-some-topic'" in exc_log_msg assert "consumer_group='test_group_id'" in exc_log_msg - assert ("expected_signal=") in exc_log_msg if has_message: assert "-- event details" in exc_log_msg else: @@ -425,20 +422,18 @@ def poll_side_effect(*args, **kwargs): mock_consumer.commit.assert_not_called() def test_check_event_error(self): - """ - DeserializingConsumer.poll() should never return a Message with an error() object, - but we check it anyway as a safeguard. This test exercises that branch. - """ with pytest.raises(Exception) as exc_info: self.event_consumer.emit_signals_from_message( FakeMessage( partition=2, error=KafkaError(123, "done broke"), - ) + headers=[{'ce_type': 'org.openedx.learning.auth.session.login.completed.v1'}] + ), + self.signal ) assert exc_info.value.args == ( - "Polled message had error object (shouldn't happen): " + "Polled message had error object: " "KafkaError{code=ERR_123?,val=123,str=\"done broke\"}", ) @@ -454,7 +449,7 @@ def test_emit_success(self, audit_logging, mock_logger, mock_set_attribute): self.normal_message.set_value(self.normal_event_data) with override_settings(EVENT_BUS_KAFKA_AUDIT_LOGGING_ENABLED=audit_logging): - self.event_consumer.emit_signals_from_message(self.normal_message) + self.event_consumer.emit_signals_from_message(self.normal_message, self.signal) self.assert_signal_sent_with(self.signal, self.normal_event_data) # Specifically, not called with 'kafka_logging_error' mock_set_attribute.assert_not_called() @@ -481,7 +476,7 @@ def test_emit_success_tolerates_missing_timestamp(self, mock_logger, mock_set_at self.normal_message.set_value(self.normal_event_data) self.normal_message._timestamp = (TIMESTAMP_NOT_AVAILABLE, None) # pylint: disable=protected-access - self.event_consumer.emit_signals_from_message(self.normal_message) + self.event_consumer.emit_signals_from_message(self.normal_message, self.signal) self.assert_signal_sent_with(self.signal, self.normal_event_data) # Specifically, not called with 'kafka_logging_error' mock_set_attribute.assert_not_called() @@ -499,7 +494,7 @@ def test_emit(self, mock_logger): # assume we've already deserialized the data self.normal_message.set_value(self.normal_event_data) with pytest.raises(ReceiverError) as exc_info: - self.event_consumer.emit_signals_from_message(self.normal_message) + self.event_consumer.emit_signals_from_message(self.normal_message, self.signal) self.assert_signal_sent_with(self.signal, self.normal_event_data) assert exc_info.value.args == ( "1 receiver(s) out of 3 produced errors (stack trace elsewhere in logs) " @@ -520,6 +515,34 @@ def test_emit(self, mock_logger): exc_info=receiver_error, ) + def test_emit_type_mismatch(self): + # assume we've already deserialized the data + self.normal_message.set_value(self.normal_event_data) + self.normal_message._headers = [('ce_type', b'xxxx')] # pylint: disable=protected-access + + with pytest.raises(Exception) as excinfo: + self.event_consumer.emit_signals_from_message(self.normal_message, self.signal) + + assert excinfo.value.args == ( + "Error emitting event from Kafka: (UNEXPECTED) message event type xxxx " + "does not match signal org.openedx.learning.auth.session.login.completed.v1", + ) + assert not self.mock_receiver.called + + def test_deserialize_type_mismatch(self): + self.normal_message._headers = [('ce_type', b'xxxx')] # pylint: disable=protected-access + + with pytest.raises(Exception) as excinfo: + self.event_consumer._deserialize_message_value( # pylint: disable=protected-access + self.normal_message, self.signal + ) + + assert excinfo.value.args == ( + "Error deserializing event from Kafka: (UNEXPECTED) message event type xxxx " + "does not match signal org.openedx.learning.auth.session.login.completed.v1", + ) + assert not self.mock_receiver.called + def test_malformed_receiver_errors(self): """ Ensure that even a really messed-up receiver is still reported correctly. @@ -529,7 +552,7 @@ def test_malformed_receiver_errors(self): (lambda x:x, Exception("for lambda")), # This would actually raise an error inside send_robust(), but it will serve well enough for testing... ("not even a function", Exception("just plain bad")), - ]) + ], self.signal) assert exc_info.value.args == ( "2 receiver(s) out of 2 produced errors (stack trace elsewhere in logs) " "when handling signal