Skip to content

Commit

Permalink
feat(chat): can delete references
Browse files Browse the repository at this point in the history
  • Loading branch information
olimorris committed Dec 18, 2024
1 parent 49d38f7 commit 367eccc
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 4 deletions.
47 changes: 46 additions & 1 deletion lua/codecompanion/strategies/chat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ function Chat:submit(opts)
message = self.References:clear(self.messages[#self.messages])

self:apply_tools_and_variables(message)
self:check_references()

-- Check if the user has manually overriden the adapter
if vim.g.codecompanion_adapter and self.adapter.name ~= vim.g.codecompanion_adapter then
Expand Down Expand Up @@ -697,8 +698,8 @@ function Chat:done()
self:add_message({ role = config.constants.LLM_ROLE, content = buf_parse_message(self.bufnr).content })

self:add_buf_message({ role = config.constants.USER_ROLE, content = "" })
self.References:render()
self.ui:display_tokens()
self.References:render()

if self.status == CONSTANTS.STATUS_SUCCESS and self:has_tools() then
buf_parse_tools(self)
Expand All @@ -725,6 +726,49 @@ function Chat:done()
end
end

---Reconcile the references table to the references in the chat buffer
---This allows users to manually remove references themselves
---@return nil
function Chat:check_references()
local refs = self.References:get_from_chat()
if vim.tbl_isempty(refs) and vim.tbl_isempty(self.refs) then
return
end

-- Fetch references that exist on the chat object but not in the buffer
local to_remove = vim
.iter(self.refs)
:filter(function(ref)
return not vim.tbl_contains(refs, ref.id)
end)
:map(function(ref)
return ref.id
end)
:totable()

if vim.tbl_isempty(to_remove) then
return
end

-- Remove them from the messages table
self.messages = vim
.iter(self.messages)
:filter(function(msg)
if msg.opts and msg.opts.reference and vim.tbl_contains(to_remove, msg.opts.reference) then
return false
end
return true
end)
:totable()

-- And from the refs table
for i, ref in pairs(self.refs) do
if vim.tbl_contains(to_remove, ref.id) then
table.remove(self.refs, i)
end
end
end

---Regenerate the response from the LLM
---@return nil
function Chat:regenerate()
Expand All @@ -739,6 +783,7 @@ end
---@return nil
function Chat:stop()
local job
self.status = CONSTANTS.STATUS_CANCELLING
if self.current_tool then
job = self.current_tool
self.current_tool = nil
Expand Down
60 changes: 58 additions & 2 deletions lua/codecompanion/strategies/chat/references.lua
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ function References:add(ref)
end

if ref then
if not ref.opts then
ref.opts = {
pinned = false,
}
end
table.insert(self.chat.refs, ref)
end

Expand Down Expand Up @@ -193,8 +198,16 @@ function References:render()
local lines = {}
table.insert(lines, "> Sharing:")

for _, ref in ipairs(self.chat.refs) do
table.insert(lines, string.format("> - %s", ref.id))
for _, ref in pairs(self.chat.refs) do
if not ref then
goto continue
end
if ref.opts and ref.opts.pinned then
table.insert(lines, string.format("> - %s%s", config.display.chat.icons.pinned_buffer, ref.id))
else
table.insert(lines, string.format("> - %s", ref.id))
end
::continue::
end
table.insert(lines, "")

Expand All @@ -212,4 +225,47 @@ function References:make_id_from_buf(bufnr)
return vim.fn.fnamemodify(bufname, ":.")
end

---Get the references from the chat buffer
---@return table
function References:get_from_chat()
local refs = {}
local parser = vim.treesitter.get_parser(self.chat.bufnr, "markdown")
local query = vim.treesitter.query.parse(
"markdown",
string.format(
[[(
(section
(atx_heading) @heading
(#match? @heading "## %s")
)
)]],
user_role
)
)
local root = parser:parse()[1]:root()
local last_heading = nil
-- Get the last heading
for id, node in query:iter_captures(root, self.chat.bufnr, 0, -1) do
if query.captures[id] == "heading" then
last_heading = node
end
end

if last_heading then
local start_row, _, _, _ = last_heading:range()
-- Get the references
local refs_query =
vim.treesitter.query.parse("markdown", [[(block_quote (list (list_item (paragraph (inline) @ref))))]])
for id, node in refs_query:iter_captures(root, self.chat.bufnr, start_row, -1) do
if refs_query.captures[id] == "ref" then
local ref = vim.treesitter.get_node_text(node, self.chat.bufnr)
ref:gsub("^> %- ", "")
table.insert(refs, vim.trim(ref))
end
end
end

return refs
end

return References
68 changes: 67 additions & 1 deletion tests/strategies/chat/test_references.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ local chat

T["References"] = new_set({
hooks = {
pre_once = function()
pre_case = function()
chat, _ = h.setup_chat_buffer()
end,
post_once = function()
Expand Down Expand Up @@ -36,4 +36,70 @@ T["References"]["Can be added to the UI of the chat buffer"] = function()
h.eq("> - testing again", buffer[5])
end

T["References"]["Can be deleted"] = function()
-- Add references
chat.References:add({
source = "test",
name = "test",
id = "<buf>test.lua</buf>",
})
chat.References:add({
source = "test",
name = "test2",
id = "<buf>test2.lua</buf>",
})

-- Add messages with and without references
chat.messages = {
{
role = "user",
content = "Message with reference",
opts = {
reference = "<buf>test.lua</buf>",
},
},
{
role = "user",
content = "Message with another reference",
opts = {
reference = "<buf>test2.lua</buf>",
},
},

{
role = "user",
content = "Message without reference",
opts = {},
},
}

local initial_count = #chat.messages
h.eq(initial_count, 3, "Should start with 3 messages")
h.eq(vim.tbl_count(chat.refs), 2, "Should have 2 reference")

-- Mock the get_from_chat method
chat.References.get_from_chat = function()
return { "<buf>test2.lua</buf>" }
end

chat:check_references()

-- Verify results
h.eq(#chat.messages, 2, "Should have 1 messages after reference removal")
h.eq(chat.messages[1].content, "Message with another reference")
h.eq(chat.messages[2].content, "Message without reference")
h.eq(vim.tbl_count(chat.refs), 1, "Should have 1 reference")

-- Verify the message with reference was removed
local has_ref_message = vim.iter(chat.messages):any(function(msg)
return msg.opts.reference == "<buf>test.lua</buf>"
end)
h.eq(has_ref_message, false, "Message with first reference should be gone")

has_ref_message = vim.iter(chat.messages):any(function(msg)
return msg.opts.reference == "<buf>test2.lua</buf>"
end)
h.eq(has_ref_message, true, "Message with second reference should still be present")
end

return T

0 comments on commit 367eccc

Please sign in to comment.