From 835f90ed252446c40f240d0e632a64efacdad88d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Oct 2023 14:50:03 -0700 Subject: [PATCH 1/6] Fixed conflict. --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d0b62dc2b..569395200 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: end-of-file-fixer - id: check-case-conflict @@ -18,7 +18,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.9.1 hooks: - id: black @@ -30,7 +30,7 @@ repos: files: \.py$ - repo: https://github.com/asottile/pyupgrade - rev: v3.9.0 + rev: v3.15.0 hooks: - id: pyupgrade args: [--py37-plus] @@ -48,7 +48,7 @@ repos: stages: [manual] - repo: https://github.com/sirosen/check-jsonschema - rev: 0.23.3 + rev: 0.27.0 hooks: - id: check-jsonschema name: "Check GitHub Workflows" From 16ce73225a04c5467d76cee534e9912cd4aa163c Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 2 Oct 2023 20:49:58 -0700 Subject: [PATCH 2/6] Updates for more stable generate feature (cherry picked from commit d9b26f5729bd12e49711b4b1ed7314c064ef3687) --- .../jupyter_ai/chat_handlers/generate.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index e85d5a916..1934c6fa5 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -1,6 +1,7 @@ import asyncio import json import os +import re from typing import Dict, Type import nbformat @@ -48,7 +49,7 @@ def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain: "Generate the outline as JSON data that will validate against this JSON schema:\n" "{schema}\n" "Here is a description of the notebook you will create an outline for: {description}\n" - "Don't include an introduction or conclusion section in the outline, focus only on sections that will need code." + "Don't include an introduction or conclusion section in the outline, focus only on description and sections that will need code.\n" ) prompt = PromptTemplate( template=task_creation_template, @@ -57,10 +58,22 @@ def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain: return cls(prompt=prompt, llm=llm, verbose=verbose) +def extract_json(text: str) -> str: + """Extract json from text using Regex.""" + # The pattern to find json string enclosed in ```json```` + pattern = r"```json\n(.*?)\n```" + + # Find all matches in the input text + matches = re.findall(pattern, text, re.DOTALL) + + return matches[0] if matches else text + + async def generate_outline(description, llm=None, verbose=False): """Generate an outline of sections given a description of a notebook.""" chain = NotebookOutlineChain.from_llm(llm=llm, verbose=verbose) outline = await chain.apredict(description=description, schema=schema) + outline = extract_json(outline) return json.loads(outline) @@ -182,14 +195,10 @@ async def generate_summary(outline, llm=None, verbose: bool = False): async def fill_outline(outline, llm, verbose=False): shared_kwargs = {"outline": outline, "llm": llm, "verbose": verbose} - all_coros = [] - all_coros.append(generate_title(**shared_kwargs)) - all_coros.append(generate_summary(**shared_kwargs)) + await generate_title(**shared_kwargs) + await generate_summary(**shared_kwargs) for section in outline["sections"]: - all_coros.append( - generate_code(section, outline["description"], llm=llm, verbose=verbose) - ) - await asyncio.gather(*all_coros) + await generate_code(section, outline["description"], llm=llm, verbose=verbose) def create_notebook(outline): From e9d71b8c133cac98f5d940fd4f7cd2df6b397b8f Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 11 Oct 2023 21:48:07 -0700 Subject: [PATCH 3/6] Refactored generate for better stability with all providers/models. (cherry picked from commit 608ed25a38613d8c64d34c68379464851669cebb) --- .../jupyter_ai_magics/providers.py | 16 +++ .../jupyter_ai/chat_handlers/generate.py | 130 +++++++++--------- 2 files changed, 83 insertions(+), 63 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 5a77926c7..9fdbffa7a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -227,6 +227,10 @@ def get_prompt_template(self, format) -> PromptTemplate: def is_chat_provider(self): return isinstance(self, BaseChatModel) + @property + def allows_concurrency(self): + return True + class AI21Provider(BaseProvider, AI21): id = "ai21" @@ -267,6 +271,10 @@ class AnthropicProvider(BaseProvider, Anthropic): pypi_package_deps = ["anthropic"] auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY") + @property + def allows_concurrency(self): + return False + class ChatAnthropicProvider(BaseProvider, ChatAnthropic): id = "anthropic-chat" @@ -285,6 +293,10 @@ class ChatAnthropicProvider(BaseProvider, ChatAnthropic): pypi_package_deps = ["anthropic"] auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY") + @property + def allows_concurrency(self): + return False + class CohereProvider(BaseProvider, Cohere): id = "cohere" @@ -665,3 +677,7 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]: return await self._generate_in_executor(*args, **kwargs) + + @property + def allows_concurrency(self): + return not "anthropic" in self.model_id diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 1934c6fa5..b3a7c4335 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -1,8 +1,6 @@ import asyncio -import json import os -import re -from typing import Dict, Type +from typing import Dict, List, Optional, Type import nbformat from jupyter_ai.chat_handlers import BaseChatHandler @@ -10,71 +8,50 @@ from jupyter_ai_magics.providers import BaseProvider from langchain.chains import LLMChain from langchain.llms import BaseLLM +from langchain.output_parsers import PydanticOutputParser from langchain.prompts import PromptTemplate +from langchain.schema.output_parser import BaseOutputParser +from pydantic import BaseModel -schema = """{ - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": { - "description": { - "type": "string" - }, - "sections": { - "type": "array", - "items": { - "type": "object", - "properties": { - "title": { - "type": "string" - }, - "content": { - "type": "string" - } - }, - "required": ["title", "content"] - } - } - }, - "required": ["sections"] -}""" + +class OutlineSection(BaseModel): + title: str + content: str + + +class Outline(BaseModel): + description: Optional[str] = None + sections: List[OutlineSection] class NotebookOutlineChain(LLMChain): """Chain to generate a notebook outline, with section titles and descriptions.""" @classmethod - def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain: + def from_llm( + cls, llm: BaseLLM, parser: BaseOutputParser[Outline], verbose: bool = False + ) -> LLMChain: task_creation_template = ( "You are an AI that creates a detailed content outline for a Jupyter notebook on a given topic.\n" - "Generate the outline as JSON data that will validate against this JSON schema:\n" - "{schema}\n" + "{format_instructions}\n" "Here is a description of the notebook you will create an outline for: {description}\n" "Don't include an introduction or conclusion section in the outline, focus only on description and sections that will need code.\n" ) prompt = PromptTemplate( template=task_creation_template, - input_variables=["description", "schema"], + input_variables=["description"], + partial_variables={"format_instructions": parser.get_format_instructions()}, ) return cls(prompt=prompt, llm=llm, verbose=verbose) -def extract_json(text: str) -> str: - """Extract json from text using Regex.""" - # The pattern to find json string enclosed in ```json```` - pattern = r"```json\n(.*?)\n```" - - # Find all matches in the input text - matches = re.findall(pattern, text, re.DOTALL) - - return matches[0] if matches else text - - async def generate_outline(description, llm=None, verbose=False): """Generate an outline of sections given a description of a notebook.""" - chain = NotebookOutlineChain.from_llm(llm=llm, verbose=verbose) - outline = await chain.apredict(description=description, schema=schema) - outline = extract_json(outline) - return json.loads(outline) + parser = PydanticOutputParser(pydantic_object=Outline) + chain = NotebookOutlineChain.from_llm(llm=llm, parser=parser, verbose=verbose) + outline = await chain.apredict(description=description) + outline = parser.parse(outline) + return outline.dict() class CodeImproverChain(LLMChain): @@ -141,7 +118,8 @@ class NotebookTitleChain(LLMChain): def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain: task_creation_template = ( "Create a short, few word, descriptive title for a Jupyter notebook with the following content.\n" - "Content:\n{content}" + "Content:\n{content}\n" + "Don't return anything other than the title." ) prompt = PromptTemplate( template=task_creation_template, @@ -178,7 +156,7 @@ async def generate_code(section, description, llm=None, verbose=False) -> None: async def generate_title(outline, llm=None, verbose: bool = False): - """Generate a title and summary of a notebook outline using an LLM.""" + """Generate a title of a notebook outline using an LLM.""" title_chain = NotebookTitleChain.from_llm(llm=llm, verbose=verbose) title = await title_chain.apredict(content=outline) title = title.strip() @@ -187,12 +165,14 @@ async def generate_title(outline, llm=None, verbose: bool = False): async def generate_summary(outline, llm=None, verbose: bool = False): + """Generate a summary of a notebook using an LLM.""" summary_chain = NotebookSummaryChain.from_llm(llm=llm, verbose=verbose) summary = await summary_chain.apredict(content=outline) outline["summary"] = summary async def fill_outline(outline, llm, verbose=False): + """Generate title and content of a notebook sections using an LLM.""" shared_kwargs = {"outline": outline, "llm": llm, "verbose": verbose} await generate_title(**shared_kwargs) @@ -201,6 +181,20 @@ async def fill_outline(outline, llm, verbose=False): await generate_code(section, outline["description"], llm=llm, verbose=verbose) +async def afill_outline(outline, llm, verbose=False): + """Concurrently generate title and content of notebook sections using an LLM.""" + shared_kwargs = {"outline": outline, "llm": llm, "verbose": verbose} + + all_coros = [] + all_coros.append(generate_title(**shared_kwargs)) + all_coros.append(generate_summary(**shared_kwargs)) + for section in outline["sections"]: + all_coros.append( + generate_code(section, outline["description"], llm=llm, verbose=verbose) + ) + await asyncio.gather(*all_coros) + + def create_notebook(outline): """Create an nbformat Notebook object for a notebook outline.""" nbf = nbformat.v4 @@ -233,29 +227,39 @@ def create_llm_chain( self.llm = llm return llm - async def _process_message(self, message: HumanChatMessage): - self.get_llm_chain() - - # first send a verification message to user - response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions." - self.reply(response, message) + async def _generate_notebook(self, prompt: str): + """Generate a notebook and save to local disk""" - # generate notebook outline - prompt = message.body + # create outline outline = await generate_outline(prompt, llm=self.llm, verbose=True) # Save the user input prompt, the description property is now LLM generated. outline["prompt"] = prompt - # fill the outline concurrently - await fill_outline(outline, llm=self.llm, verbose=True) + if self.llm.allows_concurrency: + # fill the outline concurrently + await afill_outline(outline, llm=self.llm, verbose=True) + else: + # fill outline + await fill_outline(outline, llm=self.llm, verbose=True) # create and write the notebook to disk notebook = create_notebook(outline) final_path = os.path.join(self.root_dir, outline["title"] + ".ipynb") nbformat.write(notebook, final_path) - response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it.""" - self.reply(response, message) + return final_path + + async def _process_message(self, message: HumanChatMessage): + self.get_llm_chain() + # first send a verification message to user + response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions." + self.reply(response, message) -# /generate notebook -# Error handling + try: + final_path = await self._generate_notebook(prompt=message.body) + response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it.""" + except Exception as e: + self.log.exception(e) + response = "An error occurred while generating the notebook. Try running the /generate task again." + finally: + self.reply(response, message) From ab1285767811e269972512c62503e5a7bbc5966f Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Fri, 20 Oct 2023 21:04:10 -0700 Subject: [PATCH 4/6] Upgraded LangChain to 0.0.318 (cherry picked from commit 0d247fdf177b6f3ad48b88bb9ceb886e3c07a8a2) --- packages/jupyter-ai-magics/pyproject.toml | 2 +- packages/jupyter-ai/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index a310621c1..7eb8a8644 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "ipython", "pydantic~=1.0", "importlib_metadata>=5.2.0", - "langchain==0.0.308", + "langchain==0.0.318", "typing_extensions>=4.5.0", "click~=8.0", "jsonpath-ng>=1.5.3,<2", diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index fd8a06934..993070c04 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "openai~=0.26", "aiosqlite>=0.18", "importlib_metadata>=5.2.0", - "langchain==0.0.308", + "langchain==0.0.318", "tiktoken", # required for OpenAIEmbeddings "jupyter_ai_magics", "dask[distributed]", From 5f2190363619e7c0f8f9d5b4b1a8559566afaa99 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Fri, 20 Oct 2023 21:48:23 -0700 Subject: [PATCH 5/6] Updated to use memory instead of chat history, fix for Bedrock Anthropic (cherry picked from commit 7c09863b3199d167fe18575194ee81991285d841) --- .../jupyter_ai/chat_handlers/ask.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index cad14b0e5..dbc2bd679 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -4,9 +4,19 @@ from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider from langchain.chains import ConversationalRetrievalChain +from langchain.memory import ConversationBufferWindowMemory +from langchain.prompts import PromptTemplate from .base import BaseChatHandler +PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. + +Chat History: +{chat_history} +Follow Up Input: {question} +Standalone question:""" +CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE) + class AskChatHandler(BaseChatHandler): """Processes messages prefixed with /ask. This actor will @@ -27,9 +37,15 @@ def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): self.llm = provider(**provider_params) - self.chat_history = [] + memory = ConversationBufferWindowMemory( + memory_key="chat_history", return_messages=True, k=2 + ) self.llm_chain = ConversationalRetrievalChain.from_llm( - self.llm, self._retriever, verbose=True + self.llm, + self._retriever, + memory=memory, + condense_question_prompt=CONDENSE_PROMPT, + verbose=False, ) async def _process_message(self, message: HumanChatMessage): @@ -44,14 +60,8 @@ async def _process_message(self, message: HumanChatMessage): self.get_llm_chain() try: - # limit chat history to last 2 exchanges - self.chat_history = self.chat_history[-2:] - - result = await self.llm_chain.acall( - {"question": query, "chat_history": self.chat_history} - ) + result = await self.llm_chain.acall({"question": query}) response = result["answer"] - self.chat_history.append((query, response)) self.reply(response, message) except AssertionError as e: self.log.error(e) From c7d54bf028e8674a1b578bf921135d6b9cad8767 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Krassowski?= <5832902+krassowski@users.noreply.github.com> Date: Mon, 23 Oct 2023 20:42:11 +0100 Subject: [PATCH 6/6] Allow to define block and allow lists for providers (#415) * Allow to block or allow-list providers by id * Add tests for block/allow-lists * Fix "No language model is associated with" issue This was appearing because the models which are blocked were not returned (correctly!) but the previous validation logic did not know that sometimes models may be missing for a valid reason even if there are existing settings for these. * Add docs for allow listing and block listing providers * Updated docs * Added an intro block to docs * Updated the docs --------- Co-authored-by: Piyush Jain (cherry picked from commit 92dab10608ea090b6bd87ca5ffe9ccff7a15b449) --- docs/source/users/index.md | 31 +++++++++++++++ .../jupyter_ai_magics/tests/test_utils.py | 34 ++++++++++++++++ .../jupyter_ai_magics/utils.py | 32 +++++++++++++-- .../jupyter-ai/jupyter_ai/config_manager.py | 15 +++++++ packages/jupyter-ai/jupyter_ai/extension.py | 30 +++++++++++++- .../jupyter_ai/tests/test_config_manager.py | 2 + .../jupyter_ai/tests/test_extension.py | 39 +++++++++++++++++++ packages/jupyter-ai/pyproject.toml | 1 + 8 files changed, 178 insertions(+), 6 deletions(-) create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py create mode 100644 packages/jupyter-ai/jupyter_ai/tests/test_extension.py diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 5ef4e4ed6..1a58ba247 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -705,3 +705,34 @@ The `--region-name` parameter is set to the [AWS region code](https://docs.aws.a The `--request-schema` parameter is the JSON object the endpoint expects as input, with the prompt being substituted into any value that matches the string literal `""`. For example, the request schema `{"text_inputs":""}` will submit a JSON object with the prompt stored under the `text_inputs` key. The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonPath/index.html) string that retrieves the language model's output from the endpoint's JSON response. For example, if your endpoint returns an object with the schema `{"generated_texts":[""]}`, its response path is `generated_texts.[0]`. + + +## Configuration + +You can specify an allowlist, to only allow only a certain list of providers, or a blocklist, to block some providers. + +### Blocklisting providers +This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section. + +``` +jupyter lab --AiExtension.blocked_providers=openai +``` + +To block more than one provider in the block-list, repeat the runtime configuration. + +``` +jupyter lab --AiExtension.blocked_providers=openai --AiExtension.blocked_providers=ai21 +``` + +### Allowlisting providers +This configuration allows for filtering the list of providers in the settings panel to only an allowlisted set of providers. + +``` +jupyter lab --AiExtension.allowed_providers=openai +``` + +To allow more than one provider in the allowlist, repeat the runtime configuration. + +``` +jupyter lab --AiExtension.allowed_providers=openai --AiExtension.allowed_providers=ai21 +``` diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py new file mode 100644 index 000000000..e1c517ebe --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py @@ -0,0 +1,34 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import pytest +from jupyter_ai_magics.utils import get_lm_providers + +KNOWN_LM_A = "openai" +KNOWN_LM_B = "huggingface_hub" + + +@pytest.mark.parametrize( + "restrictions", + [ + {"allowed_providers": None, "blocked_providers": None}, + {"allowed_providers": [], "blocked_providers": []}, + {"allowed_providers": [], "blocked_providers": [KNOWN_LM_B]}, + {"allowed_providers": [KNOWN_LM_A], "blocked_providers": []}, + ], +) +def test_get_lm_providers_not_restricted(restrictions): + a_not_restricted = get_lm_providers(None, restrictions) + assert KNOWN_LM_A in a_not_restricted + + +@pytest.mark.parametrize( + "restrictions", + [ + {"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]}, + {"allowed_providers": [KNOWN_LM_B], "blocked_providers": []}, + ], +) +def test_get_lm_providers_restricted(restrictions): + a_not_restricted = get_lm_providers(None, restrictions) + assert KNOWN_LM_A not in a_not_restricted diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py index 6a02c61c8..c651581bc 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Optional, Tuple, Type, Union +from typing import Dict, List, Literal, Optional, Tuple, Type, Union from importlib_metadata import entry_points from jupyter_ai_magics.aliases import MODEL_ID_ALIASES @@ -11,13 +11,19 @@ EmProvidersDict = Dict[str, BaseEmbeddingsProvider] AnyProvider = Union[BaseProvider, BaseEmbeddingsProvider] ProviderDict = Dict[str, AnyProvider] +ProviderRestrictions = Dict[ + Literal["allowed_providers", "blocked_providers"], Optional[List[str]] +] -def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: +def get_lm_providers( + log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None +) -> LmProvidersDict: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) - + if not restrictions: + restrictions = {"allowed_providers": None, "blocked_providers": None} providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.model_providers") @@ -29,6 +35,9 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: f"Unable to load model provider class from entry point `{model_provider_ep.name}`." ) continue + if not is_provider_allowed(provider.id, restrictions): + log.info(f"Skipping blocked provider `{provider.id}`.") + continue providers[provider.id] = provider log.info(f"Registered model provider `{provider.id}`.") @@ -36,11 +45,13 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: def get_em_providers( - log: Optional[Logger] = None, + log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None ) -> EmProvidersDict: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) + if not restrictions: + restrictions = {"allowed_providers": None, "blocked_providers": None} providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers") @@ -52,6 +63,9 @@ def get_em_providers( f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`." ) continue + if not is_provider_allowed(provider.id, restrictions): + log.info(f"Skipping blocked provider `{provider.id}`.") + continue providers[provider.id] = provider log.info(f"Registered embeddings model provider `{provider.id}`.") @@ -97,6 +111,16 @@ def get_em_provider( return _get_provider(model_id, em_providers) +def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool: + allowed = restrictions["allowed_providers"] + blocked = restrictions["blocked_providers"] + if blocked and provider_id in blocked: + return False + if allowed and provider_id not in allowed: + return False + return True + + def _get_provider(model_id: str, providers: ProviderDict): provider_id, local_model_id = decompose_model_id(model_id, providers) provider = providers.get(provider_id, None) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index f61e07bde..8708b4b94 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -12,8 +12,10 @@ AnyProvider, EmProvidersDict, LmProvidersDict, + ProviderRestrictions, get_em_provider, get_lm_provider, + is_provider_allowed, ) from jupyter_core.paths import jupyter_data_dir from traitlets import Integer, Unicode @@ -97,6 +99,7 @@ def __init__( log: Logger, lm_providers: LmProvidersDict, em_providers: EmProvidersDict, + restrictions: ProviderRestrictions, *args, **kwargs, ): @@ -106,6 +109,8 @@ def __init__( self._lm_providers = lm_providers """List of EM providers.""" self._em_providers = em_providers + """Provider restrictions.""" + self._restrictions = restrictions """When the server last read the config file. If the file was not modified after this time, then we can return the cached @@ -176,6 +181,10 @@ def _validate_config(self, config: GlobalConfig): _, lm_provider = get_lm_provider( config.model_provider_id, self._lm_providers ) + # do not check config for blocked providers + if not is_provider_allowed(config.model_provider_id, self._restrictions): + assert not lm_provider + return if not lm_provider: raise ValueError( f"No language model is associated with '{config.model_provider_id}'." @@ -187,6 +196,12 @@ def _validate_config(self, config: GlobalConfig): _, em_provider = get_em_provider( config.embeddings_provider_id, self._em_providers ) + # do not check config for blocked providers + if not is_provider_allowed( + config.embeddings_provider_id, self._restrictions + ): + assert not em_provider + return if not em_provider: raise ValueError( f"No embedding model is associated with '{config.embeddings_provider_id}'." diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 7b6d07b31..50865ed96 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -4,6 +4,7 @@ from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp +from traitlets import List, Unicode from .chat_handlers import ( AskChatHandler, @@ -36,11 +37,35 @@ class AiExtension(ExtensionApp): (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), ] + allowed_providers = List( + Unicode(), + default_value=None, + help="Identifiers of allow-listed providers. If `None`, all are allowed.", + allow_none=True, + config=True, + ) + + blocked_providers = List( + Unicode(), + default_value=None, + help="Identifiers of block-listed providers. If `None`, none are blocked.", + allow_none=True, + config=True, + ) + def initialize_settings(self): start = time.time() + restrictions = { + "allowed_providers": self.allowed_providers, + "blocked_providers": self.blocked_providers, + } - self.settings["lm_providers"] = get_lm_providers(log=self.log) - self.settings["em_providers"] = get_em_providers(log=self.log) + self.settings["lm_providers"] = get_lm_providers( + log=self.log, restrictions=restrictions + ) + self.settings["em_providers"] = get_em_providers( + log=self.log, restrictions=restrictions + ) self.settings["jai_config_manager"] = ConfigManager( # traitlets configuration, not JAI configuration. @@ -48,6 +73,7 @@ def initialize_settings(self): log=self.log, lm_providers=self.settings["lm_providers"], em_providers=self.settings["em_providers"], + restrictions=restrictions, ) self.log.info("Registered providers.") diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index f7afbd29a..8cb1808fe 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -36,6 +36,7 @@ def common_cm_kwargs(config_path, schema_path): "em_providers": em_providers, "config_path": config_path, "schema_path": schema_path, + "restrictions": {"allowed_providers": None, "blocked_providers": None}, } @@ -112,6 +113,7 @@ def test_init_with_existing_config( em_providers=em_providers, config_path=config_path, schema_path=schema_path, + restrictions={"allowed_providers": None, "blocked_providers": None}, ) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py new file mode 100644 index 000000000..d1a10df77 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py @@ -0,0 +1,39 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. +import pytest +from jupyter_ai.extension import AiExtension + +pytest_plugins = ["pytest_jupyter.jupyter_server"] + +KNOWN_LM_A = "openai" +KNOWN_LM_B = "huggingface_hub" + + +@pytest.mark.parametrize( + "argv", + [ + ["--AiExtension.blocked_providers", KNOWN_LM_B], + ["--AiExtension.allowed_providers", KNOWN_LM_A], + ], +) +def test_allows_providers(argv, jp_configurable_serverapp): + server = jp_configurable_serverapp(argv=argv) + ai = AiExtension() + ai._link_jupyter_server_extension(server) + ai.initialize_settings() + assert KNOWN_LM_A in ai.settings["lm_providers"] + + +@pytest.mark.parametrize( + "argv", + [ + ["--AiExtension.blocked_providers", KNOWN_LM_A], + ["--AiExtension.allowed_providers", KNOWN_LM_B], + ], +) +def test_blocks_providers(argv, jp_configurable_serverapp): + server = jp_configurable_serverapp(argv=argv) + ai = AiExtension() + ai._link_jupyter_server_extension(server) + ai.initialize_settings() + assert KNOWN_LM_A not in ai.settings["lm_providers"] diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 993070c04..291775703 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -51,6 +51,7 @@ test = [ "pytest-asyncio", "pytest-cov", "pytest-tornasync", + "pytest-jupyter", "syrupy~=4.0.8" ]