Skip to content

Commit

Permalink
Update Structured Extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Sep 3, 2024
1 parent 91235fa commit d04c737
Show file tree
Hide file tree
Showing 10 changed files with 408 additions and 61 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.51.0]

### Added
- Added more flexible structured extraction with `aiextract` -> now you can simply provide the field names and, optionally, their types without specifying the struct itself (in `aiextract`, provide the fields like `return_type = [:field_name => field_type]`).
- Added a way to attach field-level descriptions to the generated JSON schemas to better structured extraction (see `?update_schema_descriptions!` to see the syntax), which was not possible with struct-only extraction.

## [0.50.0]

### Breaking Changes
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.50.0"
version = "0.51.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
4 changes: 2 additions & 2 deletions src/Experimental/RAGTools/evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ end
context::AbstractString
question::AbstractString
answer::AbstractString
retrieval_score::Union{Number, Nothing} = nothing
retrieval_score::Union{Float64, Nothing} = nothing
retrieval_rank::Union{Int, Nothing} = nothing
answer_score::Union{Number, Nothing} = nothing
answer_score::Union{Float64, Nothing} = nothing
parameters::Dict{Symbol, Any} = Dict{Symbol, Any}()
end

Expand Down
222 changes: 190 additions & 32 deletions src/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,159 @@ function to_json_schema(type::Type{<:AbstractDict}; max_description_length::Int
throw(ArgumentError("Dicts are not supported yet as we cannot analyze their keys/values on a type-level. Use a nested Struct instead!"))
end

### Type conversion / Schema generation
"""
generate_struct(fields::Vector)
Generate a struct with the given name and fields. Fields can be specified simply as symbols (with default type `String`) or pairs of symbol and type.
Field descriptions can be provided by adding a pair with the field name suffixed with "__description" (eg, `:myfield__description => "My field description"`).
Returns: A tuple of (struct type, descriptions)
# Examples
```julia
Weather, descriptions = generate_struct(
[:location,
:temperature=>Float64,
:temperature__description=>"Temperature in degrees Fahrenheit",
:condition=>String,
:condition__description=>"Current weather condition (e.g., sunny, rainy, cloudy)"
])
```
"""
function generate_struct(fields::Vector)
name = gensym("ExtractedData")
struct_fields = []
descriptions = Dict{Symbol, String}()

for field in fields
if field isa Symbol
push!(struct_fields, :($field::String))
elseif field isa Pair
field_name, field_value = field
if endswith(string(field_name), "__description")
base_field = Symbol(replace(string(field_name), "__description" => ""))
descriptions[base_field] = field_value
elseif field_name isa Symbol &&
(field_value isa Type || field_value isa AbstractString)
push!(struct_fields, :($field_name::$field_value))
else
error("Invalid field specification: $(field). It must be a Symbol or a Pair{Symbol, Type} or Pair{Symbol, Pair{Type, String}}.")
end
else
error("Invalid field specification: $(field). It must be a Symbol or a Pair{Symbol, Type} or Pair{Symbol, Pair{Type, String}}.")
end
end

struct_def = quote
@kwdef struct $name <: AbstractExtractedData
$(struct_fields...)
end
end

# Evaluate the struct definition
eval(struct_def)

return eval(name), descriptions
end

"""
update_schema_descriptions!(
schema::Dict{String, <:Any}, descriptions::Dict{Symbol, <:AbstractString};
max_description_length::Int = 200)
Update the given JSON schema with descriptions from the `descriptions` dictionary.
This function modifies the schema in-place, adding a "description" field to each property
that has a corresponding entry in the `descriptions` dictionary.
Note: It modifies the schema in place. Only the top-level "properties" are updated!
Returns: The modified schema dictionary.
# Arguments
- `schema`: A dictionary representing the JSON schema to be updated.
- `descriptions`: A dictionary mapping field names (as symbols) to their descriptions.
- `max_description_length::Int`: Maximum length for descriptions. Defaults to 200.
# Examples
```julia
schema = Dict{String, Any}(
"name" => "varExtractedData235_extractor",
"parameters" => Dict{String, Any}(
"properties" => Dict{String, Any}(
"location" => Dict{String, Any}("type" => "string"),
"condition" => Dict{String, Any}("type" => "string"),
"temperature" => Dict{String, Any}("type" => "number")
),
"required" => ["location", "temperature", "condition"],
"type" => "object"
)
)
descriptions = Dict{Symbol, String}(
:temperature => "Temperature in degrees Fahrenheit",
:condition => "Current weather condition (e.g., sunny, rainy, cloudy)"
)
update_schema_descriptions!(schema, descriptions)
```
"""
function update_schema_descriptions!(
schema::Dict{String, <:Any}, descriptions::Dict{Symbol, <:AbstractString};
max_description_length::Int = 200)
properties = get(get(schema, "parameters", Dict()), "properties", Dict())

for (field, field_schema) in properties
field_sym = Symbol(field)
if haskey(descriptions, field_sym)
field_schema["description"] = first(
descriptions[field_sym], max_description_length)
end
end

return schema
end

"""
set_properties_strict!(properties::AbstractDict)
Sets strict mode for the properties of a JSON schema.
Changes:
- Sets `additionalProperties` to `false`.
- All keys must be included in `required`.
- All optional keys will have `null` added to their type.
Reference: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
"""
function set_properties_strict!(parameters::AbstractDict)
parameters["additionalProperties"] = false
required_fields = get(parameters, "required", String[])
optional_fields = String[]

for (key, value) in parameters["properties"]
if key required_fields
push!(optional_fields, key)
if haskey(value, "type")
value["type"] = [value["type"], "null"]
end
end

# Recursively apply to nested properties
if haskey(value, "properties")
set_properties_strict!(value)
elseif haskey(value, "items") && haskey(value["items"], "properties")
## if it's an array, we need to skip inside "items"
set_properties_strict!(value["items"])
end
end

parameters["required"] = vcat(required_fields, optional_fields)
return parameters
end

"""
function_call_signature(
datastructtype::Type; strict::Union{Nothing, Bool} = nothing,
max_description_length::Int = 100)
max_description_length::Int = 200)
Extract the argument names, types and docstrings from a struct to create the function call signature in JSON schema.
Expand All @@ -120,7 +269,7 @@ struct MyMeasurement
height::Union{Int,Nothing}
weight::Union{Nothing,Float64}
end
signature = function_call_signature(MyMeasurement)
signature, t = function_call_signature(MyMeasurement)
#
# Dict{String, Any} with 3 entries:
# "name" => "MyMeasurement_extractor"
Expand Down Expand Up @@ -166,7 +315,7 @@ That way, you can handle the error gracefully and get a reason why extraction fa
"""
function function_call_signature(
datastructtype::Type; strict::Union{Nothing, Bool} = nothing,
max_description_length::Int = 100)
max_description_length::Int = 200)
!isstructtype(datastructtype) &&
error("Only Structs are supported (provided type: $datastructtype")
## Standardize the name
Expand All @@ -191,45 +340,54 @@ function function_call_signature(
set_properties_strict!(schema["parameters"])
end
end
return schema
return schema, datastructtype
end

"""
set_properties_strict!(properties::AbstractDict)
function_call_signature(fields::Vector; strict::Union{Nothing, Bool} = nothing, max_description_length::Int = 200)
Sets strict mode for the properties of a JSON schema.
Generate a function call signature schema for a dynamically generated struct based on the provided fields.
Changes:
- Sets `additionalProperties` to `false`.
- All keys must be included in `required`.
- All optional keys will have `null` added to their type.
# Arguments
- `fields::Vector{Union{Symbol, Pair{Symbol, Type}, Pair{Symbol, String}}}`: A vector of field names or pairs of field name and type or string description, eg, `[:field1, :field2, :field3]` or `[:field1 => String, :field2 => Int, :field3 => Float64]` or `[:field1 => String, :field1__description => "Field 1 has the name"]`.
- `strict::Union{Nothing, Bool}`: Whether to enforce strict mode for the schema. Defaults to `nothing`.
- `max_description_length::Int`: Maximum length for descriptions. Defaults to 200.
Reference: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
# Returns a tuple of (schema, struct type)
- `Dict{String, Any}`: A dictionary representing the function call signature schema.
- `Type`: The struct type to create instance of the result.
See also `generate_struct`, `aiextract`, `update_schema_descriptions!`.
# Examples
```julia
schema, return_type = function_call_signature([:field1, :field2, :field3])
```
With the field types:
```julia
schema, return_type = function_call_signature([:field1 => String, :field2 => Int, :field3 => Float64])
```
And with the field descriptions:
```julia
schema, return_type = function_call_signature([:field1 => String, :field1__description => "Field 1 has the name"])
```
"""
function set_properties_strict!(parameters::AbstractDict)
parameters["additionalProperties"] = false
required_fields = get(parameters, "required", String[])
optional_fields = String[]
function function_call_signature(fields::Vector;
strict::Union{Nothing, Bool} = nothing, max_description_length::Int = 200)
@assert all(x -> x isa Symbol || x isa Pair, fields) "Invalid return types provided. All fields must be either Symbols or Pairs of Symbol and Type or String"
# Generate the struct and descriptions
datastructtype, descriptions = generate_struct(fields)

for (key, value) in parameters["properties"]
if key required_fields
push!(optional_fields, key)
if haskey(value, "type")
value["type"] = [value["type"], "null"]
end
end
# Create the schema
schema, _ = function_call_signature(
datastructtype; strict, max_description_length)

# Recursively apply to nested properties
if haskey(value, "properties")
set_properties_strict!(value)
elseif haskey(value, "items") && haskey(value["items"], "properties")
## if it's an array, we need to skip inside "items"
set_properties_strict!(value["items"])
end
end
# Update the schema with descriptions
update_schema_descriptions!(schema, descriptions; max_description_length)

parameters["required"] = vcat(required_fields, optional_fields)
return parameters
return schema, datastructtype
end

######################
Expand Down
31 changes: 27 additions & 4 deletions src/llm_anthropic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ end

"""
aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_TYPE;
return_type::Type,
return_type::Union{Type, Vector},
verbose::Bool = true,
api_key::String = ANTHROPIC_API_KEY,
model::String = MODEL_CHAT,
Expand All @@ -338,6 +338,7 @@ It's effectively a light wrapper around `aigenerate` call, which requires additi
- `prompt`: Can be a string representing the prompt for the AI conversation, a `UserMessage`, a vector of `AbstractMessage` or an `AITemplate`
- `return_type`: A **struct** TYPE representing the the information we want to extract. Do not provide a struct instance, only the type.
If the struct has a docstring, it will be provided to the model as well. It's used to enforce structured model outputs or provide more information.
Alternatively, you can provide a vector of field names and their types (see `?generate_struct` function for the syntax).
- `verbose`: A boolean indicating whether to print additional information.
- `api_key`: A string representing the API key for accessing the OpenAI API.
- `model`: A string representing the model to use for generating the response. Can be an alias corresponding to a model ID defined in `MODEL_ALIASES`.
Expand Down Expand Up @@ -443,10 +444,32 @@ Note that when using a prompt template, we provide `data` for the extraction as
Note that the error message refers to a giraffe not being a human,
because in our `MyMeasurement` docstring, we said that it's for people!
Example of using a vector of field names with `aiextract`
```julia
fields = [:location, :temperature => Float64, :condition => String]
msg = aiextract("Extract the following information from the text: location, temperature, condition. Text: The weather in New York is sunny and 72.5 degrees Fahrenheit.";
return_type = fields, model="claudeh")
```
Or simply call `aiextract("some text"; return_type = [:reasoning,:answer], model="claudeh")` to get a Chain of Thought reasoning for extraction task.
It will be returned it a new generated type, which you can check with `PromptingTools.isextracted(msg.content) == true` to confirm the data has been extracted correctly.
This new syntax also allows you to provide field-level descriptions, which will be passed to the model.
```julia
fields_with_descriptions = [
:location,
:temperature => Float64,
:temperature__description => "Temperature in degrees Fahrenheit",
:condition => String,
:condition__description => "Current weather condition (e.g., sunny, rainy, cloudy)"
]
msg = aiextract("The weather in New York is sunny and 72.5 degrees Fahrenheit."; return_type = fields_with_descriptions, model="claudeh")
```
"""
function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_TYPE;
return_type::Type,
return_type::Union{Type, Vector},
verbose::Bool = true,
api_key::String = ANTHROPIC_API_KEY,
model::String = MODEL_CHAT,
Expand All @@ -465,7 +488,7 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP
model_id = get(MODEL_ALIASES, model, model)

## Tools definition
sig = function_call_signature(return_type; max_description_length = 100)
sig, datastructtype = function_call_signature(return_type; max_description_length = 100)
tools = [Dict("name" => sig["name"], "description" => get(sig, "description", ""),
"input_schema" => sig["parameters"])]
## update tools to use caching
Expand Down Expand Up @@ -493,7 +516,7 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP
## parse it into object
arguments = JSON3.write(contents[1][:input])
try
JSON3.read(arguments, return_type)
Base.invokelatest(JSON3.read, arguments, datastructtype)
catch e
@warn "There was an error parsing the response: $e. Using the raw response instead."
JSON3.read(arguments) |> copy
Expand Down
7 changes: 7 additions & 0 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,10 @@ function response_to_message(schema::AbstractPromptSchema,
sample_id::Union{Nothing, Integer} = nothing) where {T}
throw(ArgumentError("Response unwrapping not implemented for $(typeof(schema)) and $MSG"))
end

### For structured extraction
# We can generate fields, they will all share this parent type
abstract type AbstractExtractedData end
Base.show(io::IO, x::AbstractExtractedData) = dump(io, x; maxdepth = 1)
"Check if the object is an instance of `AbstractExtractedData`"
isextracted(x) = x isa AbstractExtractedData
Loading

2 comments on commit d04c737

@svilupp
Copy link
Owner Author

@svilupp svilupp commented on d04c737 Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Added

Added more flexible structured extraction with aiextract -> now you can simply provide the field names and, optionally, their types without specifying the struct itself (in aiextract, provide the fields like return_type = [:field_name => field_type]).
Added a way to attach field-level descriptions to the generated JSON schemas to better structured extraction (see ?update_schema_descriptions! to see the syntax), which was not possible with struct-only extraction.

Commits

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/114405

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.51.0 -m "<description of version>" d04c737c161f5b6e7a20da1668335fcdcf0a96c6
git push origin v0.51.0

Please sign in to comment.