Skip to content

Commit

Permalink
feat: Support certain providers to customize the base URL. (#851)
Browse files Browse the repository at this point in the history
  • Loading branch information
Emt-lin authored Nov 27, 2024
1 parent 50b7b60 commit b63e8dd
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 176 deletions.
146 changes: 88 additions & 58 deletions src/LLMProviders/chatModelManager.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { CustomModel, LangChainParams, ModelConfig } from "@/aiParams";
import { BUILTIN_CHAT_MODELS, ChatModelProviders } from "@/constants";
import EncryptionService from "@/encryptionService";
import { ChatAnthropicWrapped, ProxyChatOpenAI } from "@/langchainWrappers";
import { HarmBlockThreshold, HarmCategory } from "@google/generative-ai";
import { ChatCohere } from "@langchain/cohere";
import { BaseChatModel } from "@langchain/core/language_models/chat_models";
Expand All @@ -10,6 +9,25 @@ import { ChatGroq } from "@langchain/groq";
import { ChatOllama } from "@langchain/ollama";
import { ChatOpenAI } from "@langchain/openai";
import { Notice } from "obsidian";
import { safeFetch } from "@/utils";
import { ChatAnthropic } from "@langchain/anthropic";

type ChatConstructorType = new (config: any) => BaseChatModel;

const CHAT_PROVIDER_CONSTRUCTORS = {
[ChatModelProviders.OPENAI]: ChatOpenAI,
[ChatModelProviders.AZURE_OPENAI]: ChatOpenAI,
[ChatModelProviders.ANTHROPIC]: ChatAnthropic,
[ChatModelProviders.COHEREAI]: ChatCohere,
[ChatModelProviders.GOOGLE]: ChatGoogleGenerativeAI,
[ChatModelProviders.OPENROUTERAI]: ChatOpenAI,
[ChatModelProviders.OLLAMA]: ChatOllama,
[ChatModelProviders.LM_STUDIO]: ChatOpenAI,
[ChatModelProviders.GROQ]: ChatGroq,
[ChatModelProviders.OPENAI_FORMAT]: ChatOpenAI,
} as const;

type ChatProviderConstructMap = typeof CHAT_PROVIDER_CONSTRUCTORS;

export default class ChatModelManager {
private encryptionService: EncryptionService;
Expand All @@ -20,12 +38,24 @@ export default class ChatModelManager {
string,
{
hasApiKey: boolean;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
AIConstructor: new (config: any) => BaseChatModel;
AIConstructor: ChatConstructorType;
vendor: string;
}
>;

private readonly providerApiKeyMap: Record<ChatModelProviders, () => string> = {
[ChatModelProviders.OPENAI]: () => this.getLangChainParams().openAIApiKey,
[ChatModelProviders.GOOGLE]: () => this.getLangChainParams().googleApiKey,
[ChatModelProviders.AZURE_OPENAI]: () => this.getLangChainParams().azureOpenAIApiKey,
[ChatModelProviders.ANTHROPIC]: () => this.getLangChainParams().anthropicApiKey,
[ChatModelProviders.COHEREAI]: () => this.getLangChainParams().cohereApiKey,
[ChatModelProviders.OPENROUTERAI]: () => this.getLangChainParams().openRouterAiApiKey,
[ChatModelProviders.GROQ]: () => this.getLangChainParams().groqApiKey,
[ChatModelProviders.OLLAMA]: () => "default-key",
[ChatModelProviders.LM_STUDIO]: () => "default-key",
[ChatModelProviders.OPENAI_FORMAT]: () => "default-key",
} as const;

private constructor(
private getLangChainParams: () => LangChainParams,
encryptionService: EncryptionService,
Expand Down Expand Up @@ -62,31 +92,50 @@ export default class ChatModelManager {
enableCors: customModel.enableCors,
};

const providerConfig = {
const providerConfig: {
[K in keyof ChatProviderConstructMap]: ConstructorParameters<
ChatProviderConstructMap[K]
>[0] /*& Record<string, unknown>;*/;
} = {
[ChatModelProviders.OPENAI]: {
modelName: customModel.name,
openAIApiKey: decrypt(customModel.apiKey || params.openAIApiKey),
// @ts-ignore
openAIOrgId: decrypt(params.openAIOrgId),
maxTokens: params.maxTokens,
configuration: {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
},
},
[ChatModelProviders.ANTHROPIC]: {
anthropicApiKey: decrypt(customModel.apiKey || params.anthropicApiKey),
modelName: customModel.name,
anthropicApiUrl: customModel.baseUrl,
clientOptions: {
// Required to bypass CORS restrictions
defaultHeaders: { "anthropic-dangerous-direct-browser-access": "true" },
fetch: customModel.enableCors ? safeFetch : undefined,
},
},
[ChatModelProviders.AZURE_OPENAI]: {
maxTokens: params.maxTokens,
azureOpenAIApiKey: decrypt(customModel.apiKey || params.azureOpenAIApiKey),
azureOpenAIApiInstanceName: params.azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName: params.azureOpenAIApiDeploymentName,
azureOpenAIApiVersion: params.azureOpenAIApiVersion,
configuration: {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
},
},
[ChatModelProviders.COHEREAI]: {
apiKey: decrypt(customModel.apiKey || params.cohereApiKey),
model: customModel.name,
},
[ChatModelProviders.GOOGLE]: {
apiKey: decrypt(customModel.apiKey || params.googleApiKey),
modelName: customModel.name,
model: customModel.name,
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
Expand All @@ -105,11 +154,15 @@ export default class ChatModelManager {
threshold: HarmBlockThreshold.BLOCK_NONE,
},
],
baseUrl: customModel.baseUrl,
},
[ChatModelProviders.OPENROUTERAI]: {
modelName: customModel.name,
openAIApiKey: decrypt(customModel.apiKey || params.openRouterAiApiKey),
openAIProxyBaseUrl: "https://openrouter.ai/api/v1",
configuration: {
baseURL: customModel.baseUrl || "https://openrouter.ai/api/v1",
fetch: customModel.enableCors ? safeFetch : undefined,
},
},
[ChatModelProviders.GROQ]: {
apiKey: decrypt(customModel.apiKey || params.groqApiKey),
Expand All @@ -118,20 +171,28 @@ export default class ChatModelManager {
[ChatModelProviders.OLLAMA]: {
// ChatOllama has `model` instead of `modelName`!!
model: customModel.name,
// @ts-ignore
apiKey: customModel.apiKey || "default-key",
// MUST NOT use /v1 in the baseUrl for ollama
baseUrl: customModel.baseUrl || "http://localhost:11434",
},
[ChatModelProviders.LM_STUDIO]: {
modelName: customModel.name,
openAIApiKey: customModel.apiKey || "default-key",
openAIProxyBaseUrl: customModel.baseUrl || "http://localhost:1234/v1",
configuration: {
baseURL: customModel.baseUrl || "http://localhost:1234/v1",
fetch: customModel.enableCors ? safeFetch : undefined,
},
},
[ChatModelProviders.OPENAI_FORMAT]: {
modelName: customModel.name,
openAIApiKey: decrypt(customModel.apiKey || "default-key"),
maxTokens: params.maxTokens,
openAIProxyBaseUrl: customModel.baseUrl || "",
configuration: {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
dangerouslyAllowBrowser: true,
},
},
};

Expand All @@ -150,66 +211,35 @@ export default class ChatModelManager {

allModels.forEach((model) => {
if (model.enabled) {
let constructor;
let apiKey;

switch (model.provider) {
case ChatModelProviders.OPENAI:
constructor = ChatOpenAI;
apiKey = model.apiKey || this.getLangChainParams().openAIApiKey;
break;
case ChatModelProviders.GOOGLE:
constructor = ChatGoogleGenerativeAI;
apiKey = model.apiKey || this.getLangChainParams().googleApiKey;
break;
case ChatModelProviders.AZURE_OPENAI:
constructor = ChatOpenAI;
apiKey = model.apiKey || this.getLangChainParams().azureOpenAIApiKey;
break;
case ChatModelProviders.ANTHROPIC:
constructor = ChatAnthropicWrapped;
apiKey = model.apiKey || this.getLangChainParams().anthropicApiKey;
break;
case ChatModelProviders.COHEREAI:
constructor = ChatCohere;
apiKey = model.apiKey || this.getLangChainParams().cohereApiKey;
break;
case ChatModelProviders.OPENROUTERAI:
constructor = ProxyChatOpenAI;
apiKey = model.apiKey || this.getLangChainParams().openRouterAiApiKey;
break;
case ChatModelProviders.OLLAMA:
constructor = ChatOllama;
apiKey = model.apiKey || "default-key";
break;
case ChatModelProviders.LM_STUDIO:
constructor = ProxyChatOpenAI;
apiKey = model.apiKey || "default-key";
break;
case ChatModelProviders.GROQ:
constructor = ChatGroq;
apiKey = model.apiKey || this.getLangChainParams().groqApiKey;
break;
case ChatModelProviders.OPENAI_FORMAT:
constructor = ProxyChatOpenAI;
apiKey = model.apiKey || "default-key";
break;
default:
console.warn(`Unknown provider: ${model.provider} for model: ${model.name}`);
return;
if (!Object.values(ChatModelProviders).contains(model.provider as ChatModelProviders)) {
console.warn(`Unknown provider: ${model.provider} for model: ${model.name}`);
return;
}

const constructor = this.getProviderConstructor(model);
const getDefaultApiKey = this.providerApiKeyMap[model.provider as ChatModelProviders];

const apiKey = model.apiKey || getDefaultApiKey();
const modelKey = `${model.name}|${model.provider}`;
modelMap[modelKey] = {
hasApiKey: Boolean(model.apiKey || apiKey),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
AIConstructor: constructor as any,
AIConstructor: constructor,
vendor: model.provider,
};
}
});
}

getProviderConstructor(model: CustomModel): ChatConstructorType {
const constructor: ChatConstructorType =
CHAT_PROVIDER_CONSTRUCTORS[model.provider as ChatModelProviders];
if (!constructor) {
console.warn(`Unknown provider: ${model.provider} for model: ${model.name}`);
throw new Error(`Unknown provider: ${model.provider} for model: ${model.name}`);
}
return constructor;
}

getChatModel(): BaseChatModel {
return ChatModelManager.chatModel;
}
Expand All @@ -232,7 +262,7 @@ export default class ChatModelManager {
const modelConfig = this.getModelConfig(model);

// MUST update it since chatModelManager is a singleton.
this.getLangChainParams().modelKey = `${model.name}|${model.provider}`;
this.getLangChainParams().modelKey = modelKey;
new Notice(`Setting model: ${modelConfig.modelName}`);
try {
const newModelInstance = new selectedModel.AIConstructor({
Expand Down
Loading

0 comments on commit b63e8dd

Please sign in to comment.