diff --git a/.flake8 b/.flake8 index 7da1f960..ad985391 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,3 @@ [flake8] max-line-length = 100 +extend-ignore = E203 diff --git a/README.md b/README.md index f0489635..185388b0 100644 --- a/README.md +++ b/README.md @@ -5,3 +5,7 @@ ~~一个自产自销的仓库~~ Logging/Debugging/Tracing/Managing/Facilitating your deep learning projects A small part of the documentation at [neetbox.550w.host](https://neetbox.550w.host). (We are not ready for the doc yet) + +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=visualDust/neetbox&type=Date)](https://star-history.com/#visualDust/neetbox&Date) diff --git a/neetbox/__init__.py b/neetbox/__init__.py index 43943025..554e74c6 100644 --- a/neetbox/__init__.py +++ b/neetbox/__init__.py @@ -8,7 +8,7 @@ from neetbox.daemon import _try_attach_daemon from neetbox.utils.framing import get_frame_module_traceback -module = get_frame_module_traceback(1).__name__ +module = get_frame_module_traceback(1).__name__ # type: ignore config_file_name = f"{module}.toml" @@ -18,6 +18,11 @@ def post_init(): project_name = get_module_level_config()["name"] setproctitle.setproctitle(project_name) + from neetbox.daemon.client._connection import connection + + # post init ws + connection._init_ws() + def init(path=None, load=False, **kwargs) -> bool: if path: @@ -60,16 +65,16 @@ def init(path=None, load=False, **kwargs) -> bool: from neetbox.logging.logger import Logger logger = Logger("NEETBOX") # builtin standalone logger - logger.ok(f"Loaded workspace config from {config_file_path}.") + logger.ok(f"found workspace config from {config_file_path}.") _try_attach_daemon() # try attach daemon - post_init() + logger.debug(f"running post init...") return True except Exception as e: from neetbox.logging.logger import Logger logger = Logger("NEETBOX") # builtin standalone logger - logger.err(f"Failed to load config from {config_file_path}: {e}") - return False + logger.err(f"failed to load config from {config_file_path}: {e}") + raise e is_in_daemon_process = ( @@ -80,3 +85,5 @@ def init(path=None, load=False, **kwargs) -> bool: success = init(load=True) # init from config file if not success: os._exit(255) + # run post init + post_init() diff --git a/neetbox/core/packing.py b/neetbox/core/packing.py new file mode 100644 index 00000000..e69de29b diff --git a/neetbox/daemon/__init__.py b/neetbox/daemon/__init__.py index 92cc379e..d67777f5 100644 --- a/neetbox/daemon/__init__.py +++ b/neetbox/daemon/__init__.py @@ -8,9 +8,8 @@ import subprocess import time -import neetbox -from neetbox.config import get_module_level_config from neetbox.daemon.client._action_agent import _NeetActionManager as NeetActionManager +from neetbox.daemon.client._connection import connection from neetbox.daemon.client._daemon_client import connect_daemon from neetbox.daemon.server.daemonable_process import DaemonableProcess from neetbox.logging import logger @@ -84,4 +83,5 @@ def _try_attach_daemon(): action = NeetActionManager.register -__all__ = ["watch", "listen", "action", "NeetActionManager", "_try_attach_daemon"] +ws_subscribe = connection.ws_subscribe +__all__ = ["watch", "listen", "action", "ws_subscribe", "NeetActionManager", "_try_attach_daemon"] diff --git a/neetbox/daemon/_agent.py b/neetbox/daemon/_agent.py new file mode 100644 index 00000000..09e6ec2f --- /dev/null +++ b/neetbox/daemon/_agent.py @@ -0,0 +1,71 @@ +import functools +import inspect +from ast import literal_eval +from threading import Thread +from typing import Callable, Optional + +from neetbox.core import Registry +from neetbox.logging import logger +from neetbox.utils.mvc import Singleton + + +class PackedAction(Callable): + def __init__(self, function: Callable, name=None, **kwargs): + super().__init__(**kwargs) + self.function = function + self.name = name if name else function.__name__ + self.argspec = inspect.getfullargspec(self.function) + + def __call__(self, **argv): + self.function(argv) + + def eval_call(self, params: dict): + eval_params = dict((k, literal_eval(v)) for k, v in params.items()) + return self.function(**eval_params) + + +class _NeetAction(metaclass=Singleton): + __ACTION_POOL: Registry = Registry("__NEET_ACTIONS") + + def register( + self, + *, + name: Optional[str] = None, + ): + return functools.partial(self._register, name=name) + + def _register(self, function: Callable, name: str = None): + packed = PackedAction(function=function, name=name) + _NeetAction.__ACTION_POOL._register(what=packed, name=packed.name, force=True) + return function + + def get_actions(self): + action_names = _NeetAction.__ACTION_POOL.keys() + actions = {} + for n in action_names: + actions[n] = _NeetAction.__ACTION_POOL[n].argspec + return actions + + def eval_call(self, name: str, params: dict): + if name not in _NeetAction.__ACTION_POOL: + logger.err(f"Could not find action with name {name}, action stopped.") + return False + return _NeetAction.__ACTION_POOL[name].eval_call(params) + + +# singleton +neet_action = _NeetAction() + + +# example +if __name__ == "__main__": + + @neet_action.register(name="some") + def some(a, b): + print(a, b) + + print("registered actions:") + print(neet_action.get_actions()) + + print("calling 'some") + neet_action.eval_call("some", {"a": "3", "b": "4"}) diff --git a/neetbox/daemon/client/_action_agent.py b/neetbox/daemon/client/_action_agent.py index 18ed5e86..25c21eb8 100644 --- a/neetbox/daemon/client/_action_agent.py +++ b/neetbox/daemon/client/_action_agent.py @@ -62,6 +62,7 @@ def run_and_callback(target_action, params, callback): Thread( target=run_and_callback, kwargs={"target_action": target_action, "params": params, "callback": callback}, + daemon=True, ).start() return None else: # blocking run diff --git a/neetbox/daemon/client/_client_apis.py b/neetbox/daemon/client/_client_apis.py index 2598ec60..efe1fa15 100644 --- a/neetbox/daemon/client/_client_apis.py +++ b/neetbox/daemon/client/_client_apis.py @@ -25,7 +25,7 @@ def get_status_of(name=None): name = name or "" - api_addr = f"{base_addr}/status" + api_addr = f"{base_addr}/web/list" logger.info(f"Fetching from {api_addr}") r = connection.http.get(api_addr) _data = r.json() diff --git a/neetbox/daemon/client/_connection.py b/neetbox/daemon/client/_connection.py index 65e54ec9..42c22211 100644 --- a/neetbox/daemon/client/_connection.py +++ b/neetbox/daemon/client/_connection.py @@ -1,12 +1,18 @@ import asyncio import functools +import json import logging -from typing import Callable, Optional +import time +from dataclasses import dataclass +from threading import Thread +from typing import Any, Callable, Optional import httpx +import websocket from neetbox.config import get_module_level_config from neetbox.core import Registry +from neetbox.daemon.server._server import CLIENT_API_ROOT from neetbox.logging import logger from neetbox.utils.mvc import Singleton @@ -14,19 +20,35 @@ httpx_logger.setLevel(logging.ERROR) EVENT_TYPE_NAME_KEY = "event-type" -EVENT_PAYLOAD_NAME_KEY = "payload" +EVENT_ID_NAME_KEY = "event-id" +NAME_NAME_KEY = "name" +PAYLOAD_NAME_KEY = "payload" + + +@dataclass +class WsMsg: + name: str + event_type: str + payload: Any + event_id: int = -1 + + def json(self): + return { + NAME_NAME_KEY: self.name, + EVENT_TYPE_NAME_KEY: self.event_type, + EVENT_ID_NAME_KEY: self.event_id, + PAYLOAD_NAME_KEY: self.payload, + } # singleton class ClientConn(metaclass=Singleton): http: httpx.Client = None - __ws_client: None # _websocket_client + __ws_client: websocket.WebSocketApp = None # _websocket_client __ws_subscription = Registry("__client_ws_subscription") # { event-type-name : list(Callable)} def __init__(self) -> None: - cfg = get_module_level_config() - def __load_http_client(): __local_http_client = httpx.Client( proxies={ @@ -38,9 +60,68 @@ def __load_http_client(): # create htrtp client ClientConn.http = __load_http_client() - # todo establishing socket connection - def __on_ws_message(msg): + def _init_ws(): + cfg = get_module_level_config() + _root_config = get_module_level_config("@") + ClientConn._display_name = cfg["displayName"] or _root_config["name"] + + # ws server url + ClientConn.ws_server_addr = f"ws://{cfg['host']}:{cfg['port'] + 1}{CLIENT_API_ROOT}" + + # create websocket app + logger.log(f"creating websocket connection to {ClientConn.ws_server_addr}") + + ws = websocket.WebSocketApp( + ClientConn.ws_server_addr, + on_open=ClientConn.__on_ws_open, + on_message=ClientConn.__on_ws_message, + on_error=ClientConn.__on_ws_err, + on_close=ClientConn.__on_ws_close, + ) + + Thread(target=ws.run_forever, kwargs={"reconnect": True}, daemon=True).start() + + # assign self to websocket log writer + from neetbox.logging._writer import _assign_connection_to_WebSocketLogWriter + + _assign_connection_to_WebSocketLogWriter(ClientConn) + + def __on_ws_open(ws: websocket.WebSocketApp): + _display_name = ClientConn._display_name + logger.ok(f"client websocket connected. sending handshake as '{_display_name}'...") + ws.send( # send handshake request + json.dumps( + { + NAME_NAME_KEY: {_display_name}, + EVENT_TYPE_NAME_KEY: "handshake", + PAYLOAD_NAME_KEY: {"who": "cli"}, + EVENT_ID_NAME_KEY: 0, # todo how does ack work + }, + default=str, + ) + ) + logger.ok(f"handshake succeed.") + ClientConn.__ws_client = ws + + def __on_ws_err(ws: websocket.WebSocketApp, msg): + logger.err(f"client websocket encountered {msg}") + + def __on_ws_close(ws: websocket.WebSocketApp, close_status_code, close_msg): + logger.warn(f"client websocket closed") + if close_status_code or close_msg: + logger.warn(f"ws close status code: {close_status_code}") + logger.warn("ws close message: {close_msg}") + ClientConn.__ws_client = None + + def __on_ws_message(ws: websocket.WebSocketApp, msg): + """EXAMPLE JSON + { + "event-type": "action", + "event-id": 111 (optional?) + "payload": ... + } + """ logger.debug(f"ws received {msg}") # message should be json event_type_name = msg[EVENT_TYPE_NAME_KEY] @@ -50,17 +131,28 @@ def __on_ws_message(msg): ) for subscriber in ClientConn._ws_subscribe[event_type_name]: try: - subscriber(msg[EVENT_PAYLOAD_NAME_KEY]) # pass payload message into subscriber + subscriber(msg) # pass payload message into subscriber except Exception as e: # subscriber throws error logger.err( f"Subscriber {subscriber} crashed on message event {event_type_name}, ignoring." ) - def ws_send(msg): - logger.debug(f"ws sending {msg}") - # send to ws if ws is connected, otherwise drop message? idk - pass + def ws_send(event_type: str, payload): + logger.debug(f"ws sending {payload}") + if ClientConn.__ws_client: # if ws client exist + ClientConn.__ws_client.send( + json.dumps( + { + NAME_NAME_KEY: ClientConn._display_name, + EVENT_TYPE_NAME_KEY: event_type, + PAYLOAD_NAME_KEY: payload, + EVENT_ID_NAME_KEY: -1, # todo how does ack work + } + ) + ) + else: + logger.debug("ws client not exist, message dropped.") def ws_subscribe(event_type_name: str): """let a function subscribe to ws messages with event type name. @@ -81,5 +173,5 @@ def _ws_subscribe(function: Callable, event_type_name: str): # singleton -ClientConn() # run init +ClientConn() # __init__ setup http client only connection = ClientConn diff --git a/neetbox/daemon/client/_daemon_client.py b/neetbox/daemon/client/_daemon_client.py index 6cbf9ef6..6ffa701e 100644 --- a/neetbox/daemon/client/_daemon_client.py +++ b/neetbox/daemon/client/_daemon_client.py @@ -23,7 +23,7 @@ def _upload_thread(daemon_config, base_addr, display_name): _ctr = 0 _api_name = "sync" - _api_addr = f"{base_addr}/{CLIENT_API_ROOT}/{_api_name}/{display_name}" + _api_addr = f"{base_addr}{CLIENT_API_ROOT}/{_api_name}/{display_name}" _disconnect_flag = False _disconnect_retries = 10 while True: @@ -91,9 +91,7 @@ def _check_daemon_alive(_api_addr): global __upload_thread if __upload_thread is None or not __upload_thread.is_alive(): __upload_thread = Thread( - target=_upload_thread, - daemon=True, - args=[cfg, _base_addr, _display_name], + target=_upload_thread, args=[cfg, _base_addr, _display_name], daemon=True ) __upload_thread.start() diff --git a/neetbox/daemon/readme.md b/neetbox/daemon/readme.md new file mode 100644 index 00000000..d977b70e --- /dev/null +++ b/neetbox/daemon/readme.md @@ -0,0 +1,126 @@ +# DAEMON readme + +## How to run server only + +at neetbox project root: +```bash +python neetbox/daemon/server/_server.py +``` + +## WS message standard + +websocke messages are described in json. There is a dataclass representing websocket message: + +```python +@dataclass +class WsMsg: + event_type: str + payload: Any + event_id: int = -1 + + def json(self): + return { + EVENT_TYPE_NAME_KEY: self.event_type, + EVENT_ID_NAME_KEY: self.event_id, + PAYLOAD_NAME_KEY: self.payload, + } +``` + +```json +{ + "event-type" : ..., + "payload" : ..., + "event-id" : ... +} +``` + +| key | value type | description | +| :--------: | :--------: | :----------------------------------------------------: | +| event-type | string | indicate type of data in payload | +| payload | string | actual data | +| event-id | int | for events who need ack. default -1 means no event id. | + +## Event types + +the table is increasing. a frequent check would keep you up to date. + +| event-type | accepting direction | means | +| :--------: | :---------------------------: | :----------------------------------------------------------: | +| handshake | cli <--> server <--> frontend | string in `payload` indicate connection type ('cli'/'web') | +| log | cli -> server -> frontend | `payload` contains log data | +| action | cli <- server <- frontend | `payload` contains action trigger | +| ack | cli <--> server <--> frontend | `payload` contains ack, and `event-id` should be a valid key | + +## Examples of websocket data + +### handshake + +for instance, frontend connected to server. frontend should report connection type immediately by sending: + +```json +{ + "event-type": "handshake", + "name": "project name", + "payload": { + "who": "web" + }, + "event-id": X +} +``` + +where `event-id` is used to send ack to the starter of the connection, it should be a random int value. + +### cli sending log to frontend + +cli sents log(s) via websocket, server will receives and broadcast this message to related frontends. cli should send: + +```json +{ + "event-type": "log", + "name": "project name", + "payload": { + "log" : {...json representing log data...} + }, + "event-id": -1 +} +``` + +where `event-id` is a useless segment, leave it default. it's okay if nobody receives log. + +### frontend(s) querys action to cli + +frontend send action request to server, and server will forwards the message to cli. frontend should send: + +```json +{ + "event-type" : "action", + "name": "project name", + "payload" : { + "action" : {...json representing action trigger...} + }, + "event-id" : x +} +``` + +front may want to know the result of action. for example, whether the action was invoked successfully. therefore, `event-id` is necessary for cli to shape a ack response. + +### cli acks frontend action query + +cli execute action query(s) from frontend, and gives response by sending ack: + +```json +{ + "event-type" : "ack", + "name": "project name", + "payload" : { + "action" : {...json representing action result...} + }, + "event-id" : x +} +``` + +where `event-id` is same as received action query. + +--- + +Those are only examples. use them wisely. diff --git a/neetbox/daemon/server/_fastapi_server.py b/neetbox/daemon/server/_fastapi_server.py new file mode 100644 index 00000000..45081031 --- /dev/null +++ b/neetbox/daemon/server/_fastapi_server.py @@ -0,0 +1,155 @@ +from typing import Any, Dict, List + +import uvicorn +from fastapi import FastAPI, HTTPException, WebSocket +from pydantic import BaseModel +from starlette.endpoints import WebSocketEndpoint + +FRONTEND_API_ROOT = "/web" +CLIENT_API_ROOT = "/cli" + +app = FastAPI() + + +# =============================================================== +# Client functions (backend) +# =============================================================== + + +class ClientEndpoint(WebSocketEndpoint): + encoding = "json" + subscriptions: Dict[str, List[WebSocket]] = {} + socket_pool: Dict[str, WebSocket] = {} + + def __init__(self, scope, receive, send, name: str): + super().__init__(scope, receive, send) + self.name = name + + async def on_connect(self, websocket: WebSocket): + await websocket.accept() + ClientEndpoint.socket_pool[self.name] = websocket # add to socket pool + + async def on_receive(self, websocket: WebSocket, data: Any): + """ + ┌────►Viewer + │ + │ + Client─────►Center──►Viewer + │ + │ + └────►Viewer + """ + for ws in ClientEndpoint.subscriptions[self.name]: + await ws.send_json(data) + + async def on_disconnect(self, websocket: WebSocket, close_code: int): + pass + + @staticmethod + def subscribe(websocket: WebSocket, name: str): + if name not in ClientEndpoint.subscriptions.keys(): + ClientEndpoint.subscriptions[name] = [] + ClientEndpoint.subscriptions[name].append(websocket) + + +@app.websocket(f"{CLIENT_API_ROOT}" + "/ws/") +async def handle_client_websocket(websocket: WebSocket, name: str): + await ClientEndpoint(websocket.scope, websocket.receive, websocket.send, name) + + +class Client(BaseModel): + status: dict = {} + + +class ClientManager: + def __init__(self) -> None: + self._client_registry: Dict[str, Client] = {} + + def register(self, name: str, client: Client): + if name in self._client_registry.keys(): + raise ValueError(f"Client with name {name} already exists.") + self._client_registry[name] = client + + def get(self, name: str): + return self._client_registry[name] + + def get_all(self): + return self._client_registry + + +client_manager = ClientManager() + + +@app.get("/register/") +async def register_client(name: str): + try: + client_manager.register(name, Client()) + except ValueError: + raise HTTPException(status_code=400, detail="Client already exists.") + + +class Status(BaseModel): + status: dict = {} + + +@app.post(f"{CLIENT_API_ROOT}" + "/sync/") +async def sync_client(name: str, status: Status): + try: + client = client_manager.get(name) + except ValueError: + raise HTTPException(status_code=404, detail="Client not found.") + client.status = status.status + + +# =============================================================== +# Viewer functions (frontend) +# =============================================================== + + +class ViewerEndpoint(WebSocketEndpoint): + encoding = "json" + + def __init__(self, scope, receive, send, name: str): + super().__init__(scope, receive, send) + self.name = name + + async def on_connect(self, websocket: WebSocket): + await websocket.accept() + ClientEndpoint.subscribe(websocket, self.name) + + async def on_receive(self, websocket: WebSocket, data: Any): + """ + Viewer─────►Center─────►Client + """ + await ClientEndpoint.socket_pool[self.name].send_json(data) + + async def on_disconnect(self, websocket: WebSocket, close_code: int): + ClientEndpoint.subscriptions[self.name].remove(websocket) # remove from subscription + + +@app.websocket(f"{FRONTEND_API_ROOT}" + "/ws/") +async def handle_viewer_websocket(websocket: WebSocket, name: str): + await ViewerEndpoint(websocket.scope, websocket.receive, websocket.send, name) + + +class ClientList(BaseModel): + names: List[str] + + +@app.get(f"{FRONTEND_API_ROOT}" + "/list", response_model=ClientList) +async def return_names_of_status(): + return client_manager.get_all().keys() + + +@app.get(f"{FRONTEND_API_ROOT}" + "/status/", response_model=Status) +async def return_status_of(name: str): + try: + client = client_manager.get(name) + except ValueError: + raise HTTPException(status_code=404, detail="Client not found.") + return client.status + + +if __name__ == "__main__": + cfg = {"port": 5000, "host": ""} + uvicorn.run("_fastapi_server:app", host=cfg["host"], port=cfg["port"], reload=True) diff --git a/neetbox/daemon/server/_server.py b/neetbox/daemon/server/_server.py index 455ce905..49a8fedf 100644 --- a/neetbox/daemon/server/_server.py +++ b/neetbox/daemon/server/_server.py @@ -7,18 +7,13 @@ import os import sys import time +from dataclasses import dataclass from threading import Thread -from typing import Dict, Tuple +from typing import Any, Dict, Tuple import setproctitle -from flask import Flask, abort, json, jsonify, request -from flask_socketio import SocketIO -from flask_socketio import emit as ws_emit -from flask_socketio import send as ws_send - -from neetbox.config import get_module_level_config -from neetbox.core import Registry -from neetbox.logging import logger +from flask import Flask, abort, json, request +from websocket_server import WebsocketServer __DAEMON_SHUTDOWN_IF_NO_UPLOAD_TIMEOUT_SEC = 60 * 60 * 12 # 12 Hours __COUNT_DOWN = __DAEMON_SHUTDOWN_IF_NO_UPLOAD_TIMEOUT_SEC @@ -28,35 +23,42 @@ FRONTEND_API_ROOT = "/web" CLIENT_API_ROOT = "/cli" +EVENT_TYPE_NAME_KEY = "event-type" +EVENT_ID_NAME_KEY = "event-id" +PAYLOAD_NAME_KEY = "payload" +NAME_NAME_KEY = "name" + + +@dataclass +class WsMsg: + name: str + event_type: str + payload: Any + event_id: int = -1 + + def json(self): + return { + NAME_NAME_KEY: self.name, + EVENT_TYPE_NAME_KEY: self.event_type, + EVENT_ID_NAME_KEY: self.event_id, + PAYLOAD_NAME_KEY: self.payload, + } -def daemon_process(cfg=None, debug=False): - # getting config - cfg = cfg or get_module_level_config() +def daemon_process(cfg, debug=False): # describe a client - class Client: + class Bridge: connected: bool name: str status: dict = {} - cli_ws_sid = None # cli ws sid - web_ws_sids = ( + cli_ws = None # cli ws sid + web_ws_list = ( [] ) # frontend ws sids. client data should be able to be shown on multiple frontend def __init__(self, name) -> None: # initialize non-websocket things self.name = name - pass - - def _ws_post_init(self, websocket): # handle handshakes - # initialize websocket things - pass - - @staticmethod - def from_ws(websocket): - new_client_connection = Client() - new_client_connection.cli_ws_sid = websocket - new_client_connection._ws_post_init(websocket) def handle_ws_recv(self): pass @@ -72,9 +74,10 @@ def ws_send(self): app = APIFlask(__name__) else: app = Flask(__name__) - socketio = SocketIO(app, cors_allowed_origins="*") - __client_registry = Registry("__daemon_server") # manage connections - connected_clients: Dict(str, Tuple(str, str)) = {} # {sid:(name,type)} store connection only + # websocket server + ws_server = WebsocketServer(port=cfg["port"] + 1) + __BRIDGES = {} # manage connections + connected_clients: Dict(int, Tuple(str, str)) = {} # {cid:(name,type)} store connection only # ======================== WS SERVER =========================== @@ -125,79 +128,114 @@ def ws_send(self): - forward message to client """ - @socketio.on("connect") - def handle_ws_connect(): - name = request.args.get("name") - path = request.path - path2type = {f"{FRONTEND_API_ROOT}": "web", f"{CLIENT_API_ROOT}": "cli"} - if not name or path not in path2type: - # connection args not valid, drop connection - return - # TODO (visualdust) check conn type for error handling - conn_type = path2type(path) - if name not in __client_registry: # Client not found. create from websocket connection - # must be cli - client = Client(name=name) - __client_registry._register(what=client, name=name) # manage clients - connected_clients[request.sid] = ( - name, - conn_type, - ) # store connection sid for later disconnection handling - if conn_type == "cli": - # add to Client - if __client_registry[name].cli_ws_sid is not None: - # overwriting, show warning - logger.warn(f"cli conn with same name already exist, overwriting...") - __client_registry[name].cli_ws_sid = request.sid - if conn_type == "web": - # add to Client - __client_registry[name].web_ws_sids.append(request.sid) - logger.ok(f"Websocket ({conn_type}) connected for {name} via {path}") - - @socketio.on("disconnect") - def handle_ws_disconnect(): - name, conn_type = connected_clients[request.sid] - # remove sid from Client entity - if conn_type == "cli": # remove client sid from Client - __client_registry[name].cli_ws_sid = None - else: - __client_registry[name].web_ws_sids.remove(request.sid) - del connected_clients[request.sid] - logger.info(f"Websocket ({conn_type}) for {name} disconnected") - - @socketio.on("json") - def handle_ws_json_message(data): - name, conn_type = connected_clients[request.sid] # who - if conn_type == "cli": # json data ws_send by client - for target_sid in __client_registry[name].web_ws_sids: - ws_send(data, to=target_sid) # forward to every client under this name - # no ack, not necessary to ack - if conn_type == "web": # json data ws_send by frontend - cli_ws_sid = __client_registry[name].cli_ws_sid - if cli_ws_sid is None: # client ws disconnected: - # ack err - ws_send({"ack": "failed", "message": "client ws disconnected"}, to=request.sid) - logger.warn( - f"frontend ({request.sid}) under name '{name}' tried to talk to a disconnected client ws." + def handle_ws_connect(client, server): + print(f"client {client} connected. waiting for assigning...") + + def handle_ws_disconnect(client, server): + _project_name, _who = connected_clients[client["id"]] + if _who == "cli": # remove client from Bridge + __BRIDGES[_project_name].cli_ws = None + else: # remove frontend from Bridge + _new_web_ws_list = [ + c for c in __BRIDGES[_project_name].web_ws_list if c["id"] != client["id"] + ] + __BRIDGES[_project_name].web_ws_list = _new_web_ws_list + del connected_clients[client["id"]] + print(f"a {_who} disconnected with id {client['id']}") + # logger.info(f"Websocket ({conn_type}) for {name} disconnected") + + def handle_ws_message(client, server: WebsocketServer, message): + message = json.loads(message) + print(message) # debug + # handle event-type + _event_type = message[EVENT_TYPE_NAME_KEY] + _payload = message[PAYLOAD_NAME_KEY] + _event_id = message[EVENT_ID_NAME_KEY] + _project_name = message[NAME_NAME_KEY] + if _event_type == "handshake": # handle handshake + # assign this client to a Bridge + _who = _payload["who"] + if _who == "web": + # new connection from frontend + # check if Bridge with name exist + if _project_name not in __BRIDGES: # there is no such bridge + server.send_message( + client=client, + msg=WsMsg( + event_type="ack", + event_id=_event_id, + payload={"result": "404", "reason": "name not found"}, + ).json(), + ) + else: # assign web to bridge + _target_bridge = __BRIDGES[_project_name] + _target_bridge.web_ws_list.append(client) + connected_clients[client["id"]] = (_project_name, "web") + server.send_message( + client=client, + msg=WsMsg( + event_type="ack", + event_id=_event_id, + payload={"result": "200", "reason": "join success"}, + ).json(), + ) + elif _who == "cli": + # new connection from cli + # check if Bridge with name exist + if _project_name not in __BRIDGES: # there is no such bridge + _target_bridge = Bridge(name=_project_name) # create new bridge for this name + __BRIDGES[_project_name] = _target_bridge + __BRIDGES[_project_name].cli_ws = client # assign cli to bridge + connected_clients[client["id"]] = (_project_name, "web") + server.send_message( + client=client, + msg=WsMsg( + name="_project_name", + event_type="ack", + event_id=_event_id, + payload={"result": "200", "reason": "join success"}, + ).json(), ) - else: # forward to client - target_sid = __client_registry[name].cli_ws_sid - ws_send(data, to=target_sid) + + elif _event_type == "log": # handle log + # forward log to frontend + if _project_name not in __BRIDGES: + # project name must exist + # drop anyway if not exist + return + else: + # forward to frontends + _target_bridge = __BRIDGES[_project_name] + for web_ws in _target_bridge.web_ws_list: + server.send_message( + client=web_ws, msg=message + ) # forward original message to frontend + + elif _event_type == "action": + # todo forward action query to cli + pass + elif _event_type == "ack": + # todo forward ack to waiting acks + pass + + ws_server.set_fn_new_client(handle_ws_connect) + ws_server.set_fn_client_left(handle_ws_disconnect) + ws_server.set_fn_message_received(handle_ws_message) # ======================== HTTP SERVER =========================== - @app.route(f"{FRONTEND_API_ROOT}/wsforward/", methods=["POST"]) - def handle_json_forward_to_client_ws(name): # forward frontend http json to client ws - data = request.json - if name in __client_registry: # client name exist - target_sid = __client_registry[name].cli_ws_sid - if target_sid is None: # no active cli ws connection for this name - logger.warn( - f"frontend tried to talk to forward to a disconnected client ws with name {name}." - ) - abort(404) - ws_send(data, to=target_sid) - return "ok" + # @app.route(f"{FRONTEND_API_ROOT}/wsforward/", methods=["POST"]) + # def handle_json_forward_to_client_ws(name): # forward frontend http json to client ws + # data = request.json + # if name in __BRIDGES: # client name exist + # target_sid = __BRIDGES[name].cli_ws + # if target_sid is None: # no active cli ws connection for this name + # # logger.warn( + # # f"frontend tried to talk to forward to a disconnected client ws with name {name}." + # # ) + # abort(404) + # ws_send(data, to=target_sid) + # return "ok" @app.route("/hello", methods=["GET"]) def just_send_hello(): @@ -207,10 +245,8 @@ def just_send_hello(): def return_status_of(name): global __COUNT_DOWN __COUNT_DOWN = __DAEMON_SHUTDOWN_IF_NO_UPLOAD_TIMEOUT_SEC - if not name: - pass # returning full dict - elif name in __client_registry: - _returning_stat = __client_registry[name].status # returning specific status + if name in __BRIDGES: + _returning_stat = __BRIDGES[name].status # returning specific status else: abort(404) return _returning_stat @@ -219,7 +255,7 @@ def return_status_of(name): def return_names_of_status(): global __COUNT_DOWN __COUNT_DOWN = __DAEMON_SHUTDOWN_IF_NO_UPLOAD_TIMEOUT_SEC - _names = {"names": list(__client_registry.keys())} + _names = {"names": list(__BRIDGES.keys())} return _names @app.route(f"{CLIENT_API_ROOT}/sync/", methods=["POST"]) @@ -227,9 +263,9 @@ def sync_status_of(name): # client side function global __COUNT_DOWN __COUNT_DOWN = __DAEMON_SHUTDOWN_IF_NO_UPLOAD_TIMEOUT_SEC _json_data = request.get_json() - if name not in __client_registry: # Client not found. create from sync request - __client_registry._register(what=Client(name=name), name=name) - __client_registry[name].status = _json_data + if name not in __BRIDGES: # Client not found + __BRIDGES[name] = Bridge(name=name) # Create from sync request + __BRIDGES[name].status = _json_data return "ok" @app.route(f"{FRONTEND_API_ROOT}/shutdown", methods=["POST"]) @@ -255,11 +291,19 @@ def _count_down_thread(): count_down_thread = Thread(target=_count_down_thread, daemon=True) count_down_thread.start() - socketio.run(app, host="0.0.0.0", port=cfg["port"], debug=debug) + ws_server.run_forever(threaded=True) + app.run(host="0.0.0.0", port=cfg["port"], debug=debug) if __name__ == "__main__": - import neetbox - - cfg = get_module_level_config(neetbox.daemon) + cfg = { + "enable": True, + "host": "localhost", + "port": 20202, + "displayName": None, + "allowIpython": False, + "mute": True, + "mode": "detached", + "uploadInterval": 10, + } daemon_process(cfg, debug=True) diff --git a/neetbox/integrations/engine.py b/neetbox/integrations/engine.py index 0b0b0bce..87330f9d 100644 --- a/neetbox/integrations/engine.py +++ b/neetbox/integrations/engine.py @@ -6,6 +6,7 @@ import importlib from enum import Enum +from functools import lru_cache from typing import List, Optional from neetbox.logging import logger @@ -22,7 +23,7 @@ def __str__(self) -> str: installed_engines: Optional[List] = None -# todo migrate to python 3.9 after frameworks are supporting it +@lru_cache def get_supported_engines(): global supported_engines if not supported_engines: @@ -42,6 +43,6 @@ def get_installed_engines(): importlib.import_module(engine.value) installed_engines.append(engine) logger.info(f"'{engine.vaule}' was found installed.") - except: + except ImportError: pass return installed_engines.copy() diff --git a/neetbox/integrations/environment/hardware.py b/neetbox/integrations/environment/hardware.py index e83c7659..c9ce5223 100644 --- a/neetbox/integrations/environment/hardware.py +++ b/neetbox/integrations/environment/hardware.py @@ -5,16 +5,6 @@ # Date: 20230413 -from neetbox.utils import pkg -from neetbox.utils.framing import get_frame_module_traceback - -module_name = get_frame_module_traceback().__name__ -assert pkg.is_installed( - "psutil", try_install_if_not=True -), f"{module_name} requires psutil which is not installed" -assert pkg.is_installed( - "GPUtil", try_install_if_not=True -), f"{module_name} requires GPUtil which is not installed" import time from threading import Thread @@ -23,8 +13,18 @@ from GPUtil import GPU from neetbox.pipeline import watch +from neetbox.utils import pkg +from neetbox.utils.framing import get_frame_module_traceback from neetbox.utils.mvc import Singleton +module_name = get_frame_module_traceback().__name__ # type: ignore +assert pkg.is_installed( + "psutil", try_install_if_not=True +), f"{module_name} requires psutil which is not installed" +assert pkg.is_installed( + "GPUtil", try_install_if_not=True +), f"{module_name} requires GPUtil which is not installed" + class _CPU_STAT(dict): def __init__(self, id=-1, percent=0.0, freq=0.0) -> None: @@ -97,15 +97,11 @@ def watcher_fun(env_instance: _Hardware, do_update_gpus: bool): freq=cpu_freq[index], ) if do_update_gpus: - env_instance["gpus"] = [ - _GPU_STAT.parse(_gpu) for _gpu in GPUtil.getGPUs() - ] + env_instance["gpus"] = [_GPU_STAT.parse(_gpu) for _gpu in GPUtil.getGPUs()] env_instance[""] = psutil.cpu_stats() time.sleep(env_instance._update_interval) - self._watcher = Thread( - target=watcher_fun, args=(self, self._with_gpu), daemon=True - ) + self._watcher = Thread(target=watcher_fun, args=(self, self._with_gpu), daemon=True) self._watcher.start() diff --git a/neetbox/integrations/environment/platform.py b/neetbox/integrations/environment/platform.py index 1eda7636..f05d980e 100644 --- a/neetbox/integrations/environment/platform.py +++ b/neetbox/integrations/environment/platform.py @@ -18,9 +18,7 @@ def __init__(self): # system self["username"] = getpass.getuser() self["machine"] = platform.machine() - self["processor"] = ( - "unknown" if len(platform.processor()) == 0 else platform.processor() - ) + self["processor"] = "unknown" if len(platform.processor()) == 0 else platform.processor() self["os_name"] = platform.system() self["os_release"] = platform.version() self["architecture"] = platform.architecture() @@ -39,9 +37,7 @@ def exec(self, command): str: The command running results. err: The command error information. """ - p = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True - ) + p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) raw_output, raw_err = p.communicate() rc = p.returncode if self.platform_info["architecture"] == "32bit": diff --git a/neetbox/integrations/resource.py b/neetbox/integrations/resource.py index fd2bc6f9..0b82dc1e 100644 --- a/neetbox/integrations/resource.py +++ b/neetbox/integrations/resource.py @@ -31,9 +31,7 @@ from neetbox.logging import logger from neetbox.utils import pkg -_loader_pool: Dict[ - str, "ResourceLoader" -] = dict() # all ResourceLoaders are stored here +_loader_pool: Dict[str, "ResourceLoader"] = dict() # all ResourceLoaders are stored here class ResourceLoader: @@ -104,9 +102,7 @@ def perform_scan(): glob_str = "**/*" if self._scan_sub_dirs else "*" if not verbose: # do not output self.file_path_list = [ - str(path) - for path in pathlib.Path(self.path).glob(glob_str) - if can_match(path) + str(path) for path in pathlib.Path(self.path).glob(glob_str) if can_match(path) ] else: self.file_path_list = [] @@ -175,7 +171,7 @@ def get_random_image_as_numpy(self): return np.array(image) def get_random_image_as_tensor(self, engine=engine.Torch): - assert engine in [engine.Torch] # todo support other engines + assert engine in [engine.Torch] # TODO support other engines if engine == engine.Torch: assert pkg.is_installed("torchvision") import torchvision.transforms as T @@ -186,12 +182,10 @@ def get_random_image_as_tensor(self, engine=engine.Torch): T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), ] ) - image = tensor_transform(self.get_random_image()).unsqueeze( - 0 - ) # To tensor of NCHW + image = tensor_transform(self.get_random_image()).unsqueeze(0) # To tensor of NCHW return image - # todo to_dataset + # TODO(VisualDust): to_dataset def download( diff --git a/neetbox/logging/__init__.py b/neetbox/logging/__init__.py index 5a00750e..dc8029b4 100644 --- a/neetbox/logging/__init__.py +++ b/neetbox/logging/__init__.py @@ -5,6 +5,5 @@ _cfg = get_module_level_config() logger.set_log_dir(_cfg["logdir"]) set_log_level(_cfg["level"]) -from neetbox.logging.logger import LogSplitStrategies -__all__ = ["logger", "LogSplitStrategies"] +__all__ = ["logger"] diff --git a/neetbox/logging/_writer.py b/neetbox/logging/_writer.py index 6f4f7a01..0db2958e 100644 --- a/neetbox/logging/_writer.py +++ b/neetbox/logging/_writer.py @@ -1,27 +1,192 @@ +import inspect +import io +import json +import os +import pathlib +from dataclasses import dataclass +from datetime import date, datetime +from typing import Any, Callable, Iterable, Optional, Union + +from rich import print as rprint + +from neetbox.logging.formatting import LogStyle, colored_text, styled_text +from neetbox.utils import formatting + + +# Log writer interface class LogWriter: - def write(self, raw_msg): + def write(self, raw_log): pass -class ConsoleLogWriter(metaclass=LogWriter): - def __init__(self) -> None: - pass +# ================== DEFINE LOG TYPE ===================== - def write(self, raw_msg): - pass +@dataclass +class RawLog: + rich_msg: str + style: LogStyle + caller_identity: Any + whom: Any = None + prefix: Optional[str] = None + datetime_format: Optional[str] = None + with_identifier: Optional[bool] = None + with_datetime: Optional[bool] = None + skip_writers: Optional[list[str]] = None -class FileLogWriter(metaclass=LogWriter): - def __init__(self) -> None: - pass + def write_by(self, writer: LogWriter) -> bool: + if self.skip_writers: # if skipping any writers + for swr in self.skip_writers: + if isinstance(writer, RawLog.name2writerType[swr]): + return False # skip this writer, do not write + writer.write(self) + return False - def write(self, raw_msg): - pass + def json(self) -> dict: + _default_style = self.style + # prefix + _prefix = self.prefix or _default_style.prefix + # composing datetime + _with_datetime = self.with_datetime or _default_style.with_datetime + _datetime = "" + if _with_datetime: + _datetime_fmt = self.datetime_format or _default_style.datetime_format + _datetime = datetime.now().strftime(_datetime_fmt) + # composing identifier + _whom = "" + _with_identifier = self.with_identifier or _default_style.with_identifier + if _with_identifier: + _caller_identity = self.caller_identity + _whom = str(self.whom) # check identity + id_seq = [] + if self.whom is None: # if using default logger, tracing back to the caller + file_level = True + _whom = "" + if _caller_identity.module_name and _default_style.trace_level >= 2: + # trace as module level + id_seq.append(_caller_identity.module_name) + file_level = False + if _caller_identity.class_name and _default_style.trace_level >= 1: + # trace as class level + id_seq.append(_caller_identity.class_name) + file_level = False + if file_level and _default_style.trace_level >= 1: + id_seq.append(_caller_identity.filename) # not module level and class level + if _caller_identity.func_name != "": + id_seq.append(_caller_identity.func_name) # skip for jupyters + for i in range(len(id_seq)): + if len(_whom) != 0: + _whom += _default_style.split_char_identity + _whom += id_seq[i] + return {"prefix": _prefix, "datetime": _datetime, "whom": _whom, "msg": self.rich_msg} -class WebSocketLogWriter(metaclass=LogWriter): - def __init__(self) -> None: - pass + def __repr__(self) -> str: + return json.dumps(self.json(), default=str) + + +# ================== CONSOLE LOG WRITER ===================== + + +class __ConsoleLogWriter(LogWriter): + def write(self, raw_log: RawLog): + _msg_dict = raw_log.json() + _style = raw_log.style + rich_msg = str( + _msg_dict["prefix"] + + _msg_dict["datetime"] + + _style.split_char_cmd * min(len(_msg_dict["datetime"]), 1) + + styled_text(_msg_dict["whom"], style=_style) + + _style.split_char_cmd * min(len(_msg_dict["whom"]), 1) + + _msg_dict["msg"] + ) + rprint(rich_msg) + + +# console writer singleton +consoleLogWriter = __ConsoleLogWriter() + + +# ================== FILE LOG WRITER ===================== + + +class FileLogWriter(LogWriter): + # class level static pool + PATH_2_FILE_WRITER = {} + + # instance level non-static things + file_writer = None # assign in __init__ - def write(self, raw_msg): + def __new__(cls, path, *args, **kwargs): + # per file, per writer. + file_abs_path = os.path.abspath(path) + if os.path.isdir(file_abs_path): + raise Exception("Target path is not a file.") + filename = formatting.legal_file_name_of(os.path.basename(path)) + dirname = os.path.dirname(path) if len(os.path.dirname(path)) != 0 else "." + if not os.path.exists(dirname): + raise Exception(f"Could not find dictionary {dirname}") + real_path = os.path.join(dirname, filename) + if file_abs_path not in FileLogWriter.PATH_2_FILE_WRITER: + newWriter = LogWriter.__new__(cls) + newWriter.file_writer = open(real_path, "a", encoding="utf-8", buffering=1) + FileLogWriter.PATH_2_FILE_WRITER[file_abs_path] = newWriter + + return FileLogWriter.PATH_2_FILE_WRITER[file_abs_path] + + def __init__(self, path) -> None: + self.file_writer = open(path, "a", encoding="utf-8", buffering=1) + + def write(self, raw_log: RawLog): + _msg_dict = raw_log.json() + _style = raw_log.style + text_msg = str( + _msg_dict["prefix"] + + _msg_dict["datetime"] + + _style.split_char_txt * min(len(_msg_dict["datetime"]), 1) + + _msg_dict["whom"] + + _style.split_char_txt * min(len(_msg_dict["whom"]), 1) + + _msg_dict["msg"] + + "\n" + ) + self.file_writer.write(text_msg) + + +# ================== WS LOG WRITER ===================== + + +class _WebSocketLogWriter(LogWriter): + # class level statics + connection = None # connection should be assigned by neetbox.daemon.client._connection to avoid recursive import + + def write(self, raw_log: RawLog): + json_data = raw_log.json() + + if _WebSocketLogWriter.connection: + _WebSocketLogWriter.connection.ws_send(event_type="log", payload=json_data) + + +def _assign_connection_to_WebSocketLogWriter(conn): + _WebSocketLogWriter.connection = conn + + +webSocketLogWriter = _WebSocketLogWriter() + + +# ================== JSON LOG WRITER ===================== + + +class JsonLogWriter(FileLogWriter): + def write(self, raw_log: RawLog): + # todo convert to json and write to file pass + + +# ================== POST INIT TYPE REF ===================== + +RawLog.name2writerType = { + "stdout": __ConsoleLogWriter, + "file": FileLogWriter, + "ws": _WebSocketLogWriter, + "json": JsonLogWriter, +} diff --git a/neetbox/logging/logger.py b/neetbox/logging/logger.py index f18e7990..92faae4c 100644 --- a/neetbox/logging/logger.py +++ b/neetbox/logging/logger.py @@ -5,18 +5,22 @@ # Date: 20230315 import functools -import io import os -import pathlib from datetime import date, datetime from enum import Enum -from inspect import isclass, iscoroutinefunction, isgeneratorfunction from random import randint -from typing import Any, Callable, Iterable, Optional, Union +from typing import Any, Optional, Union from rich import print as rprint from rich.panel import Panel +from neetbox.logging._writer import ( + FileLogWriter, + JsonLogWriter, + RawLog, + consoleLogWriter, + webSocketLogWriter, +) from neetbox.logging.formatting import LogStyle, colored_text, styled_text from neetbox.utils import formatting from neetbox.utils.framing import get_caller_identity_traceback @@ -48,194 +52,29 @@ def __ge__(self, other): return self.value >= other.value -writers_dict = {} -style_dict = {} -loggers_dict = {} - -_GLOBAL_LOG_LEVEL = LogLevel.INFO - - -def set_log_level(level: LogLevel): - if type(level) is str: - level = { - "ALL": LogLevel.ALL, - "DEBUG": LogLevel.DEBUG, - "INFO": LogLevel.INFO, - "WARNING": LogLevel.WARNING, - "ERROR": LogLevel.ERROR, - }[level] - if type(level) is int: - assert level >= 0 and level <= 3 - level = LogLevel(level) - global _GLOBAL_LOG_LEVEL - _GLOBAL_LOG_LEVEL = level - - -class LogMetadata: - def __init__(self, writer: "_AutoSplitLogWriter"): - self.written_bytes = 0 - self.log_writer = writer - - -SplitStrategyCallable = Callable[[LogMetadata], Union[str, Iterable[str]]] - - -class LogSplitStrategies: - @staticmethod - def by_date() -> SplitStrategyCallable: - def _split_strategy(metadata: LogMetadata): - return date.today().strftime("%Y%m%d") - - return _split_strategy - - @staticmethod - def by_hour() -> SplitStrategyCallable: - def _split_strategy(metadata: LogMetadata): - return datetime.now().strftime("%Y%m%d-%H") - - return _split_strategy - - @staticmethod - def by_date_and_size(size_in_bytes: int) -> SplitStrategyCallable: - class DateSizeSplitStrategy: - def __init__(self): - self.file_id = None - - def _already_exists(self, metadata: LogMetadata, file_id: int) -> bool: - f = metadata.log_writer.make_logfile_path(self.make_result(file_id)) - return f.exists() - - def make_result(self, file_id): - return date.today().strftime("%Y%m%d"), str(file_id) - - def __call__(self, metadata: LogMetadata): - if self.file_id is None: - self.file_id = 0 - while self._already_exists(metadata, self.file_id): - self.file_id += 1 - return self.make_result(self.file_id + metadata.written_bytes // size_in_bytes) - - return DateSizeSplitStrategy() - - -class _AutoSplitLogWriter(io.TextIOBase): - class ReentrantCounter: - def __init__(self): - self._count = 0 - - def __enter__(self): - self._count += 1 - - def __exit__(self, exc_type, exc_val, exc_tb): - self._count -= 1 - - def __bool__(self): - return self._count > 0 - - _writer: Union[io.IOBase, None] - _filename_template: str - _split_strategy: Union[SplitStrategyCallable, Callable] - _current_logfile: Union[pathlib.Path, None] - - def __init__( - self, - base_dir, - filename_template, - split_strategy: Optional[SplitStrategyCallable], - *, - encoding="utf-8", - open_on_creation=True, - overwrite_existing=False, - ) -> None: - self._writer = None - self._current_logfile = None - self._filename_template = filename_template - self._base_dir = pathlib.Path(str(base_dir)) - self._encoding = encoding - self._open_mode = "wb" if overwrite_existing else "ab" - self._split_lock = _AutoSplitLogWriter.ReentrantCounter() - - self._split_strategy = (lambda *_: None) if split_strategy is None else split_strategy - - self._stats = LogMetadata(self) - - if open_on_creation: - self.open() - - def _apply_filename_template(self, provider_supplied): - if provider_supplied is None: - return self._filename_template - if isinstance(provider_supplied, str): - return provider_supplied - if isinstance(provider_supplied, Iterable): - return self._filename_template.format(*provider_supplied) - - raise ValueError("Filename provider must return either a string or an iterable of strings") - - def make_logfile_path(self, provider_supplied): - return self._base_dir / self._apply_filename_template(provider_supplied) - - def _create_logfile(self): - expected_logfile = self.make_logfile_path(self._split_strategy(self._stats)) - if expected_logfile != self._current_logfile: - if self._writer is not None: - self._writer.close() - expected_logfile.parent.mkdir(parents=True, exist_ok=True) - self._current_logfile = expected_logfile - self._writer = open(self._current_logfile, self._open_mode) # type: ignore - - def _check_open(self): - if self._writer is None: - raise ValueError("Writer not opened") - - def write(self, __s): - self._check_open() - if not self._split_lock: - self._create_logfile() - - print("writing") - bytes = __s.encode(self._encoding) - self._stats.written_bytes += len(bytes) - self._writer.write(bytes) - - def writelines(self, __lines: Iterable[str]) -> None: - for line in __lines: - self.write(line + "\n") - - def open(self): - self._create_logfile() - - def __enter__(self): - if self._writer is None: - self.open() - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self.close() - - def flush(self): - self._writer.flush() - - def close(self): - if self._writer is not None: - self._writer.close() - - def split_lock(self): - return self._split_lock +_GLOBAL_LOG_LEVEL = LogLevel.ALL class Logger: + # global static + __WHOM_2_LOGGER = {} + __WHOM_2_STYLE = {} + def __init__(self, whom, style: Optional[LogStyle] = None): self.whom: Any = whom self.style: Optional[LogStyle] = style + # default writing to console and ws + self.console_writer = consoleLogWriter + self.ws_writer = webSocketLogWriter self.file_writer = None def __call__(self, whom: Any = None, style: Optional[LogStyle] = None) -> "Logger": if whom is None: return DEFAULT_LOGGER - if whom in loggers_dict: - return loggers_dict[whom] - loggers_dict[whom] = Logger(whom=whom, style=style) - return loggers_dict[whom] + if whom in Logger.__WHOM_2_LOGGER: + return Logger.__WHOM_2_LOGGER[whom] + Logger.__WHOM_2_LOGGER[whom] = Logger(whom=whom, style=style) + return Logger.__WHOM_2_LOGGER[whom] def log( self, @@ -244,89 +83,44 @@ def log( datetime_format: Optional[str] = None, with_identifier: Optional[bool] = None, with_datetime: Optional[bool] = None, - into_file: bool = True, - into_stdout: bool = True, + skip_writers: Optional[Union[list[str], str]] = None, traceback=2, ): _caller_identity = get_caller_identity_traceback(traceback=traceback) - # getting style + # converting args into a single string + _pure_str_message = "" + for msg in content: + _pure_str_message += str(msg) + " " + + if type(skip_writers) is str: + skip_writers = [skip_writers] + _style = self.style if not _style: # if style not set _style_index = str(_caller_identity) - if _style_index in style_dict: # check for previous style - _style = style_dict[_style_index] + if _style_index in Logger.__WHOM_2_STYLE: # check for previous style + _style = Logger.__WHOM_2_STYLE[_style_index] else: _style = LogStyle().randcolor() - style_dict[_style_index] = _style - - # composing prefix - _prefix = _style.prefix - if prefix is not None: # if using specific prefix - _prefix = prefix - - # composing datetime - _with_datetime = _style.with_datetime - _datetime = "" - if with_datetime is not None: # if explicitly determined wether to log with datetime - _with_datetime = with_datetime - if _with_datetime: - _datetime_fmt = datetime_format if datetime_format else _style.datetime_format - _datetime = datetime.now().strftime(_datetime_fmt) - - # if with identifier - _whom = "" - _with_identifier = _style.with_identifier - if with_identifier is not None: # if explicitly determined wether to log with identifier - _with_identifier = with_identifier - if _with_identifier: - _whom = str(self.whom) # check identity - id_seq = [] - if self.whom is None: # if using default logger, tracing back to the caller - file_level = True - _whom = "" - if _caller_identity.module_name and _style.trace_level >= 2: - # trace as module level - id_seq.append(_caller_identity.module_name) - file_level = False - if _caller_identity.class_name and _style.trace_level >= 1: - # trace as class level - id_seq.append(_caller_identity.class_name) - file_level = False - if file_level and _style.trace_level >= 1: - id_seq.append(_caller_identity.filename) # not module level and class level - if _caller_identity.func_name != "": - id_seq.append(_caller_identity.func_name) # skip for jupyters - for i in range(len(id_seq)): - if len(_whom) != 0: - _whom += _style.split_char_identity - _whom += id_seq[i] - - # converting args into a single string - _pure_str_message = "" - for msg in content: - _pure_str_message += str(msg) + " " + Logger.__WHOM_2_STYLE[_style_index] = _style + + raw_log = RawLog( + rich_msg=_pure_str_message, + caller_identity=_caller_identity, + whom=self.whom, + style=_style, + prefix=prefix, + datetime_format=datetime_format, + with_identifier=with_identifier, + with_datetime=with_datetime, + skip_writers=skip_writers, + ) + + for writer in [self.console_writer, self.ws_writer, self.file_writer]: + if writer: + raw_log.write_by(writer) - # perform log - if into_stdout: - rprint( - _prefix - + _datetime - + _style.split_char_cmd * min(len(_datetime), 1) - + styled_text(_whom, style=_style) - + _style.split_char_cmd * min(len(_whom), 1), - _pure_str_message, - ) - if into_file and self.file_writer: - self.file_writer.write( - _prefix - + _datetime - + _style.split_char_txt * min(len(_datetime), 1) - + _whom - + _style.split_char_txt * min(len(_whom), 1) - + _pure_str_message - + "\n" - ) return self def ok(self, *message, flag="OK"): @@ -334,10 +128,10 @@ def ok(self, *message, flag="OK"): self.log( *message, prefix=f"[{colored_text(flag, 'green')}]", - into_file=False, + skip_writers=["file"], traceback=3, ) - self.log(*message, prefix=flag, into_stdout=False, traceback=3) + self.log(*message, prefix=flag, skip_writers=["stdout"], traceback=3) return self def debug(self, *message, flag="DEBUG"): @@ -345,10 +139,10 @@ def debug(self, *message, flag="DEBUG"): self.log( *message, prefix=f"[{colored_text(flag, 'cyan')}]", - into_file=False, + skip_writers=["file"], traceback=3, ) - self.log(*message, prefix=flag, into_stdout=False, traceback=3) + self.log(*message, prefix=flag, skip_writers=["stdout"], traceback=3) return self def info(self, *message, flag="INFO"): @@ -356,10 +150,10 @@ def info(self, *message, flag="INFO"): self.log( *message, prefix=f"[{colored_text(flag, 'white')}]", - into_file=False, + skip_writers=["file"], traceback=3, ) - self.log(*message, prefix=flag, into_stdout=False, traceback=3) + self.log(*message, prefix=flag, skip_writers=["stdout"], traceback=3) return self def warn(self, *message, flag="WARNING"): @@ -367,10 +161,10 @@ def warn(self, *message, flag="WARNING"): self.log( *message, prefix=f"[{colored_text(flag, 'yellow')}]", - into_file=False, + skip_writers=["file"], traceback=3, ) - self.log(*message, prefix=flag, into_stdout=False, traceback=3) + self.log(*message, prefix=flag, skip_writers=["stdout"], traceback=3) return self def err(self, err, flag="ERROR", reraise=False): @@ -380,74 +174,14 @@ def err(self, err, flag="ERROR", reraise=False): self.log( str(err), prefix=f"[{colored_text(flag,'red')}]", - into_file=False, + skip_writers=["file"], traceback=3, ) - self.log(str(err), prefix=flag, into_stdout=False, traceback=3) - if reraise: + self.log(str(err), prefix=flag, skip_writers=["stdout"], traceback=3) + if reraise or _GLOBAL_LOG_LEVEL >= LogLevel.DEBUG: raise err return self - def catch( - self, exception_type=Exception, *, reraise=True, handler=None - ): # todo add handler interface - if callable(exception_type) and ( - not isclass(exception_type) or not issubclass(exception_type, BaseException) - ): - return self.catch()(exception_type) - - class Catcher: - def __init__(self, from_decorator): - self._from_decorator = from_decorator - - def __enter__(self): - return None - - def __exit__(self, type_, value, traceback_): - if type_ is None: - return - if not issubclass(type_, exception_type): - return False - if handler: - handler(traceback_) - # logger.log( - # from_decorator, catch_options, traceback=4 if from_decorator else 3 - # ) - # todo add reraise functions - return not reraise - - def __call__(self, function): - if isclass(function): - raise TypeError( - "Invalid object decorated with 'catch()', it must be a function, " - "not a class (tried to wrap '%s')" % function.__name__ - ) - - catcher = Catcher(True) - - if iscoroutinefunction(function): - - async def catch_wrapper(*args, **kwargs): - with catcher: - return await function(*args, **kwargs) - - elif isgeneratorfunction(function): - - def catch_wrapper(*args, **kwargs): - with catcher: - return (yield from function(*args, **kwargs)) - - else: - - def catch_wrapper(*args, **kwargs): - with catcher: - return function(*args, **kwargs) - - functools.update_wrapper(catch_wrapper, function) - return catch_wrapper - - return Catcher(False) - def mention(self, func): @functools.wraps(func) def with_logging(*args, **kwargs): @@ -541,21 +275,7 @@ def _bind_file(self, path): if not path: self.file_writer = None return self - if not path: - self.file_writer = None - return self - log_file_identity = os.path.abspath(path) - if os.path.isdir(log_file_identity): - raise Exception("Target path is not a file.") - filename = formatting.legal_file_name_of(os.path.basename(path)) - dirname = os.path.dirname(path) if len(os.path.dirname(path)) != 0 else "." - if not os.path.exists(dirname): - raise Exception(f"Could not find dictionary {dirname}") - real_path = os.path.join(dirname, filename) - if log_file_identity not in writers_dict: - # todo add fflush buffer size or time - writers_dict[log_file_identity] = open(real_path, "a", encoding="utf-8", buffering=1) - self.file_writer = writers_dict[log_file_identity] + self.file_writer = FileLogWriter(path=path) return self def file_bend(self) -> bool: @@ -563,3 +283,20 @@ def file_bend(self) -> bool: DEFAULT_LOGGER = Logger(None) + + +def set_log_level(level: LogLevel): + if type(level) is str: + level = { + "ALL": LogLevel.ALL, + "DEBUG": LogLevel.DEBUG, + "INFO": LogLevel.INFO, + "WARNING": LogLevel.WARNING, + "ERROR": LogLevel.ERROR, + }[level] + if type(level) is int: + assert level >= 0 and level <= 3 + level = LogLevel(level) + global _GLOBAL_LOG_LEVEL + _GLOBAL_LOG_LEVEL = level + DEFAULT_LOGGER.debug(f"global log level was set to {_GLOBAL_LOG_LEVEL}") diff --git a/neetbox/pipeline/_signal_and_slot.py b/neetbox/pipeline/_signal_and_slot.py index 5e81d9ad..9f76f6ab 100644 --- a/neetbox/pipeline/_signal_and_slot.py +++ b/neetbox/pipeline/_signal_and_slot.py @@ -9,7 +9,7 @@ from datetime import datetime from functools import partial from threading import Thread -from typing import Any, Callable +from typing import Any, Callable, Optional, Union from neetbox.config import get_module_level_config from neetbox.core import Registry @@ -75,11 +75,13 @@ def _so_update_and_ping_listen(_name, _value, _watch_config): f"Watched value {_name} takes longer time({delta_t:.8f}s) to update than it was expected({expected_time_limit}s)." ) - Thread(target=_so_update_and_ping_listen, args=(name, _the_value, _watch_config)).start() + Thread( + target=_so_update_and_ping_listen, args=(name, _the_value, _watch_config), daemon=True + ).start() return _the_value -def _watch(func: Callable, name: str, freq: float, initiative=False, force=False): +def _watch(func: Callable, name: Optional[str], freq: float, initiative=False, force=False): """Function decorator to let the daemon watch a value of the function Args: @@ -114,9 +116,9 @@ def watch(name=None, freq=None, initiative=False, force=False): return partial(_watch, name=name, freq=freq, initiative=initiative, force=force) -def _listen(func: Callable, target: str, name: str = None, force=False): +def _listen(func: Callable, target: Union[str, Callable], name: Optional[str] = None, force=False): name = name or func.__name__ - if type(target) is not str: + if not isinstance(target, str): if type(target) is partial: if target.func in [__update_and_get, __get]: target = target.args[0] @@ -134,11 +136,11 @@ def _listen(func: Callable, target: str, name: str = None, force=False): f"There is already a listener called '{name}' lisiting '{target}', overwriting." ) _listen_queue_dict[target][name] = func - logger.log(f"{name} is now lisiting to {target}.") + logger.debug(f"{name} is now lisiting to {target}.") return func -def listen(target, name: str = None, force=False): +def listen(target, name: Optional[str] = None, force=False): return partial(_listen, target=target, name=name, force=force) @@ -151,7 +153,7 @@ def _update_thread(): for _vname, _watched_fun in _watch_queue_dict.items(): _watch_config = _watched_fun.others if not _watch_config["initiative"] and _ctr % _watch_config["freq"] == 0: # do update - _the_value = __update_and_get(_vname) + _ = __update_and_get(_vname) update_thread = Thread(target=_update_thread, daemon=True) diff --git a/pyproject.toml b/pyproject.toml index ee5f1492..46c49a15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,8 +45,8 @@ pytest = "^7.4.3" pyfiglet = "^1.0.2" httpx = "^0.24.0" flask = "^2.2.3" -flask-socketio = "^5.3.6" -websockets = "^12.0" +websocket-client = "^1.6.4" +websocket-server = "^0.6.4" [tool.poetry.group.dev.dependencies] pytest = "^7.0.0" diff --git a/tests/test_logger.py b/tests/test_logger.py index 2c7ea997..adf72148 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -37,14 +37,3 @@ def b(self): self.logger.log("from class B") B().b() - - -def test_logger_catch(): - from neetbox.logging import logger - - @logger.catch(reraise=False) - def my_function(x, y, z): - # An error? It's caught anyway! - return 1 / (x + y + z) - - my_function(0, 0, 0)