diff --git a/app/api/v2/handlers/contact_api.py b/app/api/v2/handlers/contact_api.py index 6e2e10ac6..c9552ecb4 100644 --- a/app/api/v2/handlers/contact_api.py +++ b/app/api/v2/handlers/contact_api.py @@ -44,7 +44,7 @@ async def get_contact_report(self, request: web.Request): async def get_available_contact_reports(self, request: web.Request): contacts = self._api_manager.get_available_contact_reports() return web.json_response(contacts) - + async def get_contact_list(self, request: web.Request): contacts = [dict(name=c.name, description=c.description) for c in self._api_manager.contact_svc.contacts] return web.json_response(contacts) diff --git a/app/api/v2/security.py b/app/api/v2/security.py index bf0220960..0f70fb3a8 100644 --- a/app/api/v2/security.py +++ b/app/api/v2/security.py @@ -68,12 +68,12 @@ async def authentication_required_middleware(request, handler): return await handler(request) return authentication_required_middleware -"""Allow all 'OPTIONS' request to the server to return 200 -This mitigates CORS issues while developing the UI. -""" @web.middleware async def pass_option_middleware(request, handler): + """Allow all 'OPTIONS' request to the server to return 200 + This mitigates CORS issues while developing the UI. + """ if request.method == 'OPTIONS': raise web.HTTPOk() return await handler(request) diff --git a/app/ascii_banner.py b/app/ascii_banner.py index 074d668cb..0411aa3aa 100644 --- a/app/ascii_banner.py +++ b/app/ascii_banner.py @@ -1,4 +1,5 @@ import os +from rich import print as rich_print RED = "\33[1m" @@ -7,14 +8,14 @@ DARK_BLUE = "\x1b[38;5;20m" GREEN = "\033[32m" YELLOW = "\033[93m" -PURPLE = '\033[0;35m' -DARK_PURPLE = '\x1b[38;5;92m' +PURPLE = "\033[0;35m" +DARK_PURPLE = "\x1b[38;5;92m" CYAN = "\033[36m" END = "\033[0m" _BANNER = """ - ██████╗ █████╗ ██╗ ██████╗ ███████╗██████╗ █████╗ + ██████╗ █████╗ ██╗ ██████╗ ███████╗██████╗ █████╗ ██╔════╝██╔══██╗██║ ██╔══██╗██╔════╝██╔══██╗██╔══██╗ ██║ ███████║██║ ██║ ██║█████╗ ██████╔╝███████║ ██║ ██╔══██║██║ ██║ ██║██╔══╝ ██╔══██╗██╔══██║ @@ -23,7 +24,7 @@ """ -_BANNER_SECTION_1 = "\n\ +_BANNER_SECTION_1 = "\n\ ██████╗ █████╗ ██╗ ██████╗ ███████╗██████╗ █████╗\n\ ██╔════╝██╔══██╗██║ ██╔══██╗██╔════╝██╔══██╗██╔══██╗\n\ " @@ -41,8 +42,21 @@ " -if int(os.environ.get('NO_COLOR', 0)) == 1: +def no_color(): + return int(os.environ.get("NO_COLOR", 0)) == 1 + + +if no_color(): ASCII_BANNER = _BANNER else: ASCII_BANNER = f"{DARK_BLUE}{_BANNER_SECTION_1}{DARK_PURPLE}{_BANNER_SECTION_2}{DARK_RED}{BANNER_SECTION_3}{END}" - + + +def print_rich_banner(): + """Print banner using Python Rich library""" + if no_color(): + rich_print(f"{_BANNER_SECTION_1}{_BANNER_SECTION_2}{BANNER_SECTION_3}") + else: + rich_print( + f"[blue]{_BANNER_SECTION_1}[/blue][purple]{_BANNER_SECTION_2}[/purple][red]{BANNER_SECTION_3}[/red]" + ) diff --git a/requirements.txt b/requirements.txt index 2d1b44ccc..a1890aa10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,7 @@ marshmallow-enum==1.5.1 ldap3==2.9.1 lxml~=4.9.1 # debrief reportlab==4.0.4 # debrief +rich==13.7.0 svglib==1.5.1 # debrief Markdown==3.4.4 # training dnspython==2.4.2 diff --git a/server.py b/server.py index a1b44ecc4..3637d6197 100644 --- a/server.py +++ b/server.py @@ -2,6 +2,9 @@ import asyncio import logging import os +from rich.console import Console +from rich.logging import RichHandler +from rich.theme import Theme import sys import warnings import subprocess @@ -12,7 +15,7 @@ import app.api.v2 from app import version -from app.ascii_banner import ASCII_BANNER +from app.ascii_banner import ASCII_BANNER, no_color, print_rich_banner from app.api.rest_api import RestApi from app.api.v2.responses import apispec_request_validation_middleware from app.api.v2.security import pass_option_middleware @@ -35,11 +38,19 @@ def setup_logger(level=logging.DEBUG): - logging.basicConfig( - level=level, - format="%(asctime)s - %(levelname)-5s (%(filename)s:%(lineno)s %(funcName)s) %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) + format = "%(message)s" + datefmt = "%Y-%m-%d %H:%M:%S" + if no_color(): + logging.basicConfig(level=level, format=format, datefmt=datefmt) + else: + console = Console(theme=Theme({"logging.level.warning": "yellow"})) + logging.basicConfig( + level=level, + format=format, + datefmt=datefmt, + handlers=[RichHandler(rich_tracebacks=True, markup=True, console=console)] + ) + for logger_name in logging.root.manager.loggerDict.keys(): if logger_name in ("aiohttp.server", "asyncio"): continue @@ -88,7 +99,7 @@ def run_tasks(services, run_vue_server=False): loop.run_until_complete(start_vue_dev_server()) try: logging.info("All systems ready.") - logging.info(ASCII_BANNER) + logging.info(print_rich_banner()) loop.run_forever() except KeyboardInterrupt: loop.run_until_complete( @@ -138,7 +149,10 @@ def _get_parser(): def list_str(values): return values.split(",") - parser = argparse.ArgumentParser("Welcome to the system") + parser = argparse.ArgumentParser( + description=ASCII_BANNER, + formatter_class=argparse.RawDescriptionHelpFormatter + ) parser.add_argument( "-E", "--environment", @@ -191,16 +205,16 @@ def list_str(values): return parser -if __name__ == '__main__': +if __name__ == "__main__": sys.path.append("") - + parser = _get_parser() args = parser.parse_args() setup_logger(getattr(logging, args.logLevel)) if args.insecure: logging.warning( - "--insecure flag set. Caldera will use the default.yml config file." + "[orange_red1]--insecure flag set. Caldera will use the default user accounts in default.yml config file.[/orange_red1]" ) args.environment = "default" elif args.environment == "local": @@ -248,13 +262,20 @@ def list_str(values): subprocess.run(["npm", "install"], cwd="plugins/magma", check=True) subprocess.run(["npm", "run", "build"], cwd="plugins/magma", check=True) logging.info("VueJS front-end build complete.") + else: + if not os.path.exists("./plugins/magma/dist"): + logging.warning( + "[bright_yellow]Built Caldera v5 Vue components not detected, and `--build` flag not supplied." + " If attempting to start Caldera v5 for the first time, the `--build` flag must be" + " supplied to trigger the building of the Vue source components.[/bright_yellow]" + ) if args.fresh: logging.info( - "Fresh startup: resetting server data. See %s directory for data backups.", + "[green]Fresh startup: resetting server data. See %s directory for data backups.[/green]", DATA_BACKUP_DIR, ) asyncio.get_event_loop().run_until_complete(data_svc.destroy()) asyncio.get_event_loop().run_until_complete(knowledge_svc.destroy()) - run_tasks(services=app_svc.get_services(), run_vue_server=args.uiDevHost) \ No newline at end of file + run_tasks(services=app_svc.get_services(), run_vue_server=args.uiDevHost) diff --git a/tests/api/v2/handlers/test_payloads_api.py b/tests/api/v2/handlers/test_payloads_api.py index a56bb4c42..1ada24071 100644 --- a/tests/api/v2/handlers/test_payloads_api.py +++ b/tests/api/v2/handlers/test_payloads_api.py @@ -8,7 +8,7 @@ async def test_get_payloads(self, api_v2_client, api_cookies): payloads_list = await resp.json() assert len(payloads_list) > 0 payload = payloads_list[0] - assert type(payload) == str + assert type(payload) is str async def test_unauthorized_get_payloads(self, api_v2_client): resp = await api_v2_client.get('/api/v2/payloads')