Skip to content

Commit

Permalink
streamline env var naming and read host value from env vars
Browse files Browse the repository at this point in the history
  • Loading branch information
fgebhart committed Dec 5, 2024
1 parent 8ee43e7 commit 5ee2758
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 25 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import os
from aleph_alpha_client import Client, CompletionRequest, Prompt

client = Client(
token=os.getenv("AA_TOKEN"),
host="https://inference-api.your-domain.com",
token=os.environ["TEST_TOKEN"],
host=os.environ["TEST_API_URL"],
)
request = CompletionRequest(
prompt=Prompt.from_text("Provide a short description of AI:"),
Expand All @@ -39,8 +39,8 @@ from aleph_alpha_client import AsyncClient, CompletionRequest, Prompt

# Can enter context manager within an async function
async with AsyncClient(
token=os.environ["AA_TOKEN"]
host="https://inference-api.your-domain.com",
token=os.environ["TEST_TOKEN"],
host=os.environ["TEST_API_URL"],
) as client:
request = CompletionRequest(
prompt=Prompt.from_text("Provide a short description of AI:"),
Expand Down
8 changes: 4 additions & 4 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ class Client:
Example usage:
>>> request = CompletionRequest(prompt=Prompt.from_text(f"Request"), maximum_tokens=64)
>>> client = Client(
token=os.environ["AA_TOKEN"],
host="https://inference-api.your-domain.com",
token=os.environ["TEST_TOKEN"],
host=os.environ["TEST_API_URL"],
)
>>> response: CompletionResponse = client.complete(request, "pharia-1-llm-7b-control")
"""
Expand Down Expand Up @@ -743,8 +743,8 @@ class AsyncClient:
Example usage:
>>> request = CompletionRequest(prompt=Prompt.from_text(f"Request"), maximum_tokens=64)
>>> async with AsyncClient(
token=os.environ["AA_TOKEN"],
host="https://inference-api.your-domain.com"
token=os.environ["TEST_TOKEN"],
host=os.environ["TEST_API_URL"],
) as client:
response: CompletionResponse = await client.complete(request, "pharia-1-llm-7b-control")
"""
Expand Down
6 changes: 3 additions & 3 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Synchronous client.
from aleph_alpha_client import Client, CompletionRequest, Prompt
import os
client = Client(token=os.getenv("AA_TOKEN"), host="https://inference-api.your-domain.com")
client = Client(token=os.environ["TEST_TOKEN"], host=os.environ["TEST_API_URL"])
prompt = Prompt.from_text("Provide a short description of AI:")
request = CompletionRequest(prompt=prompt, maximum_tokens=20)
result = client.complete(request, model="luminous-extended")
Expand All @@ -32,7 +32,7 @@ Synchronous client with prompt containing an image.
from aleph_alpha_client import Client, CompletionRequest, PromptTemplate, Image
import os
client = Client(token=os.getenv("AA_TOKEN"), host="https://inference-api.your-domain.com")
client = Client(token=os.environ["TEST_TOKEN"], host=os.environ["TEST_API_URL"])
image = Image.from_file("path-to-an-image")
prompt_template = PromptTemplate("{{image}}This picture shows ")
prompt = prompt_template.to_prompt(image=prompt_template.placeholder(image))
Expand All @@ -50,7 +50,7 @@ Asynchronous client.
from aleph_alpha_client import AsyncClient, CompletionRequest, Prompt
# Can enter context manager within an async function
async with AsyncClient(token=os.environ["AA_TOKEN"], host="https://inference-api.your-domain.com") as client:
async with AsyncClient(token=os.environ["TEST_TOKEN"], host=os.environ["TEST_API_URL"]) as client:
request = CompletionRequest(
prompt=Prompt.from_text("Request"),
maximum_tokens=64,
Expand Down
18 changes: 10 additions & 8 deletions tests/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,27 @@ def test_api_version_mismatch_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data("0.0.0")

with pytest.raises(RuntimeError):
Client(host=httpserver.url_for(""), token="AA_TOKEN").validate_version()
Client(host=httpserver.url_for(""), token="TEST_TOKEN").validate_version()


async def test_api_version_mismatch_async_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data("0.0.0")

with pytest.raises(RuntimeError):
async with AsyncClient(host=httpserver.url_for(""), token="AA_TOKEN") as client:
async with AsyncClient(
host=httpserver.url_for(""), token="TEST_TOKEN"
) as client:
await client.validate_version()


def test_api_version_correct_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)
Client(host=httpserver.url_for(""), token="AA_TOKEN").validate_version()
Client(host=httpserver.url_for(""), token="TEST_TOKEN").validate_version()


async def test_api_version_correct_async_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)
async with AsyncClient(host=httpserver.url_for(""), token="AA_TOKEN") as client:
async with AsyncClient(host=httpserver.url_for(""), token="TEST_TOKEN") as client:
await client.validate_version()


Expand Down Expand Up @@ -71,7 +73,7 @@ def test_nice_flag_on_client(httpserver: HTTPServer):
).to_json()
)

client = Client(host=httpserver.url_for(""), token="AA_TOKEN", nice=True)
client = Client(host=httpserver.url_for(""), token="TEST_TOKEN", nice=True)

request = CompletionRequest(prompt=Prompt.from_text("Hello world"))
client.complete(request, model="luminous")
Expand All @@ -96,7 +98,7 @@ async def test_nice_flag_on_async_client(httpserver: HTTPServer):

async with AsyncClient(
host=httpserver.url_for(""),
token="AA_TOKEN",
token="TEST_TOKEN",
nice=True,
request_timeout_seconds=1,
) as client:
Expand Down Expand Up @@ -127,7 +129,7 @@ def test_tags_on_client(httpserver: HTTPServer):
client = Client(
host=httpserver.url_for(""),
request_timeout_seconds=1,
token="AA_TOKEN",
token="TEST_TOKEN",
tags=["tim-tagger"],
)

Expand All @@ -151,7 +153,7 @@ async def test_tags_on_async_client(httpserver: HTTPServer):
)

async with AsyncClient(
host=httpserver.url_for(""), token="AA_TOKEN", tags=["tim-tagger"]
host=httpserver.url_for(""), token="TEST_TOKEN", tags=["tim-tagger"]
) as client:
await client.complete(request, model="luminous")

Expand Down
12 changes: 6 additions & 6 deletions tests/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_translate_errors():
def test_retry_sync(httpserver: HTTPServer):
num_retries = 2
client = Client(
token="AA_TOKEN", host=httpserver.url_for(""), total_retries=num_retries
token="TEST_TOKEN", host=httpserver.url_for(""), total_retries=num_retries
)
expect_retryable_error(httpserver, num_calls_expected=num_retries)
expect_valid_version(httpserver)
Expand All @@ -40,7 +40,7 @@ def test_retry_sync(httpserver: HTTPServer):
def test_retry_sync_post(httpserver: HTTPServer):
num_retries = 2
client = Client(
host=httpserver.url_for(""), token="AA_TOKEN", total_retries=num_retries
host=httpserver.url_for(""), token="TEST_TOKEN", total_retries=num_retries
)
expect_retryable_error(httpserver, num_calls_expected=num_retries)
expect_valid_completion(httpserver)
Expand All @@ -52,7 +52,7 @@ def test_retry_sync_post(httpserver: HTTPServer):
def test_exhaust_retries_sync(httpserver: HTTPServer):
num_retries = 1
client = Client(
token="AA_TOKEN", host=httpserver.url_for(""), total_retries=num_retries
token="TEST_TOKEN", host=httpserver.url_for(""), total_retries=num_retries
)
expect_retryable_error(
httpserver,
Expand All @@ -69,7 +69,7 @@ async def test_retry_async(httpserver: HTTPServer):
expect_valid_version(httpserver)

async with AsyncClient(
token="AA_TOKEN", host=httpserver.url_for(""), total_retries=num_retries
token="TEST_TOKEN", host=httpserver.url_for(""), total_retries=num_retries
) as client:
await client.get_version()

Expand All @@ -80,7 +80,7 @@ async def test_retry_async_post(httpserver: HTTPServer):
expect_valid_completion(httpserver)

async with AsyncClient(
token="AA_TOKEN", host=httpserver.url_for(""), total_retries=num_retries
token="TEST_TOKEN", host=httpserver.url_for(""), total_retries=num_retries
) as client:
request = CompletionRequest(prompt=Prompt.from_text(""), maximum_tokens=7)
await client.complete(request, model="FOO")
Expand All @@ -95,7 +95,7 @@ async def test_exhaust_retries_async(httpserver: HTTPServer):
)
with pytest.raises(BusyError):
async with AsyncClient(
token="AA_TOKEN", host=httpserver.url_for(""), total_retries=num_retries
token="TEST_TOKEN", host=httpserver.url_for(""), total_retries=num_retries
) as client:
await client.get_version()

Expand Down

0 comments on commit 5ee2758

Please sign in to comment.