-
Notifications
You must be signed in to change notification settings - Fork 0
/
job_handler.py
136 lines (114 loc) · 4.6 KB
/
job_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import asyncio
import base64
import json
import logging
import queue
import aiohttp
import discord
JOB_IMG = "img"
JOB_TXT = "txt"
class JobQueue:
def __init__(self):
self.jobs = queue.Queue()
async def add_job(
self, job_type: str, channel_id: int, user: str, input: dict[str, str]
) -> None:
logging.info("Added job to queue")
self.jobs.put((job_type, channel_id, user, input))
async def run_jobs(self, client: discord.Client, job_function) -> None:
while True:
try:
job_type, channel_id, user, input = self.jobs.get(
block=False)
# Add the following line to see if a job is being executed
logging.info(f"Job [{job_type}] found, executing...")
except queue.Empty:
await asyncio.sleep(1)
continue
try:
job_result = await job_selector(
job_type, client, input, channel_id, user
)
except Exception as e:
logging.error(e)
self.jobs.task_done()
async def job_selector(job_type, client, input, channel_id, user):
job_type_map = {
JOB_IMG: job_img,
JOB_TXT: job_txt,
}
# Add the following line to see if the function is being executed
logging.info("Executing job selector...")
await job_type_map[job_type](client, input, channel_id, user)
logging.info("Job execution - done!")
# TODO: find a model that can handle text
async def job_txt(client, input, channel_id, user):
# Add the following line to see if the function is being executed
logging.info("Executing /txt job function...")
url = "http://192.168.1.123:5001/predictions"
headers = {"Content-Type": "application/json"}
data = {"input": input}
result = "lol `/txt` model is down right now. :P"
# async with aiohttp.ClientSession() as session:
# async with session.post(
# url, headers=headers, data=json.dumps(data)
# ) as response:
# if response.status == 200:
# result = await response.json()
# print(result)
# else:
# logging.error(f"Request failed with status code {response.status}")
coro = client.get_channel(channel_id).send(f"{user.mention} : {result}")
future = asyncio.run_coroutine_threadsafe(coro, client.loop)
future.result()
logging.info(f"Sent response to {user}.")
async def job_img(client, input, channel_id, user):
# Add the following line to see if the function is being executed
logging.info("Executing /img job function...")
url = "http://192.168.1.123:5000/predictions"
headers = {"Content-Type": "application/json"}
data = {"input": input}
async with aiohttp.ClientSession() as session:
async with session.post(
url, headers=headers, data=json.dumps(data)
) as response:
if response.status == 200:
response_data = await response.json()
png_data = base64.b64decode(
response_data["output"].split(",")[1])
with open("output.png", "wb") as f:
f.write(png_data)
else:
logging.error(
f"Request failed. [Status code: {response.status}]")
logging.error(response.text)
data_embed = discord.Embed()
data_embed.add_field(
name='prompt', value=input["prompt"], inline=False)
data_embed.add_field(
name='width', value=input["width"], inline=True)
data_embed.add_field(
name='height', value=input["height"], inline=True)
data_embed.add_field(
name="cfg scale", value=input["cfg_scale"], inline=True)
data_embed.add_field(
name="steps", value=input["steps"], inline=True)
data_embed.add_field(
name="sampler", value=input["sampler_name"], inline=True)
data_embed.add_field(
name="seed", value=input["seed"], inline=True)
data_embed.add_field(
name="negatives", value=input["negative_prompt"], inline=False)
coro_items = [
bot_send_msg(channel_id, client),
client.get_channel(channel_id).send(
f"{user.mention}, here's your image:", file=discord.File("output.png"), embed=data_embed
)
]
for coro in coro_items:
future = asyncio.run_coroutine_threadsafe(coro, client.loop)
future.result()
logging.info(f"Response sent to {user}")
async def bot_send_msg(channel_id: int, client: discord.Client):
channel = client.get_channel(channel_id)
await channel.send('✅ Done!')