Skip to content

Commit

Permalink
fix: better model loading ui feedback and model list update (#954)
Browse files Browse the repository at this point in the history
* fix: better model loading feedback and model list update

* added load on providersettings  update
  • Loading branch information
thecodacus authored Dec 31, 2024
1 parent 55cfd5d commit 389eedc
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 49 deletions.
113 changes: 71 additions & 42 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Preventing TS checks with files presented in the video for a better presentation.
*/
import type { Message } from 'ai';
import React, { type RefCallback, useEffect, useState } from 'react';
import React, { type RefCallback, useCallback, useEffect, useState } from 'react';
import { ClientOnly } from 'remix-utils/client-only';
import { Menu } from '~/components/sidebar/Menu.client';
import { IconButton } from '~/components/ui/IconButton';
Expand Down Expand Up @@ -31,6 +31,7 @@ import { toast } from 'react-toastify';
import StarterTemplates from './StarterTemplates';
import type { ActionAlert } from '~/types/actions';
import ChatAlert from './ChatAlert';
import { LLMManager } from '~/lib/modules/llm/manager';

const TEXTAREA_MIN_HEIGHT = 76;

Expand Down Expand Up @@ -100,26 +101,36 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
ref,
) => {
const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200;
const [apiKeys, setApiKeys] = useState<Record<string, string>>(() => {
const savedKeys = Cookies.get('apiKeys');

if (savedKeys) {
try {
return JSON.parse(savedKeys);
} catch (error) {
console.error('Failed to parse API keys from cookies:', error);
return {};
}
}

return {};
});
const [apiKeys, setApiKeys] = useState<Record<string, string>>(getApiKeysFromCookies());
const [modelList, setModelList] = useState(MODEL_LIST);
const [isModelSettingsCollapsed, setIsModelSettingsCollapsed] = useState(false);
const [isListening, setIsListening] = useState(false);
const [recognition, setRecognition] = useState<SpeechRecognition | null>(null);
const [transcript, setTranscript] = useState('');
const [isModelLoading, setIsModelLoading] = useState<string | undefined>('all');

const getProviderSettings = useCallback(() => {
let providerSettings: Record<string, IProviderSetting> | undefined = undefined;

try {
const savedProviderSettings = Cookies.get('providers');

if (savedProviderSettings) {
const parsedProviderSettings = JSON.parse(savedProviderSettings);

if (typeof parsedProviderSettings === 'object' && parsedProviderSettings !== null) {
providerSettings = parsedProviderSettings;
}
}
} catch (error) {
console.error('Error loading Provider Settings from cookies:', error);

// Clear invalid cookie data
Cookies.remove('providers');
}

return providerSettings;
}, []);
useEffect(() => {
console.log(transcript);
}, [transcript]);
Expand Down Expand Up @@ -157,25 +168,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
}, []);

useEffect(() => {
let providerSettings: Record<string, IProviderSetting> | undefined = undefined;

try {
const savedProviderSettings = Cookies.get('providers');

if (savedProviderSettings) {
const parsedProviderSettings = JSON.parse(savedProviderSettings);

if (typeof parsedProviderSettings === 'object' && parsedProviderSettings !== null) {
providerSettings = parsedProviderSettings;
}
}
} catch (error) {
console.error('Error loading Provider Settings from cookies:', error);

// Clear invalid cookie data
Cookies.remove('providers');
}

const providerSettings = getProviderSettings();
let parsedApiKeys: Record<string, string> | undefined = {};

try {
Expand All @@ -187,12 +180,49 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
// Clear invalid cookie data
Cookies.remove('apiKeys');
}
setIsModelLoading('all');
initializeModelList({ apiKeys: parsedApiKeys, providerSettings })
.then((modelList) => {
console.log('Model List: ', modelList);
setModelList(modelList);
})
.catch((error) => {
console.error('Error initializing model list:', error);
})
.finally(() => {
setIsModelLoading(undefined);
});
}, [providerList]);

const onApiKeysChange = async (providerName: string, apiKey: string) => {
const newApiKeys = { ...apiKeys, [providerName]: apiKey };
setApiKeys(newApiKeys);
Cookies.set('apiKeys', JSON.stringify(newApiKeys));

const provider = LLMManager.getInstance(import.meta.env || process.env || {}).getProvider(providerName);

if (provider && provider.getDynamicModels) {
setIsModelLoading(providerName);

initializeModelList({ apiKeys: parsedApiKeys, providerSettings }).then((modelList) => {
console.log('Model List: ', modelList);
setModelList(modelList);
});
}, [apiKeys]);
try {
const providerSettings = getProviderSettings();
const staticModels = provider.staticModels;
const dynamicModels = await provider.getDynamicModels(
newApiKeys,
providerSettings,
import.meta.env || process.env || {},
);

setModelList((preModels) => {
const filteredOutPreModels = preModels.filter((x) => x.provider !== providerName);
return [...filteredOutPreModels, ...staticModels, ...dynamicModels];
});
} catch (error) {
console.error('Error loading dynamic models:', error);
}
setIsModelLoading(undefined);
}
};

const startListening = () => {
if (recognition) {
Expand Down Expand Up @@ -381,15 +411,14 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
setProvider={setProvider}
providerList={providerList || (PROVIDER_LIST as ProviderInfo[])}
apiKeys={apiKeys}
modelLoading={isModelLoading}
/>
{(providerList || []).length > 0 && provider && (
<APIKeyManager
provider={provider}
apiKey={apiKeys[provider.name] || ''}
setApiKey={(key) => {
const newApiKeys = { ...apiKeys, [provider.name]: key };
setApiKeys(newApiKeys);
Cookies.set('apiKeys', JSON.stringify(newApiKeys));
onApiKeysChange(provider.name, key);
}}
/>
)}
Expand Down
23 changes: 16 additions & 7 deletions app/components/chat/ModelSelector.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ interface ModelSelectorProps {
modelList: ModelInfo[];
providerList: ProviderInfo[];
apiKeys: Record<string, string>;
modelLoading?: string;
}

export const ModelSelector = ({
Expand All @@ -19,6 +20,7 @@ export const ModelSelector = ({
setProvider,
modelList,
providerList,
modelLoading,
}: ModelSelectorProps) => {
// Load enabled providers from cookies

Expand Down Expand Up @@ -83,14 +85,21 @@ export const ModelSelector = ({
value={model}
onChange={(e) => setModel?.(e.target.value)}
className="flex-1 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none focus:ring-2 focus:ring-bolt-elements-focus transition-all lg:max-w-[70%]"
disabled={modelLoading === 'all' || modelLoading === provider?.name}
>
{[...modelList]
.filter((e) => e.provider == provider?.name && e.name)
.map((modelOption, index) => (
<option key={index} value={modelOption.name}>
{modelOption.label}
</option>
))}
{modelLoading == 'all' || modelLoading == provider?.name ? (
<option key={0} value="">
Loading...
</option>
) : (
[...modelList]
.filter((e) => e.provider == provider?.name && e.name)
.map((modelOption, index) => (
<option key={index} value={modelOption.name}>
{modelOption.label}
</option>
))
)}
</select>
</div>
);
Expand Down
7 changes: 7 additions & 0 deletions app/lib/modules/llm/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,16 @@ export class LLMManager {
}): Promise<ModelInfo[]> {
const { apiKeys, providerSettings, serverEnv } = options;

let enabledProviders = Array.from(this._providers.values()).map((p) => p.name);

if (providerSettings) {
enabledProviders = enabledProviders.filter((p) => providerSettings[p].enabled);
}

// Get dynamic models from all providers that support them
const dynamicModels = await Promise.all(
Array.from(this._providers.values())
.filter((provider) => enabledProviders.includes(provider.name))
.filter(
(provider): provider is BaseProvider & Required<Pick<ProviderInfo, 'getDynamicModels'>> =>
!!provider.getDynamicModels,
Expand Down
3 changes: 3 additions & 0 deletions public/icons/Hyperbolic.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 389eedc

Please sign in to comment.