Skip to content

Commit

Permalink
Merge pull request #160 from F33RNI/next
Browse files Browse the repository at this point in the history
Next
  • Loading branch information
F33RNI authored Apr 22, 2024
2 parents 111f9ce + f6331d7 commit f8a36e8
Show file tree
Hide file tree
Showing 27 changed files with 1,408 additions and 50 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ venv
env
.env
__pycache__
certificate.*
private.*
2 changes: 1 addition & 1 deletion _version.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from packaging import version

__version__ = "5.2.10"
__version__ = "5.4.2"


def version_major() -> int:
Expand Down
212 changes: 205 additions & 7 deletions bot_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
BOT_COMMAND_CHAT = "chat"
BOT_COMMAND_MODULE = "module"
BOT_COMMAND_STYLE = "style"
BOT_COMMAND_MODEL = "model"
BOT_COMMAND_CLEAR = "clear"
BOT_COMMAND_LANG = "lang"
BOT_COMMAND_CHAT_ID = "chatid"
Expand Down Expand Up @@ -121,6 +122,8 @@ def __init__(
logging_queue: multiprocessing.Queue,
queue_handler_: queue_handler.QueueHandler,
modules: Dict,
web_cooldown_timer: multiprocessing.Value,
web_request_lock: multiprocessing.Lock,
):
self.config = config
self.config_file = config_file
Expand All @@ -130,6 +133,10 @@ def __init__(
self.queue_handler = queue_handler_
self.modules = modules

# LMAO
self.web_cooldown_timer = web_cooldown_timer
self.web_request_lock = web_request_lock

self.prevent_shutdown_flag = multiprocessing.Value(c_bool, False)

self._application = None
Expand Down Expand Up @@ -159,7 +166,17 @@ def start_bot(self):

# Build bot
telegram_config = self.config.get("telegram")
builder = ApplicationBuilder().token(telegram_config.get("api_key"))
proxy = telegram_config.get("proxy")
if proxy:
logging.info(f"Using proxy {proxy} for Telegram bot")
builder = (
ApplicationBuilder()
.token(telegram_config.get("api_key"))
.proxy(proxy)
.get_updates_proxy(proxy)
)
else:
builder = ApplicationBuilder().token(telegram_config.get("api_key"))
self._application = builder.build()

# Set commands
Expand All @@ -172,6 +189,7 @@ def start_bot(self):
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_CHAT, self.bot_module_request))
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_MODULE, self.bot_command_module))
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_STYLE, self.bot_command_style))
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_MODEL, self.bot_command_model))
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_CLEAR, self.bot_command_clear))
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_LANG, self.bot_command_lang))
self._application.add_handler(CaptionCommandHandler(BOT_COMMAND_CHAT_ID, self.bot_command_chatid))
Expand Down Expand Up @@ -302,15 +320,13 @@ async def query_callback(self, update: Update, context: ContextTypes.DEFAULT_TYP
return

# Parse data from markup
action, data_, reply_message_id = data_.split("|")
action, data_, argument_ = data_.split("|")
if not action:
raise Exception("No action in callback data")
if not data_:
data_ = None
if not reply_message_id:
reply_message_id = None
else:
reply_message_id = int(reply_message_id.strip())
if not argument_:
argument_ = None

# Get user
banned, user = await self._user_get_check(update, context, prompt_language_selection=False)
Expand All @@ -329,6 +345,12 @@ async def query_callback(self, update: Update, context: ContextTypes.DEFAULT_TYP

# Regenerate request
if action == "regenerate":
# Parse message ID
if not argument_:
reply_message_id = None
else:
reply_message_id = int(argument_.strip())

# Get last message ID
reply_message_id_last = self.users_handler.get_key(0, "reply_message_id_last", user=user)
if reply_message_id_last is None or reply_message_id_last != reply_message_id:
Expand Down Expand Up @@ -364,6 +386,12 @@ async def query_callback(self, update: Update, context: ContextTypes.DEFAULT_TYP

# Continue generating
elif action == "continue":
# Parse message ID
if not argument_:
reply_message_id = None
else:
reply_message_id = int(argument_.strip())

# Get last message ID
reply_message_id_last = self.users_handler.get_key(0, "reply_message_id_last", user=user)
if reply_message_id_last is None or reply_message_id_last != reply_message_id:
Expand All @@ -385,6 +413,12 @@ async def query_callback(self, update: Update, context: ContextTypes.DEFAULT_TYP

# Send suggestion
elif action == "suggestion":
# Parse message ID
if not argument_:
reply_message_id = None
else:
reply_message_id = int(argument_.strip())

# Get last message ID
reply_message_id_last = self.users_handler.get_key(0, "reply_message_id_last", user=user)
if reply_message_id_last is None or reply_message_id_last != reply_message_id:
Expand Down Expand Up @@ -421,6 +455,12 @@ async def query_callback(self, update: Update, context: ContextTypes.DEFAULT_TYP

# Stop generating
elif action == "stop":
# Parse message ID
if not argument_:
reply_message_id = None
else:
reply_message_id = int(argument_.strip())

# Get last message ID
reply_message_id_last = self.users_handler.get_key(0, "reply_message_id_last", user=user)
if reply_message_id_last is None or reply_message_id_last != reply_message_id:
Expand Down Expand Up @@ -466,6 +506,10 @@ async def query_callback(self, update: Update, context: ContextTypes.DEFAULT_TYP
elif action == "style":
await self._bot_command_style_raw(data_, user, context)

# Change model
elif action == "model":
await self._bot_command_model_raw(data_, argument_, user, context)

# Change language
elif action == "lang":
await self._bot_command_lang_raw(data_, user, context)
Expand Down Expand Up @@ -752,8 +796,20 @@ async def bot_command_restart(self, update: Update, context: ContextTypes.DEFAUL
continue
logging.info(f"Trying to load and initialize {module_name} module")
try:
use_web = (
module_name.startswith("lmao_")
and module_name in self.config.get("modules").get("lmao_web_for_modules", [])
and "lmao_web_api_url" in self.config.get("modules")
)
module = module_wrapper_global.ModuleWrapperGlobal(
module_name, self.config, self.messages, self.users_handler, self.logging_queue
module_name,
self.config,
self.messages,
self.users_handler,
self.logging_queue,
use_web=use_web,
web_cooldown_timer=self.web_cooldown_timer,
web_request_lock=self.web_request_lock,
)
self.modules[module_name] = module
reload_logs += f"Intialized and loaded {module_name} module\n"
Expand Down Expand Up @@ -1042,6 +1098,148 @@ async def _bot_command_style_raw(self, style: str or None, user: Dict, context:
context,
)

async def bot_command_model(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""/model commands callback
Args:
update (Update): update object from bot's callback
context (ContextTypes.DEFAULT_TYPE): context object from bot's callback
"""
# Get user
banned, user = await self._user_get_check(update, context)
if user is None:
return
user_id = user.get("user_id")
user_name = self.users_handler.get_key(0, "user_name", "", user=user)
lang_id = self.users_handler.get_key(0, "lang_id", user=user)

# Log command
logging.info(f"/model command from {user_name} ({user_id})")

# Exit if banned
if banned:
return

module_id = self.users_handler.get_key(0, "module", self.config.get("modules").get("default"), user=user)

model = None

# User specified model
if context.args and len(context.args) >= 1:
try:
model = context.args[0].strip().lower()

# Get available models
current_module_id = self.users_handler.get_key(
0, "module", self.config.get("modules").get("default"), user=user
)
available_models = self.config.get(current_module_id).get("models", [])

# Get current model
model_current = self.config.get(module_id).get("model_default")
model_current = self.users_handler.get_key(0, f"{module_id}_model", model_current, user=user)

# Check
if not model_current or len(available_models) == 0:
await _send_safe(
user_id,
self.messages.get_message("model_no_models", lang_id=lang_id),
context,
)
return

# Check
if model not in available_models:
raise Exception(f"No model {model} in {' '.join(available_models)}")
except Exception as e:
logging.error("Error retrieving requested model", exc_info=e)
await _send_safe(
user["user_id"],
self.messages.get_message("model_change_error", lang_id=lang_id).format(error_text=str(e)),
context,
)
return

# Change model or ask the user
await self._bot_command_model_raw(module_id, model, user, context)

async def _bot_command_model_raw(
self, module_id: str or None, model: str or None, user: Dict, context: ContextTypes.DEFAULT_TYPE
) -> None:
"""Changes model of module
Args:
module_id (str or None): id of module to change model of
model (str or None): model name or None to ask user
user (Dict): user's data as dictionary
context (ContextTypes.DEFAULT_TYPE): context object from bot's callback
"""
user_id = user.get("user_id")
lang_id = self.users_handler.get_key(0, "lang_id", user=user)

# Extract current user's module and model
module_icon_names = self.messages.get_message("modules", lang_id=lang_id)
if not module_id:
module_id = self.users_handler.get_key(0, "module", self.config.get("modules").get("default"), user=user)
current_module_name = module_icon_names.get(module_id).get("name")
current_module_icon = module_icon_names.get(module_id).get("icon")
current_module_name = f"{current_module_icon} {current_module_name}"

# Get available models
available_models = self.config.get(module_id).get("models", [])

# Get current model
model_current = self.config.get(module_id).get("model_default")
model_current = self.users_handler.get_key(0, f"{module_id}_model", model_current, user=user)

# Check
if not model_current or len(available_models) == 0:
await _send_safe(
user_id,
self.messages.get_message("model_no_models", lang_id=lang_id),
context,
)
return

# Ask user
if not model:
buttons = []
for model_ in available_models:
buttons.append(InlineKeyboardButton(model_, callback_data=f"model|{module_id}|{model_}"))

await _send_safe(
user_id,
self.messages.get_message("model_select", lang_id=lang_id).format(
module_name=current_module_name, current_model=model_current
),
context,
reply_markup=InlineKeyboardMarkup(bot_sender.build_menu(buttons)),
)
return

# Change model
try:
# Change model of user
self.users_handler.set_key(user_id, f"{module_id}_model", model)

# Send confirmation
await _send_safe(
user_id,
self.messages.get_message("model_changed", lang_id=lang_id).format(
module_name=current_module_name, changed_model=model
),
context,
)

# Error changing model
except Exception as e:
logging.error("Error changing model", exc_info=e)
await _send_safe(
user_id,
self.messages.get_message("model_change_error", lang_id=lang_id).format(error_text=str(e)),
context,
)

########################################
# General (non-modules) commands below #
########################################
Expand Down
8 changes: 8 additions & 0 deletions bot_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,14 @@ def build_markup(
)
buttons.append(button_style)

# Add change model button
if request_response.module_name in module_wrapper_global.MODULES_WITH_MODELS:
button_model = InlineKeyboardButton(
messages_.get_message("button_model_change", user_id=user_id),
callback_data=f"model|{request_response.module_name}|",
)
buttons.append(button_model)

# Add change module button for all modules
button_module = InlineKeyboardButton(
messages_.get_message("button_module", user_id=user_id),
Expand Down
Loading

0 comments on commit f8a36e8

Please sign in to comment.