Skip to content

Commit

Permalink
Inline code completions (#465)
Browse files Browse the repository at this point in the history
* Draft inline completions implementation (server side)

* Implement inline completion provider (front)

* Add default debounce delay and error handling (front)

* Add `gpt-3.5-turbo-instruct` because text- models are deprecated.

OpenAI specifically recommends using `gpt-3.5-turbo-instruct` in
favour of text-davinci, text-ada, etc. See:
https://platform.openai.com/docs/deprecations/

* Improve/fix prompt template and add simple post-processing

* Handle missing `registerInlineProvider`, handle no model in name

* Remove IPython mention to avoid confusing languages

* Disable suggestions in markdown, move language logic

* Remove unused background and clip path from jupyternaut

* Implement toggling the AI completer via statusbar item

also adds the icon for provider re-using jupyternaut icon

* Implement streaming support

* Translate ipython to python for models, remove log

* Move `BaseLLMHandler` to `/completions` rename to `LLMHandlerMixin`

* Move frontend completions code to `/completions`

* Make `IStatusBar` required for now, lint
  • Loading branch information
dlqqq committed Jan 19, 2024
1 parent f56496c commit 7050201
Show file tree
Hide file tree
Showing 22 changed files with 1,934 additions and 1,029 deletions.
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ class OpenAIProvider(BaseProvider, OpenAI):
"text-curie-001",
"text-babbage-001",
"text-ada-001",
"gpt-3.5-turbo-instruct",
"davinci",
"curie",
"babbage",
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import BaseInlineCompletionHandler
from .default import DefaultInlineCompletionHandler

__all__ = ["BaseInlineCompletionHandler", "DefaultInlineCompletionHandler"]
76 changes: 76 additions & 0 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import traceback

# necessary to prevent circular import
from typing import TYPE_CHECKING, AsyncIterator, Dict

from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin
from jupyter_ai.completions.models import (
CompletionError,
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
ModelChangedNotification,
)
from jupyter_ai.config_manager import ConfigManager, Logger

if TYPE_CHECKING:
from jupyter_ai.handlers import InlineCompletionHandler


class BaseInlineCompletionHandler(LLMHandlerMixin):
"""Class implementing completion handling."""

handler_kind = "completion"

def __init__(
self,
log: Logger,
config_manager: ConfigManager,
model_parameters: Dict[str, Dict],
ws_sessions: Dict[str, "InlineCompletionHandler"],
):
super().__init__(log, config_manager, model_parameters)
self.ws_sessions = ws_sessions

async def on_message(
self, message: InlineCompletionRequest
) -> InlineCompletionReply:
try:
return await self.process_message(message)
except Exception as e:
return await self._handle_exc(e, message)

async def process_message(
self, message: InlineCompletionRequest
) -> InlineCompletionReply:
"""
Processes an inline completion request. Completion handlers
(subclasses) must implement this method.
The method definition does not need to be wrapped in a try/except block.
"""
raise NotImplementedError("Should be implemented by subclasses.")

async def stream(
self, message: InlineCompletionRequest
) -> AsyncIterator[InlineCompletionStreamChunk]:
""" "
Stream the inline completion as it is generated. Completion handlers
(subclasses) can implement this method.
"""
raise NotImplementedError()

async def _handle_exc(self, e: Exception, message: InlineCompletionRequest):
error = CompletionError(
type=e.__class__.__name__,
title=e.args[0] if e.args else "Exception",
traceback=traceback.format_exc(),
)
return InlineCompletionReply(
list=InlineCompletionList(items=[]), error=error, reply_to=message.number
)

def broadcast(self, message: ModelChangedNotification):
for session in self.ws_sessions.values():
session.write_message(message.dict())
184 changes: 184 additions & 0 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from typing import Dict, Type

from jupyter_ai_magics.providers import BaseProvider
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable

from ..models import (
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
ModelChangedNotification,
)
from .base import BaseInlineCompletionHandler

SYSTEM_PROMPT = """
You are an application built to provide helpful code completion suggestions.
You should only produce code. Keep comments to minimum, use the
programming language comment syntax. Produce clean code.
The code is written in JupyterLab, a data analysis and code development
environment which can execute code extended with additional syntax for
interactive features, such as magics.
""".strip()

AFTER_TEMPLATE = """
The code after the completion request is:
```
{suffix}
```
""".strip()

DEFAULT_TEMPLATE = """
The document is called `{filename}` and written in {language}.
{after}
Complete the following code:
```
{prefix}"""


class DefaultInlineCompletionHandler(BaseInlineCompletionHandler):
llm_chain: Runnable

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
next_lm_id = (
f'{lm_provider.id}:{lm_provider_params["model_id"]}'
if lm_provider
else None
)
self.broadcast(ModelChangedNotification(model=next_lm_id))

model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

if llm.is_chat_provider:
prompt_template = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT),
HumanMessagePromptTemplate.from_template(DEFAULT_TEMPLATE),
]
)
else:
prompt_template = PromptTemplate(
input_variables=["prefix", "suffix", "language", "filename"],
template=SYSTEM_PROMPT + "\n\n" + DEFAULT_TEMPLATE,
)

self.llm = llm
self.llm_chain = prompt_template | llm | StrOutputParser()

async def process_message(
self, request: InlineCompletionRequest
) -> InlineCompletionReply:
if request.stream:
token = self._token_from_request(request, 0)
return InlineCompletionReply(
list=InlineCompletionList(
items=[
{
# insert text starts empty as we do not pre-generate any part
"insertText": "",
"isIncomplete": True,
"token": token,
}
]
),
reply_to=request.number,
)
else:
self.get_llm_chain()
model_arguments = self._template_inputs_from_request(request)
suggestion = await self.llm_chain.ainvoke(input=model_arguments)
suggestion = self._post_process_suggestion(suggestion, request)
return InlineCompletionReply(
list=InlineCompletionList(items=[{"insertText": suggestion}]),
reply_to=request.number,
)

async def stream(self, request: InlineCompletionRequest):
self.get_llm_chain()
token = self._token_from_request(request, 0)
model_arguments = self._template_inputs_from_request(request)
suggestion = ""
async for fragment in self.llm_chain.astream(input=model_arguments):
suggestion += fragment
if suggestion.startswith("```"):
if "\n" not in suggestion:
# we are not ready to apply post-processing
continue
else:
suggestion = self._post_process_suggestion(suggestion, request)
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=False,
)
# at the end send a message confirming that we are done
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=True,
)

def _token_from_request(self, request: InlineCompletionRequest, suggestion: int):
"""Generate a deterministic token (for matching streamed messages)
using request number and suggestion number"""
return f"t{request.number}s{suggestion}"

def _template_inputs_from_request(self, request: InlineCompletionRequest) -> Dict:
suffix = request.suffix.strip()
# only add the suffix template if the suffix is there to save input tokens/computation time
after = AFTER_TEMPLATE.format(suffix=suffix) if suffix else ""
filename = request.path.split("/")[-1] if request.path else "untitled"

return {
"prefix": request.prefix,
"after": after,
"language": request.language,
"filename": filename,
"stop": ["\n```"],
}

def _post_process_suggestion(
self, suggestion: str, request: InlineCompletionRequest
) -> str:
"""Remove spurious fragments from the suggestion.
While most models (especially instruct and infill models do not require
any pre-processing, some models such as gpt-4 which only have chat APIs
may require removing spurious fragments. This function uses heuristics
and request data to remove such fragments.
"""
# gpt-4 tends to add "```python" or similar
language = request.language or "python"
markdown_identifiers = {"ipython": ["ipython", "python", "py"]}
bad_openings = [
f"```{identifier}"
for identifier in markdown_identifiers.get(language, [language])
] + ["```"]
for opening in bad_openings:
if suggestion.startswith(opening):
suggestion = suggestion[len(opening) :].lstrip()
# check for the prefix inclusion (only if there was a bad opening)
if suggestion.startswith(request.prefix):
suggestion = suggestion[len(request.prefix) :]
break
return suggestion
73 changes: 73 additions & 0 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Dict, Type

from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai_magics.providers import BaseProvider


class LLMHandlerMixin:
"""Base class containing shared methods and attributes used by LLM handler classes."""

# This could be used to derive `BaseChatHandler` too (there is a lot of duplication!),
# but it was decided against it to avoid introducing conflicts for backports against 1.x

handler_kind: str

def __init__(
self,
log: Logger,
config_manager: ConfigManager,
model_parameters: Dict[str, Dict],
):
self.log = log
self.config_manager = config_manager
self.model_parameters = model_parameters
self.llm = None
self.llm_params = None
self.llm_chain = None

def model_changed_callback(self):
"""Method which can be overridden in sub-classes to listen to model change."""
pass

def get_llm_chain(self):
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params

curr_lm_id = (
f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None
)
next_lm_id = (
f'{lm_provider.id}:{lm_provider_params["model_id"]}'
if lm_provider
else None
)

if not lm_provider or not lm_provider_params:
return None

if curr_lm_id != next_lm_id:
self.log.info(
f"Switching {self.handler_kind} language model from {curr_lm_id} to {next_lm_id}."
)
self.create_llm_chain(lm_provider, lm_provider_params)
self.model_changed_callback()
elif self.llm_params != lm_provider_params:
self.log.info(
f"{self.handler_kind} model params changed, updating the llm chain."
)
self.create_llm_chain(lm_provider, lm_provider_params)

self.llm_params = lm_provider_params
return self.llm_chain

def get_model_parameters(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
return self.model_parameters.get(
f"{provider.id}:{provider_params['model_id']}", {}
)

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
raise NotImplementedError("Should be implemented by subclasses")
Loading

0 comments on commit 7050201

Please sign in to comment.