From a4c07f6a9c24bb0c31bed6e42eec207539e6accd Mon Sep 17 00:00:00 2001 From: Manvir Kaur <67894494+manvkaur@users.noreply.github.com> Date: Mon, 7 Oct 2024 18:34:12 +0100 Subject: [PATCH] fix: assistant_post_input and assistant_query_input decorator params updated to configure custom table storage (#244) * update assistant decortators for table storage optional params * update functionapp and tests * remove underscores * fix casing * correct typo * resolve linting issues * fix linting issues in two more files * update default collection name to ChatState --- azure/functions/decorators/function_app.py | 26 ++++++++++++++++++---- azure/functions/decorators/openai.py | 8 +++++++ tests/decorators/test_openai.py | 8 +++++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/azure/functions/decorators/function_app.py b/azure/functions/decorators/function_app.py index 773bf5da..1378b2cb 100644 --- a/azure/functions/decorators/function_app.py +++ b/azure/functions/decorators/function_app.py @@ -80,8 +80,8 @@ def __str__(self): return self.get_function_json() def __call__(self, *args, **kwargs): - """This would allow the Function object to be directly callable and runnable - directly using the interpreter locally. + """This would allow the Function object to be directly callable + and runnable directly using the interpreter locally. Example: @app.route(route="http_trigger") @@ -332,8 +332,8 @@ def decorator(): return wrap def _get_durable_blueprint(self): - """Attempt to import the Durable Functions SDK from which DF decorators are - implemented. + """Attempt to import the Durable Functions SDK from which DF + decorators are implemented. """ try: @@ -3266,6 +3266,8 @@ def assistant_query_input(self, arg_name: str, id: str, timestamp_utc: str, + chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501 + collection_name: Optional[str] = "ChatState", # noqa: E501 data_type: Optional[ Union[DataType, str]] = None, **kwargs) \ @@ -3278,6 +3280,11 @@ def assistant_query_input(self, :param timestamp_utc: the timestamp of the earliest message in the chat history to fetch. The timestamp should be in ISO 8601 format - for example, 2023-08-01T00:00:00Z. + :param chat_storage_connection_setting: The configuration section name + for the table settings for assistant chat storage. The default value is + "AzureWebJobsStorage". + :param collection_name: The table collection name for assistant chat + storage. The default value is "ChatState". :param id: The ID of the Assistant to query. :param data_type: Defines how Functions runtime should treat the parameter value @@ -3295,6 +3302,8 @@ def decorator(): name=arg_name, id=id, timestamp_utc=timestamp_utc, + chat_storage_connection_setting=chat_storage_connection_setting, # noqa: E501 + collection_name=collection_name, data_type=parse_singular_param_to_enum(data_type, DataType), **kwargs)) @@ -3308,6 +3317,8 @@ def assistant_post_input(self, arg_name: str, id: str, user_message: str, model: Optional[str] = None, + chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501 + collection_name: Optional[str] = "ChatState", # noqa: E501 data_type: Optional[ Union[DataType, str]] = None, **kwargs) \ @@ -3321,6 +3332,11 @@ def assistant_post_input(self, arg_name: str, :param user_message: The user message that user has entered for assistant to respond to. :param model: The OpenAI chat model to use. + :param chat_storage_connection_setting: The configuration section name + for the table settings for assistant chat storage. The default value is + "AzureWebJobsStorage". + :param collection_name: The table collection name for assistant chat + storage. The default value is "ChatState". :param data_type: Defines how Functions runtime should treat the parameter value :param kwargs: Keyword arguments for specifying additional binding @@ -3338,6 +3354,8 @@ def decorator(): id=id, user_message=user_message, model=model, + chat_storage_connection_setting=chat_storage_connection_setting, # noqa: E501 + collection_name=collection_name, data_type=parse_singular_param_to_enum(data_type, DataType), **kwargs)) diff --git a/azure/functions/decorators/openai.py b/azure/functions/decorators/openai.py index df459c1c..2563a78e 100644 --- a/azure/functions/decorators/openai.py +++ b/azure/functions/decorators/openai.py @@ -77,10 +77,14 @@ def __init__(self, name: str, id: str, timestamp_utc: str, + chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501 + collection_name: Optional[str] = "ChatState", data_type: Optional[DataType] = None, **kwargs): self.id = id self.timestamp_utc = timestamp_utc + self.chat_storage_connection_setting = chat_storage_connection_setting + self.collection_name = collection_name super().__init__(name=name, data_type=data_type) @@ -165,12 +169,16 @@ def __init__(self, name: str, id: str, user_message: str, model: Optional[str] = None, + chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501 + collection_name: Optional[str] = "ChatState", data_type: Optional[DataType] = None, **kwargs): self.name = name self.id = id self.user_message = user_message self.model = model + self.chat_storage_connection_setting = chat_storage_connection_setting + self.collection_name = collection_name super().__init__(name=name, data_type=data_type) diff --git a/tests/decorators/test_openai.py b/tests/decorators/test_openai.py index f2ebdaca..c2009c72 100644 --- a/tests/decorators/test_openai.py +++ b/tests/decorators/test_openai.py @@ -57,6 +57,8 @@ def test_text_completion_input_valid_creation(self): def test_assistant_query_input_valid_creation(self): input = AssistantQueryInput(name="test", timestamp_utc="timestamp_utc", + chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501 + collection_name="ChatState", data_type=DataType.UNDEFINED, id="test_id", type="assistantQueryInput", @@ -66,6 +68,8 @@ def test_assistant_query_input_valid_creation(self): self.assertEqual(input.get_dict_repr(), {"name": "test", "timestampUtc": "timestamp_utc", + "chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501 + "collectionName": "ChatState", "dataType": DataType.UNDEFINED, "direction": BindingDirection.IN, "type": "assistantQuery", @@ -111,6 +115,8 @@ def test_assistant_post_input_valid_creation(self): input = AssistantPostInput(name="test", id="test_id", model="test_model", + chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501 + collection_name="ChatState", user_message="test_message", data_type=DataType.UNDEFINED, dummy_field="dummy") @@ -120,6 +126,8 @@ def test_assistant_post_input_valid_creation(self): {"name": "test", "id": "test_id", "model": "test_model", + "chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501 + "collectionName": "ChatState", "userMessage": "test_message", "dataType": DataType.UNDEFINED, "direction": BindingDirection.IN,