From a1237171d2fe86e99b4eed5fd7f6f9578ff24aa9 Mon Sep 17 00:00:00 2001 From: Jericho Tolentino <68654047+jericht@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:50:42 -0700 Subject: [PATCH] fix!: handle socket name collisions (#125) This change adds a fix to address socket name collisions, as well as the following changes: * Sockets are now created in the system temp directory. * Added daemon bootstrap logging (via a new --log-file option only exposed in the _serve command) so that if the daemon bootstrap fails, we are able to look at the logs and determine what went wrong. Daemon bootstrap logging stops logging to the file after bootstrap is complete. * Improved the --connection-file CLI help text * Removed the worker path component from configuration file paths * Add support for configuring connection data via environment variables. Currently, there is only one env var needed, OPENJD_ADAPTOR_SOCKET. * The daemon start command has been updated to emit an openjd_env: OPENJD_ADAPTOR_SOCKET= line to stdout so that, when running in an OpenJD environment, the option is automatically set and subsequent commands within this OpenJD environment no longer need to pass in a --connection-file argument. * The --connection-file argument is no longer necessary if OPENJD_ADAPTOR_SOCKET is provided * The daemon start command will still generate a connection file at a temporary directory when it is not provided, then delete it after daemon bootstrap is complete. Signed-off-by: Jericho Tolentino <68654047+jericht@users.noreply.github.com> --- .../_background/backend_runner.py | 26 +- .../_background/frontend_runner.py | 213 ++++++-- .../adaptor_runtime/_background/loaders.py | 69 +++ .../_background/server_response.py | 2 +- src/openjd/adaptor_runtime/_entrypoint.py | 182 +++++-- src/openjd/adaptor_runtime/_http/__init__.py | 4 +- src/openjd/adaptor_runtime/_http/sockets.py | 166 +++--- .../adaptor_runtime/_utils/_constants.py | 13 + src/openjd/adaptor_runtime/_utils/_logging.py | 3 - .../adaptor_runtime/_utils/_secure_open.py | 6 +- .../adaptors/_adaptor_runner.py | 3 +- .../adaptor_runtime/adaptors/_base_adaptor.py | 5 +- .../application_ipc/_adaptor_server.py | 7 +- .../application_ipc/_win_adaptor_server.py | 6 +- test/openjd/adaptor_runtime/conftest.py | 19 +- .../integ/AdaptorExample/adaptor.py | 1 + .../integ/background/test_background_mode.py | 42 +- .../integ/test_integration_entrypoint.py | 260 +++++++--- .../unit/background/test_backend_runner.py | 36 +- .../unit/background/test_frontend_runner.py | 480 +++++++++--------- .../unit/background/test_loaders.py | 148 ++++++ .../adaptor_runtime/unit/http/test_sockets.py | 349 +++++++------ .../adaptor_runtime/unit/test_entrypoint.py | 91 ++-- .../test_integration_client_interface.py | 2 +- 24 files changed, 1412 insertions(+), 721 deletions(-) create mode 100644 src/openjd/adaptor_runtime/_background/loaders.py create mode 100644 src/openjd/adaptor_runtime/_utils/_constants.py create mode 100644 test/openjd/adaptor_runtime/unit/background/test_loaders.py diff --git a/src/openjd/adaptor_runtime/_background/backend_runner.py b/src/openjd/adaptor_runtime/_background/backend_runner.py index 50bb7b1..5d4e16e 100644 --- a/src/openjd/adaptor_runtime/_background/backend_runner.py +++ b/src/openjd/adaptor_runtime/_background/backend_runner.py @@ -6,14 +6,16 @@ import logging import os import signal +from pathlib import Path from threading import Thread, Event +import traceback from types import FrameType -from typing import Optional, Union +from typing import Callable, List, Optional, Union from .server_response import ServerResponseGenerator from .._osname import OSName from ..adaptors import AdaptorRunner -from .._http import SocketDirectories +from .._http import SocketPaths from .._utils import secure_open if OSName.is_posix(): @@ -36,12 +38,13 @@ class BackendRunner: def __init__( self, adaptor_runner: AdaptorRunner, - connection_file_path: str, *, + connection_file_path: Path, log_buffer: LogBuffer | None = None, ) -> None: self._adaptor_runner = adaptor_runner self._connection_file_path = connection_file_path + self._log_buffer = log_buffer self._server: Optional[Union[BackgroundHTTPServer, WinBackgroundNamedPipeServer]] = None signal.signal(signal.SIGINT, self._sigint_handler) @@ -68,7 +71,7 @@ def _sigint_handler(self, signum: int, frame: Optional[FrameType]) -> None: self._server, self._adaptor_runner._cancel, force_immediate=True ) - def run(self) -> None: + def run(self, *, on_connection_file_written: List[Callable[[], None]] | None = None) -> None: """ Runs the backend logic for background mode. @@ -79,8 +82,9 @@ def run(self) -> None: shutdown_event: Event = Event() if OSName.is_posix(): # pragma: is-windows - server_path = SocketDirectories.for_os().get_process_socket_path( - "runtime", create_dir=True + server_path = SocketPaths.for_os().get_process_socket_path( + ".openjd_adaptor_runtime", + create_dir=True, ) else: # pragma: is-posix server_path = NamedPipeHelper.generate_pipe_name("AdaptorNamedPipe") @@ -123,6 +127,16 @@ def run(self) -> None: _logger.info("Shutting down server...") shutdown_event.set() raise + except Exception as e: + _logger.critical(f"Unexpected error occurred when writing to connection file: {e}") + _logger.critical(traceback.format_exc()) + _logger.info("Shutting down server") + shutdown_event.set() + else: + if on_connection_file_written: + callbacks = list(on_connection_file_written) + for cb in callbacks: + cb() finally: # Block until the shutdown_event is set shutdown_event.wait() diff --git a/src/openjd/adaptor_runtime/_background/frontend_runner.py b/src/openjd/adaptor_runtime/_background/frontend_runner.py index 23e9234..1b229cb 100644 --- a/src/openjd/adaptor_runtime/_background/frontend_runner.py +++ b/src/openjd/adaptor_runtime/_background/frontend_runner.py @@ -10,16 +10,20 @@ import socket import subprocess import sys +import tempfile import time -from pathlib import Path import urllib.parse as urllib_parse +import uuid +from pathlib import Path from threading import Event from types import FrameType from types import ModuleType -from typing import Optional, Dict +from typing import Optional, Callable, Dict from .._osname import OSName from ..process._logging import _ADAPTOR_OUTPUT_LEVEL +from .._utils._constants import _OPENJD_ENV_STDOUT_PREFIX, _OPENJD_ADAPTOR_SOCKET_ENV +from .loaders import ConnectionSettingsFileLoader from .model import ( AdaptorState, AdaptorStatus, @@ -37,28 +41,39 @@ _logger = logging.getLogger(__name__) +class ConnectionSettingsNotProvidedError(Exception): + """Raised when the connection settings are required but are missing""" + + pass + + class FrontendRunner: """ Class that runs the frontend logic in background mode. """ + connection_settings: ConnectionSettings | None + def __init__( self, - connection_file_path: str, *, timeout_s: float = 5.0, heartbeat_interval: float = 1.0, + connection_settings: ConnectionSettings | None = None, ) -> None: """ Args: - connection_file_path (str): Absolute path to the connection file. timeout_s (float, optional): Timeout for HTTP requests, in seconds. Defaults to 5. heartbeat_interval (float, optional): Interval between heartbeats, in seconds. Defaults to 1. + connection_settings (ConnectionSettings, optional): The connection settings to use. + This option is not required for the "init" command, but is required for everything + else. Defaults to None. """ self._timeout_s = timeout_s self._heartbeat_interval = heartbeat_interval - self._connection_file_path = connection_file_path + self.connection_settings = connection_settings + self._canceled = Event() signal.signal(signal.SIGINT, self._sigint_handler) if OSName.is_posix(): # pragma: is-windows @@ -68,7 +83,9 @@ def __init__( def init( self, + *, adaptor_module: ModuleType, + connection_file_path: Path, init_data: dict | None = None, path_mapping_data: dict | None = None, reentry_exe: Path | None = None, @@ -79,6 +96,8 @@ def init( Args: adaptor_module (ModuleType): The module of the adaptor running the runtime. + connection_file_path (Path): The path to the connection file to use for establishing + a connection with the backend process. init_data (dict): Data to pass to the adaptor during initialization. path_mapping_data (dict): Path mapping rules to make available to the adaptor while it's running. reentry_exe (Path): The path to the binary executable that for adaptor reentry. @@ -86,10 +105,10 @@ def init( if adaptor_module.__package__ is None: raise Exception(f"Adaptor module is not a package: {adaptor_module}") - if os.path.exists(self._connection_file_path): + if connection_file_path.exists(): raise FileExistsError( "Cannot init a new backend process with an existing connection file at: " - + self._connection_file_path + + str(connection_file_path) ) if init_data is None: @@ -107,18 +126,32 @@ def init( ] else: args = [str(reentry_exe)] + args.extend( [ "daemon", "_serve", - "--connection-file", - self._connection_file_path, "--init-data", json.dumps(init_data), "--path-mapping-rules", json.dumps(path_mapping_data), + "--connection-file", + str(connection_file_path), ] ) + + bootstrap_id = uuid.uuid4() + bootstrap_log_dir = tempfile.gettempdir() + bootstrap_log_path = os.path.join( + bootstrap_log_dir, f"adaptor-runtime-background-bootstrap-{bootstrap_id}.log" + ) + args.extend(["--bootstrap-log-file", bootstrap_log_path]) + + _logger.debug(f"Running process with args: {args}") + bootstrap_output_path = os.path.join( + bootstrap_log_dir, f"adaptor-runtime-background-bootstrap-output-{bootstrap_id}.log" + ) + output_log_file = open(bootstrap_output_path, mode="w+") try: process = subprocess.Popen( args, @@ -126,8 +159,8 @@ def init( close_fds=True, start_new_session=True, stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, + stdout=output_log_file, + stderr=output_log_file, ) except Exception as e: _logger.error(f"Failed to initialize backend process: {e}") @@ -136,19 +169,61 @@ def init( # Wait for backend process to create connection file try: - _wait_for_file(self._connection_file_path, timeout_s=5) + _wait_for_connection_file(str(connection_file_path), max_retries=5, interval_s=1) except TimeoutError: _logger.error( "Backend process failed to write connection file in time at: " - + self._connection_file_path + + str(connection_file_path) ) + + exit_code = process.poll() + if exit_code is not None: + _logger.info(f"Backend process exited with code: {exit_code}") + else: + _logger.info("Backend process is still running") + raise + finally: + # Close file handle to prevent further writes + # At this point, we have all the logs/output we need from the bootstrap + output_log_file.close() + if process.stdout: + process.stdout.close() + if process.stderr: + process.stderr.close() + + with open(bootstrap_output_path, mode="r") as f: + bootstrap_output = f.readlines() + _logger.info("========== BEGIN BOOTSTRAP OUTPUT CONTENTS ==========") + for line in bootstrap_output: + _logger.info(line.strip()) + _logger.info("========== END BOOTSTRAP OUTPUT CONTENTS ==========") + + _logger.info(f"Checking for bootstrap logs at '{bootstrap_log_path}'") + try: + with open(bootstrap_log_path, mode="r") as f: + bootstrap_logs = f.readlines() + except Exception as e: + _logger.error(f"Failed to get bootstrap logs at '{bootstrap_log_path}': {e}") + else: + _logger.info("========== BEGIN BOOTSTRAP LOG CONTENTS ==========") + for line in bootstrap_logs: + _logger.info(line.strip()) + _logger.info("========== END BOOTSTRAP LOG CONTENTS ==========") + + # Load up connection settings for the heartbeat requests + self.connection_settings = ConnectionSettingsFileLoader(connection_file_path).load() # Heartbeat to ensure backend process is listening for requests _logger.info("Verifying connection to backend...") self._heartbeat() _logger.info("Connected successfully") + # Output the socket path to the environment via OpenJD environments + _logger.info( + f"{_OPENJD_ENV_STDOUT_PREFIX}{_OPENJD_ADAPTOR_SOCKET_ENV}={self.connection_settings.socket}" + ) + def run(self, run_data: dict) -> None: """ Sends a run request to the backend @@ -248,6 +323,11 @@ def _send_request( params: dict | None = None, json_body: dict | None = None, ) -> http_client.HTTPResponse | Dict: + if not self.connection_settings: + raise ConnectionSettingsNotProvidedError( + "Connection settings are required to send requests, but none were provided" + ) + if OSName.is_windows(): # pragma: is-posix if params: # This is used for aligning to the Linux's behavior in order to reuse the code in handler. @@ -287,6 +367,11 @@ def _send_linux_request( params: dict | None = None, json_body: dict | None = None, ) -> http_client.HTTPResponse: # pragma: is-windows + if not self.connection_settings: + raise ConnectionSettingsNotProvidedError( + "Connection settings are required to send requests, but none were provided" + ) + conn = UnixHTTPConnection(self.connection_settings.socket, timeout=self._timeout_s) if params: @@ -311,15 +396,6 @@ def _send_linux_request( return response - @property - def connection_settings(self) -> ConnectionSettings: - """ - Gets the lazy-loaded connection settings. - """ - if not hasattr(self, "_connection_settings"): - self._connection_settings = _load_connection_settings(self._connection_file_path) - return self._connection_settings - def _sigint_handler(self, signum: int, frame: Optional[FrameType]) -> None: """ Signal handler for interrupt signals. @@ -338,51 +414,84 @@ def _sigint_handler(self, signum: int, frame: Optional[FrameType]) -> None: self.cancel() -def _load_connection_settings(path: str) -> ConnectionSettings: - try: - with open(path) as conn_file: - loaded_settings = json.load(conn_file) - except OSError as e: - _logger.error(f"Failed to open connection file: {e}") - raise - except json.JSONDecodeError as e: - _logger.error(f"Failed to decode connection file: {e}") - raise - return DataclassMapper(ConnectionSettings).map(loaded_settings) - - -def _wait_for_file(filepath: str, timeout_s: float, interval_s: float = 1) -> None: +def _wait_for_connection_file( + filepath: str, max_retries: int, interval_s: float = 1 +) -> ConnectionSettings: """ - Waits for a file at the specified path to exist and to be openable. + Waits for a connection file at the specified path to exist, be openable, and have connection settings. Args: filepath (str): The file path to check. - timeout_s (float): The max duration to wait before timing out, in seconds. + max_retries (int): The max number of retries before timing out. interval_s (float, optional): The interval between checks, in seconds. Default is 0.01s. Raises: TimeoutError: Raised when the file does not exist after timeout_s seconds. """ + wait_for( + description=f"File '{filepath}' to exist", + predicate=lambda: os.path.exists(filepath), + interval_s=interval_s, + max_retries=max_retries, + ) - def _wait(): - if time.time() - start < timeout_s: - time.sleep(interval_s) - else: - raise TimeoutError(f"Timed out after {timeout_s}s waiting for file at {filepath}") - - start = time.time() - while not os.path.exists(filepath): - _wait() + # Wait before opening to give the backend time to open it first + time.sleep(interval_s) - while True: - # Wait before opening to give the backend time to open it first - _wait() + def file_is_openable() -> bool: try: open(filepath, mode="r").close() - break except IOError: # File is not available yet - pass + return False + else: + return True + + wait_for( + description=f"File '{filepath}' to be openable", + predicate=file_is_openable, + interval_s=interval_s, + max_retries=max_retries, + ) + + def connection_file_loadable() -> bool: + try: + ConnectionSettingsFileLoader(Path(filepath)).load() + except Exception: + return False + else: + return True + + wait_for( + description=f"File '{filepath}' to have valid ConnectionSettings", + predicate=connection_file_loadable, + interval_s=interval_s, + max_retries=max_retries, + ) + + return ConnectionSettingsFileLoader(Path(filepath)).load() + + +def wait_for( + *, + description: str, + predicate: Callable[[], bool], + interval_s: float, + max_retries: int | None = None, +) -> None: + if max_retries is not None: + assert max_retries >= 0, "max_retries must be a non-negative integer" + assert interval_s > 0, "interval_s must be a positive number" + + _logger.info(f"Waiting for {description}") + retry_count = 0 + while not predicate(): + if max_retries is not None and retry_count >= max_retries: + raise TimeoutError(f"Timed out waiting for {description}") + + _logger.info(f"Retrying in {interval_s}s...") + retry_count += 1 + time.sleep(interval_s) class AdaptorFailedException(Exception): diff --git a/src/openjd/adaptor_runtime/_background/loaders.py b/src/openjd/adaptor_runtime/_background/loaders.py new file mode 100644 index 0000000..e258a29 --- /dev/null +++ b/src/openjd/adaptor_runtime/_background/loaders.py @@ -0,0 +1,69 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import abc +import dataclasses +import logging +import json +import os +from pathlib import Path + +from .model import ( + ConnectionSettings, + DataclassMapper, +) + +_logger = logging.getLogger(__name__) + + +class ConnectionSettingsLoadingError(Exception): + """Raised when the connection settings cannot be loaded""" + + pass + + +class ConnectionSettingsLoader(abc.ABC): + @abc.abstractmethod + def load(self) -> ConnectionSettings: + pass + + +@dataclasses.dataclass +class ConnectionSettingsFileLoader(ConnectionSettingsLoader): + file_path: Path + + def load(self) -> ConnectionSettings: + try: + with open(self.file_path) as conn_file: + loaded_settings = json.load(conn_file) + except OSError as e: + errmsg = f"Failed to open connection file '{self.file_path}': {e}" + _logger.error(errmsg) + raise ConnectionSettingsLoadingError(errmsg) from e + except json.JSONDecodeError as e: + errmsg = f"Failed to decode connection file '{self.file_path}': {e}" + _logger.error(errmsg) + raise ConnectionSettingsLoadingError(errmsg) from e + return DataclassMapper(ConnectionSettings).map(loaded_settings) + + +@dataclasses.dataclass +class ConnectionSettingsEnvLoader(ConnectionSettingsLoader): + env_map: dict[str, tuple[str, bool]] = dataclasses.field( + default_factory=lambda: {"socket": ("OPENJD_ADAPTOR_SOCKET", True)} + ) + """Mapping of environment variable to a tuple of ConnectionSettings attribute name, and whether it is required""" + + def load(self) -> ConnectionSettings: + kwargs = {} + for attr_name, (env_name, required) in self.env_map.items(): + env_val = os.environ.get(env_name) + if not env_val: + if required: + raise ConnectionSettingsLoadingError( + f"Required attribute '{attr_name}' does not have its corresponding environment variable '{env_name}' set" + ) + else: + kwargs[attr_name] = env_val + return ConnectionSettings(**kwargs) diff --git a/src/openjd/adaptor_runtime/_background/server_response.py b/src/openjd/adaptor_runtime/_background/server_response.py index 9897b2f..31d8b2f 100644 --- a/src/openjd/adaptor_runtime/_background/server_response.py +++ b/src/openjd/adaptor_runtime/_background/server_response.py @@ -17,8 +17,8 @@ from .http_server import BackgroundHTTPServer -from ..adaptors._adaptor_runner import _OPENJD_FAIL_STDOUT_PREFIX from .._http import HTTPResponse +from .._utils._constants import _OPENJD_FAIL_STDOUT_PREFIX from .model import ( AdaptorState, AdaptorStatus, diff --git a/src/openjd/adaptor_runtime/_entrypoint.py b/src/openjd/adaptor_runtime/_entrypoint.py index 6c4ab8c..e7a112d 100644 --- a/src/openjd/adaptor_runtime/_entrypoint.py +++ b/src/openjd/adaptor_runtime/_entrypoint.py @@ -2,28 +2,46 @@ from __future__ import annotations +import contextlib import logging import os import signal import sys +import tempfile from pathlib import Path from argparse import ArgumentParser, Namespace from types import FrameType as FrameType -from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, NamedTuple, Tuple +from typing import ( + TYPE_CHECKING, + Any, + cast, + Callable, + List, + Optional, + MutableSet, + Type, + TypeVar, + NamedTuple, + Tuple, +) import jsonschema import yaml from .adaptors import AdaptorRunner, BaseAdaptor from ._background import BackendRunner, FrontendRunner, InMemoryLogBuffer, LogBufferHandler +from ._background.loaders import ( + ConnectionSettingsFileLoader, + ConnectionSettingsEnvLoader, +) from .adaptors.configuration import ( RuntimeConfiguration, ConfigurationManager, ) from ._osname import OSName +from ._utils._constants import _OPENJD_ADAPTOR_SOCKET_ENV, _OPENJD_LOG_REGEX from ._utils._logging import ( - _OPENJD_LOG_REGEX, ConditionalFormatter, ) from .adaptors import SemanticVersion @@ -53,7 +71,13 @@ "file://path/to/file.json" ), "show_config": "Prints the adaptor runtime configuration, then the program exits.", - "connection_file": "The file path to the connection file for use in background mode.", + "connection_file": ( + "The file path to the connection file for use in daemon mode. For the 'daemon start' command, this file " + "must not exist. For all other commands, this file must exist. This option is highly " + "recommended if using the adaptor in an interactive terminal. Default is to read the " + f"connection data from the environment variable: {_OPENJD_ADAPTOR_SOCKET_ENV}" + ), + "log_file": "The file to log adaptor output to. Default is to not log to a file.", } _DIR = os.path.dirname(os.path.realpath(__file__)) @@ -64,7 +88,6 @@ os.path.join( _system_config_path_prefix, "openjd", - "worker", "adaptors", "runtime", "configuration.json", @@ -75,14 +98,27 @@ "schema_path": os.path.abspath(os.path.join(_DIR, "configuration.schema.json")), "default_config_path": os.path.abspath(os.path.join(_DIR, "configuration.json")), "system_config_path": _system_config_path, - "user_config_rel_path": os.path.join( - ".openjd", "worker", "adaptors", "runtime", "configuration.json" - ), + "user_config_rel_path": os.path.join(".openjd", "adaptors", "runtime", "configuration.json"), } _logger = logging.getLogger(__name__) +class _ParsedArgs(Namespace): + command: str + + # common args + init_data: str + run_data: str + path_mapping_rules: str + connection_file: str | None + bootstrap_log_file: str | None + + # is-compatible args + openjd_adaptor_cli_version: str | None + integration_data_interface_version: str | None + + class _LogConfig(NamedTuple): formatter: ConditionalFormatter stream_handler: logging.StreamHandler @@ -123,13 +159,21 @@ class EntryPoint: The main entry point of the adaptor runtime. """ + on_bootstrap_complete: MutableSet[Callable[[], None]] + """ + Set of callbacks that are called when daemon mode bootstrapping is complete. + These callbacks are never called when not running in daemon mode. + """ + def __init__(self, adaptor_class: Type[_U]) -> None: self.adaptor_class = adaptor_class # This will be the current AdaptorRunner when using the 'run' command, rather than # 'background' command self._adaptor_runner: Optional[AdaptorRunner] = None - def _init_loggers(self) -> _LogConfig: + self.on_bootstrap_complete = set() + + def _init_loggers(self, *, bootstrap_log_path: str | None = None) -> _LogConfig: "Creates runtime/adaptor loggers" formatter = ConditionalFormatter( "%(levelname)s: %(message)s", ignore_patterns=[_OPENJD_LOG_REGEX] @@ -144,6 +188,22 @@ def _init_loggers(self) -> _LogConfig: adaptor_logger = logging.getLogger(self.adaptor_class.__module__.split(".")[0]) adaptor_logger.addHandler(stream_handler) + if bootstrap_log_path: + file_formatter = logging.Formatter("[%(asctime)s][%(levelname)-8s] %(message)s") + file_handler = logging.FileHandler(bootstrap_log_path) + file_handler.setFormatter(file_formatter) + file_handler.setLevel(0) + runtime_logger.addHandler(file_handler) + adaptor_logger.addHandler(file_handler) + + def disconnect_bootstrap_logging() -> None: + # Remove file logger after bootstrap is complete + runtime_logger.removeHandler(file_handler) + adaptor_logger.removeHandler(file_handler) + self.on_bootstrap_complete.remove(disconnect_bootstrap_logging) + + self.on_bootstrap_complete.add(disconnect_bootstrap_logging) + return _LogConfig(formatter, stream_handler, runtime_logger, adaptor_logger) def _init_config(self) -> None: @@ -193,19 +253,28 @@ def start(self, reentry_exe: Optional[Path] = None) -> None: Args: reentry_exe (Path): The path to the binary executable that for adaptor reentry. """ - log_config = self._init_loggers() parser, parsed_args = self._parse_args() - version_info = self._get_version_info() + log_config = self._init_loggers( + bootstrap_log_path=( + parsed_args.bootstrap_log_file + if hasattr(parsed_args, "bootstrap_log_file") + else None + ) + ) + + interface_version_info = self._get_version_info() if parsed_args.command == "is-compatible": - return self._handle_is_compatible(version_info, parsed_args, parser) + return self._handle_is_compatible(interface_version_info, parsed_args, parser) elif parsed_args.command == "version-info": return print( yaml.dump( { - "OpenJD Adaptor CLI Version": str(version_info.adaptor_cli_version), + "OpenJD Adaptor CLI Version": str( + interface_version_info.adaptor_cli_version + ), f"{self.adaptor_class.__name__} Data Interface Version": str( - version_info.integration_data_interface_version + interface_version_info.integration_data_interface_version ), }, indent=2, @@ -295,17 +364,26 @@ def _handle_run( def _handle_daemon( self, adaptor: BaseAdaptor[AdaptorConfiguration], - parsed_args: Namespace, + parsed_args: _ParsedArgs, log_config: _LogConfig, integration_data: _IntegrationData, reentry_exe: Optional[Path] = None, ): - connection_file = parsed_args.connection_file - if not os.path.isabs(connection_file): - connection_file = os.path.abspath(connection_file) + # Validate args subcommand = parsed_args.subcommand if hasattr(parsed_args, "subcommand") else None + connection_file: Path | None = None + if hasattr(parsed_args, "connection_file") and parsed_args.connection_file: + connection_file = Path(parsed_args.connection_file) + if connection_file and not connection_file.is_absolute(): + connection_file = connection_file.absolute() + if subcommand == "_serve": + if not connection_file: + raise RuntimeError( + "--connection file is required for the '_serve' command but was not provided." + ) + # Replace stream handler with log buffer handler since output will be buffered in # background mode log_buffer = InMemoryLogBuffer(formatter=log_config.formatter) @@ -318,41 +396,61 @@ def _handle_daemon( # forever until a shutdown is requested backend = BackendRunner( AdaptorRunner(adaptor=adaptor), - connection_file, + connection_file_path=connection_file, log_buffer=log_buffer, ) - backend.run() + backend.run( + on_connection_file_written=cast( + List[Callable[[], None]], self.on_bootstrap_complete + ) + ) else: # This process is running in frontend mode. Create the frontend runner and send # the appropriate request to the backend. - frontend = FrontendRunner(connection_file) if subcommand == "start": + frontend = FrontendRunner() adaptor_module = sys.modules.get(self.adaptor_class.__module__) if adaptor_module is None: raise ModuleNotFoundError( f"Adaptor module is not loaded: {self.adaptor_class.__module__}" ) - frontend.init( - adaptor_module, - integration_data.init_data, - integration_data.path_mapping_data, - reentry_exe, - ) + with contextlib.ExitStack() as stack: + if not connection_file: + tmpdir = stack.enter_context(tempfile.TemporaryDirectory(prefix="ojd-ar-")) + connection_file = Path(tmpdir) / "connection.json" + + frontend.init( + adaptor_module=adaptor_module, + connection_file_path=connection_file, + init_data=integration_data.init_data, + path_mapping_data=integration_data.path_mapping_data, + reentry_exe=reentry_exe, + ) frontend.start() - elif subcommand == "run": - frontend.run(integration_data.run_data) - elif subcommand == "stop": - frontend.stop() - frontend.shutdown() - - def _parse_args(self) -> Tuple[ArgumentParser, Namespace]: + else: + conn_settings_loader = ( + ConnectionSettingsFileLoader(connection_file) + if connection_file + else ConnectionSettingsEnvLoader() + ) + conn_settings = conn_settings_loader.load() + frontend = FrontendRunner(connection_settings=conn_settings) + if subcommand == "run": + frontend.run(integration_data.run_data) + elif subcommand == "stop": + frontend.stop() + frontend.shutdown() + + def _parse_args(self) -> Tuple[ArgumentParser, _ParsedArgs]: parser = self._build_argparser() try: - return parser, parser.parse_args(sys.argv[1:]) + parsed_args = parser.parse_args(sys.argv[1:], _ParsedArgs()) except Exception as e: _logger.error(f"Error parsing command line arguments: {e}") raise + else: + return parser, parsed_args def _build_argparser(self) -> ArgumentParser: parser = ArgumentParser( @@ -412,9 +510,15 @@ def _build_argparser(self) -> ArgumentParser: connection_file = ArgumentParser(add_help=False) connection_file.add_argument( "--connection-file", - default="", help=_CLI_HELP_TEXT["connection_file"], - required=True, + required=False, + ) + + log_file = ArgumentParser(add_help=False) + log_file.add_argument( + "--bootstrap-log-file", + help=_CLI_HELP_TEXT["log_file"], + required=False, ) bg_parser = subparser.add_parser("daemon", help="Runs the adaptor in a daemon mode.") @@ -427,8 +531,10 @@ def _build_argparser(self) -> ArgumentParser: ) # "Hidden" command that actually runs the adaptor runtime in background mode - bg_subparser.add_parser("_serve", parents=[init_data, path_mapping_rules, connection_file]) - + bg_subparser.add_parser( + "_serve", + parents=[init_data, path_mapping_rules, connection_file, log_file], + ) bg_subparser.add_parser("start", parents=[init_data, path_mapping_rules, connection_file]) bg_subparser.add_parser("run", parents=[run_data, connection_file]) bg_subparser.add_parser("stop", parents=[connection_file]) diff --git a/src/openjd/adaptor_runtime/_http/__init__.py b/src/openjd/adaptor_runtime/_http/__init__.py index 834c7e8..70b2f0d 100644 --- a/src/openjd/adaptor_runtime/_http/__init__.py +++ b/src/openjd/adaptor_runtime/_http/__init__.py @@ -1,6 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from .request_handler import HTTPResponse, RequestHandler, ResourceRequestHandler -from .sockets import SocketDirectories +from .sockets import SocketPaths -__all__ = ["HTTPResponse", "RequestHandler", "ResourceRequestHandler", "SocketDirectories"] +__all__ = ["HTTPResponse", "RequestHandler", "ResourceRequestHandler", "SocketPaths"] diff --git a/src/openjd/adaptor_runtime/_http/sockets.py b/src/openjd/adaptor_runtime/_http/sockets.py index 58e3e34..c4c9011 100644 --- a/src/openjd/adaptor_runtime/_http/sockets.py +++ b/src/openjd/adaptor_runtime/_http/sockets.py @@ -16,37 +16,44 @@ # Max PID on 64-bit systems is 4194304 (2^22) _PID_MAX_LENGTH = 7 -_PID_MAX_LENGTH_PADDED = _PID_MAX_LENGTH + 1 # 1 char for path seperator -class SocketDirectories(abc.ABC): +class SocketPaths(abc.ABC): """ - Base class for determining the base directory for sockets used in the Adaptor Runtime. + Base class for determining the paths for sockets used in the Adaptor Runtime. """ @staticmethod def for_os(osname: OSName = OSName()): # pragma: no cover - """_summary_ + """ + Gets the SocketPaths class for a specific OS. Args: - osname (OSName, optional): The OS to get socket directories for. + osname (OSName, optional): The OS to get socket paths for. Defaults to the current OS. Raises: UnsupportedPlatformException: Raised when this class is requested for an unsupported platform. """ - klass = _get_socket_directories_cls(osname) + klass = _get_socket_paths_cls(osname) if not klass: raise UnsupportedPlatformException(osname) return klass() - def get_process_socket_path(self, namespace: str | None = None, *, create_dir: bool = False): + def get_process_socket_path( + self, + namespace: str | None = None, + *, + base_dir: str | None = None, + create_dir: bool = False, + ): """ Gets the path for this process' socket in the given namespace. Args: namespace (Optional[str]): The optional namespace (subdirectory) where the sockets go. + base_dir (Optional[str]): The base directory to create sockets in. Defaults to the temp directory. create_dir (bool): Whether to create the socket directory. Default is false. Raises: @@ -60,15 +67,29 @@ def get_process_socket_path(self, namespace: str | None = None, *, create_dir: b len(socket_name) <= _PID_MAX_LENGTH ), f"PID too long. Only PIDs up to {_PID_MAX_LENGTH} digits are supported." - return os.path.join(self.get_socket_dir(namespace, create=create_dir), socket_name) + return self.get_socket_path( + socket_name, + namespace, + base_dir=base_dir, + create_dir=create_dir, + ) - def get_socket_dir(self, namespace: str | None = None, *, create: bool = False) -> str: + def get_socket_path( + self, + base_socket_name: str, + namespace: str | None = None, + *, + base_dir: str | None = None, + create_dir: bool = False, + ) -> str: """ - Gets the base directory for sockets used in Adaptor IPC + Gets the path for a socket used in Adaptor IPC Args: + base_socket_name (str): The name of the socket namespace (Optional[str]): The optional namespace (subdirectory) where the sockets go - create (bool): Whether to create the directory or not. Default is false. + base_dir (Optional[str]): The base directory to create sockets in. Defaults to the temp directory. + create_dir (bool): Whether to create the directory or not. Default is false. Raises: NonvalidSocketPathException: Raised if the user has configured a socket base directory @@ -77,48 +98,38 @@ def get_socket_dir(self, namespace: str | None = None, *, create: bool = False) not be raised if the user has configured a socket base directory. """ - def create_dir(path: str) -> str: - if create: + def mkdir(path: str) -> str: + if create_dir: os.makedirs(path, mode=0o700, exist_ok=True) return path - rel_path = os.path.join(".openjd", "adaptors", "sockets") - if namespace: - rel_path = os.path.join(rel_path, namespace) - - reasons: list[str] = [] + def gen_socket_path(dir: str, base_name: str): + name = base_name + i = 0 + while os.path.exists(os.path.join(dir, name)): + i += 1 + name = f"{base_name}_{i}" + return os.path.join(dir, name) - # First try home directory - home_dir = os.path.expanduser("~") - socket_dir = os.path.join(home_dir, rel_path) - try: - self.verify_socket_path(socket_dir) - except NonvalidSocketPathException as e: - reasons.append(f"Cannot create sockets directory in the home directory because: {e}") + if not base_dir: + socket_dir = os.path.realpath(tempfile.gettempdir()) else: - return create_dir(socket_dir) + socket_dir = os.path.realpath(base_dir) - # Last resort is the temp directory - temp_dir = tempfile.gettempdir() - socket_dir = os.path.join(temp_dir, rel_path) + if namespace: + socket_dir = os.path.join(socket_dir, namespace) + + mkdir(socket_dir) + + socket_path = gen_socket_path(socket_dir, base_socket_name) try: - self.verify_socket_path(socket_dir) + self.verify_socket_path(socket_path) except NonvalidSocketPathException as e: - reasons.append(f"Cannot create sockets directory in the temp directory because: {e}") - else: - # Also check that the sticky bit is set on the temp dir - if not os.stat(temp_dir).st_mode & stat.S_ISVTX: - reasons.append( - f"Cannot use temporary directory {temp_dir} because it does not have the " - "sticky bit (restricted deletion flag) set" - ) - else: - return create_dir(socket_dir) + raise NoSocketPathFoundException( + f"Socket path '{socket_path}' failed verification: {e}" + ) from e - raise NoSocketPathFoundException( - "Failed to find a suitable base directory to create sockets in for the following " - f"reasons: {os.linesep.join(reasons)}" - ) + return socket_path @abc.abstractmethod def verify_socket_path(self, path: str) -> None: # pragma: no cover @@ -132,53 +143,86 @@ def verify_socket_path(self, path: str) -> None: # pragma: no cover pass -class LinuxSocketDirectories(SocketDirectories): +class WindowsSocketPaths(SocketPaths): + """ + Specialization for verifying socket paths on Windows systems. + """ + + def verify_socket_path(self, path: str) -> None: + # TODO: Verify Windows permissions of parent directories are least privileged + pass + + +class UnixSocketPaths(SocketPaths): + """ + Specialization for verifying socket paths on Unix systems. + """ + + def verify_socket_path(self, path: str) -> None: + # Walk up directories and check that the sticky bit is set if the dir is world writable + prev_path = path + curr_path = os.path.dirname(path) + while prev_path != curr_path and len(curr_path) > 0: + path_stat = os.stat(curr_path) + if path_stat.st_mode & stat.S_IWOTH and not path_stat.st_mode & stat.S_ISVTX: + raise NoSocketPathFoundException( + f"Cannot use directory {curr_path} because it is world writable and does not " + "have the sticky bit (restricted deletion flag) set" + ) + prev_path = curr_path + curr_path = os.path.dirname(curr_path) + + +class LinuxSocketPaths(UnixSocketPaths): """ Specialization for socket paths in Linux systems. """ # This is based on the max length of socket names to 108 bytes # See unix(7) under "Address format" - _socket_path_max_length = 108 - _socket_dir_max_length = _socket_path_max_length - _PID_MAX_LENGTH_PADDED + # In practice, only 107 bytes are accepted (one byte for null terminator) + _socket_name_max_length = 108 - 1 def verify_socket_path(self, path: str) -> None: + super().verify_socket_path(path) path_length = len(path.encode("utf-8")) - if path_length > self._socket_dir_max_length: + if path_length > self._socket_name_max_length: raise NonvalidSocketPathException( - "Socket base directory path too big. The maximum allowed size is " - f"{self._socket_dir_max_length} bytes, but the directory has a size of " + "Socket name too long. The maximum allowed size is " + f"{self._socket_name_max_length} bytes, but the name has a size of " f"{path_length}: {path}" ) -class MacOSSocketDirectories(SocketDirectories): +class MacOSSocketPaths(UnixSocketPaths): """ Specialization for socket paths in macOS systems. """ # This is based on the max length of socket names to 104 bytes # See https://github.com/apple-oss-distributions/xnu/blob/1031c584a5e37aff177559b9f69dbd3c8c3fd30a/bsd/sys/un.h#L79 - _socket_path_max_length = 104 - _socket_dir_max_length = _socket_path_max_length - _PID_MAX_LENGTH_PADDED + # In practice, only 103 bytes are accepted (one byte for null terminator) + _socket_name_max_length = 104 - 1 def verify_socket_path(self, path: str) -> None: + super().verify_socket_path(path) path_length = len(path.encode("utf-8")) - if path_length > self._socket_dir_max_length: + if path_length > self._socket_name_max_length: raise NonvalidSocketPathException( - "Socket base directory path too big. The maximum allowed size is " - f"{self._socket_dir_max_length} bytes, but the directory has a size of " + "Socket name too long. The maximum allowed size is " + f"{self._socket_name_max_length} bytes, but the name has a size of " f"{path_length}: {path}" ) -_os_map: dict[str, type[SocketDirectories]] = { - OSName.LINUX: LinuxSocketDirectories, - OSName.MACOS: MacOSSocketDirectories, +_os_map: dict[str, type[SocketPaths]] = { + OSName.LINUX: LinuxSocketPaths, + OSName.MACOS: MacOSSocketPaths, + OSName.WINDOWS: WindowsSocketPaths, } -def _get_socket_directories_cls( +def _get_socket_paths_cls( osname: OSName, -) -> type[SocketDirectories] | None: # pragma: no cover +) -> type[SocketPaths] | None: # pragma: no cover return _os_map.get(osname, None) diff --git a/src/openjd/adaptor_runtime/_utils/_constants.py b/src/openjd/adaptor_runtime/_utils/_constants.py new file mode 100644 index 0000000..b7c5bc1 --- /dev/null +++ b/src/openjd/adaptor_runtime/_utils/_constants.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import re + +_OPENJD_LOG_PATTERN = r"^openjd_\S+: " +_OPENJD_LOG_REGEX = re.compile(_OPENJD_LOG_PATTERN) + +_OPENJD_FAIL_STDOUT_PREFIX = "openjd_fail: " +_OPENJD_PROGRESS_STDOUT_PREFIX = "openjd_progress: " +_OPENJD_STATUS_STDOUT_PREFIX = "openjd_status: " +_OPENJD_ENV_STDOUT_PREFIX = "openjd_env: " + +_OPENJD_ADAPTOR_SOCKET_ENV = "OPENJD_ADAPTOR_SOCKET" diff --git a/src/openjd/adaptor_runtime/_utils/_logging.py b/src/openjd/adaptor_runtime/_utils/_logging.py index 8c4d6e3..2a10f01 100644 --- a/src/openjd/adaptor_runtime/_utils/_logging.py +++ b/src/openjd/adaptor_runtime/_utils/_logging.py @@ -6,9 +6,6 @@ Optional, ) -_OPENJD_LOG_PATTERN = r"^openjd_\S+: " -_OPENJD_LOG_REGEX = re.compile(_OPENJD_LOG_PATTERN) - class ConditionalFormatter(logging.Formatter): """ diff --git a/src/openjd/adaptor_runtime/_utils/_secure_open.py b/src/openjd/adaptor_runtime/_utils/_secure_open.py index a44d246..47269de 100644 --- a/src/openjd/adaptor_runtime/_utils/_secure_open.py +++ b/src/openjd/adaptor_runtime/_utils/_secure_open.py @@ -81,7 +81,7 @@ def get_file_owner_in_windows(filepath: "StrOrBytesPath") -> str: # pragma: is- Returns: str: A string in the format 'DOMAIN\\Username' representing the file's owner. """ - sd = win32security.GetFileSecurity(filepath, win32security.OWNER_SECURITY_INFORMATION) + sd = win32security.GetFileSecurity(str(filepath), win32security.OWNER_SECURITY_INFORMATION) owner_sid = sd.GetSecurityDescriptorOwner() name, domain, _ = win32security.LookupAccountSid(None, owner_sid) return f"{domain}\\{name}" @@ -108,13 +108,13 @@ def set_file_permissions_in_windows(filepath: "StrOrBytesPath") -> None: # prag dacl.AddAccessAllowedAce(win32security.ACL_REVISION, win32con.DELETE, user_sid) # Apply the DACL to the file - sd = win32security.GetFileSecurity(filepath, win32security.DACL_SECURITY_INFORMATION) + sd = win32security.GetFileSecurity(str(filepath), win32security.DACL_SECURITY_INFORMATION) sd.SetSecurityDescriptorDacl( 1, # A flag that indicates the presence of a DACL in the security descriptor. dacl, # An ACL structure that specifies the DACL for the security descriptor. 0, # Don't retrieve the default DACL ) - win32security.SetFileSecurity(filepath, win32security.DACL_SECURITY_INFORMATION, sd) + win32security.SetFileSecurity(str(filepath), win32security.DACL_SECURITY_INFORMATION, sd) def _get_flags_from_mode_str(open_mode: str) -> int: diff --git a/src/openjd/adaptor_runtime/adaptors/_adaptor_runner.py b/src/openjd/adaptor_runtime/adaptors/_adaptor_runner.py index 3d3b1e8..4217364 100644 --- a/src/openjd/adaptor_runtime/adaptors/_adaptor_runner.py +++ b/src/openjd/adaptor_runtime/adaptors/_adaptor_runner.py @@ -6,13 +6,12 @@ from ._adaptor_states import AdaptorState, AdaptorStates from ._base_adaptor import BaseAdaptor as BaseAdaptor +from .._utils._constants import _OPENJD_FAIL_STDOUT_PREFIX __all__ = ["AdaptorRunner"] _logger = logging.getLogger(__name__) -_OPENJD_FAIL_STDOUT_PREFIX: str = "openjd_fail: " - class AdaptorRunner(AdaptorStates): """ diff --git a/src/openjd/adaptor_runtime/adaptors/_base_adaptor.py b/src/openjd/adaptor_runtime/adaptors/_base_adaptor.py index 55ce364..70912b3 100644 --- a/src/openjd/adaptor_runtime/adaptors/_base_adaptor.py +++ b/src/openjd/adaptor_runtime/adaptors/_base_adaptor.py @@ -13,6 +13,7 @@ from typing import Type from typing import TypeVar +from .._utils._constants import _OPENJD_PROGRESS_STDOUT_PREFIX, _OPENJD_STATUS_STDOUT_PREFIX from .configuration import AdaptorConfiguration, ConfigurationManager from .configuration._configuration_manager import ( create_adaptor_configuration_manager as create_adaptor_configuration_manager, @@ -54,8 +55,8 @@ class BaseAdaptor(AdaptorStates, Generic[_T]): Base class for adaptors. """ - _OPENJD_PROGRESS_STDOUT_PREFIX: str = "openjd_progress: " - _OPENJD_STATUS_STDOUT_PREFIX: str = "openjd_status: " + _OPENJD_PROGRESS_STDOUT_PREFIX: str = _OPENJD_PROGRESS_STDOUT_PREFIX + _OPENJD_STATUS_STDOUT_PREFIX: str = _OPENJD_STATUS_STDOUT_PREFIX def __init__( self, diff --git a/src/openjd/adaptor_runtime/application_ipc/_adaptor_server.py b/src/openjd/adaptor_runtime/application_ipc/_adaptor_server.py index 70a88af..d914ad9 100644 --- a/src/openjd/adaptor_runtime/application_ipc/_adaptor_server.py +++ b/src/openjd/adaptor_runtime/application_ipc/_adaptor_server.py @@ -8,7 +8,7 @@ from socketserver import UnixStreamServer # type: ignore[attr-defined] from typing import TYPE_CHECKING -from .._http import SocketDirectories +from .._http import SocketPaths from ._http_request_handler import AdaptorHTTPRequestHandler if TYPE_CHECKING: # pragma: no cover because pytest will think we should test for this. @@ -33,7 +33,10 @@ def __init__( actions_queue: ActionsQueue, adaptor: BaseAdaptor, ) -> None: # pragma: no cover - socket_path = SocketDirectories.for_os().get_process_socket_path("dcc", create_dir=True) + socket_path = SocketPaths.for_os().get_process_socket_path( + ".openjd_adaptor_server", + create_dir=True, + ) super().__init__(socket_path, AdaptorHTTPRequestHandler) self.actions_queue = actions_queue diff --git a/src/openjd/adaptor_runtime/application_ipc/_win_adaptor_server.py b/src/openjd/adaptor_runtime/application_ipc/_win_adaptor_server.py index d269260..0471d91 100644 --- a/src/openjd/adaptor_runtime/application_ipc/_win_adaptor_server.py +++ b/src/openjd/adaptor_runtime/application_ipc/_win_adaptor_server.py @@ -26,7 +26,11 @@ class WinAdaptorServer(NamedPipeServer): actions_queue: ActionsQueue adaptor: BaseAdaptor - def __init__(self, actions_queue: ActionsQueue, adaptor: BaseAdaptor) -> None: + def __init__( + self, + actions_queue: ActionsQueue, + adaptor: BaseAdaptor, + ) -> None: """ Adaptor Server class in Windows. diff --git a/test/openjd/adaptor_runtime/conftest.py b/test/openjd/adaptor_runtime/conftest.py index 83e282a..642d94a 100644 --- a/test/openjd/adaptor_runtime/conftest.py +++ b/test/openjd/adaptor_runtime/conftest.py @@ -1,12 +1,15 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. import platform +import random +import string from typing import Generator +from unittest.mock import MagicMock, patch -from openjd.adaptor_runtime._osname import OSName -import string -import random import pytest +from openjd.adaptor_runtime._osname import OSName +from openjd.adaptor_runtime._http import sockets + if OSName.is_windows(): import win32net import win32netcon @@ -92,3 +95,13 @@ def delete_user() -> None: yield username, password # Delete the user after test completes delete_user() + + +@pytest.fixture(scope="session", autouse=OSName().is_macos()) +def mock_sockets_py_tempfile_gettempdir_to_slash_tmp() -> Generator[MagicMock, None, None]: + """ + Mock that is automatically used on Mac to override the tempfile.gettempdir() usages in sockets.py + because the folder returned by it on Mac is too long for the socket name to fit (max 104 bytes, see sockets.py) + """ + with patch.object(sockets.tempfile, "gettempdir", return_value="/tmp") as m: + yield m diff --git a/test/openjd/adaptor_runtime/integ/AdaptorExample/adaptor.py b/test/openjd/adaptor_runtime/integ/AdaptorExample/adaptor.py index 42a6a43..afbf414 100644 --- a/test/openjd/adaptor_runtime/integ/AdaptorExample/adaptor.py +++ b/test/openjd/adaptor_runtime/integ/AdaptorExample/adaptor.py @@ -45,6 +45,7 @@ def on_start(self) -> None: # This example initializes a server thread to interact with a client application, showing command exchange and # execution. _logger.info("on_start") + # Initialize the server thread to manage actions self.server = AdaptorServer( # actions_queue will be used for storing the actions. In the client application, it will keep polling the diff --git a/test/openjd/adaptor_runtime/integ/background/test_background_mode.py b/test/openjd/adaptor_runtime/integ/background/test_background_mode.py index f62ff90..974e736 100644 --- a/test/openjd/adaptor_runtime/integ/background/test_background_mode.py +++ b/test/openjd/adaptor_runtime/integ/background/test_background_mode.py @@ -20,7 +20,10 @@ from openjd.adaptor_runtime._background.frontend_runner import ( FrontendRunner, HTTPError, - _load_connection_settings, +) +from openjd.adaptor_runtime._background.loaders import ( + ConnectionSettingsLoadingError, + ConnectionSettingsFileLoader, ) from openjd.adaptor_runtime._osname import OSName @@ -51,7 +54,7 @@ def mock_runtime_logger_level(self, tmpdir: pathlib.Path): yield @pytest.fixture - def connection_file_path(self, tmp_path: pathlib.Path) -> str: + def connection_file_path(self, tmp_path: pathlib.Path) -> pathlib.Path: connection_dir = os.path.join(tmp_path.absolute(), "connection_dir") os.mkdir(connection_dir) if OSName.is_windows(): @@ -63,18 +66,20 @@ def connection_file_path(self, tmp_path: pathlib.Path) -> str: from openjd.adaptor_runtime._utils._secure_open import set_file_permissions_in_windows set_file_permissions_in_windows(connection_dir) - return os.path.join(connection_dir, "connection.json") + return pathlib.Path(connection_dir) / "connection.json" @pytest.fixture def initialized_setup( self, - connection_file_path: str, + connection_file_path: pathlib.Path, caplog: pytest.LogCaptureFixture, ) -> Generator[tuple[FrontendRunner, psutil.Process], None, None]: caplog.set_level(0) - frontend = FrontendRunner(connection_file_path, timeout_s=5.0) - frontend.init(sys.modules[AdaptorExample.__module__]) - conn_settings = _load_connection_settings(connection_file_path) + frontend = FrontendRunner(timeout_s=5.0) + frontend.init( + adaptor_module=sys.modules[AdaptorExample.__module__], + connection_file_path=connection_file_path, + ) match = re.search("Started backend process. PID: ([0-9]+)", caplog.text) assert match is not None @@ -93,14 +98,21 @@ def initialized_setup( # Once all handles are closed, the system automatically cleans up the named pipe. if OSName.is_posix(): try: - os.remove(conn_settings.socket) - except FileNotFoundError: - pass # Already deleted + conn_settings = ConnectionSettingsFileLoader(connection_file_path).load() + except ConnectionSettingsLoadingError as e: + print( + f"Failed to load connection settings, socket file cleanup will be skipped: {e}" + ) + else: + try: + os.remove(conn_settings.socket) + except FileNotFoundError: + pass # Already deleted def test_init( self, initialized_setup: tuple[FrontendRunner, psutil.Process], - connection_file_path: str, + connection_file_path: pathlib.Path, ) -> None: # GIVEN _, backend_proc = initialized_setup @@ -108,7 +120,7 @@ def test_init( # THEN assert os.path.exists(connection_file_path) - connection_settings = _load_connection_settings(connection_file_path) + connection_settings = ConnectionSettingsFileLoader(connection_file_path).load() if OSName.is_windows(): import pywintypes @@ -140,11 +152,11 @@ def test_init( def test_shutdown( self, initialized_setup: tuple[FrontendRunner, psutil.Process], - connection_file_path: str, + connection_file_path: pathlib.Path, ) -> None: # GIVEN frontend, backend_proc = initialized_setup - conn_settings = _load_connection_settings(connection_file_path) + conn_settings = ConnectionSettingsFileLoader(connection_file_path).load() # WHEN frontend.shutdown() @@ -153,7 +165,7 @@ def test_shutdown( assert all( [ _wait_for_file_deletion(p, timeout_s=1) - for p in [connection_file_path, conn_settings.socket] + for p in [str(connection_file_path), conn_settings.socket] ] ) diff --git a/test/openjd/adaptor_runtime/integ/test_integration_entrypoint.py b/test/openjd/adaptor_runtime/integ/test_integration_entrypoint.py index ca2eb05..d1690ff 100644 --- a/test/openjd/adaptor_runtime/integ/test_integration_entrypoint.py +++ b/test/openjd/adaptor_runtime/integ/test_integration_entrypoint.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +import pathlib import os import re import sys @@ -32,7 +33,7 @@ class TestCommandAdaptorRun: """ def test_runs_command_adaptor( - self, capfd: pytest.CaptureFixture, caplog: pytest.LogCaptureFixture + self, capfd: pytest.CaptureFixture, caplog: pytest.LogCaptureFixture, tmp_path: pathlib.Path ): # GIVEN caplog.set_level(INFO) @@ -75,63 +76,170 @@ class TestCommandAdaptorDaemon: Tests for the CommandAdaptor running using the `daemon` command-line. """ - def test_start_stop(self, caplog: pytest.LogCaptureFixture, tmp_path: Path): - # GIVEN - caplog.set_level(INFO) - connection_file = tmp_path / "connection.json" - test_start_argv = [ - "program_filename.py", - "daemon", - "start", - "--connection-file", - str(connection_file), - "--init-data", - json.dumps( - { - "on_prerun": "on_prerun", - "on_postrun": "on_postrun", - } - ), - ] - test_stop_argv = [ - "program_filename.py", - "daemon", - "stop", - "--connection-file", - str(connection_file), - ] - entrypoint = EntryPoint(CommandAdaptorExample) + class TestUsingConnectionFile: + """ + Daemon tests using the --connection-file option + """ - # WHEN - with ( - patch.object(runtime_entrypoint.sys, "argv", test_start_argv), - patch.object(runtime_entrypoint.logging.Logger, "setLevel"), - ): - entrypoint.start() - with ( - patch.object(runtime_entrypoint.sys, "argv", test_stop_argv), - patch.object(runtime_entrypoint.logging.Logger, "setLevel"), - ): - entrypoint.start() + def test_start_stop(self, caplog: pytest.LogCaptureFixture, tmp_path: Path): + # GIVEN + caplog.set_level(INFO) + connection_file = tmp_path / "connection.json" + entrypoint = EntryPoint(CommandAdaptorExample) - # THEN - assert "Initializing backend process" in caplog.text - assert "Connected successfully" in caplog.text - assert "Running in background daemon mode." in caplog.text - assert "Daemon background process stopped." in caplog.text - assert "on_prerun" not in caplog.text - assert "on_postrun" not in caplog.text - - def test_run(self, caplog: pytest.LogCaptureFixture, tmp_path: Path): - # GIVEN - caplog.set_level(INFO) - connection_file = tmp_path / "connection.json" - test_start_argv = [ + # WHEN + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_start_argv(connection_file=connection_file), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_stop_argv(connection_file=connection_file), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + + # THEN + assert "Initializing backend process" in caplog.text + assert "Connected successfully" in caplog.text + assert "Running in background daemon mode." in caplog.text + assert "Daemon background process stopped." in caplog.text + assert "on_prerun" not in caplog.text + assert "on_postrun" not in caplog.text + + def test_run(self, caplog: pytest.LogCaptureFixture, tmp_path: Path): + # GIVEN + caplog.set_level(INFO) + connection_file = tmp_path / "connection.json" + test_run_argv = [ + "program_filename.py", + "daemon", + "run", + "--connection-file", + str(connection_file), + "--run-data", + json.dumps( + {"args": ["echo", "hello world"] if OSName.is_windows() else ["hello world"]} + ), + ] + entrypoint = EntryPoint(CommandAdaptorExample) + + # WHEN + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_start_argv(connection_file=connection_file), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + with ( + patch.object(runtime_entrypoint.sys, "argv", test_run_argv), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_stop_argv(connection_file=connection_file), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + + # THEN + assert "on_prerun" in caplog.text + assert "hello world" in caplog.text + assert "on_postrun" in caplog.text + + class TestUsingEnvVar: + """ + Daemon tests that do not use the --connection-file option and instead use the + OPENJD_ADAPTOR_SOCKET environment variable + """ + + def test_full_cycle(self, caplog: pytest.LogCaptureFixture) -> None: + # GIVEN + caplog.set_level(INFO) + entrypoint = EntryPoint(CommandAdaptorExample) + + # WHEN + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_start_argv(connection_file=None), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + + # THEN + match = re.search( + "openjd_env: OPENJD_ADAPTOR_SOCKET=(.*)$", + caplog.text, + re.MULTILINE, + ) + assert ( + match is not None + ), f"Expected openjd_env statement not found in output: {caplog.text}" + openjd_adaptor_socket = match.group(1) + print( + f"DEBUG: Using OPENJD_ADAPTOR_SOCKET={openjd_adaptor_socket} (exists={os.path.exists(openjd_adaptor_socket)})" + ) + + # WHEN + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_run_argv(connection_file=None), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + patch.dict( + runtime_entrypoint.os.environ, {"OPENJD_ADAPTOR_SOCKET": openjd_adaptor_socket} + ), + ): + entrypoint.start() + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_stop_argv(connection_file=None), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + patch.dict( + runtime_entrypoint.os.environ, {"OPENJD_ADAPTOR_SOCKET": openjd_adaptor_socket} + ), + ): + entrypoint.start() + + # THEN + assert "Initializing backend process" in caplog.text + assert "Connected successfully" in caplog.text + assert "Running in background daemon mode." in caplog.text + assert "on_prerun" in caplog.text + assert "hello world" in caplog.text + assert "on_postrun" in caplog.text + assert "Daemon background process stopped." in caplog.text + + @staticmethod + def get_start_argv(*, connection_file: Path | None = None) -> list[str]: + return [ "program_filename.py", "daemon", "start", - "--connection-file", - str(connection_file), + *(["--connection-file", str(connection_file)] if connection_file else []), "--init-data", json.dumps( { @@ -140,44 +248,32 @@ def test_run(self, caplog: pytest.LogCaptureFixture, tmp_path: Path): } ), ] - test_run_argv = [ + + @staticmethod + def get_run_argv(*, connection_file: Path | None = None) -> list[str]: + return [ "program_filename.py", "daemon", "run", - "--connection-file", - str(connection_file), + *(["--connection-file", str(connection_file)] if connection_file else []), "--run-data", json.dumps( {"args": ["echo", "hello world"] if OSName.is_windows() else ["hello world"]} ), ] - test_stop_argv = [ + + @staticmethod + def get_stop_argv(*, connection_file: Path | None = None) -> list[str]: + return [ "program_filename.py", "daemon", "stop", - "--connection-file", - str(connection_file), + *( + [ + "--connection-file", + str(connection_file), + ] + if connection_file + else [] + ), ] - entrypoint = EntryPoint(CommandAdaptorExample) - - # WHEN - with ( - patch.object(runtime_entrypoint.sys, "argv", test_start_argv), - patch.object(runtime_entrypoint.logging.Logger, "setLevel"), - ): - entrypoint.start() - with ( - patch.object(runtime_entrypoint.sys, "argv", test_run_argv), - patch.object(runtime_entrypoint.logging.Logger, "setLevel"), - ): - entrypoint.start() - with ( - patch.object(runtime_entrypoint.sys, "argv", test_stop_argv), - patch.object(runtime_entrypoint.logging.Logger, "setLevel"), - ): - entrypoint.start() - - # THEN - assert "on_prerun" in caplog.text - assert "hello world" in caplog.text - assert "on_postrun" in caplog.text diff --git a/test/openjd/adaptor_runtime/unit/background/test_backend_runner.py b/test/openjd/adaptor_runtime/unit/background/test_backend_runner.py index c4788b7..1ce62a6 100644 --- a/test/openjd/adaptor_runtime/unit/background/test_backend_runner.py +++ b/test/openjd/adaptor_runtime/unit/background/test_backend_runner.py @@ -1,4 +1,5 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from __future__ import annotations import json import os @@ -23,7 +24,7 @@ class TestBackendRunner: @pytest.fixture(autouse=True) def socket_path(self, tmp_path: pathlib.Path) -> Generator[str, None, None]: if OSName.is_posix(): - with patch.object(backend_runner.SocketDirectories, "get_process_socket_path") as mock: + with patch.object(backend_runner.SocketPaths, "get_process_socket_path") as mock: path = os.path.join(tmp_path, "socket", "1234") mock.return_value = path @@ -67,15 +68,17 @@ def test_run( ): # GIVEN caplog.set_level("DEBUG") - conn_file_path = "/path/to/conn_file" + conn_file = pathlib.Path(os.sep) / "path" / "to" / "conn_file" connection_settings = {"socket": socket_path} adaptor_runner = Mock() - runner = BackendRunner(adaptor_runner, conn_file_path) + runner = BackendRunner(adaptor_runner, connection_file_path=conn_file) # WHEN open_mock: MagicMock with patch.object( - backend_runner, "secure_open", mock_open(read_data=json.dumps(connection_settings)) + backend_runner, + "secure_open", + mock_open(read_data=json.dumps(connection_settings)), ) as open_mock: runner.run() @@ -93,7 +96,7 @@ def test_run( ) mock_thread.assert_called_once() mock_thread.return_value.start.assert_called_once() - open_mock.assert_called_once_with(conn_file_path, open_mode="w") + open_mock.assert_called_once_with(conn_file, open_mode="w") mock_json_dump.assert_called_once_with( ConnectionSettings(socket_path), open_mock.return_value, @@ -101,9 +104,9 @@ def test_run( ) mock_thread.return_value.join.assert_called_once() if OSName.is_posix(): - mock_os_remove.assert_has_calls([call(conn_file_path), call(socket_path)]) + mock_os_remove.assert_has_calls([call(conn_file), call(socket_path)]) else: - mock_os_remove.assert_has_calls([call(conn_file_path)]) + mock_os_remove.assert_has_calls([call(conn_file)]) def test_run_raises_when_http_server_fails_to_start( self, @@ -114,7 +117,10 @@ def test_run_raises_when_http_server_fails_to_start( caplog.set_level("DEBUG") exc = Exception() mock_server_cls.side_effect = exc - runner = BackendRunner(Mock(), "") + runner = BackendRunner( + Mock(), + connection_file_path=pathlib.Path(os.path.sep) / "tmp" / "connection.json", + ) # WHEN with pytest.raises(Exception) as raised_exc: @@ -144,9 +150,9 @@ def test_run_raises_when_writing_connection_file_fails( caplog.set_level("DEBUG") err = OSError() open_mock.side_effect = err - conn_file_path = "/path/to/conn_file" + conn_file = pathlib.Path(os.sep) / "path" / "to" / "conn_file" adaptor_runner = Mock() - runner = BackendRunner(adaptor_runner, conn_file_path) + runner = BackendRunner(adaptor_runner, connection_file_path=conn_file) # WHEN with pytest.raises(OSError) as raised_err: @@ -164,12 +170,12 @@ def test_run_raises_when_writing_connection_file_fails( ] mock_thread.assert_called_once() mock_thread.return_value.start.assert_called_once() - open_mock.assert_called_once_with(conn_file_path, open_mode="w") + open_mock.assert_called_once_with(conn_file, open_mode="w") mock_thread.return_value.join.assert_called_once() if OSName.is_posix(): - mock_os_remove.assert_has_calls([call(conn_file_path), call(socket_path)]) + mock_os_remove.assert_has_calls([call(conn_file), call(socket_path)]) else: - mock_os_remove.assert_has_calls([call(conn_file_path)]) + mock_os_remove.assert_has_calls([call(conn_file)]) @patch.object(backend_runner.signal, "signal") @patch.object(backend_runner.ServerResponseGenerator, "submit_task") @@ -178,9 +184,9 @@ def test_signal_hook(self, mock_submit, signal_mock: MagicMock) -> None: # as expected. # GIVEN - conn_file_path = "/path/to/conn_file" + conn_file_path = pathlib.Path(os.sep) / "path" / "to" / "conn_file" adaptor_runner = Mock() - runner = BackendRunner(adaptor_runner, conn_file_path) + runner = BackendRunner(adaptor_runner, connection_file_path=conn_file_path) server_mock = MagicMock() submit_mock = MagicMock() server_mock.submit = submit_mock diff --git a/test/openjd/adaptor_runtime/unit/background/test_frontend_runner.py b/test/openjd/adaptor_runtime/unit/background/test_frontend_runner.py index 1f2ecc5..9545ea5 100644 --- a/test/openjd/adaptor_runtime/unit/background/test_frontend_runner.py +++ b/test/openjd/adaptor_runtime/unit/background/test_frontend_runner.py @@ -4,32 +4,32 @@ import http.client as http_client import json +import os import re import signal import subprocess import sys +import tempfile from pathlib import Path from types import ModuleType from typing import Generator, Optional -from unittest.mock import MagicMock, PropertyMock, call, mock_open, patch +from unittest.mock import MagicMock, call, patch import pytest -import openjd.adaptor_runtime._background.frontend_runner as frontend_runner +from openjd.adaptor_runtime._background import frontend_runner from openjd.adaptor_runtime._osname import OSName from openjd.adaptor_runtime.adaptors import AdaptorState from openjd.adaptor_runtime._background.frontend_runner import ( AdaptorFailedException, FrontendRunner, HTTPError, - _load_connection_settings, - _wait_for_file, + _wait_for_connection_file, ) from openjd.adaptor_runtime._background.model import ( AdaptorStatus, BufferedOutput, ConnectionSettings, - DataclassMapper, HeartbeatResponse, ) @@ -43,17 +43,77 @@ class TestFrontendRunner: def server_name(self) -> str: return "/path/to/socket" if OSName.is_posix() else r"\\.\pipe\TestPipe" + @pytest.fixture + def connection_settings(self, server_name: str) -> ConnectionSettings: + return ConnectionSettings(server_name) + @pytest.fixture(autouse=True) - def mock_connection_settings(self, server_name: str) -> Generator[MagicMock, None, None]: - with patch.object(FrontendRunner, "connection_settings", new_callable=PropertyMock) as mock: - mock.return_value = ConnectionSettings(server_name) - yield mock + def mock_connection_settings( + self, connection_settings: ConnectionSettings + ) -> Generator[None, None, None]: + with patch.object( + FrontendRunner, + "connection_settings", + return_value=connection_settings, + create=True, + ): + yield class TestInit: """ Tests for the FrontendRunner.init method """ + @pytest.fixture(autouse=True) + def open_mock(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner, "open") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_path_exists(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.Path, "exists") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_uuid(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.uuid, "uuid4") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_gettempdir(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.tempfile, "gettempdir") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_Popen(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.subprocess, "Popen") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_wait_for_connection_file(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner, "_wait_for_connection_file") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_heartbeat(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.FrontendRunner, "_heartbeat") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_sys_argv(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.sys, "argv", return_value=[]) as m: + yield m + + @pytest.fixture(autouse=True) + def mock_sys_executable(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.sys, "executable", return_value="executable") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_connection_settings_file_load(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.ConnectionSettingsFileLoader, "load") as m: + yield m + @pytest.mark.parametrize( argnames="reentry_exe", argvalues=[ @@ -61,51 +121,53 @@ class TestInit: (Path("reeentry_exe_value"),), ], ) - @patch.object(frontend_runner.sys, "argv") - @patch.object(frontend_runner.sys, "executable") - @patch.object(frontend_runner.json, "dumps") - @patch.object(FrontendRunner, "_heartbeat") - @patch.object(frontend_runner, "_wait_for_file") - @patch.object(frontend_runner.subprocess, "Popen") - @patch.object(frontend_runner.os.path, "exists") def test_initializes_backend_process( self, - mock_exists: MagicMock, + mock_path_exists: MagicMock, mock_Popen: MagicMock, - mock_wait_for_file: MagicMock, + mock_wait_for_connection_file: MagicMock, mock_heartbeat: MagicMock, - mock_json_dumps: MagicMock, mock_sys_executable: MagicMock, mock_sys_argv: MagicMock, + mock_uuid: MagicMock, + open_mock: MagicMock, caplog: pytest.LogCaptureFixture, reentry_exe: Optional[Path], ): # GIVEN caplog.set_level("DEBUG") - mock_json_dumps.return_value = "test" - mock_exists.return_value = False + mock_path_exists.return_value = False pid = 123 mock_Popen.return_value.pid = pid mock_sys_executable.return_value = "executable" mock_sys_argv.return_value = [] adaptor_module = ModuleType("") adaptor_module.__package__ = "package" - conn_file_path = "/path" init_data = {"init": "data"} path_mapping_data: dict = {} - runner = FrontendRunner(conn_file_path) + connection_file_path = Path("connection.test") + runner = FrontendRunner() # WHEN - runner.init(adaptor_module, init_data, path_mapping_data, reentry_exe) + runner.init( + adaptor_module=adaptor_module, + connection_file_path=connection_file_path, + init_data=init_data, + path_mapping_data=path_mapping_data, + reentry_exe=reentry_exe, + ) # THEN - assert caplog.messages == [ - "Initializing backend process...", - f"Started backend process. PID: {pid}", - "Verifying connection to backend...", - "Connected successfully", - ] - mock_exists.assert_called_once_with(conn_file_path) + assert all( + m in caplog.messages + for m in [ + "Initializing backend process...", + f"Started backend process. PID: {pid}", + "Verifying connection to backend...", + "Connected successfully", + ] + ) + mock_path_exists.assert_called_once_with() if reentry_exe is None: expected_args = [ sys.executable, @@ -113,78 +175,95 @@ def test_initializes_backend_process( adaptor_module.__package__, "daemon", "_serve", - "--connection-file", - conn_file_path, "--init-data", json.dumps(init_data), "--path-mapping-rules", json.dumps(path_mapping_data), + "--connection-file", + str(connection_file_path), ] else: expected_args = [ str(reentry_exe), "daemon", "_serve", - "--connection-file", - conn_file_path, "--init-data", json.dumps(init_data), "--path-mapping-rules", json.dumps(path_mapping_data), + "--connection-file", + str(connection_file_path), + ] + expected_args.extend( + [ + "--bootstrap-log-file", + os.path.join( + tempfile.gettempdir(), + f"adaptor-runtime-background-bootstrap-{mock_uuid.return_value}.log", + ), ] + ) mock_Popen.assert_called_once_with( expected_args, shell=False, close_fds=True, start_new_session=True, stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, + stdout=open_mock.return_value, + stderr=open_mock.return_value, + ) + mock_wait_for_connection_file.assert_called_once_with( + str(connection_file_path), + max_retries=5, + interval_s=1, ) - mock_wait_for_file.assert_called_once_with(conn_file_path, timeout_s=5) mock_heartbeat.assert_called_once() def test_raises_when_adaptor_module_not_package(self): # GIVEN adaptor_module = ModuleType("") adaptor_module.__package__ = None - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN with pytest.raises(Exception) as raised_exc: - runner.init(adaptor_module) + runner.init( + adaptor_module=adaptor_module, + connection_file_path="/path", + ) # THEN assert raised_exc.match(f"Adaptor module is not a package: {adaptor_module}") - @patch.object(frontend_runner.os.path, "exists") def test_raises_when_connection_file_exists( self, - mock_exists: MagicMock, + mock_path_exists: MagicMock, ): # GIVEN - mock_exists.return_value = True + mock_path_exists.return_value = True adaptor_module = ModuleType("") adaptor_module.__package__ = "package" - conn_file_path = "/path" - runner = FrontendRunner(conn_file_path) + conn_file_path = Path("/path") + runner = FrontendRunner() # WHEN with pytest.raises(FileExistsError) as raised_err: - runner.init(adaptor_module) + runner.init( + adaptor_module=adaptor_module, + connection_file_path=conn_file_path, + ) # THEN assert raised_err.match( - "Cannot init a new backend process with an existing connection file at: " - + conn_file_path + re.escape( + f"Cannot init a new backend process with an existing connection file at: {conn_file_path}" + ) ) - mock_exists.assert_called_once_with(conn_file_path) + mock_path_exists.assert_called_once_with() - @patch.object(frontend_runner.subprocess, "Popen") - @patch.object(frontend_runner.os.path, "exists") def test_raises_when_failed_to_create_backend_process( self, - mock_exists: MagicMock, + mock_path_exists: MagicMock, mock_Popen: MagicMock, caplog: pytest.LogCaptureFixture, ): @@ -192,102 +271,123 @@ def test_raises_when_failed_to_create_backend_process( caplog.set_level("DEBUG") exc = Exception() mock_Popen.side_effect = exc - mock_exists.return_value = False + mock_path_exists.return_value = False adaptor_module = ModuleType("") adaptor_module.__package__ = "package" - conn_file_path = "/path" - runner = FrontendRunner(conn_file_path) + conn_file_path = Path("/path") + runner = FrontendRunner() # WHEN with pytest.raises(Exception) as raised_exc: - runner.init(adaptor_module) + runner.init( + adaptor_module=adaptor_module, + connection_file_path=conn_file_path, + ) # THEN assert raised_exc.value is exc - assert caplog.messages == [ - "Initializing backend process...", - "Failed to initialize backend process: ", - ] - mock_exists.assert_called_once_with(conn_file_path) + assert all( + m in caplog.messages + for m in [ + "Initializing backend process...", + "Failed to initialize backend process: ", + ] + ) + mock_path_exists.assert_called_once_with() mock_Popen.assert_called_once() - @patch.object(frontend_runner, "_wait_for_file") - @patch.object(frontend_runner.subprocess, "Popen") - @patch.object(frontend_runner.os.path, "exists") def test_raises_when_connection_file_wait_times_out( self, - mock_exists: MagicMock, + mock_path_exists: MagicMock, mock_Popen: MagicMock, - mock_wait_for_file: MagicMock, + mock_wait_for_connection_file: MagicMock, caplog: pytest.LogCaptureFixture, ): # GIVEN caplog.set_level("DEBUG") err = TimeoutError() - mock_wait_for_file.side_effect = err - mock_exists.return_value = False + mock_wait_for_connection_file.side_effect = err + mock_path_exists.return_value = False pid = 123 mock_Popen.return_value.pid = pid adaptor_module = ModuleType("") adaptor_module.__package__ = "package" - conn_file_path = "/path" - runner = FrontendRunner(conn_file_path) + conn_file_path = Path("/path") + runner = FrontendRunner() # WHEN with pytest.raises(TimeoutError) as raised_err: - runner.init(adaptor_module) + runner.init( + adaptor_module=adaptor_module, + connection_file_path=conn_file_path, + ) # THEN assert raised_err.value is err - print(caplog.messages) - assert caplog.messages == [ - "Initializing backend process...", - f"Started backend process. PID: {pid}", - f"Backend process failed to write connection file in time at: {conn_file_path}", - ] - mock_exists.assert_called_once_with(conn_file_path) + assert all( + m in caplog.messages + for m in [ + "Initializing backend process...", + f"Started backend process. PID: {pid}", + f"Backend process failed to write connection file in time at: {conn_file_path}", + ] + ) + mock_path_exists.assert_called_once_with() mock_Popen.assert_called_once() - mock_wait_for_file.assert_called_once_with(conn_file_path, timeout_s=5) + mock_wait_for_connection_file.assert_called_once_with( + str(conn_file_path), + max_retries=5, + interval_s=1, + ) class TestHeartbeat: """ Tests for the FrontendRunner._heartbeat method """ - @patch.object(frontend_runner.json, "load") - @patch.object(DataclassMapper, "map") - @patch.object(FrontendRunner, "_send_request") + @pytest.fixture(autouse=True) + def mock_json_load(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.json, "load") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_dataclass_mapper_map(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.DataclassMapper, "map") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_send_request(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.FrontendRunner, "_send_request") as m: + yield m + def test_sends_heartbeat( self, mock_send_request: MagicMock, - mock_map: MagicMock, + mock_dataclass_mapper_map: MagicMock, mock_json_load: MagicMock, ): # GIVEN if OSName.is_windows(): mock_send_request.return_value = {"body": '{"key1": "value1"}'} mock_response = mock_send_request.return_value - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN response = runner._heartbeat() # THEN - assert response is mock_map.return_value + assert response is mock_dataclass_mapper_map.return_value if OSName.is_posix(): mock_json_load.assert_called_once_with(mock_response.fp) - mock_map.assert_called_once_with(mock_json_load.return_value) + mock_dataclass_mapper_map.assert_called_once_with(mock_json_load.return_value) else: - mock_map.assert_called_once_with({"key1": "value1"}) + mock_dataclass_mapper_map.assert_called_once_with({"key1": "value1"}) mock_send_request.assert_called_once_with("GET", "/heartbeat", params=None) - @patch.object(frontend_runner.json, "load") - @patch.object(DataclassMapper, "map") - @patch.object(FrontendRunner, "_send_request") def test_sends_heartbeat_with_ack_id( self, mock_send_request: MagicMock, - mock_map: MagicMock, + mock_dataclass_mapper_map: MagicMock, mock_json_load: MagicMock, ): # GIVEN @@ -295,18 +395,18 @@ def test_sends_heartbeat_with_ack_id( if OSName.is_windows(): mock_send_request.return_value = {"body": '{"key1": "value1"}'} mock_response = mock_send_request.return_value - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN response = runner._heartbeat(ack_id) # THEN - assert response is mock_map.return_value + assert response is mock_dataclass_mapper_map.return_value if OSName.is_posix(): mock_json_load.assert_called_once_with(mock_response.fp) - mock_map.assert_called_once_with(mock_json_load.return_value) + mock_dataclass_mapper_map.assert_called_once_with(mock_json_load.return_value) else: - mock_map.assert_called_once_with({"key1": "value1"}) + mock_dataclass_mapper_map.assert_called_once_with({"key1": "value1"}) mock_send_request.assert_called_once_with( "GET", "/heartbeat", params={"ack_id": ack_id} ) @@ -338,7 +438,7 @@ def test_heartbeats_until_complete( mock_event.wait = MagicMock() mock_event.is_set = MagicMock(return_value=False) heartbeat_interval = 1 - runner = FrontendRunner("", heartbeat_interval=heartbeat_interval) + runner = FrontendRunner(heartbeat_interval=heartbeat_interval) # WHEN runner._heartbeat_until_state_complete(state) @@ -367,7 +467,7 @@ def test_raises_when_adaptor_fails(self, mock_heartbeat: MagicMock) -> None: failed=False, ), ] - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN with pytest.raises(AdaptorFailedException) as raised_exc: @@ -385,7 +485,7 @@ class TestShutdown: @patch.object(FrontendRunner, "_send_request") def test_sends_shutdown(self, mock_send_request: MagicMock): # GIVEN - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN runner.shutdown() @@ -407,7 +507,7 @@ def test_sends_run( ): # GIVEN run_data = {"run": "data"} - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN runner.run(run_data) @@ -429,7 +529,7 @@ def test_sends_start( mock_heartbeat_until_state_complete: MagicMock, ): # GIVEN - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN runner.start() @@ -451,7 +551,7 @@ def test_sends_end( mock_heartbeat_until_state_complete: MagicMock, ): # GIVEN - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN runner.stop() @@ -471,7 +571,7 @@ def test_sends_cancel( mock_send_request: MagicMock, ): # GIVEN - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN runner.cancel() @@ -497,12 +597,16 @@ def mock_getresponse(self, mock_response: MagicMock) -> Generator[MagicMock, Non yield mock @patch.object(frontend_runner.UnixHTTPConnection, "request") - def test_sends_request(self, mock_request: MagicMock, mock_getresponse: MagicMock): + def test_sends_request( + self, + mock_request: MagicMock, + mock_getresponse: MagicMock, + connection_settings: ConnectionSettings, + ): # GIVEN method = "GET" path = "/path" - conn_file_path = "/conn/file/path" - runner = FrontendRunner(conn_file_path) + runner = FrontendRunner(connection_settings=connection_settings) # WHEN response = runner._send_request(method, path) @@ -521,6 +625,7 @@ def test_raises_when_request_fails( self, mock_request: MagicMock, mock_getresponse: MagicMock, + connection_settings: ConnectionSettings, caplog: pytest.LogCaptureFixture, ): # GIVEN @@ -528,8 +633,7 @@ def test_raises_when_request_fails( mock_getresponse.side_effect = exc method = "GET" path = "/path" - conn_file_path = "/conn/file/path" - runner = FrontendRunner(conn_file_path) + runner = FrontendRunner(connection_settings=connection_settings) # WHEN with pytest.raises(http_client.HTTPException) as raised_exc: @@ -551,6 +655,7 @@ def test_raises_when_error_response_received( mock_request: MagicMock, mock_getresponse: MagicMock, mock_response: MagicMock, + connection_settings: ConnectionSettings, caplog: pytest.LogCaptureFixture, ): # GIVEN @@ -558,8 +663,7 @@ def test_raises_when_error_response_received( mock_response.reason = "Something went wrong" method = "GET" path = "/path" - conn_file_path = "/conn/file/path" - runner = FrontendRunner(conn_file_path) + runner = FrontendRunner(connection_settings=connection_settings) # WHEN with pytest.raises(HTTPError) as raised_err: @@ -579,13 +683,17 @@ def test_raises_when_error_response_received( mock_getresponse.assert_called_once() @patch.object(frontend_runner.UnixHTTPConnection, "request") - def test_formats_query_string(self, mock_request: MagicMock, mock_getresponse: MagicMock): + def test_formats_query_string( + self, + mock_request: MagicMock, + mock_getresponse: MagicMock, + connection_settings: ConnectionSettings, + ): # GIVEN method = "GET" path = "/path" - conn_file_path = "/conn/file/path" params = {"first param": 1, "second_param": ["one", "two three"]} - runner = FrontendRunner(conn_file_path) + runner = FrontendRunner(connection_settings=connection_settings) # WHEN response = runner._send_request(method, path, params=params) @@ -600,13 +708,17 @@ def test_formats_query_string(self, mock_request: MagicMock, mock_getresponse: M assert response is mock_getresponse.return_value @patch.object(frontend_runner.UnixHTTPConnection, "request") - def test_sends_body(self, mock_request: MagicMock, mock_getresponse: MagicMock): + def test_sends_body( + self, + mock_request: MagicMock, + mock_getresponse: MagicMock, + connection_settings: ConnectionSettings, + ): # GIVEN method = "GET" path = "/path" - conn_file_path = "/conn/file/path" json = {"the": "body"} - runner = FrontendRunner(conn_file_path) + runner = FrontendRunner(connection_settings=connection_settings) # WHEN response = runner._send_request(method, path, json_body=json) @@ -638,17 +750,23 @@ def mock_read_from_pipe(self, mock_response: MagicMock) -> Generator[MagicMock, mock_read_from_pipe.return_value = mock_response yield mock_read_from_pipe + @pytest.fixture + def connection_settings(self) -> ConnectionSettings: + return ConnectionSettings("\\\\.\\pipe") + + @pytest.fixture + def runner(self, connection_settings: ConnectionSettings) -> FrontendRunner: + return FrontendRunner(connection_settings=connection_settings) + def test_sends_request( self, mock_read_from_pipe: MagicMock, mock_response: str, + runner: FrontendRunner, ): # GIVEN method = "GET" path = "/path" - conn_file_path = r"C:\conn\file\path" - - runner = FrontendRunner(conn_file_path) # WHEN with patch.object( @@ -670,6 +788,7 @@ def test_raises_when_request_fails( self, mock_read_from_pipe: MagicMock, mock_response: str, + runner: FrontendRunner, caplog: pytest.LogCaptureFixture, ): # GIVEN @@ -679,8 +798,6 @@ def test_raises_when_request_fails( mock_read_from_pipe.side_effect = error_instance method = "GET" path = "/path" - conn_file_path = r"C:\conn\file\path" - runner = FrontendRunner(conn_file_path) # WHEN with patch.object( @@ -703,13 +820,12 @@ def test_raises_when_request_fails( def test_raises_when_error_response_received( self, mock_response: str, + runner: FrontendRunner, caplog: pytest.LogCaptureFixture, ): # GIVEN method = "GET" path = "/path" - conn_file_path = r"C:\conn\file\path" - runner = FrontendRunner(conn_file_path) # WHEN with patch.object( @@ -740,14 +856,13 @@ def test_formats_query_string( self, mock_read_from_pipe, mock_response: str, + runner: FrontendRunner, caplog: pytest.LogCaptureFixture, ): # GIVEN method = "GET" path = "/path" - conn_file_path = r"C:\conn\file\path" params = {"first param": 1, "second_param": ["one", "two three"]} - runner = FrontendRunner(conn_file_path) # WHEN with patch.object( @@ -770,14 +885,13 @@ def test_sends_body( self, mock_read_from_pipe, mock_response: str, + runner: FrontendRunner, caplog: pytest.LogCaptureFixture, ): # GIVEN method = "GET" path = "/path" - conn_file_path = r"C:\conn\file\path" json_body = {"the": "body"} - runner = FrontendRunner(conn_file_path) # WHEN with patch.object( @@ -804,8 +918,7 @@ def test_hook(self, signal_mock: MagicMock, cancel_mock: MagicMock) -> None: # as expected. # GIVEN - conn_file_path = "/path/to/conn_file" - runner = FrontendRunner(conn_file_path) + runner = FrontendRunner() # WHEN runner._sigint_handler(MagicMock(), MagicMock()) @@ -819,144 +932,57 @@ def test_hook(self, signal_mock: MagicMock, cancel_mock: MagicMock) -> None: cancel_mock.assert_called_once() -class TestLoadConnectionSettings: +class TestWaitForConnectionFile: """ - Tests for the _load_connection_settings method + Tests for the _wait_for_connection_file method """ - @patch.object(DataclassMapper, "map") - def test_loads_settings( - self, - mock_map: MagicMock, - ): - # GIVEN - filepath = "/path" - connection_settings = {"port": 123} - - # WHEN - with patch.object( - frontend_runner, "open", mock_open(read_data=json.dumps(connection_settings)) - ): - _load_connection_settings(filepath) - - # THEN - mock_map.assert_called_once_with(connection_settings) - + @patch.object(frontend_runner.ConnectionSettingsFileLoader, "load") @patch.object(frontend_runner, "open") - def test_raises_when_file_open_fails( - self, - open_mock: MagicMock, - caplog: pytest.LogCaptureFixture, - ): - # GIVEN - filepath = "/path" - err = OSError() - open_mock.side_effect = err - - # WHEN - with pytest.raises(OSError) as raised_err: - _load_connection_settings(filepath) - - # THEN - assert raised_err.value is err - assert "Failed to open connection file: " in caplog.text - - @patch.object(frontend_runner.json, "load") - def test_raises_when_json_decode_fails( - self, - mock_json_load: MagicMock, - caplog: pytest.LogCaptureFixture, - ): - # GIVEN - filepath = "/path" - err = json.JSONDecodeError("", "", 0) - mock_json_load.side_effect = err - - # WHEN - with pytest.raises(json.JSONDecodeError) as raised_err: - with patch.object(frontend_runner, "open", mock_open()): - _load_connection_settings(filepath) - - # THEN - assert raised_err.value is err - assert "Failed to decode connection file: " in caplog.text - - -class TestWaitForFile: - """ - Tests for the _wait_for_file method - """ - - @patch.object(frontend_runner, "open") - @patch.object(frontend_runner.time, "time") @patch.object(frontend_runner.time, "sleep") @patch.object(frontend_runner.os.path, "exists") def test_waits_for_file( self, mock_exists: MagicMock, mock_sleep: MagicMock, - mock_time: MagicMock, open_mock: MagicMock, + mock_conn_file_loader_load: MagicMock, ): # GIVEN filepath = "/path" - timeout = sys.float_info.max + max_retries = 9999 interval = 0.01 - mock_time.side_effect = [1, 2, 3, 4] mock_exists.side_effect = [False, True] err = IOError() open_mock.side_effect = [err, MagicMock()] + mock_conn_file_loader_load.return_value = ConnectionSettings("/server") # WHEN - _wait_for_file(filepath, timeout, interval) + _wait_for_connection_file(filepath, max_retries, interval) # THEN - assert mock_time.call_count == 4 mock_exists.assert_has_calls([call(filepath)] * 2) mock_sleep.assert_has_calls([call(interval)] * 3) open_mock.assert_has_calls([call(filepath, mode="r")] * 2) - @patch.object(frontend_runner.time, "time") @patch.object(frontend_runner.time, "sleep") @patch.object(frontend_runner.os.path, "exists") - def test_raises_when_timeout_reached( + def test_raises_when_retries_reached( self, mock_exists: MagicMock, mock_sleep: MagicMock, - mock_time: MagicMock, ): # GIVEN filepath = "/path" - timeout = 0 + max_retries = 0 interval = 0.01 - mock_time.side_effect = [1, 2] - mock_exists.side_effect = [False] + mock_exists.return_value = False # WHEN with pytest.raises(TimeoutError) as raised_err: - _wait_for_file(filepath, timeout, interval) + _wait_for_connection_file(filepath, max_retries, interval) # THEN - assert raised_err.match(f"Timed out after {timeout}s waiting for file at {filepath}") - assert mock_time.call_count == 2 + assert raised_err.match(f"Timed out waiting for File '{filepath}' to exist") mock_exists.assert_called_once_with(filepath) mock_sleep.assert_not_called() - - -@patch.object(frontend_runner, "_load_connection_settings") -def test_connection_settings_lazy_loads(mock_load_connection_settings: MagicMock): - # GIVEN - filepath = "/path" - expected = ConnectionSettings("/socket") - mock_load_connection_settings.return_value = expected - runner = FrontendRunner(filepath) - - # Assert the internal connection settings var is not set yet - assert not hasattr(runner, "_connection_settings") - - # WHEN - actual = runner.connection_settings - - # THEN - assert actual is expected - assert runner._connection_settings is expected diff --git a/test/openjd/adaptor_runtime/unit/background/test_loaders.py b/test/openjd/adaptor_runtime/unit/background/test_loaders.py new file mode 100644 index 0000000..513221b --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/background/test_loaders.py @@ -0,0 +1,148 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import dataclasses +import pathlib +import json +import re +import typing +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +from openjd.adaptor_runtime._background import loaders +from openjd.adaptor_runtime._background.loaders import ( + ConnectionSettingsEnvLoader, + ConnectionSettingsFileLoader, + ConnectionSettingsLoadingError, +) +from openjd.adaptor_runtime._background.model import ( + ConnectionSettings, +) + + +class TestConnectionSettingsFileLoader: + """ + Tests for the ConnectionsettingsFileLoader class + """ + + @pytest.fixture + def connection_settings(self) -> ConnectionSettings: + return ConnectionSettings(socket="socket") + + @pytest.fixture(autouse=True) + def open_mock( + self, connection_settings: ConnectionSettings + ) -> typing.Generator[MagicMock, None, None]: + with patch.object( + loaders, + "open", + mock_open(read_data=json.dumps(dataclasses.asdict(connection_settings))), + ) as m: + yield m + + @pytest.fixture + def connection_file_path(self) -> pathlib.Path: + return pathlib.Path("test") + + @pytest.fixture + def loader(self, connection_file_path: pathlib.Path) -> ConnectionSettingsFileLoader: + return ConnectionSettingsFileLoader(connection_file_path) + + def test_loads_settings( + self, + connection_settings: ConnectionSettings, + loader: ConnectionSettingsFileLoader, + ): + # WHEN + result = loader.load() + + # THEN + assert result == connection_settings + + def test_raises_when_file_open_fails( + self, + open_mock: MagicMock, + loader: ConnectionSettingsFileLoader, + connection_file_path: pathlib.Path, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + err = OSError() + open_mock.side_effect = err + + # WHEN + with pytest.raises(ConnectionSettingsLoadingError): + loader.load() + + # THEN + assert f"Failed to open connection file '{connection_file_path}': " in caplog.text + + def test_raises_when_json_decode_fails( + self, + loader: ConnectionSettingsFileLoader, + connection_file_path: pathlib.Path, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + err = json.JSONDecodeError("", "", 0) + + with patch.object(loaders.json, "load", side_effect=err): + with pytest.raises(ConnectionSettingsLoadingError): + # WHEN + loader.load() + + # THEN + assert f"Failed to decode connection file '{connection_file_path}': " in caplog.text + + +class TestConnectionSettingsEnvLoader: + @pytest.fixture + def connection_settings(self) -> ConnectionSettings: + return ConnectionSettings(socket="socket") + + @pytest.fixture + def mock_env(self, connection_settings: ConnectionSettings) -> dict[str, typing.Any]: + return { + env_name: getattr(connection_settings, attr_name) + for attr_name, (env_name, _) in ConnectionSettingsEnvLoader().env_map.items() + } + + @pytest.fixture(autouse=True) + def mock_os_environ( + self, mock_env: dict[str, typing.Any] + ) -> typing.Generator[dict, None, None]: + with patch.dict(loaders.os.environ, mock_env) as d: + yield d + + @pytest.fixture + def loader(self) -> ConnectionSettingsEnvLoader: + return ConnectionSettingsEnvLoader() + + def test_loads_connection_settings( + self, + loader: ConnectionSettingsEnvLoader, + connection_settings: ConnectionSettings, + ) -> None: + # WHEN + settings = loader.load() + + # THEN + assert connection_settings == settings + + def test_raises_error_when_required_not_provided( + self, + loader: ConnectionSettingsEnvLoader, + ) -> None: + # GIVEN + with patch.object(loaders.os.environ, "get", return_value=None): + with pytest.raises(loaders.ConnectionSettingsLoadingError) as raised_err: + # WHEN + loader.load() + + # THEN + assert re.match( + "^Required attribute '.*' does not have its corresponding environment variable '.*' set", + str(raised_err.value), + ) diff --git a/test/openjd/adaptor_runtime/unit/http/test_sockets.py b/test/openjd/adaptor_runtime/unit/http/test_sockets.py index f6e9c5f..972c461 100644 --- a/test/openjd/adaptor_runtime/unit/http/test_sockets.py +++ b/test/openjd/adaptor_runtime/unit/http/test_sockets.py @@ -1,73 +1,56 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. import os +import pathlib import re import stat from typing import Generator -from unittest.mock import ANY, MagicMock, call, patch +from unittest.mock import MagicMock, patch import pytest import openjd.adaptor_runtime._http.sockets as sockets from openjd.adaptor_runtime._http.sockets import ( - LinuxSocketDirectories, - MacOSSocketDirectories, + LinuxSocketPaths, + MacOSSocketPaths, NonvalidSocketPathException, NoSocketPathFoundException, - SocketDirectories, + SocketPaths, + UnixSocketPaths, ) -class SocketDirectoriesStub(SocketDirectories): +class SocketPathsStub(SocketPaths): def verify_socket_path(self, path: str) -> None: pass -class TestSocketDirectories: +class TestSocketPaths: class TestGetProcessSocketPath: """ - Tests for SocketDirectories.get_process_socket_path() + Tests for SocketPaths.get_process_socket_path() """ - @pytest.fixture - def socket_dir(self) -> str: - return "/path/to/socket/dir" - - @pytest.fixture(autouse=True) - def mock_socket_dir(self, socket_dir: str) -> Generator[MagicMock, None, None]: - with patch.object(SocketDirectories, "get_socket_dir") as mock: - mock.return_value = socket_dir - yield mock - - @pytest.mark.parametrize( - argnames=["create_dir"], - argvalues=[[True], [False]], - ids=["creates dir", "does not create dir"], - ) @patch.object(sockets.os, "getpid", return_value=1234) def test_gets_path( self, mock_getpid: MagicMock, - socket_dir: str, - mock_socket_dir: MagicMock, - create_dir: bool, ) -> None: # GIVEN namespace = "my-namespace" - subject = SocketDirectoriesStub() + subject = SocketPathsStub() # WHEN - result = subject.get_process_socket_path(namespace, create_dir=create_dir) + result = subject.get_process_socket_path(namespace) # THEN - assert result == os.path.join(socket_dir, str(mock_getpid.return_value)) + assert result.endswith(os.path.join(namespace, str(mock_getpid.return_value))) mock_getpid.assert_called_once() - mock_socket_dir.assert_called_once_with(namespace, create=create_dir) @patch.object(sockets.os, "getpid", return_value="a" * (sockets._PID_MAX_LENGTH + 1)) def test_asserts_max_pid_length(self, mock_getpid: MagicMock): # GIVEN - subject = SocketDirectoriesStub() + subject = SocketPathsStub() # WHEN with pytest.raises(AssertionError) as raised_err: @@ -79,76 +62,39 @@ def test_asserts_max_pid_length(self, mock_getpid: MagicMock): ) mock_getpid.assert_called_once() - class TestGetSocketDir: + class TestGetSocketPath: """ - Tests for SocketDirectories.get_socket_dir() + Tests for SocketPaths.get_socket_path() """ @pytest.fixture(autouse=True) - def mock_makedirs(self) -> Generator[MagicMock, None, None]: - with patch.object(sockets.os, "makedirs") as mock: + def mock_exists(self) -> Generator[MagicMock, None, None]: + with patch.object(sockets.os.path, "exists") as mock: + mock.return_value = False yield mock - @pytest.fixture - def home_dir(self) -> str: - return os.path.join("home", "user") - @pytest.fixture(autouse=True) - def mock_expanduser(self, home_dir: str) -> Generator[MagicMock, None, None]: - with patch.object(sockets.os.path, "expanduser", return_value=home_dir) as mock: - yield mock - - @pytest.fixture - def temp_dir(self) -> str: - return "tmp" - - @pytest.fixture(autouse=True) - def mock_gettempdir(self, temp_dir: str) -> Generator[MagicMock, None, None]: - with patch.object(sockets.tempfile, "gettempdir", return_value=temp_dir) as mock: + def mock_makedirs(self) -> Generator[MagicMock, None, None]: + with patch.object(sockets.os, "makedirs") as mock: yield mock - def test_gets_home_dir( - self, - mock_expanduser: MagicMock, - home_dir: str, - ) -> None: - # GIVEN - subject = SocketDirectoriesStub() - - # WHEN - result = subject.get_socket_dir() - - # THEN - mock_expanduser.assert_called_once_with("~") - assert result.startswith(home_dir) - + @patch.object(sockets.tempfile, "gettempdir", wraps=sockets.tempfile.gettempdir) @patch.object(sockets.os, "stat") - @patch.object(SocketDirectoriesStub, "verify_socket_path") def test_gets_temp_dir( self, - mock_verify_socket_path: MagicMock, mock_stat: MagicMock, mock_gettempdir: MagicMock, - temp_dir: str, ) -> None: # GIVEN - exc = NonvalidSocketPathException() - mock_verify_socket_path.side_effect = [exc, None] # Raise exc only once mock_stat.return_value.st_mode = stat.S_ISVTX - subject = SocketDirectoriesStub() + subject = UnixSocketPaths() # WHEN - result = subject.get_socket_dir() + subject.get_socket_path("sock") # THEN mock_gettempdir.assert_called_once() - mock_verify_socket_path.assert_has_calls( - [ - call(ANY), # home dir - call(result), # temp dir - ] - ) - mock_stat.assert_called_once_with(temp_dir) + mock_stat.assert_called() @pytest.mark.parametrize( argnames=["create"], @@ -157,154 +103,205 @@ def test_gets_temp_dir( ) def test_create_dir(self, mock_makedirs: MagicMock, create: bool) -> None: # GIVEN - subject = SocketDirectoriesStub() + subject = SocketPathsStub() # WHEN - result = subject.get_socket_dir(create=create) + result = subject.get_socket_path("sock", create_dir=create) # THEN if create: - mock_makedirs.assert_called_once_with(result, mode=0o700, exist_ok=True) + mock_makedirs.assert_called_once_with( + os.path.dirname(result), mode=0o700, exist_ok=True + ) else: mock_makedirs.assert_not_called() + def test_uses_base_dir(self, tmp_path: pathlib.Path) -> None: + # GIVEN + subject = SocketPathsStub() + base_dir = str(tmp_path) + + # WHEN + result = subject.get_socket_path("sock", base_dir=base_dir) + + # THEN + assert result.startswith(base_dir) + def test_uses_namespace(self) -> None: # GIVEN namespace = "my-namespace" - subject = SocketDirectoriesStub() + subject = SocketPathsStub() # WHEN - result = subject.get_socket_dir(namespace) + result = subject.get_socket_path("sock", namespace) # THEN - assert result.endswith(namespace) + assert os.path.dirname(result).endswith(namespace) - @patch.object(SocketDirectoriesStub, "verify_socket_path") - def test_raises_when_no_valid_dir_found(self, mock_verify_socket_path: MagicMock) -> None: + @patch.object(SocketPathsStub, "verify_socket_path") + def test_raises_when_no_valid_path_found(self, mock_verify_socket_path: MagicMock) -> None: # GIVEN mock_verify_socket_path.side_effect = NonvalidSocketPathException() - subject = SocketDirectoriesStub() + subject = SocketPathsStub() # WHEN with pytest.raises(NoSocketPathFoundException) as raised_exc: - subject.get_socket_dir() + subject.get_socket_path("sock") # THEN - assert raised_exc.match( - "Failed to find a suitable base directory to create sockets in for the following " - "reasons: " - ) - assert mock_verify_socket_path.call_count == 2 + assert raised_exc.match("^Socket path '.*' failed verification: .*") + assert mock_verify_socket_path.call_count == 1 - @patch.object(SocketDirectoriesStub, "verify_socket_path") - @patch.object(sockets.os, "stat") - def test_raises_when_no_tmpdir_sticky_bit( + @patch.object(sockets.os.path, "exists") + def test_handles_socket_name_collisions( self, - mock_stat: MagicMock, - mock_verify_socket_path: MagicMock, - temp_dir: str, + mock_exists: MagicMock, ) -> None: # GIVEN - mock_verify_socket_path.side_effect = [NonvalidSocketPathException(), None] - mock_stat.return_value.st_mode = 0 - subject = SocketDirectoriesStub() + sock_name = "sock" + existing_sock_names = [sock_name, f"{sock_name}_1", f"{sock_name}_2"] + mock_exists.side_effect = ([True] * len(existing_sock_names)) + [False] + expected_sock_name = f"{sock_name}_3" + + subject = SocketPathsStub() # WHEN - with pytest.raises(NoSocketPathFoundException) as raised_exc: - subject.get_socket_dir() + result = subject.get_socket_path(sock_name) # THEN - assert raised_exc.match( - re.escape( - f"Cannot use temporary directory {temp_dir} because it does not have the " - "sticky bit (restricted deletion flag) set" - ) - ) + assert result.endswith(expected_sock_name) + assert mock_exists.call_count == (len(existing_sock_names) + 1) + +class TestUnixSocketPaths: + @pytest.fixture(autouse=True) + def mock_stat(self) -> Generator[MagicMock, None, None]: + with patch.object(sockets.os, "stat") as m: + yield m -class TestLinuxSocketDirectories: @pytest.mark.parametrize( - argnames=["path"], + argnames=["mode", "should_fail"], argvalues=[ - ["a"], - ["a" * 100], + [stat.S_IWOTH & stat.S_ISVTX, False], + [stat.S_IWOTH, True], + [stat.S_IROTH, False], + ], + ids=[ + "world-writable, sticky, passes", + "world-writable, not sticky, fails", + "world-readable, not sticky, passes", ], - ids=["one byte", "100 bytes"], ) - def test_accepts_paths_within_100_bytes(self, path: str): - """ - Verifies the function accepts paths up to 100 bytes (108 byte max - 8 byte padding - for socket name portion (path sep + PID)) - """ - # GIVEN - subject = LinuxSocketDirectories() - - try: - # WHEN - subject.verify_socket_path(path) - except NonvalidSocketPathException as e: - pytest.fail(f"verify_socket_path raised an error when it should not have: {e}") - else: - # THEN - pass # success + def test_verifies_all_parent_directories( + self, + mode: int, + should_fail: bool, + mock_stat: MagicMock, + ) -> None: + pass - def test_rejects_paths_over_100_bytes(self): + @patch.object(sockets.os, "stat") + def test_raises_when_no_tmpdir_sticky_bit( + self, + mock_stat: MagicMock, + ) -> None: # GIVEN - length = 101 - path = "a" * length - subject = LinuxSocketDirectories() + mock_stat.return_value.st_mode = stat.S_IWOTH + socket_name = "/sock" + subject = UnixSocketPaths() # WHEN - with pytest.raises(NonvalidSocketPathException) as raised_exc: - subject.verify_socket_path(path) + with pytest.raises(NoSocketPathFoundException) as raised_exc: + subject.verify_socket_path(socket_name) # THEN assert raised_exc.match( - "Socket base directory path too big. The maximum allowed size is " - f"{subject._socket_dir_max_length} bytes, but the directory has a size of " - f"{length}: {path}" + re.escape( + f"Cannot use directory {os.path.dirname(socket_name)} because it is world writable" + " and does not have the sticky bit (restricted deletion flag) set" + ) ) + class TestLinuxSocketPaths: + @pytest.mark.parametrize( + argnames=["path"], + argvalues=[ + ["a"], + ["a" * 107], + ], + ids=["one byte", "107 bytes"], + ) + def test_accepts_names_within_107_bytes(self, path: str): + """ + Verifies the function accepts paths up to 100 bytes (108 byte max - 1 byte null terminator) + """ + # GIVEN + subject = LinuxSocketPaths() -class TestMacOSSocketDirectories: - @pytest.mark.parametrize( - argnames=["path"], - argvalues=[ - ["a"], - ["a" * 96], - ], - ids=["one byte", "96 bytes"], - ) - def test_accepts_paths_within_100_bytes(self, path: str): - """ - Verifies the function accepts paths up to 96 bytes (104 byte max - 8 byte padding - for socket name portion (path sep + PID)) - """ - # GIVEN - subject = MacOSSocketDirectories() + try: + # WHEN + subject.verify_socket_path(path) + except NonvalidSocketPathException as e: + pytest.fail(f"verify_socket_path raised an error when it should not have: {e}") + else: + # THEN + pass # success + + def test_rejects_names_over_107_bytes(self): + # GIVEN + length = 108 + path = "a" * length + subject = LinuxSocketPaths() - try: # WHEN - subject.verify_socket_path(path) - except NonvalidSocketPathException as e: - pytest.fail(f"verify_socket_path raised an error when it should not have: {e}") - else: + with pytest.raises(NonvalidSocketPathException) as raised_exc: + subject.verify_socket_path(path) + # THEN - pass # success + assert raised_exc.match( + "Socket name too long. The maximum allowed size is " + f"{subject._socket_name_max_length} bytes, but the name has a size of " + f"{length}: {path}" + ) - def test_rejects_paths_over_96_bytes(self): - # GIVEN - length = 97 - path = "a" * length - subject = MacOSSocketDirectories() + class TestMacOSSocketPaths: + @pytest.mark.parametrize( + argnames=["path"], + argvalues=[ + ["a"], + ["a" * 103], + ], + ids=["one byte", "103 bytes"], + ) + def test_accepts_paths_within_103_bytes(self, path: str): + """ + Verifies the function accepts paths up to 103 bytes (104 byte max - 1 byte null terminator) + """ + # GIVEN + subject = MacOSSocketPaths() - # WHEN - with pytest.raises(NonvalidSocketPathException) as raised_exc: - subject.verify_socket_path(path) + try: + # WHEN + subject.verify_socket_path(path) + except NonvalidSocketPathException as e: + pytest.fail(f"verify_socket_path raised an error when it should not have: {e}") + else: + # THEN + pass # success - # THEN - assert raised_exc.match( - "Socket base directory path too big. The maximum allowed size is " - f"{subject._socket_dir_max_length} bytes, but the directory has a size of " - f"{length}: {path}" - ) + def test_rejects_paths_over_103_bytes(self): + # GIVEN + length = 104 + path = "a" * length + subject = MacOSSocketPaths() + + # WHEN + with pytest.raises(NonvalidSocketPathException) as raised_exc: + subject.verify_socket_path(path) + + # THEN + assert raised_exc.match( + "Socket name too long. The maximum allowed size is " + f"{subject._socket_name_max_length} bytes, but the name has a size of " + f"{length}: {path}" + ) diff --git a/test/openjd/adaptor_runtime/unit/test_entrypoint.py b/test/openjd/adaptor_runtime/unit/test_entrypoint.py index 8ae766a..80de3dc 100644 --- a/test/openjd/adaptor_runtime/unit/test_entrypoint.py +++ b/test/openjd/adaptor_runtime/unit/test_entrypoint.py @@ -4,6 +4,7 @@ import argparse import json +import os import signal from pathlib import Path from typing import Optional @@ -21,6 +22,7 @@ ) from openjd.adaptor_runtime.adaptors import BaseAdaptor, SemanticVersion from openjd.adaptor_runtime._background import BackendRunner, FrontendRunner +from openjd.adaptor_runtime._background.model import ConnectionSettings from openjd.adaptor_runtime._osname import OSName from openjd.adaptor_runtime._entrypoint import _load_data @@ -214,7 +216,14 @@ def test_creates_adaptor_with_init_data(self, mock_adaptor_cls: MagicMock): # GIVEN init_data = {"init": "data"} with patch.object( - runtime_entrypoint.sys, "argv", ["Adaptor", "run", "--init-data", json.dumps(init_data)] + runtime_entrypoint.sys, + "argv", + [ + "Adaptor", + "run", + "--init-data", + json.dumps(init_data), + ], ): entrypoint = EntryPoint(mock_adaptor_cls) @@ -461,7 +470,7 @@ def test_runs_background_serve( ): # GIVEN init_data = {"init": "data"} - conn_file = "/path/to/conn_file" + conn_file = Path(os.sep) / "path" / "to" / "conn_file" with patch.object( runtime_entrypoint.sys, "argv", @@ -472,7 +481,7 @@ def test_runs_background_serve( "--init-data", json.dumps(init_data), "--connection-file", - conn_file, + str(conn_file), ], ): entrypoint = EntryPoint(mock_adaptor_cls) @@ -487,7 +496,7 @@ def test_runs_background_serve( ) mock_init.assert_called_once_with( mock_adaptor_runner.return_value, - conn_file, + connection_file_path=conn_file.resolve(), log_buffer=mock_log_buffer.return_value, ) mock_run.assert_called_once() @@ -506,7 +515,7 @@ def test_background_serve_no_signal_hook( ): # GIVEN init_data = {"init": "data"} - conn_file = "/path/to/conn_file" + conn_file = os.path.join(os.sep, "path", "to", "conn_file") with patch.object( runtime_entrypoint.sys, "argv", @@ -534,7 +543,7 @@ def test_background_start_raises_when_adaptor_module_not_loaded( mock_magic_init: MagicMock, ): # GIVEN - conn_file = "/path/to/conn_file" + conn_file = Path(os.sep) / "path" / "to" / "conn_file" with patch.object( runtime_entrypoint.sys, "argv", @@ -543,7 +552,7 @@ def test_background_start_raises_when_adaptor_module_not_loaded( "daemon", "start", "--connection-file", - conn_file, + str(conn_file), ], ): entrypoint = EntryPoint(FakeAdaptor) @@ -555,7 +564,7 @@ def test_background_start_raises_when_adaptor_module_not_loaded( # THEN assert raised_err.match(f"Adaptor module is not loaded: {FakeAdaptor.__module__}") - mock_magic_init.assert_called_once_with(conn_file) + mock_magic_init.assert_called_once_with() @pytest.mark.parametrize( argnames="reentry_exe", @@ -570,12 +579,12 @@ def test_background_start_raises_when_adaptor_module_not_loaded( def test_runs_background_start( self, mock_start: MagicMock, + mock_init: MagicMock, mock_magic_init: MagicMock, - mock_magic_start: MagicMock, reentry_exe: Optional[Path], ): # GIVEN - conn_file = "/path/to/conn_file" + conn_file = Path(os.sep) / "path" / "to" / "conn_file" with patch.object( runtime_entrypoint.sys, "argv", @@ -584,7 +593,7 @@ def test_runs_background_start( "daemon", "start", "--connection-file", - conn_file, + str(conn_file), ], ): mock_adaptor_module = Mock() @@ -597,10 +606,17 @@ def test_runs_background_start( entrypoint.start(reentry_exe=reentry_exe) # THEN - mock_magic_init.assert_called_once_with(mock_adaptor_module, {}, {}, reentry_exe) - mock_magic_start.assert_called_once_with(conn_file) + mock_magic_init.assert_called_once_with() + mock_init.assert_called_once_with( + adaptor_module=mock_adaptor_module, + connection_file_path=conn_file.resolve(), + init_data={}, + path_mapping_data={}, + reentry_exe=reentry_exe, + ) mock_start.assert_called_once_with() + @patch.object(runtime_entrypoint.ConnectionSettingsFileLoader, "load") @patch.object(FrontendRunner, "__init__", return_value=None) @patch.object(FrontendRunner, "shutdown") @patch.object(FrontendRunner, "stop") @@ -609,9 +625,11 @@ def test_runs_background_stop( mock_end: MagicMock, mock_shutdown: MagicMock, mock_magic_init: MagicMock, + mock_connection_settings_load: MagicMock, ): # GIVEN - conn_file = "/path/to/conn_file" + connection_settings = ConnectionSettings("socket") + mock_connection_settings_load.return_value = connection_settings with patch.object( runtime_entrypoint.sys, "argv", @@ -620,7 +638,7 @@ def test_runs_background_stop( "daemon", "stop", "--connection-file", - conn_file, + "/path/to/conn/file", ], ): entrypoint = EntryPoint(FakeAdaptor) @@ -629,19 +647,22 @@ def test_runs_background_stop( entrypoint.start() # THEN - mock_magic_init.assert_called_once_with(conn_file) + mock_magic_init.assert_called_once_with(connection_settings=connection_settings) mock_end.assert_called_once() mock_shutdown.assert_called_once_with() + @patch.object(runtime_entrypoint.ConnectionSettingsFileLoader, "load") @patch.object(FrontendRunner, "__init__", return_value=None) @patch.object(FrontendRunner, "run") def test_runs_background_run( self, mock_run: MagicMock, mock_magic_init: MagicMock, + mock_connection_settings_load: MagicMock, ): # GIVEN - conn_file = "/path/to/conn_file" + conn_settings = ConnectionSettings("socket") + mock_connection_settings_load.return_value = conn_settings run_data = {"run": "data"} with patch.object( runtime_entrypoint.sys, @@ -653,7 +674,7 @@ def test_runs_background_run( "--run-data", json.dumps(run_data), "--connection-file", - conn_file, + "/path/to/conn/file", ], ): entrypoint = EntryPoint(FakeAdaptor) @@ -662,9 +683,11 @@ def test_runs_background_run( entrypoint.start() # THEN - mock_magic_init.assert_called_once_with(conn_file) + mock_magic_init.assert_called_once_with(connection_settings=conn_settings) mock_run.assert_called_once_with(run_data) + mock_connection_settings_load.assert_called_once() + @patch.object(runtime_entrypoint.ConnectionSettingsFileLoader, "load") @patch.object(FrontendRunner, "__init__", return_value=None) @patch.object(FrontendRunner, "run") @patch.object(runtime_entrypoint.signal, "signal") @@ -673,9 +696,10 @@ def test_background_no_signal_hook( signal_mock: MagicMock, mock_run: MagicMock, mock_magic_init: MagicMock, + mock_connection_settings_load: MagicMock, ): # GIVEN - conn_file = "/path/to/conn_file" + conn_file = Path(os.sep) / "path" / "to" / "conn_file" run_data = {"run": "data"} with patch.object( runtime_entrypoint.sys, @@ -687,7 +711,7 @@ def test_background_no_signal_hook( "--run-data", json.dumps(run_data), "--connection-file", - conn_file, + str(conn_file), ], ): entrypoint = EntryPoint(FakeAdaptor) @@ -698,13 +722,15 @@ def test_background_no_signal_hook( # THEN signal_mock.assert_not_called() + @patch.object(runtime_entrypoint, "ConnectionSettingsFileLoader") @patch.object(runtime_entrypoint, "FrontendRunner") def test_makes_connection_file_path_absolute( self, mock_runner: MagicMock, + mock_connection_settings_loader: MagicMock, ): # GIVEN - conn_file = "relpath" + conn_file = Path("relpath") with patch.object( runtime_entrypoint.sys, "argv", @@ -713,23 +739,30 @@ def test_makes_connection_file_path_absolute( "daemon", "run", "--connection-file", - conn_file, + str(conn_file), ], ): entrypoint = EntryPoint(FakeAdaptor) # WHEN - mock_isabs: MagicMock + mock_is_absolute: MagicMock with ( - patch.object(runtime_entrypoint.os.path, "isabs", return_value=False) as mock_isabs, - patch.object(runtime_entrypoint.os.path, "abspath") as mock_abspath, + patch.object( + runtime_entrypoint.Path, "is_absolute", return_value=False + ) as mock_is_absolute, + patch.object( + runtime_entrypoint.Path, "absolute", return_value=Path("absolute") + ) as mock_absolute, ): entrypoint.start() # THEN - mock_isabs.assert_any_call(conn_file) - mock_abspath.assert_any_call(conn_file) - mock_runner.assert_called_once_with(mock_abspath.return_value) + mock_is_absolute.assert_called_once() + mock_absolute.assert_any_call() + mock_connection_settings_loader.assert_called_once_with(mock_absolute.return_value) + mock_runner.assert_called_once_with( + connection_settings=mock_connection_settings_loader.return_value.load.return_value + ) class TestLoadData: diff --git a/test/openjd/adaptor_runtime_client/integ/test_integration_client_interface.py b/test/openjd/adaptor_runtime_client/integ/test_integration_client_interface.py index 6fc2bd1..1dd0940 100644 --- a/test/openjd/adaptor_runtime_client/integ/test_integration_client_interface.py +++ b/test/openjd/adaptor_runtime_client/integ/test_integration_client_interface.py @@ -53,7 +53,7 @@ def test_graceful_shutdown(self) -> None: # Ensure the process actually shutdown assert client_subprocess.returncode is not None - def test_client_in_thread_does_not_do_graceful_shutdown(self): + def test_client_in_thread_does_not_do_graceful_shutdown(self) -> None: """Ensures that a client running in a thread does not crash by attempting to register a signal, since they can only be created in the main thread. This means the graceful shutdown is effectively ignored."""