Skip to content

Commit

Permalink
ollama[minor]: Port embeddings to ollama package (#6464)
Browse files Browse the repository at this point in the history
* ollama[minor]: Port embeddings to ollama package

* deprecate community embeddings
  • Loading branch information
bracesproul authored Aug 8, 2024
1 parent ed01967 commit be246a6
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 34 deletions.
3 changes: 3 additions & 0 deletions libs/langchain-community/src/embeddings/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ interface OllamaEmbeddingsParams extends EmbeddingsParams {
requestOptions?: CamelCasedRequestOptions;
}

/**
* @deprecated OllamaEmbeddings have been moved to the `@langchain/ollama` package. Install it with `npm install @langchain/ollama`.
*/
export class OllamaEmbeddings extends Embeddings {
model = "llama2";

Expand Down
38 changes: 4 additions & 34 deletions libs/langchain-ollama/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import {
convertOllamaMessagesToLangChain,
convertToOllamaMessages,
} from "./utils.js";
import { OllamaCamelCaseOptions } from "./types.js";

export interface ChatOllamaCallOptions extends BaseChatModelCallOptions {
/**
Expand Down Expand Up @@ -55,7 +56,9 @@ export interface PullModelOptions {
/**
* Input to chat model class.
*/
export interface ChatOllamaInput extends BaseChatModelParams {
export interface ChatOllamaInput
extends BaseChatModelParams,
OllamaCamelCaseOptions {
/**
* The model to invoke. If the model does not exist, it
* will be pulled.
Expand All @@ -75,40 +78,7 @@ export interface ChatOllamaInput extends BaseChatModelParams {
*/
checkOrPullModel?: boolean;
streaming?: boolean;
numa?: boolean;
numCtx?: number;
numBatch?: number;
numGpu?: number;
mainGpu?: number;
lowVram?: boolean;
f16Kv?: boolean;
logitsAll?: boolean;
vocabOnly?: boolean;
useMmap?: boolean;
useMlock?: boolean;
embeddingOnly?: boolean;
numThread?: number;
numKeep?: number;
seed?: number;
numPredict?: number;
topK?: number;
topP?: number;
tfsZ?: number;
typicalP?: number;
repeatLastN?: number;
temperature?: number;
repeatPenalty?: number;
presencePenalty?: number;
frequencyPenalty?: number;
mirostat?: number;
mirostatTau?: number;
mirostatEta?: number;
penalizeNewline?: boolean;
format?: string;
/**
* @default "5m"
*/
keepAlive?: string | number;
}

/**
Expand Down
151 changes: 151 additions & 0 deletions libs/langchain-ollama/src/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings";
import { Ollama } from "ollama/browser";
import type { Options as OllamaOptions } from "ollama";
import { OllamaCamelCaseOptions } from "./types.js";

/**
* Interface for OllamaEmbeddings parameters. Extends EmbeddingsParams and
* defines additional parameters specific to the OllamaEmbeddings class.
*/
interface OllamaEmbeddingsParams extends EmbeddingsParams {
/**
* The Ollama model to use for embeddings.
* @default "mxbai-embed-large"
*/
model?: string;

/**
* Base URL of the Ollama server
* @default "http://localhost:11434"
*/
baseUrl?: string;

/**
* Extra headers to include in the Ollama API request
*/
headers?: Record<string, string>;

/**
* Defaults to "5m"
*/
keepAlive?: string;

/**
* Whether or not to truncate the input text to fit inside the model's
* context window.
* @default false
*/
truncate?: boolean;

/**
* Advanced Ollama API request parameters in camelCase, see
* https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
* for details of the available parameters.
*/
requestOptions?: OllamaCamelCaseOptions;
}

export class OllamaEmbeddings extends Embeddings {
model = "mxbai-embed-large";

baseUrl = "http://localhost:11434";

headers?: Record<string, string>;

keepAlive = "5m";

requestOptions?: Partial<OllamaOptions>;

client: Ollama;

truncate = false;

constructor(fields?: OllamaEmbeddingsParams) {
super({ maxConcurrency: 1, ...fields });

this.client = new Ollama({
host: fields?.baseUrl,
});
this.baseUrl = fields?.baseUrl ?? this.baseUrl;

this.model = fields?.model ?? this.model;
this.headers = fields?.headers;
this.keepAlive = fields?.keepAlive ?? this.keepAlive;
this.truncate = fields?.truncate ?? this.truncate;
this.requestOptions = fields?.requestOptions
? this._convertOptions(fields?.requestOptions)
: undefined;
}

/** convert camelCased Ollama request options like "useMMap" to
* the snake_cased equivalent which the ollama API actually uses.
* Used only for consistency with the llms/Ollama and chatModels/Ollama classes
*/
_convertOptions(
requestOptions: OllamaCamelCaseOptions
): Partial<OllamaOptions> {
const snakeCasedOptions: Partial<OllamaOptions> = {};
const mapping: Record<keyof OllamaCamelCaseOptions, string> = {
embeddingOnly: "embedding_only",
frequencyPenalty: "frequency_penalty",
keepAlive: "keep_alive",
logitsAll: "logits_all",
lowVram: "low_vram",
mainGpu: "main_gpu",
mirostat: "mirostat",
mirostatEta: "mirostat_eta",
mirostatTau: "mirostat_tau",
numBatch: "num_batch",
numCtx: "num_ctx",
numGpu: "num_gpu",
numKeep: "num_keep",
numPredict: "num_predict",
numThread: "num_thread",
penalizeNewline: "penalize_newline",
presencePenalty: "presence_penalty",
repeatLastN: "repeat_last_n",
repeatPenalty: "repeat_penalty",
temperature: "temperature",
stop: "stop",
tfsZ: "tfs_z",
topK: "top_k",
topP: "top_p",
typicalP: "typical_p",
useMlock: "use_mlock",
useMmap: "use_mmap",
vocabOnly: "vocab_only",
f16Kv: "f16_kv",
numa: "numa",
seed: "seed",
};

for (const [key, value] of Object.entries(requestOptions)) {
const snakeCasedOption = mapping[key as keyof OllamaCamelCaseOptions];
if (snakeCasedOption) {
snakeCasedOptions[snakeCasedOption as keyof OllamaOptions] = value;
}
}
return snakeCasedOptions;
}

async embedDocuments(texts: string[]): Promise<number[][]> {
return this.embeddingWithRetry(texts);
}

async embedQuery(text: string) {
return (await this.embeddingWithRetry([text]))[0];
}

private async embeddingWithRetry(texts: string[]): Promise<number[][]> {
const res = await this.caller.call(() =>
this.client.embed({
model: this.model,
input: texts,
keep_alive: this.keepAlive,
options: this.requestOptions,
truncate: this.truncate,
})
);
return res.embeddings;
}
}
2 changes: 2 additions & 0 deletions libs/langchain-ollama/src/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
export * from "./chat_models.js";
export * from "./embeddings.js";
export * from "./types.js";
19 changes: 19 additions & 0 deletions libs/langchain-ollama/src/tests/embeddings.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { test, expect } from "@jest/globals";
import { OllamaEmbeddings } from "../embeddings.js";

test("Test OllamaEmbeddings.embedQuery", async () => {
const embeddings = new OllamaEmbeddings();
const res = await embeddings.embedQuery("Hello world");
expect(res).toHaveLength(1024);
expect(typeof res[0]).toBe("number");
});

test("Test OllamaEmbeddings.embedDocuments", async () => {
const embeddings = new OllamaEmbeddings();
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]);
expect(res).toHaveLength(2);
expect(res[0]).toHaveLength(1024);
expect(typeof res[0][0]).toBe("number");
expect(res[1]).toHaveLength(1024);
expect(typeof res[1][0]).toBe("number");
});
36 changes: 36 additions & 0 deletions libs/langchain-ollama/src/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
export interface OllamaCamelCaseOptions {
numa?: boolean;
numCtx?: number;
numBatch?: number;
numGpu?: number;
mainGpu?: number;
lowVram?: boolean;
f16Kv?: boolean;
logitsAll?: boolean;
vocabOnly?: boolean;
useMmap?: boolean;
useMlock?: boolean;
embeddingOnly?: boolean;
numThread?: number;
numKeep?: number;
seed?: number;
numPredict?: number;
topK?: number;
topP?: number;
tfsZ?: number;
typicalP?: number;
repeatLastN?: number;
temperature?: number;
repeatPenalty?: number;
presencePenalty?: number;
frequencyPenalty?: number;
mirostat?: number;
mirostatTau?: number;
mirostatEta?: number;
penalizeNewline?: boolean;
/**
* @default "5m"
*/
keepAlive?: string | number;
stop?: string[];
}

0 comments on commit be246a6

Please sign in to comment.