Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update azure-ai-inference client library to support sending images as part of chat completions #36022

Merged
merged 5 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions sdk/ai/azure-ai-inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ print(response.choices[0].message.content)

<!-- END SNIPPET -->

The following types or messages are supported: `SystemMessage`,`UserMessage`, `AssistantMessage`, `ToolMessage` (See sample [sample_chat_completions_with_tools.py](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_tools.py) for usage of `ToolMessage`).
The following types or messages are supported: `SystemMessage`,`UserMessage`, `AssistantMessage`, `ToolMessage`. See sample [sample_chat_completions_with_tools.py](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_tools.py) for usage of `ToolMessage`. See [sample_chat_completions_with_images.py](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_images.py) for usage of `UserMessage` that
includes uploading an image.

Alternatively, you can provide the messages as dictionary instead of using the strongly typed classes like `SystemMessage` and `UserMessage`:

Expand All @@ -232,7 +233,10 @@ response = client.complete(
"role": "assistant",
"content": "The main construction of the International Space Station (ISS) was completed between 1998 and 2011. During this period, more than 30 flights by US space shuttles and 40 by Russian rockets were conducted to transport components and modules to the station.",
},
{"role": "user", "content": "And what was the estimated cost to build it?"},
{
"role": "user",
"content": "And what was the estimated cost to build it?"
},
]
}
)
Expand Down Expand Up @@ -399,7 +403,7 @@ try:
result = client.complete( ... )
except HttpResponseError as e:
print(f"Status code: {e.status_code} ({e.reason})")
print(f"{e.message}")
print(e.message)
```

For example, when you provide a wrong authentication key:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,6 @@ def _complete(
response. Required.
"model": "str", # The model used for the chat completion. Required.
"usage": {
"capacity_type": "str", # Indicates whether your capacity has been
affected by the usage amount (token count) reported here. Required. Known
values are: "usage" and "fixed".
"completion_tokens": 0, # The number of tokens generated across all
completions emissions. Required.
"prompt_tokens": 0, # The number of tokens in the provided prompts
Expand Down Expand Up @@ -678,9 +675,6 @@ def _embed(
"id": "str", # Unique identifier for the embeddings result. Required.
"model": "str", # The model ID used to generate this result. Required.
"usage": {
"capacity_type": "str", # Indicates whether your capacity has been
affected by the usage amount (token count) reported here. Required. Known
values are: "usage" and "fixed".
"input_tokens": 0, # Number of tokens in the request prompt.
Required.
"prompt_tokens": 0, # Number of tokens used for the prompt sent to
Expand Down Expand Up @@ -953,9 +947,6 @@ def _embed(
"id": "str", # Unique identifier for the embeddings result. Required.
"model": "str", # The model ID used to generate this result. Required.
"usage": {
"capacity_type": "str", # Indicates whether your capacity has been
affected by the usage amount (token count) reported here. Required. Known
values are: "usage" and "fixed".
"input_tokens": 0, # Number of tokens in the request prompt.
Required.
"prompt_tokens": 0, # Number of tokens used for the prompt sent to
Expand Down
31 changes: 9 additions & 22 deletions sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ def load_client(
:raises ~azure.core.exceptions.HttpResponseError
"""

with ChatCompletionsClient(endpoint, credential, **kwargs) as client: # Pick any of the clients, it does not matter.
model_info = client.get_model_info() # type: ignore
with ChatCompletionsClient(
endpoint, credential, **kwargs
) as client: # Pick any of the clients, it does not matter.
model_info = client.get_model_info() # type: ignore

_LOGGER.info("model_info=%s", model_info)
if not model_info.model_type:
Expand Down Expand Up @@ -142,7 +144,6 @@ def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCr
self._model_info: Optional[_models.ModelInfo] = None
super().__init__(endpoint, credential, **kwargs)


@overload
def complete(
self,
Expand All @@ -164,9 +165,7 @@ def complete(
] = None,
seed: Optional[int] = None,
**kwargs: Any,
) -> _models.ChatCompletions:
...

) -> _models.ChatCompletions: ...

@overload
def complete(
Expand All @@ -189,9 +188,7 @@ def complete(
] = None,
seed: Optional[int] = None,
**kwargs: Any,
) -> _models.StreamingChatCompletions:
...

) -> _models.StreamingChatCompletions: ...

@overload
def complete(
Expand Down Expand Up @@ -535,7 +532,6 @@ def complete(

return _deserialize(_models._models.ChatCompletions, response.json()) # pylint: disable=protected-access


@distributed_trace
def get_model_info(self, **kwargs: Any) -> _models.ModelInfo:
# pylint: disable=line-too-long
Expand All @@ -546,15 +542,13 @@ def get_model_info(self, **kwargs: Any) -> _models.ModelInfo:
:raises ~azure.core.exceptions.HttpResponseError
"""
if not self._model_info:
self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init
self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init
return self._model_info


def __str__(self) -> str:
# pylint: disable=client-method-name-no-double-underscore
return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__()


# Remove this once https://github.com/Azure/autorest.python/issues/2619 is fixed,
# and you see the equivalent auto-generated method in _client.py return "Self"
def __enter__(self) -> Self:
Expand All @@ -581,7 +575,6 @@ def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCr
self._model_info: Optional[_models.ModelInfo] = None
super().__init__(endpoint, credential, **kwargs)


@overload
def embed(
self,
Expand Down Expand Up @@ -791,7 +784,6 @@ def embed(

return deserialized # type: ignore


@distributed_trace
def get_model_info(self, **kwargs: Any) -> _models.ModelInfo:
# pylint: disable=line-too-long
Expand All @@ -802,15 +794,13 @@ def get_model_info(self, **kwargs: Any) -> _models.ModelInfo:
:raises ~azure.core.exceptions.HttpResponseError
"""
if not self._model_info:
self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init
self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init
return self._model_info


def __str__(self) -> str:
# pylint: disable=client-method-name-no-double-underscore
return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__()


# Remove this once https://github.com/Azure/autorest.python/issues/2619 is fixed,
# and you see the equivalent auto-generated method in _client.py return "Self"
def __enter__(self) -> Self:
Expand Down Expand Up @@ -1046,7 +1036,6 @@ def embed(

return deserialized # type: ignore


@distributed_trace
def get_model_info(self, **kwargs: Any) -> _models.ModelInfo:
# pylint: disable=line-too-long
Expand All @@ -1057,15 +1046,13 @@ def get_model_info(self, **kwargs: Any) -> _models.ModelInfo:
:raises ~azure.core.exceptions.HttpResponseError
"""
if not self._model_info:
self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init
self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init
return self._model_info


def __str__(self) -> str:
# pylint: disable=client-method-name-no-double-underscore
return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__()


# Remove this once https://github.com/Azure/autorest.python/issues/2619 is fixed,
# and you see the equivalent auto-generated method in _client.py return "Self"
def __enter__(self) -> Self:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,7 @@ def _deserialize(self, target_obj, data):
elif isinstance(response, type) and issubclass(response, Enum):
return self.deserialize_enum(data, response)

if data is None:
if data is None or data is CoreNull:
return data
try:
attributes = response._attribute_map # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,6 @@ async def _complete(
response. Required.
"model": "str", # The model used for the chat completion. Required.
"usage": {
"capacity_type": "str", # Indicates whether your capacity has been
affected by the usage amount (token count) reported here. Required. Known
values are: "usage" and "fixed".
"completion_tokens": 0, # The number of tokens generated across all
completions emissions. Required.
"prompt_tokens": 0, # The number of tokens in the provided prompts
Expand Down Expand Up @@ -547,9 +544,6 @@ async def _embed(
"id": "str", # Unique identifier for the embeddings result. Required.
"model": "str", # The model ID used to generate this result. Required.
"usage": {
"capacity_type": "str", # Indicates whether your capacity has been
affected by the usage amount (token count) reported here. Required. Known
values are: "usage" and "fixed".
"input_tokens": 0, # Number of tokens in the request prompt.
Required.
"prompt_tokens": 0, # Number of tokens used for the prompt sent to
Expand Down Expand Up @@ -822,9 +816,6 @@ async def _embed(
"id": "str", # Unique identifier for the embeddings result. Required.
"model": "str", # The model ID used to generate this result. Required.
"usage": {
"capacity_type": "str", # Indicates whether your capacity has been
affected by the usage amount (token count) reported here. Required. Known
values are: "usage" and "fixed".
"input_tokens": 0, # Number of tokens in the request prompt.
Required.
"prompt_tokens": 0, # Number of tokens used for the prompt sent to
Expand Down
32 changes: 10 additions & 22 deletions sdk/ai/azure-ai-inference/azure/ai/inference/aio/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ async def load_client(
:raises ~azure.core.exceptions.HttpResponseError
"""

async with ChatCompletionsClient(endpoint, credential, **kwargs) as client: # Pick any of the clients, it does not matter.
model_info = await client.get_model_info() # type: ignore
async with ChatCompletionsClient(
endpoint, credential, **kwargs
) as client: # Pick any of the clients, it does not matter.
model_info = await client.get_model_info() # type: ignore

_LOGGER.info("model_info=%s", model_info)
if not model_info.model_type:
Expand Down Expand Up @@ -151,9 +153,7 @@ async def complete(
] = None,
seed: Optional[int] = None,
**kwargs: Any,
) -> _models.ChatCompletions:
...

) -> _models.ChatCompletions: ...

@overload
async def complete(
Expand All @@ -177,9 +177,7 @@ async def complete(
] = None,
seed: Optional[int] = None,
**kwargs: Any,
) -> _models.AsyncStreamingChatCompletions:
...

) -> _models.AsyncStreamingChatCompletions: ...

@overload
async def complete(
Expand Down Expand Up @@ -539,7 +537,6 @@ async def complete(

return _deserialize(_models.ChatCompletions, response.json()) # pylint: disable=protected-access


@distributed_trace_async
async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo:
# pylint: disable=line-too-long
Expand All @@ -550,15 +547,13 @@ async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo:
:raises ~azure.core.exceptions.HttpResponseError
"""
if not self._model_info:
self._model_info = await self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init
self._model_info = await self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init
return self._model_info


def __str__(self) -> str:
# pylint: disable=client-method-name-no-double-underscore
return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__()


# Remove this once https://github.com/Azure/autorest.python/issues/2619 is fixed,
# and you see the equivalent auto-generated method in _client.py return "Self"
async def __aenter__(self) -> Self:
Expand Down Expand Up @@ -587,7 +582,6 @@ def __init__(
self._model_info: Optional[_models.ModelInfo] = None
super().__init__(endpoint=endpoint, credential=credential, **kwargs)


@overload
async def embed(
self,
Expand Down Expand Up @@ -797,7 +791,6 @@ async def embed(

return deserialized # type: ignore


@distributed_trace_async
async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo:
# pylint: disable=line-too-long
Expand All @@ -808,15 +801,13 @@ async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo:
:raises ~azure.core.exceptions.HttpResponseError
"""
if not self._model_info:
self._model_info = await self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init
self._model_info = await self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init
return self._model_info


def __str__(self) -> str:
# pylint: disable=client-method-name-no-double-underscore
return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__()


# Remove this once https://github.com/Azure/autorest.python/issues/2619 is fixed,
# and you see the equivalent auto-generated method in _client.py return "Self"
async def __aenter__(self) -> Self:
Expand Down Expand Up @@ -845,7 +836,6 @@ def __init__(
self._model_info: Optional[_models.ModelInfo] = None
super().__init__(endpoint=endpoint, credential=credential, **kwargs)


@overload
async def embed(
self,
Expand Down Expand Up @@ -1055,7 +1045,6 @@ async def embed(

return deserialized # type: ignore


@distributed_trace_async
async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo:
# pylint: disable=line-too-long
Expand All @@ -1066,21 +1055,20 @@ async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo:
:raises ~azure.core.exceptions.HttpResponseError
"""
if not self._model_info:
self._model_info = await self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init
self._model_info = await self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init
return self._model_info


def __str__(self) -> str:
# pylint: disable=client-method-name-no-double-underscore
return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__()


# Remove this once https://github.com/Azure/autorest.python/issues/2619 is fixed,
# and you see the equivalent auto-generated method in _client.py return "Self"
async def __aenter__(self) -> Self:
await self._client.__aenter__()
return self


__all__: List[str] = [
"load_client",
"ChatCompletionsClient",
Expand Down
Loading
Loading