Skip to content

Commit

Permalink
fix: assistant_post_input and assistant_query_input decorator params …
Browse files Browse the repository at this point in the history
…updated to configure custom table storage (#244)

* update assistant decortators for table storage optional params

* update functionapp and tests

* remove underscores

* fix casing

* correct typo

* resolve linting issues

* fix linting issues in two more files

* update default collection name to ChatState
  • Loading branch information
manvkaur authored Oct 7, 2024
1 parent 6356f6f commit a4c07f6
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
26 changes: 22 additions & 4 deletions azure/functions/decorators/function_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def __str__(self):
return self.get_function_json()

def __call__(self, *args, **kwargs):
"""This would allow the Function object to be directly callable and runnable
directly using the interpreter locally.
"""This would allow the Function object to be directly callable
and runnable directly using the interpreter locally.
Example:
@app.route(route="http_trigger")
Expand Down Expand Up @@ -332,8 +332,8 @@ def decorator():
return wrap

def _get_durable_blueprint(self):
"""Attempt to import the Durable Functions SDK from which DF decorators are
implemented.
"""Attempt to import the Durable Functions SDK from which DF
decorators are implemented.
"""

try:
Expand Down Expand Up @@ -3266,6 +3266,8 @@ def assistant_query_input(self,
arg_name: str,
id: str,
timestamp_utc: str,
chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501
collection_name: Optional[str] = "ChatState", # noqa: E501
data_type: Optional[
Union[DataType, str]] = None,
**kwargs) \
Expand All @@ -3278,6 +3280,11 @@ def assistant_query_input(self,
:param timestamp_utc: the timestamp of the earliest message in the chat
history to fetch. The timestamp should be in ISO 8601 format - for
example, 2023-08-01T00:00:00Z.
:param chat_storage_connection_setting: The configuration section name
for the table settings for assistant chat storage. The default value is
"AzureWebJobsStorage".
:param collection_name: The table collection name for assistant chat
storage. The default value is "ChatState".
:param id: The ID of the Assistant to query.
:param data_type: Defines how Functions runtime should treat the
parameter value
Expand All @@ -3295,6 +3302,8 @@ def decorator():
name=arg_name,
id=id,
timestamp_utc=timestamp_utc,
chat_storage_connection_setting=chat_storage_connection_setting, # noqa: E501
collection_name=collection_name,
data_type=parse_singular_param_to_enum(data_type,
DataType),
**kwargs))
Expand All @@ -3308,6 +3317,8 @@ def assistant_post_input(self, arg_name: str,
id: str,
user_message: str,
model: Optional[str] = None,
chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501
collection_name: Optional[str] = "ChatState", # noqa: E501
data_type: Optional[
Union[DataType, str]] = None,
**kwargs) \
Expand All @@ -3321,6 +3332,11 @@ def assistant_post_input(self, arg_name: str,
:param user_message: The user message that user has entered for
assistant to respond to.
:param model: The OpenAI chat model to use.
:param chat_storage_connection_setting: The configuration section name
for the table settings for assistant chat storage. The default value is
"AzureWebJobsStorage".
:param collection_name: The table collection name for assistant chat
storage. The default value is "ChatState".
:param data_type: Defines how Functions runtime should treat the
parameter value
:param kwargs: Keyword arguments for specifying additional binding
Expand All @@ -3338,6 +3354,8 @@ def decorator():
id=id,
user_message=user_message,
model=model,
chat_storage_connection_setting=chat_storage_connection_setting, # noqa: E501
collection_name=collection_name,
data_type=parse_singular_param_to_enum(data_type,
DataType),
**kwargs))
Expand Down
8 changes: 8 additions & 0 deletions azure/functions/decorators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,14 @@ def __init__(self,
name: str,
id: str,
timestamp_utc: str,
chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501
collection_name: Optional[str] = "ChatState",
data_type: Optional[DataType] = None,
**kwargs):
self.id = id
self.timestamp_utc = timestamp_utc
self.chat_storage_connection_setting = chat_storage_connection_setting
self.collection_name = collection_name
super().__init__(name=name, data_type=data_type)


Expand Down Expand Up @@ -165,12 +169,16 @@ def __init__(self, name: str,
id: str,
user_message: str,
model: Optional[str] = None,
chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501
collection_name: Optional[str] = "ChatState",
data_type: Optional[DataType] = None,
**kwargs):
self.name = name
self.id = id
self.user_message = user_message
self.model = model
self.chat_storage_connection_setting = chat_storage_connection_setting
self.collection_name = collection_name
super().__init__(name=name, data_type=data_type)


Expand Down
8 changes: 8 additions & 0 deletions tests/decorators/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def test_text_completion_input_valid_creation(self):
def test_assistant_query_input_valid_creation(self):
input = AssistantQueryInput(name="test",
timestamp_utc="timestamp_utc",
chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501
collection_name="ChatState",
data_type=DataType.UNDEFINED,
id="test_id",
type="assistantQueryInput",
Expand All @@ -66,6 +68,8 @@ def test_assistant_query_input_valid_creation(self):
self.assertEqual(input.get_dict_repr(),
{"name": "test",
"timestampUtc": "timestamp_utc",
"chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501
"collectionName": "ChatState",
"dataType": DataType.UNDEFINED,
"direction": BindingDirection.IN,
"type": "assistantQuery",
Expand Down Expand Up @@ -111,6 +115,8 @@ def test_assistant_post_input_valid_creation(self):
input = AssistantPostInput(name="test",
id="test_id",
model="test_model",
chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501
collection_name="ChatState",
user_message="test_message",
data_type=DataType.UNDEFINED,
dummy_field="dummy")
Expand All @@ -120,6 +126,8 @@ def test_assistant_post_input_valid_creation(self):
{"name": "test",
"id": "test_id",
"model": "test_model",
"chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501
"collectionName": "ChatState",
"userMessage": "test_message",
"dataType": DataType.UNDEFINED,
"direction": BindingDirection.IN,
Expand Down

0 comments on commit a4c07f6

Please sign in to comment.