diff --git a/core/src/types/file/index.ts b/core/src/types/file/index.ts index 61ba614ed2..0b1ea9fe08 100644 --- a/core/src/types/file/index.ts +++ b/core/src/types/file/index.ts @@ -69,6 +69,11 @@ export interface DownloadState2 { */ type: DownloadType2 + /** + * Percentage of the download. + */ + progress: number + /** * The status of the download. */ diff --git a/electron/.eslintrc.js b/electron/.eslintrc.js index d252ec42bd..a8b7c00cb8 100644 --- a/electron/.eslintrc.js +++ b/electron/.eslintrc.js @@ -34,5 +34,5 @@ module.exports = { { name: 'Link', linkAttribute: 'to' }, ], }, - ignorePatterns: ['build', 'renderer', 'node_modules', '@global'], + ignorePatterns: ['build', 'renderer', 'node_modules', '@global', 'playwright-report'], } diff --git a/electron/download.bat b/electron/download.bat index 817ceb9794..04f763cd0e 100644 --- a/electron/download.bat +++ b/electron/download.bat @@ -1,6 +1,24 @@ @echo off +setlocal + +:: Read the version from the version.txt file set /p CORTEX_VERSION=<./resources/version.txt + +:: Set the download URL set DOWNLOAD_URL=https://github.com/janhq/cortex/releases/download/v%CORTEX_VERSION%/cortex-%CORTEX_VERSION%-amd64-windows.tar.gz -echo Downloading from %DOWNLOAD_URL% -.\node_modules\.bin\download %DOWNLOAD_URL% -e -o ./resources/win \ No newline at end of file +:: Set the output directory and file name +set OUTPUT_DIR=./resources/win +set OUTPUT_FILE=%OUTPUT_DIR%/cortex.exe + +echo %OUTPUT_FILE% + +:: Check if the file already exists +if exist %OUTPUT_FILE% ( + echo File %OUTPUT_FILE% already exists. Skipping download. +) else ( + echo Downloading from %DOWNLOAD_URL% + .\node_modules\.bin\download %DOWNLOAD_URL% -e -o %OUTPUT_DIR% +) + +endlocal \ No newline at end of file diff --git a/electron/handlers/native.ts b/electron/handlers/native.ts index fd882f6c1a..ef62eac430 100644 --- a/electron/handlers/native.ts +++ b/electron/handlers/native.ts @@ -14,7 +14,7 @@ import { writeFileSync, readFileSync, existsSync, - mkdirSync, + mkdirSync } from 'fs' import { dump } from 'js-yaml' import os from 'os' @@ -229,25 +229,20 @@ export function handleAppIPCs() { const cortexHomeDir = join(os.homedir(), 'cortex') const cortexModelFolderPath = join(cortexHomeDir, 'models') + + if(!existsSync(cortexModelFolderPath)) + mkdirSync(cortexModelFolderPath) console.log('cortexModelFolderPath', cortexModelFolderPath) const reflect = require('@alumna/reflect') for (const modelName of allModelFolders) { const modelFolderPath = join(janModelFolderPath, modelName) - const filesInModelFolder = readdirSync(modelFolderPath) - if (filesInModelFolder.length <= 1) { - // if only have model.json file or empty folder, we skip it - continue - } + try { - const destinationPath = join(cortexModelFolderPath, modelName) + const filesInModelFolder = readdirSync(modelFolderPath) - // create folder if not exist - if (!existsSync(destinationPath)) { - mkdirSync(destinationPath, { recursive: true }) - } + const destinationPath = join(cortexModelFolderPath, modelName) - try { const modelJsonFullPath = join( janModelFolderPath, modelName, @@ -256,12 +251,25 @@ export function handleAppIPCs() { const model = JSON.parse(readFileSync(modelJsonFullPath, 'utf-8')) const fileNames: string[] = model.sources.map((x: any) => x.filename) - // prepend fileNames with cortexModelFolderPath - const files = fileNames.map((x: string) => - join(cortexModelFolderPath, model.id, x) - ) + let files: string[] = [] - const engine = 'cortex.llamacpp' + if(filesInModelFolder.length > 1) { + // prepend fileNames with cortexModelFolderPath + files = fileNames.map((x: string) => + join(cortexModelFolderPath, model.id, x) + ) + } else if(model.sources.length && !/^(http|https):\/\/[^/]+\/.*/.test(model.sources[0].url)) { + // Symlink case + files = [ model.sources[0].url ] + } else continue; + + // create folder if not exist + // only for local model files + if (!existsSync(destinationPath) && filesInModelFolder.length > 1) { + mkdirSync(destinationPath, { recursive: true }) + } + + const engine = (model.engine === 'nitro' || model.engine === 'cortex') ? 'cortex.llamacpp' : (model.engine ?? 'cortex.llamacpp') const updatedModelFormat = { id: model.id, @@ -288,24 +296,27 @@ export function handleAppIPCs() { max_tokens: model.parameters?.max_tokens ?? 2048, stream: model.parameters?.stream ?? true, } - - const { err } = await reflect({ - src: modelFolderPath, - dest: destinationPath, - recursive: true, - exclude: ['model.json'], - delete: false, - overwrite: true, - errorOnExist: false, - }) - if (err) console.error(err) - else { - // create the model.yml file - const modelYamlData = dump(updatedModelFormat) - const modelYamlPath = join(cortexModelFolderPath, `${modelName}.yaml`) - - writeFileSync(modelYamlPath, modelYamlData) + if(filesInModelFolder.length > 1 ) { + const { err } = await reflect({ + src: modelFolderPath, + dest: destinationPath, + recursive: true, + exclude: ['model.json'], + delete: false, + overwrite: true, + errorOnExist: false, + }) + + if (err) { + console.error(err); + continue; + } } + // create the model.yml file + const modelYamlData = dump(updatedModelFormat) + const modelYamlPath = join(cortexModelFolderPath, `${modelName}.yaml`) + + writeFileSync(modelYamlPath, modelYamlData) } catch (err) { console.error(err) } @@ -316,6 +327,13 @@ export function handleAppIPCs() { NativeRoute.getAllMessagesAndThreads, async (_event): Promise => { const janThreadFolderPath = join(getJanDataFolderPath(), 'threads') + // check if exist + if (!existsSync(janThreadFolderPath)) { + return { + threads: [], + messages: [], + } + } // get children of thread folder const allThreadFolders = readdirSync(janThreadFolderPath) const threads: any[] = [] @@ -335,10 +353,12 @@ export function handleAppIPCs() { threadFolder, 'messages.jsonl' ) + + if(!existsSync(messageFullPath)) continue; const lines = readFileSync(messageFullPath, 'utf-8') - .toString() - .split('\n') - .filter((line: any) => line !== '') + .toString() + .split('\n') + .filter((line: any) => line !== '') for (const line of lines) { messages.push(JSON.parse(line)) } @@ -357,6 +377,10 @@ export function handleAppIPCs() { NativeRoute.getAllLocalModels, async (_event): Promise => { const janModelsFolderPath = join(getJanDataFolderPath(), 'models') + + if (!existsSync(janModelsFolderPath)) { + return false + } // get children of thread folder const allModelsFolders = readdirSync(janModelsFolderPath) let hasLocalModels = false diff --git a/electron/main.ts b/electron/main.ts index c6414c92ac..5ce177e49c 100644 --- a/electron/main.ts +++ b/electron/main.ts @@ -1,7 +1,7 @@ import { app, BrowserWindow } from 'electron' import { join, resolve } from 'path' -import { exec } from 'child_process' +import { exec, execSync, ChildProcess } from 'child_process' import { cortexPath } from './cortex-runner' /** @@ -56,13 +56,18 @@ log.info('Log from the main process') // replace all console.log to log Object.assign(console, log.functions) +let cortexService: ChildProcess | undefined = undefined + app .whenReady() + .then(() => killProcessesOnPort(3929)) + .then(() => killProcessesOnPort(1337)) .then(() => { - log.info('Starting cortex with path:', cortexPath) + const command = `${cortexPath} -a 127.0.0.1 -p 1337` + + log.info('Starting cortex with command:', command) // init cortex - // running shell command cortex init -s - exec(`${cortexPath}`, (error, stdout, stderr) => { + cortexService = exec(`${command}`, (error, stdout, stderr) => { if (error) { log.error(`error: ${error.message}`) return @@ -123,25 +128,37 @@ app.on('open-url', (_event, url) => { }) app.once('quit', async () => { - await stopApiServer() cleanUpAndQuit() }) app.once('window-all-closed', async () => { await stopApiServer() + await stopCortexService() cleanUpAndQuit() }) +async function stopCortexService() { + try { + const pid = cortexService?.pid + if (!pid) { + console.log('No cortex service to stop.') + return + } + process.kill(pid) + console.log(`Service with PID ${pid} has been terminated.`) + } catch (error) { + console.error('Error killing service:', error) + } +} + async function stopApiServer() { + // this function is not meant to be success. It will throw an error. try { - console.log('Stopping API server') - const response = await fetch('http://localhost:1337/v1/process', { + await fetch('http://localhost:1337/v1/system', { method: 'DELETE', }) - - console.log('Response status:', response.status) } catch (error) { - console.error('Error stopping API server:', error) + // do nothing } } @@ -154,6 +171,88 @@ function handleIPCs() { handleAppIPCs() } +function killProcessesOnPort(port: number): void { + try { + console.log(`Killing processes on port ${port}...`) + if (process.platform === 'win32') { + killProcessesOnWindowsPort(port) + } else { + killProcessesOnUnixPort(port) + } + } catch (error) { + console.error( + `Failed to kill process(es) on port ${port}: ${(error as Error).message}` + ) + } +} + +function killProcessesOnWindowsPort(port: number): void { + let result: string + try { + result = execSync(`netstat -ano | findstr :${port}`).toString() + } catch (error) { + console.log(`No processes found on port ${port}.`) + return + } + + const lines = result.split('\n').filter(Boolean) + + if (lines.length === 0) { + console.log(`No processes found on port ${port}.`) + return + } + + const pids = lines + .map((line) => { + const parts = line.trim().split(/\s+/) + return parts[parts.length - 1] + }) + .filter((pid): pid is string => Boolean(pid) && !isNaN(Number(pid))) + + if (pids.length === 0) { + console.log(`No valid PIDs found for port ${port}.`) + return + } + const uniquePids = Array.from(new Set(pids)) + console.log('uniquePids', uniquePids) + + uniquePids.forEach((pid) => { + try { + execSync(`taskkill /PID ${pid} /F`) + console.log( + `Process with PID ${pid} on port ${port} has been terminated.` + ) + } catch (error) { + console.error( + `Failed to kill process with PID ${pid}: ${(error as Error).message}` + ) + } + }) +} + +function killProcessesOnUnixPort(port: number): void { + let pids: string[] + + try { + pids = execSync(`lsof -ti tcp:${port}`) + .toString() + .trim() + .split('\n') + .filter(Boolean) + } catch (error) { + if ((error as { status?: number }).status === 1) { + console.log(`No processes found on port ${port}.`) + return + } + throw error // Re-throw if it's not the "no processes found" error + } + + pids.forEach((pid) => { + process.kill(parseInt(pid), 'SIGTERM') + console.log(`Process with PID ${pid} on port ${port} has been terminated.`) + }) +} + /** * Suppress Node error messages */ diff --git a/electron/managers/window.ts b/electron/managers/window.ts index 16da61c693..be2d0a7b95 100644 --- a/electron/managers/window.ts +++ b/electron/managers/window.ts @@ -32,6 +32,7 @@ class WindowManager { x: bounds.x, y: bounds.y, webPreferences: { + allowRunningInsecureContent: true, nodeIntegration: true, preload: preloadPath, webSecurity: false, diff --git a/electron/resources/version.txt b/electron/resources/version.txt index ec5a9b5e75..3281cfcd80 100644 --- a/electron/resources/version.txt +++ b/electron/resources/version.txt @@ -1 +1 @@ -0.5.0-1 +0.5.0-5 diff --git a/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx b/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx index 312350adb1..94abec4b79 100644 --- a/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx +++ b/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx @@ -1,9 +1,11 @@ +import { useCallback } from 'react' + import { Model } from '@janhq/core' import { Button, Badge } from '@janhq/joi' import { useAtomValue } from 'jotai' -import useModels from '@/hooks/useModels' +import useModelStop from '@/hooks/useModelStop' import { activeModelsAtom, @@ -13,7 +15,7 @@ import { const Column = ['Name', 'Engine', ''] const TableActiveModel: React.FC = () => { - const { stopModel } = useModels() + const stopModelMutation = useModelStop() const activeModels = useAtomValue(activeModelsAtom) const downloadedModels = useAtomValue(downloadedModelsAtom) @@ -25,6 +27,13 @@ const TableActiveModel: React.FC = () => { } }) + const onStopModelClick = useCallback( + (modelId: string) => { + stopModelMutation.mutate(modelId) + }, + [stopModelMutation] + ) + return (
@@ -58,7 +67,7 @@ const TableActiveModel: React.FC = () => { diff --git a/web/containers/Layout/BottomPanel/SystemMonitor/index.tsx b/web/containers/Layout/BottomPanel/SystemMonitor/index.tsx index b88cf6bc66..46632d9fd0 100644 --- a/web/containers/Layout/BottomPanel/SystemMonitor/index.tsx +++ b/web/containers/Layout/BottomPanel/SystemMonitor/index.tsx @@ -57,7 +57,7 @@ const SystemMonitor: React.FC = () => { const register = useCallback(async () => { if (abortControllerRef.current) return abortControllerRef.current = new AbortController() - await fetchEventSource(`${host}/events/resources`, { + await fetchEventSource(`${host}/system/events/resources`, { onmessage(ev) { if (!ev.data || ev.data === '') return try { diff --git a/web/containers/Layout/index.tsx b/web/containers/Layout/index.tsx index f33da709f2..b57524b181 100644 --- a/web/containers/Layout/index.tsx +++ b/web/containers/Layout/index.tsx @@ -16,8 +16,10 @@ import TopPanel from '@/containers/Layout/TopPanel' import { getImportModelStageAtom } from '@/hooks/useImportModel' import DownloadLocalModelModal from '@/screens/HubScreen2/components/DownloadLocalModelModal' +import InferenceErrorModal from '@/screens/HubScreen2/components/InferenceErrorModal' import SetUpApiKeyModal from '@/screens/HubScreen2/components/SetUpApiKeyModal' import SetUpRemoteModelModal from '@/screens/HubScreen2/components/SetUpRemoteModelModal' +import WarningMultipleModelModal from '@/screens/HubScreen2/components/WarningMultipleModelModal' import { SUCCESS_SET_NEW_DESTINATION } from '@/screens/Settings/Advanced/DataFolder' import CancelModelImportModal from '@/screens/Settings/CancelModelImportModal' import ChooseWhatToImportModal from '@/screens/Settings/ChooseWhatToImportModal' @@ -82,6 +84,8 @@ const BaseLayout = () => { {importModelStage === 'EDIT_MODEL_INFO' && } {importModelStage === 'CONFIRM_CANCEL' && } + + diff --git a/web/containers/Providers/DownloadEventListener.tsx b/web/containers/Providers/DownloadEventListener.tsx index 3f8fb23555..9ec6382ae7 100644 --- a/web/containers/Providers/DownloadEventListener.tsx +++ b/web/containers/Providers/DownloadEventListener.tsx @@ -6,8 +6,14 @@ import { useAtomValue, useSetAtom } from 'jotai' import { downloadStateListAtom } from '@/hooks/useDownloadState' +import useModels from '@/hooks/useModels' + import { waitingForCortexAtom } from '@/helpers/atoms/App.atom' import { hostAtom } from '@/helpers/atoms/AppConfig.atom' +import { + setImportingModelSuccessAtom, + updateImportingModelProgressAtom, +} from '@/helpers/atoms/Model.atom' const DownloadEventListener: React.FC = () => { const host = useAtomValue(hostAtom) @@ -15,15 +21,52 @@ const DownloadEventListener: React.FC = () => { const abortController = useRef(new AbortController()) const setDownloadStateList = useSetAtom(downloadStateListAtom) const setWaitingForCortex = useSetAtom(waitingForCortexAtom) + const { getModels } = useModels() + + const updateImportingModelProgress = useSetAtom( + updateImportingModelProgressAtom + ) + const setImportingModelSuccess = useSetAtom(setImportingModelSuccessAtom) + + const handleLocalImportModels = useCallback( + (events: DownloadState2[]) => { + if (events.length === 0) return + for (const event of events) { + if (event.progress === 100) { + setImportingModelSuccess(event.id) + } else { + updateImportingModelProgress(event.id, event.progress) + } + } + getModels() + }, + [setImportingModelSuccess, updateImportingModelProgress, getModels] + ) const subscribeDownloadEvent = useCallback(async () => { if (isRegistered.current) return - await fetchEventSource(`${host}/events/download`, { + await fetchEventSource(`${host}/system/events/download`, { onmessage(ev) { if (!ev.data || ev.data === '') return try { - const downloadEvent = JSON.parse(ev.data) as DownloadState2[] - setDownloadStateList(downloadEvent) + const downloadEvents = JSON.parse(ev.data) as DownloadState2[] + const remoteDownloadEvents: DownloadState2[] = [] + const localImportEvents: DownloadState2[] = [] + // filter out the import local events + for (const event of downloadEvents) { + console.debug('Receiving event', event) + if ( + isAbsolutePath(event.id) && + event.type === 'model' && + event.children.length === 0 + ) { + localImportEvents.push(event) + } else { + remoteDownloadEvents.push(event) + } + } + handleLocalImportModels(localImportEvents) + setDownloadStateList(remoteDownloadEvents) } catch (err) { console.error(err) } @@ -40,7 +83,7 @@ const DownloadEventListener: React.FC = () => { }) console.log('Download event subscribed') isRegistered.current = true - }, [host, setDownloadStateList, setWaitingForCortex]) + }, [host, setDownloadStateList, setWaitingForCortex, handleLocalImportModels]) const unsubscribeDownloadEvent = useCallback(() => { if (!isRegistered.current) return @@ -60,4 +103,22 @@ const DownloadEventListener: React.FC = () => { return null } +const isAbsolutePath = (path: string): boolean => { + // Trim any leading or trailing whitespace + const trimmedPath = path.trim() + + // Check for Unix-like absolute path + if (trimmedPath.startsWith('/')) { + return true + } + + // Check for Windows absolute path (with drive letter) + if (/^[A-Za-z]:[/\\]/.test(trimmedPath)) { + return true + } + + // All other paths are not considered absolute local paths + return false +} + export default DownloadEventListener diff --git a/web/containers/Providers/EventListener.tsx b/web/containers/Providers/EventListener.tsx index 7c415bd93d..71dfccd38c 100644 --- a/web/containers/Providers/EventListener.tsx +++ b/web/containers/Providers/EventListener.tsx @@ -9,7 +9,6 @@ import DownloadEventListener from './DownloadEventListener' import KeyListener from './KeyListener' import ModelEventListener from './ModelEventListener' -import ModelImportListener from './ModelImportListener' const EventListenerWrapper: React.FC = () => ( @@ -19,7 +18,6 @@ const EventListenerWrapper: React.FC = () => ( - ) diff --git a/web/containers/Providers/ModelEventListener.tsx b/web/containers/Providers/ModelEventListener.tsx index 02bdb99fbd..110f80c692 100644 --- a/web/containers/Providers/ModelEventListener.tsx +++ b/web/containers/Providers/ModelEventListener.tsx @@ -91,7 +91,7 @@ function ModelEventListener() { if (abortController.current) return abortController.current = new AbortController() - await fetchEventSource(`${host}/events/model`, { + await fetchEventSource(`${host}/system/events/model`, { onmessage(ev) { if (!ev.data || ev.data === '') return try { diff --git a/web/containers/Providers/ModelImportListener.tsx b/web/containers/Providers/ModelImportListener.tsx deleted file mode 100644 index 9a46be6186..0000000000 --- a/web/containers/Providers/ModelImportListener.tsx +++ /dev/null @@ -1,103 +0,0 @@ -import { useCallback, useEffect } from 'react' - -import { ImportingModel, Model } from '@janhq/core' -import { useSetAtom } from 'jotai' - -import { snackbar } from '../Toast' - -import { - setImportingModelErrorAtom, - setImportingModelSuccessAtom, - updateImportingModelProgressAtom, -} from '@/helpers/atoms/Model.atom' - -const ModelImportListener: React.FC = () => { - const updateImportingModelProgress = useSetAtom( - updateImportingModelProgressAtom - ) - const setImportingModelSuccess = useSetAtom(setImportingModelSuccessAtom) - const setImportingModelFailed = useSetAtom(setImportingModelErrorAtom) - - const onImportModelUpdate = useCallback( - async (state: ImportingModel) => { - if (!state.importId) return - updateImportingModelProgress(state.importId, state.percentage ?? 0) - }, - [updateImportingModelProgress] - ) - - const onImportModelFailed = useCallback( - async (state: ImportingModel) => { - if (!state.importId) return - setImportingModelFailed(state.importId, state.error ?? '') - }, - [setImportingModelFailed] - ) - - const onImportModelSuccess = useCallback( - (state: ImportingModel) => { - if (!state.modelId) return - // events.emit(ModelEvent.OnModelsUpdate, {}) - setImportingModelSuccess(state.importId, state.modelId) - }, - [setImportingModelSuccess] - ) - - const onImportModelFinished = useCallback((importedModels: Model[]) => { - const modelText = importedModels.length === 1 ? 'model' : 'models' - snackbar({ - description: `Successfully imported ${importedModels.length} ${modelText}`, - type: 'success', - }) - }, []) - - useEffect(() => { - console.debug('ModelImportListener: registering event listeners..') - - // events.on( - // LocalImportModelEvent.onLocalImportModelUpdate, - // onImportModelUpdate - // ) - // events.on( - // LocalImportModelEvent.onLocalImportModelSuccess, - // onImportModelSuccess - // ) - // events.on( - // LocalImportModelEvent.onLocalImportModelFinished, - // onImportModelFinished - // ) - // events.on( - // LocalImportModelEvent.onLocalImportModelFailed, - // onImportModelFailed - // ) - - // return () => { - // console.debug('ModelImportListener: unregistering event listeners...') - // events.off( - // LocalImportModelEvent.onLocalImportModelUpdate, - // onImportModelUpdate - // ) - // events.off( - // LocalImportModelEvent.onLocalImportModelSuccess, - // onImportModelSuccess - // ) - // events.off( - // LocalImportModelEvent.onLocalImportModelFinished, - // onImportModelFinished - // ) - // events.off( - // LocalImportModelEvent.onLocalImportModelFailed, - // onImportModelFailed - // ) - // } - }, [ - onImportModelUpdate, - onImportModelSuccess, - onImportModelFinished, - onImportModelFailed, - ]) - - return null -} - -export default ModelImportListener diff --git a/web/containers/Providers/index.tsx b/web/containers/Providers/index.tsx index 6e335e8251..e27fcaaf15 100644 --- a/web/containers/Providers/index.tsx +++ b/web/containers/Providers/index.tsx @@ -13,8 +13,6 @@ import ThemeWrapper from '@/containers/Providers/Theme' import { setupCoreServices } from '@/services/coreService' -import Umami from '@/utils/umami' - import DataLoader from './DataLoader' import ModalMigrations from './ModalMigrations' @@ -35,7 +33,7 @@ const Providers = ({ children }: PropsWithChildren) => { - + {/* */} {setupCore && ( diff --git a/web/containers/WaitingCortexModal/index.tsx b/web/containers/WaitingCortexModal/index.tsx index 3168ba4810..19739d90e0 100644 --- a/web/containers/WaitingCortexModal/index.tsx +++ b/web/containers/WaitingCortexModal/index.tsx @@ -1,5 +1,9 @@ +import { useCallback, useEffect } from 'react' + import { Modal } from '@janhq/joi' -import { useAtomValue } from 'jotai' +import { useAtom, useAtomValue } from 'jotai' + +import useCortex from '@/hooks/useCortex' import Spinner from '../Loader/Spinner' @@ -8,12 +12,22 @@ import { hostAtom } from '@/helpers/atoms/AppConfig.atom' const WaitingForCortexModal: React.FC = () => { const host = useAtomValue(hostAtom) - const open = useAtomValue(waitingForCortexAtom) + const [waitingForCortex, setWaitingForCortex] = useAtom(waitingForCortexAtom) + const { isSystemAlive } = useCortex() + + const checkSystemAlive = useCallback(async () => { + setWaitingForCortex(!(await isSystemAlive())) + }, [setWaitingForCortex, isSystemAlive]) + + // Check health for the first time on mount + useEffect(() => { + checkSystemAlive() + }, [checkSystemAlive]) return ( diff --git a/web/helpers/atoms/App.atom.ts b/web/helpers/atoms/App.atom.ts index eacbf3a511..0668f47631 100644 --- a/web/helpers/atoms/App.atom.ts +++ b/web/helpers/atoms/App.atom.ts @@ -11,7 +11,7 @@ export const mainViewStateAtom = atom(MainViewState.Thread) export const defaultJanDataFolderAtom = atom('') -export const waitingForCortexAtom = atom(false) +export const waitingForCortexAtom = atom(true) // Store panel atom export const showLeftPanelAtom = atom(true) diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts index 229c8800de..88f685705b 100644 --- a/web/helpers/atoms/Model.atom.ts +++ b/web/helpers/atoms/Model.atom.ts @@ -61,14 +61,14 @@ export const setImportingModelErrorAtom = atom( export const setImportingModelSuccessAtom = atom( null, - (get, set, importId: string, modelId: string) => { + (get, set, importId: string) => { const model = get(importingModelsAtom).find((x) => x.importId === importId) if (!model) return const newModel: ImportingModel = { ...model, - modelId, + modelId: undefined, status: 'IMPORTED', - percentage: 1, + percentage: 100, } const newList = get(importingModelsAtom).map((x) => x.importId === importId ? newModel : x diff --git a/web/helpers/atoms/Thread.atom.ts b/web/helpers/atoms/Thread.atom.ts index 0a44fcdc6d..af5598e213 100644 --- a/web/helpers/atoms/Thread.atom.ts +++ b/web/helpers/atoms/Thread.atom.ts @@ -61,17 +61,16 @@ export const setActiveThreadIdAtom = atom( } ) -export const waitingToSendMessage = atom(undefined) export const isLoadingModelAtom = atom(undefined) -export const isGeneratingResponseAtom = atom(undefined) +export const isGeneratingResponseAtom = atom(false) /** * Stores all threads for the current user */ export const threadsAtom = atom([]) -export const deleteThreadAtom = atom(null, (get, set, threadId: string) => { +export const deleteThreadAtom = atom(null, (_get, set, threadId: string) => { set(threadsAtom, (threads) => { // set active thread to the latest const allThreads = threads.filter((c) => c.id !== threadId) diff --git a/web/hooks/useCortex.ts b/web/hooks/useCortex.ts index 61798d9302..491c646936 100644 --- a/web/hooks/useCortex.ts +++ b/web/hooks/useCortex.ts @@ -1,6 +1,7 @@ -import 'cortexso-node/shims/web' import { useCallback } from 'react' +import { Cortex } from '@cortexso/cortex.js' +import { Engine } from '@cortexso/cortex.js/resources' import { Assistant, Model, @@ -11,10 +12,9 @@ import { AssistantCreateParams, AssistantUpdateParams, LlmEngine, + LlmEngines, } from '@janhq/core' -import { Cortex } from 'cortexso-node' - import { useAtomValue } from 'jotai' import { UpdateConfigMutationVariables } from './useEngineMutation' @@ -24,22 +24,6 @@ import { MessageUpdateMutationVariables } from './useMessageUpdateMutation' import { hostAtom } from '@/helpers/atoms/AppConfig.atom' -const EngineInitStatuses = [ - 'ready', - 'not_initialized', - 'missing_configuration', - 'not_supported', -] as const -export type EngineInitStatus = (typeof EngineInitStatuses)[number] - -export type EngineStatus = { - name: LlmEngine - description: string - version: string - productName: string - status: EngineInitStatus -} - const useCortex = () => { const host = useAtomValue(hostAtom) @@ -49,50 +33,33 @@ const useCortex = () => { dangerouslyAllowBrowser: true, }) - // TODO: put in to cortexso-node? - const getEngineStatuses = useCallback(async (): Promise => { - const response = await fetch(`${host}/engines`, { - method: 'GET', - }) - const data = await response.json() - const engineStatuses: EngineStatus[] = [] - data.data.forEach((engineStatus: EngineStatus) => { - engineStatuses.push(engineStatus) - }) - return engineStatuses - }, [host]) - - // TODO: put in to cortexso-node? - const getEngineStatus = useCallback( - async (engine: LlmEngine): Promise => { - try { - const response = await fetch(`${host}/engines/${engine}`, { - method: 'GET', - }) - const data = (await response.json()) as EngineStatus - return data - } catch (err) { - console.error(err) + const getEngineStatuses = useCallback(async (): Promise => { + const engineResponse = await cortex.engines.list() + // @ts-expect-error incompatible types + const engineStatuses: Engine[] = engineResponse.body.data.map( + (engine: Engine) => { + return { + name: engine.name, + description: engine.description, + version: engine.version, + productName: engine.productName, + status: engine.status, + } } - }, - [host] - ) + ) + + return engineStatuses + }, [cortex.engines]) - // TODO: put in to cortexso-node? const initializeEngine = useCallback( async (engine: LlmEngine) => { try { - await fetch(`${host}/engines/${engine}/init/`, { - method: 'POST', - headers: { - accept: 'application/json', - }, - }) + await cortex.engines.init(engine) } catch (err) { console.error(err) } }, - [host] + [cortex.engines] ) const fetchAssistants = useCallback(async () => { @@ -132,8 +99,15 @@ const useCortex = () => { console.debug('Model id is empty, skipping', model) continue } + const engine = LlmEngines.find((engine) => engine === model.engine) + if (!engine) { + console.error(`Model ${modelId} has an invalid engine ${model.engine}`) + continue + } + models.push({ ...model, + engine: engine, model: modelId, // @ts-expect-error each model must have associated files files: model['files'], @@ -263,26 +237,18 @@ const useCortex = () => { const updateModel = useCallback( async (modelId: string, options: Record) => { try { - return await fetch(`${host}/models/${modelId}`, { - method: 'PATCH', - headers: { - 'accept': 'application/json', - // eslint-disable-next-line @typescript-eslint/naming-convention - 'Content-Type': 'application/json', - }, - body: JSON.stringify(options), - }) + return await cortex.models.update(modelId, options) } catch (err) { console.error(err) } }, - [host] + [cortex.models] ) - // TODO: put this into cortexso-node const downloadModel = useCallback( async (modelId: string, fileName?: string, persistedModelId?: string) => { try { + // return await cortex.models.download(modelId) return await fetch(`${host}/models/${modelId}/pull`, { method: 'POST', headers: { @@ -305,19 +271,12 @@ const useCortex = () => { const abortDownload = useCallback( async (downloadId: string) => { try { - return await fetch(`${host}/models/${downloadId}/pull`, { - method: 'DELETE', - headers: { - 'accept': 'application/json', - // eslint-disable-next-line @typescript-eslint/naming-convention - 'Content-Type': 'application/json', - }, - }) + return await cortex.models.abortDownload(downloadId) } catch (err) { console.error(err) } }, - [host] + [cortex.models] ) const createAssistant = useCallback( @@ -335,22 +294,14 @@ const useCortex = () => { // TODO: add this to cortex-node const registerEngineConfig = useCallback( async (variables: UpdateConfigMutationVariables) => { + const { engine, config } = variables try { - const { engine, config } = variables - await fetch(`${host}/engines/${engine}`, { - method: 'PATCH', - headers: { - 'accept': 'application/json', - // eslint-disable-next-line @typescript-eslint/naming-convention - 'Content-Type': 'application/json', - }, - body: JSON.stringify(config), - }) + await cortex.engines.update(engine, config) } catch (err) { console.error(err) } }, - [host] + [cortex.engines] ) // add this to cortex-node? @@ -368,6 +319,15 @@ const useCortex = () => { [host] ) + const isSystemAlive = useCallback(async () => { + try { + await cortex.system.status() + return true + } catch { + return false + } + }, [cortex.system]) + return { fetchAssistants, fetchThreads, @@ -392,9 +352,9 @@ const useCortex = () => { chatCompletionNonStreaming, registerEngineConfig, createModel, - getEngineStatus, initializeEngine, getEngineStatuses, + isSystemAlive, } } diff --git a/web/hooks/useDownloadState.ts b/web/hooks/useDownloadState.ts index ab8c5676ee..0833c428d1 100644 --- a/web/hooks/useDownloadState.ts +++ b/web/hooks/useDownloadState.ts @@ -10,6 +10,7 @@ export const addDownloadModelStateAtom = atom( id: modelId, title: modelId, type: DownloadType2.Model, + progress: 0, status: DownloadStatus.Downloading, children: [ { diff --git a/web/hooks/useDropModelBinaries.ts b/web/hooks/useDropModelBinaries.ts index d87e96627e..6e593304bc 100644 --- a/web/hooks/useDropModelBinaries.ts +++ b/web/hooks/useDropModelBinaries.ts @@ -3,8 +3,6 @@ import { useCallback } from 'react' import { ImportingModel } from '@janhq/core' import { useSetAtom } from 'jotai' -import { v4 as uuidv4 } from 'uuid' - import { snackbar } from '@/containers/Toast' import { getFileInfoFromFile } from '@/utils/file' @@ -26,17 +24,23 @@ export default function useDropModelBinaries() { ) const supportedFiles = files.filter((file) => file.path.endsWith('.gguf')) - const importingModels: ImportingModel[] = supportedFiles.map((file) => ({ - importId: uuidv4(), - modelId: undefined, - name: file.name.replace('.gguf', ''), - description: '', - path: file.path, - tags: [], - size: file.size, - status: 'PREPARING', - format: 'gguf', - })) + const importingModels: ImportingModel[] = supportedFiles.map((file) => { + const normalizedPath = isWindows + ? file.path.replace(/\\/g, '/') + : file.path + + return { + importId: normalizedPath, + modelId: undefined, + name: normalizedPath.replace('.gguf', ''), + description: '', + path: file.path, + tags: [], + size: file.size, + status: 'PREPARING', + format: 'gguf', + } + }) if (unsupportedFiles.length > 0) { snackbar({ description: `Only files with .gguf extension can be imported.`, diff --git a/web/hooks/useEngineInit.ts b/web/hooks/useEngineInit.ts index 81cdbc5a27..34e21b74aa 100644 --- a/web/hooks/useEngineInit.ts +++ b/web/hooks/useEngineInit.ts @@ -1,6 +1,7 @@ +import { Engine } from '@cortexso/cortex.js/resources' import { useMutation, useQueryClient } from '@tanstack/react-query' -import useCortex, { EngineStatus } from './useCortex' +import useCortex from './useCortex' import { engineQueryKey } from './useEngineQuery' const useEngineInit = () => { @@ -10,21 +11,21 @@ const useEngineInit = () => { return useMutation({ mutationFn: initializeEngine, - onSuccess: async (data, variables) => { - console.debug(`Engine ${variables} initialized`, data) + onSuccess: async (data, engineName) => { + console.debug(`Engine ${engineName} initialized`, data) // optimistically set the engine status to 'ready' const queryCacheData = await queryClient.getQueryData(engineQueryKey) if (!queryCacheData) { return queryClient.invalidateQueries({ queryKey: engineQueryKey }) } - const engineStatuses = queryCacheData as EngineStatus[] + const engineStatuses = queryCacheData as Engine[] engineStatuses.forEach((engine) => { - if (engine.name === variables) { + if (engine.name === engineName) { engine.status = 'ready' } }) - console.log(`Updated engine status: ${engineStatuses}`) + console.debug(`Updated engine status: ${engineStatuses}`) await queryClient.setQueryData(engineQueryKey, engineStatuses) }, diff --git a/web/hooks/useMigratingData.ts b/web/hooks/useMigratingData.ts index 8938846277..e19bff80e4 100644 --- a/web/hooks/useMigratingData.ts +++ b/web/hooks/useMigratingData.ts @@ -68,13 +68,12 @@ const useMigratingData = () => { continue } const threadTitle: string = thread.title ?? 'New Thread' - const instruction: string = thread.assistants[0]?.instruction ?? '' - + const instructions: string = thread.assistants[0]?.instructions ?? '' // currently, we don't have api support for creating thread with messages const cortexThread = await createThread(modelId, assistants[0]) console.log('createThread', cortexThread) // update instruction - cortexThread.assistants[0].instructions = instruction + cortexThread.assistants[0].instructions = instructions cortexThread.title = threadTitle // update thread name diff --git a/web/hooks/useModelStop.ts b/web/hooks/useModelStop.ts new file mode 100644 index 0000000000..891f38b34c --- /dev/null +++ b/web/hooks/useModelStop.ts @@ -0,0 +1,21 @@ +import { useMutation } from '@tanstack/react-query' + +import useCortex from './useCortex' + +const useModelStop = () => { + const { stopModel } = useCortex() + + return useMutation({ + mutationFn: stopModel, + + onSuccess: (data, modelId) => { + console.debug(`Model ${modelId} stopped successfully`, data) + }, + + onError: (error, modelId) => { + console.debug(`Stop model ${modelId} error`, error) + }, + }) +} + +export default useModelStop diff --git a/web/hooks/useModels.ts b/web/hooks/useModels.ts index e8d2714a40..40a42da413 100644 --- a/web/hooks/useModels.ts +++ b/web/hooks/useModels.ts @@ -16,7 +16,6 @@ const useModels = () => { const removeDownloadedModel = useSetAtom(removeDownloadedModelAtom) const { fetchModels, - stopModel: cortexStopModel, deleteModel: cortexDeleteModel, updateModel: cortexUpdateModel, } = useCortex() @@ -29,11 +28,6 @@ const useModels = () => { getDownloadedModels() }, [setDownloadedModels, fetchModels]) - const stopModel = useCallback( - async (modelId: string) => cortexStopModel(modelId), - [cortexStopModel] - ) - const deleteModel = useCallback( async (modelId: string) => { await cortexDeleteModel(modelId) @@ -54,7 +48,7 @@ const useModels = () => { [cortexUpdateModel] ) - return { getModels, stopModel, deleteModel, updateModel } + return { getModels, deleteModel, updateModel } } export default useModels diff --git a/web/hooks/useSendMessage.ts b/web/hooks/useSendMessage.ts index 274cd57ea5..a8f1187677 100644 --- a/web/hooks/useSendMessage.ts +++ b/web/hooks/useSendMessage.ts @@ -17,6 +17,11 @@ import { currentPromptAtom, editPromptAtom } from '@/containers/Providers/Jotai' import { toaster } from '@/containers/Toast' +import { inferenceErrorAtom } from '@/screens/HubScreen2/components/InferenceErrorModal' + +import { showWarningMultipleModelModalAtom } from '@/screens/HubScreen2/components/WarningMultipleModelModal' +import { concurrentModelWarningThreshold } from '@/screens/Settings/MyModels/ModelItem' + import useCortex from './useCortex' import useEngineInit from './useEngineInit' @@ -89,6 +94,11 @@ const useSendMessage = () => { const startModel = useModelStart() const abortControllerRef = useRef(undefined) + const didUserAborted = useRef(false) + const setInferenceErrorAtom = useSetAtom(inferenceErrorAtom) + const setShowWarningMultipleModelModal = useSetAtom( + showWarningMultipleModelModalAtom + ) const validatePrerequisite = useCallback(async (): Promise => { const errorTitle = 'Failed to send message' @@ -195,10 +205,17 @@ const useSendMessage = () => { const stopInference = useCallback(() => { abortControllerRef.current?.abort() + didUserAborted.current = true }, []) const summarizeThread = useCallback( async (messages: string[], modelId: string, thread: Thread) => { + // if its a local model, and is not started, skip summarization + if (LocalEngines.find((e) => e === selectedModel!.engine) != null) { + if (!activeModels.map((model) => model.model).includes(modelId)) { + return + } + } const maxWordForThreadTitle = 10 const summarizeMessages: ChatCompletionMessageParam[] = [ { @@ -223,6 +240,8 @@ const useSendMessage = () => { updateThreadTitle(thread.id, summarizedText) }, [ + activeModels, + selectedModel, addThreadIdShouldAnimateTitle, chatCompletionNonStreaming, updateThreadTitle, @@ -241,6 +260,11 @@ const useSendMessage = () => { if (LocalEngines.find((e) => e === selectedModel!.engine) != null) { // start model if local and not started if (!activeModels.map((model) => model.model).includes(modelId)) { + if (activeModels.length >= concurrentModelWarningThreshold) { + // if max concurrent models reached, stop the first model + // display popup + setShowWarningMultipleModelModal(true) + } await startModel.mutateAsync(modelId) } } @@ -268,7 +292,10 @@ const useSendMessage = () => { case 'assistant': return { role: msg.role, - content: (msg.content[0] as TextContentBlock).text.value, + content: + msg.content[0] != null + ? (msg.content[0] as TextContentBlock).text.value + : '', } // we will need to support other roles in the future @@ -300,6 +327,7 @@ const useSendMessage = () => { ...modelOptions, }) + didUserAborted.current = false abortControllerRef.current = stream.controller const assistantMessage = await createMessage.mutateAsync({ @@ -366,6 +394,7 @@ const useSendMessage = () => { }, }) } else { + didUserAborted.current = false const abortController = new AbortController() const response = await chatCompletionNonStreaming( { @@ -427,9 +456,18 @@ const useSendMessage = () => { } } catch (err) { console.error(err) + // @ts-expect-error error message should be there + const errorMessage = err['message'] + if (errorMessage != null) { + setInferenceErrorAtom({ + engine: selectedModel!.engine, + message: errorMessage, + }) + } toaster({ - title: 'Failed to generate response', + title: `Error with ${selectedModel!.model}`, + description: 'Failed to generate response', type: 'error', }) } @@ -442,13 +480,15 @@ const useSendMessage = () => { selectedModel, updateMessage, createMessage, - validatePrerequisite, startModel, + setInferenceErrorAtom, + validatePrerequisite, updateMessageState, addNewMessage, chatCompletionNonStreaming, chatCompletionStreaming, setIsGeneratingResponse, + setShowWarningMultipleModelModal, ]) const sendMessage = useCallback( @@ -479,6 +519,11 @@ const useSendMessage = () => { if (LocalEngines.find((e) => e === selectedModel!.engine) != null) { // start model if local and not started if (!activeModels.map((model) => model.model).includes(modelId)) { + if (activeModels.length >= concurrentModelWarningThreshold) { + // if max concurrent models reached, stop the first model + // display popup + setShowWarningMultipleModelModal(true) + } await startModel.mutateAsync(modelId) } } @@ -502,7 +547,10 @@ const useSendMessage = () => { case 'assistant': return { role: msg.role, - content: (msg.content[0] as TextContentBlock).text.value, + content: + msg.content[0] != null + ? (msg.content[0] as TextContentBlock).text.value + : '', } // we will need to support other roles in the future @@ -536,7 +584,7 @@ const useSendMessage = () => { top_p: selectedModel!.top_p ?? 1, ...modelOptions, }) - + didUserAborted.current = false abortControllerRef.current = stream.controller const assistantMessage = await createMessage.mutateAsync({ @@ -606,7 +654,10 @@ const useSendMessage = () => { }, }) } else { + didUserAborted.current = false const abortController = new AbortController() + abortControllerRef.current = abortController + const response = await chatCompletionNonStreaming( { messages, @@ -663,7 +714,7 @@ const useSendMessage = () => { content: responseMessage.content, }, }) - + abortControllerRef.current = undefined if (responseMessage) { setIsGeneratingResponse(false) } @@ -672,17 +723,27 @@ const useSendMessage = () => { } } catch (err) { console.error(err) + // @ts-expect-error error message should be there + const errorMessage = err['message'] + if (errorMessage != null) { + setInferenceErrorAtom({ + engine: selectedModel!.engine, + message: errorMessage, + }) + } + setIsGeneratingResponse(false) shouldSummarize = false toaster({ - title: 'Failed to generate response', + title: `Error with ${selectedModel!.model}`, + description: 'Failed to generate response', type: 'error', }) } try { - if (!shouldSummarize) return + if (!shouldSummarize || didUserAborted.current === true) return // summarize if needed const textMessages: string[] = messages .map((msg) => { @@ -702,16 +763,18 @@ const useSendMessage = () => { selectedModel, updateMessage, createMessage, + startModel, + setInferenceErrorAtom, validatePrerequisite, setCurrentPrompt, setEditPrompt, setIsGeneratingResponse, updateMessageState, addNewMessage, - startModel, chatCompletionNonStreaming, chatCompletionStreaming, summarizeThread, + setShowWarningMultipleModelModal, ] ) diff --git a/web/package.json b/web/package.json index ac80448fa3..632cb8f7f5 100644 --- a/web/package.json +++ b/web/package.json @@ -17,7 +17,7 @@ "yaml": "^2.4.5", "@huggingface/hub": "^0.15.1", "embla-carousel-react": "^8.1.5", - "cortexso-node": "^0.0.4", + "@cortexso/cortex.js": "^0.1.6", "@microsoft/fetch-event-source": "^2.0.1", "@janhq/core": "link:./core", "@janhq/joi": "link:./joi", @@ -44,7 +44,6 @@ "sass": "^1.69.4", "tailwind-merge": "^2.0.0", "tailwindcss": "3.3.5", - "uuid": "^9.0.1", "use-debounce": "^10.0.0" }, "devDependencies": { diff --git a/web/screens/HubScreen2/components/InferenceErrorModal.tsx b/web/screens/HubScreen2/components/InferenceErrorModal.tsx new file mode 100644 index 0000000000..eef456e98a --- /dev/null +++ b/web/screens/HubScreen2/components/InferenceErrorModal.tsx @@ -0,0 +1,45 @@ +import { Fragment, useCallback } from 'react' + +import { LlmEngine } from '@janhq/core' +import { Button, Modal, ModalClose } from '@janhq/joi' +import { atom, useAtom } from 'jotai' + +export type InferenceError = { + message: string + engine?: LlmEngine +} + +export const inferenceErrorAtom = atom(undefined) + +const InferenceErrorModal: React.FC = () => { + const [inferenceError, setInferenceError] = useAtom(inferenceErrorAtom) + + const onClose = useCallback(() => { + setInferenceError(undefined) + }, [setInferenceError]) + + return ( + +

+ {inferenceError?.message} +

+
+ + + +
+
+ } + /> + ) +} + +export default InferenceErrorModal diff --git a/web/screens/HubScreen2/components/ModelSearchBar.tsx b/web/screens/HubScreen2/components/ModelSearchBar.tsx index 6ca3da884d..2ad05f6aaa 100644 --- a/web/screens/HubScreen2/components/ModelSearchBar.tsx +++ b/web/screens/HubScreen2/components/ModelSearchBar.tsx @@ -2,7 +2,7 @@ import React, { useCallback, useState } from 'react' import { Button, Input } from '@janhq/joi' import { useSetAtom } from 'jotai' -import { SearchIcon } from 'lucide-react' +import { ImportIcon, SearchIcon } from 'lucide-react' import { FoldersIcon } from 'lucide-react' import { useDebouncedCallback } from 'use-debounce' @@ -10,6 +10,8 @@ import { toaster } from '@/containers/Toast' import { useGetHFRepoData } from '@/hooks/useGetHFRepoData' +import { setImportModelStageAtom } from '@/hooks/useImportModel' + import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' import { importHuggingFaceModelStageAtom, @@ -26,6 +28,7 @@ const ModelSearchBar: React.FC = ({ onSearchChanged }) => { const { getHfRepoData } = useGetHFRepoData() const setMainViewState = useSetAtom(mainViewStateAtom) const setSelectedSetting = useSetAtom(selectedSettingAtom) + const setImportModelStage = useSetAtom(setImportModelStageAtom) const setImportingHuggingFaceRepoData = useSetAtom( importingHuggingFaceRepoDataAtom @@ -34,6 +37,10 @@ const ModelSearchBar: React.FC = ({ onSearchChanged }) => { importHuggingFaceModelStageAtom ) + const onImportModelClick = useCallback(() => { + setImportModelStage('SELECTING_MODEL') + }, [setImportModelStage]) + const debounced = useDebouncedCallback(async (searchText: string) => { if (searchText.indexOf('/') === -1) { // If we don't find / in the text, perform a local search @@ -90,6 +97,14 @@ const ModelSearchBar: React.FC = ({ onSearchChanged }) => { My models +
) } diff --git a/web/screens/HubScreen2/components/SetUpApiKeyModal.tsx b/web/screens/HubScreen2/components/SetUpApiKeyModal.tsx index 052e80393d..edeba3187a 100644 --- a/web/screens/HubScreen2/components/SetUpApiKeyModal.tsx +++ b/web/screens/HubScreen2/components/SetUpApiKeyModal.tsx @@ -38,6 +38,9 @@ const SetUpApiKeyModal: React.FC = () => { alert('Does not have engine') return } + const normalizedApiKey = apiKey.trim().replaceAll('*', '') + if (normalizedApiKey.length === 0) return + updateEngineConfig.mutate({ engine: remoteEngine, config: { diff --git a/web/screens/HubScreen2/components/WarningMultipleModelModal.tsx b/web/screens/HubScreen2/components/WarningMultipleModelModal.tsx new file mode 100644 index 0000000000..f824a30f4b --- /dev/null +++ b/web/screens/HubScreen2/components/WarningMultipleModelModal.tsx @@ -0,0 +1,49 @@ +import { Fragment, useCallback, useMemo } from 'react' + +import { Button, Modal, ModalClose } from '@janhq/joi' +import { atom, useAtom, useAtomValue } from 'jotai' + +import { activeModelsAtom } from '@/helpers/atoms/Model.atom' + +export const showWarningMultipleModelModalAtom = atom(false) + +const WarningMultipleModelModal: React.FC = () => { + const [showWarningMultipleModelModal, setShowWarningMultipleModelModal] = + useAtom(showWarningMultipleModelModalAtom) + const activeModels = useAtomValue(activeModelsAtom) + + const onClose = useCallback(() => { + setShowWarningMultipleModelModal(false) + }, [setShowWarningMultipleModelModal]) + + const title = useMemo( + () => `${activeModels.length} models running`, + [activeModels] + ) + + return ( + +

+ This may affect performance. Please review them via System Monitor + in the lower right conner of Jan app. +

+
+ + + +
+ + } + /> + ) +} + +export default WarningMultipleModelModal diff --git a/web/screens/Settings/Advanced/index.tsx b/web/screens/Settings/Advanced/index.tsx index aaf28f9a22..030de9717d 100644 --- a/web/screens/Settings/Advanced/index.tsx +++ b/web/screens/Settings/Advanced/index.tsx @@ -10,7 +10,7 @@ import { useAtom, useAtomValue } from 'jotai' import { toaster } from '@/containers/Toast' -import useModels from '@/hooks/useModels' +import useModelStop from '@/hooks/useModelStop' import { useSettings } from '@/hooks/useSettings' import { @@ -53,7 +53,7 @@ const Advanced = () => { const { readSettings, saveSettings } = useSettings() const activeModels = useAtomValue(activeModelsAtom) // const [open, setOpen] = useState(false) - const { stopModel } = useModels() + const stopModel = useModelStop() // const selectedGpu = gpuList // .filter((x) => gpusInUse.includes(x.id)) @@ -92,7 +92,7 @@ const Advanced = () => { }) for (const model of activeModels) { - await stopModel(model.model) + await stopModel.mutateAsync(model.model) } setVulkanEnabled(e) diff --git a/web/screens/Settings/ChooseWhatToImportModal/index.tsx b/web/screens/Settings/ChooseWhatToImportModal/index.tsx index 889dbbe67e..8213038641 100644 --- a/web/screens/Settings/ChooseWhatToImportModal/index.tsx +++ b/web/screens/Settings/ChooseWhatToImportModal/index.tsx @@ -1,18 +1,22 @@ import { useCallback } from 'react' -import { SelectFileOption } from '@janhq/core' +import { ImportingModel, SelectFileOption } from '@janhq/core' import { Button, Modal } from '@janhq/joi' import { useSetAtom, useAtomValue } from 'jotai' -import useImportModel, { +import { snackbar } from '@/containers/Toast' + +import { setImportModelStageAtom, getImportModelStageAtom, } from '@/hooks/useImportModel' +import { importingModelsAtom } from '@/helpers/atoms/Model.atom' + const ChooseWhatToImportModal = () => { const setImportModelStage = useSetAtom(setImportModelStageAtom) + const setImportingModels = useSetAtom(importingModelsAtom) const importModelStage = useAtomValue(getImportModelStageAtom) - const { sanitizeFilePaths } = useImportModel() const onImportFileClick = useCallback(async () => { const options: SelectFileOption = { @@ -24,10 +28,36 @@ const ChooseWhatToImportModal = () => { { name: 'All Files', extensions: ['*'] }, ], } - const filePaths = await window.core?.api?.selectFiles(options) + const filePaths: string[] = await window.core?.api?.selectFiles(options) if (!filePaths || filePaths.length === 0) return - sanitizeFilePaths(filePaths) - }, [sanitizeFilePaths]) + + const importingModels: ImportingModel[] = filePaths + .filter((path) => path.endsWith('.gguf')) + .map((path) => { + const normalizedPath = isWindows ? path.replace(/\\/g, '/') : path + + return { + importId: normalizedPath, + modelId: undefined, + name: normalizedPath.replace('.gguf', ''), + description: '', + path: path, + tags: [], + size: 0, + status: 'PREPARING', + format: 'gguf', + } + }) + if (importingModels.length < 1) { + snackbar({ + description: `Only files with .gguf extension can be imported.`, + type: 'error', + }) + return + } + setImportingModels(importingModels) + setImportModelStage('MODEL_SELECTED') + }, [setImportingModels, setImportModelStage]) const onImportFolderClick = useCallback(async () => { const options: SelectFileOption = { @@ -36,10 +66,37 @@ const ChooseWhatToImportModal = () => { allowMultiple: true, selectDirectory: true, } - const filePaths = await window.core?.api?.selectFiles(options) + const filePaths: string[] = await window.core?.api?.selectFiles(options) if (!filePaths || filePaths.length === 0) return - sanitizeFilePaths(filePaths) - }, [sanitizeFilePaths]) + + console.log('filePaths folder', filePaths) + const importingModels: ImportingModel[] = filePaths + .filter((path) => path.endsWith('.gguf')) + .map((path) => { + const normalizedPath = isWindows ? path.replace(/\\/g, '/') : path + + return { + importId: normalizedPath, + modelId: undefined, + name: normalizedPath.replace('.gguf', ''), + description: '', + path: path, + tags: [], + size: 0, + status: 'PREPARING', + format: 'gguf', + } + }) + if (importingModels.length < 1) { + snackbar({ + description: `Only files with .gguf extension can be imported.`, + type: 'error', + }) + return + } + setImportingModels(importingModels) + setImportModelStage('MODEL_SELECTED') + }, [setImportingModels, setImportModelStage]) return ( { + const normalized = status.charAt(0).toUpperCase() + status.slice(1) + return normalized.replaceAll('_', ' ') +} + const EngineSetting: React.FC = () => { const { isLoading, data } = useEngineQuery() @@ -56,7 +61,7 @@ const EngineSetting: React.FC = () => { {engineStatus.version} - {engineStatus.status} + {getStatusTitle(engineStatus.status)} ) })} diff --git a/web/screens/Settings/ImportModelOptionModal/index.tsx b/web/screens/Settings/ImportModelOptionModal/index.tsx index 5a2af2335f..ed01c400af 100644 --- a/web/screens/Settings/ImportModelOptionModal/index.tsx +++ b/web/screens/Settings/ImportModelOptionModal/index.tsx @@ -20,12 +20,12 @@ const importOptions: ModelImportOption[] = [ description: 'You maintain your model files outside of Jan. Keeping your files where they are, and Jan will create a smart link to them.', }, - { - type: 'MOVE_BINARY_FILE', - title: 'Move model binary file', - description: - 'Jan will move your model binary file from your current folder into Jan Data Folder.', - }, + // { + // type: 'MOVE_BINARY_FILE', + // title: 'Move model binary file', + // description: + // 'Jan will move your model binary file from your current folder into Jan Data Folder.', + // }, ] const ImportModelOptionModal = () => { diff --git a/web/screens/Settings/ImportSuccessIcon/index.tsx b/web/screens/Settings/ImportSuccessIcon/index.tsx index e574acbf0d..a822ca4d2c 100644 --- a/web/screens/Settings/ImportSuccessIcon/index.tsx +++ b/web/screens/Settings/ImportSuccessIcon/index.tsx @@ -1,6 +1,6 @@ -import React, { useCallback, useState } from 'react' +import React, { useState } from 'react' -import { Check, Pencil } from 'lucide-react' +import { Check } from 'lucide-react' type Props = { onEditModelClick: () => void @@ -9,6 +9,8 @@ type Props = { const ImportSuccessIcon: React.FC = ({ onEditModelClick }) => { const [isHovered, setIsHovered] = useState(false) + console.log(isHovered, onEditModelClick) + const onMouseOver = () => { setIsHovered(true) } @@ -19,34 +21,34 @@ const ImportSuccessIcon: React.FC = ({ onEditModelClick }) => { return (
- {isHovered ? ( + {/* {isHovered ? ( - ) : ( - - )} + ) : ( */} + + {/* )} */}
) } const SuccessIcon = React.memo(() => ( -
+
)) -const EditIcon: React.FC = React.memo(({ onEditModelClick }) => { - const onClick = useCallback(() => { - onEditModelClick() - }, [onEditModelClick]) - - return ( -
- -
- ) -}) +// const EditIcon: React.FC = React.memo(({ onEditModelClick }) => { +// const onClick = useCallback(() => { +// onEditModelClick() +// }, [onEditModelClick]) + +// return ( +//
+// +//
+// ) +// }) export default ImportSuccessIcon diff --git a/web/screens/Settings/ImportingModelModal/ImportingModelItem.tsx b/web/screens/Settings/ImportingModelModal/ImportingModelItem.tsx index c7f6c35f0d..4ac2a4debb 100644 --- a/web/screens/Settings/ImportingModelModal/ImportingModelItem.tsx +++ b/web/screens/Settings/ImportingModelModal/ImportingModelItem.tsx @@ -1,15 +1,11 @@ import { useCallback, useMemo } from 'react' import { ImportingModel } from '@janhq/core' -import { useSetAtom } from 'jotai' import { AlertCircle } from 'lucide-react' -import { setImportModelStageAtom } from '@/hooks/useImportModel' - import { toGibibytes } from '@/utils/converter' -import { editingModelIdAtom } from '../EditModelInfoModal' import ImportInProgressIcon from '../ImportInProgressIcon' import ImportSuccessIcon from '../ImportSuccessIcon' @@ -18,16 +14,13 @@ type Props = { } const ImportingModelItem = ({ model }: Props) => { - const setImportModelStage = useSetAtom(setImportModelStageAtom) - const setEditingModelId = useSetAtom(editingModelIdAtom) - const onEditModelInfoClick = useCallback(() => { - setEditingModelId(model.importId) - setImportModelStage('EDIT_MODEL_INFO') - }, [setImportModelStage, setEditingModelId, model.importId]) + // setEditingModelId(model.importId) + // setImportModelStage('EDIT_MODEL_INFO') + }, []) const onDeleteModelClick = useCallback(() => {}, []) - + console.log('namh model', model) const displayStatus = useMemo(() => { if (model.status === 'FAILED') { return 'Failed' diff --git a/web/screens/Settings/ImportingModelModal/index.tsx b/web/screens/Settings/ImportingModelModal/index.tsx index 355ff49dca..30f7d7a997 100644 --- a/web/screens/Settings/ImportingModelModal/index.tsx +++ b/web/screens/Settings/ImportingModelModal/index.tsx @@ -1,50 +1,47 @@ -import { useCallback, useEffect, useState } from 'react' +import { useEffect } from 'react' -import { Button, Modal } from '@janhq/joi' +import { Modal } from '@janhq/joi' import { useAtomValue, useSetAtom } from 'jotai' import { AlertCircle } from 'lucide-react' +import useCortex from '@/hooks/useCortex' import { getImportModelStageAtom, setImportModelStageAtom, } from '@/hooks/useImportModel' -import { openFileTitle } from '@/utils/titleUtils' - import ImportingModelItem from './ImportingModelItem' -import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom' import { importingModelsAtom } from '@/helpers/atoms/Model.atom' const ImportingModelModal = () => { + const { downloadModel } = useCortex() + const setImportModelStage = useSetAtom(setImportModelStageAtom) const importingModels = useAtomValue(importingModelsAtom) const importModelStage = useAtomValue(getImportModelStageAtom) - const setImportModelStage = useSetAtom(setImportModelStageAtom) - const janDataFolder = useAtomValue(janDataFolderPathAtom) - - const [modelFolder, setModelFolder] = useState('') - - useEffect(() => { - const getModelPath = async () => { - // const modelPath = await joinPath([janDataFolder, 'models']) - setModelFolder('') - } - getModelPath() - }, [janDataFolder]) const finishedImportModel = importingModels.filter( (model) => model.status === 'IMPORTED' ).length - const onOpenModelFolderClick = useCallback( - () => { - // openFileExplorer(modelFolder) - }, - [ - /*modelFolder*/ - ] - ) + useEffect(() => { + const importModels = async () => { + for (const model of importingModels) { + await downloadModel(model.path) + // const parsedResult = await result?.json() + // if ( + // parsedResult['message'] && + // parsedResult['message'] === 'Download model started successfully.' + // ) { + // // update importingModels + // } + // console.log(`NamH result ${JSON.stringify(parsedResult)}`) + } + } + importModels() + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [downloadModel]) return ( { content={
-
-
+
{importingModels.map((model) => ( ))} diff --git a/web/screens/Settings/MyModels/ModelItem/index.tsx b/web/screens/Settings/MyModels/ModelItem/index.tsx index e0229480c4..4b53f7c552 100644 --- a/web/screens/Settings/MyModels/ModelItem/index.tsx +++ b/web/screens/Settings/MyModels/ModelItem/index.tsx @@ -3,7 +3,7 @@ import { memo, useCallback, useMemo, useState } from 'react' import { LocalEngines, Model } from '@janhq/core' import { Badge, Button, useClickOutside } from '@janhq/joi' -import { useAtomValue } from 'jotai' +import { useAtomValue, useSetAtom } from 'jotai' import { MoreVerticalIcon, PlayIcon, @@ -13,22 +13,32 @@ import { import { twMerge } from 'tailwind-merge' import useModelStart from '@/hooks/useModelStart' +import useModelStop from '@/hooks/useModelStop' import useModels from '@/hooks/useModels' +import { showWarningMultipleModelModalAtom } from '@/screens/HubScreen2/components/WarningMultipleModelModal' + import { activeModelsAtom } from '@/helpers/atoms/Model.atom' type Props = { model: Model } +// If more than this number of models are running, show a warning modal. +export const concurrentModelWarningThreshold = 2 + const ModelItem: React.FC = ({ model }) => { const activeModels = useAtomValue(activeModelsAtom) const startModel = useModelStart() + const stopModel = useModelStop() const [more, setMore] = useState(false) - const { stopModel, deleteModel } = useModels() + const { deleteModel } = useModels() const [menu, setMenu] = useState(null) const [toggle, setToggle] = useState(null) + const setShowWarningMultipleModelModal = useSetAtom( + showWarningMultipleModelModalAtom + ) useClickOutside(() => setMore(false), null, [menu, toggle]) const isActive = useMemo( @@ -39,17 +49,30 @@ const ModelItem: React.FC = ({ model }) => { const onModelActionClick = useCallback( (modelId: string) => { if (isActive) { - stopModel(modelId) - } else { - startModel.mutate(modelId) + // if model already active, stop it + stopModel.mutate(modelId) + return + } + + if (activeModels.length >= concurrentModelWarningThreshold) { + // if max concurrent models reached, stop the first model + // display popup + setShowWarningMultipleModelModal(true) } + startModel.mutate(modelId) }, - [isActive, startModel, stopModel] + [ + isActive, + startModel, + stopModel, + activeModels.length, + setShowWarningMultipleModelModal, + ] ) const onDeleteModelClicked = useCallback( async (modelId: string) => { - await stopModel(modelId) + await stopModel.mutateAsync(modelId) await deleteModel(modelId) }, [stopModel, deleteModel] diff --git a/web/screens/Settings/MyModels/index.tsx b/web/screens/Settings/MyModels/index.tsx index ad3820d3f1..24d657b91d 100644 --- a/web/screens/Settings/MyModels/index.tsx +++ b/web/screens/Settings/MyModels/index.tsx @@ -6,7 +6,7 @@ import { LlmEngines } from '@janhq/core' import { Button, ScrollArea } from '@janhq/joi' import { useAtomValue, useSetAtom } from 'jotai' -import { UploadCloudIcon } from 'lucide-react' +import { ImportIcon, UploadCloudIcon } from 'lucide-react' import { twMerge } from 'tailwind-merge' @@ -16,6 +16,8 @@ import ModelSearch from '@/containers/ModelSearch' import useDropModelBinaries from '@/hooks/useDropModelBinaries' +import { setImportModelStageAtom } from '@/hooks/useImportModel' + import ModelItem from './ModelItem' import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' @@ -26,6 +28,11 @@ const MyModels = () => { const downloadedModels = useAtomValue(downloadedModelsAtom) const { onDropModels } = useDropModelBinaries() const [searchText, setSearchText] = useState('') + const setImportModelStage = useSetAtom(setImportModelStageAtom) + + const onImportModelClick = useCallback(() => { + setImportModelStage('SELECTING_MODEL') + }, [setImportModelStage]) const filteredDownloadedModels = useMemo( () => @@ -75,14 +82,14 @@ const MyModels = () => {
- {/* */} +
{!filteredDownloadedModels.length ? ( diff --git a/web/screens/Settings/SelectingModelModal/index.tsx b/web/screens/Settings/SelectingModelModal/index.tsx index c0900827ba..b84f72e3ea 100644 --- a/web/screens/Settings/SelectingModelModal/index.tsx +++ b/web/screens/Settings/SelectingModelModal/index.tsx @@ -1,39 +1,109 @@ import { useCallback } from 'react' import { useDropzone } from 'react-dropzone' +import { ImportingModel, SelectFileOption } from '@janhq/core' import { Modal } from '@janhq/joi' import { useAtomValue, useSetAtom } from 'jotai' import { UploadCloudIcon } from 'lucide-react' +import { snackbar } from '@/containers/Toast' + import useDropModelBinaries from '@/hooks/useDropModelBinaries' import { getImportModelStageAtom, setImportModelStageAtom, } from '@/hooks/useImportModel' -const SelectingModelModal = () => { +import { importingModelsAtom } from '@/helpers/atoms/Model.atom' + +const SelectingModelModal: React.FC = () => { const setImportModelStage = useSetAtom(setImportModelStageAtom) + const setImportingModels = useSetAtom(importingModelsAtom) const importModelStage = useAtomValue(getImportModelStageAtom) const { onDropModels } = useDropModelBinaries() - // const { sanitizeFilePaths } = useImportModel() + + const onImportFileWindowsClick = useCallback(async () => { + const options: SelectFileOption = { + title: 'Select model files', + buttonLabel: 'Select', + allowMultiple: true, + filters: [ + { name: 'GGUF Files', extensions: ['gguf'] }, + { name: 'All Files', extensions: ['*'] }, + ], + } + const filePaths: string[] = await window.core?.api?.selectFiles(options) + if (!filePaths || filePaths.length === 0) return + + const importingModels: ImportingModel[] = filePaths + .filter((path) => path.endsWith('.gguf')) + .map((path) => { + const normalizedPath = isWindows ? path.replace(/\\/g, '/') : path + + return { + importId: normalizedPath, + modelId: undefined, + name: normalizedPath.replace('.gguf', ''), + description: '', + path: path, + tags: [], + size: 0, + status: 'PREPARING', + format: 'gguf', + } + }) + if (importingModels.length < 1) { + snackbar({ + description: `Only files with .gguf extension can be imported.`, + type: 'error', + }) + return + } + setImportingModels(importingModels) + setImportModelStage('MODEL_SELECTED') + }, [setImportingModels, setImportModelStage]) const onSelectFileClick = useCallback(async () => { - // const platform = (await systemInformation()).osInfo?.platform - // if (platform === 'win32') { - // setImportModelStage('CHOOSE_WHAT_TO_IMPORT') - // return - // } - // const options: SelectFileOption = { - // title: 'Select model folders', - // buttonLabel: 'Select', - // allowMultiple: true, - // selectDirectory: true, - // } - // const filePaths = await window.core?.api?.selectFiles(options) - // if (!filePaths || filePaths.length === 0) return - // sanitizeFilePaths(filePaths) - }, []) + if (isWindows) { + return onImportFileWindowsClick() + } + const options: SelectFileOption = { + title: 'Select model folders', + buttonLabel: 'Select', + allowMultiple: true, + selectDirectory: true, + } + const filePaths: string[] = await window.core?.api?.selectFiles(options) + if (!filePaths || filePaths.length === 0) return + + const importingModels: ImportingModel[] = filePaths + .filter((path) => path.endsWith('.gguf')) + .map((path) => { + const normalizedPath = isWindows ? path.replace(/\\/g, '/') : path + + return { + importId: normalizedPath, + modelId: undefined, + name: normalizedPath.replace('.gguf', ''), + description: '', + path: path, + tags: [], + size: 0, + status: 'PREPARING', + format: 'gguf', + } + }) + if (importingModels.length < 1) { + snackbar({ + description: `Only files with .gguf extension can be imported.`, + type: 'error', + }) + return + } + setImportingModels(importingModels) + setImportModelStage('MODEL_SELECTED') + }, [setImportModelStage, setImportingModels, onImportFileWindowsClick]) const { isDragActive, getRootProps } = useDropzone({ noClick: true, @@ -52,9 +122,7 @@ const SelectingModelModal = () => { return ( { - setImportModelStage('NONE') - }} + onOpenChange={() => setImportModelStage('NONE')} title="Import Model" content={
diff --git a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx index 35de825bec..233308f51b 100644 --- a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx @@ -3,7 +3,7 @@ import { useCallback } from 'react' import { SettingComponentProps } from '@janhq/core' import { useAtomValue } from 'jotai' -import useModels from '@/hooks/useModels' +import useModelStop from '@/hooks/useModelStop' import SettingComponentBuilder from '../../../../containers/ModelSetting/SettingComponent' @@ -17,7 +17,7 @@ type Props = { const AssistantSetting: React.FC = ({ componentData }) => { const activeThread = useAtomValue(activeThreadAtom) const activeModels = useAtomValue(activeModelsAtom) - const { stopModel } = useModels() + const stopModel = useModelStop() const onValueChanged = useCallback( (key: string, value: string | number | boolean) => { @@ -29,7 +29,7 @@ const AssistantSetting: React.FC = ({ componentData }) => { const model = activeModels.find( (model) => activeThread.assistants[0]?.model === model.model ) - if (model) stopModel(model.model) + if (model) stopModel.mutate(model.model) } // if ( diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx index ea480c2c90..7ae4b1456c 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx @@ -8,7 +8,11 @@ import EmptyThread from './EmptyThread' import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' -const ChatBody: React.FC = () => { +type Props = { + onResendMessage: () => void +} + +const ChatBody: React.FC = ({ onResendMessage }) => { const messages = useAtomValue(getCurrentChatMessagesAtom) if (!messages.length) return @@ -22,6 +26,7 @@ const ChatBody: React.FC = () => { key={message.id} msg={message} isLatestMessage={isLatestMessage} + onResendMessage={onResendMessage} /> ) })} diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/SendMessageButton.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/SendMessageButton.tsx new file mode 100644 index 0000000000..0469338c30 --- /dev/null +++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/SendMessageButton.tsx @@ -0,0 +1,47 @@ +import { useMemo } from 'react' + +import React from 'react' + +import { Button } from '@janhq/joi' +import { useAtomValue } from 'jotai' + +import { currentPromptAtom } from '@/containers/Providers/Jotai' + +type Props = { + onSendMessageClick: (message: string) => void +} + +const SendMessageButton: React.FC = ({ onSendMessageClick }) => { + const currentPrompt = useAtomValue(currentPromptAtom) + + const showSendButton = useMemo(() => { + if (currentPrompt.trim().length === 0) return false + return true + }, [currentPrompt]) + + if (!showSendButton) return null + + return ( + + ) +} + +export default React.memo(SendMessageButton) diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/StopInferenceButton.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/StopInferenceButton.tsx new file mode 100644 index 0000000000..a8cf9273c0 --- /dev/null +++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/StopInferenceButton.tsx @@ -0,0 +1,21 @@ +import React from 'react' + +import { Button } from '@janhq/joi' + +import { StopCircle } from 'lucide-react' + +type Props = { + onStopInferenceClick: () => void +} + +const StopInferenceButton: React.FC = ({ onStopInferenceClick }) => ( + +) + +export default React.memo(StopInferenceButton) diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/index.tsx new file mode 100644 index 0000000000..a55d691411 --- /dev/null +++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/index.tsx @@ -0,0 +1,40 @@ +import { useMemo } from 'react' + +import { useAtomValue } from 'jotai' + +import SendMessageButton from './SendMessageButton' +import StopInferenceButton from './StopInferenceButton' + +import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' + +import { isGeneratingResponseAtom } from '@/helpers/atoms/Thread.atom' + +type Props = { + onStopInferenceClick: () => void + onSendMessageClick: (message: string) => void +} + +const ChatActionButton: React.FC = ({ + onStopInferenceClick, + onSendMessageClick, +}) => { + const messages = useAtomValue(getCurrentChatMessagesAtom) + const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom) + + const showStopButton = useMemo(() => { + if (isGeneratingResponse) return true + + const lastMessage = messages[messages.length - 1] + if (!lastMessage) return false + if (lastMessage.status === 'in_progress') return true + return false + }, [isGeneratingResponse, messages]) + + if (showStopButton) { + return + } + + return +} + +export default ChatActionButton diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatTextInput/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatTextInput/index.tsx new file mode 100644 index 0000000000..e9d7a93410 --- /dev/null +++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatTextInput/index.tsx @@ -0,0 +1,95 @@ +import { useCallback, useEffect, useMemo, useRef } from 'react' + +import { TextArea } from '@janhq/joi' +import { useAtom, useAtomValue } from 'jotai' + +import { twMerge } from 'tailwind-merge' + +import { currentPromptAtom } from '@/containers/Providers/Jotai' + +import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' + +import { spellCheckAtom } from '@/helpers/atoms/Setting.atom' +import { + getActiveThreadIdAtom, + isGeneratingResponseAtom, +} from '@/helpers/atoms/Thread.atom' + +type Props = { + isSettingActive: boolean + onSendMessageClick: (message: string) => void +} + +const ChatTextInput: React.FC = ({ + isSettingActive, + onSendMessageClick, +}) => { + const messages = useAtomValue(getCurrentChatMessagesAtom) + const [currentPrompt, setCurrentPrompt] = useAtom(currentPromptAtom) + const textareaRef = useRef(null) + const activeThreadId = useAtomValue(getActiveThreadIdAtom) + const spellCheck = useAtomValue(spellCheckAtom) + + const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom) + + const disabled = useMemo(() => { + return !activeThreadId + }, [activeThreadId]) + + const onChange = useCallback( + (e: React.ChangeEvent) => { + setCurrentPrompt(e.target.value) + }, + [setCurrentPrompt] + ) + + useEffect(() => { + if (textareaRef.current) { + textareaRef.current.focus() + } + }, [activeThreadId]) + + useEffect(() => { + if (textareaRef.current?.clientHeight) { + textareaRef.current.style.height = isSettingActive ? '100px' : '40px' + textareaRef.current.style.height = textareaRef.current.scrollHeight + 'px' + textareaRef.current.style.overflow = + textareaRef.current.clientHeight >= 390 ? 'auto' : 'hidden' + } + }, [textareaRef.current?.clientHeight, currentPrompt, isSettingActive]) + + const onKeyDown = useCallback( + (e: React.KeyboardEvent) => { + if (e.key === 'Enter' && !e.shiftKey && !e.nativeEvent.isComposing) { + e.preventDefault() + if (isGeneratingResponse) return + const lastMessage = messages[messages.length - 1] + if (!lastMessage || lastMessage.status !== 'in_progress') { + onSendMessageClick(currentPrompt) + return + } + } + }, + [messages, isGeneratingResponse, currentPrompt, onSendMessageClick] + ) + + return ( +