Skip to content

Commit

Permalink
fix default model
Browse files Browse the repository at this point in the history
  • Loading branch information
wladpaiva committed Oct 22, 2023
1 parent b32eafa commit aed887d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
6 changes: 6 additions & 0 deletions .changeset/dirty-bulldogs-check.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'aibitat': patch
---

Fix an issue where the default model wasnt getting replaced when specifying on
specific agent
42 changes: 20 additions & 22 deletions src/aibitat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ export type FunctionDefinition = {
export class AIbitat {
private emitter = new EventEmitter()

private defaultProvider
private defaultProvider: ProviderConfig
private defaultInterrupt
private maxRounds
private _chats
Expand All @@ -194,10 +194,10 @@ export class AIbitat {
this.defaultInterrupt = interrupt
this.maxRounds = maxRounds

this.defaultProvider = this.getProviderForConfig({
this.defaultProvider = {
provider,
...rest,
})!
}
}

/**
Expand Down Expand Up @@ -285,6 +285,8 @@ export class AIbitat {
throw new Error(`Channel configuration "${channel}" not found`)
}
return {
provider: 'openai' as const,
model: 'gpt-4' as const,
maxRounds: 10,
role: 'Group chat manager.',
...config,
Expand Down Expand Up @@ -606,11 +608,7 @@ export class AIbitat {
// get the provider that will be used for the manager
// if the manager has a provider, use that otherwise
// use the GPT-4 because it has a better reasoning
const nodeProvider = this.getProviderForConfig(channelConfig)
const provider =
nodeProvider ||
this.getProviderForConfig({provider: 'openai', model: 'gpt-4'})!

const provider = this.getProviderForConfig(channelConfig)
const history = this.getHistory({to: channel})

// build the messages to send to the provider
Expand Down Expand Up @@ -704,8 +702,7 @@ ${this.getHistory({to: route.to})
?.map(name => this.functions.get(name))
.filter(a => !!a) as FunctionDefinition[] | undefined

const nodeProvider = this.getProviderForConfig(fromConfig)
const provider = nodeProvider || this.defaultProvider
const provider = this.getProviderForConfig(fromConfig)

// get the chat completion
const content = await provider.create(messages, functions)
Expand Down Expand Up @@ -814,20 +811,21 @@ ${this.getHistory({to: route.to})
* @param config The provider configuration.
*/
private getProviderForConfig(config: ProviderConfig) {
if (typeof config.provider === 'string') {
switch (config.provider) {
case 'openai':
return new OpenAIProvider({model: config.model})

default:
throw new Error(
`Unknown provider: ${config.provider}. Please use "openai"`,
)
}
const x = {
...this.defaultProvider,
...config,
}

if (config.provider) {
return config.provider
if (typeof x.provider === 'object') {
return x.provider
}

switch (x.provider) {
case 'openai':
return new OpenAIProvider({model: x.model})

default:
throw new Error(`Unknown provider: ${x.provider}. Please use "openai"`)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ export class OpenAIProvider extends AIProvider<OpenAI> {
call: OpenAI.Chat.ChatCompletionMessage.FunctionCall,
) {
const funcToCall = functions.find(f => f.name === call.name)
log(`calling function "${call.name}" with arguments: `, call.arguments)
if (!funcToCall) {
throw new Error(`Function '${call.name}' not found`)
}
Expand All @@ -245,7 +246,6 @@ export class OpenAIProvider extends AIProvider<OpenAI> {
)
}

log('calling function: ', funcToCall.name, 'with arguments: ', json)
return await funcToCall.handler(json)
}
}

0 comments on commit aed887d

Please sign in to comment.