Skip to content

Commit

Permalink
AzureOpenai access from behind company proxies. (#1459)
Browse files Browse the repository at this point in the history
  • Loading branch information
PranavPuranik authored and shlokkhemani committed Sep 7, 2024
1 parent e7d2ab6 commit 448c672
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 6 deletions.
9 changes: 7 additions & 2 deletions embedchain/docs/api-reference/advanced/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ embedder:
config:
model: 'text-embedding-ada-002'
api_key: sk-xxx
http_client_proxies: http://testproxy.mem0.net:8000

chunker:
chunk_size: 2000
Expand Down Expand Up @@ -106,7 +107,8 @@ cache:
"provider": "openai",
"config": {
"model": "text-embedding-ada-002",
"api_key": "sk-xxx"
"api_key": "sk-xxx",
"http_client_proxies": "http://testproxy.mem0.net:8000",
}
},
"chunker": {
Expand Down Expand Up @@ -168,7 +170,8 @@ config = {
'provider': 'openai',
'config': {
'model': 'text-embedding-ada-002',
'api_key': 'sk-xxx'
'api_key': 'sk-xxx',
"http_client_proxies": "http://testproxy.mem0.net:8000",
}
},
'chunker': {
Expand Down Expand Up @@ -236,6 +239,8 @@ Alright, let's dive into what each key means in the yaml config above:
- `title` (String): The title for the embedding model for Google Embedder.
- `task_type` (String): The task type for the embedding model for Google Embedder.
- `model_kwargs` (Dict): Used to pass extra arguments to embedders.
- `http_client_proxies` (Dict | String): The proxy server settings used to create `self.http_client` using `httpx.Client(proxies=http_client_proxies)`
- `http_async_client_proxies` (Dict | String): The proxy server settings for async calls used to create `self.http_async_client` using `httpx.AsyncClient(proxies=http_async_client_proxies)`
5. `chunker` Section:
- `chunk_size` (Integer): The size of each chunk of text that is sent to the language model.
- `chunk_overlap` (Integer): The amount of overlap between each chunk of text.
Expand Down
15 changes: 14 additions & 1 deletion embedchain/embedchain/config/embedder/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import httpx

from embedchain.helpers.json_serializable import register_deserializable

Expand All @@ -14,6 +16,8 @@ def __init__(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
http_client_proxies: Optional[Union[Dict, str]] = None,
http_async_client_proxies: Optional[Union[Dict, str]] = None,
):
"""
Initialize a new instance of an embedder config class.
Expand All @@ -32,6 +36,11 @@ def __init__(
:type api_base: Optional[str], optional
:param model_kwargs: key-value arguments for the embedding model, defaults a dict inside init.
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init.
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
:type http_client_proxies: Optional[Dict | str], optional
:param http_async_client_proxies: The proxy server settings for async calls used to create
self.http_async_client, defaults to None
:type http_async_client_proxies: Optional[Dict | str], optional
"""
self.model = model
self.deployment_name = deployment_name
Expand All @@ -40,3 +49,7 @@ def __init__(
self.api_key = api_key
self.api_base = api_base
self.model_kwargs = model_kwargs or {}
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
self.http_async_client = (
httpx.AsyncClient(proxies=http_async_client_proxies) if http_async_client_proxies else None
)
8 changes: 6 additions & 2 deletions embedchain/embedchain/embedder/azure_openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from langchain_community.embeddings import AzureOpenAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings

from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
Expand All @@ -14,7 +14,11 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
if self.config.model is None:
self.config.model = "text-embedding-ada-002"

embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
embeddings = AzureOpenAIEmbeddings(
deployment=self.config.deployment_name,
http_client=self.config.http_client,
http_async_client=self.config.http_async_client,
)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)

self.set_embedding_fn(embedding_fn=embedding_fn)
Expand Down
2 changes: 2 additions & 0 deletions embedchain/embedchain/llm/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
temperature=config.temperature,
max_tokens=config.max_tokens,
streaming=config.stream,
http_client=config.http_client,
http_async_client=config.http_async_client,
)

if config.top_p and config.top_p != 1:
Expand Down
2 changes: 2 additions & 0 deletions embedchain/embedchain/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ def validate_config(config_data):
Optional("base_url"): str,
Optional("endpoint"): str,
Optional("model_kwargs"): dict,
Optional("http_client_proxies"): Or(str, dict),
Optional("http_async_client_proxies"): Or(str, dict),
},
},
Optional("embedding_model"): {
Expand Down
12 changes: 12 additions & 0 deletions embedchain/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 52 additions & 0 deletions embedchain/tests/embedder/test_azure_openai_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from unittest.mock import patch, Mock

import httpx

from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.azure_openai import AzureOpenAIEmbedder


def test_azure_openai_embedder_with_http_client(monkeypatch):
mock_http_client = Mock(spec=httpx.Client)
mock_http_client_instance = Mock(spec=httpx.Client)
mock_http_client.return_value = mock_http_client_instance

with patch("embedchain.embedder.azure_openai.AzureOpenAIEmbeddings") as mock_embeddings, patch(
"httpx.Client", new=mock_http_client
) as mock_http_client:
config = BaseEmbedderConfig(
deployment_name="text-embedding-ada-002",
http_client_proxies="http://testproxy.mem0.net:8000",
)

_ = AzureOpenAIEmbedder(config=config)

mock_embeddings.assert_called_once_with(
deployment="text-embedding-ada-002",
http_client=mock_http_client_instance,
http_async_client=None,
)
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")


def test_azure_openai_embedder_with_http_async_client(monkeypatch):
mock_http_async_client = Mock(spec=httpx.AsyncClient)
mock_http_async_client_instance = Mock(spec=httpx.AsyncClient)
mock_http_async_client.return_value = mock_http_async_client_instance

with patch("embedchain.embedder.azure_openai.AzureOpenAIEmbeddings") as mock_embeddings, patch(
"httpx.AsyncClient", new=mock_http_async_client
) as mock_http_async_client:
config = BaseEmbedderConfig(
deployment_name="text-embedding-ada-002",
http_async_client_proxies={"http://": "http://testproxy.mem0.net:8000"},
)

_ = AzureOpenAIEmbedder(config=config)

mock_embeddings.assert_called_once_with(
deployment="text-embedding-ada-002",
http_client=None,
http_async_client=mock_http_async_client_instance,
)
mock_http_async_client.assert_called_once_with(proxies={"http://": "http://testproxy.mem0.net:8000"})
79 changes: 78 additions & 1 deletion embedchain/tests/llm/test_azure_openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import MagicMock, patch
from unittest.mock import Mock, MagicMock, patch

import httpx
import pytest
from langchain.schema import HumanMessage, SystemMessage

Expand Down Expand Up @@ -43,6 +44,8 @@ def test_get_answer(azure_openai_llm):
temperature=azure_openai_llm.config.temperature,
max_tokens=azure_openai_llm.config.max_tokens,
streaming=azure_openai_llm.config.stream,
http_client=None,
http_async_client=None,
)


Expand Down Expand Up @@ -84,4 +87,78 @@ def test_with_api_version():
temperature=0.7,
max_tokens=50,
streaming=False,
http_client=None,
http_async_client=None,
)


def test_get_llm_model_answer_with_http_client_proxies():
mock_http_client = Mock(spec=httpx.Client)
mock_http_client_instance = Mock(spec=httpx.Client)
mock_http_client.return_value = mock_http_client_instance

with patch("langchain_openai.AzureChatOpenAI") as mock_chat, patch(
"httpx.Client", new=mock_http_client
) as mock_http_client:
mock_chat.return_value.invoke.return_value.content = "Mocked response"

config = BaseLlmConfig(
deployment_name="azure_deployment",
temperature=0.7,
max_tokens=50,
stream=False,
system_prompt="System prompt",
model="gpt-3.5-turbo",
http_client_proxies="http://testproxy.mem0.net:8000",
)

llm = AzureOpenAILlm(config)
llm.get_llm_model_answer("Test query")

mock_chat.assert_called_once_with(
deployment_name="azure_deployment",
openai_api_version="2024-02-01",
model_name="gpt-3.5-turbo",
temperature=0.7,
max_tokens=50,
streaming=False,
http_client=mock_http_client_instance,
http_async_client=None,
)
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")


def test_get_llm_model_answer_with_http_async_client_proxies():
mock_http_async_client = Mock(spec=httpx.AsyncClient)
mock_http_async_client_instance = Mock(spec=httpx.AsyncClient)
mock_http_async_client.return_value = mock_http_async_client_instance

with patch("langchain_openai.AzureChatOpenAI") as mock_chat, patch(
"httpx.AsyncClient", new=mock_http_async_client
) as mock_http_async_client:
mock_chat.return_value.invoke.return_value.content = "Mocked response"

config = BaseLlmConfig(
deployment_name="azure_deployment",
temperature=0.7,
max_tokens=50,
stream=False,
system_prompt="System prompt",
model="gpt-3.5-turbo",
http_async_client_proxies={"http://": "http://testproxy.mem0.net:8000"},
)

llm = AzureOpenAILlm(config)
llm.get_llm_model_answer("Test query")

mock_chat.assert_called_once_with(
deployment_name="azure_deployment",
openai_api_version="2024-02-01",
model_name="gpt-3.5-turbo",
temperature=0.7,
max_tokens=50,
streaming=False,
http_client=None,
http_async_client=mock_http_async_client_instance,
)
mock_http_async_client.assert_called_once_with(proxies={"http://": "http://testproxy.mem0.net:8000"})

0 comments on commit 448c672

Please sign in to comment.