Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make selections more robust #54

Merged
merged 1 commit into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions packages/jupyter-ai/src/commands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { AiService } from './handler';
import { OpenTaskDialog } from './components/open-task-dialog';
import { ClosableDialog } from './widgets/closable-dialog';
import { InsertionContext, insertOutput } from './inserter';
import { getTextSelection, getEditor } from './utils';
import { getTextSelection, getEditor, getCellIndex } from './utils';

/**
* Creates a placeholder markdown cell either above/below the currently active
Expand Down Expand Up @@ -53,7 +53,7 @@ function insertPlaceholderCell(
* Replaces a cell with a markdown cell containing a string.
*/
function replaceWithMarkdown(notebook: Notebook, cellId: string, body: string) {
const cellIdx = findIndex(notebook, cellId);
const cellIdx = getCellIndex(notebook, cellId);
if (cellIdx === -1) {
return;
}
Expand All @@ -64,15 +64,8 @@ function replaceWithMarkdown(notebook: Notebook, cellId: string, body: string) {
NotebookActions.run(notebook);
}

function findIndex(notebook: Notebook, id: string): number {
const idx = notebook.model?.sharedModel.cells.findIndex(
cell => cell.getId() === id
);
return idx === undefined ? -1 : idx;
}

function deleteCell(notebook: Notebook, id: string): void {
const idx = findIndex(notebook, id);
const idx = getCellIndex(notebook, id);
if (idx !== -1) {
notebook.model?.sharedModel.deleteCell(idx);
}
Expand Down
11 changes: 8 additions & 3 deletions packages/jupyter-ai/src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,13 @@ function ChatBody(): JSX.Element {
setLoading(false);
}

if (replaceSelection) {
replaceSelectionFn(response.output);
if (replaceSelection && selection) {
const { cellId, ...selectionProps } = selection;
replaceSelectionFn({
...selectionProps,
...(cellId && { cellId }),
text: response.output
});
}
setMessageGroups(messageGroups => [
...messageGroups,
Expand Down Expand Up @@ -147,7 +152,7 @@ function ChatBody(): JSX.Element {
value={input}
onChange={setInput}
onSend={onSend}
hasSelection={!!selection}
hasSelection={!!selection?.text}
includeSelection={includeSelection}
toggleIncludeSelection={() =>
setIncludeSelection(includeSelection => !includeSelection)
Expand Down
10 changes: 5 additions & 5 deletions packages/jupyter-ai/src/contexts/selection-context.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import React, { useCallback, useContext, useEffect, useState } from 'react';
import { SelectionWatcher } from '../selection-watcher';
import { Selection, SelectionWatcher } from '../selection-watcher';

const SelectionContext = React.createContext<
[string, (value: string) => unknown]
[Selection | null, (value: Selection) => unknown]
>([
'',
null,
() => {
/* noop */
}
Expand All @@ -23,7 +23,7 @@ export function SelectionContextProvider({
selectionWatcher,
children
}: SelectionContextProviderProps) {
const [selection, setSelection] = useState('');
const [selection, setSelection] = useState<Selection | null>(null);

/**
* Effect: subscribe to SelectionWatcher
Expand All @@ -35,7 +35,7 @@ export function SelectionContextProvider({
}, []);

const replaceSelection = useCallback(
(value: string) => {
(value: Selection) => {
selectionWatcher.replaceSelection(value);
},
[selectionWatcher]
Expand Down
142 changes: 129 additions & 13 deletions packages/jupyter-ai/src/selection-watcher.ts
Original file line number Diff line number Diff line change
@@ -1,52 +1,168 @@
import { JupyterFrontEnd, LabShell } from '@jupyterlab/application';
import { DocumentWidget } from '@jupyterlab/docregistry';
import { CodeEditor } from '@jupyterlab/codeeditor';
import { CodeMirrorEditor } from '@jupyterlab/codemirror';
import { FileEditor } from '@jupyterlab/fileeditor';
import { Notebook } from '@jupyterlab/notebook';

import { find } from '@lumino/algorithm';
import { Widget } from '@lumino/widgets';
import { Signal } from '@lumino/signaling';

import { getEditor, getTextSelection } from './utils';
import { getCellIndex } from './utils';

/**
* Gets the editor instance used by a document widget. Returns `null` if unable.
*/
function getEditor(widget: Widget | null) {
if (!(widget instanceof DocumentWidget)) {
return null;
}

let editor: CodeEditor.IEditor | undefined;
const { content } = widget;

if (content instanceof FileEditor) {
editor = content.editor;
} else if (content instanceof Notebook) {
editor = content.activeCell?.editor;
}

if (!(editor instanceof CodeMirrorEditor)) {
return undefined;
}

return editor;
}

/**
* Gets a Selection object from a document widget. Returns `null` if unable.
*/
function getTextSelection(widget: Widget | null): Selection | null {
const editor = getEditor(widget);
// widget type check is redundant but hints the type to TypeScript
if (!editor || !(widget instanceof DocumentWidget)) {
return null;
}

let cellId: string | undefined = undefined;
if (widget.content instanceof Notebook) {
cellId = widget.content.activeCell?.model.id;
}

let { start, end, ...selectionObj } = editor.getSelection();
const startOffset = editor.getOffsetAt(start);
const endOffset = editor.getOffsetAt(end);
const text = editor.model.sharedModel
.getSource()
.substring(startOffset, endOffset);

// ensure start <= end
// required for editor.model.sharedModel.updateSource()
if (startOffset > endOffset) {
[start, end] = [end, start];
}

return {
...selectionObj,
start,
end,
text,
widgetId: widget.id,
...(cellId && {
cellId
})
};
}

export type Selection = CodeEditor.ITextSelection & {
/**
* The text within the selection as a string.
*/
text: string;
/**
* The ID of the document widget in which the selection was made.
*/
widgetId: string;
/**
* The ID of the cell in which the selection was made, if the original widget
* was a notebook.
*/
cellId?: string;
};

export class SelectionWatcher {
constructor(shell: JupyterFrontEnd.IShell) {
if (!(shell instanceof LabShell)) {
throw 'Shell is not an instance of LabShell. Jupyter AI does not currently support custom shells.';
}

shell.currentChanged.connect((sender, args) => {
this._shell = shell;
this._shell.currentChanged.connect((sender, args) => {
this._mainAreaWidget = args.newValue;
});

setInterval(this._poll.bind(this), 200);
}

get selectionChanged() {
return this._selectionChanged;
}

replaceSelection(value: string) {
if (!(this._mainAreaWidget instanceof DocumentWidget)) {
replaceSelection(selection: Selection) {
// unfortunately shell.currentWidget doesn't update synchronously after
// shell.activateById(), which is why we have to get a reference to the
// widget manually.
const widget = find(
this._shell.widgets(),
widget => widget.id === selection.widgetId
);
if (!(widget instanceof DocumentWidget)) {
return;
}

const editor = getEditor(this._mainAreaWidget.content);
editor?.replaceSelection?.(value);
}
// activate the widget if not already active
this._shell.activateById(selection.widgetId);

protected _poll() {
if (!(this._mainAreaWidget instanceof DocumentWidget)) {
// activate notebook cell if specified
if (widget.content instanceof Notebook && selection.cellId) {
const cellIndex = getCellIndex(widget.content, selection.cellId);
if (cellIndex !== -1) {
widget.content.activeCellIndex = cellIndex;
}
}

// get editor instance
const editor = getEditor(widget);
if (!editor) {
return;
}

editor.model.sharedModel.updateSource(
editor.getOffsetAt(selection.start),
editor.getOffsetAt(selection.end),
selection.text
);
const newPosition = editor.getPositionAt(
editor.getOffsetAt(selection.start) + selection.text.length
);
editor.setSelection({ start: newPosition, end: newPosition });
}

protected _poll() {
const prevSelection = this._selection;
const currSelection = getTextSelection(this._mainAreaWidget.content);
const currSelection = getTextSelection(this._mainAreaWidget);

if (prevSelection === currSelection) {
if (prevSelection?.text === currSelection?.text) {
return;
}

this._selection = currSelection;
this._selectionChanged.emit(currSelection);
}

protected _shell: LabShell;
protected _mainAreaWidget: Widget | null = null;
protected _selection = '';
protected _selectionChanged = new Signal<this, string>(this);
protected _selection: Selection | null = null;
protected _selectionChanged = new Signal<this, Selection | null>(this);
}
10 changes: 10 additions & 0 deletions packages/jupyter-ai/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,13 @@ export function getEditor(widget: Widget): CodeEditor.IEditor | undefined {

return editor;
}

/**
* Gets the index of the cell associated with `cellId`.
*/
export function getCellIndex(notebook: Notebook, cellId: string): number {
const idx = notebook.model?.sharedModel.cells.findIndex(
cell => cell.getId() === cellId
);
return idx === undefined ? -1 : idx;
}