Skip to content

Commit

Permalink
fix: added optional context param for tasks (#1523)
Browse files Browse the repository at this point in the history
* added optional context param for tasks

* checks for 3.11 or lower

* test fixes

* lint & skipping tests

* only one check needed

* lint + comments

* better tests

* removed comment

---------

Co-authored-by: Victoria Hall <victoria.hall@microsoft.com>
  • Loading branch information
hallvictoria and Victoria Hall authored Aug 14, 2024
1 parent bbc683e commit f4c9c2d
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 4 deletions.
14 changes: 11 additions & 3 deletions azure_functions_worker/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,11 @@ async def dispatch_forever(self): # sourcery skip: swap-if-expression
start_stream=protos.StartStream(
worker_id=self.worker_id)))

# In Python 3.11+, constructing a task has an optional context
# parameter. Allow for this param to be passed to ContextEnabledTask
self._loop.set_task_factory(
lambda loop, coro: ContextEnabledTask(coro, loop=loop))
lambda loop, coro, context=None: ContextEnabledTask(
coro, loop=loop, context=context))

# Detach console logging before enabling GRPC channel logging
logger.info('Detaching console logging.')
Expand Down Expand Up @@ -1068,8 +1071,13 @@ def emit(self, record: LogRecord) -> None:
class ContextEnabledTask(asyncio.Task):
AZURE_INVOCATION_ID = '__azure_function_invocation_id__'

def __init__(self, coro, loop):
super().__init__(coro, loop=loop)
def __init__(self, coro, loop, context=None):
# The context param is only available for 3.11+. If
# not, it can't be sent in the init() call.
if sys.version_info.minor >= 11:
super().__init__(coro, loop=loop, context=context)
else:
super().__init__(coro, loop=loop)

current_task = asyncio.current_task(loop)
if current_task is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"scriptFile": "main.py",
"bindings": [
{
"type": "httpTrigger",
"direction": "in",
"name": "req"
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}
35 changes: 35 additions & 0 deletions tests/unittests/http_functions/create_task_with_context/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import asyncio
import contextvars

import azure.functions

num = contextvars.ContextVar('num')


async def count(name: str):
# The number of times the loop is executed
# depends on the val set in context
val = num.get()
for i in range(val):
await asyncio.sleep(0.5)
return f"Finished {name} in {val}"


async def main(req: azure.functions.HttpRequest):
# Create first task with context num = 5
num.set(5)
first_ctx = contextvars.copy_context()
first_count_task = asyncio.create_task(count("Hello World"), context=first_ctx)

# Create second task with context num = 10
num.set(10)
second_ctx = contextvars.copy_context()
second_count_task = asyncio.create_task(count("Hello World"), context=second_ctx)

# Execute tasks
first_count_val = await first_count_task
second_count_val = await second_count_task

return f'{first_count_val + " | " + second_count_val}'
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"scriptFile": "main.py",
"bindings": [
{
"type": "httpTrigger",
"direction": "in",
"name": "req"
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}
20 changes: 20 additions & 0 deletions tests/unittests/http_functions/create_task_without_context/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import asyncio

import azure.functions


async def count(name: str, num: int):
# The number of times the loop executes is decided by a
# user-defined param
for i in range(num):
await asyncio.sleep(0.5)
return f"Finished {name} in {num}"


async def main(req: azure.functions.HttpRequest):
# No context is being sent into asyncio.create_task
count_task = asyncio.create_task(count("Hello World", 5))
count_val = await count_task
return f'{count_val}'
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import asyncio
import contextvars
import hashlib
import json
import logging
Expand All @@ -14,6 +15,25 @@

logger = logging.getLogger("my-function")

num = contextvars.ContextVar('num')


async def count_with_context(name: str):
# The number of times the loop is executed
# depends on the val set in context
val = num.get()
for i in range(val):
await asyncio.sleep(0.5)
return f"Finished {name} in {val}"


async def count_without_context(name: str, number: int):
# The number of times the loop executes is decided by a
# user-defined param
for i in range(number):
await asyncio.sleep(0.5)
return f"Finished {name} in {number}"


@app.route(route="return_str")
def return_str(req: func.HttpRequest) -> str:
Expand Down Expand Up @@ -404,3 +424,32 @@ def set_cookie_resp_header_empty(
resp.headers.add("Set-Cookie", '')

return resp


@app.route('create_task_with_context')
async def create_task_with_context(req: func.HttpRequest):
# Create first task with context num = 5
num.set(5)
first_ctx = contextvars.copy_context()
first_count_task = asyncio.create_task(
count_with_context("Hello World"), context=first_ctx)

# Create second task with context num = 10
num.set(10)
second_ctx = contextvars.copy_context()
second_count_task = asyncio.create_task(
count_with_context("Hello World"), context=second_ctx)

# Execute tasks
first_count_val = await first_count_task
second_count_val = await second_count_task

return f'{first_count_val + " | " + second_count_val}'


@app.route('create_task_without_context')
async def create_task_without_context(req: func.HttpRequest):
# No context is being sent into asyncio.create_task
count_task = asyncio.create_task(count_without_context("Hello World", 5))
count_val = await count_task
return f'{count_val}'
39 changes: 38 additions & 1 deletion tests/unittests/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.
import asyncio
import collections as col
import contextvars
import os
import sys
import unittest
Expand All @@ -21,7 +22,7 @@
PYTHON_THREADPOOL_THREAD_COUNT_MAX_37,
PYTHON_THREADPOOL_THREAD_COUNT_MIN,
)
from azure_functions_worker.dispatcher import Dispatcher
from azure_functions_worker.dispatcher import Dispatcher, ContextEnabledTask
from azure_functions_worker.version import VERSION

SysVersionInfo = col.namedtuple("VersionInfo", ["major", "minor", "micro",
Expand Down Expand Up @@ -989,3 +990,39 @@ def test_dispatcher_indexing_in_load_request_with_exception(
self.assertEqual(
response.function_load_response.result.exception.message,
"Exception: Mocked Exception")


class TestContextEnabledTask(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

def tearDown(self):
self.loop.close()

def test_init_with_context(self):
# Since ContextEnabledTask accepts the context param,
# no errors will be thrown here
num = contextvars.ContextVar('num')
num.set(5)
ctx = contextvars.copy_context()
exception_raised = False
try:
self.loop.set_task_factory(
lambda loop, coro, context=None: ContextEnabledTask(
coro, loop=loop, context=ctx))
except TypeError:
exception_raised = True
self.assertFalse(exception_raised)

async def test_init_without_context(self):
# If the context param is not defined,
# no errors will be thrown for backwards compatibility
exception_raised = False
try:
self.loop.set_task_factory(
lambda loop, coro: ContextEnabledTask(
coro, loop=loop))
except TypeError:
exception_raised = True
self.assertFalse(exception_raised)
15 changes: 15 additions & 0 deletions tests/unittests/test_http_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,21 @@ def check_log_hijack_current_event_loop(self, host_out: typing.List[str]):
# System logs should not exist in host_out
self.assertNotIn('parallelly_log_system at disguised_logger', host_out)

@skipIf(sys.version_info.minor < 11,
"The context param is only available for 3.11+")
def test_create_task_with_context(self):
r = self.webhost.request('GET', 'create_task_with_context')

self.assertEqual(r.status_code, 200)
self.assertEqual(r.text, 'Finished Hello World in 5'
' | Finished Hello World in 10')

def test_create_task_without_context(self):
r = self.webhost.request('GET', 'create_task_without_context')

self.assertEqual(r.status_code, 200)
self.assertEqual(r.text, 'Finished Hello World in 5')


class TestHttpFunctionsStein(TestHttpFunctions):

Expand Down

0 comments on commit f4c9c2d

Please sign in to comment.