-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
## Related Docs: - https://ai.google.dev/docs/gemini_api_overview - https://ai.google.dev/docs/function_calling ## How to use this provider ```yaml bridge: ai: server: addr: localhost:8000 provider: gemini providers: gemini: api_key: <your-api-key> ``` ## Be careful that 1. The data format describes in api doc is different from actual api. These can be found from the unit tests. 2. You can not set your function with `-` as this will break your api response. Google will drop characters of the function name in its response. So, I have to check this `-`, and replace to `_` if it presence.
- Loading branch information
1 parent
b5317dc
commit 2012a16
Showing
8 changed files
with
1,392 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
package gemini | ||
|
||
import ( | ||
"encoding/json" | ||
|
||
"github.com/yomorun/yomo/ai" | ||
"github.com/yomorun/yomo/core/ylog" | ||
) | ||
|
||
func convertStandardToFunctionDeclaration(functionDefinition *ai.FunctionDefinition) *FunctionDeclaration { | ||
if functionDefinition == nil { | ||
return nil | ||
} | ||
|
||
return &FunctionDeclaration{ | ||
Name: functionDefinition.Name, | ||
Description: functionDefinition.Description, | ||
Parameters: convertStandardToFunctionParameters(functionDefinition.Parameters), | ||
} | ||
} | ||
|
||
func convertFunctionDeclarationToStandard(functionDefinition *FunctionDeclaration) *ai.FunctionDefinition { | ||
if functionDefinition == nil { | ||
return nil | ||
} | ||
|
||
return &ai.FunctionDefinition{ | ||
Name: functionDefinition.Name, | ||
Description: functionDefinition.Description, | ||
Parameters: convertFunctionParametersToStandard(functionDefinition.Parameters), | ||
} | ||
} | ||
|
||
func convertStandardToFunctionParameters(parameters *ai.FunctionParameters) *FunctionParameters { | ||
if parameters == nil { | ||
return nil | ||
} | ||
|
||
return &FunctionParameters{ | ||
Type: parameters.Type, | ||
Properties: convertStandardToProperty(parameters.Properties), | ||
Required: parameters.Required, | ||
} | ||
} | ||
|
||
func convertFunctionParametersToStandard(parameters *FunctionParameters) *ai.FunctionParameters { | ||
if parameters == nil { | ||
return nil | ||
} | ||
|
||
return &ai.FunctionParameters{ | ||
Type: parameters.Type, | ||
Properties: convertPropertyToStandard(parameters.Properties), | ||
Required: parameters.Required, | ||
} | ||
} | ||
|
||
func convertStandardToProperty(properties map[string]*ai.ParameterProperty) map[string]*Property { | ||
if properties == nil { | ||
return nil | ||
} | ||
|
||
result := make(map[string]*Property) | ||
for k, v := range properties { | ||
result[k] = &Property{ | ||
Type: v.Type, | ||
Description: v.Description, | ||
} | ||
} | ||
return result | ||
} | ||
|
||
func convertPropertyToStandard(properties map[string]*Property) map[string]*ai.ParameterProperty { | ||
if properties == nil { | ||
return nil | ||
} | ||
|
||
result := make(map[string]*ai.ParameterProperty) | ||
for k, v := range properties { | ||
result[k] = &ai.ParameterProperty{ | ||
Type: v.Type, | ||
Description: v.Description, | ||
} | ||
} | ||
return result | ||
} | ||
|
||
// generateJSONSchemaArguments generates the JSON schema arguments from OpenAPI compatible arguments | ||
// https://ai.google.dev/docs/function_calling#how_it_works | ||
func generateJSONSchemaArguments(args map[string]interface{}) string { | ||
schema := make(map[string]interface{}) | ||
|
||
for k, v := range args { | ||
schema[k] = v | ||
} | ||
|
||
schemaJSON, err := json.Marshal(schema) | ||
if err != nil { | ||
return "" | ||
} | ||
|
||
return string(schemaJSON) | ||
} | ||
|
||
func parseAPIResponseBody(respBody []byte) (*Response, error) { | ||
var response *Response | ||
err := json.Unmarshal(respBody, &response) | ||
if err != nil { | ||
ylog.Error("parseAPIResponseBody", "err", err, "respBody", string(respBody)) | ||
return nil, err | ||
} | ||
return response, nil | ||
} | ||
|
||
func parseToolCallFromResponse(response *Response) []ai.ToolCall { | ||
calls := make([]ai.ToolCall, 0) | ||
for _, candidate := range response.Candidates { | ||
fn := candidate.Content.Parts[0].FunctionCall | ||
fd := &ai.FunctionDefinition{ | ||
Name: fn.Name, | ||
Arguments: generateJSONSchemaArguments(fn.Args), | ||
} | ||
call := ai.ToolCall{ | ||
ID: "cc-gemini-id", | ||
Type: "cc-function", | ||
Function: fd, | ||
} | ||
calls = append(calls, call) | ||
} | ||
return calls | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package gemini | ||
|
||
// RequestBody is the request body | ||
type RequestBody struct { | ||
Contents Contents `json:"contents"` | ||
Tools []Tool `json:"tools"` | ||
} | ||
|
||
// Contents is the contents in RequestBody | ||
type Contents struct { | ||
Role string `json:"role"` | ||
Parts Parts `json:"parts"` | ||
} | ||
|
||
// Parts is the contents.parts in RequestBody | ||
type Parts struct { | ||
Text string `json:"text"` | ||
} | ||
|
||
// Tool is the element of tools in RequestBody | ||
type Tool struct { | ||
FunctionDeclarations []*FunctionDeclaration `json:"function_declarations"` | ||
} | ||
|
||
// FunctionDeclaration is the element of Tool | ||
type FunctionDeclaration struct { | ||
Name string `json:"name"` | ||
Description string `json:"description"` | ||
Parameters *FunctionParameters `json:"parameters"` | ||
} | ||
|
||
// FunctionParameters is the parameters of FunctionDeclaration | ||
type FunctionParameters struct { | ||
Type string `json:"type"` | ||
Properties map[string]*Property `json:"properties"` | ||
Required []string `json:"required"` | ||
} | ||
|
||
// Property is the element of ParameterProperties | ||
type Property struct { | ||
Type string `json:"type"` | ||
Description string `json:"description"` | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package gemini | ||
|
||
type Response struct { | ||
Candidates []Candidate `json:"candidates"` | ||
PromptFeedback PromptFeedback `json:"promptFeedback"` | ||
// UsageMetadata UsageMetadata `json:"usageMetadata"` | ||
} | ||
|
||
// Candidate is the element of Response | ||
type Candidate struct { | ||
Content *CandidateContent `json:"content"` | ||
FinishReason string `json:"finishReason"` | ||
Index int `json:"index"` | ||
// SafetyRatings []CandidateSafetyRating `json:"safetyRatings"` | ||
} | ||
|
||
// CandidateContent is the content of Candidate | ||
type CandidateContent struct { | ||
Parts []*Part `json:"parts"` | ||
Role string `json:"role"` | ||
} | ||
|
||
// Part is the element of CandidateContent | ||
type Part struct { | ||
FunctionCall *FunctionCall `json:"functionCall"` | ||
} | ||
|
||
// FunctionCall is the functionCall of Part | ||
type FunctionCall struct { | ||
Name string `json:"name"` | ||
Args map[string]interface{} `json:"args"` | ||
} | ||
|
||
// CandidateSafetyRating is the safetyRatings of Candidate | ||
type CandidateSafetyRating struct { | ||
Category string `json:"category"` | ||
Probability string `json:"probability"` | ||
} | ||
|
||
// UsageMetadata is the token usage in Response | ||
type UsageMetadata struct { | ||
PromptTokenCount int `json:"promptTokenCount"` | ||
TotalTokenCount int `json:"totalTokenCount"` | ||
} | ||
|
||
// SafetyRating is the element of PromptFeedback | ||
type SafetyRating struct { | ||
Category string `json:"category"` | ||
Probability string `json:"probability"` | ||
} | ||
|
||
// PromptFeedback is the feedback of Prompt | ||
type PromptFeedback struct { | ||
SafetyRatings []*SafetyRating `json:"safetyRatings"` | ||
} |
Oops, something went wrong.