From 758a3fbba3aa96cd9e1b3f516446c7d5842b1d67 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Thu, 29 Jun 2023 20:33:54 -0700 Subject: [PATCH 1/3] Fix for propagating provider param updates --- packages/jupyter-ai/jupyter_ai/actors/base.py | 10 +++++++--- packages/jupyter-ai/jupyter_ai/actors/chat_provider.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/actors/base.py b/packages/jupyter-ai/jupyter_ai/actors/base.py index d1ddddcf3..96eddff3e 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/base.py +++ b/packages/jupyter-ai/jupyter_ai/actors/base.py @@ -88,7 +88,7 @@ def get_llm_chain(self): else None ) - if not lm_provider: + if not lm_provider or not lm_provider_params: return None if curr_lm_id != next_lm_id: @@ -96,6 +96,10 @@ def get_llm_chain(self): f"Switching chat language model from {curr_lm_id} to {next_lm_id}." ) self.create_llm_chain(lm_provider, lm_provider_params) + elif self.llm_params != lm_provider_params: + self.log.info("Chat model params changed, updating the llm chain.") + self.create_llm_chain(lm_provider, lm_provider_params) + return self.llm_chain def get_embeddings(self): @@ -104,10 +108,10 @@ def get_embeddings(self): embedding_params = ray.get(actor.get_provider_params.remote()) embedding_model_id = ray.get(actor.get_model_id.remote()) - if not provider: + if not provider or not embedding_params: return None - if embedding_model_id != self.embedding_model_id: + if embedding_model_id != self.embedding_model_id or self.embeddings_params != embedding_params: self.embeddings = provider(**embedding_params) return self.embeddings diff --git a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py index e108f9f2a..e5f60ddcb 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py @@ -16,7 +16,7 @@ def update(self, config: GlobalConfig): local_model_id, provider = ray.get( actor.get_model_provider_data.remote(model_id) ) - + if not provider: raise ValueError(f"No provider and model found with '{model_id}'") From 4bad317c7650f978f4efe3c8640c877356de4d83 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Thu, 29 Jun 2023 20:41:13 -0700 Subject: [PATCH 2/3] Adds vertical scroll for settings panel --- packages/jupyter-ai/src/components/chat-settings.tsx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 834f84ccb..65fb9bbc4 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -289,7 +289,8 @@ export function ChatSettings(): JSX.Element { sx={{ padding: 4, boxSizing: 'border-box', - '& > .MuiAlert-root': { marginBottom: 2 } + '& > .MuiAlert-root': { marginBottom: 2 }, + overflowY: 'auto' }} > {state === ChatSettingsState.SubmitError && ( From 5567506eb47003a6cc68619ccc556d7894f7da43 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 30 Jun 2023 03:42:26 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- packages/jupyter-ai/jupyter_ai/actors/base.py | 7 +++++-- packages/jupyter-ai/jupyter_ai/actors/chat_provider.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/actors/base.py b/packages/jupyter-ai/jupyter_ai/actors/base.py index 96eddff3e..fa2a30bb5 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/base.py +++ b/packages/jupyter-ai/jupyter_ai/actors/base.py @@ -99,7 +99,7 @@ def get_llm_chain(self): elif self.llm_params != lm_provider_params: self.log.info("Chat model params changed, updating the llm chain.") self.create_llm_chain(lm_provider, lm_provider_params) - + return self.llm_chain def get_embeddings(self): @@ -111,7 +111,10 @@ def get_embeddings(self): if not provider or not embedding_params: return None - if embedding_model_id != self.embedding_model_id or self.embeddings_params != embedding_params: + if ( + embedding_model_id != self.embedding_model_id + or self.embeddings_params != embedding_params + ): self.embeddings = provider(**embedding_params) return self.embeddings diff --git a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py index e5f60ddcb..e108f9f2a 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py @@ -16,7 +16,7 @@ def update(self, config: GlobalConfig): local_model_id, provider = ray.get( actor.get_model_provider_data.remote(model_id) ) - + if not provider: raise ValueError(f"No provider and model found with '{model_id}'")