Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(adapters): fix gemini respect system prompt removed #529

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions lua/codecompanion/adapters/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,25 @@ return {
form_messages = function(self, messages)
-- Format system prompts
local system = utils.pluck_messages(vim.deepcopy(messages), "system")
for _, msg in ipairs(system) do
msg.text = msg.content

-- Remove unnecessary fields
msg.tag = nil
msg.content = nil
msg.role = nil
msg.id = nil
msg.opts = nil
local system_instruction

-- Only create system_instruction if there are system messages
if #system > 0 then
for _, msg in ipairs(system) do
msg.text = msg.content

-- Remove unnecessary fields
msg.tag = nil
msg.content = nil
msg.role = nil
msg.id = nil
msg.opts = nil
end
system_instruction = {
role = self.roles.user,
parts = system,
}
end
local sys_prompts = {
role = self.roles.user,
parts = system,
}

-- Format messages (remove all system prompts)
local output = {}
Expand All @@ -71,10 +76,15 @@ return {
})
end

return {
system_instruction = sys_prompts,
-- Only include system_instruction if it exists
local result = {
contents = output,
}
if system_instruction then
result.system_instruction = system_instruction
end

return result
end,

---Returns the number of tokens generated from the LLM
Expand Down
86 changes: 86 additions & 0 deletions tests/adapters/test_gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,92 @@ describe("Gemini adapter", function()
h.eq(output, adapter.handlers.form_messages(adapter, messages))
end)

it("can form messages with system prompt", function()
adapter = require("codecompanion.adapters").extend("gemini")
local messages_with_system = {
{
content = "You are a helpful assistant",
role = "system",
id = 1,
opts = { visible = false },
},
{
content = "hello",
id = 2,
opts = { visible = true },
role = "user",
},
{
content = "Hi, how can I help?",
id = 3,
opts = { visible = true },
role = "llm",
},
}

local output = {
system_instruction = {
role = "user",
parts = {
{ text = "You are a helpful assistant" },
},
},
contents = {
{
role = "user",
parts = {
{ text = "hello" },
},
},
{
role = "user",
parts = {
{ text = "Hi, how can I help?" },
},
},
},
}

h.eq(output, adapter.handlers.form_messages(adapter, messages_with_system))
end)

it("can form messages without system prompt", function()
adapter = require("codecompanion.adapters").extend("gemini")
local messages_without_system = {
{
content = "hello",
id = 1,
opts = { visible = true },
role = "user",
},
{
content = "Hi, how can I help?",
id = 2,
opts = { visible = true },
role = "llm",
},
}

local output = {
contents = {
{
role = "user",
parts = {
{ text = "hello" },
},
},
{
role = "user",
parts = {
{ text = "Hi, how can I help?" },
},
},
},
}

h.eq(output, adapter.handlers.form_messages(adapter, messages_without_system))
end)

it("can output streamed data into a format for the chat buffer", function()
h.eq(stream_response[#stream_response].output, adapter_helpers.chat_buffer_output(stream_response, adapter))
end)
Expand Down