Skip to content

Commit

Permalink
Mask API key for AI21 LLM (langchain-ai#12418)
Browse files Browse the repository at this point in the history
- **Description:** Added masking of the API Key for AI21 LLM when
printed and improved the docstring for AI21 LLM.
- Updated the AI21 LLM to utilize SecretStr from pydantic to securely
manage API key.
- Made improvements in the docstring of AI21 LLM. It now mentions that
the API key can also be passed as a named parameter to the constructor.
    - Added unit tests.
  - **Issue:** langchain-ai#12165 
  - **Tag maintainer:** @eyurtsev

---------

Co-authored-by: Anirudh Gautam <anirudh@Anirudhs-Mac-mini.local>
  • Loading branch information
2 people authored and xieqihui committed Nov 21, 2023
1 parent 28eeb7c commit b2a642e
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 9 deletions.
19 changes: 11 additions & 8 deletions libs/langchain/langchain/llms/ai21.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast

import requests

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.utils import get_from_dict_or_env
from langchain.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain.utils import convert_to_secret_str, get_from_dict_or_env


class AI21PenaltyData(BaseModel):
Expand All @@ -23,13 +23,13 @@ class AI21(LLM):
"""AI21 large language models.
To use, you should have the environment variable ``AI21_API_KEY``
set with your API key.
set with your API key or pass it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms import AI21
ai21 = AI21(model="j2-jumbo-instruct")
ai21 = AI21(ai21_api_key="my-api-key", model="j2-jumbo-instruct")
"""

model: str = "j2-jumbo-instruct"
Expand Down Expand Up @@ -62,7 +62,7 @@ class AI21(LLM):
logitBias: Optional[Dict[str, float]] = None
"""Adjust the probability of specific tokens being generated."""

ai21_api_key: Optional[str] = None
ai21_api_key: Optional[SecretStr] = None

stop: Optional[List[str]] = None

Expand All @@ -77,7 +77,9 @@ class Config:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY")
ai21_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY")
)
values["ai21_api_key"] = ai21_api_key
return values

Expand Down Expand Up @@ -141,9 +143,10 @@ def _call(
else:
base_url = "https://api.ai21.com/studio/v1"
params = {**self._default_params, **kwargs}
self.ai21_api_key = cast(SecretStr, self.ai21_api_key)
response = requests.post(
url=f"{base_url}/{self.model}/complete",
headers={"Authorization": f"Bearer {self.ai21_api_key}"},
headers={"Authorization": f"Bearer {self.ai21_api_key.get_secret_value()}"},
json={"prompt": prompt, "stopSequences": stop, **params},
)
if response.status_code != 200:
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain/langchain/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
from langchain.utils.utils import (
check_package_version,
convert_to_secret_str,
get_pydantic_field_names,
guard_import,
mock_now,
Expand All @@ -27,6 +28,7 @@
"StrictFormatter",
"check_package_version",
"comma_list",
"convert_to_secret_str",
"cosine_similarity",
"cosine_similarity_top_k",
"formatter",
Expand Down
11 changes: 10 additions & 1 deletion libs/langchain/langchain/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import importlib
import warnings
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Set, Tuple
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union

from packaging.version import parse
from requests import HTTPError, Response

from langchain.pydantic_v1 import SecretStr


def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive."""
Expand Down Expand Up @@ -169,3 +171,10 @@ def build_extra_kwargs(
)

return extra_kwargs


def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
if isinstance(value, SecretStr):
return value
return SecretStr(value)
41 changes: 41 additions & 0 deletions libs/langchain/tests/unit_tests/llms/test_ai21.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Test AI21 llm"""
from typing import cast

from pytest import CaptureFixture, MonkeyPatch

from langchain.llms.ai21 import AI21
from langchain.pydantic_v1 import SecretStr


def test_api_key_is_secret_string() -> None:
llm = AI21(ai21_api_key="secret-api-key")
assert isinstance(llm.ai21_api_key, SecretStr)


def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("AI21_API_KEY", "secret-api-key")
llm = AI21()
print(llm.ai21_api_key, end="")
captured = capsys.readouterr()

assert captured.out == "**********"


def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = AI21(ai21_api_key="secret-api-key")
print(llm.ai21_api_key, end="")
captured = capsys.readouterr()

assert captured.out == "**********"


def test_uses_actual_secret_value_from_secretstr() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
llm = AI21(ai21_api_key="secret-api-key")
assert cast(SecretStr, llm.ai21_api_key).get_secret_value() == "secret-api-key"

0 comments on commit b2a642e

Please sign in to comment.