diff --git a/requirements.txt b/requirements.txt index 840512e945b..b5dc75f2e04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,6 +50,7 @@ google-api-core==2.10.1 google-api-python-client==2.64.0 google-auth==2.12.0 google-auth-httplib2==0.1.0 +google-cloud-aiplatform==1.36.4 googleapis-common-protos==1.56.4 greenlet==1.1.3 gunicorn==20.1.0 diff --git a/setup.cfg b/setup.cfg index 5a06fa34331..53c90227d48 100644 --- a/setup.cfg +++ b/setup.cfg @@ -114,6 +114,9 @@ openai = openai~=0.27.8 tiktoken~=0.3.3 +google = + google-cloud-aiplatform~=1.36.4 + tsinghua = icetk~=0.0.4 diff --git a/src/helm/benchmark/test_model_properties.py b/src/helm/benchmark/test_model_properties.py index f1aaeffdd87..43d148ee6c3 100644 --- a/src/helm/benchmark/test_model_properties.py +++ b/src/helm/benchmark/test_model_properties.py @@ -149,6 +149,12 @@ end_of_text_token="", prefix_token="", ), + TokenizerConfig( + name="google/mt5-base", + tokenizer_spec=TokenizerSpec(class_name="helm.proxy.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer"), + end_of_text_token="", + prefix_token="", + ), TokenizerConfig( name="facebook/opt-66b", tokenizer_spec=TokenizerSpec(class_name="helm.proxy.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer"), @@ -785,6 +791,53 @@ ), max_sequence_length=511, ), + ModelDeployment( + name="google/text-bison@001", + client_spec=ClientSpec(class_name="helm.proxy.clients.vertexai_client.VertexAIClient"), + tokenizer_name="google/mt5-base", + window_service_spec=WindowServiceSpec( + class_name="helm.benchmark.window_services.palm_window_service.PaLM2WindowService" + ), + max_sequence_length=8192, + ), + ModelDeployment( + name="google/text-bison-32k", + client_spec=ClientSpec(class_name="helm.proxy.clients.vertexai_client.VertexAIClient"), + tokenizer_name="google/mt5-base", + window_service_spec=WindowServiceSpec( + class_name="helm.benchmark.window_services.palm_window_service.PaLM232KWindowService" + ), + max_sequence_length=32000, + max_sequence_and_generated_tokens_length=32000, + ), + ModelDeployment( + name="google/text-unicorn@001", + client_spec=ClientSpec(class_name="helm.proxy.clients.vertexai_client.VertexAIClient"), + tokenizer_name="google/mt5-base", + window_service_spec=WindowServiceSpec( + class_name="helm.benchmark.window_services.palm_window_service.PaLM2WindowService" + ), + max_sequence_length=8192, + ), + ModelDeployment( + name="google/code-bison@001", + client_spec=ClientSpec(class_name="helm.proxy.clients.vertexai_client.VertexAIClient"), + tokenizer_name="google/mt5-base", + window_service_spec=WindowServiceSpec( + class_name="helm.benchmark.window_services.palm_window_service.CodeBisonWindowService" + ), + max_sequence_length=6144, + ), + ModelDeployment( + name="google/code-bison-32k", + client_spec=ClientSpec(class_name="helm.proxy.clients.vertexai_client.VertexAIClient"), + tokenizer_name="google/mt5-base", + window_service_spec=WindowServiceSpec( + class_name="helm.benchmark.window_services.palm_window_service.PaLM232KWindowService" + ), + max_sequence_length=32000, + max_sequence_and_generated_tokens_length=32000, + ), ModelDeployment( name="together/h3-2.7b", client_spec=ClientSpec(class_name="helm.proxy.clients.together_client.TogetherClient"), diff --git a/src/helm/benchmark/window_services/palm_window_service.py b/src/helm/benchmark/window_services/palm_window_service.py new file mode 100644 index 00000000000..9f20a00e895 --- /dev/null +++ b/src/helm/benchmark/window_services/palm_window_service.py @@ -0,0 +1,50 @@ +from .local_window_service import LocalWindowService +from .tokenizer_service import TokenizerService + + +class PaLM2WindowService(LocalWindowService): + def __init__(self, service: TokenizerService): + super().__init__(service) + + @property + def tokenizer_name(self) -> str: + """The tokenizer is most likely not correct but there is no official tokenizer. + See comment in model_deployments.yaml for more info.""" + # TODO #2083: Update this when the tokenizer is known. + return "google/mt5-base" + + @property + def max_sequence_length(self) -> int: + return 8192 + + @property + def max_request_length(self) -> int: + return self.max_sequence_length + + @property + def end_of_text_token(self) -> str: + # TODO #2083: Update this when the tokenizer is known. + # This is purely a guess based on T511bWindowService. + return "" + + @property + def prefix_token(self) -> str: + # TODO #2083: Update this when the tokenizer is known. + # This is purely a guess based on T511bWindowService. + return "" + + +class PaLM232KWindowService(PaLM2WindowService): + @property + def max_sequence_length(self) -> int: + return 32000 + + @property + def max_sequence_and_generated_tokens_length(self) -> int: + return self.max_request_length + + +class CodeBisonWindowService(PaLM2WindowService): + @property + def max_sequence_length(self) -> int: + return 6144 diff --git a/src/helm/config/model_deployments.yaml b/src/helm/config/model_deployments.yaml index 9fbd5148aa7..27dd9d3266d 100644 --- a/src/helm/config/model_deployments.yaml +++ b/src/helm/config/model_deployments.yaml @@ -366,6 +366,68 @@ model_deployments: + # Google + + ## PaLM 2 + - name: google/text-bison@001 + model_name: google/text-bison@001 + tokenizer_name: google/mt5-base + max_sequence_length: 8192 + client_spec: + class_name: "helm.proxy.clients.vertexai_client.VertexAIClient" + args: {} + window_service_spec: + class_name: "helm.benchmark.window_services.palm_window_service.PaLM2WindowService" + args: {} + + - name: google/text-bison-32k + model_name: google/text-bison-32k + tokenizer_name: google/mt5-base + max_sequence_length: 32000 + max_sequence_and_generated_tokens_length: 32000 + client_spec: + class_name: "helm.proxy.clients.vertexai_client.VertexAIClient" + args: {} + window_service_spec: + class_name: "helm.benchmark.window_services.palm_window_service.PaLM232KWindowService" + args: {} + + - name: google/text-unicorn@001 + model_name: google/text-unicorn@001 + tokenizer_name: google/mt5-base + max_sequence_length: 8192 + client_spec: + class_name: "helm.proxy.clients.vertexai_client.VertexAIClient" + args: {} + window_service_spec: + class_name: "helm.benchmark.window_services.palm_window_service.PaLM2WindowService" + args: {} + + - name: google/code-bison@001 + model_name: google/code-bison@001 + tokenizer_name: google/mt5-base + max_sequence_length: 6144 + client_spec: + class_name: "helm.proxy.clients.vertexai_client.VertexAIClient" + args: {} + window_service_spec: + class_name: "helm.benchmark.window_services.palm_window_service.CodeBisonWindowService" + args: {} + + - name: google/code-bison-32k + model_name: google/code-bison-32k + tokenizer_name: google/mt5-base + max_sequence_length: 32000 + max_sequence_and_generated_tokens_length: 32000 + client_spec: + class_name: "helm.proxy.clients.vertexai_client.VertexAIClient" + args: {} + window_service_spec: + class_name: "helm.benchmark.window_services.palm_window_service.PaLM232KWindowService" + args: {} + + + # HuggingFace ## Bigcode diff --git a/src/helm/config/model_metadata.yaml b/src/helm/config/model_metadata.yaml index 78f4705f1bd..8b905e1d2ac 100644 --- a/src/helm/config/model_metadata.yaml +++ b/src/helm/config/model_metadata.yaml @@ -506,6 +506,46 @@ models: release_date: 2023-03-01 # was first announced on 2022-04 but remained private. tags: [] # TODO: add tags + - name: google/text-bison@001 + display_name: PaLM-2 (Bison) + description: The best value PaLM model. PaLM 2 (Pathways Language Model) is a Transformer-based model trained using a mixture of objectives that was evaluated on English and multilingual language, and reasoning tasks. ([report](https://arxiv.org/pdf/2305.10403.pdf)) + creator_organization_name: Google + access: limited + release_date: 2023-06-07 # Source: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text#model_versions + tags: [TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG] + + - name: google/text-bison-32k + display_name: PaLM-2 (Bison) + description: The best value PaLM model with a 32K context. PaLM 2 (Pathways Language Model) is a Transformer-based model trained using a mixture of objectives that was evaluated on English and multilingual language, and reasoning tasks. ([report](https://arxiv.org/pdf/2305.10403.pdf)) + creator_organization_name: Google + access: limited + release_date: 2023-06-07 # Source: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text#model_versions + tags: [TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG] + + - name: google/text-unicorn@001 + display_name: PaLM-2 (Unicorn) + description: The largest model in PaLM family. PaLM 2 (Pathways Language Model) is a Transformer-based model trained using a mixture of objectives that was evaluated on English and multilingual language, and reasoning tasks. ([report](https://arxiv.org/pdf/2305.10403.pdf)) + creator_organization_name: Google + access: limited + release_date: 2023-11-30 # Source: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text#model_versions + tags: [TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG] + + - name: google/code-bison@001 + display_name: Codey PaLM-2 (Bison) + description: A model fine-tuned to generate code based on a natural language description of the desired code. PaLM 2 (Pathways Language Model) is a Transformer-based model trained using a mixture of objectives that was evaluated on English and multilingual language, and reasoning tasks. ([report](https://arxiv.org/pdf/2305.10403.pdf)) + creator_organization_name: Google + access: limited + release_date: 2023-06-29 # Source: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/code-generation#model_versions + tags: [CODE_MODEL_TAG] + + - name: google/code-bison-32k + display_name: Codey PaLM-2 (Bison) + description: Codey with a 32K context. PaLM 2 (Pathways Language Model) is a Transformer-based model trained using a mixture of objectives that was evaluated on English and multilingual language, and reasoning tasks. ([report](https://arxiv.org/pdf/2305.10403.pdf)) + creator_organization_name: Google + access: limited + release_date: 2023-06-29 # Source: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/code-generation#model_versions + tags: [CODE_MODEL_TAG] + # HazyResearch diff --git a/src/helm/config/tokenizer_configs.yaml b/src/helm/config/tokenizer_configs.yaml index c7c0d1446ec..75a323c7a0b 100644 --- a/src/helm/config/tokenizer_configs.yaml +++ b/src/helm/config/tokenizer_configs.yaml @@ -108,6 +108,11 @@ tokenizer_configs: class_name: "helm.proxy.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer" end_of_text_token: "" prefix_token: "" + - name: google/mt5-base + tokenizer_spec: + class_name: "helm.proxy.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer" + end_of_text_token: "" + prefix_token: "" # Hf-internal-testing - name: hf-internal-testing/llama-tokenizer diff --git a/src/helm/proxy/clients/auto_client.py b/src/helm/proxy/clients/auto_client.py index 8d5ddcce6e6..02e09bf3727 100644 --- a/src/helm/proxy/clients/auto_client.py +++ b/src/helm/proxy/clients/auto_client.py @@ -78,6 +78,8 @@ def _get_client(self, model_deployment_name: str) -> Client: host_organization + "OrgId", None ), # OpenAI, GooseAI, Microsoft "lock_file_path": lambda: os.path.join(self.cache_path, f"{host_organization}.lock"), # Microsoft + "project_id": lambda: self.credentials.get(host_organization + "ProjectId", None), # VertexAI + "location": lambda: self.credentials.get(host_organization + "Location", None), # VertexAI }, ) client = create_object(client_spec) diff --git a/src/helm/proxy/clients/vertexai_client.py b/src/helm/proxy/clients/vertexai_client.py new file mode 100644 index 00000000000..bc3ca2852b8 --- /dev/null +++ b/src/helm/proxy/clients/vertexai_client.py @@ -0,0 +1,112 @@ +import requests +from typing import List + +from helm.common.cache import CacheConfig +from helm.common.optional_dependencies import handle_module_not_found_error +from helm.common.request import wrap_request_time, Request, RequestResult, Sequence, Token +from helm.common.tokenization_request import ( + TokenizationRequest, + TokenizationRequestResult, +) +from helm.proxy.tokenizers.tokenizer import Tokenizer +from .client import CachingClient, truncate_sequence + +try: + import vertexai + from vertexai.language_models import TextGenerationModel, TextGenerationResponse +except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["google"]) + + +class VertexAIClient(CachingClient): + def __init__(self, tokenizer: Tokenizer, cache_config: CacheConfig, project_id: str, location: str) -> None: + super().__init__(cache_config=cache_config) + self.project_id = project_id + self.location = location + self.tokenizer = tokenizer + + vertexai.init(project=self.project_id, location=self.location) + + def make_request(self, request: Request) -> RequestResult: + """Make a request""" + parameters = { + "temperature": request.temperature, + "max_output_tokens": request.max_tokens, + "top_k": request.top_k_per_token, + "top_p": request.top_p, + "stop_sequences": request.stop_sequences, + "candidate_count": request.num_completions, + # TODO #2084: Add support for these parameters. + # The parameters "echo", "frequency_penalty", and "presence_penalty" are supposed to be supported + # in an HTTP request (See https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text), + # but they are not supported in the Python SDK: + # https://github.com/googleapis/python-aiplatform/blob/beae48f63e40ea171c3f1625164569e7311b8e5a/vertexai/language_models/_language_models.py#L968C1-L980C1 + # "frequency_penalty": request.frequency_penalty, + # "presence_penalty": request.presence_penalty, + # "echo": request.echo_prompt, + } + + completions: List[Sequence] = [] + model_name: str = request.model_engine + + try: + + def do_it(): + model = TextGenerationModel.from_pretrained(model_name) + response = model.predict(request.prompt, **parameters) + candidates: List[TextGenerationResponse] = response.candidates + response_dict = { + "predictions": [{"text": completion.text for completion in candidates}], + } # TODO: Extract more information from the response + return response_dict + + # We need to include the engine's name to differentiate among requests made for different model + # engines since the engine name is not included in the request itself. + # Same for the prompt. + cache_key = CachingClient.make_cache_key( + { + "engine": request.model_engine, + "prompt": request.prompt, + **parameters, + }, + request, + ) + + response, cached = self.cache.get(cache_key, wrap_request_time(do_it)) + except (requests.exceptions.RequestException, AssertionError) as e: + error: str = f"VertexAIClient error: {e}" + return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[]) + + for prediction in response["predictions"]: + response_text = prediction["text"] + + # The Python SDK does not support echo + # TODO #2084: Add support for echo. + text: str = request.prompt + response_text if request.echo_prompt else response_text + + tokenization_result: TokenizationRequestResult = self.tokenizer.tokenize( + TokenizationRequest(text, tokenizer="google/mt5-base") + ) + + # TODO #2085: Add support for log probs. + # Once again, log probs seem to be supported by the API but not by the Python SDK. + # HTTP Response body reference: + # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text#response_body + # Python SDK reference: + # https://github.com/googleapis/python-aiplatform/blob/beae48f63e40ea171c3f1625164569e7311b8e5a/vertexai/language_models/_language_models.py#L868 + tokens: List[Token] = [ + Token(text=str(text), logprob=0, top_logprobs={}) for text in tokenization_result.raw_tokens + ] + + completion = Sequence(text=response_text, logprob=0, tokens=tokens) + sequence = truncate_sequence(completion, request, print_warning=True) + completions.append(sequence) + + return RequestResult( + success=True, + cached=cached, + request_time=response["request_time"], + request_datetime=response["request_datetime"], + completions=completions, + embedding=[], + )