Skip to content

Commit

Permalink
Support for GROQ API #65
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinfrlch committed Sep 15, 2024
1 parent d36428d commit a862e6f
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 98 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
</p>
<p align=center>
<img src=https://img.shields.io/badge/HACS-Custom-orange.svg?style=for-the-badg>
<img src=https://img.shields.io/badge/version-1.1.1-blue>
<img src=https://img.shields.io/badge/version-1.1.3-blue>
<a href="https://github.com/valentinfrlch/ha-llmvision/issues">
<img src="https://img.shields.io/maintenance/yes/2024.svg">
<img alt="Issues" src="https://img.shields.io/github/issues/valentinfrlch/ha-llmvision?color=0088ff"/>
Expand Down
10 changes: 5 additions & 5 deletions custom_components/llmvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
CONF_OPENAI_API_KEY,
CONF_ANTHROPIC_API_KEY,
CONF_GOOGLE_API_KEY,
CONF_GROQ_API_KEY,
CONF_LOCALAI_IP_ADDRESS,
CONF_LOCALAI_PORT,
CONF_LOCALAI_HTTPS,
Expand All @@ -28,19 +29,15 @@
)
from .request_handlers import RequestHandler
from .media_handlers import MediaProcessor
import logging
from homeassistant.core import SupportsResponse
from homeassistant.exceptions import ServiceValidationError

_LOGGER = logging.getLogger(__name__)


async def async_setup_entry(hass, entry):
"""Save config entry to hass.data"""
# Get all entries from config flow
openai_api_key = entry.data.get(CONF_OPENAI_API_KEY)
anthropic_api_key = entry.data.get(CONF_ANTHROPIC_API_KEY)
google_api_key = entry.data.get(CONF_GOOGLE_API_KEY)
groq_api_key = entry.data.get(CONF_GROQ_API_KEY)
localai_ip_address = entry.data.get(CONF_LOCALAI_IP_ADDRESS)
localai_port = entry.data.get(CONF_LOCALAI_PORT)
localai_https = entry.data.get(CONF_LOCALAI_HTTPS)
Expand All @@ -61,6 +58,7 @@ async def async_setup_entry(hass, entry):
CONF_OPENAI_API_KEY: openai_api_key,
CONF_ANTHROPIC_API_KEY: anthropic_api_key,
CONF_GOOGLE_API_KEY: google_api_key,
CONF_GROQ_API_KEY: groq_api_key,
CONF_LOCALAI_IP_ADDRESS: localai_ip_address,
CONF_LOCALAI_PORT: localai_port,
CONF_LOCALAI_HTTPS: localai_https,
Expand Down Expand Up @@ -106,6 +104,8 @@ def _default_model(self, provider):
return "claude-3-5-sonnet-20240620"
elif provider == "Google":
return "gemini-1.5-flash-latest"
elif provider == "Groq":
return "llava-v1.5-7b-4096-preview"
elif provider == "LocalAI":
return "gpt-4-vision-preview"
elif provider == "Ollama":
Expand Down
52 changes: 50 additions & 2 deletions custom_components/llmvision/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CONF_OPENAI_API_KEY,
CONF_ANTHROPIC_API_KEY,
CONF_GOOGLE_API_KEY,
CONF_GROQ_API_KEY,
CONF_LOCALAI_IP_ADDRESS,
CONF_LOCALAI_PORT,
CONF_LOCALAI_HTTPS,
Expand Down Expand Up @@ -64,11 +65,22 @@ async def _validate_api_key(self, api_key):
payload = {
"contents": [{
"parts": [
{"text": "Hello, world!"}
{"text": "Hello"}
]}
]
}
method = "POST"
elif self.user_input["provider"] == "Groq":
header = {
'Authorization': 'Bearer ' + api_key,
'Content-Type': 'application/json'
}
base_url = "api.groq.com"
endpoint = "/openai/v1/chat/completions"
payload = {"messages": [
{"role": "user", "content": "Hello"}], "model": "gemma-7b-it"}
method = "POST"

return await self._handshake(base_url=base_url, endpoint=endpoint, protocol="https", header=header, payload=payload, expected_status=200, method=method)

def _validate_provider(self):
Expand Down Expand Up @@ -158,6 +170,12 @@ async def google(self):
_LOGGER.error("Could not connect to Google server.")
raise ServiceValidationError("handshake_failed")

async def groq(self):
self._validate_provider()
if not await self._validate_api_key(self.user_input[CONF_GROQ_API_KEY]):
_LOGGER.error("Could not connect to Groq server.")
raise ServiceValidationError("handshake_failed")

def get_configured_providers(self):
providers = []
try:
Expand All @@ -177,6 +195,8 @@ def get_configured_providers(self):
providers.append("Ollama")
if CONF_CUSTOM_OPENAI_ENDPOINT in self.hass.data[DOMAIN]:
providers.append("Custom OpenAI")
if CONF_GROQ_API_KEY in self.hass.data[DOMAIN]:
providers.append("Groq")
return providers


Expand All @@ -193,6 +213,7 @@ async def handle_provider(self, provider, configured_providers):
"OpenAI": self.async_step_openai,
"Anthropic": self.async_step_anthropic,
"Google": self.async_step_google,
"Groq": self.async_step_groq,
"Ollama": self.async_step_ollama,
"LocalAI": self.async_step_localai,
"Custom OpenAI": self.async_step_custom_openai,
Expand All @@ -209,7 +230,7 @@ async def async_step_user(self, user_input=None):
data_schema = vol.Schema({
vol.Required("provider", default="OpenAI"): selector({
"select": {
"options": ["OpenAI", "Anthropic", "Google", "Ollama", "LocalAI", "Custom OpenAI"],
"options": ["OpenAI", "Anthropic", "Google", "Groq", "Ollama", "LocalAI", "Custom OpenAI"],
"mode": "dropdown",
"sort": False,
"custom_value": False
Expand Down Expand Up @@ -369,6 +390,33 @@ async def async_step_google(self, user_input=None):
data_schema=data_schema,
)

async def async_step_groq(self, user_input=None):
data_schema = vol.Schema({
vol.Required(CONF_GROQ_API_KEY): str,
})

if user_input is not None:
# save provider to user_input
user_input["provider"] = self.init_info["provider"]
validator = Validator(self.hass, user_input)
try:
await validator.groq()
# add the mode to user_input
user_input["provider"] = self.init_info["provider"]
return self.async_create_entry(title="LLM Vision Groq", data=user_input)
except ServiceValidationError as e:
_LOGGER.error(f"Validation failed: {e}")
return self.async_show_form(
step_id="groq",
data_schema=data_schema,
errors={"base": "handshake_failed"}
)

return self.async_show_form(
step_id="groq",
data_schema=data_schema,
)

async def async_step_custom_openai(self, user_input=None):
data_schema = vol.Schema({
vol.Required(CONF_CUSTOM_OPENAI_ENDPOINT): str,
Expand Down
2 changes: 2 additions & 0 deletions custom_components/llmvision/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CONF_OPENAI_API_KEY = 'openai_api_key'
CONF_ANTHROPIC_API_KEY = 'anthropic_api_key'
CONF_GOOGLE_API_KEY = 'google_api_key'
CONF_GROQ_API_KEY = 'groq_api_key'
CONF_LOCALAI_IP_ADDRESS = 'localai_ip'
CONF_LOCALAI_PORT = 'localai_port'
CONF_LOCALAI_HTTPS = 'localai_https'
Expand Down Expand Up @@ -49,5 +50,6 @@
ENDPOINT_OPENAI = "https://api.openai.com/v1/chat/completions"
ENDPOINT_ANTHROPIC = "https://api.anthropic.com/v1/messages"
ENDPOINT_GOOGLE = "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
ENDPOINT_GROQ = "https://api.groq.com/openai/v1/chat/completions"
ENDPOINT_LOCALAI = "{protocol}://{ip_address}:{port}/v1/chat/completions"
ENDPOINT_OLLAMA = "{protocol}://{ip_address}:{port}/api/chat"
2 changes: 1 addition & 1 deletion custom_components/llmvision/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
"documentation": "https://github.com/valentinfrlch/ha-llmvision",
"iot_class": "cloud_polling",
"issue_tracker": "https://github.com/valentinfrlch/ha-llmvision/issues",
"version": "1.1.2"
"version": "1.1.3"
}
127 changes: 58 additions & 69 deletions custom_components/llmvision/request_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CONF_OPENAI_API_KEY,
CONF_ANTHROPIC_API_KEY,
CONF_GOOGLE_API_KEY,
CONF_GROQ_API_KEY,
CONF_LOCALAI_IP_ADDRESS,
CONF_LOCALAI_PORT,
CONF_LOCALAI_HTTPS,
Expand All @@ -17,6 +18,7 @@
CONF_CUSTOM_OPENAI_API_KEY,
VERSION_ANTHROPIC,
ENDPOINT_OPENAI,
ENDPOINT_GROQ,
ERROR_OPENAI_NOT_CONFIGURED,
ERROR_ANTHROPIC_NOT_CONFIGURED,
ERROR_GOOGLE_NOT_CONFIGURED,
Expand All @@ -26,8 +28,6 @@
)

_LOGGER = logging.getLogger(__name__)
# base64_pattern = re.compile(r'([A-Za-z0-9+/=]{1000,})')


def sanitize_data(data):
"""Remove long string data from request data to reduce log size"""
Expand All @@ -53,7 +53,6 @@ def __init__(self, hass, message, max_tokens, temperature, detail):
self.filenames = []

async def make_request(self, call):
_LOGGER.debug(f"Base64 Images: {sanitize_data(self.base64_images)}")
if call.provider == 'OpenAI':
api_key = self.hass.data.get(DOMAIN).get(CONF_OPENAI_API_KEY)
model = call.model
Expand All @@ -75,6 +74,13 @@ async def make_request(self, call):
api_key=api_key,
base64_images=self.base64_images)
response_text = await self.google(model=model, api_key=api_key)
elif call.provider == 'Groq':
api_key = self.hass.data.get(DOMAIN).get(CONF_GROQ_API_KEY)
model = call.model
self._validate_call(provider=call.provider,
api_key=api_key,
base64_images=self.base64_images)
response_text = await self.groq(model=model, api_key=api_key)
elif call.provider == 'LocalAI':
ip_address = self.hass.data.get(
DOMAIN, {}).get(CONF_LOCALAI_IP_ADDRESS)
Expand All @@ -93,9 +99,11 @@ async def make_request(self, call):
port=port,
https=https)
elif call.provider == 'Ollama':
ip_address = self.hass.data.get(DOMAIN, {}).get(CONF_OLLAMA_IP_ADDRESS)
ip_address = self.hass.data.get(
DOMAIN, {}).get(CONF_OLLAMA_IP_ADDRESS)
port = self.hass.data.get(DOMAIN, {}).get(CONF_OLLAMA_PORT)
https = self.hass.data.get(DOMAIN, {}).get(CONF_OLLAMA_HTTPS, False)
https = self.hass.data.get(DOMAIN, {}).get(
CONF_OLLAMA_HTTPS, False)
model = call.model
self._validate_call(provider=call.provider,
api_key=None,
Expand All @@ -109,7 +117,8 @@ async def make_request(self, call):
elif call.provider == 'Custom OpenAI':
api_key = self.hass.data.get(DOMAIN).get(
CONF_CUSTOM_OPENAI_API_KEY, "")
endpoint = self.hass.data.get(DOMAIN).get(CONF_CUSTOM_OPENAI_ENDPOINT)
endpoint = self.hass.data.get(DOMAIN).get(
CONF_CUSTOM_OPENAI_ENDPOINT)

model = call.model
self._validate_call(provider=call.provider,
Expand Down Expand Up @@ -246,6 +255,36 @@ async def google(self, model, api_key):
"content").get("parts")[0].get("text")
return response_text

async def groq(self, model, api_key, endpoint=ENDPOINT_GROQ):
first_image = self.base64_images[0]
# Set headers and payload
headers = {'Content-type': 'application/json',
'Authorization': 'Bearer ' + api_key}
data = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": self.message},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{first_image}"}
}
]
}
],
"model": model
}

response = await self._post(
url=endpoint, headers=headers, data=data)

print(response)

response_text = response.get(
"choices")[0].get("message").get("content")
return response_text

async def localai(self, model, ip_address, port, https):
from .const import ENDPOINT_LOCALAI
data = {"model": model,
Expand Down Expand Up @@ -312,8 +351,6 @@ async def ollama(self, model, ip_address, port, https):
async def _post(self, url, headers, data):
"""Post data to url and return response data"""
_LOGGER.info(f"Request data: {sanitize_data(data)}")
_LOGGER.debug(
f"URL type: {type(url)}, Headers type: {type(headers)}, Data type: {type(data)}")

try:
response = await self.session.post(url, headers=headers, json=data)
Expand All @@ -322,7 +359,8 @@ async def _post(self, url, headers, data):

if response.status != 200:
provider = inspect.stack()[1].function
_LOGGER.debug(f"Provider: {provider}")
_LOGGER.error(
f"Provider {provider} failed with status code {response.status}")
parsed_response = self._resolve_error(url, response, provider)
raise ServiceValidationError(parsed_response)
else:
Expand Down Expand Up @@ -366,74 +404,25 @@ def _validate_call(self, provider, api_key, base64_images, ip_address=None, port
elif provider == 'Ollama':
if not ip_address or not port:
raise ServiceValidationError(ERROR_OLLAMA_NOT_CONFIGURED)
elif provider == 'Custom OpenAI':
pass
else:
raise ServiceValidationError("invalid_provider")
# Check media input
if base64_images == []:
raise ServiceValidationError(ERROR_NO_IMAGE_INPUT)

def _resolve_error(self, url, response, provider):
"""Translate response status to error message"""
if provider == "openai":
if response.status == 401:
return "Invalid Authentication. Ensure you are using a valid API key."
if response.status == 403:
return "Country, region, or territory not supported."
if response.status == 404:
return "The requested model does not exist."
if response.status == 429:
return "Rate limit exceeded. You are sending requests too quickly."
if response.status == 500:
return "Issue on OpenAI's servers. Wait a few minutes and try again."
if response.status == 503:
return "OpenAI's Servers are experiencing high traffic. Try again later."
else:
return f"Error: {response}"
return f"Error: {response.get('error', {}).get('message', 'Unknown error')}"
elif provider == "anthropic":
if response.status == 400:
return "Invalid Request. There was an issue with the format or content of your request."
if response.status == 401:
return "Invalid Authentication. Ensure you are using a valid API key."
if response.status == 403:
return "Access Error. Your API key does not have permission to use the specified resource."
if response.status == 404:
return "The requested model does not exist."
if response.status == 429:
return "Rate limit exceeded. You are sending requests too quickly."
if response.status == 500:
return "Issue on Anthropic's servers. Wait a few minutes and try again."
if response.status == 529:
return "Anthropic's Servers are experiencing high traffic. Try again later."
else:
return f"Error: {response}"
return f"Error: {response.get('error', {}).get('message', 'Unknown error')}"
elif provider == "google":
if response.status == 400:
return "Max input tokens exceeded. Reduce target width or number of images."
if response.status == 403:
return "Access Error. Your API key does not have permission to use the specified resource."
if response.status == 404:
return "The requested model does not exist."
if response.status == 406:
return "Insufficient Funds. Ensure you have enough credits to use the service."
if response.status == 429:
return "Rate limit exceeded. You are sending requests too quickly."
if response.status == 503:
return "Google's Servers are temporarily overloaded or down. Try again later."
else:
return f"Error: {response}"
return f"Error: {response.get('error', {}).get('message', 'Unknown error')}"
elif provider == "groq":
return f"Error: {response.get('error', {}).get('message', 'Unknown error')}"
elif provider == "ollama":
if response.status == 400:
return "Invalid Request. There was an issue with the format or content of your request."
if response.status == 404:
return "The requested model does not exist."
if response.status == 500:
return "Internal server issue (on Ollama server)."
else:
return f"Error: {response}"
return f"Error: {response.get('error', {}).get('message', 'Unknown error')}"
elif provider == "localai":
if response.status == 400:
return "Invalid Request. There was an issue with the format or content of your request."
if response.status == 404:
return "The requested model does not exist."
if response.status == 500:
return "Internal server issue (on LocalAI server)."
else:
return f"Error: {response}"
return f"Error: {response.get('error', {}).get('message', 'Unknown error')}"
Loading

0 comments on commit a862e6f

Please sign in to comment.