Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Http Server + Add WebSockets #9

Merged
merged 9 commits into from
Feb 13, 2023
224 changes: 142 additions & 82 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
import threading
import heapq
import traceback
import asyncio

try:
import aiohttp
from aiohttp import web
except ImportError:
print("Module 'aiohttp' not installed. Please install it via:")
print("pip install aiohttp")
print("or")
print("pip install -r requirements.txt")
sys.exit()

if __name__ == "__main__":
if '--help' in sys.argv:
Expand All @@ -25,7 +36,6 @@
os.environ['ATTN_PRECISION'] = "fp16"

import torch

import nodes

def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}):
Expand Down Expand Up @@ -286,16 +296,19 @@ def prompt_worker(q):
q.task_done(item_id)

class PromptQueue:
def __init__(self):
def __init__(self, socket_handler):
self.socket_handler = socket_handler
self.mutex = threading.RLock()
self.not_empty = threading.Condition(self.mutex)
self.task_counter = 0
self.queue = []
self.currently_running = {}
socket_handler.prompt_queue = self

def put(self, item):
with self.mutex:
heapq.heappush(self.queue, item)
self.socket_handler.queue_updated(self)
self.not_empty.notify()

def get(self):
Expand All @@ -306,11 +319,13 @@ def get(self):
i = self.task_counter
self.currently_running[i] = copy.deepcopy(item)
self.task_counter += 1
self.socket_handler.queue_updated(self)
return (item, i)

def task_done(self, item_id):
with self.mutex:
self.currently_running.pop(item_id)
self.socket_handler.queue_updated(self)

def get_current_queue(self):
with self.mutex:
Expand All @@ -326,6 +341,7 @@ def get_tasks_remaining(self):
def wipe_queue(self):
with self.mutex:
self.queue = []
self.socket_handler.queue_updated(self)

def delete_queue_item(self, function):
with self.mutex:
Expand All @@ -336,35 +352,80 @@ def delete_queue_item(self, function):
else:
self.queue.pop(x)
heapq.heapify(self.queue)
self.socket_handler.queue_updated(self)
return True
return False

from http.server import BaseHTTPRequestHandler, HTTPServer

class PromptServer(BaseHTTPRequestHandler):
def _set_headers(self, code=200, ct='text/html'):
self.send_response(code)
self.send_header('Content-type', ct)
self.end_headers()
def log_message(self, format, *args):
pass
def do_GET(self):
if self.path == "/prompt":
self._set_headers(ct='application/json')
prompt_info = {}
exec_info = {}
exec_info['queue_remaining'] = self.server.prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info
self.wfile.write(json.dumps(prompt_info).encode('utf-8'))
elif self.path == "/queue":
self._set_headers(ct='application/json')
queue_info = {}
current_queue = self.server.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
self.wfile.write(json.dumps(queue_info).encode('utf-8'))
elif self.path == "/object_info":
self._set_headers(ct='application/json')
def get_queue_info(prompt_queue):
prompt_info = {}
exec_info = {}
exec_info['queue_remaining'] = prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info
return prompt_info

class SocketHandler():
def __init__(self, loop):
self.connected = set()
self.messages = asyncio.Queue()
self.loop = loop

async def publish_loop(self):
while True:
msg = await self.messages.get()
await self.send(msg)

def queue_updated(self, queue):
# This is called by the queue processing thread so we need to make it thread safe
loop.call_soon_threadsafe(self.messages.put_nowait, { 'type': 'status', 'status': get_queue_info(queue) })

async def send(self, message, socket = None):
if isinstance(message, str) == False:
message = json.dumps(message)

if socket is None:
for ws in self.connected:
await ws.send_str(message)
else:
await socket.send_str(message)

async def process(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
self.connected.add(ws)
try:
# Send initial state to the new client
await self.send({ 'type': 'status', 'status': get_queue_info(self.prompt_queue) }, ws)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception())
finally:
self.connected.remove(ws)

return ws

class PromptServer():
def __init__(self, prompt_queue, socket_handler):
self.prompt_queue = prompt_queue
self.socket_handler = socket_handler
self.number = 0
self.app = web.Application()
self.web_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "webshit")
routes = web.RouteTableDef()

@routes.get('/ws')
async def websocket_handler(request):
return await self.socket_handler.process(request)

@routes.get("/")
async def get_root(request):
return web.FileResponse(os.path.join(self.web_root, "index.html"))

@routes.get("/prompt")
async def get_prompt(request):
return web.json_response(get_queue_info(self.prompt_queue))

@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
Expand All @@ -377,87 +438,87 @@ def do_GET(self):
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
out[x] = info
self.wfile.write(json.dumps(out).encode('utf-8'))
elif self.path[1:] in os.listdir(self.server.server_dir):
if self.path[1:].endswith('.css'):
self._set_headers(ct='text/css')
elif self.path[1:].endswith('.js'):
self._set_headers(ct='text/javascript')
else:
self._set_headers()
with open(os.path.join(self.server.server_dir, self.path[1:]), "rb") as f:
self.wfile.write(f.read())
else:
self._set_headers()
with open(os.path.join(self.server.server_dir, "index.html"), "rb") as f:
self.wfile.write(f.read())

def do_HEAD(self):
self._set_headers()

def do_POST(self):
resp_code = 200
out_string = ""
if self.path == "/prompt":
return web.json_response(out)

@routes.get("/queue")
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)

@routes.post("/prompt")
async def post_prompt(request):
print("got prompt")
data_string = self.rfile.read(int(self.headers['Content-Length']))
json_data = json.loads(data_string)
resp_code = 200
out_string = ""
json_data = await request.json()

if "number" in json_data:
number = float(json_data['number'])
else:
number = self.server.number
number = self.number
if "front" in json_data:
if json_data['front']:
number = -number

self.server.number += 1
self.number += 1
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = validate_prompt(prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
if valid[0]:
self.server.prompt_queue.put((number, id(prompt), prompt, extra_data))
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
else:
resp_code = 400
out_string = valid[1]
print("invalid prompt:", valid[1])
elif self.path == "/queue":
data_string = self.rfile.read(int(self.headers['Content-Length']))
json_data = json.loads(data_string)

return web.Response(body=out_string, status=resp_code)

@routes.post("/queue")
async def post_queue(request):
json_data = await request.json()
if "clear" in json_data:
if json_data["clear"]:
self.server.prompt_queue.wipe_queue()
self.prompt_queue.wipe_queue()
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete)
self.server.prompt_queue.delete_queue_item(delete_func)

self._set_headers(code=resp_code)
self.end_headers()
self.wfile.write(out_string.encode('utf8'))
return


def run(prompt_queue, address='', port=8188):
server_address = (address, port)
httpd = HTTPServer(server_address, PromptServer)
httpd.server_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "webshit")
httpd.prompt_queue = prompt_queue
httpd.number = 0
if server_address[0] == '':
addr = '0.0.0.0'
else:
addr = server_address[0]
self.prompt_queue.delete_queue_item(delete_func)

return web.Response(status=200)

self.app.add_routes(routes)
self.app.add_routes([
web.static('/', self.web_root),
])

async def start_server(server, address, port):
runner = web.AppRunner(server.app)
await runner.setup()
site = web.TCPSite(runner, address, port)
await site.start()

if address == '':
address = '0.0.0.0'
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(addr, server_address[1]))
httpd.serve_forever()
print("To see the GUI go to: http://{}:{}".format(address, port))

async def run(prompt_queue, socket_handler, address='', port=8188):
server = PromptServer(prompt_queue, socket_handler)
await asyncio.gather(start_server(server, address, port), socket_handler.publish_loop())

if __name__ == "__main__":
q = PromptQueue()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

socket_handler = SocketHandler(loop)
q = PromptQueue(socket_handler)
threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start()
if '--listen' in sys.argv:
address = '0.0.0.0'
Expand All @@ -471,6 +532,5 @@ def run(prompt_queue, address='', port=8188):
except:
pass

run(q, address=address, port=port)

loop.run_until_complete(run(q, socket_handler, address=address, port=port))

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ open-clip-torch
transformers
safetensors
pytorch_lightning

aiohttp
accelerate

Loading