Skip to content

Commit

Permalink
Fix ray-project/Aviary integration (#6607)
Browse files Browse the repository at this point in the history
- Description: The aviary integration has changed url link. This PR
provide fix for those changes and also it makes providing the input URL
optional to the API (since they can be set via env variables).
  - Issue: N/A
  - Dependencies: N/A
  - Twitter handle: N/A

---------

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
  • Loading branch information
kouroshHakha authored Jun 23, 2023
1 parent dbe1d02 commit f6fdabd
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 56 deletions.
156 changes: 101 additions & 55 deletions langchain/llms/aviary.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Wrapper around Aviary"""
from typing import Any, Dict, List, Mapping, Optional
import dataclasses
import os
from typing import Any, Dict, List, Mapping, Optional, Union, cast

import requests
from pydantic import Extra, Field, root_validator
from pydantic import Extra, root_validator

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
Expand All @@ -12,40 +14,99 @@
TIMEOUT = 60


@dataclasses.dataclass
class AviaryBackend:
backend_url: str
bearer: str

def __post_init__(self) -> None:
self.header = {"Authorization": self.bearer}

@classmethod
def from_env(cls) -> "AviaryBackend":
aviary_url = os.getenv("AVIARY_URL")
assert aviary_url, "AVIARY_URL must be set"

aviary_token = os.getenv("AVIARY_TOKEN", "")

bearer = f"Bearer {aviary_token}" if aviary_token else ""
aviary_url += "/" if not aviary_url.endswith("/") else ""

return cls(aviary_url, bearer)


def get_models() -> List[str]:
"""List available models"""
backend = AviaryBackend.from_env()
request_url = backend.backend_url + "-/routes"
response = requests.get(request_url, headers=backend.header, timeout=TIMEOUT)
try:
result = response.json()
except requests.JSONDecodeError as e:
raise RuntimeError(
f"Error decoding JSON from {request_url}. Text response: {response.text}"
) from e
result = sorted(
[k.lstrip("/").replace("--", "/") for k in result.keys() if "--" in k]
)
return result


def get_completions(
model: str,
prompt: str,
use_prompt_format: bool = True,
version: str = "",
) -> Dict[str, Union[str, float, int]]:
"""Get completions from Aviary models."""

backend = AviaryBackend.from_env()
url = backend.backend_url + model.replace("/", "--") + "/" + version + "query"
response = requests.post(
url,
headers=backend.header,
json={"prompt": prompt, "use_prompt_format": use_prompt_format},
timeout=TIMEOUT,
)
try:
return response.json()
except requests.JSONDecodeError as e:
raise RuntimeError(
f"Error decoding JSON from {url}. Text response: {response.text}"
) from e


class Aviary(LLM):
"""Allow you to use an Aviary.
Aviary is a backend for hosted models. You can
find out more about aviary at
http://github.com/ray-project/aviary
Has no dependencies, since it connects to backend
directly.
To get a list of the models supported on an
aviary, follow the instructions on the web site to
install the aviary CLI and then use:
`aviary models`
You must at least specify the environment
variable or parameter AVIARY_URL.
You may optionally specify the environment variable
or parameter AVIARY_TOKEN.
AVIARY_URL and AVIARY_TOKEN environement variables must be set.
Example:
.. code-block:: python
from langchain.llms import Aviary
light = Aviary(aviary_url='AVIARY_URL',
model='amazon/LightGPT')
result = light.predict('How do you make fried rice?')
os.environ["AVIARY_URL"] = "<URL>"
os.environ["AVIARY_TOKEN"] = "<TOKEN>"
light = Aviary(model='amazon/LightGPT')
output = light('How do you make fried rice?')
"""

model: str
aviary_url: str
aviary_token: str = Field("", exclude=True)
model: str = "amazon/LightGPT"
aviary_url: Optional[str] = None
aviary_token: Optional[str] = None
# If True the prompt template for the model will be ignored.
use_prompt_format: bool = True
# API version to use for Aviary
version: Optional[str] = None

class Config:
"""Configuration for this pydantic object."""
Expand All @@ -56,49 +117,35 @@ class Config:
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
if not aviary_url.endswith("/"):
aviary_url += "/"
values["aviary_url"] = aviary_url
aviary_token = get_from_dict_or_env(
values, "aviary_token", "AVIARY_TOKEN", default=""
)
values["aviary_token"] = aviary_token
aviary_token = get_from_dict_or_env(values, "aviary_token", "AVIARY_TOKEN")

aviary_endpoint = aviary_url + "models"
headers = {"Authorization": f"Bearer {aviary_token}"} if aviary_token else {}
try:
response = requests.get(aviary_endpoint, headers=headers)
result = response.json()
# Confirm model is available
if values["model"] not in result:
raise ValueError(
f"{aviary_url} does not support model {values['model']}."
)
# Set env viarables for aviary sdk
os.environ["AVIARY_URL"] = aviary_url
os.environ["AVIARY_TOKEN"] = aviary_token

try:
aviary_models = get_models()
except requests.exceptions.RequestException as e:
raise ValueError(e)

model = values.get("model")
if model and model not in aviary_models:
raise ValueError(f"{aviary_url} does not support model {values['model']}.")

return values

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"model_name": self.model,
"aviary_url": self.aviary_url,
"aviary_token": self.aviary_token,
}

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "aviary"

@property
def headers(self) -> Dict[str, str]:
if self.aviary_token:
return {"Authorization": f"Bearer {self.aviary_token}"}
else:
return {}
return f"aviary-{self.model.replace('/', '-')}"

def _call(
self,
Expand All @@ -119,19 +166,18 @@ def _call(
response = aviary("Tell me a joke.")
"""
url = self.aviary_url + "query/" + self.model.replace("/", "--")
response = requests.post(
url,
headers=self.headers,
json={"prompt": prompt},
timeout=TIMEOUT,
kwargs = {"use_prompt_format": self.use_prompt_format}
if self.version:
kwargs["version"] = self.version

output = get_completions(
model=self.model,
prompt=prompt,
**kwargs,
)
try:
text = response.json()[self.model]["generated_text"]
except requests.JSONDecodeError as e:
raise ValueError(
f"Error decoding JSON from {url}. Text response: {response.text}",
) from e

text = cast(str, output["generated_text"])
if stop:
text = enforce_stop_tokens(text, stop)

return text
3 changes: 2 additions & 1 deletion tests/integration_tests/llms/test_aviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

def test_aviary_call() -> None:
"""Test valid call to Anyscale."""
llm = Aviary(model="test/model")
llm = Aviary()
output = llm("Say bar:")
print(f"llm answer:\n{output}")
assert isinstance(output, str)

0 comments on commit f6fdabd

Please sign in to comment.