Skip to content

Commit

Permalink
Magic tweaks (jupyterlab#31)
Browse files Browse the repository at this point in the history
* suppress warning when using OpenAIChat models

* do not call IPython.display.display() when returning output

* add raw format mode
  • Loading branch information
dlqqq authored Apr 5, 2023
1 parent 512fa74 commit 6929e82
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions packages/jupyter-ai/jupyter_ai/magics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
import json
import warnings
from typing import Optional
from importlib_metadata import entry_points

from jupyter_ai.providers import BaseProvider
from importlib_metadata import entry_points
from IPython import get_ipython
from IPython.core.magic import Magics, magics_class, line_cell_magic
from IPython.core.magic_arguments import magic_arguments, argument, parse_argstring
from IPython.display import display, HTML, Markdown, Math, JSON
from IPython.display import HTML, Markdown, Math, JSON

from jupyter_ai.providers import BaseProvider


MODEL_ID_ALIASES = {
Expand All @@ -23,6 +25,7 @@
"math": Math,
"md": Markdown,
"json": JSON,
"raw": None
}

class FormatDict(dict):
Expand All @@ -41,6 +44,11 @@ def __init__(self, shell):
super(AiMagics, self).__init__(shell)
self.transcript_openai = []

# suppress warning when using old OpenAIChat provider
warnings.filterwarnings("ignore", message="You are trying to use a chat model. This way of initializing it is "
"no longer supported. Instead, please use: "
"`from langchain.chat_models import ChatOpenAI`")

# load model providers from entry point
self.providers = {}
eps = entry_points()
Expand Down Expand Up @@ -97,7 +105,7 @@ def _get_provider(self, provider_id: Optional[str]) -> BaseProvider:
optionally prefixed with the ID of the model provider, delimited
by a colon.""")
@argument('-f', '--format',
choices=["markdown", "html", "json", "math", "md"],
choices=["markdown", "html", "json", "math", "md", "raw"],
nargs="?",
default="markdown",
help="""IPython display to use when rendering output. [default="markdown"]""")
Expand All @@ -124,7 +132,7 @@ def ai(self, line, cell=None):
provider_id, local_model_id = self._decompose_model_id(args.model_id)
Provider = self._get_provider(provider_id)
if Provider is None:
return display(f"Cannot determine model provider from model ID {args.model_id}.")
return f"Cannot determine model provider from model ID {args.model_id}."

# if `--reset` is specified, reset transcript and return early
if (provider_id == "openai-chat" and args.reset):
Expand Down Expand Up @@ -162,10 +170,12 @@ def ai(self, line, cell=None):

# build output display
DisplayClass = DISPLAYS_BY_FORMAT[args.format]
if DisplayClass is None:
return output
if args.format == 'json':
# JSON display expects a dict, not a JSON string
output = json.loads(output)
output_display = DisplayClass(output)

# finally, display output display
return display(output_display)
return output_display

0 comments on commit 6929e82

Please sign in to comment.