diff --git a/azure/functions/__init__.py b/azure/functions/__init__.py index 86565e37..053d788a 100644 --- a/azure/functions/__init__.py +++ b/azure/functions/__init__.py @@ -23,6 +23,7 @@ from ._queue import QueueMessage from ._servicebus import ServiceBusMessage from ._sql import SqlRow, SqlRowList +from ._mysql import MySqlRow, MySqlRowList # Import binding implementations to register them from . import blob # NoQA @@ -37,6 +38,7 @@ from . import durable_functions # NoQA from . import sql # NoQA from . import warmup # NoQA +from . import mysql # NoQA __all__ = ( @@ -67,6 +69,8 @@ 'SqlRowList', 'TimerRequest', 'WarmUpContext', + 'MySqlRow', + 'MySqlRowList', # Middlewares 'WsgiMiddleware', @@ -98,4 +102,4 @@ 'BlobSource' ) -__version__ = '1.21.0' +__version__ = '1.22.0b2' diff --git a/azure/functions/_mysql.py b/azure/functions/_mysql.py new file mode 100644 index 00000000..9c7515d9 --- /dev/null +++ b/azure/functions/_mysql.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import abc +import collections +import json + + +class BaseMySqlRow(abc.ABC): + + @classmethod + @abc.abstractmethod + def from_json(cls, json_data: str) -> 'BaseMySqlRow': + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def from_dict(cls, dct: dict) -> 'BaseMySqlRow': + raise NotImplementedError + + @abc.abstractmethod + def __getitem__(self, key): + raise NotImplementedError + + @abc.abstractmethod + def __setitem__(self, key, value): + raise NotImplementedError + + @abc.abstractmethod + def to_json(self) -> str: + raise NotImplementedError + + +class BaseMySqlRowList(abc.ABC): + pass + + +class MySqlRow(BaseMySqlRow, collections.UserDict): + """A MySql Row. + + MySqlRow objects are ''UserDict'' subclasses and behave like dicts. + """ + + @classmethod + def from_json(cls, json_data: str) -> 'BaseMySqlRow': + """Create a MySqlRow from a JSON string.""" + return cls.from_dict(json.loads(json_data)) + + @classmethod + def from_dict(cls, dct: dict) -> 'BaseMySqlRow': + """Create a MySqlRow from a dict object""" + return cls({k: v for k, v in dct.items()}) + + def to_json(self) -> str: + """Return the JSON representation of the MySqlRow""" + return json.dumps(dict(self)) + + def __getitem__(self, key): + return collections.UserDict.__getitem__(self, key) + + def __setitem__(self, key, value): + return collections.UserDict.__setitem__(self, key, value) + + def __repr__(self) -> str: + return ( + f'' + ) + + +class MySqlRowList(BaseMySqlRowList, collections.UserList): + "A ''UserList'' subclass containing a list of :class:'~MySqlRow' objects" + pass diff --git a/azure/functions/decorators/blob.py b/azure/functions/decorators/blob.py index bd2861fa..eaad9220 100644 --- a/azure/functions/decorators/blob.py +++ b/azure/functions/decorators/blob.py @@ -17,7 +17,10 @@ def __init__(self, **kwargs): self.path = path self.connection = connection - self.source = source.value if source else None + if type(source) is BlobSource: + self.source = source.value if source else None + else: + self.source = source super().__init__(name=name, data_type=data_type) @staticmethod diff --git a/azure/functions/decorators/function_app.py b/azure/functions/decorators/function_app.py index 773bf5da..1378b2cb 100644 --- a/azure/functions/decorators/function_app.py +++ b/azure/functions/decorators/function_app.py @@ -80,8 +80,8 @@ def __str__(self): return self.get_function_json() def __call__(self, *args, **kwargs): - """This would allow the Function object to be directly callable and runnable - directly using the interpreter locally. + """This would allow the Function object to be directly callable + and runnable directly using the interpreter locally. Example: @app.route(route="http_trigger") @@ -332,8 +332,8 @@ def decorator(): return wrap def _get_durable_blueprint(self): - """Attempt to import the Durable Functions SDK from which DF decorators are - implemented. + """Attempt to import the Durable Functions SDK from which DF + decorators are implemented. """ try: @@ -3266,6 +3266,8 @@ def assistant_query_input(self, arg_name: str, id: str, timestamp_utc: str, + chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501 + collection_name: Optional[str] = "ChatState", # noqa: E501 data_type: Optional[ Union[DataType, str]] = None, **kwargs) \ @@ -3278,6 +3280,11 @@ def assistant_query_input(self, :param timestamp_utc: the timestamp of the earliest message in the chat history to fetch. The timestamp should be in ISO 8601 format - for example, 2023-08-01T00:00:00Z. + :param chat_storage_connection_setting: The configuration section name + for the table settings for assistant chat storage. The default value is + "AzureWebJobsStorage". + :param collection_name: The table collection name for assistant chat + storage. The default value is "ChatState". :param id: The ID of the Assistant to query. :param data_type: Defines how Functions runtime should treat the parameter value @@ -3295,6 +3302,8 @@ def decorator(): name=arg_name, id=id, timestamp_utc=timestamp_utc, + chat_storage_connection_setting=chat_storage_connection_setting, # noqa: E501 + collection_name=collection_name, data_type=parse_singular_param_to_enum(data_type, DataType), **kwargs)) @@ -3308,6 +3317,8 @@ def assistant_post_input(self, arg_name: str, id: str, user_message: str, model: Optional[str] = None, + chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501 + collection_name: Optional[str] = "ChatState", # noqa: E501 data_type: Optional[ Union[DataType, str]] = None, **kwargs) \ @@ -3321,6 +3332,11 @@ def assistant_post_input(self, arg_name: str, :param user_message: The user message that user has entered for assistant to respond to. :param model: The OpenAI chat model to use. + :param chat_storage_connection_setting: The configuration section name + for the table settings for assistant chat storage. The default value is + "AzureWebJobsStorage". + :param collection_name: The table collection name for assistant chat + storage. The default value is "ChatState". :param data_type: Defines how Functions runtime should treat the parameter value :param kwargs: Keyword arguments for specifying additional binding @@ -3338,6 +3354,8 @@ def decorator(): id=id, user_message=user_message, model=model, + chat_storage_connection_setting=chat_storage_connection_setting, # noqa: E501 + collection_name=collection_name, data_type=parse_singular_param_to_enum(data_type, DataType), **kwargs)) diff --git a/azure/functions/decorators/openai.py b/azure/functions/decorators/openai.py index df459c1c..2563a78e 100644 --- a/azure/functions/decorators/openai.py +++ b/azure/functions/decorators/openai.py @@ -77,10 +77,14 @@ def __init__(self, name: str, id: str, timestamp_utc: str, + chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501 + collection_name: Optional[str] = "ChatState", data_type: Optional[DataType] = None, **kwargs): self.id = id self.timestamp_utc = timestamp_utc + self.chat_storage_connection_setting = chat_storage_connection_setting + self.collection_name = collection_name super().__init__(name=name, data_type=data_type) @@ -165,12 +169,16 @@ def __init__(self, name: str, id: str, user_message: str, model: Optional[str] = None, + chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501 + collection_name: Optional[str] = "ChatState", data_type: Optional[DataType] = None, **kwargs): self.name = name self.id = id self.user_message = user_message self.model = model + self.chat_storage_connection_setting = chat_storage_connection_setting + self.collection_name = collection_name super().__init__(name=name, data_type=data_type) diff --git a/azure/functions/mysql.py b/azure/functions/mysql.py new file mode 100644 index 00000000..06a04a56 --- /dev/null +++ b/azure/functions/mysql.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import collections.abc +import json +import typing + +from azure.functions import _mysql as mysql + +from . import meta + + +class MySqlConverter(meta.InConverter, meta.OutConverter, + binding='mysql'): + + @classmethod + def check_input_type_annotation(cls, pytype: type) -> bool: + return issubclass(pytype, mysql.BaseMySqlRowList) + + @classmethod + def check_output_type_annotation(cls, pytype: type) -> bool: + return issubclass(pytype, (mysql.BaseMySqlRowList, mysql.BaseMySqlRow)) + + @classmethod + def decode(cls, + data: meta.Datum, + *, + trigger_metadata) -> typing.Optional[mysql.MySqlRowList]: + if data is None or data.type is None: + return None + + data_type = data.type + + if data_type in ['string', 'json']: + body = data.value + + elif data_type == 'bytes': + body = data.value.decode('utf-8') + + else: + raise NotImplementedError( + f'Unsupported payload type: {data_type}') + + rows = json.loads(body) + if not isinstance(rows, list): + rows = [rows] + + return mysql.MySqlRowList( + (None if row is None else mysql.MySqlRow.from_dict(row)) + for row in rows) + + @classmethod + def encode(cls, obj: typing.Any, *, + expected_type: typing.Optional[type]) -> meta.Datum: + if isinstance(obj, mysql.MySqlRow): + data = mysql.MySqlRowList([obj]) + + elif isinstance(obj, mysql.MySqlRowList): + data = obj + + elif isinstance(obj, collections.abc.Iterable): + data = mysql.MySqlRowList() + + for row in obj: + if not isinstance(row, mysql.MySqlRow): + raise NotImplementedError( + f'Unsupported list type: {type(obj)}, \ + lists must contain MySqlRow objects') + else: + data.append(row) + + else: + raise NotImplementedError(f'Unsupported type: {type(obj)}') + + return meta.Datum( + type='json', + value=json.dumps([dict(d) for d in data]) + ) diff --git a/tests/decorators/test_openai.py b/tests/decorators/test_openai.py index f2ebdaca..c2009c72 100644 --- a/tests/decorators/test_openai.py +++ b/tests/decorators/test_openai.py @@ -57,6 +57,8 @@ def test_text_completion_input_valid_creation(self): def test_assistant_query_input_valid_creation(self): input = AssistantQueryInput(name="test", timestamp_utc="timestamp_utc", + chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501 + collection_name="ChatState", data_type=DataType.UNDEFINED, id="test_id", type="assistantQueryInput", @@ -66,6 +68,8 @@ def test_assistant_query_input_valid_creation(self): self.assertEqual(input.get_dict_repr(), {"name": "test", "timestampUtc": "timestamp_utc", + "chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501 + "collectionName": "ChatState", "dataType": DataType.UNDEFINED, "direction": BindingDirection.IN, "type": "assistantQuery", @@ -111,6 +115,8 @@ def test_assistant_post_input_valid_creation(self): input = AssistantPostInput(name="test", id="test_id", model="test_model", + chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501 + collection_name="ChatState", user_message="test_message", data_type=DataType.UNDEFINED, dummy_field="dummy") @@ -120,6 +126,8 @@ def test_assistant_post_input_valid_creation(self): {"name": "test", "id": "test_id", "model": "test_model", + "chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501 + "collectionName": "ChatState", "userMessage": "test_message", "dataType": DataType.UNDEFINED, "direction": BindingDirection.IN, diff --git a/tests/test_mysql.py b/tests/test_mysql.py new file mode 100644 index 00000000..514c066a --- /dev/null +++ b/tests/test_mysql.py @@ -0,0 +1,293 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import json +import unittest + +import azure.functions as func +import azure.functions.mysql as mysql +from azure.functions.meta import Datum + + +class TestMySql(unittest.TestCase): + def test_mysql_decode_none(self): + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=None, trigger_metadata=None) + self.assertIsNone(result) + + def test_mysql_decode_string(self): + datum: Datum = Datum(""" + { + "id": "1", + "name": "test" + } + """, "string") + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + self.assertIsNotNone(result, + 'MySqlRowList should be non-None') + self.assertEqual(len(result), + 1, + 'MySqlRowList should have exactly 1 item') + self.assertEqual(result[0]['id'], + '1', + 'MySqlRow item should have id 1') + self.assertEqual(result[0]['name'], + 'test', + 'MySqlRow item should have name test') + + def test_mysql_decode_bytes(self): + datum: Datum = Datum(""" + { + "id": "1", + "name": "test" + } + """.encode(), "bytes") + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + self.assertIsNotNone(result, + 'MySqlRowList should be non-None') + self.assertEqual(len(result), + 1, + 'MySqlRowList should have exactly 1 item') + self.assertEqual(result[0]['id'], + '1', + 'MySqlRow item should have id 1') + self.assertEqual(result[0]['name'], + 'test', + 'MySqlRow item should have name test') + + def test_mysql_decode_json(self): + datum: Datum = Datum(""" + { + "id": "1", + "name": "test" + } + """, "json") + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + self.assertIsNotNone(result, + 'MySqlRowList should be non-None') + self.assertEqual(len(result), + 1, + 'MySqlRowList should have exactly 1 item') + self.assertEqual(result[0]['id'], + '1', + 'MySqlRow item should have id 1') + self.assertEqual(result[0]['name'], + 'test', + 'MySqlRow item should have name test') + + def test_mysql_decode_json_name_is_null(self): + datum: Datum = Datum(""" + { + "id": "1", + "name": null + } + """, "json") + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + self.assertIsNotNone(result, + 'MySqlRowList itself should be non-None') + self.assertEqual(len(result), + 1, + 'MySqlRowList should have exactly 1 item') + self.assertEqual(result[0]['name'], + None, + 'Item in MySqlRowList should be None') + + def test_mysql_decode_json_multiple_entries(self): + datum: Datum = Datum(""" + [ + { + "id": "1", + "name": "test1" + }, + { + "id": "2", + "name": "test2" + } + ] + """, "json") + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + self.assertIsNotNone(result) + self.assertEqual(len(result), + 2, + 'MySqlRowList should have exactly 2 items') + self.assertEqual(result[0]['id'], + '1', + 'First MySqlRowList item should have id 1') + self.assertEqual(result[0]['name'], + 'test1', + 'First MySqlRowList item should have name test1') + self.assertEqual(result[1]['id'], + '2', + 'First MySqlRowList item should have id 2') + self.assertEqual(result[1]['name'], + 'test2', + 'Second MySqlRowList item should have name test2') + + def test_mysql_decode_json_multiple_nulls(self): + datum: Datum = Datum("[null]", "json") + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + self.assertIsNotNone(result) + self.assertEqual(len(result), + 1, + 'MySqlRowList should have exactly 1 item') + self.assertEqual(result[0], + None, + 'MySqlRow item should be None') + + def test_mysql_encode_mysqlrow(self): + mysqlRow = func.MySqlRow.from_json(""" + { + "id": "1", + "name": "test" + } + """) + datum = mysql.MySqlConverter.encode(obj=mysqlRow, expected_type=None) + self.assertEqual(datum.type, + 'json', + 'Datum type should be JSON') + self.assertEqual(len(datum.python_value), + 1, + 'Encoded value should be list of length 1') + self.assertEqual(datum.python_value[0]['id'], + '1', + 'id should be 1') + self.assertEqual(datum.python_value[0]['name'], + 'test', + 'name should be test') + + def test_mysql_encode_mysqlrowlist(self): + datum: Datum = Datum(""" + { + "id": "1", + "name": "test" + } + """, "json") + mysqlRowList: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + datum = mysql.MySqlConverter.encode( + obj=mysqlRowList, expected_type=None) + self.assertEqual(datum.type, + 'json', + 'Datum type should be JSON') + self.assertEqual(len(datum.python_value), + 1, + 'Encoded value should be list of length 1') + self.assertEqual(datum.python_value[0]['id'], + '1', + 'id should be 1') + self.assertEqual(datum.python_value[0]['name'], + 'test', + 'name should be test') + + def test_mysql_encode_list_of_mysqlrows(self): + mysqlRows = [ + func.MySqlRow.from_json(""" + { + "id": "1", + "name": "test" + } + """), + func.MySqlRow.from_json(""" + { + "id": "2", + "name": "test2" + } + """) + ] + datum = mysql.MySqlConverter.encode(obj=mysqlRows, expected_type=None) + self.assertEqual(datum.type, + 'json', + 'Datum type should be JSON') + self.assertEqual(len(datum.python_value), + 2, + 'Encoded value should be list of length 2') + self.assertEqual(datum.python_value[0]['id'], + '1', + 'id should be 1') + self.assertEqual(datum.python_value[0]['name'], + 'test', + 'name should be test') + self.assertEqual(datum.python_value[1]['id'], + '2', + 'id should be 2') + self.assertEqual(datum.python_value[1]['name'], + 'test2', + 'name should be test2') + + def test_mysql_encode_list_of_str_raises(self): + strList = [ + """ + { + "id": "1", + "name": "test" + } + """ + ] + self.assertRaises(NotImplementedError, + mysql.MySqlConverter.encode, + obj=strList, + expected_type=None) + + def test_mysql_encode_list_of_mysqlrowlist_raises(self): + datum: Datum = Datum(""" + { + "id": "1", + "name": "test" + } + """, "json") + mysqlRowListList = [ + mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + ] + self.assertRaises(NotImplementedError, + mysql.MySqlConverter.encode, + obj=mysqlRowListList, + expected_type=None) + + def test_mysql_input_type(self): + check_input_type = mysql.MySqlConverter.check_input_type_annotation + self.assertTrue(check_input_type(func.MySqlRowList), + 'MySqlRowList should be accepted') + self.assertFalse(check_input_type(func.MySqlRow), + 'MySqlRow should not be accepted') + self.assertFalse(check_input_type(str), + 'str should not be accepted') + + def test_mysql_output_type(self): + check_output_type = mysql.MySqlConverter.check_output_type_annotation + self.assertTrue(check_output_type(func.MySqlRowList), + 'MySqlRowList should be accepted') + self.assertTrue(check_output_type(func.MySqlRow), + 'MySqlRow should be accepted') + self.assertFalse(check_output_type(str), + 'str should not be accepted') + + def test_mysqlrow_json(self): + # Parse MySqlRow from JSON + mysqlRow = func.MySqlRow.from_json(""" + { + "id": "1", + "name": "test" + } + """) + self.assertEqual(mysqlRow['id'], + '1', + 'Parsed MySqlRow id should be 1') + self.assertEqual(mysqlRow['name'], + 'test', + 'Parsed MySqlRow name should be test') + + # Parse JSON from MySqlRow + mysqlRowJson = json.loads(func.MySqlRow.to_json(mysqlRow)) + self.assertEqual(mysqlRowJson['id'], + '1', + 'Parsed JSON id should be 1') + self.assertEqual(mysqlRowJson['name'], + 'test', + 'Parsed JSON name should be test')