diff --git a/lua/codecompanion/strategies/chat/init.lua b/lua/codecompanion/strategies/chat/init.lua index 203fe2aa..8759387b 100644 --- a/lua/codecompanion/strategies/chat/init.lua +++ b/lua/codecompanion/strategies/chat/init.lua @@ -635,6 +635,7 @@ function Chat:submit(opts) message = self.References:clear(self.messages[#self.messages]) self:apply_tools_and_variables(message) + self:check_references() -- Check if the user has manually overriden the adapter if vim.g.codecompanion_adapter and self.adapter.name ~= vim.g.codecompanion_adapter then @@ -697,8 +698,8 @@ function Chat:done() self:add_message({ role = config.constants.LLM_ROLE, content = buf_parse_message(self.bufnr).content }) self:add_buf_message({ role = config.constants.USER_ROLE, content = "" }) - self.References:render() self.ui:display_tokens() + self.References:render() if self.status == CONSTANTS.STATUS_SUCCESS and self:has_tools() then buf_parse_tools(self) @@ -725,6 +726,49 @@ function Chat:done() end end +---Reconcile the references table to the references in the chat buffer +---This allows users to manually remove references themselves +---@return nil +function Chat:check_references() + local refs = self.References:get_from_chat() + if vim.tbl_isempty(refs) and vim.tbl_isempty(self.refs) then + return + end + + -- Fetch references that exist on the chat object but not in the buffer + local to_remove = vim + .iter(self.refs) + :filter(function(ref) + return not vim.tbl_contains(refs, ref.id) + end) + :map(function(ref) + return ref.id + end) + :totable() + + if vim.tbl_isempty(to_remove) then + return + end + + -- Remove them from the messages table + self.messages = vim + .iter(self.messages) + :filter(function(msg) + if msg.opts and msg.opts.reference and vim.tbl_contains(to_remove, msg.opts.reference) then + return false + end + return true + end) + :totable() + + -- And from the refs table + for i, ref in pairs(self.refs) do + if vim.tbl_contains(to_remove, ref.id) then + table.remove(self.refs, i) + end + end +end + ---Regenerate the response from the LLM ---@return nil function Chat:regenerate() @@ -739,6 +783,7 @@ end ---@return nil function Chat:stop() local job + self.status = CONSTANTS.STATUS_CANCELLING if self.current_tool then job = self.current_tool self.current_tool = nil diff --git a/lua/codecompanion/strategies/chat/references.lua b/lua/codecompanion/strategies/chat/references.lua index 5c2144f6..f86ac979 100644 --- a/lua/codecompanion/strategies/chat/references.lua +++ b/lua/codecompanion/strategies/chat/references.lua @@ -106,6 +106,11 @@ function References:add(ref) end if ref then + if not ref.opts then + ref.opts = { + pinned = false, + } + end table.insert(self.chat.refs, ref) end @@ -193,8 +198,16 @@ function References:render() local lines = {} table.insert(lines, "> Sharing:") - for _, ref in ipairs(self.chat.refs) do - table.insert(lines, string.format("> - %s", ref.id)) + for _, ref in pairs(self.chat.refs) do + if not ref then + goto continue + end + if ref.opts and ref.opts.pinned then + table.insert(lines, string.format("> - %s%s", config.display.chat.icons.pinned_buffer, ref.id)) + else + table.insert(lines, string.format("> - %s", ref.id)) + end + ::continue:: end table.insert(lines, "") @@ -212,4 +225,47 @@ function References:make_id_from_buf(bufnr) return vim.fn.fnamemodify(bufname, ":.") end +---Get the references from the chat buffer +---@return table +function References:get_from_chat() + local refs = {} + local parser = vim.treesitter.get_parser(self.chat.bufnr, "markdown") + local query = vim.treesitter.query.parse( + "markdown", + string.format( + [[( + (section + (atx_heading) @heading + (#match? @heading "## %s") + ) +)]], + user_role + ) + ) + local root = parser:parse()[1]:root() + local last_heading = nil + -- Get the last heading + for id, node in query:iter_captures(root, self.chat.bufnr, 0, -1) do + if query.captures[id] == "heading" then + last_heading = node + end + end + + if last_heading then + local start_row, _, _, _ = last_heading:range() + -- Get the references + local refs_query = + vim.treesitter.query.parse("markdown", [[(block_quote (list (list_item (paragraph (inline) @ref))))]]) + for id, node in refs_query:iter_captures(root, self.chat.bufnr, start_row, -1) do + if refs_query.captures[id] == "ref" then + local ref = vim.treesitter.get_node_text(node, self.chat.bufnr) + ref:gsub("^> %- ", "") + table.insert(refs, vim.trim(ref)) + end + end + end + + return refs +end + return References diff --git a/tests/strategies/chat/test_references.lua b/tests/strategies/chat/test_references.lua index 8f740474..06f6b1e9 100644 --- a/tests/strategies/chat/test_references.lua +++ b/tests/strategies/chat/test_references.lua @@ -7,7 +7,7 @@ local chat T["References"] = new_set({ hooks = { - pre_once = function() + pre_case = function() chat, _ = h.setup_chat_buffer() end, post_once = function() @@ -36,4 +36,70 @@ T["References"]["Can be added to the UI of the chat buffer"] = function() h.eq("> - testing again", buffer[5]) end +T["References"]["Can be deleted"] = function() + -- Add references + chat.References:add({ + source = "test", + name = "test", + id = "test.lua", + }) + chat.References:add({ + source = "test", + name = "test2", + id = "test2.lua", + }) + + -- Add messages with and without references + chat.messages = { + { + role = "user", + content = "Message with reference", + opts = { + reference = "test.lua", + }, + }, + { + role = "user", + content = "Message with another reference", + opts = { + reference = "test2.lua", + }, + }, + + { + role = "user", + content = "Message without reference", + opts = {}, + }, + } + + local initial_count = #chat.messages + h.eq(initial_count, 3, "Should start with 3 messages") + h.eq(vim.tbl_count(chat.refs), 2, "Should have 2 reference") + + -- Mock the get_from_chat method + chat.References.get_from_chat = function() + return { "test2.lua" } + end + + chat:check_references() + + -- Verify results + h.eq(#chat.messages, 2, "Should have 1 messages after reference removal") + h.eq(chat.messages[1].content, "Message with another reference") + h.eq(chat.messages[2].content, "Message without reference") + h.eq(vim.tbl_count(chat.refs), 1, "Should have 1 reference") + + -- Verify the message with reference was removed + local has_ref_message = vim.iter(chat.messages):any(function(msg) + return msg.opts.reference == "test.lua" + end) + h.eq(has_ref_message, false, "Message with first reference should be gone") + + has_ref_message = vim.iter(chat.messages):any(function(msg) + return msg.opts.reference == "test2.lua" + end) + h.eq(has_ref_message, true, "Message with second reference should still be present") +end + return T