diff --git a/src/LLMProviders/chatModelManager.ts b/src/LLMProviders/chatModelManager.ts index 5ce94814..991043da 100644 --- a/src/LLMProviders/chatModelManager.ts +++ b/src/LLMProviders/chatModelManager.ts @@ -1,7 +1,6 @@ import { CustomModel, LangChainParams, ModelConfig } from "@/aiParams"; import { BUILTIN_CHAT_MODELS, ChatModelProviders } from "@/constants"; import EncryptionService from "@/encryptionService"; -import { ChatAnthropicWrapped, ProxyChatOpenAI } from "@/langchainWrappers"; import { HarmBlockThreshold, HarmCategory } from "@google/generative-ai"; import { ChatCohere } from "@langchain/cohere"; import { BaseChatModel } from "@langchain/core/language_models/chat_models"; @@ -10,6 +9,25 @@ import { ChatGroq } from "@langchain/groq"; import { ChatOllama } from "@langchain/ollama"; import { ChatOpenAI } from "@langchain/openai"; import { Notice } from "obsidian"; +import { safeFetch } from "@/utils"; +import { ChatAnthropic } from "@langchain/anthropic"; + +type ChatConstructorType = new (config: any) => BaseChatModel; + +const CHAT_PROVIDER_CONSTRUCTORS = { + [ChatModelProviders.OPENAI]: ChatOpenAI, + [ChatModelProviders.AZURE_OPENAI]: ChatOpenAI, + [ChatModelProviders.ANTHROPIC]: ChatAnthropic, + [ChatModelProviders.COHEREAI]: ChatCohere, + [ChatModelProviders.GOOGLE]: ChatGoogleGenerativeAI, + [ChatModelProviders.OPENROUTERAI]: ChatOpenAI, + [ChatModelProviders.OLLAMA]: ChatOllama, + [ChatModelProviders.LM_STUDIO]: ChatOpenAI, + [ChatModelProviders.GROQ]: ChatGroq, + [ChatModelProviders.OPENAI_FORMAT]: ChatOpenAI, +} as const; + +type ChatProviderConstructMap = typeof CHAT_PROVIDER_CONSTRUCTORS; export default class ChatModelManager { private encryptionService: EncryptionService; @@ -20,12 +38,24 @@ export default class ChatModelManager { string, { hasApiKey: boolean; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - AIConstructor: new (config: any) => BaseChatModel; + AIConstructor: ChatConstructorType; vendor: string; } >; + private readonly providerApiKeyMap: Record string> = { + [ChatModelProviders.OPENAI]: () => this.getLangChainParams().openAIApiKey, + [ChatModelProviders.GOOGLE]: () => this.getLangChainParams().googleApiKey, + [ChatModelProviders.AZURE_OPENAI]: () => this.getLangChainParams().azureOpenAIApiKey, + [ChatModelProviders.ANTHROPIC]: () => this.getLangChainParams().anthropicApiKey, + [ChatModelProviders.COHEREAI]: () => this.getLangChainParams().cohereApiKey, + [ChatModelProviders.OPENROUTERAI]: () => this.getLangChainParams().openRouterAiApiKey, + [ChatModelProviders.GROQ]: () => this.getLangChainParams().groqApiKey, + [ChatModelProviders.OLLAMA]: () => "default-key", + [ChatModelProviders.LM_STUDIO]: () => "default-key", + [ChatModelProviders.OPENAI_FORMAT]: () => "default-key", + } as const; + private constructor( private getLangChainParams: () => LangChainParams, encryptionService: EncryptionService, @@ -62,16 +92,31 @@ export default class ChatModelManager { enableCors: customModel.enableCors, }; - const providerConfig = { + const providerConfig: { + [K in keyof ChatProviderConstructMap]: ConstructorParameters< + ChatProviderConstructMap[K] + >[0] /*& Record;*/; + } = { [ChatModelProviders.OPENAI]: { modelName: customModel.name, openAIApiKey: decrypt(customModel.apiKey || params.openAIApiKey), + // @ts-ignore openAIOrgId: decrypt(params.openAIOrgId), maxTokens: params.maxTokens, + configuration: { + baseURL: customModel.baseUrl, + fetch: customModel.enableCors ? safeFetch : undefined, + }, }, [ChatModelProviders.ANTHROPIC]: { anthropicApiKey: decrypt(customModel.apiKey || params.anthropicApiKey), modelName: customModel.name, + anthropicApiUrl: customModel.baseUrl, + clientOptions: { + // Required to bypass CORS restrictions + defaultHeaders: { "anthropic-dangerous-direct-browser-access": "true" }, + fetch: customModel.enableCors ? safeFetch : undefined, + }, }, [ChatModelProviders.AZURE_OPENAI]: { maxTokens: params.maxTokens, @@ -79,6 +124,10 @@ export default class ChatModelManager { azureOpenAIApiInstanceName: params.azureOpenAIApiInstanceName, azureOpenAIApiDeploymentName: params.azureOpenAIApiDeploymentName, azureOpenAIApiVersion: params.azureOpenAIApiVersion, + configuration: { + baseURL: customModel.baseUrl, + fetch: customModel.enableCors ? safeFetch : undefined, + }, }, [ChatModelProviders.COHEREAI]: { apiKey: decrypt(customModel.apiKey || params.cohereApiKey), @@ -86,7 +135,7 @@ export default class ChatModelManager { }, [ChatModelProviders.GOOGLE]: { apiKey: decrypt(customModel.apiKey || params.googleApiKey), - modelName: customModel.name, + model: customModel.name, safetySettings: [ { category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, @@ -105,11 +154,15 @@ export default class ChatModelManager { threshold: HarmBlockThreshold.BLOCK_NONE, }, ], + baseUrl: customModel.baseUrl, }, [ChatModelProviders.OPENROUTERAI]: { modelName: customModel.name, openAIApiKey: decrypt(customModel.apiKey || params.openRouterAiApiKey), - openAIProxyBaseUrl: "https://openrouter.ai/api/v1", + configuration: { + baseURL: customModel.baseUrl || "https://openrouter.ai/api/v1", + fetch: customModel.enableCors ? safeFetch : undefined, + }, }, [ChatModelProviders.GROQ]: { apiKey: decrypt(customModel.apiKey || params.groqApiKey), @@ -118,6 +171,7 @@ export default class ChatModelManager { [ChatModelProviders.OLLAMA]: { // ChatOllama has `model` instead of `modelName`!! model: customModel.name, + // @ts-ignore apiKey: customModel.apiKey || "default-key", // MUST NOT use /v1 in the baseUrl for ollama baseUrl: customModel.baseUrl || "http://localhost:11434", @@ -125,13 +179,20 @@ export default class ChatModelManager { [ChatModelProviders.LM_STUDIO]: { modelName: customModel.name, openAIApiKey: customModel.apiKey || "default-key", - openAIProxyBaseUrl: customModel.baseUrl || "http://localhost:1234/v1", + configuration: { + baseURL: customModel.baseUrl || "http://localhost:1234/v1", + fetch: customModel.enableCors ? safeFetch : undefined, + }, }, [ChatModelProviders.OPENAI_FORMAT]: { modelName: customModel.name, openAIApiKey: decrypt(customModel.apiKey || "default-key"), maxTokens: params.maxTokens, - openAIProxyBaseUrl: customModel.baseUrl || "", + configuration: { + baseURL: customModel.baseUrl, + fetch: customModel.enableCors ? safeFetch : undefined, + dangerouslyAllowBrowser: true, + }, }, }; @@ -150,66 +211,35 @@ export default class ChatModelManager { allModels.forEach((model) => { if (model.enabled) { - let constructor; - let apiKey; - - switch (model.provider) { - case ChatModelProviders.OPENAI: - constructor = ChatOpenAI; - apiKey = model.apiKey || this.getLangChainParams().openAIApiKey; - break; - case ChatModelProviders.GOOGLE: - constructor = ChatGoogleGenerativeAI; - apiKey = model.apiKey || this.getLangChainParams().googleApiKey; - break; - case ChatModelProviders.AZURE_OPENAI: - constructor = ChatOpenAI; - apiKey = model.apiKey || this.getLangChainParams().azureOpenAIApiKey; - break; - case ChatModelProviders.ANTHROPIC: - constructor = ChatAnthropicWrapped; - apiKey = model.apiKey || this.getLangChainParams().anthropicApiKey; - break; - case ChatModelProviders.COHEREAI: - constructor = ChatCohere; - apiKey = model.apiKey || this.getLangChainParams().cohereApiKey; - break; - case ChatModelProviders.OPENROUTERAI: - constructor = ProxyChatOpenAI; - apiKey = model.apiKey || this.getLangChainParams().openRouterAiApiKey; - break; - case ChatModelProviders.OLLAMA: - constructor = ChatOllama; - apiKey = model.apiKey || "default-key"; - break; - case ChatModelProviders.LM_STUDIO: - constructor = ProxyChatOpenAI; - apiKey = model.apiKey || "default-key"; - break; - case ChatModelProviders.GROQ: - constructor = ChatGroq; - apiKey = model.apiKey || this.getLangChainParams().groqApiKey; - break; - case ChatModelProviders.OPENAI_FORMAT: - constructor = ProxyChatOpenAI; - apiKey = model.apiKey || "default-key"; - break; - default: - console.warn(`Unknown provider: ${model.provider} for model: ${model.name}`); - return; + if (!Object.values(ChatModelProviders).contains(model.provider as ChatModelProviders)) { + console.warn(`Unknown provider: ${model.provider} for model: ${model.name}`); + return; } + const constructor = this.getProviderConstructor(model); + const getDefaultApiKey = this.providerApiKeyMap[model.provider as ChatModelProviders]; + + const apiKey = model.apiKey || getDefaultApiKey(); const modelKey = `${model.name}|${model.provider}`; modelMap[modelKey] = { hasApiKey: Boolean(model.apiKey || apiKey), - // eslint-disable-next-line @typescript-eslint/no-explicit-any - AIConstructor: constructor as any, + AIConstructor: constructor, vendor: model.provider, }; } }); } + getProviderConstructor(model: CustomModel): ChatConstructorType { + const constructor: ChatConstructorType = + CHAT_PROVIDER_CONSTRUCTORS[model.provider as ChatModelProviders]; + if (!constructor) { + console.warn(`Unknown provider: ${model.provider} for model: ${model.name}`); + throw new Error(`Unknown provider: ${model.provider} for model: ${model.name}`); + } + return constructor; + } + getChatModel(): BaseChatModel { return ChatModelManager.chatModel; } @@ -232,7 +262,7 @@ export default class ChatModelManager { const modelConfig = this.getModelConfig(model); // MUST update it since chatModelManager is a singleton. - this.getLangChainParams().modelKey = `${model.name}|${model.provider}`; + this.getLangChainParams().modelKey = modelKey; new Notice(`Setting model: ${modelConfig.modelName}`); try { const newModelInstance = new selectedModel.AIConstructor({ diff --git a/src/LLMProviders/embeddingManager.ts b/src/LLMProviders/embeddingManager.ts index 8a5532cd..e9f0e63f 100644 --- a/src/LLMProviders/embeddingManager.ts +++ b/src/LLMProviders/embeddingManager.ts @@ -3,12 +3,25 @@ import { CustomModel, LangChainParams } from "@/aiParams"; import { EmbeddingModelProviders } from "@/constants"; import EncryptionService from "@/encryptionService"; import { CustomError } from "@/error"; -import { ProxyOpenAIEmbeddings } from "@/langchainWrappers"; import { CohereEmbeddings } from "@langchain/cohere"; import { Embeddings } from "@langchain/core/embeddings"; import { GoogleGenerativeAIEmbeddings } from "@langchain/google-genai"; import { OllamaEmbeddings } from "@langchain/ollama"; import { OpenAIEmbeddings } from "@langchain/openai"; +import { safeFetch } from "@/utils"; + +type EmbeddingConstructorType = new (config: any) => Embeddings; + +const EMBEDDING_PROVIDER_CONSTRUCTORS = { + [EmbeddingModelProviders.OPENAI]: OpenAIEmbeddings, + [EmbeddingModelProviders.COHEREAI]: CohereEmbeddings, + [EmbeddingModelProviders.GOOGLE]: GoogleGenerativeAIEmbeddings, + [EmbeddingModelProviders.AZURE_OPENAI]: OpenAIEmbeddings, + [EmbeddingModelProviders.OLLAMA]: OllamaEmbeddings, + [EmbeddingModelProviders.OPENAI_FORMAT]: OpenAIEmbeddings, +} as const; + +type EmbeddingProviderConstructorMap = typeof EMBEDDING_PROVIDER_CONSTRUCTORS; export default class EmbeddingManager { private encryptionService: EncryptionService; @@ -19,11 +32,20 @@ export default class EmbeddingManager { string, { hasApiKey: boolean; - EmbeddingConstructor: new (config: any) => Embeddings; + EmbeddingConstructor: EmbeddingConstructorType; vendor: string; } >; + private readonly providerAipKeyMap: Record string> = { + [EmbeddingModelProviders.OPENAI]: () => this.getLangChainParams().openAIApiKey, + [EmbeddingModelProviders.COHEREAI]: () => this.getLangChainParams().cohereApiKey, + [EmbeddingModelProviders.GOOGLE]: () => this.getLangChainParams().googleApiKey, + [EmbeddingModelProviders.AZURE_OPENAI]: () => this.getLangChainParams().azureOpenAIApiKey, + [EmbeddingModelProviders.OLLAMA]: () => "default-key", + [EmbeddingModelProviders.OPENAI_FORMAT]: () => "", + }; + private constructor( private getLangChainParams: () => LangChainParams, encryptionService: EncryptionService, @@ -49,46 +71,34 @@ export default class EmbeddingManager { return EmbeddingManager.instance; } + getProviderConstructor(model: CustomModel): EmbeddingConstructorType { + const constructor = EMBEDDING_PROVIDER_CONSTRUCTORS[model.provider as EmbeddingModelProviders]; + if (!constructor) { + console.warn(`Unknown provider: ${model.provider} for model: ${model.name}`); + throw new Error(`Unknown provider: ${model.provider} for model: ${model.name}`); + } + return constructor; + } + // Build a map of modelKey to model config private buildModelMap(activeEmbeddingModels: CustomModel[]) { EmbeddingManager.modelMap = {}; const modelMap = EmbeddingManager.modelMap; - const params = this.getLangChainParams(); activeEmbeddingModels.forEach((model) => { if (model.enabled) { - let constructor; - let apiKey; - - switch (model.provider) { - case EmbeddingModelProviders.OPENAI: - constructor = OpenAIEmbeddings; - apiKey = params.openAIApiKey; - break; - case EmbeddingModelProviders.COHEREAI: - constructor = CohereEmbeddings; - apiKey = params.cohereApiKey; - break; - case EmbeddingModelProviders.GOOGLE: - constructor = GoogleGenerativeAIEmbeddings; - apiKey = params.googleApiKey; - break; - case EmbeddingModelProviders.AZURE_OPENAI: - constructor = OpenAIEmbeddings; - apiKey = params.azureOpenAIApiKey; - break; - case EmbeddingModelProviders.OLLAMA: - constructor = OllamaEmbeddings; - apiKey = "default-key"; - break; - case EmbeddingModelProviders.OPENAI_FORMAT: - constructor = ProxyOpenAIEmbeddings; - apiKey = model.apiKey; - break; - default: - console.warn(`Unknown provider: ${model.provider} for embedding model: ${model.name}`); - return; + if ( + !Object.values(EmbeddingModelProviders).contains( + model.provider as EmbeddingModelProviders + ) + ) { + console.warn(`Unknown provider: ${model.provider} for embedding model: ${model.name}`); + return; } + const constructor = this.getProviderConstructor(model); + const apiKey = + model.apiKey || this.providerAipKeyMap[model.provider as EmbeddingModelProviders](); + const modelKey = `${model.name}|${model.provider}`; modelMap[modelKey] = { hasApiKey: Boolean(apiKey), @@ -147,7 +157,7 @@ export default class EmbeddingManager { } } - private getEmbeddingConfig(customModel: CustomModel): any { + private getEmbeddingConfig(customModel: CustomModel) { const decrypt = (key: string) => this.encryptionService.getDecryptedKey(key); const params = this.getLangChainParams(); const modelName = customModel.name; @@ -157,25 +167,37 @@ export default class EmbeddingManager { maxConcurrency: 3, }; - const providerConfigs = { + const providerConfig: { + [K in keyof EmbeddingProviderConstructorMap]: ConstructorParameters< + EmbeddingProviderConstructorMap[K] + >[0] /*& Record;*/; + } = { [EmbeddingModelProviders.OPENAI]: { modelName, - openAIApiKey: decrypt(params.openAIApiKey), + openAIApiKey: decrypt(customModel.apiKey || params.openAIApiKey), timeout: 10000, + configuration: { + baseURL: customModel.baseUrl, + fetch: customModel.enableCors ? safeFetch : undefined, + }, }, [EmbeddingModelProviders.COHEREAI]: { model: modelName, - apiKey: decrypt(params.cohereApiKey), + apiKey: decrypt(customModel.apiKey || params.cohereApiKey), }, [EmbeddingModelProviders.GOOGLE]: { modelName: modelName, apiKey: decrypt(params.googleApiKey), }, [EmbeddingModelProviders.AZURE_OPENAI]: { - azureOpenAIApiKey: decrypt(params.azureOpenAIApiKey), + azureOpenAIApiKey: decrypt(customModel.apiKey || params.azureOpenAIApiKey), azureOpenAIApiInstanceName: params.azureOpenAIApiInstanceName, azureOpenAIApiDeploymentName: params.azureOpenAIApiEmbeddingDeploymentName, azureOpenAIApiVersion: params.azureOpenAIApiVersion, + configuration: { + baseURL: customModel.baseUrl, + fetch: customModel.enableCors ? safeFetch : undefined, + }, }, [EmbeddingModelProviders.OLLAMA]: { baseUrl: customModel.baseUrl || "http://localhost:11434", @@ -185,18 +207,17 @@ export default class EmbeddingManager { [EmbeddingModelProviders.OPENAI_FORMAT]: { modelName, openAIApiKey: decrypt(customModel.apiKey || ""), - openAIEmbeddingProxyBaseUrl: customModel.baseUrl, + configuration: { + baseURL: customModel.baseUrl, + fetch: customModel.enableCors ? safeFetch : undefined, + dangerouslyAllowBrowser: true, + }, }, }; - const modelKey = `${modelName}|${customModel.provider}`; - const selectedModel = EmbeddingManager.modelMap[modelKey]; - if (!selectedModel) { - console.error(`No embedding model found for key: ${modelKey}`); - } - const providerConfig = - providerConfigs[selectedModel.vendor as keyof typeof providerConfigs] || {}; + const selectedProviderConfig = + providerConfig[customModel.provider as EmbeddingModelProviders] || {}; - return { ...baseConfig, ...providerConfig }; + return { ...baseConfig, ...selectedProviderConfig }; } } diff --git a/src/langchainWrappers.ts b/src/langchainWrappers.ts index 4674d3e7..1ba790f1 100644 --- a/src/langchainWrappers.ts +++ b/src/langchainWrappers.ts @@ -1,9 +1,8 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { AnthropicInput, ChatAnthropic } from "@langchain/anthropic"; -import { ChatOpenAI } from "@langchain/openai"; -import { OpenAIEmbeddings } from "@langchain/openai"; -import { requestUrl } from "obsidian"; +import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai"; import OpenAI from "openai"; +import { safeFetch } from "@/utils"; // Migrated to OpenAI v4 client from v3: https://github.com/openai/openai-node/discussions/217 export class ProxyChatOpenAI extends ChatOpenAI { @@ -43,69 +42,3 @@ export class ChatAnthropicWrapped extends ChatAnthropic { }); } } - -/** Proxy function to use in place of fetch() to bypass CORS restrictions. - * It currently doesn't support streaming until this is implemented - * https://forum.obsidian.md/t/support-streaming-the-request-and-requesturl-response-body/87381 */ -async function safeFetch(url: string, options: RequestInit): Promise { - // Necessary to remove 'content-length' in order to make headers compatible with requestUrl() - delete (options.headers as Record)["content-length"]; - - if (typeof options.body === "string") { - const newBody = JSON.parse(options.body ?? {}); - // frequency_penalty: default 0, but perplexity.ai requires 1 by default. - // so, delete this argument for now - delete newBody["frequency_penalty"]; - options.body = JSON.stringify(newBody); - } - - const response = await requestUrl({ - url, - contentType: "application/json", - headers: options.headers as Record, - method: "POST", - body: options.body?.toString(), - }); - - return { - ok: response.status >= 200 && response.status < 300, - status: response.status, - statusText: response.status.toString(), - headers: new Headers(response.headers), - url: url, - type: "basic", - redirected: false, - body: createReadableStreamFromString(response.text), - bodyUsed: true, - json: () => response.json, - text: async () => response.text, - clone: () => { - throw new Error("not implemented"); - }, - arrayBuffer: () => { - throw new Error("not implemented"); - }, - blob: () => { - throw new Error("not implemented"); - }, - formData: () => { - throw new Error("not implemented"); - }, - }; -} - -function createReadableStreamFromString(input: string) { - return new ReadableStream({ - start(controller) { - // Convert the input string to a Uint8Array - const encoder = new TextEncoder(); - const uint8Array = encoder.encode(input); - - // Push the data to the stream - controller.enqueue(uint8Array); - - // Close the stream - controller.close(); - }, - }); -} diff --git a/src/utils.ts b/src/utils.ts index e29d2fe7..0b139de7 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -6,7 +6,7 @@ import { MemoryVariables } from "@langchain/core/memory"; import { RunnableSequence } from "@langchain/core/runnables"; import { BaseChain, RetrievalQAChain } from "langchain/chains"; import moment from "moment"; -import { TFile, Vault, parseYaml } from "obsidian"; +import { TFile, Vault, parseYaml, requestUrl } from "obsidian"; export const getModelNameFromKey = (modelKey: string): string => { return modelKey.split("|")[0]; @@ -549,3 +549,91 @@ export function extractYoutubeUrl(text: string): string | null { const match = text.match(YOUTUBE_URL_REGEX); return match ? match[0] : null; } + +/** Proxy function to use in place of fetch() to bypass CORS restrictions. + * It currently doesn't support streaming until this is implemented + * https://forum.obsidian.md/t/support-streaming-the-request-and-requesturl-response-body/87381 */ +export async function safeFetch(url: string, options: RequestInit): Promise { + // Necessary to remove 'content-length' in order to make headers compatible with requestUrl() + delete (options.headers as Record)["content-length"]; + + if (typeof options.body === "string") { + const newBody = JSON.parse(options.body ?? {}); + // frequency_penalty: default 0, but perplexity.ai requires 1 by default. + // so, delete this argument for now + delete newBody["frequency_penalty"]; + options.body = JSON.stringify(newBody); + } + + const response = await requestUrl({ + url, + contentType: "application/json", + headers: options.headers as Record, + method: "POST", + body: options.body?.toString(), + }); + + return { + ok: response.status >= 200 && response.status < 300, + status: response.status, + statusText: response.status.toString(), + headers: new Headers(response.headers), + url: url, + type: "basic", + redirected: false, + body: createReadableStreamFromString(response.text), + bodyUsed: true, + json: () => response.json, + text: async () => response.text, + clone: () => { + throw new Error("not implemented"); + }, + arrayBuffer: () => { + throw new Error("not implemented"); + }, + blob: () => { + throw new Error("not implemented"); + }, + formData: () => { + throw new Error("not implemented"); + }, + }; +} + +function createReadableStreamFromString(input: string) { + return new ReadableStream({ + start(controller) { + // Convert the input string to a Uint8Array + const encoder = new TextEncoder(); + const uint8Array = encoder.encode(input); + + // Push the data to the stream + controller.enqueue(uint8Array); + + // Close the stream + controller.close(); + }, + }); +} + +export function err2String(err: any, stack = false) { + // maybe to be improved + return err instanceof Error + ? err.message + + "\n" + + `${err?.cause ? "more message: " + err.cause.message : ""}` + + "\n" + + `${stack ? err.stack : ""}` + : JSON.stringify(err); +} + +export function omit, K extends keyof T>( + obj: T, + keys: K[] +): Omit { + const result = { ...obj }; + keys.forEach((key) => { + delete result[key]; + }); + return result; +}