diff --git a/lua/codecompanion/strategies/chat/init.lua b/lua/codecompanion/strategies/chat/init.lua
index 203fe2aa..b1941d8f 100644
--- a/lua/codecompanion/strategies/chat/init.lua
+++ b/lua/codecompanion/strategies/chat/init.lua
@@ -635,8 +635,10 @@ function Chat:submit(opts)
message = self.References:clear(self.messages[#self.messages])
self:apply_tools_and_variables(message)
+ self:check_references()
- -- Check if the user has manually overriden the adapter
+ -- Check if the user has manually overriden the adapter. This is useful if the
+ -- user loses their internet connection and wants to switch to a local LLM
if vim.g.codecompanion_adapter and self.adapter.name ~= vim.g.codecompanion_adapter then
self.adapter = adapters.resolve(config.adapters[vim.g.codecompanion_adapter])
end
@@ -695,11 +697,11 @@ function Chat:done()
end
self:add_message({ role = config.constants.LLM_ROLE, content = buf_parse_message(self.bufnr).content })
-
self:add_buf_message({ role = config.constants.USER_ROLE, content = "" })
- self.References:render()
self.ui:display_tokens()
+ self.References:render()
+
if self.status == CONSTANTS.STATUS_SUCCESS and self:has_tools() then
buf_parse_tools(self)
end
@@ -725,6 +727,50 @@ function Chat:done()
end
end
+---Reconcile the references table to the references in the chat buffer
+---This allows users to manually remove references themselves
+---@return nil
+function Chat:check_references()
+ local refs = self.References:get_from_chat()
+ if vim.tbl_isempty(refs) and vim.tbl_isempty(self.refs) then
+ return
+ end
+
+ -- Fetch references that exist on the chat object but not in the buffer
+ local to_remove = vim
+ .iter(pairs(self.refs))
+ :filter(function(ref, _)
+ return not vim.tbl_contains(refs, ref)
+ end)
+ :map(function(ref, _)
+ return ref
+ end)
+ :totable()
+
+ if vim.tbl_isempty(to_remove) then
+ return
+ end
+
+ -- Remove them from the messages table
+ self.messages = vim
+ .iter(self.messages)
+ :filter(function(msg)
+ if msg.opts and msg.opts.reference and vim.tbl_contains(to_remove, msg.opts.reference) then
+ return false
+ end
+ return true
+ end)
+ :totable()
+
+ -- And from the refs table
+ self.refs = vim
+ .iter(self.refs)
+ :filter(function(ref)
+ return not vim.tbl_contains(to_remove, ref)
+ end)
+ :totable()
+end
+
---Regenerate the response from the LLM
---@return nil
function Chat:regenerate()
@@ -739,6 +785,7 @@ end
---@return nil
function Chat:stop()
local job
+ self.status = CONSTANTS.STATUS_CANCELLING
if self.current_tool then
job = self.current_tool
self.current_tool = nil
diff --git a/lua/codecompanion/strategies/chat/references.lua b/lua/codecompanion/strategies/chat/references.lua
index 5c2144f6..6e23c54a 100644
--- a/lua/codecompanion/strategies/chat/references.lua
+++ b/lua/codecompanion/strategies/chat/references.lua
@@ -106,7 +106,14 @@ function References:add(ref)
end
if ref then
- table.insert(self.chat.refs, ref)
+ self.chat.refs[ref.id] = {
+ id = ref.id,
+ name = ref.name,
+ source = ref.source,
+ opts = ref.opts or {
+ pinned = false,
+ },
+ }
end
local parsed_buffer = ts_parse_buffer(self.chat)
@@ -193,8 +200,16 @@ function References:render()
local lines = {}
table.insert(lines, "> Sharing:")
- for _, ref in ipairs(self.chat.refs) do
- table.insert(lines, string.format("> - %s", ref.id))
+ for ref, _ in pairs(self.chat.refs) do
+ if not ref then
+ goto continue
+ end
+ if ref.opts and ref.opts.pinned then
+ table.insert(lines, string.format("> - %s%s", config.display.chat.icons.pinned_buffer, ref))
+ else
+ table.insert(lines, string.format("> - %s", ref))
+ end
+ ::continue::
end
table.insert(lines, "")
@@ -212,4 +227,51 @@ function References:make_id_from_buf(bufnr)
return vim.fn.fnamemodify(bufname, ":.")
end
+---Get the references from the chat buffer
+---@return table
+function References:get_from_chat()
+ local refs = {}
+
+ local parser = vim.treesitter.get_parser(self.chat.bufnr, "markdown")
+ local query = vim.treesitter.query.parse(
+ "markdown",
+ string.format(
+ [[(
+ (section
+ (atx_heading) @heading
+ (#match? @heading "## %s")
+ )
+)]],
+ user_role
+ )
+ )
+
+ local root = parser:parse()[1]:root()
+ local last_heading = nil
+
+ -- Get the last heading
+ for id, node in query:iter_captures(root, self.chat.bufnr, 0, -1) do
+ if query.captures[id] == "heading" then
+ last_heading = node
+ end
+ end
+
+ if last_heading then
+ local start_row, _, _, _ = last_heading:range()
+
+ -- Get the references
+ local refs_query = vim.treesitter.query.parse("markdown", [[(block_quote (list (list_item (paragraph)? @ref)))]])
+
+ for id, node in refs_query:iter_captures(root, self.chat.bufnr, start_row, -1) do
+ if refs_query.captures[id] == "ref" then
+ local ref = vim.treesitter.get_node_text(node, self.chat.bufnr)
+ ref:gsub("^> %- ", "")
+ table.insert(refs, vim.trim(ref))
+ end
+ end
+ end
+
+ return refs
+end
+
return References
diff --git a/lua/codecompanion/types.lua b/lua/codecompanion/types.lua
index 8a82c377..c9f2c525 100644
--- a/lua/codecompanion/types.lua
+++ b/lua/codecompanion/types.lua
@@ -72,9 +72,9 @@
---@field tokens? table Total tokens spent in the chat buffer so far
---@class CodeCompanion.Chat.Ref
----@field source string The source of the reference e.g. slash_command
----@field name string The name of the source e.g. buffer
---@field id string The unique ID of the reference which links it to a message in the chat buffer and is displayed to the user
+---@field name string The name of the source e.g. buffer
+---@field source string The source of the reference e.g. slash_command
---@field opts? table
---@class CodeCompanion.Chat.UI
diff --git a/tests/strategies/chat/test_references.lua b/tests/strategies/chat/test_references.lua
index 8f740474..6b3ba84b 100644
--- a/tests/strategies/chat/test_references.lua
+++ b/tests/strategies/chat/test_references.lua
@@ -7,7 +7,7 @@ local chat
T["References"] = new_set({
hooks = {
- pre_once = function()
+ pre_case = function()
chat, _ = h.setup_chat_buffer()
end,
post_once = function()
@@ -36,4 +36,54 @@ T["References"]["Can be added to the UI of the chat buffer"] = function()
h.eq("> - testing again", buffer[5])
end
+T["References"]["Can be deleted"] = function()
+ -- First add a reference and a message that uses it
+ chat.References:add({
+ source = "test",
+ name = "test",
+ id = "test.lua",
+ })
+
+ -- Add messages with and without references
+ chat.messages = {
+ {
+ role = "user",
+ content = "Message with reference",
+ opts = {
+ reference = "test.lua",
+ },
+ },
+ {
+ role = "user",
+ content = "Message without reference",
+ opts = {},
+ },
+ }
+
+ -- Store initial message count
+ local initial_count = #chat.messages
+ h.eq(initial_count, 2, "Should start with 2 messages")
+ h.eq(vim.tbl_count(chat.refs), 1, "Should have 1 reference")
+
+ -- Mock the get_from_chat method to return empty refs
+ chat.References.get_from_chat = function()
+ return {}
+ end
+
+ -- Run the check_references function
+ chat:check_references()
+
+ -- Verify results
+ h.eq(#chat.messages, 1, "Should have 1 message after reference removal")
+ h.eq(chat.messages[1].content, "Message without reference", "Message without reference should remain")
+ h.eq(vim.tbl_count(chat.refs), 0, "Should have 0 references")
+
+ -- Verify the message with reference was removed
+ local has_ref_message = vim.iter(chat.messages):any(function(msg)
+ return msg.opts.reference == "test.lua"
+ end)
+
+ h.eq(has_ref_message, false, "Message with removed reference should be gone")
+end
+
return T