Skip to content

Commit

Permalink
feat: WIP sources rework
Browse files Browse the repository at this point in the history
  • Loading branch information
Saghen committed Aug 30, 2024
1 parent e837718 commit ad347a1
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 18 deletions.
Empty file.
160 changes: 160 additions & 0 deletions lua/blink/cmp/sources/lib/init.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
local config = require('blink.cmp.config')
local sources = {
registered = {
lsp = require('blink.cmp.sources.lsp'),
buffer = require('blink.cmp.sources.buffer'),
snippets = require('blink.cmp.sources.snippets'),
},

-- hack: sweet mother of all hacks
last_in_flight_id = -1,
in_flight_id = {
lsp = -1,
buffer = -1,
snippets = -1,
},

sources_responses = {},
current_context_id = -1,
on_completions_callback = function(_) end,
}

--- @return string[]
function sources.get_trigger_characters()
local blocked_trigger_characters = {}
for _, char in ipairs(config.trigger.blocked_trigger_characters) do
blocked_trigger_characters[char] = true
end

local trigger_characters = {}
for _, source in pairs(sources.registered) do
if source.get_trigger_characters ~= nil then
local source_trigger_characters = source.get_trigger_characters()
for _, char in ipairs(source_trigger_characters) do
if not blocked_trigger_characters[char] then table.insert(trigger_characters, char) end
end
end
end
return trigger_characters
end

function sources.listen_on_completions(callback) sources.on_completions_callback = callback end

--- @param context blink.cmp.ShowContext
function sources.completions(context)
-- a new context means we should refetch everything
local is_new_context = context.id ~= sources.current_context_id
sources.current_context_id = context.id

if is_new_context then sources.sources_responses = {} end

for source_name, source in pairs(sources.registered) do
-- the source indicates we should refetch when this character is typed
local trigger_characters = source.get_trigger_characters ~= nil and source.get_trigger_characters() or {}
local trigger_character = context.trigger_character
and vim.tbl_contains(trigger_characters, context.trigger_character)
-- the source indicates the previous results were incomplete and should be refetched on the next trigger
local previous_incomplete = sources.sources_responses[source_name] ~= nil
and sources.sources_responses[source_name].isIncomplete
-- check if we have no data and no calls are in flight
local no_data = sources.sources_responses[source_name] == nil and sources.in_flight_id[source_name] == -1

-- if none of these are true, we can use the existing cached results
if is_new_context or trigger_character or previous_incomplete or no_data then
if source.cancel_completions ~= nil then source.cancel_completions() end

-- register the call
sources.last_in_flight_id = sources.last_in_flight_id + 1
local in_flight_id = sources.last_in_flight_id
sources.in_flight_id[source_name] = in_flight_id

-- get the reason for the trigger
local trigger_context = trigger_character
and { kind = vim.lsp.protocol.CompletionTriggerKind.TriggerCharacter, character = context.trigger_character }
or previous_incomplete and { kind = vim.lsp.protocol.CompletionTriggerKind.TriggerForIncompleteCompletions }
or { kind = vim.lsp.protocol.CompletionTriggerKind.Invoked }

-- fetch them completions
-- fixme: what if we refetch due to incomplete items or a trigger_character? the context trigger id wouldnt change
-- change so stale data would be returned if the source doesn't support cancellation
local cursor_column = vim.api.nvim_win_get_cursor(0)[2]
local source_context = vim.fn.deepcopy(context)
source_context.trigger = trigger_context
source.completions(source_context, function(items)
-- a new call was made or this one was cancelled
if sources.in_flight_id[source_name] ~= in_flight_id then return end
sources.in_flight_id[source_name] = -1

sources.add_source_completions(source_name, items, cursor_column)
if not sources.some_in_flight() then sources.send_completions(context) end
end)
end
end

-- no completions will be in flight if none of them ran,
-- so we send the completions
if not sources.some_in_flight() then sources.send_completions(context) end
end

--- @param source_name string
--- @param source_response blink.cmp.CompletionResponse
--- @param cursor_column number
function sources.add_source_completions(source_name, source_response, cursor_column)
for _, item in ipairs(source_response.items) do
item.source = source_name
item.cursor_column = cursor_column
end

sources.sources_responses[source_name] = source_response
end

--- @return boolean
function sources.some_in_flight()
for _, in_flight in pairs(sources.in_flight_id) do
if in_flight ~= -1 then return true end
end
return false
end

--- @param context blink.cmp.ShowContext
function sources.send_completions(context)
local sources_responses = sources.sources_responses
-- apply source filters
for _, source in pairs(sources.registered) do
if source.filter_completions ~= nil then
sources_responses = source.filter_completions(context, sources_responses)
end
end

-- flatten the items
local flattened_items = {}
for source_name, response in pairs(sources.sources_responses) do
local source = sources.registered[source_name]
if source.should_show_completions == nil or source.should_show_completions(context, sources_responses) then
vim.list_extend(flattened_items, response.items)
end
end

sources.on_completions_callback(context, flattened_items)
end

function sources.cancel_completions()
for source_name, source in pairs(sources.registered) do
sources.in_flight_id[source_name] = -1
if source.cancel_completions ~= nil then source.cancel_completions() end
end
end

--- @param item blink.cmp.CompletionItem
--- @param callback fun(resolved_item: blink.cmp.CompletionItem | nil)
--- @return fun(): nil Cancelation function
function sources.resolve(item, callback)
local item_source = sources.registered[item.source]
if item_source == nil or item_source.resolve == nil then
callback(nil)
return function() end
end
return item_source.resolve(item, callback) or function() end
end

return sources
55 changes: 55 additions & 0 deletions lua/blink/cmp/sources/lib/source.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
local source = {}

--- @param config blink.cmp.SourceProviderConfig
--- @return blink.cmp.SourceProvider
function source.new(config)
local self = setmetatable({}, { __index = source })
self.module = require(config[1]).new(config.opts or {})

self.fallback_for = config.fallback_for
self.keyword_length = config.keyword_length
self.score_offset = config.score_offset
self.deduplicate = config.deduplicate
self.override = config.override or {}

return self
end

function source:get_trigger_characters()
if self.override.get_trigger_characters ~= nil then
return self.override.get_trigger_characters(self.module.get_trigger_characters)
end
if self.module.get_trigger_characters == nil then return {} end
return self.module.get_trigger_characters()
end

function source:completions(context, callback)
if self.override.completions ~= nil then
return self.override.completions(context, callback, self.module.completions)
end
self.module.completions(context, callback)
end

function source:filter_completions(context, source_responses)
if self.override.filter_completions ~= nil then
return self.override.filter_completions(context, source_responses, self.module.filter_completions)
end
if self.module.filter_completions == nil then return source_responses end
return self.module.filter_completions(context, source_responses)
end

function source:resolve(item, callback)
if self.override.resolve ~= nil then return self.override.resolve(item, callback, self.module.resolve) end
if self.module.resolve == nil then return callback(item) end
self.module.resolve(item, callback)
end

function source:cancel_completions()
if self.override.cancel_completions ~= nil then
return self.override.cancel_completions(self.module.cancel_completions)
end
if self.module.cancel_completions == nil then return end
self.module.cancel_completions()
end

return source
33 changes: 33 additions & 0 deletions lua/blink/cmp/sources/lib/types.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
--- @class blink.cmp.CompletionTriggerContext
--- @field kind number
--- @field character string
---
--- @class blink.cmp.CompletionContext : blink.cmp.ShowContext
--- @field trigger blink.cmp.CompletionTriggerContext | nil
---
--- @class blink.cmp.CompletionResponse
--- @field isIncomplete boolean
--- @field items blink.cmp.CompletionItem[]
---
--- @class blink.cmp.Source
--- @field get_trigger_characters (fun(): string[]) | nil
--- @field completions fun(context: blink.cmp.CompletionContext, callback: fun(response: blink.cmp.CompletionResponse))
--- @field filter_completions (fun(context: blink.cmp.CompletionContext, source_responses: table<string, blink.cmp.CompletionResponse>): blink.cmp.CompletionItem[]) | nil
--- @field cancel_completions fun() | nil
--- @field should_show_completions (fun(context: blink.cmp.CompletionContext, source_responses: table<string, blink.cmp.CompletionResponse>): boolean) | nil
--- @field resolve (fun(item: blink.cmp.CompletionItem, callback: fun(resolved_item: lsp.CompletionItem | nil)): ((fun(): nil) | nil)) | nil
---
--- @class blink.cmp.SourceProvider
--- @field module blink.cmp.Source
--- @field fallback_for string[]
--- @field keyword_length number
--- @field score_offset number
--- @field deduplicate blink.cmp.DeduplicateConfig
--- @field override blink.cmp.OverrideConfig
---
--- @field get_trigger_characters fun(self: blink.cmp.SourceProvider): string[]
--- @field completions fun(self: blink.cmp.SourceProvider, context: blink.cmp.CompletionContext, callback: fun(response: blink.cmp.CompletionResponse))
--- @field filter_completions fun(self: blink.cmp.SourceProvider, context: blink.cmp.CompletionContext, source_responses: table<string, blink.cmp.CompletionResponse>): blink.cmp.CompletionItem[]
--- @field cancel_completions fun(self: blink.cmp.SourceProvider)
--- @field should_show_completions fun(self: blink.cmp.SourceProvider, context: blink.cmp.CompletionContext, source_responses: table<string, blink.cmp.CompletionResponse>): boolean
--- @field resolve fun(self: blink.cmp.SourceProvider, item: blink.cmp.CompletionItem, callback: fun(resolved_item: lsp.CompletionItem | nil)): ((fun(): nil) | nil)
18 changes: 0 additions & 18 deletions lua/blink/cmp/sources/types.lua

This file was deleted.

0 comments on commit ad347a1

Please sign in to comment.