Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(embeddings): Add text embedding feature #376

Merged
merged 12 commits into from
Nov 23, 2023
3 changes: 1 addition & 2 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
OpenAIBackendRole,
RoleType,
)
from camel.utils import get_model_encoding, openai_api_key_required
from camel.utils import get_model_encoding


@dataclass(frozen=True)
Expand Down Expand Up @@ -257,7 +257,6 @@ def record_message(self, message: BaseMessage) -> None:
"""
self.update_memory(message, OpenAIBackendRole.ASSISTANT)

@openai_api_key_required
def step(
self,
input_message: BaseMessage,
Expand Down
20 changes: 20 additions & 0 deletions camel/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .base import BaseEmbedding
from .openai_embedding import OpenAIEmbedding

__all__ = [
"BaseEmbedding",
"OpenAIEmbedding",
]
63 changes: 63 additions & 0 deletions camel/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from typing import Any, List


class BaseEmbedding(ABC):
r"""Abstract base class for text embedding functionalities."""

@abstractmethod
def embed_texts(
self,
texts: List[str],
**kwargs: Any,
) -> List[List[float]]:
r"""Generates embeddings for the given texts.

Args:
texts (List[str]): The texts for which to generate the embeddings.
**kwargs (Any): Extra kwargs passed to the embedding API.

Returns:
List[List[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
"""
pass

def embed_text(
self,
text: str,
**kwargs: Any,
) -> List[float]:
r"""Generates an embedding for the given text.

Args:
text (str): The text for which to generate the embedding.
**kwargs (Any): Extra kwargs passed to the embedding API.

Returns:
List[float]: A list of floating-point numbers representing the
generated embedding.
"""
return self.embed_texts([text], **kwargs)[0]

@abstractmethod
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.

Returns:
int: The dimensionality of the embedding for the current model.
"""
pass
74 changes: 74 additions & 0 deletions camel/embeddings/openai_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Any, List

from openai import OpenAI

from camel.embeddings import BaseEmbedding
from camel.types import EmbeddingModelType
from camel.utils import openai_api_key_required


class OpenAIEmbedding(BaseEmbedding):
r"""Provides text embedding functionalities using OpenAI's models.

Args:
model (OpenAiEmbeddingModel, optional): The model type to be used for
generating embeddings. (default: :obj:`ModelType.ADA2`)

Raises:
RuntimeError: If an unsupported model type is specified.
"""

def __init__(
self,
model_type: EmbeddingModelType = EmbeddingModelType.ADA2,
) -> None:
if not model_type.is_openai:
raise ValueError("Invalid OpenAI embedding model type.")
self.model_type = model_type
self.output_dim = model_type.output_dim
self.client = OpenAI()

@openai_api_key_required
def embed_texts(
self,
texts: List[str],
**kwargs: Any,
) -> List[List[float]]:
r"""Generates embeddings for the given texts.

Args:
texts (List[str]): The texts for which to generate the embeddings.
**kwargs (Any): Extra kwargs passed to the embedding API.

Returns:
List[List[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
"""
# TODO: count tokens
response = self.client.embeddings.create(
input=texts,
lightaime marked this conversation as resolved.
Show resolved Hide resolved
model=self.model_type.value,
**kwargs,
)
return [data.embedding for data in response.data]

def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.

Returns:
int: The dimensionality of the embedding for the current model.
"""
return self.output_dim
7 changes: 6 additions & 1 deletion camel/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
from camel.utils import BaseTokenCounter, OpenAITokenCounter
from camel.utils import (
BaseTokenCounter,
OpenAITokenCounter,
openai_api_key_required,
)


class OpenAIModel(BaseModelBackend):
Expand Down Expand Up @@ -53,6 +57,7 @@ def token_counter(self) -> BaseTokenCounter:
self._token_counter = OpenAITokenCounter(self.model_type)
return self._token_counter

@openai_api_key_required
def run(
self,
messages: List[OpenAIMessage],
Expand Down
2 changes: 2 additions & 0 deletions camel/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TaskType,
TerminationMode,
OpenAIBackendRole,
EmbeddingModelType,
VectorDistance,
)
from .openai_types import (
Expand All @@ -38,6 +39,7 @@
'TaskType',
'TerminationMode',
'OpenAIBackendRole',
'EmbeddingModelType',
'VectorDistance',
'Choice',
'ChatCompletion',
Expand Down
76 changes: 50 additions & 26 deletions camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,38 +39,28 @@ class ModelType(Enum):

@property
def value_for_tiktoken(self) -> str:
return self.value if self.name != "STUB" else "gpt-3.5-turbo"
return self.value if self is ModelType.STUB else "gpt-3.5-turbo"

@property
def is_openai(self) -> bool:
r"""Returns whether this type of models is an OpenAI-released model.

Returns:
bool: Whether this type of models belongs to OpenAI.
"""
if self.name in {
"GPT_3_5_TURBO",
"GPT_3_5_TURBO_16K",
"GPT_4",
"GPT_4_32K",
"GPT_4_TURBO",
"GPT_4_TURBO_VISION",
}:
return True
else:
return False
r"""Returns whether this type of models is an OpenAI-released model."""
return self in {
ModelType.GPT_3_5_TURBO,
ModelType.GPT_3_5_TURBO_16K,
ModelType.GPT_4,
ModelType.GPT_4_32K,
ModelType.GPT_4_TURBO,
ModelType.GPT_4_TURBO_VISION,
}

@property
def is_open_source(self) -> bool:
r"""Returns whether this type of models is open-source.

Returns:
bool: Whether this type of models is open-source.
"""
if self.name in {"LLAMA_2", "VICUNA", "VICUNA_16K"}:
return True
else:
return False
r"""Returns whether this type of models is open-source."""
return self in {
ModelType.LLAMA_2,
ModelType.VICUNA,
ModelType.VICUNA_16K,
}

@property
def token_limit(self) -> int:
Expand Down Expand Up @@ -123,6 +113,40 @@ def validate_model_name(self, model_name: str) -> bool:
return self.value in model_name.lower()


class EmbeddingModelType(Enum):
ADA2 = "text-embedding-ada-002"
ADA1 = "text-embedding-ada-001"
BABBAGE1 = "text-embedding-babbage-001"
CURIE1 = "text-embedding-curie-001"
DAVINCI1 = "text-embedding-davinci-001"

@property
def is_openai(self) -> bool:
r"""Returns whether this type of models is an OpenAI-released model."""
return self in {
EmbeddingModelType.ADA1,
EmbeddingModelType.ADA2,
EmbeddingModelType.BABBAGE1,
EmbeddingModelType.CURIE1,
EmbeddingModelType.DAVINCI1,
}

@property
def output_dim(self) -> int:
if self is EmbeddingModelType.ADA2:
return 1536
elif self is EmbeddingModelType.ADA1:
return 1024
elif self is EmbeddingModelType.BABBAGE1:
return 2048
elif self is EmbeddingModelType.CURIE1:
return 4096
elif self is EmbeddingModelType.DAVINCI1:
return 12288
else:
raise ValueError(f"Unknown model type {self}.")


class TaskType(Enum):
AI_SOCIETY = "ai_society"
CODE = "code"
Expand Down
8 changes: 2 additions & 6 deletions camel/utils/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import requests

from camel.types import ModelType, TaskType
from camel.types import TaskType

F = TypeVar('F', bound=Callable[..., Any])

Expand All @@ -55,11 +55,7 @@ def openai_api_key_required(func: F) -> F:

@wraps(func)
def wrapper(self, *args, **kwargs):
if hasattr(self, "model") and self.model_type == ModelType.STUB:
return func(self, *args, **kwargs)
elif self.model_type.is_open_source:
return func(self, *args, **kwargs)
elif 'OPENAI_API_KEY' in os.environ:
dandansamax marked this conversation as resolved.
Show resolved Hide resolved
if 'OPENAI_API_KEY' in os.environ:
return func(self, *args, **kwargs)
else:
raise ValueError('OpenAI API key not found.')
Expand Down
23 changes: 23 additions & 0 deletions test/embeddings/test_openai_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import pytest

from camel.embeddings import BaseEmbedding, OpenAIEmbedding


@pytest.mark.parametrize("embedding_model", [OpenAIEmbedding()])
def test_embedding(embedding_model: BaseEmbedding):
text = "test embedding text."
vector = embedding_model.embed_text(text)
assert len(vector) == embedding_model.get_output_dim()
Loading