diff --git a/src/openjd/adaptor_runtime/_background/backend_runner.py b/src/openjd/adaptor_runtime/_background/backend_runner.py index 50bb7b1..5d4e16e 100644 --- a/src/openjd/adaptor_runtime/_background/backend_runner.py +++ b/src/openjd/adaptor_runtime/_background/backend_runner.py @@ -6,14 +6,16 @@ import logging import os import signal +from pathlib import Path from threading import Thread, Event +import traceback from types import FrameType -from typing import Optional, Union +from typing import Callable, List, Optional, Union from .server_response import ServerResponseGenerator from .._osname import OSName from ..adaptors import AdaptorRunner -from .._http import SocketDirectories +from .._http import SocketPaths from .._utils import secure_open if OSName.is_posix(): @@ -36,12 +38,13 @@ class BackendRunner: def __init__( self, adaptor_runner: AdaptorRunner, - connection_file_path: str, *, + connection_file_path: Path, log_buffer: LogBuffer | None = None, ) -> None: self._adaptor_runner = adaptor_runner self._connection_file_path = connection_file_path + self._log_buffer = log_buffer self._server: Optional[Union[BackgroundHTTPServer, WinBackgroundNamedPipeServer]] = None signal.signal(signal.SIGINT, self._sigint_handler) @@ -68,7 +71,7 @@ def _sigint_handler(self, signum: int, frame: Optional[FrameType]) -> None: self._server, self._adaptor_runner._cancel, force_immediate=True ) - def run(self) -> None: + def run(self, *, on_connection_file_written: List[Callable[[], None]] | None = None) -> None: """ Runs the backend logic for background mode. @@ -79,8 +82,9 @@ def run(self) -> None: shutdown_event: Event = Event() if OSName.is_posix(): # pragma: is-windows - server_path = SocketDirectories.for_os().get_process_socket_path( - "runtime", create_dir=True + server_path = SocketPaths.for_os().get_process_socket_path( + ".openjd_adaptor_runtime", + create_dir=True, ) else: # pragma: is-posix server_path = NamedPipeHelper.generate_pipe_name("AdaptorNamedPipe") @@ -123,6 +127,16 @@ def run(self) -> None: _logger.info("Shutting down server...") shutdown_event.set() raise + except Exception as e: + _logger.critical(f"Unexpected error occurred when writing to connection file: {e}") + _logger.critical(traceback.format_exc()) + _logger.info("Shutting down server") + shutdown_event.set() + else: + if on_connection_file_written: + callbacks = list(on_connection_file_written) + for cb in callbacks: + cb() finally: # Block until the shutdown_event is set shutdown_event.wait() diff --git a/src/openjd/adaptor_runtime/_background/frontend_runner.py b/src/openjd/adaptor_runtime/_background/frontend_runner.py index 23e9234..1b229cb 100644 --- a/src/openjd/adaptor_runtime/_background/frontend_runner.py +++ b/src/openjd/adaptor_runtime/_background/frontend_runner.py @@ -10,16 +10,20 @@ import socket import subprocess import sys +import tempfile import time -from pathlib import Path import urllib.parse as urllib_parse +import uuid +from pathlib import Path from threading import Event from types import FrameType from types import ModuleType -from typing import Optional, Dict +from typing import Optional, Callable, Dict from .._osname import OSName from ..process._logging import _ADAPTOR_OUTPUT_LEVEL +from .._utils._constants import _OPENJD_ENV_STDOUT_PREFIX, _OPENJD_ADAPTOR_SOCKET_ENV +from .loaders import ConnectionSettingsFileLoader from .model import ( AdaptorState, AdaptorStatus, @@ -37,28 +41,39 @@ _logger = logging.getLogger(__name__) +class ConnectionSettingsNotProvidedError(Exception): + """Raised when the connection settings are required but are missing""" + + pass + + class FrontendRunner: """ Class that runs the frontend logic in background mode. """ + connection_settings: ConnectionSettings | None + def __init__( self, - connection_file_path: str, *, timeout_s: float = 5.0, heartbeat_interval: float = 1.0, + connection_settings: ConnectionSettings | None = None, ) -> None: """ Args: - connection_file_path (str): Absolute path to the connection file. timeout_s (float, optional): Timeout for HTTP requests, in seconds. Defaults to 5. heartbeat_interval (float, optional): Interval between heartbeats, in seconds. Defaults to 1. + connection_settings (ConnectionSettings, optional): The connection settings to use. + This option is not required for the "init" command, but is required for everything + else. Defaults to None. """ self._timeout_s = timeout_s self._heartbeat_interval = heartbeat_interval - self._connection_file_path = connection_file_path + self.connection_settings = connection_settings + self._canceled = Event() signal.signal(signal.SIGINT, self._sigint_handler) if OSName.is_posix(): # pragma: is-windows @@ -68,7 +83,9 @@ def __init__( def init( self, + *, adaptor_module: ModuleType, + connection_file_path: Path, init_data: dict | None = None, path_mapping_data: dict | None = None, reentry_exe: Path | None = None, @@ -79,6 +96,8 @@ def init( Args: adaptor_module (ModuleType): The module of the adaptor running the runtime. + connection_file_path (Path): The path to the connection file to use for establishing + a connection with the backend process. init_data (dict): Data to pass to the adaptor during initialization. path_mapping_data (dict): Path mapping rules to make available to the adaptor while it's running. reentry_exe (Path): The path to the binary executable that for adaptor reentry. @@ -86,10 +105,10 @@ def init( if adaptor_module.__package__ is None: raise Exception(f"Adaptor module is not a package: {adaptor_module}") - if os.path.exists(self._connection_file_path): + if connection_file_path.exists(): raise FileExistsError( "Cannot init a new backend process with an existing connection file at: " - + self._connection_file_path + + str(connection_file_path) ) if init_data is None: @@ -107,18 +126,32 @@ def init( ] else: args = [str(reentry_exe)] + args.extend( [ "daemon", "_serve", - "--connection-file", - self._connection_file_path, "--init-data", json.dumps(init_data), "--path-mapping-rules", json.dumps(path_mapping_data), + "--connection-file", + str(connection_file_path), ] ) + + bootstrap_id = uuid.uuid4() + bootstrap_log_dir = tempfile.gettempdir() + bootstrap_log_path = os.path.join( + bootstrap_log_dir, f"adaptor-runtime-background-bootstrap-{bootstrap_id}.log" + ) + args.extend(["--bootstrap-log-file", bootstrap_log_path]) + + _logger.debug(f"Running process with args: {args}") + bootstrap_output_path = os.path.join( + bootstrap_log_dir, f"adaptor-runtime-background-bootstrap-output-{bootstrap_id}.log" + ) + output_log_file = open(bootstrap_output_path, mode="w+") try: process = subprocess.Popen( args, @@ -126,8 +159,8 @@ def init( close_fds=True, start_new_session=True, stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, + stdout=output_log_file, + stderr=output_log_file, ) except Exception as e: _logger.error(f"Failed to initialize backend process: {e}") @@ -136,19 +169,61 @@ def init( # Wait for backend process to create connection file try: - _wait_for_file(self._connection_file_path, timeout_s=5) + _wait_for_connection_file(str(connection_file_path), max_retries=5, interval_s=1) except TimeoutError: _logger.error( "Backend process failed to write connection file in time at: " - + self._connection_file_path + + str(connection_file_path) ) + + exit_code = process.poll() + if exit_code is not None: + _logger.info(f"Backend process exited with code: {exit_code}") + else: + _logger.info("Backend process is still running") + raise + finally: + # Close file handle to prevent further writes + # At this point, we have all the logs/output we need from the bootstrap + output_log_file.close() + if process.stdout: + process.stdout.close() + if process.stderr: + process.stderr.close() + + with open(bootstrap_output_path, mode="r") as f: + bootstrap_output = f.readlines() + _logger.info("========== BEGIN BOOTSTRAP OUTPUT CONTENTS ==========") + for line in bootstrap_output: + _logger.info(line.strip()) + _logger.info("========== END BOOTSTRAP OUTPUT CONTENTS ==========") + + _logger.info(f"Checking for bootstrap logs at '{bootstrap_log_path}'") + try: + with open(bootstrap_log_path, mode="r") as f: + bootstrap_logs = f.readlines() + except Exception as e: + _logger.error(f"Failed to get bootstrap logs at '{bootstrap_log_path}': {e}") + else: + _logger.info("========== BEGIN BOOTSTRAP LOG CONTENTS ==========") + for line in bootstrap_logs: + _logger.info(line.strip()) + _logger.info("========== END BOOTSTRAP LOG CONTENTS ==========") + + # Load up connection settings for the heartbeat requests + self.connection_settings = ConnectionSettingsFileLoader(connection_file_path).load() # Heartbeat to ensure backend process is listening for requests _logger.info("Verifying connection to backend...") self._heartbeat() _logger.info("Connected successfully") + # Output the socket path to the environment via OpenJD environments + _logger.info( + f"{_OPENJD_ENV_STDOUT_PREFIX}{_OPENJD_ADAPTOR_SOCKET_ENV}={self.connection_settings.socket}" + ) + def run(self, run_data: dict) -> None: """ Sends a run request to the backend @@ -248,6 +323,11 @@ def _send_request( params: dict | None = None, json_body: dict | None = None, ) -> http_client.HTTPResponse | Dict: + if not self.connection_settings: + raise ConnectionSettingsNotProvidedError( + "Connection settings are required to send requests, but none were provided" + ) + if OSName.is_windows(): # pragma: is-posix if params: # This is used for aligning to the Linux's behavior in order to reuse the code in handler. @@ -287,6 +367,11 @@ def _send_linux_request( params: dict | None = None, json_body: dict | None = None, ) -> http_client.HTTPResponse: # pragma: is-windows + if not self.connection_settings: + raise ConnectionSettingsNotProvidedError( + "Connection settings are required to send requests, but none were provided" + ) + conn = UnixHTTPConnection(self.connection_settings.socket, timeout=self._timeout_s) if params: @@ -311,15 +396,6 @@ def _send_linux_request( return response - @property - def connection_settings(self) -> ConnectionSettings: - """ - Gets the lazy-loaded connection settings. - """ - if not hasattr(self, "_connection_settings"): - self._connection_settings = _load_connection_settings(self._connection_file_path) - return self._connection_settings - def _sigint_handler(self, signum: int, frame: Optional[FrameType]) -> None: """ Signal handler for interrupt signals. @@ -338,51 +414,84 @@ def _sigint_handler(self, signum: int, frame: Optional[FrameType]) -> None: self.cancel() -def _load_connection_settings(path: str) -> ConnectionSettings: - try: - with open(path) as conn_file: - loaded_settings = json.load(conn_file) - except OSError as e: - _logger.error(f"Failed to open connection file: {e}") - raise - except json.JSONDecodeError as e: - _logger.error(f"Failed to decode connection file: {e}") - raise - return DataclassMapper(ConnectionSettings).map(loaded_settings) - - -def _wait_for_file(filepath: str, timeout_s: float, interval_s: float = 1) -> None: +def _wait_for_connection_file( + filepath: str, max_retries: int, interval_s: float = 1 +) -> ConnectionSettings: """ - Waits for a file at the specified path to exist and to be openable. + Waits for a connection file at the specified path to exist, be openable, and have connection settings. Args: filepath (str): The file path to check. - timeout_s (float): The max duration to wait before timing out, in seconds. + max_retries (int): The max number of retries before timing out. interval_s (float, optional): The interval between checks, in seconds. Default is 0.01s. Raises: TimeoutError: Raised when the file does not exist after timeout_s seconds. """ + wait_for( + description=f"File '{filepath}' to exist", + predicate=lambda: os.path.exists(filepath), + interval_s=interval_s, + max_retries=max_retries, + ) - def _wait(): - if time.time() - start < timeout_s: - time.sleep(interval_s) - else: - raise TimeoutError(f"Timed out after {timeout_s}s waiting for file at {filepath}") - - start = time.time() - while not os.path.exists(filepath): - _wait() + # Wait before opening to give the backend time to open it first + time.sleep(interval_s) - while True: - # Wait before opening to give the backend time to open it first - _wait() + def file_is_openable() -> bool: try: open(filepath, mode="r").close() - break except IOError: # File is not available yet - pass + return False + else: + return True + + wait_for( + description=f"File '{filepath}' to be openable", + predicate=file_is_openable, + interval_s=interval_s, + max_retries=max_retries, + ) + + def connection_file_loadable() -> bool: + try: + ConnectionSettingsFileLoader(Path(filepath)).load() + except Exception: + return False + else: + return True + + wait_for( + description=f"File '{filepath}' to have valid ConnectionSettings", + predicate=connection_file_loadable, + interval_s=interval_s, + max_retries=max_retries, + ) + + return ConnectionSettingsFileLoader(Path(filepath)).load() + + +def wait_for( + *, + description: str, + predicate: Callable[[], bool], + interval_s: float, + max_retries: int | None = None, +) -> None: + if max_retries is not None: + assert max_retries >= 0, "max_retries must be a non-negative integer" + assert interval_s > 0, "interval_s must be a positive number" + + _logger.info(f"Waiting for {description}") + retry_count = 0 + while not predicate(): + if max_retries is not None and retry_count >= max_retries: + raise TimeoutError(f"Timed out waiting for {description}") + + _logger.info(f"Retrying in {interval_s}s...") + retry_count += 1 + time.sleep(interval_s) class AdaptorFailedException(Exception): diff --git a/src/openjd/adaptor_runtime/_background/loaders.py b/src/openjd/adaptor_runtime/_background/loaders.py new file mode 100644 index 0000000..e258a29 --- /dev/null +++ b/src/openjd/adaptor_runtime/_background/loaders.py @@ -0,0 +1,69 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import abc +import dataclasses +import logging +import json +import os +from pathlib import Path + +from .model import ( + ConnectionSettings, + DataclassMapper, +) + +_logger = logging.getLogger(__name__) + + +class ConnectionSettingsLoadingError(Exception): + """Raised when the connection settings cannot be loaded""" + + pass + + +class ConnectionSettingsLoader(abc.ABC): + @abc.abstractmethod + def load(self) -> ConnectionSettings: + pass + + +@dataclasses.dataclass +class ConnectionSettingsFileLoader(ConnectionSettingsLoader): + file_path: Path + + def load(self) -> ConnectionSettings: + try: + with open(self.file_path) as conn_file: + loaded_settings = json.load(conn_file) + except OSError as e: + errmsg = f"Failed to open connection file '{self.file_path}': {e}" + _logger.error(errmsg) + raise ConnectionSettingsLoadingError(errmsg) from e + except json.JSONDecodeError as e: + errmsg = f"Failed to decode connection file '{self.file_path}': {e}" + _logger.error(errmsg) + raise ConnectionSettingsLoadingError(errmsg) from e + return DataclassMapper(ConnectionSettings).map(loaded_settings) + + +@dataclasses.dataclass +class ConnectionSettingsEnvLoader(ConnectionSettingsLoader): + env_map: dict[str, tuple[str, bool]] = dataclasses.field( + default_factory=lambda: {"socket": ("OPENJD_ADAPTOR_SOCKET", True)} + ) + """Mapping of environment variable to a tuple of ConnectionSettings attribute name, and whether it is required""" + + def load(self) -> ConnectionSettings: + kwargs = {} + for attr_name, (env_name, required) in self.env_map.items(): + env_val = os.environ.get(env_name) + if not env_val: + if required: + raise ConnectionSettingsLoadingError( + f"Required attribute '{attr_name}' does not have its corresponding environment variable '{env_name}' set" + ) + else: + kwargs[attr_name] = env_val + return ConnectionSettings(**kwargs) diff --git a/src/openjd/adaptor_runtime/_background/server_response.py b/src/openjd/adaptor_runtime/_background/server_response.py index 9897b2f..31d8b2f 100644 --- a/src/openjd/adaptor_runtime/_background/server_response.py +++ b/src/openjd/adaptor_runtime/_background/server_response.py @@ -17,8 +17,8 @@ from .http_server import BackgroundHTTPServer -from ..adaptors._adaptor_runner import _OPENJD_FAIL_STDOUT_PREFIX from .._http import HTTPResponse +from .._utils._constants import _OPENJD_FAIL_STDOUT_PREFIX from .model import ( AdaptorState, AdaptorStatus, diff --git a/src/openjd/adaptor_runtime/_entrypoint.py b/src/openjd/adaptor_runtime/_entrypoint.py index 6c4ab8c..e7a112d 100644 --- a/src/openjd/adaptor_runtime/_entrypoint.py +++ b/src/openjd/adaptor_runtime/_entrypoint.py @@ -2,28 +2,46 @@ from __future__ import annotations +import contextlib import logging import os import signal import sys +import tempfile from pathlib import Path from argparse import ArgumentParser, Namespace from types import FrameType as FrameType -from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, NamedTuple, Tuple +from typing import ( + TYPE_CHECKING, + Any, + cast, + Callable, + List, + Optional, + MutableSet, + Type, + TypeVar, + NamedTuple, + Tuple, +) import jsonschema import yaml from .adaptors import AdaptorRunner, BaseAdaptor from ._background import BackendRunner, FrontendRunner, InMemoryLogBuffer, LogBufferHandler +from ._background.loaders import ( + ConnectionSettingsFileLoader, + ConnectionSettingsEnvLoader, +) from .adaptors.configuration import ( RuntimeConfiguration, ConfigurationManager, ) from ._osname import OSName +from ._utils._constants import _OPENJD_ADAPTOR_SOCKET_ENV, _OPENJD_LOG_REGEX from ._utils._logging import ( - _OPENJD_LOG_REGEX, ConditionalFormatter, ) from .adaptors import SemanticVersion @@ -53,7 +71,13 @@ "file://path/to/file.json" ), "show_config": "Prints the adaptor runtime configuration, then the program exits.", - "connection_file": "The file path to the connection file for use in background mode.", + "connection_file": ( + "The file path to the connection file for use in daemon mode. For the 'daemon start' command, this file " + "must not exist. For all other commands, this file must exist. This option is highly " + "recommended if using the adaptor in an interactive terminal. Default is to read the " + f"connection data from the environment variable: {_OPENJD_ADAPTOR_SOCKET_ENV}" + ), + "log_file": "The file to log adaptor output to. Default is to not log to a file.", } _DIR = os.path.dirname(os.path.realpath(__file__)) @@ -64,7 +88,6 @@ os.path.join( _system_config_path_prefix, "openjd", - "worker", "adaptors", "runtime", "configuration.json", @@ -75,14 +98,27 @@ "schema_path": os.path.abspath(os.path.join(_DIR, "configuration.schema.json")), "default_config_path": os.path.abspath(os.path.join(_DIR, "configuration.json")), "system_config_path": _system_config_path, - "user_config_rel_path": os.path.join( - ".openjd", "worker", "adaptors", "runtime", "configuration.json" - ), + "user_config_rel_path": os.path.join(".openjd", "adaptors", "runtime", "configuration.json"), } _logger = logging.getLogger(__name__) +class _ParsedArgs(Namespace): + command: str + + # common args + init_data: str + run_data: str + path_mapping_rules: str + connection_file: str | None + bootstrap_log_file: str | None + + # is-compatible args + openjd_adaptor_cli_version: str | None + integration_data_interface_version: str | None + + class _LogConfig(NamedTuple): formatter: ConditionalFormatter stream_handler: logging.StreamHandler @@ -123,13 +159,21 @@ class EntryPoint: The main entry point of the adaptor runtime. """ + on_bootstrap_complete: MutableSet[Callable[[], None]] + """ + Set of callbacks that are called when daemon mode bootstrapping is complete. + These callbacks are never called when not running in daemon mode. + """ + def __init__(self, adaptor_class: Type[_U]) -> None: self.adaptor_class = adaptor_class # This will be the current AdaptorRunner when using the 'run' command, rather than # 'background' command self._adaptor_runner: Optional[AdaptorRunner] = None - def _init_loggers(self) -> _LogConfig: + self.on_bootstrap_complete = set() + + def _init_loggers(self, *, bootstrap_log_path: str | None = None) -> _LogConfig: "Creates runtime/adaptor loggers" formatter = ConditionalFormatter( "%(levelname)s: %(message)s", ignore_patterns=[_OPENJD_LOG_REGEX] @@ -144,6 +188,22 @@ def _init_loggers(self) -> _LogConfig: adaptor_logger = logging.getLogger(self.adaptor_class.__module__.split(".")[0]) adaptor_logger.addHandler(stream_handler) + if bootstrap_log_path: + file_formatter = logging.Formatter("[%(asctime)s][%(levelname)-8s] %(message)s") + file_handler = logging.FileHandler(bootstrap_log_path) + file_handler.setFormatter(file_formatter) + file_handler.setLevel(0) + runtime_logger.addHandler(file_handler) + adaptor_logger.addHandler(file_handler) + + def disconnect_bootstrap_logging() -> None: + # Remove file logger after bootstrap is complete + runtime_logger.removeHandler(file_handler) + adaptor_logger.removeHandler(file_handler) + self.on_bootstrap_complete.remove(disconnect_bootstrap_logging) + + self.on_bootstrap_complete.add(disconnect_bootstrap_logging) + return _LogConfig(formatter, stream_handler, runtime_logger, adaptor_logger) def _init_config(self) -> None: @@ -193,19 +253,28 @@ def start(self, reentry_exe: Optional[Path] = None) -> None: Args: reentry_exe (Path): The path to the binary executable that for adaptor reentry. """ - log_config = self._init_loggers() parser, parsed_args = self._parse_args() - version_info = self._get_version_info() + log_config = self._init_loggers( + bootstrap_log_path=( + parsed_args.bootstrap_log_file + if hasattr(parsed_args, "bootstrap_log_file") + else None + ) + ) + + interface_version_info = self._get_version_info() if parsed_args.command == "is-compatible": - return self._handle_is_compatible(version_info, parsed_args, parser) + return self._handle_is_compatible(interface_version_info, parsed_args, parser) elif parsed_args.command == "version-info": return print( yaml.dump( { - "OpenJD Adaptor CLI Version": str(version_info.adaptor_cli_version), + "OpenJD Adaptor CLI Version": str( + interface_version_info.adaptor_cli_version + ), f"{self.adaptor_class.__name__} Data Interface Version": str( - version_info.integration_data_interface_version + interface_version_info.integration_data_interface_version ), }, indent=2, @@ -295,17 +364,26 @@ def _handle_run( def _handle_daemon( self, adaptor: BaseAdaptor[AdaptorConfiguration], - parsed_args: Namespace, + parsed_args: _ParsedArgs, log_config: _LogConfig, integration_data: _IntegrationData, reentry_exe: Optional[Path] = None, ): - connection_file = parsed_args.connection_file - if not os.path.isabs(connection_file): - connection_file = os.path.abspath(connection_file) + # Validate args subcommand = parsed_args.subcommand if hasattr(parsed_args, "subcommand") else None + connection_file: Path | None = None + if hasattr(parsed_args, "connection_file") and parsed_args.connection_file: + connection_file = Path(parsed_args.connection_file) + if connection_file and not connection_file.is_absolute(): + connection_file = connection_file.absolute() + if subcommand == "_serve": + if not connection_file: + raise RuntimeError( + "--connection file is required for the '_serve' command but was not provided." + ) + # Replace stream handler with log buffer handler since output will be buffered in # background mode log_buffer = InMemoryLogBuffer(formatter=log_config.formatter) @@ -318,41 +396,61 @@ def _handle_daemon( # forever until a shutdown is requested backend = BackendRunner( AdaptorRunner(adaptor=adaptor), - connection_file, + connection_file_path=connection_file, log_buffer=log_buffer, ) - backend.run() + backend.run( + on_connection_file_written=cast( + List[Callable[[], None]], self.on_bootstrap_complete + ) + ) else: # This process is running in frontend mode. Create the frontend runner and send # the appropriate request to the backend. - frontend = FrontendRunner(connection_file) if subcommand == "start": + frontend = FrontendRunner() adaptor_module = sys.modules.get(self.adaptor_class.__module__) if adaptor_module is None: raise ModuleNotFoundError( f"Adaptor module is not loaded: {self.adaptor_class.__module__}" ) - frontend.init( - adaptor_module, - integration_data.init_data, - integration_data.path_mapping_data, - reentry_exe, - ) + with contextlib.ExitStack() as stack: + if not connection_file: + tmpdir = stack.enter_context(tempfile.TemporaryDirectory(prefix="ojd-ar-")) + connection_file = Path(tmpdir) / "connection.json" + + frontend.init( + adaptor_module=adaptor_module, + connection_file_path=connection_file, + init_data=integration_data.init_data, + path_mapping_data=integration_data.path_mapping_data, + reentry_exe=reentry_exe, + ) frontend.start() - elif subcommand == "run": - frontend.run(integration_data.run_data) - elif subcommand == "stop": - frontend.stop() - frontend.shutdown() - - def _parse_args(self) -> Tuple[ArgumentParser, Namespace]: + else: + conn_settings_loader = ( + ConnectionSettingsFileLoader(connection_file) + if connection_file + else ConnectionSettingsEnvLoader() + ) + conn_settings = conn_settings_loader.load() + frontend = FrontendRunner(connection_settings=conn_settings) + if subcommand == "run": + frontend.run(integration_data.run_data) + elif subcommand == "stop": + frontend.stop() + frontend.shutdown() + + def _parse_args(self) -> Tuple[ArgumentParser, _ParsedArgs]: parser = self._build_argparser() try: - return parser, parser.parse_args(sys.argv[1:]) + parsed_args = parser.parse_args(sys.argv[1:], _ParsedArgs()) except Exception as e: _logger.error(f"Error parsing command line arguments: {e}") raise + else: + return parser, parsed_args def _build_argparser(self) -> ArgumentParser: parser = ArgumentParser( @@ -412,9 +510,15 @@ def _build_argparser(self) -> ArgumentParser: connection_file = ArgumentParser(add_help=False) connection_file.add_argument( "--connection-file", - default="", help=_CLI_HELP_TEXT["connection_file"], - required=True, + required=False, + ) + + log_file = ArgumentParser(add_help=False) + log_file.add_argument( + "--bootstrap-log-file", + help=_CLI_HELP_TEXT["log_file"], + required=False, ) bg_parser = subparser.add_parser("daemon", help="Runs the adaptor in a daemon mode.") @@ -427,8 +531,10 @@ def _build_argparser(self) -> ArgumentParser: ) # "Hidden" command that actually runs the adaptor runtime in background mode - bg_subparser.add_parser("_serve", parents=[init_data, path_mapping_rules, connection_file]) - + bg_subparser.add_parser( + "_serve", + parents=[init_data, path_mapping_rules, connection_file, log_file], + ) bg_subparser.add_parser("start", parents=[init_data, path_mapping_rules, connection_file]) bg_subparser.add_parser("run", parents=[run_data, connection_file]) bg_subparser.add_parser("stop", parents=[connection_file]) diff --git a/src/openjd/adaptor_runtime/_http/__init__.py b/src/openjd/adaptor_runtime/_http/__init__.py index 834c7e8..70b2f0d 100644 --- a/src/openjd/adaptor_runtime/_http/__init__.py +++ b/src/openjd/adaptor_runtime/_http/__init__.py @@ -1,6 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from .request_handler import HTTPResponse, RequestHandler, ResourceRequestHandler -from .sockets import SocketDirectories +from .sockets import SocketPaths -__all__ = ["HTTPResponse", "RequestHandler", "ResourceRequestHandler", "SocketDirectories"] +__all__ = ["HTTPResponse", "RequestHandler", "ResourceRequestHandler", "SocketPaths"] diff --git a/src/openjd/adaptor_runtime/_http/sockets.py b/src/openjd/adaptor_runtime/_http/sockets.py index 58e3e34..c4c9011 100644 --- a/src/openjd/adaptor_runtime/_http/sockets.py +++ b/src/openjd/adaptor_runtime/_http/sockets.py @@ -16,37 +16,44 @@ # Max PID on 64-bit systems is 4194304 (2^22) _PID_MAX_LENGTH = 7 -_PID_MAX_LENGTH_PADDED = _PID_MAX_LENGTH + 1 # 1 char for path seperator -class SocketDirectories(abc.ABC): +class SocketPaths(abc.ABC): """ - Base class for determining the base directory for sockets used in the Adaptor Runtime. + Base class for determining the paths for sockets used in the Adaptor Runtime. """ @staticmethod def for_os(osname: OSName = OSName()): # pragma: no cover - """_summary_ + """ + Gets the SocketPaths class for a specific OS. Args: - osname (OSName, optional): The OS to get socket directories for. + osname (OSName, optional): The OS to get socket paths for. Defaults to the current OS. Raises: UnsupportedPlatformException: Raised when this class is requested for an unsupported platform. """ - klass = _get_socket_directories_cls(osname) + klass = _get_socket_paths_cls(osname) if not klass: raise UnsupportedPlatformException(osname) return klass() - def get_process_socket_path(self, namespace: str | None = None, *, create_dir: bool = False): + def get_process_socket_path( + self, + namespace: str | None = None, + *, + base_dir: str | None = None, + create_dir: bool = False, + ): """ Gets the path for this process' socket in the given namespace. Args: namespace (Optional[str]): The optional namespace (subdirectory) where the sockets go. + base_dir (Optional[str]): The base directory to create sockets in. Defaults to the temp directory. create_dir (bool): Whether to create the socket directory. Default is false. Raises: @@ -60,15 +67,29 @@ def get_process_socket_path(self, namespace: str | None = None, *, create_dir: b len(socket_name) <= _PID_MAX_LENGTH ), f"PID too long. Only PIDs up to {_PID_MAX_LENGTH} digits are supported." - return os.path.join(self.get_socket_dir(namespace, create=create_dir), socket_name) + return self.get_socket_path( + socket_name, + namespace, + base_dir=base_dir, + create_dir=create_dir, + ) - def get_socket_dir(self, namespace: str | None = None, *, create: bool = False) -> str: + def get_socket_path( + self, + base_socket_name: str, + namespace: str | None = None, + *, + base_dir: str | None = None, + create_dir: bool = False, + ) -> str: """ - Gets the base directory for sockets used in Adaptor IPC + Gets the path for a socket used in Adaptor IPC Args: + base_socket_name (str): The name of the socket namespace (Optional[str]): The optional namespace (subdirectory) where the sockets go - create (bool): Whether to create the directory or not. Default is false. + base_dir (Optional[str]): The base directory to create sockets in. Defaults to the temp directory. + create_dir (bool): Whether to create the directory or not. Default is false. Raises: NonvalidSocketPathException: Raised if the user has configured a socket base directory @@ -77,48 +98,38 @@ def get_socket_dir(self, namespace: str | None = None, *, create: bool = False) not be raised if the user has configured a socket base directory. """ - def create_dir(path: str) -> str: - if create: + def mkdir(path: str) -> str: + if create_dir: os.makedirs(path, mode=0o700, exist_ok=True) return path - rel_path = os.path.join(".openjd", "adaptors", "sockets") - if namespace: - rel_path = os.path.join(rel_path, namespace) - - reasons: list[str] = [] + def gen_socket_path(dir: str, base_name: str): + name = base_name + i = 0 + while os.path.exists(os.path.join(dir, name)): + i += 1 + name = f"{base_name}_{i}" + return os.path.join(dir, name) - # First try home directory - home_dir = os.path.expanduser("~") - socket_dir = os.path.join(home_dir, rel_path) - try: - self.verify_socket_path(socket_dir) - except NonvalidSocketPathException as e: - reasons.append(f"Cannot create sockets directory in the home directory because: {e}") + if not base_dir: + socket_dir = os.path.realpath(tempfile.gettempdir()) else: - return create_dir(socket_dir) + socket_dir = os.path.realpath(base_dir) - # Last resort is the temp directory - temp_dir = tempfile.gettempdir() - socket_dir = os.path.join(temp_dir, rel_path) + if namespace: + socket_dir = os.path.join(socket_dir, namespace) + + mkdir(socket_dir) + + socket_path = gen_socket_path(socket_dir, base_socket_name) try: - self.verify_socket_path(socket_dir) + self.verify_socket_path(socket_path) except NonvalidSocketPathException as e: - reasons.append(f"Cannot create sockets directory in the temp directory because: {e}") - else: - # Also check that the sticky bit is set on the temp dir - if not os.stat(temp_dir).st_mode & stat.S_ISVTX: - reasons.append( - f"Cannot use temporary directory {temp_dir} because it does not have the " - "sticky bit (restricted deletion flag) set" - ) - else: - return create_dir(socket_dir) + raise NoSocketPathFoundException( + f"Socket path '{socket_path}' failed verification: {e}" + ) from e - raise NoSocketPathFoundException( - "Failed to find a suitable base directory to create sockets in for the following " - f"reasons: {os.linesep.join(reasons)}" - ) + return socket_path @abc.abstractmethod def verify_socket_path(self, path: str) -> None: # pragma: no cover @@ -132,53 +143,86 @@ def verify_socket_path(self, path: str) -> None: # pragma: no cover pass -class LinuxSocketDirectories(SocketDirectories): +class WindowsSocketPaths(SocketPaths): + """ + Specialization for verifying socket paths on Windows systems. + """ + + def verify_socket_path(self, path: str) -> None: + # TODO: Verify Windows permissions of parent directories are least privileged + pass + + +class UnixSocketPaths(SocketPaths): + """ + Specialization for verifying socket paths on Unix systems. + """ + + def verify_socket_path(self, path: str) -> None: + # Walk up directories and check that the sticky bit is set if the dir is world writable + prev_path = path + curr_path = os.path.dirname(path) + while prev_path != curr_path and len(curr_path) > 0: + path_stat = os.stat(curr_path) + if path_stat.st_mode & stat.S_IWOTH and not path_stat.st_mode & stat.S_ISVTX: + raise NoSocketPathFoundException( + f"Cannot use directory {curr_path} because it is world writable and does not " + "have the sticky bit (restricted deletion flag) set" + ) + prev_path = curr_path + curr_path = os.path.dirname(curr_path) + + +class LinuxSocketPaths(UnixSocketPaths): """ Specialization for socket paths in Linux systems. """ # This is based on the max length of socket names to 108 bytes # See unix(7) under "Address format" - _socket_path_max_length = 108 - _socket_dir_max_length = _socket_path_max_length - _PID_MAX_LENGTH_PADDED + # In practice, only 107 bytes are accepted (one byte for null terminator) + _socket_name_max_length = 108 - 1 def verify_socket_path(self, path: str) -> None: + super().verify_socket_path(path) path_length = len(path.encode("utf-8")) - if path_length > self._socket_dir_max_length: + if path_length > self._socket_name_max_length: raise NonvalidSocketPathException( - "Socket base directory path too big. The maximum allowed size is " - f"{self._socket_dir_max_length} bytes, but the directory has a size of " + "Socket name too long. The maximum allowed size is " + f"{self._socket_name_max_length} bytes, but the name has a size of " f"{path_length}: {path}" ) -class MacOSSocketDirectories(SocketDirectories): +class MacOSSocketPaths(UnixSocketPaths): """ Specialization for socket paths in macOS systems. """ # This is based on the max length of socket names to 104 bytes # See https://github.com/apple-oss-distributions/xnu/blob/1031c584a5e37aff177559b9f69dbd3c8c3fd30a/bsd/sys/un.h#L79 - _socket_path_max_length = 104 - _socket_dir_max_length = _socket_path_max_length - _PID_MAX_LENGTH_PADDED + # In practice, only 103 bytes are accepted (one byte for null terminator) + _socket_name_max_length = 104 - 1 def verify_socket_path(self, path: str) -> None: + super().verify_socket_path(path) path_length = len(path.encode("utf-8")) - if path_length > self._socket_dir_max_length: + if path_length > self._socket_name_max_length: raise NonvalidSocketPathException( - "Socket base directory path too big. The maximum allowed size is " - f"{self._socket_dir_max_length} bytes, but the directory has a size of " + "Socket name too long. The maximum allowed size is " + f"{self._socket_name_max_length} bytes, but the name has a size of " f"{path_length}: {path}" ) -_os_map: dict[str, type[SocketDirectories]] = { - OSName.LINUX: LinuxSocketDirectories, - OSName.MACOS: MacOSSocketDirectories, +_os_map: dict[str, type[SocketPaths]] = { + OSName.LINUX: LinuxSocketPaths, + OSName.MACOS: MacOSSocketPaths, + OSName.WINDOWS: WindowsSocketPaths, } -def _get_socket_directories_cls( +def _get_socket_paths_cls( osname: OSName, -) -> type[SocketDirectories] | None: # pragma: no cover +) -> type[SocketPaths] | None: # pragma: no cover return _os_map.get(osname, None) diff --git a/src/openjd/adaptor_runtime/_utils/_constants.py b/src/openjd/adaptor_runtime/_utils/_constants.py new file mode 100644 index 0000000..b7c5bc1 --- /dev/null +++ b/src/openjd/adaptor_runtime/_utils/_constants.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import re + +_OPENJD_LOG_PATTERN = r"^openjd_\S+: " +_OPENJD_LOG_REGEX = re.compile(_OPENJD_LOG_PATTERN) + +_OPENJD_FAIL_STDOUT_PREFIX = "openjd_fail: " +_OPENJD_PROGRESS_STDOUT_PREFIX = "openjd_progress: " +_OPENJD_STATUS_STDOUT_PREFIX = "openjd_status: " +_OPENJD_ENV_STDOUT_PREFIX = "openjd_env: " + +_OPENJD_ADAPTOR_SOCKET_ENV = "OPENJD_ADAPTOR_SOCKET" diff --git a/src/openjd/adaptor_runtime/_utils/_logging.py b/src/openjd/adaptor_runtime/_utils/_logging.py index 8c4d6e3..2a10f01 100644 --- a/src/openjd/adaptor_runtime/_utils/_logging.py +++ b/src/openjd/adaptor_runtime/_utils/_logging.py @@ -6,9 +6,6 @@ Optional, ) -_OPENJD_LOG_PATTERN = r"^openjd_\S+: " -_OPENJD_LOG_REGEX = re.compile(_OPENJD_LOG_PATTERN) - class ConditionalFormatter(logging.Formatter): """ diff --git a/src/openjd/adaptor_runtime/_utils/_secure_open.py b/src/openjd/adaptor_runtime/_utils/_secure_open.py index a44d246..47269de 100644 --- a/src/openjd/adaptor_runtime/_utils/_secure_open.py +++ b/src/openjd/adaptor_runtime/_utils/_secure_open.py @@ -81,7 +81,7 @@ def get_file_owner_in_windows(filepath: "StrOrBytesPath") -> str: # pragma: is- Returns: str: A string in the format 'DOMAIN\\Username' representing the file's owner. """ - sd = win32security.GetFileSecurity(filepath, win32security.OWNER_SECURITY_INFORMATION) + sd = win32security.GetFileSecurity(str(filepath), win32security.OWNER_SECURITY_INFORMATION) owner_sid = sd.GetSecurityDescriptorOwner() name, domain, _ = win32security.LookupAccountSid(None, owner_sid) return f"{domain}\\{name}" @@ -108,13 +108,13 @@ def set_file_permissions_in_windows(filepath: "StrOrBytesPath") -> None: # prag dacl.AddAccessAllowedAce(win32security.ACL_REVISION, win32con.DELETE, user_sid) # Apply the DACL to the file - sd = win32security.GetFileSecurity(filepath, win32security.DACL_SECURITY_INFORMATION) + sd = win32security.GetFileSecurity(str(filepath), win32security.DACL_SECURITY_INFORMATION) sd.SetSecurityDescriptorDacl( 1, # A flag that indicates the presence of a DACL in the security descriptor. dacl, # An ACL structure that specifies the DACL for the security descriptor. 0, # Don't retrieve the default DACL ) - win32security.SetFileSecurity(filepath, win32security.DACL_SECURITY_INFORMATION, sd) + win32security.SetFileSecurity(str(filepath), win32security.DACL_SECURITY_INFORMATION, sd) def _get_flags_from_mode_str(open_mode: str) -> int: diff --git a/src/openjd/adaptor_runtime/adaptors/_adaptor_runner.py b/src/openjd/adaptor_runtime/adaptors/_adaptor_runner.py index 3d3b1e8..4217364 100644 --- a/src/openjd/adaptor_runtime/adaptors/_adaptor_runner.py +++ b/src/openjd/adaptor_runtime/adaptors/_adaptor_runner.py @@ -6,13 +6,12 @@ from ._adaptor_states import AdaptorState, AdaptorStates from ._base_adaptor import BaseAdaptor as BaseAdaptor +from .._utils._constants import _OPENJD_FAIL_STDOUT_PREFIX __all__ = ["AdaptorRunner"] _logger = logging.getLogger(__name__) -_OPENJD_FAIL_STDOUT_PREFIX: str = "openjd_fail: " - class AdaptorRunner(AdaptorStates): """ diff --git a/src/openjd/adaptor_runtime/adaptors/_base_adaptor.py b/src/openjd/adaptor_runtime/adaptors/_base_adaptor.py index 55ce364..70912b3 100644 --- a/src/openjd/adaptor_runtime/adaptors/_base_adaptor.py +++ b/src/openjd/adaptor_runtime/adaptors/_base_adaptor.py @@ -13,6 +13,7 @@ from typing import Type from typing import TypeVar +from .._utils._constants import _OPENJD_PROGRESS_STDOUT_PREFIX, _OPENJD_STATUS_STDOUT_PREFIX from .configuration import AdaptorConfiguration, ConfigurationManager from .configuration._configuration_manager import ( create_adaptor_configuration_manager as create_adaptor_configuration_manager, @@ -54,8 +55,8 @@ class BaseAdaptor(AdaptorStates, Generic[_T]): Base class for adaptors. """ - _OPENJD_PROGRESS_STDOUT_PREFIX: str = "openjd_progress: " - _OPENJD_STATUS_STDOUT_PREFIX: str = "openjd_status: " + _OPENJD_PROGRESS_STDOUT_PREFIX: str = _OPENJD_PROGRESS_STDOUT_PREFIX + _OPENJD_STATUS_STDOUT_PREFIX: str = _OPENJD_STATUS_STDOUT_PREFIX def __init__( self, diff --git a/src/openjd/adaptor_runtime/application_ipc/_adaptor_server.py b/src/openjd/adaptor_runtime/application_ipc/_adaptor_server.py index 70a88af..d914ad9 100644 --- a/src/openjd/adaptor_runtime/application_ipc/_adaptor_server.py +++ b/src/openjd/adaptor_runtime/application_ipc/_adaptor_server.py @@ -8,7 +8,7 @@ from socketserver import UnixStreamServer # type: ignore[attr-defined] from typing import TYPE_CHECKING -from .._http import SocketDirectories +from .._http import SocketPaths from ._http_request_handler import AdaptorHTTPRequestHandler if TYPE_CHECKING: # pragma: no cover because pytest will think we should test for this. @@ -33,7 +33,10 @@ def __init__( actions_queue: ActionsQueue, adaptor: BaseAdaptor, ) -> None: # pragma: no cover - socket_path = SocketDirectories.for_os().get_process_socket_path("dcc", create_dir=True) + socket_path = SocketPaths.for_os().get_process_socket_path( + ".openjd_adaptor_server", + create_dir=True, + ) super().__init__(socket_path, AdaptorHTTPRequestHandler) self.actions_queue = actions_queue diff --git a/src/openjd/adaptor_runtime/application_ipc/_win_adaptor_server.py b/src/openjd/adaptor_runtime/application_ipc/_win_adaptor_server.py index d269260..0471d91 100644 --- a/src/openjd/adaptor_runtime/application_ipc/_win_adaptor_server.py +++ b/src/openjd/adaptor_runtime/application_ipc/_win_adaptor_server.py @@ -26,7 +26,11 @@ class WinAdaptorServer(NamedPipeServer): actions_queue: ActionsQueue adaptor: BaseAdaptor - def __init__(self, actions_queue: ActionsQueue, adaptor: BaseAdaptor) -> None: + def __init__( + self, + actions_queue: ActionsQueue, + adaptor: BaseAdaptor, + ) -> None: """ Adaptor Server class in Windows. diff --git a/test/openjd/adaptor_runtime/conftest.py b/test/openjd/adaptor_runtime/conftest.py index 83e282a..642d94a 100644 --- a/test/openjd/adaptor_runtime/conftest.py +++ b/test/openjd/adaptor_runtime/conftest.py @@ -1,12 +1,15 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. import platform +import random +import string from typing import Generator +from unittest.mock import MagicMock, patch -from openjd.adaptor_runtime._osname import OSName -import string -import random import pytest +from openjd.adaptor_runtime._osname import OSName +from openjd.adaptor_runtime._http import sockets + if OSName.is_windows(): import win32net import win32netcon @@ -92,3 +95,13 @@ def delete_user() -> None: yield username, password # Delete the user after test completes delete_user() + + +@pytest.fixture(scope="session", autouse=OSName().is_macos()) +def mock_sockets_py_tempfile_gettempdir_to_slash_tmp() -> Generator[MagicMock, None, None]: + """ + Mock that is automatically used on Mac to override the tempfile.gettempdir() usages in sockets.py + because the folder returned by it on Mac is too long for the socket name to fit (max 104 bytes, see sockets.py) + """ + with patch.object(sockets.tempfile, "gettempdir", return_value="/tmp") as m: + yield m diff --git a/test/openjd/adaptor_runtime/integ/AdaptorExample/adaptor.py b/test/openjd/adaptor_runtime/integ/AdaptorExample/adaptor.py index 42a6a43..afbf414 100644 --- a/test/openjd/adaptor_runtime/integ/AdaptorExample/adaptor.py +++ b/test/openjd/adaptor_runtime/integ/AdaptorExample/adaptor.py @@ -45,6 +45,7 @@ def on_start(self) -> None: # This example initializes a server thread to interact with a client application, showing command exchange and # execution. _logger.info("on_start") + # Initialize the server thread to manage actions self.server = AdaptorServer( # actions_queue will be used for storing the actions. In the client application, it will keep polling the diff --git a/test/openjd/adaptor_runtime/integ/background/test_background_mode.py b/test/openjd/adaptor_runtime/integ/background/test_background_mode.py index f62ff90..974e736 100644 --- a/test/openjd/adaptor_runtime/integ/background/test_background_mode.py +++ b/test/openjd/adaptor_runtime/integ/background/test_background_mode.py @@ -20,7 +20,10 @@ from openjd.adaptor_runtime._background.frontend_runner import ( FrontendRunner, HTTPError, - _load_connection_settings, +) +from openjd.adaptor_runtime._background.loaders import ( + ConnectionSettingsLoadingError, + ConnectionSettingsFileLoader, ) from openjd.adaptor_runtime._osname import OSName @@ -51,7 +54,7 @@ def mock_runtime_logger_level(self, tmpdir: pathlib.Path): yield @pytest.fixture - def connection_file_path(self, tmp_path: pathlib.Path) -> str: + def connection_file_path(self, tmp_path: pathlib.Path) -> pathlib.Path: connection_dir = os.path.join(tmp_path.absolute(), "connection_dir") os.mkdir(connection_dir) if OSName.is_windows(): @@ -63,18 +66,20 @@ def connection_file_path(self, tmp_path: pathlib.Path) -> str: from openjd.adaptor_runtime._utils._secure_open import set_file_permissions_in_windows set_file_permissions_in_windows(connection_dir) - return os.path.join(connection_dir, "connection.json") + return pathlib.Path(connection_dir) / "connection.json" @pytest.fixture def initialized_setup( self, - connection_file_path: str, + connection_file_path: pathlib.Path, caplog: pytest.LogCaptureFixture, ) -> Generator[tuple[FrontendRunner, psutil.Process], None, None]: caplog.set_level(0) - frontend = FrontendRunner(connection_file_path, timeout_s=5.0) - frontend.init(sys.modules[AdaptorExample.__module__]) - conn_settings = _load_connection_settings(connection_file_path) + frontend = FrontendRunner(timeout_s=5.0) + frontend.init( + adaptor_module=sys.modules[AdaptorExample.__module__], + connection_file_path=connection_file_path, + ) match = re.search("Started backend process. PID: ([0-9]+)", caplog.text) assert match is not None @@ -93,14 +98,21 @@ def initialized_setup( # Once all handles are closed, the system automatically cleans up the named pipe. if OSName.is_posix(): try: - os.remove(conn_settings.socket) - except FileNotFoundError: - pass # Already deleted + conn_settings = ConnectionSettingsFileLoader(connection_file_path).load() + except ConnectionSettingsLoadingError as e: + print( + f"Failed to load connection settings, socket file cleanup will be skipped: {e}" + ) + else: + try: + os.remove(conn_settings.socket) + except FileNotFoundError: + pass # Already deleted def test_init( self, initialized_setup: tuple[FrontendRunner, psutil.Process], - connection_file_path: str, + connection_file_path: pathlib.Path, ) -> None: # GIVEN _, backend_proc = initialized_setup @@ -108,7 +120,7 @@ def test_init( # THEN assert os.path.exists(connection_file_path) - connection_settings = _load_connection_settings(connection_file_path) + connection_settings = ConnectionSettingsFileLoader(connection_file_path).load() if OSName.is_windows(): import pywintypes @@ -140,11 +152,11 @@ def test_init( def test_shutdown( self, initialized_setup: tuple[FrontendRunner, psutil.Process], - connection_file_path: str, + connection_file_path: pathlib.Path, ) -> None: # GIVEN frontend, backend_proc = initialized_setup - conn_settings = _load_connection_settings(connection_file_path) + conn_settings = ConnectionSettingsFileLoader(connection_file_path).load() # WHEN frontend.shutdown() @@ -153,7 +165,7 @@ def test_shutdown( assert all( [ _wait_for_file_deletion(p, timeout_s=1) - for p in [connection_file_path, conn_settings.socket] + for p in [str(connection_file_path), conn_settings.socket] ] ) diff --git a/test/openjd/adaptor_runtime/integ/test_integration_entrypoint.py b/test/openjd/adaptor_runtime/integ/test_integration_entrypoint.py index ca2eb05..d1690ff 100644 --- a/test/openjd/adaptor_runtime/integ/test_integration_entrypoint.py +++ b/test/openjd/adaptor_runtime/integ/test_integration_entrypoint.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +import pathlib import os import re import sys @@ -32,7 +33,7 @@ class TestCommandAdaptorRun: """ def test_runs_command_adaptor( - self, capfd: pytest.CaptureFixture, caplog: pytest.LogCaptureFixture + self, capfd: pytest.CaptureFixture, caplog: pytest.LogCaptureFixture, tmp_path: pathlib.Path ): # GIVEN caplog.set_level(INFO) @@ -75,63 +76,170 @@ class TestCommandAdaptorDaemon: Tests for the CommandAdaptor running using the `daemon` command-line. """ - def test_start_stop(self, caplog: pytest.LogCaptureFixture, tmp_path: Path): - # GIVEN - caplog.set_level(INFO) - connection_file = tmp_path / "connection.json" - test_start_argv = [ - "program_filename.py", - "daemon", - "start", - "--connection-file", - str(connection_file), - "--init-data", - json.dumps( - { - "on_prerun": "on_prerun", - "on_postrun": "on_postrun", - } - ), - ] - test_stop_argv = [ - "program_filename.py", - "daemon", - "stop", - "--connection-file", - str(connection_file), - ] - entrypoint = EntryPoint(CommandAdaptorExample) + class TestUsingConnectionFile: + """ + Daemon tests using the --connection-file option + """ - # WHEN - with ( - patch.object(runtime_entrypoint.sys, "argv", test_start_argv), - patch.object(runtime_entrypoint.logging.Logger, "setLevel"), - ): - entrypoint.start() - with ( - patch.object(runtime_entrypoint.sys, "argv", test_stop_argv), - patch.object(runtime_entrypoint.logging.Logger, "setLevel"), - ): - entrypoint.start() + def test_start_stop(self, caplog: pytest.LogCaptureFixture, tmp_path: Path): + # GIVEN + caplog.set_level(INFO) + connection_file = tmp_path / "connection.json" + entrypoint = EntryPoint(CommandAdaptorExample) - # THEN - assert "Initializing backend process" in caplog.text - assert "Connected successfully" in caplog.text - assert "Running in background daemon mode." in caplog.text - assert "Daemon background process stopped." in caplog.text - assert "on_prerun" not in caplog.text - assert "on_postrun" not in caplog.text - - def test_run(self, caplog: pytest.LogCaptureFixture, tmp_path: Path): - # GIVEN - caplog.set_level(INFO) - connection_file = tmp_path / "connection.json" - test_start_argv = [ + # WHEN + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_start_argv(connection_file=connection_file), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_stop_argv(connection_file=connection_file), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + + # THEN + assert "Initializing backend process" in caplog.text + assert "Connected successfully" in caplog.text + assert "Running in background daemon mode." in caplog.text + assert "Daemon background process stopped." in caplog.text + assert "on_prerun" not in caplog.text + assert "on_postrun" not in caplog.text + + def test_run(self, caplog: pytest.LogCaptureFixture, tmp_path: Path): + # GIVEN + caplog.set_level(INFO) + connection_file = tmp_path / "connection.json" + test_run_argv = [ + "program_filename.py", + "daemon", + "run", + "--connection-file", + str(connection_file), + "--run-data", + json.dumps( + {"args": ["echo", "hello world"] if OSName.is_windows() else ["hello world"]} + ), + ] + entrypoint = EntryPoint(CommandAdaptorExample) + + # WHEN + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_start_argv(connection_file=connection_file), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + with ( + patch.object(runtime_entrypoint.sys, "argv", test_run_argv), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_stop_argv(connection_file=connection_file), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + + # THEN + assert "on_prerun" in caplog.text + assert "hello world" in caplog.text + assert "on_postrun" in caplog.text + + class TestUsingEnvVar: + """ + Daemon tests that do not use the --connection-file option and instead use the + OPENJD_ADAPTOR_SOCKET environment variable + """ + + def test_full_cycle(self, caplog: pytest.LogCaptureFixture) -> None: + # GIVEN + caplog.set_level(INFO) + entrypoint = EntryPoint(CommandAdaptorExample) + + # WHEN + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_start_argv(connection_file=None), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + + # THEN + match = re.search( + "openjd_env: OPENJD_ADAPTOR_SOCKET=(.*)$", + caplog.text, + re.MULTILINE, + ) + assert ( + match is not None + ), f"Expected openjd_env statement not found in output: {caplog.text}" + openjd_adaptor_socket = match.group(1) + print( + f"DEBUG: Using OPENJD_ADAPTOR_SOCKET={openjd_adaptor_socket} (exists={os.path.exists(openjd_adaptor_socket)})" + ) + + # WHEN + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_run_argv(connection_file=None), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + patch.dict( + runtime_entrypoint.os.environ, {"OPENJD_ADAPTOR_SOCKET": openjd_adaptor_socket} + ), + ): + entrypoint.start() + with ( + patch.object( + runtime_entrypoint.sys, + "argv", + TestCommandAdaptorDaemon.get_stop_argv(connection_file=None), + ), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + patch.dict( + runtime_entrypoint.os.environ, {"OPENJD_ADAPTOR_SOCKET": openjd_adaptor_socket} + ), + ): + entrypoint.start() + + # THEN + assert "Initializing backend process" in caplog.text + assert "Connected successfully" in caplog.text + assert "Running in background daemon mode." in caplog.text + assert "on_prerun" in caplog.text + assert "hello world" in caplog.text + assert "on_postrun" in caplog.text + assert "Daemon background process stopped." in caplog.text + + @staticmethod + def get_start_argv(*, connection_file: Path | None = None) -> list[str]: + return [ "program_filename.py", "daemon", "start", - "--connection-file", - str(connection_file), + *(["--connection-file", str(connection_file)] if connection_file else []), "--init-data", json.dumps( { @@ -140,44 +248,32 @@ def test_run(self, caplog: pytest.LogCaptureFixture, tmp_path: Path): } ), ] - test_run_argv = [ + + @staticmethod + def get_run_argv(*, connection_file: Path | None = None) -> list[str]: + return [ "program_filename.py", "daemon", "run", - "--connection-file", - str(connection_file), + *(["--connection-file", str(connection_file)] if connection_file else []), "--run-data", json.dumps( {"args": ["echo", "hello world"] if OSName.is_windows() else ["hello world"]} ), ] - test_stop_argv = [ + + @staticmethod + def get_stop_argv(*, connection_file: Path | None = None) -> list[str]: + return [ "program_filename.py", "daemon", "stop", - "--connection-file", - str(connection_file), + *( + [ + "--connection-file", + str(connection_file), + ] + if connection_file + else [] + ), ] - entrypoint = EntryPoint(CommandAdaptorExample) - - # WHEN - with ( - patch.object(runtime_entrypoint.sys, "argv", test_start_argv), - patch.object(runtime_entrypoint.logging.Logger, "setLevel"), - ): - entrypoint.start() - with ( - patch.object(runtime_entrypoint.sys, "argv", test_run_argv), - patch.object(runtime_entrypoint.logging.Logger, "setLevel"), - ): - entrypoint.start() - with ( - patch.object(runtime_entrypoint.sys, "argv", test_stop_argv), - patch.object(runtime_entrypoint.logging.Logger, "setLevel"), - ): - entrypoint.start() - - # THEN - assert "on_prerun" in caplog.text - assert "hello world" in caplog.text - assert "on_postrun" in caplog.text diff --git a/test/openjd/adaptor_runtime/unit/background/test_backend_runner.py b/test/openjd/adaptor_runtime/unit/background/test_backend_runner.py index c4788b7..1ce62a6 100644 --- a/test/openjd/adaptor_runtime/unit/background/test_backend_runner.py +++ b/test/openjd/adaptor_runtime/unit/background/test_backend_runner.py @@ -1,4 +1,5 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from __future__ import annotations import json import os @@ -23,7 +24,7 @@ class TestBackendRunner: @pytest.fixture(autouse=True) def socket_path(self, tmp_path: pathlib.Path) -> Generator[str, None, None]: if OSName.is_posix(): - with patch.object(backend_runner.SocketDirectories, "get_process_socket_path") as mock: + with patch.object(backend_runner.SocketPaths, "get_process_socket_path") as mock: path = os.path.join(tmp_path, "socket", "1234") mock.return_value = path @@ -67,15 +68,17 @@ def test_run( ): # GIVEN caplog.set_level("DEBUG") - conn_file_path = "/path/to/conn_file" + conn_file = pathlib.Path(os.sep) / "path" / "to" / "conn_file" connection_settings = {"socket": socket_path} adaptor_runner = Mock() - runner = BackendRunner(adaptor_runner, conn_file_path) + runner = BackendRunner(adaptor_runner, connection_file_path=conn_file) # WHEN open_mock: MagicMock with patch.object( - backend_runner, "secure_open", mock_open(read_data=json.dumps(connection_settings)) + backend_runner, + "secure_open", + mock_open(read_data=json.dumps(connection_settings)), ) as open_mock: runner.run() @@ -93,7 +96,7 @@ def test_run( ) mock_thread.assert_called_once() mock_thread.return_value.start.assert_called_once() - open_mock.assert_called_once_with(conn_file_path, open_mode="w") + open_mock.assert_called_once_with(conn_file, open_mode="w") mock_json_dump.assert_called_once_with( ConnectionSettings(socket_path), open_mock.return_value, @@ -101,9 +104,9 @@ def test_run( ) mock_thread.return_value.join.assert_called_once() if OSName.is_posix(): - mock_os_remove.assert_has_calls([call(conn_file_path), call(socket_path)]) + mock_os_remove.assert_has_calls([call(conn_file), call(socket_path)]) else: - mock_os_remove.assert_has_calls([call(conn_file_path)]) + mock_os_remove.assert_has_calls([call(conn_file)]) def test_run_raises_when_http_server_fails_to_start( self, @@ -114,7 +117,10 @@ def test_run_raises_when_http_server_fails_to_start( caplog.set_level("DEBUG") exc = Exception() mock_server_cls.side_effect = exc - runner = BackendRunner(Mock(), "") + runner = BackendRunner( + Mock(), + connection_file_path=pathlib.Path(os.path.sep) / "tmp" / "connection.json", + ) # WHEN with pytest.raises(Exception) as raised_exc: @@ -144,9 +150,9 @@ def test_run_raises_when_writing_connection_file_fails( caplog.set_level("DEBUG") err = OSError() open_mock.side_effect = err - conn_file_path = "/path/to/conn_file" + conn_file = pathlib.Path(os.sep) / "path" / "to" / "conn_file" adaptor_runner = Mock() - runner = BackendRunner(adaptor_runner, conn_file_path) + runner = BackendRunner(adaptor_runner, connection_file_path=conn_file) # WHEN with pytest.raises(OSError) as raised_err: @@ -164,12 +170,12 @@ def test_run_raises_when_writing_connection_file_fails( ] mock_thread.assert_called_once() mock_thread.return_value.start.assert_called_once() - open_mock.assert_called_once_with(conn_file_path, open_mode="w") + open_mock.assert_called_once_with(conn_file, open_mode="w") mock_thread.return_value.join.assert_called_once() if OSName.is_posix(): - mock_os_remove.assert_has_calls([call(conn_file_path), call(socket_path)]) + mock_os_remove.assert_has_calls([call(conn_file), call(socket_path)]) else: - mock_os_remove.assert_has_calls([call(conn_file_path)]) + mock_os_remove.assert_has_calls([call(conn_file)]) @patch.object(backend_runner.signal, "signal") @patch.object(backend_runner.ServerResponseGenerator, "submit_task") @@ -178,9 +184,9 @@ def test_signal_hook(self, mock_submit, signal_mock: MagicMock) -> None: # as expected. # GIVEN - conn_file_path = "/path/to/conn_file" + conn_file_path = pathlib.Path(os.sep) / "path" / "to" / "conn_file" adaptor_runner = Mock() - runner = BackendRunner(adaptor_runner, conn_file_path) + runner = BackendRunner(adaptor_runner, connection_file_path=conn_file_path) server_mock = MagicMock() submit_mock = MagicMock() server_mock.submit = submit_mock diff --git a/test/openjd/adaptor_runtime/unit/background/test_frontend_runner.py b/test/openjd/adaptor_runtime/unit/background/test_frontend_runner.py index 1f2ecc5..9545ea5 100644 --- a/test/openjd/adaptor_runtime/unit/background/test_frontend_runner.py +++ b/test/openjd/adaptor_runtime/unit/background/test_frontend_runner.py @@ -4,32 +4,32 @@ import http.client as http_client import json +import os import re import signal import subprocess import sys +import tempfile from pathlib import Path from types import ModuleType from typing import Generator, Optional -from unittest.mock import MagicMock, PropertyMock, call, mock_open, patch +from unittest.mock import MagicMock, call, patch import pytest -import openjd.adaptor_runtime._background.frontend_runner as frontend_runner +from openjd.adaptor_runtime._background import frontend_runner from openjd.adaptor_runtime._osname import OSName from openjd.adaptor_runtime.adaptors import AdaptorState from openjd.adaptor_runtime._background.frontend_runner import ( AdaptorFailedException, FrontendRunner, HTTPError, - _load_connection_settings, - _wait_for_file, + _wait_for_connection_file, ) from openjd.adaptor_runtime._background.model import ( AdaptorStatus, BufferedOutput, ConnectionSettings, - DataclassMapper, HeartbeatResponse, ) @@ -43,17 +43,77 @@ class TestFrontendRunner: def server_name(self) -> str: return "/path/to/socket" if OSName.is_posix() else r"\\.\pipe\TestPipe" + @pytest.fixture + def connection_settings(self, server_name: str) -> ConnectionSettings: + return ConnectionSettings(server_name) + @pytest.fixture(autouse=True) - def mock_connection_settings(self, server_name: str) -> Generator[MagicMock, None, None]: - with patch.object(FrontendRunner, "connection_settings", new_callable=PropertyMock) as mock: - mock.return_value = ConnectionSettings(server_name) - yield mock + def mock_connection_settings( + self, connection_settings: ConnectionSettings + ) -> Generator[None, None, None]: + with patch.object( + FrontendRunner, + "connection_settings", + return_value=connection_settings, + create=True, + ): + yield class TestInit: """ Tests for the FrontendRunner.init method """ + @pytest.fixture(autouse=True) + def open_mock(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner, "open") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_path_exists(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.Path, "exists") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_uuid(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.uuid, "uuid4") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_gettempdir(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.tempfile, "gettempdir") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_Popen(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.subprocess, "Popen") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_wait_for_connection_file(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner, "_wait_for_connection_file") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_heartbeat(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.FrontendRunner, "_heartbeat") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_sys_argv(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.sys, "argv", return_value=[]) as m: + yield m + + @pytest.fixture(autouse=True) + def mock_sys_executable(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.sys, "executable", return_value="executable") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_connection_settings_file_load(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.ConnectionSettingsFileLoader, "load") as m: + yield m + @pytest.mark.parametrize( argnames="reentry_exe", argvalues=[ @@ -61,51 +121,53 @@ class TestInit: (Path("reeentry_exe_value"),), ], ) - @patch.object(frontend_runner.sys, "argv") - @patch.object(frontend_runner.sys, "executable") - @patch.object(frontend_runner.json, "dumps") - @patch.object(FrontendRunner, "_heartbeat") - @patch.object(frontend_runner, "_wait_for_file") - @patch.object(frontend_runner.subprocess, "Popen") - @patch.object(frontend_runner.os.path, "exists") def test_initializes_backend_process( self, - mock_exists: MagicMock, + mock_path_exists: MagicMock, mock_Popen: MagicMock, - mock_wait_for_file: MagicMock, + mock_wait_for_connection_file: MagicMock, mock_heartbeat: MagicMock, - mock_json_dumps: MagicMock, mock_sys_executable: MagicMock, mock_sys_argv: MagicMock, + mock_uuid: MagicMock, + open_mock: MagicMock, caplog: pytest.LogCaptureFixture, reentry_exe: Optional[Path], ): # GIVEN caplog.set_level("DEBUG") - mock_json_dumps.return_value = "test" - mock_exists.return_value = False + mock_path_exists.return_value = False pid = 123 mock_Popen.return_value.pid = pid mock_sys_executable.return_value = "executable" mock_sys_argv.return_value = [] adaptor_module = ModuleType("") adaptor_module.__package__ = "package" - conn_file_path = "/path" init_data = {"init": "data"} path_mapping_data: dict = {} - runner = FrontendRunner(conn_file_path) + connection_file_path = Path("connection.test") + runner = FrontendRunner() # WHEN - runner.init(adaptor_module, init_data, path_mapping_data, reentry_exe) + runner.init( + adaptor_module=adaptor_module, + connection_file_path=connection_file_path, + init_data=init_data, + path_mapping_data=path_mapping_data, + reentry_exe=reentry_exe, + ) # THEN - assert caplog.messages == [ - "Initializing backend process...", - f"Started backend process. PID: {pid}", - "Verifying connection to backend...", - "Connected successfully", - ] - mock_exists.assert_called_once_with(conn_file_path) + assert all( + m in caplog.messages + for m in [ + "Initializing backend process...", + f"Started backend process. PID: {pid}", + "Verifying connection to backend...", + "Connected successfully", + ] + ) + mock_path_exists.assert_called_once_with() if reentry_exe is None: expected_args = [ sys.executable, @@ -113,78 +175,95 @@ def test_initializes_backend_process( adaptor_module.__package__, "daemon", "_serve", - "--connection-file", - conn_file_path, "--init-data", json.dumps(init_data), "--path-mapping-rules", json.dumps(path_mapping_data), + "--connection-file", + str(connection_file_path), ] else: expected_args = [ str(reentry_exe), "daemon", "_serve", - "--connection-file", - conn_file_path, "--init-data", json.dumps(init_data), "--path-mapping-rules", json.dumps(path_mapping_data), + "--connection-file", + str(connection_file_path), + ] + expected_args.extend( + [ + "--bootstrap-log-file", + os.path.join( + tempfile.gettempdir(), + f"adaptor-runtime-background-bootstrap-{mock_uuid.return_value}.log", + ), ] + ) mock_Popen.assert_called_once_with( expected_args, shell=False, close_fds=True, start_new_session=True, stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, + stdout=open_mock.return_value, + stderr=open_mock.return_value, + ) + mock_wait_for_connection_file.assert_called_once_with( + str(connection_file_path), + max_retries=5, + interval_s=1, ) - mock_wait_for_file.assert_called_once_with(conn_file_path, timeout_s=5) mock_heartbeat.assert_called_once() def test_raises_when_adaptor_module_not_package(self): # GIVEN adaptor_module = ModuleType("") adaptor_module.__package__ = None - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN with pytest.raises(Exception) as raised_exc: - runner.init(adaptor_module) + runner.init( + adaptor_module=adaptor_module, + connection_file_path="/path", + ) # THEN assert raised_exc.match(f"Adaptor module is not a package: {adaptor_module}") - @patch.object(frontend_runner.os.path, "exists") def test_raises_when_connection_file_exists( self, - mock_exists: MagicMock, + mock_path_exists: MagicMock, ): # GIVEN - mock_exists.return_value = True + mock_path_exists.return_value = True adaptor_module = ModuleType("") adaptor_module.__package__ = "package" - conn_file_path = "/path" - runner = FrontendRunner(conn_file_path) + conn_file_path = Path("/path") + runner = FrontendRunner() # WHEN with pytest.raises(FileExistsError) as raised_err: - runner.init(adaptor_module) + runner.init( + adaptor_module=adaptor_module, + connection_file_path=conn_file_path, + ) # THEN assert raised_err.match( - "Cannot init a new backend process with an existing connection file at: " - + conn_file_path + re.escape( + f"Cannot init a new backend process with an existing connection file at: {conn_file_path}" + ) ) - mock_exists.assert_called_once_with(conn_file_path) + mock_path_exists.assert_called_once_with() - @patch.object(frontend_runner.subprocess, "Popen") - @patch.object(frontend_runner.os.path, "exists") def test_raises_when_failed_to_create_backend_process( self, - mock_exists: MagicMock, + mock_path_exists: MagicMock, mock_Popen: MagicMock, caplog: pytest.LogCaptureFixture, ): @@ -192,102 +271,123 @@ def test_raises_when_failed_to_create_backend_process( caplog.set_level("DEBUG") exc = Exception() mock_Popen.side_effect = exc - mock_exists.return_value = False + mock_path_exists.return_value = False adaptor_module = ModuleType("") adaptor_module.__package__ = "package" - conn_file_path = "/path" - runner = FrontendRunner(conn_file_path) + conn_file_path = Path("/path") + runner = FrontendRunner() # WHEN with pytest.raises(Exception) as raised_exc: - runner.init(adaptor_module) + runner.init( + adaptor_module=adaptor_module, + connection_file_path=conn_file_path, + ) # THEN assert raised_exc.value is exc - assert caplog.messages == [ - "Initializing backend process...", - "Failed to initialize backend process: ", - ] - mock_exists.assert_called_once_with(conn_file_path) + assert all( + m in caplog.messages + for m in [ + "Initializing backend process...", + "Failed to initialize backend process: ", + ] + ) + mock_path_exists.assert_called_once_with() mock_Popen.assert_called_once() - @patch.object(frontend_runner, "_wait_for_file") - @patch.object(frontend_runner.subprocess, "Popen") - @patch.object(frontend_runner.os.path, "exists") def test_raises_when_connection_file_wait_times_out( self, - mock_exists: MagicMock, + mock_path_exists: MagicMock, mock_Popen: MagicMock, - mock_wait_for_file: MagicMock, + mock_wait_for_connection_file: MagicMock, caplog: pytest.LogCaptureFixture, ): # GIVEN caplog.set_level("DEBUG") err = TimeoutError() - mock_wait_for_file.side_effect = err - mock_exists.return_value = False + mock_wait_for_connection_file.side_effect = err + mock_path_exists.return_value = False pid = 123 mock_Popen.return_value.pid = pid adaptor_module = ModuleType("") adaptor_module.__package__ = "package" - conn_file_path = "/path" - runner = FrontendRunner(conn_file_path) + conn_file_path = Path("/path") + runner = FrontendRunner() # WHEN with pytest.raises(TimeoutError) as raised_err: - runner.init(adaptor_module) + runner.init( + adaptor_module=adaptor_module, + connection_file_path=conn_file_path, + ) # THEN assert raised_err.value is err - print(caplog.messages) - assert caplog.messages == [ - "Initializing backend process...", - f"Started backend process. PID: {pid}", - f"Backend process failed to write connection file in time at: {conn_file_path}", - ] - mock_exists.assert_called_once_with(conn_file_path) + assert all( + m in caplog.messages + for m in [ + "Initializing backend process...", + f"Started backend process. PID: {pid}", + f"Backend process failed to write connection file in time at: {conn_file_path}", + ] + ) + mock_path_exists.assert_called_once_with() mock_Popen.assert_called_once() - mock_wait_for_file.assert_called_once_with(conn_file_path, timeout_s=5) + mock_wait_for_connection_file.assert_called_once_with( + str(conn_file_path), + max_retries=5, + interval_s=1, + ) class TestHeartbeat: """ Tests for the FrontendRunner._heartbeat method """ - @patch.object(frontend_runner.json, "load") - @patch.object(DataclassMapper, "map") - @patch.object(FrontendRunner, "_send_request") + @pytest.fixture(autouse=True) + def mock_json_load(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.json, "load") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_dataclass_mapper_map(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.DataclassMapper, "map") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_send_request(self) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.FrontendRunner, "_send_request") as m: + yield m + def test_sends_heartbeat( self, mock_send_request: MagicMock, - mock_map: MagicMock, + mock_dataclass_mapper_map: MagicMock, mock_json_load: MagicMock, ): # GIVEN if OSName.is_windows(): mock_send_request.return_value = {"body": '{"key1": "value1"}'} mock_response = mock_send_request.return_value - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN response = runner._heartbeat() # THEN - assert response is mock_map.return_value + assert response is mock_dataclass_mapper_map.return_value if OSName.is_posix(): mock_json_load.assert_called_once_with(mock_response.fp) - mock_map.assert_called_once_with(mock_json_load.return_value) + mock_dataclass_mapper_map.assert_called_once_with(mock_json_load.return_value) else: - mock_map.assert_called_once_with({"key1": "value1"}) + mock_dataclass_mapper_map.assert_called_once_with({"key1": "value1"}) mock_send_request.assert_called_once_with("GET", "/heartbeat", params=None) - @patch.object(frontend_runner.json, "load") - @patch.object(DataclassMapper, "map") - @patch.object(FrontendRunner, "_send_request") def test_sends_heartbeat_with_ack_id( self, mock_send_request: MagicMock, - mock_map: MagicMock, + mock_dataclass_mapper_map: MagicMock, mock_json_load: MagicMock, ): # GIVEN @@ -295,18 +395,18 @@ def test_sends_heartbeat_with_ack_id( if OSName.is_windows(): mock_send_request.return_value = {"body": '{"key1": "value1"}'} mock_response = mock_send_request.return_value - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN response = runner._heartbeat(ack_id) # THEN - assert response is mock_map.return_value + assert response is mock_dataclass_mapper_map.return_value if OSName.is_posix(): mock_json_load.assert_called_once_with(mock_response.fp) - mock_map.assert_called_once_with(mock_json_load.return_value) + mock_dataclass_mapper_map.assert_called_once_with(mock_json_load.return_value) else: - mock_map.assert_called_once_with({"key1": "value1"}) + mock_dataclass_mapper_map.assert_called_once_with({"key1": "value1"}) mock_send_request.assert_called_once_with( "GET", "/heartbeat", params={"ack_id": ack_id} ) @@ -338,7 +438,7 @@ def test_heartbeats_until_complete( mock_event.wait = MagicMock() mock_event.is_set = MagicMock(return_value=False) heartbeat_interval = 1 - runner = FrontendRunner("", heartbeat_interval=heartbeat_interval) + runner = FrontendRunner(heartbeat_interval=heartbeat_interval) # WHEN runner._heartbeat_until_state_complete(state) @@ -367,7 +467,7 @@ def test_raises_when_adaptor_fails(self, mock_heartbeat: MagicMock) -> None: failed=False, ), ] - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN with pytest.raises(AdaptorFailedException) as raised_exc: @@ -385,7 +485,7 @@ class TestShutdown: @patch.object(FrontendRunner, "_send_request") def test_sends_shutdown(self, mock_send_request: MagicMock): # GIVEN - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN runner.shutdown() @@ -407,7 +507,7 @@ def test_sends_run( ): # GIVEN run_data = {"run": "data"} - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN runner.run(run_data) @@ -429,7 +529,7 @@ def test_sends_start( mock_heartbeat_until_state_complete: MagicMock, ): # GIVEN - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN runner.start() @@ -451,7 +551,7 @@ def test_sends_end( mock_heartbeat_until_state_complete: MagicMock, ): # GIVEN - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN runner.stop() @@ -471,7 +571,7 @@ def test_sends_cancel( mock_send_request: MagicMock, ): # GIVEN - runner = FrontendRunner("") + runner = FrontendRunner() # WHEN runner.cancel() @@ -497,12 +597,16 @@ def mock_getresponse(self, mock_response: MagicMock) -> Generator[MagicMock, Non yield mock @patch.object(frontend_runner.UnixHTTPConnection, "request") - def test_sends_request(self, mock_request: MagicMock, mock_getresponse: MagicMock): + def test_sends_request( + self, + mock_request: MagicMock, + mock_getresponse: MagicMock, + connection_settings: ConnectionSettings, + ): # GIVEN method = "GET" path = "/path" - conn_file_path = "/conn/file/path" - runner = FrontendRunner(conn_file_path) + runner = FrontendRunner(connection_settings=connection_settings) # WHEN response = runner._send_request(method, path) @@ -521,6 +625,7 @@ def test_raises_when_request_fails( self, mock_request: MagicMock, mock_getresponse: MagicMock, + connection_settings: ConnectionSettings, caplog: pytest.LogCaptureFixture, ): # GIVEN @@ -528,8 +633,7 @@ def test_raises_when_request_fails( mock_getresponse.side_effect = exc method = "GET" path = "/path" - conn_file_path = "/conn/file/path" - runner = FrontendRunner(conn_file_path) + runner = FrontendRunner(connection_settings=connection_settings) # WHEN with pytest.raises(http_client.HTTPException) as raised_exc: @@ -551,6 +655,7 @@ def test_raises_when_error_response_received( mock_request: MagicMock, mock_getresponse: MagicMock, mock_response: MagicMock, + connection_settings: ConnectionSettings, caplog: pytest.LogCaptureFixture, ): # GIVEN @@ -558,8 +663,7 @@ def test_raises_when_error_response_received( mock_response.reason = "Something went wrong" method = "GET" path = "/path" - conn_file_path = "/conn/file/path" - runner = FrontendRunner(conn_file_path) + runner = FrontendRunner(connection_settings=connection_settings) # WHEN with pytest.raises(HTTPError) as raised_err: @@ -579,13 +683,17 @@ def test_raises_when_error_response_received( mock_getresponse.assert_called_once() @patch.object(frontend_runner.UnixHTTPConnection, "request") - def test_formats_query_string(self, mock_request: MagicMock, mock_getresponse: MagicMock): + def test_formats_query_string( + self, + mock_request: MagicMock, + mock_getresponse: MagicMock, + connection_settings: ConnectionSettings, + ): # GIVEN method = "GET" path = "/path" - conn_file_path = "/conn/file/path" params = {"first param": 1, "second_param": ["one", "two three"]} - runner = FrontendRunner(conn_file_path) + runner = FrontendRunner(connection_settings=connection_settings) # WHEN response = runner._send_request(method, path, params=params) @@ -600,13 +708,17 @@ def test_formats_query_string(self, mock_request: MagicMock, mock_getresponse: M assert response is mock_getresponse.return_value @patch.object(frontend_runner.UnixHTTPConnection, "request") - def test_sends_body(self, mock_request: MagicMock, mock_getresponse: MagicMock): + def test_sends_body( + self, + mock_request: MagicMock, + mock_getresponse: MagicMock, + connection_settings: ConnectionSettings, + ): # GIVEN method = "GET" path = "/path" - conn_file_path = "/conn/file/path" json = {"the": "body"} - runner = FrontendRunner(conn_file_path) + runner = FrontendRunner(connection_settings=connection_settings) # WHEN response = runner._send_request(method, path, json_body=json) @@ -638,17 +750,23 @@ def mock_read_from_pipe(self, mock_response: MagicMock) -> Generator[MagicMock, mock_read_from_pipe.return_value = mock_response yield mock_read_from_pipe + @pytest.fixture + def connection_settings(self) -> ConnectionSettings: + return ConnectionSettings("\\\\.\\pipe") + + @pytest.fixture + def runner(self, connection_settings: ConnectionSettings) -> FrontendRunner: + return FrontendRunner(connection_settings=connection_settings) + def test_sends_request( self, mock_read_from_pipe: MagicMock, mock_response: str, + runner: FrontendRunner, ): # GIVEN method = "GET" path = "/path" - conn_file_path = r"C:\conn\file\path" - - runner = FrontendRunner(conn_file_path) # WHEN with patch.object( @@ -670,6 +788,7 @@ def test_raises_when_request_fails( self, mock_read_from_pipe: MagicMock, mock_response: str, + runner: FrontendRunner, caplog: pytest.LogCaptureFixture, ): # GIVEN @@ -679,8 +798,6 @@ def test_raises_when_request_fails( mock_read_from_pipe.side_effect = error_instance method = "GET" path = "/path" - conn_file_path = r"C:\conn\file\path" - runner = FrontendRunner(conn_file_path) # WHEN with patch.object( @@ -703,13 +820,12 @@ def test_raises_when_request_fails( def test_raises_when_error_response_received( self, mock_response: str, + runner: FrontendRunner, caplog: pytest.LogCaptureFixture, ): # GIVEN method = "GET" path = "/path" - conn_file_path = r"C:\conn\file\path" - runner = FrontendRunner(conn_file_path) # WHEN with patch.object( @@ -740,14 +856,13 @@ def test_formats_query_string( self, mock_read_from_pipe, mock_response: str, + runner: FrontendRunner, caplog: pytest.LogCaptureFixture, ): # GIVEN method = "GET" path = "/path" - conn_file_path = r"C:\conn\file\path" params = {"first param": 1, "second_param": ["one", "two three"]} - runner = FrontendRunner(conn_file_path) # WHEN with patch.object( @@ -770,14 +885,13 @@ def test_sends_body( self, mock_read_from_pipe, mock_response: str, + runner: FrontendRunner, caplog: pytest.LogCaptureFixture, ): # GIVEN method = "GET" path = "/path" - conn_file_path = r"C:\conn\file\path" json_body = {"the": "body"} - runner = FrontendRunner(conn_file_path) # WHEN with patch.object( @@ -804,8 +918,7 @@ def test_hook(self, signal_mock: MagicMock, cancel_mock: MagicMock) -> None: # as expected. # GIVEN - conn_file_path = "/path/to/conn_file" - runner = FrontendRunner(conn_file_path) + runner = FrontendRunner() # WHEN runner._sigint_handler(MagicMock(), MagicMock()) @@ -819,144 +932,57 @@ def test_hook(self, signal_mock: MagicMock, cancel_mock: MagicMock) -> None: cancel_mock.assert_called_once() -class TestLoadConnectionSettings: +class TestWaitForConnectionFile: """ - Tests for the _load_connection_settings method + Tests for the _wait_for_connection_file method """ - @patch.object(DataclassMapper, "map") - def test_loads_settings( - self, - mock_map: MagicMock, - ): - # GIVEN - filepath = "/path" - connection_settings = {"port": 123} - - # WHEN - with patch.object( - frontend_runner, "open", mock_open(read_data=json.dumps(connection_settings)) - ): - _load_connection_settings(filepath) - - # THEN - mock_map.assert_called_once_with(connection_settings) - + @patch.object(frontend_runner.ConnectionSettingsFileLoader, "load") @patch.object(frontend_runner, "open") - def test_raises_when_file_open_fails( - self, - open_mock: MagicMock, - caplog: pytest.LogCaptureFixture, - ): - # GIVEN - filepath = "/path" - err = OSError() - open_mock.side_effect = err - - # WHEN - with pytest.raises(OSError) as raised_err: - _load_connection_settings(filepath) - - # THEN - assert raised_err.value is err - assert "Failed to open connection file: " in caplog.text - - @patch.object(frontend_runner.json, "load") - def test_raises_when_json_decode_fails( - self, - mock_json_load: MagicMock, - caplog: pytest.LogCaptureFixture, - ): - # GIVEN - filepath = "/path" - err = json.JSONDecodeError("", "", 0) - mock_json_load.side_effect = err - - # WHEN - with pytest.raises(json.JSONDecodeError) as raised_err: - with patch.object(frontend_runner, "open", mock_open()): - _load_connection_settings(filepath) - - # THEN - assert raised_err.value is err - assert "Failed to decode connection file: " in caplog.text - - -class TestWaitForFile: - """ - Tests for the _wait_for_file method - """ - - @patch.object(frontend_runner, "open") - @patch.object(frontend_runner.time, "time") @patch.object(frontend_runner.time, "sleep") @patch.object(frontend_runner.os.path, "exists") def test_waits_for_file( self, mock_exists: MagicMock, mock_sleep: MagicMock, - mock_time: MagicMock, open_mock: MagicMock, + mock_conn_file_loader_load: MagicMock, ): # GIVEN filepath = "/path" - timeout = sys.float_info.max + max_retries = 9999 interval = 0.01 - mock_time.side_effect = [1, 2, 3, 4] mock_exists.side_effect = [False, True] err = IOError() open_mock.side_effect = [err, MagicMock()] + mock_conn_file_loader_load.return_value = ConnectionSettings("/server") # WHEN - _wait_for_file(filepath, timeout, interval) + _wait_for_connection_file(filepath, max_retries, interval) # THEN - assert mock_time.call_count == 4 mock_exists.assert_has_calls([call(filepath)] * 2) mock_sleep.assert_has_calls([call(interval)] * 3) open_mock.assert_has_calls([call(filepath, mode="r")] * 2) - @patch.object(frontend_runner.time, "time") @patch.object(frontend_runner.time, "sleep") @patch.object(frontend_runner.os.path, "exists") - def test_raises_when_timeout_reached( + def test_raises_when_retries_reached( self, mock_exists: MagicMock, mock_sleep: MagicMock, - mock_time: MagicMock, ): # GIVEN filepath = "/path" - timeout = 0 + max_retries = 0 interval = 0.01 - mock_time.side_effect = [1, 2] - mock_exists.side_effect = [False] + mock_exists.return_value = False # WHEN with pytest.raises(TimeoutError) as raised_err: - _wait_for_file(filepath, timeout, interval) + _wait_for_connection_file(filepath, max_retries, interval) # THEN - assert raised_err.match(f"Timed out after {timeout}s waiting for file at {filepath}") - assert mock_time.call_count == 2 + assert raised_err.match(f"Timed out waiting for File '{filepath}' to exist") mock_exists.assert_called_once_with(filepath) mock_sleep.assert_not_called() - - -@patch.object(frontend_runner, "_load_connection_settings") -def test_connection_settings_lazy_loads(mock_load_connection_settings: MagicMock): - # GIVEN - filepath = "/path" - expected = ConnectionSettings("/socket") - mock_load_connection_settings.return_value = expected - runner = FrontendRunner(filepath) - - # Assert the internal connection settings var is not set yet - assert not hasattr(runner, "_connection_settings") - - # WHEN - actual = runner.connection_settings - - # THEN - assert actual is expected - assert runner._connection_settings is expected diff --git a/test/openjd/adaptor_runtime/unit/background/test_loaders.py b/test/openjd/adaptor_runtime/unit/background/test_loaders.py new file mode 100644 index 0000000..513221b --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/background/test_loaders.py @@ -0,0 +1,148 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import dataclasses +import pathlib +import json +import re +import typing +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +from openjd.adaptor_runtime._background import loaders +from openjd.adaptor_runtime._background.loaders import ( + ConnectionSettingsEnvLoader, + ConnectionSettingsFileLoader, + ConnectionSettingsLoadingError, +) +from openjd.adaptor_runtime._background.model import ( + ConnectionSettings, +) + + +class TestConnectionSettingsFileLoader: + """ + Tests for the ConnectionsettingsFileLoader class + """ + + @pytest.fixture + def connection_settings(self) -> ConnectionSettings: + return ConnectionSettings(socket="socket") + + @pytest.fixture(autouse=True) + def open_mock( + self, connection_settings: ConnectionSettings + ) -> typing.Generator[MagicMock, None, None]: + with patch.object( + loaders, + "open", + mock_open(read_data=json.dumps(dataclasses.asdict(connection_settings))), + ) as m: + yield m + + @pytest.fixture + def connection_file_path(self) -> pathlib.Path: + return pathlib.Path("test") + + @pytest.fixture + def loader(self, connection_file_path: pathlib.Path) -> ConnectionSettingsFileLoader: + return ConnectionSettingsFileLoader(connection_file_path) + + def test_loads_settings( + self, + connection_settings: ConnectionSettings, + loader: ConnectionSettingsFileLoader, + ): + # WHEN + result = loader.load() + + # THEN + assert result == connection_settings + + def test_raises_when_file_open_fails( + self, + open_mock: MagicMock, + loader: ConnectionSettingsFileLoader, + connection_file_path: pathlib.Path, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + err = OSError() + open_mock.side_effect = err + + # WHEN + with pytest.raises(ConnectionSettingsLoadingError): + loader.load() + + # THEN + assert f"Failed to open connection file '{connection_file_path}': " in caplog.text + + def test_raises_when_json_decode_fails( + self, + loader: ConnectionSettingsFileLoader, + connection_file_path: pathlib.Path, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + err = json.JSONDecodeError("", "", 0) + + with patch.object(loaders.json, "load", side_effect=err): + with pytest.raises(ConnectionSettingsLoadingError): + # WHEN + loader.load() + + # THEN + assert f"Failed to decode connection file '{connection_file_path}': " in caplog.text + + +class TestConnectionSettingsEnvLoader: + @pytest.fixture + def connection_settings(self) -> ConnectionSettings: + return ConnectionSettings(socket="socket") + + @pytest.fixture + def mock_env(self, connection_settings: ConnectionSettings) -> dict[str, typing.Any]: + return { + env_name: getattr(connection_settings, attr_name) + for attr_name, (env_name, _) in ConnectionSettingsEnvLoader().env_map.items() + } + + @pytest.fixture(autouse=True) + def mock_os_environ( + self, mock_env: dict[str, typing.Any] + ) -> typing.Generator[dict, None, None]: + with patch.dict(loaders.os.environ, mock_env) as d: + yield d + + @pytest.fixture + def loader(self) -> ConnectionSettingsEnvLoader: + return ConnectionSettingsEnvLoader() + + def test_loads_connection_settings( + self, + loader: ConnectionSettingsEnvLoader, + connection_settings: ConnectionSettings, + ) -> None: + # WHEN + settings = loader.load() + + # THEN + assert connection_settings == settings + + def test_raises_error_when_required_not_provided( + self, + loader: ConnectionSettingsEnvLoader, + ) -> None: + # GIVEN + with patch.object(loaders.os.environ, "get", return_value=None): + with pytest.raises(loaders.ConnectionSettingsLoadingError) as raised_err: + # WHEN + loader.load() + + # THEN + assert re.match( + "^Required attribute '.*' does not have its corresponding environment variable '.*' set", + str(raised_err.value), + ) diff --git a/test/openjd/adaptor_runtime/unit/http/test_sockets.py b/test/openjd/adaptor_runtime/unit/http/test_sockets.py index f6e9c5f..972c461 100644 --- a/test/openjd/adaptor_runtime/unit/http/test_sockets.py +++ b/test/openjd/adaptor_runtime/unit/http/test_sockets.py @@ -1,73 +1,56 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. import os +import pathlib import re import stat from typing import Generator -from unittest.mock import ANY, MagicMock, call, patch +from unittest.mock import MagicMock, patch import pytest import openjd.adaptor_runtime._http.sockets as sockets from openjd.adaptor_runtime._http.sockets import ( - LinuxSocketDirectories, - MacOSSocketDirectories, + LinuxSocketPaths, + MacOSSocketPaths, NonvalidSocketPathException, NoSocketPathFoundException, - SocketDirectories, + SocketPaths, + UnixSocketPaths, ) -class SocketDirectoriesStub(SocketDirectories): +class SocketPathsStub(SocketPaths): def verify_socket_path(self, path: str) -> None: pass -class TestSocketDirectories: +class TestSocketPaths: class TestGetProcessSocketPath: """ - Tests for SocketDirectories.get_process_socket_path() + Tests for SocketPaths.get_process_socket_path() """ - @pytest.fixture - def socket_dir(self) -> str: - return "/path/to/socket/dir" - - @pytest.fixture(autouse=True) - def mock_socket_dir(self, socket_dir: str) -> Generator[MagicMock, None, None]: - with patch.object(SocketDirectories, "get_socket_dir") as mock: - mock.return_value = socket_dir - yield mock - - @pytest.mark.parametrize( - argnames=["create_dir"], - argvalues=[[True], [False]], - ids=["creates dir", "does not create dir"], - ) @patch.object(sockets.os, "getpid", return_value=1234) def test_gets_path( self, mock_getpid: MagicMock, - socket_dir: str, - mock_socket_dir: MagicMock, - create_dir: bool, ) -> None: # GIVEN namespace = "my-namespace" - subject = SocketDirectoriesStub() + subject = SocketPathsStub() # WHEN - result = subject.get_process_socket_path(namespace, create_dir=create_dir) + result = subject.get_process_socket_path(namespace) # THEN - assert result == os.path.join(socket_dir, str(mock_getpid.return_value)) + assert result.endswith(os.path.join(namespace, str(mock_getpid.return_value))) mock_getpid.assert_called_once() - mock_socket_dir.assert_called_once_with(namespace, create=create_dir) @patch.object(sockets.os, "getpid", return_value="a" * (sockets._PID_MAX_LENGTH + 1)) def test_asserts_max_pid_length(self, mock_getpid: MagicMock): # GIVEN - subject = SocketDirectoriesStub() + subject = SocketPathsStub() # WHEN with pytest.raises(AssertionError) as raised_err: @@ -79,76 +62,39 @@ def test_asserts_max_pid_length(self, mock_getpid: MagicMock): ) mock_getpid.assert_called_once() - class TestGetSocketDir: + class TestGetSocketPath: """ - Tests for SocketDirectories.get_socket_dir() + Tests for SocketPaths.get_socket_path() """ @pytest.fixture(autouse=True) - def mock_makedirs(self) -> Generator[MagicMock, None, None]: - with patch.object(sockets.os, "makedirs") as mock: + def mock_exists(self) -> Generator[MagicMock, None, None]: + with patch.object(sockets.os.path, "exists") as mock: + mock.return_value = False yield mock - @pytest.fixture - def home_dir(self) -> str: - return os.path.join("home", "user") - @pytest.fixture(autouse=True) - def mock_expanduser(self, home_dir: str) -> Generator[MagicMock, None, None]: - with patch.object(sockets.os.path, "expanduser", return_value=home_dir) as mock: - yield mock - - @pytest.fixture - def temp_dir(self) -> str: - return "tmp" - - @pytest.fixture(autouse=True) - def mock_gettempdir(self, temp_dir: str) -> Generator[MagicMock, None, None]: - with patch.object(sockets.tempfile, "gettempdir", return_value=temp_dir) as mock: + def mock_makedirs(self) -> Generator[MagicMock, None, None]: + with patch.object(sockets.os, "makedirs") as mock: yield mock - def test_gets_home_dir( - self, - mock_expanduser: MagicMock, - home_dir: str, - ) -> None: - # GIVEN - subject = SocketDirectoriesStub() - - # WHEN - result = subject.get_socket_dir() - - # THEN - mock_expanduser.assert_called_once_with("~") - assert result.startswith(home_dir) - + @patch.object(sockets.tempfile, "gettempdir", wraps=sockets.tempfile.gettempdir) @patch.object(sockets.os, "stat") - @patch.object(SocketDirectoriesStub, "verify_socket_path") def test_gets_temp_dir( self, - mock_verify_socket_path: MagicMock, mock_stat: MagicMock, mock_gettempdir: MagicMock, - temp_dir: str, ) -> None: # GIVEN - exc = NonvalidSocketPathException() - mock_verify_socket_path.side_effect = [exc, None] # Raise exc only once mock_stat.return_value.st_mode = stat.S_ISVTX - subject = SocketDirectoriesStub() + subject = UnixSocketPaths() # WHEN - result = subject.get_socket_dir() + subject.get_socket_path("sock") # THEN mock_gettempdir.assert_called_once() - mock_verify_socket_path.assert_has_calls( - [ - call(ANY), # home dir - call(result), # temp dir - ] - ) - mock_stat.assert_called_once_with(temp_dir) + mock_stat.assert_called() @pytest.mark.parametrize( argnames=["create"], @@ -157,154 +103,205 @@ def test_gets_temp_dir( ) def test_create_dir(self, mock_makedirs: MagicMock, create: bool) -> None: # GIVEN - subject = SocketDirectoriesStub() + subject = SocketPathsStub() # WHEN - result = subject.get_socket_dir(create=create) + result = subject.get_socket_path("sock", create_dir=create) # THEN if create: - mock_makedirs.assert_called_once_with(result, mode=0o700, exist_ok=True) + mock_makedirs.assert_called_once_with( + os.path.dirname(result), mode=0o700, exist_ok=True + ) else: mock_makedirs.assert_not_called() + def test_uses_base_dir(self, tmp_path: pathlib.Path) -> None: + # GIVEN + subject = SocketPathsStub() + base_dir = str(tmp_path) + + # WHEN + result = subject.get_socket_path("sock", base_dir=base_dir) + + # THEN + assert result.startswith(base_dir) + def test_uses_namespace(self) -> None: # GIVEN namespace = "my-namespace" - subject = SocketDirectoriesStub() + subject = SocketPathsStub() # WHEN - result = subject.get_socket_dir(namespace) + result = subject.get_socket_path("sock", namespace) # THEN - assert result.endswith(namespace) + assert os.path.dirname(result).endswith(namespace) - @patch.object(SocketDirectoriesStub, "verify_socket_path") - def test_raises_when_no_valid_dir_found(self, mock_verify_socket_path: MagicMock) -> None: + @patch.object(SocketPathsStub, "verify_socket_path") + def test_raises_when_no_valid_path_found(self, mock_verify_socket_path: MagicMock) -> None: # GIVEN mock_verify_socket_path.side_effect = NonvalidSocketPathException() - subject = SocketDirectoriesStub() + subject = SocketPathsStub() # WHEN with pytest.raises(NoSocketPathFoundException) as raised_exc: - subject.get_socket_dir() + subject.get_socket_path("sock") # THEN - assert raised_exc.match( - "Failed to find a suitable base directory to create sockets in for the following " - "reasons: " - ) - assert mock_verify_socket_path.call_count == 2 + assert raised_exc.match("^Socket path '.*' failed verification: .*") + assert mock_verify_socket_path.call_count == 1 - @patch.object(SocketDirectoriesStub, "verify_socket_path") - @patch.object(sockets.os, "stat") - def test_raises_when_no_tmpdir_sticky_bit( + @patch.object(sockets.os.path, "exists") + def test_handles_socket_name_collisions( self, - mock_stat: MagicMock, - mock_verify_socket_path: MagicMock, - temp_dir: str, + mock_exists: MagicMock, ) -> None: # GIVEN - mock_verify_socket_path.side_effect = [NonvalidSocketPathException(), None] - mock_stat.return_value.st_mode = 0 - subject = SocketDirectoriesStub() + sock_name = "sock" + existing_sock_names = [sock_name, f"{sock_name}_1", f"{sock_name}_2"] + mock_exists.side_effect = ([True] * len(existing_sock_names)) + [False] + expected_sock_name = f"{sock_name}_3" + + subject = SocketPathsStub() # WHEN - with pytest.raises(NoSocketPathFoundException) as raised_exc: - subject.get_socket_dir() + result = subject.get_socket_path(sock_name) # THEN - assert raised_exc.match( - re.escape( - f"Cannot use temporary directory {temp_dir} because it does not have the " - "sticky bit (restricted deletion flag) set" - ) - ) + assert result.endswith(expected_sock_name) + assert mock_exists.call_count == (len(existing_sock_names) + 1) + +class TestUnixSocketPaths: + @pytest.fixture(autouse=True) + def mock_stat(self) -> Generator[MagicMock, None, None]: + with patch.object(sockets.os, "stat") as m: + yield m -class TestLinuxSocketDirectories: @pytest.mark.parametrize( - argnames=["path"], + argnames=["mode", "should_fail"], argvalues=[ - ["a"], - ["a" * 100], + [stat.S_IWOTH & stat.S_ISVTX, False], + [stat.S_IWOTH, True], + [stat.S_IROTH, False], + ], + ids=[ + "world-writable, sticky, passes", + "world-writable, not sticky, fails", + "world-readable, not sticky, passes", ], - ids=["one byte", "100 bytes"], ) - def test_accepts_paths_within_100_bytes(self, path: str): - """ - Verifies the function accepts paths up to 100 bytes (108 byte max - 8 byte padding - for socket name portion (path sep + PID)) - """ - # GIVEN - subject = LinuxSocketDirectories() - - try: - # WHEN - subject.verify_socket_path(path) - except NonvalidSocketPathException as e: - pytest.fail(f"verify_socket_path raised an error when it should not have: {e}") - else: - # THEN - pass # success + def test_verifies_all_parent_directories( + self, + mode: int, + should_fail: bool, + mock_stat: MagicMock, + ) -> None: + pass - def test_rejects_paths_over_100_bytes(self): + @patch.object(sockets.os, "stat") + def test_raises_when_no_tmpdir_sticky_bit( + self, + mock_stat: MagicMock, + ) -> None: # GIVEN - length = 101 - path = "a" * length - subject = LinuxSocketDirectories() + mock_stat.return_value.st_mode = stat.S_IWOTH + socket_name = "/sock" + subject = UnixSocketPaths() # WHEN - with pytest.raises(NonvalidSocketPathException) as raised_exc: - subject.verify_socket_path(path) + with pytest.raises(NoSocketPathFoundException) as raised_exc: + subject.verify_socket_path(socket_name) # THEN assert raised_exc.match( - "Socket base directory path too big. The maximum allowed size is " - f"{subject._socket_dir_max_length} bytes, but the directory has a size of " - f"{length}: {path}" + re.escape( + f"Cannot use directory {os.path.dirname(socket_name)} because it is world writable" + " and does not have the sticky bit (restricted deletion flag) set" + ) ) + class TestLinuxSocketPaths: + @pytest.mark.parametrize( + argnames=["path"], + argvalues=[ + ["a"], + ["a" * 107], + ], + ids=["one byte", "107 bytes"], + ) + def test_accepts_names_within_107_bytes(self, path: str): + """ + Verifies the function accepts paths up to 100 bytes (108 byte max - 1 byte null terminator) + """ + # GIVEN + subject = LinuxSocketPaths() -class TestMacOSSocketDirectories: - @pytest.mark.parametrize( - argnames=["path"], - argvalues=[ - ["a"], - ["a" * 96], - ], - ids=["one byte", "96 bytes"], - ) - def test_accepts_paths_within_100_bytes(self, path: str): - """ - Verifies the function accepts paths up to 96 bytes (104 byte max - 8 byte padding - for socket name portion (path sep + PID)) - """ - # GIVEN - subject = MacOSSocketDirectories() + try: + # WHEN + subject.verify_socket_path(path) + except NonvalidSocketPathException as e: + pytest.fail(f"verify_socket_path raised an error when it should not have: {e}") + else: + # THEN + pass # success + + def test_rejects_names_over_107_bytes(self): + # GIVEN + length = 108 + path = "a" * length + subject = LinuxSocketPaths() - try: # WHEN - subject.verify_socket_path(path) - except NonvalidSocketPathException as e: - pytest.fail(f"verify_socket_path raised an error when it should not have: {e}") - else: + with pytest.raises(NonvalidSocketPathException) as raised_exc: + subject.verify_socket_path(path) + # THEN - pass # success + assert raised_exc.match( + "Socket name too long. The maximum allowed size is " + f"{subject._socket_name_max_length} bytes, but the name has a size of " + f"{length}: {path}" + ) - def test_rejects_paths_over_96_bytes(self): - # GIVEN - length = 97 - path = "a" * length - subject = MacOSSocketDirectories() + class TestMacOSSocketPaths: + @pytest.mark.parametrize( + argnames=["path"], + argvalues=[ + ["a"], + ["a" * 103], + ], + ids=["one byte", "103 bytes"], + ) + def test_accepts_paths_within_103_bytes(self, path: str): + """ + Verifies the function accepts paths up to 103 bytes (104 byte max - 1 byte null terminator) + """ + # GIVEN + subject = MacOSSocketPaths() - # WHEN - with pytest.raises(NonvalidSocketPathException) as raised_exc: - subject.verify_socket_path(path) + try: + # WHEN + subject.verify_socket_path(path) + except NonvalidSocketPathException as e: + pytest.fail(f"verify_socket_path raised an error when it should not have: {e}") + else: + # THEN + pass # success - # THEN - assert raised_exc.match( - "Socket base directory path too big. The maximum allowed size is " - f"{subject._socket_dir_max_length} bytes, but the directory has a size of " - f"{length}: {path}" - ) + def test_rejects_paths_over_103_bytes(self): + # GIVEN + length = 104 + path = "a" * length + subject = MacOSSocketPaths() + + # WHEN + with pytest.raises(NonvalidSocketPathException) as raised_exc: + subject.verify_socket_path(path) + + # THEN + assert raised_exc.match( + "Socket name too long. The maximum allowed size is " + f"{subject._socket_name_max_length} bytes, but the name has a size of " + f"{length}: {path}" + ) diff --git a/test/openjd/adaptor_runtime/unit/test_entrypoint.py b/test/openjd/adaptor_runtime/unit/test_entrypoint.py index 8ae766a..80de3dc 100644 --- a/test/openjd/adaptor_runtime/unit/test_entrypoint.py +++ b/test/openjd/adaptor_runtime/unit/test_entrypoint.py @@ -4,6 +4,7 @@ import argparse import json +import os import signal from pathlib import Path from typing import Optional @@ -21,6 +22,7 @@ ) from openjd.adaptor_runtime.adaptors import BaseAdaptor, SemanticVersion from openjd.adaptor_runtime._background import BackendRunner, FrontendRunner +from openjd.adaptor_runtime._background.model import ConnectionSettings from openjd.adaptor_runtime._osname import OSName from openjd.adaptor_runtime._entrypoint import _load_data @@ -214,7 +216,14 @@ def test_creates_adaptor_with_init_data(self, mock_adaptor_cls: MagicMock): # GIVEN init_data = {"init": "data"} with patch.object( - runtime_entrypoint.sys, "argv", ["Adaptor", "run", "--init-data", json.dumps(init_data)] + runtime_entrypoint.sys, + "argv", + [ + "Adaptor", + "run", + "--init-data", + json.dumps(init_data), + ], ): entrypoint = EntryPoint(mock_adaptor_cls) @@ -461,7 +470,7 @@ def test_runs_background_serve( ): # GIVEN init_data = {"init": "data"} - conn_file = "/path/to/conn_file" + conn_file = Path(os.sep) / "path" / "to" / "conn_file" with patch.object( runtime_entrypoint.sys, "argv", @@ -472,7 +481,7 @@ def test_runs_background_serve( "--init-data", json.dumps(init_data), "--connection-file", - conn_file, + str(conn_file), ], ): entrypoint = EntryPoint(mock_adaptor_cls) @@ -487,7 +496,7 @@ def test_runs_background_serve( ) mock_init.assert_called_once_with( mock_adaptor_runner.return_value, - conn_file, + connection_file_path=conn_file.resolve(), log_buffer=mock_log_buffer.return_value, ) mock_run.assert_called_once() @@ -506,7 +515,7 @@ def test_background_serve_no_signal_hook( ): # GIVEN init_data = {"init": "data"} - conn_file = "/path/to/conn_file" + conn_file = os.path.join(os.sep, "path", "to", "conn_file") with patch.object( runtime_entrypoint.sys, "argv", @@ -534,7 +543,7 @@ def test_background_start_raises_when_adaptor_module_not_loaded( mock_magic_init: MagicMock, ): # GIVEN - conn_file = "/path/to/conn_file" + conn_file = Path(os.sep) / "path" / "to" / "conn_file" with patch.object( runtime_entrypoint.sys, "argv", @@ -543,7 +552,7 @@ def test_background_start_raises_when_adaptor_module_not_loaded( "daemon", "start", "--connection-file", - conn_file, + str(conn_file), ], ): entrypoint = EntryPoint(FakeAdaptor) @@ -555,7 +564,7 @@ def test_background_start_raises_when_adaptor_module_not_loaded( # THEN assert raised_err.match(f"Adaptor module is not loaded: {FakeAdaptor.__module__}") - mock_magic_init.assert_called_once_with(conn_file) + mock_magic_init.assert_called_once_with() @pytest.mark.parametrize( argnames="reentry_exe", @@ -570,12 +579,12 @@ def test_background_start_raises_when_adaptor_module_not_loaded( def test_runs_background_start( self, mock_start: MagicMock, + mock_init: MagicMock, mock_magic_init: MagicMock, - mock_magic_start: MagicMock, reentry_exe: Optional[Path], ): # GIVEN - conn_file = "/path/to/conn_file" + conn_file = Path(os.sep) / "path" / "to" / "conn_file" with patch.object( runtime_entrypoint.sys, "argv", @@ -584,7 +593,7 @@ def test_runs_background_start( "daemon", "start", "--connection-file", - conn_file, + str(conn_file), ], ): mock_adaptor_module = Mock() @@ -597,10 +606,17 @@ def test_runs_background_start( entrypoint.start(reentry_exe=reentry_exe) # THEN - mock_magic_init.assert_called_once_with(mock_adaptor_module, {}, {}, reentry_exe) - mock_magic_start.assert_called_once_with(conn_file) + mock_magic_init.assert_called_once_with() + mock_init.assert_called_once_with( + adaptor_module=mock_adaptor_module, + connection_file_path=conn_file.resolve(), + init_data={}, + path_mapping_data={}, + reentry_exe=reentry_exe, + ) mock_start.assert_called_once_with() + @patch.object(runtime_entrypoint.ConnectionSettingsFileLoader, "load") @patch.object(FrontendRunner, "__init__", return_value=None) @patch.object(FrontendRunner, "shutdown") @patch.object(FrontendRunner, "stop") @@ -609,9 +625,11 @@ def test_runs_background_stop( mock_end: MagicMock, mock_shutdown: MagicMock, mock_magic_init: MagicMock, + mock_connection_settings_load: MagicMock, ): # GIVEN - conn_file = "/path/to/conn_file" + connection_settings = ConnectionSettings("socket") + mock_connection_settings_load.return_value = connection_settings with patch.object( runtime_entrypoint.sys, "argv", @@ -620,7 +638,7 @@ def test_runs_background_stop( "daemon", "stop", "--connection-file", - conn_file, + "/path/to/conn/file", ], ): entrypoint = EntryPoint(FakeAdaptor) @@ -629,19 +647,22 @@ def test_runs_background_stop( entrypoint.start() # THEN - mock_magic_init.assert_called_once_with(conn_file) + mock_magic_init.assert_called_once_with(connection_settings=connection_settings) mock_end.assert_called_once() mock_shutdown.assert_called_once_with() + @patch.object(runtime_entrypoint.ConnectionSettingsFileLoader, "load") @patch.object(FrontendRunner, "__init__", return_value=None) @patch.object(FrontendRunner, "run") def test_runs_background_run( self, mock_run: MagicMock, mock_magic_init: MagicMock, + mock_connection_settings_load: MagicMock, ): # GIVEN - conn_file = "/path/to/conn_file" + conn_settings = ConnectionSettings("socket") + mock_connection_settings_load.return_value = conn_settings run_data = {"run": "data"} with patch.object( runtime_entrypoint.sys, @@ -653,7 +674,7 @@ def test_runs_background_run( "--run-data", json.dumps(run_data), "--connection-file", - conn_file, + "/path/to/conn/file", ], ): entrypoint = EntryPoint(FakeAdaptor) @@ -662,9 +683,11 @@ def test_runs_background_run( entrypoint.start() # THEN - mock_magic_init.assert_called_once_with(conn_file) + mock_magic_init.assert_called_once_with(connection_settings=conn_settings) mock_run.assert_called_once_with(run_data) + mock_connection_settings_load.assert_called_once() + @patch.object(runtime_entrypoint.ConnectionSettingsFileLoader, "load") @patch.object(FrontendRunner, "__init__", return_value=None) @patch.object(FrontendRunner, "run") @patch.object(runtime_entrypoint.signal, "signal") @@ -673,9 +696,10 @@ def test_background_no_signal_hook( signal_mock: MagicMock, mock_run: MagicMock, mock_magic_init: MagicMock, + mock_connection_settings_load: MagicMock, ): # GIVEN - conn_file = "/path/to/conn_file" + conn_file = Path(os.sep) / "path" / "to" / "conn_file" run_data = {"run": "data"} with patch.object( runtime_entrypoint.sys, @@ -687,7 +711,7 @@ def test_background_no_signal_hook( "--run-data", json.dumps(run_data), "--connection-file", - conn_file, + str(conn_file), ], ): entrypoint = EntryPoint(FakeAdaptor) @@ -698,13 +722,15 @@ def test_background_no_signal_hook( # THEN signal_mock.assert_not_called() + @patch.object(runtime_entrypoint, "ConnectionSettingsFileLoader") @patch.object(runtime_entrypoint, "FrontendRunner") def test_makes_connection_file_path_absolute( self, mock_runner: MagicMock, + mock_connection_settings_loader: MagicMock, ): # GIVEN - conn_file = "relpath" + conn_file = Path("relpath") with patch.object( runtime_entrypoint.sys, "argv", @@ -713,23 +739,30 @@ def test_makes_connection_file_path_absolute( "daemon", "run", "--connection-file", - conn_file, + str(conn_file), ], ): entrypoint = EntryPoint(FakeAdaptor) # WHEN - mock_isabs: MagicMock + mock_is_absolute: MagicMock with ( - patch.object(runtime_entrypoint.os.path, "isabs", return_value=False) as mock_isabs, - patch.object(runtime_entrypoint.os.path, "abspath") as mock_abspath, + patch.object( + runtime_entrypoint.Path, "is_absolute", return_value=False + ) as mock_is_absolute, + patch.object( + runtime_entrypoint.Path, "absolute", return_value=Path("absolute") + ) as mock_absolute, ): entrypoint.start() # THEN - mock_isabs.assert_any_call(conn_file) - mock_abspath.assert_any_call(conn_file) - mock_runner.assert_called_once_with(mock_abspath.return_value) + mock_is_absolute.assert_called_once() + mock_absolute.assert_any_call() + mock_connection_settings_loader.assert_called_once_with(mock_absolute.return_value) + mock_runner.assert_called_once_with( + connection_settings=mock_connection_settings_loader.return_value.load.return_value + ) class TestLoadData: diff --git a/test/openjd/adaptor_runtime_client/integ/test_integration_client_interface.py b/test/openjd/adaptor_runtime_client/integ/test_integration_client_interface.py index 6fc2bd1..1dd0940 100644 --- a/test/openjd/adaptor_runtime_client/integ/test_integration_client_interface.py +++ b/test/openjd/adaptor_runtime_client/integ/test_integration_client_interface.py @@ -53,7 +53,7 @@ def test_graceful_shutdown(self) -> None: # Ensure the process actually shutdown assert client_subprocess.returncode is not None - def test_client_in_thread_does_not_do_graceful_shutdown(self): + def test_client_in_thread_does_not_do_graceful_shutdown(self) -> None: """Ensures that a client running in a thread does not crash by attempting to register a signal, since they can only be created in the main thread. This means the graceful shutdown is effectively ignored."""