Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cerebras API + node validation for airetry! #217

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.58.0]

### Added
- Added support for [Cerebras](https://cloud.cerebras.ai) hosted models (set your ENV `CEREBRAS_API_KEY`). Available model aliases: `cl3` (Llama3.1 8bn), `cl70` (Llama3.1 70bn).
- Added a kwarg to `aiclassify` to provide a custom token ID mapping (`token_ids_map`) to work with custom tokenizers.

### Updated
- Improved the implementation of `airetry!` to concatenate feedback from all ancestor nodes ONLY IF `feedback_inplace=true` (because otherwise LLM can see it in the message history).

### Fixed
- Fixed a potential bug in `airetry!` where the `aicall` object was not properly validated to ensure it has been `run!` first.

## [0.57.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.57.0"
version = "0.58.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
15 changes: 12 additions & 3 deletions src/Experimental/AgentTools/retry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
verbose::Bool = true, throw::Bool = false, evaluate_all::Bool = true, feedback_expensive::Bool = false,
max_retries::Union{Nothing, Int} = nothing, retry_delay::Union{Nothing, Int} = nothing)

Evaluates the condition `f_cond` on the `aicall` object.
Evaluates the condition `f_cond` on the `aicall` object.
If the condition is not met, it will return the best sample to retry from and provide `feedback` (string or function) to `aicall`. That's why it's mutating.
It will retry maximum `max_retries` times, with `throw=true`, an error will be thrown if the condition is not met after `max_retries` retries.

Note: `aicall` must be run first via `run!(aicall)` before calling `airetry!`.

Function signatures
- `f_cond(aicall::AICallBlock) -> Bool`, ie, it must accept the aicall object and return a boolean value.
- `feedback` can be a string or `feedback(aicall::AICallBlock) -> String`, ie, it must accept the aicall object and return a string.
Expand Down Expand Up @@ -286,6 +288,9 @@ function airetry!(f_cond::Function, aicall::AICallBlock,
(; config) = aicall
(; max_calls, feedback_inplace, feedback_template) = aicall.config

## Validate that the aicall has been run first
@assert aicall.success isa Bool "Provided `aicall` has not been run yet. Use `run!(aicall)` first, before calling `airetry!` to check the condition."

max_retries = max_retries isa Nothing ? config.max_retries : max_retries
retry_delay = retry_delay isa Nothing ? config.retry_delay : retry_delay
verbose = min(verbose, get(aicall.kwargs, :verbose, 99))
Expand Down Expand Up @@ -505,8 +510,12 @@ conversation[end].content ==
function add_feedback!(conversation::AbstractVector{<:PT.AbstractMessage},
sample::SampleNode; feedback_inplace::Bool = false,
feedback_template::Symbol = :FeedbackFromEvaluator)
##
all_feedback = collect_all_feedback(sample)
## If you use in-place feedback, collect all feedback from ancestors (because you won't see the history otherwise)
all_feedback = if feedback_inplace
collect_all_feedback(sample)
else
sample.feedback
end
## short circuit if no feedback
if strip(all_feedback) == ""
return conversation
Expand Down
14 changes: 14 additions & 0 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,20 @@ Requires one environment variable to be set:
"""
struct OpenRouterOpenAISchema <: AbstractOpenAISchema end

"""
CerebrasOpenAISchema
Schema to call the [Cerebras](https://cerebras.ai/) API.
Links:
- [Get your API key](https://cloud.cerebras.ai)
- [API Reference](https://inference-docs.cerebras.ai/api-reference/chat-completions)
Requires one environment variable to be set:
- `CEREBRAS_API_KEY`: Your API key
"""
struct CerebrasOpenAISchema <: AbstractOpenAISchema end

abstract type AbstractOllamaSchema <: AbstractPromptSchema end

"""
Expand Down
93 changes: 63 additions & 30 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,15 @@
api_key = isempty(OPENROUTER_API_KEY) ? api_key : OPENROUTER_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::CerebrasOpenAISchema,

Check warning on line 240 in src/llm_openai.jl

View check run for this annotation

Codecov / codecov/patch

src/llm_openai.jl#L240

Added line #L240 was not covered by tests
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "https://api.cerebras.ai/v1",
kwargs...)
api_key = isempty(CEREBRAS_API_KEY) ? api_key : CEREBRAS_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)

Check warning on line 247 in src/llm_openai.jl

View check run for this annotation

Codecov / codecov/patch

src/llm_openai.jl#L246-L247

Added lines #L246 - L247 were not covered by tests
end
function OpenAI.create_chat(schema::DatabricksOpenAISchema,
api_key::AbstractString,
model::AbstractString,
Expand Down Expand Up @@ -272,19 +281,20 @@
end
end
function OpenAI.create_chat(schema::AzureOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
api_version::String = "2023-03-15-preview",
http_kwargs::NamedTuple = NamedTuple(),
streamcallback::Any = nothing,
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)
api_key::AbstractString,
model::AbstractString,
conversation;
api_version::String = "2023-03-15-preview",
http_kwargs::NamedTuple = NamedTuple(),
streamcallback::Any = nothing,
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)

# Build the corresponding provider object
provider = OpenAI.AzureProvider(;
api_key = isempty(AZURE_OPENAI_API_KEY) ? api_key : AZURE_OPENAI_API_KEY,
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) * "/openai/deployments/$model",
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) *
"/openai/deployments/$model",
api_version = api_version
)
# Override standard OpenAI request endpoint
Expand All @@ -297,7 +307,7 @@
query = Dict("api-version" => provider.api_version),
streamcallback = streamcallback,
kwargs...
)
)
end

# Extend OpenAI create_embeddings to allow for testing
Expand Down Expand Up @@ -396,17 +406,18 @@
OpenAI.create_embeddings(provider, docs, model; kwargs...)
end
function OpenAI.create_embeddings(schema::AzureOpenAISchema,
api_key::AbstractString,
docs,
model::AbstractString;
api_version::String = "2023-03-15-preview",
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)
api_key::AbstractString,
docs,
model::AbstractString;
api_version::String = "2023-03-15-preview",
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)

# Build the corresponding provider object
provider = OpenAI.AzureProvider(;
api_key = isempty(AZURE_OPENAI_API_KEY) ? api_key : AZURE_OPENAI_API_KEY,
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) * "/openai/deployments/$model",
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) *
"/openai/deployments/$model",
api_version = api_version)
# Override standard OpenAI request endpoint
OpenAI.openai_request(
Expand Down Expand Up @@ -851,11 +862,15 @@
"38" => 3150,
"39" => 3255,
"40" => 1723)
## Note: You can provide your own token IDs map to `encode_choices` to use a custom mapping via kwarg: token_ids_map

function pick_tokenizer(model::AbstractString)
function pick_tokenizer(model::AbstractString;
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing)
global OPENAI_TOKEN_IDS_GPT35_GPT4, OPENAI_TOKEN_IDS_GPT4O
OPENAI_TOKEN_IDS = if model == "gpt-4" || startswith(model, "gpt-3.5") ||
startswith(model, "gpt-4-")
OPENAI_TOKEN_IDS = if !isnothing(token_ids_map)
token_ids_map

Check warning on line 871 in src/llm_openai.jl

View check run for this annotation

Codecov / codecov/patch

src/llm_openai.jl#L871

Added line #L871 was not covered by tests
elseif (model == "gpt-4" || startswith(model, "gpt-3.5") ||
startswith(model, "gpt-4-"))
OPENAI_TOKEN_IDS_GPT35_GPT4
elseif startswith(model, "gpt-4o")
OPENAI_TOKEN_IDS_GPT4O
Expand All @@ -866,10 +881,15 @@
end

"""
encode_choices(schema::OpenAISchema, choices::AbstractVector{<:AbstractString}; kwargs...)
encode_choices(schema::OpenAISchema, choices::AbstractVector{<:AbstractString};
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)

encode_choices(schema::OpenAISchema, choices::AbstractVector{T};
kwargs...) where {T <: Tuple{<:AbstractString, <:AbstractString}}
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...) where {T <: Tuple{<:AbstractString, <:AbstractString}}

Encode the choices into an enumerated list that can be interpolated into the prompt and creates the corresponding logit biases (to choose only from the selected tokens).

Expand All @@ -880,6 +900,8 @@
# Arguments
- `schema::OpenAISchema`: The OpenAISchema object.
- `choices::AbstractVector{<:Union{AbstractString,Tuple{<:AbstractString, <:AbstractString}}}`: The choices to be encoded, represented as a vector of the choices directly, or tuples where each tuple contains a choice and its description.
- `model::AbstractString`: The model to use for encoding. Can be an alias corresponding to a model ID defined in `MODEL_ALIASES`.
- `token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing`: A dictionary mapping custom token IDs to their corresponding integer values. If `nothing`, it will use the default token IDs for the given model.
- `kwargs...`: Additional keyword arguments.

# Returns
Expand Down Expand Up @@ -908,8 +930,9 @@
function encode_choices(schema::OpenAISchema,
choices::AbstractVector{<:AbstractString};
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
OPENAI_TOKEN_IDS = pick_tokenizer(model)
OPENAI_TOKEN_IDS = pick_tokenizer(model; token_ids_map)
## if all choices are in the dictionary, use the dictionary
if all(Base.Fix1(haskey, OPENAI_TOKEN_IDS), choices)
choices_prompt = ["$c for \"$c\"" for c in choices]
Expand All @@ -927,8 +950,9 @@
function encode_choices(schema::OpenAISchema,
choices::AbstractVector{T};
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...) where {T <: Tuple{<:AbstractString, <:AbstractString}}
OPENAI_TOKEN_IDS = pick_tokenizer(model)
OPENAI_TOKEN_IDS = pick_tokenizer(model; token_ids_map)
## if all choices are in the dictionary, use the dictionary
if all(Base.Fix1(haskey, OPENAI_TOKEN_IDS), first.(choices))
choices_prompt = ["$c for \"$desc\"" for (c, desc) in choices]
Expand Down Expand Up @@ -958,6 +982,7 @@

function decode_choices(schema::OpenAISchema, choices, conv::AbstractVector;
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
if length(conv) > 0 && last(conv) isa AIMessage && hasproperty(last(conv), :run_id)
## if it is a multi-sample response,
Expand All @@ -966,7 +991,7 @@
for i in eachindex(conv)
msg = conv[i]
if isaimessage(msg) && msg.run_id == run_id
conv[i] = decode_choices(schema, choices, msg; model)
conv[i] = decode_choices(schema, choices, msg; model, token_ids_map)
end
end
end
Expand All @@ -976,16 +1001,20 @@
"""
decode_choices(schema::OpenAISchema,
choices::AbstractVector{<:AbstractString},
msg::AIMessage; model::AbstractString, kwargs...)
msg::AIMessage; model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)

Decodes the underlying AIMessage against the original choices to lookup what the category name was.

If it fails, it will return `msg.content == nothing`
"""
function decode_choices(schema::OpenAISchema,
choices::AbstractVector{<:AbstractString},
msg::AIMessage; model::AbstractString, kwargs...)
OPENAI_TOKEN_IDS = pick_tokenizer(model)
msg::AIMessage; model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
OPENAI_TOKEN_IDS = pick_tokenizer(model; token_ids_map)
parsed_digit = tryparse(Int, strip(msg.content))
if !isnothing(parsed_digit) && haskey(OPENAI_TOKEN_IDS, strip(msg.content))
## It's encoded
Expand All @@ -1006,6 +1035,7 @@
choices::AbstractVector{T} = ["true", "false", "unknown"],
model::AbstractString = MODEL_CHAT,
api_kwargs::NamedTuple = NamedTuple(),
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...) where {T <: Union{AbstractString, Tuple{<:AbstractString, <:AbstractString}}}

Classifies the given prompt/statement into an arbitrary list of `choices`, which must be only the choices (vector of strings) or choices and descriptions are provided (vector of tuples, ie, `("choice","description")`).
Expand All @@ -1025,6 +1055,7 @@
- `choices::AbstractVector{T}`: The choices to be classified into. It can be a vector of strings or a vector of tuples, where the first element is the choice and the second is the description.
- `model::AbstractString = MODEL_CHAT`: The model to use for classification. Can be an alias corresponding to a model ID defined in `MODEL_ALIASES`.
- `api_kwargs::NamedTuple = NamedTuple()`: Additional keyword arguments for the API call.
- `token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing`: A dictionary mapping custom token IDs to their corresponding integer values. If `nothing`, it will use the default token IDs for the given model.
- `kwargs`: Additional keyword arguments for the prompt template.

# Example
Expand Down Expand Up @@ -1085,12 +1116,13 @@
choices::AbstractVector{T} = ["true", "false", "unknown"],
model::AbstractString = MODEL_CHAT,
api_kwargs::NamedTuple = NamedTuple(),
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...) where {T <:
Union{AbstractString, Tuple{<:AbstractString, <:AbstractString}}}
## Encode the choices and the corresponding prompt
model_id = get(MODEL_ALIASES, model, model)
choices_prompt, logit_bias, decode_ids = encode_choices(
prompt_schema, choices; model = model_id)
prompt_schema, choices; model = model_id, token_ids_map)
## We want only 1 token
api_kwargs = merge(api_kwargs, (; logit_bias, max_tokens = 1, temperature = 0))
msg_or_conv = aigenerate(prompt_schema,
Expand All @@ -1099,7 +1131,8 @@
model = model_id,
api_kwargs,
kwargs...)
return decode_choices(prompt_schema, decode_ids, msg_or_conv; model = model_id)
return decode_choices(
prompt_schema, decode_ids, msg_or_conv; model = model_id, token_ids_map)
end

function response_to_message(schema::AbstractOpenAISchema,
Expand Down
24 changes: 23 additions & 1 deletion src/user_preferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Check your preferences by calling `get_preferences(key::String)`.
- `GROQ_API_KEY`: The API key for the Groq API. Free in beta! Get yours from [here](https://console.groq.com/keys).
- `DEEPSEEK_API_KEY`: The API key for the DeepSeek API. Get \$5 credit when you join. Get yours from [here](https://platform.deepseek.com/api_keys).
- `OPENROUTER_API_KEY`: The API key for the OpenRouter API. Get yours from [here](https://openrouter.ai/keys).
- `CEREBRAS_API_KEY`: The API key for the Cerebras API. Get yours from [here](https://cloud.cerebras.ai/).
- `MODEL_CHAT`: The default model to use for aigenerate and most ai* calls. See `MODEL_REGISTRY` for a list of available models or define your own.
- `MODEL_EMBEDDING`: The default model to use for aiembed (embedding documents). See `MODEL_REGISTRY` for a list of available models or define your own.
- `PROMPT_SCHEMA`: The default prompt schema to use for aigenerate and most ai* calls (if not specified in `MODEL_REGISTRY`). Set as a string, eg, `"OpenAISchema"`.
Expand Down Expand Up @@ -55,6 +56,7 @@ Define your `register_model!()` calls in your `startup.jl` file to make them ava
- `GROQ_API_KEY`: The API key for the Groq API. Free in beta! Get yours from [here](https://console.groq.com/keys).
- `DEEPSEEK_API_KEY`: The API key for the DeepSeek API. Get \$5 credit when you join. Get yours from [here](https://platform.deepseek.com/api_keys).
- `OPENROUTER_API_KEY`: The API key for the OpenRouter API. Get yours from [here](https://openrouter.ai/keys).
- `CEREBRAS_API_KEY`: The API key for the Cerebras API.
- `LOG_DIR`: The directory to save the logs to, eg, when using `SaverSchema <: AbstractTracerSchema`. Defaults to `joinpath(pwd(), "log")`. Refer to `?SaverSchema` for more information on how it works and examples.

Preferences.jl takes priority over ENV variables, so if you set a preference, it will take precedence over the ENV variable.
Expand All @@ -78,6 +80,7 @@ const ALLOWED_PREFERENCES = ["MISTRALAI_API_KEY",
"GROQ_API_KEY",
"DEEPSEEK_API_KEY",
"OPENROUTER_API_KEY", # Added OPENROUTER_API_KEY
"CEREBRAS_API_KEY",
"MODEL_CHAT",
"MODEL_EMBEDDING",
"MODEL_ALIASES",
Expand Down Expand Up @@ -159,6 +162,7 @@ global VOYAGE_API_KEY::String = ""
global GROQ_API_KEY::String = ""
global DEEPSEEK_API_KEY::String = ""
global OPENROUTER_API_KEY::String = ""
global CEREBRAS_API_KEY::String = ""
global LOCAL_SERVER::String = ""
global LOG_DIR::String = ""

Expand Down Expand Up @@ -216,6 +220,9 @@ function load_api_keys!()
global OPENROUTER_API_KEY # Added OPENROUTER_API_KEY
OPENROUTER_API_KEY = @load_preference("OPENROUTER_API_KEY",
default=get(ENV, "OPENROUTER_API_KEY", ""))
global CEREBRAS_API_KEY
CEREBRAS_API_KEY = @load_preference("CEREBRAS_API_KEY",
default=get(ENV, "CEREBRAS_API_KEY", ""))
global LOCAL_SERVER
LOCAL_SERVER = @load_preference("LOCAL_SERVER",
default=get(ENV, "LOCAL_SERVER", ""))
Expand Down Expand Up @@ -410,6 +417,11 @@ aliases = merge(
"gll" => "llama-3.1-405b-reasoning", #l for large
"gmixtral" => "mixtral-8x7b-32768",
"ggemma9" => "gemma2-9b-it",
## Cerebras
"cl3" => "llama3.1-8b",
"cllama3" => "llama3.1-8b",
"cl70" => "llama3.1-70b",
"cllama70" => "llama3.1-70b",
## DeepSeek
"dschat" => "deepseek-chat",
"dscode" => "deepseek-coder",
Expand Down Expand Up @@ -885,7 +897,17 @@ registry = Dict{String, ModelSpec}(
OpenRouterOpenAISchema(),
2e-6,
2e-6,
"Meta's Llama3.1 405b, hosted by OpenRouter. This is a BASE model!! Max output 32K tokens, 131K context. See details [here](https://openrouter.ai/models/meta-llama/llama-3.1-405b)")
"Meta's Llama3.1 405b, hosted by OpenRouter. This is a BASE model!! Max output 32K tokens, 131K context. See details [here](https://openrouter.ai/models/meta-llama/llama-3.1-405b)"),
"llama3.1-8b" => ModelSpec("llama3.1-8b",
CerebrasOpenAISchema(),
1e-7,
1e-7,
"Meta's Llama3.1 8b, hosted by Cerebras.ai. Max 8K context."),
"llama3.1-70b" => ModelSpec("llama3.1-70b",
CerebrasOpenAISchema(),
6e-7,
6e-7,
"Meta's Llama3.1 70b, hosted by Cerebras.ai. Max 8K context.")
)

"""
Expand Down
Loading
Loading