Skip to content

Commit

Permalink
feat(logger): add thread safe logging keys (#5141)
Browse files Browse the repository at this point in the history
* Getting baseline thread safe keys working

* Functional tests for thread safe keys

* Updating documentation

* Cleanup

* Small fixes for linting

* Clearing thread local keys with clear_state=True

* Cleaning up PR

* Small fixes to docs

* Fixing type annotations for THREAD_LOCAL_KEYS

* Replacing '|' with {**dict1, **dict2} due to support of Python < 3.9

* fix types from v2 to v3

* Changing documentation and method names

---------

Co-authored-by: Simon Thulbourn <sthulb@users.noreply.github.com>
Co-authored-by: Leandro Damascena <lcdama@amazon.pt>
  • Loading branch information
3 people authored Oct 22, 2024
1 parent fd609bc commit d3698d2
Show file tree
Hide file tree
Showing 13 changed files with 459 additions and 11 deletions.
83 changes: 81 additions & 2 deletions aws_lambda_powertools/logging/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
import traceback
from abc import ABCMeta, abstractmethod
from contextvars import ContextVar
from datetime import datetime, timezone
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Iterable
Expand Down Expand Up @@ -61,6 +62,21 @@ def clear_state(self) -> None:
"""Removes any previously added logging keys"""
raise NotImplementedError()

# These specific thread-safe methods are necessary to manage shared context in concurrent environments.
# They prevent race conditions and ensure data consistency across multiple threads.
def thread_safe_append_keys(self, **additional_keys) -> None:
raise NotImplementedError()

def thread_safe_get_current_keys(self) -> dict[str, Any]:
return {}

def thread_safe_remove_keys(self, keys: Iterable[str]) -> None:
raise NotImplementedError()

def thread_safe_clear_keys(self) -> None:
"""Removes any previously added logging keys in a specific thread"""
raise NotImplementedError()


class LambdaPowertoolsFormatter(BasePowertoolsFormatter):
"""Powertools for AWS Lambda (Python) Logging formatter.
Expand Down Expand Up @@ -247,6 +263,24 @@ def clear_state(self) -> None:
self.log_format = dict.fromkeys(self.log_record_order)
self.log_format.update(**self.keys_combined)

# These specific thread-safe methods are necessary to manage shared context in concurrent environments.
# They prevent race conditions and ensure data consistency across multiple threads.
def thread_safe_append_keys(self, **additional_keys) -> None:
# Append additional key-value pairs to the context safely in a thread-safe manner.
set_context_keys(**additional_keys)

def thread_safe_get_current_keys(self) -> dict[str, Any]:
# Retrieve the current context keys safely in a thread-safe manner.
return _get_context().get()

def thread_safe_remove_keys(self, keys: Iterable[str]) -> None:
# Remove specified keys from the context safely in a thread-safe manner.
remove_context_keys(keys)

def thread_safe_clear_keys(self) -> None:
# Clear all keys from the context safely in a thread-safe manner.
clear_context_keys()

@staticmethod
def _build_default_keys() -> dict[str, str]:
return {
Expand Down Expand Up @@ -345,14 +379,33 @@ def _extract_log_keys(self, log_record: logging.LogRecord) -> dict[str, Any]:
record_dict["asctime"] = self.formatTime(record=log_record)
extras = {k: v for k, v in record_dict.items() if k not in RESERVED_LOG_ATTRS}

formatted_log = {}
formatted_log: dict[str, Any] = {}

# Iterate over a default or existing log structure
# then replace any std log attribute e.g. '%(level)s' to 'INFO', '%(process)d to '4773'
# check if the value is a str if the key is a reserved attribute, the modulo operator only supports string
# lastly add or replace incoming keys (those added within the constructor or .structure_logs method)
for key, value in self.log_format.items():
if value and key in RESERVED_LOG_ATTRS:
formatted_log[key] = value % record_dict
if isinstance(value, str):
formatted_log[key] = value % record_dict
else:
raise ValueError(
"Logging keys that override reserved log attributes need to be type 'str', "
f"instead got '{type(value).__name__}'",
)
else:
formatted_log[key] = value

for key, value in _get_context().get().items():
if value and key in RESERVED_LOG_ATTRS:
if isinstance(value, str):
formatted_log[key] = value % record_dict
else:
raise ValueError(
"Logging keys that override reserved log attributes need to be type 'str', "
f"instead got '{type(value).__name__}'",
)
else:
formatted_log[key] = value

Expand All @@ -370,3 +423,29 @@ def _strip_none_records(records: dict[str, Any]) -> dict[str, Any]:

# Fetch current and future parameters from PowertoolsFormatter that should be reserved
RESERVED_FORMATTER_CUSTOM_KEYS: list[str] = inspect.getfullargspec(LambdaPowertoolsFormatter).args[1:]

# ContextVar for thread local keys
THREAD_LOCAL_KEYS: ContextVar[dict[str, Any]] = ContextVar("THREAD_LOCAL_KEYS", default={})


def _get_context() -> ContextVar[dict[str, Any]]:
return THREAD_LOCAL_KEYS


def clear_context_keys() -> None:
_get_context().set({})


def set_context_keys(**kwargs: dict[str, Any]) -> None:
context = _get_context()
context.set({**context.get(), **kwargs})


def remove_context_keys(keys: Iterable[str]) -> None:
context = _get_context()
context_values = context.get()

for k in keys:
context_values.pop(k, None)

context.set(context_values)
19 changes: 19 additions & 0 deletions aws_lambda_powertools/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,24 @@ def get_current_keys(self) -> dict[str, Any]:
def remove_keys(self, keys: Iterable[str]) -> None:
self.registered_formatter.remove_keys(keys)

# These specific thread-safe methods are necessary to manage shared context in concurrent environments.
# They prevent race conditions and ensure data consistency across multiple threads.
def thread_safe_append_keys(self, **additional_keys: object) -> None:
# Append additional key-value pairs to the context safely in a thread-safe manner.
self.registered_formatter.thread_safe_append_keys(**additional_keys)

def thread_safe_get_current_keys(self) -> dict[str, Any]:
# Retrieve the current context keys safely in a thread-safe manner.
return self.registered_formatter.thread_safe_get_current_keys()

def thread_safe_remove_keys(self, keys: Iterable[str]) -> None:
# Remove specified keys from the context safely in a thread-safe manner.
self.registered_formatter.thread_safe_remove_keys(keys)

def thread_safe_clear_keys(self) -> None:
# Clear all keys from the context safely in a thread-safe manner.
self.registered_formatter.thread_safe_clear_keys()

def structure_logs(self, append: bool = False, formatter_options: dict | None = None, **keys) -> None:
"""Sets logging formatting to JSON.
Expand Down Expand Up @@ -633,6 +651,7 @@ def structure_logs(self, append: bool = False, formatter_options: dict | None =

# Mode 3
self.registered_formatter.clear_state()
self.registered_formatter.thread_safe_clear_keys()
self.registered_formatter.append_keys(**log_keys)

def set_correlation_id(self, value: str | None) -> None:
Expand Down
3 changes: 2 additions & 1 deletion docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,12 @@ Here's an example on how we can handle the `/todos` path.

When using Amazon API Gateway HTTP API to front your Lambda functions, you can use `APIGatewayHttpResolver`.

<!-- markdownlint-disable MD013 -->
???+ note
Using HTTP API v1 payload? Use `APIGatewayRestResolver` instead. `APIGatewayHttpResolver` defaults to v2 payload.

<!-- markdownlint-disable-next-line MD013 -->
If you're using Terraform to deploy a HTTP API, note that it defaults the [payload_format_version](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/apigatewayv2_integration#payload_format_version){target="_blank" rel="nofollow"} value to 1.0 if not specified.
<!-- markdownlint-enable MD013 -->

```python hl_lines="5 11" title="Using HTTP API resolver"
--8<-- "examples/event_handler_rest/src/getting_started_http_api_resolver.py"
Expand Down
80 changes: 78 additions & 2 deletions docs/core/logger.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,14 @@ To ease routine tasks like extracting correlation ID from popular event sources,

You can append additional keys using either mechanism:

* Persist new keys across all future log messages via `append_keys` method
* New keys persist across all future log messages via `append_keys` method
* Add additional keys on a per log message basis as a keyword=value, or via `extra` parameter
* New keys persist across all future logs in a specific thread via `thread_safe_append_keys` method. Check [Working with thread-safe keys](#working-with-thread-safe-keys) section.

#### append_keys method

???+ warning
`append_keys` is not thread-safe, please see [RFC](https://github.com/aws-powertools/powertools-lambda-python/issues/991){target="_blank"}.
`append_keys` is not thread-safe, use [thread_safe_append_keys](#appending-thread-safe-additional-keys) instead

You can append your own keys to your existing Logger via `append_keys(**additional_key_values)` method.

Expand Down Expand Up @@ -228,6 +229,16 @@ It accepts any dictionary, and all keyword arguments will be added as part of th

### Removing additional keys

You can remove additional keys using either mechanism:

* Remove new keys across all future log messages via `remove_keys` method
* Remove keys persist across all future logs in a specific thread via `thread_safe_remove_keys` method. Check [Working with thread-safe keys](#working-with-thread-safe-keys) section.

???+ danger
Keys added by `append_keys` can only be removed by `remove_keys` and thread-local keys added by `thread_safe_append_keys` can only be removed by `thread_safe_remove_keys` or `thread_safe_clear_keys`. Thread-local and normal logger keys are distinct values and can't be manipulated interchangeably.

#### remove_keys method

You can remove any additional key from Logger state using `remove_keys`.

=== "remove_keys.py"
Expand Down Expand Up @@ -284,6 +295,9 @@ You can view all currently configured keys from the Logger state using the `get_
--8<-- "examples/logger/src/get_current_keys.py"
```

???+ info
For thread-local additional logging keys, use `get_current_thread_keys` instead

### Log levels

The default log level is `INFO`. It can be set using the `level` constructor option, `setLevel()` method or by using the `POWERTOOLS_LOG_LEVEL` environment variable.
Expand Down Expand Up @@ -473,6 +487,68 @@ You can use any of the following built-in JMESPath expressions as part of [injec
| **APPLICATION_LOAD_BALANCER** | `'headers."x-amzn-trace-id"'` | ALB X-Ray Trace ID |
| **EVENT_BRIDGE** | `"id"` | EventBridge Event ID |

### Working with thread-safe keys

#### Appending thread-safe additional keys

You can append your own thread-local keys in your existing Logger via the `thread_safe_append_keys` method

=== "thread_safe_append_keys.py"

```python hl_lines="11"
--8<-- "examples/logger/src/thread_safe_append_keys.py"
```

=== "thread_safe_append_keys_output.json"

```json hl_lines="8 9 17 18"
--8<-- "examples/logger/src/thread_safe_append_keys_output.json"
```

#### Removing thread-safe additional keys

You can remove any additional thread-local keys from Logger using either `thread_safe_remove_keys` or `thread_safe_clear_keys`.

Use the `thread_safe_remove_keys` method to remove a list of thread-local keys that were previously added using the `thread_safe_append_keys` method.

=== "thread_safe_remove_keys.py"

```python hl_lines="13"
--8<-- "examples/logger/src/thread_safe_remove_keys.py"
```

=== "thread_safe_remove_keys_output.json"

```json hl_lines="8 9 17 18 26 34"
--8<-- "examples/logger/src/thread_safe_remove_keys_output.json"
```

#### Clearing thread-safe additional keys

Use the `thread_safe_clear_keys` method to remove all thread-local keys that were previously added using the `thread_safe_append_keys` method.

=== "thread_safe_clear_keys.py"

```python hl_lines="13"
--8<-- "examples/logger/src/thread_safe_clear_keys.py"
```

=== "thread_safe_clear_keys_output.json"

```json hl_lines="8 9 17 18"
--8<-- "examples/logger/src/thread_safe_clear_keys_output.json"
```

#### Accessing thread-safe currently keys

You can view all currently thread-local keys from the Logger state using the `thread_safe_get_current_keys()` method. This method is useful when you need to avoid overwriting keys that are already configured.

=== "thread_safe_get_current_keys.py"

```python hl_lines="13"
--8<-- "examples/logger/src/thread_safe_get_current_keys.py"
```

### Reusing Logger across your code

Similar to [Tracer](./tracer.md#reusing-tracer-across-your-code){target="_blank"}, a new instance that uses the same `service` name will reuse a previous Logger instance.
Expand Down
21 changes: 21 additions & 0 deletions examples/logger/src/thread_safe_append_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import threading
from typing import List

from aws_lambda_powertools import Logger
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = Logger()


def threaded_func(order_id: str):
logger.thread_safe_append_keys(order_id=order_id, thread_id=threading.get_ident())
logger.info("Collecting payment")


def lambda_handler(event: dict, context: LambdaContext) -> str:
order_ids: List[str] = event["order_ids"]

threading.Thread(target=threaded_func, args=(order_ids[0],)).start()
threading.Thread(target=threaded_func, args=(order_ids[1],)).start()

return "hello world"
20 changes: 20 additions & 0 deletions examples/logger/src/thread_safe_append_keys_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"level": "INFO",
"location": "threaded_func:11",
"message": "Collecting payment",
"timestamp": "2024-09-08 03:04:11,316-0400",
"service": "payment",
"order_id": "order_id_value_1",
"thread_id": "3507187776085958"
},
{
"level": "INFO",
"location": "threaded_func:11",
"message": "Collecting payment",
"timestamp": "2024-09-08 03:04:11,316-0400",
"service": "payment",
"order_id": "order_id_value_2",
"thread_id": "140718447808512"
}
]
23 changes: 23 additions & 0 deletions examples/logger/src/thread_safe_clear_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import threading
from typing import List

from aws_lambda_powertools import Logger
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = Logger()


def threaded_func(order_id: str):
logger.thread_safe_append_keys(order_id=order_id, thread_id=threading.get_ident())
logger.info("Collecting payment")
logger.thread_safe_clear_keys()
logger.info("Exiting thread")


def lambda_handler(event: dict, context: LambdaContext) -> str:
order_ids: List[str] = event["order_ids"]

threading.Thread(target=threaded_func, args=(order_ids[0],)).start()
threading.Thread(target=threaded_func, args=(order_ids[1],)).start()

return "hello world"
34 changes: 34 additions & 0 deletions examples/logger/src/thread_safe_clear_keys_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
[
{
"level": "INFO",
"location": "threaded_func:11",
"message": "Collecting payment",
"timestamp": "2024-09-08 12:26:10,648-0400",
"service": "payment",
"order_id": "order_id_value_1",
"thread_id": 140077070292544
},
{
"level": "INFO",
"location": "threaded_func:11",
"message": "Collecting payment",
"timestamp": "2024-09-08 12:26:10,649-0400",
"service": "payment",
"order_id": "order_id_value_2",
"thread_id": 140077061899840
},
{
"level": "INFO",
"location": "threaded_func:13",
"message": "Exiting thread",
"timestamp": "2024-09-08 12:26:10,649-0400",
"service": "payment"
},
{
"level": "INFO",
"location": "threaded_func:13",
"message": "Exiting thread",
"timestamp": "2024-09-08 12:26:10,649-0400",
"service": "payment"
}
]
14 changes: 14 additions & 0 deletions examples/logger/src/thread_safe_get_current_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from aws_lambda_powertools import Logger
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = Logger()


@logger.inject_lambda_context
def lambda_handler(event: dict, context: LambdaContext) -> str:
logger.info("Collecting payment")

if "order" not in logger.thread_safe_get_current_keys():
logger.thread_safe_append_keys(order=event.get("order"))

return "hello world"
Loading

0 comments on commit d3698d2

Please sign in to comment.