From 1fdeaeacb8af9af1c363f286ffbb043ae4e6f08f Mon Sep 17 00:00:00 2001 From: Paul Robello Date: Tue, 16 Jul 2024 16:23:27 -0700 Subject: [PATCH] Fixes, Enhancements and Refinement (#14) * added ps support for remote connections * chat tab status bar with current and max context length for current model * chat stop / abort button * added double click to session list to load item * better chat param entry flow * cleanup readme, update TOC * numerous bug fixes --- LICENSE | 2 +- README.md | 98 +++++++---- parllama/__init__.py | 2 +- parllama/app.py | 177 ++++++++++--------- parllama/data_manager.py | 59 +++++-- parllama/dialogs/model_details_dialog.py | 14 ++ parllama/help.md | 10 +- parllama/messages/main.py | 38 ++++- parllama/models/chat.py | 119 ++++++++++--- parllama/models/ollama_data.py | 60 ++++++- parllama/models/ollama_ps.py | 36 ++++ parllama/models/settings_data.py | 14 +- parllama/screens/main_screen.tcss | 7 - parllama/widgets/chat_message_list.py | 30 ++++ parllama/widgets/input_blur_submit.py | 17 ++ parllama/widgets/session_list.py | 11 ++ parllama/widgets/session_list_item.py | 4 +- parllama/widgets/views/chat_tab.py | 190 ++++++++++++++------- parllama/widgets/views/chat_view.py | 89 ++++++++-- parllama/widgets/views/local_model_view.py | 14 ++ parllama/widgets/views/site_model_view.py | 12 +- setup.cfg | 2 +- 22 files changed, 751 insertions(+), 254 deletions(-) create mode 100644 parllama/models/ollama_ps.py create mode 100644 parllama/widgets/chat_message_list.py create mode 100644 parllama/widgets/input_blur_submit.py diff --git a/LICENSE b/LICENSE index 3a43997..4e5db6f 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2021 Will McGugan +Copyright (c) 2021 Paul Robello Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index ba12f51..413d263 100644 --- a/README.md +++ b/README.md @@ -2,32 +2,38 @@ ## Table of Contents -- [About](#about) -- [Screenshots](#screenshots) -- [Prerequisites](#prerequisites-for-running) - - [For Running](#prerequisites-for-running) - - [For Development](#prerequisites-for-dev) - - [For Model Quantization](#prerequisites-for-model-quantization) -- [Installation](#installing-from-mypi-using-pipx) - - [Using pipx](#installing-from-mypi-using-pipx) - - [Using pip](#installing-from-mypi-using-pip) - - [For Development](#installing-for-dev-mode) -- [Command Line Arguments](#command-line-arguments) -- [Environment Variables](#environment-variables) -- [Running PAR_LLAMA](#running-par_llama) - - [With pipx installation](#with-pipx-installation) - - [With pip installation](#with-pip-installation) - - [Under Windows WSL](#running-under-windows-wsl) - - [In Development Mode](#dev-mode) -- [Example Workflow](#example-workflow) -- [Themes](#themes) -- [Contributing](#contributing) -- [Roadmap](#roadmap) -- [What's New](#whats-new) +1. [About](#about) + 1. [Screenshots](#screenshots) +2. [Prerequisites for running](#prerequisites-for-running) +3. [Prerequisites for dev](#prerequisites-for-dev) +4. [Prerequisites for huggingface model quantization](#prerequisites-for-huggingface-model-quantization) +5. [Installing from mypi using pipx](#installing-from-mypi-using-pipx) +6. [Installing from mypi using pip](#installing-from-mypi-using-pip) +7. [Installing for dev mode](#installing-for-dev-mode) +8. [Command line arguments](#command-line-arguments) +9. [Environment Variables](#environment-variables) +10. [Running PAR_LLAMA](#running-par_llama) + 1. [with pipx installation](#with-pipx-installation) + 2. [with pip installation](#with-pip-installation) +11. [Running against a remote instance](#running-against-a-remote-instance) +12. [Running under Windows WSL](#running-under-windows-wsl) + 1. [Dev mode](#dev-mode) +13. [Example workflow](#example-workflow) +14. [Themes](#themes) +15. [Contributing](#contributing) +16. [Roadmap](#roadmap) + 1. [Where we are](#where-we-are) + 2. [Where we're going](#where-were-going) +17. [What's new](#whats-new) + 1. [v0.3.1](#v031) + 2. [v0.3.0](#v030) + 3. [v0.2.51](#v0251) + 4. [v0.2.5](#v025) ## About PAR LLAMA is a TUI application designed for easy management and use of Ollama based LLMs. The application was built with [Textual](https://textual.textualize.io/) and [Rich](https://github.com/Textualize/rich?tab=readme-ov-file) +and runs on all major OS's including but not limited to Windows, Windows WSL, Mac, and Linux. ### Screenshots Supports Dark and Light mode as well as custom themes. @@ -45,14 +51,18 @@ Supports Dark and Light mode as well as custom themes. ## Prerequisites for running * Install and run [Ollama](https://ollama.com/download) * Install Python 3.11 or newer - * On Windows the [Scoop](https://scoop.sh/) tool makes it easy to install and manage things like python. + * [https://www.python.org/downloads/](https://www.python.org/downloads/) has installers for all versions of Python for all os's + * On Windows the [Scoop](https://scoop.sh/) tool makes it easy to install and manage things like python + * Install Scoop then do `scoop install python` ## Prerequisites for dev * Install pipenv + * if you have pip you can install it globally using `pip install pipenv` * Install GNU Compatible Make command + * On windows if you have scoop installed you can install make with `scoop install make` -## Prerequisites for model quantization -If you want to be able to quantize custom models, download the following tool from the releases area: +## Prerequisites for huggingface model quantization +If you want to be able to quantize custom models from huggingface, download the following tool from the releases area: [HuggingFaceModelDownloader](https://github.com/bodaay/HuggingFaceModelDownloader) Install [Docker Desktop](https://www.docker.com/products/docker-desktop/) @@ -72,6 +82,11 @@ Once pipx is installed, run the following: ```bash pipx install parllama ``` +To upgrade an existing installation use the --force flag: +```bash +pipx install parllama --force +``` + ## Installing from mypi using pip Create a virtual environment and install using pip @@ -141,6 +156,10 @@ From parent folder of venv source venv/Scripts/activate parllama ``` +## Running against a remote instance +```bash +parllama -u "http://REMOTE_HOST:11434" +``` ## Running under Windows WSL Ollama by default only listens to localhost for connections, so you must set the environment variable OLLAMA_HOST=0.0.0.0:11434 @@ -160,7 +179,7 @@ parllama -u "http://$(hostname).local:11434" ``` Depending on your DNS setup if the above does not work, try this: ```bash - parllama -u "http://$(grep -m 1 nameserver /etc/resolv.conf | awk '{print $2}'):11434" +parllama -u "http://$(grep -m 1 nameserver /etc/resolv.conf | awk '{print $2}'):11434" ``` PAR_LLAMA will remember the -u flag so subsequent runs will not require that you specify it. @@ -175,15 +194,19 @@ make dev * Start parllama. * Click the "Site" tab. * Use ^R to fetch the latest models from Ollama.com. -* User the "Filter Site models" text box and type "llama3". +* Use the "Filter Site models" text box and type "llama3". * Find the entry with title of "llama3". * Click the blue tag "8B" to update the search box to read "llama3:8b". * Press ^P to pull the model from Ollama to your local machine. Depending on the size of the model and your internet connection this can take a few min. -* Click the "Local" tab to see models that have been locally downloaded -* Select the "llama3:8b" entry and press ^C to jump to the "Chat" tab and auto select the model +* Click the "Local" tab to see models that have been locally downloaded. +* Select the "llama3:8b" entry and press ^C to jump to the "Chat" tab and auto select the model. * Type a message to the model such as "Why is the sky blue?". It will take a few seconds for Ollama to load the model. After which the LLMs answer will stream in. * Towards the very top of the app you will see what model is loaded and what percent of it is loaded into the GPU / CPU. If a model cant be loaded 100% on the GPU it will run slower. * To export your conversation as a Markdown file type "/session.export" in the message input box. This will open a export dialog. +* Press ^N to add a new chat tab. +* Select a different model or change the temperature and ask the same questions. +* Jump between the tabs to compare responses by click the tabs or using slash commands `/tab.1` and `/tab.2` +* Press ^S to see all your past and current sessions. You can recall any past session by selecting it and pressing Enter or ^N if you want to load it into a new tab. * Type "/help" or "/?" to see what other slash commands are available. ## Themes @@ -254,18 +277,27 @@ if anything remains to be fixed before the commit is allowed. ## Roadmap -**Where we are** +### Where we are * Initial release - Find, maintain and create new models -* Basic chat with LLM -* Chat history / conversation management +* Connect to remote instances +* Chat with history / conversation management * Chat tabs allow chat with multiple models at same time -**Where we're going** +### Where we're going * Chat using embeddings for local documents * LLM tool use +* Ability to use other AI providers like Open AI ## What's new +### v0.3.2 +* Ollama ps stats bar now works with remote connections except for CPU / GPU %'s which ollama's api does not provide +* Chat tabs now have a session info bar with info like current / max context length +* Added conversation stop button to abort llm response +* Added ability to delete messages from session +* More model details displayed on model detail screen +* Better performance when changing session params on chat tab + ### v0.3.1 * Add chat tabs to support multiple sessions * Added cli option to prevent saving chat history to disk diff --git a/parllama/__init__.py b/parllama/__init__.py index c3124a5..1eb8e2b 100644 --- a/parllama/__init__.py +++ b/parllama/__init__.py @@ -6,7 +6,7 @@ __credits__ = ["Paul Robello"] __maintainer__ = "Paul Robello" __email__ = "probello@gmail.com" -__version__ = "0.3.1" +__version__ = "0.3.2" __licence__ = "MIT" __application_title__ = "PAR LLAMA" __application_binary__ = "parllama" diff --git a/parllama/app.py b/parllama/app.py index 2b4f9a0..d36e86d 100644 --- a/parllama/app.py +++ b/parllama/app.py @@ -7,10 +7,13 @@ from queue import Queue from typing import Any +import humanize import ollama import pyperclip # type: ignore from rich.columns import Columns +from rich.console import ConsoleRenderable from rich.console import RenderableType +from rich.console import RichCast from rich.progress_bar import ProgressBar from rich.style import Style from rich.text import Text @@ -20,6 +23,7 @@ from textual.binding import Binding from textual.color import Color from textual.message import Message +from textual.message_pump import MessagePump from textual.notifications import SeverityLevel from textual.timer import Timer from textual.widget import Widget @@ -32,7 +36,6 @@ from parllama.chat_manager import ChatManager from parllama.data_manager import dm from parllama.dialogs.help_dialog import HelpDialog -from parllama.messages.main import AppRequest from parllama.messages.main import ChangeTab from parllama.messages.main import CreateModelFromExistingRequested from parllama.messages.main import DeleteSession @@ -52,6 +55,7 @@ from parllama.messages.main import NotifyErrorMessage from parllama.messages.main import NotifyInfoMessage from parllama.messages.main import PsMessage +from parllama.messages.main import RegisterForUpdates from parllama.messages.main import SendToClipboard from parllama.messages.main import SessionListChanged from parllama.messages.main import SessionSelected @@ -99,7 +103,7 @@ class ParLlamaApp(App[None]): # DEFAULT_CSS = """ # """ - notify_subs: dict[str, set[Widget]] + notify_subs: dict[str, set[MessagePump]] main_screen: MainScreen job_queue: Queue[QueueJob] is_busy: bool = False @@ -111,7 +115,7 @@ class ParLlamaApp(App[None]): def __init__(self) -> None: """Initialize the application.""" super().__init__() - self.notify_subs = {"*": set[Widget]()} + self.notify_subs = {"*": set[MessagePump]()} chat_manager.set_app(self) self.job_timer = None @@ -155,42 +159,48 @@ def notify_error(self, event: NotifyErrorMessage) -> None: async def on_mount(self) -> None: """Display the main or locked screen.""" await self.push_screen(self.main_screen) - self.main_screen.post_message( - StatusMessage(f"Data folder: {settings.data_dir}") - ) - self.main_screen.post_message( - StatusMessage(f"Chat folder: {settings.chat_dir}") - ) - self.main_screen.post_message( + self.post_message_all(StatusMessage(f"Data folder: {settings.data_dir}")) + self.post_message_all(StatusMessage(f"Chat folder: {settings.chat_dir}")) + self.post_message_all( StatusMessage(f"Using Ollama server url: {settings.ollama_host}") ) if settings.ollama_ps_poll_interval: - self.main_screen.post_message( + self.post_message_all( StatusMessage( f"Polling Ollama ps every: {settings.ollama_ps_poll_interval} seconds" ) ) else: - self.main_screen.post_message(StatusMessage("Polling Ollama ps disabled")) + self.post_message_all(StatusMessage("Polling Ollama ps disabled")) - self.main_screen.post_message( + self.post_message_all( StatusMessage( f"""Theme: "{settings.theme_name}" in {settings.theme_mode} mode""" ) ) - self.main_screen.post_message( - StatusMessage(f"Last screen: {settings.last_screen}") - ) - self.main_screen.post_message( + self.post_message_all(StatusMessage(f"Last screen: {settings.last_screen}")) + self.post_message_all( StatusMessage(f"Last chat model: {settings.last_chat_model}") ) - self.main_screen.post_message( + self.post_message_all( StatusMessage(f"Last model temp: {settings.last_chat_temperature}") ) - self.main_screen.post_message( + self.post_message_all( StatusMessage(f"Last session name: {settings.last_chat_session_name}") ) + self.app.post_message( + RegisterForUpdates( + widget=self, + event_names=[ + "ModelPulled", + "ModelPushed", + "ModelCreated", + "LocalModelDeleted", + "LocalModelCopied", + ], + ) + ) self.job_timer = self.set_timer(1, self.do_jobs) if settings.ollama_ps_poll_interval > 0: self.ps_timer = self.set_timer(1, self.update_ps) @@ -384,12 +394,10 @@ async def do_progress(self, job: QueueJob, res: Iterator[dict[str, Any]]) -> str if pb: parts.append(pb) - self.main_screen.post_message( - StatusMessage(Columns(parts), log_it=False) - ) + self.post_message_all(StatusMessage(Columns(parts), log_it=False)) return last_status except ollama.ResponseError as e: - self.main_screen.post_message( + self.post_message_all( StatusMessage(Text.assemble(("error:" + str(e), "red"))) ) raise e @@ -403,7 +411,8 @@ async def do_pull(self, job: PullModelJob) -> None: self.post_message_all( ModelPulled(model_name=job.modelName, success=last_status == "success") ) - except ollama.ResponseError: + except ollama.ResponseError as e: + self.log_it(e) self.post_message_all(ModelPulled(model_name=job.modelName, success=False)) async def do_push(self, job: PushModelJob) -> None: @@ -508,21 +517,29 @@ def action_refresh_models(self) -> None: @on(UnRegisterForUpdates) def on_unregister_for_updates(self, event: UnRegisterForUpdates) -> None: - """Unregister for updates event""" + """Unregister widget from all updates""" if not event.widget: return for _, s in self.notify_subs.items(): s.discard(event.widget) - @on(AppRequest) - def on_app_request(self, event: AppRequest) -> None: - """Add any widget that requests an action to notify_subs""" - if not event.widget: - return - self.notify_subs["*"].add(event.widget) - if event.__class__.__name__ not in self.notify_subs: - self.notify_subs[event.__class__.__name__] = set() - self.notify_subs[event.__class__.__name__].add(event.widget) + @on(RegisterForUpdates) + def on_register_for_updates(self, event: RegisterForUpdates) -> None: + """Register for updates event""" + for event_name in event.event_names: + if event_name not in self.notify_subs: + self.notify_subs[event_name] = set() + self.notify_subs[event_name].add(event.widget) + + # @on(AppRequest) + # def on_app_request(self, event: AppRequest) -> None: + # """Add any widget that requests an action to notify_subs""" + # if not event.widget: + # return + # # self.notify_subs["*"].add(event.widget) + # # if event.__class__.__name__ not in self.notify_subs: + # # self.notify_subs[event.__class__.__name__] = set() + # # self.notify_subs[event.__class__.__name__].add(event.widget) @on(LocalModelListRefreshRequested) def on_model_list_refresh_requested(self) -> None: @@ -537,21 +554,18 @@ async def refresh_models(self): """Refresh the models.""" self.is_refreshing = True try: - self.main_screen.post_message( - StatusMessage("Local model list refreshing...") - ) + self.post_message_all(StatusMessage("Local model list refreshing...")) dm.refresh_models() - self.main_screen.post_message(StatusMessage("Local model list refreshed")) - self.main_screen.local_view.post_message(LocalModelListLoaded()) - self.main_screen.chat_view.post_message(LocalModelListLoaded()) + self.post_message_all(StatusMessage("Local model list refreshed")) + self.post_message_all(LocalModelListLoaded()) finally: self.is_refreshing = False - @on(LocalModelListLoaded) - def on_model_data_loaded(self) -> None: - """Refresh model completed""" - self.main_screen.post_message(StatusMessage("Local model list refreshed")) - # self.notify("Local models refreshed.") + # @on(LocalModelListLoaded) + # def on_model_data_loaded(self) -> None: + # """Refresh model completed""" + # self.post_message_all(StatusMessage("Local model list refreshed")) + # # self.notify("Local models refreshed.") @on(SiteModelsRefreshRequested) def on_site_models_refresh_requested(self, msg: SiteModelsRefreshRequested) -> None: @@ -571,7 +585,7 @@ async def refresh_site_models(self, msg: SiteModelsRefreshRequested): """Refresh the site model.""" self.is_refreshing = True try: - self.main_screen.post_message( + self.post_message_all( StatusMessage( f"Site models for {msg.ollama_namespace or 'models'} refreshing... force={msg.force}" ) @@ -580,7 +594,7 @@ async def refresh_site_models(self, msg: SiteModelsRefreshRequested): self.main_screen.site_view.post_message( SiteModelsLoaded(ollama_namespace=msg.ollama_namespace) ) - self.main_screen.post_message( + self.post_message_all( StatusMessage( f"Site models for {msg.ollama_namespace or 'models'} loaded. force={msg.force}" ) @@ -592,38 +606,33 @@ async def refresh_site_models(self, msg: SiteModelsRefreshRequested): @work(group="update_ps", thread=True) async def update_ps(self) -> None: """Update ps status bar msg""" - if not dm.ollama_bin: - self.notify( - "Ollama binary not found. PS output not available.", - severity="error", - timeout=6, - ) - return was_blank = False while self.is_running: if settings.ollama_ps_poll_interval < 1: - self.main_screen.post_message(PsMessage(msg="")) + self.post_message_all(PsMessage(msg="")) break await asyncio.sleep(settings.ollama_ps_poll_interval) ret = dm.model_ps() - if not ret: + if len(ret.models) < 1: if not was_blank: - self.main_screen.post_message(PsMessage(msg="")) + self.post_message_all(PsMessage(msg="")) was_blank = True continue was_blank = False - info = ret[0] # only take first one since ps status bar is a single line - self.main_screen.post_message( + info = ret.models[ + 0 + ] # only take first one since ps status bar is a single line + self.post_message_all( PsMessage( msg=Text.assemble( "Name: ", - info["name"], + info.name, " Size: ", - info["size"], + humanize.naturalsize(info.size_vram), " Processor: ", - info["processor"], + ret.processor, " Until: ", - info["until"], + humanize.naturaltime(info.expires_at), ) ) ) @@ -633,16 +642,21 @@ def status_notify(self, msg: str, severity: SeverityLevel = "information") -> No self.notify(msg, severity=severity) self.main_screen.post_message(StatusMessage(msg)) - def post_message_all(self, msg: Message, sub_name: str = "*") -> None: + def post_message_all(self, event: Message) -> None: """Post a message to all screens""" - if isinstance(msg, StatusMessage): - self.log(msg.msg) - self.last_status = msg.msg + if isinstance(event, StatusMessage): + if event.log_it: + self.log(event.msg) + self.last_status = event.msg + self.main_screen.post_message(event) + return + if isinstance(event, PsMessage): + self.main_screen.post_message(event) + return + sub_name = event.__class__.__name__ if sub_name in self.notify_subs: for w in list(self.notify_subs[sub_name]): - w.post_message(msg) - if self.main_screen: - self.main_screen.post_message(msg) + w.post_message(event) @on(ChangeTab) def on_change_tab(self, event: ChangeTab) -> None: @@ -655,7 +669,9 @@ def on_create_model_from_existing_requested( self, msg: CreateModelFromExistingRequested ) -> None: """Create model from existing event""" - self.main_screen.create_view.name_input.value = f"my-{msg.model_name}:latest" + self.main_screen.create_view.name_input.value = f"my-{msg.model_name}" + if not self.main_screen.create_view.name_input.value.endswith(":latest"): + self.main_screen.create_view.name_input.value += ":latest" self.main_screen.create_view.text_area.text = msg.model_code self.main_screen.create_view.quantize_input.value = msg.quantization_level or "" self.main_screen.change_tab("Create") @@ -670,20 +686,23 @@ async def on_model_interact_requested(self, event: ModelInteractRequested) -> No self.main_screen.chat_view.user_input.focus() @on(SessionListChanged) - def on_session_list_changed(self) -> None: + def on_session_list_changed(self, event: SessionListChanged) -> None: """Session list changed event""" - self.main_screen.chat_view.session_list.post_message(SessionListChanged()) + event.stop() + self.post_message_all(event) @on(SessionSelected) def on_session_selected(self, event: SessionSelected) -> None: """Session selected event""" - self.main_screen.chat_view.post_message( - SessionSelected(session_id=event.session_id, new_tab=event.new_tab) - ) + event.stop() + self.post_message_all(event) @on(DeleteSession) def on_delete_session(self, event: DeleteSession) -> None: """Delete session event""" - self.main_screen.chat_view.post_message( - DeleteSession(session_id=event.session_id) - ) + event.stop() + self.post_message_all(event) + + def log_it(self, msg: ConsoleRenderable | RichCast | str | object) -> None: + """Log a message to the log view""" + self.main_screen.log_view.richlog.write(msg) diff --git a/parllama/data_manager.py b/parllama/data_manager.py index d4e2656..9168361 100644 --- a/parllama/data_manager.py +++ b/parllama/data_manager.py @@ -13,6 +13,7 @@ import docker.errors # type: ignore import docker.types # type: ignore +import httpx import requests import simplejson as json from bs4 import BeautifulSoup @@ -24,6 +25,7 @@ from parllama.models.ollama_data import ModelShowPayload from parllama.models.ollama_data import SiteModel from parllama.models.ollama_data import SiteModelData +from parllama.models.ollama_ps import OllamaPsResponse from parllama.models.settings_data import settings from parllama.utils import output_to_dicts from parllama.utils import run_cmd @@ -31,6 +33,21 @@ from parllama.widgets.site_model_list_item import SiteModelListItem +def api_model_ps() -> OllamaPsResponse: + """Get model ps.""" + # fetch data from self.ollama_host as json + res = httpx.get(f"{settings.ollama_host}/api/ps") + if res.status_code != 200: + return OllamaPsResponse() + try: + ret = OllamaPsResponse(**res.json()) + return ret + except Exception as e: # pylint: disable=broad-exception-caught + print(f"Error: {e}") + print(res.text) + return OllamaPsResponse() + + class DataManager: """Data manager for Par Llama.""" @@ -50,6 +67,20 @@ def __init__(self): self.ollama_bin = str(ollama_bin) if ollama_bin is not None else None + def model_ps(self) -> OllamaPsResponse: + """Get model ps.""" + api_ret = api_model_ps() + if not self.ollama_bin: + return api_ret + ret = run_cmd([self.ollama_bin, "ps"]) + + if not ret: + return api_ret + local_ret = output_to_dicts(ret) + if len(local_ret) > 0: + api_ret.processor = local_ret[0]["processor"] + return api_ret + def get_model_by_name(self, name: str) -> FullModel | None: """Get a model by name.""" for model in self.models: @@ -65,13 +96,23 @@ def _get_all_model_data() -> list[LocalModelListItem]: pattern = r"^(# Modelfile .*)\n(# To build.*)\n# (FROM .*\n)\n(FROM .*)\n(.*)$" replacement = r"\3\5" for model in res.models: - res2 = ModelShowPayload(**settings.ollama_client.show(model.name)) - + model_data = settings.ollama_client.show(model.name) + res2 = ModelShowPayload(**model_data) res2.modelfile = re.sub( pattern, replacement, res2.modelfile, flags=re.MULTILINE | re.IGNORECASE ) - res3 = FullModel(**model.model_dump(), **res2.model_dump()) + res3 = FullModel( + **model.model_dump(), + parameters=res2.parameters, + template=res2.template, + modelfile=res2.modelfile, + model_info=res2.model_info, + ) + # print(json.dumps(model.model_dump(), indent=2, default=str)) + # print(json.dumps(res2.model_dump(), indent=2, default=str)) + # print(json.dumps(res3.model_dump(), indent=2, default=str)) all_models.append(LocalModelListItem(res3)) + # break return all_models def refresh_models(self) -> list[LocalModelListItem]: @@ -83,16 +124,6 @@ def get_model_select_options(self) -> list[tuple[str, str]]: """Get select options.""" return [(model.model.name, model.model.name) for model in self.models] - def model_ps(self) -> list[dict[str, Any]]: - """Get model ps.""" - if not self.ollama_bin: - return [] - ret = run_cmd([self.ollama_bin, "ps"]) - - if not ret: - return [] - return output_to_dicts(ret) - @staticmethod def pull_model(model_name: str) -> Iterator[dict[str, Any]]: """Pull a model.""" @@ -143,7 +174,7 @@ def refresh_site_models( settings.ensure_cache_folder() if not namespace: - namespace = "models" + namespace = "library" namespace = os.path.basename(namespace) file_name = os.path.join(settings.cache_dir, f"site_models-{namespace}.json") diff --git a/parllama/dialogs/model_details_dialog.py b/parllama/dialogs/model_details_dialog.py index a150b4a..e252015 100644 --- a/parllama/dialogs/model_details_dialog.py +++ b/parllama/dialogs/model_details_dialog.py @@ -11,6 +11,7 @@ from textual.screen import ModalScreen from textual.widgets import Button from textual.widgets import MarkdownViewer +from textual.widgets import Pretty from textual.widgets import Static from textual.widgets import TextArea @@ -26,6 +27,14 @@ class ModelDetailsDialog(ModalScreen[None]): ModelDetailsDialog { background: black 75%; align: center middle; + #model_info { + background: $panel; + height: 10; + width: 1fr; + margin-bottom: 1; + border: solid $primary; + border-title-color: $primary; + } &> VerticalScroll { background: $surface; width: 75%; @@ -93,6 +102,11 @@ def compose(self) -> ComposeResult: ta.read_only = True yield ta + with VerticalScroll(id="model_info") as vs2: + vs2.border_title = "Model Info" + info = self.model.model_info.model_dump(mode="json", exclude_unset=True) + yield Pretty(info) + ta = TextArea( self.model.modelfile, id="modelfile", classes="editor height-10" ) diff --git a/parllama/help.md b/parllama/help.md index c7aff83..6eee65f 100644 --- a/parllama/help.md +++ b/parllama/help.md @@ -91,11 +91,11 @@ Chat with local LLMs ### Chat Screen Session Panel keys -| Key | Command | -|----------|------------------------------------------| -| `enter` | Load selected session into current tab | -| `ctrl+n` | Load selected session into new tab | -| `delete` | Delete selected session and related tabs | +| Key | Command | +|-------------------------|------------------------------------------| +| `enter` or `dbl click` | Load selected session into current tab | +| `ctrl+n` | Load selected session into new tab | +| `delete` | Delete selected session and related tabs | ### Chat Slash Commands: diff --git a/parllama/messages/main.py b/parllama/messages/main.py index 68a7298..e5ee4e2 100644 --- a/parllama/messages/main.py +++ b/parllama/messages/main.py @@ -6,27 +6,32 @@ from rich.console import RenderableType from textual.message import Message -from textual.widget import Widget +from textual.message_pump import MessagePump -from ..models.ollama_data import FullModel +from parllama.models.ollama_data import FullModel @dataclass class AppRequest(Message): """Request to app to perform an action.""" - widget: Widget | None + widget: MessagePump | None @dataclass -class RegisterForUpdates(AppRequest): +class RegisterForUpdates(Message): """Register widget for updates.""" + widget: MessagePump + event_names: list[str] + @dataclass -class UnRegisterForUpdates(AppRequest): +class UnRegisterForUpdates(Message): """Unregister widget for updates.""" + widget: MessagePump + @dataclass class LocalModelCopied(Message): @@ -233,7 +238,14 @@ class ChatMessageSent(Message): class NewChatSession(Message): """New chat session class""" - id: str + session_id: str + + +@dataclass +class SessionUpdated(Message): + """Session Was Updated""" + + session_id: str @dataclass @@ -263,6 +275,20 @@ class DeleteSession(Message): session_id: str +@dataclass +class StopChatGeneration(Message): + """Request chat generation to be stopped.""" + + session_id: str + + +@dataclass +class ChatGenerationAborted(Message): + """Chat generation has been aborted.""" + + session_id: str + + @dataclass class UpdateChatControlStates(Message): """Notify that chat control states need to be updated.""" diff --git a/parllama/models/chat.py b/parllama/models/chat.py index b26a73c..c25ceee 100644 --- a/parllama/models/chat.py +++ b/parllama/models/chat.py @@ -16,7 +16,9 @@ from ollama import Options as OllamaOptions from textual.widget import Widget +from parllama.messages.main import ChatGenerationAborted from parllama.messages.main import ChatMessage +from parllama.messages.main import SessionUpdated from parllama.models.settings_data import settings @@ -103,6 +105,9 @@ class ChatSession: messages: list[OllamaMessage] id_to_msg: dict[str, OllamaMessage] last_updated: datetime.datetime + _name_generated: bool = False + _abort: bool = False + _generating: bool = False # pylint: disable=too-many-arguments def __init__( @@ -170,32 +175,58 @@ def set_temperature(self, temperature: float | None) -> None: async def send_chat(self, from_user: str, widget: Widget) -> bool: """Send a chat message to LLM""" - msg: OllamaMessage = OllamaMessage(content=from_user, role="user") - self.add_message(msg) - widget.post_message( - ChatMessage(session_id=self.session_id, message_id=msg.message_id) - ) - - msg = OllamaMessage(content="", role="assistant") - self.add_message(msg) - widget.post_message( - ChatMessage(session_id=self.session_id, message_id=msg.message_id) - ) - - stream: Iterator[Mapping[str, Any]] = settings.ollama_client.chat( # type: ignore - model=self.llm_model_name, - messages=[m.to_ollama_native() for m in self.messages], - options=self.options, - stream=True, - ) + self._generating = True + try: + msg: OllamaMessage = OllamaMessage(content=from_user, role="user") + self.add_message(msg) + widget.post_message( + ChatMessage(session_id=self.session_id, message_id=msg.message_id) + ) - for chunk in stream: - msg.content += chunk["message"]["content"] + msg = OllamaMessage(content="", role="assistant") + self.add_message(msg) widget.post_message( ChatMessage(session_id=self.session_id, message_id=msg.message_id) ) - msg.save() - return True + + stream: Iterator[Mapping[str, Any]] = settings.ollama_client.chat( # type: ignore + model=self.llm_model_name, + messages=[m.to_ollama_native() for m in self.messages], + options=self.options, + stream=True, + ) + is_aborted = False + for chunk in stream: + msg.content += chunk["message"]["content"] + if self._abort: + is_aborted = True + try: + msg.content += "\n\nAborted..." + widget.post_message(ChatGenerationAborted(self.session_id)) + stream.close() # type: ignore + except Exception: # pylint:disable=broad-except + pass + finally: + self._abort = False + break + widget.post_message( + ChatMessage(session_id=self.session_id, message_id=msg.message_id) + ) + + msg.save() + + if ( + not is_aborted + and settings.auto_name_session + and not self._name_generated + ): + self._name_generated = True + self.set_name(self.gen_session_name(self.messages[0].content)) + widget.post_message(SessionUpdated(session_id=self.session_id)) + finally: + self._generating = False + + return not is_aborted def new_session(self, session_name: str = "My Chat"): """Start new session""" @@ -204,6 +235,28 @@ def new_session(self, session_name: str = "My Chat"): self.messages.clear() self.id_to_msg.clear() + def gen_session_name(self, text: str) -> str: + """Generate a session name from the given text using llm""" + ret = settings.ollama_client.generate( + model=settings.auto_name_session_llm or self.llm_model_name, + options={"temperature": 0.1}, + prompt=text, + system="""You are a helpful assistant. + You will be given some text to summarize. + You must follow these instructions: + * Generate a descriptive session name of no more than a 4 words. + * Only output the session name. + * Do not answer any questions or explain anything. + * Do not output any preamble. + Examples: + "Why is grass green" -> "Green Grass" + "Why is the sky blue?" -> "Blue Sky" + "What is the tallest mountain?" -> "Tallest Mountain" + "What is the meaning of life?" -> "Meaning of Life" + """, + ) + return ret["response"].strip() # type: ignore + def __iter__(self): """Iterate over messages""" return iter(self.messages) @@ -334,3 +387,25 @@ def export_as_markdown(self, filename: str) -> bool: return True except (OSError, IOError): return False + + def stop_generation(self) -> None: + """Stop LLM model generation""" + self._abort = True + + @property + def abort_pending(self) -> bool: + """Check if LLM model generation is pending""" + return self._abort + + @property + def is_generating(self) -> bool: + """Check if LLM model generation is in progress""" + return self._generating + + @property + def context_length(self) -> int: + """Return current message context length""" + total: int = 0 + for msg in self.messages: + total += len(msg.content) + return total diff --git a/parllama/models/ollama_data.py b/parllama/models/ollama_data.py index 777d4b8..37ec25a 100644 --- a/parllama/models/ollama_data.py +++ b/parllama/models/ollama_data.py @@ -5,10 +5,13 @@ from datetime import datetime from typing import cast from typing import Literal +from typing import Optional from typing import TypeAlias import ollama from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field MessageRoles: TypeAlias = Literal["user", "assistant", "system"] @@ -43,14 +46,65 @@ class ModelDetails(BaseModel): quantization_level: str +class ModelInfo(BaseModel): + """Ollama Model Info.""" + + general_architecture: Optional[str] = Field(None, alias="general.architecture") + general_file_type: Optional[int] = Field(None, alias="general.file_type") + general_parameter_count: Optional[int] = Field( + None, alias="general.parameter_count" + ) + general_quantization_version: Optional[int] = Field( + None, alias="general.quantization_version" + ) + llama_attention_head_count: Optional[int] = Field( + None, alias="llama.attention.head_count" + ) + llama_attention_head_count_kv: Optional[int] = Field( + None, alias="llama.attention.head_count_kv" + ) + llama_attention_layer_norm_rms_epsilon: Optional[float] = Field( + None, alias="llama.attention.layer_norm_rms_epsilon" + ) + llama_block_count: Optional[int] = Field(None, alias="llama.block_count") + llama_context_length: Optional[int] = Field(None, alias="llama.context_length") + llama_embedding_length: Optional[int] = Field(None, alias="llama.embedding_length") + llama_feed_forward_length: Optional[int] = Field( + None, alias="llama.feed_forward_length" + ) + llama_rope_dimension_count: Optional[int] = Field( + None, alias="llama.rope.dimension_count" + ) + llama_rope_freq_base: Optional[int] = Field(None, alias="llama.rope.freq_base") + llama_vocab_size: Optional[int] = Field(None, alias="llama.vocab_size") + tokenizer_ggml_bos_token_id: Optional[int] = Field( + None, alias="tokenizer.ggml.bos_token_id" + ) + tokenizer_ggml_eos_token_id: Optional[int] = Field( + None, alias="tokenizer.ggml.eos_token_id" + ) + tokenizer_ggml_merges: Optional[list[str]] = Field( + None, alias="tokenizer.ggml.merges" + ) + tokenizer_ggml_model: Optional[str] = Field(None, alias="tokenizer.ggml.model") + tokenizer_ggml_pre: Optional[str] = Field(None, alias="tokenizer.ggml.pre") + tokenizer_ggml_token_type: Optional[list[str]] = Field( + None, alias="tokenizer.ggml.token_type" + ) + tokenizer_ggml_tokens: Optional[list[str]] = Field( + None, alias="tokenizer.ggml.tokens" + ) + + class ModelShowPayload(BaseModel): """Ollama Model Show Payload.""" - license: str | None = None + model_config = ConfigDict(protected_namespaces=()) modelfile: str parameters: str | None = None template: str - # details: ModelDetails # omit of being combined with Model + details: ModelDetails # omit if being combined with Model + model_info: ModelInfo class Model(BaseModel): @@ -74,10 +128,12 @@ class ModelListPayload(BaseModel): class FullModel(Model): """Ollama Full Model""" + model_config = ConfigDict(protected_namespaces=()) license: str | None = None modelfile: str parameters: str | None = None template: str | None = None + model_info: ModelInfo def get_messages(self) -> list[ollama.Message]: """Get messages from the model.""" diff --git a/parllama/models/ollama_ps.py b/parllama/models/ollama_ps.py new file mode 100644 index 0000000..094e2c5 --- /dev/null +++ b/parllama/models/ollama_ps.py @@ -0,0 +1,36 @@ +"""Ollama PS response model.""" +from __future__ import annotations + +import datetime + +from pydantic import BaseModel + + +class OllamaPsModelDetails(BaseModel): + """Ollama PS Model Details.""" + + parent_model: str + format: str + family: str + families: list[str] + parameter_size: str + quantization_level: str + + +class OllamaPsModel(BaseModel): + """Ollama PS Model.""" + + name: str + model: str + size: int + digest: str + details: OllamaPsModelDetails + expires_at: datetime.datetime + size_vram: int + + +class OllamaPsResponse(BaseModel): + """Ollama PS response model.""" + + models: list[OllamaPsModel] = [] + processor: str = "- / -" diff --git a/parllama/models/settings_data.py b/parllama/models/settings_data.py index 6de1495..13af11f 100644 --- a/parllama/models/settings_data.py +++ b/parllama/models/settings_data.py @@ -23,6 +23,7 @@ class Settings(BaseModel): data_dir: str = os.path.expanduser("~/.parllama") cache_dir: str = "" chat_dir: str = "" + chat_tab_max_length: int = 15 settings_file: str = "settings.json" theme_name: str = "par" starting_screen: ScreenType = "Local" @@ -35,7 +36,8 @@ class Settings(BaseModel): max_log_lines: int = 1000 ollama_host: str = "http://localhost:11434" ollama_ps_poll_interval: int = 3 - auto_name_chat: bool = True + auto_name_session: bool = False + auto_name_session_llm: str = "" # pylint: disable=too-many-branches, too-many-statements def __init__(self) -> None: @@ -141,7 +143,15 @@ def load_from_file(self) -> None: self.ollama_ps_poll_interval = data.get( "ollama_ps_poll_interval", self.ollama_ps_poll_interval ) - self.auto_name_chat = data.get("auto_name_chat", self.auto_name_chat) + self.auto_name_session = data.get( + "auto_name_chat", self.auto_name_session + ) + self.auto_name_session_llm = data.get( + "auto_name_session_llm", self.auto_name_session_llm + ) + self.chat_tab_max_length = max( + 8, data.get("chat_tab_max_length", self.chat_tab_max_length) + ) except FileNotFoundError: pass # If file does not exist, continue with default settings diff --git a/parllama/screens/main_screen.tcss b/parllama/screens/main_screen.tcss index 03ff1ad..ca84d9f 100644 --- a/parllama/screens/main_screen.tcss +++ b/parllama/screens/main_screen.tcss @@ -31,10 +31,3 @@ ListView:focus > SiteModelListItem.--highlight { background: $surface; } } - -SessionList { - width: 40; - height: 1fr; - dock: left; - padding: 1; -} diff --git a/parllama/widgets/chat_message_list.py b/parllama/widgets/chat_message_list.py new file mode 100644 index 0000000..35880ca --- /dev/null +++ b/parllama/widgets/chat_message_list.py @@ -0,0 +1,30 @@ +"""Chat message list widget.""" +from __future__ import annotations + +from textual.containers import VerticalScroll + + +class ChatMessageList(VerticalScroll, can_focus=False, can_focus_children=True): + """Chat message list widget.""" + + DEFAULT_CSS = """ + ChatMessageList { + background: $primary-background; + ChatMessageWidget { + padding: 1; + border: none; + border-left: blank; + &:focus { + border-left: thick $primary; + } + } + MarkdownH2 { + margin: 0; + padding: 0; + } + } + """ + + def __init__(self, **kwargs) -> None: + """Initialise the view.""" + super().__init__(**kwargs) diff --git a/parllama/widgets/input_blur_submit.py b/parllama/widgets/input_blur_submit.py new file mode 100644 index 0000000..b6395ae --- /dev/null +++ b/parllama/widgets/input_blur_submit.py @@ -0,0 +1,17 @@ +"""Input field that submits when losing focus.""" +from __future__ import annotations + +from textual.events import Blur +from textual.widgets import Input + + +class InputBlurSubmit(Input): + """Input field that submits when losing focus.""" + + def __init__(self, **kwargs) -> None: + """Initialise the widget.""" + super().__init__(**kwargs) + + async def on_blur(self, _: Blur) -> None: + """Submit the input when losing focus.""" + await self.action_submit() diff --git a/parllama/widgets/session_list.py b/parllama/widgets/session_list.py index f9e9b4a..9e732be 100644 --- a/parllama/widgets/session_list.py +++ b/parllama/widgets/session_list.py @@ -13,8 +13,10 @@ from parllama.chat_manager import chat_manager from parllama.messages.main import DeleteSession +from parllama.messages.main import RegisterForUpdates from parllama.messages.main import SessionListChanged from parllama.messages.main import SessionSelected +from parllama.widgets.dbl_click_list_item import DblClickListItem from parllama.widgets.session_list_item import SessionListItem @@ -52,6 +54,14 @@ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.list_view = ListView(initial_index=None) + async def on_mount(self) -> None: + """Set up the dialog once the DOM is ready.""" + self.app.post_message( + RegisterForUpdates( + widget=self, event_names=["SessionListChanged", "SessionSelected"] + ) + ) + def compose(self) -> ComposeResult: """Compose the content of the view.""" yield Static("Sessions") @@ -89,6 +99,7 @@ async def on_session_list_changed(self, event: SessionListChanged) -> None: self.list_view.index = self.list_view.children.index(item) break + @on(DblClickListItem.DoubleClicked) def action_load_item(self) -> None: """Handle list view selected event.""" selected_item: SessionListItem = cast( diff --git a/parllama/widgets/session_list_item.py b/parllama/widgets/session_list_item.py index 2ed857b..b2ebf9a 100644 --- a/parllama/widgets/session_list_item.py +++ b/parllama/widgets/session_list_item.py @@ -5,12 +5,12 @@ from textual.app import ComposeResult from textual.containers import Vertical from textual.widgets import Label -from textual.widgets import ListItem from parllama.models.chat import ChatSession +from parllama.widgets.dbl_click_list_item import DblClickListItem -class SessionListItem(ListItem, can_focus=False, can_focus_children=True): +class SessionListItem(DblClickListItem, can_focus=False, can_focus_children=True): """Session list item""" DEFAULT_CSS = """ diff --git a/parllama/widgets/views/chat_tab.py b/parllama/widgets/views/chat_tab.py index a1c3b0c..a8490b8 100644 --- a/parllama/widgets/views/chat_tab.py +++ b/parllama/widgets/views/chat_tab.py @@ -2,29 +2,37 @@ from __future__ import annotations import uuid +from typing import cast +import humanize from ollama import Options +from rich.text import Text from textual import on from textual import work from textual.app import ComposeResult +from textual.binding import Binding from textual.containers import Horizontal from textual.containers import Vertical -from textual.containers import VerticalScroll from textual.events import Focus from textual.events import Show +from textual.message import Message from textual.reactive import Reactive from textual.widgets import Button from textual.widgets import Input from textual.widgets import Label from textual.widgets import Select +from textual.widgets import Static from textual.widgets import TabbedContent from textual.widgets import TabPane from parllama.chat_manager import chat_manager from parllama.data_manager import dm from parllama.messages.main import ChatMessage +from parllama.messages.main import ChatMessageSent from parllama.messages.main import DeleteSession +from parllama.messages.main import LocalModelDeleted from parllama.messages.main import LocalModelListLoaded +from parllama.messages.main import RegisterForUpdates from parllama.messages.main import SessionSelected from parllama.messages.main import UnRegisterForUpdates from parllama.messages.main import UpdateChatControlStates @@ -35,57 +43,50 @@ from parllama.screens.save_session import SaveSession from parllama.utils import str_ellipsis from parllama.widgets.chat_message import ChatMessageWidget +from parllama.widgets.chat_message_list import ChatMessageList +from parllama.widgets.input_blur_submit import InputBlurSubmit from parllama.widgets.input_tab_complete import InputTabComplete from parllama.widgets.session_list import SessionList -MAX_TAB_TITLE_LENGTH = 12 - class ChatTab(TabPane): """Chat tab""" + BINDINGS = [ + Binding( + key="delete", + action="delete_msg", + description="Delete Msg", + show=True, + ), + ] DEFAULT_CSS = """ ChatTab { - #tool_bar { - height: 3; - background: $surface-darken-1; - #model_name { - max-width: 40; + #tool_bar { + height: 3; + background: $surface-darken-1; + #model_name { + max-width: 40; + } } #temperature_input { - width: 11; + width: 11; } #session_name_input { - min-width: 15; - max-width: 40; - width: auto; + min-width: 15; + max-width: 40; + width: auto; } #new_button { - margin-left: 2; - min-width: 9; - background: $warning-darken-2; - border-top: tall $warning-lighten-1; + margin-left: 2; + min-width: 9; + background: $warning-darken-2; + border-top: tall $warning-lighten-1; } Label { - margin: 1; - background: transparent; - } - } - #messages { - background: $primary-background; - ChatMessageWidget{ - padding: 1; - border: none; - border-left: blank; - &:focus { - border-left: thick $primary; - } + margin: 1; + background: transparent; } - MarkdownH2 { - margin: 0; - padding: 0; - } - } } """ @@ -99,7 +100,7 @@ def __init__( session_name = chat_manager.mk_session_name("New Chat") super().__init__( id=f"tp_{uuid.uuid4().hex}", - title=str_ellipsis(session_name, MAX_TAB_TITLE_LENGTH), + title=str_ellipsis(session_name, settings.chat_tab_max_length), **kwargs, ) self.session_list = session_list @@ -107,7 +108,7 @@ def __init__( self.model_select: Select[str] = Select( id="model_name", options=[], prompt="Select Model" ) - self.temperature_input: Input = Input( + self.temperature_input: InputBlurSubmit = InputBlurSubmit( id="temperature_input", value=( f"{settings.last_chat_temperature:.2f}" @@ -116,14 +117,15 @@ def __init__( ), max_length=4, restrict=r"^\d?\.?\d?\d?$", + valid_empty=False, ) - self.session_name_input: Input = Input( + self.session_name_input: InputBlurSubmit = InputBlurSubmit( id="session_name_input", value=session_name, + valid_empty=False, ) - self.vs: VerticalScroll = VerticalScroll(id="messages") - self.vs.can_focus = False + self.vs: ChatMessageList = ChatMessageList(id="messages") self.busy = False self.session = chat_manager.get_or_create_session_name( @@ -132,6 +134,8 @@ def __init__( options=self.get_session_options(), ) + self.session_status_bar = Static("", id="SessionStatusBar") + def _watch_busy(self, value: bool) -> None: """Update controls when busy changes""" self.update_control_states() @@ -147,6 +151,7 @@ def _watch_busy(self, value: bool) -> None: def compose(self) -> ComposeResult: """Compose the content of the view.""" with Vertical(id="main"): + yield self.session_status_bar with Horizontal(id="tool_bar"): yield self.model_select yield Label("Temp") @@ -162,7 +167,11 @@ def compose(self) -> ComposeResult: async def on_mount(self) -> None: """Set up the dialog once the DOM is ready.""" - # self.app.post_message(RegisterForUpdates(widget=self)) + self.app.post_message( + RegisterForUpdates( + widget=self, event_names=["LocalModelDeleted", "LocalModelListLoaded"] + ) + ) async def on_unmount(self) -> None: """Remove dialog from updates when unmounted.""" @@ -185,9 +194,10 @@ def update_session_select(self) -> None: SessionSelected(session_id=self.session.session_id) ) - @on(Input.Changed, "#temperature_input") - def on_temperature_input_changed(self) -> None: + @on(Input.Submitted, "#temperature_input") + def on_temperature_input_changed(self, event: Message) -> None: """Handle temperature input change""" + event.stop() try: if self.temperature_input.value: settings.last_chat_temperature = float(self.temperature_input.value) @@ -198,10 +208,12 @@ def on_temperature_input_changed(self) -> None: self.session.set_temperature(settings.last_chat_temperature) settings.save_settings_to_file() chat_manager.notify_changed() + self.user_input.focus() - @on(Input.Changed, "#session_name_input") - def on_session_name_input_changed(self) -> None: + @on(Input.Submitted, "#session_name_input") + def on_session_name_input_changed(self, event: Message) -> None: """Handle session name input change""" + event.stop() session_name: str = self.session_name_input.value.strip() if not session_name: return @@ -210,6 +222,7 @@ def on_session_name_input_changed(self) -> None: self.session.set_name(settings.last_chat_session_name) self.notify_tab_label_changed() chat_manager.notify_changed() + self.user_input.focus() def update_control_states(self): """Update disabled state of controls based on model and user input values""" @@ -222,7 +235,11 @@ def on_model_select_changed(self) -> None: if self.model_select.value not in (Select.BLANK, settings.last_chat_model): settings.last_chat_model = str(self.model_select.value) settings.save_settings_to_file() - self.session.set_llm_model(self.model_select.value) # type: ignore + if self.model_select.value != Select.BLANK: + self.session.set_llm_model(self.model_select.value) # type: ignore + else: + self.session.set_llm_model("") + self.update_session_status_bar() def set_model_name(self, model_name: str) -> None: """ "Set model names""" @@ -248,17 +265,6 @@ def on_local_model_list_loaded(self, evt: LocalModelListLoaded) -> None: if v == old_v: self.model_select.value = old_v - # if self.model_select.value != Select.BLANK: - # self.session.set_llm_model(self.model_select.value) # type: ignore - # # TODO fix this smell - # if ( - # self.parent - # and self.parent.parent - # and self.parent.parent.parent - # and cast(TabbedContent, self.parent.parent.parent).active == "Chat" - # ): - # self.user_input.focus() - @on(Button.Pressed, "#new_button") async def on_new_button_pressed(self, event: Button.Pressed) -> None: """New button pressed""" @@ -294,6 +300,7 @@ async def action_new_session(self, session_name: str = "New Chat") -> None: self.session.add_message( OllamaMessage(role=msg["role"], content=msg["content"]) ) + self.update_session_status_bar() self.user_input.focus() def notify_tab_label_changed(self) -> None: @@ -301,20 +308,18 @@ def notify_tab_label_changed(self) -> None: self.post_message( UpdateTabLabel( str(self.id), - str_ellipsis(self.session.session_name, MAX_TAB_TITLE_LENGTH), + str_ellipsis(self.session.session_name, settings.chat_tab_max_length), ) ) @on(ChatMessage) async def on_chat_message(self, event: ChatMessage) -> None: """Handle a chat message""" - event.stop() - ses = chat_manager.get_session(event.session_id) - if not ses: - self.notify("Chat session not found", severity="error") + if self.session.session_id != event.session_id: + self.notify("Chat session id missmatch", severity="error") return - msg: OllamaMessage | None = ses.get_message(event.message_id) + msg: OllamaMessage | None = self.session.get_message(event.message_id) if not msg: self.notify("Chat message not found", severity="error") return @@ -333,6 +338,7 @@ async def on_chat_message(self, event: ChatMessage) -> None: self.set_timer(0.1, self.scroll_to_bottom) chat_manager.notify_changed() + self.update_session_status_bar() def scroll_to_bottom(self) -> None: """Scroll to the bottom of the chat window.""" @@ -373,6 +379,8 @@ async def load_session(self, session_id: str) -> None: ) with self.prevent(Focus, Input.Changed, Select.Changed): self.set_model_name(self.session.llm_model_name) + if self.model_select.value == Select.BLANK: + self.notify("Model defined in session is not installed") self.temperature_input.value = str( self.session.options.get("temperature", "") ) @@ -380,6 +388,7 @@ async def load_session(self, session_id: str) -> None: self.update_control_states() self.notify_tab_label_changed() self.set_timer(0.1, self.scroll_to_bottom) + self.update_session_status_bar() self.user_input.focus() @on(SessionSelected) @@ -396,3 +405,60 @@ async def on_delete_session(self, event: DeleteSession) -> None: self.notify("Chat session deleted") if self.session.session_id == event.session_id: await self.action_new_session() + + async def on_session_updated(self) -> None: + """Session updated event""" + self.session_name_input.value = self.session.session_name + self.notify_tab_label_changed() + self.update_session_status_bar() + + def update_session_status_bar(self) -> None: + """Update session status bar""" + model = dm.get_model_by_name(self.session.llm_model_name) + if model: + max_context_length = model.model_info.llama_context_length or 0 + else: + max_context_length = 0 + self.session_status_bar.update( + Text.assemble( + "Context Length: ", + humanize.intcomma(self.session.context_length), + " / ", + humanize.intcomma(max_context_length), + ) + ) + + async def action_delete_msg(self) -> None: + """Handle the delete message action.""" + ret = self.vs.query("ChatMessageWidget:focus") + if len(ret) != 1: + return + msg: ChatMessageWidget = cast(ChatMessageWidget, ret[0]) + del self.session[msg.msg.message_id] + await msg.remove() + self.session.save() + self.update_session_status_bar() + if len(self.session) == 0: + self.user_input.focus() + + @work(thread=True, name="msg_send_worker") + async def do_send_message(self, msg: str) -> None: + """Send the message.""" + self.busy = True + await self.session.send_chat(msg, self) + self.post_message(ChatMessageSent(self.session.session_id)) + + @on(ChatMessageSent) + def on_chat_message_sent(self) -> None: + """Handle a chat message sent""" + self.busy = False + + @on(LocalModelDeleted) + def on_model_deleted(self, event: LocalModelDeleted) -> None: + """Model deleted check if the currently selected model.""" + event.stop() + + if event.model_name == self.model_select.value: + self.model_select.value = Select.BLANK + self.on_local_model_list_loaded(LocalModelListLoaded()) + self.update_control_states() diff --git a/parllama/widgets/views/chat_view.py b/parllama/widgets/views/chat_view.py index ebbfe9c..2bd7600 100644 --- a/parllama/widgets/views/chat_view.py +++ b/parllama/widgets/views/chat_view.py @@ -5,7 +5,6 @@ from typing import cast from textual import on -from textual import work from textual.app import ComposeResult from textual.binding import Binding from textual.containers import Horizontal @@ -21,12 +20,14 @@ from parllama.chat_manager import chat_manager from parllama.data_manager import dm from parllama.dialogs.information import InformationDialog +from parllama.messages.main import ChatGenerationAborted from parllama.messages.main import ChatMessage from parllama.messages.main import ChatMessageSent from parllama.messages.main import DeleteSession from parllama.messages.main import LocalModelListLoaded from parllama.messages.main import RegisterForUpdates from parllama.messages.main import SessionSelected +from parllama.messages.main import SessionUpdated from parllama.messages.main import UpdateChatControlStates from parllama.messages.main import UpdateTabLabel from parllama.models.chat import ChatSession @@ -55,6 +56,12 @@ class ChatView(Vertical, can_focus=False, can_focus_children=True): DEFAULT_CSS = """ ChatView { layers: left; + SessionList { + width: 40; + height: 1fr; + dock: left; + padding: 1; + } #chat_tabs { height: 1fr; } @@ -65,6 +72,12 @@ class ChatView(Vertical, can_focus=False, can_focus_children=True): width: 1fr; } #send_button { + min-width: 7; + width: 7; + margin-right: 1; + } + #stop_button { + min-width: 6; width: 6; } } @@ -127,7 +140,12 @@ def __init__(self, **kwargs) -> None: submit_on_complete=False, ) - self.send_button: Button = Button("Send", id="send_button", disabled=True) + self.send_button: Button = Button( + "Send", id="send_button", disabled=True, variant="success" + ) + self.stop_button: Button = Button( + "Stop", id="stop_button", disabled=True, variant="error" + ) def compose(self) -> ComposeResult: """Compose the content of the view.""" @@ -138,10 +156,21 @@ def compose(self) -> ComposeResult: with Horizontal(id="send_bar"): yield self.user_input yield self.send_button + yield self.stop_button async def on_mount(self) -> None: """Set up the dialog once the DOM is ready.""" - self.app.post_message(RegisterForUpdates(widget=self)) + self.app.post_message( + RegisterForUpdates( + widget=self, + event_names=[ + "LocalModelDeleted", + "LocalModelListLoaded", + "SessionSelected", + "DeleteSession", + ], + ) + ) @on(Input.Changed, "#user_input") def on_user_input_changed(self) -> None: @@ -155,6 +184,9 @@ def update_control_states(self): or self.active_tab.model_select.value == Select.BLANK or len(self.user_input.value.strip()) == 0 ) + self.stop_button.disabled = ( + not self.active_tab.busy or self.session.abort_pending + ) @on(LocalModelListLoaded) def on_local_model_list_loaded(self, evt: LocalModelListLoaded) -> None: @@ -193,12 +225,13 @@ async def action_send_message(self, event: Message) -> None: self.notify("LLM Busy...", severity="error") return - self.update_control_states() self.user_input.value = "" + self.update_control_states() + if user_msg.startswith("/"): return await self.handle_command(user_msg[1:].lower().strip()) - self.active_tab.busy = True - self.do_send_message(user_msg) + + self.active_tab.do_send_message(user_msg) # pylint: disable=too-many-branches async def handle_command(self, cmd: str) -> None: @@ -273,19 +306,35 @@ async def handle_command(self, cmd: str) -> None: else: self.notify(f"Unknown command: {cmd}", severity="error") - @work(thread=True) - async def do_send_message(self, msg: str) -> None: - """Send the message.""" - await self.session.send_chat(msg, self) - self.post_message(ChatMessageSent(self.session.session_id)) + @on(Button.Pressed, "#stop_button") + def on_stop_button_pressed(self, event: Button.Pressed) -> None: + """Stop button pressed""" + event.stop() + self.stop_button.disabled = True + self.active_tab.session.stop_generation() + # self.workers.cancel_group(self, "message_send") + self.workers.cancel_node(self.active_tab) + self.active_tab.busy = False @on(ChatMessageSent) def on_chat_message_sent(self, event: ChatMessageSent) -> None: """Handle a chat message sent""" event.stop() + if self.session.session_id == event.session_id: + self.stop_button.disabled = True + + @on(SessionUpdated) + async def on_session_updated(self, event: SessionUpdated) -> None: + """Session updated event""" + + event.stop() + session = chat_manager.get_session(event.session_id) + if not session: + return + session.set_name(chat_manager.mk_session_name(session.session_name)) for tab in self.chat_tabs.query(ChatTab): if tab.session.session_id == event.session_id: - tab.busy = False + await tab.on_session_updated() def action_toggle_session_list(self) -> None: """Toggle the session list.""" @@ -352,16 +401,20 @@ def re_index_labels(self) -> None: async def on_delete_session(self, event: DeleteSession) -> None: """Delete session event""" event.stop() + tab_removed: bool = False for tab in self.chat_tabs.query(ChatTab): if tab.session.session_id == event.session_id: await self.chat_tabs.remove_pane(str(tab.id)) - self.re_index_labels() - break + tab_removed = True + chat_manager.delete_session(event.session_id) if len(self.chat_tabs.query(ChatTab)) == 0: await self.action_new_tab() + elif tab_removed: + self.re_index_labels() + self.notify("Chat session deleted") - self.user_input.focus() + # self.user_input.focus() async def action_new_tab(self) -> None: """New tab action""" @@ -401,3 +454,9 @@ def focus_tab(self, tab_num: int) -> None: tabs = self.chat_tabs.query(ChatTab) if len(tabs) > tab_num: self.chat_tabs.active = str(tabs[tab_num].id) + + @on(ChatGenerationAborted) + def on_chat_generation_aborted(self, event: ChatGenerationAborted) -> None: + """Chat generation aborted event""" + event.stop() + self.notify("Chat Aborted", severity="warning") diff --git a/parllama/widgets/views/local_model_view.py b/parllama/widgets/views/local_model_view.py index d26169e..b1a2a40 100644 --- a/parllama/widgets/views/local_model_view.py +++ b/parllama/widgets/views/local_model_view.py @@ -28,6 +28,7 @@ from parllama.messages.main import ModelPulled from parllama.messages.main import ModelPullRequested from parllama.messages.main import ModelPushRequested +from parllama.messages.main import RegisterForUpdates from parllama.messages.main import SetModelNameLoading from parllama.messages.main import ShowLocalModel from parllama.widgets.filter_input import FilterInput @@ -121,6 +122,19 @@ def compose(self) -> ComposeResult: async def on_mount(self) -> None: """Mount the view.""" + + self.app.post_message( + RegisterForUpdates( + widget=self, + event_names=[ + "LocalModelListLoaded", + "LocalModelDeleted", + "SetModelNameLoading", + "ModelPulled", + "ModelPushed", + ], + ) + ) self.action_refresh_models() def action_refresh_models(self): diff --git a/parllama/widgets/views/site_model_view.py b/parllama/widgets/views/site_model_view.py index 8d95b1b..48a8417 100644 --- a/parllama/widgets/views/site_model_view.py +++ b/parllama/widgets/views/site_model_view.py @@ -19,6 +19,7 @@ from parllama.data_manager import dm from parllama.messages.main import ModelPullRequested +from parllama.messages.main import RegisterForUpdates from parllama.messages.main import SiteModelsLoaded from parllama.messages.main import SiteModelsRefreshRequested from parllama.models.settings_data import settings @@ -119,6 +120,12 @@ def compose(self) -> ComposeResult: def on_mount(self) -> None: """Configure the dialog once the DOM is ready.""" + self.app.post_message( + RegisterForUpdates( + widget=self, + event_names=["SiteModelsLoaded", "SetModelNameLoading"], + ) + ) self.lv.loading = True self.app.post_message( SiteModelsRefreshRequested( @@ -144,9 +151,10 @@ async def on_list_view_highlighted(self, event: ListView.Highlighted) -> None: def action_pull_model(self): """Request model pull""" if not self.search_input.value: + self.notify("Please enter a model name", severity="warning") return if self.namespace_input.value: - self.screen.post_message( + self.app.post_message( ModelPullRequested( widget=self, model_name=self.namespace_input.value @@ -155,7 +163,7 @@ def action_pull_model(self): ) ) else: - self.screen.post_message( + self.app.post_message( ModelPullRequested(widget=self, model_name=self.search_input.value) ) diff --git a/setup.cfg b/setup.cfg index 8e36ce5..b491a1d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ author_email = probello@gmail.com maintainer = Paul Robello maintainer_email = probello@gmail.com license = License :: OSI Approved :: MIT License -license_files = LICENCE +license_files = LICENSE keywords = ollama, ai, terminal, tui classifiers = License :: OSI Approved :: MIT License