Skip to content

Commit

Permalink
Update config_flow.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ITSpecialist111 committed Oct 19, 2024
1 parent 7a9bb62 commit 8a5a447
Showing 1 changed file with 61 additions and 7 deletions.
68 changes: 61 additions & 7 deletions custom_components/ai_automation_suggester/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,24 @@ async def async_step_user(self, user_input=None):
"""Handle the initial step."""
errors = {}
if user_input is not None:
# Validate API key if using cloud AI
if not user_input.get("use_local_ai") and not user_input.get("openai_api_key"):
errors["openai_api_key"] = "required"
else:
return self.async_create_entry(title="AI Automation Suggester", data=user_input)
try:
# Validate API key if using cloud AI
if not user_input.get("use_local_ai") and not user_input.get("openai_api_key"):
errors["openai_api_key"] = "required"
else:
# Validate the OpenAI API key
if not user_input.get("use_local_ai"):
await self.hass.async_add_executor_job(
self.validate_openai_api_key,
user_input.get("openai_api_key")
)
return self.async_create_entry(title="AI Automation Suggester", data=user_input)
except ValueError as e:
_LOGGER.error(f"Error during config flow: {e}")
errors["base"] = "invalid_api_key"
except Exception as e:
_LOGGER.error(f"Unexpected error during config flow: {e}")
errors["base"] = "cannot_connect"

data_schema = vol.Schema({
vol.Required("scan_frequency", default=24): vol.All(vol.Coerce(int), vol.Range(min=1)),
Expand All @@ -33,9 +46,21 @@ async def async_step_user(self, user_input=None):
})
return self.async_show_form(step_id="user", data_schema=data_schema, errors=errors)

def validate_openai_api_key(self, api_key):
"""Validate the OpenAI API key."""
import openai
openai.api_key = api_key
try:
openai.Engine.list()
except openai.error.AuthenticationError:
raise ValueError("Invalid OpenAI API key")
except Exception as e:
raise e

@staticmethod
@callback
def async_get_options_flow(config_entry):
"""Get the options flow."""
return AIAutomationOptionsFlowHandler(config_entry)


Expand All @@ -48,8 +73,26 @@ def __init__(self, config_entry):

async def async_step_init(self, user_input=None):
"""Manage the AI Automation Suggester options."""
errors = {}
if user_input is not None:
return self.async_create_entry(title="", data=user_input)
try:
# Validate API key if using cloud AI
if not user_input.get("use_local_ai") and not user_input.get("openai_api_key"):
errors["openai_api_key"] = "required"
else:
# Validate the OpenAI API key
if not user_input.get("use_local_ai"):
await self.hass.async_add_executor_job(
self.validate_openai_api_key,
user_input.get("openai_api_key")
)
return self.async_create_entry(title="", data=user_input)
except ValueError as e:
_LOGGER.error(f"Error during options flow: {e}")
errors["base"] = "invalid_api_key"
except Exception as e:
_LOGGER.error(f"Unexpected error during options flow: {e}")
errors["base"] = "cannot_connect"

data_schema = vol.Schema({
vol.Required("scan_frequency", default=self.config_entry.options.get("scan_frequency", 24)):
Expand All @@ -58,4 +101,15 @@ async def async_step_init(self, user_input=None):
vol.Optional("openai_api_key", default=self.config_entry.options.get("openai_api_key", "")): str,
})

return self.async_show_form(step_id="init", data_schema=data_schema)
return self.async_show_form(step_id="init", data_schema=data_schema, errors=errors)

def validate_openai_api_key(self, api_key):
"""Validate the OpenAI API key."""
import openai
openai.api_key = api_key
try:
openai.Engine.list()
except openai.error.AuthenticationError:
raise ValueError("Invalid OpenAI API key")
except Exception as e:
raise e

0 comments on commit 8a5a447

Please sign in to comment.