Skip to content

Commit

Permalink
fix: Temporary fix:for o1-xx model need to covert systemMessage to ai…
Browse files Browse the repository at this point in the history
…Message.
  • Loading branch information
Emt-lin committed Nov 27, 2024
1 parent b2441c0 commit c354202
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 18 deletions.
20 changes: 17 additions & 3 deletions src/LLMProviders/chainManager.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { CustomModel, LangChainParams, SetChainOptions } from "@/aiParams";
import ChainFactory, { ChainType, Document } from "@/chainFactory";
import { BUILTIN_CHAT_MODELS, USER_SENDER } from "@/constants";
import { AI_SENDER, BUILTIN_CHAT_MODELS, USER_SENDER } from "@/constants";
import EncryptionService from "@/encryptionService";
import {
ChainRunner,
Expand Down Expand Up @@ -317,12 +317,26 @@ export default class ChainManager {
this.validateChatModel();
this.validateChainInitialization();

const chatModel = this.chatModelManager.getChatModel();
const modelName = (chatModel as any).modelName || (chatModel as any).model || "";
const isO1Model = modelName.startsWith("o1");

// Handle ignoreSystemMessage
if (ignoreSystemMessage) {
const effectivePrompt = ChatPromptTemplate.fromMessages([
if (ignoreSystemMessage || isO1Model) {
let effectivePrompt = ChatPromptTemplate.fromMessages([
new MessagesPlaceholder("history"),
HumanMessagePromptTemplate.fromTemplate("{input}"),
]);

// TODO: hack for o1 models, to be removed when they support system prompt
if (isO1Model) {
// Temporary fix:for o1-xx model need to covert systemMessage to aiMessage
effectivePrompt = ChatPromptTemplate.fromMessages([
[AI_SENDER, this.getLangChainParams().systemMessage || ""],
effectivePrompt,
]);
}

this.setChain(this.getLangChainParams().chainType, {
...this.getLangChainParams().options,
prompt: effectivePrompt,
Expand Down
54 changes: 39 additions & 15 deletions src/LLMProviders/chatModelManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,13 @@ export default class ChatModelManager {
private getModelConfig(customModel: CustomModel): ModelConfig {
const decrypt = (key: string) => this.encryptionService.getDecryptedKey(key);
const params = this.getLangChainParams();

// Check if the model starts with "o1"
const modelName = customModel.name;
const isO1Model = modelName.startsWith("o1");

const baseConfig: ModelConfig = {
modelName: customModel.name,
modelName: modelName,
temperature: params.temperature,
streaming: true,
maxRetries: 3,
Expand All @@ -98,19 +103,19 @@ export default class ChatModelManager {
>[0] /*& Record<string, unknown>;*/;
} = {
[ChatModelProviders.OPENAI]: {
modelName: customModel.name,
modelName: modelName,
openAIApiKey: decrypt(customModel.apiKey || params.openAIApiKey),
// @ts-ignore
openAIOrgId: decrypt(params.openAIOrgId),
maxTokens: params.maxTokens,
configuration: {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
},
// @ts-ignore
openAIOrgId: decrypt(params.openAIOrgId),
...this.handleOpenAIExtraArgs(isO1Model, params.maxTokens, params.temperature, true),
},
[ChatModelProviders.ANTHROPIC]: {
anthropicApiKey: decrypt(customModel.apiKey || params.anthropicApiKey),
modelName: customModel.name,
modelName: modelName,
anthropicApiUrl: customModel.baseUrl,
clientOptions: {
// Required to bypass CORS restrictions
Expand All @@ -119,7 +124,6 @@ export default class ChatModelManager {
},
},
[ChatModelProviders.AZURE_OPENAI]: {
maxTokens: params.maxTokens,
azureOpenAIApiKey: decrypt(customModel.apiKey || params.azureOpenAIApiKey),
azureOpenAIApiInstanceName: params.azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName: params.azureOpenAIApiDeploymentName,
Expand All @@ -128,14 +132,15 @@ export default class ChatModelManager {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
},
...this.handleOpenAIExtraArgs(isO1Model, params.maxTokens, params.temperature, true),
},
[ChatModelProviders.COHEREAI]: {
apiKey: decrypt(customModel.apiKey || params.cohereApiKey),
model: customModel.name,
model: modelName,
},
[ChatModelProviders.GOOGLE]: {
apiKey: decrypt(customModel.apiKey || params.googleApiKey),
model: customModel.name,
modelName: modelName,
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
Expand All @@ -157,7 +162,7 @@ export default class ChatModelManager {
baseUrl: customModel.baseUrl,
},
[ChatModelProviders.OPENROUTERAI]: {
modelName: customModel.name,
modelName: modelName,
openAIApiKey: decrypt(customModel.apiKey || params.openRouterAiApiKey),
configuration: {
baseURL: customModel.baseUrl || "https://openrouter.ai/api/v1",
Expand All @@ -166,33 +171,33 @@ export default class ChatModelManager {
},
[ChatModelProviders.GROQ]: {
apiKey: decrypt(customModel.apiKey || params.groqApiKey),
modelName: customModel.name,
modelName: modelName,
},
[ChatModelProviders.OLLAMA]: {
// ChatOllama has `model` instead of `modelName`!!
model: customModel.name,
model: modelName,
// @ts-ignore
apiKey: customModel.apiKey || "default-key",
// MUST NOT use /v1 in the baseUrl for ollama
baseUrl: customModel.baseUrl || "http://localhost:11434",
},
[ChatModelProviders.LM_STUDIO]: {
modelName: customModel.name,
modelName: modelName,
openAIApiKey: customModel.apiKey || "default-key",
configuration: {
baseURL: customModel.baseUrl || "http://localhost:1234/v1",
fetch: customModel.enableCors ? safeFetch : undefined,
},
},
[ChatModelProviders.OPENAI_FORMAT]: {
modelName: customModel.name,
modelName: modelName,
openAIApiKey: decrypt(customModel.apiKey || "default-key"),
maxTokens: params.maxTokens,
configuration: {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
dangerouslyAllowBrowser: true,
},
...this.handleOpenAIExtraArgs(isO1Model, params.maxTokens, params.temperature, true),
},
};

Expand All @@ -202,6 +207,25 @@ export default class ChatModelManager {
return { ...baseConfig, ...selectedProviderConfig };
}

private handleOpenAIExtraArgs(
isO1Model: boolean,
maxTokens: number,
temperature: number,
streaming: boolean
) {
return isO1Model
? {
maxCompletionTokens: maxTokens,
temperature: 1,
streaming: false,
}
: {
maxTokens: maxTokens,
temperature: temperature,
streaming: streaming,
};
}

// Build a map of modelKey to model config
public buildModelMap(activeModels: CustomModel[]) {
ChatModelManager.modelMap = {};
Expand Down

0 comments on commit c354202

Please sign in to comment.