From d3698d28fac133eb51b174b1726ee6723909c8de Mon Sep 17 00:00:00 2001 From: Simon Frank <44034567+SimonBFrank@users.noreply.github.com> Date: Tue, 22 Oct 2024 04:16:39 -0400 Subject: [PATCH] feat(logger): add thread safe logging keys (#5141) * Getting baseline thread safe keys working * Functional tests for thread safe keys * Updating documentation * Cleanup * Small fixes for linting * Clearing thread local keys with clear_state=True * Cleaning up PR * Small fixes to docs * Fixing type annotations for THREAD_LOCAL_KEYS * Replacing '|' with {**dict1, **dict2} due to support of Python < 3.9 * fix types from v2 to v3 * Changing documentation and method names --------- Co-authored-by: Simon Thulbourn Co-authored-by: Leandro Damascena --- aws_lambda_powertools/logging/formatter.py | 83 ++++++++++++- aws_lambda_powertools/logging/logger.py | 19 +++ docs/core/event_handler/api_gateway.md | 3 +- docs/core/logger.md | 80 ++++++++++++- .../logger/src/thread_safe_append_keys.py | 21 ++++ .../src/thread_safe_append_keys_output.json | 20 ++++ examples/logger/src/thread_safe_clear_keys.py | 23 ++++ .../src/thread_safe_clear_keys_output.json | 34 ++++++ .../src/thread_safe_get_current_keys.py | 14 +++ .../logger/src/thread_safe_remove_keys.py | 23 ++++ .../src/thread_safe_remove_keys_output.json | 36 ++++++ .../event_handler/_pydantic/conftest.py | 4 +- .../test_logger_powertools_formatter.py | 110 +++++++++++++++++- 13 files changed, 459 insertions(+), 11 deletions(-) create mode 100644 examples/logger/src/thread_safe_append_keys.py create mode 100644 examples/logger/src/thread_safe_append_keys_output.json create mode 100644 examples/logger/src/thread_safe_clear_keys.py create mode 100644 examples/logger/src/thread_safe_clear_keys_output.json create mode 100644 examples/logger/src/thread_safe_get_current_keys.py create mode 100644 examples/logger/src/thread_safe_remove_keys.py create mode 100644 examples/logger/src/thread_safe_remove_keys_output.json diff --git a/aws_lambda_powertools/logging/formatter.py b/aws_lambda_powertools/logging/formatter.py index 48797f51e2..07db499d1f 100644 --- a/aws_lambda_powertools/logging/formatter.py +++ b/aws_lambda_powertools/logging/formatter.py @@ -7,6 +7,7 @@ import time import traceback from abc import ABCMeta, abstractmethod +from contextvars import ContextVar from datetime import datetime, timezone from functools import partial from typing import TYPE_CHECKING, Any, Callable, Iterable @@ -61,6 +62,21 @@ def clear_state(self) -> None: """Removes any previously added logging keys""" raise NotImplementedError() + # These specific thread-safe methods are necessary to manage shared context in concurrent environments. + # They prevent race conditions and ensure data consistency across multiple threads. + def thread_safe_append_keys(self, **additional_keys) -> None: + raise NotImplementedError() + + def thread_safe_get_current_keys(self) -> dict[str, Any]: + return {} + + def thread_safe_remove_keys(self, keys: Iterable[str]) -> None: + raise NotImplementedError() + + def thread_safe_clear_keys(self) -> None: + """Removes any previously added logging keys in a specific thread""" + raise NotImplementedError() + class LambdaPowertoolsFormatter(BasePowertoolsFormatter): """Powertools for AWS Lambda (Python) Logging formatter. @@ -247,6 +263,24 @@ def clear_state(self) -> None: self.log_format = dict.fromkeys(self.log_record_order) self.log_format.update(**self.keys_combined) + # These specific thread-safe methods are necessary to manage shared context in concurrent environments. + # They prevent race conditions and ensure data consistency across multiple threads. + def thread_safe_append_keys(self, **additional_keys) -> None: + # Append additional key-value pairs to the context safely in a thread-safe manner. + set_context_keys(**additional_keys) + + def thread_safe_get_current_keys(self) -> dict[str, Any]: + # Retrieve the current context keys safely in a thread-safe manner. + return _get_context().get() + + def thread_safe_remove_keys(self, keys: Iterable[str]) -> None: + # Remove specified keys from the context safely in a thread-safe manner. + remove_context_keys(keys) + + def thread_safe_clear_keys(self) -> None: + # Clear all keys from the context safely in a thread-safe manner. + clear_context_keys() + @staticmethod def _build_default_keys() -> dict[str, str]: return { @@ -345,14 +379,33 @@ def _extract_log_keys(self, log_record: logging.LogRecord) -> dict[str, Any]: record_dict["asctime"] = self.formatTime(record=log_record) extras = {k: v for k, v in record_dict.items() if k not in RESERVED_LOG_ATTRS} - formatted_log = {} + formatted_log: dict[str, Any] = {} # Iterate over a default or existing log structure # then replace any std log attribute e.g. '%(level)s' to 'INFO', '%(process)d to '4773' + # check if the value is a str if the key is a reserved attribute, the modulo operator only supports string # lastly add or replace incoming keys (those added within the constructor or .structure_logs method) for key, value in self.log_format.items(): if value and key in RESERVED_LOG_ATTRS: - formatted_log[key] = value % record_dict + if isinstance(value, str): + formatted_log[key] = value % record_dict + else: + raise ValueError( + "Logging keys that override reserved log attributes need to be type 'str', " + f"instead got '{type(value).__name__}'", + ) + else: + formatted_log[key] = value + + for key, value in _get_context().get().items(): + if value and key in RESERVED_LOG_ATTRS: + if isinstance(value, str): + formatted_log[key] = value % record_dict + else: + raise ValueError( + "Logging keys that override reserved log attributes need to be type 'str', " + f"instead got '{type(value).__name__}'", + ) else: formatted_log[key] = value @@ -370,3 +423,29 @@ def _strip_none_records(records: dict[str, Any]) -> dict[str, Any]: # Fetch current and future parameters from PowertoolsFormatter that should be reserved RESERVED_FORMATTER_CUSTOM_KEYS: list[str] = inspect.getfullargspec(LambdaPowertoolsFormatter).args[1:] + +# ContextVar for thread local keys +THREAD_LOCAL_KEYS: ContextVar[dict[str, Any]] = ContextVar("THREAD_LOCAL_KEYS", default={}) + + +def _get_context() -> ContextVar[dict[str, Any]]: + return THREAD_LOCAL_KEYS + + +def clear_context_keys() -> None: + _get_context().set({}) + + +def set_context_keys(**kwargs: dict[str, Any]) -> None: + context = _get_context() + context.set({**context.get(), **kwargs}) + + +def remove_context_keys(keys: Iterable[str]) -> None: + context = _get_context() + context_values = context.get() + + for k in keys: + context_values.pop(k, None) + + context.set(context_values) diff --git a/aws_lambda_powertools/logging/logger.py b/aws_lambda_powertools/logging/logger.py index 75a14c6ea2..acefe9757c 100644 --- a/aws_lambda_powertools/logging/logger.py +++ b/aws_lambda_powertools/logging/logger.py @@ -589,6 +589,24 @@ def get_current_keys(self) -> dict[str, Any]: def remove_keys(self, keys: Iterable[str]) -> None: self.registered_formatter.remove_keys(keys) + # These specific thread-safe methods are necessary to manage shared context in concurrent environments. + # They prevent race conditions and ensure data consistency across multiple threads. + def thread_safe_append_keys(self, **additional_keys: object) -> None: + # Append additional key-value pairs to the context safely in a thread-safe manner. + self.registered_formatter.thread_safe_append_keys(**additional_keys) + + def thread_safe_get_current_keys(self) -> dict[str, Any]: + # Retrieve the current context keys safely in a thread-safe manner. + return self.registered_formatter.thread_safe_get_current_keys() + + def thread_safe_remove_keys(self, keys: Iterable[str]) -> None: + # Remove specified keys from the context safely in a thread-safe manner. + self.registered_formatter.thread_safe_remove_keys(keys) + + def thread_safe_clear_keys(self) -> None: + # Clear all keys from the context safely in a thread-safe manner. + self.registered_formatter.thread_safe_clear_keys() + def structure_logs(self, append: bool = False, formatter_options: dict | None = None, **keys) -> None: """Sets logging formatting to JSON. @@ -633,6 +651,7 @@ def structure_logs(self, append: bool = False, formatter_options: dict | None = # Mode 3 self.registered_formatter.clear_state() + self.registered_formatter.thread_safe_clear_keys() self.registered_formatter.append_keys(**log_keys) def set_correlation_id(self, value: str | None) -> None: diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index cba8addfdd..c4082b43ca 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -128,11 +128,12 @@ Here's an example on how we can handle the `/todos` path. When using Amazon API Gateway HTTP API to front your Lambda functions, you can use `APIGatewayHttpResolver`. + ???+ note Using HTTP API v1 payload? Use `APIGatewayRestResolver` instead. `APIGatewayHttpResolver` defaults to v2 payload. - If you're using Terraform to deploy a HTTP API, note that it defaults the [payload_format_version](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/apigatewayv2_integration#payload_format_version){target="_blank" rel="nofollow"} value to 1.0 if not specified. + ```python hl_lines="5 11" title="Using HTTP API resolver" --8<-- "examples/event_handler_rest/src/getting_started_http_api_resolver.py" diff --git a/docs/core/logger.md b/docs/core/logger.md index 2a45ff0828..818d5a6589 100644 --- a/docs/core/logger.md +++ b/docs/core/logger.md @@ -159,13 +159,14 @@ To ease routine tasks like extracting correlation ID from popular event sources, You can append additional keys using either mechanism: -* Persist new keys across all future log messages via `append_keys` method +* New keys persist across all future log messages via `append_keys` method * Add additional keys on a per log message basis as a keyword=value, or via `extra` parameter +* New keys persist across all future logs in a specific thread via `thread_safe_append_keys` method. Check [Working with thread-safe keys](#working-with-thread-safe-keys) section. #### append_keys method ???+ warning - `append_keys` is not thread-safe, please see [RFC](https://github.com/aws-powertools/powertools-lambda-python/issues/991){target="_blank"}. + `append_keys` is not thread-safe, use [thread_safe_append_keys](#appending-thread-safe-additional-keys) instead You can append your own keys to your existing Logger via `append_keys(**additional_key_values)` method. @@ -228,6 +229,16 @@ It accepts any dictionary, and all keyword arguments will be added as part of th ### Removing additional keys +You can remove additional keys using either mechanism: + +* Remove new keys across all future log messages via `remove_keys` method +* Remove keys persist across all future logs in a specific thread via `thread_safe_remove_keys` method. Check [Working with thread-safe keys](#working-with-thread-safe-keys) section. + +???+ danger + Keys added by `append_keys` can only be removed by `remove_keys` and thread-local keys added by `thread_safe_append_keys` can only be removed by `thread_safe_remove_keys` or `thread_safe_clear_keys`. Thread-local and normal logger keys are distinct values and can't be manipulated interchangeably. + +#### remove_keys method + You can remove any additional key from Logger state using `remove_keys`. === "remove_keys.py" @@ -284,6 +295,9 @@ You can view all currently configured keys from the Logger state using the `get_ --8<-- "examples/logger/src/get_current_keys.py" ``` +???+ info + For thread-local additional logging keys, use `get_current_thread_keys` instead + ### Log levels The default log level is `INFO`. It can be set using the `level` constructor option, `setLevel()` method or by using the `POWERTOOLS_LOG_LEVEL` environment variable. @@ -473,6 +487,68 @@ You can use any of the following built-in JMESPath expressions as part of [injec | **APPLICATION_LOAD_BALANCER** | `'headers."x-amzn-trace-id"'` | ALB X-Ray Trace ID | | **EVENT_BRIDGE** | `"id"` | EventBridge Event ID | +### Working with thread-safe keys + +#### Appending thread-safe additional keys + +You can append your own thread-local keys in your existing Logger via the `thread_safe_append_keys` method + +=== "thread_safe_append_keys.py" + + ```python hl_lines="11" + --8<-- "examples/logger/src/thread_safe_append_keys.py" + ``` + +=== "thread_safe_append_keys_output.json" + + ```json hl_lines="8 9 17 18" + --8<-- "examples/logger/src/thread_safe_append_keys_output.json" + ``` + +#### Removing thread-safe additional keys + +You can remove any additional thread-local keys from Logger using either `thread_safe_remove_keys` or `thread_safe_clear_keys`. + +Use the `thread_safe_remove_keys` method to remove a list of thread-local keys that were previously added using the `thread_safe_append_keys` method. + +=== "thread_safe_remove_keys.py" + + ```python hl_lines="13" + --8<-- "examples/logger/src/thread_safe_remove_keys.py" + ``` + +=== "thread_safe_remove_keys_output.json" + + ```json hl_lines="8 9 17 18 26 34" + --8<-- "examples/logger/src/thread_safe_remove_keys_output.json" + ``` + +#### Clearing thread-safe additional keys + +Use the `thread_safe_clear_keys` method to remove all thread-local keys that were previously added using the `thread_safe_append_keys` method. + +=== "thread_safe_clear_keys.py" + + ```python hl_lines="13" + --8<-- "examples/logger/src/thread_safe_clear_keys.py" + ``` + +=== "thread_safe_clear_keys_output.json" + + ```json hl_lines="8 9 17 18" + --8<-- "examples/logger/src/thread_safe_clear_keys_output.json" + ``` + +#### Accessing thread-safe currently keys + +You can view all currently thread-local keys from the Logger state using the `thread_safe_get_current_keys()` method. This method is useful when you need to avoid overwriting keys that are already configured. + +=== "thread_safe_get_current_keys.py" + + ```python hl_lines="13" + --8<-- "examples/logger/src/thread_safe_get_current_keys.py" + ``` + ### Reusing Logger across your code Similar to [Tracer](./tracer.md#reusing-tracer-across-your-code){target="_blank"}, a new instance that uses the same `service` name will reuse a previous Logger instance. diff --git a/examples/logger/src/thread_safe_append_keys.py b/examples/logger/src/thread_safe_append_keys.py new file mode 100644 index 0000000000..716d5eef8b --- /dev/null +++ b/examples/logger/src/thread_safe_append_keys.py @@ -0,0 +1,21 @@ +import threading +from typing import List + +from aws_lambda_powertools import Logger +from aws_lambda_powertools.utilities.typing import LambdaContext + +logger = Logger() + + +def threaded_func(order_id: str): + logger.thread_safe_append_keys(order_id=order_id, thread_id=threading.get_ident()) + logger.info("Collecting payment") + + +def lambda_handler(event: dict, context: LambdaContext) -> str: + order_ids: List[str] = event["order_ids"] + + threading.Thread(target=threaded_func, args=(order_ids[0],)).start() + threading.Thread(target=threaded_func, args=(order_ids[1],)).start() + + return "hello world" diff --git a/examples/logger/src/thread_safe_append_keys_output.json b/examples/logger/src/thread_safe_append_keys_output.json new file mode 100644 index 0000000000..bb4a9d2d55 --- /dev/null +++ b/examples/logger/src/thread_safe_append_keys_output.json @@ -0,0 +1,20 @@ +[ + { + "level": "INFO", + "location": "threaded_func:11", + "message": "Collecting payment", + "timestamp": "2024-09-08 03:04:11,316-0400", + "service": "payment", + "order_id": "order_id_value_1", + "thread_id": "3507187776085958" + }, + { + "level": "INFO", + "location": "threaded_func:11", + "message": "Collecting payment", + "timestamp": "2024-09-08 03:04:11,316-0400", + "service": "payment", + "order_id": "order_id_value_2", + "thread_id": "140718447808512" + } +] \ No newline at end of file diff --git a/examples/logger/src/thread_safe_clear_keys.py b/examples/logger/src/thread_safe_clear_keys.py new file mode 100644 index 0000000000..607e9766d0 --- /dev/null +++ b/examples/logger/src/thread_safe_clear_keys.py @@ -0,0 +1,23 @@ +import threading +from typing import List + +from aws_lambda_powertools import Logger +from aws_lambda_powertools.utilities.typing import LambdaContext + +logger = Logger() + + +def threaded_func(order_id: str): + logger.thread_safe_append_keys(order_id=order_id, thread_id=threading.get_ident()) + logger.info("Collecting payment") + logger.thread_safe_clear_keys() + logger.info("Exiting thread") + + +def lambda_handler(event: dict, context: LambdaContext) -> str: + order_ids: List[str] = event["order_ids"] + + threading.Thread(target=threaded_func, args=(order_ids[0],)).start() + threading.Thread(target=threaded_func, args=(order_ids[1],)).start() + + return "hello world" diff --git a/examples/logger/src/thread_safe_clear_keys_output.json b/examples/logger/src/thread_safe_clear_keys_output.json new file mode 100644 index 0000000000..791e2afd45 --- /dev/null +++ b/examples/logger/src/thread_safe_clear_keys_output.json @@ -0,0 +1,34 @@ +[ + { + "level": "INFO", + "location": "threaded_func:11", + "message": "Collecting payment", + "timestamp": "2024-09-08 12:26:10,648-0400", + "service": "payment", + "order_id": "order_id_value_1", + "thread_id": 140077070292544 + }, + { + "level": "INFO", + "location": "threaded_func:11", + "message": "Collecting payment", + "timestamp": "2024-09-08 12:26:10,649-0400", + "service": "payment", + "order_id": "order_id_value_2", + "thread_id": 140077061899840 + }, + { + "level": "INFO", + "location": "threaded_func:13", + "message": "Exiting thread", + "timestamp": "2024-09-08 12:26:10,649-0400", + "service": "payment" + }, + { + "level": "INFO", + "location": "threaded_func:13", + "message": "Exiting thread", + "timestamp": "2024-09-08 12:26:10,649-0400", + "service": "payment" + } +] diff --git a/examples/logger/src/thread_safe_get_current_keys.py b/examples/logger/src/thread_safe_get_current_keys.py new file mode 100644 index 0000000000..b9b67a20cf --- /dev/null +++ b/examples/logger/src/thread_safe_get_current_keys.py @@ -0,0 +1,14 @@ +from aws_lambda_powertools import Logger +from aws_lambda_powertools.utilities.typing import LambdaContext + +logger = Logger() + + +@logger.inject_lambda_context +def lambda_handler(event: dict, context: LambdaContext) -> str: + logger.info("Collecting payment") + + if "order" not in logger.thread_safe_get_current_keys(): + logger.thread_safe_append_keys(order=event.get("order")) + + return "hello world" diff --git a/examples/logger/src/thread_safe_remove_keys.py b/examples/logger/src/thread_safe_remove_keys.py new file mode 100644 index 0000000000..b9e4c918da --- /dev/null +++ b/examples/logger/src/thread_safe_remove_keys.py @@ -0,0 +1,23 @@ +import threading +from typing import List + +from aws_lambda_powertools import Logger +from aws_lambda_powertools.utilities.typing import LambdaContext + +logger = Logger() + + +def threaded_func(order_id: str): + logger.thread_safe_append_keys(order_id=order_id, thread_id=threading.get_ident()) + logger.info("Collecting payment") + logger.thread_safe_remove_keys(["order_id"]) + logger.info("Exiting thread") + + +def lambda_handler(event: dict, context: LambdaContext) -> str: + order_ids: List[str] = event["order_ids"] + + threading.Thread(target=threaded_func, args=(order_ids[0],)).start() + threading.Thread(target=threaded_func, args=(order_ids[1],)).start() + + return "hello world" diff --git a/examples/logger/src/thread_safe_remove_keys_output.json b/examples/logger/src/thread_safe_remove_keys_output.json new file mode 100644 index 0000000000..24ff93739b --- /dev/null +++ b/examples/logger/src/thread_safe_remove_keys_output.json @@ -0,0 +1,36 @@ +[ + { + "level": "INFO", + "location": "threaded_func:11", + "message": "Collecting payment", + "timestamp": "2024-09-08 12:26:10,648-0400", + "service": "payment", + "order_id": "order_id_value_1", + "thread_id": 140077070292544 + }, + { + "level": "INFO", + "location": "threaded_func:11", + "message": "Collecting payment", + "timestamp": "2024-09-08 12:26:10,649-0400", + "service": "payment", + "order_id": "order_id_value_2", + "thread_id": 140077061899840 + }, + { + "level": "INFO", + "location": "threaded_func:13", + "message": "Exiting thread", + "timestamp": "2024-09-08 12:26:10,649-0400", + "service": "payment", + "thread_id": 140077070292544 + }, + { + "level": "INFO", + "location": "threaded_func:13", + "message": "Exiting thread", + "timestamp": "2024-09-08 12:26:10,649-0400", + "service": "payment", + "thread_id": 140077061899840 + } +] diff --git a/tests/functional/event_handler/_pydantic/conftest.py b/tests/functional/event_handler/_pydantic/conftest.py index 1d38e2e26b..6dd0b6d14a 100644 --- a/tests/functional/event_handler/_pydantic/conftest.py +++ b/tests/functional/event_handler/_pydantic/conftest.py @@ -97,7 +97,7 @@ def pydanticv2_only(): def openapi30_schema(): from urllib.request import urlopen - f = urlopen("https://raw.githubusercontent.com/OAI/OpenAPI-Specification/main/schemas/v3.0/schema.json") + f = urlopen("https://spec.openapis.org/oas/3.0/schema/2021-09-28") data = json.loads(f.read().decode("utf-8")) return fastjsonschema.compile( data, @@ -109,7 +109,7 @@ def openapi30_schema(): def openapi31_schema(): from urllib.request import urlopen - f = urlopen("https://raw.githubusercontent.com/OAI/OpenAPI-Specification/main/schemas/v3.1/schema.json") + f = urlopen("https://spec.openapis.org/oas/3.1/schema/2022-10-07") data = json.loads(f.read().decode("utf-8")) return fastjsonschema.compile( data, diff --git a/tests/functional/logger/required_dependencies/test_logger_powertools_formatter.py b/tests/functional/logger/required_dependencies/test_logger_powertools_formatter.py index fe47e72d59..fdf4c0dd39 100644 --- a/tests/functional/logger/required_dependencies/test_logger_powertools_formatter.py +++ b/tests/functional/logger/required_dependencies/test_logger_powertools_formatter.py @@ -8,6 +8,7 @@ import string import time from collections import namedtuple +from threading import Thread import pytest @@ -40,7 +41,7 @@ def service_name(): def capture_logging_output(stdout): - return json.loads(stdout.getvalue().strip()) + return [json.loads(d.strip()) for d in stdout.getvalue().strip().split("\n")] @pytest.mark.parametrize("level", ["DEBUG", "WARNING", "ERROR", "INFO", "CRITICAL"]) @@ -370,7 +371,7 @@ def test_datadog_formatter_use_rfc3339_date(stdout, service_name): logger.info({}) # THEN the timestamp uses RFC3339 by default - log = capture_logging_output(stdout) + log = capture_logging_output(stdout)[0] assert re.fullmatch(RFC3339_REGEX, log["timestamp"]) # "2022-10-27T17:42:26.841+0200" @@ -389,7 +390,7 @@ def handler(event, context): # THEN we expect a "stack_trace" in log handler({}, lambda_context) - log = capture_logging_output(stdout) + log = capture_logging_output(stdout)[0] assert "stack_trace" in log @@ -410,5 +411,106 @@ def handler(event, context): # THEN we expect a "stack_trace" not in log handler({}, lambda_context) - log = capture_logging_output(stdout) + log = capture_logging_output(stdout)[0] assert "stack_trace" not in log + + +def test_thread_safe_keys_encapsulation(service_name, stdout): + logger = Logger( + service=service_name, + stream=stdout, + ) + + def send_thread_message_with_key(message, keys): + logger.thread_safe_append_keys(**keys) + logger.info(message) + + global_key = {"exampleKey": "globalKey"} + logger.append_keys(**global_key) + logger.info("global key added") + + thread1_keys = {"exampleThread1Key": "thread1"} + Thread(target=send_thread_message_with_key, args=("thread1", thread1_keys)).start() + thread2_keys = {"exampleThread2Key": "thread2"} + Thread(target=send_thread_message_with_key, args=("thread2", thread2_keys)).start() + + logger.info("final log, all thread keys gone") + + logs = capture_logging_output(stdout) + + assert logs[0].get("exampleKey") == "globalKey" + + assert logs[1].get("exampleKey") == "globalKey" + assert logs[1].get("exampleThread1Key") == "thread1" + assert logs[1].get("exampleThread2Key") is None + + assert logs[2].get("exampleKey") == "globalKey" + assert logs[2].get("exampleThread1Key") is None + assert logs[2].get("exampleThread2Key") == "thread2" + + assert logs[3].get("exampleKey") == "globalKey" + assert logs[3].get("exampleThread1Key") is None + assert logs[3].get("exampleThread2Key") is None + + +def test_thread_safe_remove_key(service_name, stdout): + logger = Logger( + service=service_name, + stream=stdout, + ) + + def send_message_with_key_and_without(message, keys): + logger.thread_safe_append_keys(**keys) + logger.info(message) + logger.thread_safe_remove_keys(keys.keys()) + logger.info(message) + + thread1_keys = {"exampleThread1Key": "thread1"} + Thread(target=send_message_with_key_and_without, args=("msg", thread1_keys)).start() + + logs = capture_logging_output(stdout) + + assert logs[0].get("exampleThread1Key") == "thread1" + assert logs[1].get("exampleThread1Key") is None + + +def test_thread_safe_clear_key(service_name, stdout): + logger = Logger( + service=service_name, + stream=stdout, + ) + + def send_message_with_key_and_clear(message, keys): + logger.thread_safe_append_keys(**keys) + logger.info(message) + logger.thread_safe_clear_keys() + logger.info(message) + + thread1_keys = {"exampleThread1Key": "thread1"} + Thread(target=send_message_with_key_and_clear, args=("msg", thread1_keys)).start() + + logs = capture_logging_output(stdout) + print(logs) + + assert logs[0].get("exampleThread1Key") == "thread1" + assert logs[1].get("exampleThread1Key") is None + + +def test_thread_safe_getkey(service_name, stdout): + logger = Logger( + service=service_name, + stream=stdout, + ) + + def send_message_with_key_and_get(message, keys): + logger.thread_safe_append_keys(**keys) + logger.info(logger.thread_safe_get_current_keys()) + + thread1_keys = {"exampleThread1Key": "thread1"} + Thread(target=send_message_with_key_and_get, args=("msg", thread1_keys)).start() + + logs = capture_logging_output(stdout) + print(logs) + + assert logs[0].get("exampleThread1Key") == "thread1" + assert logs[0].get("message") == thread1_keys