Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ollama LLM provider tools support #14623

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 95 additions & 29 deletions packages/ai-ollama/src/node/ollama-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
LanguageModelRequest,
LanguageModelRequestMessage,
LanguageModelResponse,
LanguageModelStreamResponse,
LanguageModelStreamResponsePart,
ToolRequest
} from '@theia/ai-core';
Expand Down Expand Up @@ -61,51 +62,84 @@ export class OllamaModel implements LanguageModel {
const settings = this.getSettings(request);
const ollama = this.initializeOllama();

if (request.response_format?.type === 'json_schema') {
return this.handleStructuredOutputRequest(ollama, request);
}
const response = await ollama.chat({
const ollamaRequest: ExtendedChatRequest = {
model: this.model,
...this.DEFAULT_REQUEST_SETTINGS,
...settings,
messages: request.messages.map(this.toOllamaMessage),
stream: true,
tools: request.tools?.map(this.toOllamaTool),
tools: request.tools?.map(this.toOllamaTool)
};
const structured = request.response_format?.type === 'json_schema';
return this.dispatchRequest(ollama, ollamaRequest, structured, cancellationToken);
}

protected async dispatchRequest(ollama: Ollama, ollamaRequest: ExtendedChatRequest, structured: boolean, cancellation?: CancellationToken): Promise<LanguageModelResponse> {

// Handle structured output request
if (structured) {
return this.handleStructuredOutputRequest(ollama, ollamaRequest);
}

// Handle tool request - response may call tools
if (ollamaRequest.tools && ollamaRequest.tools?.length > 0) {
return this.handleToolsRequest(ollama, ollamaRequest);
}

// Handle standard chat request
const response = await ollama.chat({
...ollamaRequest,
stream: true
});
return this.handleCancellationAndWrapIterator(response, cancellation);
}

cancellationToken?.onCancellationRequested(() => {
response.abort();
protected async handleToolsRequest(ollama: Ollama, chatRequest: ExtendedChatRequest): Promise<LanguageModelResponse> {
const response = await ollama.chat({
...chatRequest,
stream: false
});

async function* wrapAsyncIterator<T>(inputIterable: AsyncIterable<ChatResponse>): AsyncIterable<LanguageModelStreamResponsePart> {
for await (const item of inputIterable) {
// TODO handle tool calls
yield { content: item.message.content };
const tools: ToolWithHandler[] = chatRequest.tools ?? [];
if (response.message.tool_calls) {
for (const toolCall of response.message.tool_calls) {
const functionToCall = tools.find(tool => tool.function.name === toolCall.function?.name);
if (functionToCall) {
const args = JSON.stringify(toolCall.function?.arguments);
const funcResult = await functionToCall.handler(args);
chatRequest.messages.push(response.message);
chatRequest.messages.push({
role: 'tool',
content: String(funcResult),
});
}
}
// Get final response from model with function outputs
const finalResponse = await ollama.chat({ ...chatRequest, stream: false });
if (finalResponse.message.tool_calls) {
// Recursive tools call
return this.handleToolsRequest(ollama, chatRequest);
}
return { text: finalResponse.message.content };
}
return { stream: wrapAsyncIterator(response) };
return { text: response.message.content };
}

protected async handleStructuredOutputRequest(ollama: Ollama, request: LanguageModelRequest): Promise<LanguageModelParsedResponse> {
const settings = this.getSettings(request);
const result = await ollama.chat({
...settings,
...this.DEFAULT_REQUEST_SETTINGS,
model: this.model,
messages: request.messages.map(this.toOllamaMessage),
protected async handleStructuredOutputRequest(ollama: Ollama, chatRequest: ChatRequest): Promise<LanguageModelParsedResponse> {
const response = await ollama.chat({
...chatRequest,
format: 'json',
stream: false,
});
try {
return {
content: result.message.content,
parsed: JSON.parse(result.message.content)
content: response.message.content,
parsed: JSON.parse(response.message.content)
};
} catch (error) {
// TODO use ILogger
console.log('Failed to parse structured response from the language model.', error);
return {
content: result.message.content,
content: response.message.content,
parsed: {}
};
}
Expand All @@ -119,11 +153,22 @@ export class OllamaModel implements LanguageModel {
return new Ollama({ host: host });
}

protected toOllamaTool(tool: ToolRequest): Tool {
const transform = (props: Record<string, {
[key: string]: unknown;
type: string;
}> | undefined) => {
protected handleCancellationAndWrapIterator(response: AbortableAsyncIterable<ChatResponse>, token?: CancellationToken): LanguageModelStreamResponse {
token?.onCancellationRequested(() => {
// maybe it is better to use ollama.abort() as we are using one client per request
response.abort();
});

async function* wrapAsyncIterator<T>(inputIterable: AsyncIterable<ChatResponse>): AsyncIterable<LanguageModelStreamResponsePart> {
for await (const item of inputIterable) {
yield { content: item.message.content };
}
}
return { stream: wrapAsyncIterator(response) };
}

protected toOllamaTool(tool: ToolRequest): ToolWithHandler {
const transform = (props: Record<string, { [key: string]: unknown; type: string; }> | undefined) => {
if (!props) {
return undefined;
}
Expand All @@ -148,7 +193,8 @@ export class OllamaModel implements LanguageModel {
required: Object.keys(tool.parameters?.properties ?? {}),
properties: transform(tool.parameters?.properties) ?? {}
},
}
},
handler: tool.handler
};
}

Expand All @@ -165,3 +211,23 @@ export class OllamaModel implements LanguageModel {
return { role: 'system', content: '' };
}
}

/**
* Extended Tool containing a handler
* @see Tool
*/
type ToolWithHandler = Tool & { handler: (arg_string: string) => Promise<unknown> };

/**
* Extended chat request with mandatory messages and ToolWithHandler tools
*
* @see ChatRequest
* @see ToolWithHandler
*/
type ExtendedChatRequest = ChatRequest & {
messages: Message[]
tools?: ToolWithHandler[]
};

// Ollama doesn't export this type, so we have to define it here
type AbortableAsyncIterable<T> = AsyncIterable<T> & { abort: () => void };
Loading