Skip to content

Commit

Permalink
feat: Handle refusal in OpenAI's Structured Outputs API (#533)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Aug 21, 2024
1 parent 68d8011 commit f4c4ed9
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 74 deletions.
53 changes: 4 additions & 49 deletions packages/langchain_openai/lib/src/chat_models/chat_openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,10 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
final ChatOpenAIOptions? options,
}) async {
final completion = await _client.createChatCompletion(
request: _createChatCompletionRequest(
request: createChatCompletionRequest(
input.toChatMessages(),
options: options,
defaultOptions: defaultOptions,
),
);
return completion.toChatResult(completion.id ?? _uuid.v4());
Expand All @@ -263,9 +264,10 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
}) {
return _client
.createChatCompletionStream(
request: _createChatCompletionRequest(
request: createChatCompletionRequest(
input.toChatMessages(),
options: options,
defaultOptions: defaultOptions,
stream: true,
),
)
Expand All @@ -275,53 +277,6 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
);
}

/// Creates a [CreateChatCompletionRequest] from the given input.
CreateChatCompletionRequest _createChatCompletionRequest(
final List<ChatMessage> messages, {
final ChatOpenAIOptions? options,
final bool stream = false,
}) {
final messagesDtos = messages.toChatCompletionMessages();
final toolsDtos =
(options?.tools ?? defaultOptions.tools)?.toChatCompletionTool();
final toolChoice = (options?.toolChoice ?? defaultOptions.toolChoice)
?.toChatCompletionToolChoice();
final responseFormatDto =
(options?.responseFormat ?? defaultOptions.responseFormat)
?.toChatCompletionResponseFormat();
final serviceTierDto = (options?.serviceTier ?? defaultOptions.serviceTier)
.toCreateChatCompletionRequestServiceTier();

return CreateChatCompletionRequest(
model: ChatCompletionModel.modelId(
options?.model ?? defaultOptions.model ?? defaultModel,
),
messages: messagesDtos,
tools: toolsDtos,
toolChoice: toolChoice,
frequencyPenalty:
options?.frequencyPenalty ?? defaultOptions.frequencyPenalty,
logitBias: options?.logitBias ?? defaultOptions.logitBias,
maxTokens: options?.maxTokens ?? defaultOptions.maxTokens,
n: options?.n ?? defaultOptions.n,
presencePenalty:
options?.presencePenalty ?? defaultOptions.presencePenalty,
responseFormat: responseFormatDto,
seed: options?.seed ?? defaultOptions.seed,
stop: (options?.stop ?? defaultOptions.stop) != null
? ChatCompletionStop.listString(options?.stop ?? defaultOptions.stop!)
: null,
temperature: options?.temperature ?? defaultOptions.temperature,
topP: options?.topP ?? defaultOptions.topP,
parallelToolCalls:
options?.parallelToolCalls ?? defaultOptions.parallelToolCalls,
serviceTier: serviceTierDto,
user: options?.user ?? defaultOptions.user,
streamOptions:
stream ? const ChatCompletionStreamOptions(includeUsage: true) : null,
);
}

/// Tokenizes the given prompt using tiktoken with the encoding used by the
/// [model]. If an encoding model is specified in [encoding] field, that
/// encoding is used instead.
Expand Down
128 changes: 103 additions & 25 deletions packages/langchain_openai/lib/src/chat_models/mappers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,93 @@ import 'package:langchain_core/language_models.dart';
import 'package:langchain_core/tools.dart';
import 'package:openai_dart/openai_dart.dart';

import 'chat_openai.dart';
import 'types.dart';

/// Creates a [CreateChatCompletionRequest] from the given input.
CreateChatCompletionRequest createChatCompletionRequest(
final List<ChatMessage> messages, {
required final ChatOpenAIOptions? options,
required final ChatOpenAIOptions defaultOptions,
final bool stream = false,
}) {
final messagesDtos = messages.toChatCompletionMessages();
final toolsDtos =
(options?.tools ?? defaultOptions.tools)?.toChatCompletionTool();
final toolChoice = (options?.toolChoice ?? defaultOptions.toolChoice)
?.toChatCompletionToolChoice();
final responseFormatDto =
(options?.responseFormat ?? defaultOptions.responseFormat)
?.toChatCompletionResponseFormat();
final serviceTierDto = (options?.serviceTier ?? defaultOptions.serviceTier)
.toCreateChatCompletionRequestServiceTier();

return CreateChatCompletionRequest(
model: ChatCompletionModel.modelId(
options?.model ?? defaultOptions.model ?? ChatOpenAI.defaultModel,
),
messages: messagesDtos,
tools: toolsDtos,
toolChoice: toolChoice,
frequencyPenalty:
options?.frequencyPenalty ?? defaultOptions.frequencyPenalty,
logitBias: options?.logitBias ?? defaultOptions.logitBias,
maxTokens: options?.maxTokens ?? defaultOptions.maxTokens,
n: options?.n ?? defaultOptions.n,
presencePenalty: options?.presencePenalty ?? defaultOptions.presencePenalty,
responseFormat: responseFormatDto,
seed: options?.seed ?? defaultOptions.seed,
stop: (options?.stop ?? defaultOptions.stop) != null
? ChatCompletionStop.listString(options?.stop ?? defaultOptions.stop!)
: null,
temperature: options?.temperature ?? defaultOptions.temperature,
topP: options?.topP ?? defaultOptions.topP,
parallelToolCalls:
options?.parallelToolCalls ?? defaultOptions.parallelToolCalls,
serviceTier: serviceTierDto,
user: options?.user ?? defaultOptions.user,
streamOptions:
stream ? const ChatCompletionStreamOptions(includeUsage: true) : null,
);
}

extension ChatMessageListMapper on List<ChatMessage> {
List<ChatCompletionMessage> toChatCompletionMessages() {
return map(_mapMessage).toList(growable: false);
}

ChatCompletionMessage _mapMessage(final ChatMessage msg) {
return switch (msg) {
final SystemChatMessage systemChatMessage => ChatCompletionMessage.system(
content: systemChatMessage.content,
),
final HumanChatMessage humanChatMessage => ChatCompletionMessage.user(
content: switch (humanChatMessage.content) {
final ChatMessageContentText c => _mapMessageContentString(c),
final ChatMessageContentImage c =>
ChatCompletionUserMessageContent.parts(
[_mapMessageContentPartImage(c)],
),
final ChatMessageContentMultiModal c => _mapMessageContentPart(c),
},
),
final AIChatMessage aiChatMessage => ChatCompletionMessage.assistant(
content: aiChatMessage.content,
toolCalls: aiChatMessage.toolCalls.isNotEmpty
? aiChatMessage.toolCalls
.map(_mapMessageToolCall)
.toList(growable: false)
: null,
),
final ToolChatMessage toolChatMessage => ChatCompletionMessage.tool(
toolCallId: toolChatMessage.toolCallId,
content: toolChatMessage.content,
),
final SystemChatMessage msg => _mapSystemMessage(msg),
final HumanChatMessage msg => _mapHumanMessage(msg),
final AIChatMessage msg => _mapAIMessage(msg),
final ToolChatMessage msg => _mapToolMessage(msg),
CustomChatMessage() =>
throw UnsupportedError('OpenAI does not support custom messages'),
};
}

ChatCompletionMessage _mapSystemMessage(
final SystemChatMessage systemChatMessage,
) {
return ChatCompletionMessage.system(content: systemChatMessage.content);
}

ChatCompletionMessage _mapHumanMessage(
final HumanChatMessage humanChatMessage,
) {
return ChatCompletionMessage.user(
content: switch (humanChatMessage.content) {
final ChatMessageContentText c => _mapMessageContentString(c),
final ChatMessageContentImage c =>
ChatCompletionUserMessageContent.parts(
[_mapMessageContentPartImage(c)],
),
final ChatMessageContentMultiModal c => _mapMessageContentPart(c),
},
);
}

ChatCompletionUserMessageContentString _mapMessageContentString(
final ChatMessageContentText c,
) {
Expand Down Expand Up @@ -105,6 +153,17 @@ extension ChatMessageListMapper on List<ChatMessage> {
return ChatCompletionMessageContentParts(partsList);
}

ChatCompletionMessage _mapAIMessage(final AIChatMessage aiChatMessage) {
return ChatCompletionMessage.assistant(
content: aiChatMessage.content,
toolCalls: aiChatMessage.toolCalls.isNotEmpty
? aiChatMessage.toolCalls
.map(_mapMessageToolCall)
.toList(growable: false)
: null,
);
}

ChatCompletionMessageToolCall _mapMessageToolCall(
final AIChatMessageToolCall toolCall,
) {
Expand All @@ -117,12 +176,26 @@ extension ChatMessageListMapper on List<ChatMessage> {
),
);
}

ChatCompletionMessage _mapToolMessage(
final ToolChatMessage toolChatMessage,
) {
return ChatCompletionMessage.tool(
toolCallId: toolChatMessage.toolCallId,
content: toolChatMessage.content,
);
}
}

extension CreateChatCompletionResponseMapper on CreateChatCompletionResponse {
ChatResult toChatResult(final String id) {
final choice = choices.first;
final msg = choice.message;

if (msg.refusal != null && msg.refusal!.isNotEmpty) {
throw OpenAIRefusalException(msg.refusal!);
}

return ChatResult(
id: id,
output: AIChatMessage(
Expand Down Expand Up @@ -211,6 +284,11 @@ extension CreateChatCompletionStreamResponseMapper
ChatResult toChatResult(final String id) {
final choice = choices.firstOrNull;
final delta = choice?.delta;

if (delta?.refusal != null && delta!.refusal!.isNotEmpty) {
throw OpenAIRefusalException(delta.refusal!);
}

return ChatResult(
id: id,
output: AIChatMessage(
Expand Down
22 changes: 22 additions & 0 deletions packages/langchain_openai/lib/src/chat_models/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,25 @@ enum ChatOpenAIServiceTier {
/// uptime SLA and no latency guarantee.
vDefault,
}

/// {@template openai_refusal_exception}
/// Exception thrown when OpenAI Structured Outputs API returns a refusal.
///
/// When using OpenAI's Structured Outputs API with user-generated input, the
/// model may occasionally refuse to fulfill the request for safety reasons.
///
/// See here for more on refusals:
/// https://platform.openai.com/docs/guides/structured-outputs/refusals
/// {@endtemplate}
class OpenAIRefusalException implements Exception {
/// {@macro openai_refusal_exception}
const OpenAIRefusalException(this.message);

/// The refusal message.
final String message;

@override
String toString() {
return 'OpenAIRefusalException: $message';
}
}

0 comments on commit f4c4ed9

Please sign in to comment.