Skip to content

Commit

Permalink
Bump version and fix dependency issue (#524)
Browse files Browse the repository at this point in the history
Co-authored-by: Rajas Bansal <rajas@refuel.ai>
  • Loading branch information
rajasbansal and rajasbansal authored Aug 10, 2023
1 parent 10ff03a commit 3e2e35e
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "refuel-autolabel"
version = "0.0.11"
version = "0.0.12"
description = "Label, clean and enrich text datasets with LLMs"
readme = "README.md"
authors = [{ name = "Refuel.ai", email = "support@refuel.ai" }]
Expand Down
16 changes: 12 additions & 4 deletions src/autolabel/models/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Optional

from anthropic import tokenizer
from autolabel.configs import AutolabelConfig
from autolabel.models import BaseModel
from autolabel.cache import BaseCache
Expand Down Expand Up @@ -29,7 +28,14 @@ class AnthropicLLM(BaseModel):

def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
super().__init__(config, cache)
from langchain.chat_models import ChatAnthropic

try:
from langchain.chat_models import ChatAnthropic
from anthropic import tokenizer
except ImportError:
raise ImportError(
"anthropic is required to use the anthropic LLM. Please install it with the following command: pip install 'refuel-autolabel[anthropic]'"
)

# populate model name
self.model_name = config.model_name() or self.DEFAULT_MODEL
Expand All @@ -39,6 +45,8 @@ def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
# initialize LLM
self.llm = ChatAnthropic(model=self.model_name, **self.model_params)

self.tokenizer = tokenizer

def _label(self, prompts: List[str]) -> RefuelLLMResult:
prompts = [[HumanMessage(content=prompt)] for prompt in prompts]
try:
Expand All @@ -50,9 +58,9 @@ def _label(self, prompts: List[str]) -> RefuelLLMResult:
return self._label_individually(prompts)

def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
num_prompt_toks = tokenizer.count_tokens(prompt)
num_prompt_toks = self.tokenizer.count_tokens(prompt)
if label:
num_label_toks = tokenizer.count_tokens(label)
num_label_toks = self.tokenizer.count_tokens(label)
else:
# get an upper bound
num_label_toks = self.model_params["max_tokens_to_sample"]
Expand Down
9 changes: 7 additions & 2 deletions src/autolabel/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ class CohereLLM(BaseModel):

def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
super().__init__(config, cache)
import cohere
from langchain.llms import Cohere
try:
import cohere
from langchain.llms import Cohere
except ImportError:
raise ImportError(
"cohere is required to use the cohere LLM. Please install it with the following command: pip install 'refuel-autolabel[cohere]'"
)

# populate model name
self.model_name = config.model_name() or self.DEFAULT_MODEL
Expand Down
17 changes: 12 additions & 5 deletions src/autolabel/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from functools import cached_property
from typing import List, Optional
import logging
import tiktoken

from autolabel.models import BaseModel
from autolabel.configs import AutolabelConfig
Expand Down Expand Up @@ -84,8 +83,14 @@ def _engine(self) -> str:

def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
super().__init__(config, cache)
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
try:
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
import tiktoken
except ImportError:
raise ImportError(
"anthropic is required to use the anthropic LLM. Please install it with the following command: pip install 'refuel-autolabel[openai]'"
)

# populate model name
self.model_name = config.model_name() or self.DEFAULT_MODEL
Expand All @@ -111,6 +116,8 @@ def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
}
self.llm = OpenAI(model_name=self.model_name, **self.model_params)

self.tiktoken = tiktoken

def _generate_logit_bias(self) -> None:
"""Generates logit bias for the labels specified in the config
Expand All @@ -122,7 +129,7 @@ def _generate_logit_bias(self) -> None:
"No labels specified in the config. Skipping logit bias generation."
)
return {}
encoding = tiktoken.encoding_for_model(self.model_name)
encoding = self.tiktoken.encoding_for_model(self.model_name)
logit_bias = {}
max_tokens = 0
for label in self.config.labels_list():
Expand All @@ -149,7 +156,7 @@ def _label(self, prompts: List[str]) -> RefuelLLMResult:
return self._label_individually(prompts)

def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
encoding = tiktoken.encoding_for_model(self.model_name)
encoding = self.tiktoken.encoding_for_model(self.model_name)
num_prompt_toks = len(encoding.encode(prompt))
if label:
num_label_toks = len(encoding.encode(label))
Expand Down
10 changes: 7 additions & 3 deletions src/autolabel/models/palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,13 @@ def __init__(
cache: BaseCache = None,
) -> None:
super().__init__(config, cache)

from langchain.chat_models import ChatVertexAI
from langchain.llms import VertexAI
try:
from langchain.chat_models import ChatVertexAI
from langchain.llms import VertexAI
except ImportError:
raise ImportError(
"palm is required to use the Palm LLM. Please install it with the following command: pip install 'refuel-autolabel[google]'"
)

# populate model name
self.model_name = config.model_name() or self.DEFAULT_MODEL
Expand Down

0 comments on commit 3e2e35e

Please sign in to comment.