Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setting LLM temperature with VoiceAssistant #741

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/shiny-baboons-jog.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

Allow setting LLM temperature with VoiceAssistant
6 changes: 5 additions & 1 deletion livekit-agents/livekit/agents/voice_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@
VoiceAssistant,
)

__all__ = ["VoiceAssistant", "AssistantCallContext", "AssistantTranscriptionOptions"]
__all__ = [
"VoiceAssistant",
"AssistantCallContext",
"AssistantTranscriptionOptions",
]
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def llm_stream(self) -> LLMStream:
def _default_before_llm_cb(
assistant: VoiceAssistant, chat_ctx: ChatContext
) -> LLMStream:
return assistant.llm.chat(chat_ctx=chat_ctx, fnc_ctx=assistant.fnc_ctx)
return assistant.llm.chat(
chat_ctx=chat_ctx,
fnc_ctx=assistant.fnc_ctx,
)


def _default_before_tts_cb(
Expand Down Expand Up @@ -690,7 +693,8 @@ def _commit_user_question_if_needed() -> None:
chat_ctx.messages.extend(extra_tools_messages)

answer_llm_stream = self._llm.chat(
chat_ctx=chat_ctx, fnc_ctx=self._fnc_ctx
chat_ctx=chat_ctx,
fnc_ctx=self._fnc_ctx,
)
answer_synthesis = self._synthesize_agent_speech(
speech_handle.id, answer_llm_stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
class LLMOptions:
model: str | ChatModels
user: str | None
temperature: float | None


class LLM(llm.LLM):
Expand All @@ -48,20 +49,20 @@ def __init__(
base_url: str | None = None,
user: str | None = None,
client: anthropic.AsyncClient | None = None,
temperature: float | None = None,
) -> None:
"""
Create a new instance of Anthropic LLM.

``api_key`` must be set to your Anthropic API key, either using the argument or by setting
the ``ANTHROPIC_API_KEY`` environmental variable.
"""

# throw an error on our end
api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
if api_key is None:
raise ValueError("Anthropic API key is required")

self._opts = LLMOptions(model=model, user=user)
self._opts = LLMOptions(model=model, user=user, temperature=temperature)
self._client = client or anthropic.AsyncClient(
api_key=api_key,
base_url=base_url,
Expand All @@ -85,6 +86,9 @@ def chat(
n: int | None = 1,
parallel_tool_calls: bool | None = None,
) -> "LLMStream":
if temperature is None:
temperature = self._opts.temperature

opts: dict[str, Any] = dict()
if fnc_ctx and len(fnc_ctx.ai_functions) > 0:
fncs_desc: list[anthropic.types.ToolParam] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
class LLMOptions:
model: str | ChatModels
user: str | None
temperature: float | None


class LLM(llm.LLM):
Expand All @@ -54,20 +55,20 @@ def __init__(
base_url: str | None = None,
user: str | None = None,
client: openai.AsyncClient | None = None,
temperature: float | None = None,
) -> None:
"""
Create a new instance of OpenAI LLM.

``api_key`` must be set to your OpenAI API key, either using the argument or by setting the
``OPENAI_API_KEY`` environmental variable.
"""

# throw an error on our end
api_key = api_key or os.environ.get("OPENAI_API_KEY")
if api_key is None:
raise ValueError("OpenAI API key is required")

self._opts = LLMOptions(model=model, user=user)
self._opts = LLMOptions(model=model, user=user, temperature=temperature)
self._client = client or openai.AsyncClient(
api_key=api_key,
base_url=base_url,
Expand Down Expand Up @@ -97,6 +98,7 @@ def with_azure(
project: str | None = None,
base_url: str | None = None,
user: str | None = None,
temperature: float | None = None,
) -> LLM:
"""
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
Expand All @@ -120,7 +122,7 @@ def with_azure(
base_url=base_url,
) # type: ignore

return LLM(model=model, client=azure_client, user=user)
return LLM(model=model, client=azure_client, user=user, temperature=temperature)

@staticmethod
def with_cerebras(
Expand All @@ -130,6 +132,7 @@ def with_cerebras(
base_url: str | None = "https://api.cerebras.ai/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
) -> LLM:
"""
Create a new instance of Cerebras LLM.
Expand All @@ -144,7 +147,12 @@ def with_cerebras(
raise ValueError("Cerebras API key is required")

return LLM(
model=model, api_key=api_key, base_url=base_url, client=client, user=user
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
)

@staticmethod
Expand All @@ -155,6 +163,7 @@ def with_fireworks(
base_url: str | None = "https://api.fireworks.ai/inference/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
) -> LLM:
"""
Create a new instance of Fireworks LLM.
Expand All @@ -169,7 +178,12 @@ def with_fireworks(
raise ValueError("Fireworks API key is required")

return LLM(
model=model, api_key=api_key, base_url=base_url, client=client, user=user
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
)

@staticmethod
Expand All @@ -180,6 +194,7 @@ def with_groq(
base_url: str | None = "https://api.groq.com/openai/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
) -> LLM:
"""
Create a new instance of Groq LLM.
Expand All @@ -194,7 +209,12 @@ def with_groq(
raise ValueError("Groq API key is required")

return LLM(
model=model, api_key=api_key, base_url=base_url, client=client, user=user
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
)

@staticmethod
Expand All @@ -205,6 +225,7 @@ def with_deepseek(
base_url: str | None = "https://api.deepseek.com/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
) -> LLM:
"""
Create a new instance of DeepSeek LLM.
Expand All @@ -219,7 +240,12 @@ def with_deepseek(
raise ValueError("DeepSeek API key is required")

return LLM(
model=model, api_key=api_key, base_url=base_url, client=client, user=user
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
)

@staticmethod
Expand All @@ -230,6 +256,7 @@ def with_octo(
base_url: str | None = "https://text.octoai.run/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
) -> LLM:
"""
Create a new instance of OctoAI LLM.
Expand All @@ -244,7 +271,12 @@ def with_octo(
raise ValueError("OctoAI API key is required")

return LLM(
model=model, api_key=api_key, base_url=base_url, client=client, user=user
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
)

@staticmethod
Expand All @@ -253,12 +285,19 @@ def with_ollama(
model: str = "llama3.1",
base_url: str | None = "http://localhost:11434/v1",
client: openai.AsyncClient | None = None,
temperature: float | None = None,
) -> LLM:
"""
Create a new instance of Ollama LLM.
"""

return LLM(model=model, api_key="ollama", base_url=base_url, client=client)
return LLM(
model=model,
api_key="ollama",
base_url=base_url,
client=client,
temperature=temperature,
)

@staticmethod
def with_perplexity(
Expand All @@ -268,9 +307,15 @@ def with_perplexity(
base_url: str | None = "https://api.perplexity.ai",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
) -> LLM:
return LLM(
model=model, api_key=api_key, base_url=base_url, client=client, user=user
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
)

@staticmethod
Expand All @@ -281,6 +326,7 @@ def with_together(
base_url: str | None = "https://api.together.xyz/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
) -> LLM:
"""
Create a new instance of TogetherAI LLM.
Expand All @@ -295,7 +341,12 @@ def with_together(
raise ValueError("TogetherAI API key is required")

return LLM(
model=model, api_key=api_key, base_url=base_url, client=client, user=user
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
)

@staticmethod
Expand All @@ -312,6 +363,7 @@ def create_azure_client(
project: str | None = None,
base_url: str | None = None,
user: str | None = None,
temperature: float | None = None,
) -> LLM:
logger.warning("This alias is deprecated. Use LLM.with_azure() instead")
return LLM.with_azure(
Expand All @@ -325,6 +377,7 @@ def create_azure_client(
project=project,
base_url=base_url,
user=user,
temperature=temperature,
)

def chat(
Expand All @@ -348,6 +401,8 @@ def chat(
opts["parallel_tool_calls"] = parallel_tool_calls

user = self._opts.user or openai.NOT_GIVEN
if temperature is None:
temperature = self._opts.temperature

messages = _build_oai_context(chat_ctx, id(self))
cmp = self._client.chat.completions.create(
Expand Down
Loading