diff --git a/lua/CopilotChat/init.lua b/lua/CopilotChat/init.lua index fec59c1f..99d24f8f 100644 --- a/lua/CopilotChat/init.lua +++ b/lua/CopilotChat/init.lua @@ -62,7 +62,7 @@ end local function find_lines_between_separator( lines, - start_line, + current_line, start_pattern, end_pattern, allow_end_of_file @@ -72,7 +72,6 @@ local function find_lines_between_separator( end local line_count = #lines - local current_line = vim.api.nvim_win_get_cursor(0)[1] - start_line local separator_line_start = 1 local separator_line_finish = line_count local found_one = false @@ -385,14 +384,20 @@ end --- Open the chat window. ---@param config CopilotChat.config|CopilotChat.config.prompt|nil ----@param source CopilotChat.config.source? -function M.open(config, source) +function M.open(config) + -- If we are already in chat window, do nothing + if state.chat:active() then + return + end + config = vim.tbl_deep_extend('force', M.config, config or {}) state.config = config - state.source = vim.tbl_extend('keep', source or {}, { + + -- Save the source buffer and window (e.g the buffer we are currently asking about) + state.source = { bufnr = vim.api.nvim_get_current_buf(), winnr = vim.api.nvim_get_current_win(), - }) + } utils.return_to_normal_mode() state.chat:open(config) @@ -407,12 +412,11 @@ end --- Toggle the chat window. ---@param config CopilotChat.config|nil ----@param source CopilotChat.config.source? -function M.toggle(config, source) +function M.toggle(config) if state.chat:visible() then M.close() else - M.open(config, source) + M.open(config) end end @@ -472,11 +476,10 @@ end --- Ask a question to the Copilot model. ---@param prompt string ---@param config CopilotChat.config|CopilotChat.config.prompt|nil ----@param source CopilotChat.config.source? -function M.ask(prompt, config, source) +function M.ask(prompt, config) config = vim.tbl_deep_extend('force', M.config, config or {}) vim.diagnostic.reset(vim.api.nvim_create_namespace('copilot_diagnostics')) - M.open(config, source) + M.open(config) prompt = vim.trim(prompt or '') if prompt == '' then @@ -489,6 +492,14 @@ function M.ask(prompt, config, source) finish(config, nil, true) end + -- Clear the current input prompt before asking a new question + local chat_lines = vim.api.nvim_buf_get_lines(state.chat.bufnr, 0, -1, false) + local _, start_line, end_line = + find_lines_between_separator(chat_lines, #chat_lines, M.config.separator .. '$', nil, true) + if #chat_lines == end_line then + vim.api.nvim_buf_set_lines(state.chat.bufnr, start_line, end_line, false, { '' }) + end + state.chat:append(prompt) state.chat:append('\n\n' .. config.answer_header .. config.separator .. '\n\n') @@ -884,17 +895,15 @@ function M.setup(config) map_key(M.config.mappings.submit_prompt, bufnr, function() local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local lines, start_line, end_line = - find_lines_between_separator(chat_lines, 0, M.config.separator .. '$', nil, true) - local input = vim.trim(table.concat(lines, '\n')) - if input ~= '' then - -- If we are entering the input at the end, replace it - if #chat_lines == end_line then - vim.api.nvim_buf_set_lines(bufnr, start_line, end_line, false, { '' }) - end - - M.ask(input, state.config, state.source) - end + local current_line = vim.api.nvim_win_get_cursor(0)[1] + local lines = find_lines_between_separator( + chat_lines, + current_line, + M.config.separator .. '$', + nil, + true + ) + M.ask(vim.trim(table.concat(lines, '\n')), state.config) end) map_key(M.config.mappings.toggle_sticky, bufnr, function() @@ -909,7 +918,7 @@ function M.setup(config) local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) local _, start_line, end_line = - find_lines_between_separator(chat_lines, 0, M.config.separator .. '$', nil, true) + find_lines_between_separator(chat_lines, cur_line, M.config.separator .. '$', nil, true) if vim.startswith(current_line, '> ') then return @@ -947,10 +956,15 @@ function M.setup(config) end local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local current_line = vim.api.nvim_win_get_cursor(0)[1] local section_lines, start_line = - find_lines_between_separator(chat_lines, 0, M.config.separator .. '$') - local lines = - find_lines_between_separator(section_lines, start_line - 1, '^```%w+$', '^```$') + find_lines_between_separator(chat_lines, current_line, M.config.separator .. '$') + local lines = find_lines_between_separator( + section_lines, + current_line - start_line - 1, + '^```%w+$', + '^```$' + ) if #lines > 0 then vim.api.nvim_buf_set_text( state.source.bufnr, @@ -970,10 +984,15 @@ function M.setup(config) end local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local current_line = vim.api.nvim_win_get_cursor(0)[1] local section_lines, start_line = - find_lines_between_separator(chat_lines, 0, M.config.separator .. '$') - local lines = - find_lines_between_separator(section_lines, start_line - 1, '^```%w+$', '^```$') + find_lines_between_separator(chat_lines, current_line, M.config.separator .. '$') + local lines = find_lines_between_separator( + section_lines, + current_line - start_line - 1, + '^```%w+$', + '^```$' + ) if #lines > 0 then local content = table.concat(lines, '\n') vim.fn.setreg(M.config.mappings.yank_diff.register, content) @@ -987,10 +1006,16 @@ function M.setup(config) end local chat_lines = vim.api.nvim_buf_get_lines(state.chat.bufnr, 0, -1, false) + local current_line = vim.api.nvim_win_get_cursor(0)[1] local section_lines, start_line = - find_lines_between_separator(chat_lines, 0, M.config.separator .. '$') + find_lines_between_separator(chat_lines, current_line, M.config.separator .. '$') local lines = table.concat( - find_lines_between_separator(section_lines, start_line - 1, '^```%w+$', '^```$'), + find_lines_between_separator( + section_lines, + current_line - start_line - 1, + '^```%w+$', + '^```$' + ), '\n' ) if vim.trim(lines) ~= '' then