diff --git a/libs/langchain/langchain/llms/gooseai.py b/libs/langchain/langchain/llms/gooseai.py index 831cff9a54ab0..30947ddb1b48f 100644 --- a/libs/langchain/langchain/llms/gooseai.py +++ b/libs/langchain/langchain/llms/gooseai.py @@ -1,14 +1,21 @@ import logging -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, Union from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import Extra, Field, root_validator +from langchain.pydantic_v1 import Extra, Field, SecretStr, root_validator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) +def _to_secret(value: Union[SecretStr, str]) -> SecretStr: + """Convert a string to a SecretStr if needed.""" + if isinstance(value, SecretStr): + return value + return SecretStr(value) + + class GooseAI(LLM): """GooseAI large language models. @@ -60,7 +67,7 @@ class GooseAI(LLM): logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict) """Adjust the probability of specific tokens being generated.""" - gooseai_api_key: Optional[str] = None + gooseai_api_key: Optional[SecretStr] = None class Config: """Configuration for this pydantic config.""" @@ -89,13 +96,14 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - gooseai_api_key = get_from_dict_or_env( - values, "gooseai_api_key", "GOOSEAI_API_KEY" + gooseai_api_key = _to_secret( + get_from_dict_or_env(values, "gooseai_api_key", "GOOSEAI_API_KEY") ) + values["gooseai_api_key"] = gooseai_api_key try: import openai - openai.api_key = gooseai_api_key + openai.api_key = gooseai_api_key.get_secret_value() openai.api_base = "https://api.goose.ai/v1" values["client"] = openai.Completion except ImportError: diff --git a/libs/langchain/tests/unit_tests/llms/test_gooseai.py b/libs/langchain/tests/unit_tests/llms/test_gooseai.py new file mode 100644 index 0000000000000..db7a25ee9f92f --- /dev/null +++ b/libs/langchain/tests/unit_tests/llms/test_gooseai.py @@ -0,0 +1,32 @@ +"""Test GooseAI""" + +import pytest +from pytest import MonkeyPatch + +from langchain.llms.gooseai import GooseAI +from langchain.pydantic_v1 import SecretStr + + +@pytest.mark.requires("openai") +def test_api_key_is_secret_string() -> None: + llm = GooseAI(gooseai_api_key="secret-api-key") + assert isinstance(llm.gooseai_api_key, SecretStr) + assert llm.gooseai_api_key.get_secret_value() == "secret-api-key" + + +@pytest.mark.requires("openai") +def test_api_key_masked_when_passed_via_constructor() -> None: + llm = GooseAI(gooseai_api_key="secret-api-key") + assert str(llm.gooseai_api_key) == "**********" + assert "secret-api-key" not in repr(llm.gooseai_api_key) + assert "secret-api-key" not in repr(llm) + + +@pytest.mark.requires("openai") +def test_api_key_masked_when_passed_from_env() -> None: + with MonkeyPatch.context() as mp: + mp.setenv("GOOSEAI_API_KEY", "secret-api-key") + llm = GooseAI() + assert str(llm.gooseai_api_key) == "**********" + assert "secret-api-key" not in repr(llm.gooseai_api_key) + assert "secret-api-key" not in repr(llm)