Skip to content

Commit

Permalink
fix(adapters): #528 Gemini adapter fails if system prompt is removed
Browse files Browse the repository at this point in the history
  • Loading branch information
bassamsdata authored Dec 12, 2024
1 parent d4d6b21 commit 7ca4364
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 15 deletions.
40 changes: 25 additions & 15 deletions lua/codecompanion/adapters/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,25 @@ return {
form_messages = function(self, messages)
-- Format system prompts
local system = utils.pluck_messages(vim.deepcopy(messages), "system")
for _, msg in ipairs(system) do
msg.text = msg.content

-- Remove unnecessary fields
msg.tag = nil
msg.content = nil
msg.role = nil
msg.id = nil
msg.opts = nil
local system_instruction

-- Only create system_instruction if there are system messages
if #system > 0 then
for _, msg in ipairs(system) do
msg.text = msg.content

-- Remove unnecessary fields
msg.tag = nil
msg.content = nil
msg.role = nil
msg.id = nil
msg.opts = nil
end
system_instruction = {
role = self.roles.user,
parts = system,
}
end
local sys_prompts = {
role = self.roles.user,
parts = system,
}

-- Format messages (remove all system prompts)
local output = {}
Expand All @@ -71,10 +76,15 @@ return {
})
end

return {
system_instruction = sys_prompts,
-- Only include system_instruction if it exists
local result = {
contents = output,
}
if system_instruction then
result.system_instruction = system_instruction
end

return result
end,

---Returns the number of tokens generated from the LLM
Expand Down
86 changes: 86 additions & 0 deletions tests/adapters/test_gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,92 @@ describe("Gemini adapter", function()
h.eq(output, adapter.handlers.form_messages(adapter, messages))
end)

it("can form messages with system prompt", function()
adapter = require("codecompanion.adapters").extend("gemini")
local messages_with_system = {
{
content = "You are a helpful assistant",
role = "system",
id = 1,
opts = { visible = false },
},
{
content = "hello",
id = 2,
opts = { visible = true },
role = "user",
},
{
content = "Hi, how can I help?",
id = 3,
opts = { visible = true },
role = "llm",
},
}

local output = {
system_instruction = {
role = "user",
parts = {
{ text = "You are a helpful assistant" },
},
},
contents = {
{
role = "user",
parts = {
{ text = "hello" },
},
},
{
role = "user",
parts = {
{ text = "Hi, how can I help?" },
},
},
},
}

h.eq(output, adapter.handlers.form_messages(adapter, messages_with_system))
end)

it("can form messages without system prompt", function()
adapter = require("codecompanion.adapters").extend("gemini")
local messages_without_system = {
{
content = "hello",
id = 1,
opts = { visible = true },
role = "user",
},
{
content = "Hi, how can I help?",
id = 2,
opts = { visible = true },
role = "llm",
},
}

local output = {
contents = {
{
role = "user",
parts = {
{ text = "hello" },
},
},
{
role = "user",
parts = {
{ text = "Hi, how can I help?" },
},
},
},
}

h.eq(output, adapter.handlers.form_messages(adapter, messages_without_system))
end)

it("can output streamed data into a format for the chat buffer", function()
h.eq(stream_response[#stream_response].output, adapter_helpers.chat_buffer_output(stream_response, adapter))
end)
Expand Down

0 comments on commit 7ca4364

Please sign in to comment.