Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ensuring that max_context_tokens is never larger than what supported by models #3519

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 100 additions & 33 deletions core/quivr_core/rag/entities/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,89 +75,139 @@ class DefaultModelSuppliers(str, Enum):


class LLMConfig(QuivrBaseConfig):
context: int | None = None
max_context_tokens: int | None = None
max_output_tokens: int | None = None
tokenizer_hub: str | None = None


class LLMModelConfig:
_model_defaults: Dict[DefaultModelSuppliers, Dict[str, LLMConfig]] = {
DefaultModelSuppliers.OPENAI: {
"gpt-4o": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
"gpt-4o-mini": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
"gpt-4-turbo": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4"),
"gpt-4": LLMConfig(context=8192, tokenizer_hub="Xenova/gpt-4"),
"gpt-4o": LLMConfig(
max_context_tokens=128000,
max_output_tokens=16384,
tokenizer_hub="Xenova/gpt-4o",
),
"gpt-4o-mini": LLMConfig(
max_context_tokens=128000,
max_output_tokens=16384,
tokenizer_hub="Xenova/gpt-4o",
),
"gpt-4-turbo": LLMConfig(
max_context_tokens=128000,
max_output_tokens=4096,
tokenizer_hub="Xenova/gpt-4",
),
"gpt-4": LLMConfig(
max_context_tokens=8192,
max_output_tokens=8192,
tokenizer_hub="Xenova/gpt-4",
),
"gpt-3.5-turbo": LLMConfig(
context=16385, tokenizer_hub="Xenova/gpt-3.5-turbo"
max_context_tokens=16385,
max_output_tokens=4096,
tokenizer_hub="Xenova/gpt-3.5-turbo",
),
"text-embedding-3-large": LLMConfig(
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
max_context_tokens=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
),
"text-embedding-3-small": LLMConfig(
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
max_context_tokens=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
),
"text-embedding-ada-002": LLMConfig(
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
max_context_tokens=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
),
},
DefaultModelSuppliers.ANTHROPIC: {
"claude-3-5-sonnet": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=200000,
max_output_tokens=8192,
tokenizer_hub="Xenova/claude-tokenizer",
),
"claude-3-opus": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=200000,
max_output_tokens=4096,
tokenizer_hub="Xenova/claude-tokenizer",
),
"claude-3-sonnet": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=200000,
max_output_tokens=4096,
tokenizer_hub="Xenova/claude-tokenizer",
),
"claude-3-haiku": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=200000,
max_output_tokens=4096,
tokenizer_hub="Xenova/claude-tokenizer",
),
"claude-2-1": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=200000,
max_output_tokens=4096,
tokenizer_hub="Xenova/claude-tokenizer",
),
"claude-2-0": LLMConfig(
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=100000,
max_output_tokens=4096,
tokenizer_hub="Xenova/claude-tokenizer",
),
"claude-instant-1-2": LLMConfig(
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=100000,
max_output_tokens=4096,
tokenizer_hub="Xenova/claude-tokenizer",
),
},
# Unclear for LLAMA models...
# see https://huggingface.co/meta-llama/Llama-3.1-405B-Instruct/discussions/6
DefaultModelSuppliers.META: {
"llama-3.1": LLMConfig(
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
max_context_tokens=128000,
max_output_tokens=4096,
tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer",
),
"llama-3": LLMConfig(
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
max_context_tokens=8192,
max_output_tokens=2048,
tokenizer_hub="Xenova/llama3-tokenizer-new",
),
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
"code-llama": LLMConfig(
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
max_context_tokens=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
),
},
DefaultModelSuppliers.GROQ: {
"llama-3.1": LLMConfig(
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
"llama-3.3-70b": LLMConfig(
max_context_tokens=128000,
max_output_tokens=32768,
tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer",
),
"llama-3.1-70b": LLMConfig(
max_context_tokens=128000,
max_output_tokens=32768,
tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer",
),
"llama-3": LLMConfig(
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
max_context_tokens=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
),
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
"code-llama": LLMConfig(
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
max_context_tokens=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
),
},
DefaultModelSuppliers.MISTRAL: {
"mistral-large": LLMConfig(
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
max_context_tokens=128000,
max_output_tokens=4096,
tokenizer_hub="Xenova/mistral-tokenizer-v3",
),
"mistral-small": LLMConfig(
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
max_context_tokens=128000,
max_output_tokens=4096,
tokenizer_hub="Xenova/mistral-tokenizer-v3",
),
"mistral-nemo": LLMConfig(
context=128000, tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer"
max_context_tokens=128000,
max_output_tokens=4096,
tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer",
),
"codestral": LLMConfig(
context=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
max_context_tokens=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
),
},
}
Expand Down Expand Up @@ -193,13 +243,12 @@ def get_llm_model_config(
class LLMEndpointConfig(QuivrBaseConfig):
supplier: DefaultModelSuppliers = DefaultModelSuppliers.OPENAI
model: str = "gpt-4o"
context_length: int | None = None
tokenizer_hub: str | None = None
llm_base_url: str | None = None
env_variable_name: str | None = None
llm_api_key: str | None = None
max_context_tokens: int = 2000
max_output_tokens: int = 2000
max_context_tokens: int = 10000
max_output_tokens: int = 4000
temperature: float = 0.7
streaming: bool = True
prompt: CustomPromptsModel | None = None
Expand Down Expand Up @@ -240,7 +289,25 @@ def set_llm_model_config(self):
self.supplier, self.model
)
if llm_model_config:
self.context_length = llm_model_config.context
if llm_model_config.max_context_tokens:
_max_context_tokens = (
llm_model_config.max_context_tokens
- llm_model_config.max_output_tokens
if llm_model_config.max_output_tokens
else llm_model_config.max_context_tokens
)
if self.max_context_tokens > _max_context_tokens:
logger.warning(
f"Lowering max_context_tokens from {self.max_context_tokens} to {_max_context_tokens}"
)
self.max_context_tokens = _max_context_tokens
if llm_model_config.max_output_tokens:
if self.max_output_tokens > llm_model_config.max_output_tokens:
logger.warning(
f"Lowering max_output_tokens from {self.max_output_tokens} to {llm_model_config.max_output_tokens}"
)
self.max_output_tokens = llm_model_config.max_output_tokens

self.tokenizer_hub = llm_model_config.tokenizer_hub

def set_llm_model(self, model: str):
Expand Down
Loading