-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ollama[minor]: Port embeddings to ollama package (#6464)
* ollama[minor]: Port embeddings to ollama package * deprecate community embeddings
- Loading branch information
1 parent
ed01967
commit be246a6
Showing
6 changed files
with
215 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; | ||
import { Ollama } from "ollama/browser"; | ||
import type { Options as OllamaOptions } from "ollama"; | ||
import { OllamaCamelCaseOptions } from "./types.js"; | ||
|
||
/** | ||
* Interface for OllamaEmbeddings parameters. Extends EmbeddingsParams and | ||
* defines additional parameters specific to the OllamaEmbeddings class. | ||
*/ | ||
interface OllamaEmbeddingsParams extends EmbeddingsParams { | ||
/** | ||
* The Ollama model to use for embeddings. | ||
* @default "mxbai-embed-large" | ||
*/ | ||
model?: string; | ||
|
||
/** | ||
* Base URL of the Ollama server | ||
* @default "http://localhost:11434" | ||
*/ | ||
baseUrl?: string; | ||
|
||
/** | ||
* Extra headers to include in the Ollama API request | ||
*/ | ||
headers?: Record<string, string>; | ||
|
||
/** | ||
* Defaults to "5m" | ||
*/ | ||
keepAlive?: string; | ||
|
||
/** | ||
* Whether or not to truncate the input text to fit inside the model's | ||
* context window. | ||
* @default false | ||
*/ | ||
truncate?: boolean; | ||
|
||
/** | ||
* Advanced Ollama API request parameters in camelCase, see | ||
* https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values | ||
* for details of the available parameters. | ||
*/ | ||
requestOptions?: OllamaCamelCaseOptions; | ||
} | ||
|
||
export class OllamaEmbeddings extends Embeddings { | ||
model = "mxbai-embed-large"; | ||
|
||
baseUrl = "http://localhost:11434"; | ||
|
||
headers?: Record<string, string>; | ||
|
||
keepAlive = "5m"; | ||
|
||
requestOptions?: Partial<OllamaOptions>; | ||
|
||
client: Ollama; | ||
|
||
truncate = false; | ||
|
||
constructor(fields?: OllamaEmbeddingsParams) { | ||
super({ maxConcurrency: 1, ...fields }); | ||
|
||
this.client = new Ollama({ | ||
host: fields?.baseUrl, | ||
}); | ||
this.baseUrl = fields?.baseUrl ?? this.baseUrl; | ||
|
||
this.model = fields?.model ?? this.model; | ||
this.headers = fields?.headers; | ||
this.keepAlive = fields?.keepAlive ?? this.keepAlive; | ||
this.truncate = fields?.truncate ?? this.truncate; | ||
this.requestOptions = fields?.requestOptions | ||
? this._convertOptions(fields?.requestOptions) | ||
: undefined; | ||
} | ||
|
||
/** convert camelCased Ollama request options like "useMMap" to | ||
* the snake_cased equivalent which the ollama API actually uses. | ||
* Used only for consistency with the llms/Ollama and chatModels/Ollama classes | ||
*/ | ||
_convertOptions( | ||
requestOptions: OllamaCamelCaseOptions | ||
): Partial<OllamaOptions> { | ||
const snakeCasedOptions: Partial<OllamaOptions> = {}; | ||
const mapping: Record<keyof OllamaCamelCaseOptions, string> = { | ||
embeddingOnly: "embedding_only", | ||
frequencyPenalty: "frequency_penalty", | ||
keepAlive: "keep_alive", | ||
logitsAll: "logits_all", | ||
lowVram: "low_vram", | ||
mainGpu: "main_gpu", | ||
mirostat: "mirostat", | ||
mirostatEta: "mirostat_eta", | ||
mirostatTau: "mirostat_tau", | ||
numBatch: "num_batch", | ||
numCtx: "num_ctx", | ||
numGpu: "num_gpu", | ||
numKeep: "num_keep", | ||
numPredict: "num_predict", | ||
numThread: "num_thread", | ||
penalizeNewline: "penalize_newline", | ||
presencePenalty: "presence_penalty", | ||
repeatLastN: "repeat_last_n", | ||
repeatPenalty: "repeat_penalty", | ||
temperature: "temperature", | ||
stop: "stop", | ||
tfsZ: "tfs_z", | ||
topK: "top_k", | ||
topP: "top_p", | ||
typicalP: "typical_p", | ||
useMlock: "use_mlock", | ||
useMmap: "use_mmap", | ||
vocabOnly: "vocab_only", | ||
f16Kv: "f16_kv", | ||
numa: "numa", | ||
seed: "seed", | ||
}; | ||
|
||
for (const [key, value] of Object.entries(requestOptions)) { | ||
const snakeCasedOption = mapping[key as keyof OllamaCamelCaseOptions]; | ||
if (snakeCasedOption) { | ||
snakeCasedOptions[snakeCasedOption as keyof OllamaOptions] = value; | ||
} | ||
} | ||
return snakeCasedOptions; | ||
} | ||
|
||
async embedDocuments(texts: string[]): Promise<number[][]> { | ||
return this.embeddingWithRetry(texts); | ||
} | ||
|
||
async embedQuery(text: string) { | ||
return (await this.embeddingWithRetry([text]))[0]; | ||
} | ||
|
||
private async embeddingWithRetry(texts: string[]): Promise<number[][]> { | ||
const res = await this.caller.call(() => | ||
this.client.embed({ | ||
model: this.model, | ||
input: texts, | ||
keep_alive: this.keepAlive, | ||
options: this.requestOptions, | ||
truncate: this.truncate, | ||
}) | ||
); | ||
return res.embeddings; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
export * from "./chat_models.js"; | ||
export * from "./embeddings.js"; | ||
export * from "./types.js"; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import { test, expect } from "@jest/globals"; | ||
import { OllamaEmbeddings } from "../embeddings.js"; | ||
|
||
test("Test OllamaEmbeddings.embedQuery", async () => { | ||
const embeddings = new OllamaEmbeddings(); | ||
const res = await embeddings.embedQuery("Hello world"); | ||
expect(res).toHaveLength(1024); | ||
expect(typeof res[0]).toBe("number"); | ||
}); | ||
|
||
test("Test OllamaEmbeddings.embedDocuments", async () => { | ||
const embeddings = new OllamaEmbeddings(); | ||
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]); | ||
expect(res).toHaveLength(2); | ||
expect(res[0]).toHaveLength(1024); | ||
expect(typeof res[0][0]).toBe("number"); | ||
expect(res[1]).toHaveLength(1024); | ||
expect(typeof res[1][0]).toBe("number"); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
export interface OllamaCamelCaseOptions { | ||
numa?: boolean; | ||
numCtx?: number; | ||
numBatch?: number; | ||
numGpu?: number; | ||
mainGpu?: number; | ||
lowVram?: boolean; | ||
f16Kv?: boolean; | ||
logitsAll?: boolean; | ||
vocabOnly?: boolean; | ||
useMmap?: boolean; | ||
useMlock?: boolean; | ||
embeddingOnly?: boolean; | ||
numThread?: number; | ||
numKeep?: number; | ||
seed?: number; | ||
numPredict?: number; | ||
topK?: number; | ||
topP?: number; | ||
tfsZ?: number; | ||
typicalP?: number; | ||
repeatLastN?: number; | ||
temperature?: number; | ||
repeatPenalty?: number; | ||
presencePenalty?: number; | ||
frequencyPenalty?: number; | ||
mirostat?: number; | ||
mirostatTau?: number; | ||
mirostatEta?: number; | ||
penalizeNewline?: boolean; | ||
/** | ||
* @default "5m" | ||
*/ | ||
keepAlive?: string | number; | ||
stop?: string[]; | ||
} |