From 6f5c924af4d5e2ebd5555d041ba904a0d5345ea2 Mon Sep 17 00:00:00 2001 From: "Visal .In" Date: Mon, 30 Oct 2023 09:29:12 +0700 Subject: [PATCH] feat: Improve auto complete (#183) * improve auto complete * escape id if it is conflicted with keywords * feat: improve the auto complete in the middle of the statement --- .../CodeEditor/SchemaCompletionTree.ts | 66 ++++++- .../components/CodeEditor/SqlCodeEditor.tsx | 18 +- .../CodeEditor/autocomplete_test_utils.ts | 140 +++++++++++++++ .../CodeEditor/handleAutoComplete.test.ts | 163 ++++-------------- .../CodeEditor/handleCustomSqlAutoComplete.ts | 61 +++---- 5 files changed, 273 insertions(+), 175 deletions(-) create mode 100644 src/renderer/components/CodeEditor/autocomplete_test_utils.ts diff --git a/src/renderer/components/CodeEditor/SchemaCompletionTree.ts b/src/renderer/components/CodeEditor/SchemaCompletionTree.ts index 7a1c80a..42fb1ac 100644 --- a/src/renderer/components/CodeEditor/SchemaCompletionTree.ts +++ b/src/renderer/components/CodeEditor/SchemaCompletionTree.ts @@ -1,16 +1,22 @@ import { Completion } from '@codemirror/autocomplete'; +import { SQLDialectSpec } from 'language/dist'; import { DatabaseSchema, DatabaseSchemaList, TableSchema, } from 'types/SqlSchema'; -function buildTableCompletionTree(table: TableSchema): SchemaCompletionTree { +function buildTableCompletionTree( + table: TableSchema, + dialect: SQLDialectSpec, + keywords: Record, +): SchemaCompletionTree { const root = new SchemaCompletionTree(); for (const col of Object.values(table.columns)) { root.addOption(col.name, { label: col.name, + apply: escapeConflictedId(dialect, col.name, keywords), type: 'property', detail: col.dataType, boost: 3, @@ -21,7 +27,9 @@ function buildTableCompletionTree(table: TableSchema): SchemaCompletionTree { } function buildDatabaseCompletionTree( - database: DatabaseSchema + database: DatabaseSchema, + dialect: SQLDialectSpec, + keywords: Record, ): SchemaCompletionTree { const root = new SchemaCompletionTree(); @@ -33,15 +41,30 @@ function buildDatabaseCompletionTree( boost: 1, }); - root.addChild(table.name, buildTableCompletionTree(table)); + root.addChild( + table.name, + buildTableCompletionTree(table, dialect, keywords), + ); } return root; } +function escapeConflictedId( + dialect: SQLDialectSpec, + label: string, + keywords: Record, +): string { + if (keywords[label.toLowerCase()]) + return `${dialect.identifierQuotes}${label}${dialect.identifierQuotes}`; + return label; +} + function buildCompletionTree( schema: DatabaseSchemaList | undefined, - currentDatabase: string | undefined + currentDatabase: string | undefined, + dialect: SQLDialectSpec, + keywords: Record, ): SchemaCompletionTree { const root: SchemaCompletionTree = new SchemaCompletionTree(); if (!schema) return root; @@ -51,23 +74,31 @@ function buildCompletionTree( for (const table of Object.values(schema[currentDatabase].tables)) { root.addOption(table.name, { label: table.name, + apply: escapeConflictedId(dialect, table.name, keywords), type: 'table', detail: 'table', boost: 1, }); - root.addChild(table.name, buildTableCompletionTree(table)); + root.addChild( + table.name, + buildTableCompletionTree(table, dialect, keywords), + ); } } for (const database of Object.values(schema)) { root.addOption(database.name, { label: database.name, + apply: escapeConflictedId(dialect, database.name, keywords), type: 'property', detail: 'database', }); - root.addChild(database.name, buildDatabaseCompletionTree(database)); + root.addChild( + database.name, + buildDatabaseCompletionTree(database, dialect, keywords), + ); } return root; @@ -75,12 +106,31 @@ function buildCompletionTree( export default class SchemaCompletionTree { protected options: Record = {}; protected child: Record = {}; + protected keywords: Record = {}; static build( schema: DatabaseSchemaList | undefined, - currentDatabase: string | undefined + currentDatabase: string | undefined, + dialect: SQLDialectSpec, ) { - return buildCompletionTree(schema, currentDatabase); + const keywords = (dialect.keywords + ' ' + dialect.builtin) + .split(' ') + .filter(Boolean) + .map((s) => s.toLowerCase()); + + const keywordDict = keywords.reduce( + (a, keyword) => { + a[keyword] = true; + return a; + }, + {} as Record, + ); + + return buildCompletionTree(schema, currentDatabase, dialect, keywordDict); + } + + getLength() { + return Object.keys(this.options).length; } addOption(name: string, complete: Completion) { diff --git a/src/renderer/components/CodeEditor/SqlCodeEditor.tsx b/src/renderer/components/CodeEditor/SqlCodeEditor.tsx index 89c434e..39bdba8 100644 --- a/src/renderer/components/CodeEditor/SqlCodeEditor.tsx +++ b/src/renderer/components/CodeEditor/SqlCodeEditor.tsx @@ -30,6 +30,7 @@ import { MySQLDialect, MySQLTooltips } from 'dialects/MySQLDialect'; import { QueryDialetType } from 'libs/QueryBuilder'; import { PgDialect, PgTooltips } from 'dialects/PgDialect copy'; import { useKeybinding } from 'renderer/contexts/KeyBindingProvider'; +import SchemaCompletionTree from './SchemaCompletionTree'; const SqlCodeEditor = forwardRef(function SqlCodeEditor( props: ReactCodeMirrorProps & { @@ -43,16 +44,28 @@ const SqlCodeEditor = forwardRef(function SqlCodeEditor( const { binding } = useKeybinding(); const theme = useCodeEditorTheme(); + const dialect = props.dialect === 'mysql' ? MySQLDialect : PgDialect; + const tooltips = props.dialect === 'mysql' ? MySQLTooltips : PgTooltips; + + const schemaTree = useMemo(() => { + return SchemaCompletionTree.build( + schema?.getSchema(), + currentDatabase, + dialect.spec, + ); + }, [schema, currentDatabase, dialect]); + const customAutoComplete = useCallback( (context: CompletionContext, tree: SyntaxNode): CompletionResult | null => { return handleCustomSqlAutoComplete( context, tree, + schemaTree, schema?.getSchema(), currentDatabase, ); }, - [schema, currentDatabase], + [schema, schemaTree, currentDatabase], ); const tableNameHighlightPlugin = useMemo(() => { @@ -64,9 +77,6 @@ const SqlCodeEditor = forwardRef(function SqlCodeEditor( return createSQLTableNameHighlightPlugin([]); }, [schema, currentDatabase]); - const dialect = props.dialect === 'mysql' ? MySQLDialect : PgDialect; - const tooltips = props.dialect === 'mysql' ? MySQLTooltips : PgTooltips; - const keyExtension = useMemo(() => { return keymap.of([ // Prevent the default behavior if it matches any of diff --git a/src/renderer/components/CodeEditor/autocomplete_test_utils.ts b/src/renderer/components/CodeEditor/autocomplete_test_utils.ts new file mode 100644 index 0000000..9847134 --- /dev/null +++ b/src/renderer/components/CodeEditor/autocomplete_test_utils.ts @@ -0,0 +1,140 @@ +import { EditorState } from '@codemirror/state'; +import { + CompletionContext, + CompletionResult, + CompletionSource, +} from '@codemirror/autocomplete'; +import handleCustomSqlAutoComplete from './handleCustomSqlAutoComplete'; +import { MySQL, genericCompletion } from './../../../language/dist'; +import { + DatabaseSchemaList, + TableColumnSchema, + TableSchema, +} from 'types/SqlSchema'; +import SchemaCompletionTree from './SchemaCompletionTree'; + +export function get_test_autocomplete( + doc: string, + { + schema, + currentDatabase, + }: { schema: DatabaseSchemaList; currentDatabase?: string }, +) { + const cur = doc.indexOf('|'), + dialect = MySQL; + + doc = doc.slice(0, cur) + doc.slice(cur + 1); + const state = EditorState.create({ + doc, + selection: { anchor: cur }, + extensions: [ + dialect, + dialect.language.data.of({ + autocomplete: genericCompletion((context, tree) => + handleCustomSqlAutoComplete( + context, + tree, + SchemaCompletionTree.build(schema, currentDatabase, dialect.spec), + schema, + currentDatabase, + ), + ), + }), + ], + }); + + const result = state.languageDataAt('autocomplete', cur)[0]( + new CompletionContext(state, cur, false), + ); + return result as CompletionResult | null; +} + +export function convert_autocomplete_to_string( + result: CompletionResult | null, +) { + return !result + ? '' + : result.options + .slice() + .sort( + (a, b) => + (b.boost || 0) - (a.boost || 0) || (a.label < b.label ? -1 : 1), + ) + .map((o) => o.apply || o.label) + .join(', '); +} + +function map_column_type( + tableName: string, + name: string, + type: string, +): TableColumnSchema { + const tokens = type.split('('); + let enumValues: string[] | undefined; + + if (tokens[1]) { + // remove ) + enumValues = tokens[1] + .replace(')', '') + .replaceAll("'", '') + .split(',') + .map((a) => a.trim()); + } + + return { + name, + tableName, + schemaName: '', + charLength: 0, + comment: '', + enumValues, + dataType: tokens[0], + nullable: true, + }; +} + +function map_cols( + tableName: string, + cols: Record, +): Record { + return Object.entries(cols).reduce( + (acc, [colName, colType]) => { + acc[colName] = map_column_type(tableName, colName, colType); + return acc; + }, + {} as Record, + ); +} + +function map_table( + tables: Record>, +): Record { + return Object.entries(tables).reduce( + (acc, [tableName, cols]) => { + acc[tableName] = { + columns: map_cols(tableName, cols), + constraints: [], + name: tableName, + type: 'TABLE', + primaryKey: [], + }; + + return acc; + }, + {} as Record, + ); +} + +export function create_test_schema( + schemas: Record>>, +) { + return Object.entries(schemas).reduce((acc, [schema, tables]) => { + acc[schema] = { + name: schema, + events: [], + triggers: [], + tables: map_table(tables), + }; + return acc; + }, {} as DatabaseSchemaList); +} diff --git a/src/renderer/components/CodeEditor/handleAutoComplete.test.ts b/src/renderer/components/CodeEditor/handleAutoComplete.test.ts index 44302ef..8750999 100644 --- a/src/renderer/components/CodeEditor/handleAutoComplete.test.ts +++ b/src/renderer/components/CodeEditor/handleAutoComplete.test.ts @@ -1,155 +1,50 @@ -import { EditorState } from '@codemirror/state'; import { - CompletionContext, - CompletionResult, - CompletionSource, -} from '@codemirror/autocomplete'; -import handleCustomSqlAutoComplete from './handleCustomSqlAutoComplete'; -import { MySQL, genericCompletion } from './../../../language/dist'; -import { DatabaseSchemaList } from 'types/SqlSchema'; + create_test_schema, + get_test_autocomplete as get, + convert_autocomplete_to_string as str, +} from './autocomplete_test_utils'; -function get( - doc: string, - { - schema, - currentDatabase, - }: { schema: DatabaseSchemaList; currentDatabase?: string } -) { - const cur = doc.indexOf('|'), - dialect = MySQL; - - doc = doc.slice(0, cur) + doc.slice(cur + 1); - const state = EditorState.create({ - doc, - selection: { anchor: cur }, - extensions: [ - dialect, - dialect.language.data.of({ - autocomplete: genericCompletion((context, tree) => - handleCustomSqlAutoComplete(context, tree, schema, currentDatabase) - ), - }), - ], - }); - - const result = state.languageDataAt('autocomplete', cur)[0]( - new CompletionContext(state, cur, false) - ); - return result as CompletionResult | null; -} - -function str(result: CompletionResult | null) { - return !result - ? '' - : result.options - .slice() - .sort( - (a, b) => - (b.boost || 0) - (a.boost || 0) || (a.label < b.label ? -1 : 1) - ) - .map((o) => o.apply || o.label) - .join(', '); -} - -const schema1: DatabaseSchemaList = { +const schema1 = create_test_schema({ foo: { - name: 'foo', - events: [], - triggers: [], - tables: { - users: { - name: 'users', - constraints: [], - primaryKey: [], - type: 'TABLE', - columns: { - id: { - schemaName: '', - tableName: '', - name: 'id', - comment: '', - charLength: 0, - dataType: 'int', - nullable: false, - }, - name: { - schemaName: '', - tableName: '', - name: 'name', - comment: '', - charLength: 0, - dataType: 'varchar', - nullable: false, - }, - address: { - schemaName: '', - tableName: '', - name: 'address', - comment: '', - charLength: 0, - dataType: 'varchar', - nullable: false, - }, - }, - }, - products: { - name: 'products', - constraints: [], - primaryKey: [], - type: 'TABLE', - columns: { - id: { - name: 'name', - comment: '', - charLength: 0, - dataType: 'varchar', - nullable: false, - schemaName: '', - tableName: '' - }, - name: { - name: 'description', - comment: '', - charLength: 0, - dataType: 'varchar', - nullable: false, - schemaName: '', - tableName: '' - }, - product_type: { - name: 'product_type', - comment: '', - charLength: 0, - dataType: 'enum', - enumValues: ['HOME', 'BOOK', 'FASHION'], - nullable: false, - schemaName: '', - tableName: '' - }, - }, - }, + users: { id: 'int', name: 'varchar', address: 'varchar' }, + products: { + name: 'varchar', + description: 'varchar', + product_type: "enum('HOME', 'BOOK', 'FASHION')", }, }, -}; +}); describe('SQL completion', () => { it('completes table names', () => { expect( - str(get('select u|', { schema: schema1, currentDatabase: 'foo' })) + str(get('select u|', { schema: schema1, currentDatabase: 'foo' })), ).toBe('products, users, foo'); }); it('completes column based on the FROM table', () => { expect( str( - get('select n| from users', { schema: schema1, currentDatabase: 'foo' }) - ) + get('select n| from users', { + schema: schema1, + currentDatabase: 'foo', + }), + ), ).toBe('address, id, name, products, users, foo'); + + expect( + str( + get('select users.|, name from users', { + schema: schema1, + currentDatabase: 'foo', + }), + ), + ).toBe('address, id, name'); }); it('completes column after specified table with .', () => { expect( - str(get('select users.|', { schema: schema1, currentDatabase: 'foo' })) + str(get('select users.|', { schema: schema1, currentDatabase: 'foo' })), ).toBe('address, id, name'); }); @@ -159,8 +54,8 @@ describe('SQL completion', () => { get("select * from products where product_type = 'H|'", { schema: schema1, currentDatabase: 'foo', - }) - ) + }), + ), ).toBe('BOOK, FASHION, HOME'); }); }); diff --git a/src/renderer/components/CodeEditor/handleCustomSqlAutoComplete.ts b/src/renderer/components/CodeEditor/handleCustomSqlAutoComplete.ts index 00b7bba..204fb21 100644 --- a/src/renderer/components/CodeEditor/handleCustomSqlAutoComplete.ts +++ b/src/renderer/components/CodeEditor/handleCustomSqlAutoComplete.ts @@ -14,7 +14,7 @@ function getNodeString(context: CompletionContext, node: SyntaxNode) { function allowNodeWhenSearchForIdentify( context: CompletionContext, - node: SyntaxNode + node: SyntaxNode, ) { if (node.type.name === 'Operator') return true; if (node.type.name === 'Keyword') { @@ -25,14 +25,14 @@ function allowNodeWhenSearchForIdentify( function getIdentifierParentPath( context: CompletionContext, - node: SyntaxNode | null + node: SyntaxNode | null, ): string[] { const result: string[] = []; let prev = node; while (prev) { if (prev.type.name !== '.') { - const currentStr = getNodeString(context, prev); + const currentStr = getNodeString(context, prev).trim(); if (currentStr[currentStr.length - 1] !== '.') break; } else { prev = prev.prevSibling; @@ -41,12 +41,12 @@ function getIdentifierParentPath( if ( !['Identifier', 'QuotedIdentifier', 'CompositeIdentifier'].includes( - prev.type.name + prev.type.name, ) ) break; - result.push(getNodeString(context, prev).replaceAll('.', '')); + result.push(getNodeString(context, prev).trim().replaceAll('.', '')); prev = prev.prevSibling; } @@ -56,7 +56,7 @@ function getIdentifierParentPath( function searchForIdentifier( context: CompletionContext, - node: SyntaxNode + node: SyntaxNode, ): string | null { let currentNode = node.prevSibling; while (currentNode) { @@ -85,7 +85,7 @@ function handleEnumAutoComplete( node: SyntaxNode, schema: DatabaseSchemaList, currentDatabase: string | undefined, - exposedTable: TableSchema[] + exposedTable: TableSchema[], ): CompletionResult | null { let currentNode = node; @@ -109,7 +109,7 @@ function handleEnumAutoComplete( schema, currentDatabase, exposedTable, - identifier + identifier, ); if ( @@ -123,7 +123,7 @@ function handleEnumAutoComplete( label: value, type: 'enum', detail: 'enum', - }) + }), ); return { @@ -137,13 +137,11 @@ function handleEnumAutoComplete( } function getSchemaSuggestionFromPath( - schema: DatabaseSchemaList | undefined, - currentDatabase: string | undefined, - path: string[] + tree: SchemaCompletionTree, + path: string[], ) { - if (!schema) return []; + if (tree.getLength() === 0) return []; - const tree = SchemaCompletionTree.build(schema, currentDatabase); let treeWalk: SchemaCompletionTree | undefined = tree; for (const currentPath of path) { @@ -160,15 +158,16 @@ function getSchemaSuggestionFromPath( export default function handleCustomSqlAutoComplete( context: CompletionContext, tree: SyntaxNode, + schemaTree: SchemaCompletionTree, schema: DatabaseSchemaList | undefined, - currentDatabase: string | undefined + currentDatabase: string | undefined, ): CompletionResult | null { if (!schema) return null; if (tree.type.name === 'Script') { tree = tree.resolveInner( context.state.doc.sliceString(0, context.pos).trimEnd().length, - -1 + -1, ); } @@ -176,9 +175,11 @@ export default function handleCustomSqlAutoComplete( tree = SqlCompletionHelper.resolveInner(tree, context.pos) || tree; } + console.log(tree); + const currentSelectedTableNames = SqlCompletionHelper.fromTables( context, - tree + tree, ); const currentSelectedTables = currentSelectedTableNames @@ -186,8 +187,8 @@ export default function handleCustomSqlAutoComplete( SqlCompletionHelper.getTableFromIdentifier( schema, currentDatabase, - tableName - ) + tableName, + ), ) .filter(Boolean) as TableSchema[]; @@ -201,7 +202,7 @@ export default function handleCustomSqlAutoComplete( type: 'property', detail: column.dataType, boost: 3, - }) + }), ); if (SqlCompletionHelper.isInsideFrom(context, tree)) { @@ -213,13 +214,17 @@ export default function handleCustomSqlAutoComplete( tree, schema, currentDatabase, - currentSelectedTables + currentSelectedTables, ); if (enumSuggestion) { return enumSuggestion; } - if (['Identifier', 'QuotedIdentifier'].includes(tree.type.name)) { + if ( + ['Identifier', 'QuotedIdentifier', 'Keyword', 'Builtin'].includes( + tree.type.name, + ) + ) { return { from: tree.type.name === 'QuotedIdentifier' ? tree.from + 1 : tree.from, to: tree.type.name === 'QuotedIdentifier' ? tree.to - 1 : tree.to, @@ -228,9 +233,8 @@ export default function handleCustomSqlAutoComplete( options: [ ...currentColumnCompletion, ...getSchemaSuggestionFromPath( - schema, - currentDatabase, - getIdentifierParentPath(context, tree.prevSibling) + schemaTree, + getIdentifierParentPath(context, tree.prevSibling), ), ], }; @@ -243,9 +247,8 @@ export default function handleCustomSqlAutoComplete( validFor: /^\w*$/, options: [ ...getSchemaSuggestionFromPath( - schema, - currentDatabase, - getIdentifierParentPath(context, tree) + schemaTree, + getIdentifierParentPath(context, tree), ), ], }; @@ -257,7 +260,7 @@ export default function handleCustomSqlAutoComplete( from: context.pos, options: [ ...currentColumnCompletion, - ...getSchemaSuggestionFromPath(schema, currentDatabase, []), + ...getSchemaSuggestionFromPath(schemaTree, []), ], validFor: /^\w*$/, };