diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a9a3fa743..678c6ec401 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Changes in this release: - Removal of snapshot and restore functionality from the server (#789) - Replace virtualenv by python standard venv - Updated to Tornado 5, moving from tornado ioloop to the standard python async framework (#765) +- Extend mypy type annotations v 2018.3 (2018-12-07) Changes in this release: diff --git a/src/inmanta/module.py b/src/inmanta/module.py index bb629b6e8f..8194eef9d6 100644 --- a/src/inmanta/module.py +++ b/src/inmanta/module.py @@ -21,7 +21,7 @@ from io import BytesIO import logging import os -from os.path import sys +import sys import re from subprocess import CalledProcessError import subprocess @@ -32,13 +32,28 @@ from inmanta import env from inmanta import plugins -from inmanta.ast import Namespace, CompilerException, ModuleNotFoundException, Location, LocatableString +from inmanta.ast import Namespace, CompilerException, ModuleNotFoundException, Location, LocatableString, Range from inmanta.ast.blocks import BasicBlock -from inmanta.ast.statements import DefinitionStatement, BiStatement, Statement +from inmanta.ast.statements import DefinitionStatement, BiStatement, Statement, DynamicStatement from inmanta.ast.statements.define import DefineImport from inmanta.parser import plyInmantaParser -from inmanta.util import memoize, get_compiler_version +from inmanta.util import get_compiler_version from typing import Tuple, List, Dict +from typing import Optional +from typing import Union +from typing import Any +from functools import lru_cache + +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + + +if TYPE_CHECKING: + from typing import Iterable, Set # noqa: F401 + from pkg_resources.packaging.version import Version # noqa: F401 + from pkg_resources import Requirement # noqa: F401 LOGGER = logging.getLogger(__name__) @@ -64,55 +79,55 @@ class ProjectNotFoundExcpetion(Exception): class GitProvider(object): - def clone(self, src, dest): + def clone(self, src: str, dest: str) -> None: pass - def fetch(self, repo): + def fetch(self, repo: str) -> None: pass - def get_all_tags(self, repo): + def get_all_tags(self, repo: str) -> List[str]: pass - def get_file_for_version(self, repo, tag, file): + def get_file_for_version(self, repo: str, tag: str, file: str) -> str: pass - def checkout_tag(self, repo, tag): + def checkout_tag(self, repo: str, tag: str) -> None: pass - def commit(self, repo, message, commit_all, add=[]): + def commit(self, repo: str, message: str, commit_all: bool, add: List[str]=[]) -> None: pass - def tag(self, repo, tag): + def tag(self, repo: str, tag: str) -> None: pass - def push(self, repo): + def push(self, repo: str) -> str: pass class CLIGitProvider(GitProvider): - def clone(self, src, dest): + def clone(self, src: str, dest: str) -> None: env = os.environ.copy() env["GIT_ASKPASS"] = "true" subprocess.check_call(["git", "clone", src, dest], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, env=env) - def fetch(self, repo): + def fetch(self, repo: str) -> None: env = os.environ.copy() env["GIT_ASKPASS"] = "true" subprocess.check_call(["git", "fetch", "--tags"], cwd=repo, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, env=env) - def status(self, repo): + def status(self, repo: str) -> str: return subprocess.check_output(["git", "status", "--porcelain"], cwd=repo).decode("utf-8") - def get_all_tags(self, repo): + def get_all_tags(self, repo: str) -> List[str]: return subprocess.check_output(["git", "tag"], cwd=repo).decode("utf-8").splitlines() - def checkout_tag(self, repo, tag): + def checkout_tag(self, repo: str, tag: str) -> None: subprocess.check_call(["git", "checkout", tag], cwd=repo, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - def commit(self, repo, message, commit_all, add=[]): + def commit(self, repo: str, message: str, commit_all: bool, add: List[str]=[]) -> None: for file in add: subprocess.check_call(["git", "add", file], cwd=repo, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) if not commit_all: @@ -122,102 +137,25 @@ def commit(self, repo, message, commit_all, add=[]): subprocess.check_call(["git", "commit", "-a", "-m", message], cwd=repo, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - def tag(self, repo, tag): + def tag(self, repo: str, tag: str) -> None: subprocess.check_call(["git", "tag", "-a", "-m", "auto tag by module tool", tag], cwd=repo, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - def push(self, repo): + def push(self, repo: str) -> str: return subprocess.check_output(["git", "push", "--follow-tags", "--porcelain"], cwd=repo, stderr=subprocess.DEVNULL).decode("utf-8") - def get_file_for_version(self, repo, tag, file): + def get_file_for_version(self, repo: str, tag: str, file: str) -> str: data = subprocess.check_output(["git", "archive", "--format=tar", tag, file], cwd=repo, stderr=subprocess.DEVNULL) tf = TarFile(fileobj=BytesIO(data)) tfile = tf.next() + assert tfile is not None b = tf.extractfile(tfile) + assert b is not None return b.read().decode("utf-8") -# try: -# import pygit2 -# import re -# -# class LibGitProvider(GitProvider): -# -# def clone(self, src, dest): -# pygit2.clone_repository(src, dest) -# -# def fetch(self, repo): -# repoh = pygit2.Repository(repo) -# repoh.remotes["origin"].fetch() -# -# def status(self, repo): -# # todo -# return subprocess.check_output(["git", "status", "--porcelain"], cwd=repo).decode("utf-8") -# -# def get_all_tags(self, repo): -# repoh = pygit2.Repository(repo) -# regex = re.compile('^refs/tags/(.*)') -# return [m.group(1) for m in [regex.match(t) for t in repoh.listall_references()] if m] -# -# def checkout_tag(self, repo, tag): -# repoh = pygit2.Repository(repo) -# repoh.checkout("refs/tags/" + tag) -# -# def commit(self, repo, message, commit_all, add=[]): -# repoh = pygit2.Repository(repo) -# index = repoh.index -# index.read() -# -# for file in add: -# index.add(os.path.relpath(file, repo)) -# -# if commit_all: -# index.add_all() -# -# index.write() -# tree = index.write_tree() -# -# config = pygit2.Config.get_global_config() -# try: -# email = config["user.email"] -# except KeyError: -# email = "inmanta@example.com" -# LOGGER.warn("user.email not set in git config") -# -# try: -# username = config["user.name"] -# except KeyError: -# username = "Inmanta Moduletool" -# LOGGER.warn("user.name not set in git config") -# -# author = pygit2.Signature(username, email) -# -# return repoh.create_commit("HEAD", author, author, message, tree, [repoh.head.get_object().hex]) -# -# def tag(self, repo, tag): -# repoh = pygit2.Repository(repo) -# -# config = pygit2.Config.get_global_config() -# try: -# email = config["user.email"] -# except KeyError: -# email = "inmanta@example.com" -# LOGGER.warn("user.email not set in git config") -# -# try: -# username = config["user.name"] -# except KeyError: -# username = "Inmanta Moduletool" -# LOGGER.warn("user.name not set in git config") -# -# author = pygit2.Signature(username, email) -# -# repoh.create_tag(tag, repoh.head.target, pygit2.GIT_OBJ_COMMIT, author, "auto tag by module tool") -# -# gitprovider = LibGitProvider() -# except ImportError as e: gitprovider = CLIGitProvider() @@ -226,14 +164,14 @@ class ModuleRepo(object): def clone(self, name: str, dest: str) -> bool: raise NotImplementedError("Abstract method") - def path_for(self, name: str): + def path_for(self, name: str) -> Optional[str]: # same class is used for search parh and remote repos, perhaps not optimal raise NotImplementedError("Abstract method") class CompositeModuleRepo(ModuleRepo): - def __init__(self, children): + def __init__(self, children: List[ModuleRepo]) -> None: self.children = children def clone(self, name: str, dest: str) -> bool: @@ -242,7 +180,7 @@ def clone(self, name: str, dest: str) -> bool: return True return False - def path_for(self, name: str): + def path_for(self, name: str) -> Optional[str]: for child in self.children: result = child.path_for(name) if result is not None: @@ -252,7 +190,7 @@ def path_for(self, name: str): class LocalFileRepo(ModuleRepo): - def __init__(self, root, parent_root=None): + def __init__(self, root: str, parent_root: Optional[str] = None) -> None: if parent_root is None: self.root = os.path.abspath(root) else: @@ -266,7 +204,7 @@ def clone(self, name: str, dest: str) -> bool: LOGGER.debug("could not clone repo", exc_info=True) return False - def path_for(self, name: str): + def path_for(self, name: str) -> Optional[str]: path = os.path.join(self.root, name) if os.path.exists(path): return path @@ -275,7 +213,7 @@ def path_for(self, name: str): class RemoteRepo(ModuleRepo): - def __init__(self, baseurl): + def __init__(self, baseurl: str) -> None: self.baseurl = baseurl def clone(self, name: str, dest: str) -> bool: @@ -290,18 +228,18 @@ def clone(self, name: str, dest: str) -> bool: LOGGER.debug("could not clone repo", exc_info=True) return False - def path_for(self, name: str): + def path_for(self, name: str) -> Optional[str]: raise NotImplementedError("Should only be called on local repos") -def make_repo(path, root=None): +def make_repo(path: str, root: Optional[str] = None) -> Union[LocalFileRepo, RemoteRepo]: if ":" in path: return RemoteRepo(path) else: return LocalFileRepo(path, parent_root=root) -def merge_specs(mainspec, new): +def merge_specs(mainspec: "Dict[str, List[Requirement]]", new: "List[Requirement]") -> None: """Merge two maps str->[T] by concatting their lists.""" for req in new: key = req.project_name @@ -316,21 +254,21 @@ class ModuleLike(object): Commons superclass for projects and modules, which are both versioned by git """ - def __init__(self, path): + def __init__(self, path: str) -> None: """ @param path: root git directory """ self._path = path - self._meta = {} + self._meta = {} # type: Dict[str, Any] - def get_name(self): - raise NotImplemented() + def get_name(self) -> str: + raise NotImplementedError() name = property(get_name) - def _load_file(self, ns, file) -> Tuple[List[Statement], BasicBlock]: + def _load_file(self, ns: Namespace, file: str) -> Tuple[List[Statement], BasicBlock]: ns.location = Location(file, 1) - statements = [] + statements = [] # type: List[Statement] stmts = plyInmantaParser.parse(ns, file) block = BasicBlock(ns) for s in stmts: @@ -342,10 +280,11 @@ def _load_file(self, ns, file) -> Tuple[List[Statement], BasicBlock]: elif isinstance(s, str) or isinstance(s, LocatableString): pass else: + assert isinstance(s, DynamicStatement) block.add(s) return (statements, block) - def requires(self) -> "List[List[Requirement]]": + def requires(self) -> "List[Requirement]": """ Get the requires for this module """ @@ -359,11 +298,11 @@ def requires(self) -> "List[List[Requirement]]": req = [x for x in parse_requirements(spec)] if len(req) > 1: print("Module file for %s has bad line in requirements specification %s" % (self._path, spec)) - req = req[0] - reqs.append(req) + reqe = req[0] + reqs.append(reqe) return reqs - def get_config(self, name, default): + def get_config(self, name: str, default: Any) -> Any: if name not in self._meta: return default else: @@ -383,7 +322,7 @@ class Project(ModuleLike): PROJECT_FILE = "project.yml" _project = None - def __init__(self, path, autostd=True, main_file="main.cf"): + def __init__(self, path: str, autostd: bool = True, main_file: str = "main.cf") -> None: """ Initialize the project, this includes * Loading the project.yaml (into self._meta) @@ -441,7 +380,7 @@ def __init__(self, path, autostd=True, main_file="main.cf"): self.virtualenv = env.VirtualEnv(os.path.join(path, ".env")) self.loaded = False - self.modules = {} + self.modules = {} # type: Dict[str, Module] self.root_ns = Namespace("__root__") @@ -455,7 +394,7 @@ def __init__(self, path, autostd=True, main_file="main.cf"): self._install_mode = mode @classmethod - def get_project_dir(cls, cur_dir): + def get_project_dir(cls, cur_dir: str) -> str: """ Find the project directory where we are working in. Traverse up until we find Project.PROJECT_FILE or reach / """ @@ -471,7 +410,7 @@ def get_project_dir(cls, cur_dir): return cls.get_project_dir(parent_dir) @classmethod - def get(cls, main_file="main.cf"): + def get(cls, main_file: str = "main.cf") -> "Project": """ Get the instance of the project """ @@ -481,7 +420,7 @@ def get(cls, main_file="main.cf"): return cls._project @classmethod - def set(cls, project): + def set(cls, project: "Project") -> None: """ Get the instance of the project """ @@ -489,7 +428,7 @@ def set(cls, project): os.chdir(project._path) plugins.PluginMeta.clear() - def load(self): + def load(self) -> None: if not self.loaded: self.get_complete_ast() self.use_virtual_env() @@ -512,20 +451,21 @@ def load(self): else: self.load_plugins() - @memoize + @lru_cache() def get_ast(self) -> Tuple[List[Statement], BasicBlock]: return self.__load_ast() - @memoize - def get_imports(self): + @lru_cache() + def get_imports(self) -> List[DefineImport]: (statements, _) = self.get_ast() imports = [x for x in statements if isinstance(x, DefineImport)] if self.autostd: - imports.insert(0, DefineImport("std", "std")) + std_locatable = LocatableString("std", Range("internal", 0, 0, 0, 0), 0, self.root_ns) + imports.insert(0, DefineImport(std_locatable, std_locatable)) return imports - @memoize - def get_complete_ast(self): + @lru_cache() + def get_complete_ast(self) -> Tuple[List[Statement], List[BasicBlock]]: # load ast (statements, block) = self.get_ast() blocks = [block] @@ -539,7 +479,7 @@ def get_complete_ast(self): return (statements, blocks) - def __load_ast(self): + def __load_ast(self) -> Tuple[List[Statement], BasicBlock]: main_ns = Namespace("__config__", self.root_ns) return self._load_file(main_ns, os.path.join(self.project_path, self.main_file)) @@ -547,7 +487,7 @@ def get_modules(self) -> Dict[str, "Module"]: self.load() return self.modules - def get_module(self, full_module_name): + def get_module(self, full_module_name: str) -> "Module": parts = full_module_name.split("::") module_name = parts[0] @@ -566,7 +506,7 @@ def load_module_recursive(self, imports: List[DefineImport]) -> List[Tuple[str, # get imports imports = [x for x in self.get_imports()] - done = set() + done = set() # type: Set[str] while len(imports) > 0: imp = imports.pop() ns = imp.name @@ -596,7 +536,7 @@ def load_module_recursive(self, imports: List[DefineImport]) -> List[Tuple[str, return out - def load_module(self, module_name) -> "Module": + def load_module(self, module_name: str) -> "Module": try: path = self.resolver.path_for(module_name) if path is not None: @@ -606,7 +546,10 @@ def load_module(self, module_name) -> "Module": if module_name in reqs: module = Module.install(self, module_name, reqs[module_name], install_mode=self._install_mode) else: - module = Module.install(self, module_name, parse_requirements(module_name), install_mode=self._install_mode) + module = Module.install(self, + module_name, + list(parse_requirements(module_name)), + install_mode=self._install_mode) self.modules[module_name] = module return module except Exception: @@ -638,7 +581,7 @@ def sorted_modules(self) -> list: """ Return a list of all modules, sorted on their name """ - names = self.modules.keys() + names = list(self.modules.keys()) names = sorted(names) mod_list = [] @@ -647,26 +590,26 @@ def sorted_modules(self) -> list: return mod_list - def collect_requirements(self): + def collect_requirements(self) -> "Dict[str, List[Requirement]]": """ Collect the list of all requirements of all modules in the project. """ if not self.loaded: LOGGER.warning("collecting reqs on project that has not been loaded completely") - specs = {} + specs = {} # type: Dict[str, List[Requirement]] merge_specs(specs, self.requires()) for module in self.modules.values(): reqs = module.requires() merge_specs(specs, reqs) return specs - def collect_imported_requirements(self): + def collect_imported_requirements(self) -> "Dict[str, Iterable[Requirement]]": imports = set([x.name.split("::")[0] for x in self.get_complete_ast()[0] if isinstance(x, DefineImport)]) imports.add("std") specs = self.collect_requirements() - def get_spec(name): + def get_spec(name: str) -> "Iterable[Requirement]": if name in specs: return specs[name] return parse_requirements(name) @@ -695,27 +638,27 @@ def verify_requires(self) -> bool: return good - def collect_python_requirements(self): + def collect_python_requirements(self) -> List[str]: """ Collect the list of all python requirements off all modules in this project """ pyreq = [x.strip() for x in [mod.get_python_requirements() for mod in self.modules.values()] if x is not None] - pyreq = '\n'.join(pyreq).split("\n") - pyreq = [x for x in pyreq if len(x.strip()) > 0] - return list(set(pyreq)) + pyreqa = '\n'.join(pyreq).split("\n") + pyreqb = [x for x in pyreqa if len(x.strip()) > 0] + return list(set(pyreqb)) - def get_name(self): + def get_name(self) -> str: return "project.yml" name = property(get_name) - def get_config_file_name(self): + def get_config_file_name(self) -> str: return os.path.join(self._path, "project.yml") - def get_root_namespace(self): + def get_root_namespace(self) -> Namespace: return self.root_ns - def get_freeze(self, mode="==", recursive=False): + def get_freeze(self, mode: str = "==", recursive: bool = False) -> Dict[str, str]: # collect in scope modules if not recursive: modules = {m.name: m for m in (self.get_module(imp.name) for imp in self.get_imports())} @@ -737,7 +680,7 @@ class Module(ModuleLike): MODEL_DIR = "model" requires_fields = ["name", "license", "version"] - def __init__(self, project: Project, path: str, **kwmeta: dict): + def __init__(self, project: Project, path: str, **kwmeta: dict) -> None: """ Create a new configuration module @@ -748,7 +691,7 @@ def __init__(self, project: Project, path: str, **kwmeta: dict): super().__init__(path) self._project = project self._meta = kwmeta - self._plugin_namespaces = [] + self._plugin_namespaces = [] # type: List[str] if not Module.is_valid_module(self._path): raise InvalidModuleException(("Module %s is not a valid inmanta configuration module. Make sure that a " + @@ -757,7 +700,7 @@ def __init__(self, project: Project, path: str, **kwmeta: dict): self.load_module_file() self.is_versioned() - def rewrite_version(self, new_version): + def rewrite_version(self, new_version: str) -> None: new_version = str(new_version) # make sure it is a string! with open(self.get_config_file_name(), "r") as fd: module_def = fd.read() @@ -787,7 +730,7 @@ def rewrite_version(self, new_version): self._meta = new_info - def get_name(self): + def get_name(self) -> str: """ Returns the name of the module (if the meta data is set) """ @@ -820,7 +763,12 @@ def compiler_version(self) -> str: return None @classmethod - def install(cls, project, modulename, requirements, install=True, install_mode=INSTALL_RELEASES): + def install(cls, + project: Project, + modulename: str, + requirements: "List[Requirement]", + install: bool = True, + install_mode: str = INSTALL_RELEASES) -> "Module": """ Install a module, return module object """ @@ -840,7 +788,13 @@ def install(cls, project, modulename, requirements, install=True, install_mode=I return cls.update(project, modulename, requirements, path, False, install_mode=install_mode) @classmethod - def update(cls, project, modulename, requirements, path=None, fetch=True, install_mode=INSTALL_RELEASES): + def update(cls, + project: Project, + modulename: str, + requirements: "List[Requirement]", + path: str = None, + fetch: bool = True, + install_mode: str = INSTALL_RELEASES) -> "Module": """ Update a module, return module object """ @@ -865,10 +819,14 @@ def update(cls, project, modulename, requirements, path=None, fetch=True, instal return Module(project, path) @classmethod - def get_suitable_version_for(cls, modulename, requirements, path, release_only=True): + def get_suitable_version_for(cls, + modulename: str, + requirements: "List[Requirement]", + path: str, + release_only: bool = True) -> "Optional[Version]": versions = gitprovider.get_all_tags(path) - def try_parse(x): + def try_parse(x: str) -> "Version": try: return parse_version(x) except Exception: @@ -891,8 +849,12 @@ def try_parse(x): return versions[0] if len(versions) > 0 else None @classmethod - def __best_for_compiler_version(cls, modulename, versions, path, comp_version): - def get_cv_for(best): + def __best_for_compiler_version(cls, + modulename: str, + versions: "List[Version]", + path: str, + comp_version: "Version") -> "Optional[Version]": + def get_cv_for(best: "Version") -> "Optional[Version]": cfg = gitprovider.get_file_for_version(path, str(best), "module.yml") cfg = yaml.load(cfg) if "compiler_version" not in cfg: @@ -925,7 +887,7 @@ def get_cv_for(best): return None return versions[lo] - def is_versioned(self): + def is_versioned(self) -> bool: """ Check if this module is versioned, and if so the version number in the module file should have a tag. If the version has + the current revision can be a child otherwise the current @@ -938,7 +900,7 @@ def is_versioned(self): return True @classmethod - def is_valid_module(cls, module_path): + def is_valid_module(cls, module_path: str) -> bool: """ Checks if this module is a valid configuration module. A module should contain a module.yml file. @@ -948,7 +910,7 @@ def is_valid_module(cls, module_path): return True - def load_module_file(self): + def load_module_file(self) -> None: """ Load the module definition file """ @@ -971,10 +933,10 @@ def load_module_file(self): LOGGER.warning("The name in the module file (%s) does not match the directory name (%s)" % (self._meta["name"], os.path.basename(self._path))) - def get_config_file_name(self): + def get_config_file_name(self) -> str: return os.path.join(self._path, "module.yml") - def get_module_files(self): + def get_module_files(self) -> List[str]: """ Returns the path of all model files in this module, relative to the module root """ @@ -984,8 +946,8 @@ def get_module_files(self): return files - @memoize - def get_ast(self, name) -> Tuple[List[Statement], BasicBlock]: + @lru_cache() + def get_ast(self, name: str) -> Tuple[List[Statement], BasicBlock]: if name == self.name: file = os.path.join(self._path, Module.MODEL_DIR, "_init.cf") else: @@ -1004,7 +966,7 @@ def get_ast(self, name) -> Tuple[List[Statement], BasicBlock]: except FileNotFoundError: raise InvalidModuleException("could not locate module with name: %s", name) - def get_freeze(self, submodule, recursive=False, mode=">="): + def get_freeze(self, submodule: str, recursive: bool = False, mode: str = ">=") -> Dict[str, str]: imports = [statement.name for statement in self.get_imports(submodule)] out = {} @@ -1023,15 +985,16 @@ def get_freeze(self, submodule, recursive=False, mode=">="): # drop submodules return {x: v for x, v in out.items() if "::" not in x} - @memoize - def get_imports(self, name): - (statements, _) = self.get_ast(name) + @lru_cache() + def get_imports(self, name: str) -> List[DefineImport]: + (statements, block) = self.get_ast(name) imports = [x for x in statements if isinstance(x, DefineImport)] if self._project.autostd: - imports.insert(0, DefineImport("std", "std")) + std_locatable = LocatableString("std", Range("internal", 0, 0, 0, 0), 0, block.namespace) + imports.insert(0, DefineImport(std_locatable, std_locatable)) return imports - def _get_model_files(self, curdir): + def _get_model_files(self, curdir: str) -> List[str]: files = [] init_cf = os.path.join(curdir, "_init.cf") if not os.path.exists(init_cf): @@ -1068,7 +1031,7 @@ def get_all_submodules(self) -> List[str]: return modules - def load_plugins(self): + def load_plugins(self) -> None: """ Load all plug-ins from a configuration module """ @@ -1116,7 +1079,7 @@ def try_parse(x): return versions - def status(self): + def status(self) -> None: """ Run a git status on this module """ @@ -1137,7 +1100,7 @@ def status(self): print("Failed to get status of module") LOGGER.exception("Failed to get status of module %s") - def push(self): + def push(self) -> None: """ Run a git status on this module """ @@ -1151,7 +1114,7 @@ def push(self): print("done") print() - def get_python_requirements(self): + def get_python_requirements(self) -> Optional[str]: """ Install python requirements with pip in a virtual environment """ @@ -1162,15 +1125,15 @@ def get_python_requirements(self): else: return None - @memoize - def get_python_requirements_as_list(self): + @lru_cache() + def get_python_requirements_as_list(self) -> List[str]: raw = self.get_python_requirements() if raw is None: return [] else: return [y for y in [x.strip() for x in raw.split("\n")] if len(y) != 0] - def execute_command(self, cmd): + def execute_command(self, cmd: str) -> None: print("executing %s on %s in %s" % (cmd, self.get_name(), self._path)) print("=" * 10) subprocess.call(cmd, shell=True, cwd=self._path) diff --git a/src/inmanta/plugins.py b/src/inmanta/plugins.py index e8f91823b0..473cf84c4c 100644 --- a/src/inmanta/plugins.py +++ b/src/inmanta/plugins.py @@ -148,7 +148,7 @@ def get_functions(cls): return cls.__functions @classmethod - def clear(cls): + def clear(cls) -> None: cls.__functions = {}