diff --git a/lua/codecompanion/strategies/chat/init.lua b/lua/codecompanion/strategies/chat/init.lua index 203fe2aa..b1941d8f 100644 --- a/lua/codecompanion/strategies/chat/init.lua +++ b/lua/codecompanion/strategies/chat/init.lua @@ -635,8 +635,10 @@ function Chat:submit(opts) message = self.References:clear(self.messages[#self.messages]) self:apply_tools_and_variables(message) + self:check_references() - -- Check if the user has manually overriden the adapter + -- Check if the user has manually overriden the adapter. This is useful if the + -- user loses their internet connection and wants to switch to a local LLM if vim.g.codecompanion_adapter and self.adapter.name ~= vim.g.codecompanion_adapter then self.adapter = adapters.resolve(config.adapters[vim.g.codecompanion_adapter]) end @@ -695,11 +697,11 @@ function Chat:done() end self:add_message({ role = config.constants.LLM_ROLE, content = buf_parse_message(self.bufnr).content }) - self:add_buf_message({ role = config.constants.USER_ROLE, content = "" }) - self.References:render() self.ui:display_tokens() + self.References:render() + if self.status == CONSTANTS.STATUS_SUCCESS and self:has_tools() then buf_parse_tools(self) end @@ -725,6 +727,50 @@ function Chat:done() end end +---Reconcile the references table to the references in the chat buffer +---This allows users to manually remove references themselves +---@return nil +function Chat:check_references() + local refs = self.References:get_from_chat() + if vim.tbl_isempty(refs) and vim.tbl_isempty(self.refs) then + return + end + + -- Fetch references that exist on the chat object but not in the buffer + local to_remove = vim + .iter(pairs(self.refs)) + :filter(function(ref, _) + return not vim.tbl_contains(refs, ref) + end) + :map(function(ref, _) + return ref + end) + :totable() + + if vim.tbl_isempty(to_remove) then + return + end + + -- Remove them from the messages table + self.messages = vim + .iter(self.messages) + :filter(function(msg) + if msg.opts and msg.opts.reference and vim.tbl_contains(to_remove, msg.opts.reference) then + return false + end + return true + end) + :totable() + + -- And from the refs table + self.refs = vim + .iter(self.refs) + :filter(function(ref) + return not vim.tbl_contains(to_remove, ref) + end) + :totable() +end + ---Regenerate the response from the LLM ---@return nil function Chat:regenerate() @@ -739,6 +785,7 @@ end ---@return nil function Chat:stop() local job + self.status = CONSTANTS.STATUS_CANCELLING if self.current_tool then job = self.current_tool self.current_tool = nil diff --git a/lua/codecompanion/strategies/chat/references.lua b/lua/codecompanion/strategies/chat/references.lua index 5c2144f6..6e23c54a 100644 --- a/lua/codecompanion/strategies/chat/references.lua +++ b/lua/codecompanion/strategies/chat/references.lua @@ -106,7 +106,14 @@ function References:add(ref) end if ref then - table.insert(self.chat.refs, ref) + self.chat.refs[ref.id] = { + id = ref.id, + name = ref.name, + source = ref.source, + opts = ref.opts or { + pinned = false, + }, + } end local parsed_buffer = ts_parse_buffer(self.chat) @@ -193,8 +200,16 @@ function References:render() local lines = {} table.insert(lines, "> Sharing:") - for _, ref in ipairs(self.chat.refs) do - table.insert(lines, string.format("> - %s", ref.id)) + for ref, _ in pairs(self.chat.refs) do + if not ref then + goto continue + end + if ref.opts and ref.opts.pinned then + table.insert(lines, string.format("> - %s%s", config.display.chat.icons.pinned_buffer, ref)) + else + table.insert(lines, string.format("> - %s", ref)) + end + ::continue:: end table.insert(lines, "") @@ -212,4 +227,51 @@ function References:make_id_from_buf(bufnr) return vim.fn.fnamemodify(bufname, ":.") end +---Get the references from the chat buffer +---@return table +function References:get_from_chat() + local refs = {} + + local parser = vim.treesitter.get_parser(self.chat.bufnr, "markdown") + local query = vim.treesitter.query.parse( + "markdown", + string.format( + [[( + (section + (atx_heading) @heading + (#match? @heading "## %s") + ) +)]], + user_role + ) + ) + + local root = parser:parse()[1]:root() + local last_heading = nil + + -- Get the last heading + for id, node in query:iter_captures(root, self.chat.bufnr, 0, -1) do + if query.captures[id] == "heading" then + last_heading = node + end + end + + if last_heading then + local start_row, _, _, _ = last_heading:range() + + -- Get the references + local refs_query = vim.treesitter.query.parse("markdown", [[(block_quote (list (list_item (paragraph)? @ref)))]]) + + for id, node in refs_query:iter_captures(root, self.chat.bufnr, start_row, -1) do + if refs_query.captures[id] == "ref" then + local ref = vim.treesitter.get_node_text(node, self.chat.bufnr) + ref:gsub("^> %- ", "") + table.insert(refs, vim.trim(ref)) + end + end + end + + return refs +end + return References diff --git a/lua/codecompanion/types.lua b/lua/codecompanion/types.lua index 8a82c377..c9f2c525 100644 --- a/lua/codecompanion/types.lua +++ b/lua/codecompanion/types.lua @@ -72,9 +72,9 @@ ---@field tokens? table Total tokens spent in the chat buffer so far ---@class CodeCompanion.Chat.Ref ----@field source string The source of the reference e.g. slash_command ----@field name string The name of the source e.g. buffer ---@field id string The unique ID of the reference which links it to a message in the chat buffer and is displayed to the user +---@field name string The name of the source e.g. buffer +---@field source string The source of the reference e.g. slash_command ---@field opts? table ---@class CodeCompanion.Chat.UI diff --git a/tests/strategies/chat/test_references.lua b/tests/strategies/chat/test_references.lua index 8f740474..6b3ba84b 100644 --- a/tests/strategies/chat/test_references.lua +++ b/tests/strategies/chat/test_references.lua @@ -7,7 +7,7 @@ local chat T["References"] = new_set({ hooks = { - pre_once = function() + pre_case = function() chat, _ = h.setup_chat_buffer() end, post_once = function() @@ -36,4 +36,54 @@ T["References"]["Can be added to the UI of the chat buffer"] = function() h.eq("> - testing again", buffer[5]) end +T["References"]["Can be deleted"] = function() + -- First add a reference and a message that uses it + chat.References:add({ + source = "test", + name = "test", + id = "test.lua", + }) + + -- Add messages with and without references + chat.messages = { + { + role = "user", + content = "Message with reference", + opts = { + reference = "test.lua", + }, + }, + { + role = "user", + content = "Message without reference", + opts = {}, + }, + } + + -- Store initial message count + local initial_count = #chat.messages + h.eq(initial_count, 2, "Should start with 2 messages") + h.eq(vim.tbl_count(chat.refs), 1, "Should have 1 reference") + + -- Mock the get_from_chat method to return empty refs + chat.References.get_from_chat = function() + return {} + end + + -- Run the check_references function + chat:check_references() + + -- Verify results + h.eq(#chat.messages, 1, "Should have 1 message after reference removal") + h.eq(chat.messages[1].content, "Message without reference", "Message without reference should remain") + h.eq(vim.tbl_count(chat.refs), 0, "Should have 0 references") + + -- Verify the message with reference was removed + local has_ref_message = vim.iter(chat.messages):any(function(msg) + return msg.opts.reference == "test.lua" + end) + + h.eq(has_ref_message, false, "Message with removed reference should be gone") +end + return T