Skip to content

Commit

Permalink
add more Yi-model, remove extra parameteres, fix tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
MuggleJinx committed Oct 30, 2024
1 parent 638b971 commit cc067bc
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 57 deletions.
27 changes: 1 addition & 26 deletions camel/configs/yi_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from __future__ import annotations

from typing import Optional, Sequence, Union

from pydantic import Field
from typing import Optional, Union

from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven


class YiConfig(BaseConfig):
Expand Down Expand Up @@ -48,35 +45,13 @@ class YiConfig(BaseConfig):
while higher values make it more diverse. (default: :obj:`0.3`)
stream (bool, optional): If True, enables streaming output.
(default: :obj:`False`)
stop (Union[str, Sequence[str], NotGiven], optional): Up to `4`
sequences where the API will stop generating further tokens.
(default: :obj:`NOT_GIVEN`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. (default: :obj:`0.0`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. (default: :obj:`{}`)
user (str, optional): A unique identifier representing your end-user,
which can help monitor and detect abuse. (default: :obj:`""`)
"""

tool_choice: Optional[Union[dict[str, str], str]] = None
max_tokens: Optional[int] = 5000
top_p: float = 0.9
temperature: float = 0.3
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
logit_bias: dict = Field(default_factory=dict)
user: str = ""


YI_API_PARAMS = {param for param in YiConfig.model_fields.keys()}
4 changes: 2 additions & 2 deletions camel/models/yi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from camel.utils import (
BaseTokenCounter,
YiTokenCounter,
OpenAITokenCounter,
api_keys_required,
)

Expand Down Expand Up @@ -109,7 +109,7 @@ def token_counter(self) -> BaseTokenCounter:
"""

if not self._token_counter:
self._token_counter = YiTokenCounter(self.model_type)
self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
return self._token_counter

def check_model_config(self):
Expand Down
16 changes: 15 additions & 1 deletion camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ class ModelType(UnifiedModelType, Enum):
YI_LARGE = "yi-large"
YI_MEDIUM = "yi-medium"
YI_LARGE_TURBO = "yi-large-turbo"

YI_VISION = "yi-vision"
YI_MEDIUM_200K = "yi-medium-200k"
YI_SPARK = "yi-spark"
YI_LARGE_RAG = "yi-large-rag"
YI_LARGE_FC = "yi-large-fc"

def __str__(self):
return self.value
Expand Down Expand Up @@ -239,6 +243,11 @@ def is_yi(self) -> bool:
ModelType.YI_LARGE,
ModelType.YI_MEDIUM,
ModelType.YI_LARGE_TURBO,
ModelType.YI_VISION,
ModelType.YI_MEDIUM_200K,
ModelType.YI_SPARK,
ModelType.YI_LARGE_RAG,
ModelType.YI_LARGE_FC,
}

@property
Expand Down Expand Up @@ -273,6 +282,9 @@ def token_limit(self) -> int:
ModelType.YI_LIGHTNING,
ModelType.YI_MEDIUM,
ModelType.YI_LARGE_TURBO,
ModelType.YI_VISION,
ModelType.YI_SPARK,
ModelType.YI_LARGE_RAG,
}:
return 16_384
elif self in {
Expand All @@ -281,6 +293,7 @@ def token_limit(self) -> int:
ModelType.MISTRAL_MIXTRAL_8x7B,
ModelType.GROQ_MIXTRAL_8_7B,
ModelType.YI_LARGE,
ModelType.YI_LARGE_FC,
}:
return 32_768
elif self in {ModelType.MISTRAL_MIXTRAL_8x22B}:
Expand Down Expand Up @@ -319,6 +332,7 @@ def token_limit(self) -> int:
return 200_000
elif self in {
ModelType.MISTRAL_CODESTRAL_MAMBA,
ModelType.YI_MEDIUM_200K,
}:
return 256_000
elif self in {
Expand Down
1 change: 0 additions & 1 deletion camel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
LiteLLMTokenCounter,
MistralTokenCounter,
OpenAITokenCounter,
YiTokenCounter,
get_model_encoding,
)

Expand Down
25 changes: 0 additions & 25 deletions camel/utils/token_counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,28 +405,3 @@ def _convert_response_from_openai_to_mistral(
)

return mistral_request


# The API does not provide official token counting for Yi models, using the default OpenAI tokenizer.
class YiTokenCounter(BaseTokenCounter):
def __init__(self, model_type: UnifiedModelType):
r"""Constructor for the token counter for Yi models.
Args:
model_type (UnifiedModelType): Model type for which tokens will be
counted.
"""
self._internal_tokenizer = OpenAITokenCounter(model_type)

def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
r"""Count number of tokens in the provided message list using
the tokenizer specific to this type of model.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Returns:
int: Number of tokens in the messages.
"""
return self._internal_tokenizer.count_tokens_from_messages(messages)
2 changes: 0 additions & 2 deletions test/models/test_yi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from camel.configs import YiConfig
from camel.models import YiModel
from camel.types import ModelType
from camel.utils import YiTokenCounter


@pytest.mark.model_backend
Expand All @@ -36,7 +35,6 @@ def test_yi_model(model_type: ModelType):
model = YiModel(model_type)
assert model.model_type == model_type
assert model.model_config_dict == YiConfig().as_dict()
assert isinstance(model.token_counter, YiTokenCounter)
assert isinstance(model.model_type.value_for_tiktoken, str)
assert isinstance(model.model_type.token_limit, int)

Expand Down

0 comments on commit cc067bc

Please sign in to comment.