Skip to content

Commit

Permalink
feat: Support OpenAI's strict mode for tool calling in ChatOpenAI (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Aug 21, 2024
1 parent 1834b3a commit 71623f4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
40 changes: 35 additions & 5 deletions packages/langchain_core/lib/src/tools/base.dart
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ToolSpec {
required this.name,
required this.description,
required this.inputJsonSchema,
this.strict = false,
});

/// The unique name of the tool that clearly communicates its purpose.
Expand Down Expand Up @@ -50,18 +51,31 @@ class ToolSpec {
/// ```
final Map<String, dynamic> inputJsonSchema;

/// Whether to enable strict schema adherence when generating the tool call.
/// If set to true, the model will follow the exact schema defined in the
/// [inputJsonSchema] field.
///
/// This is only supported by some providers (e.g. OpenAI). Mind that when
/// enabled, only a subset of JSON Schema may be supported. Check out the
/// provider's tool calling documentation for more information.
final bool strict;

@override
bool operator ==(covariant final ToolSpec other) {
final mapEquals = const DeepCollectionEquality().equals;
return identical(this, other) ||
name == other.name &&
description == other.description &&
mapEquals(inputJsonSchema, other.inputJsonSchema);
mapEquals(inputJsonSchema, other.inputJsonSchema) &&
strict == other.strict;
}

@override
int get hashCode =>
name.hashCode ^ description.hashCode ^ inputJsonSchema.hashCode;
name.hashCode ^
description.hashCode ^
inputJsonSchema.hashCode ^
strict.hashCode;

@override
String toString() {
Expand All @@ -70,6 +84,7 @@ ToolSpec{
name: $name,
description: $description,
inputJsonSchema: $inputJsonSchema,
strict: $strict,
}
''';
}
Expand All @@ -80,6 +95,7 @@ ToolSpec{
'name': name,
'description': description,
'inputJsonSchema': inputJsonSchema,
'strict': strict,
};
}
}
Expand All @@ -102,6 +118,7 @@ abstract base class Tool<Input extends Object, Options extends ToolOptions,
required this.name,
required this.description,
required this.inputJsonSchema,
this.strict = false,
this.returnDirect = false,
this.handleToolError,
final Options? defaultOptions,
Expand All @@ -118,6 +135,9 @@ abstract base class Tool<Input extends Object, Options extends ToolOptions,
@override
final Map<String, dynamic> inputJsonSchema;

@override
final bool strict;

/// Whether to return the tool's output directly.
/// Setting this to true means that after the tool is called,
/// the AgentExecutor will stop looping.
Expand All @@ -132,7 +152,9 @@ abstract base class Tool<Input extends Object, Options extends ToolOptions,
/// purpose.
/// - [description] is used to tell the model how/when/why to use the tool.
/// You can provide few-shot examples as a part of the description.
/// - [inputJsonSchema] is the schema to parse and validate tool's input
/// - [inputJsonSchema] is the schema to parse and validate tool's input.
/// - [strict] whether to enable strict schema adherence when generating the
/// tool call (only supported by some providers).
/// - [func] is the function that will be called when the tool is run.
/// arguments.
/// - [getInputFromJson] is a function that parses the input JSON to the
Expand All @@ -148,6 +170,7 @@ abstract base class Tool<Input extends Object, Options extends ToolOptions,
required final String name,
required final String description,
required final Map<String, dynamic> inputJsonSchema,
final bool strict = false,
required final FutureOr<Output> Function(Input input) func,
Input Function(Map<String, dynamic> json)? getInputFromJson,
final bool returnDirect = false,
Expand All @@ -157,6 +180,7 @@ abstract base class Tool<Input extends Object, Options extends ToolOptions,
name: name,
description: description,
inputJsonSchema: inputJsonSchema,
strict: strict,
function: func,
getInputFromJson: getInputFromJson ?? (json) => json['input'] as Input,
returnDirect: returnDirect,
Expand Down Expand Up @@ -217,19 +241,24 @@ abstract base class Tool<Input extends Object, Options extends ToolOptions,
return identical(this, other) ||
name == other.name &&
description == other.description &&
mapEquals(inputJsonSchema, other.inputJsonSchema);
mapEquals(inputJsonSchema, other.inputJsonSchema) &&
strict == other.strict;
}

@override
int get hashCode =>
name.hashCode ^ description.hashCode ^ inputJsonSchema.hashCode;
name.hashCode ^
description.hashCode ^
inputJsonSchema.hashCode ^
strict.hashCode;

@override
Map<String, dynamic> toJson() {
return {
'name': name,
'description': description,
'inputJsonSchema': inputJsonSchema,
'strict': strict,
};
}
}
Expand All @@ -245,6 +274,7 @@ final class _ToolFunc<Input extends Object, Output extends Object>
required super.name,
required super.description,
required super.inputJsonSchema,
required super.strict,
required FutureOr<Output> Function(Input input) function,
required Input Function(Map<String, dynamic> json) getInputFromJson,
super.returnDirect = false,
Expand Down
6 changes: 6 additions & 0 deletions packages/langchain_core/lib/src/tools/string.dart
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ abstract base class StringTool<Options extends ToolOptions>
required super.name,
required super.description,
final String inputDescription = 'The input to the tool',
super.strict = false,
super.returnDirect = false,
super.handleToolError,
super.defaultOptions,
Expand All @@ -36,6 +37,8 @@ abstract base class StringTool<Options extends ToolOptions>
/// purpose.
/// - [description] is used to tell the model how/when/why to use the tool.
/// You can provide few-shot examples as a part of the description.
/// - [strict] whether to enable strict schema adherence when generating the
/// tool call (only supported by some providers).
/// - [func] is the function that will be called when the tool is run.
/// - [returnDirect] whether to return the tool's output directly.
/// Setting this to true means that after the tool is called,
Expand All @@ -46,6 +49,7 @@ abstract base class StringTool<Options extends ToolOptions>
required final String name,
required final String description,
final String inputDescription = 'The input to the tool',
final bool strict = false,
required final FutureOr<String> Function(String input) func,
final bool returnDirect = false,
final String Function(ToolException)? handleToolError,
Expand All @@ -54,6 +58,7 @@ abstract base class StringTool<Options extends ToolOptions>
name: name,
description: description,
inputDescription: inputDescription,
strict: strict,
func: func,
returnDirect: returnDirect,
handleToolError: handleToolError,
Expand Down Expand Up @@ -84,6 +89,7 @@ final class _StringToolFunc<Options extends ToolOptions>
required super.name,
required super.description,
super.inputDescription,
required super.strict,
required FutureOr<String> Function(String) func,
super.returnDirect = false,
super.handleToolError,
Expand Down

0 comments on commit 71623f4

Please sign in to comment.