Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop-google-firestore-instrum…
Browse files Browse the repository at this point in the history
…entation' into feature-async-wrapper-argument
  • Loading branch information
TimPansino committed Aug 2, 2023
2 parents faf3ccc + edd1f94 commit 479f9e2
Show file tree
Hide file tree
Showing 9 changed files with 774 additions and 12 deletions.
93 changes: 83 additions & 10 deletions newrelic/api/datastore_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def __enter__(self):
self.product = transaction._intern_string(self.product)
self.target = transaction._intern_string(self.target)
self.operation = transaction._intern_string(self.operation)
self.host = transaction._intern_string(self.host)
self.port_path_or_id = transaction._intern_string(self.port_path_or_id)
self.database_name = transaction._intern_string(self.database_name)

datastore_tracer_settings = transaction.settings.datastore_tracer
self.instance_reporting_enabled = datastore_tracer_settings.instance_reporting.enabled
Expand All @@ -92,7 +95,14 @@ def __repr__(self):
return "<%s object at 0x%x %s>" % (
self.__class__.__name__,
id(self),
dict(product=self.product, target=self.target, operation=self.operation),
dict(
product=self.product,
target=self.target,
operation=self.operation,
host=self.host,
port_path_or_id=self.port_path_or_id,
database_name=self.database_name,
),
)

def finalize_data(self, transaction, exc=None, value=None, tb=None):
Expand Down Expand Up @@ -125,7 +135,7 @@ def create_node(self):
)


def DatastoreTraceWrapper(wrapped, product, target, operation, async_wrapper=None):
def DatastoreTraceWrapper(wrapped, product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None):
"""Wraps a method to time datastore queries.
:param wrapped: The function to apply the trace to.
Expand All @@ -140,9 +150,17 @@ def DatastoreTraceWrapper(wrapped, product, target, operation, async_wrapper=Non
or the name of any API function/method in the client
library.
:type operation: str or callable
:rtype: :class:`newrelic.common.object_wrapper.FunctionWrapper`
:param host: The name of the server hosting the actual datastore.
:type host: str
:param port_path_or_id: The value passed in can represent either the port,
path, or id of the datastore being connected to.
:type port_path_or_id: str
:param database_name: The name of database where the current query is being
executed.
:type database_name: str
:param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper.
:type async_wrapper: callable or None
:rtype: :class:`newrelic.common.object_wrapper.FunctionWrapper`
This is typically used to wrap datastore queries such as calls to Redis or
ElasticSearch.
Expand Down Expand Up @@ -189,7 +207,33 @@ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs):
else:
_operation = operation

trace = DatastoreTrace(_product, _target, _operation, parent=parent, source=wrapped)
if callable(host):
if instance is not None:
_host = host(instance, *args, **kwargs)
else:
_host = host(*args, **kwargs)
else:
_host = host

if callable(port_path_or_id):
if instance is not None:
_port_path_or_id = port_path_or_id(instance, *args, **kwargs)
else:
_port_path_or_id = port_path_or_id(*args, **kwargs)
else:
_port_path_or_id = port_path_or_id

if callable(database_name):
if instance is not None:
_database_name = database_name(instance, *args, **kwargs)
else:
_database_name = database_name(*args, **kwargs)
else:
_database_name = database_name

trace = DatastoreTrace(
_product, _target, _operation, _host, _port_path_or_id, _database_name, parent=parent, source=wrapped
)

if wrapper: # pylint: disable=W0125,W0126
return wrapper(wrapped, trace)(*args, **kwargs)
Expand All @@ -200,7 +244,7 @@ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs):
return FunctionWrapper(wrapped, _nr_datastore_trace_wrapper_)


def datastore_trace(product, target, operation, async_wrapper=None):
def datastore_trace(product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None):
"""Decorator allows datastore query to be timed.
:param product: The name of the vendor.
Expand All @@ -213,6 +257,14 @@ def datastore_trace(product, target, operation, async_wrapper=None):
or the name of any API function/method in the client
library.
:type operation: str
:param host: The name of the server hosting the actual datastore.
:type host: str
:param port_path_or_id: The value passed in can represent either the port,
path, or id of the datastore being connected to.
:type port_path_or_id: str
:param database_name: The name of database where the current query is being
executed.
:type database_name: str
:param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper.
:type async_wrapper: callable or None
Expand All @@ -228,10 +280,21 @@ def datastore_trace(product, target, operation, async_wrapper=None):
... time.sleep(*args, **kwargs)
"""
return functools.partial(DatastoreTraceWrapper, product=product, target=target, operation=operation, async_wrapper=async_wrapper)


def wrap_datastore_trace(module, object_path, product, target, operation, async_wrapper=None):
return functools.partial(
DatastoreTraceWrapper,
product=product,
target=target,
operation=operation,
host=host,
port_path_or_id=port_path_or_id,
database_name=database_name,
async_wrapper=None,
)


def wrap_datastore_trace(
module, object_path, product, target, operation, host=None, port_path_or_id=None, database_name=None
):
"""Method applies custom timing to datastore query.
:param module: Module containing the method to be instrumented.
Expand All @@ -248,6 +311,14 @@ def wrap_datastore_trace(module, object_path, product, target, operation, async_
or the name of any API function/method in the client
library.
:type operation: str
:param host: The name of the server hosting the actual datastore.
:type host: str
:param port_path_or_id: The value passed in can represent either the port,
path, or id of the datastore being connected to.
:type port_path_or_id: str
:param database_name: The name of database where the current query is being
executed.
:type database_name: str
:param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper.
:type async_wrapper: callable or None
Expand All @@ -262,4 +333,6 @@ def wrap_datastore_trace(module, object_path, product, target, operation, async_
... 'sleep')
"""
wrap_object(module, object_path, DatastoreTraceWrapper, (product, target, operation, async_wrapper))
wrap_object(
module, object_path, DatastoreTraceWrapper, (product, target, operation, host, port_path_or_id, database_name, async_wrapper)
)
42 changes: 41 additions & 1 deletion newrelic/common/async_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
is_coroutine_callable,
is_asyncio_coroutine,
is_generator_function,
is_async_generator_function,
)


Expand All @@ -29,7 +30,6 @@ def evaluate_wrapper(wrapper_string, wrapped, trace):


def coroutine_wrapper(wrapped, trace):

WRAPPER = textwrap.dedent("""
@functools.wraps(wrapped)
async def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -81,9 +81,49 @@ def wrapper(*args, **kwargs):
return wrapper


def async_generator_wrapper(wrapped, trace):
WRAPPER = textwrap.dedent("""
@functools.wraps(wrapped)
async def wrapper(*args, **kwargs):
g = wrapped(*args, **kwargs)
value = None
with trace:
while True:
try:
yielded = await g.asend(value)
except StopAsyncIteration as e:
# The underlying async generator has finished, return propagates a new StopAsyncIteration
return
except StopIteration as e:
# The call to async_generator_asend.send() should raise a StopIteration containing the yielded value
yielded = e.value
try:
value = yield yielded
except BaseException as e:
# An exception was thrown with .athrow(), propagate to the original async generator.
# Return value logic must be identical to .asend()
try:
value = yield await g.athrow(type(e), e)
except StopAsyncIteration as e:
# The underlying async generator has finished, return propagates a new StopAsyncIteration
return
except StopIteration as e:
# The call to async_generator_athrow.send() should raise a StopIteration containing a yielded value
value = yield e.value
""")

try:
return evaluate_wrapper(WRAPPER, wrapped, trace)
except:
return wrapped


def async_wrapper(wrapped):
if is_coroutine_callable(wrapped):
return coroutine_wrapper
elif is_async_generator_function(wrapped):
return async_generator_wrapper
elif is_generator_function(wrapped):
if is_asyncio_coroutine(wrapped):
return awaitable_generator_wrapper
Expand Down
8 changes: 8 additions & 0 deletions newrelic/common/coroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,11 @@ def _iscoroutinefunction_tornado(fn):

def is_coroutine_callable(wrapped):
return is_coroutine_function(wrapped) or is_coroutine_function(getattr(wrapped, "__call__", None))


if hasattr(inspect, 'isasyncgenfunction'):
def is_async_generator_function(wrapped):
return inspect.isasyncgenfunction(wrapped)
else:
def is_async_generator_function(wrapped):
return False
5 changes: 5 additions & 0 deletions tests/agent_features/_test_async_coroutine_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from newrelic.api.datastore_trace import datastore_trace
from newrelic.api.external_trace import external_trace
from newrelic.api.function_trace import function_trace
from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace
from newrelic.api.memcache_trace import memcache_trace
from newrelic.api.message_trace import message_trace

Expand All @@ -41,6 +42,8 @@
(functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"),
(functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"),
(functools.partial(memcache_trace, "cmd"), "Memcache/cmd"),
(functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL/<unknown>/<anonymous>/<unknown>"),
(functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/<unknown>"),
],
)
def test_awaitable_timing(event_loop, trace, metric):
Expand Down Expand Up @@ -79,6 +82,8 @@ def _test():
(functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"),
(functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"),
(functools.partial(memcache_trace, "cmd"), "Memcache/cmd"),
(functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL/<unknown>/<anonymous>/<unknown>"),
(functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/<unknown>"),
],
)
@pytest.mark.parametrize("yield_from", [True, False])
Expand Down
Loading

0 comments on commit 479f9e2

Please sign in to comment.