From c619829ca9bebc8c69b438ba8d8f407d6197b1e0 Mon Sep 17 00:00:00 2001 From: Bassam Data <105807570+bassamsdata@users.noreply.github.com> Date: Wed, 11 Dec 2024 21:28:34 -0500 Subject: [PATCH 1/2] fix(adapters): fix gemini respect system prompt removed --- lua/codecompanion/adapters/gemini.lua | 40 +++++++++++++++++---------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/lua/codecompanion/adapters/gemini.lua b/lua/codecompanion/adapters/gemini.lua index 32ae5410..b38a5884 100644 --- a/lua/codecompanion/adapters/gemini.lua +++ b/lua/codecompanion/adapters/gemini.lua @@ -44,20 +44,25 @@ return { form_messages = function(self, messages) -- Format system prompts local system = utils.pluck_messages(vim.deepcopy(messages), "system") - for _, msg in ipairs(system) do - msg.text = msg.content - - -- Remove unnecessary fields - msg.tag = nil - msg.content = nil - msg.role = nil - msg.id = nil - msg.opts = nil + local system_instruction + + -- Only create system_instruction if there are system messages + if #system > 0 then + for _, msg in ipairs(system) do + msg.text = msg.content + + -- Remove unnecessary fields + msg.tag = nil + msg.content = nil + msg.role = nil + msg.id = nil + msg.opts = nil + end + system_instruction = { + role = self.roles.user, + parts = system, + } end - local sys_prompts = { - role = self.roles.user, - parts = system, - } -- Format messages (remove all system prompts) local output = {} @@ -71,10 +76,15 @@ return { }) end - return { - system_instruction = sys_prompts, + -- Only include system_instruction if it exists + local result = { contents = output, } + if system_instruction then + result.system_instruction = system_instruction + end + + return result end, ---Returns the number of tokens generated from the LLM From 92193bb16679ae606aa56772540bd084eb3997a0 Mon Sep 17 00:00:00 2001 From: Bassam Data <105807570+bassamsdata@users.noreply.github.com> Date: Wed, 11 Dec 2024 22:35:11 -0500 Subject: [PATCH 2/2] tests: adding some tests for gemini adapter. --- tests/adapters/test_gemini.lua | 86 ++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/tests/adapters/test_gemini.lua b/tests/adapters/test_gemini.lua index 39b589d6..efd76187 100644 --- a/tests/adapters/test_gemini.lua +++ b/tests/adapters/test_gemini.lua @@ -65,6 +65,92 @@ describe("Gemini adapter", function() h.eq(output, adapter.handlers.form_messages(adapter, messages)) end) + it("can form messages with system prompt", function() + adapter = require("codecompanion.adapters").extend("gemini") + local messages_with_system = { + { + content = "You are a helpful assistant", + role = "system", + id = 1, + opts = { visible = false }, + }, + { + content = "hello", + id = 2, + opts = { visible = true }, + role = "user", + }, + { + content = "Hi, how can I help?", + id = 3, + opts = { visible = true }, + role = "llm", + }, + } + + local output = { + system_instruction = { + role = "user", + parts = { + { text = "You are a helpful assistant" }, + }, + }, + contents = { + { + role = "user", + parts = { + { text = "hello" }, + }, + }, + { + role = "user", + parts = { + { text = "Hi, how can I help?" }, + }, + }, + }, + } + + h.eq(output, adapter.handlers.form_messages(adapter, messages_with_system)) + end) + + it("can form messages without system prompt", function() + adapter = require("codecompanion.adapters").extend("gemini") + local messages_without_system = { + { + content = "hello", + id = 1, + opts = { visible = true }, + role = "user", + }, + { + content = "Hi, how can I help?", + id = 2, + opts = { visible = true }, + role = "llm", + }, + } + + local output = { + contents = { + { + role = "user", + parts = { + { text = "hello" }, + }, + }, + { + role = "user", + parts = { + { text = "Hi, how can I help?" }, + }, + }, + }, + } + + h.eq(output, adapter.handlers.form_messages(adapter, messages_without_system)) + end) + it("can output streamed data into a format for the chat buffer", function() h.eq(stream_response[#stream_response].output, adapter_helpers.chat_buffer_output(stream_response, adapter)) end)