Skip to content

Commit

Permalink
Add Cerebras API + node validation for airetry!
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Oct 9, 2024
1 parent 188ede9 commit 592eafb
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 38 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.58.0]

### Added
- Added support for [Cerebras](https://cloud.cerebras.ai) hosted models (set your ENV `CEREBRAS_API_KEY`). Available model aliases: `cl3` (Llama3.1 8bn), `cl70` (Llama3.1 70bn).
- Added a kwarg to `aiclassify` to provide a custom token ID mapping (`token_ids_map`) to work with custom tokenizers.

### Updated
- Improved the implementation of `airetry!` to concatenate feedback from all ancestor nodes ONLY IF `feedback_inplace=true` (because otherwise LLM can see it in the message history).

### Fixed
- Fixed a potential bug in `airetry!` where the `aicall` object was not properly validated to ensure it has been `run!` first.

## [0.57.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.57.0"
version = "0.58.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
15 changes: 12 additions & 3 deletions src/Experimental/AgentTools/retry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
verbose::Bool = true, throw::Bool = false, evaluate_all::Bool = true, feedback_expensive::Bool = false,
max_retries::Union{Nothing, Int} = nothing, retry_delay::Union{Nothing, Int} = nothing)
Evaluates the condition `f_cond` on the `aicall` object.
Evaluates the condition `f_cond` on the `aicall` object.
If the condition is not met, it will return the best sample to retry from and provide `feedback` (string or function) to `aicall`. That's why it's mutating.
It will retry maximum `max_retries` times, with `throw=true`, an error will be thrown if the condition is not met after `max_retries` retries.
Note: `aicall` must be run first via `run!(aicall)` before calling `airetry!`.
Function signatures
- `f_cond(aicall::AICallBlock) -> Bool`, ie, it must accept the aicall object and return a boolean value.
- `feedback` can be a string or `feedback(aicall::AICallBlock) -> String`, ie, it must accept the aicall object and return a string.
Expand Down Expand Up @@ -286,6 +288,9 @@ function airetry!(f_cond::Function, aicall::AICallBlock,
(; config) = aicall
(; max_calls, feedback_inplace, feedback_template) = aicall.config

## Validate that the aicall has been run first
@assert aicall.success isa Bool "Provided `aicall` has not been run yet. Use `run!(aicall)` first, before calling `airetry!` to check the condition."

max_retries = max_retries isa Nothing ? config.max_retries : max_retries
retry_delay = retry_delay isa Nothing ? config.retry_delay : retry_delay
verbose = min(verbose, get(aicall.kwargs, :verbose, 99))
Expand Down Expand Up @@ -505,8 +510,12 @@ conversation[end].content ==
function add_feedback!(conversation::AbstractVector{<:PT.AbstractMessage},
sample::SampleNode; feedback_inplace::Bool = false,
feedback_template::Symbol = :FeedbackFromEvaluator)
##
all_feedback = collect_all_feedback(sample)
## If you use in-place feedback, collect all feedback from ancestors (because you won't see the history otherwise)
all_feedback = if feedback_inplace
collect_all_feedback(sample)
else
sample.feedback
end
## short circuit if no feedback
if strip(all_feedback) == ""
return conversation
Expand Down
14 changes: 14 additions & 0 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,20 @@ Requires one environment variable to be set:
"""
struct OpenRouterOpenAISchema <: AbstractOpenAISchema end

"""
CerebrasOpenAISchema
Schema to call the [Cerebras](https://cerebras.ai/) API.
Links:
- [Get your API key](https://cloud.cerebras.ai)
- [API Reference](https://inference-docs.cerebras.ai/api-reference/chat-completions)
Requires one environment variable to be set:
- `CEREBRAS_API_KEY`: Your API key
"""
struct CerebrasOpenAISchema <: AbstractOpenAISchema end

abstract type AbstractOllamaSchema <: AbstractPromptSchema end

"""
Expand Down
93 changes: 63 additions & 30 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,15 @@ function OpenAI.create_chat(schema::OpenRouterOpenAISchema,
api_key = isempty(OPENROUTER_API_KEY) ? api_key : OPENROUTER_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::CerebrasOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "https://api.cerebras.ai/v1",
kwargs...)
api_key = isempty(CEREBRAS_API_KEY) ? api_key : CEREBRAS_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::DatabricksOpenAISchema,
api_key::AbstractString,
model::AbstractString,
Expand Down Expand Up @@ -272,19 +281,20 @@ function OpenAI.create_chat(schema::DatabricksOpenAISchema,
end
end
function OpenAI.create_chat(schema::AzureOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
api_version::String = "2023-03-15-preview",
http_kwargs::NamedTuple = NamedTuple(),
streamcallback::Any = nothing,
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)
api_key::AbstractString,
model::AbstractString,
conversation;
api_version::String = "2023-03-15-preview",
http_kwargs::NamedTuple = NamedTuple(),
streamcallback::Any = nothing,
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)

# Build the corresponding provider object
provider = OpenAI.AzureProvider(;
api_key = isempty(AZURE_OPENAI_API_KEY) ? api_key : AZURE_OPENAI_API_KEY,
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) * "/openai/deployments/$model",
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) *
"/openai/deployments/$model",
api_version = api_version
)
# Override standard OpenAI request endpoint
Expand All @@ -297,7 +307,7 @@ function OpenAI.create_chat(schema::AzureOpenAISchema,
query = Dict("api-version" => provider.api_version),
streamcallback = streamcallback,
kwargs...
)
)
end

# Extend OpenAI create_embeddings to allow for testing
Expand Down Expand Up @@ -396,17 +406,18 @@ function OpenAI.create_embeddings(schema::FireworksOpenAISchema,
OpenAI.create_embeddings(provider, docs, model; kwargs...)
end
function OpenAI.create_embeddings(schema::AzureOpenAISchema,
api_key::AbstractString,
docs,
model::AbstractString;
api_version::String = "2023-03-15-preview",
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)
api_key::AbstractString,
docs,
model::AbstractString;
api_version::String = "2023-03-15-preview",
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)

# Build the corresponding provider object
provider = OpenAI.AzureProvider(;
api_key = isempty(AZURE_OPENAI_API_KEY) ? api_key : AZURE_OPENAI_API_KEY,
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) * "/openai/deployments/$model",
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) *
"/openai/deployments/$model",
api_version = api_version)
# Override standard OpenAI request endpoint
OpenAI.openai_request(
Expand Down Expand Up @@ -851,11 +862,15 @@ const OPENAI_TOKEN_IDS_GPT4O = Dict(
"38" => 3150,
"39" => 3255,
"40" => 1723)
## Note: You can provide your own token IDs map to `encode_choices` to use a custom mapping via kwarg: token_ids_map

function pick_tokenizer(model::AbstractString)
function pick_tokenizer(model::AbstractString;
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing)
global OPENAI_TOKEN_IDS_GPT35_GPT4, OPENAI_TOKEN_IDS_GPT4O
OPENAI_TOKEN_IDS = if model == "gpt-4" || startswith(model, "gpt-3.5") ||
startswith(model, "gpt-4-")
OPENAI_TOKEN_IDS = if !isnothing(token_ids_map)
token_ids_map
elseif (model == "gpt-4" || startswith(model, "gpt-3.5") ||
startswith(model, "gpt-4-"))
OPENAI_TOKEN_IDS_GPT35_GPT4
elseif startswith(model, "gpt-4o")
OPENAI_TOKEN_IDS_GPT4O
Expand All @@ -866,10 +881,15 @@ function pick_tokenizer(model::AbstractString)
end

"""
encode_choices(schema::OpenAISchema, choices::AbstractVector{<:AbstractString}; kwargs...)
encode_choices(schema::OpenAISchema, choices::AbstractVector{<:AbstractString};
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
encode_choices(schema::OpenAISchema, choices::AbstractVector{T};
kwargs...) where {T <: Tuple{<:AbstractString, <:AbstractString}}
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...) where {T <: Tuple{<:AbstractString, <:AbstractString}}
Encode the choices into an enumerated list that can be interpolated into the prompt and creates the corresponding logit biases (to choose only from the selected tokens).
Expand All @@ -880,6 +900,8 @@ There can be at most 40 choices provided.
# Arguments
- `schema::OpenAISchema`: The OpenAISchema object.
- `choices::AbstractVector{<:Union{AbstractString,Tuple{<:AbstractString, <:AbstractString}}}`: The choices to be encoded, represented as a vector of the choices directly, or tuples where each tuple contains a choice and its description.
- `model::AbstractString`: The model to use for encoding. Can be an alias corresponding to a model ID defined in `MODEL_ALIASES`.
- `token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing`: A dictionary mapping custom token IDs to their corresponding integer values. If `nothing`, it will use the default token IDs for the given model.
- `kwargs...`: Additional keyword arguments.
# Returns
Expand Down Expand Up @@ -908,8 +930,9 @@ logit_bias # Output: Dict(16 => 100, 17 => 100, 18 => 100)
function encode_choices(schema::OpenAISchema,
choices::AbstractVector{<:AbstractString};
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
OPENAI_TOKEN_IDS = pick_tokenizer(model)
OPENAI_TOKEN_IDS = pick_tokenizer(model; token_ids_map)
## if all choices are in the dictionary, use the dictionary
if all(Base.Fix1(haskey, OPENAI_TOKEN_IDS), choices)
choices_prompt = ["$c for \"$c\"" for c in choices]
Expand All @@ -927,8 +950,9 @@ end
function encode_choices(schema::OpenAISchema,
choices::AbstractVector{T};
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...) where {T <: Tuple{<:AbstractString, <:AbstractString}}
OPENAI_TOKEN_IDS = pick_tokenizer(model)
OPENAI_TOKEN_IDS = pick_tokenizer(model; token_ids_map)
## if all choices are in the dictionary, use the dictionary
if all(Base.Fix1(haskey, OPENAI_TOKEN_IDS), first.(choices))
choices_prompt = ["$c for \"$desc\"" for (c, desc) in choices]
Expand Down Expand Up @@ -958,6 +982,7 @@ end

function decode_choices(schema::OpenAISchema, choices, conv::AbstractVector;
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
if length(conv) > 0 && last(conv) isa AIMessage && hasproperty(last(conv), :run_id)
## if it is a multi-sample response,
Expand All @@ -966,7 +991,7 @@ function decode_choices(schema::OpenAISchema, choices, conv::AbstractVector;
for i in eachindex(conv)
msg = conv[i]
if isaimessage(msg) && msg.run_id == run_id
conv[i] = decode_choices(schema, choices, msg; model)
conv[i] = decode_choices(schema, choices, msg; model, token_ids_map)
end
end
end
Expand All @@ -976,16 +1001,20 @@ end
"""
decode_choices(schema::OpenAISchema,
choices::AbstractVector{<:AbstractString},
msg::AIMessage; model::AbstractString, kwargs...)
msg::AIMessage; model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
Decodes the underlying AIMessage against the original choices to lookup what the category name was.
If it fails, it will return `msg.content == nothing`
"""
function decode_choices(schema::OpenAISchema,
choices::AbstractVector{<:AbstractString},
msg::AIMessage; model::AbstractString, kwargs...)
OPENAI_TOKEN_IDS = pick_tokenizer(model)
msg::AIMessage; model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
OPENAI_TOKEN_IDS = pick_tokenizer(model; token_ids_map)
parsed_digit = tryparse(Int, strip(msg.content))
if !isnothing(parsed_digit) && haskey(OPENAI_TOKEN_IDS, strip(msg.content))
## It's encoded
Expand All @@ -1006,6 +1035,7 @@ end
choices::AbstractVector{T} = ["true", "false", "unknown"],
model::AbstractString = MODEL_CHAT,
api_kwargs::NamedTuple = NamedTuple(),
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...) where {T <: Union{AbstractString, Tuple{<:AbstractString, <:AbstractString}}}
Classifies the given prompt/statement into an arbitrary list of `choices`, which must be only the choices (vector of strings) or choices and descriptions are provided (vector of tuples, ie, `("choice","description")`).
Expand All @@ -1025,6 +1055,7 @@ It uses Logit bias trick and limits the output to 1 token to force the model to
- `choices::AbstractVector{T}`: The choices to be classified into. It can be a vector of strings or a vector of tuples, where the first element is the choice and the second is the description.
- `model::AbstractString = MODEL_CHAT`: The model to use for classification. Can be an alias corresponding to a model ID defined in `MODEL_ALIASES`.
- `api_kwargs::NamedTuple = NamedTuple()`: Additional keyword arguments for the API call.
- `token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing`: A dictionary mapping custom token IDs to their corresponding integer values. If `nothing`, it will use the default token IDs for the given model.
- `kwargs`: Additional keyword arguments for the prompt template.
# Example
Expand Down Expand Up @@ -1085,12 +1116,13 @@ function aiclassify(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_
choices::AbstractVector{T} = ["true", "false", "unknown"],
model::AbstractString = MODEL_CHAT,
api_kwargs::NamedTuple = NamedTuple(),
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...) where {T <:
Union{AbstractString, Tuple{<:AbstractString, <:AbstractString}}}
## Encode the choices and the corresponding prompt
model_id = get(MODEL_ALIASES, model, model)
choices_prompt, logit_bias, decode_ids = encode_choices(
prompt_schema, choices; model = model_id)
prompt_schema, choices; model = model_id, token_ids_map)
## We want only 1 token
api_kwargs = merge(api_kwargs, (; logit_bias, max_tokens = 1, temperature = 0))
msg_or_conv = aigenerate(prompt_schema,
Expand All @@ -1099,7 +1131,8 @@ function aiclassify(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_
model = model_id,
api_kwargs,
kwargs...)
return decode_choices(prompt_schema, decode_ids, msg_or_conv; model = model_id)
return decode_choices(
prompt_schema, decode_ids, msg_or_conv; model = model_id, token_ids_map)
end

function response_to_message(schema::AbstractOpenAISchema,
Expand Down
24 changes: 23 additions & 1 deletion src/user_preferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Check your preferences by calling `get_preferences(key::String)`.
- `GROQ_API_KEY`: The API key for the Groq API. Free in beta! Get yours from [here](https://console.groq.com/keys).
- `DEEPSEEK_API_KEY`: The API key for the DeepSeek API. Get \$5 credit when you join. Get yours from [here](https://platform.deepseek.com/api_keys).
- `OPENROUTER_API_KEY`: The API key for the OpenRouter API. Get yours from [here](https://openrouter.ai/keys).
- `CEREBRAS_API_KEY`: The API key for the Cerebras API. Get yours from [here](https://cloud.cerebras.ai/).
- `MODEL_CHAT`: The default model to use for aigenerate and most ai* calls. See `MODEL_REGISTRY` for a list of available models or define your own.
- `MODEL_EMBEDDING`: The default model to use for aiembed (embedding documents). See `MODEL_REGISTRY` for a list of available models or define your own.
- `PROMPT_SCHEMA`: The default prompt schema to use for aigenerate and most ai* calls (if not specified in `MODEL_REGISTRY`). Set as a string, eg, `"OpenAISchema"`.
Expand Down Expand Up @@ -55,6 +56,7 @@ Define your `register_model!()` calls in your `startup.jl` file to make them ava
- `GROQ_API_KEY`: The API key for the Groq API. Free in beta! Get yours from [here](https://console.groq.com/keys).
- `DEEPSEEK_API_KEY`: The API key for the DeepSeek API. Get \$5 credit when you join. Get yours from [here](https://platform.deepseek.com/api_keys).
- `OPENROUTER_API_KEY`: The API key for the OpenRouter API. Get yours from [here](https://openrouter.ai/keys).
- `CEREBRAS_API_KEY`: The API key for the Cerebras API.
- `LOG_DIR`: The directory to save the logs to, eg, when using `SaverSchema <: AbstractTracerSchema`. Defaults to `joinpath(pwd(), "log")`. Refer to `?SaverSchema` for more information on how it works and examples.
Preferences.jl takes priority over ENV variables, so if you set a preference, it will take precedence over the ENV variable.
Expand All @@ -78,6 +80,7 @@ const ALLOWED_PREFERENCES = ["MISTRALAI_API_KEY",
"GROQ_API_KEY",
"DEEPSEEK_API_KEY",
"OPENROUTER_API_KEY", # Added OPENROUTER_API_KEY
"CEREBRAS_API_KEY",
"MODEL_CHAT",
"MODEL_EMBEDDING",
"MODEL_ALIASES",
Expand Down Expand Up @@ -159,6 +162,7 @@ global VOYAGE_API_KEY::String = ""
global GROQ_API_KEY::String = ""
global DEEPSEEK_API_KEY::String = ""
global OPENROUTER_API_KEY::String = ""
global CEREBRAS_API_KEY::String = ""
global LOCAL_SERVER::String = ""
global LOG_DIR::String = ""

Expand Down Expand Up @@ -216,6 +220,9 @@ function load_api_keys!()
global OPENROUTER_API_KEY # Added OPENROUTER_API_KEY
OPENROUTER_API_KEY = @load_preference("OPENROUTER_API_KEY",
default=get(ENV, "OPENROUTER_API_KEY", ""))
global CEREBRAS_API_KEY
CEREBRAS_API_KEY = @load_preference("CEREBRAS_API_KEY",
default=get(ENV, "CEREBRAS_API_KEY", ""))
global LOCAL_SERVER
LOCAL_SERVER = @load_preference("LOCAL_SERVER",
default=get(ENV, "LOCAL_SERVER", ""))
Expand Down Expand Up @@ -410,6 +417,11 @@ aliases = merge(
"gll" => "llama-3.1-405b-reasoning", #l for large
"gmixtral" => "mixtral-8x7b-32768",
"ggemma9" => "gemma2-9b-it",
## Cerebras
"cl3" => "llama3.1-8b",
"cllama3" => "llama3.1-8b",
"cl70" => "llama3.1-70b",
"cllama70" => "llama3.1-70b",
## DeepSeek
"dschat" => "deepseek-chat",
"dscode" => "deepseek-coder",
Expand Down Expand Up @@ -885,7 +897,17 @@ registry = Dict{String, ModelSpec}(
OpenRouterOpenAISchema(),
2e-6,
2e-6,
"Meta's Llama3.1 405b, hosted by OpenRouter. This is a BASE model!! Max output 32K tokens, 131K context. See details [here](https://openrouter.ai/models/meta-llama/llama-3.1-405b)")
"Meta's Llama3.1 405b, hosted by OpenRouter. This is a BASE model!! Max output 32K tokens, 131K context. See details [here](https://openrouter.ai/models/meta-llama/llama-3.1-405b)"),
"llama3.1-8b" => ModelSpec("llama3.1-8b",
CerebrasOpenAISchema(),
1e-7,
1e-7,
"Meta's Llama3.1 8b, hosted by Cerebras.ai. Max 8K context."),
"llama3.1-70b" => ModelSpec("llama3.1-70b",
CerebrasOpenAISchema(),
6e-7,
6e-7,
"Meta's Llama3.1 70b, hosted by Cerebras.ai. Max 8K context.")
)

"""
Expand Down
Loading

2 comments on commit 592eafb

@svilupp
Copy link
Owner Author

@svilupp svilupp commented on 592eafb Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Added

  • Added support for Cerebras hosted models (set your ENV CEREBRAS_API_KEY). Available model aliases: cl3 (Llama3.1 8bn), cl70 (Llama3.1 70bn).
  • Added a kwarg to aiclassify to provide a custom token ID mapping (token_ids_map) to work with custom tokenizers.

Updated

  • Improved the implementation of airetry! to concatenate feedback from all ancestor nodes ONLY IF feedback_inplace=true (because otherwise LLM can see it in the message history).

Fixed

  • Fixed a potential bug in airetry! where the aicall object was not properly validated to ensure it has been run! first.

Commits

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/116929

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.58.0 -m "<description of version>" 592eafb04fc175acb638546d92584820de312d85
git push origin v0.58.0

Please sign in to comment.