Skip to content

Commit

Permalink
Simplify how source buffer is being saved and unify prompt clearing
Browse files Browse the repository at this point in the history
- Instead of passing state.source around, simply update it only if we are
  outside of chat window and make M.open() noop otherwise
- Unify when last prompt is cleared, clear it in .ask always instead of
  doing it only when explicitely submitting prompt
- Unify how selection is checked when simply showing it

Signed-off-by: Tomas Slusny <slusnucky@gmail.com>
  • Loading branch information
deathbeam committed Nov 17, 2024
1 parent c3518e5 commit 28c8072
Showing 1 changed file with 61 additions and 36 deletions.
97 changes: 61 additions & 36 deletions lua/CopilotChat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ end

local function find_lines_between_separator(
lines,
start_line,
current_line,
start_pattern,
end_pattern,
allow_end_of_file
Expand All @@ -72,7 +72,6 @@ local function find_lines_between_separator(
end

local line_count = #lines
local current_line = vim.api.nvim_win_get_cursor(0)[1] - start_line
local separator_line_start = 1
local separator_line_finish = line_count
local found_one = false
Expand Down Expand Up @@ -385,14 +384,20 @@ end

--- Open the chat window.
---@param config CopilotChat.config|CopilotChat.config.prompt|nil
---@param source CopilotChat.config.source?
function M.open(config, source)
function M.open(config)
-- If we are already in chat window, do nothing
if state.chat:active() then
return
end

config = vim.tbl_deep_extend('force', M.config, config or {})
state.config = config
state.source = vim.tbl_extend('keep', source or {}, {

-- Save the source buffer and window (e.g the buffer we are currently asking about)
state.source = {
bufnr = vim.api.nvim_get_current_buf(),
winnr = vim.api.nvim_get_current_win(),
})
}

utils.return_to_normal_mode()
state.chat:open(config)
Expand All @@ -407,12 +412,11 @@ end

--- Toggle the chat window.
---@param config CopilotChat.config|nil
---@param source CopilotChat.config.source?
function M.toggle(config, source)
function M.toggle(config)
if state.chat:visible() then
M.close()
else
M.open(config, source)
M.open(config)
end
end

Expand Down Expand Up @@ -472,11 +476,10 @@ end
--- Ask a question to the Copilot model.
---@param prompt string
---@param config CopilotChat.config|CopilotChat.config.prompt|nil
---@param source CopilotChat.config.source?
function M.ask(prompt, config, source)
function M.ask(prompt, config)
config = vim.tbl_deep_extend('force', M.config, config or {})
vim.diagnostic.reset(vim.api.nvim_create_namespace('copilot_diagnostics'))
M.open(config, source)
M.open(config)

prompt = vim.trim(prompt or '')
if prompt == '' then
Expand All @@ -489,6 +492,14 @@ function M.ask(prompt, config, source)
finish(config, nil, true)
end

-- Clear the current input prompt before asking a new question
local chat_lines = vim.api.nvim_buf_get_lines(state.chat.bufnr, 0, -1, false)
local _, start_line, end_line =
find_lines_between_separator(chat_lines, #chat_lines, M.config.separator .. '$', nil, true)
if #chat_lines == end_line then
vim.api.nvim_buf_set_lines(state.chat.bufnr, start_line, end_line, false, { '' })
end

state.chat:append(prompt)
state.chat:append('\n\n' .. config.answer_header .. config.separator .. '\n\n')

Expand Down Expand Up @@ -884,17 +895,15 @@ function M.setup(config)

map_key(M.config.mappings.submit_prompt, bufnr, function()
local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local lines, start_line, end_line =
find_lines_between_separator(chat_lines, 0, M.config.separator .. '$', nil, true)
local input = vim.trim(table.concat(lines, '\n'))
if input ~= '' then
-- If we are entering the input at the end, replace it
if #chat_lines == end_line then
vim.api.nvim_buf_set_lines(bufnr, start_line, end_line, false, { '' })
end

M.ask(input, state.config, state.source)
end
local current_line = vim.api.nvim_win_get_cursor(0)[1]
local lines = find_lines_between_separator(
chat_lines,
current_line,
M.config.separator .. '$',
nil,
true
)
M.ask(vim.trim(table.concat(lines, '\n')), state.config)
end)

map_key(M.config.mappings.toggle_sticky, bufnr, function()
Expand All @@ -909,7 +918,7 @@ function M.setup(config)

local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local _, start_line, end_line =
find_lines_between_separator(chat_lines, 0, M.config.separator .. '$', nil, true)
find_lines_between_separator(chat_lines, cur_line, M.config.separator .. '$', nil, true)

if vim.startswith(current_line, '> ') then
return
Expand Down Expand Up @@ -942,15 +951,20 @@ function M.setup(config)

map_key(M.config.mappings.accept_diff, bufnr, function()
local selection = get_selection()
if not selection.start_row or not selection.end_row then
if not selection or not selection.start_row or not selection.end_row then
return
end

local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local current_line = vim.api.nvim_win_get_cursor(0)[1]
local section_lines, start_line =
find_lines_between_separator(chat_lines, 0, M.config.separator .. '$')
local lines =
find_lines_between_separator(section_lines, start_line - 1, '^```%w+$', '^```$')
find_lines_between_separator(chat_lines, current_line, M.config.separator .. '$')
local lines = find_lines_between_separator(
section_lines,
current_line - start_line - 1,
'^```%w+$',
'^```$'
)
if #lines > 0 then
vim.api.nvim_buf_set_text(
state.source.bufnr,
Expand All @@ -965,15 +979,20 @@ function M.setup(config)

map_key(M.config.mappings.yank_diff, bufnr, function()
local selection = get_selection()
if not selection.start_row or not selection.end_row then
if not selection or not selection.lines then
return
end

local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local current_line = vim.api.nvim_win_get_cursor(0)[1]
local section_lines, start_line =
find_lines_between_separator(chat_lines, 0, M.config.separator .. '$')
local lines =
find_lines_between_separator(section_lines, start_line - 1, '^```%w+$', '^```$')
find_lines_between_separator(chat_lines, current_line, M.config.separator .. '$')
local lines = find_lines_between_separator(
section_lines,
current_line - start_line - 1,
'^```%w+$',
'^```$'
)
if #lines > 0 then
local content = table.concat(lines, '\n')
vim.fn.setreg(M.config.mappings.yank_diff.register, content)
Expand All @@ -982,15 +1001,21 @@ function M.setup(config)

map_key(M.config.mappings.show_diff, bufnr, function()
local selection = get_selection()
if not selection or not selection.start_row or not selection.end_row then
if not selection or not selection.lines then
return
end

local chat_lines = vim.api.nvim_buf_get_lines(state.chat.bufnr, 0, -1, false)
local current_line = vim.api.nvim_win_get_cursor(0)[1]
local section_lines, start_line =
find_lines_between_separator(chat_lines, 0, M.config.separator .. '$')
find_lines_between_separator(chat_lines, current_line, M.config.separator .. '$')
local lines = table.concat(
find_lines_between_separator(section_lines, start_line - 1, '^```%w+$', '^```$'),
find_lines_between_separator(
section_lines,
current_line - start_line - 1,
'^```%w+$',
'^```$'
),
'\n'
)
if vim.trim(lines) ~= '' then
Expand Down Expand Up @@ -1026,7 +1051,7 @@ function M.setup(config)

map_key(M.config.mappings.show_user_selection, bufnr, function()
local selection = get_selection()
if not selection.start_row or not selection.end_row then
if not selection or not selection.lines then
return
end

Expand Down

0 comments on commit 28c8072

Please sign in to comment.