Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Customizing default_headers in Azure OpenAI #1925

Merged
merged 3 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions docs/components/embedders/models/azure_openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ config = {
"provider": "azure_openai",
"config": {
"model": "text-embedding-3-large"
"azure_kwargs" : {
"api_version" : "",
"azure_deployment" : "",
"azure_endpoint" : "",
"api_key": ""
"azure_kwargs": {
"api_version": "",
"azure_deployment": "",
"azure_endpoint": "",
"api_key": "",
"default_headers": {
"CustomHeader": "your-custom-header",
}
}
}
}
Expand Down
26 changes: 16 additions & 10 deletions docs/components/llms/models/azure_openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ config = {
"model": "your-deployment-name",
"temperature": 0.1,
"max_tokens": 2000,
"azure_kwargs" : {
"azure_deployment" : "",
"api_version" : "",
"azure_endpoint" : "",
"api_key" : ""
"azure_kwargs": {
"azure_deployment": "",
"api_version": "",
"azure_endpoint": "",
"api_key": "",
"default_headers": {
"CustomHeader": "your-custom-header",
}
}
}
}
Expand Down Expand Up @@ -54,11 +57,14 @@ config = {
"model": "your-deployment-name",
"temperature": 0.1,
"max_tokens": 2000,
"azure_kwargs" : {
"azure_deployment" : "",
"api_version" : "",
"azure_endpoint" : "",
"api_key" : ""
"azure_kwargs": {
"azure_deployment": "",
"api_version": "",
"azure_endpoint": "",
"api_key": "",
"default_headers": {
"CustomHeader": "your-custom-header",
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions mem0/configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class AzureConfig(BaseModel):
azure_deployment (str): The name of the Azure deployment.
azure_endpoint (str): The endpoint URL for the Azure service.
api_version (str): The version of the Azure API being used.
default_headers (Dict[str, str]): Headers to include in requests to the Azure API.
"""

api_key: str = Field(
Expand All @@ -72,3 +73,4 @@ class AzureConfig(BaseModel):
azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
api_version: str = Field(description="The version of the Azure API being used.", default=None)
default_headers: Optional[Dict[str, str]] = Field(description="Headers to include in requests to the Azure API.", default=None)
2 changes: 2 additions & 0 deletions mem0/embeddings/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT")
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT")
api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION")
default_headers = self.config.azure_kwargs.default_headers

self.client = AzureOpenAI(
azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
http_client=self.config.http_client,
default_headers=default_headers,
)

def embed(self, text):
Expand Down
2 changes: 2 additions & 0 deletions mem0/llms/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT")
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT")
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION")
default_headers = self.config.azure_kwargs.default_headers

self.client = AzureOpenAI(
azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
http_client=self.config.http_client,
default_headers=default_headers,
)

def _parse_response(self, response, tools):
Expand Down
4 changes: 3 additions & 1 deletion mem0/llms/azure_openai_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment
azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
# Can display a warning if API version is of model and api-version
default_headers = self.config.azure_kwargs.default_headers

# Can display a warning if API version is of model and api-version
self.client = AzureOpenAI(
azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
http_client=self.config.http_client,
default_headers=default_headers,
)

def _parse_response(self, response, tools):
Expand Down
42 changes: 26 additions & 16 deletions tests/embeddings/test_azure_openai_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
from unittest.mock import Mock, patch
from mem0.embeddings.azure_openai import AzureOpenAIEmbedding

import pytest

from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.azure_openai import AzureOpenAIEmbedding


@pytest.fixture
Expand Down Expand Up @@ -29,18 +31,26 @@ def test_embed_text(mock_openai_client):
assert embedding == [0.1, 0.2, 0.3]


def test_embed_text_with_newlines(mock_openai_client):
config = BaseEmbedderConfig(model="text-embedding-ada-002")
embedder = AzureOpenAIEmbedding(config)

mock_embedding_response = Mock()
mock_embedding_response.data = [Mock(embedding=[0.4, 0.5, 0.6])]
mock_openai_client.embeddings.create.return_value = mock_embedding_response

text = "Hello,\nthis is a test\nwith newlines."
embedding = embedder.embed(text)

mock_openai_client.embeddings.create.assert_called_once_with(
input=["Hello, this is a test with newlines."], model="text-embedding-ada-002"
@pytest.mark.parametrize(
"default_headers, expected_header",
[
(None, None),
({"Test": "test_value"}, "test_value"),
({}, None)
],
)
def test_embed_text_with_default_headers(default_headers, expected_header):
config = BaseEmbedderConfig(
model="text-embedding-ada-002",
azure_kwargs={
"api_key": "test",
"api_version": "test_version",
"azure_endpoint": "test_endpoint",
"azuer_deployment": "test_deployment",
"default_headers": default_headers
}
)
assert embedding == [0.4, 0.5, 0.6]
embedder = AzureOpenAIEmbedding(config)
assert embedder.client.api_key == "test"
assert embedder.client._api_version == "test_version"
assert embedder.client.default_headers.get("Test") == expected_header
12 changes: 10 additions & 2 deletions tests/llms/test_azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,17 @@ def test_generate_response_with_tools(mock_openai_client):
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}


def test_generate_with_http_proxies():
@pytest.mark.parametrize(
"default_headers",
[None, {"Firstkey": "FirstVal", "SecondKey": "SecondVal"}],
)
def test_generate_with_http_proxies(default_headers):
mock_http_client = Mock(spec=httpx.Client)
mock_http_client_instance = Mock(spec=httpx.Client)
mock_http_client.return_value = mock_http_client_instance
azure_kwargs = {"api_key": "test"}
if default_headers:
azure_kwargs["default_headers"] = default_headers

with (
patch("mem0.llms.azure_openai.AzureOpenAI") as mock_azure_openai,
Expand All @@ -108,7 +115,7 @@ def test_generate_with_http_proxies():
top_p=TOP_P,
api_key="test",
http_client_proxies="http://testproxy.mem0.net:8000",
azure_kwargs={"api_key": "test"},
azure_kwargs=azure_kwargs,
)

_ = AzureOpenAILLM(config)
Expand All @@ -119,5 +126,6 @@ def test_generate_with_http_proxies():
azure_deployment=None,
azure_endpoint=None,
api_version=None,
default_headers=default_headers,
)
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
Loading