Skip to content

Commit

Permalink
Inline completion support (jupyterlab#582)
Browse files Browse the repository at this point in the history
* Inline code completions (jupyterlab#465)

* Draft inline completions implementation (server side)

* Implement inline completion provider (front)

* Add default debounce delay and error handling (front)

* Add `gpt-3.5-turbo-instruct` because text- models are deprecated.

OpenAI specifically recommends using `gpt-3.5-turbo-instruct` in
favour of text-davinci, text-ada, etc. See:
https://platform.openai.com/docs/deprecations/

* Improve/fix prompt template and add simple post-processing

* Handle missing `registerInlineProvider`, handle no model in name

* Remove IPython mention to avoid confusing languages

* Disable suggestions in markdown, move language logic

* Remove unused background and clip path from jupyternaut

* Implement toggling the AI completer via statusbar item

also adds the icon for provider re-using jupyternaut icon

* Implement streaming support

* Translate ipython to python for models, remove log

* Move `BaseLLMHandler` to `/completions` rename to `LLMHandlerMixin`

* Move frontend completions code to `/completions`

* Make `IStatusBar` required for now, lint

* Simplify inline completion backend (jupyterlab#553)

* do not import from pydantic directly

* refactor inline completion backend

* Autocomplete frontend fixes (jupyterlab#583)

* remove duplicate definition of inline completion provider

* rename completion variables, plugins, token to be more accurate

* abbreviate JupyterAIInlineProvider => JaiInlineProvider

* bump @jupyterlab/completer and typescript

* WIP: fix Jupyter AI completion settings

* Fix issues with settings population

* read from settings directly instead of using a cache

* disable Jupyter AI completion by default

* improve completion plugin menu items

* revert unnecessary edits to package manifest

* Update packages/jupyter-ai/src/components/statusbar-item.tsx

Co-authored-by: Michał Krassowski <5832902+krassowski@users.noreply.github.com>

* tweak wording

---------

Co-authored-by: krassowski <5832902+krassowski@users.noreply.github.com>

---------

Co-authored-by: David L. Qiu <david@qiu.dev>
  • Loading branch information
2 people authored and Marchlak committed Oct 28, 2024
1 parent 66f026f commit 0ba1215
Show file tree
Hide file tree
Showing 23 changed files with 2,460 additions and 1,373 deletions.
4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@
"test": "lerna run test"
},
"devDependencies": {
"@jupyterlab/builder": "^4",
"lerna": "^6.4.1",
"nx": "^15.9.2"
},
"resolutions": {
"@jupyterlab/completer": "4.1.0-beta.0"
},
"nx": {
"includedScripts": []
}
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ class OpenAIProvider(BaseProvider, OpenAI):
"text-curie-001",
"text-babbage-001",
"text-ada-001",
"gpt-3.5-turbo-instruct",
"davinci",
"curie",
"babbage",
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import BaseInlineCompletionHandler
from .default import DefaultInlineCompletionHandler

__all__ = ["BaseInlineCompletionHandler", "DefaultInlineCompletionHandler"]
169 changes: 169 additions & 0 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import json
import time
import traceback
from asyncio import AbstractEventLoop
from typing import Any, AsyncIterator, Dict, Union

import tornado
from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin
from jupyter_ai.completions.models import (
CompletionError,
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
)
from jupyter_server.base.handlers import JupyterHandler
from langchain.pydantic_v1 import BaseModel, ValidationError


class BaseInlineCompletionHandler(
LLMHandlerMixin, JupyterHandler, tornado.websocket.WebSocketHandler
):
"""A Tornado WebSocket handler that receives inline completion requests and
fulfills them accordingly. This class is instantiated once per WebSocket
connection."""

##
# Interface for subclasses
##
async def handle_request(
self, message: InlineCompletionRequest
) -> InlineCompletionReply:
"""
Handles an inline completion request, without streaming. Subclasses
must define this method and write a reply via `self.write_message()`.
The method definition does not need to be wrapped in a try/except block.
"""
raise NotImplementedError(
"The required method `self.handle_request()` is not defined by this subclass."
)

async def handle_stream_request(
self, message: InlineCompletionRequest
) -> AsyncIterator[InlineCompletionStreamChunk]:
"""
Handles an inline completion request, **with streaming**.
Implementations may optionally define this method. Implementations that
do so should stream replies via successive calls to
`self.write_message()`.
The method definition does not need to be wrapped in a try/except block.
"""
raise NotImplementedError(
"The optional method `self.handle_stream_request()` is not defined by this subclass."
)

##
# Definition of base class
##
handler_kind = "completion"

@property
def loop(self) -> AbstractEventLoop:
return self.settings["jai_event_loop"]

def write_message(self, message: Union[bytes, str, Dict[str, Any], BaseModel]):
"""
Write a bytes, string, dict, or Pydantic model object to the WebSocket
connection. The base definition of this method is provided by Tornado.
"""
if isinstance(message, BaseModel):
message = message.dict()

super().write_message(message)

def initialize(self):
self.log.debug("Initializing websocket connection %s", self.request.path)

def pre_get(self):
"""Handles authentication/authorization."""
# authenticate the request before opening the websocket
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise tornado.web.HTTPError(403)

# authorize the user.
if not self.authorizer.is_authorized(self, user, "execute", "events"):
raise tornado.web.HTTPError(403)

async def get(self, *args, **kwargs):
"""Get an event socket."""
self.pre_get()
res = super().get(*args, **kwargs)
await res

async def on_message(self, message):
"""Public Tornado method that is called when the client sends a message
over this connection. This should **not** be overriden by subclasses."""

# first, verify that the message is an `InlineCompletionRequest`.
self.log.debug("Message received: %s", message)
try:
message = json.loads(message)
request = InlineCompletionRequest(**message)
except ValidationError as e:
self.log.error(e)
return

# next, dispatch the request to the correct handler and create the
# `handle_request` coroutine object
handle_request = None
if request.stream:
try:
handle_request = self._handle_stream_request(request)
except NotImplementedError:
self.log.error(
"Unable to handle stream request. The current `InlineCompletionHandler` does not implement the `handle_stream_request()` method."
)
return

else:
handle_request = self._handle_request(request)

# finally, wrap `handle_request` in an exception handler, and start the
# task on the event loop.
async def handle_request_and_catch():
try:
await handle_request
except Exception as e:
await self.handle_exc(e, request)

self.loop.create_task(handle_request_and_catch())

async def handle_exc(self, e: Exception, request: InlineCompletionRequest):
"""
Handles an exception raised in either `handle_request()` or
`handle_stream_request()`. This base class provides a default
implementation, which may be overriden by subclasses.
"""
error = CompletionError(
type=e.__class__.__name__,
title=e.args[0] if e.args else "Exception",
traceback=traceback.format_exc(),
)
self.write_message(
InlineCompletionReply(
list=InlineCompletionList(items=[]),
error=error,
reply_to=request.number,
)
)

async def _handle_request(self, request: InlineCompletionRequest):
"""Private wrapper around `self.handle_request()`."""
start = time.time()
await self.handle_request(request)
latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Inline completion handler resolved in {latency_ms} ms.")

async def _handle_stream_request(self, request: InlineCompletionRequest):
"""Private wrapper around `self.handle_stream_request()`."""
start = time.time()
await self._handle_stream_request(request)
async for chunk in self.stream(request):
self.write_message(chunk.dict())
latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Inline completion streaming completed in {latency_ms} ms.")
192 changes: 192 additions & 0 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from typing import Dict, Type

from jupyter_ai_magics.providers import BaseProvider
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable

from ..models import (
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
)
from .base import BaseInlineCompletionHandler

SYSTEM_PROMPT = """
You are an application built to provide helpful code completion suggestions.
You should only produce code. Keep comments to minimum, use the
programming language comment syntax. Produce clean code.
The code is written in JupyterLab, a data analysis and code development
environment which can execute code extended with additional syntax for
interactive features, such as magics.
""".strip()

AFTER_TEMPLATE = """
The code after the completion request is:
```
{suffix}
```
""".strip()

DEFAULT_TEMPLATE = """
The document is called `{filename}` and written in {language}.
{after}
Complete the following code:
```
{prefix}"""


class DefaultInlineCompletionHandler(BaseInlineCompletionHandler):
llm_chain: Runnable

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

if llm.is_chat_provider:
prompt_template = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT),
HumanMessagePromptTemplate.from_template(DEFAULT_TEMPLATE),
]
)
else:
prompt_template = PromptTemplate(
input_variables=["prefix", "suffix", "language", "filename"],
template=SYSTEM_PROMPT + "\n\n" + DEFAULT_TEMPLATE,
)

self.llm = llm
self.llm_chain = prompt_template | llm | StrOutputParser()

async def handle_request(
self, request: InlineCompletionRequest
) -> InlineCompletionReply:
"""Handles an inline completion request without streaming."""
self.get_llm_chain()
model_arguments = self._template_inputs_from_request(request)
suggestion = await self.llm_chain.ainvoke(input=model_arguments)
suggestion = self._post_process_suggestion(suggestion, request)
self.write_message(
InlineCompletionReply(
list=InlineCompletionList(items=[{"insertText": suggestion}]),
reply_to=request.number,
)
)

def _write_incomplete_reply(self, request: InlineCompletionRequest):
"""Writes an incomplete `InlineCompletionReply`, indicating to the
client that LLM output is about to streamed across this connection.
Should be called first in `self.handle_stream_request()`."""

token = self._token_from_request(request, 0)
reply = InlineCompletionReply(
list=InlineCompletionList(
items=[
{
# insert text starts empty as we do not pre-generate any part
"insertText": "",
"isIncomplete": True,
"token": token,
}
]
),
reply_to=request.number,
)
self.write_message(reply)

async def handle_stream_request(self, request: InlineCompletionRequest):
# first, send empty initial reply.
self._write_incomplete_reply()

# then, generate and stream LLM output over this connection.
self.get_llm_chain()
token = self._token_from_request(request, 0)
model_arguments = self._template_inputs_from_request(request)
suggestion = ""

async for fragment in self.llm_chain.astream(input=model_arguments):
suggestion += fragment
if suggestion.startswith("```"):
if "\n" not in suggestion:
# we are not ready to apply post-processing
continue
else:
suggestion = self._post_process_suggestion(suggestion, request)
self.write_message(
InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=False,
)
)

# finally, send a message confirming that we are done
self.write_message(
InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=True,
)
)

def _token_from_request(self, request: InlineCompletionRequest, suggestion: int):
"""Generate a deterministic token (for matching streamed messages)
using request number and suggestion number"""
return f"t{request.number}s{suggestion}"

def _template_inputs_from_request(self, request: InlineCompletionRequest) -> Dict:
suffix = request.suffix.strip()
# only add the suffix template if the suffix is there to save input tokens/computation time
after = AFTER_TEMPLATE.format(suffix=suffix) if suffix else ""
filename = request.path.split("/")[-1] if request.path else "untitled"

return {
"prefix": request.prefix,
"after": after,
"language": request.language,
"filename": filename,
"stop": ["\n```"],
}

def _post_process_suggestion(
self, suggestion: str, request: InlineCompletionRequest
) -> str:
"""Remove spurious fragments from the suggestion.
While most models (especially instruct and infill models do not require
any pre-processing, some models such as gpt-4 which only have chat APIs
may require removing spurious fragments. This function uses heuristics
and request data to remove such fragments.
"""
# gpt-4 tends to add "```python" or similar
language = request.language or "python"
markdown_identifiers = {"ipython": ["ipython", "python", "py"]}
bad_openings = [
f"```{identifier}"
for identifier in markdown_identifiers.get(language, [language])
] + ["```"]
for opening in bad_openings:
if suggestion.startswith(opening):
suggestion = suggestion[len(opening) :].lstrip()
# check for the prefix inclusion (only if there was a bad opening)
if suggestion.startswith(request.prefix):
suggestion = suggestion[len(request.prefix) :]
break
return suggestion
Loading

0 comments on commit 0ba1215

Please sign in to comment.