Skip to content

Commit

Permalink
Lazily import client modules in AutoClient (#1679)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored Jun 21, 2023
1 parent c4b7a07 commit 2bb4d78
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 44 deletions.
139 changes: 102 additions & 37 deletions src/helm/proxy/clients/auto_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import replace
from typing import Dict, Optional
from typing import Dict, Optional, TYPE_CHECKING

from retrying import RetryError, Attempt

Expand All @@ -14,45 +14,33 @@
DecodeRequestResult,
)
from helm.proxy.retry import retry_request
from .critique_client import CritiqueClient, RandomCritiqueClient
from .model_critique_client import ModelCritiqueClient
from .scale_critique_client import ScaleCritiqueClient
from .surge_ai_critique_client import SurgeAICritiqueClient
from .mechanical_turk_critique_client import MechanicalTurkCritiqueClient
from .client import Client
from .ai21_client import AI21Client
from .aleph_alpha_client import AlephAlphaClient
from .anthropic_client import AnthropicClient
from .chat_gpt_client import ChatGPTClient
from .cohere_client import CohereClient
from .together_client import TogetherClient
from .google_client import GoogleClient
from .goose_ai_client import GooseAIClient
from .huggingface_client import HuggingFaceClient
from .ice_tokenizer_client import ICETokenizerClient
from .megatron_client import MegatronClient
from .openai_client import OpenAIClient
from .microsoft_client import MicrosoftClient
from .perspective_api_client import PerspectiveAPIClient
from .palmyra_client import PalmyraClient
from .yalm_tokenizer_client import YaLMTokenizerClient
from .simple_client import SimpleClient
from helm.proxy.clients.critique_client import CritiqueClient
from helm.proxy.clients.client import Client
from helm.proxy.clients.huggingface_model_registry import get_huggingface_model_config
from helm.proxy.clients.toxicity_classifier_client import ToxicityClassifierClient


if TYPE_CHECKING:
import helm.proxy.clients.huggingface_client


class AutoClient(Client):
"""Automatically dispatch to the proper `Client` based on the organization."""
"""Automatically dispatch to the proper `Client` based on the organization.
The modules for each client are lazily imported when the respective client is created.
This greatly speeds up the import time of this module, and allows the client modules to
use optional dependencies."""

def __init__(self, credentials: Dict[str, str], cache_path: str, mongo_uri: str = ""):
self.credentials = credentials
self.cache_path = cache_path
self.mongo_uri = mongo_uri
self.clients: Dict[str, Client] = {}
self.tokenizer_clients: Dict[str, Client] = {}
# self.critique_client is lazily instantiated by get_critique_client()
self.critique_client: Optional[CritiqueClient] = None
huggingface_cache_config = self._build_cache_config("huggingface")
self.huggingface_client = HuggingFaceClient(huggingface_cache_config)
# self._huggingface_client is lazily instantiated by get_huggingface_client()
self._huggingface_client: Optional["helm.proxy.clients.huggingface_client.HuggingFaceClient"] = None
# self._critique_client is lazily instantiated by get_critique_client()
self._critique_client: Optional[CritiqueClient] = None
hlog(f"AutoClient: cache_path = {cache_path}")
hlog(f"AutoClient: mongo_uri = {mongo_uri}")

Expand All @@ -73,8 +61,13 @@ def _get_client(self, model: str) -> Client:
cache_config: CacheConfig = self._build_cache_config(organization)

if get_huggingface_model_config(model):
from helm.proxy.clients.huggingface_client import HuggingFaceClient

client = HuggingFaceClient(cache_config=cache_config)
elif organization == "openai":
from helm.proxy.clients.chat_gpt_client import ChatGPTClient
from helm.proxy.clients.openai_client import OpenAIClient

# TODO: add ChatGPT to the OpenAIClient when it's supported.
# We're using a separate client for now since we're using an unofficial Python library.
# See https://github.com/acheong08/ChatGPT/wiki/Setup on how to get a valid session token.
Expand All @@ -96,24 +89,38 @@ def _get_client(self, model: str) -> Client:
org_id=org_id,
)
elif organization == "AlephAlpha":
from helm.proxy.clients.aleph_alpha_client import AlephAlphaClient

client = AlephAlphaClient(api_key=self.credentials["alephAlphaKey"], cache_config=cache_config)
elif organization == "ai21":
from helm.proxy.clients.ai21_client import AI21Client

client = AI21Client(api_key=self.credentials["ai21ApiKey"], cache_config=cache_config)
elif organization == "cohere":
from helm.proxy.clients.cohere_client import CohereClient

client = CohereClient(api_key=self.credentials["cohereApiKey"], cache_config=cache_config)
elif organization == "gooseai":
from helm.proxy.clients.goose_ai_client import GooseAIClient

org_id = self.credentials.get("gooseaiOrgId", None)
client = GooseAIClient(
api_key=self.credentials["gooseaiApiKey"], cache_config=cache_config, org_id=org_id
)
elif organization == "huggingface" or organization == "mosaicml":
client = self.huggingface_client
from helm.proxy.clients.huggingface_client import HuggingFaceClient

client = HuggingFaceClient(cache_config)
elif organization == "anthropic":
from helm.proxy.clients.anthropic_client import AnthropicClient

client = AnthropicClient(
api_key=self.credentials.get("anthropicApiKey", None),
cache_config=cache_config,
)
elif organization == "microsoft":
from helm.proxy.clients.microsoft_client import MicrosoftClient

org_id = self.credentials.get("microsoftOrgId", None)
lock_file_path: str = os.path.join(self.cache_path, f"{organization}.lock")
client = MicrosoftClient(
Expand All @@ -123,17 +130,27 @@ def _get_client(self, model: str) -> Client:
org_id=org_id,
)
elif organization == "google":
from helm.proxy.clients.google_client import GoogleClient

client = GoogleClient(cache_config=cache_config)
elif organization == "together":
from helm.proxy.clients.together_client import TogetherClient

client = TogetherClient(api_key=self.credentials.get("togetherApiKey", None), cache_config=cache_config)
elif organization == "simple":
from helm.proxy.clients.simple_client import SimpleClient

client = SimpleClient(cache_config=cache_config)
elif organization == "writer":
from helm.proxy.clients.palmyra_client import PalmyraClient

client = PalmyraClient(
api_key=self.credentials["writerApiKey"],
cache_config=cache_config,
)
elif organization == "nvidia":
from helm.proxy.clients.megatron_client import MegatronClient

client = MegatronClient(cache_config=cache_config)
else:
raise ValueError(f"Could not find client for model: {model}")
Expand Down Expand Up @@ -173,6 +190,8 @@ def _get_tokenizer_client(self, tokenizer: str) -> Client:
if client is None:
cache_config: CacheConfig = self._build_cache_config(organization)
if get_huggingface_model_config(tokenizer):
from helm.proxy.clients.huggingface_client import HuggingFaceClient

client = HuggingFaceClient(cache_config=cache_config)
elif organization in [
"bigscience",
Expand All @@ -185,30 +204,52 @@ def _get_tokenizer_client(self, tokenizer: str) -> Client:
"microsoft",
"hf-internal-testing",
]:
from helm.proxy.clients.huggingface_client import HuggingFaceClient

client = HuggingFaceClient(cache_config=cache_config)
elif organization == "openai":
from helm.proxy.clients.openai_client import OpenAIClient

client = OpenAIClient(
cache_config=cache_config,
)
elif organization == "AlephAlpha":
from helm.proxy.clients.aleph_alpha_client import AlephAlphaClient

client = AlephAlphaClient(api_key=self.credentials["alephAlphaKey"], cache_config=cache_config)
elif organization == "anthropic":
from helm.proxy.clients.anthropic_client import AnthropicClient

client = AnthropicClient(
api_key=self.credentials.get("anthropicApiKey", None), cache_config=cache_config
)
elif organization == "TsinghuaKEG":
from helm.proxy.clients.ice_tokenizer_client import ICETokenizerClient

client = ICETokenizerClient(cache_config=cache_config)
elif organization == "Yandex":
from helm.proxy.clients.yalm_tokenizer_client import YaLMTokenizerClient

client = YaLMTokenizerClient(cache_config=cache_config)
elif organization == "ai21":
from helm.proxy.clients.ai21_client import AI21Client

client = AI21Client(api_key=self.credentials["ai21ApiKey"], cache_config=cache_config)
elif organization == "cohere":
from helm.proxy.clients.cohere_client import CohereClient

client = CohereClient(api_key=self.credentials["cohereApiKey"], cache_config=cache_config)
elif organization == "simple":
from helm.proxy.clients.simple_client import SimpleClient

client = SimpleClient(cache_config=cache_config)
elif organization == "nvidia":
from helm.proxy.clients.megatron_client import MegatronClient

client = MegatronClient(cache_config=cache_config)
elif organization == "writer":
from helm.proxy.clients.palmyra_client import PalmyraClient

client = PalmyraClient(
api_key=self.credentials["writerApiKey"],
cache_config=cache_config,
Expand Down Expand Up @@ -250,42 +291,66 @@ def decode_with_retry(client: Client, request: DecodeRequest) -> DecodeRequestRe
hlog(retry_error)
return replace(last_attempt.value, error=f"{retry_error}. Error: {last_attempt.value.error}")

def get_toxicity_classifier_client(self) -> PerspectiveAPIClient:
def get_toxicity_classifier_client(self) -> ToxicityClassifierClient:
"""Get the toxicity classifier client. We currently only support Perspective API."""
from helm.proxy.clients.perspective_api_client import PerspectiveAPIClient

cache_config: CacheConfig = self._build_cache_config("perspectiveapi")
return PerspectiveAPIClient(self.credentials.get("perspectiveApiKey", ""), cache_config)

def get_critique_client(self) -> CritiqueClient:
"""Get the critique client."""
if self._critique_client:
return self._critique_client
critique_type = self.credentials.get("critiqueType")
if critique_type == "random":
self.critique_client = RandomCritiqueClient()
from helm.proxy.clients.critique_client import RandomCritiqueClient

self._critique_client = RandomCritiqueClient()
elif critique_type == "mturk":
self.critique_client = MechanicalTurkCritiqueClient()
from helm.proxy.clients.mechanical_turk_critique_client import MechanicalTurkCritiqueClient

self._critique_client = MechanicalTurkCritiqueClient()
elif critique_type == "surgeai":
from helm.proxy.clients.surge_ai_critique_client import SurgeAICritiqueClient

surgeai_credentials = self.credentials.get("surgeaiApiKey")
if not surgeai_credentials:
raise ValueError("surgeaiApiKey credentials are required for SurgeAICritiqueClient")
self.critique_client = SurgeAICritiqueClient(surgeai_credentials, self._build_cache_config("surgeai"))
self._critique_client = SurgeAICritiqueClient(surgeai_credentials, self._build_cache_config("surgeai"))
elif critique_type == "model":
from helm.proxy.clients.model_critique_client import ModelCritiqueClient

model_name: Optional[str] = self.credentials.get("critiqueModelName")
if model_name is None:
raise ValueError("critiqueModelName is required for ModelCritiqueClient")
client: Client = self._get_client(model_name)
self.critique_client = ModelCritiqueClient(client, model_name)
self._critique_client = ModelCritiqueClient(client, model_name)
elif critique_type == "scale":
from helm.proxy.clients.scale_critique_client import ScaleCritiqueClient

scale_credentials = self.credentials.get("scaleApiKey")
scale_project = self.credentials.get("scaleProject", None)
if not scale_project:
raise ValueError("scaleProject is required for ScaleCritiqueClient.")
if not scale_credentials:
raise ValueError("scaleApiKey is required for ScaleCritiqueClient")
self.critique_client = ScaleCritiqueClient(
self._critique_client = ScaleCritiqueClient(
scale_credentials, self._build_cache_config("scale"), scale_project
)
else:
raise ValueError(
"CritiqueClient is not configured; set critiqueType to 'mturk',"
"'mturk-sandbox', 'surgeai', 'scale' or 'random'"
)
return self.critique_client
return self._critique_client

def get_huggingface_client(self) -> "helm.proxy.clients.huggingface_client.HuggingFaceClient":
"""Get the Hugging Face client."""
from helm.proxy.clients.huggingface_client import HuggingFaceClient

if self._huggingface_client:
assert isinstance(self._huggingface_client, HuggingFaceClient)
return self._huggingface_client
self._huggingface_client = HuggingFaceClient(self._build_cache_config("huggingface"))
return self._huggingface_client
3 changes: 2 additions & 1 deletion src/helm/proxy/clients/perspective_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from googleapiclient.errors import BatchError, HttpError
from googleapiclient.http import BatchHttpRequest
from httplib2 import HttpLib2Error
from helm.proxy.clients.toxicity_classifier_client import ToxicityClassifierClient
from helm.proxy.retry import NonRetriableException

from helm.common.cache import Cache, CacheConfig
Expand All @@ -18,7 +19,7 @@ class PerspectiveAPIClientCredentialsError(NonRetriableException):
pass


class PerspectiveAPIClient:
class PerspectiveAPIClient(ToxicityClassifierClient):
"""
Perspective API predicts the perceived impact a comment may have on a conversation by evaluating that comment
across a range of emotional concepts, called attributes. When you send a request to the API, you’ll request the
Expand Down
12 changes: 12 additions & 0 deletions src/helm/proxy/clients/toxicity_classifier_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import ABC, abstractmethod

from helm.common.perspective_api_request import PerspectiveAPIRequest, PerspectiveAPIRequestResult


class ToxicityClassifierClient(ABC):
"""A client that gets toxicity attributes and scores"""

@abstractmethod
def get_toxicity_scores(self, request: PerspectiveAPIRequest) -> PerspectiveAPIRequestResult:
"""Get the toxicity attributes and scores for a batch of text."""
raise NotImplementedError()
12 changes: 6 additions & 6 deletions src/helm/proxy/services/server_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from helm.common.hierarchical_logger import hlog
from helm.proxy.accounts import Accounts, Account
from helm.proxy.clients.auto_client import AutoClient
from helm.proxy.clients.perspective_api_client import PerspectiveAPIClient
from helm.proxy.clients.toxicity_classifier_client import ToxicityClassifierClient
from helm.proxy.example_queries import example_queries
from helm.proxy.models import ALL_MODELS, get_model_group
from helm.proxy.query import Query, QueryResult
Expand Down Expand Up @@ -53,10 +53,10 @@ def __init__(self, base_path: str = ".", root_mode=False, mongo_uri: str = ""):
credentials = {}

self.client = AutoClient(credentials, cache_path, mongo_uri)
self.token_counter = AutoTokenCounter(self.client.huggingface_client)
self.token_counter = AutoTokenCounter(self.client.get_huggingface_client())
self.accounts = Accounts(accounts_path, root_mode=root_mode)
# Lazily instantiated by get_toxicity_scores()
self.perspective_api_client: Optional[PerspectiveAPIClient] = None
self.toxicity_classifier_client: Optional[ToxicityClassifierClient] = None

def get_general_info(self) -> GeneralInfo:
return GeneralInfo(version=VERSION, example_queries=example_queries, all_models=ALL_MODELS)
Expand Down Expand Up @@ -123,9 +123,9 @@ def decode(self, auth: Authentication, request: DecodeRequest) -> DecodeRequestR
def get_toxicity_scores(self, auth: Authentication, request: PerspectiveAPIRequest) -> PerspectiveAPIRequestResult:
@retry_request
def get_toxicity_scores_with_retry(request: PerspectiveAPIRequest) -> PerspectiveAPIRequestResult:
if not self.perspective_api_client:
self.perspective_api_client = self.client.get_toxicity_classifier_client()
return self.perspective_api_client.get_toxicity_scores(request)
if not self.toxicity_classifier_client:
self.toxicity_classifier_client = self.client.get_toxicity_classifier_client()
return self.toxicity_classifier_client.get_toxicity_scores(request)

self.accounts.authenticate(auth)
return get_toxicity_scores_with_retry(request)
Expand Down

0 comments on commit 2bb4d78

Please sign in to comment.