Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for code execution in ChatGoogleGenerativeAI #564

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/modules/model_io/models/chat_models/integrations/googleai.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,32 @@ final res = await model.invoke(
PromptValue.string('What’s the weather like in Boston and Madrid right now in celsius?'),
);
```

## Code execution

`ChatGoogleGenerativeAI` supports [code execution](https://ai.google.dev/gemini-api/docs/code-execution?lang=python#billing), just set `enableCodeExecution` to `true` in the options.

```dart
final chatModel = ChatGoogleGenerativeAI(
apiKey: apiKey,
defaultOptions: ChatGoogleGenerativeAIOptions(
model: 'gemini-1.5-flash',
enableCodeExecution: true,
),
);
final res = await chatModel.invoke(
PromptValue.string(
'Calculate the fibonacci sequence up to 10 terms. '
'Return only the last term without explanations.',
),
);

final text = res.output.content;
print(text); // 34

final executableCode = res.metadata['executable_code'] as String;
print(executableCode);

final codeExecutionResult = res.metadata['code_execution_result'] as Map<String, dynamic>;
print(codeExecutionResult);
```
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ void main(final List<String> arguments) async {
await _chatGoogleGenerativeAI();
await _chatGoogleGenerativeAIMultiModal();
await _chatOpenAIStreaming();
await _codeExecution();
}

Future<void> _chatGoogleGenerativeAI() async {
Expand Down Expand Up @@ -105,3 +106,29 @@ Future<void> _chatOpenAIStreaming() async {

chatModel.close();
}

Future<void> _codeExecution() async {
final apiKey = Platform.environment['GOOGLEAI_API_KEY'];

final chatModel = ChatGoogleGenerativeAI(
apiKey: apiKey,
defaultOptions: const ChatGoogleGenerativeAIOptions(
model: 'gemini-1.5-flash',
enableCodeExecution: true,
),
);

final res = await chatModel.invoke(
PromptValue.string(
'Calculate the fibonacci sequence up to 10 terms. '
'Return only the last term without explanations.',
),
);
final text = res.output.content;
print(text); // 34
final executableCode = res.metadata['executable_code'] as String;
print(executableCode);
final codeExecutionResult =
res.metadata['code_execution_result'] as Map<String, dynamic>;
print(codeExecutionResult);
}
4 changes: 2 additions & 2 deletions examples/docs_examples/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ packages:
dependency: transitive
description:
name: google_generative_ai
sha256: e2f4c0ac13f0898f670ce5ac0dc4501ebe09b96f9d59163724380d9aa82065be
sha256: "81dae159c89e4d9bdc46955b6f4ee5ae0a291f9e8f990d76f43944e0d6041d4f"
url: "https://pub.dev"
source: hosted
version: "0.4.4"
version: "0.4.6"
google_identity_services_web:
dependency: transitive
description:
Expand Down
4 changes: 2 additions & 2 deletions examples/hello_world_flutter/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ packages:
dependency: transitive
description:
name: google_generative_ai
sha256: e2f4c0ac13f0898f670ce5ac0dc4501ebe09b96f9d59163724380d9aa82065be
sha256: "81dae159c89e4d9bdc46955b6f4ee5ae0a291f9e8f990d76f43944e0d6041d4f"
url: "https://pub.dev"
source: hosted
version: "0.4.4"
version: "0.4.6"
google_identity_services_web:
dependency: transitive
description:
Expand Down
2 changes: 1 addition & 1 deletion melos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ command:
flutter_markdown: ^0.7.3
freezed_annotation: ^2.4.2
gcloud: ^0.8.13
google_generative_ai: 0.4.4
google_generative_ai: ^0.4.6
googleapis: ^13.0.0
googleapis_auth: ^1.6.0
http: ^1.2.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,11 @@ class ChatGoogleGenerativeAI
(options?.responseSchema ?? defaultOptions.responseSchema)
?.toSchema(),
),
(options?.tools ?? defaultOptions.tools)?.toToolList(),
(options?.tools ?? defaultOptions.tools).toToolList(
enableCodeExecution: options?.enableCodeExecution ??
defaultOptions.enableCodeExecution ??
false,
),
(options?.toolChoice ?? defaultOptions.toolChoice)?.toToolConfig(),
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ extension GenerateContentResponseMapper on g.GenerateContentResponse {
final g.DataPart p => base64Encode(p.bytes),
final g.FilePart p => p.uri.toString(),
g.FunctionResponse() || g.FunctionCall() => '',
g.ExecutableCode() => '',
g.CodeExecutionResult() => '',
_ => throw AssertionError('Unknown part type: $p'),
},
)
Expand Down Expand Up @@ -133,6 +135,14 @@ extension GenerateContentResponseMapper on g.GenerateContentResponse {
)
.toList(growable: false),
'finish_message': candidate.finishMessage,
'executable_code': candidate.content.parts
.whereType<g.ExecutableCode>()
.map((code) => code.toJson())
.toList(growable: false),
'code_execution_result': candidate.content.parts
.whereType<g.CodeExecutionResult>()
.map((result) => result.toJson())
.toList(growable: false),
},
usage: LanguageModelUsage(
promptTokens: usageMetadata?.promptTokenCount,
Expand Down Expand Up @@ -189,17 +199,24 @@ extension SafetySettingsMapper on List<ChatGoogleGenerativeAISafetySetting> {
}
}

extension ChatToolListMapper on List<ToolSpec> {
List<g.Tool> toToolList() {
extension ChatToolListMapper on List<ToolSpec>? {
List<g.Tool>? toToolList({required final bool enableCodeExecution}) {
if (this == null && !enableCodeExecution) {
return null;
}

return [
g.Tool(
functionDeclarations: map(
(tool) => g.FunctionDeclaration(
tool.name,
tool.description,
tool.inputJsonSchema.toSchema(),
),
).toList(growable: false),
functionDeclarations: this
?.map(
(tool) => g.FunctionDeclaration(
tool.name,
tool.description,
tool.inputJsonSchema.toSchema(),
),
)
.toList(growable: false),
codeExecution: enableCodeExecution ? g.CodeExecution() : null,
),
];
}
Expand Down
12 changes: 12 additions & 0 deletions packages/langchain_google/lib/src/chat_models/google_ai/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
this.responseMimeType,
this.responseSchema,
this.safetySettings,
this.enableCodeExecution,
super.tools,
super.toolChoice,
super.concurrencyLimit,
Expand Down Expand Up @@ -115,6 +116,12 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
/// the default safety setting for that category.
final List<ChatGoogleGenerativeAISafetySetting>? safetySettings;

/// When code execution is enabled the model may generate code and run it in the
/// process of generating a response to the prompt. When this happens the code
/// that was executed and it's output will be included in the response metadata
/// as `metadata['executable_code']` and `metadata['code_execution_result']`.
final bool? enableCodeExecution;

@override
ChatGoogleGenerativeAIOptions copyWith({
final String? model,
Expand All @@ -125,6 +132,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
final double? temperature,
final List<String>? stopSequences,
final List<ChatGoogleGenerativeAISafetySetting>? safetySettings,
final bool? enableCodeExecution,
final List<ToolSpec>? tools,
final ChatToolChoice? toolChoice,
final int? concurrencyLimit,
Expand All @@ -138,6 +146,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
temperature: temperature ?? this.temperature,
stopSequences: stopSequences ?? this.stopSequences,
safetySettings: safetySettings ?? this.safetySettings,
enableCodeExecution: enableCodeExecution ?? this.enableCodeExecution,
tools: tools ?? this.tools,
toolChoice: toolChoice ?? this.toolChoice,
concurrencyLimit: concurrencyLimit ?? this.concurrencyLimit,
Expand All @@ -157,6 +166,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
temperature: other?.temperature,
stopSequences: other?.stopSequences,
safetySettings: other?.safetySettings,
enableCodeExecution: other?.enableCodeExecution,
tools: other?.tools,
toolChoice: other?.toolChoice,
concurrencyLimit: other?.concurrencyLimit,
Expand All @@ -173,6 +183,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
temperature == other.temperature &&
stopSequences == other.stopSequences &&
safetySettings == other.safetySettings &&
enableCodeExecution == other.enableCodeExecution &&
tools == other.tools &&
toolChoice == other.toolChoice &&
concurrencyLimit == other.concurrencyLimit;
Expand All @@ -188,6 +199,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
temperature.hashCode ^
stopSequences.hashCode ^
safetySettings.hashCode ^
enableCodeExecution.hashCode ^
tools.hashCode ^
toolChoice.hashCode ^
concurrencyLimit.hashCode;
Expand Down
2 changes: 1 addition & 1 deletion packages/langchain_google/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies:
collection: ^1.18.0
fetch_client: ^1.1.2
gcloud: ^0.8.13
google_generative_ai: 0.4.4
google_generative_ai: ^0.4.6
googleapis: ^13.0.0
googleapis_auth: ^1.6.0
http: ^1.2.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,5 +248,21 @@ void main() {
expect(aiMessage2.content, contains('22'));
expect(aiMessage2.content, contains('25'));
});

test('Test code execution', () async {
final res = await chatModel.invoke(
PromptValue.string(
'Calculate the fibonacci sequence up to 10 terms and output the last one.',
),
options: const ChatGoogleGenerativeAIOptions(
model: 'gemini-1.5-flash',
enableCodeExecution: true,
),
);
final text = res.output.content;
expect(text, contains('34'));
expect(res.metadata['executable_code'], isNotNull);
expect(res.metadata['code_execution_result'], isNotNull);
});
});
}
Loading