diff --git a/.changeset/dirty-bulldogs-check.md b/.changeset/dirty-bulldogs-check.md new file mode 100644 index 0000000..4cf563f --- /dev/null +++ b/.changeset/dirty-bulldogs-check.md @@ -0,0 +1,6 @@ +--- +'aibitat': patch +--- + +Fix an issue where the default model wasnt getting replaced when specifying on +specific agent diff --git a/src/aibitat.ts b/src/aibitat.ts index 02d6967..a31b51a 100644 --- a/src/aibitat.ts +++ b/src/aibitat.ts @@ -173,7 +173,7 @@ export type FunctionDefinition = { export class AIbitat { private emitter = new EventEmitter() - private defaultProvider + private defaultProvider: ProviderConfig private defaultInterrupt private maxRounds private _chats @@ -194,10 +194,10 @@ export class AIbitat { this.defaultInterrupt = interrupt this.maxRounds = maxRounds - this.defaultProvider = this.getProviderForConfig({ + this.defaultProvider = { provider, ...rest, - })! + } } /** @@ -285,6 +285,8 @@ export class AIbitat { throw new Error(`Channel configuration "${channel}" not found`) } return { + provider: 'openai' as const, + model: 'gpt-4' as const, maxRounds: 10, role: 'Group chat manager.', ...config, @@ -606,11 +608,7 @@ export class AIbitat { // get the provider that will be used for the manager // if the manager has a provider, use that otherwise // use the GPT-4 because it has a better reasoning - const nodeProvider = this.getProviderForConfig(channelConfig) - const provider = - nodeProvider || - this.getProviderForConfig({provider: 'openai', model: 'gpt-4'})! - + const provider = this.getProviderForConfig(channelConfig) const history = this.getHistory({to: channel}) // build the messages to send to the provider @@ -704,8 +702,7 @@ ${this.getHistory({to: route.to}) ?.map(name => this.functions.get(name)) .filter(a => !!a) as FunctionDefinition[] | undefined - const nodeProvider = this.getProviderForConfig(fromConfig) - const provider = nodeProvider || this.defaultProvider + const provider = this.getProviderForConfig(fromConfig) // get the chat completion const content = await provider.create(messages, functions) @@ -814,20 +811,21 @@ ${this.getHistory({to: route.to}) * @param config The provider configuration. */ private getProviderForConfig(config: ProviderConfig) { - if (typeof config.provider === 'string') { - switch (config.provider) { - case 'openai': - return new OpenAIProvider({model: config.model}) - - default: - throw new Error( - `Unknown provider: ${config.provider}. Please use "openai"`, - ) - } + const x = { + ...this.defaultProvider, + ...config, } - if (config.provider) { - return config.provider + if (typeof x.provider === 'object') { + return x.provider + } + + switch (x.provider) { + case 'openai': + return new OpenAIProvider({model: x.model}) + + default: + throw new Error(`Unknown provider: ${x.provider}. Please use "openai"`) } } diff --git a/src/providers/openai.ts b/src/providers/openai.ts index 3e659fb..48ebf1a 100644 --- a/src/providers/openai.ts +++ b/src/providers/openai.ts @@ -231,6 +231,7 @@ export class OpenAIProvider extends AIProvider { call: OpenAI.Chat.ChatCompletionMessage.FunctionCall, ) { const funcToCall = functions.find(f => f.name === call.name) + log(`calling function "${call.name}" with arguments: `, call.arguments) if (!funcToCall) { throw new Error(`Function '${call.name}' not found`) } @@ -245,7 +246,6 @@ export class OpenAIProvider extends AIProvider { ) } - log('calling function: ', funcToCall.name, 'with arguments: ', json) return await funcToCall.handler(json) } }