From e0982cabed446dfeca1db549cace96619aadb068 Mon Sep 17 00:00:00 2001 From: Wouter De Borger Date: Tue, 31 Jul 2018 14:30:45 +0200 Subject: [PATCH] Issue/564 transport cleanup (#644) * added bootloader * split off session management --- src/inmanta/app.py | 8 +- src/inmanta/deploy.py | 1 + src/inmanta/protocol.py | 860 +++++++---------------------- src/inmanta/server/__init__.py | 10 +- src/inmanta/server/agentmanager.py | 129 +++-- src/inmanta/server/bootloader.py | 46 ++ src/inmanta/server/config.py | 2 +- src/inmanta/server/protocol.py | 602 ++++++++++++++++++++ src/inmanta/server/server.py | 69 +-- src/inmanta/util.py | 45 ++ tests/conftest.py | 35 +- tests/test_2way_protocol.py | 146 +++-- tests/test_agent.py | 5 +- tests/test_agent_manager.py | 16 +- tests/test_server.py | 94 ++-- tests/test_server_agent.py | 83 ++- 16 files changed, 1226 insertions(+), 925 deletions(-) create mode 100644 src/inmanta/server/bootloader.py create mode 100644 src/inmanta/server/protocol.py diff --git a/src/inmanta/app.py b/src/inmanta/app.py index 607a9cc384..1f5821493e 100755 --- a/src/inmanta/app.py +++ b/src/inmanta/app.py @@ -34,23 +34,23 @@ from inmanta.export import cfg_env, ModelExporter from inmanta.ast import CompilerException import yaml +from inmanta.server.bootloader import InmantaBootloader LOGGER = logging.getLogger() @command("server", help_msg="Start the inmanta server") def start_server(options): - from inmanta import server io_loop = IOLoop.current() - s = server.Server(io_loop) - s.start() + ibl = InmantaBootloader() + ibl.start() try: io_loop.start() except KeyboardInterrupt: IOLoop.current().stop() - s.stop() + ibl.stop() @command("agent", help_msg="Start the inmanta agent") diff --git a/src/inmanta/deploy.py b/src/inmanta/deploy.py index 907b83ef3f..e7f13daf1a 100644 --- a/src/inmanta/deploy.py +++ b/src/inmanta/deploy.py @@ -110,6 +110,7 @@ def setup_server(self, no_agent_log): args = [sys.executable, "-m", "inmanta.app", "-vvv", "-c", server_config, "server"] self._server_proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + LOGGER.debug("Started server on port %d", self._server_port) while self._server_proc.poll() is None: diff --git a/src/inmanta/protocol.py b/src/inmanta/protocol.py index e9b3d214b5..b81664c81f 100644 --- a/src/inmanta/protocol.py +++ b/src/inmanta/protocol.py @@ -1,5 +1,5 @@ """ - Copyright 2017 Inmanta + Copyright 2018 Inmanta Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ Contact: code@inmanta.com """ - import logging import socket import time @@ -31,182 +30,32 @@ import io import gzip -import tornado.web -from tornado import gen, queues, web +import tornado +from tornado import gen, web from inmanta import methods, const, execute from inmanta import config as inmanta_config -from tornado.httpserver import HTTPServer from tornado.httpclient import HTTPRequest, AsyncHTTPClient, HTTPError from tornado.ioloop import IOLoop -import ssl import jwt -LOGGER = logging.getLogger(__name__) -INMANTA_MT_HEADER = "X-Inmanta-tid" - - -class Result(object): - """ - A result of a method call - """ - - def __init__(self, multiple=False, code=0, result=None): - self._multiple = multiple - if multiple: - self._result = [] - if result is not None: - self._result.append(result) - else: - self._result = result - self.code = code - self._callback = None - - def add_result(self, result): - """ - Add a new result to an instance - - :param result: The result to store - """ - assert(self._multiple) - self._result.append(result) - if self._callback: - self._callback(self) - - def get_result(self): - """ - Only when the result is marked as available the result can be returned - """ - if self.available(): - return self._result - raise Exception("The result is not yet available") - - def set_result(self, value): - if not self.available(): - assert(not self._multiple) - self._result = value - if self._callback: - self._callback(self) - - def available(self): - if self._multiple: - return len(self._result) > 0 is not None or self.code > 0 - else: - return self._result is not None or self.code > 0 - - def wait(self, timeout=60): - """ - Wait for the result to become available - """ - count = 0 - while count < timeout: - time.sleep(0.1) - count += 0.1 - - result = property(get_result, set_result) - - def callback(self, fnc): - """ - Set a callback function that is to be called when the result is ready. When multiple - results are expected, the callback is called for each result. - """ - self._callback = fnc - - -class Transport(object): - """ - This class implements a transport for the Inmanta protocol. - - :param end_point_name: The name of the endpoint to which this transport belongs. This is used - for logging and configuration purposes - """ - @classmethod - def create(cls, transport_class, endpoint=None): - """ - Create an instance of the transport class - """ - return transport_class(endpoint) - - def __init__(self, endpoint=None): - self.__end_point = endpoint - self.daemon = True - self._connected = False - - endpoint = property(lambda x: x.__end_point) +from inmanta.util import Scheduler - def get_id(self): - """ - Returns a unique id for a transport on an endpoint - """ - return "%s_%s_transport" % (self.__end_point.name, self.__class__.__transport_name__) - - id = property(get_id) - - def start_endpoint(self): - """ - Start the transport as endpoint - """ - self.start() - - def stop_endpoint(self): - """ - Stop the transport as endpoint - """ - self.stop() - - def start_client(self): - """ - Start this transport as client - """ - self.start() - def stop_client(self): - """ - Stop this transport as client - """ - self.stop() - - def start(self): - """ - Start the transport as a new thread - """ +LOGGER = logging.getLogger(__name__) +INMANTA_MT_HEADER = "X-Inmanta-tid" - def stop(self): - """ - Stop the transport - """ - self._connected = False +""" - def call(self, method, destination=None, **kwargs): - """ - Perform a method call - """ - raise NotImplementedError() +RestServer => manages tornado/handlers, marshalling, dispatching, and endpoints - def _decode(self, body): - """ - Decode a response body - """ - if body is not None and len(body) > 0: - body = json.loads(tornado.escape.to_basestring(body)) - else: - body = None +ServerSlice => contributes handlers and methods - return body +ServerSlice.server [1] -- RestServer.endpoints [1:] - def set_connected(self): - """ - Mark this transport as connected - """ - LOGGER.debug("Transport %s is connected", self.get_id()) - self._connected = True - - def is_connected(self): - """ - Is this transport connected - """ - return self._connected +""" +# Util functions def custom_json_encoder(o): """ A custom json encoder that knows how to encode other types commonly used by Inmanta @@ -253,8 +102,10 @@ def gzipped_json(value): return True, gzip_value.getvalue() -class UnauhorizedError(Exception): - pass +def sh(msg, max_len=10): + if len(msg) < max_len: + return msg + return msg[0:max_len - 3] + "..." def encode_token(client_types, environment=None, idempotent=False, expire=None): @@ -320,145 +171,6 @@ def decode_token(token): return payload -class RESTHandler(tornado.web.RequestHandler): - """ - A generic class use by the transport - """ - - def initialize(self, transport: Transport, config): - self._transport = transport - self._config = config - - def _get_config(self, http_method): - if http_method.upper() not in self._config: - allowed = ", ".join(self._config.keys()) - self.set_header("Allow", allowed) - self._transport.return_error_msg(405, "%s is not supported for this url. Supported methods: %s" % - (http_method, allowed)) - return - - return self._config[http_method] - - def get_auth_token(self, headers: dict): - """ - Get the auth token provided by the caller. The token is provided as a bearer token. - """ - if "Authorization" not in headers: - return None - - parts = headers["Authorization"].split(" ") - if len(parts) == 0 or parts[0].lower() != "bearer" or len(parts) > 2 or len(parts) == 1: - LOGGER.warning("Invalid authentication header, Inmanta expects a bearer token. (%s was provided)", - headers["Authorization"]) - return None - - return decode_token(parts[1]) - - def respond(self, body, headers, status): - if body is not None: - self.write(json_encode(body)) - - for header, value in headers.items(): - self.set_header(header, value) - - self.set_status(status) - - @gen.coroutine - def _call(self, kwargs, http_method, call_config): - """ - An rpc like call - """ - if call_config is None: - body, headers, status = self._transport.return_error_msg(404, "This method does not exist.") - self.respond(body, headers, status) - return - - self.set_header("Access-Control-Allow-Origin", "*") - try: - message = self._transport._decode(self.request.body) - if message is None: - message = {} - - for key, value in self.request.query_arguments.items(): - if len(value) == 1: - message[key] = value[0].decode("latin-1") - else: - message[key] = [v.decode("latin-1") for v in value] - - request_headers = self.request.headers - - try: - auth_token = self.get_auth_token(request_headers) - except UnauhorizedError as e: - self.respond(*self._transport.return_error_msg(403, "Access denied: " + e.args[0])) - return - - auth_enabled = inmanta_config.Config.get("server", "auth", False) - if not auth_enabled or auth_token is not None: - result = yield self._transport._execute_call(kwargs, http_method, call_config, - message, request_headers, auth_token) - self.respond(*result) - else: - self.respond(*self._transport.return_error_msg(401, "Access to this resource is unauthorized.")) - except ValueError: - LOGGER.exception("An exception occured") - self.respond(*self._transport.return_error_msg(500, "Unable to decode request body")) - - @gen.coroutine - def head(self, *args, **kwargs): - yield self._call(http_method="HEAD", call_config=self._get_config("HEAD"), kwargs=kwargs) - - @gen.coroutine - def get(self, *args, **kwargs): - yield self._call(http_method="GET", call_config=self._get_config("GET"), kwargs=kwargs) - - @gen.coroutine - def post(self, *args, **kwargs): - yield self._call(http_method="POST", call_config=self._get_config("POST"), kwargs=kwargs) - - @gen.coroutine - def delete(self, *args, **kwargs): - yield self._call(http_method="DELETE", call_config=self._get_config("DELETE"), kwargs=kwargs) - - @gen.coroutine - def patch(self, *args, **kwargs): - yield self._call(http_method="PATCH", call_config=self._get_config("PATCH"), kwargs=kwargs) - - @gen.coroutine - def put(self, *args, **kwargs): - yield self._call(http_method="PUT", call_config=self._get_config("PUT"), kwargs=kwargs) - - @gen.coroutine - def options(self, *args, **kwargs): - allow_headers = "Origin, Accept, Content-Type, X-Requested-With, X-CSRF-Token" - if len(self._transport.headers): - allow_headers += ", " + ", ".join(self._transport.headers) - - self.set_header("Access-Control-Allow-Origin", "*") - self.set_header("Access-Control-Allow-Methods", "HEAD, GET, POST, PUT, OPTIONS, DELETE, PATCH") - self.set_header("Access-Control-Allow-Headers", allow_headers) - - self.set_status(200) - - -class StaticContentHandler(tornado.web.RequestHandler): - def initialize(self, transport: Transport, content, content_type): - self._transport = transport - self._content = content - self._content_type = content_type - - def get(self, *args, **kwargs): - self.set_header("Content-Type", self._content_type) - self.write(self._content) - self.set_status(200) - - -def sh(msg, max_len=10): - if len(msg) < max_len: - return msg - return msg[0:max_len - 3] + "..." - - def authorize_request(auth_data, metadata, message, config): """ Authorize a request based on the given data @@ -482,27 +194,90 @@ def authorize_request(auth_data, metadata, message, config): if ct in config[0]["client_types"]: ok = True - if not ok: - raise UnauhorizedError("The authorization token does not have a valid client type for this call." + - " (%s provided, %s expected" % (auth_data[ct_key], config[0]["client_types"])) + if not ok: + raise UnauhorizedError("The authorization token does not have a valid client type for this call." + + " (%s provided, %s expected" % (auth_data[ct_key], config[0]["client_types"])) + + return + + +# API +class UnauhorizedError(Exception): + pass + + +class Result(object): + """ + A result of a method call + """ + + def __init__(self, multiple=False, code=0, result=None): + self._multiple = multiple + if multiple: + self._result = [] + if result is not None: + self._result.append(result) + else: + self._result = result + self.code = code + self._callback = None + + def add_result(self, result): + """ + Add a new result to an instance + + :param result: The result to store + """ + assert(self._multiple) + self._result.append(result) + if self._callback: + self._callback(self) + + def get_result(self): + """ + Only when the result is marked as available the result can be returned + """ + if self.available(): + return self._result + raise Exception("The result is not yet available") + + def set_result(self, value): + if not self.available(): + assert(not self._multiple) + self._result = value + if self._callback: + self._callback(self) + + def available(self): + if self._multiple: + return len(self._result) > 0 is not None or self.code > 0 + else: + return self._result is not None or self.code > 0 + + def wait(self, timeout=60): + """ + Wait for the result to become available + """ + count = 0 + while count < timeout: + time.sleep(0.1) + count += 0.1 + + result = property(get_result, set_result) - return + def callback(self, fnc): + """ + Set a callback function that is to be called when the result is ready. When multiple + results are expected, the callback is called for each result. + """ + self._callback = fnc -class RESTTransport(Transport): - """" - A REST (json body over http) transport. Only methods that operate on resource can use all - HTTP verbs. For other methods the POST verb is used. - """ - __transport_name__ = "rest" +# Tornado Interface - def __init__(self, endpoint, connection_timout=120): - super().__init__(endpoint) - self.set_connected() - self._handlers = [] - self.token = inmanta_config.Config.get(self.id, "token", None) - self.connection_timout = connection_timout - self.headers = set() + +# Shared +class RESTBase(object): def _create_base_url(self, properties, msg=None, versioned=True): """ @@ -522,53 +297,16 @@ def _create_base_url(self, properties, msg=None, versioned=True): return url - def create_op_mapping(self): - """ - Build a mapping between urls, ops and methods - """ - url_map = defaultdict(dict) - headers = set() - for method, method_handlers in self.endpoint.__methods__.items(): - properties = method.__protocol_properties__ - call = (self.endpoint, method_handlers[0]) - - if "arg_options" in properties: - for opts in properties["arg_options"].values(): - if "header" in opts: - headers.add(opts["header"]) - - url = self._create_base_url(properties) - properties["api_version"] = "1" - url_map[url][properties["operation"]] = (properties, call, method.__wrapped__) - - url = self._create_base_url(properties, versioned=False) - properties = properties.copy() - properties["api_version"] = None - url_map[url][properties["operation"]] = (properties, call, method.__wrapped__) - - headers.add("Authorization") - self.headers = headers - return url_map - - def match_call(self, url, method): + def _decode(self, body): """ - Get the method call for the given url and http method + Decode a response body """ - url_map = self.create_op_mapping() - for url_re, handlers in url_map.items(): - if not url_re.endswith("$"): - url_re += "$" - match = re.match(url_re, url) - if match and method in handlers: - return match.groupdict(), handlers[method] - - return None, None + if body is not None and len(body) > 0: + body = json.loads(tornado.escape.to_basestring(body)) + else: + body = None - def return_error_msg(self, status=500, msg="", headers={}): - body = {"message": msg} - headers["Content-Type"] = "application/json" - LOGGER.debug("Signaling error to client: %d, %s, %s", status, body, headers) - return body, headers, status + return body @gen.coroutine def _execute_call(self, kwargs, http_method, config, message, request_headers, auth=None): @@ -719,70 +457,113 @@ def _execute_call(self, kwargs, http_method, config, message, request_headers, a LOGGER.exception("An exception occured during the request.") return self.return_error_msg(500, "An exception occured: " + str(e.args), headers) - def add_static_handler(self, location, path, default_filename=None, start=False): - """ - Configure a static handler to serve data from the specified path. - """ - if location[0] != "/": - location = "/" + location - if location[-1] != "/": - location = location + "/" +# Client side +class RESTTransport(RESTBase): + """" + A REST (json body over http) transport. Only methods that operate on resource can use all + HTTP verbs. For other methods the POST verb is used. + """ + __transport_name__ = "rest" + + def __init__(self, endpoint, connection_timout=120): + self.__end_point = endpoint + self.daemon = True + self._connected = False + self.set_connected() + self._handlers = [] + self.token = inmanta_config.Config.get(self.id, "token", None) + self.connection_timout = connection_timout + self.headers = set() + + endpoint = property(lambda x: x.__end_point) - options = {"path": path} - if default_filename is None: - options["default_filename"] = "index.html" + def get_id(self): + """ + Returns a unique id for a transport on an endpoint + """ + return "%s_%s_transport" % (self.__end_point.name, self.__class__.__transport_name__) - self._handlers.append((r"%s(.*)" % location, tornado.web.StaticFileHandler, options)) - self._handlers.append((r"%s" % location[:-1], tornado.web.RedirectHandler, {"url": location})) + id = property(get_id) - if start: - self._handlers.append((r"/", tornado.web.RedirectHandler, {"url": location})) + def start_client(self): + """ + Start this transport as client + """ + self.start() - def add_static_content(self, path, content, content_type="application/javascript"): - self._handlers.append((r"%s(.*)" % path, StaticContentHandler, {"transport": self, "content": content, - "content_type": content_type})) + def stop_client(self): + """ + Stop this transport as client + """ + self.stop() - def start_endpoint(self): + def start(self): """ - Start the transport + Start the transport as a new thread """ - url_map = self.create_op_mapping() + pass - for url, configs in url_map.items(): - handler_config = {} - for op, cfg in configs.items(): - handler_config[op] = cfg + def stop(self): + """ + Stop the transport + """ + self._connected = False - self._handlers.append((url, RESTHandler, {"transport": self, "config": handler_config})) - LOGGER.debug("Registering handler(s) for url %s and methods %s" % (url, ", ".join(handler_config.keys()))) + def set_connected(self): + """ + Mark this transport as connected + """ + LOGGER.debug("Transport %s is connected", self.get_id()) + self._connected = True - port = 8888 - if self.id in inmanta_config.Config.get() and "port" in inmanta_config.Config.get()[self.id]: - port = inmanta_config.Config.get()[self.id]["port"] + def is_connected(self): + """ + Is this transport connected + """ + return self._connected - application = tornado.web.Application(self._handlers, compress_response=True) + def create_op_mapping(self): + """ + Build a mapping between urls, ops and methods + """ + url_map = defaultdict(dict) + headers = set() + for method, method_handlers in self.endpoint.__methods__.items(): + properties = method.__protocol_properties__ + call = (self.endpoint, method_handlers[0]) - crt = inmanta_config.Config.get("server", "ssl_cert_file", None) - key = inmanta_config.Config.get("server", "ssl_key_file", None) + if "arg_options" in properties: + for opts in properties["arg_options"].values(): + if "header" in opts: + headers.add(opts["header"]) - if(crt is not None and key is not None): - ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - ssl_ctx.load_cert_chain(crt, key) + url = self._create_base_url(properties) + properties["api_version"] = "1" + url_map[url][properties["operation"]] = (properties, call, method.__wrapped__) - self.http_server = HTTPServer(application, decompress_request=True, ssl_options=ssl_ctx) - LOGGER.debug("Created REST transport with SSL") - else: - self.http_server = HTTPServer(application, decompress_request=True) + url = self._create_base_url(properties, versioned=False) + properties = properties.copy() + properties["api_version"] = None + url_map[url][properties["operation"]] = (properties, call, method.__wrapped__) - self.http_server.listen(port) + headers.add("Authorization") + self.headers = headers + return url_map - LOGGER.debug("Start REST transport") - super().start() + def match_call(self, url, method): + """ + Get the method call for the given url and http method + """ + url_map = self.create_op_mapping() + for url_re, handlers in url_map.items(): + if not url_re.endswith("$"): + url_re += "$" + match = re.match(url_re, url) + if match and method in handlers: + return match.groupdict(), handlers[method] - def stop_endpoint(self): - super().stop() - self.http_server.stop() + return None, None def _get_client_config(self): """ @@ -941,51 +722,6 @@ def __new__(cls, class_name, bases, dct): return type.__new__(cls, class_name, bases, dct) -class Scheduler(object): - """ - An event scheduler class - """ - - def __init__(self, io_loop): - self._scheduled = set() - self._io_loop = io_loop - - def add_action(self, action, interval, initial_delay=None): - """ - Add a new action - - :param action A function to call periodically - :param interval The interval between execution of actions - :param initial_delay Delay to the first execution, default to interval - """ - - if initial_delay is None: - initial_delay = interval - - LOGGER.debug("Scheduling action %s every %d seconds with initial delay %d", action, interval, initial_delay) - - def action_function(): - LOGGER.info("Calling %s" % action) - if action in self._scheduled: - try: - action() - except Exception: - LOGGER.exception("Uncaught exception while executing scheduled action") - - finally: - self._io_loop.call_later(interval, action_function) - - self._io_loop.call_later(initial_delay, action_function) - self._scheduled.add(action) - - def remove(self, action): - """ - Remove a scheduled action - """ - if action in self._scheduled: - self._scheduled.remove(action) - - class Endpoint(object): """ An end-point in the rpc framework @@ -1040,215 +776,6 @@ def get_node_name(self): node_name = property(get_node_name) -class Session(object): - """ - An environment that segments agents connected to the server - """ - - def __init__(self, sessionstore, io_loop, sid, hang_interval, timout, tid, endpoint_names, nodename): - self._sid = sid - self._interval = hang_interval - self._timeout = timout - self._sessionstore = sessionstore - self._seen = time.time() - self._callhandle = None - self.expired = False - - self.tid = tid - self.endpoint_names = endpoint_names - self.nodename = nodename - - self._io_loop = io_loop - - self._replies = {} - self.check_expire() - self._queue = queues.Queue() - - self.client = ReturnClient(str(sid), self) - - def check_expire(self): - if self.expired: - LOGGER.exception("Tried to expire session already expired") - ttw = self._timeout + self._seen - time.time() - if ttw < 0: - self.expire(self._seen - time.time()) - else: - self._callhandle = self._io_loop.call_later(ttw, self.check_expire) - - def get_id(self): - return self._sid - - id = property(get_id) - - def expire(self, timeout): - self.expired = True - if self._callhandle is not None: - self._io_loop.remove_timeout(self._callhandle) - self._sessionstore.expire(self, timeout) - - def seen(self): - self._seen = time.time() - - def _set_timeout(self, future, timeout, log_message): - def on_timeout(): - if not self.expired: - LOGGER.warning(log_message) - future.set_exception(gen.TimeoutError()) - - timeout_handle = self._io_loop.add_timeout(self._io_loop.time() + timeout, on_timeout) - future.add_done_callback(lambda _: self._io_loop.remove_timeout(timeout_handle)) - - def put_call(self, call_spec, timeout=10): - future = tornado.concurrent.Future() - - reply_id = uuid.uuid4() - - LOGGER.debug("Putting call %s: %s %s for agent %s in queue", reply_id, call_spec["method"], call_spec["url"], self._sid) - - q = self._queue - call_spec["reply_id"] = reply_id - q.put(call_spec) - self._set_timeout(future, timeout, "Call %s: %s %s for agent %s timed out." % - (reply_id, call_spec["method"], call_spec["url"], self._sid)) - self._replies[call_spec["reply_id"]] = future - - return future - - @gen.coroutine - def get_calls(self): - """ - Get all calls queued for a node. If no work is available, wait until timeout. This method returns none if a call - fails. - """ - try: - q = self._queue - call_list = [] - call = yield q.get(timeout=self._io_loop.time() + self._interval) - call_list.append(call) - while q.qsize() > 0: - call = yield q.get() - call_list.append(call) - - return call_list - - except gen.TimeoutError: - return None - - def set_reply(self, reply_id, data): - LOGGER.log(3, "Received Reply: %s", reply_id) - if reply_id in self._replies: - future = self._replies[reply_id] - del self._replies[reply_id] - if not future.done(): - future.set_result(data) - else: - LOGGER.debug("Received Reply that is unknown: %s", reply_id) - - def get_client(self): - return self.client - - -class ServerEndpoint(Endpoint, metaclass=EndpointMeta): - """ - A service that receives method calls over one or more transports - """ - __methods__ = {} - - def __init__(self, name, io_loop, transport=RESTTransport, interval=60, hangtime=None): - super().__init__(io_loop, name) - self._transport = transport - - self._transport_instance = Transport.create(self._transport, self) - self._sched = Scheduler(self._io_loop) - - self._heartbeat_cb = None - self.agent_handles = {} - self._sessions = {} - self.interval = interval - if hangtime is None: - hangtime = interval * 3 / 4 - self.hangtime = hangtime - - def schedule(self, call, interval=60): - self._sched.add_action(call, interval) - - def start(self): - """ - Start this end-point using the central configuration - """ - LOGGER.debug("Starting transport for endpoint %s", self.name) - if self._transport_instance is not None: - self._transport_instance.start_endpoint() - - def stop(self): - """ - Stop the end-point and all of its transports - """ - if self._transport_instance is not None: - self._transport_instance.stop_endpoint() - LOGGER.debug("Stopped %s", self._transport_instance) - # terminate all sessions cleanly - for session in self._sessions.copy().values(): - session.expire(0) - - def validate_sid(self, sid): - if isinstance(sid, str): - sid = uuid.UUID(sid) - return sid in self._sessions - - def get_or_create_session(self, sid, tid, endpoint_names, nodename): - if isinstance(sid, str): - sid = uuid.UUID(sid) - - if sid not in self._sessions: - session = self.new_session(sid, tid, endpoint_names, nodename) - self._sessions[sid] = session - else: - session = self._sessions[sid] - self.seen(session, endpoint_names) - - return session - - def new_session(self, sid, tid, endpoint_names, nodename): - LOGGER.debug("New session with id %s on node %s for env %s with endpoints %s" % (sid, nodename, tid, endpoint_names)) - return Session(self, self._io_loop, sid, self.hangtime, self.interval, tid, endpoint_names, nodename) - - def expire(self, session: Session, timeout): - LOGGER.debug("Expired session with id %s, last seen %d seconds ago" % (session.get_id(), timeout)) - del self._sessions[session.id] - - def seen(self, session: Session, endpoint_names: list): - LOGGER.debug("Seen session with id %s" % (session.get_id())) - session.seen() - - @handle(methods.HeartBeatMethod.heartbeat, env="tid") - @gen.coroutine - def heartbeat(self, sid, env, endpoint_names, nodename): - LOGGER.debug("Received heartbeat from %s for agents %s in %s", nodename, ",".join(endpoint_names), env.id) - - session = self.get_or_create_session(sid, env.id, endpoint_names, nodename) - - LOGGER.debug("Let node %s wait for method calls to become available. (long poll)", nodename) - call_list = yield session.get_calls() - if call_list is not None: - LOGGER.debug("Pushing %d method calls to node %s", len(call_list), nodename) - return 200, {"method_calls": call_list} - else: - LOGGER.debug("Heartbeat wait expired for %s, returning. (long poll)", nodename) - - return 200 - - @handle(methods.HeartBeatMethod.heartbeat_reply) - @gen.coroutine - def heartbeat_reply(self, sid, reply_id, data): - try: - env = self._sessions[sid] - env.set_reply(reply_id, data) - return 200 - except Exception: - LOGGER.warning("could not deliver agent reply with sid=%s and reply_id=%s" % (sid, reply_id), exc_info=True) - - class AgentEndPoint(Endpoint, metaclass=EndpointMeta): """ An endpoint for clients that make calls to a server and that receive calls back from the server using long-poll @@ -1411,7 +938,7 @@ def __init__(self, name, ioloop=None, transport=RESTTransport): self._transport_instance = None LOGGER.debug("Start transport for client %s", self.name) - tr = Transport.create(self._transport, self) + tr = self._transport(self) self._transport_instance = tr @gen.coroutine @@ -1428,6 +955,7 @@ class SyncClient(object): """ A synchronous client that communicates with end-point based on its configuration """ + def __init__(self, name, timeout=120): self.name = name self.timeout = timeout diff --git a/src/inmanta/server/__init__.py b/src/inmanta/server/__init__.py index 69c49ab020..3de52f0d73 100644 --- a/src/inmanta/server/__init__.py +++ b/src/inmanta/server/__init__.py @@ -1,17 +1,23 @@ """ - Copyright 2017 Inmanta + Copyright 2018 Inmanta + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + Contact: code@inmanta.com """ # flake8: noqa: F401 -from inmanta.server.server import Server \ No newline at end of file +SLICE_SERVER = "server" +SLICE_AGENT_MANAGER = "agentmanager" +SLICE_SESSION_MANAGER = "session" diff --git a/src/inmanta/server/agentmanager.py b/src/inmanta/server/agentmanager.py index beb2c8e53f..2f70b3b901 100644 --- a/src/inmanta/server/agentmanager.py +++ b/src/inmanta/server/agentmanager.py @@ -1,5 +1,5 @@ """ - Copyright 2017 Inmanta + Copyright 2018 Inmanta Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,8 +21,8 @@ from tornado import locks from inmanta.config import Config -from inmanta import data -from inmanta import protocol +from inmanta import data, methods +from inmanta.server import protocol, SLICE_AGENT_MANAGER from inmanta.asyncutil import retry_limited from . import config as server_config @@ -33,9 +33,15 @@ import sys import subprocess import uuid +from inmanta.server.protocol import ServerSlice +from inmanta.server import config as opt +from tornado.ioloop import IOLoop +from inmanta.protocol import encode_token LOGGER = logging.getLogger(__name__) + + agent_lock = locks.Lock() @@ -62,27 +68,36 @@ | | +---------------+ + +get_resources_for_agent + +resource_action_update + +dryrun_update + +set_parameters """ -class AgentManager(object): +class AgentManager(ServerSlice): ''' This class contains all server functionality related to the management of agents ''' - def __init__(self, server, closesessionsonstart=True, fact_back_off=60): - self._server = server + def __init__(self, restserver, closesessionsonstart=True, fact_back_off=None): + super(AgentManager, self).__init__(IOLoop.current(), SLICE_AGENT_MANAGER) + self.restserver = restserver + + if fact_back_off is None: + fact_back_off = opt.server_fact_resource_block.get() self._agent_procs = {} # env uuid -> subprocess.Popen - server.add_future(self.start_agents()) # back-off timer for fact requests self._fact_resource_block = fact_back_off # per resource time of last fact request self._fact_resource_block_set = {} - self._server_storage = server._server_storage - # session lock self.session_lock = locks.Lock() # all sessions @@ -92,13 +107,26 @@ def __init__(self, server, closesessionsonstart=True, fact_back_off=60): self.closesessionsonstart = closesessionsonstart - # From server - def new_session(self, session: protocol.Session): + def prestart(self, server): + ServerSlice.prestart(self, server) + self._server = server.get_endpoint("server") + self._server_storage = self._server._server_storage + server.get_endpoint("session").add_listener(self) + + def new_session(self, session): self.add_future(self.register_session(session, datetime.now())) - def expire(self, session: protocol.Session): + def expire(self, session, timeout): self.add_future(self.expire_session(session, datetime.now())) + def seen(self, session, endpoint_names): + if set(session.endpoint_names) != set(endpoint_names): + LOGGER.warning("Agent endpoint set changed, this should not occur, update ignored (was %s is %s)" % + (set(session.endpoint_names), set(endpoint_names))) + # start async, let it run free + self.add_future(self.flush_agent_presence(session, datetime.now())) + + # From server def get_agent_client(self, tid: uuid.UUID, endpoint): if isinstance(tid, str): tid = uuid.UUID(tid) @@ -107,26 +135,15 @@ def get_agent_client(self, tid: uuid.UUID, endpoint): return self.tid_endpoint_to_session[(tid, endpoint)].get_client() return None - def seen(self, session, endpoint_names): - if set(session.endpoint_names) != set(endpoint_names): - LOGGER.warning("Agent endpoint set changed, this should not occur, update ignored (was %s is %s)" % - (set(session.endpoint_names), set(endpoint_names))) - # start async, let it run free - self.add_future(self.flush_agent_presence(session, datetime.now())) - def start(self): + self.add_future(self.start_agents()) if self.closesessionsonstart: self.add_future(self.clean_db()) def stop(self): self.terminate_agents() - # To Server - def add_future(self, future): - self._server.add_future(future) - # Agent Management - @gen.coroutine def ensure_agent_registered(self, env: data.Environment, nodename: str): """ @@ -320,9 +337,25 @@ def _fork_inmanta(self, args, outfile, errfile, cwd=None): errhandle.close() # External APIS + @protocol.handle(methods.NodeMethod.get_agent_process, agent_id="id") + @gen.coroutine + def get_agent_process(self, agent_id): + return (yield self.get_agent_process_report(agent_id)) + + @protocol.handle(methods.ServerAgentApiMethod.trigger_agent, agent_id="id", env="tid") + @gen.coroutine + def trigger_agent(self, env, agent_id): + raise NotImplemented() + @protocol.handle(methods.NodeMethod.list_agent_processes) @gen.coroutine - def list_agent_processes(self, tid, expired): + def list_agent_processes(self, environment, expired): + if environment is not None: + env = yield data.Environment.get_by_id(environment) + if env is None: + return 404, {"message": "The given environment id does not exist!"} + + tid = environment if tid is not None: if expired: aps = yield data.AgentProcess.get_by_env(tid) @@ -347,6 +380,30 @@ def list_agent_processes(self, tid, expired): return 200, {"processes": processes} + @protocol.handle(methods.ServerAgentApiMethod.list_agents, env="tid") + @gen.coroutine + def list_agents(self, env): + if env is not None: + tid = env.id + ags = yield data.Agent.get_list(environment=tid) + else: + ags = yield data.Agent.get_list() + + return 200, {"agents": [a.to_dict() for a in ags], "servertime": datetime.now().isoformat()} + + @protocol.handle(methods.AgentRecovery.get_state, env="tid") + @gen.coroutine + def get_state(self, env: uuid.UUID, sid: uuid.UUID, agent: str): + tid = env.id + if isinstance(tid, str): + tid = uuid.UUID(tid) + key = (tid, agent) + if key in self.tid_endpoint_to_session: + session = self.tid_endpoint_to_session[(tid, agent)] + if session.id == sid: + return 200, {"enabled": True} + return 200, {"enabled": False} + @gen.coroutine def get_agent_process_report(self, apid: uuid.UUID): ap = yield data.AgentProcess.get_by_id(apid) @@ -359,15 +416,6 @@ def get_agent_process_report(self, apid: uuid.UUID): result = yield client.get_status() return result.code, result.get_result() - @gen.coroutine - def list_agents(self, tid): - if tid is not None: - ags = yield data.Agent.get_list(environment=tid) - else: - ags = yield data.Agent.get_list() - - return 200, {"agents": [a.to_dict() for a in ags], "servertime": datetime.now().isoformat()} - # Start/stop agents @gen.coroutine def _ensure_agents(self, env: data.Environment, agents: list, restart: bool=False): @@ -474,7 +522,7 @@ def _make_agent_config(self, env: data.Environment, agent_names: list, agent_map "statedir": privatestatedir, "agent_splay": agent_splay, "agent_interval": agent_interval} if server_config.server_enable_auth.get(): - token = protocol.encode_token(["agent"], environment_id) + token = encode_token(["agent"], environment_id) config += """ token=%s """ % (token) @@ -541,17 +589,6 @@ def _request_parameter(self, env_id: uuid.UUID, resource_id): else: return 404, {"message": "resource_id parameter is required."} - @gen.coroutine - def get_state(self, tid: uuid.UUID, sid: uuid.UUID, agent: str): - if isinstance(tid, str): - tid = uuid.UUID(tid) - key = (tid, agent) - if key in self.tid_endpoint_to_session: - session = self.tid_endpoint_to_session[(tid, agent)] - if session.id == sid: - return 200, {"enabled": True} - return 200, {"enabled": False} - @gen.coroutine def start_agents(self): """ diff --git a/src/inmanta/server/bootloader.py b/src/inmanta/server/bootloader.py new file mode 100644 index 0000000000..4b8bf28525 --- /dev/null +++ b/src/inmanta/server/bootloader.py @@ -0,0 +1,46 @@ +""" + Copyright 2018 Inmanta + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Contact: code@inmanta.com +""" +from tornado.ioloop import IOLoop +from inmanta.server import server +from inmanta.server.protocol import RESTServer +from inmanta.server.agentmanager import AgentManager + + +class InmantaBootloader(object): + + def __init__(self, agent_no_log=False): + self.restserver = RESTServer() + self.agent_no_log = agent_no_log + + def get_server_slice(self): + io_loop = IOLoop.current() + return server.Server(io_loop, agent_no_log=self.agent_no_log) + + def get_agent_manager_slice(self): + return AgentManager(self.restserver) + + def get_server_slices(self): + return [self.get_server_slice(), self.get_agent_manager_slice()] + + def start(self): + for mypart in self.get_server_slices(): + self.restserver.add_endpoint(mypart) + self.restserver.start() + + def stop(self): + self.restserver.stop() diff --git a/src/inmanta/server/config.py b/src/inmanta/server/config.py index d6d4355b91..bef8c8c19e 100644 --- a/src/inmanta/server/config.py +++ b/src/inmanta/server/config.py @@ -1,5 +1,5 @@ """ - Copyright 2017 Inmanta + Copyright 2018 Inmanta Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/src/inmanta/server/protocol.py b/src/inmanta/server/protocol.py new file mode 100644 index 0000000000..e7b3bf27f9 --- /dev/null +++ b/src/inmanta/server/protocol.py @@ -0,0 +1,602 @@ +""" + Copyright 2018 Inmanta + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Contact: code@inmanta.com +""" +from inmanta.util import Scheduler +from inmanta.protocol import RESTBase, decode_token, json_encode, UnauhorizedError, ReturnClient, handle + +from inmanta import config as inmanta_config, methods +from inmanta.server import config as opt, SLICE_SESSION_MANAGER + +import tornado.web +from tornado import gen, queues +from tornado.ioloop import IOLoop +from tornado.httpserver import HTTPServer + +import logging +import ssl +import time +import uuid +from _collections import defaultdict + + +LOGGER = logging.getLogger(__name__) + + +# Server Side +class RESTServer(RESTBase): + + def __init__(self, connection_timout=120): + self.__end_points = [] + self.__endpoint_dict = {} + self._handlers = [] + self.token = inmanta_config.Config.get(self.id, "token", None) + self.connection_timout = connection_timout + self.headers = set() + self.sessions_handler = SessionManager(IOLoop.current()) + self.add_endpoint(self.sessions_handler) + + def add_endpoint(self, endpoint: "ServerSlice"): + self.__end_points.append(endpoint) + self.__endpoint_dict[endpoint.name] = endpoint + + def get_endpoint(self, name): + return self.__endpoint_dict[name] + + def validate_sid(self, sid): + return self.sessions_handler.validate_sid(sid) + + def get_id(self): + """ + Returns a unique id for a transport on an endpoint + """ + return "server_rest_transport" + + id = property(get_id) + + def create_op_mapping(self): + """ + Build a mapping between urls, ops and methods + """ + url_map = defaultdict(dict) + + # TODO: avoid colliding handlers + + for endpoint in self.__end_points: + for method, method_handlers in endpoint.__methods__.items(): + properties = method.__protocol_properties__ + call = (endpoint, method_handlers[0]) + + if "arg_options" in properties: + for opts in properties["arg_options"].values(): + if "header" in opts: + self.headers.add(opts["header"]) + + url = self._create_base_url(properties) + properties["api_version"] = "1" + url_map[url][properties["operation"]] = (properties, call, method.__wrapped__) + url = self._create_base_url(properties, versioned=False) + properties = properties.copy() + properties["api_version"] = None + url_map[url][properties["operation"]] = (properties, call, method.__wrapped__) + return url_map + + def start(self): + """ + Start the transport + """ + LOGGER.debug("Starting Server Rest Endpoint") + + for endpoint in self.__end_points: + endpoint.prestart(self) + + for endpoint in self.__end_points: + endpoint.start() + self._handlers.extend(endpoint.get_handlers()) + + url_map = self.create_op_mapping() + + for url, configs in url_map.items(): + handler_config = {} + for op, cfg in configs.items(): + handler_config[op] = cfg + + self._handlers.append((url, RESTHandler, {"transport": self, "config": handler_config})) + LOGGER.debug("Registering handler(s) for url %s and methods %s" % (url, ", ".join(handler_config.keys()))) + + port = 8888 + if self.id in inmanta_config.Config.get() and "port" in inmanta_config.Config.get()[self.id]: + port = inmanta_config.Config.get()[self.id]["port"] + + application = tornado.web.Application(self._handlers, compress_response=True) + + crt = inmanta_config.Config.get("server", "ssl_cert_file", None) + key = inmanta_config.Config.get("server", "ssl_key_file", None) + + if(crt is not None and key is not None): + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain(crt, key) + + self.http_server = HTTPServer(application, decompress_request=True, ssl_options=ssl_ctx) + LOGGER.debug("Created REST transport with SSL") + else: + self.http_server = HTTPServer(application, decompress_request=True) + + self.http_server.listen(port) + + LOGGER.debug("Start REST transport") + + def stop(self): + LOGGER.debug("Stoppin Server Rest Endpoint") + self.http_server.stop() + for endpoint in self.__end_points: + endpoint.stop() + + def return_error_msg(self, status=500, msg="", headers={}): + body = {"message": msg} + headers["Content-Type"] = "application/json" + LOGGER.debug("Signaling error to client: %d, %s, %s", status, body, headers) + return body, headers, status + + +class ServerSlice(object): + """ + An API serving part of the server. + """ + + def __init__(self, io_loop, name): + self._name = name + self._io_loop = io_loop + + self.create_endpoint_metadata() + self._end_point_names = [] + self._handlers = [] + self._sched = Scheduler(self._io_loop) + + def prestart(self, server: RESTServer): + """Called by the RestServer host prior to start, can be used to collect references to other server slices""" + pass + + def start(self): + pass + + def stop(self): + pass + + name = property(lambda self: self._name) + + def get_handlers(self): + return self._handlers + + def get_end_point_names(self): + # TODO: why? + return self._end_point_names + + def add_end_point_name(self, name): + """ + Add an additional name to this endpoint to which it reacts and sends out in heartbeats + """ + LOGGER.debug("Adding '%s' as endpoint", name) + self._end_point_names.append(name) + + def add_future(self, future): + """ + Add a future to the ioloop to be handled, but do not require the result. + """ + def handle_result(f): + try: + f.result() + except Exception as e: + LOGGER.exception("An exception occurred while handling a future: %s", str(e)) + + self._io_loop.add_future(future, handle_result) + + def schedule(self, call, interval=60): + self._sched.add_action(call, interval) + + def create_endpoint_metadata(self): + total_dict = {method_name: getattr(self, method_name) + for method_name in dir(self) if callable(getattr(self, method_name))} + + methods = {} + for name, attr in total_dict.items(): + if name[0:2] != "__" and hasattr(attr, "__protocol_method__"): + if attr.__protocol_method__ in methods: + raise Exception("Unable to register multiple handlers for the same method. %s" % attr.__protocol_method__) + + methods[attr.__protocol_method__] = (name, attr) + + self.__methods__ = methods + + def add_static_handler(self, location, path, default_filename=None, start=False): + """ + Configure a static handler to serve data from the specified path. + """ + if location[0] != "/": + location = "/" + location + + if location[-1] != "/": + location = location + "/" + + options = {"path": path} + if default_filename is None: + options["default_filename"] = "index.html" + + self._handlers.append((r"%s(.*)" % location, tornado.web.StaticFileHandler, options)) + self._handlers.append((r"%s" % location[:-1], tornado.web.RedirectHandler, {"url": location})) + + if start: + self._handlers.append((r"/", tornado.web.RedirectHandler, {"url": location})) + + def add_static_content(self, path, content, content_type="application/javascript"): + self._handlers.append((r"%s(.*)" % path, StaticContentHandler, {"transport": self, "content": content, + "content_type": content_type})) + + +class Session(object): + """ + An environment that segments agents connected to the server + """ + + def __init__(self, sessionstore, io_loop, sid, hang_interval, timout, tid, endpoint_names, nodename): + self._sid = sid + self._interval = hang_interval + self._timeout = timout + self._sessionstore = sessionstore + self._seen = time.time() + self._callhandle = None + self.expired = False + + self.tid = tid + self.endpoint_names = endpoint_names + self.nodename = nodename + + self._io_loop = io_loop + + self._replies = {} + self.check_expire() + self._queue = queues.Queue() + + self.client = ReturnClient(str(sid), self) + + def check_expire(self): + if self.expired: + LOGGER.exception("Tried to expire session already expired") + ttw = self._timeout + self._seen - time.time() + if ttw < 0: + self.expire(self._seen - time.time()) + else: + self._callhandle = self._io_loop.call_later(ttw, self.check_expire) + + def get_id(self): + return self._sid + + id = property(get_id) + + def expire(self, timeout): + self.expired = True + if self._callhandle is not None: + self._io_loop.remove_timeout(self._callhandle) + self._sessionstore.expire(self, timeout) + + def seen(self): + self._seen = time.time() + + def _set_timeout(self, future, timeout, log_message): + def on_timeout(): + if not self.expired: + LOGGER.warning(log_message) + future.set_exception(gen.TimeoutError()) + + timeout_handle = self._io_loop.add_timeout(self._io_loop.time() + timeout, on_timeout) + future.add_done_callback(lambda _: self._io_loop.remove_timeout(timeout_handle)) + + def put_call(self, call_spec, timeout=10): + future = tornado.concurrent.Future() + + reply_id = uuid.uuid4() + + LOGGER.debug("Putting call %s: %s %s for agent %s in queue", reply_id, call_spec["method"], call_spec["url"], self._sid) + + q = self._queue + call_spec["reply_id"] = reply_id + q.put(call_spec) + self._set_timeout(future, timeout, "Call %s: %s %s for agent %s timed out." % + (reply_id, call_spec["method"], call_spec["url"], self._sid)) + self._replies[call_spec["reply_id"]] = future + + return future + + @gen.coroutine + def get_calls(self): + """ + Get all calls queued for a node. If no work is available, wait until timeout. This method returns none if a call + fails. + """ + try: + q = self._queue + call_list = [] + call = yield q.get(timeout=self._io_loop.time() + self._interval) + call_list.append(call) + while q.qsize() > 0: + call = yield q.get() + call_list.append(call) + + return call_list + + except gen.TimeoutError: + return None + + def set_reply(self, reply_id, data): + LOGGER.log(3, "Received Reply: %s", reply_id) + if reply_id in self._replies: + future = self._replies[reply_id] + del self._replies[reply_id] + if not future.done(): + future.set_result(data) + else: + LOGGER.debug("Received Reply that is unknown: %s", reply_id) + + def get_client(self): + return self.client + + +class SessionListener(object): + + def new_session(self, session: Session): + pass + + def expire(self, session: Session, timeout): + pass + + def seen(self, session: Session, endpoint_names: list): + pass + + +# Internals +class SessionManager(ServerSlice): + """ + A service that receives method calls over one or more transports + """ + __methods__ = {} + + def __init__(self, io_loop): + super().__init__(io_loop, SLICE_SESSION_MANAGER) + + # Config + interval = opt.agent_timeout.get() + hangtime = opt.agent_hangtime.get() + + if hangtime is None: + hangtime = interval * 3 / 4 + + self.hangtime = hangtime + self.interval = interval + + # Session management + self._heartbeat_cb = None + self.agent_handles = {} + self._sessions = {} + + # Listeners + self.listeners = [] + + def add_listener(self, listener): + self.listeners.append(listener) + + def stop(self): + """ + Stop the end-point and all of its transports + """ + # terminate all sessions cleanly + for session in self._sessions.copy().values(): + session.expire(0) + + def validate_sid(self, sid): + if isinstance(sid, str): + sid = uuid.UUID(sid) + return sid in self._sessions + + def get_or_create_session(self, sid, tid, endpoint_names, nodename): + if isinstance(sid, str): + sid = uuid.UUID(sid) + + if sid not in self._sessions: + session = self.new_session(sid, tid, endpoint_names, nodename) + self._sessions[sid] = session + for listener in self.listeners: + listener.new_session(session) + else: + session = self._sessions[sid] + self.seen(session, endpoint_names) + for listener in self.listeners: + listener.seen(session, endpoint_names) + + return session + + def new_session(self, sid, tid, endpoint_names, nodename): + LOGGER.debug("New session with id %s on node %s for env %s with endpoints %s" % (sid, nodename, tid, endpoint_names)) + return Session(self, self._io_loop, sid, self.hangtime, self.interval, tid, endpoint_names, nodename) + + def expire(self, session: Session, timeout): + LOGGER.debug("Expired session with id %s, last seen %d seconds ago" % (session.get_id(), timeout)) + for listener in self.listeners: + listener.expire(session, timeout) + del self._sessions[session.id] + + def seen(self, session: Session, endpoint_names: list): + LOGGER.debug("Seen session with id %s" % (session.get_id())) + session.seen() + + @handle(methods.HeartBeatMethod.heartbeat, env="tid") + @gen.coroutine + def heartbeat(self, sid, env, endpoint_names, nodename): + LOGGER.debug("Received heartbeat from %s for agents %s in %s", nodename, ",".join(endpoint_names), env.id) + + session = self.get_or_create_session(sid, env.id, endpoint_names, nodename) + + LOGGER.debug("Let node %s wait for method calls to become available. (long poll)", nodename) + call_list = yield session.get_calls() + if call_list is not None: + LOGGER.debug("Pushing %d method calls to node %s", len(call_list), nodename) + return 200, {"method_calls": call_list} + else: + LOGGER.debug("Heartbeat wait expired for %s, returning. (long poll)", nodename) + + return 200 + + @handle(methods.HeartBeatMethod.heartbeat_reply) + @gen.coroutine + def heartbeat_reply(self, sid, reply_id, data): + try: + env = self._sessions[sid] + env.set_reply(reply_id, data) + return 200 + except Exception: + LOGGER.warning("could not deliver agent reply with sid=%s and reply_id=%s" % (sid, reply_id), exc_info=True) + + +class RESTHandler(tornado.web.RequestHandler): + """ + A generic class use by the transport + """ + + def initialize(self, transport: "RESTServer", config): + self._transport = transport + self._config = config + + def _get_config(self, http_method): + if http_method.upper() not in self._config: + allowed = ", ".join(self._config.keys()) + self.set_header("Allow", allowed) + self._transport.return_error_msg(405, "%s is not supported for this url. Supported methods: %s" % + (http_method, allowed)) + return + + return self._config[http_method] + + def get_auth_token(self, headers: dict): + """ + Get the auth token provided by the caller. The token is provided as a bearer token. + """ + if "Authorization" not in headers: + return None + + parts = headers["Authorization"].split(" ") + if len(parts) == 0 or parts[0].lower() != "bearer" or len(parts) > 2 or len(parts) == 1: + LOGGER.warning("Invalid authentication header, Inmanta expects a bearer token. (%s was provided)", + headers["Authorization"]) + return None + + return decode_token(parts[1]) + + def respond(self, body, headers, status): + if body is not None: + self.write(json_encode(body)) + + for header, value in headers.items(): + self.set_header(header, value) + + self.set_status(status) + + @gen.coroutine + def _call(self, kwargs, http_method, call_config): + """ + An rpc like call + """ + if call_config is None: + body, headers, status = self._transport.return_error_msg(404, "This method does not exist.") + self.respond(body, headers, status) + return + + self.set_header("Access-Control-Allow-Origin", "*") + try: + message = self._transport._decode(self.request.body) + if message is None: + message = {} + + for key, value in self.request.query_arguments.items(): + if len(value) == 1: + message[key] = value[0].decode("latin-1") + else: + message[key] = [v.decode("latin-1") for v in value] + + request_headers = self.request.headers + + try: + auth_token = self.get_auth_token(request_headers) + except UnauhorizedError as e: + self.respond(*self._transport.return_error_msg(403, "Access denied: " + e.args[0])) + return + + auth_enabled = inmanta_config.Config.get("server", "auth", False) + if not auth_enabled or auth_token is not None: + result = yield self._transport._execute_call(kwargs, http_method, call_config, + message, request_headers, auth_token) + self.respond(*result) + else: + self.respond(*self._transport.return_error_msg(401, "Access to this resource is unauthorized.")) + except ValueError: + LOGGER.exception("An exception occured") + self.respond(*self._transport.return_error_msg(500, "Unable to decode request body")) + + @gen.coroutine + def head(self, *args, **kwargs): + yield self._call(http_method="HEAD", call_config=self._get_config("HEAD"), kwargs=kwargs) + + @gen.coroutine + def get(self, *args, **kwargs): + yield self._call(http_method="GET", call_config=self._get_config("GET"), kwargs=kwargs) + + @gen.coroutine + def post(self, *args, **kwargs): + yield self._call(http_method="POST", call_config=self._get_config("POST"), kwargs=kwargs) + + @gen.coroutine + def delete(self, *args, **kwargs): + yield self._call(http_method="DELETE", call_config=self._get_config("DELETE"), kwargs=kwargs) + + @gen.coroutine + def patch(self, *args, **kwargs): + yield self._call(http_method="PATCH", call_config=self._get_config("PATCH"), kwargs=kwargs) + + @gen.coroutine + def put(self, *args, **kwargs): + yield self._call(http_method="PUT", call_config=self._get_config("PUT"), kwargs=kwargs) + + @gen.coroutine + def options(self, *args, **kwargs): + allow_headers = "Origin, Accept, Content-Type, X-Requested-With, X-CSRF-Token" + if len(self._transport.headers): + allow_headers += ", " + ", ".join(self._transport.headers) + + self.set_header("Access-Control-Allow-Origin", "*") + self.set_header("Access-Control-Allow-Methods", "HEAD, GET, POST, PUT, OPTIONS, DELETE, PATCH") + self.set_header("Access-Control-Allow-Headers", allow_headers) + + self.set_status(200) + + +class StaticContentHandler(tornado.web.RequestHandler): + def initialize(self, transport: "RESTServer", content, content_type): + self._transport = transport + self._content = content + self._content_type = content_type + + def get(self, *args, **kwargs): + self.set_header("Content-Type", self._content_type) + self.write(self._content) + self.set_status(200) diff --git a/src/inmanta/server/server.py b/src/inmanta/server/server.py index ddd003aa25..6c4f899475 100644 --- a/src/inmanta/server/server.py +++ b/src/inmanta/server/server.py @@ -1,5 +1,5 @@ """ - Copyright 2017 Inmanta + Copyright 2018 Inmanta Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,13 +39,13 @@ from inmanta import const from inmanta import data, config from inmanta import methods -from inmanta import protocol +from inmanta.server import protocol, SLICE_SERVER from inmanta.ast import type from inmanta.resources import Id from inmanta.server import config as opt -from inmanta.server.agentmanager import AgentManager import json from inmanta.util import hash_file +from inmanta.protocol import encode_token LOGGER = logging.getLogger(__name__) agent_lock = locks.Lock() @@ -53,15 +53,16 @@ DBLIMIT = 100000 -class Server(protocol.ServerEndpoint): +class Server(protocol.ServerSlice): """ The central Inmanta server that communicates with clients and agents and persists configuration information """ def __init__(self, io_loop, database_host=None, database_port=None, agent_no_log=False): - super().__init__("server", io_loop=io_loop, interval=opt.agent_timeout.get(), hangtime=opt.agent_hangtime.get()) + super().__init__(io_loop=io_loop, name=SLICE_SERVER) LOGGER.info("Starting server endpoint") + self._server_storage = self.check_storage() self._agent_no_log = agent_no_log @@ -80,8 +81,6 @@ def __init__(self, io_loop, database_host=None, database_port=None, agent_no_log self._fact_expire = opt.server_fact_expire.get() self._fact_renew = opt.server_fact_renew.get() - self.add_end_point_name(self.node_name) - self.schedule(self.renew_expired_facts, self._fact_renew) self.schedule(self._purge_versions, opt.server_purge_version_interval.get()) @@ -89,31 +88,17 @@ def __init__(self, io_loop, database_host=None, database_port=None, agent_no_log self._recompiles = defaultdict(lambda: None) - self.agentmanager = AgentManager(self, fact_back_off=opt.server_fact_resource_block.get()) - self.setup_dashboard() self.dryrun_lock = locks.Lock() - def new_session(self, sid, tid, endpoint_names, nodename): - session = protocol.ServerEndpoint.new_session(self, sid, tid, endpoint_names, nodename) - self.agentmanager.new_session(session) - return session - - def expire(self, session, timeout): - self.agentmanager.expire(session) - protocol.ServerEndpoint.expire(self, session, timeout) - - def seen(self, session, endpoint_names): - self.agentmanager.seen(session, endpoint_names) - protocol.ServerEndpoint.seen(self, session, endpoint_names) + def prestart(self, server): + self.agentmanager = server.get_endpoint("agentmanager") def start(self): super().start() - self.agentmanager.start() def stop(self): super().stop() - self.agentmanager.stop() def get_agent_client(self, tid: UUID, endpoint): return self.agentmanager.get_agent_client(tid, endpoint) @@ -144,8 +129,8 @@ def setup_dashboard(self): 'backend': window.location.origin+'/'%s }); """ % auth - self._transport_instance.add_static_content("/dashboard/config.js", content=content) - self._transport_instance.add_static_handler("/dashboard", dashboard_path, start=True) + self.add_static_content("/dashboard/config.js", content=content) + self.add_static_handler("/dashboard", dashboard_path, start=True) @gen.coroutine def _purge_versions(self): @@ -644,36 +629,6 @@ def file_diff(self, a, b): return 200, {"diff": list(diff)} - @protocol.handle(methods.NodeMethod.get_agent_process, agent_id="id") - @gen.coroutine - def get_agent_process(self, agent_id): - return (yield self.agentmanager.get_agent_process_report(agent_id)) - - @protocol.handle(methods.ServerAgentApiMethod.trigger_agent, agent_id="id", env="tid") - @gen.coroutine - def trigger_agent(self, env, agent_id): - yield self.agentmanager.trigger_agent(env.id, agent_id) - - @protocol.handle(methods.NodeMethod.list_agent_processes) - @gen.coroutine - def list_agent_processes(self, environment, expired): - if environment is not None: - env = yield data.Environment.get_by_id(environment) - if env is None: - return 404, {"message": "The given environment id does not exist!"} - - return (yield self.agentmanager.list_agent_processes(environment, expired)) - - @protocol.handle(methods.ServerAgentApiMethod.list_agents, env="tid") - @gen.coroutine - def list_agents(self, env): - return (yield self.agentmanager.list_agents(env.id)) - - @protocol.handle(methods.AgentRecovery.get_state, env="tid") - @gen.coroutine - def get_state(self, env: uuid.UUID, sid: uuid.UUID, agent: str): - return (yield self.agentmanager.get_state(env.id, sid, agent)) - @protocol.handle(methods.ResourceMethod.get_resource, resource_id="id", env="tid") @gen.coroutine def get_resource(self, env, resource_id, logs, status, log_action, log_limit): @@ -1547,7 +1502,7 @@ def _recompile_environment(self, environment_id, update_repo=False, wait=0, meta cmd = inmanta_path + ["-vvv", "export", "-e", str(environment_id), "--server_address", server_address, "--server_port", opt.transport_port.get(), "--metadata", json.dumps(metadata)] if config.Config.get("server", "auth", False): - token = protocol.encode_token(["compiler", "api"], str(environment_id)) + token = encode_token(["compiler", "api"], str(environment_id)) cmd.append("--token") cmd.append(token) @@ -1819,4 +1774,4 @@ def create_token(self, env, client_types, idempotent): """ Create a new auth token for this environment """ - return 200, {"token": protocol.encode_token(client_types, str(env.id), idempotent)} + return 200, {"token": encode_token(client_types, str(env.id), idempotent)} diff --git a/src/inmanta/util.py b/src/inmanta/util.py index f24d85f700..b459532ac9 100644 --- a/src/inmanta/util.py +++ b/src/inmanta/util.py @@ -64,3 +64,48 @@ def hash_file(content): sha1sum.update(content) return sha1sum.hexdigest() + + +class Scheduler(object): + """ + An event scheduler class + """ + + def __init__(self, io_loop): + self._scheduled = set() + self._io_loop = io_loop + + def add_action(self, action, interval, initial_delay=None): + """ + Add a new action + + :param action A function to call periodically + :param interval The interval between execution of actions + :param initial_delay Delay to the first execution, default to interval + """ + + if initial_delay is None: + initial_delay = interval + + LOGGER.debug("Scheduling action %s every %d seconds with initial delay %d", action, interval, initial_delay) + + def action_function(): + LOGGER.info("Calling %s" % action) + if action in self._scheduled: + try: + action() + except Exception: + LOGGER.exception("Uncaught exception while executing scheduled action") + + finally: + self._io_loop.call_later(interval, action_function) + + self._io_loop.call_later(initial_delay, action_function) + self._scheduled.add(action) + + def remove(self, action): + """ + Remove a scheduled action + """ + if action in self._scheduled: + self._scheduled.remove(action) diff --git a/tests/conftest.py b/tests/conftest.py index e9fdc4376f..ea68c2af2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -40,6 +40,7 @@ from tornado import gen import re from tornado.ioloop import IOLoop +from inmanta.server.bootloader import InmantaBootloader DEFAULT_PORT_ENVVAR = 'MONGOBOX_PORT' @@ -142,11 +143,12 @@ def server(inmanta_config, io_loop, mongo_db, mongo_client, motor): # causes handler failure IOLoop._instance = io_loop - from inmanta.server import Server state_dir = tempfile.mkdtemp() port = get_free_tcp_port() - config.Config.get("database", "name", "inmanta-" + ''.join(random.choice(string.ascii_letters) for _ in range(10))) + config.Config.set("database", "name", "inmanta-" + ''.join(random.choice(string.ascii_letters) for _ in range(10))) + config.Config.set("database", "host", "localhost") + config.Config.set("database", "port", str(mongo_db.port)) config.Config.set("config", "state-dir", state_dir) config.Config.set("config", "log-dir", os.path.join(state_dir, "logs")) config.Config.set("server_rest_transport", "port", port) @@ -159,13 +161,13 @@ def server(inmanta_config, io_loop, mongo_db, mongo_client, motor): data.use_motor(motor) - server = Server(database_host="localhost", database_port=int(mongo_db.port), io_loop=io_loop) - server.start() + ibl = InmantaBootloader() + ibl.start() - yield server + yield ibl.restserver + ibl.stop() del IOLoop._instance - server.stop() shutil.rmtree(state_dir) @@ -174,7 +176,8 @@ def server(inmanta_config, io_loop, mongo_db, mongo_client, motor): (False, False, False), (True, True, True)], ids=["SSL and Auth", "SSL", "Auth", "Normal", "SSL and Auth with not self signed certificate"]) def server_multi(inmanta_config, io_loop, mongo_db, mongo_client, request): - from inmanta.server import Server + IOLoop._instance = io_loop + state_dir = tempfile.mkdtemp() ssl, auth, ca = request.param @@ -208,7 +211,9 @@ def server_multi(inmanta_config, io_loop, mongo_db, mongo_client, request): config.Config.set(x, "token", token) port = get_free_tcp_port() - config.Config.get("database", "name", "inmanta-" + ''.join(random.choice(string.ascii_letters) for _ in range(10))) + config.Config.set("database", "name", "inmanta-" + ''.join(random.choice(string.ascii_letters) for _ in range(10))) + config.Config.set("database", "host", "localhost") + config.Config.set("database", "port", str(mongo_db.port)) config.Config.set("config", "state-dir", state_dir) config.Config.set("config", "log-dir", os.path.join(state_dir, "logs")) config.Config.set("server_rest_transport", "port", port) @@ -219,12 +224,18 @@ def server_multi(inmanta_config, io_loop, mongo_db, mongo_client, request): config.Config.set("config", "executable", os.path.abspath(os.path.join(__file__, "../../src/inmanta/app.py"))) config.Config.set("server", "agent-timeout", "2") - server = Server(database_host="localhost", database_port=int(mongo_db.port), io_loop=io_loop) - server.start() + ibl = InmantaBootloader() + ibl.start() + + yield ibl.restserver + + ibl.stop() - yield server + try: + del IOLoop._instance + except Exception: + pass - server.stop() shutil.rmtree(state_dir) diff --git a/tests/test_2way_protocol.py b/tests/test_2way_protocol.py index f041cce40a..ed7a9018e8 100644 --- a/tests/test_2way_protocol.py +++ b/tests/test_2way_protocol.py @@ -21,12 +21,16 @@ import uuid import colorlog -from inmanta import methods +from inmanta import methods, data from tornado import gen import pytest from tornado.gen import sleep from utils import retry_limited from tornado.ioloop import IOLoop +from inmanta.server.protocol import RESTServer, SessionListener, ServerSlice +from inmanta.server import SLICE_SESSION_MANAGER, server +from inmanta.methods import ENV_ARG +import importlib LOGGER = logging.getLogger(__name__) @@ -39,7 +43,7 @@ def get_status_x(self, tid: uuid.UUID): pass @methods.protocol(operation="GET", id=True, server_agent=True, timeout=10) - def get_agent_status(self, id): + def get_agent_status_x(self, id): pass @@ -47,38 +51,55 @@ def get_agent_status(self, id): from inmanta import protocol # NOQA -class Server(protocol.ServerEndpoint): +class SessionSpy(SessionListener, ServerSlice): - def __init__(self, name, io_loop, interval=60): - protocol.ServerEndpoint.__init__(self, name, io_loop, interval=interval) + def __init__(self): + ServerSlice.__init__(self, IOLoop.current(), "sessionspy") self.expires = 0 + self.__sessions = [] + + def new_session(self, session): + self.__sessions.append(session) @protocol.handle(StatusMethod.get_status_x) @gen.coroutine def get_status_x(self, tid): status_list = [] - for session in self._sessions.values(): + for session in self.__sessions: client = session.get_client() - status = yield client.get_agent_status("x") + status = yield client.get_agent_status_x("x") if status is not None and status.code == 200: status_list.append(status.result) return 200, {"agents": status_list} def expire(self, session, timeout): - protocol.ServerEndpoint.expire(self, session, timeout) + self.__sessions.remove(session) print(session._sid) self.expires += 1 + def get_sessions(self): + return self.__sessions + class Agent(protocol.AgentEndPoint): - @protocol.handle(StatusMethod.get_agent_status) + @protocol.handle(StatusMethod.get_agent_status_x) @gen.coroutine - def get_agent_status(self, id): + def get_agent_status_x(self, id): return 200, {"status": "ok", "agents": self.end_point_names} +importlib.reload(protocol) +importlib.reload(server.protocol) + + +@gen.coroutine +def get_environment(env: uuid.UUID, metadata: dict): + return data.Environment(from_mongo=True, _id=env, name="test", project=env, repo_url="xx", repo_branch="xx") + + +@pytest.mark.gen_test(timeout=30) def test_2way_protocol(free_port, logs=False): from inmanta.config import Config @@ -118,17 +139,26 @@ def test_2way_protocol(free_port, logs=False): Config.set("client_rest_transport", "port", free_port) Config.set("cmdline_rest_transport", "port", free_port) - io_loop = IOLoop.current() - server = Server("server", io_loop) - server.start() + # Disable validation of envs + old_get_env = ENV_ARG["getter"] + ENV_ARG["getter"] = get_environment - agent = Agent("agent", io_loop) - agent.add_end_point_name("agent") - agent.set_environment(uuid.uuid4()) - agent.start() + try: + io_loop = IOLoop.current() + rs = RESTServer() + server = SessionSpy() + rs.get_endpoint(SLICE_SESSION_MANAGER).add_listener(server) + rs.add_endpoint(server) + rs.start() + + agent = Agent("agent", io_loop) + agent.add_end_point_name("agent") + agent.set_environment(uuid.uuid4()) + agent.start() + + yield retry_limited(lambda: len(server.get_sessions()) == 1, 0.1) + assert len(server.get_sessions()) == 1 - @gen.coroutine - def do_call(): client = protocol.Client("client") status = yield client.get_status_x(str(agent.environment)) assert status.code == 200 @@ -138,24 +168,21 @@ def do_call(): server.stop() io_loop.stop() - io_loop.add_callback(do_call) - io_loop.add_timeout(io_loop.time() + 2, lambda: io_loop.stop()) - try: - io_loop.start() - except KeyboardInterrupt: - io_loop.stop() - server.stop() - agent.stop() + rs.stop() + agent.stop() + finally: + ENV_ARG["getter"] = old_get_env @gen.coroutine def check_sessions(sessions): for s in sessions: - a = yield s.client.get_agent_status("X") + a = yield s.client.get_agent_status_x("X") assert a.get_result()['status'] == 'ok' @pytest.mark.slowtest +@pytest.mark.gen_test(timeout=30) def test_timeout(free_port): from inmanta.config import Config @@ -171,22 +198,31 @@ def test_timeout(free_port): Config.set("compiler_rest_transport", "port", free_port) Config.set("client_rest_transport", "port", free_port) Config.set("cmdline_rest_transport", "port", free_port) - server = Server("server", io_loop, interval=2) - server.start() + Config.set("server", "agent-timeout", "1") + + # Disable validation of envs + old_get_env = ENV_ARG["getter"] + ENV_ARG["getter"] = get_environment - env = uuid.uuid4() + try: - # agent 1 - agent = Agent("agent", io_loop) - agent.add_end_point_name("agent") - agent.set_environment(env) - agent.start() + rs = RESTServer() + server = SessionSpy() + rs.get_endpoint(SLICE_SESSION_MANAGER).add_listener(server) + rs.add_endpoint(server) + rs.start() + + env = uuid.uuid4() + + # agent 1 + agent = Agent("agent", io_loop) + agent.add_end_point_name("agent") + agent.set_environment(env) + agent.start() - @gen.coroutine - def do_call(): # wait till up - yield retry_limited(lambda: len(server._sessions) == 1, 0.1) - assert len(server._sessions) == 1 + yield retry_limited(lambda: len(server.get_sessions()) == 1, 0.1) + assert len(server.get_sessions()) == 1 # agent 2 agent2 = Agent("agent", io_loop) @@ -195,14 +231,14 @@ def do_call(): agent2.start() # wait till up - yield retry_limited(lambda: len(server._sessions) == 2, 0.1) - assert len(server._sessions) == 2 + yield retry_limited(lambda: len(server.get_sessions()) == 2, 0.1) + assert len(server.get_sessions()) == 2 # see if it stays up - yield(check_sessions(server._sessions.values())) + yield(check_sessions(server.get_sessions())) yield sleep(2) - assert len(server._sessions) == 2 - yield(check_sessions(server._sessions.values())) + assert len(server.get_sessions()) == 2 + yield(check_sessions(server.get_sessions())) # take it down agent2.stop() @@ -210,20 +246,14 @@ def do_call(): # timout yield sleep(2) # check if down - assert len(server._sessions) == 1 - print(server._sessions) - yield(check_sessions(server._sessions.values())) + assert len(server.get_sessions()) == 1 + print(server.get_sessions()) + yield(check_sessions(server.get_sessions())) assert server.expires == 1 agent.stop() server.stop() - io_loop.stop() - - io_loop.add_callback(do_call) - io_loop.add_timeout(io_loop.time() + 2, lambda: io_loop.stop()) - try: - io_loop.start() - except KeyboardInterrupt: - io_loop.stop() - server.stop() - agent.stop() + rs.stop() + agent.stop() + finally: + ENV_ARG["getter"] = old_get_env diff --git a/tests/test_agent.py b/tests/test_agent.py index aaffcc81cc..9f662e8dde 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -19,6 +19,7 @@ import pytest from utils import retry_limited from inmanta.agent import reporting +from inmanta.server import SLICE_SESSION_MANAGER @pytest.mark.slowtest @@ -29,8 +30,8 @@ def test_agent_get_status(io_loop, server, environment): myagent.add_end_point_name("agent1") myagent.start() - yield retry_limited(lambda: len(server._sessions) == 1, 0.5) - clients = server._sessions.values() + yield retry_limited(lambda: len(server.get_endpoint(SLICE_SESSION_MANAGER)._sessions) == 1, 0.5) + clients = server.get_endpoint(SLICE_SESSION_MANAGER)._sessions.values() assert len(clients) == 1 clients = [x for x in clients] client = clients[0].get_client() diff --git a/tests/test_agent_manager.py b/tests/test_agent_manager.py index 6710afd120..c36c724008 100644 --- a/tests/test_agent_manager.py +++ b/tests/test_agent_manager.py @@ -84,6 +84,7 @@ def test_primary_selection(motor): futures = Collector() server.add_future.side_effect = futures am = AgentManager(server, False) + am.add_future = futures @gen.coroutine def assert_agent(name: str, state: str, sid: UUID): @@ -134,7 +135,7 @@ def assert_agents(s1, s2, s3, sid1=None, sid2=None, sid3=None): yield assert_agents("paused", "up", "up", sid2=ts1.id, sid3=ts2.id) # expire first - am.expire(ts1) + am.expire(ts1, 100) yield futures.proccess() assert len(am.sessions) == 1 ts2.get_client().set_state.assert_called_with("agent2", True) @@ -142,7 +143,7 @@ def assert_agents(s1, s2, s3, sid1=None, sid2=None, sid3=None): yield assert_agents("paused", "up", "up", sid2=ts2.id, sid3=ts2.id) # expire second - am.expire(ts2) + am.expire(ts2, 100) yield futures.proccess() assert len(am.sessions) == 0 yield assert_agents("paused", "down", "down") @@ -165,6 +166,7 @@ def test_api(motor): futures = Collector() server.add_future.side_effect = futures am = AgentManager(server, False) + am.add_future = futures # one session ts1 = MockSession(uuid4(), env.id, ["agent1", "agent2"], "ts1") @@ -247,7 +249,7 @@ def dummy_status(): 'environment': env2.id, "state": "down"}]} assert_equal_ish(shouldbe, all_agents, ['name']) - code, all_agents = yield am.list_agents(env2.id) + code, all_agents = yield am.list_agents(env2) assert code == 200 shouldbe = { 'agents': [{'name': 'agent4', 'paused': False, 'last_failover': '', 'primary': '', @@ -269,6 +271,7 @@ def test_db_clean(motor): futures = Collector() server.add_future.side_effect = futures am = AgentManager(server, False) + am.add_future = futures @gen.coroutine def assert_agent(name: str, state: str, sid: UUID): @@ -318,7 +321,7 @@ def assert_agents(s1, s2, s3, sid1=None, sid2=None, sid3=None): yield assert_agents("paused", "up", "up", sid2=ts1.id, sid3=ts2.id) # expire first - am.expire(ts1) + am.expire(ts1, 100) yield futures.proccess() assert len(am.sessions) == 1 ts2.get_client().set_state.assert_called_with("agent2", True) @@ -327,6 +330,7 @@ def assert_agents(s1, s2, s3, sid1=None, sid2=None, sid3=None): # failover am = AgentManager(server, False) + am.add_future = futures yield am.clean_db() # one session @@ -354,7 +358,7 @@ def assert_agents(s1, s2, s3, sid1=None, sid2=None, sid3=None): yield assert_agents("paused", "up", "up", sid2=ts1.id, sid3=ts2.id) # expire first - am.expire(ts1) + am.expire(ts1, 100) yield futures.proccess() assert len(am.sessions) == 1 ts2.get_client().set_state.assert_called_with("agent2", True) @@ -362,7 +366,7 @@ def assert_agents(s1, s2, s3, sid1=None, sid2=None, sid3=None): yield assert_agents("paused", "up", "up", sid2=ts2.id, sid3=ts2.id) # expire second - am.expire(ts2) + am.expire(ts2, 100) yield futures.proccess() assert len(am.sessions) == 0 yield assert_agents("paused", "down", "down") diff --git a/tests/test_server.py b/tests/test_server.py index 2bf9230dc5..33a564f99a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -25,7 +25,7 @@ from inmanta.agent.agent import Agent from inmanta import data, protocol from inmanta import const -from inmanta.server import config as opt +from inmanta.server import config as opt, SLICE_AGENT_MANAGER, SLICE_SESSION_MANAGER from datetime import datetime from uuid import UUID from inmanta.export import upload_code @@ -43,35 +43,38 @@ def test_autostart(server, client, environment): env = yield data.Environment.get_by_id(uuid.UUID(environment)) yield env.set(data.AUTOSTART_AGENT_MAP, {"iaas_agent": "", "iaas_agentx": ""}) - yield server.agentmanager.ensure_agent_registered(env, "iaas_agent") - yield server.agentmanager.ensure_agent_registered(env, "iaas_agentx") + agentmanager = server.get_endpoint(SLICE_AGENT_MANAGER) + sessionendpoint = server.get_endpoint(SLICE_SESSION_MANAGER) - res = yield server.agentmanager._ensure_agents(env, ["iaas_agent"]) + yield agentmanager.ensure_agent_registered(env, "iaas_agent") + yield agentmanager.ensure_agent_registered(env, "iaas_agentx") + + res = yield agentmanager._ensure_agents(env, ["iaas_agent"]) assert res - yield retry_limited(lambda: len(server._sessions) == 1, 20) - assert len(server._sessions) == 1 - res = yield server.agentmanager._ensure_agents(env, ["iaas_agent"]) + yield retry_limited(lambda: len(sessionendpoint._sessions) == 1, 20) + assert len(sessionendpoint._sessions) == 1 + res = yield agentmanager._ensure_agents(env, ["iaas_agent"]) assert not res - assert len(server._sessions) == 1 + assert len(sessionendpoint._sessions) == 1 LOGGER.warning("Killing agent") - server.agentmanager._agent_procs[env.id].terminate() - yield retry_limited(lambda: len(server._sessions) == 0, 20) - res = yield server.agentmanager._ensure_agents(env, ["iaas_agent"]) + agentmanager._agent_procs[env.id].terminate() + yield retry_limited(lambda: len(sessionendpoint._sessions) == 0, 20) + res = yield agentmanager._ensure_agents(env, ["iaas_agent"]) assert res - yield retry_limited(lambda: len(server._sessions) == 1, 3) - assert len(server._sessions) == 1 + yield retry_limited(lambda: len(sessionendpoint._sessions) == 1, 3) + assert len(sessionendpoint._sessions) == 1 # second agent for same env - res = yield server.agentmanager._ensure_agents(env, ["iaas_agentx"]) + res = yield agentmanager._ensure_agents(env, ["iaas_agentx"]) assert res - yield retry_limited(lambda: len(server._sessions) == 1, 20) - assert len(server._sessions) == 1 + yield retry_limited(lambda: len(sessionendpoint._sessions) == 1, 20) + assert len(sessionendpoint._sessions) == 1 # Test stopping all agents - yield server.agentmanager.stop_agents(env) - assert len(server._sessions) == 0 - assert len(server.agentmanager._agent_procs) == 0 + yield agentmanager.stop_agents(env) + assert len(sessionendpoint._sessions) == 0 + assert len(agentmanager._agent_procs) == 0 @pytest.mark.gen_test(timeout=60) @@ -80,6 +83,10 @@ def test_autostart_dual_env(client, server): """ Test auto start of agent """ + + agentmanager = server.get_endpoint("server").agentmanager + sessionendpoint = server.get_endpoint("session") + result = yield client.create_project("env-test") assert result.code == 200 project_id = result.result["project"]["id"] @@ -96,18 +103,18 @@ def test_autostart_dual_env(client, server): env2 = yield data.Environment.get_by_id(uuid.UUID(env_id2)) yield env2.set(data.AUTOSTART_AGENT_MAP, {"iaas_agent": ""}) - yield server.agentmanager.ensure_agent_registered(env, "iaas_agent") - yield server.agentmanager.ensure_agent_registered(env2, "iaas_agent") + yield agentmanager.ensure_agent_registered(env, "iaas_agent") + yield agentmanager.ensure_agent_registered(env2, "iaas_agent") - res = yield server.agentmanager._ensure_agents(env, ["iaas_agent"]) + res = yield agentmanager._ensure_agents(env, ["iaas_agent"]) assert res - yield retry_limited(lambda: len(server._sessions) == 1, 20) - assert len(server._sessions) == 1 + yield retry_limited(lambda: len(sessionendpoint._sessions) == 1, 20) + assert len(sessionendpoint._sessions) == 1 - res = yield server.agentmanager._ensure_agents(env2, ["iaas_agent"]) + res = yield agentmanager._ensure_agents(env2, ["iaas_agent"]) assert res - yield retry_limited(lambda: len(server._sessions) == 2, 20) - assert len(server._sessions) == 2 + yield retry_limited(lambda: len(sessionendpoint._sessions) == 2, 20) + assert len(sessionendpoint._sessions) == 2 @pytest.mark.gen_test(timeout=60) @@ -119,28 +126,31 @@ def test_autostart_batched(client, server, environment): env = yield data.Environment.get_by_id(uuid.UUID(environment)) yield env.set(data.AUTOSTART_AGENT_MAP, {"iaas_agent": "", "iaas_agentx": ""}) - yield server.agentmanager.ensure_agent_registered(env, "iaas_agent") - yield server.agentmanager.ensure_agent_registered(env, "iaas_agentx") + agentmanager = server.get_endpoint(SLICE_AGENT_MANAGER) + sessionendpoint = server.get_endpoint(SLICE_SESSION_MANAGER) + + yield agentmanager.ensure_agent_registered(env, "iaas_agent") + yield agentmanager.ensure_agent_registered(env, "iaas_agentx") - res = yield server.agentmanager._ensure_agents(env, ["iaas_agent", "iaas_agentx"]) + res = yield agentmanager._ensure_agents(env, ["iaas_agent", "iaas_agentx"]) assert res - yield retry_limited(lambda: len(server._sessions) == 1, 20) - assert len(server._sessions) == 1 - res = yield server.agentmanager._ensure_agents(env, ["iaas_agent"]) + yield retry_limited(lambda: len(sessionendpoint._sessions) == 1, 20) + assert len(sessionendpoint._sessions) == 1 + res = yield agentmanager._ensure_agents(env, ["iaas_agent"]) assert not res - assert len(server._sessions) == 1 + assert len(sessionendpoint._sessions) == 1 - res = yield server.agentmanager._ensure_agents(env, ["iaas_agent", "iaas_agentx"]) + res = yield agentmanager._ensure_agents(env, ["iaas_agent", "iaas_agentx"]) assert not res - assert len(server._sessions) == 1 + assert len(sessionendpoint._sessions) == 1 LOGGER.warning("Killing agent") - server.agentmanager._agent_procs[env.id].terminate() - yield retry_limited(lambda: len(server._sessions) == 0, 20) - res = yield server.agentmanager._ensure_agents(env, ["iaas_agent", "iaas_agentx"]) + agentmanager._agent_procs[env.id].terminate() + yield retry_limited(lambda: len(sessionendpoint._sessions) == 0, 20) + res = yield agentmanager._ensure_agents(env, ["iaas_agent", "iaas_agentx"]) assert res - yield retry_limited(lambda: len(server._sessions) == 1, 3) - assert len(server._sessions) == 1 + yield retry_limited(lambda: len(sessionendpoint._sessions) == 1, 3) + assert len(sessionendpoint._sessions) == 1 @pytest.mark.gen_test(timeout=10) @@ -160,7 +170,7 @@ def test_version_removal(client, server): for _i in range(20): version += 1 - yield server._purge_versions() + yield server.get_endpoint("server")._purge_versions() res = yield client.put_version(tid=env_id, version=version, resources=[], unknowns=[], version_info={}) assert res.code == 200 result = yield client.get_project(id=project_id) diff --git a/tests/test_server_agent.py b/tests/test_server_agent.py index bdd622ebdf..b808056d68 100644 --- a/tests/test_server_agent.py +++ b/tests/test_server_agent.py @@ -36,8 +36,9 @@ from inmanta.agent.agent import Agent from utils import retry_limited, assert_equal_ish, UNKWN from inmanta.config import Config -from inmanta.server.server import Server from inmanta.ast import CompilerException +from inmanta.server.bootloader import InmantaBootloader +from inmanta.server import SLICE_AGENT_MANAGER logger = logging.getLogger("inmanta.test.server_agent") @@ -228,6 +229,9 @@ def test_dryrun_and_deploy(io_loop, server_multi, client_multi, resource_contain There is a second agent with an undefined resource. The server will shortcut the dryrun and deploy for this resource without an agent being present. """ + + agentmanager = server_multi.get_endpoint(SLICE_AGENT_MANAGER) + resource_container.Provider.reset() result = yield client_multi.create_project("env-test") project_id = result.result["project"]["id"] @@ -240,7 +244,7 @@ def test_dryrun_and_deploy(io_loop, server_multi, client_multi, resource_contain agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server_multi.agentmanager.sessions) == 1, 10) + yield retry_limited(lambda: len(agentmanager.sessions) == 1, 10) resource_container.Provider.set("agent1", "key2", "incorrect_value") resource_container.Provider.set("agent1", "key3", "value") @@ -369,6 +373,8 @@ def test_server_restart(resource_container, io_loop, server, mongo_db, client): """ dryrun and deploy a configuration model """ + agentmanager = server.get_endpoint(SLICE_AGENT_MANAGER) + resource_container.Provider.reset() result = yield client.create_project("env-test") project_id = result.result["project"]["id"] @@ -380,16 +386,19 @@ def test_server_restart(resource_container, io_loop, server, mongo_db, client): code_loader=False) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 1, 10) + yield retry_limited(lambda: len(agentmanager.sessions) == 1, 10) resource_container.Provider.set("agent1", "key2", "incorrect_value") resource_container.Provider.set("agent1", "key3", "value") server.stop() - server = Server(database_host="localhost", database_port=int(mongo_db.port), io_loop=io_loop) - server.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 1, 10) + ibl = InmantaBootloader() + server = ibl.restserver + ibl.start() + agentmanager = server.get_endpoint(SLICE_AGENT_MANAGER) + + yield retry_limited(lambda: len(agentmanager.sessions) == 1, 10) version = int(time.time()) @@ -482,7 +491,7 @@ def test_server_restart(resource_container, io_loop, server, mongo_db, client): assert not resource_container.Provider.isset("agent1", "key3") agent.stop() - server.stop() + ibl.stop() @pytest.mark.gen_test(timeout=30) @@ -490,6 +499,8 @@ def test_spontaneous_deploy(resource_container, io_loop, server, client): """ dryrun and deploy a configuration model """ + agentmanager = server.get_endpoint(SLICE_AGENT_MANAGER) + resource_container.Provider.reset() result = yield client.create_project("env-test") project_id = result.result["project"]["id"] @@ -504,7 +515,7 @@ def test_spontaneous_deploy(resource_container, io_loop, server, client): code_loader=False) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 1, 10) + yield retry_limited(lambda: len(agentmanager.sessions) == 1, 10) resource_container.Provider.set("agent1", "key2", "incorrect_value") resource_container.Provider.set("agent1", "key3", "value") @@ -583,7 +594,7 @@ def test_dual_agent(resource_container, io_loop, server, client, environment): myagent.add_end_point_name("agent1") myagent.add_end_point_name("agent2") myagent.start() - yield retry_limited(lambda: len(server._sessions) == 1, 10) + yield retry_limited(lambda: len(server.get_endpoint("session")._sessions) == 1, 10) resource_container.Provider.set("agent1", "key1", "incorrect_value") resource_container.Provider.set("agent2", "key1", "incorrect_value") @@ -682,7 +693,7 @@ def test_snapshot_restore(resource_container, client, server, io_loop): code_loader=False) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server._sessions) == 1, 10) + yield retry_limited(lambda: len(server.get_endpoint("session")._sessions) == 1, 10) resource_container.Provider.set("agent1", "key", "value") @@ -776,6 +787,8 @@ def test_snapshot_restore(resource_container, client, server, io_loop): @pytest.mark.gen_test def test_server_agent_api(resource_container, client, server, io_loop): + agentmanager = server.get_endpoint(SLICE_AGENT_MANAGER) + result = yield client.create_project("env-test") project_id = result.result["project"]["id"] @@ -789,8 +802,8 @@ def test_server_agent_api(resource_container, client, server, io_loop): code_loader=False) agent.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 2, 10) - assert len(server.agentmanager.sessions) == 2 + yield retry_limited(lambda: len(agentmanager.sessions) == 2, 10) + assert len(agentmanager.sessions) == 2 result = yield client.list_agent_processes(env_id) assert result.code == 200 @@ -881,7 +894,7 @@ def test_get_facts(resource_container, client, server, io_loop): code_loader=False) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server._sessions) == 1, 10) + yield retry_limited(lambda: len(server.get_endpoint("session")._sessions) == 1, 10) resource_container.Provider.set("agent1", "key", "value") @@ -929,7 +942,7 @@ def test_purged_facts(resource_container, client, server, io_loop, environment): code_loader=False) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server._sessions) == 1, 10) + yield retry_limited(lambda: len(server.get_endpoint("session")._sessions) == 1, 10) resource_container.Provider.set("agent1", "key", "value") @@ -1026,7 +1039,7 @@ def test_unkown_parameters(resource_container, client, server, io_loop): code_loader=False) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server._sessions) == 1, 10) + yield retry_limited(lambda: len(server.get_endpoint("session")._sessions) == 1, 10) resource_container.Provider.set("agent1", "key", "value") @@ -1054,7 +1067,7 @@ def test_unkown_parameters(resource_container, client, server, io_loop): result = yield client.release_version(env_id, version, True) assert result.code == 200 - yield server.renew_expired_facts() + yield server.get_endpoint("server").renew_expired_facts() env_id = uuid.UUID(env_id) params = yield data.Parameter.get_list(environment=env_id, resource_id=resource_id_wov) @@ -1082,7 +1095,7 @@ def test_fail(resource_container, client, server, io_loop): code_loader=False, poolsize=10) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server._sessions) == 1, 10) + yield retry_limited(lambda: len(server.get_endpoint("session")._sessions) == 1, 10) resource_container.Provider.set("agent1", "key", "value") @@ -1192,7 +1205,7 @@ def test_wait(resource_container, client, server, io_loop): agent.start() # wait for agent - yield retry_limited(lambda: len(server._sessions) == 1, 10) + yield retry_limited(lambda: len(server.get_endpoint("session")._sessions) == 1, 10) # set the deploy environment resource_container.Provider.set("agent1", "key", "value") @@ -1343,7 +1356,7 @@ def test_multi_instance(resource_container, client, server, io_loop): agent.start() # wait for agent - yield retry_limited(lambda: len(server._sessions) == 1, 10) + yield retry_limited(lambda: len(server.get_endpoint("session")._sessions) == 1, 10) # set the deploy environment resource_container.Provider.set("agent1", "key", "value") @@ -1454,6 +1467,8 @@ def test_cross_agent_deps(resource_container, io_loop, server, client): """ deploy a configuration model with cross host dependency """ + agentmanager = server.get_endpoint(SLICE_AGENT_MANAGER) + resource_container.Provider.reset() # config for recovery mechanism Config.set("config", "agent-interval", "10") @@ -1467,13 +1482,13 @@ def test_cross_agent_deps(resource_container, io_loop, server, client): code_loader=False) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 1, 10) + yield retry_limited(lambda: len(agentmanager.sessions) == 1, 10) agent2 = Agent(io_loop, hostname="node2", environment=env_id, agent_map={"agent2": "localhost"}, code_loader=False) agent2.add_end_point_name("agent2") agent2.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 2, 10) + yield retry_limited(lambda: len(agentmanager.sessions) == 2, 10) resource_container.Provider.set("agent1", "key2", "incorrect_value") resource_container.Provider.set("agent1", "key3", "value") @@ -1559,6 +1574,8 @@ def test_dryrun_scale(resource_container, io_loop, server, client): """ test dryrun scaling """ + agentmanager = server.get_endpoint(SLICE_AGENT_MANAGER) + resource_container.Provider.reset() result = yield client.create_project("env-test") project_id = result.result["project"]["id"] @@ -1570,7 +1587,7 @@ def test_dryrun_scale(resource_container, io_loop, server, client): code_loader=False) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 1, 10) + yield retry_limited(lambda: len(agentmanager.sessions) == 1, 10) version = int(time.time()) @@ -1617,12 +1634,14 @@ def test_send_events(resource_container, io_loop, environment, server, client): """ Send and receive events within one agent """ + agentmanager = server.get_endpoint(SLICE_AGENT_MANAGER) + resource_container.Provider.reset() agent = Agent(io_loop, hostname="node1", environment=environment, agent_map={"agent1": "localhost"}, code_loader=False) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 1, 10) + yield retry_limited(lambda: len(agentmanager.sessions) == 1, 10) version = int(time.time()) @@ -1680,18 +1699,20 @@ def test_send_events_cross_agent(resource_container, io_loop, environment, serve """ Send and receive events over agents """ + agentmanager = server.get_endpoint(SLICE_AGENT_MANAGER) + resource_container.Provider.reset() agent = Agent(io_loop, hostname="node1", environment=environment, agent_map={"agent1": "localhost"}, code_loader=False) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 1, 10) + yield retry_limited(lambda: len(agentmanager.sessions) == 1, 10) agent2 = Agent(io_loop, hostname="node2", environment=environment, agent_map={"agent2": "localhost"}, code_loader=False) agent2.add_end_point_name("agent2") agent2.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 2, 10) + yield retry_limited(lambda: len(agentmanager.sessions) == 2, 10) version = int(time.time()) @@ -1753,12 +1774,14 @@ def test_send_events_cross_agent_restart(resource_container, io_loop, environmen """ Send and receive events over agents with agents starting after deploy """ + agentmanager = server.get_endpoint(SLICE_AGENT_MANAGER) + resource_container.Provider.reset() agent2 = Agent(io_loop, hostname="node2", environment=environment, agent_map={"agent2": "localhost"}, code_loader=False) agent2.add_end_point_name("agent2") agent2.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 1, 10) + yield retry_limited(lambda: len(agentmanager.sessions) == 1, 10) version = int(time.time()) @@ -1808,7 +1831,7 @@ def test_send_events_cross_agent_restart(resource_container, io_loop, environmen code_loader=False) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 2, 10) + yield retry_limited(lambda: len(agentmanager.sessions) == 2, 10) while (result.result["model"]["total"] - result.result["model"]["done"]) > 0: result = yield client.get_version(environment, version) @@ -1834,12 +1857,14 @@ def test_auto_deploy(io_loop, server, client, resource_container, environment): """ dryrun and deploy a configuration model automatically """ + agentmanager = server.get_endpoint(SLICE_AGENT_MANAGER) + resource_container.Provider.reset() agent = Agent(io_loop, hostname="node1", environment=environment, agent_map={"agent1": "localhost"}, code_loader=False) agent.add_end_point_name("agent1") agent.start() - yield retry_limited(lambda: len(server.agentmanager.sessions) == 1, 10) + yield retry_limited(lambda: len(agentmanager.sessions) == 1, 10) resource_container.Provider.set("agent1", "key2", "incorrect_value") resource_container.Provider.set("agent1", "key3", "value") @@ -2178,7 +2203,7 @@ def wait_for_version(cnt): return versions.result - project_dir = os.path.join(server._server_storage["environments"], str(environment)) + project_dir = os.path.join(server.get_endpoint("server")._server_storage["environments"], str(environment)) project_source = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "project") shutil.copytree(project_source, project_dir)