From cd77f402641e783d84a5c5fb9486ddb364bc5879 Mon Sep 17 00:00:00 2001 From: North101 Date: Sun, 30 Jun 2024 12:52:51 +0100 Subject: [PATCH] Use phew as a http server. Make code async --- access-front-door/src/lib/phew/__init__.py | 77 ++++ access-front-door/src/lib/phew/dns.py | 32 ++ access-front-door/src/lib/phew/logging.py | 111 +++++ access-front-door/src/lib/phew/server.py | 506 +++++++++++++++++++++ access-front-door/src/lib/phew/template.py | 65 +++ access-front-door/src/main.py | 155 ++++--- 6 files changed, 878 insertions(+), 68 deletions(-) create mode 100644 access-front-door/src/lib/phew/__init__.py create mode 100644 access-front-door/src/lib/phew/dns.py create mode 100644 access-front-door/src/lib/phew/logging.py create mode 100644 access-front-door/src/lib/phew/server.py create mode 100644 access-front-door/src/lib/phew/template.py diff --git a/access-front-door/src/lib/phew/__init__.py b/access-front-door/src/lib/phew/__init__.py new file mode 100644 index 0000000..4131b0f --- /dev/null +++ b/access-front-door/src/lib/phew/__init__.py @@ -0,0 +1,77 @@ +__version__ = "0.0.2" + +# highly recommended to set a lowish garbage collection threshold +# to minimise memory fragmentation as we sometimes want to +# allocate relatively large blocks of ram. +import gc, os, machine +gc.threshold(50000) + +# phew! the Pico (or Python) HTTP Endpoint Wrangler +from . import logging + +# determine if remotely mounted or not, changes some behaviours like +# logging truncation +remote_mount = False +try: + os.statvfs(".") # causes exception if remotely mounted (mpremote/pyboard.py) +except: + remote_mount = True + +def get_ip_address(): + import network + try: + return network.WLAN(network.STA_IF).ifconfig()[0] + except: + return None + +def is_connected_to_wifi(): + import network, time + wlan = network.WLAN(network.STA_IF) + return wlan.isconnected() + +# helper method to quickly get connected to wifi +def connect_to_wifi(ssid, password, timeout_seconds=30): + import network, time + + statuses = { + network.STAT_IDLE: "idle", + network.STAT_CONNECTING: "connecting", + network.STAT_WRONG_PASSWORD: "wrong password", + network.STAT_NO_AP_FOUND: "access point not found", + network.STAT_CONNECT_FAIL: "connection failed", + network.STAT_GOT_IP: "got ip address" + } + + wlan = network.WLAN(network.STA_IF) + wlan.active(True) + wlan.connect(ssid, password) + start = time.ticks_ms() + status = wlan.status() + + logging.debug(f" - {statuses[status]}") + while not wlan.isconnected() and (time.ticks_ms() - start) < (timeout_seconds * 1000): + new_status = wlan.status() + if status != new_status: + logging.debug(f" - {statuses[status]}") + status = new_status + time.sleep(0.25) + + if wlan.status() == network.STAT_GOT_IP: + return wlan.ifconfig()[0] + return None + + +# helper method to put the pico into access point mode +def access_point(ssid, password = None): + import network + + # start up network in access point mode + wlan = network.WLAN(network.AP_IF) + wlan.config(essid=ssid) + if password: + wlan.config(password=password) + else: + wlan.config(security=0) # disable password + wlan.active(True) + + return wlan diff --git a/access-front-door/src/lib/phew/dns.py b/access-front-door/src/lib/phew/dns.py new file mode 100644 index 0000000..2a1a4b3 --- /dev/null +++ b/access-front-door/src/lib/phew/dns.py @@ -0,0 +1,32 @@ +import uasyncio, usocket +from . import logging + +async def _handler(socket, ip_address): + while True: + try: + yield uasyncio.core._io_queue.queue_read(socket) + request, client = socket.recvfrom(256) + response = request[:2] # request id + response += b"\x81\x80" # response flags + response += request[4:6] + request[4:6] # qd/an count + response += b"\x00\x00\x00\x00" # ns/ar count + response += request[12:] # origional request body + response += b"\xC0\x0C" # pointer to domain name at byte 12 + response += b"\x00\x01\x00\x01" # type and class (A record / IN class) + response += b"\x00\x00\x00\x3C" # time to live 60 seconds + response += b"\x00\x04" # response length (4 bytes = 1 ipv4 address) + response += bytes(map(int, ip_address.split("."))) # ip address parts + socket.sendto(response, client) + except Exception as e: + logging.error(e) + +def run_catchall(ip_address, port=53): + logging.info("> starting catch all dns server on port {}".format(port)) + + _socket = usocket.socket(usocket.AF_INET, usocket.SOCK_DGRAM) + _socket.setblocking(False) + _socket.setsockopt(usocket.SOL_SOCKET, usocket.SO_REUSEADDR, 1) + _socket.bind(usocket.getaddrinfo(ip_address, port, 0, usocket.SOCK_DGRAM)[0][-1]) + + loop = uasyncio.get_event_loop() + loop.create_task(_handler(_socket, ip_address)) \ No newline at end of file diff --git a/access-front-door/src/lib/phew/logging.py b/access-front-door/src/lib/phew/logging.py new file mode 100644 index 0000000..dd70394 --- /dev/null +++ b/access-front-door/src/lib/phew/logging.py @@ -0,0 +1,111 @@ +import machine, os, gc + +log_file = "log.txt" + +LOG_INFO = 0b00001 +LOG_WARNING = 0b00010 +LOG_ERROR = 0b00100 +LOG_DEBUG = 0b01000 +LOG_EXCEPTION = 0b10000 +LOG_ALL = LOG_INFO | LOG_WARNING | LOG_ERROR | LOG_DEBUG | LOG_EXCEPTION + +_logging_types = LOG_ALL + +# the log file will be truncated if it exceeds _log_truncate_at bytes in +# size. the defaults values are designed to limit the log to at most +# three blocks on the Pico +_log_truncate_at = 11 * 1024 +_log_truncate_to = 8 * 1024 + +def datetime_string(): + dt = machine.RTC().datetime() + return "{0:04d}-{1:02d}-{2:02d} {4:02d}:{5:02d}:{6:02d}".format(*dt) + +def file_size(file): + try: + return os.stat(file)[6] + except OSError: + return None + +def set_truncate_thresholds(truncate_at, truncate_to): + global _log_truncate_at + global _log_truncate_to + _log_truncate_at = truncate_at + _log_truncate_to = truncate_to + +def enable_logging_types(types): + global _logging_types + _logging_types = _logging_types | types + +def disable_logging_types(types): + global _logging_types + _logging_types = _logging_types & ~types + +# truncates the log file down to a target size while maintaining +# clean line breaks +def truncate(file, target_size): + # get the current size of the log file + size = file_size(file) + + # calculate how many bytes we're aiming to discard + discard = size - target_size + if discard <= 0: + return + + with open(file, "rb") as infile: + with open(file + ".tmp", "wb") as outfile: + # skip a bunch of the input file until we've discarded + # at least enough + while discard > 0: + chunk = infile.read(1024) + discard -= len(chunk) + + # try to find a line break nearby to split first chunk on + break_position = max( + chunk.find (b"\n", -discard), # search forward + chunk.rfind(b"\n", -discard) # search backwards + ) + if break_position != -1: # if we found a line break.. + outfile.write(chunk[break_position + 1:]) + + # now copy the rest of the file + while True: + chunk = infile.read(1024) + if not chunk: + break + outfile.write(chunk) + + # delete the old file and replace with the new + os.remove(file) + os.rename(file + ".tmp", file) + + +def log(level, text): + datetime = datetime_string() + log_entry = "{0} [{1:8} /{2:>4}kB] {3}".format(datetime, level, round(gc.mem_free() / 1024), text) + print(log_entry) + with open(log_file, "a") as logfile: + logfile.write(log_entry + '\n') + + if _log_truncate_at and file_size(log_file) > _log_truncate_at: + truncate(log_file, _log_truncate_to) + +def info(*items): + if _logging_types & LOG_INFO: + log("info", " ".join(map(str, items))) + +def warn(*items): + if _logging_types & LOG_WARNING: + log("warning", " ".join(map(str, items))) + +def error(*items): + if _logging_types & LOG_ERROR: + log("error", " ".join(map(str, items))) + +def debug(*items): + if _logging_types & LOG_DEBUG: + log("debug", " ".join(map(str, items))) + +def exception(*items): + if _logging_types & LOG_EXCEPTION: + log("exception", " ".join(map(str, items))) \ No newline at end of file diff --git a/access-front-door/src/lib/phew/server.py b/access-front-door/src/lib/phew/server.py new file mode 100644 index 0000000..4c445be --- /dev/null +++ b/access-front-door/src/lib/phew/server.py @@ -0,0 +1,506 @@ +import binascii +import gc +import random + +import uasyncio, os, time +from . import logging + + +def file_exists(filename): + try: + return (os.stat(filename)[0] & 0x4000) == 0 + except OSError: + return False + + +def urldecode(text): + text = text.replace("+", " ") + result = "" + token_caret = 0 + # decode any % encoded characters + while True: + start = text.find("%", token_caret) + if start == -1: + result += text[token_caret:] + break + result += text[token_caret:start] + code = int(text[start + 1:start + 3], 16) + result += chr(code) + token_caret = start + 3 + return result + +def _parse_query_string(query_string): + result = {} + for parameter in query_string.split("&"): + key, value = parameter.split("=", 1) + key = urldecode(key) + value = urldecode(value) + result[key] = value + return result + + +class Request: + def __init__(self, method, uri, protocol): + self.method = method + self.uri = uri + self.protocol = protocol + self.form = {} + self.data = {} + self.query = {} + query_string_start = uri.find("?") if uri.find("?") != -1 else len(uri) + self.path = uri[:query_string_start] + self.query_string = uri[query_string_start + 1:] + if self.query_string: + self.query = _parse_query_string(self.query_string) + + def __str__(self): + return f"""\ +request: {self.method} {self.path} {self.protocol} +headers: {self.headers} +form: {self.form} +data: {self.data}""" + + +class Response: + def __init__(self, body, status=200, headers={}): + self.status = status + self.headers = headers + self.body = body + + def add_header(self, name, value): + self.headers[name] = value + + def __str__(self): + return f"""\ +status: {self.status} +headers: {self.headers} +body: {self.body}""" + + +content_type_map = { + "html": "text/html", + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "svg": "image/svg+xml", + "json": "application/json", + "png": "image/png", + "css": "text/css", + "js": "text/javascript", + "csv": "text/csv", + "txt": "text/plain", + "bin": "application/octet-stream", + "xml": "application/xml", + "gif": "image/gif", +} + + +class FileResponse(Response): + def __init__(self, file, status=200, headers={}): + self.status = 404 + self.headers = headers + self.file = file + + try: + if (os.stat(self.file)[0] & 0x4000) == 0: + self.status = 200 + + # auto set content type + extension = self.file.split(".")[-1].lower() + if extension in content_type_map: + headers["Content-Type"] = content_type_map[extension] + + headers["Content-Length"] = os.stat(self.file)[6] + except OSError: + return False + + +class Route: + def __init__(self, path, handler, methods=["GET"]): + self.path = path + self.methods = methods + self.handler = handler + self.path_parts = path.split("/") + + # returns True if the supplied request matches this route + def matches(self, request): + if request.method not in self.methods: + return False + compare_parts = request.path.split("/") + if len(compare_parts) != len(self.path_parts): + return False + for part, compare in zip(self.path_parts, compare_parts): + if not part.startswith("<") and part != compare: + return False + return True + + # call the route handler passing any named parameters in the path + def call_handler(self, request): + parameters = {} + for part, compare in zip(self.path_parts, request.path.split("/")): + if part.startswith("<"): + name = part[1:-1] + parameters[name] = compare + + return self.handler(request, **parameters) + + def __str__(self): + return f"""\ +path: {self.path} +methods: {self.methods} +""" + + def __repr__(self): + return f"" + + +# parses the headers for a http request (or the headers attached to +# each field in a multipart/form-data) +async def _parse_headers(reader): + headers = {} + while True: + header_line = await reader.readline() + if header_line == b"\r\n": # crlf denotes body start + break + name, value = header_line.decode().strip().split(": ", 1) + headers[name.lower()] = value + return headers + + + +# if the content type is multipart/form-data then parse the fields +async def _parse_form_data(reader, headers): + boundary = headers["content-type"].split("boundary=")[1] + # discard first boundary line + dummy = await reader.readline() + + form = {} + while True: + # get the field name + field_headers = await _parse_headers(reader) + if len(field_headers) == 0: + break + name = field_headers["content-disposition"].split("name=\"")[1][:-1] + # get the field value + value = "" + while True: + line = await reader.readline() + line = line.decode().strip() + # if we hit a boundary then save the value and move to next field + if line == "--" + boundary: + form[name] = value + break + # if we hit end of form data boundary then save value and return + if line == "--" + boundary + "--": + form[name] = value + return form + value += line + return None + + +# if the content type is application/json then parse the body +async def _parse_json_body(reader, headers): + import json + content_length_bytes = int(headers["content-length"]) + body = await reader.readexactly(content_length_bytes) + return json.loads(body.decode()) + + +status_message_map = { + 200: "OK", 201: "Created", 202: "Accepted", + 203: "Non-Authoritative Information", 204: "No Content", + 205: "Reset Content", 206: "Partial Content", 300: "Multiple Choices", + 301: "Moved Permanently", 302: "Found", 303: "See Other", + 304: "Not Modified", 305: "Use Proxy", 306: "Switch Proxy", + 307: "Temporary Redirect", 308: "Permanent Redirect", + 400: "Bad Request", 401: "Unauthorized", 403: "Forbidden", + 404: "Not Found", 405: "Method Not Allowed", 406: "Not Acceptable", + 408: "Request Timeout", 409: "Conflict", 410: "Gone", + 414: "URI Too Long", 415: "Unsupported Media Type", + 416: "Range Not Satisfiable", 418: "I'm a teapot", + 500: "Internal Server Error", 501: "Not Implemented" +} + +class Session: + + ''' + Session class used to store all the attributes of a session. + ''' + + def __init__(self, max_age=86400): + # create a 128 bit session id encoded in hex + n = [] + for i in range(4): + n.append(random.getrandbits(32).to_bytes(4,'big')) + self.session_id = binascii.hexlify(bytearray().join(n)).decode() + self.expires = time.time() + max_age + self.max_age = max_age + + def expired(self): + return self.expires < time.time() + + +class Phew: + + def __init__(self): + self._routes = [] + self._login_required = set() + self.catchall_handler = None + self._login_catchall = None + self.loop = uasyncio.get_event_loop() + self.sessions = [] + + # handle an incoming request to the web server + async def _handle_request(self, reader, writer): + + # Do a GC collect before handling the request + gc.collect() + + response = None + + request_start_time = time.ticks_ms() + + request_line = await reader.readline() + try: + method, uri, protocol = request_line.decode().split() + except Exception as e: + logging.error(e) + return + + request = Request(method, uri, protocol) + request.headers = await _parse_headers(reader) + if "content-length" in request.headers and "content-type" in request.headers: + if request.headers["content-type"].startswith("multipart/form-data"): + request.form = await _parse_form_data(reader, request.headers) + if request.headers["content-type"].startswith("application/json"): + request.data = await _parse_json_body(reader, request.headers) + if request.headers["content-type"].startswith("application/x-www-form-urlencoded"): + form_data = b"" + content_length = int(request.headers["content-length"]) + while content_length > 0: + data = await reader.read(content_length) + if len(data) == 0: + break + content_length -= len(data) + form_data += data + request.form = _parse_query_string(form_data.decode()) + + route = self._match_route(request) + if route and self._login_catchall and self.is_login_required(route.handler) and not self.active_session(request): + response = self._login_catchall(request) + elif route: + response = route.call_handler(request) + elif self.catchall_handler: + if self.is_login_required(self.catchall_handler) and not self.active_session(request): + # handle the case that the catchall handler is annotated with @login_required() + response = self._login_catchall(request) + else: + response = self.catchall_handler(request) + + # if shorthand body generator only notation used then convert to tuple + if type(response).__name__ == "generator": + response = (response,) + + # if shorthand body text only notation used then convert to tuple + if isinstance(response, str): + response = (response,) + + # if shorthand tuple notation used then build full response object + if isinstance(response, tuple): + body = response[0] + status = response[1] if len(response) >= 2 else 200 + content_type = response[2] if len(response) >= 3 else "text/html" + response = Response(body, status=status) + response.add_header("Content-Type", content_type) + if hasattr(body, '__len__'): + response.add_header("Content-Length", len(body)) + + # write status line + status_message = status_message_map.get(response.status, "Unknown") + writer.write(f"HTTP/1.1 {response.status} {status_message}\r\n".encode("ascii")) + + # write headers + for key, value in response.headers.items(): + writer.write(f"{key}: {value}\r\n".encode("ascii")) + + # blank line to denote end of headers + writer.write("\r\n".encode("ascii")) + + if isinstance(response, FileResponse): + # file + with open(response.file, "rb") as f: + while True: + chunk = f.read(1024) + if not chunk: + break + writer.write(chunk) + await writer.drain() + elif type(response.body).__name__ == "generator": + # generator + for chunk in response.body: + writer.write(chunk) + await writer.drain() + else: + # string/bytes + writer.write(response.body) + await writer.drain() + + writer.close() + await writer.wait_closed() + + processing_time = time.ticks_ms() - request_start_time + logging.info(f"> {request.method} {request.path} ({response.status} {status_message}) [{processing_time}ms]") + + + # adds a new route to the routing table + def add_route(self, path, handler, methods=["GET"]): + self._routes.append(Route(path, handler, methods)) + # descending complexity order so most complex routes matched first + self._routes = sorted(self._routes, key=lambda route: len(route.path_parts), reverse=True) + + + def set_callback(self, handler): + self.catchall_handler = handler + + + # decorator shorthand for adding a route + def route(self, path, methods=["GET"]): + def _route(f): + self.add_route(path, f, methods=methods) + return f + return _route + + + # add the handler to the _login_required list + def add_login_required(self, handler): + self._login_required.add(handler) + + + def is_login_required(self, handler): + return handler in self._login_required + + + # decorator indicating that authentication is required for a handler + def login_required(self): + def _login_required(f): + self.add_login_required(f) + return f + return _login_required + + + def set_login_catchall(self, handler): + self._login_catchall = handler + + + # decorator for adding login_handler route + def login_catchall(self): + def _login_catchall(f): + self.set_login_catchall(f) + return f + return _login_catchall + + + # decorator for adding catchall route + def catchall(self): + def _catchall(f): + self.set_callback(f) + return f + return _catchall + + def redirect(self, url, status = 301): + return Response("", status, {"Location": url}) + + def serve_file(self, file): + return FileResponse(file) + + # returns the route matching the supplied path or None + def _match_route(self, request): + for route in self._routes: + if route.matches(request): + return route + return None + + def run_as_task(self, loop, host = "0.0.0.0", port = 80, ssl=None): + loop.create_task(uasyncio.start_server(self._handle_request, host, port, ssl=ssl)) + + def run(self, host = "0.0.0.0", port = 80, ssl=None): + logging.info("> starting web server on port {}".format(port)) + self.loop.create_task(uasyncio.start_server(self._handle_request, host, port, ssl=ssl)) + self.loop.run_forever() + + def stop(self): + self.loop.stop() + + def close(self): + self.loop.close() + + def create_session(self, max_age=86400): + session = Session(max_age=max_age) + self.sessions.append(session) + return session + + def get_session(self, request): + session = None + name = None + value = None + if "cookie" in request.headers: + cookie = request.headers["cookie"] + if cookie: + name, value = cookie.split("=") + if name == "sessionid": + # find session + for s in self.sessions: + if s.session_id == value: + session = s + return session + def remove_session(self, request): + session = self.get_session(request) + if session is not None: + self.sessions.remove(session) + + def active_session(self, request): + session = self.get_session(request) + return session is not None and not session.expired() + +# Compatibility methods +default_phew_app = None + + +def default_phew(): + global default_phew_app + if not default_phew_app: + default_phew_app = Phew() + return default_phew_app + + +def set_callback(handler): + default_phew().set_callback(handler) + + +# decorator shorthand for adding a route +def route(path, methods=["GET"]): + return default_phew().route(path, methods) + + +# decorator for adding catchall route +def catchall(): + return default_phew().catchall() + + +def redirect(url, status=301): + return default_phew().redirect(url, status) + + +def serve_file(file): + return default_phew().serve_file(file) + + +def run(host="0.0.0.0", port=80): + default_phew().run(host, port) + + +def stop(): + default_phew().stop() + + +def close(): + default_phew().close() diff --git a/access-front-door/src/lib/phew/template.py b/access-front-door/src/lib/phew/template.py new file mode 100644 index 0000000..d1d1df4 --- /dev/null +++ b/access-front-door/src/lib/phew/template.py @@ -0,0 +1,65 @@ +from . import logging + +async def render_template(template, **kwargs): + import time + start_time = time.ticks_ms() + + with open(template, "rb") as f: + # read the whole template file, we could work on single lines but + # the performance is much worse - so long as our templates are + # just a handful of kB it's ok to do this + data = f.read() + token_caret = 0 + + while True: + # find the next tag that needs evaluating + start = data.find(b"{{", token_caret) + end = data.find(b"}}", start) + + match = start != -1 and end != -1 + + # no more magic to handle, just return what's left + if not match: + yield data[token_caret:] + break + + expression = data[start + 2:end].strip() + + # output the bit before the tag + yield data[token_caret:start] + + # merge locals with the supplied named arguments and + # the response object + params = {} + params.update(locals()) + params.update(kwargs) + #params["response"] = response + + # parse the expression + try: + if expression.decode("utf-8") in params: + result = params[expression.decode("utf-8")] + result = result.replace("&", "&") + result = result.replace('"', """) + result = result.replace("'", "'") + result = result.replace(">", ">") + result = result.replace("<", "<") + else: + result = eval(expression, globals(), params) + + if type(result).__name__ == "generator": + # if expression returned a generator then iterate it fully + # and yield each result + for chunk in result: + yield chunk + else: + # yield the result of the expression + if result is not None: + yield str(result) + except: + pass + + # discard the parsed bit + token_caret = end + 2 + + logging.debug("> parsed template:", template, "(took", time.ticks_ms() - start_time, "ms)") \ No newline at end of file diff --git a/access-front-door/src/main.py b/access-front-door/src/main.py index 926cf1a..edd1e58 100644 --- a/access-front-door/src/main.py +++ b/access-front-door/src/main.py @@ -1,100 +1,119 @@ -import socket -import time - import network +import uasyncio as asyncio +import utime as time from machine import PWM, Pin +from phew import server +from phew.server import Request from . import env -# Variables -RESPONSE = "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n" - - -class Door(): +class Wifi(): def __init__(self): - self.pwm = PWM(Pin(23)) # Set up pin D23 to output - self.led = Pin(2, Pin.OUT) # Pin 2 is the built-in LED - self.setup_wifi() - self.setup_socket() - - def setup_wifi(self): self.wifi = network.WLAN(network.STA_IF) - self.wifi.active(True) time.sleep_us(100) self.wifi.config(dhcp_hostname=env.HOSTNAME) - def connect_wifi(self): + async def connect(self, timeout_ms=60*1000): + if self.isconnected(): + return True + # Connect to WiFi + self.wifi.active(True) self.wifi.connect(env.WIFI_SSID, env.WIFI_PASSWORD) - while not self.wifi.isconnected(): - pass - - def setup_socket(self): - # Set up webserver - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.bind(('', 8080)) - self.socket.listen(5) - - def get_parameters_from(self, request: bytes): - parameters: dict[str, str] = {} - request_str = request.decode('utf-8') - params_index = request_str.find('Content-Length:') - ampersand_split = request_str[params_index:].split("&") - for element in ampersand_split: - equal_split = element.split("=") - parameters[equal_split[0]] = equal_split[1] - - return parameters - - def read_socket(self): - conn: socket.socket = self.socket.accept()[0] - request = conn.recv(1024) - parameters = self.get_parameters_from(request) - conn.send(RESPONSE) - conn.close() - - return parameters + try: + await asyncio.wait_for_ms(self.wait_for_connected(), timeout_ms) + return True + except asyncio.TimeoutError: + self.wifi.disconnect() + self.wifi.active(False) + return False - def lock(self): - self.led.off() - self.pwm.duty(0) - - def unlock(self): - # Output signal on pin to unlock and light up the internal light - self.led.on() - self.pwm.duty(1023) + def ip(self): + return self.wifi.ifconfig()[0] - def run(self): + def isconnected(self): + return self.wifi.isconnected() + + async def wait_for_connected(self): + while not self.isconnected(): + await asyncio.sleep_ms(500) + + async def wait_for_disconnected(self): + while self.isconnected(): + await asyncio.sleep_ms(500) + + async def stay_connected(self): while True: - try: - self.update() - except Exception as e: - print(e) + await self.connect() + await asyncio.sleep_ms(500) - def update(self): - self.lock() - if not self.wifi.isconnected(): - self.connect_wifi() +class DoorServer(): + def __init__(self): + self.pwm = PWM(Pin(23)) # Set up pin D23 to output + self.led = Pin(2, Pin.OUT) # Pin 2 is the built-in LED + self.setup_server() - parameters = self.read_socket() + # Tracks locks / unlocks so that we only actually lock when + # the value is 0 + self.lock_queue = 0 + + def setup_server(self): + self.server = server.Phew() + self.server.add_route('/', self.index, methods=['POST']) + + async def index(self, request: Request): + params = request.data # Bail if the request didn't come from a known source - if parameters.get('psk') != env.SHARED_PASSWORD: + if params.get('psk') != env.SHARED_PASSWORD: return - duration = parameters.get("duration") + duration = params.get("duration") try: duration = int(duration) if duration is not None else env.DEFAULT_UNLOCK_DURATION except ValueError: - return + duration = env.DEFAULT_UNLOCK_DURATION unlock_duration = max(min(duration, 30), 1) + self.lock_queue += 1 self.unlock() - time.sleep(unlock_duration) + asyncio.create_task(self.schedule_lock(unlock_duration)) + return 'OK' + + async def schedule_lock(self, duration: int): + await asyncio.sleep(duration) + + self.lock_queue -= 1 + if self.lock_queue == 0: + self.lock() + + def lock(self): + self.led.off() + self.pwm.duty(0) + + def unlock(self): + # Output signal on pin to unlock and light up the internal light + self.led.on() + self.pwm.duty(1023) + + def run(self): + self.server.run() + + +async def main(): + door = DoorServer() + door.lock() + + wifi = Wifi() + while not await wifi.connect(): + print('Connecting...') + print('Connected:', wifi.ip()) -if __name__ == '__main__': - door = Door() door.run() + + +if __name__ == '__main__': + asyncio.run(main())