Skip to content

Commit

Permalink
feat: Add support for code execution in ChatGoogleGenerativeAI (#564)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Oct 8, 2024
1 parent 9f7406f commit 020bc09
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 16 deletions.
29 changes: 29 additions & 0 deletions docs/modules/model_io/models/chat_models/integrations/googleai.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,32 @@ final res = await model.invoke(
PromptValue.string('What’s the weather like in Boston and Madrid right now in celsius?'),
);
```

## Code execution

`ChatGoogleGenerativeAI` supports [code execution](https://ai.google.dev/gemini-api/docs/code-execution?lang=python#billing), just set `enableCodeExecution` to `true` in the options.

```dart
final chatModel = ChatGoogleGenerativeAI(
apiKey: apiKey,
defaultOptions: ChatGoogleGenerativeAIOptions(
model: 'gemini-1.5-flash',
enableCodeExecution: true,
),
);
final res = await chatModel.invoke(
PromptValue.string(
'Calculate the fibonacci sequence up to 10 terms. '
'Return only the last term without explanations.',
),
);
final text = res.output.content;
print(text); // 34
final executableCode = res.metadata['executable_code'] as String;
print(executableCode);
final codeExecutionResult = res.metadata['code_execution_result'] as Map<String, dynamic>;
print(codeExecutionResult);
```
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ void main(final List<String> arguments) async {
await _chatGoogleGenerativeAI();
await _chatGoogleGenerativeAIMultiModal();
await _chatOpenAIStreaming();
await _codeExecution();
}

Future<void> _chatGoogleGenerativeAI() async {
Expand Down Expand Up @@ -105,3 +106,29 @@ Future<void> _chatOpenAIStreaming() async {

chatModel.close();
}

Future<void> _codeExecution() async {
final apiKey = Platform.environment['GOOGLEAI_API_KEY'];

final chatModel = ChatGoogleGenerativeAI(
apiKey: apiKey,
defaultOptions: const ChatGoogleGenerativeAIOptions(
model: 'gemini-1.5-flash',
enableCodeExecution: true,
),
);

final res = await chatModel.invoke(
PromptValue.string(
'Calculate the fibonacci sequence up to 10 terms. '
'Return only the last term without explanations.',
),
);
final text = res.output.content;
print(text); // 34
final executableCode = res.metadata['executable_code'] as String;
print(executableCode);
final codeExecutionResult =
res.metadata['code_execution_result'] as Map<String, dynamic>;
print(codeExecutionResult);
}
4 changes: 2 additions & 2 deletions examples/docs_examples/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ packages:
dependency: transitive
description:
name: google_generative_ai
sha256: e2f4c0ac13f0898f670ce5ac0dc4501ebe09b96f9d59163724380d9aa82065be
sha256: "81dae159c89e4d9bdc46955b6f4ee5ae0a291f9e8f990d76f43944e0d6041d4f"
url: "https://pub.dev"
source: hosted
version: "0.4.4"
version: "0.4.6"
google_identity_services_web:
dependency: transitive
description:
Expand Down
4 changes: 2 additions & 2 deletions examples/hello_world_flutter/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ packages:
dependency: transitive
description:
name: google_generative_ai
sha256: e2f4c0ac13f0898f670ce5ac0dc4501ebe09b96f9d59163724380d9aa82065be
sha256: "81dae159c89e4d9bdc46955b6f4ee5ae0a291f9e8f990d76f43944e0d6041d4f"
url: "https://pub.dev"
source: hosted
version: "0.4.4"
version: "0.4.6"
google_identity_services_web:
dependency: transitive
description:
Expand Down
2 changes: 1 addition & 1 deletion melos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ command:
flutter_markdown: ^0.7.3
freezed_annotation: ^2.4.2
gcloud: ^0.8.13
google_generative_ai: 0.4.4
google_generative_ai: ^0.4.6
googleapis: ^13.0.0
googleapis_auth: ^1.6.0
http: ^1.2.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,11 @@ class ChatGoogleGenerativeAI
(options?.responseSchema ?? defaultOptions.responseSchema)
?.toSchema(),
),
(options?.tools ?? defaultOptions.tools)?.toToolList(),
(options?.tools ?? defaultOptions.tools).toToolList(
enableCodeExecution: options?.enableCodeExecution ??
defaultOptions.enableCodeExecution ??
false,
),
(options?.toolChoice ?? defaultOptions.toolChoice)?.toToolConfig(),
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ extension GenerateContentResponseMapper on g.GenerateContentResponse {
final g.DataPart p => base64Encode(p.bytes),
final g.FilePart p => p.uri.toString(),
g.FunctionResponse() || g.FunctionCall() => '',
g.ExecutableCode() => '',
g.CodeExecutionResult() => '',
_ => throw AssertionError('Unknown part type: $p'),
},
)
Expand Down Expand Up @@ -133,6 +135,14 @@ extension GenerateContentResponseMapper on g.GenerateContentResponse {
)
.toList(growable: false),
'finish_message': candidate.finishMessage,
'executable_code': candidate.content.parts
.whereType<g.ExecutableCode>()
.map((code) => code.toJson())
.toList(growable: false),
'code_execution_result': candidate.content.parts
.whereType<g.CodeExecutionResult>()
.map((result) => result.toJson())
.toList(growable: false),
},
usage: LanguageModelUsage(
promptTokens: usageMetadata?.promptTokenCount,
Expand Down Expand Up @@ -189,17 +199,24 @@ extension SafetySettingsMapper on List<ChatGoogleGenerativeAISafetySetting> {
}
}

extension ChatToolListMapper on List<ToolSpec> {
List<g.Tool> toToolList() {
extension ChatToolListMapper on List<ToolSpec>? {
List<g.Tool>? toToolList({required final bool enableCodeExecution}) {
if (this == null && !enableCodeExecution) {
return null;
}

return [
g.Tool(
functionDeclarations: map(
(tool) => g.FunctionDeclaration(
tool.name,
tool.description,
tool.inputJsonSchema.toSchema(),
),
).toList(growable: false),
functionDeclarations: this
?.map(
(tool) => g.FunctionDeclaration(
tool.name,
tool.description,
tool.inputJsonSchema.toSchema(),
),
)
.toList(growable: false),
codeExecution: enableCodeExecution ? g.CodeExecution() : null,
),
];
}
Expand Down
12 changes: 12 additions & 0 deletions packages/langchain_google/lib/src/chat_models/google_ai/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
this.responseMimeType,
this.responseSchema,
this.safetySettings,
this.enableCodeExecution,
super.tools,
super.toolChoice,
super.concurrencyLimit,
Expand Down Expand Up @@ -115,6 +116,12 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
/// the default safety setting for that category.
final List<ChatGoogleGenerativeAISafetySetting>? safetySettings;

/// When code execution is enabled the model may generate code and run it in the
/// process of generating a response to the prompt. When this happens the code
/// that was executed and it's output will be included in the response metadata
/// as `metadata['executable_code']` and `metadata['code_execution_result']`.
final bool? enableCodeExecution;

@override
ChatGoogleGenerativeAIOptions copyWith({
final String? model,
Expand All @@ -125,6 +132,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
final double? temperature,
final List<String>? stopSequences,
final List<ChatGoogleGenerativeAISafetySetting>? safetySettings,
final bool? enableCodeExecution,
final List<ToolSpec>? tools,
final ChatToolChoice? toolChoice,
final int? concurrencyLimit,
Expand All @@ -138,6 +146,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
temperature: temperature ?? this.temperature,
stopSequences: stopSequences ?? this.stopSequences,
safetySettings: safetySettings ?? this.safetySettings,
enableCodeExecution: enableCodeExecution ?? this.enableCodeExecution,
tools: tools ?? this.tools,
toolChoice: toolChoice ?? this.toolChoice,
concurrencyLimit: concurrencyLimit ?? this.concurrencyLimit,
Expand All @@ -157,6 +166,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
temperature: other?.temperature,
stopSequences: other?.stopSequences,
safetySettings: other?.safetySettings,
enableCodeExecution: other?.enableCodeExecution,
tools: other?.tools,
toolChoice: other?.toolChoice,
concurrencyLimit: other?.concurrencyLimit,
Expand All @@ -173,6 +183,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
temperature == other.temperature &&
stopSequences == other.stopSequences &&
safetySettings == other.safetySettings &&
enableCodeExecution == other.enableCodeExecution &&
tools == other.tools &&
toolChoice == other.toolChoice &&
concurrencyLimit == other.concurrencyLimit;
Expand All @@ -188,6 +199,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
temperature.hashCode ^
stopSequences.hashCode ^
safetySettings.hashCode ^
enableCodeExecution.hashCode ^
tools.hashCode ^
toolChoice.hashCode ^
concurrencyLimit.hashCode;
Expand Down
2 changes: 1 addition & 1 deletion packages/langchain_google/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies:
collection: ^1.18.0
fetch_client: ^1.1.2
gcloud: ^0.8.13
google_generative_ai: 0.4.4
google_generative_ai: ^0.4.6
googleapis: ^13.0.0
googleapis_auth: ^1.6.0
http: ^1.2.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,5 +248,21 @@ void main() {
expect(aiMessage2.content, contains('22'));
expect(aiMessage2.content, contains('25'));
});

test('Test code execution', () async {
final res = await chatModel.invoke(
PromptValue.string(
'Calculate the fibonacci sequence up to 10 terms and output the last one.',
),
options: const ChatGoogleGenerativeAIOptions(
model: 'gemini-1.5-flash',
enableCodeExecution: true,
),
);
final text = res.output.content;
expect(text, contains('34'));
expect(res.metadata['executable_code'], isNotNull);
expect(res.metadata['code_execution_result'], isNotNull);
});
});
}

0 comments on commit 020bc09

Please sign in to comment.