diff --git a/docs/conf.py b/docs/conf.py index 162f075..0e32d37 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,21 +14,22 @@ # import os import sys -sys.path.insert(0, os.path.abspath('..')) + +sys.path.insert(0, os.path.abspath("..")) # -- Project information ----------------------------------------------------- import os -project = 'Memento' -copyright = '2018, AlphaStudio' -author = 'AlphaStudio' +project = "Memento" +copyright = "2018, AlphaStudio" +author = "AlphaStudio" # The full version, including alpha/beta/rc tags root_dir = os.path.dirname(os.path.realpath(__file__)) -release = open(os.path.join(root_dir, '../../version.txt')).read().strip() +release = open(os.path.join(root_dir, "../../version.txt")).read().strip() # The short X.Y version -version = release[0:release.rindex(".")] +version = release[0 : release.rindex(".")] # -- General configuration --------------------------------------------------- @@ -41,24 +42,24 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.coverage', - 'sphinx.ext.viewcode', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.coverage", + "sphinx.ext.viewcode", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -70,10 +71,10 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path . -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # -- Options for HTML output ------------------------------------------------- @@ -81,7 +82,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -92,7 +93,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -105,8 +106,8 @@ # html_sidebars = {} html_context = { - 'css_files': [ - '_static/fix_table_text_wrap.css', # fix table width + "css_files": [ + "_static/fix_table_text_wrap.css", # fix table width ], } @@ -114,7 +115,7 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'Mementodoc' +htmlhelp_basename = "Mementodoc" # -- Options for LaTeX output ------------------------------------------------ @@ -123,15 +124,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -141,8 +139,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'Memento.tex', 'Memento Documentation', - 'AlphaStudio', 'manual'), + (master_doc, "Memento.tex", "Memento Documentation", "AlphaStudio", "manual"), ] @@ -150,10 +147,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'memento', 'Memento Documentation', - [author], 1) -] +man_pages = [(master_doc, "memento", "Memento Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -162,16 +156,22 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Memento', 'Memento Documentation', - author, 'Memento', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "Memento", + "Memento Documentation", + author, + "Memento", + "One line description of project.", + "Miscellaneous", + ), ] # -- Extension configuration ------------------------------------------------- # Include Python objects as they appear in source files -autodoc_member_order = 'bysource' +autodoc_member_order = "bysource" # Default flags used by autodoc directives -autodoc_default_flags = ['members', 'show-inheritance'] +autodoc_default_flags = ["members", "show-inheritance"] autosummary_generate = True diff --git a/tests/conftest.py b/tests/conftest.py index 5fd945e..34a0b7b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,9 @@ def pytest_addoption(parser): - parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") + parser.addoption( + "--runslow", action="store_true", default=False, help="run slow tests" + ) def pytest_configure(config): @@ -31,7 +33,8 @@ def pytest_collection_modifyitems(config, items): skip_slow = pytest.mark.skip(reason="need --runslow option to run") skip_non_canonical_version = pytest.mark.skip( - reason=f"need python canonical version {canonical_version} to run (running {current_version})") + reason=f"need python canonical version {canonical_version} to run (running {current_version})" + ) for item in items: if "slow" in item.keywords and not config.getoption("--runslow"): diff --git a/tests/test_call_stack.py b/tests/test_call_stack.py index c4d8a03..4180994 100644 --- a/tests/test_call_stack.py +++ b/tests/test_call_stack.py @@ -63,10 +63,16 @@ def test_push_pop_caller(self): recursive_context = RecursiveContext() recursive_context.update("correlation_id", corr_id) call_stack = CallStack.get() - frame1 = StackFrame(fn_sample_1.fn_reference().with_args(), - Environment.get().default_cluster.runner, recursive_context) - frame2 = StackFrame(fn_sample_1.fn_reference().with_args(), - Environment.get().default_cluster.runner, recursive_context) + frame1 = StackFrame( + fn_sample_1.fn_reference().with_args(), + Environment.get().default_cluster.runner, + recursive_context, + ) + frame2 = StackFrame( + fn_sample_1.fn_reference().with_args(), + Environment.get().default_cluster.runner, + recursive_context, + ) call_stack.push_frame(frame1) assert call_stack.depth() == 1 assert frame1 is call_stack.get_calling_frame() diff --git a/tests/test_code_hash.py b/tests/test_code_hash.py index aa3fb41..6cf43f9 100644 --- a/tests/test_code_hash.py +++ b/tests/test_code_hash.py @@ -23,7 +23,11 @@ from twosigma.memento import MementoFunction from twosigma.memento.exception import UndeclaredDependencyError from twosigma.memento import memento_function, Environment -from twosigma.memento.code_hash import fn_code_hash, list_dotted_names, resolve_to_symbolic_names +from twosigma.memento.code_hash import ( + fn_code_hash, + list_dotted_names, + resolve_to_symbolic_names, +) @memento_function() @@ -79,6 +83,7 @@ def _non_memento_fn_3(): def dep_with_embedded_fn(): def embedded_fn(): return dep_b() + return embedded_fn() @@ -95,6 +100,7 @@ def fn_with_cell_vars(): def inner(): y = x return y + return inner() @@ -133,7 +139,7 @@ def setup_method(self): self.env_before = Environment.get() self.env_dir = tempfile.mkdtemp(prefix="memoizeTest") env_file = "{}/env.json".format(self.env_dir) - with open(env_file, 'w') as f: + with open(env_file, "w") as f: print("""{"name": "test"}""", file=f) Environment.set(env_file) @@ -195,19 +201,26 @@ def test_non_memento_fn(self): global _floating_fn try: - assert {dep_b} == dep_floating_fn.dependencies().transitive_memento_fn_dependencies() + assert { + dep_b + } == dep_floating_fn.dependencies().transitive_memento_fn_dependencies() version_before = dep_floating_fn.version() _floating_fn = _non_memento_fn_2 version_after = dep_floating_fn.version() assert version_before != version_after - assert {dep_a, dep_b} == dep_floating_fn.dependencies().transitive_memento_fn_dependencies() + assert { + dep_a, + dep_b, + } == dep_floating_fn.dependencies().transitive_memento_fn_dependencies() finally: _floating_fn = _non_memento_fn_1 def test_dep_with_embedded_fn(self): - assert {dep_b} == dep_with_embedded_fn.dependencies().transitive_memento_fn_dependencies() + assert { + dep_b + } == dep_with_embedded_fn.dependencies().transitive_memento_fn_dependencies() def test_redefine_memento_fn_as_non_memento_fn(self): """ @@ -254,14 +267,18 @@ def test_fn_with_local_vars(self): Make sure local variables are not included in the function hash """ - assert not any("UndefinedSymbol;x" in r.describe() for r in fn_with_local_vars.hash_rules()) + assert not any( + "UndefinedSymbol;x" in r.describe() for r in fn_with_local_vars.hash_rules() + ) def test_fn_with_cell_vars(self): """ Make sure cell variables are not included in the function hash """ - assert not any("UndefinedSymbol;x" in r.describe() for r in fn_with_cell_vars.hash_rules()) + assert not any( + "UndefinedSymbol;x" in r.describe() for r in fn_with_cell_vars.hash_rules() + ) def test_cluster_lock_prevents_version_update(self): """ @@ -279,7 +296,9 @@ def test_cluster_lock_prevents_version_update(self): assert prev_value == dep_global_var() global_var = prev_value + 1 v2 = dep_floating_fn.version() - assert prev_value == dep_global_var() # Should be memoized from previous call + assert ( + prev_value == dep_global_var() + ) # Should be memoized from previous call assert v1 == v2 finally: global_var = prev_value @@ -303,6 +322,8 @@ def test_safe_to_call_memento_fn_wrappers(self): function. """ - result = fn_calls_wrapped_one_plus_one.dependencies().transitive_memento_fn_dependencies() + result = ( + fn_calls_wrapped_one_plus_one.dependencies().transitive_memento_fn_dependencies() + ) # noinspection PyUnresolvedReferences assert {_wrapped_one_plus_one.__wrapped__} == result diff --git a/tests/test_configuration.py b/tests/test_configuration.py index a4cced4..8b35ede 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -53,10 +53,7 @@ def test_environment_default_init(self): assert env.base_dir is None def test_environment_set_static(self): - m.Environment.set({ - "name": "test1", - "repos": [] - }) + m.Environment.set({"name": "test1", "repos": []}) env = m.Environment.get() assert "test1" == env.name assert [] == env.repos @@ -102,7 +99,7 @@ def get_sample_config(): "maintainer": "maintainer1", "documentation": "doc1", "clusters": {}, - "modules": ["a.b.c"] + "modules": ["a.b.c"], } @staticmethod @@ -118,10 +115,7 @@ def test_config_repo_set_static(self): self.assert_sample_config(m.ConfigurationRepository(self.get_sample_config())) def test_config_repo_env_static(self): - m.Environment.set({ - "name": "test", - "repos": [self.get_sample_config()] - }) + m.Environment.set({"name": "test", "repos": [self.get_sample_config()]}) self.assert_sample_config(m.Environment.get().repos[0]) def test_config_repo_env_file(self): @@ -131,10 +125,7 @@ def test_config_repo_env_file(self): config_file = "{}/config.json".format(d) with open(config_file, "w") as f: json.dump(self.get_sample_config(), f) - m.Environment.set({ - "name": "test", - "repos": [config_file] - }) + m.Environment.set({"name": "test", "repos": [config_file]}) self.assert_sample_config(m.Environment.get().repos[0]) finally: if d: @@ -146,10 +137,7 @@ def test_config_repo_env_file_relative_path(self): d = tempfile.mkdtemp(prefix="memento_test_configuration") env_file = "{}/env.json".format(d) with open(env_file, "w") as f: - json.dump({ - "name": "test", - "repos": ["subdir/config.json"] - }, f) + json.dump({"name": "test", "repos": ["subdir/config.json"]}, f) os.mkdir("{}/subdir".format(d)) config_file = "{}/subdir/config.json".format(d) with open(config_file, "w") as f: @@ -167,10 +155,7 @@ def get_sample_cluster(): "description": "description1", "maintainer": "maintainer1", "documentation": "doc1", - "storage": { - "type": "null", - "readonly": True - } + "storage": {"type": "null", "readonly": True}, } @staticmethod @@ -191,14 +176,17 @@ def test_function_cluster_env_file(self): cluster_file = "{}/cluster.json".format(d) with open(cluster_file, "w") as f: json.dump(self.get_sample_cluster(), f) - m.Environment.set({ - "name": "test", - "repos": [{ - "name": "config1", - "clusters": {"cluster1": cluster_file} - }] - }) - self.assert_sample_cluster(m.Environment.get().repos[0].clusters["cluster1"]) + m.Environment.set( + { + "name": "test", + "repos": [ + {"name": "config1", "clusters": {"cluster1": cluster_file}} + ], + } + ) + self.assert_sample_cluster( + m.Environment.get().repos[0].clusters["cluster1"] + ) finally: if d: shutil.rmtree(d) @@ -209,54 +197,58 @@ def test_function_cluster_env_file_relative_path(self): d = tempfile.mkdtemp(prefix="memento_test_configuration") config_file = "{}/config.json".format(d) with open(config_file, "w") as f: - json.dump({ - "name": "config1", - "clusters": {"cluster1": "cluster_subdir/cluster.json"} - }, f) + json.dump( + { + "name": "config1", + "clusters": {"cluster1": "cluster_subdir/cluster.json"}, + }, + f, + ) cluster_file = "{}/cluster_subdir/cluster.json".format(d) os.mkdir("{}/cluster_subdir".format(d)) with open(cluster_file, "w") as f: json.dump(self.get_sample_cluster(), f) - m.Environment.set({ - "name": "test", - "repos": [config_file] - }) - self.assert_sample_cluster(m.Environment.get().repos[0].clusters["cluster1"]) + m.Environment.set({"name": "test", "repos": [config_file]}) + self.assert_sample_cluster( + m.Environment.get().repos[0].clusters["cluster1"] + ) finally: if d: shutil.rmtree(d) def test_get_cluster(self): - m.Environment.set({ - "name": "test", - "repos": [ - { - "name": "repo1", - "clusters": { - "A": { - "name": "A", - "description": "1", - "storage": {"type": "null"} - } - } - }, - { - "name": "repo2", - "clusters": { - "A": { - "name": "A", - "description": "2", - "storage": {"type": "null"} + m.Environment.set( + { + "name": "test", + "repos": [ + { + "name": "repo1", + "clusters": { + "A": { + "name": "A", + "description": "1", + "storage": {"type": "null"}, + } }, - "B": { - "name": "B", - "description": "3", - "storage": {"type": "null"} - } - } - } - ] - }) + }, + { + "name": "repo2", + "clusters": { + "A": { + "name": "A", + "description": "2", + "storage": {"type": "null"}, + }, + "B": { + "name": "B", + "description": "3", + "storage": {"type": "null"}, + }, + }, + }, + ], + } + ) env = m.Environment.get() assert "1" == env.get_cluster("A").description assert "3" == env.get_cluster("B").description @@ -270,20 +262,17 @@ def test_default_cluster_storage_is_filesystem(self): assert "filesystem" == cluster.storage.storage_type def test_default_storage_is_filesystem(self): - m.Environment.set({ - "name": "test", - "repos": [ - { - "name": "repo1", - "clusters": { - "A": { - "name": "A", - "description": "1" - } + m.Environment.set( + { + "name": "test", + "repos": [ + { + "name": "repo1", + "clusters": {"A": {"name": "A", "description": "1"}}, } - } - ] - }) + ], + } + ) env = m.Environment.get() cluster = env.get_cluster("A") assert "filesystem" == cluster.storage.storage_type @@ -297,58 +286,49 @@ def test_get_registered_clusters(self): assert "test_cluster" in m.Environment.get_registered_clusters() def test_get_repo(self): - m.Environment.set({ - "name": "test", - "repos": [ - { - "name": "repo1", - "clusters": { - "A": { - "name": "A", - "description": "1" - } + m.Environment.set( + { + "name": "test", + "repos": [ + { + "name": "repo1", + "clusters": {"A": {"name": "A", "description": "1"}}, } - } - ] - }) + ], + } + ) env = m.Environment.get() assert "repo1" == env.get_repo("repo1").name assert env.get_repo("repo2") is None def test_append_repo(self): - m.Environment.set({ - "name": "test", - "repos": [ - { - "name": "repo1", - "clusters": { - "A": { - "name": "A", - "description": "1" - } + m.Environment.set( + { + "name": "test", + "repos": [ + { + "name": "repo1", + "clusters": {"A": {"name": "A", "description": "1"}}, } - } - ] - }) + ], + } + ) env = m.Environment.get() env.append_repo(ConfigurationRepository(name="repo2")) assert "repo2" == env.repos[-1].name def test_prepend_repo(self): - m.Environment.set({ - "name": "test", - "repos": [ - { - "name": "repo1", - "clusters": { - "A": { - "name": "A", - "description": "1" - } + m.Environment.set( + { + "name": "test", + "repos": [ + { + "name": "repo1", + "clusters": {"A": {"name": "A", "description": "1"}}, } - } - ] - }) + ], + } + ) env = m.Environment.get() env.prepend_repo(ConfigurationRepository(name="repo2")) assert "repo2" == env.repos[0].name @@ -371,33 +351,30 @@ def test_is_function_registered(self): assert not Environment.is_function_registered(qn + "x") def test_env_to_dict(self): - m.Environment.set({ - "name": "e_name", - "base_dir": "e_basedir", - "repos": [ - { - "name": "r_name", - "base_dir": "r_basedir", - "description": "r_description", - "maintainer": "r_maintainer", - "clusters": { - "c_name": { - "name": "c_name", - "description": "c_description", - "maintainer": "c_maintainer", - "documentation": "c_doc", - "storage": { - "type": "filesystem", - "path": "/tmp" - }, - "runner": { - "type": "local" + m.Environment.set( + { + "name": "e_name", + "base_dir": "e_basedir", + "repos": [ + { + "name": "r_name", + "base_dir": "r_basedir", + "description": "r_description", + "maintainer": "r_maintainer", + "clusters": { + "c_name": { + "name": "c_name", + "description": "c_description", + "maintainer": "c_maintainer", + "documentation": "c_doc", + "storage": {"type": "filesystem", "path": "/tmp"}, + "runner": {"type": "local"}, } - } + }, } - } - ] - }) + ], + } + ) env = m.Environment.get() env_dict = env.to_dict() env2 = m.Environment(env_dict) @@ -432,6 +409,8 @@ def test_locked_cluster(self): cluster.locked = True try: with pytest.raises(ValueError): - Environment.register_function(None, m.memento_function(not_yet_registered_fn)) + Environment.register_function( + None, m.memento_function(not_yet_registered_fn) + ) finally: cluster.locked = False diff --git a/tests/test_dependency_graph.py b/tests/test_dependency_graph.py index 92617ba..01361f1 100644 --- a/tests/test_dependency_graph.py +++ b/tests/test_dependency_graph.py @@ -71,7 +71,7 @@ def setup_method(self): self.env_before = Environment.get() self.env_dir = tempfile.mkdtemp(prefix="dependencyGraphTest") env_file = "{}/env.json".format(self.env_dir) - with open(env_file, 'w') as f: + with open(env_file, "w") as f: print("""{"name": "test"}""", file=f) Environment.set(env_file) @@ -97,20 +97,30 @@ def test_verbose_graph(self): def test_df(self): deps = dep_a.dependencies(verbose=True) # type: DependencyGraphType - expected_df = pd.DataFrame(data=[ - {"src": dep_a.qualified_name_without_version, - "target": dep_b.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": dep_a.qualified_name_without_version, - "target": "no_such_symbol", - "type": "UndefinedSymbol"}, - {"src": dep_b.qualified_name_without_version, - "target": dep_c.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": dep_c.qualified_name_without_version, - "target": "global_var", - "type": "GlobalVariable"} - ]).sort_values(by="src") + expected_df = pd.DataFrame( + data=[ + { + "src": dep_a.qualified_name_without_version, + "target": dep_b.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": dep_a.qualified_name_without_version, + "target": "no_such_symbol", + "type": "UndefinedSymbol", + }, + { + "src": dep_b.qualified_name_without_version, + "target": dep_c.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": dep_c.qualified_name_without_version, + "target": "global_var", + "type": "GlobalVariable", + }, + ] + ).sort_values(by="src") actual_df = deps.df().sort_values(by="src") pd.testing.assert_frame_equal(expected_df, actual_df) @@ -125,8 +135,15 @@ def test_cycles_do_not_prevent_transitive_dep(self): dep_df = dep.df() assert any("_test_cycle_2" in name for name in dep_df.src.values) assert any("_test_cycle_2" in name for name in dep_df.target.values) - assert len( - dep_df[dep_df["src"].str.contains("_test_cycle_2") & dep_df["target"].str.contains("_test_cycle_3")]) == 1 + assert ( + len( + dep_df[ + dep_df["src"].str.contains("_test_cycle_2") + & dep_df["target"].str.contains("_test_cycle_3") + ] + ) + == 1 + ) def test_label_filter(self): graph = _fn_test_cycle_1.dependencies(label_filter=lambda x: "FILTERED").graph() @@ -136,17 +153,27 @@ def test_label_filter(self): def test_rules_until_first_memento_fn(self): def validate(expected, deps): for pair in expected: - assert 1 == len([r for r in deps if f"fn_{pair[0]}" in r.parent_symbol and - f"fn_{pair[1]}" in r.symbol]), f"{pair} not found" + assert 1 == len( + [ + r + for r in deps + if f"fn_{pair[0]}" in r.parent_symbol + and f"fn_{pair[1]}" in r.symbol + ] + ), f"{pair} not found" assert len(expected) == len(deps) # noinspection PyTypeChecker - validate((("e", "c"), ("e", "d"), ("e", "G")), - DependencyGraph._rules_until_first_memento_fn(fn_e)) + validate( + (("e", "c"), ("e", "d"), ("e", "G")), + DependencyGraph._rules_until_first_memento_fn(fn_e), + ) # noinspection PyTypeChecker - validate((("c", "b"), ("b", "e"), ("b", "f"), ("b", "a")), - DependencyGraph._rules_until_first_memento_fn(fn_c)) + validate( + (("c", "b"), ("b", "e"), ("b", "f"), ("b", "a")), + DependencyGraph._rules_until_first_memento_fn(fn_c), + ) def test_complex_graph_verbose(self): # noinspection PyTypeChecker @@ -155,35 +182,59 @@ def test_complex_graph_verbose(self): @staticmethod def do_test_complex_graph_verbose(a, c, d, e, f): deps = c.dependencies(verbose=True) # type: DependencyGraph - expected_df = pd.DataFrame(data=[ - {"src": a.qualified_name_without_version, - "target": e.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": "twosigma.memento.runner_test:fn_b", - "target": a.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": "twosigma.memento.runner_test:fn_b", - "target": e.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": "twosigma.memento.runner_test:fn_b", - "target": f.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": c.qualified_name_without_version, - "target": "twosigma.memento.runner_test:fn_b", - "type": "Function"}, - {"src": d.qualified_name_without_version, - "target": a.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": e.qualified_name_without_version, - "target": "fn_G", - "type": "GlobalVariable"}, - {"src": e.qualified_name_without_version, - "target": c.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": e.qualified_name_without_version, - "target": d.qualified_name_without_version, - "type": "MementoFunction"} - ]).sort_values(by=["src", "target"]).reset_index(drop=True) + expected_df = ( + pd.DataFrame( + data=[ + { + "src": a.qualified_name_without_version, + "target": e.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": "twosigma.memento.runner_test:fn_b", + "target": a.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": "twosigma.memento.runner_test:fn_b", + "target": e.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": "twosigma.memento.runner_test:fn_b", + "target": f.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": c.qualified_name_without_version, + "target": "twosigma.memento.runner_test:fn_b", + "type": "Function", + }, + { + "src": d.qualified_name_without_version, + "target": a.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": e.qualified_name_without_version, + "target": "fn_G", + "type": "GlobalVariable", + }, + { + "src": e.qualified_name_without_version, + "target": c.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": e.qualified_name_without_version, + "target": d.qualified_name_without_version, + "type": "MementoFunction", + }, + ] + ) + .sort_values(by=["src", "target"]) + .reset_index(drop=True) + ) actual_df = deps.df().sort_values(by=["src", "target"]).reset_index(drop=True) pd.testing.assert_frame_equal(expected_df, actual_df) @@ -194,28 +245,48 @@ def test_complex_graph(self): @staticmethod def do_test_complex_graph(a, c, d, e, f): deps = c.dependencies() # type: DependencyGraph - expected_df = pd.DataFrame(data=[ - {"src": a.qualified_name_without_version, - "target": e.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": c.qualified_name_without_version, - "target": a.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": c.qualified_name_without_version, - "target": e.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": c.qualified_name_without_version, - "target": f.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": d.qualified_name_without_version, - "target": a.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": e.qualified_name_without_version, - "target": c.qualified_name_without_version, - "type": "MementoFunction"}, - {"src": e.qualified_name_without_version, - "target": d.qualified_name_without_version, - "type": "MementoFunction"}, - ]).sort_values(by=["src", "target"]).reset_index(drop=True) + expected_df = ( + pd.DataFrame( + data=[ + { + "src": a.qualified_name_without_version, + "target": e.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": c.qualified_name_without_version, + "target": a.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": c.qualified_name_without_version, + "target": e.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": c.qualified_name_without_version, + "target": f.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": d.qualified_name_without_version, + "target": a.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": e.qualified_name_without_version, + "target": c.qualified_name_without_version, + "type": "MementoFunction", + }, + { + "src": e.qualified_name_without_version, + "target": d.qualified_name_without_version, + "type": "MementoFunction", + }, + ] + ) + .sort_values(by=["src", "target"]) + .reset_index(drop=True) + ) actual_df = deps.df().sort_values(by=["src", "target"]).reset_index(drop=True) pd.testing.assert_frame_equal(expected_df, actual_df) diff --git a/tests/test_exception.py b/tests/test_exception.py index 20891ea..e4c664f 100644 --- a/tests/test_exception.py +++ b/tests/test_exception.py @@ -21,7 +21,9 @@ class TestMementoException: """Class to test MementoException.""" def test_attr(self): - e = MementoException("python::builtins:ValueError", "test message", "test stack") + e = MementoException( + "python::builtins:ValueError", "test message", "test stack" + ) assert "python::builtins:ValueError" == e.exception_name assert "test message" == e.message assert "test stack" == e.stack_trace @@ -31,7 +33,9 @@ def test_validate_format_exception_name(self): MementoException("bad_format", "message", "stack") def test_to_exception(self): - e = MementoException("python::builtins:ValueError", "test message", "stack trace") + e = MementoException( + "python::builtins:ValueError", "test message", "stack trace" + ) assert isinstance(e.to_exception(), ValueError) assert -1 != str(e.to_exception()).find("test message") @@ -42,7 +46,8 @@ def test_from_exception(self): assert -1 != me.message.find("test message") def test_exception_from_another_language(self): - e = MementoException("java::java.lang.IllegalArgumentException", "test message", - "stack trace") + e = MementoException( + "java::java.lang.IllegalArgumentException", "test message", "stack trace" + ) assert isinstance(e.to_exception(), MementoException) assert -1 != str(e.to_exception()).find("test message") diff --git a/tests/test_external.py b/tests/test_external.py index 52b2a75..e0b38d0 100644 --- a/tests/test_external.py +++ b/tests/test_external.py @@ -26,12 +26,20 @@ def __init__(self, fn_reference: FunctionReference, context: InvocationContext): super().__init__(fn_reference, context, "test", hash_rules=list()) def clone_with( - self, fn: Callable = None, src_fn: Callable = None, cluster_name: str = None, - version: str = None, calculated_version: str = None, context: InvocationContext = None, - partial_args: Tuple[Any] = None, partial_kwargs: Dict[str, Any] = None, - auto_dependencies: bool = True, - dependencies: List[Union[str, MementoFunctionType]] = None, - version_code_hash: str = None, version_salt: str = None) -> MementoFunctionType: + self, + fn: Callable = None, + src_fn: Callable = None, + cluster_name: str = None, + version: str = None, + calculated_version: str = None, + context: InvocationContext = None, + partial_args: Tuple[Any] = None, + partial_kwargs: Dict[str, Any] = None, + auto_dependencies: bool = True, + dependencies: List[Union[str, MementoFunctionType]] = None, + version_code_hash: str = None, + version_salt: str = None, + ) -> MementoFunctionType: pass diff --git a/tests/test_memento_function.py b/tests/test_memento_function.py index 270a4ad..dc208dd 100644 --- a/tests/test_memento_function.py +++ b/tests/test_memento_function.py @@ -27,8 +27,14 @@ import twosigma.memento as m from pandas.testing import assert_frame_equal -from twosigma.memento import Environment, ConfigurationRepository, FunctionCluster, Memento, \ - FunctionReference, MementoFunction +from twosigma.memento import ( + Environment, + ConfigurationRepository, + FunctionCluster, + Memento, + FunctionReference, + MementoFunction, +) from twosigma.memento.metadata import ResultType from twosigma.memento.partition import InMemoryPartition from twosigma.memento.code_hash import fn_code_hash @@ -126,8 +132,7 @@ def fn_test_memoize_7(a): def fn_test_memoize_df(): global _called _called = True - return pd.DataFrame([{"name": "a", "value": 1}, - {"name": "b", "value": 2}]) + return pd.DataFrame([{"name": "a", "value": 1}, {"name": "b", "value": 2}]) # Turn off auto-dependencies, else global variable _called will introduce version change @@ -413,7 +418,7 @@ def setup_method(self): self.env_before = m.Environment.get() self.env_dir = tempfile.mkdtemp(prefix="memoizeTest") env_file = "{}/env.json".format(self.env_dir) - with open(env_file, 'w') as f: + with open(env_file, "w") as f: print("""{"name": "test"}""", file=f) m.Environment.set(env_file) _called = False @@ -425,21 +430,27 @@ def teardown_method(self): def test_fn_reference(self): ref0a = m.FunctionReference(fn_test_memoize_0a) qual_name_0a = "tests.test_memento_function:fn_test_memoize_0a" - assert qual_name_0a == ref0a.qualified_name[0:ref0a.qualified_name.find("#")] + assert qual_name_0a == ref0a.qualified_name[0 : ref0a.qualified_name.find("#")] assert ref0a.cluster_name is None ref0a = m.FunctionReference.from_qualified_name(qual_name_0a) - assert qual_name_0a == ref0a.qualified_name[0:ref0a.qualified_name.find("#")] + assert qual_name_0a == ref0a.qualified_name[0 : ref0a.qualified_name.find("#")] assert ref0a.cluster_name is None ref0b = m.FunctionReference(fn_test_memoize_0b) qual_name_0b = "tests.test_memento_function:fn_test_memoize_0b" - assert "a::" + qual_name_0b == ref0b.qualified_name[0:ref0b.qualified_name.find("#")] + assert ( + "a::" + qual_name_0b + == ref0b.qualified_name[0 : ref0b.qualified_name.find("#")] + ) assert "a" == ref0b.cluster_name ref0b = m.FunctionReference.from_qualified_name("a::" + qual_name_0b) assert ref0b.memento_fn.version() is not None - assert "a::" + qual_name_0b == ref0b.qualified_name[0:ref0b.qualified_name.find("#")] + assert ( + "a::" + qual_name_0b + == ref0b.qualified_name[0 : ref0b.qualified_name.find("#")] + ) assert "a" == ref0b.cluster_name def test_memoize(self): @@ -579,8 +590,7 @@ def test_dataframe(self): fn_test_memoize_df.forget_all() - expected = pd.DataFrame([{"name": "a", "value": 1}, - {"name": "b", "value": 2}]) + expected = pd.DataFrame([{"name": "a", "value": 1}, {"name": "b", "value": 2}]) assert not _called @@ -627,20 +637,59 @@ def test_return_type_recorded(self): assert 2 == partition.get("b") assert _called - assert ResultType.null == fn_test_memoize_null.memento().invocation_metadata.result_type - assert ResultType.boolean == fn_test_memoize_boolean.memento().invocation_metadata.result_type - assert ResultType.string == fn_test_memoize_str.memento().invocation_metadata.result_type - assert ResultType.binary == fn_test_memoize_bin.memento().invocation_metadata.result_type - assert ResultType.number == fn_test_memoize_int.memento().invocation_metadata.result_type - assert ResultType.number == fn_test_memoize_float.memento().invocation_metadata.result_type - assert ResultType.date == fn_test_memoize_date.memento().invocation_metadata.result_type - assert ResultType.timestamp == fn_test_memoize_datetime.memento().invocation_metadata.result_type - assert ResultType.dictionary == fn_test_memoize_dict.memento().invocation_metadata.result_type + assert ( + ResultType.null + == fn_test_memoize_null.memento().invocation_metadata.result_type + ) + assert ( + ResultType.boolean + == fn_test_memoize_boolean.memento().invocation_metadata.result_type + ) + assert ( + ResultType.string + == fn_test_memoize_str.memento().invocation_metadata.result_type + ) + assert ( + ResultType.binary + == fn_test_memoize_bin.memento().invocation_metadata.result_type + ) + assert ( + ResultType.number + == fn_test_memoize_int.memento().invocation_metadata.result_type + ) + assert ( + ResultType.number + == fn_test_memoize_float.memento().invocation_metadata.result_type + ) + assert ( + ResultType.date + == fn_test_memoize_date.memento().invocation_metadata.result_type + ) + assert ( + ResultType.timestamp + == fn_test_memoize_datetime.memento().invocation_metadata.result_type + ) + assert ( + ResultType.dictionary + == fn_test_memoize_dict.memento().invocation_metadata.result_type + ) # note for future test maintainers: array would also be acceptable - assert ResultType.list_result == fn_test_memoize_list.memento().invocation_metadata.result_type - assert ResultType.series == fn_test_memoize_series.memento().invocation_metadata.result_type - assert ResultType.data_frame == fn_test_memoize_data_frame.memento().invocation_metadata.result_type - assert ResultType.partition == fn_test_memoize_partition.memento().invocation_metadata.result_type + assert ( + ResultType.list_result + == fn_test_memoize_list.memento().invocation_metadata.result_type + ) + assert ( + ResultType.series + == fn_test_memoize_series.memento().invocation_metadata.result_type + ) + assert ( + ResultType.data_frame + == fn_test_memoize_data_frame.memento().invocation_metadata.result_type + ) + assert ( + ResultType.partition + == fn_test_memoize_partition.memento().invocation_metadata.result_type + ) def test_robust_handling_of_io_error_during_memoization(self): global _called @@ -653,9 +702,23 @@ def test_robust_handling_of_io_error_during_memoization(self): try: # Use a path that is illegal on both Linux and Windows m.Environment.set( - Environment(name="bad_env", repos=[ConfigurationRepository(name="bad_repo", clusters={ - "bad_cluster": FunctionCluster(name="bad_cluster", - storage=FilesystemStorageBackend(path="/proc/test<"))})])) + Environment( + name="bad_env", + repos=[ + ConfigurationRepository( + name="bad_repo", + clusters={ + "bad_cluster": FunctionCluster( + name="bad_cluster", + storage=FilesystemStorageBackend( + path="/proc/test<" + ), + ) + }, + ) + ], + ) + ) _called = False assert "memoize me" == fn_test_memoize_string() @@ -679,15 +742,25 @@ def test_validate_args(self): sample_a(datetime.date.today()) sample_a(datetime.datetime.now(datetime.timezone.utc)) sample_a([1, 2, 3]) - multi_list = [None, True, "a", 1, 2.0, datetime.date.today(), - datetime.datetime.now(datetime.timezone.utc), - [1, 2, 3], {"a": "b", "c": "d"}] + multi_list = [ + None, + True, + "a", + 1, + 2.0, + datetime.date.today(), + datetime.datetime.now(datetime.timezone.utc), + [1, 2, 3], + {"a": "b", "c": "d"}, + ] sample_a(multi_list) sample_a({"a": "b", "c": "d"}) sample_a({"a": "b", "c": multi_list}) # Memento cannot serialize a function as a return value fn_test_fn_with_arg_noreturn(fn_test_memoize_0a) - fn_test_fn_with_arg_noreturn([1, 2, fn_test_memoize_0a, {"a": fn_test_memoize_0a}]) + fn_test_fn_with_arg_noreturn( + [1, 2, fn_test_memoize_0a, {"a": fn_test_memoize_0a}] + ) # These should not be accepted class CustomClass: @@ -822,8 +895,10 @@ def test_list_memoized_functions(self): sample_b(2) fns = m.list_memoized_functions() assert 2 == len(fns) - expected_names = {FunctionReference(sample_a).qualified_name, - FunctionReference(sample_b).qualified_name} + expected_names = { + FunctionReference(sample_a).qualified_name, + FunctionReference(sample_b).qualified_name, + } actual_names = set([x.qualified_name for x in fns]) assert expected_names == actual_names @@ -931,7 +1006,9 @@ def test_memento_forget_exceptions_recursively(self): def test_with_args(self): ref_a = fn_double.fn_reference().with_args(3) - assert fn_double.fn_reference().qualified_name == ref_a.fn_reference.qualified_name + assert ( + fn_double.fn_reference().qualified_name == ref_a.fn_reference.qualified_name + ) assert (3,) == ref_a.args assert {} == ref_a.kwargs @@ -950,29 +1027,18 @@ def test_map_over_range(self): def test_call_batch(self): # Make sure call_batch raises an error if kwarg keys are not strings - kwarg_list = [ - {"x": 1, "y": 2}, - {1: 2, "b": 4} - ] + kwarg_list = [{"x": 1, "y": 2}, {1: 2, "b": 4}] with pytest.raises(TypeError): add.call_batch(kwarg_list) # Try a successful call: - kwarg_list = [ - {"x": 1, "y": 2}, - {"x": 3, "y": 4} - ] + kwarg_list = [{"x": 1, "y": 2}, {"x": 3, "y": 4}] results = add.call_batch(kwarg_list) expected = [3, 7] assert expected == results def test_call_batch_raise_first_exception(self): - kwarg_list = [ - {"x": 0}, - {"x": 1}, - {"x": 2}, - {"x": 3} - ] + kwarg_list = [{"x": 0}, {"x": 1}, {"x": 2}, {"x": 3}] with pytest.raises(ValueError): fn_raise_on_odd.call_batch(kwarg_list) @@ -985,21 +1051,21 @@ def test_call_batch_raise_first_exception(self): def test_monitor_progress(self): # Try a successful call: - kwarg_list = [ - {"x": 1, "y": 2}, - {"x": 3, "y": 4} - ] + kwarg_list = [{"x": 1, "y": 2}, {"x": 3, "y": 4}] results = add.monitor_progress().call_batch(kwarg_list) expected = [3, 7] assert expected == results def test_does_not_accept_local_functions(self): try: + @m.memento_function def local_fn(): pass - pytest.fail("Should not have allowed declaration of a local memento function") + pytest.fail( + "Should not have allowed declaration of a local memento function" + ) except ValueError: pass @@ -1048,8 +1114,10 @@ def test_new_version_invalidates_dependent_memoized_data(self): finally: self.redefine_version(fn_current_time, "1") - @unittest.skipUnless((sys.version_info.major, sys.version_info.minor) == (3, 11), - "Code hash test requires Python is 3.11") + @unittest.skipUnless( + (sys.version_info.major, sys.version_info.minor) == (3, 11), + "Code hash test requires Python is 3.11", + ) def test_code_hash(self): """ Test the stability of code hashing @@ -1081,8 +1149,9 @@ def test_version_code_hash_attribute(self): assert f1.version() == f2.version() # Change the dependency of f2 and make sure version does change - f2 = m.memento_function(fn_unbound_2, version_code_hash=f1.code_hash, - dependencies=[sample_a]) + f2 = m.memento_function( + fn_unbound_2, version_code_hash=f1.code_hash, dependencies=[sample_a] + ) assert f1.version() != f2.version() finally: fn_unbound_2.__name__ = orig_fn_unbound_2_name @@ -1110,23 +1179,21 @@ def test_list(self): add(1, 1) df = add.list() - expected = DataFrame(data=[ - {"x": 1, "y": 1, "result_type": "number"} - ]) + expected = DataFrame(data=[{"x": 1, "y": 1, "result_type": "number"}]) assert_frame_equal(expected, df) add(1, 2) df = add.list() - expected = DataFrame(data=[ - {"x": 1, "y": 1, "result_type": "number"}, - {"x": 1, "y": 2, "result_type": "number"} - ]) + expected = DataFrame( + data=[ + {"x": 1, "y": 1, "result_type": "number"}, + {"x": 1, "y": 2, "result_type": "number"}, + ] + ) assert_frame_equal(expected, df) df = add.list(y=2) - expected = DataFrame(data=[ - {"x": 1, "y": 2, "result_type": "number"} - ]) + expected = DataFrame(data=[{"x": 1, "y": 2, "result_type": "number"}]) assert_frame_equal(expected, df) assert add.list(y=3) is None diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 93172b2..6d7b59c 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -28,7 +28,8 @@ class TestMetadata: def test_result_type_from_object(self): assert ResultType.exception == ResultType.from_object( - MementoException("python::builtins:ValueError", "message", "stack_trace")) + MementoException("python::builtins:ValueError", "message", "stack_trace") + ) assert ResultType.null == ResultType.from_object(None) assert ResultType.boolean == ResultType.from_object(True) assert ResultType.string == ResultType.from_object("foo") @@ -37,22 +38,41 @@ def test_result_type_from_object(self): assert ResultType.number == ResultType.from_object(1.2) with pytest.raises(ValueError): ResultType.from_object((3 + 4j)) - assert ResultType.timestamp == ResultType.from_object(datetime.datetime.now(datetime.timezone.utc)) + assert ResultType.timestamp == ResultType.from_object( + datetime.datetime.now(datetime.timezone.utc) + ) assert ResultType.date == ResultType.from_object(datetime.date.today()) assert ResultType.list_result == ResultType.from_object([1, 2]) assert ResultType.dictionary == ResultType.from_object({"a": "b"}) assert ResultType.index == ResultType.from_object(pd.Index([1, 2])) assert ResultType.series == ResultType.from_object(pd.Series([1, 2])) - assert ResultType.data_frame == ResultType.from_object(pd.DataFrame({"a": [1, 2]})) + assert ResultType.data_frame == ResultType.from_object( + pd.DataFrame({"a": [1, 2]}) + ) assert ResultType.array_boolean == ResultType.from_object( - np.array([True, False], dtype=np.dtype("?"))) - assert ResultType.array_int8 == ResultType.from_object(np.array([1, 2], dtype=np.dtype("i1"))) - assert ResultType.array_int16 == ResultType.from_object(np.array([1, 2], dtype=np.dtype("i2"))) - assert ResultType.array_int32 == ResultType.from_object(np.array([1, 2], dtype=np.dtype("i4"))) - assert ResultType.array_int64 == ResultType.from_object(np.array([1, 2], dtype=np.dtype("i8"))) - assert ResultType.array_float32 == ResultType.from_object(np.array([1, 2], dtype=np.dtype("f4"))) - assert ResultType.array_float64 == ResultType.from_object(np.array([1, 2], dtype=np.dtype("f8"))) - assert ResultType.partition == ResultType.from_object(InMemoryPartition({"a": 1})) + np.array([True, False], dtype=np.dtype("?")) + ) + assert ResultType.array_int8 == ResultType.from_object( + np.array([1, 2], dtype=np.dtype("i1")) + ) + assert ResultType.array_int16 == ResultType.from_object( + np.array([1, 2], dtype=np.dtype("i2")) + ) + assert ResultType.array_int32 == ResultType.from_object( + np.array([1, 2], dtype=np.dtype("i4")) + ) + assert ResultType.array_int64 == ResultType.from_object( + np.array([1, 2], dtype=np.dtype("i8")) + ) + assert ResultType.array_float32 == ResultType.from_object( + np.array([1, 2], dtype=np.dtype("f4")) + ) + assert ResultType.array_float64 == ResultType.from_object( + np.array([1, 2], dtype=np.dtype("f8")) + ) + assert ResultType.partition == ResultType.from_object( + InMemoryPartition({"a": 1}) + ) def test_result_type_from_annotation(self): assert ResultType.exception == ResultType.from_annotation(MementoException) @@ -76,4 +96,6 @@ def test_result_type_from_annotation(self): assert ResultType.data_frame == ResultType.from_annotation(pd.DataFrame) # Not sure how to represent np.ndarray type hints assert ResultType.partition == ResultType.from_annotation(InMemoryPartition) - assert ResultType.memento_function == ResultType.from_annotation(MementoFunction) + assert ResultType.memento_function == ResultType.from_annotation( + MementoFunction + ) diff --git a/tests/test_reference.py b/tests/test_reference.py index cc69387..7d63a7d 100644 --- a/tests/test_reference.py +++ b/tests/test_reference.py @@ -24,8 +24,11 @@ from twosigma.memento import memento_function, Environment, MementoFunction from twosigma.memento.exception import DependencyNotFoundError from twosigma.memento.external import UnboundExternalMementoFunction -from twosigma.memento.reference import FunctionReference, FunctionReferenceWithArguments, \ - ArgumentHasher +from twosigma.memento.reference import ( + FunctionReference, + FunctionReferenceWithArguments, + ArgumentHasher, +) _called = False @@ -58,6 +61,7 @@ def _test_dep_depends_on_nonexistent_auto() -> int: dependency auto-detection logic, but does not exist as a function. """ + def internal_fn(): return 42 @@ -116,7 +120,7 @@ def setup_method(self): self.env_before = Environment.get() self.env_dir = tempfile.mkdtemp(prefix="memoizeTest") env_file = "{}/env.json".format(self.env_dir) - with open(env_file, 'w') as f: + with open(env_file, "w") as f: print("""{"name": "test"}""", file=f) Environment.set(env_file) _called = False @@ -126,7 +130,9 @@ def teardown_method(self): Environment.set(self.env_before) def test_parse_qualified_name(self): - parts = FunctionReference.parse_qualified_name("cluster::module.name:fn_name#hash") + parts = FunctionReference.parse_qualified_name( + "cluster::module.name:fn_name#hash" + ) assert "cluster" == parts["cluster"] assert "module.name" == parts["module"] assert "fn_name" == parts["function"] @@ -134,27 +140,48 @@ def test_parse_qualified_name(self): def test_reference(self): ref = FunctionReference.from_qualified_name("tests.test_reference:_test_method") - assert "tests.test_reference:_test_method" == ref.qualified_name[0:ref.qualified_name.find("#")] - assert "tests.test_reference:_test_method" == \ - ref.qualified_name_without_cluster[0:ref.qualified_name_without_cluster.find("#")] + assert ( + "tests.test_reference:_test_method" + == ref.qualified_name[0 : ref.qualified_name.find("#")] + ) + assert ( + "tests.test_reference:_test_method" + == ref.qualified_name_without_cluster[ + 0 : ref.qualified_name_without_cluster.find("#") + ] + ) assert "tests.test_reference" == ref.module assert "_test_method" == ref.function_name assert ref.cluster_name is None assert ref.memento_fn.version() is not None ref = FunctionReference(_test_method) - assert "tests.test_reference:_test_method" == ref.qualified_name[0:ref.qualified_name.find("#")] - assert "tests.test_reference:_test_method" == \ - ref.qualified_name_without_cluster[0:ref.qualified_name_without_cluster.find("#")] + assert ( + "tests.test_reference:_test_method" + == ref.qualified_name[0 : ref.qualified_name.find("#")] + ) + assert ( + "tests.test_reference:_test_method" + == ref.qualified_name_without_cluster[ + 0 : ref.qualified_name_without_cluster.find("#") + ] + ) assert "tests.test_reference" == ref.module assert "_test_method" == ref.function_name assert ref.cluster_name is None assert ref.memento_fn.version() is not None ref = FunctionReference(_test_method, cluster_name="cluster1") - assert "cluster1::tests.test_reference:_test_method" == ref.qualified_name[0:ref.qualified_name.find("#")] - assert "tests.test_reference:_test_method" == \ - ref.qualified_name_without_cluster[0:ref.qualified_name_without_cluster.find("#")] + assert ( + "cluster1::tests.test_reference:_test_method" + == ref.qualified_name[0 : ref.qualified_name.find("#")] + ) + assert ( + "tests.test_reference:_test_method" + == ref.qualified_name_without_cluster[ + 0 : ref.qualified_name_without_cluster.find("#") + ] + ) assert "tests.test_reference" == ref.module assert "_test_method" == ref.function_name assert "cluster1" == ref.cluster_name @@ -163,7 +190,10 @@ def test_reference(self): def test_version(self): ref = _test_method_2.fn_reference() assert "tests.test_reference:_test_method_2#2" == ref.qualified_name - assert "tests.test_reference:_test_method_2#2" == ref.qualified_name_without_cluster + assert ( + "tests.test_reference:_test_method_2#2" + == ref.qualified_name_without_cluster + ) assert "tests.test_reference" == ref.module assert "_test_method_2" == ref.function_name assert "2" == ref.memento_fn.version() @@ -171,15 +201,23 @@ def test_version(self): ref = FunctionReference(_test_method_2, cluster_name="cluster1", version="2") assert "cluster1::tests.test_reference:_test_method_2#2" == ref.qualified_name - assert "tests.test_reference:_test_method_2#2" == ref.qualified_name_without_cluster + assert ( + "tests.test_reference:_test_method_2#2" + == ref.qualified_name_without_cluster + ) assert "tests.test_reference" == ref.module assert "_test_method_2" == ref.function_name assert "2" == ref.memento_fn.version() assert ref.memento_fn.code_hash is None - ref = FunctionReference.from_qualified_name("tests.test_reference:_test_method_2#2") + ref = FunctionReference.from_qualified_name( + "tests.test_reference:_test_method_2#2" + ) assert "tests.test_reference:_test_method_2#2" == ref.qualified_name - assert "tests.test_reference:_test_method_2#2" == ref.qualified_name_without_cluster + assert ( + "tests.test_reference:_test_method_2#2" + == ref.qualified_name_without_cluster + ) assert "tests.test_reference" == ref.module assert "_test_method_2" == ref.function_name assert ref.cluster_name is None @@ -191,7 +229,8 @@ def test_version(self): # If the version does not match, check that reference is treated as external assert FunctionReference.from_qualified_name( "unknown_cluster::test_reference:_test_method_2#3", - parameter_names=["x", "y"]).external + parameter_names=["x", "y"], + ).external def test_find(self): ref = FunctionReference(_test_method, cluster_name="cluster1") @@ -204,16 +243,18 @@ def test_find(self): def test_reference_with_args(self): ref1 = FunctionReference(_test_method_2, cluster_name="cluster1") - ref1a = FunctionReferenceWithArguments(fn_reference=ref1, args=(1,), kwargs={"y": 2}, - context_args={"z": 3}) + ref1a = FunctionReferenceWithArguments( + fn_reference=ref1, args=(1,), kwargs={"y": 2}, context_args={"z": 3} + ) assert ref1.qualified_name == ref1a.fn_reference.qualified_name assert (1,) == ref1a.args assert {"y": 2} == ref1a.kwargs assert {"z": 3} == ref1a.context_args - ref1b = FunctionReferenceWithArguments(fn_reference=ref1, args=(), kwargs={"x": 1, "y": 2}, - context_args={"z": 3}) + ref1b = FunctionReferenceWithArguments( + fn_reference=ref1, args=(), kwargs={"x": 1, "y": 2}, context_args={"z": 3} + ) assert ref1.qualified_name == ref1b.fn_reference.qualified_name assert () == ref1b.args @@ -224,27 +265,41 @@ def test_reference_with_args(self): def test_context_args_affect_hash(self): ref1 = FunctionReference(_test_method_2, cluster_name="cluster1") - ref1a = FunctionReferenceWithArguments(fn_reference=ref1, args=(1,), kwargs={"y": 2}) - ref1b = FunctionReferenceWithArguments(fn_reference=ref1, args=(1,), kwargs={"y": 2}, - context_args={"z": 3}) + ref1a = FunctionReferenceWithArguments( + fn_reference=ref1, args=(1,), kwargs={"y": 2} + ) + ref1b = FunctionReferenceWithArguments( + fn_reference=ref1, args=(1,), kwargs={"y": 2}, context_args={"z": 3} + ) assert ref1a.arg_hash != ref1b.arg_hash def test_compute_args(self): ref1 = FunctionReference(_test_method_2, cluster_name="cluster1") - hash1 = FunctionReferenceWithArguments(fn_reference=ref1, args=(1, 2), kwargs={}).arg_hash - hash2 = FunctionReferenceWithArguments(fn_reference=ref1, args=(2, 3), kwargs={}).arg_hash + hash1 = FunctionReferenceWithArguments( + fn_reference=ref1, args=(1, 2), kwargs={} + ).arg_hash + hash2 = FunctionReferenceWithArguments( + fn_reference=ref1, args=(2, 3), kwargs={} + ).arg_hash assert hash1 != hash2 - hash3 = FunctionReferenceWithArguments(fn_reference=ref1, args=(2, 3), kwargs={}).arg_hash + hash3 = FunctionReferenceWithArguments( + fn_reference=ref1, args=(2, 3), kwargs={} + ).arg_hash assert hash2 == hash3 hash4 = FunctionReferenceWithArguments( - fn_reference=_test_method_2.partial(2).fn_reference(), args=(3,), kwargs={}).arg_hash + fn_reference=_test_method_2.partial(2).fn_reference(), args=(3,), kwargs={} + ).arg_hash assert hash3 == hash4 # Test that args are mapped to kwargs - hash5 = FunctionReferenceWithArguments(fn_reference=ref1, args=(2,), kwargs={"y": 3}).arg_hash + hash5 = FunctionReferenceWithArguments( + fn_reference=ref1, args=(2,), kwargs={"y": 3} + ).arg_hash assert hash2 == hash5 hash6 = FunctionReferenceWithArguments( fn_reference=_test_method_2.partial(x=2).fn_reference(), - args=(), kwargs={"y": 3}).arg_hash + args=(), + kwargs={"y": 3}, + ).arg_hash assert hash2 == hash6 def test_timestamp_arg(self): @@ -253,7 +308,10 @@ def test_timestamp_arg(self): t2 = t1.to_pydatetime() assert fn1(t1) == fn1(t2) - assert fn1.fn_reference().with_args(t1).arg_hash == fn1.fn_reference().with_args(t2).arg_hash + assert ( + fn1.fn_reference().with_args(t1).arg_hash + == fn1.fn_reference().with_args(t2).arg_hash + ) def test_arg_hasher_normalize(self): assert ArgumentHasher.normalize(None) is None @@ -262,13 +320,20 @@ def test_arg_hasher_normalize(self): assert 42 == ArgumentHasher.normalize(42) f = cast(float, ArgumentHasher.normalize(123.45)) assert pytest.approx(123.45) == f - assert datetime.date(2019, 4, 3) == ArgumentHasher.normalize(datetime.date(2019, 4, 3)) - assert datetime.datetime(2019, 4, 3, 12, 34, 56) == \ - ArgumentHasher.normalize(datetime.datetime(2019, 4, 3, 12, 34, 56)) - assert datetime.datetime(2019, 4, 3, 12, 34, 56, 500000) == \ - ArgumentHasher.normalize(datetime.datetime(2019, 4, 3, 12, 34, 56, 500000)) - assert datetime.datetime(2019, 4, 3, 12, 34, 56, tzinfo=pytz.UTC) == \ - ArgumentHasher.normalize(datetime.datetime(2019, 4, 3, 12, 34, 56, tzinfo=pytz.UTC)) + assert datetime.date(2019, 4, 3) == ArgumentHasher.normalize( + datetime.date(2019, 4, 3) + ) + assert datetime.datetime(2019, 4, 3, 12, 34, 56) == ArgumentHasher.normalize( + datetime.datetime(2019, 4, 3, 12, 34, 56) + ) + assert datetime.datetime( + 2019, 4, 3, 12, 34, 56, 500000 + ) == ArgumentHasher.normalize(datetime.datetime(2019, 4, 3, 12, 34, 56, 500000)) + assert datetime.datetime( + 2019, 4, 3, 12, 34, 56, tzinfo=pytz.UTC + ) == ArgumentHasher.normalize( + datetime.datetime(2019, 4, 3, 12, 34, 56, tzinfo=pytz.UTC) + ) normalized_list = cast(list, ArgumentHasher.normalize([1, 2, 3])) assert [1, 2, 3] == normalized_list in_dict = {"a": 1, "b": 2, "c": [1, 2, 3], "d": {"e": "f"}} @@ -288,38 +353,56 @@ def test_arg_hasher_encode(self): assert 42 == ArgumentHasher._encode(42) f = cast(float, ArgumentHasher._encode(123.45)) assert pytest.approx(123.45) == f - assert {"_mementoType": "date", "iso8601": "2019-04-03"} == ArgumentHasher._encode(datetime.date(2019, 4, 3)) - assert {"_mementoType": "datetime", "iso8601": "2019-04-03T12:34:56"} == \ - ArgumentHasher._encode(datetime.datetime(2019, 4, 3, 12, 34, 56)) - assert {"_mementoType": "datetime", "iso8601": "2019-04-03T12:34:56.500000"} == \ - ArgumentHasher._encode(datetime.datetime(2019, 4, 3, 12, 34, 56, 500000)) - assert {"_mementoType": "datetime", "iso8601": "2019-04-03T12:34:56+00:00"} == \ - ArgumentHasher._encode(datetime.datetime(2019, 4, 3, 12, 34, 56, tzinfo=pytz.UTC)) + assert { + "_mementoType": "date", + "iso8601": "2019-04-03", + } == ArgumentHasher._encode(datetime.date(2019, 4, 3)) + assert { + "_mementoType": "datetime", + "iso8601": "2019-04-03T12:34:56", + } == ArgumentHasher._encode(datetime.datetime(2019, 4, 3, 12, 34, 56)) + assert { + "_mementoType": "datetime", + "iso8601": "2019-04-03T12:34:56.500000", + } == ArgumentHasher._encode(datetime.datetime(2019, 4, 3, 12, 34, 56, 500000)) + assert { + "_mementoType": "datetime", + "iso8601": "2019-04-03T12:34:56+00:00", + } == ArgumentHasher._encode( + datetime.datetime(2019, 4, 3, 12, 34, 56, tzinfo=pytz.UTC) + ) encoded_list = cast(list, ArgumentHasher._encode([1, 2, 3])) assert [1, 2, 3] == encoded_list in_dict = {"a": 1, "b": 2, "c": [1, 2, 3], "d": {"e": "f"}} d = cast(dict, ArgumentHasher._encode(in_dict)) assert in_dict == d - assert {"_mementoType": "FunctionReference", - "parameterNames": ["a"], - "qualifiedName": fn1.fn_reference().qualified_name, - "partialArgs": None, - "partialKwargs": {}} == ArgumentHasher._encode(fn1) - assert {"_mementoType": "FunctionReference", - "parameterNames": ["a"], - "qualifiedName": fn1.partial(a=7).fn_reference().qualified_name, - "partialArgs": None, - "partialKwargs": {"a": 7}} == ArgumentHasher._encode(fn1.partial(a=7)) + assert { + "_mementoType": "FunctionReference", + "parameterNames": ["a"], + "qualifiedName": fn1.fn_reference().qualified_name, + "partialArgs": None, + "partialKwargs": {}, + } == ArgumentHasher._encode(fn1) + assert { + "_mementoType": "FunctionReference", + "parameterNames": ["a"], + "qualifiedName": fn1.partial(a=7).fn_reference().qualified_name, + "partialArgs": None, + "partialKwargs": {"a": 7}, + } == ArgumentHasher._encode(fn1.partial(a=7)) def test_arg_hasher_normalized_json(self): assert "null" == ArgumentHasher._normalized_json(None) - assert "\"abc123\"" == ArgumentHasher._normalized_json("abc123") + assert '"abc123"' == ArgumentHasher._normalized_json("abc123") assert "42" == ArgumentHasher._normalized_json(42) assert "123.45" == ArgumentHasher._normalized_json(123.45) assert "true" == ArgumentHasher._normalized_json(True) assert "[1,2,3]" == ArgumentHasher._normalized_json([1, 2, 3]) in_dict = {"c": [1, 2, 3], "a": 1, "d": {"e": "f"}, "b": 2} - assert '{"a":1,"b":2,"c":[1,2,3],"d":{"e":"f"}}' == ArgumentHasher._normalized_json(in_dict) + assert ( + '{"a":1,"b":2,"c":[1,2,3],"d":{"e":"f"}}' + == ArgumentHasher._normalized_json(in_dict) + ) # noinspection SpellCheckingInspection def test_arg_hasher_stability(self): @@ -327,20 +410,35 @@ def test_arg_hasher_stability(self): Ensure arg hasher is stable from release to release. """ - assert "44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a" == ArgumentHasher.compute_hash({}) - assert "4cc66ba3de661a1a9319c150542555d384af8d81724a3b64dab2001d85df06df" == \ - ArgumentHasher.compute_hash({"a": 42}) - assert "d091f9c83c091f79652fe8786375b3fe4ce0861a56f5bfbafedbe431877ff0e8" == \ - ArgumentHasher.compute_hash({"a": None}) - assert "f7f851f4ba8ef23c0a3f2c20548bcc4bac24c46bc1c2c9332f7be4a695f22275" == \ - ArgumentHasher.compute_hash({"a": 123.45}) - assert "70621113b1eb7b8fbec0b1cb896e5f6adb32a9dbe08b5032c5edef18fca6002c" == \ - ArgumentHasher.compute_hash({"a": "abc123"}) - assert "730bc329ebcd24c6c9663ca4bb0e199a090dbf9d9d1058651d8560236abb1095" == \ - ArgumentHasher.compute_hash({"a": [1, 2, 3]}) + assert ( + "44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a" + == ArgumentHasher.compute_hash({}) + ) + assert ( + "4cc66ba3de661a1a9319c150542555d384af8d81724a3b64dab2001d85df06df" + == ArgumentHasher.compute_hash({"a": 42}) + ) + assert ( + "d091f9c83c091f79652fe8786375b3fe4ce0861a56f5bfbafedbe431877ff0e8" + == ArgumentHasher.compute_hash({"a": None}) + ) + assert ( + "f7f851f4ba8ef23c0a3f2c20548bcc4bac24c46bc1c2c9332f7be4a695f22275" + == ArgumentHasher.compute_hash({"a": 123.45}) + ) + assert ( + "70621113b1eb7b8fbec0b1cb896e5f6adb32a9dbe08b5032c5edef18fca6002c" + == ArgumentHasher.compute_hash({"a": "abc123"}) + ) + assert ( + "730bc329ebcd24c6c9663ca4bb0e199a090dbf9d9d1058651d8560236abb1095" + == ArgumentHasher.compute_hash({"a": [1, 2, 3]}) + ) in_dict = {"c": [1, 2, 3], "a": 1, "d": {"e": "f"}, "b": 2} - assert "52ea3bee36356ba2a31ff7931c95d69aee13f8f3d24727aa4fd9d456885ea00f" == \ - ArgumentHasher.compute_hash({"a": in_dict}) + assert ( + "52ea3bee36356ba2a31ff7931c95d69aee13f8f3d24727aa4fd9d456885ea00f" + == ArgumentHasher.compute_hash({"a": in_dict}) + ) def test_pickling(self): assert fn1 == loads(dumps(fn1.fn_reference())).memento_fn @@ -355,14 +453,18 @@ def test_required_dependencies_fail_if_not_present(self): orig_test_dep_nonexistent = _test_dep_nonexistent try: - _test_dep_depends_on_nonexistent.required_dependencies.add("_test_dep_nonexistent") + _test_dep_depends_on_nonexistent.required_dependencies.add( + "_test_dep_nonexistent" + ) del globals()["_test_dep_nonexistent"] MementoFunction.increment_global_fn_generation() with pytest.raises(DependencyNotFoundError): _test_dep_depends_on_nonexistent() finally: globals()["_test_dep_nonexistent"] = orig_test_dep_nonexistent - _test_dep_depends_on_nonexistent.required_dependencies.remove("_test_dep_nonexistent") + _test_dep_depends_on_nonexistent.required_dependencies.remove( + "_test_dep_nonexistent" + ) def test_detected_dependencies_do_not_fail_if_not_present(self): """Test that detected dependencies do not cause evaluation to fail if not found.""" @@ -410,10 +512,18 @@ def test_circular_references_work(self): assert 42 == _test_dep_circular_1() assert 42 == _test_dep_circular_2() - assert {_test_dep_circular_2} == _test_dep_circular_1.dependencies().transitive_memento_fn_dependencies() - assert {_test_dep_circular_2} == _test_dep_circular_1.dependencies().direct_memento_fn_dependencies() - assert {_test_dep_circular_1} == _test_dep_circular_2.dependencies().transitive_memento_fn_dependencies() - assert {_test_dep_circular_1} == _test_dep_circular_2.dependencies().transitive_memento_fn_dependencies() + assert { + _test_dep_circular_2 + } == _test_dep_circular_1.dependencies().transitive_memento_fn_dependencies() + assert { + _test_dep_circular_2 + } == _test_dep_circular_1.dependencies().direct_memento_fn_dependencies() + assert { + _test_dep_circular_1 + } == _test_dep_circular_2.dependencies().transitive_memento_fn_dependencies() + assert { + _test_dep_circular_1 + } == _test_dep_circular_2.dependencies().transitive_memento_fn_dependencies() # Test the version numbers are stable if there are circular dependencies v1a = _test_dep_circular_1.version() @@ -456,12 +566,17 @@ def _static_method_test(): def test_static_method_dependencies(self): assert 42 == self._static_method_test() - assert {_test_dep_a, _test_dep_c} == self._static_method_test.dependencies().\ - transitive_memento_fn_dependencies() + assert { + _test_dep_a, + _test_dep_c, + } == self._static_method_test.dependencies().transitive_memento_fn_dependencies() def test_qualified_name_without_version(self): ref1 = FunctionReference(_test_method_2, cluster_name="cluster1") - assert "cluster1::tests.test_reference:_test_method_2" == ref1.qualified_name_without_version + assert ( + "cluster1::tests.test_reference:_test_method_2" + == ref1.qualified_name_without_version + ) def test_parameter_names(self): fn_ref = _test_method_2.fn_reference() @@ -484,19 +599,20 @@ def test_auto_external_reference(self): # Unknown module ref1 = FunctionReference.from_qualified_name( - "unknown_cluster::unknown.module:fn1#" + v, parameter_names=["x", "y"]) + "unknown_cluster::unknown.module:fn1#" + v, parameter_names=["x", "y"] + ) assert ref1.external # Unknown function ref2 = FunctionReference.from_qualified_name( - "unknown_cluster::tests.test_reference:fn1a#" + v, - parameter_names=["a"]) + "unknown_cluster::tests.test_reference:fn1a#" + v, parameter_names=["a"] + ) assert ref2.external # Known function print(fn1.__module__) print(fn1.fn_reference()) ref3 = FunctionReference.from_qualified_name( - "unknown_cluster::tests.test_reference:fn1#" + v, - parameter_names=["a"]) + "unknown_cluster::tests.test_reference:fn1#" + v, parameter_names=["a"] + ) assert not ref3.external, str(fn1.fn_reference()) + " " + str(fn1.__module__) diff --git a/tests/test_resource.py b/tests/test_resource.py index aaa444e..4d38c5f 100644 --- a/tests/test_resource.py +++ b/tests/test_resource.py @@ -45,7 +45,7 @@ def setup_method(self): self.env_before = m.Environment.get() self.env_dir = tempfile.mkdtemp(prefix="resourceTest") env_file = "{}/env.json".format(self.env_dir) - with open(env_file, 'w') as f: + with open(env_file, "w") as f: print("""{"name": "test"}""", file=f) m.Environment.set(env_file) _called = False diff --git a/tests/test_runner_backend.py b/tests/test_runner_backend.py index 9c27886..31f54b6 100644 --- a/tests/test_runner_backend.py +++ b/tests/test_runner_backend.py @@ -23,10 +23,19 @@ from unittest import TestCase # noqa: F401 import twosigma.memento as m # noqa: F401 -from twosigma.memento.runner_test import runner_fn_test_1, runner_fn_test_apply_and_double, \ - runner_fn_test_add, runner_fn_test_sum_double_batch, fn_calls_undeclared_dependency, \ - fn_with_explicit_version_calls_undeclared_dependency, fn_returns_key_override_result, \ - runner_fn_calls_runner_fn_f, fn_recursive_a, fn_recursive_b, fn_recursive_c +from twosigma.memento.runner_test import ( + runner_fn_test_1, + runner_fn_test_apply_and_double, + runner_fn_test_add, + runner_fn_test_sum_double_batch, + fn_calls_undeclared_dependency, + fn_with_explicit_version_calls_undeclared_dependency, + fn_returns_key_override_result, + runner_fn_calls_runner_fn_f, + fn_recursive_a, + fn_recursive_b, + fn_recursive_c, +) class RunnerBackendTester(ABC): @@ -37,7 +46,7 @@ class RunnerBackendTester(ABC): """ - backend = None # type: m.RunnerBackend + backend = None # type: m.RunnerBackend def setup_method(self): pass @@ -48,8 +57,16 @@ def teardown_method(self): @staticmethod def test_memoize(): # This also tests serializing function references - assert {"a": 1, "b": 2, "c": [{"three": 3}], "d": None, "e": True, "f": True} == \ - runner_fn_test_1(1, c=[{"three": 3}], b=2, e=runner_fn_test_1, f=[{"a": runner_fn_test_1}]) + assert { + "a": 1, + "b": 2, + "c": [{"three": 3}], + "d": None, + "e": True, + "f": True, + } == runner_fn_test_1( + 1, c=[{"three": 3}], b=2, e=runner_fn_test_1, f=[{"a": runner_fn_test_1}] + ) @staticmethod def test_call_stack_invocations_tracked(): @@ -61,9 +78,14 @@ def test_call_stack_invocations_tracked(): runner_fn_test_apply_and_double(add_2_and_double, 2) memento = runner_fn_test_apply_and_double.memento(add_2_and_double, 2) - invocations = memento.invocation_metadata.invocations # type: List[FunctionReferenceWithArguments] + invocations = ( + memento.invocation_metadata.invocations + ) # type: List[FunctionReferenceWithArguments] assert 1 == len(invocations) - assert "runner_fn_test_apply_and_double" == invocations[0].fn_reference.function_name + assert ( + "runner_fn_test_apply_and_double" + == invocations[0].fn_reference.function_name + ) # Try with a batch run - make sure all invocations from the batch are recorded in invocations # runner_fn_test_sum_double_batch --> runner_fn_test_apply_and_double --> runner_fn_test_add @@ -71,8 +93,14 @@ def test_call_stack_invocations_tracked(): memento = runner_fn_test_sum_double_batch.memento(add_2_and_double, 10, 12) invocations = memento.invocation_metadata.invocations assert 2 == len(invocations) - assert "runner_fn_test_apply_and_double" == invocations[0].fn_reference.function_name - assert "runner_fn_test_apply_and_double" == invocations[1].fn_reference.function_name + assert ( + "runner_fn_test_apply_and_double" + == invocations[0].fn_reference.function_name + ) + assert ( + "runner_fn_test_apply_and_double" + == invocations[1].fn_reference.function_name + ) @staticmethod def test_correlation_id(): @@ -144,11 +172,17 @@ def test_memento_dependencies(): assert mem_a is not None assert 2 == len(mem_a.invocation_metadata.invocations) for i in mem_a.invocation_metadata.invocations: - assert fn_recursive_b.fn_reference().qualified_name == i.fn_reference.qualified_name + assert ( + fn_recursive_b.fn_reference().qualified_name + == i.fn_reference.qualified_name + ) mem_b = i.fn_reference.memento_fn.memento(x=i.kwargs["x"]) # type: Memento i_b = mem_b.invocation_metadata assert 1 == len(i_b.invocations) - assert fn_recursive_c.fn_reference().qualified_name == i_b.invocations[0].fn_reference.qualified_name + assert ( + fn_recursive_c.fn_reference().qualified_name + == i_b.invocations[0].fn_reference.qualified_name + ) @staticmethod def test_memento_graph(): diff --git a/tests/test_runner_local.py b/tests/test_runner_local.py index bb22cd1..4e34726 100644 --- a/tests/test_runner_local.py +++ b/tests/test_runner_local.py @@ -21,11 +21,21 @@ import twosigma.memento as m -from twosigma.memento import RunnerBackend, Environment, ConfigurationRepository, FunctionCluster # noqa: F401 +from twosigma.memento import ( + RunnerBackend, + Environment, + ConfigurationRepository, + FunctionCluster, +) # noqa: F401 from twosigma.memento.context import InvocationContext # noqa: F401 from twosigma.memento.runner_local import LocalRunnerBackend -from twosigma.memento.runner_test import set_runner_fn_test_1_called, get_runner_fn_test_1_called, runner_fn_test_1,\ - fn_A, fn_B +from twosigma.memento.runner_test import ( + set_runner_fn_test_1_called, + get_runner_fn_test_1_called, + runner_fn_test_1, + fn_A, + fn_B, +) from twosigma.memento.storage_filesystem import FilesystemStorageBackend from tests.test_runner_backend import RunnerBackendTester @@ -58,28 +68,36 @@ class TestRunnerLocal(RunnerBackendTester): """ - backend = None # type: RunnerBackend + backend = None # type: RunnerBackend def setup_method(self): super().setup_method() self.original_env = m.Environment.get() self.base_path = tempfile.mkdtemp(prefix="memento_runner_local_test") self.data_path = "{}/data".format(self.base_path) - m.Environment.set(Environment(name="test1", base_dir=self.base_path, repos=[ - ConfigurationRepository( - name="repo1", - clusters={ - "cluster1": FunctionCluster( - name="cluster1", - storage=FilesystemStorageBackend(path=self.data_path), - runner=LocalRunnerBackend()), - "memento.unit_test": FunctionCluster( - name="memento.unit_test", - storage=FilesystemStorageBackend(path=self.data_path), - runner=LocalRunnerBackend()) - } + m.Environment.set( + Environment( + name="test1", + base_dir=self.base_path, + repos=[ + ConfigurationRepository( + name="repo1", + clusters={ + "cluster1": FunctionCluster( + name="cluster1", + storage=FilesystemStorageBackend(path=self.data_path), + runner=LocalRunnerBackend(), + ), + "memento.unit_test": FunctionCluster( + name="memento.unit_test", + storage=FilesystemStorageBackend(path=self.data_path), + runner=LocalRunnerBackend(), + ), + }, + ) + ], ) - ])) + ) self.cluster = m.Environment.get().get_cluster("cluster1") self.backend = self.cluster.runner diff --git a/tests/test_runner_null.py b/tests/test_runner_null.py index 51a7d7a..3bcec05 100644 --- a/tests/test_runner_null.py +++ b/tests/test_runner_null.py @@ -18,7 +18,12 @@ import pytest import twosigma.memento as m -from twosigma.memento import RunnerBackend, Environment, ConfigurationRepository, FunctionCluster # noqa: F401 +from twosigma.memento import ( + RunnerBackend, + Environment, + ConfigurationRepository, + FunctionCluster, +) # noqa: F401 from twosigma.memento.runner_null import NullRunnerBackend from twosigma.memento.storage_null import NullStorageBackend @@ -36,22 +41,30 @@ class TestRunnerNull: """ - backend = None # type: RunnerBackend + backend = None # type: RunnerBackend def setup_method(self): self.original_env = m.Environment.get() self.base_path = tempfile.mkdtemp(prefix="memento_runner_null_test") self.data_path = "{}/data".format(self.base_path) - m.Environment.set(Environment(name="test1", base_dir=self.base_path, repos=[ - ConfigurationRepository( - name="repo1", - clusters={ - "cluster1": FunctionCluster(name="cluster1", - storage=NullStorageBackend(), - runner=NullRunnerBackend()) - } + m.Environment.set( + Environment( + name="test1", + base_dir=self.base_path, + repos=[ + ConfigurationRepository( + name="repo1", + clusters={ + "cluster1": FunctionCluster( + name="cluster1", + storage=NullStorageBackend(), + runner=NullRunnerBackend(), + ) + }, + ) + ], ) - ])) + ) self.cluster = m.Environment.get().get_cluster("cluster1") self.backend = self.cluster.runner diff --git a/tests/test_serialization.py b/tests/test_serialization.py index f61c296..5cd8901 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -16,8 +16,14 @@ import tempfile import twosigma.memento as m -from twosigma.memento import Environment, ConfigurationRepository, FunctionCluster, \ - memento_function, file_resource, Memento # noqa: F401 +from twosigma.memento import ( + Environment, + ConfigurationRepository, + FunctionCluster, + memento_function, + file_resource, + Memento, +) # noqa: F401 from twosigma.memento.context import RecursiveContext from twosigma.memento.reference import FunctionReferenceWithArguments from twosigma.memento.runner_null import NullRunnerBackend @@ -34,16 +40,24 @@ class TestSerialization: def setup_method(self): self.original_env = m.Environment.get() self.base_path = tempfile.mkdtemp(prefix="memento_serialization_test") - m.Environment.set(Environment(name="test1", base_dir=self.base_path, repos=[ - ConfigurationRepository( - name="repo1", - clusters={ - "cluster1": FunctionCluster(name="cluster1", - storage=NullStorageBackend(), - runner=NullRunnerBackend()) - } + m.Environment.set( + Environment( + name="test1", + base_dir=self.base_path, + repos=[ + ConfigurationRepository( + name="repo1", + clusters={ + "cluster1": FunctionCluster( + name="cluster1", + storage=NullStorageBackend(), + runner=NullRunnerBackend(), + ) + }, + ) + ], ) - ])) + ) def teardown_method(self): shutil.rmtree(self.base_path) @@ -64,26 +78,33 @@ def test_invocation_metadata(self): fn_test(1) obj = fn_test.memento(1).invocation_metadata result = MementoCodec.decode_invocation_metadata( - MementoCodec.encode_invocation_metadata(obj)) + MementoCodec.encode_invocation_metadata(obj) + ) assert repr(obj) == repr(result) def test_fn_reference_with_args(self): obj = FunctionReferenceWithArguments( - fn_reference=fn_test.fn_reference(), args=(1,), kwargs={}) + fn_reference=fn_test.fn_reference(), args=(1,), kwargs={} + ) result = MementoCodec.decode_fn_reference_with_args( - MementoCodec.encode_fn_reference_with_args(obj)) + MementoCodec.encode_fn_reference_with_args(obj) + ) assert repr(obj) == repr(result) def test_fn_reference_with_arg_hash(self): - obj = FunctionReferenceWithArguments(fn_reference=fn_test.fn_reference(), args=(1,), - kwargs={}).fn_reference_with_arg_hash() + obj = FunctionReferenceWithArguments( + fn_reference=fn_test.fn_reference(), args=(1,), kwargs={} + ).fn_reference_with_arg_hash() result = MementoCodec.decode_fn_reference_with_arg_hash( - MementoCodec.encode_fn_reference_with_arg_hash(obj)) + MementoCodec.encode_fn_reference_with_arg_hash(obj) + ) assert repr(obj) == repr(result) def test_resource_handle(self): obj = file_resource("/dev/null") - result = MementoCodec.decode_resource_handle(MementoCodec.encode_resource_handle(obj)) + result = MementoCodec.decode_resource_handle( + MementoCodec.encode_resource_handle(obj) + ) assert repr(obj) == repr(result) def test_fn_reference(self): @@ -93,7 +114,9 @@ def test_fn_reference(self): def test_recursive_context(self): obj = RecursiveContext(correlation_id="123", retry_on_remote_call=True) - result = MementoCodec.decode_recursive_context(MementoCodec.encode_recursive_context(obj)) + result = MementoCodec.decode_recursive_context( + MementoCodec.encode_recursive_context(obj) + ) assert obj.__dict__ == result.__dict__ def test_arg(self): diff --git a/tests/test_storage_backend.py b/tests/test_storage_backend.py index b60b364..db6de31 100644 --- a/tests/test_storage_backend.py +++ b/tests/test_storage_backend.py @@ -23,6 +23,7 @@ import pandas as pd import pytest from pandas.testing import assert_series_equal + # noinspection PyUnresolvedReferences from pandas import Timestamp import numpy as np @@ -33,8 +34,13 @@ from twosigma.memento.metadata import ResultType, InvocationMetadata, Memento from twosigma.memento.partition import InMemoryPartition from twosigma.memento.reference import FunctionReferenceWithArguments -from twosigma.memento.storage_base import DataSource, MetadataSource, DataSourceKey, MemoryCache, \ - ResultIsWithData +from twosigma.memento.storage_base import ( + DataSource, + MetadataSource, + DataSourceKey, + MemoryCache, + ResultIsWithData, +) now_date = Timestamp.today(None).date() now_time = Timestamp.today(None) @@ -99,10 +105,12 @@ def fn_return_time(): @m.memento_function(cluster="cluster1") def fn_return_dictionary(): - return {"a": 2, - "b": {"c": "d"}, - "e": {"f": [1.0, 2.1, 3.2]}, - "g": pd.Series([1, 2, 3])} + return { + "a": 2, + "b": {"c": "d"}, + "e": {"f": [1.0, 2.1, 3.2]}, + "g": pd.Series([1, 2, 3]), + } @m.memento_function(cluster="cluster1") @@ -122,38 +130,39 @@ def fn_return_numpy_array_bool(): @m.memento_function(cluster="cluster1") def fn_return_numpy_array_int8(): - return np.array([1, 2, 3], dtype='int8') + return np.array([1, 2, 3], dtype="int8") @m.memento_function(cluster="cluster1") def fn_return_numpy_array_int16(): - return np.array([1, 2, 3], dtype='int16') + return np.array([1, 2, 3], dtype="int16") @m.memento_function(cluster="cluster1") def fn_return_numpy_array_int32(): - return np.array([1, 2, 3], dtype='int32') + return np.array([1, 2, 3], dtype="int32") @m.memento_function(cluster="cluster1") def fn_return_numpy_array_int64(): - return np.array([1, 2, 3], dtype='int64') + return np.array([1, 2, 3], dtype="int64") @m.memento_function(cluster="cluster1") def fn_return_numpy_array_float32(): - return np.array([1.0, 2.0, 3.0], dtype='float32') + return np.array([1.0, 2.0, 3.0], dtype="float32") @m.memento_function(cluster="cluster1") def fn_return_numpy_array_float64(): - return np.array([1.0, 2.0, 3.0], dtype='float64') + return np.array([1.0, 2.0, 3.0], dtype="float64") @m.memento_function(cluster="cluster1") def fn_return_index(): - return pd.DatetimeIndex([Timestamp.now() - datetime.timedelta(days=1), - Timestamp.now()]) + return pd.DatetimeIndex( + [Timestamp.now() - datetime.timedelta(days=1), Timestamp.now()] + ) @m.memento_function(cluster="cluster1") @@ -169,9 +178,13 @@ def fn_return_series_with_index(): @m.memento_function(cluster="cluster1") def fn_return_series_with_multiindex(): - idx = pd.DatetimeIndex([Timestamp.now() - datetime.timedelta(days=2), - Timestamp.now() - datetime.timedelta(days=1), - Timestamp.now()]) + idx = pd.DatetimeIndex( + [ + Timestamp.now() - datetime.timedelta(days=2), + Timestamp.now() - datetime.timedelta(days=1), + Timestamp.now(), + ] + ) idx2 = pd.Index([11, 22, 33]) idx3 = pd.date_range("2020-01-01", periods=3, freq="D") return pd.Series([1, 2, 3], index=[idx, idx2, idx3], name="foo") @@ -179,28 +192,42 @@ def fn_return_series_with_multiindex(): @m.memento_function(cluster="cluster1") def fn_return_dataframe(): - return pd.DataFrame([ - {"name": "a", "val": "1"}, - {"name": "b", "val": "2"}, - {"name": "c", "val": "3"} - ]) + return pd.DataFrame( + [ + {"name": "a", "val": "1"}, + {"name": "b", "val": "2"}, + {"name": "c", "val": "3"}, + ] + ) @m.memento_function(cluster="cluster1") def fn_return_dataframe_with_index(): - return pd.DataFrame([ - {"name": "a", "val": "1"}, - {"name": "b", "val": "2"}, - {"name": "c", "val": "3"} - ], index=pd.Index(['x', 'y', 'z'])) + return pd.DataFrame( + [ + {"name": "a", "val": "1"}, + {"name": "b", "val": "2"}, + {"name": "c", "val": "3"}, + ], + index=pd.Index(["x", "y", "z"]), + ) @m.memento_function(cluster="cluster1") def fn_return_partition(): - return InMemoryPartition({"a": None, "b": True, "c": 1, "d": 2.0, "e": [1, 2, 3], - "f": {"a": "b"}, "g": np.array([1, 2, 3], dtype='int8'), - "h": pd.DataFrame([{"name": "a", "val": "1"}]), - "i": InMemoryPartition({"j": "k"})}) + return InMemoryPartition( + { + "a": None, + "b": True, + "c": 1, + "d": 2.0, + "e": [1, 2, 3], + "f": {"a": "b"}, + "g": np.array([1, 2, 3], dtype="int8"), + "h": pd.DataFrame([{"name": "a", "val": "1"}]), + "i": InMemoryPartition({"j": "k"}), + } + ) @m.memento_function(cluster="cluster1") @@ -248,8 +275,8 @@ class StorageBackendTester(ABC): """ - backend = None # type: StorageBackend - test = None # type: TestCase + backend = None # type: StorageBackend + test = None # type: TestCase def setup_method(self): pass @@ -258,7 +285,9 @@ def teardown_method(self): pass @staticmethod - def get_dummy_memento(fn_reference_with_args: FunctionReferenceWithArguments) -> Memento: + def get_dummy_memento( + fn_reference_with_args: FunctionReferenceWithArguments, + ) -> Memento: now = datetime.datetime.now(datetime.timezone.utc) return Memento( time=now, @@ -267,47 +296,67 @@ def get_dummy_memento(fn_reference_with_args: FunctionReferenceWithArguments) -> fn_reference_with_args=fn_reference_with_args, result_type=ResultType.string, invocations=[], - resources=[] + resources=[], ), function_dependencies={fn_reference_with_args.fn_reference}, runner={}, correlation_id="abc123", - content_key=VersionedDataSourceKey("def456", "0") + content_key=VersionedDataSourceKey("def456", "0"), ) def test_list_functions(self): assert 0 == len(self.backend.list_functions()) fn1(1) retry_until(lambda: self.backend.list_functions(), lambda x: len(x) == 1) - assert {FunctionReference(fn1).qualified_name} == \ - set([x.qualified_name for x in self.backend.list_functions()]) + assert {FunctionReference(fn1).qualified_name} == set( + [x.qualified_name for x in self.backend.list_functions()] + ) fn1(1) fn1(2) fn_return_none_1() retry_until(lambda: self.backend.list_functions(), lambda x: len(x) == 2) - assert {FunctionReference(fn1).qualified_name, - FunctionReference(fn_return_none_1).qualified_name} == \ - set([x.qualified_name for x in self.backend.list_functions()]) + assert { + FunctionReference(fn1).qualified_name, + FunctionReference(fn_return_none_1).qualified_name, + } == set([x.qualified_name for x in self.backend.list_functions()]) def test_list_mementos(self): assert 0 == len(self.backend.list_mementos(FunctionReference(fn1))) assert 0 == len(self.backend.list_mementos(FunctionReference(fn2))) fn1(1) - retry_until(lambda: self.backend.list_mementos(FunctionReference(fn1)), - lambda x: len(x) == 1) - assert {1} == set([x.invocation_metadata.fn_reference_with_args.args[0] - for x in self.backend.list_mementos(FunctionReference(fn1))]) + retry_until( + lambda: self.backend.list_mementos(FunctionReference(fn1)), + lambda x: len(x) == 1, + ) + assert {1} == set( + [ + x.invocation_metadata.fn_reference_with_args.args[0] + for x in self.backend.list_mementos(FunctionReference(fn1)) + ] + ) fn1(1) fn1(2) - retry_until(lambda: self.backend.list_mementos(FunctionReference(fn1)), - lambda x: len(x) == 2) - assert {1, 2} == set([x.invocation_metadata.fn_reference_with_args.args[0] - for x in self.backend.list_mementos(FunctionReference(fn1))]) + retry_until( + lambda: self.backend.list_mementos(FunctionReference(fn1)), + lambda x: len(x) == 2, + ) + assert {1, 2} == set( + [ + x.invocation_metadata.fn_reference_with_args.args[0] + for x in self.backend.list_mementos(FunctionReference(fn1)) + ] + ) fn2(3) - retry_until(lambda: self.backend.list_mementos(FunctionReference(fn2)), - lambda x: len(x) == 1) - assert {3} == set([x.invocation_metadata.fn_reference_with_args.args[0] - for x in self.backend.list_mementos(FunctionReference(fn2))]) + retry_until( + lambda: self.backend.list_mementos(FunctionReference(fn2)), + lambda x: len(x) == 1, + ) + assert {3} == set( + [ + x.invocation_metadata.fn_reference_with_args.args[0] + for x in self.backend.list_mementos(FunctionReference(fn2)) + ] + ) def test_memoize(self): fn1_reference = fn_return_none_1.fn_reference().with_args() @@ -321,12 +370,15 @@ def test_memoize(self): fn_reference_with_args=fn1_reference, result_type=ResultType.string, invocations=[fn2_reference], - resources=[] + resources=[], ), - function_dependencies={fn1_reference.fn_reference, fn2_reference.fn_reference}, + function_dependencies={ + fn1_reference.fn_reference, + fn2_reference.fn_reference, + }, runner={}, correlation_id="abc123", - content_key=VersionedDataSourceKey("def456", "0") + content_key=VersionedDataSourceKey("def456", "0"), ) result = None self.backend.memoize(None, memento, result) @@ -334,29 +386,55 @@ def test_memoize(self): # Validate the content hash was changed during the call to memoize assert VersionedDataSourceKey("def456", "0") != memento.content_key - actual_memento = self.backend.get_memento(fn1_reference.fn_reference_with_arg_hash()) + actual_memento = self.backend.get_memento( + fn1_reference.fn_reference_with_arg_hash() + ) actual_result = self.backend.read_result(actual_memento) assert memento.time == actual_memento.time, "now={}".format(now) - assert memento.invocation_metadata.runtime == actual_memento.invocation_metadata.runtime - assert memento.invocation_metadata.fn_reference_with_args.fn_reference.qualified_name == \ - actual_memento.invocation_metadata.fn_reference_with_args.fn_reference.qualified_name - assert memento.invocation_metadata.fn_reference_with_args.args == \ - actual_memento.invocation_metadata.fn_reference_with_args.args - assert memento.invocation_metadata.fn_reference_with_args.kwargs == \ - actual_memento.invocation_metadata.fn_reference_with_args.kwargs - assert memento.invocation_metadata.fn_reference_with_args.arg_hash == \ - actual_memento.invocation_metadata.fn_reference_with_args.arg_hash - assert memento.invocation_metadata.result_type == actual_memento.invocation_metadata.result_type - assert memento.invocation_metadata.invocations[0].fn_reference.qualified_name == \ - actual_memento.invocation_metadata.invocations[0].fn_reference.qualified_name - assert memento.invocation_metadata.invocations[0].arg_hash == \ - actual_memento.invocation_metadata.invocations[0].arg_hash + assert ( + memento.invocation_metadata.runtime + == actual_memento.invocation_metadata.runtime + ) + assert ( + memento.invocation_metadata.fn_reference_with_args.fn_reference.qualified_name + == actual_memento.invocation_metadata.fn_reference_with_args.fn_reference.qualified_name + ) + assert ( + memento.invocation_metadata.fn_reference_with_args.args + == actual_memento.invocation_metadata.fn_reference_with_args.args + ) + assert ( + memento.invocation_metadata.fn_reference_with_args.kwargs + == actual_memento.invocation_metadata.fn_reference_with_args.kwargs + ) + assert ( + memento.invocation_metadata.fn_reference_with_args.arg_hash + == actual_memento.invocation_metadata.fn_reference_with_args.arg_hash + ) + assert ( + memento.invocation_metadata.result_type + == actual_memento.invocation_metadata.result_type + ) + assert ( + memento.invocation_metadata.invocations[0].fn_reference.qualified_name + == actual_memento.invocation_metadata.invocations[ + 0 + ].fn_reference.qualified_name + ) + assert ( + memento.invocation_metadata.invocations[0].arg_hash + == actual_memento.invocation_metadata.invocations[0].arg_hash + ) assert result == actual_result - assert self.backend.is_memoized(fn1_reference.fn_reference, fn1_reference.arg_hash) + assert self.backend.is_memoized( + fn1_reference.fn_reference, fn1_reference.arg_hash + ) self.backend.forget_call(fn1_reference.fn_reference_with_arg_hash()) - assert not self.backend.is_memoized(fn1_reference.fn_reference, fn1_reference.arg_hash) + assert not self.backend.is_memoized( + fn1_reference.fn_reference, fn1_reference.arg_hash + ) @staticmethod def test_memoize_exception(): @@ -566,15 +644,20 @@ def test_memoize_partition(): assert 2.0 == second.get("d") assert [1, 2, 3] == second.get("e") assert {"a": "b"} == second.get("f") - assert np.array_equal(np.array([1, 2, 3], dtype='int8'), cast(np.array, second.get("g"))) + assert np.array_equal( + np.array([1, 2, 3], dtype="int8"), cast(np.array, second.get("g")) + ) # noinspection PyUnresolvedReferences assert first.get("h").equals(second.get("h")) # noinspection PyUnresolvedReferences assert "k" == second.get("i").get("j") def test_memoize_partition_with_parent(self): - memory_cache = self.backend._memory_cache if hasattr( - self.backend, "_memory_cache") else None # type: Optional[MemoryCache] + memory_cache = ( + self.backend._memory_cache + if hasattr(self.backend, "_memory_cache") + else None + ) # type: Optional[MemoryCache] # Ensure a is memoized to disk assert fn_return_partition_with_parent_a.memento() is None @@ -631,7 +714,7 @@ def test_make_url_for_result(self): url = self.backend.make_url_for_result(memento) assert url is not None # The URL should probably have the content key in it somewhere (though this is not technically a requirement) - ck = memento.content_key.key[memento.content_key.key.rfind("/")+1:] + ck = memento.content_key.key[memento.content_key.key.rfind("/") + 1 :] assert url.find(ck) != -1, url def test_forget_call(self): @@ -691,7 +774,9 @@ def test_readonly(self): with pytest.raises(ValueError): self.backend.write_metadata( memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), - "key1", "value1".encode("utf-8")) + "key1", + "value1".encode("utf-8"), + ) # Make the cluster writable and then ensure memoization occurs cluster1.storage.read_only = False @@ -718,17 +803,26 @@ def test_metadata_rw(): storage = m.Environment.get().get_cluster("cluster1").storage storage.write_metadata( memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), - "key1", "value1".encode("utf-8")) + "key1", + "value1".encode("utf-8"), + ) # noinspection PyUnresolvedReferences assert "value1" == storage.read_metadata( memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), - "key1", retry_on_none=True).decode("utf-8") + "key1", + retry_on_none=True, + ).decode("utf-8") # test that deleting the memento deletes the metadata fn_add_one.forget(1) - assert storage.read_metadata( - memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), - "key1", retry_on_none=True) is None + assert ( + storage.read_metadata( + memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), + "key1", + retry_on_none=True, + ) + is None + ) @staticmethod def test_metadata_rw_store_with_data(): @@ -738,17 +832,26 @@ def test_metadata_rw_store_with_data(): storage = m.Environment.get().get_cluster("cluster1").storage storage.write_metadata( memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), - "key1", "value1".encode("utf-8")) + "key1", + "value1".encode("utf-8"), + ) # noinspection PyUnresolvedReferences assert "value1" == storage.read_metadata( memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), - "key1", retry_on_none=True).decode("utf-8") + "key1", + retry_on_none=True, + ).decode("utf-8") # test that deleting the memento deletes the metadata fn_add_one.forget(1) - assert storage.read_metadata( - memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), - "key1", retry_on_none=True) is None + assert ( + storage.read_metadata( + memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), + "key1", + retry_on_none=True, + ) + is None + ) @staticmethod def test_content_addressable_storage(): @@ -757,7 +860,9 @@ def test_content_addressable_storage(): assert r1 == r2 m1 = fn1.memento(100) m2 = fn2.memento(100) - assert (m1.content_key is None and m2.content_key is None) or (m1.content_key.key == m2.content_key.key) + assert (m1.content_key is None and m2.content_key is None) or ( + m1.content_key.key == m2.content_key.key + ) class DataSourceTester(ABC): @@ -769,11 +874,11 @@ class DataSourceTester(ABC): """ data_source = None # type: DataSource - test = None # type: TestCase - data = None # type: BytesIO - data2 = None # type: BytesIO - data3 = None # type: BytesIO - data4 = None # type: BytesIO + test = None # type: TestCase + data = None # type: BytesIO + data2 = None # type: BytesIO + data3 = None # type: BytesIO + data4 = None # type: BytesIO def setup_method(self): self.data = BytesIO("abc".encode("utf-8")) @@ -984,7 +1089,9 @@ def test_list_keys(self): # Some storage implementations need to set up an initial key with which to associate # the data source. This will be added to the expected results in the following tests. baseline_set = set(self.data_source.list_keys_nonversioned(root)) - baseline_set_recursive = set(self.data_source.list_keys_nonversioned(root, recursive=True)) + baseline_set_recursive = set( + self.data_source.list_keys_nonversioned(root, recursive=True) + ) self.data_source.output(a_b, data) self.data_source.output(a_c, data2) @@ -992,68 +1099,165 @@ def test_list_keys(self): self.data_source.output(f, data4) # Test recursive=False and limit - assert baseline_set.union({a, f}) == \ - set(self.data_source.list_keys_nonversioned(directory=root, file_prefix="", recursive=False)) - assert 2 == len(list(self.data_source.list_keys_nonversioned(directory=root, file_prefix="", - recursive=False, limit=2))) + assert baseline_set.union({a, f}) == set( + self.data_source.list_keys_nonversioned( + directory=root, file_prefix="", recursive=False + ) + ) + assert 2 == len( + list( + self.data_source.list_keys_nonversioned( + directory=root, file_prefix="", recursive=False, limit=2 + ) + ) + ) - assert {a_b, a_c, a_d} == \ - set(self.data_source.list_keys_nonversioned(directory=a, file_prefix="", recursive=False)) - assert 2 == len(list(self.data_source.list_keys_nonversioned( - directory=a, file_prefix="", recursive=False, limit=2))) + assert {a_b, a_c, a_d} == set( + self.data_source.list_keys_nonversioned( + directory=a, file_prefix="", recursive=False + ) + ) + assert 2 == len( + list( + self.data_source.list_keys_nonversioned( + directory=a, file_prefix="", recursive=False, limit=2 + ) + ) + ) - assert {a_d_ef} == \ - set(self.data_source.list_keys_nonversioned(directory=a_d, file_prefix="e", recursive=False)) + assert {a_d_ef} == set( + self.data_source.list_keys_nonversioned( + directory=a_d, file_prefix="e", recursive=False + ) + ) - assert 1 == len(list(self.data_source.list_keys_nonversioned(directory=a_d, file_prefix="e", - recursive=False, limit=2))) + assert 1 == len( + list( + self.data_source.list_keys_nonversioned( + directory=a_d, file_prefix="e", recursive=False, limit=2 + ) + ) + ) - assert set() == \ - set(self.data_source.list_keys_nonversioned(directory=a, file_prefix="g", - recursive=False)) - assert 0 == len(list(self.data_source.list_keys_nonversioned(directory=a, file_prefix="g", - recursive=False, limit=2))) + assert set() == set( + self.data_source.list_keys_nonversioned( + directory=a, file_prefix="g", recursive=False + ) + ) + assert 0 == len( + list( + self.data_source.list_keys_nonversioned( + directory=a, file_prefix="g", recursive=False, limit=2 + ) + ) + ) - assert {a_d_ef} == set(self.data_source.list_keys_nonversioned( - directory=a_d, file_prefix="e", recursive=False)) + assert {a_d_ef} == set( + self.data_source.list_keys_nonversioned( + directory=a_d, file_prefix="e", recursive=False + ) + ) - assert 1 == len(list(self.data_source.list_keys_nonversioned( - directory=a_d, file_prefix="e", recursive=False, limit=2))) + assert 1 == len( + list( + self.data_source.list_keys_nonversioned( + directory=a_d, file_prefix="e", recursive=False, limit=2 + ) + ) + ) # Test recursive=True and limit - assert baseline_set_recursive.union({a_b, a_c, a_d_ef, f}) == \ - set(self.data_source.list_keys_nonversioned(directory=root, file_prefix="", recursive=True)) - - assert min(2, len(baseline_set) + 2) == \ - len(list(self.data_source.list_keys_nonversioned(directory=root, file_prefix="", - recursive=True, limit=2))) - assert {a_b, a_c, a_d_ef} == set(self.data_source.list_keys_nonversioned( - directory=a, file_prefix="", recursive=True)) - - assert 2 == len(list(self.data_source.list_keys_nonversioned( - directory=a, file_prefix="", recursive=True, limit=2))) - assert {a_d_ef} == set(self.data_source.list_keys_nonversioned(directory=a_d, file_prefix="e", recursive=True)) - assert 1 == len(list(self.data_source.list_keys_nonversioned( - directory=a_d, file_prefix="e", recursive=True, limit=2))) - assert set() == set(self.data_source.list_keys_nonversioned(directory=a, file_prefix="g", recursive=True)) - assert 0 == len(list(self.data_source.list_keys_nonversioned( - directory=a, file_prefix="g", recursive=True, limit=2))) - assert {a_d_ef} == set(self.data_source.list_keys_nonversioned(directory=a_d, file_prefix="e", recursive=True)) - assert 1 == len(list(self.data_source.list_keys_nonversioned( - directory=a_d, file_prefix="e", recursive=True, limit=2))) + assert baseline_set_recursive.union({a_b, a_c, a_d_ef, f}) == set( + self.data_source.list_keys_nonversioned( + directory=root, file_prefix="", recursive=True + ) + ) + + assert min(2, len(baseline_set) + 2) == len( + list( + self.data_source.list_keys_nonversioned( + directory=root, file_prefix="", recursive=True, limit=2 + ) + ) + ) + assert {a_b, a_c, a_d_ef} == set( + self.data_source.list_keys_nonversioned( + directory=a, file_prefix="", recursive=True + ) + ) + + assert 2 == len( + list( + self.data_source.list_keys_nonversioned( + directory=a, file_prefix="", recursive=True, limit=2 + ) + ) + ) + assert {a_d_ef} == set( + self.data_source.list_keys_nonversioned( + directory=a_d, file_prefix="e", recursive=True + ) + ) + assert 1 == len( + list( + self.data_source.list_keys_nonversioned( + directory=a_d, file_prefix="e", recursive=True, limit=2 + ) + ) + ) + assert set() == set( + self.data_source.list_keys_nonversioned( + directory=a, file_prefix="g", recursive=True + ) + ) + assert 0 == len( + list( + self.data_source.list_keys_nonversioned( + directory=a, file_prefix="g", recursive=True, limit=2 + ) + ) + ) + assert {a_d_ef} == set( + self.data_source.list_keys_nonversioned( + directory=a_d, file_prefix="e", recursive=True + ) + ) + assert 1 == len( + list( + self.data_source.list_keys_nonversioned( + directory=a_d, file_prefix="e", recursive=True, limit=2 + ) + ) + ) # Remove some files and test self.data_source.delete_all_versions(a_b, recursive=False) self.data_source.delete_all_versions(a_c, recursive=False) self.data_source.delete_all_versions(a_d_ef, recursive=False) - assert baseline_set.union({f}) == set(self.data_source.list_keys_nonversioned( - directory=root, file_prefix="", recursive=False)) - assert baseline_set_recursive.union({f}) == set(self.data_source.list_keys_nonversioned( - directory=root, file_prefix="", recursive=True)) - assert min(2, len(baseline_set) + 1) == len(list(self.data_source.list_keys_nonversioned( - directory=root, file_prefix="", recursive=False, limit=2))) - assert min(2, len(baseline_set) + 1) == len(list(self.data_source.list_keys_nonversioned( - directory=root, file_prefix="", recursive=True, limit=2))) + assert baseline_set.union({f}) == set( + self.data_source.list_keys_nonversioned( + directory=root, file_prefix="", recursive=False + ) + ) + assert baseline_set_recursive.union({f}) == set( + self.data_source.list_keys_nonversioned( + directory=root, file_prefix="", recursive=True + ) + ) + assert min(2, len(baseline_set) + 1) == len( + list( + self.data_source.list_keys_nonversioned( + directory=root, file_prefix="", recursive=False, limit=2 + ) + ) + ) + assert min(2, len(baseline_set) + 1) == len( + list( + self.data_source.list_keys_nonversioned( + directory=root, file_prefix="", recursive=True, limit=2 + ) + ) + ) def test_get_versioned_key(self): data = self.data @@ -1092,49 +1296,70 @@ def setup_method(self): self.test_memento_1_0 = StorageBackendTester.get_dummy_memento(self.fn_ref_1_0) self.fn_ref_2 = fn2.fn_reference().with_args(2) self.test_memento_2 = StorageBackendTester.get_dummy_memento(self.fn_ref_2) - self.fn_ref_date = fn1.fn_reference().with_args(datetime.datetime(2019, 1, 1).date()) - self.test_memento_date = StorageBackendTester.get_dummy_memento(self.fn_ref_date) + self.fn_ref_date = fn1.fn_reference().with_args( + datetime.datetime(2019, 1, 1).date() + ) + self.test_memento_date = StorageBackendTester.get_dummy_memento( + self.fn_ref_date + ) def teardown_method(self): pass def test_put_and_get_memento(self): self.metadata_source.put_memento(self.test_memento_1) - result = self.metadata_source.get_mementos([self.fn_ref_1.fn_reference_with_arg_hash()])[0] - assert self.test_memento_1.invocation_metadata.fn_reference_with_args.args[0] == \ - result.invocation_metadata.fn_reference_with_args.args[0] + result = self.metadata_source.get_mementos( + [self.fn_ref_1.fn_reference_with_arg_hash()] + )[0] + assert ( + self.test_memento_1.invocation_metadata.fn_reference_with_args.args[0] + == result.invocation_metadata.fn_reference_with_args.args[0] + ) def test_put_and_get_memento_with_date(self): self.metadata_source.put_memento(self.test_memento_date) result = self.metadata_source.get_mementos( - [self.fn_ref_date.fn_reference_with_arg_hash()])[0] - assert self.test_memento_date.invocation_metadata.fn_reference_with_args.args[0] == \ - result.invocation_metadata.fn_reference_with_args.args[0] + [self.fn_ref_date.fn_reference_with_arg_hash()] + )[0] + assert ( + self.test_memento_date.invocation_metadata.fn_reference_with_args.args[0] + == result.invocation_metadata.fn_reference_with_args.args[0] + ) def test_get_mementos(self): - all_three = [self.fn_ref_1.fn_reference_with_arg_hash(), - self.fn_ref_1_0.fn_reference_with_arg_hash(), - self.fn_ref_2.fn_reference_with_arg_hash()] + all_three = [ + self.fn_ref_1.fn_reference_with_arg_hash(), + self.fn_ref_1_0.fn_reference_with_arg_hash(), + self.fn_ref_2.fn_reference_with_arg_hash(), + ] assert [None, None, None] == self.metadata_source.get_mementos(all_three) self.metadata_source.put_memento(self.test_memento_1) - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] if mem else None - for mem in self.metadata_source.get_mementos(all_three)] + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] if mem else None + for mem in self.metadata_source.get_mementos(all_three) + ] assert [1, None, None] == listed_args self.metadata_source.put_memento(self.test_memento_1_0) - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] if mem else None - for mem in self.metadata_source.get_mementos(all_three)] + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] if mem else None + for mem in self.metadata_source.get_mementos(all_three) + ] assert [1, 0, None] == listed_args self.metadata_source.put_memento(self.test_memento_2) - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] if mem else None - for mem in self.metadata_source.get_mementos(all_three)] + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] if mem else None + for mem in self.metadata_source.get_mementos(all_three) + ] assert [1, 0, 2] == listed_args def test_all_mementos_exist(self): - full_list = [self.fn_ref_1.fn_reference_with_arg_hash(), - self.fn_ref_2.fn_reference_with_arg_hash()] + full_list = [ + self.fn_ref_1.fn_reference_with_arg_hash(), + self.fn_ref_2.fn_reference_with_arg_hash(), + ] assert not self.metadata_source.all_mementos_exist(full_list) self.metadata_source.put_memento(self.test_memento_1) assert not self.metadata_source.all_mementos_exist(full_list) @@ -1146,116 +1371,217 @@ def test_list_functions(self): self.metadata_source.put_memento(self.test_memento_1) - fn_list = retry_until(lambda: self.metadata_source.list_functions(), lambda x: len(x) > 0) - assert [self.fn_ref_1.fn_reference.qualified_name] == [f.qualified_name for f in fn_list] + fn_list = retry_until( + lambda: self.metadata_source.list_functions(), lambda x: len(x) > 0 + ) + assert [self.fn_ref_1.fn_reference.qualified_name] == [ + f.qualified_name for f in fn_list + ] self.metadata_source.put_memento(self.test_memento_2) - fn_list = retry_until(lambda: self.metadata_source.list_functions(), lambda x: len(x) > 0) - assert [self.fn_ref_1.fn_reference.qualified_name, self.fn_ref_2.fn_reference.qualified_name] ==\ - [f.qualified_name for f in fn_list] + fn_list = retry_until( + lambda: self.metadata_source.list_functions(), lambda x: len(x) > 0 + ) + assert [ + self.fn_ref_1.fn_reference.qualified_name, + self.fn_ref_2.fn_reference.qualified_name, + ] == [f.qualified_name for f in fn_list] def test_list_mementos(self): - assert [] == self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, None) + assert [] == self.metadata_source.list_mementos( + self.fn_ref_1.fn_reference, None + ) self.metadata_source.put_memento(self.test_memento_1) - retry_until(lambda: self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, None), - lambda x: len(x) == 1) - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] for mem in - self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, None)] + retry_until( + lambda: self.metadata_source.list_mementos( + self.fn_ref_1.fn_reference, None + ), + lambda x: len(x) == 1, + ) + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] + for mem in self.metadata_source.list_mementos( + self.fn_ref_1.fn_reference, None + ) + ] assert {1} == set(listed_args) self.metadata_source.put_memento(self.test_memento_1_0) - retry_until(lambda: self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, None), - lambda x: len(x) == 2) - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] for mem in - self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, None)] + retry_until( + lambda: self.metadata_source.list_mementos( + self.fn_ref_1.fn_reference, None + ), + lambda x: len(x) == 2, + ) + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] + for mem in self.metadata_source.list_mementos( + self.fn_ref_1.fn_reference, None + ) + ] assert {0, 1} == set(listed_args) # Test limit - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] for mem in - self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, 1)] + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] + for mem in self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, 1) + ] assert 1 == len(listed_args) def test_write_and_read_metadata(self): self.metadata_source.put_memento(self.test_memento_1) - self.metadata_source.write_metadata(self.test_memento_1.invocation_metadata. - fn_reference_with_args.fn_reference_with_arg_hash(), - "key1", "value1".encode("utf-8"), - stored_with_data=False) + self.metadata_source.write_metadata( + self.test_memento_1.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), + "key1", + "value1".encode("utf-8"), + stored_with_data=False, + ) assert "value1" == self.metadata_source.read_metadata( - self.test_memento_1.invocation_metadata.fn_reference_with_args. - fn_reference_with_arg_hash(), "key1", retry_on_none=True).decode("utf-8") + self.test_memento_1.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), + "key1", + retry_on_none=True, + ).decode("utf-8") def test_write_and_read_large_metadata(self): # ~ 5 MB of data: data = "0123456789" * (1024 * 512) self.metadata_source.put_memento(self.test_memento_1) - self.metadata_source.write_metadata(self.test_memento_1.invocation_metadata. - fn_reference_with_args.fn_reference_with_arg_hash(), - "key1", data.encode("utf-8"), - stored_with_data=False) + self.metadata_source.write_metadata( + self.test_memento_1.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), + "key1", + data.encode("utf-8"), + stored_with_data=False, + ) assert data == self.metadata_source.read_metadata( - self.test_memento_1.invocation_metadata.fn_reference_with_args. - fn_reference_with_arg_hash(), "key1", retry_on_none=True).decode("utf-8") + self.test_memento_1.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), + "key1", + retry_on_none=True, + ).decode("utf-8") def test_write_and_read_metadata_store_with_data(self): self.metadata_source.put_memento(self.test_memento_1) - self.metadata_source.write_metadata(self.test_memento_1.invocation_metadata. - fn_reference_with_args.fn_reference_with_arg_hash(), - "key1", "value1".encode("utf-8"), - stored_with_data=True) - assert isinstance(self.metadata_source.read_metadata( - self.test_memento_1.invocation_metadata.fn_reference_with_args. - fn_reference_with_arg_hash(), "key1", retry_on_none=True), ResultIsWithData) + self.metadata_source.write_metadata( + self.test_memento_1.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), + "key1", + "value1".encode("utf-8"), + stored_with_data=True, + ) + assert isinstance( + self.metadata_source.read_metadata( + self.test_memento_1.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), + "key1", + retry_on_none=True, + ), + ResultIsWithData, + ) def test_forget_call(self): self.metadata_source.put_memento(self.test_memento_1) self.metadata_source.put_memento(self.test_memento_1_0) - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] for mem in - self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, None)] + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] + for mem in self.metadata_source.list_mementos( + self.fn_ref_1.fn_reference, None + ) + ] assert {0, 1} == set(listed_args) self.metadata_source.forget_call(self.fn_ref_1_0.fn_reference_with_arg_hash()) - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] for mem in - self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, None)] + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] + for mem in self.metadata_source.list_mementos( + self.fn_ref_1.fn_reference, None + ) + ] assert {1} == set(listed_args) self.metadata_source.forget_call(self.fn_ref_1.fn_reference_with_arg_hash()) - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] for mem in - self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, None)] + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] + for mem in self.metadata_source.list_mementos( + self.fn_ref_1.fn_reference, None + ) + ] assert set() == set(listed_args) def test_forget_everything(self): self.metadata_source.put_memento(self.test_memento_1) self.metadata_source.put_memento(self.test_memento_1_0) self.metadata_source.put_memento(self.test_memento_2) - retry_until(lambda: self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, None), - lambda x: set(mem.invocation_metadata.fn_reference_with_args.args[0] - for mem in x) == {0, 1}) - retry_until(lambda: self.metadata_source.list_mementos(self.fn_ref_2.fn_reference, None), - lambda x: set(mem.invocation_metadata.fn_reference_with_args.args[0] - for mem in x) == {2}) + retry_until( + lambda: self.metadata_source.list_mementos( + self.fn_ref_1.fn_reference, None + ), + lambda x: set( + mem.invocation_metadata.fn_reference_with_args.args[0] for mem in x + ) + == {0, 1}, + ) + retry_until( + lambda: self.metadata_source.list_mementos( + self.fn_ref_2.fn_reference, None + ), + lambda x: set( + mem.invocation_metadata.fn_reference_with_args.args[0] for mem in x + ) + == {2}, + ) self.metadata_source.forget_everything() - retry_until(lambda: self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, None), - lambda x: set(mem.invocation_metadata.fn_reference_with_args.args[0] - for mem in x) == set()) - retry_until(lambda: self.metadata_source.list_mementos(self.fn_ref_2.fn_reference, None), - lambda x: set(mem.invocation_metadata.fn_reference_with_args.args[0] - for mem in x) == set()) + retry_until( + lambda: self.metadata_source.list_mementos( + self.fn_ref_1.fn_reference, None + ), + lambda x: set( + mem.invocation_metadata.fn_reference_with_args.args[0] for mem in x + ) + == set(), + ) + retry_until( + lambda: self.metadata_source.list_mementos( + self.fn_ref_2.fn_reference, None + ), + lambda x: set( + mem.invocation_metadata.fn_reference_with_args.args[0] for mem in x + ) + == set(), + ) def test_forget_function(self): self.metadata_source.put_memento(self.test_memento_1) self.metadata_source.put_memento(self.test_memento_1_0) self.metadata_source.put_memento(self.test_memento_2) - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] for mem in - self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, None)] + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] + for mem in self.metadata_source.list_mementos( + self.fn_ref_1.fn_reference, None + ) + ] assert {0, 1} == set(listed_args) - list_result = retry_until(lambda: self.metadata_source.list_mementos(self.fn_ref_2.fn_reference, None), - lambda x: len(x) == 1) - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] for mem in list_result] + list_result = retry_until( + lambda: self.metadata_source.list_mementos( + self.fn_ref_2.fn_reference, None + ), + lambda x: len(x) == 1, + ) + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] + for mem in list_result + ] assert {2} == set(listed_args) self.metadata_source.forget_function(self.fn_ref_1.fn_reference) - listed_result = retry_until(lambda: self.metadata_source.list_mementos(self.fn_ref_1.fn_reference, None), - lambda x: len(x) == 0) - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] for mem in listed_result] + listed_result = retry_until( + lambda: self.metadata_source.list_mementos( + self.fn_ref_1.fn_reference, None + ), + lambda x: len(x) == 0, + ) + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] + for mem in listed_result + ] assert set() == set(listed_args) - listed_args = [mem.invocation_metadata.fn_reference_with_args.args[0] for mem in - self.metadata_source.list_mementos(self.fn_ref_2.fn_reference, None)] + listed_args = [ + mem.invocation_metadata.fn_reference_with_args.args[0] + for mem in self.metadata_source.list_mementos( + self.fn_ref_2.fn_reference, None + ) + ] assert {2} == set(listed_args) diff --git a/tests/test_storage_base.py b/tests/test_storage_base.py index 5094c96..21d0d14 100644 --- a/tests/test_storage_base.py +++ b/tests/test_storage_base.py @@ -18,6 +18,7 @@ import numpy as np import pytest + # There is no public equivalent, so we import the protected function. # If it is ever removed, we can replace it in the unit test. # noinspection PyProtectedMember @@ -38,12 +39,12 @@ def fn_test1(): @memento_function def fn_test_long(): - return bytes([1]) * (2**31-12) + return bytes([1]) * (2**31 - 12) @memento_function def fn_test_long_str(): - return "1" * (2**31-12) + return "1" * (2**31 - 12) class TestMemoryCache: @@ -56,15 +57,16 @@ def get_dummy_memento() -> Memento: invocation_metadata=InvocationMetadata( runtime=datetime.timedelta(seconds=123.0), fn_reference_with_args=FunctionReferenceWithArguments( - fn_test1.fn_reference(), (), {}), + fn_test1.fn_reference(), (), {} + ), result_type=ResultType.number, invocations=[], - resources=[] + resources=[], ), function_dependencies={fn_test1.fn_reference()}, runner={}, correlation_id="abc123", - content_key=VersionedDataSourceKey("key", "def456") + content_key=VersionedDataSourceKey("key", "def456"), ) @pytest.mark.skip(reason="test needs further investigation") @@ -95,16 +97,20 @@ def test_cache_key_for_memento(self): memento = self.get_dummy_memento() arg_hash = memento.invocation_metadata.fn_reference_with_args.arg_hash # Note: This also tests the stability of the function code hash - assert "tests.test_storage_base:fn_test1#de769e9c8c9b500e/{}".format(arg_hash) == \ - cache._cache_key_for_memento(memento) + assert "tests.test_storage_base:fn_test1#de769e9c8c9b500e/{}".format( + arg_hash + ) == cache._cache_key_for_memento(memento) @pytest.mark.needs_canonical_version def test_cache_key_for_fn(self): cache = MemoryCache(1) - arg_hash = FunctionReferenceWithArguments(fn_test1.fn_reference(), (), {}).arg_hash + arg_hash = FunctionReferenceWithArguments( + fn_test1.fn_reference(), (), {} + ).arg_hash # Note: This also tests the stability of the function code hash - assert "tests.test_storage_base:fn_test1#de769e9c8c9b500e/{}".format(arg_hash) == \ - cache._cache_key_for_fn(fn_test1.fn_reference(), arg_hash) + assert "tests.test_storage_base:fn_test1#de769e9c8c9b500e/{}".format( + arg_hash + ) == cache._cache_key_for_fn(fn_test1.fn_reference(), arg_hash) def test_put_read_result(self): cache = MemoryCache(1) @@ -141,7 +147,9 @@ def test_get_mementos(self): def test_is_memoized(self): cache = MemoryCache(1) memento = self.get_dummy_memento() - arg_hash = FunctionReferenceWithArguments(fn_test1.fn_reference(), (), {}).arg_hash + arg_hash = FunctionReferenceWithArguments( + fn_test1.fn_reference(), (), {} + ).arg_hash assert not cache.is_memoized(fn_test1.fn_reference(), arg_hash) cache.put(memento, 1, has_result=True) assert cache.is_memoized(fn_test1.fn_reference(), arg_hash) @@ -149,7 +157,9 @@ def test_is_memoized(self): def test_is_all_memoized(self): cache = MemoryCache(1) memento = self.get_dummy_memento() - fn_reference_with_args = FunctionReferenceWithArguments(fn_test1.fn_reference(), (), {}) + fn_reference_with_args = FunctionReferenceWithArguments( + fn_test1.fn_reference(), (), {} + ) assert not cache.is_all_memoized([fn_reference_with_args]) cache.put(memento, 1, has_result=True) assert cache.is_all_memoized([fn_reference_with_args]) @@ -158,28 +168,46 @@ def test_forget_call(self): cache = MemoryCache(1) memento = self.get_dummy_memento() cache.put(memento, 1, has_result=True) - fn_reference_with_args = FunctionReferenceWithArguments(fn_test1.fn_reference(), (), {}) - assert cache.is_memoized(fn_test1.fn_reference(), fn_reference_with_args.arg_hash) + fn_reference_with_args = FunctionReferenceWithArguments( + fn_test1.fn_reference(), (), {} + ) + assert cache.is_memoized( + fn_test1.fn_reference(), fn_reference_with_args.arg_hash + ) cache.forget_call(fn_reference_with_args.fn_reference_with_arg_hash()) - assert not cache.is_memoized(fn_test1.fn_reference(), fn_reference_with_args.arg_hash) + assert not cache.is_memoized( + fn_test1.fn_reference(), fn_reference_with_args.arg_hash + ) def test_forget_everything(self): cache = MemoryCache(1) memento = self.get_dummy_memento() cache.put(memento, 1, has_result=True) - fn_reference_with_args = FunctionReferenceWithArguments(fn_test1.fn_reference(), (), {}) - assert cache.is_memoized(fn_test1.fn_reference(), fn_reference_with_args.arg_hash) + fn_reference_with_args = FunctionReferenceWithArguments( + fn_test1.fn_reference(), (), {} + ) + assert cache.is_memoized( + fn_test1.fn_reference(), fn_reference_with_args.arg_hash + ) cache.forget_everything() - assert not cache.is_memoized(fn_test1.fn_reference(), fn_reference_with_args.arg_hash) + assert not cache.is_memoized( + fn_test1.fn_reference(), fn_reference_with_args.arg_hash + ) def test_forget_function(self): cache = MemoryCache(1) memento = self.get_dummy_memento() cache.put(memento, 1, has_result=True) - fn_reference_with_args = FunctionReferenceWithArguments(fn_test1.fn_reference(), (), {}) - assert cache.is_memoized(fn_test1.fn_reference(), fn_reference_with_args.arg_hash) + fn_reference_with_args = FunctionReferenceWithArguments( + fn_test1.fn_reference(), (), {} + ) + assert cache.is_memoized( + fn_test1.fn_reference(), fn_reference_with_args.arg_hash + ) cache.forget_function(fn_test1.fn_reference()) - assert not cache.is_memoized(fn_test1.fn_reference(), fn_reference_with_args.arg_hash) + assert not cache.is_memoized( + fn_test1.fn_reference(), fn_reference_with_args.arg_hash + ) class TestSerializationStrategy: @@ -190,9 +218,11 @@ def test_super_large_dataframe(self): # DataFrame to ensure it can be encoded. arr = np.random.randint(0, 1000000, size=250) * 10 - df = pd.DataFrame({'test': arr}) + df = pd.DataFrame({"test": arr}) total_mem = df.memory_usage(deep=True).sum() - assert total_mem >= 1024 * 1024 * 1024 * 2, 'Memory used: {}'.format(total_mem) # 2 GB + assert total_mem >= 1024 * 1024 * 1024 * 2, "Memory used: {}".format( + total_mem + ) # 2 GB # IPC strategy should work strategy = DefaultCodec.ValuePickleStrategy() @@ -211,15 +241,25 @@ def assert_equality(self, first, second, message): elif isinstance(first, np.ndarray): assert_array_equal(first, second, message) elif isinstance(first, list): - assert len(first) == len(second), f'Different lengths: {len(first)} vs {len(second)}: {message}' + assert len(first) == len( + second + ), f"Different lengths: {len(first)} vs {len(second)}: {message}" for i in range(0, len(first)): - self.assert_equality(first[i], second[i], - f'Difference at index {i}: {first[i]} vs {second[i]}: {message}') + self.assert_equality( + first[i], + second[i], + f"Difference at index {i}: {first[i]} vs {second[i]}: {message}", + ) elif isinstance(first, dict): - assert first.keys() == second.keys(), f'Different keys: {first.keys()} vs {second.keys()}: {message}' + assert ( + first.keys() == second.keys() + ), f"Different keys: {first.keys()} vs {second.keys()}: {message}" for key in first: - self.assert_equality(first[key], second[key], - f'Difference at key {key}: {first[key]} vs {second[key]}: {message}') + self.assert_equality( + first[key], + second[key], + f"Difference at key {key}: {first[key]} vs {second[key]}: {message}", + ) else: assert first == second, message @@ -227,36 +267,36 @@ def test_encode_values(self): test_values = [ False, True, - '', - 'abc', - b'\x80\x02\x03', - '42', - '-9.2718', + "", + "abc", + b"\x80\x02\x03", + "42", + "-9.2718", datetime.date.today(), datetime.datetime.now(), datetime.timedelta(days=3, minutes=7, milliseconds=200), {}, - {'a': [4, 'foo', None]}, + {"a": [4, "foo", None]}, [], - [5, 'wat', True, [{}, []]], - pd.DataFrame({'a': np.random.rand(5)}), + [5, "wat", True, [{}, []]], + pd.DataFrame({"a": np.random.rand(5)}), np.random.rand(5), - pd.Series([1, 2, 3], index=pd.Index(['a', 'b', 'c']), name='foo'), - pd.Index([1, 2, 3], name='foo'), + pd.Series([1, 2, 3], index=pd.Index(["a", "b", "c"]), name="foo"), + pd.Index([1, 2, 3], name="foo"), ] strategy = DefaultCodec.ValuePickleStrategy() for val in test_values: encoded = strategy.encode(val) decoded = strategy.decode(encoded) - self.assert_equality(val, decoded, f'Failed on: {val}, decoded={decoded}') + self.assert_equality(val, decoded, f"Failed on: {val}, decoded={decoded}") def test_encode_list(self): test_values = [ [], [7], [[42]], - ['hello', [42], [[[datetime.datetime.now()], -8.6], [['abc'], 'xyz']]], - [True, [], None, {}, 'z', 42, 3.14159, b'\x80\x02\x03'], + ["hello", [42], [[[datetime.datetime.now()], -8.6], [["abc"], "xyz"]]], + [True, [], None, {}, "z", 42, 3.14159, b"\x80\x02\x03"], [pd.DataFrame([1, 2]), [pd.DataFrame([3, 4])]], ] strategy = DefaultCodec.ValuePickleStrategy() @@ -265,16 +305,16 @@ def test_encode_list(self): encoded = strategy.encode(val) decoded = strategy.decode(encoded) assert isinstance(decoded, list) - self.assert_equality(val, decoded, f'Failed on: {val}, decoded={decoded}') + self.assert_equality(val, decoded, f"Failed on: {val}, decoded={decoded}") def test_encode_dict(self): test_values = [ {}, - {'a': None}, - {'a': 'foo'}, - {'a': 'foo', 'b': 'bar', 'c': 42, 'd': False}, - {'a': -6.1, 'b': None, 'c': [], 'd': {'e': {}, 'f': ['nested', 'things']}}, - {'a': pd.DataFrame([1, 2]), 'b': {'c': [pd.DataFrame([3, 4])]}}, + {"a": None}, + {"a": "foo"}, + {"a": "foo", "b": "bar", "c": 42, "d": False}, + {"a": -6.1, "b": None, "c": [], "d": {"e": {}, "f": ["nested", "things"]}}, + {"a": pd.DataFrame([1, 2]), "b": {"c": [pd.DataFrame([3, 4])]}}, ] strategy = DefaultCodec.ValuePickleStrategy() for val in test_values: @@ -282,15 +322,15 @@ def test_encode_dict(self): encoded = strategy.encode(val) decoded = strategy.decode(encoded) assert isinstance(decoded, dict) - self.assert_equality(val, decoded, f'Failed on: {val}, decoded={decoded}') + self.assert_equality(val, decoded, f"Failed on: {val}, decoded={decoded}") def test_encode_series(self): test_values = [ pd.Series([], dtype=int), pd.Series([1, 2, 3]), - pd.Series([1, 2, 3], name='foo'), - pd.Series([1, 2, 3], index=pd.Index(['a', 'b', 'c'])), - pd.Series([1, 2, 3], index=pd.Index(['a', 'b', 'c']), name='foo'), + pd.Series([1, 2, 3], name="foo"), + pd.Series([1, 2, 3], index=pd.Index(["a", "b", "c"])), + pd.Series([1, 2, 3], index=pd.Index(["a", "b", "c"]), name="foo"), ] strategy = DefaultCodec.ValuePickleStrategy() for val in test_values: @@ -298,15 +338,17 @@ def test_encode_series(self): encoded = strategy.encode(val) decoded = strategy.decode(encoded) assert isinstance(decoded, pd.Series) - self.assert_equality(val, decoded, f'Failed on: {val}, decoded={decoded}') + self.assert_equality(val, decoded, f"Failed on: {val}, decoded={decoded}") def test_encode_index(self): test_values = [ pd.Index([]), pd.Index([1, 2, 3]), - pd.Index([1, 2, 3], name='foo'), - pd.date_range('2020-01-01', periods=10, freq='D'), - pd.MultiIndex.from_arrays([[7, 8, 9], ['red', 'green', 'blue']], names=('n', 'c')), + pd.Index([1, 2, 3], name="foo"), + pd.date_range("2020-01-01", periods=10, freq="D"), + pd.MultiIndex.from_arrays( + [[7, 8, 9], ["red", "green", "blue"]], names=("n", "c") + ), ] strategy = DefaultCodec.ValuePickleStrategy() for val in test_values: @@ -314,7 +356,7 @@ def test_encode_index(self): encoded = strategy.encode(val) decoded = strategy.decode(encoded) assert isinstance(decoded, pd.Index) - self.assert_equality(val, decoded, f'Failed on: {val}, decoded={decoded}') + self.assert_equality(val, decoded, f"Failed on: {val}, decoded={decoded}") def test_encode_numpy_array(self): test_values = [ @@ -329,7 +371,7 @@ def test_encode_numpy_array(self): encoded = strategy.encode(val) decoded = strategy.decode(encoded) assert isinstance(decoded, np.ndarray) - self.assert_equality(val, decoded, f'Failed on: {val}, decoded={decoded}') + self.assert_equality(val, decoded, f"Failed on: {val}, decoded={decoded}") @pytest.mark.slow def test_serialize_longbytes(self): @@ -338,8 +380,8 @@ def test_serialize_longbytes(self): encoded = strategy.encode(val) decoded = strategy.decode(encoded) - assert val == decoded, f'Failed test on: {val}, decoded={decoded}' - self.assert_equality(val, decoded, f'Failed test on: {val}, decoded={decoded}') + assert val == decoded, f"Failed test on: {val}, decoded={decoded}" + self.assert_equality(val, decoded, f"Failed test on: {val}, decoded={decoded}") @pytest.mark.slow def test_serialize_longstring(self): @@ -347,7 +389,7 @@ def test_serialize_longstring(self): val = "\u0009\u000A\u0026\u2022\u25E6\u2219\u2023\u2043" * (2**27) encoded = strategy.encode(val) decoded = strategy.decode(encoded) - assert val == decoded, f'Failed test on: {val}, decoded={decoded}' + assert val == decoded, f"Failed test on: {val}, decoded={decoded}" @pytest.mark.slow def test_long_bytes(self): diff --git a/tests/test_storage_filesystem.py b/tests/test_storage_filesystem.py index 577f69e..900d758 100644 --- a/tests/test_storage_filesystem.py +++ b/tests/test_storage_filesystem.py @@ -23,7 +23,12 @@ from twosigma.memento import Environment, ConfigurationRepository, FunctionCluster from twosigma.memento.reference import FunctionReferenceWithArguments from twosigma.memento.storage_filesystem import FilesystemStorageBackend -from tests.test_storage_backend import StorageBackendTester, DataSourceTester, MetadataSourceTester, fn1 +from tests.test_storage_backend import ( + StorageBackendTester, + DataSourceTester, + MetadataSourceTester, + fn1, +) class TestStorageFilesystem(StorageBackendTester): @@ -34,18 +39,27 @@ def setup_method(self): self.original_env = m.Environment.get() self.base_path = tempfile.mkdtemp(prefix="memento_storage_filesystem_test") self.data_path = "{}/data".format(self.base_path) - m.Environment.set(Environment(name="test1", base_dir=self.base_path, repos=[ - ConfigurationRepository( - name="repo1", - clusters={ - "cluster1": FunctionCluster(name="cluster1", - storage=FilesystemStorageBackend( - path="{}/data".format(self.base_path), - # test with a different metadata path from data path - metadata_path="{}/metadata".format(self.base_path))) - } + m.Environment.set( + Environment( + name="test1", + base_dir=self.base_path, + repos=[ + ConfigurationRepository( + name="repo1", + clusters={ + "cluster1": FunctionCluster( + name="cluster1", + storage=FilesystemStorageBackend( + path="{}/data".format(self.base_path), + # test with a different metadata path from data path + metadata_path="{}/metadata".format(self.base_path), + ), + ) + }, + ) + ], ) - ])) + ) self.cluster = m.Environment.get().get_cluster("cluster1") self.backend = self.cluster.storage @@ -68,21 +82,32 @@ class TestStorageFilesystemWithMemoryCache(StorageBackendTester): def setup_method(self): super().setup_method() self.original_env = m.Environment.get() - self.base_path = tempfile.mkdtemp(prefix="memento_storage_filesystem_with_memory_cache_test") + self.base_path = tempfile.mkdtemp( + prefix="memento_storage_filesystem_with_memory_cache_test" + ) self.data_path = "{}/data".format(self.base_path) - m.Environment.set(Environment(name="test1", base_dir=self.base_path, repos=[ - ConfigurationRepository( - name="repo1", - clusters={ - "cluster1": FunctionCluster(name="cluster1", - storage=FilesystemStorageBackend( - path="{}/data".format(self.base_path), - # test with a different metadata path from data path - metadata_path="{}/metadata".format(self.base_path), - memory_cache_mb=16)) - } + m.Environment.set( + Environment( + name="test1", + base_dir=self.base_path, + repos=[ + ConfigurationRepository( + name="repo1", + clusters={ + "cluster1": FunctionCluster( + name="cluster1", + storage=FilesystemStorageBackend( + path="{}/data".format(self.base_path), + # test with a different metadata path from data path + metadata_path="{}/metadata".format(self.base_path), + memory_cache_mb=16, + ), + ) + }, + ) + ], ) - ])) + ) self.cluster = m.Environment.get().get_cluster("cluster1") self.backend = self.cluster.storage @@ -95,7 +120,9 @@ def teardown_method(self): def test_cache_eviction(self): cache = cast(FilesystemStorageBackend, self.backend)._memory_cache for i in range(0, 32): - fn_ref = fn1.fn_reference().with_args(i) # type: FunctionReferenceWithArguments + fn_ref = fn1.fn_reference().with_args( + i + ) # type: FunctionReferenceWithArguments mm = self.get_dummy_memento(fn_ref) self.backend.memoize(None, mm, "." * 1024000) assert self.backend.is_memoized(fn_ref.fn_reference, fn_ref.arg_hash) @@ -116,16 +143,25 @@ def setup_method(self): self.original_env = m.Environment.get() self.base_path = tempfile.mkdtemp(prefix="memento_storage_filesystem_test") self.data_path = "{}/data".format(self.base_path) - m.Environment.set(Environment(name="test1", base_dir=self.base_path, repos=[ - ConfigurationRepository( - name="repo1", - clusters={ - "cluster1": FunctionCluster(name="cluster1", - storage=FilesystemStorageBackend( - path="{}/data".format(self.base_path))) - } + m.Environment.set( + Environment( + name="test1", + base_dir=self.base_path, + repos=[ + ConfigurationRepository( + name="repo1", + clusters={ + "cluster1": FunctionCluster( + name="cluster1", + storage=FilesystemStorageBackend( + path="{}/data".format(self.base_path) + ), + ) + }, + ) + ], ) - ])) + ) self.cluster = m.Environment.get().get_cluster("cluster1") # noinspection PyUnresolvedReferences self.data_source = self.cluster.storage._data_source @@ -146,16 +182,25 @@ def setup_method(self): self.original_env = m.Environment.get() self.base_path = tempfile.mkdtemp(prefix="memento_storage_filesystem_test") self.data_path = "{}/metadata".format(self.base_path) - m.Environment.set(Environment(name="test1", base_dir=self.base_path, repos=[ - ConfigurationRepository( - name="repo1", - clusters={ - "cluster1": FunctionCluster(name="cluster1", - storage=FilesystemStorageBackend( - path="{}/data".format(self.base_path))) - } + m.Environment.set( + Environment( + name="test1", + base_dir=self.base_path, + repos=[ + ConfigurationRepository( + name="repo1", + clusters={ + "cluster1": FunctionCluster( + name="cluster1", + storage=FilesystemStorageBackend( + path="{}/data".format(self.base_path) + ), + ) + }, + ) + ], ) - ])) + ) self.cluster = m.Environment.get().get_cluster("cluster1") # noinspection PyUnresolvedReferences self.metadata_source = self.cluster.storage._metadata_source diff --git a/tests/test_storage_memory.py b/tests/test_storage_memory.py index d1f905a..01db7b3 100644 --- a/tests/test_storage_memory.py +++ b/tests/test_storage_memory.py @@ -29,14 +29,22 @@ def setup_method(self): super().setup_method() self.original_env = m.Environment.get() self.base_path = tempfile.mkdtemp(prefix="memento_storage_memory_test") - m.Environment.set(Environment(name="test1", base_dir=self.base_path, repos=[ - ConfigurationRepository( - name="repo1", - clusters={ - "cluster1": FunctionCluster(name="cluster1", storage=MemoryStorageBackend()) - } + m.Environment.set( + Environment( + name="test1", + base_dir=self.base_path, + repos=[ + ConfigurationRepository( + name="repo1", + clusters={ + "cluster1": FunctionCluster( + name="cluster1", storage=MemoryStorageBackend() + ) + }, + ) + ], ) - ])) + ) self.cluster = m.Environment.get().get_cluster("cluster1") self.backend = self.cluster.storage diff --git a/tests/test_storage_null.py b/tests/test_storage_null.py index 1501a6c..923b6ce 100644 --- a/tests/test_storage_null.py +++ b/tests/test_storage_null.py @@ -18,7 +18,12 @@ import twosigma.memento as m -from twosigma.memento import StorageBackend, FunctionCluster, ConfigurationRepository, Environment # noqa: F401 +from twosigma.memento import ( + StorageBackend, + FunctionCluster, + ConfigurationRepository, + Environment, +) # noqa: F401 from twosigma.memento.metadata import ResultType, InvocationMetadata, Memento from twosigma.memento.storage_null import NullStorageBackend from twosigma.memento.types import VersionedDataSourceKey @@ -42,20 +47,28 @@ class TestStorageNull: """ - backend = None # type: StorageBackend + backend = None # type: StorageBackend def setup_method(self): self.original_env = m.Environment.get() self.base_path = tempfile.mkdtemp(prefix="memento_storage_null_test") self.data_path = "{}/data".format(self.base_path) - m.Environment.set(Environment(name="test1", base_dir=self.base_path, repos=[ - ConfigurationRepository( - name="repo1", - clusters={ - "cluster1": FunctionCluster(name="cluster1", storage=NullStorageBackend()) - } + m.Environment.set( + Environment( + name="test1", + base_dir=self.base_path, + repos=[ + ConfigurationRepository( + name="repo1", + clusters={ + "cluster1": FunctionCluster( + name="cluster1", storage=NullStorageBackend() + ) + }, + ) + ], ) - ])) + ) self.cluster = m.Environment.get().get_cluster("cluster1") self.backend = self.cluster.storage @@ -77,15 +90,16 @@ def test_memoize(self): fn_reference_with_args=fn1_reference, runtime=datetime.timedelta(test_runtime), result_type=ResultType.string, - invocations=[ - fn2_reference - ], - resources=[] + invocations=[fn2_reference], + resources=[], ), - function_dependencies={fn1_reference.fn_reference, fn2_reference.fn_reference}, + function_dependencies={ + fn1_reference.fn_reference, + fn2_reference.fn_reference, + }, runner={}, correlation_id="abc123", - content_key=VersionedDataSourceKey("key", "def456") + content_key=VersionedDataSourceKey("key", "def456"), ) result = fn_return_none_1() self.backend.memoize(None, memento, result) @@ -93,8 +107,14 @@ def test_memoize(self): # The null storage should not waste compute cycles computing the content hash assert memento.content_key is None - assert self.backend.get_memento(fn1_reference.fn_reference_with_arg_hash()) is None + assert ( + self.backend.get_memento(fn1_reference.fn_reference_with_arg_hash()) is None + ) - assert not self.backend.is_memoized(fn1_reference.fn_reference, fn1_reference.arg_hash) + assert not self.backend.is_memoized( + fn1_reference.fn_reference, fn1_reference.arg_hash + ) self.backend.forget_call(fn1_reference.fn_reference_with_arg_hash()) - assert not self.backend.is_memoized(fn1_reference.fn_reference, fn1_reference.arg_hash) + assert not self.backend.is_memoized( + fn1_reference.fn_reference, fn1_reference.arg_hash + ) diff --git a/twosigma/memento/__about__.py b/twosigma/memento/__about__.py index caca0e7..381bc76 100644 --- a/twosigma/memento/__about__.py +++ b/twosigma/memento/__about__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '0.28' +__version__ = "0.28" diff --git a/twosigma/memento/__init__.py b/twosigma/memento/__init__.py index 2926862..ae35b0b 100644 --- a/twosigma/memento/__init__.py +++ b/twosigma/memento/__init__.py @@ -32,20 +32,16 @@ from .reference import FunctionReference -from .configuration import \ - ConfigurationRepository, \ - FunctionCluster, \ - Environment +from .configuration import ConfigurationRepository, FunctionCluster, Environment -from .metadata import \ - Memento, \ - InvocationMetadata +from .metadata import Memento, InvocationMetadata -from .memento import \ - forget_cluster, \ - list_memoized_functions, \ - memento_function, \ - MementoFunction +from .memento import ( + forget_cluster, + list_memoized_functions, + memento_function, + MementoFunction, +) from .resource_function import file_resource @@ -63,7 +59,7 @@ "InvocationMetadata", "ConfigurationRepository", "FunctionCluster", - "Environment" + "Environment", ] diff --git a/twosigma/memento/base.py b/twosigma/memento/base.py index ecb9d88..b6d7ebe 100644 --- a/twosigma/memento/base.py +++ b/twosigma/memento/base.py @@ -34,18 +34,26 @@ def call(self, *args, **kwargs): cluster_config = Environment.get().get_cluster(fn_reference.cluster_name) if cluster_config is None: raise ValueError( - "No cluster found with name {}".format(fn_reference.cluster_name)) + "No cluster found with name {}".format(fn_reference.cluster_name) + ) storage_backend = cluster_config.storage runner_backend = cluster_config.runner # Call function as a batch function of size one. - results = memento_run_batch(context=self.context, - fn_reference_with_args=[FunctionReferenceWithArguments( - fn_reference, args, kwargs, - context_args=self.context.recursive.context_args)], - storage_backend=storage_backend, - runner_backend=runner_backend, - log_runner_backend=runner_backend) + results = memento_run_batch( + context=self.context, + fn_reference_with_args=[ + FunctionReferenceWithArguments( + fn_reference, + args, + kwargs, + context_args=self.context.recursive.context_args, + ) + ], + storage_backend=storage_backend, + runner_backend=runner_backend, + log_runner_backend=runner_backend, + ) result = results[0] if isinstance(result, Exception): @@ -53,8 +61,9 @@ def call(self, *args, **kwargs): return result - def call_batch(self, kwargs_list: List[Dict[str, Any]], - raise_first_exception=True) -> List[Any]: + def call_batch( + self, kwargs_list: List[Dict[str, Any]], raise_first_exception=True + ) -> List[Any]: """ Evaluates this function several times, in batch with the provided arguments. @@ -76,28 +85,39 @@ def call_batch(self, kwargs_list: List[Dict[str, Any]], if any([type(key) for key in kwargs.keys() if type(key) != str]): raise TypeError( "Keys must be strings for all kwargs in kwargs list. Got {}".format( - kwargs_list)) + kwargs_list + ) + ) # Get cluster configuration fn_reference = self.fn_reference() cluster_config = Environment.get().get_cluster(fn_reference.cluster_name) if cluster_config is None: raise ValueError( - "No cluster found with name {}".format(fn_reference.cluster_name)) + "No cluster found with name {}".format(fn_reference.cluster_name) + ) storage_backend = cluster_config.storage runner_backend = cluster_config.runner # Construct a list of FunctionReferenceWithArguments - fns = [FunctionReferenceWithArguments(fn_reference, args=(), kwargs=kwargs, - context_args=self.context.recursive.context_args) - for kwargs in kwargs_list] + fns = [ + FunctionReferenceWithArguments( + fn_reference, + args=(), + kwargs=kwargs, + context_args=self.context.recursive.context_args, + ) + for kwargs in kwargs_list + ] # Invoke the functions in a batch - result = memento_run_batch(context=self.context, - fn_reference_with_args=fns, - storage_backend=storage_backend, - runner_backend=runner_backend, - log_runner_backend=runner_backend) + result = memento_run_batch( + context=self.context, + fn_reference_with_args=fns, + storage_backend=storage_backend, + runner_backend=runner_backend, + log_runner_backend=runner_backend, + ) if raise_first_exception: for r in result: @@ -142,8 +162,9 @@ def map_over_range(self, **kwargs) -> Dict: assert kwargs is not None, "kwargs must not be None" keys = kwargs.keys() - assert len( - keys) == 1, "kwargs must contain exactly one key, corresponding to a fn parameter name" + assert ( + len(keys) == 1 + ), "kwargs must contain exactly one key, corresponding to a fn parameter name" name = next(x for x in keys) values = kwargs[name] @@ -172,11 +193,15 @@ def forget(self, *args, **kwargs): # Forget only the results for the provided parameters fn_reference_with_args = fn_reference.with_args( - *args, **kwargs, _memento_context_args=self.context.recursive.context_args) + *args, **kwargs, _memento_context_args=self.context.recursive.context_args + ) arg_hash = fn_reference_with_args.arg_hash log.info( - "Forgetting {} for arg hash {}".format(fn_reference.qualified_name, arg_hash)) + "Forgetting {} for arg hash {}".format( + fn_reference.qualified_name, arg_hash + ) + ) storage_backend.forget_call(fn_reference_with_args.fn_reference_with_arg_hash()) def forget_all(self): @@ -211,15 +236,21 @@ def get_metadata(self, key: str, args=None, kwargs=None) -> Optional[bytes]: cluster_config = Environment.get().get_cluster(fn_reference.cluster_name) if cluster_config is None: raise ValueError( - "No cluster found with name {}".format(fn_reference.cluster_name)) + "No cluster found with name {}".format(fn_reference.cluster_name) + ) storage_backend = cluster_config.storage - fa = FunctionReferenceWithArguments(fn_reference, args=args, kwargs=kwargs, - context_args=self.context.recursive.context_args) + fa = FunctionReferenceWithArguments( + fn_reference, + args=args, + kwargs=kwargs, + context_args=self.context.recursive.context_args, + ) memento = storage_backend.get_memento(fa.fn_reference_with_arg_hash()) if memento: return storage_backend.read_metadata( memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), - key) + key, + ) else: return None @@ -228,7 +259,9 @@ def with_prevent_further_calls(self, prevent_calls: bool): Sets whether calls after this one are prevented (`True`) or allowed (`False`). """ - new_context = self.context.update_recursive("prevent_further_calls", prevent_calls) + new_context = self.context.update_recursive( + "prevent_further_calls", prevent_calls + ) return self.clone_with(context=new_context) def with_context_args(self, context_args: Dict[str, Any]): @@ -242,8 +275,9 @@ def with_context_args(self, context_args: Dict[str, Any]): Never modify the dict in place. """ - assert context_args is not self.context.recursive.context_args, \ - "Context arg dict must be cloned, not modified in-place" + assert ( + context_args is not self.context.recursive.context_args + ), "Context arg dict must be cloned, not modified in-place" new_context = self.context.update_recursive("context_args", context_args) return self.clone_with(context=new_context) @@ -295,8 +329,12 @@ def memento(self, *args, **kwargs) -> Memento: fn_reference = self.fn_reference() cluster_config = Environment.get().get_cluster(fn_reference.cluster_name) storage = cluster_config.storage - fa = FunctionReferenceWithArguments(fn_reference, args=args, kwargs=kwargs, - context_args=self.context.recursive.context_args) + fa = FunctionReferenceWithArguments( + fn_reference, + args=args, + kwargs=kwargs, + context_args=self.context.recursive.context_args, + ) return storage.get_memento(fa.fn_reference_with_arg_hash()) def monitor_progress(self, monitor: bool = True): @@ -339,13 +377,19 @@ def partial(self, *partial_args, **partial_kwargs) -> MementoFunctionType: fn_reference = self.fn_reference() new_partial_args = fn_reference.partial_args or () new_partial_args += partial_args - new_partial_kwargs = dict(fn_reference.partial_kwargs) if \ - fn_reference.partial_kwargs is not None else {} + new_partial_kwargs = ( + dict(fn_reference.partial_kwargs) + if fn_reference.partial_kwargs is not None + else {} + ) new_partial_kwargs.update(partial_kwargs) - return self.clone_with(partial_args=new_partial_args, - partial_kwargs=new_partial_kwargs) + return self.clone_with( + partial_args=new_partial_args, partial_kwargs=new_partial_kwargs + ) - def put_metadata(self, key: str, value: bytes, *args, store_with_data: bool = False, **kwargs): + def put_metadata( + self, key: str, value: bytes, *args, store_with_data: bool = False, **kwargs + ): """ Write custom metadata for the given arguments to the given key. This is useful, for example, for writing logs to be @@ -366,10 +410,15 @@ def put_metadata(self, key: str, value: bytes, *args, store_with_data: bool = Fa cluster_config = Environment.get().get_cluster(fn_reference.cluster_name) if cluster_config is None: raise ValueError( - "No cluster found with name {}".format(fn_reference.cluster_name)) + "No cluster found with name {}".format(fn_reference.cluster_name) + ) storage_backend = cluster_config.storage - fa = FunctionReferenceWithArguments(fn_reference, args=args, kwargs=kwargs, - context_args=self.context.recursive.context_args) + fa = FunctionReferenceWithArguments( + fn_reference, + args=args, + kwargs=kwargs, + context_args=self.context.recursive.context_args, + ) memento = storage_backend.get_memento(fa.fn_reference_with_arg_hash()) if memento is None: raise MementoNotFoundError("No memento found with provided arguments") @@ -380,7 +429,10 @@ def put_metadata(self, key: str, value: bytes, *args, store_with_data: bool = Fa storage_backend.write_metadata( memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash(), - key, value, store_with_content_key=store_with_content_key) + key, + value, + store_with_content_key=store_with_content_key, + ) def force_local(self, local: bool = True): """ @@ -413,7 +465,9 @@ def list(self, *args, **kwargs) -> DataFrame: If no results match, `None` is returned. """ results = [] - effective_kwargs = self.fn_reference().with_args(*args, **kwargs).effective_kwargs + effective_kwargs = ( + self.fn_reference().with_args(*args, **kwargs).effective_kwargs + ) context_args = self.context.recursive.context_args if context_args is None: context_args = {} diff --git a/twosigma/memento/call_stack.py b/twosigma/memento/call_stack.py index 91349e5..9d4eaac 100644 --- a/twosigma/memento/call_stack.py +++ b/twosigma/memento/call_stack.py @@ -45,8 +45,12 @@ class StackFrame: memento = None # type: Memento recursive_context = None # type: RecursiveContext - def __init__(self, fn_reference_with_args: FunctionReferenceWithArguments, - runner: RunnerBackend, recursive_context: RecursiveContext): + def __init__( + self, + fn_reference_with_args: FunctionReferenceWithArguments, + runner: RunnerBackend, + recursive_context: RecursiveContext, + ): self.memento = Memento( time=datetime.datetime.now(datetime.timezone.utc), invocation_metadata=InvocationMetadata( @@ -54,12 +58,12 @@ def __init__(self, fn_reference_with_args: FunctionReferenceWithArguments, fn_reference_with_args=fn_reference_with_args, result_type=None, invocations=[], - resources=[] + resources=[], ), function_dependencies={fn_reference_with_args.fn_reference}, runner=runner.to_dict(), correlation_id=recursive_context.correlation_id, - content_key=None + content_key=None, ) self.recursive_context = recursive_context @@ -79,6 +83,7 @@ class CallStack: a new `CallStack` is created. """ + _frames = None # type: List[StackFrame] def __init__(self): diff --git a/twosigma/memento/code_hash.py b/twosigma/memento/code_hash.py index a54737b..9667ed7 100644 --- a/twosigma/memento/code_hash.py +++ b/twosigma/memento/code_hash.py @@ -70,7 +70,8 @@ def hash_if_code_object(o): o.co_names, o.co_nlocals, o.co_stacksize, - o.co_varnames] + o.co_varnames, + ] if salt: sha256.update(salt.encode("utf-8")) sha256.update(json.dumps(attr_values, sort_keys=True).encode("utf-8")) @@ -93,7 +94,9 @@ def hash_if_code_object(o): return repr(fn) -def resolve_to_symbolic_names(dependencies: List[Union[str, MementoFunctionType]]) -> Set[str]: +def resolve_to_symbolic_names( + dependencies: List[Union[str, MementoFunctionType]] +) -> Set[str]: """ Takes a set of str and MementoFunctionType and resolves each to a symbolic name, represented as a str. @@ -107,11 +110,13 @@ def resolve_to_symbol(dep) -> str: return dep elif isinstance(dep, MementoFunctionType): q_name = dep.fn_reference().qualified_name - q_name = q_name[0:q_name.find("#")] if "#" in q_name else q_name + q_name = q_name[0 : q_name.find("#")] if "#" in q_name else q_name return q_name else: - raise ValueError("Each dependency must be either a str or " - "MementoFunctionType. Got {}".format(dep)) + raise ValueError( + "Each dependency must be either a str or " + "MementoFunctionType. Got {}".format(dep) + ) return set(resolve_to_symbol(dep) for dep in dependencies) @@ -153,6 +158,7 @@ def eval_attr(n) -> Optional[str]: elif isinstance(n, ast.Name): return n.id return None + eval_attr_result = eval_attr(node) if eval_attr_result is not None: self.references.add(eval_attr_result) @@ -181,8 +187,9 @@ def visit_Name(self, node): result.difference_update(local_vars) # Also remove anything that dereferences a local variable to_remove = { - symbol for symbol in result - if "." in symbol and symbol[0:symbol.find(".")] in local_vars + symbol + for symbol in result + if "." in symbol and symbol[0 : symbol.find(".")] in local_vars } result.difference_update(to_remove) @@ -203,6 +210,7 @@ class HashRule(ABC): that are detected to be called from that symbol. """ + all_rules = [] # type: List[Type[HashRule]] """All known HashRule instances, in order of evaluation""" @@ -221,7 +229,9 @@ class HashRule(ABC): rule_hash = None # type: Optional[str] """The hash computed for this hash rule, in the format of a hex string, or `None`""" - def __init__(self, key: str, parent_symbol: Optional[str], symbol: str, first_level: bool): + def __init__( + self, key: str, parent_symbol: Optional[str], symbol: str, first_level: bool + ): self.key = key self.parent_symbol = parent_symbol self.symbol = symbol @@ -230,14 +240,20 @@ def __init__(self, key: str, parent_symbol: Optional[str], symbol: str, first_le def describe(self) -> str: """Return a textual description of the component of the hash""" - return "{}{}{}".format(self.key, - "#" + self.rule_hash[0:16] if self.rule_hash is not None else "", - " (direct)" if self.first_level else "") + return "{}{}{}".format( + self.key, + "#" + self.rule_hash[0:16] if self.rule_hash is not None else "", + " (direct)" if self.first_level else "", + ) @abstractmethod - def collect_transitive_dependencies(self, result: Set["HashRule"], - root_fn: MementoFunctionType, - package_scope: Set[str], blacklist: List[object]): + def collect_transitive_dependencies( + self, + result: Set["HashRule"], + root_fn: MementoFunctionType, + package_scope: Set[str], + blacklist: List[object], + ): """ Analyze the entity behind this symbol and recursively descend to collect a full tree of transitive dependencies into the `results` variable. Cycles are broken by @@ -274,10 +290,17 @@ def did_change(self) -> bool: pass @staticmethod - def _visit_dependency(result: Set["HashRule"], src_fn: Callable, - parent_symbol: str, symbol: str, required: bool, - root_fn: MementoFunctionType, first_level: bool, - package_scope: Set[str], blacklist: List[object]): + def _visit_dependency( + result: Set["HashRule"], + src_fn: Callable, + parent_symbol: str, + symbol: str, + required: bool, + root_fn: MementoFunctionType, + first_level: bool, + package_scope: Set[str], + blacklist: List[object], + ): """ Evaluates the given symbol (of the form a.b.c) using the globals scope of the provided function and adds a `HashRule` to the result set, if found. The `HashRule` also gets @@ -311,15 +334,25 @@ def _visit_dependency(result: Set["HashRule"], src_fn: Callable, # If ":" is in the symbol name, this is a qualified name from memento. # Construct via a FunctionReference if ":" in symbol: + def memento_fn_resolver(): return FunctionReference.from_qualified_name(symbol).memento_fn + memento_fn = memento_fn_resolver() - rule = MementoFunctionHashRule(parent_symbol=parent_symbol, symbol=symbol, - resolver=lambda: memento_fn_resolver, - obj=memento_fn, first_level=first_level) + rule = MementoFunctionHashRule( + parent_symbol=parent_symbol, + symbol=symbol, + resolver=lambda: memento_fn_resolver, + obj=memento_fn, + first_level=first_level, + ) # collect_transitive_dependencies will add this rule to result - rule.collect_transitive_dependencies(result=result, root_fn=root_fn, - package_scope=package_scope, blacklist=blacklist) + rule.collect_transitive_dependencies( + result=result, + root_fn=root_fn, + package_scope=package_scope, + blacklist=blacklist, + ) return # Otherwise, treat this as a "dotted name" (e.g. a.b.c.fn()) @@ -328,11 +361,19 @@ def memento_fn_resolver(): if required: raise DependencyNotFoundError( "Could not find required dependency {} for function {}. " - "Failed to resolve {} in global table.". - format(symbol, src_fn.__name__, parts[0])) - result.add(UndefinedSymbolHashRule(global_table, parent_symbol=parent_symbol, - symbol=parts[0], first_level=first_level, - ref_is_global_table=True)) + "Failed to resolve {} in global table.".format( + symbol, src_fn.__name__, parts[0] + ) + ) + result.add( + UndefinedSymbolHashRule( + global_table, + parent_symbol=parent_symbol, + symbol=parts[0], + first_level=first_level, + ref_is_global_table=True, + ) + ) return first_part = parts[0] @@ -342,16 +383,18 @@ def resolver(): ref = resolver() - def resolve_symbol(parent_sym: str, sym: str, resolver_fn, - reference: object) -> Optional[HashRule]: + def resolve_symbol( + parent_sym: str, sym: str, resolver_fn, reference: object + ) -> Optional[HashRule]: if any(x is reference for x in blacklist): # Do not include some symbols (e.g. memento library itself) return None for strategy in HashRule.all_rules: # noinspection PyUnresolvedReferences - found = strategy.try_resolve(parent_sym, sym, resolver_fn, - reference, first_level=first_level) + found = strategy.try_resolve( + parent_sym, sym, resolver_fn, reference, first_level=first_level + ) if found is not None: return found return None @@ -359,8 +402,12 @@ def resolve_symbol(parent_sym: str, sym: str, resolver_fn, rule = resolve_symbol(parent_symbol, symbol, resolver, ref) if rule is not None: # collect_transitive_dependencies will add this rule to result - rule.collect_transitive_dependencies(result=result, root_fn=root_fn, - package_scope=package_scope, blacklist=blacklist) + rule.collect_transitive_dependencies( + result=result, + root_fn=root_fn, + package_scope=package_scope, + blacklist=blacklist, + ) return symbol_part = parts[0] @@ -369,11 +416,19 @@ def resolve_symbol(parent_sym: str, sym: str, resolver_fn, if required: raise DependencyNotFoundError( "Could not find required dependency {} for function {}. " - "Failed to resolve {}.{}.". - format(symbol, src_fn.__name__, symbol_part, parts[i])) - result.add(UndefinedSymbolHashRule(ref, parent_symbol=parent_symbol, - symbol=parts[i], first_level=first_level, - ref_is_global_table=False)) + "Failed to resolve {}.{}.".format( + symbol, src_fn.__name__, symbol_part, parts[i] + ) + ) + result.add( + UndefinedSymbolHashRule( + ref, + parent_symbol=parent_symbol, + symbol=parts[i], + first_level=first_level, + ref_is_global_table=False, + ) + ) return symbol_part += "." part_i = parts[i] @@ -387,15 +442,19 @@ def resolver(): rule = resolve_symbol(parent_symbol, symbol_part, resolver, ref) if rule is not None: # collect_transitive_dependencies will add this rule to result - rule.collect_transitive_dependencies(result=result, root_fn=root_fn, - package_scope=package_scope, - blacklist=blacklist) + rule.collect_transitive_dependencies( + result=result, + root_fn=root_fn, + package_scope=package_scope, + blacklist=blacklist, + ) return if required: raise DependencyNotFoundError( "Could not find required dependency {} for function {}-{}. " - "No hash rules matched.".format(symbol, src_fn.__name__, symbol_part)) + "No hash rules matched.".format(symbol, src_fn.__name__, symbol_part) + ) @abstractmethod def clone(self) -> "HashRule": @@ -429,20 +488,40 @@ class UndefinedSymbolHashRule(HashRule): ref_is_global_table = None # type: bool """If true, ref is the global table, so we should use an `in` check instead of `hasattr`""" - def __init__(self, ref: object, parent_symbol: str, symbol: str, first_level: bool, - ref_is_global_table: bool): + def __init__( + self, + ref: object, + parent_symbol: str, + symbol: str, + first_level: bool, + ref_is_global_table: bool, + ): # noinspection PyUnresolvedReferences - super().__init__(key="UndefinedSymbol;{};{}".format(parent_symbol, symbol), - parent_symbol=parent_symbol, symbol=symbol, first_level=first_level) + super().__init__( + key="UndefinedSymbol;{};{}".format(parent_symbol, symbol), + parent_symbol=parent_symbol, + symbol=symbol, + first_level=first_level, + ) self.ref = ref self.ref_is_global_table = ref_is_global_table def clone(self) -> HashRule: - return UndefinedSymbolHashRule(self.ref, self.parent_symbol, self.symbol, - self.first_level, self.ref_is_global_table) - - def collect_transitive_dependencies(self, result: Set[HashRule], root_fn: MementoFunctionType, - package_scope: Set[str], blacklist: List[object]): + return UndefinedSymbolHashRule( + self.ref, + self.parent_symbol, + self.symbol, + self.first_level, + self.ref_is_global_table, + ) + + def collect_transitive_dependencies( + self, + result: Set[HashRule], + root_fn: MementoFunctionType, + package_scope: Set[str], + blacklist: List[object], + ): # An undefined symbol cannot have transitive dependencies pass @@ -459,7 +538,8 @@ def did_change(self) -> bool: def __repr__(self): return "UndefinedSymbolHashRule(parent_symbol={parent_symbol}, symbol={symbol})".format( - parent_symbol=repr(self.parent_symbol), symbol=repr(self.symbol)) + parent_symbol=repr(self.parent_symbol), symbol=repr(self.symbol) + ) class MementoFunctionHashRule(HashRule): @@ -467,40 +547,68 @@ class MementoFunctionHashRule(HashRule): Hash rule for the case where a variable points to a Memento Function """ + memento_fn = None # type: MementoFunctionType resolver = None # type: Callable parent_symbol = None # type: str # noinspection PyUnusedLocal @staticmethod - def try_resolve(parent_symbol: str, symbol: str, resolver: Callable, ref: object, - first_level: bool) -> \ - Optional["MementoFunctionHashRule"]: + def try_resolve( + parent_symbol: str, + symbol: str, + resolver: Callable, + ref: object, + first_level: bool, + ) -> Optional["MementoFunctionHashRule"]: # Memento functions may be wrapped by decorators, so check each level of wrapping # to decide if this is a MementoFunctionType. while True: if isinstance(ref, MementoFunctionType): - return MementoFunctionHashRule(parent_symbol, symbol, resolver, ref, first_level) + return MementoFunctionHashRule( + parent_symbol, symbol, resolver, ref, first_level + ) elif hasattr(ref, "__wrapped__"): ref = ref.__wrapped__ else: return None - def __init__(self, parent_symbol: Optional[str], symbol: str, resolver: Callable, - obj: MementoFunctionType, first_level: bool): + def __init__( + self, + parent_symbol: Optional[str], + symbol: str, + resolver: Callable, + obj: MementoFunctionType, + first_level: bool, + ): # noinspection PyUnresolvedReferences - super().__init__(key="MementoFunction;{};{}".format( - parent_symbol, obj.qualified_name_without_version), - parent_symbol=parent_symbol, symbol=symbol, first_level=first_level) + super().__init__( + key="MementoFunction;{};{}".format( + parent_symbol, obj.qualified_name_without_version + ), + parent_symbol=parent_symbol, + symbol=symbol, + first_level=first_level, + ) self.memento_fn = obj self.resolver = resolver def clone(self) -> "HashRule": - return MementoFunctionHashRule(self.parent_symbol, self.symbol, self.resolver, - self.memento_fn, self.first_level) - - def collect_transitive_dependencies(self, result: Set[HashRule], root_fn: MementoFunctionType, - package_scope: Set[str], blacklist: List[object]): + return MementoFunctionHashRule( + self.parent_symbol, + self.symbol, + self.resolver, + self.memento_fn, + self.first_level, + ) + + def collect_transitive_dependencies( + self, + result: Set[HashRule], + root_fn: MementoFunctionType, + package_scope: Set[str], + blacklist: List[object], + ): # Make sure self is not already accounted for: if self in result: return @@ -513,24 +621,37 @@ def collect_transitive_dependencies(self, result: Set[HashRule], root_fn: Mement memento_fn = self.memento_fn for dep in memento_fn.required_dependencies: - HashRule._visit_dependency(result=result, src_fn=memento_fn.src_fn, - parent_symbol=memento_fn.qualified_name_without_version, - symbol=dep, - required=True, root_fn=root_fn, - first_level=memento_fn is root_fn, - package_scope=package_scope, blacklist=blacklist) + HashRule._visit_dependency( + result=result, + src_fn=memento_fn.src_fn, + parent_symbol=memento_fn.qualified_name_without_version, + symbol=dep, + required=True, + root_fn=root_fn, + first_level=memento_fn is root_fn, + package_scope=package_scope, + blacklist=blacklist, + ) for dep in memento_fn.detected_dependencies: - HashRule._visit_dependency(result=result, src_fn=memento_fn.src_fn, - parent_symbol=memento_fn.qualified_name_without_version, - symbol=dep, - required=False, root_fn=root_fn, - first_level=memento_fn is root_fn, - package_scope=package_scope, blacklist=blacklist) + HashRule._visit_dependency( + result=result, + src_fn=memento_fn.src_fn, + parent_symbol=memento_fn.qualified_name_without_version, + symbol=dep, + required=False, + root_fn=root_fn, + first_level=memento_fn is root_fn, + package_scope=package_scope, + blacklist=blacklist, + ) def compute_hash(self) -> Optional[str]: - return self.memento_fn.explicit_version \ - if self.memento_fn.explicit_version is not None else self.memento_fn.code_hash + return ( + self.memento_fn.explicit_version + if self.memento_fn.explicit_version is not None + else self.memento_fn.code_hash + ) def did_change(self) -> bool: # Changes to the definition of a MementoFunctionType are more robust and detected using a @@ -549,33 +670,64 @@ class GlobalVariableHashRule(HashRule): Hash rule for the case where a variable points to a global variable. """ + var = None # type: object resolver = None # type: Callable last_value = None # type: bytes @staticmethod - def try_resolve(parent_symbol: str, symbol: str, resolver: Callable, ref: object, - first_level: bool) \ - -> Optional["GlobalVariableHashRule"]: + def try_resolve( + parent_symbol: str, + symbol: str, + resolver: Callable, + ref: object, + first_level: bool, + ) -> Optional["GlobalVariableHashRule"]: val = GlobalVariableHashRule._serialize_value(ref) - return GlobalVariableHashRule(parent_symbol, symbol, resolver, ref, val, first_level) \ - if val is not None else None - - def __init__(self, parent_symbol: str, symbol: str, resolver: Callable, ref: object, - last_value: bytes, first_level: bool): - super().__init__(key="GlobalVariable;{};{}".format(parent_symbol, symbol), - parent_symbol=parent_symbol, symbol=symbol, first_level=first_level) + return ( + GlobalVariableHashRule( + parent_symbol, symbol, resolver, ref, val, first_level + ) + if val is not None + else None + ) + + def __init__( + self, + parent_symbol: str, + symbol: str, + resolver: Callable, + ref: object, + last_value: bytes, + first_level: bool, + ): + super().__init__( + key="GlobalVariable;{};{}".format(parent_symbol, symbol), + parent_symbol=parent_symbol, + symbol=symbol, + first_level=first_level, + ) self.var = ref self.resolver = resolver self.last_value = last_value def clone(self): - return GlobalVariableHashRule(self.parent_symbol, self.symbol, self.resolver, - self.var, self.last_value, self.first_level) - - def collect_transitive_dependencies(self, result: Set["HashRule"], - root_fn: MementoFunctionType, - package_scope: Set[str], blacklist: List[object]): + return GlobalVariableHashRule( + self.parent_symbol, + self.symbol, + self.resolver, + self.var, + self.last_value, + self.first_level, + ) + + def collect_transitive_dependencies( + self, + result: Set["HashRule"], + root_fn: MementoFunctionType, + package_scope: Set[str], + blacklist: List[object], + ): # Variables cannot have transitive dependencies, so just add self and return result.add(self) @@ -595,7 +747,9 @@ def did_change(self) -> bool: @staticmethod def _serialize_value(var: object) -> Optional[bytes]: try: - return json.dumps(MementoCodec.encode_arg(var), sort_keys=True).encode("utf-8") + return json.dumps(MementoCodec.encode_arg(var), sort_keys=True).encode( + "utf-8" + ) except (TypeError, ValueError): # not a type that Memento understands or can hash. Do not hash. return None @@ -612,31 +766,62 @@ class NonMementoFunctionHashRule(HashRule): Hash rule for the case where a variable points to a non-memento function. """ + src_fn = None # type: Callable resolver = None # type: Callable @staticmethod - def try_resolve(parent_symbol: str, symbol: str, resolver: Callable, ref: object, - first_level: bool) -> \ - Optional["NonMementoFunctionHashRule"]: - return NonMementoFunctionHashRule(parent_symbol, symbol, resolver, ref, first_level) \ - if callable(ref) and hasattr(ref, "__globals__") else None - - def __init__(self, parent_symbol: str, symbol: str, resolver: Callable, obj: Callable, - first_level: bool): + def try_resolve( + parent_symbol: str, + symbol: str, + resolver: Callable, + ref: object, + first_level: bool, + ) -> Optional["NonMementoFunctionHashRule"]: + return ( + NonMementoFunctionHashRule( + parent_symbol, symbol, resolver, ref, first_level + ) + if callable(ref) and hasattr(ref, "__globals__") + else None + ) + + def __init__( + self, + parent_symbol: str, + symbol: str, + resolver: Callable, + obj: Callable, + first_level: bool, + ): # noinspection PyUnresolvedReferences - super().__init__(key="Function;{};{}".format(parent_symbol, - obj.__module__ + ":" + obj.__qualname__), - parent_symbol=parent_symbol, symbol=symbol, first_level=first_level) + super().__init__( + key="Function;{};{}".format( + parent_symbol, obj.__module__ + ":" + obj.__qualname__ + ), + parent_symbol=parent_symbol, + symbol=symbol, + first_level=first_level, + ) self.src_fn = obj self.resolver = resolver def clone(self) -> HashRule: - return NonMementoFunctionHashRule(self.parent_symbol, self.symbol, self.resolver, - self.src_fn, self.first_level) - - def collect_transitive_dependencies(self, result: Set[HashRule], root_fn: MementoFunctionType, - package_scope: Set[str], blacklist: List[object]): + return NonMementoFunctionHashRule( + self.parent_symbol, + self.symbol, + self.resolver, + self.src_fn, + self.first_level, + ) + + def collect_transitive_dependencies( + self, + result: Set[HashRule], + root_fn: MementoFunctionType, + package_scope: Set[str], + blacklist: List[object], + ): # Make sure self is not already accounted for: if self in result: return @@ -654,10 +839,17 @@ def collect_transitive_dependencies(self, result: Set[HashRule], root_fn: Mement for dep in list_dotted_names(src_fn): # noinspection PyUnresolvedReferences symbol_parent = src_fn.__module__ + ":" + src_fn.__qualname__ - HashRule._visit_dependency(result=result, src_fn=src_fn, - parent_symbol=symbol_parent, symbol=dep, - required=False, root_fn=root_fn, first_level=False, - package_scope=package_scope, blacklist=blacklist) + HashRule._visit_dependency( + result=result, + src_fn=src_fn, + parent_symbol=symbol_parent, + symbol=dep, + required=False, + root_fn=root_fn, + first_level=False, + package_scope=package_scope, + blacklist=blacklist, + ) def compute_hash(self) -> Optional[str]: return fn_code_hash(self.src_fn) @@ -678,5 +870,5 @@ def __repr__(self): HashRule.all_rules = [ MementoFunctionHashRule, NonMementoFunctionHashRule, - GlobalVariableHashRule # must go last + GlobalVariableHashRule, # must go last ] diff --git a/twosigma/memento/configuration.py b/twosigma/memento/configuration.py index 7b872ff..662abdb 100644 --- a/twosigma/memento/configuration.py +++ b/twosigma/memento/configuration.py @@ -46,21 +46,26 @@ # Bytes used to salt all function hashes. # WARNING: Changing anything in this environment hash will force re-evaluation of everything! -ENVIRONMENT_HASH_BYTES =\ - hashlib.sha256(json.dumps( +ENVIRONMENT_HASH_BYTES = hashlib.sha256( + json.dumps( { "memento_serialization_version": 7, # Updated when a backwards-incompatible - # change is made to the file format + # change is made to the file format "packages": { - "pandas": pandas.__version__[0:pandas.__version__.find(".")] # major version - } + "pandas": pandas.__version__[ + 0 : pandas.__version__.find(".") + ] # major version + }, }, - sort_keys=True - ).encode("utf-8")).digest() + sort_keys=True, + ).encode("utf-8") +).digest() # Global collection of registered functions, persists beyond environment creation and destruction _registered_function_names = set() # type: Set[str] -_registered_functions = defaultdict(lambda: list()) # type: Dict[str, List[MementoFunctionType]] +_registered_functions = defaultdict( + lambda: list() +) # type: Dict[str, List[MementoFunctionType]] # environment is lazily-loaded the first time Environment.get() is called. environment = None @@ -87,16 +92,22 @@ def _load_config(base_dir: str, config: Union[str, Dict], **kwargs) -> Dict: config_path = Path(config) if not config_path.is_absolute(): if base_dir_path is None: - raise FileNotFoundError("Could not evaluate relative path '{}' " - "as there is no base directory defined".format(config_path)) + raise FileNotFoundError( + "Could not evaluate relative path '{}' " + "as there is no base directory defined".format(config_path) + ) # this is a relative path. Prepend base_dir config_path = base_dir_path.joinpath(config_path) if config_path.is_dir(): - raise FileNotFoundError("Expected file but got directory: {}".format(config_path)) + raise FileNotFoundError( + "Expected file but got directory: {}".format(config_path) + ) if not config_path.is_file(): - raise FileNotFoundError("Could not find configuration file: {}".format(config_path)) + raise FileNotFoundError( + "Could not find configuration file: {}".format(config_path) + ) # Open as a jinja2 template and perform parameter substitution with config_path.open("r") as f: @@ -104,10 +115,16 @@ def _load_config(base_dir: str, config: Union[str, Dict], **kwargs) -> Dict: if config_path.name.endswith(".json"): result = json.load(substituted_file) - elif config_path.name.endswith(".yaml") or config_path.name.endswith(".yml") or config_path.name.endswith(".jinja"): + elif ( + config_path.name.endswith(".yaml") + or config_path.name.endswith(".yml") + or config_path.name.endswith(".jinja") + ): result = yaml.safe_load(substituted_file) else: - raise IOError("Unknown config file extension for file {}".format(config_path)) + raise IOError( + "Unknown config file extension for file {}".format(config_path) + ) result["base_dir"] = str(config_path.parent) @@ -157,9 +174,16 @@ class FunctionCluster: """ - def __init__(self, config: Dict = None, name: str = None, description: str = None, - maintainer: str = None, documentation: str = None, storage: StorageBackend = None, - runner: RunnerBackend = None): + def __init__( + self, + config: Dict = None, + name: str = None, + description: str = None, + maintainer: str = None, + documentation: str = None, + storage: StorageBackend = None, + runner: RunnerBackend = None, + ): """ Create a new DataFunctionCluster from the provided configuration. @@ -207,24 +231,32 @@ def __init__(self, config: Dict = None, name: str = None, description: str = Non if storage is not None: self.storage = storage elif "storage" not in self.config: - self.storage = StorageBackend.create(_DEFAULT_STORAGE_TYPE, _DEFAULT_STORAGE_CONFIG) + self.storage = StorageBackend.create( + _DEFAULT_STORAGE_TYPE, _DEFAULT_STORAGE_CONFIG + ) else: storage_config = self.config["storage"] if "type" not in storage_config: - raise ValueError("Missing required parameter 'type' in storage " - "configuration {}".format(self.name)) + raise ValueError( + "Missing required parameter 'type' in storage " + "configuration {}".format(self.name) + ) storage_type = storage_config["type"] self.storage = StorageBackend.create(storage_type, storage_config) if runner is not None: self.runner = runner elif "runner" not in self.config: - self.runner = RunnerBackend.create(_DEFAULT_RUNNER_TYPE, _DEFAULT_RUNNER_CONFIG) + self.runner = RunnerBackend.create( + _DEFAULT_RUNNER_TYPE, _DEFAULT_RUNNER_CONFIG + ) else: runner_config = self.config["runner"] if "type" not in runner_config: - raise ValueError("Missing required parameter 'type' in runner " - "configuration {}".format(self.name)) + raise ValueError( + "Missing required parameter 'type' in runner " + "configuration {}".format(self.name) + ) runner_type = runner_config["type"] self.runner = RunnerBackend.create(runner_type, runner_config) @@ -237,7 +269,7 @@ def to_dict(self): config = { "name": self.name, "storage": self.storage.to_dict(), - "runner": self.runner.to_dict() + "runner": self.runner.to_dict(), } if self.description is not None: config["description"] = self.description @@ -254,19 +286,23 @@ class _DefaultFunctionCluster(FunctionCluster): """ - def __init__(self, env: 'Environment'): + def __init__(self, env: "Environment"): - env_path = env.get_base_dir() or str(Path("~").expanduser().joinpath(".memento", "env", env.name)) + env_path = env.get_base_dir() or str( + Path("~").expanduser().joinpath(".memento", "env", env.name) + ) base_path = Path(env_path).joinpath("cluster", "default") - super().__init__({ - "name": "default", - "description": "Default function cluster", - "storage": { - "type": "filesystem", - "path": str(base_path), - "readonly": False + super().__init__( + { + "name": "default", + "description": "Default function cluster", + "storage": { + "type": "filesystem", + "path": str(base_path), + "readonly": False, + }, } - }) + ) class ConfigurationRepository: @@ -307,11 +343,17 @@ class ConfigurationRepository: clusters = None # type: Dict[str, FunctionCluster] modules = None # type: List[str] - def __init__(self, config: Dict = None, name: str = None, base_dir: str = None, - description: str = None, - maintainer: str = None, documentation: str = None, - clusters: Dict[str, FunctionCluster] = None, - modules: List[str] = None): + def __init__( + self, + config: Dict = None, + name: str = None, + base_dir: str = None, + description: str = None, + maintainer: str = None, + documentation: str = None, + clusters: Dict[str, FunctionCluster] = None, + modules: List[str] = None, + ): """ Create a new ConfigurationRepository from the config file located at the provided path. @@ -344,7 +386,9 @@ def __init__(self, config: Dict = None, name: str = None, base_dir: str = None, self.documentation = self.config.get("documentation", None) self.clusters = dict() for cluster_name, config_path in self.config.get("clusters", {}).items(): - self.clusters[cluster_name] = FunctionCluster(_load_config(self.base_dir, config_path)) + self.clusters[cluster_name] = FunctionCluster( + _load_config(self.base_dir, config_path) + ) self.modules = self.config.get("modules", None) if name is not None: @@ -376,9 +420,7 @@ def to_dict(self): config = { "name": self.name, "modules": self.modules, - "clusters": { - k: v.to_dict() for (k, v) in self.clusters.items() - } + "clusters": {k: v.to_dict() for (k, v) in self.clusters.items()}, } if self.base_dir is not None: config["base_dir"] = self.base_dir @@ -398,7 +440,8 @@ def from_file(path: str, **kwargs) -> "ConfigurationRepository": """ return ConfigurationRepository( - _load_config(os.path.dirname(path), os.path.basename(path), **kwargs)) + _load_config(os.path.dirname(path), os.path.basename(path), **kwargs) + ) class Environment: @@ -427,8 +470,13 @@ class Environment: a message when the environment changes """ - def __init__(self, config: Dict = None, name: str = None, base_dir: str = None, - repos: List[ConfigurationRepository] = None): + def __init__( + self, + config: Dict = None, + name: str = None, + base_dir: str = None, + repos: List[ConfigurationRepository] = None, + ): """ Create a new environment from the provided configuration object. The expected form of the config object is: @@ -463,8 +511,10 @@ def __init__(self, config: Dict = None, name: str = None, base_dir: str = None, self.config = config self.name = config.get("name", "default") self.base_dir = config.get("base_dir", None) - self.repos = [ConfigurationRepository(_load_config(self.base_dir, repo_config)) - for repo_config in config.get("repos", [])] + self.repos = [ + ConfigurationRepository(_load_config(self.base_dir, repo_config)) + for repo_config in config.get("repos", []) + ] if name is not None: self.name = name @@ -522,7 +572,9 @@ def register_function(cluster_name: Optional[str], fn: MementoFunctionType): # Check if cluster is locked cluster = Environment.get().get_cluster(cluster_name) if cluster is not None and cluster.locked: - raise ValueError("Cluster {} is locked to new functions".format(cluster_name)) + raise ValueError( + "Cluster {} is locked to new functions".format(cluster_name) + ) # Use "" as a key if this is the default cluster if cluster_name is None: @@ -584,16 +636,13 @@ def to_dict(self): Return a dict representation of this environment """ - config = { - "name": self.name, - "repos": [repo.to_dict() for repo in self.repos] - } + config = {"name": self.name, "repos": [repo.to_dict() for repo in self.repos]} if self.base_dir is not None: config["base_dir"] = self.base_dir return config @classmethod - def set(cls, config: ['Environment', str, object]): + def set(cls, config: ["Environment", str, object]): """ Switch Memento's default environment. @@ -605,11 +654,14 @@ def set(cls, config: ['Environment', str, object]): """ global environment - environment = config if isinstance(config, Environment) \ + environment = ( + config + if isinstance(config, Environment) else Environment(_load_config(os.getcwd(), config)) + ) @classmethod - def get(cls) -> 'Environment': + def get(cls) -> "Environment": """ Return Memento's current default environment. @@ -634,7 +686,9 @@ def _load_environment() -> Environment: memento_env = os.getenv("MEMENTO_ENV") if not memento_env: # Search for default environment file - default_config_file = Path("~").expanduser().joinpath(".memento", "env", "default", "env") + default_config_file = ( + Path("~").expanduser().joinpath(".memento", "env", "default", "env") + ) for ext in [".json", ".yaml"]: filename = default_config_file.with_name(default_config_file.name + ext) if filename.is_file(): diff --git a/twosigma/memento/context.py b/twosigma/memento/context.py index c0d888a..5b2a913 100644 --- a/twosigma/memento/context.py +++ b/twosigma/memento/context.py @@ -83,10 +83,13 @@ class RecursiveContext(ScopedContext): # Note: If additional attributes ar added, be sure to update # serialization.encode_recursive_context. - def __init__(self, correlation_id: str = None, - retry_on_remote_call: bool = False, - prevent_further_calls: bool = False, - context_args: Dict[str, Any] = None) -> None: + def __init__( + self, + correlation_id: str = None, + retry_on_remote_call: bool = False, + prevent_further_calls: bool = False, + context_args: Dict[str, Any] = None, + ) -> None: super().__init__() self.__dict__["correlation_id"] = correlation_id self.__dict__["retry_on_remote_call"] = retry_on_remote_call @@ -102,12 +105,12 @@ def __str__(self): def __repr__(self): return "RecursiveContext({})".format(self.__dict__) - def copy(self) -> 'RecursiveContext': + def copy(self) -> "RecursiveContext": result = RecursiveContext() result.__dict__.update(self.__dict__) return result - def update(self, key: str, value: Any) -> 'RecursiveContext': + def update(self, key: str, value: Any) -> "RecursiveContext": if key not in self.__dict__: raise ValueError("No such property {}".format(key)) result = self.copy() @@ -140,10 +143,12 @@ class LocalContext(ScopedContext): of this result. """ - def __init__(self, - ignore_result: bool = False, - force_local: bool = False, - monitor_progress: bool = False) -> None: + def __init__( + self, + ignore_result: bool = False, + force_local: bool = False, + monitor_progress: bool = False, + ) -> None: super().__init__() self.__dict__["ignore_result"] = ignore_result self.__dict__["force_local"] = force_local @@ -158,12 +163,12 @@ def __str__(self): def __repr__(self): return "LocalContext({})".format(self.__dict__) - def copy(self) -> 'LocalContext': + def copy(self) -> "LocalContext": result = LocalContext() result.__dict__.update(self.__dict__) return result - def update(self, key: str, value: Any) -> 'LocalContext': + def update(self, key: str, value: Any) -> "LocalContext": if key not in self.__dict__: raise ValueError("No such property {}".format(key)) result = self.copy() @@ -180,7 +185,9 @@ class InvocationContext: recursive = None # type: RecursiveContext local = None # type: LocalContext - def __init__(self, recursive: RecursiveContext = None, local: LocalContext = None) -> None: + def __init__( + self, recursive: RecursiveContext = None, local: LocalContext = None + ) -> None: self.__dict__["recursive"] = recursive or RecursiveContext() self.__dict__["local"] = local or LocalContext() @@ -190,14 +197,14 @@ def __setattr__(self, key, value): def __str__(self): return "InvocationContext({}, {})".format(self.recursive, self.local) - def update_local(self, key: str, value: Any) -> 'InvocationContext': + def update_local(self, key: str, value: Any) -> "InvocationContext": """ Create a copy of this context with the given property updated in the local part of the context. """ return InvocationContext(self.recursive, self.local.update(key, value)) - def update_recursive(self, key: str, value: Any) -> 'InvocationContext': + def update_recursive(self, key: str, value: Any) -> "InvocationContext": """ Create a copy of this context with the given property updated in the recursive part of the context. diff --git a/twosigma/memento/dependency_graph.py b/twosigma/memento/dependency_graph.py index 9088165..5a829f8 100644 --- a/twosigma/memento/dependency_graph.py +++ b/twosigma/memento/dependency_graph.py @@ -72,8 +72,12 @@ class DependencyGraph(DependencyGraphType): _all_rules = None # type: List[HashRule] _label_filter = None # type: Callable[[str], str] - def __init__(self, memento_fn: MementoFunctionType = False, verbose: bool = False, - label_filter: Callable[[str], str] = None): + def __init__( + self, + memento_fn: MementoFunctionType = False, + verbose: bool = False, + label_filter: Callable[[str], str] = None, + ): """ Creates a new dependency graph for the given Memento Function @@ -93,28 +97,40 @@ def __init__(self, memento_fn: MementoFunctionType = False, verbose: bool = Fals def with_verbose(self, verbose: bool) -> "DependencyGraphType": return DependencyGraph(self.memento_fn, verbose, self._label_filter) - def with_label_filter(self, label_filter: Callable[[str], str]) -> "DependencyGraphType": + def with_label_filter( + self, label_filter: Callable[[str], str] + ) -> "DependencyGraphType": return DependencyGraph(self.memento_fn, self._verbose, label_filter) def transitive_memento_fn_dependencies(self) -> Set[MementoFunctionType]: """The set of all transitive dependencies on other Memento functions""" # noinspection PyUnresolvedReferences return set( - rule.memento_fn for rule in self._all_rules + rule.memento_fn + for rule in self._all_rules if hasattr(rule, "memento_fn") and rule.memento_fn != self.memento_fn ) def direct_memento_fn_dependencies(self) -> Set[MementoFunctionType]: """The set of all direct dependencies on other Memento functions""" + def is_direct_dependency(rule): is_memento_fn = hasattr(rule, "memento_fn") and rule.memento_fn is not None - return is_memento_fn and rule.memento_fn != self.memento_fn and rule.first_level + return ( + is_memento_fn + and rule.memento_fn != self.memento_fn + and rule.first_level + ) # noinspection PyTypeChecker,PyUnresolvedReferences - return set(rule.memento_fn for rule in self._all_rules if is_direct_dependency(rule)) + return set( + rule.memento_fn for rule in self._all_rules if is_direct_dependency(rule) + ) @classmethod - def _rules_until_first_memento_fn(cls, memento_fn: MementoFunctionType) -> List[HashRule]: + def _rules_until_first_memento_fn( + cls, memento_fn: MementoFunctionType + ) -> List[HashRule]: """ Set of all hash rules up through the first MementoFunctions in the stack but not beyond. @@ -132,8 +148,10 @@ def _rules_until_first_memento_fn(cls, memento_fn: MementoFunctionType) -> List[ (node_type, _, node_name) = cls.parse_key(rule.key) # noinspection PyUnresolvedReferences if hasattr(rule, "memento_fn") and rule.memento_fn is not None: - if rule.memento_fn.qualified_name_without_version ==\ - memento_fn.qualified_name_without_version: + if ( + rule.memento_fn.qualified_name_without_version + == memento_fn.qualified_name_without_version + ): # Exclude current function from result continue else: @@ -142,7 +160,9 @@ def _rules_until_first_memento_fn(cls, memento_fn: MementoFunctionType) -> List[ # Note: Converges to O(n^2) in degenerate very flat graphs # This should be very small n in almost all realistic cases. if node_name == r.parent_symbol: - if r.key not in already_processed: # prevent infinite loop on cycle + if ( + r.key not in already_processed + ): # prevent infinite loop on cycle already_processed.add(r.key) more_rules.append(r) result.append(rule) @@ -150,18 +170,20 @@ def _rules_until_first_memento_fn(cls, memento_fn: MementoFunctionType) -> List[ return result @staticmethod - def generate_graphviz(graph: Dict[str, Node], qualified_name_without_version: str) -> Digraph: + def generate_graphviz( + graph: Dict[str, Node], qualified_name_without_version: str + ) -> Digraph: digraph = Digraph( format="svg", graph_attr={"rankdir": "BT"}, - node_attr={"shape": "box", "fontname": "Helvetica", "fontsize": "10"} + node_attr={"shape": "box", "fontname": "Helvetica", "fontsize": "10"}, ) type_to_attrs = { "MementoFunction": {"shape": "rectangle"}, "Function": {"shape": "rectangle", "style": "rounded"}, "GlobalVariable": {"shape": "ellipse"}, - "UndefinedSymbol": {"shape": "octagon"} + "UndefinedSymbol": {"shape": "octagon"}, } def hash_name(name: str) -> str: @@ -174,14 +196,21 @@ def hash_name(name: str) -> str: # Parse name and only show cluster and module if different from root node (cluster, module, fn_name) = _parse_name(node.id) label = "<" - label += html.escape(cluster) + "
" if cluster is not None and \ - cluster != root_cluster else "" - label += html.escape(module) + "
" if module is not None and \ - module != root_module else "" + label += ( + html.escape(cluster) + "
" + if cluster is not None and cluster != root_cluster + else "" + ) + label += ( + html.escape(module) + "
" + if module is not None and module != root_module + else "" + ) label += html.escape(node.label) label += ">" - attrs = type_to_attrs[node.node_type if node.node_type is - not None else "MementoFunction"] + attrs = type_to_attrs[ + node.node_type if node.node_type is not None else "MementoFunction" + ] digraph.node(hash_name(node.id), label=label, **attrs) for edge in node.edges: @@ -191,22 +220,24 @@ def hash_name(name: str) -> str: def graph(self) -> Digraph: graph = self._get_graph() - return DependencyGraph.generate_graphviz(graph, - self.memento_fn.qualified_name_without_version) + return DependencyGraph.generate_graphviz( + graph, self.memento_fn.qualified_name_without_version + ) @staticmethod def generate_df(graph: Dict[str, Node]) -> pd.DataFrame: rows = [] for node in graph.values(): for edge in node.edges: - rows.append({ - "src": node.id, - "target": edge, - "type": graph[edge].node_type - }) - - return pd.DataFrame(data=rows, columns=["src", "target", "type"]).\ - sort_values(by=["src", "type"]).reset_index(drop=True) + rows.append( + {"src": node.id, "target": edge, "type": graph[edge].node_type} + ) + + return ( + pd.DataFrame(data=rows, columns=["src", "target", "type"]) + .sort_values(by=["src", "type"]) + .reset_index(drop=True) + ) def df(self): """ @@ -233,22 +264,28 @@ def parse_key(key: str) -> Tuple[str, str, str]: semi_index = key.find(";") semi_index_2 = key.find(";", semi_index + 1) node_type = key[0:semi_index] - node_parent = key[semi_index + 1:semi_index_2] - node_name = key[semi_index_2 + 1:] + node_parent = key[semi_index + 1 : semi_index_2] + node_name = key[semi_index_2 + 1 :] return node_type, node_parent, node_name def _get_graph(self) -> Dict[str, Node]: if self._graph is None: self._graph = dict() - self.generate_graph(set(), self._graph, self.memento_fn, self._label_filter, - self._verbose) + self.generate_graph( + set(), self._graph, self.memento_fn, self._label_filter, self._verbose + ) return self._graph @classmethod def generate_graph( - cls, processed: Set[str], graph: Dict[str, Node], - root_fn: MementoFunctionType, label_filter: Callable[[str], str], verbose: bool): + cls, + processed: Set[str], + graph: Dict[str, Node], + root_fn: MementoFunctionType, + label_filter: Callable[[str], str], + verbose: bool, + ): """ Generates a tree of Nodes and returns the root node (which is always a MementoFunction node). @@ -300,8 +337,9 @@ def get_node(name: str, lbl_filter: Callable[[str], str]): parent_node.edges.add(node_name) # noinspection PyUnresolvedReferences if hasattr(rule, "memento_fn") and rule.memento_fn is not None: - cls.generate_graph(processed, graph, rule.memento_fn, label_filter, - verbose) + cls.generate_graph( + processed, graph, rule.memento_fn, label_filter, verbose + ) else: # Only add edges from root to MementoFunctions root_node = get_node(root_node_name, label_filter) @@ -313,4 +351,6 @@ def get_node(name: str, lbl_filter: Callable[[str], str]): node = get_node(node_name, label_filter) node.node_type = node_type root_node.edges.add(node_name) - cls.generate_graph(processed, graph, rule.memento_fn, label_filter, verbose) + cls.generate_graph( + processed, graph, rule.memento_fn, label_filter, verbose + ) diff --git a/twosigma/memento/exception.py b/twosigma/memento/exception.py index 66a511b..c593ad2 100644 --- a/twosigma/memento/exception.py +++ b/twosigma/memento/exception.py @@ -25,7 +25,7 @@ from twosigma.memento.reference import FunctionReferenceWithArgHash -_MEMENTO_EXCEPTION_REGEX = r'([^:]*)::([^:]*):?([^:]*)' +_MEMENTO_EXCEPTION_REGEX = r"([^:]*)::([^:]*):?([^:]*)" class MementoException(RuntimeError): @@ -49,8 +49,10 @@ class MementoException(RuntimeError): def __init__(self, exception_name: str, message: str, stack_trace: str): super().__init__( - "{}: {}. Original stack trace follows:\n{}".format(exception_name, message, - stack_trace)) + "{}: {}. Original stack trace follows:\n{}".format( + exception_name, message, stack_trace + ) + ) self.exception_name = exception_name self.message = message self.stack_trace = stack_trace @@ -80,8 +82,11 @@ def to_exception(self) -> Exception: return self try: # noinspection PyCallingNonCallable - return ref("{}. Original stack trace follows:\n{}".format(self.message, - self.stack_trace)) + return ref( + "{}. Original stack trace follows:\n{}".format( + self.message, self.stack_trace + ) + ) except TypeError: # If we couldn't construct the exception (e.g. it has required parameters), # just return this as a MementoException @@ -107,8 +112,11 @@ def from_exception(e: Exception) -> "MementoException": module = exc_class.__module__ qual_name = exc_class.__qualname__ full_qual_name = "{}::{}:{}".format(language, module, qual_name) - return MementoException(full_qual_name, str(e), "".join( - traceback.format_exception(type(e), e, e.__traceback__))) + return MementoException( + full_qual_name, + str(e), + "".join(traceback.format_exception(type(e), e, e.__traceback__)), + ) class NonMemoizedException(RuntimeError): @@ -126,6 +134,7 @@ class MementoNotFoundError(NonMemoizedException): but couldn't associate it with a Memento) """ + def __init__(self, message: str): super().__init__(message) diff --git a/twosigma/memento/external.py b/twosigma/memento/external.py index ea2f914..6f09369 100644 --- a/twosigma/memento/external.py +++ b/twosigma/memento/external.py @@ -62,8 +62,12 @@ def cluster_name(self): _hash_rules = None # type: List[HashRule] def __init__( - self, fn_reference: FunctionReference, context: InvocationContext, - function_type: str, hash_rules: List[HashRule]): + self, + fn_reference: FunctionReference, + context: InvocationContext, + function_type: str, + hash_rules: List[HashRule], + ): """ Creates a new ExternalMementoFunction for the given function reference. @@ -78,25 +82,44 @@ def __init__( parts = FunctionReference.parse_qualified_name(fn_reference.qualified_name) self._version = parts["version"] self.context = context - self.qualified_name_without_version = self._fn_reference.qualified_name_without_version + self.qualified_name_without_version = ( + self._fn_reference.qualified_name_without_version + ) self.code_hash = None self.function_type = function_type self._hash_rules = hash_rules def _clone_fn_ref( - self, fn: Callable = None, src_fn: Callable = None, cluster_name: str = None, - version: str = None, calculated_version: str = None, - partial_args: Tuple[Any] = None, partial_kwargs: Dict[str, Any] = None, - auto_dependencies: bool = True, - dependencies: List[Union[str, "MementoFunctionType"]] = None, - version_code_hash: str = None, version_salt: str = None) -> FunctionReference: - assert fn is None, "External function may not refer to a function in the local process" - assert src_fn is None, "External function may not refer to a source function in the " \ - "local process" - assert calculated_version is None, "External functions always have fixed versions" - assert auto_dependencies, "Cannot disable auto_dependencies for external functions" + self, + fn: Callable = None, + src_fn: Callable = None, + cluster_name: str = None, + version: str = None, + calculated_version: str = None, + partial_args: Tuple[Any] = None, + partial_kwargs: Dict[str, Any] = None, + auto_dependencies: bool = True, + dependencies: List[Union[str, "MementoFunctionType"]] = None, + version_code_hash: str = None, + version_salt: str = None, + ) -> FunctionReference: + assert ( + fn is None + ), "External function may not refer to a function in the local process" + assert src_fn is None, ( + "External function may not refer to a source function in the " + "local process" + ) + assert ( + calculated_version is None + ), "External functions always have fixed versions" + assert ( + auto_dependencies + ), "Cannot disable auto_dependencies for external functions" assert dependencies is None, "Cannot set dependencies for external functions" - assert version_code_hash is None, "Cannot set version code hash for external functions" + assert ( + version_code_hash is None + ), "Cannot set version code hash for external functions" assert version_salt is None, "Cannot set version_salt for external functions" return FunctionReference( memento_fn=self, @@ -107,7 +130,7 @@ def _clone_fn_ref( module_name=self._fn_reference.module, function_name=self._fn_reference.function_name, parameter_names=self._fn_reference.parameter_names, - external=True + external=True, ) def hash_rules(self) -> List[HashRule]: @@ -120,7 +143,8 @@ def fn_reference(self): return self._fn_reference def dependencies( - self, verbose=False, label_filter: Callable[[str], str] = None) -> DependencyGraphType: + self, verbose=False, label_filter: Callable[[str], str] = None + ) -> DependencyGraphType: return DependencyGraph(self, verbose=verbose, label_filter=label_filter) def _filter_call(self, *args, **kwargs) -> Any: @@ -149,7 +173,9 @@ def __repr__(self) -> str: return "ExternalMementoFunction({})".format(repr(self.fn_reference())) @classmethod - def get_registered_function_type_classes(cls) -> Dict[str, type(MementoFunctionBase)]: + def get_registered_function_type_classes( + cls, + ) -> Dict[str, type(MementoFunctionBase)]: return _registered_function_type_classes @@ -157,24 +183,32 @@ class UnboundExternalMementoFunction(ExternalMementoFunctionBase): """ ExternalMementoFunction which is not bound to a particular server endpoint. """ + def __init__( - self, - context: Optional[InvocationContext] = None, - cluster_name: Optional[str] = None, module_name: Optional[str] = None, - function_name: Optional[str] = None, version: Optional[str] = None, - partial_args: Optional[Tuple[Any]] = None, - partial_kwargs: Optional[Dict[str, Any]] = None, - parameter_names: Optional[List[str]] = None, - fn_reference: Optional[FunctionReference] = None): + self, + context: Optional[InvocationContext] = None, + cluster_name: Optional[str] = None, + module_name: Optional[str] = None, + function_name: Optional[str] = None, + version: Optional[str] = None, + partial_args: Optional[Tuple[Any]] = None, + partial_kwargs: Optional[Dict[str, Any]] = None, + parameter_names: Optional[List[str]] = None, + fn_reference: Optional[FunctionReference] = None, + ): assert fn_reference or cluster_name is not None, "Cluster name is required" if fn_reference is None: fn_reference = FunctionReference( memento_fn=self, - cluster_name=cluster_name, module_name=module_name, - function_name=function_name, version=version, - partial_args=partial_args, partial_kwargs=partial_kwargs, - parameter_names=parameter_names, external=True + cluster_name=cluster_name, + module_name=module_name, + function_name=function_name, + version=version, + partial_args=partial_args, + partial_kwargs=partial_kwargs, + parameter_names=parameter_names, + external=True, ) if context is None: @@ -183,21 +217,36 @@ def __init__( super().__init__(fn_reference, context, "unbound", hash_rules=list()) def clone_with( - self, fn: Callable = None, src_fn: Callable = None, cluster_name: str = None, - version: str = None, calculated_version: str = None, context: InvocationContext = None, - partial_args: Tuple[Any] = None, partial_kwargs: Dict[str, Any] = None, - auto_dependencies: bool = True, - dependencies: List[Union[str, MementoFunctionType]] = None, - version_code_hash: str = None, version_salt: str = None) -> MementoFunctionType: + self, + fn: Callable = None, + src_fn: Callable = None, + cluster_name: str = None, + version: str = None, + calculated_version: str = None, + context: InvocationContext = None, + partial_args: Tuple[Any] = None, + partial_kwargs: Dict[str, Any] = None, + auto_dependencies: bool = True, + dependencies: List[Union[str, MementoFunctionType]] = None, + version_code_hash: str = None, + version_salt: str = None, + ) -> MementoFunctionType: fn_ref = self._clone_fn_ref( - fn=fn, src_fn=src_fn, cluster_name=cluster_name, version=version, - calculated_version=calculated_version, partial_args=partial_args, - partial_kwargs=partial_kwargs, auto_dependencies=auto_dependencies, - dependencies=dependencies, version_code_hash=version_code_hash, - version_salt=version_salt + fn=fn, + src_fn=src_fn, + cluster_name=cluster_name, + version=version, + calculated_version=calculated_version, + partial_args=partial_args, + partial_kwargs=partial_kwargs, + auto_dependencies=auto_dependencies, + dependencies=dependencies, + version_code_hash=version_code_hash, + version_salt=version_salt, ) return UnboundExternalMementoFunction( - context=context or self.context, fn_reference=fn_ref) + context=context or self.context, fn_reference=fn_ref + ) ExternalMementoFunctionBase.register("unbound", UnboundExternalMementoFunction) diff --git a/twosigma/memento/memento.py b/twosigma/memento/memento.py index de9f556..7b2b8f0 100644 --- a/twosigma/memento/memento.py +++ b/twosigma/memento/memento.py @@ -29,13 +29,19 @@ from .logging import log from .types import MementoFunctionType, DependencyGraphType from .reference import FunctionReference -from .code_hash import fn_code_hash, resolve_to_symbolic_names, \ - HashRule, MementoFunctionHashRule, list_dotted_names +from .code_hash import ( + fn_code_hash, + resolve_to_symbolic_names, + HashRule, + MementoFunctionHashRule, + list_dotted_names, +) from .metadata import ResultType -_MementoFunctionVersionCacheEntry =\ - namedtuple("_MementoFunctionVersionCacheEntry", ["as_of_generation", "version"]) +_MementoFunctionVersionCacheEntry = namedtuple( + "_MementoFunctionVersionCacheEntry", ["as_of_generation", "version"] +) class MementoFunction(MementoFunctionBase): @@ -81,7 +87,9 @@ class MementoFunction(MementoFunctionBase): forces all version hashes to be recomputed. """ - _global_fn_version_cache = dict() # type: Dict[str, _MementoFunctionVersionCacheEntry] + _global_fn_version_cache = ( + dict() + ) # type: Dict[str, _MementoFunctionVersionCacheEntry] """ Cache that maps from function name to a version cache entry that contains the generation number as of when this was current and the function version number. @@ -123,7 +131,9 @@ class MementoFunction(MementoFunctionBase): auto_dependencies = None # type: bool "If True, dependencies will be searched for automatically" - _constructor_provided_dependencies = None # type: List[Union[str, MementoFunctionType]] + _constructor_provided_dependencies = ( + None + ) # type: List[Union[str, MementoFunctionType]] "The explicit list of dependencies provided by the user" _constructor_provided_version_code_hash = None # type: str @@ -162,7 +172,9 @@ def supports_kwargs(self) -> bool: """True if function supports kwargs, False otherwise""" fn = self.fn # noinspection PyUnresolvedReferences - return hasattr(fn, "__code__") and (fn.__code__.co_flags & inspect.CO_VARKEYWORDS != 0) + return hasattr(fn, "__code__") and ( + fn.__code__.co_flags & inspect.CO_VARKEYWORDS != 0 + ) def get_args(self) -> List[Dict[str, str]]: """ @@ -177,25 +189,27 @@ def get_args(self) -> List[Dict[str, str]]: param = params[param_name] arg = { "name": param.name, - "argumentType": ResultType.from_annotation(param.annotation).name + "argumentType": ResultType.from_annotation(param.annotation).name, } args.append(arg) return args def __init__( - self, fn: Callable, - src_fn: Callable = None, - cluster_name: str = None, - version: str = None, - calculated_version: str = None, - context: InvocationContext = None, - partial_args: Tuple[Any] = None, - partial_kwargs: Dict[str, Any] = None, - auto_dependencies: bool = True, - dependencies: List[Union[str, MementoFunctionType]] = None, - version_code_hash: str = None, - version_salt: str = None, - register_fn: bool = True): + self, + fn: Callable, + src_fn: Callable = None, + cluster_name: str = None, + version: str = None, + calculated_version: str = None, + context: InvocationContext = None, + partial_args: Tuple[Any] = None, + partial_kwargs: Dict[str, Any] = None, + auto_dependencies: bool = True, + dependencies: List[Union[str, MementoFunctionType]] = None, + version_code_hash: str = None, + version_salt: str = None, + register_fn: bool = True, + ): """ Creates a new MementoFunction that wraps the provided `fn`. @@ -242,8 +256,9 @@ def __init__( """ assert inspect.isfunction(fn), "fn {} is not a function".format(fn) - assert not isinstance(fn, MementoFunctionType), \ - "Cannot create a MementoFunction that wraps another MementoFunction" + assert not isinstance( + fn, MementoFunctionType + ), "Cannot create a MementoFunction that wraps another MementoFunction" self._hash_rules = [] # type: List[HashRule] self.fn = fn @@ -258,7 +273,9 @@ def __init__( elif version_code_hash is not None: code_hash = version_code_hash else: - code_hash = fn_code_hash(fn, salt=version_salt, environment=ENVIRONMENT_HASH_BYTES) + code_hash = fn_code_hash( + fn, salt=version_salt, environment=ENVIRONMENT_HASH_BYTES + ) self.code_hash = code_hash self.context = context or InvocationContext() @@ -274,12 +291,15 @@ def __init__( # Resolve required dependencies to symbolic names so evaluation can be deferred. # This allows re-binding of functions later. self._constructor_provided_dependencies = dependencies - self.required_dependencies = resolve_to_symbolic_names(dependencies) \ - if dependencies else set() - assert self.required_dependencies is None or \ - all(isinstance(dep, str) for dep in self.required_dependencies),\ - "Could not resolve all functions in dependencies to symbolic names" - self.detected_dependencies = list_dotted_names(self.src_fn) if auto_dependencies else set() + self.required_dependencies = ( + resolve_to_symbolic_names(dependencies) if dependencies else set() + ) + assert self.required_dependencies is None or all( + isinstance(dep, str) for dep in self.required_dependencies + ), "Could not resolve all functions in dependencies to symbolic names" + self.detected_dependencies = ( + list_dotted_names(self.src_fn) if auto_dependencies else set() + ) self.partial_args = partial_args self.partial_kwargs = partial_kwargs @@ -287,8 +307,10 @@ def __init__( self._fn_reference = None if "" in self.qualified_name_without_version: - raise ValueError("Memento functions must be top-level functions, " - "not local to another function.") + raise ValueError( + "Memento functions must be top-level functions, " + "not local to another function." + ) functools.update_wrapper(self, fn) @@ -296,23 +318,27 @@ def __init__( # Increase generation number so other functions can update their version number # if necessary MementoFunction.increment_global_fn_generation( - reason="registered new function {}".format(self.qualified_name_without_version)) + reason="registered new function {}".format( + self.qualified_name_without_version + ) + ) Environment.register_function(cluster_name, self) def clone_with( - self, - fn: Callable = None, - src_fn: Callable = None, - cluster_name: str = None, - version: str = None, - calculated_version: str = None, - context: InvocationContext = None, - partial_args: Tuple[Any] = None, - partial_kwargs: Dict[str, Any] = None, - auto_dependencies: bool = True, - dependencies: List[Union[str, MementoFunctionType]] = None, - version_code_hash: str = None, - version_salt: str = None) -> MementoFunctionType: + self, + fn: Callable = None, + src_fn: Callable = None, + cluster_name: str = None, + version: str = None, + calculated_version: str = None, + context: InvocationContext = None, + partial_args: Tuple[Any] = None, + partial_kwargs: Dict[str, Any] = None, + auto_dependencies: bool = True, + dependencies: List[Union[str, MementoFunctionType]] = None, + version_code_hash: str = None, + version_salt: str = None, + ) -> MementoFunctionType: """Re-constructs a clone of this function, modifying one or more attributes""" return MementoFunction( fn=fn or self.fn, @@ -325,22 +351,27 @@ def clone_with( partial_kwargs=partial_kwargs or self.partial_kwargs, auto_dependencies=auto_dependencies or self.auto_dependencies, dependencies=dependencies or self._constructor_provided_dependencies, - version_code_hash=version_code_hash or self._constructor_provided_version_code_hash, + version_code_hash=version_code_hash + or self._constructor_provided_version_code_hash, version_salt=version_salt or self._constructor_provided_version_salt, - register_fn=False) + register_fn=False, + ) def call(self, *args, **kwargs): self._validate_dependency() return super(MementoFunction, self).call(*args, **kwargs) - def call_batch(self, kwargs_list: List[Dict[str, Any]], - raise_first_exception=True) -> List[Any]: + def call_batch( + self, kwargs_list: List[Dict[str, Any]], raise_first_exception=True + ) -> List[Any]: self._validate_dependency() return super(MementoFunction, self).call_batch( - kwargs_list=kwargs_list, raise_first_exception=raise_first_exception) + kwargs_list=kwargs_list, raise_first_exception=raise_first_exception + ) - def dependencies(self, verbose=False, - label_filter: Callable[[str], str] = None) -> DependencyGraphType: + def dependencies( + self, verbose=False, label_filter: Callable[[str], str] = None + ) -> DependencyGraphType: """ Return an object that allows the caller to explore the dependencies of this function on other memento functions, plain functions and global variables. When invoked from @@ -368,11 +399,13 @@ def _filter_call(self, *args, **kwargs) -> Any: def _update_fn_reference(self): """Update the _fn_reference attribute based on the latest computed version""" - self._fn_reference = FunctionReference(self, - cluster_name=self.cluster_name, - version=self.version(), - partial_args=self.partial_args, - partial_kwargs=self.partial_kwargs) + self._fn_reference = FunctionReference( + self, + cluster_name=self.cluster_name, + version=self.version(), + partial_args=self.partial_args, + partial_kwargs=self.partial_kwargs, + ) def _update_dependencies(self): """Assemble dependencies and update the version and fn_reference""" @@ -391,8 +424,13 @@ def _update_dependencies(self): # Check the version cache to see if we need to recompute the version entry = None # type: Optional[_MementoFunctionVersionCacheEntry] - if self.qualified_name_without_version in MementoFunction._global_fn_version_cache: - entry = MementoFunction._global_fn_version_cache[self.qualified_name_without_version] + if ( + self.qualified_name_without_version + in MementoFunction._global_fn_version_cache + ): + entry = MementoFunction._global_fn_version_cache[ + self.qualified_name_without_version + ] if entry.as_of_generation == MementoFunction._global_fn_generation: changed_rules = [rule for rule in self._hash_rules if rule.did_change()] if len(changed_rules) > 0: @@ -402,7 +440,8 @@ def _update_dependencies(self): MementoFunction.increment_global_fn_generation( reason="function {} hash rules changed: {}".format( self.qualified_name_without_version, - [rule.describe() for rule in changed_rules]) + [rule.describe() for rule in changed_rules], + ) ) else: if self._calculated_version is None: @@ -419,15 +458,20 @@ def _update_dependencies(self): if entry is None or (entry is not None and entry.version != version): # Notify the user about the new version - log.debug("At generation {}, calculated version for fn {} as {}.".format( - MementoFunction._global_fn_generation, - self.qualified_name_without_version, version)) + log.debug( + "At generation {}, calculated version for fn {} as {}.".format( + MementoFunction._global_fn_generation, + self.qualified_name_without_version, + version, + ) + ) # Update the cache entry MementoFunction._global_fn_version_cache[ - self.qualified_name_without_version] = _MementoFunctionVersionCacheEntry( - as_of_generation=MementoFunction._global_fn_generation, - version=version) + self.qualified_name_without_version + ] = _MementoFunctionVersionCacheEntry( + as_of_generation=MementoFunction._global_fn_generation, version=version + ) def _recompute_version(self): """Collect dependencies and [re]compute the version of this function""" @@ -436,13 +480,19 @@ def _recompute_version(self): hash_rules = set() # type: Set[HashRule] # Collect dependencies - self_rule = MementoFunctionHashRule(parent_symbol=None, - symbol=self.qualified_name_without_version, - resolver=lambda: self, obj=self, first_level=True) - self_rule.collect_transitive_dependencies(result=hash_rules, root_fn=self, - package_scope={inspect.getmodule( - self.src_fn).__package__}, - blacklist=[memento_function]) + self_rule = MementoFunctionHashRule( + parent_symbol=None, + symbol=self.qualified_name_without_version, + resolver=lambda: self, + obj=self, + first_level=True, + ) + self_rule.collect_transitive_dependencies( + result=hash_rules, + root_fn=self, + package_scope={inspect.getmodule(self.src_fn).__package__}, + blacklist=[memento_function], + ) # Order hash rules ordered_hash_rules = sorted(hash_rules) @@ -460,8 +510,9 @@ def _recompute_version(self): return version @staticmethod - def _extract_fn_ref_args(caller_args: Tuple, caller_kwargs: Dict, - caller_context_args: Dict) -> Set[str]: + def _extract_fn_ref_args( + caller_args: Tuple, caller_kwargs: Dict, caller_context_args: Dict + ) -> Set[str]: """ Assemble a list of arguments that are FunctionReferences, as these are valid to call. Note these can be nested in data structures like Lists or Dicts. @@ -514,7 +565,9 @@ def _validate_dependency(self): # Top of stack, so no caller. Any call is allowed. return - caller_ref = frame.memento.invocation_metadata.fn_reference_with_args.fn_reference + caller_ref = ( + frame.memento.invocation_metadata.fn_reference_with_args.fn_reference + ) caller = cast(MementoFunctionType, caller_ref.memento_fn) if caller.explicit_version is not None: # Caller has declared version explicitly, so there is no need to worry that @@ -523,23 +576,33 @@ def _validate_dependency(self): caller_args = frame.memento.invocation_metadata.fn_reference_with_args.args caller_kwargs = frame.memento.invocation_metadata.fn_reference_with_args.kwargs - caller_context_args = frame.memento.invocation_metadata.fn_reference_with_args.context_args - fn_ref_args = self._extract_fn_ref_args(caller_args, caller_kwargs, caller_context_args) + caller_context_args = ( + frame.memento.invocation_metadata.fn_reference_with_args.context_args + ) + fn_ref_args = self._extract_fn_ref_args( + caller_args, caller_kwargs, caller_context_args + ) # Any dependency is a valid function - valid_fns = {fn.fn_reference().qualified_name - for fn in caller.dependencies().transitive_memento_fn_dependencies()} + valid_fns = { + fn.fn_reference().qualified_name + for fn in caller.dependencies().transitive_memento_fn_dependencies() + } # Any argument that is a function reference is also valid valid_fns |= fn_ref_args - if caller.qualified_name_without_version != self.qualified_name_without_version and \ - self.fn_reference().qualified_name not in valid_fns: + if ( + caller.qualified_name_without_version != self.qualified_name_without_version + and self.fn_reference().qualified_name not in valid_fns + ): raise UndeclaredDependencyError( "{target} is not declared or detected to be a dependency of {src}. " - "Solution: Add @memento_function(dependencies=[{target}]) to {src}.". - format(src=caller.qualified_name_without_version, - target=self.qualified_name_without_version)) + "Solution: Add @memento_function(dependencies=[{target}]) to {src}.".format( + src=caller.qualified_name_without_version, + target=self.qualified_name_without_version, + ) + ) @classmethod def increment_global_fn_generation(cls, reason=None): @@ -550,8 +613,11 @@ def increment_global_fn_generation(cls, reason=None): """ cls._global_fn_generation += 1 - log.debug("New global generation{}: {}".format( - " ({})".format(reason) if reason else "", cls._global_fn_generation)) + log.debug( + "New global generation{}: {}".format( + " ({})".format(reason) if reason else "", cls._global_fn_generation + ) + ) def __str__(self) -> str: return self.__repr__() @@ -560,12 +626,15 @@ def __repr__(self) -> str: return "{}.memento_fn".format(repr(self.fn_reference())) -def memento_function(*plain_fn, cluster: str = None, version: Any = None, - auto_dependencies: bool = True, - dependencies: List[Union[Callable, MementoFunctionType]] = None, - version_code_hash: str = None, - version_salt: str = None) -> \ - Union[MementoFunctionType, Callable[..., MementoFunctionType]]: +def memento_function( + *plain_fn, + cluster: str = None, + version: Any = None, + auto_dependencies: bool = True, + dependencies: List[Union[Callable, MementoFunctionType]] = None, + version_code_hash: str = None, + version_salt: str = None, +) -> Union[MementoFunctionType, Callable[..., MementoFunctionType]]: """ Decorator that causes a function to be treated as a memento function. If it is called with the same parameters in the future, the result will be memoized and not @@ -614,11 +683,15 @@ def c(): # This logic is to support both @memento_function and @memento_function() consistently: def decorator(fn) -> MementoFunction: - return MementoFunction(fn=fn, cluster_name=cluster, version=version, - auto_dependencies=auto_dependencies, - dependencies=dependencies, - version_code_hash=version_code_hash, - version_salt=version_salt) + return MementoFunction( + fn=fn, + cluster_name=cluster, + version=version, + auto_dependencies=auto_dependencies, + dependencies=dependencies, + version_code_hash=version_code_hash, + version_salt=version_salt, + ) if cluster is None and len(plain_fn) == 1: # Decorator Invoked without arguments @@ -628,28 +701,47 @@ def decorator(fn) -> MementoFunction: class ExternalMementoFunction(ExternalMementoFunctionBase): - def __init__(self, fn_reference: FunctionReference, context: InvocationContext, - hash_rules: List[HashRule]): + def __init__( + self, + fn_reference: FunctionReference, + context: InvocationContext, + hash_rules: List[HashRule], + ): super().__init__(fn_reference, context, "memento_function", hash_rules) def clone_with( - self, fn: Callable = None, src_fn: Callable = None, cluster_name: str = None, - version: str = None, calculated_version: str = None, context: InvocationContext = None, - partial_args: Tuple[Any] = None, partial_kwargs: Dict[str, Any] = None, - auto_dependencies: bool = True, - dependencies: List[Union[str, MementoFunctionType]] = None, - version_code_hash: str = None, version_salt: str = None) -> MementoFunctionType: + self, + fn: Callable = None, + src_fn: Callable = None, + cluster_name: str = None, + version: str = None, + calculated_version: str = None, + context: InvocationContext = None, + partial_args: Tuple[Any] = None, + partial_kwargs: Dict[str, Any] = None, + auto_dependencies: bool = True, + dependencies: List[Union[str, MementoFunctionType]] = None, + version_code_hash: str = None, + version_salt: str = None, + ) -> MementoFunctionType: fn_ref = self._clone_fn_ref( - fn=fn, src_fn=src_fn, cluster_name=cluster_name, version=version, - calculated_version=calculated_version, partial_args=partial_args, - partial_kwargs=partial_kwargs, auto_dependencies=auto_dependencies, - dependencies=dependencies, version_code_hash=version_code_hash, - version_salt=version_salt + fn=fn, + src_fn=src_fn, + cluster_name=cluster_name, + version=version, + calculated_version=calculated_version, + partial_args=partial_args, + partial_kwargs=partial_kwargs, + auto_dependencies=auto_dependencies, + dependencies=dependencies, + version_code_hash=version_code_hash, + version_salt=version_salt, ) return ExternalMementoFunction( fn_reference=fn_ref, context=context or self.context, - hash_rules=self._hash_rules) + hash_rules=self._hash_rules, + ) ExternalMementoFunctionBase.register("memento_function", ExternalMementoFunction) @@ -667,7 +759,11 @@ def forget_cluster(cluster_name: str = None): if cluster_config is None: raise ValueError("Cluster with name '{}' not found".format(cluster_name)) storage_backend = cluster_config.storage - log.info("Forgetting all functions for all arg hashes in cluster_name {}".format(cluster_name)) + log.info( + "Forgetting all functions for all arg hashes in cluster_name {}".format( + cluster_name + ) + ) storage_backend.forget_everything() diff --git a/twosigma/memento/metadata.py b/twosigma/memento/metadata.py index 8facf95..3ca55fe 100644 --- a/twosigma/memento/metadata.py +++ b/twosigma/memento/metadata.py @@ -27,33 +27,36 @@ from .resource import ResourceHandle from .exception import MementoException from .partition import Partition -from .reference import FunctionReferenceWithArguments, FunctionReference, \ - FunctionReferenceWithArgHash +from .reference import ( + FunctionReferenceWithArguments, + FunctionReference, + FunctionReferenceWithArgHash, +) from .types import VersionedDataSourceKey, MementoFunctionType class ResultType(Enum): - exception = 0, - null = 1, - boolean = 2, - string = 3, - binary = 4, - number = 5, - date = 6, - timestamp = 7, - list_result = 8, # list is a reserved word - dictionary = 9, - array_boolean = 10, - array_int8 = 11, - array_int16 = 12, - array_int32 = 13, - array_int64 = 14, - array_float32 = 15, - array_float64 = 16, - index = 17, - series = 18, - data_frame = 19, - partition = 20, + exception = (0,) + null = (1,) + boolean = (2,) + string = (3,) + binary = (4,) + number = (5,) + date = (6,) + timestamp = (7,) + list_result = (8,) # list is a reserved word + dictionary = (9,) + array_boolean = (10,) + array_int8 = (11,) + array_int16 = (12,) + array_int32 = (13,) + array_int64 = (14,) + array_float32 = (15,) + array_float64 = (16,) + index = (17,) + series = (18,) + data_frame = (19,) + partition = (20,) memento_function = 21 # not a valid return type, but valid argument type @staticmethod @@ -72,7 +75,9 @@ def from_object(obj) -> "ResultType": return ResultType.number if isinstance(obj, complex): raise ValueError("Memento cannot [de]serialize a complex") - if isinstance(obj, datetime.datetime): # pd.Timestamp also extends datetime.datetime + if isinstance( + obj, datetime.datetime + ): # pd.Timestamp also extends datetime.datetime return ResultType.timestamp if isinstance(obj, datetime.date): return ResultType.date @@ -87,21 +92,23 @@ def from_object(obj) -> "ResultType": if isinstance(obj, pd.DataFrame): return ResultType.data_frame if isinstance(obj, np.ndarray): - if obj.dtype == 'bool': + if obj.dtype == "bool": return ResultType.array_boolean - if obj.dtype == 'int8': + if obj.dtype == "int8": return ResultType.array_int8 - if obj.dtype == 'int16': + if obj.dtype == "int16": return ResultType.array_int16 - if obj.dtype == 'int32': + if obj.dtype == "int32": return ResultType.array_int32 - if obj.dtype == 'int64': + if obj.dtype == "int64": return ResultType.array_int64 - if obj.dtype == 'float32': + if obj.dtype == "float32": return ResultType.array_float32 - if obj.dtype == 'float64': + if obj.dtype == "float64": return ResultType.array_float64 - raise ValueError("Memento cannot [de]serialize a ndarray of type {}".format(obj.dtype)) + raise ValueError( + "Memento cannot [de]serialize a ndarray of type {}".format(obj.dtype) + ) if isinstance(obj, Partition): return ResultType.partition raise ValueError("Memento cannot [de]serialize a {}".format(type(obj))) @@ -180,18 +187,22 @@ class InvocationMetadata: result_type = None # type: Optional[ResultType] def __sizeof__(self): - return sys.getsizeof(self.fn_reference_with_args) + \ - sum([sys.getsizeof(x) for x in self.invocations]) + \ - sum([sys.getsizeof(x) for x in self.resources]) + \ - sys.getsizeof(self.runtime) + \ - sys.getsizeof(self.result_type) - - def __init__(self, - fn_reference_with_args: FunctionReferenceWithArguments, - invocations: List[FunctionReferenceWithArguments], - resources: List[ResourceHandle], - runtime: Optional[datetime.timedelta], - result_type: Optional[ResultType]): + return ( + sys.getsizeof(self.fn_reference_with_args) + + sum([sys.getsizeof(x) for x in self.invocations]) + + sum([sys.getsizeof(x) for x in self.resources]) + + sys.getsizeof(self.runtime) + + sys.getsizeof(self.result_type) + ) + + def __init__( + self, + fn_reference_with_args: FunctionReferenceWithArguments, + invocations: List[FunctionReferenceWithArguments], + resources: List[ResourceHandle], + runtime: Optional[datetime.timedelta], + result_type: Optional[ResultType], + ): self.fn_reference_with_args = fn_reference_with_args self.invocations = invocations self.resources = resources @@ -199,9 +210,15 @@ def __init__(self, self.result_type = result_type def __repr__(self): - return "InvocationMetadata(fn_reference_with_args={}, invocations={}, runtime={}, " \ - "result_type={})".format(repr(self.fn_reference_with_args), repr(self.invocations), - repr(self.runtime), repr(self.result_type)) + return ( + "InvocationMetadata(fn_reference_with_args={}, invocations={}, runtime={}, " + "result_type={})".format( + repr(self.fn_reference_with_args), + repr(self.invocations), + repr(self.runtime), + repr(self.result_type), + ) + ) def __str__(self): return self.__repr__() @@ -231,13 +248,15 @@ class Memento: content_key = None # type: Optional[VersionedDataSourceKey] """If specified, this key overrides the default key for a Memento result""" - def __init__(self, - time: datetime.datetime, - invocation_metadata: InvocationMetadata, - function_dependencies: MutableSet[FunctionReference], - runner: Dict[str, object], - correlation_id: str, - content_key: Optional[VersionedDataSourceKey]): + def __init__( + self, + time: datetime.datetime, + invocation_metadata: InvocationMetadata, + function_dependencies: MutableSet[FunctionReference], + runner: Dict[str, object], + correlation_id: str, + content_key: Optional[VersionedDataSourceKey], + ): """ Creates a new Memento, tracking the metadata for the function invocation. @@ -266,6 +285,7 @@ def forget(self): """ from .configuration import Environment + env = Environment.get() fn_reference_with_args = self.invocation_metadata.fn_reference_with_args @@ -275,7 +295,11 @@ def forget(self): # Forget only the results for the provided parameters arg_hash = fn_reference_with_args.fn_reference_with_arg_hash() - log.info("Forgetting {} for arg hash {}".format(fn_reference.qualified_name, arg_hash)) + log.info( + "Forgetting {} for arg hash {}".format( + fn_reference.qualified_name, arg_hash + ) + ) storage_backend.forget_call(arg_hash) def forget_exceptions_recursively(self, dry_run=False): @@ -298,6 +322,7 @@ def forget_exceptions_recursively(self, dry_run=False): """ from .configuration import Environment + env = Environment.get() warned = set() # type: Set[str] @@ -310,12 +335,15 @@ def get_cluster(fn_reference: FunctionReference): if result is None: if cluster_name not in warned: warned.add(cluster_name) - log.warning(f"Cannot find cluster {cluster_name} in default environment") + log.warning( + f"Cannot find cluster {cluster_name} in default environment" + ) return result @lru_cache(maxsize=10240) def get_memento( - cluster_storage, fn_with_arg_hash: FunctionReferenceWithArgHash) -> Memento: + cluster_storage, fn_with_arg_hash: FunctionReferenceWithArgHash + ) -> Memento: return cluster_storage.get_memento(fn_with_arg_hash) def add_result(invocation: InvocationMetadata): @@ -327,7 +355,9 @@ def add_result(invocation: InvocationMetadata): for inv in invocation.invocations: fn_cluster = get_cluster(inv.fn_reference) if fn_cluster is not None: - memento = get_memento(fn_cluster.storage, inv.fn_reference_with_arg_hash()) + memento = get_memento( + fn_cluster.storage, inv.fn_reference_with_arg_hash() + ) if memento is not None: add_result(memento.invocation_metadata) @@ -337,7 +367,9 @@ def add_result(invocation: InvocationMetadata): if not dry_run: cluster = get_cluster(fn_with_args.fn_reference) if cluster is not None: - cluster.storage.forget_call(fn_with_args.fn_reference_with_arg_hash()) + cluster.storage.forget_call( + fn_with_args.fn_reference_with_arg_hash() + ) log.warning(f"{'Would forget' if dry_run else 'Forgot'} {fn_with_args}") def trace(self, max_depth=None, only_exceptions=False) -> str: @@ -350,6 +382,7 @@ def trace(self, max_depth=None, only_exceptions=False) -> str: """ from .configuration import Environment + env = Environment.get() warned = set() # type: Set[str] @@ -361,11 +394,15 @@ def label(m: InvocationMetadata) -> str: """ return "{}({}) [{}] -> {}".format( m.fn_reference_with_args.fn_reference.qualified_name, - _label_fn_ref_args(m.fn_reference_with_args), m.runtime, str(m.result_type)) + _label_fn_ref_args(m.fn_reference_with_args), + m.runtime, + str(m.result_type), + ) @lru_cache(maxsize=10240) def get_memento( - cluster_storage, fn_with_arg_hash: FunctionReferenceWithArgHash) -> Memento: + cluster_storage, fn_with_arg_hash: FunctionReferenceWithArgHash + ) -> Memento: return cluster_storage.get_memento(fn_with_arg_hash) def add_row(rows: List[str], indent: int, invocation: InvocationMetadata): @@ -377,16 +414,29 @@ def add_row(rows: List[str], indent: int, invocation: InvocationMetadata): if cluster is None: if cluster_name not in warned: warned.add(cluster_name) - log.warning(f"Cannot find cluster {cluster_name} in default environment") + log.warning( + f"Cannot find cluster {cluster_name} in default environment" + ) elif max_depth is None or indent < (max_depth - 1): - memento = get_memento(cluster.storage, invocation.fn_reference_with_arg_hash()) - if memento is not None and (not only_exceptions or memento. - invocation_metadata. - result_type == ResultType.exception): + memento = get_memento( + cluster.storage, invocation.fn_reference_with_arg_hash() + ) + if memento is not None and ( + not only_exceptions + or memento.invocation_metadata.result_type + == ResultType.exception + ): add_row(rows, indent + 1, memento.invocation_metadata) - elif not only_exceptions and (max_depth is not None and indent >= (max_depth - 1)): - rows.append((" " * ((indent + 1) * 4)) + "{}({})".format( - invocation.fn_reference.qualified_name, _label_fn_ref_args(invocation))) + elif not only_exceptions and ( + max_depth is not None and indent >= (max_depth - 1) + ): + rows.append( + (" " * ((indent + 1) * 4)) + + "{}({})".format( + invocation.fn_reference.qualified_name, + _label_fn_ref_args(invocation), + ) + ) result_rows = [] add_row(result_rows, 0, self.invocation_metadata) @@ -402,23 +452,27 @@ def graph(self, max_depth=None, only_exceptions=False) -> graphviz.Digraph: """ from .configuration import Environment + env = Environment.get() - graph = graphviz.Digraph(graph_attr={"rankdir": "LR", "splines": "ortho"}, - node_attr={"shape": "box", "fontname": "Helvetica", - "fontsize": "10"}) + graph = graphviz.Digraph( + graph_attr={"rankdir": "LR", "splines": "ortho"}, + node_attr={"shape": "box", "fontname": "Helvetica", "fontsize": "10"}, + ) node_id_holder = [0] warned = set() # type: Set[str] @lru_cache(maxsize=10240) def get_memento( - cluster_storage, fn_with_arg_hash: FunctionReferenceWithArgHash) -> Memento: + cluster_storage, fn_with_arg_hash: FunctionReferenceWithArgHash + ) -> Memento: return cluster_storage.get_memento(fn_with_arg_hash) def graph_node( - fn_reference_with_args: FunctionReferenceWithArguments, - m: Optional[InvocationMetadata]): + fn_reference_with_args: FunctionReferenceWithArguments, + m: Optional[InvocationMetadata], + ): node_id_holder[0] += 1 node_id = "n" + str(node_id_holder[0]) kwarg_str = _label_fn_ref_args(fn_reference_with_args) @@ -428,13 +482,18 @@ def graph_node( kwarg_short = kwarg_str[0:max_len] + "..." if m is not None: label = "{}({})\n-> {}\n{}".format( - fn_reference_with_args.fn_reference.function_name, kwarg_short, - m.result_type.name, m.runtime) + fn_reference_with_args.fn_reference.function_name, + kwarg_short, + m.result_type.name, + m.runtime, + ) else: label = "{}({})".format( - fn_reference_with_args.fn_reference.function_name, kwarg_short) - tooltip = "{}({})".format(fn_reference_with_args.fn_reference.function_name, - kwarg_str) + fn_reference_with_args.fn_reference.function_name, kwarg_short + ) + tooltip = "{}({})".format( + fn_reference_with_args.fn_reference.function_name, kwarg_str + ) graph.node(node_id, label=label, tooltip=tooltip) return node_id @@ -447,17 +506,25 @@ def add_node(m: InvocationMetadata, depth: int) -> str: if cluster is None: if cluster_name not in warned: warned.add(cluster_name) - log.warning(f"Cannot find cluster {cluster_name} in default environment") + log.warning( + f"Cannot find cluster {cluster_name} in default environment" + ) elif max_depth is None or depth < (max_depth - 1): cluster_storage = cluster.storage - memento = get_memento(cluster_storage, invocation.fn_reference_with_arg_hash()) - - if memento is not None and (not only_exceptions or memento. - invocation_metadata. - result_type == ResultType.exception): + memento = get_memento( + cluster_storage, invocation.fn_reference_with_arg_hash() + ) + + if memento is not None and ( + not only_exceptions + or memento.invocation_metadata.result_type + == ResultType.exception + ): other_node_id = add_node(memento.invocation_metadata, depth + 1) graph.edge(node_id, other_node_id) - elif not only_exceptions and (max_depth is not None and depth >= (max_depth - 1)): + elif not only_exceptions and ( + max_depth is not None and depth >= (max_depth - 1) + ): other_node_id = graph_node(invocation, None) graph.edge(node_id, other_node_id) @@ -468,18 +535,26 @@ def add_node(m: InvocationMetadata, depth: int) -> str: return graph def __sizeof__(self): - return sys.getsizeof(self.time) +\ - sys.getsizeof(self.invocation_metadata) +\ - sys.getsizeof(self.runner) +\ - sys.getsizeof(self.correlation_id) +\ - sys.getsizeof(self.content_key) + return ( + sys.getsizeof(self.time) + + sys.getsizeof(self.invocation_metadata) + + sys.getsizeof(self.runner) + + sys.getsizeof(self.correlation_id) + + sys.getsizeof(self.content_key) + ) def __repr__(self): - return "Memento(time={}, invocation_metadata={}, function_dependencies={}, runner={}, " \ - "correlation_id={}, content_key={})". \ - format(repr(self.time), repr(self.invocation_metadata), - repr(self.function_dependencies), repr(self.runner), repr(self.correlation_id), - repr(self.content_key)) + return ( + "Memento(time={}, invocation_metadata={}, function_dependencies={}, runner={}, " + "correlation_id={}, content_key={})".format( + repr(self.time), + repr(self.invocation_metadata), + repr(self.function_dependencies), + repr(self.runner), + repr(self.correlation_id), + repr(self.content_key), + ) + ) def __str__(self): return self.__repr__() diff --git a/twosigma/memento/partition.py b/twosigma/memento/partition.py index d5ee151..d36dc9f 100644 --- a/twosigma/memento/partition.py +++ b/twosigma/memento/partition.py @@ -103,7 +103,7 @@ class InMemoryPartition(Partition): """ - _results = None # type: Dict[str, object] + _results = None # type: Dict[str, object] _output_keys = None # type: Optional[Dict] """ diff --git a/twosigma/memento/reference.py b/twosigma/memento/reference.py index 10c650f..0ff7898 100644 --- a/twosigma/memento/reference.py +++ b/twosigma/memento/reference.py @@ -87,24 +87,20 @@ def _encode(arg: object) -> object: Encodes the provided argument as an object using the memento standard encoding rules. """ - if arg is None \ - or isinstance(arg, bool) \ - or isinstance(arg, str) \ - or isinstance(arg, int) \ - or isinstance(arg, float): + if ( + arg is None + or isinstance(arg, bool) + or isinstance(arg, str) + or isinstance(arg, int) + or isinstance(arg, float) + ): return arg if isinstance(arg, datetime.datetime): - return { - "_mementoType": "datetime", - "iso8601": arg.isoformat() - } + return {"_mementoType": "datetime", "iso8601": arg.isoformat()} if isinstance(arg, datetime.date): - return { - "_mementoType": "date", - "iso8601": arg.isoformat() - } + return {"_mementoType": "date", "iso8601": arg.isoformat()} if isinstance(arg, list): return [ArgumentHasher._encode(x) for x in arg] @@ -119,12 +115,15 @@ def _encode(arg: object) -> object: "_mementoType": "FunctionReference", "qualifiedName": fn_reference.qualified_name, "partialArgs": ArgumentHasher._encode( - list(partial_args) if partial_args else None), + list(partial_args) if partial_args else None + ), "partialKwargs": ArgumentHasher._encode(fn_reference.partial_kwargs), - "parameterNames": fn_reference.parameter_names + "parameterNames": fn_reference.parameter_names, } - raise ValueError("Illegal argument type for memento argument: {}".format(type(arg))) + raise ValueError( + "Illegal argument type for memento argument: {}".format(type(arg)) + ) @staticmethod def _decode(arg: object) -> object: @@ -132,11 +131,13 @@ def _decode(arg: object) -> object: Decodes the provided argument as an object using the memento standard encoding rules. """ - if arg is None \ - or isinstance(arg, bool) \ - or isinstance(arg, str) \ - or isinstance(arg, int) \ - or isinstance(arg, float): + if ( + arg is None + or isinstance(arg, bool) + or isinstance(arg, str) + or isinstance(arg, int) + or isinstance(arg, float) + ): return arg if isinstance(arg, list): @@ -147,14 +148,21 @@ def _decode(arg: object) -> object: memento_type = arg["_mementoType"] if memento_type == "FunctionReference": qualified_name = arg["qualifiedName"] - partial_args_list = cast(list, ArgumentHasher._decode(arg["partialArgs"])) - partial_kwargs = cast(dict, ArgumentHasher._decode(arg["partialKwargs"])) + partial_args_list = cast( + list, ArgumentHasher._decode(arg["partialArgs"]) + ) + partial_kwargs = cast( + dict, ArgumentHasher._decode(arg["partialKwargs"]) + ) parameter_names = cast(list, arg["parameterNames"]) return FunctionReference.from_qualified_name( qualified_name=qualified_name, - partial_args=tuple(partial_args_list) if partial_args_list else None, + partial_args=( + tuple(partial_args_list) if partial_args_list else None + ), partial_kwargs=partial_kwargs, - parameter_names=parameter_names).memento_fn + parameter_names=parameter_names, + ).memento_fn elif memento_type == "datetime": return date_parser.isoparse(arg["iso8601"]) elif memento_type == "date": @@ -164,7 +172,9 @@ def _decode(arg: object) -> object: else: return {k: ArgumentHasher._decode(v) for (k, v) in arg.items()} - raise ValueError("Illegal argument type for memento argument: {}".format(type(arg))) + raise ValueError( + "Illegal argument type for memento argument: {}".format(type(arg)) + ) @staticmethod def _normalized_json(obj: object) -> str: @@ -172,23 +182,35 @@ def _normalized_json(obj: object) -> str: Compute the normalized json version of the given object """ - if obj is None \ - or isinstance(obj, bool) \ - or isinstance(obj, str) \ - or isinstance(obj, int) \ - or isinstance(obj, float): + if ( + obj is None + or isinstance(obj, bool) + or isinstance(obj, str) + or isinstance(obj, int) + or isinstance(obj, float) + ): return json.dumps(obj) if isinstance(obj, list): - return "[" + ",".join([ArgumentHasher._normalized_json(x) for x in obj]) + "]" + return ( + "[" + ",".join([ArgumentHasher._normalized_json(x) for x in obj]) + "]" + ) if isinstance(obj, dict): - return "{" + ",".join([ - json.dumps(k) + ":" + ArgumentHasher._normalized_json(v) - for (k, v) in list(sorted(obj.items(), key=lambda t: t[0])) - ]) + "}" + return ( + "{" + + ",".join( + [ + json.dumps(k) + ":" + ArgumentHasher._normalized_json(v) + for (k, v) in list(sorted(obj.items(), key=lambda t: t[0])) + ] + ) + + "}" + ) - raise ValueError("Illegal object type for normalized json: {}".format(type(obj))) + raise ValueError( + "Illegal object type for normalized json: {}".format(type(obj)) + ) @staticmethod def compute_hash(effective_kwargs: dict) -> str: @@ -199,11 +221,15 @@ def compute_hash(effective_kwargs: dict) -> str: arg_hash = hashlib.sha256() encoded_effective_kwargs = ArgumentHasher._encode(effective_kwargs) effective_kwargs_normalized_json = ArgumentHasher._normalized_json( - encoded_effective_kwargs) + encoded_effective_kwargs + ) arg_hash.update(effective_kwargs_normalized_json.encode("utf-8")) result = arg_hash.hexdigest() - log.debug("Computed hash of normalized kwargs {} = {}".format( - effective_kwargs_normalized_json, result)) + log.debug( + "Computed hash of normalized kwargs {} = {}".format( + effective_kwargs_normalized_json, result + ) + ) return result @@ -300,10 +326,18 @@ def partial_kwargs(self): parameter_names = None # type: List[str] """Formal parameter names of this function, in order""" - def __init__(self, memento_fn: Optional[MementoFunctionType], cluster_name: str = None, - version: str = None, partial_args: Tuple[Any] = None, - partial_kwargs: Dict[str, Any] = None, module_name: str = None, - function_name: str = None, parameter_names: List[str] = None, external=False): + def __init__( + self, + memento_fn: Optional[MementoFunctionType], + cluster_name: str = None, + version: str = None, + partial_args: Tuple[Any] = None, + partial_kwargs: Dict[str, Any] = None, + module_name: str = None, + function_name: str = None, + parameter_names: List[str] = None, + external=False, + ): """ Construct a function reference, either by passing in the function or the qualified name of the function. @@ -314,15 +348,18 @@ def __init__(self, memento_fn: Optional[MementoFunctionType], cluster_name: str """ # Get the information from the MementoFunction assert memento_fn is not None, "memento_fn must not be None" - assert isinstance(memento_fn, MementoFunctionType),\ - "memento_fn must be a MementoFunctionType" + assert isinstance( + memento_fn, MementoFunctionType + ), "memento_fn must be a MementoFunctionType" self.external = external # Get cluster_name if cluster_name is not None: self._cluster_name = cluster_name else: - self._cluster_name = memento_fn.cluster_name if memento_fn is not None else None + self._cluster_name = ( + memento_fn.cluster_name if memento_fn is not None else None + ) # Get module # noinspection PyUnresolvedReferences @@ -331,7 +368,9 @@ def __init__(self, memento_fn: Optional[MementoFunctionType], cluster_name: str elif memento_fn is not None and hasattr(memento_fn.fn, "__module__"): self._module = memento_fn.fn.__module__ else: - assert module_name is not None, "Either memento_fn or module_name must be specified" + assert ( + module_name is not None + ), "Either memento_fn or module_name must be specified" # Get function name if function_name is not None: @@ -339,37 +378,45 @@ def __init__(self, memento_fn: Optional[MementoFunctionType], cluster_name: str elif memento_fn is not None: self._function_name = memento_fn.fn.__name__ else: - assert function_name is not None,\ - "Either memento_fn or function_name must be specified" + assert ( + function_name is not None + ), "Either memento_fn or function_name must be specified" # Get version number if version is None: version = memento_fn.version() # Construct the qualified name - qualified_name = memento_fn.qualified_name_without_version if \ - memento_fn is not None and memento_fn.fn is not None \ + qualified_name = ( + memento_fn.qualified_name_without_version + if memento_fn is not None and memento_fn.fn is not None else self._module + ":" + self._function_name + ) if version is not None: qualified_name += "#" + version if cluster_name is not None and "::" not in qualified_name: qualified_name = cluster_name + "::" + qualified_name self._qualified_name = qualified_name - self._qualified_name_without_cluster = self.qualified_name \ - if "::" not in self.qualified_name \ - else self.qualified_name[self.qualified_name.find("::") + 2:] + self._qualified_name_without_cluster = ( + self.qualified_name + if "::" not in self.qualified_name + else self.qualified_name[self.qualified_name.find("::") + 2 :] + ) self.qualified_name_without_version = self.module + ":" + self.function_name if cluster_name is not None: - self.qualified_name_without_version = self.cluster_name + "::" +\ - self.qualified_name_without_version + self.qualified_name_without_version = ( + self.cluster_name + "::" + self.qualified_name_without_version + ) - normalized_partial_args = ArgumentHasher.normalize( - list(partial_args)) if partial_args else [] # type: list + normalized_partial_args = ( + ArgumentHasher.normalize(list(partial_args)) if partial_args else [] + ) # type: list self._partial_args = tuple(normalized_partial_args) - normalized_partial_kwargs = ArgumentHasher.normalize( - partial_kwargs) if partial_kwargs else {} # type: dict + normalized_partial_kwargs = ( + ArgumentHasher.normalize(partial_kwargs) if partial_kwargs else {} + ) # type: dict self._partial_kwargs = normalized_partial_kwargs # Record function @@ -380,10 +427,12 @@ def __init__(self, memento_fn: Optional[MementoFunctionType], cluster_name: str self.parameter_names = parameter_names elif memento_fn is not None: self.parameter_names = list( - inspect.signature(memento_fn.fn).parameters.keys()) + inspect.signature(memento_fn.fn).parameters.keys() + ) else: - assert parameter_names is not None,\ - "Must specify parameter_names if memento_fn not provided" + assert ( + parameter_names is not None + ), "Must specify parameter_names if memento_fn not provided" @staticmethod def parse_qualified_name(qualified_name: str) -> Dict: @@ -395,17 +444,22 @@ def parse_qualified_name(qualified_name: str) -> Dict: # Parse information from the string match = re.match( r"((?P.*)::)?(?P.*):(?P[^#]*)(#(?P.*))?", - qualified_name) + qualified_name, + ) if not match: raise ValueError( - "fn_or_name '{}' is not a valid qualified name".format(qualified_name)) + "fn_or_name '{}' is not a valid qualified name".format(qualified_name) + ) return match.groupdict() @staticmethod - def from_qualified_name(qualified_name: str, partial_args: Tuple[Any] = None, - partial_kwargs: Dict[str, Any] = None, - parameter_names: List[str] = None, external=False) -> \ - "FunctionReference": + def from_qualified_name( + qualified_name: str, + partial_args: Tuple[Any] = None, + partial_kwargs: Dict[str, Any] = None, + parameter_names: List[str] = None, + external=False, + ) -> "FunctionReference": """ Attempts to find a function with the given qualified name. @@ -420,32 +474,48 @@ def from_qualified_name(qualified_name: str, partial_args: Tuple[Any] = None, if not external: try: - memento_fn = FunctionReference._find_function(module=module, - function_name=function_name, - version=version, - partial_args=partial_args, - partial_kwargs=partial_kwargs) - return FunctionReference(memento_fn, cluster_name=cluster_name, version=version, - partial_args=partial_args, partial_kwargs=partial_kwargs) + memento_fn = FunctionReference._find_function( + module=module, + function_name=function_name, + version=version, + partial_args=partial_args, + partial_kwargs=partial_kwargs, + ) + return FunctionReference( + memento_fn, + cluster_name=cluster_name, + version=version, + partial_args=partial_args, + partial_kwargs=partial_kwargs, + ) except (ModuleNotFoundError, ValueError, AttributeError): # Cannot find module or function. Treat as an external function reference. external = True if external: from .external import UnboundExternalMementoFunction + # We don't know which server this function is bound to, yet. # We also may not know the parameter names, so pass [] if unknown unbound_fn = UnboundExternalMementoFunction( - cluster_name=cluster_name, module_name=module, function_name=function_name, - version=version, partial_args=partial_args, partial_kwargs=partial_kwargs, - parameter_names=parameter_names if parameter_names is not None else [] + cluster_name=cluster_name, + module_name=module, + function_name=function_name, + version=version, + partial_args=partial_args, + partial_kwargs=partial_kwargs, + parameter_names=parameter_names if parameter_names is not None else [], ) return unbound_fn.fn_reference() @staticmethod - def _find_function(module: str, function_name: str, version: str, - partial_args: Tuple[Any] = None, - partial_kwargs: Dict[str, Any] = None) -> MementoFunctionType: + def _find_function( + module: str, + function_name: str, + version: str, + partial_args: Tuple[Any] = None, + partial_kwargs: Dict[str, Any] = None, + ) -> MementoFunctionType: """ Find the function to which this reference points. @@ -463,37 +533,54 @@ def _find_function(module: str, function_name: str, version: str, module = importlib.import_module(module) ref = module if function_name.find("") != -1: - raise ValueError("Memento functions must be top-level. Cannot find a " - "function that is local to another function.") + raise ValueError( + "Memento functions must be top-level. Cannot find a " + "function that is local to another function." + ) for part in function_name.split("."): ref = getattr(ref, part) if not callable(ref): raise ValueError("{} does not refer to a function".format(function_name)) if not isinstance(ref, MementoFunctionType): - raise ValueError("{} did not resolve to a MementoFunctionType".format(function_name)) + raise ValueError( + "{} did not resolve to a MementoFunctionType".format(function_name) + ) memento_fn = ref # Check version if version is not None and memento_fn.version() != version: raise ValueError( "Function version does not match for {}: Expected {} but " - "registered function is {} with dependencies {}". - format(function_name, version, memento_fn.version(), - memento_fn.dependencies().transitive_memento_fn_dependencies())) + "registered function is {} with dependencies {}".format( + function_name, + version, + memento_fn.version(), + memento_fn.dependencies().transitive_memento_fn_dependencies(), + ) + ) if partial_args or partial_kwargs: - normalized_partial_args = ArgumentHasher.normalize( - list(partial_args)) if partial_args else [] # type: list - normalized_partial_kwargs = ArgumentHasher.normalize( - partial_kwargs) if partial_kwargs else {} # type: dict - return memento_fn.partial(*tuple(normalized_partial_args), **normalized_partial_kwargs) + normalized_partial_args = ( + ArgumentHasher.normalize(list(partial_args)) if partial_args else [] + ) # type: list + normalized_partial_kwargs = ( + ArgumentHasher.normalize(partial_kwargs) if partial_kwargs else {} + ) # type: dict + return memento_fn.partial( + *tuple(normalized_partial_args), **normalized_partial_kwargs + ) return memento_fn - def with_args(self, *args, _memento_context_args: Dict[str, Any] = None, **kwargs) -> \ - 'FunctionReferenceWithArguments': - return FunctionReferenceWithArguments(fn_reference=self, args=args, kwargs=kwargs, - context_args=_memento_context_args) + def with_args( + self, *args, _memento_context_args: Dict[str, Any] = None, **kwargs + ) -> "FunctionReferenceWithArguments": + return FunctionReferenceWithArguments( + fn_reference=self, + args=args, + kwargs=kwargs, + context_args=_memento_context_args, + ) def __str__(self) -> str: return self.__repr__() @@ -503,9 +590,12 @@ def __repr__(self) -> str: return "FunctionReference({}, partial_args={}, partial_kwargs={}, external={})".format( repr(self.qualified_name), repr(self.partial_args), - repr(collections.OrderedDict( - sorted(self.partial_kwargs.items())) if self.partial_kwargs else None), - repr(self.external) + repr( + collections.OrderedDict(sorted(self.partial_kwargs.items())) + if self.partial_kwargs + else None + ), + repr(self.external), ) def __eq__(self, other): @@ -513,28 +603,41 @@ def __eq__(self, other): return False fn_ref = cast(FunctionReference, other) - return fn_ref.qualified_name == self.qualified_name and \ - fn_ref.partial_args == self.partial_args and \ - fn_ref.partial_kwargs == self.partial_kwargs and \ - fn_ref.external == self.external + return ( + fn_ref.qualified_name == self.qualified_name + and fn_ref.partial_args == self.partial_args + and fn_ref.partial_kwargs == self.partial_kwargs + and fn_ref.external == self.external + ) def __hash__(self): return hash(self.qualified_name) def __getstate__(self): - return self.qualified_name, self.partial_args, self.partial_kwargs, self.external + return ( + self.qualified_name, + self.partial_args, + self.partial_kwargs, + self.external, + ) def __setstate__(self, state): qualified_name, partial_args, partial_kwargs, external = state - fn_ref = FunctionReference.from_qualified_name(qualified_name, partial_args=partial_args, - partial_kwargs=partial_kwargs, - external=external) - FunctionReference.__init__(self, fn_ref.memento_fn, - cluster_name=fn_ref.cluster_name, - version=fn_ref.memento_fn.version(), - partial_args=fn_ref.partial_args, - partial_kwargs=fn_ref.partial_kwargs, - external=fn_ref.external) + fn_ref = FunctionReference.from_qualified_name( + qualified_name, + partial_args=partial_args, + partial_kwargs=partial_kwargs, + external=external, + ) + FunctionReference.__init__( + self, + fn_ref.memento_fn, + cluster_name=fn_ref.cluster_name, + version=fn_ref.memento_fn.version(), + partial_args=fn_ref.partial_args, + partial_kwargs=fn_ref.partial_kwargs, + external=fn_ref.external, + ) class FunctionReferenceWithArgHash: @@ -554,9 +657,11 @@ def __init__(self, fn_reference: FunctionReference, arg_hash: str): self.arg_hash = arg_hash def __eq__(self, o: "FunctionReferenceWithArgHash") -> bool: - return isinstance(o, FunctionReferenceWithArgHash) and \ - self.fn_reference.qualified_name == o.fn_reference.qualified_name and \ - self.arg_hash == o.arg_hash + return ( + isinstance(o, FunctionReferenceWithArgHash) + and self.fn_reference.qualified_name == o.fn_reference.qualified_name + and self.arg_hash == o.arg_hash + ) def __hash__(self): return hash((self.fn_reference.qualified_name, self.arg_hash)) @@ -566,8 +671,8 @@ def __str__(self) -> str: def __repr__(self) -> str: return "FunctionReferenceWithArgHash(fn_reference={}, arg_hash={})".format( - repr(self.fn_reference), - repr(self.arg_hash)) + repr(self.fn_reference), repr(self.arg_hash) + ) class FunctionReferenceWithArguments: @@ -583,27 +688,43 @@ class FunctionReferenceWithArguments: effective_kwargs = None # type: Dict[str, Any] arg_hash = None # type: str - def __init__(self, fn_reference: FunctionReference, args: Tuple, kwargs: Dict[str, Any], - context_args: Optional[Dict[str, Any]] = None): + def __init__( + self, + fn_reference: FunctionReference, + args: Tuple, + kwargs: Dict[str, Any], + context_args: Optional[Dict[str, Any]] = None, + ): if not fn_reference.memento_fn: raise FunctionNotFoundError( "Cannot create a FunctionReferenceWithArguments if the underlying function " "reference cannot be mapped to a real function. This could be because the " "memento function is not in the path or a function version mismatch. " - "Reference: {}".format(repr(fn_reference))) + "Reference: {}".format(repr(fn_reference)) + ) self.fn_reference = fn_reference - normalized_args = ArgumentHasher.normalize(list(args)) if args else [] # type: list + normalized_args = ( + ArgumentHasher.normalize(list(args)) if args else [] + ) # type: list self.args = tuple(normalized_args) - normalized_kwargs = ArgumentHasher.normalize(kwargs) if kwargs else {} # type: dict + normalized_kwargs = ( + ArgumentHasher.normalize(kwargs) if kwargs else {} + ) # type: dict self.kwargs = normalized_kwargs - normalized_context_args = ArgumentHasher.normalize(context_args) \ - if context_args else {} # type: dict + normalized_context_args = ( + ArgumentHasher.normalize(context_args) if context_args else {} + ) # type: dict self.context_args = normalized_context_args self.effective_kwargs = self._compute_effective_kwargs() - self.effective_kwargs_with_context_args =\ + self.effective_kwargs_with_context_args = ( self._compute_effective_kwargs_with_context_args() - self.arg_hash = ArgumentHasher.compute_hash(self.effective_kwargs_with_context_args) - validate_args(*self.args, **self.kwargs, _memento_context_args=self.context_args) + ) + self.arg_hash = ArgumentHasher.compute_hash( + self.effective_kwargs_with_context_args + ) + validate_args( + *self.args, **self.kwargs, _memento_context_args=self.context_args + ) def fn_reference_with_arg_hash(self) -> FunctionReferenceWithArgHash: """ @@ -620,19 +741,25 @@ def _compute_effective_kwargs(self) -> Dict[str, Any]: parameter_names = self.fn_reference.parameter_names partial_args = self.fn_reference.partial_args if len(parameter_names) < len(partial_args): - raise ValueError(f"More partial arguments provided ({len(partial_args)} " - f"than the arguments for the function ({parameter_names})") + raise ValueError( + f"More partial arguments provided ({len(partial_args)} " + f"than the arguments for the function ({parameter_names})" + ) for i in range(0, len(partial_args)): result[parameter_names[i]] = partial_args[i] # Which parameter names are left after partial? - remaining_parameter_names = [name for name in parameter_names if name not in result] + remaining_parameter_names = [ + name for name in parameter_names if name not in result + ] # Now fill in args if len(remaining_parameter_names) < len(self.args): - raise ValueError(f"More arguments provided ({len(self.args)} " - f"than the remaining arguments for the " - f"function ({remaining_parameter_names})") + raise ValueError( + f"More arguments provided ({len(self.args)} " + f"than the remaining arguments for the " + f"function ({remaining_parameter_names})" + ) for i in range(0, len(self.args)): result[remaining_parameter_names[i]] = self.args[i] @@ -652,9 +779,11 @@ def _compute_effective_kwargs_with_context_args(self) -> Dict[str, Any]: return hash_kwargs def __eq__(self, o: "FunctionReferenceWithArguments") -> bool: - return isinstance(o, FunctionReferenceWithArguments) and \ - self.fn_reference.qualified_name == o.fn_reference.qualified_name and \ - self.arg_hash == o.arg_hash + return ( + isinstance(o, FunctionReferenceWithArguments) + and self.fn_reference.qualified_name == o.fn_reference.qualified_name + and self.arg_hash == o.arg_hash + ) def __hash__(self): return hash((self.fn_reference.qualified_name, self.arg_hash)) @@ -663,12 +792,15 @@ def __str__(self) -> str: return self.__repr__() def __repr__(self) -> str: - return "FunctionReferenceWithArguments(fn_reference={}, args={}, kwargs={}, " \ - "context_args={})".format( - repr(self.fn_reference), - repr(self.args), - repr(self.kwargs), - repr(self.context_args)) + return ( + "FunctionReferenceWithArguments(fn_reference={}, args={}, kwargs={}, " + "context_args={})".format( + repr(self.fn_reference), + repr(self.args), + repr(self.kwargs), + repr(self.context_args), + ) + ) def validate_args(*args, _memento_context_args: Dict[str, Any] = None, **kwargs): @@ -678,29 +810,41 @@ def validate_args(*args, _memento_context_args: Dict[str, Any] = None, **kwargs) """ def validate_arg(a): - return a is None or \ - isinstance(a, bool) or \ - isinstance(a, str) or \ - isinstance(a, int) or \ - isinstance(a, float) or \ - isinstance(a, datetime.date) or \ - isinstance(a, datetime.datetime) or \ - (isinstance(a, dict) and False not in [validate_arg(v) for (k, v) in a.items()]) or \ - (isinstance(a, list) and False not in [validate_arg(v) for v in a]) or \ - isinstance(a, MementoFunctionType) + return ( + a is None + or isinstance(a, bool) + or isinstance(a, str) + or isinstance(a, int) + or isinstance(a, float) + or isinstance(a, datetime.date) + or isinstance(a, datetime.datetime) + or ( + isinstance(a, dict) + and False not in [validate_arg(v) for (k, v) in a.items()] + ) + or (isinstance(a, list) and False not in [validate_arg(v) for v in a]) + or isinstance(a, MementoFunctionType) + ) if _memento_context_args is not None and not validate_arg(_memento_context_args): raise AssertionError( - "Memento cannot handle context arg. Value: {}".format(type(_memento_context_args))) + "Memento cannot handle context arg. Value: {}".format( + type(_memento_context_args) + ) + ) for idx, arg in enumerate(args): if not validate_arg(arg): raise AssertionError( "Memento cannot handle function argument type {} at index {}. Value: {}".format( - type(arg), idx, args)) + type(arg), idx, args + ) + ) for key, arg in kwargs.items(): if not validate_arg(arg): raise AssertionError( "Memento cannot handle function argument type {} for kwarg {}. Value: {}".format( - type(arg), key, kwargs)) + type(arg), key, kwargs + ) + ) diff --git a/twosigma/memento/resource.py b/twosigma/memento/resource.py index f44de7b..16973f7 100644 --- a/twosigma/memento/resource.py +++ b/twosigma/memento/resource.py @@ -35,16 +35,15 @@ def __init__(self, resource_type: str, url: str, version: str): self.version = version def __eq__(self, o: "ResourceHandle"): - return isinstance(o, ResourceHandle) and \ - self.__dict__ == o.__dict__ + return isinstance(o, ResourceHandle) and self.__dict__ == o.__dict__ def __hash__(self): return hash((self.resource_type, self.url, self.version)) def __repr__(self): return "ResourceHandle(resource_type={}, url={}, version={})".format( - repr(self.resource_type), - repr(self.url), repr(self.version)) + repr(self.resource_type), repr(self.url), repr(self.version) + ) def __str__(self): return self.__repr__() diff --git a/twosigma/memento/resource_function.py b/twosigma/memento/resource_function.py index 80c6c48..cfcd133 100644 --- a/twosigma/memento/resource_function.py +++ b/twosigma/memento/resource_function.py @@ -53,8 +53,9 @@ def __call__(self, *args, **kwargs) -> ResourceHandle: return handle -def resource_function(resource_type: str) -> \ - Callable[[Callable[[str], ResourceHandle]], ResourceFunction]: +def resource_function( + resource_type: str, +) -> Callable[[Callable[[str], ResourceHandle]], ResourceFunction]: """ Decorator that causes a function to be treated as a Memento resource function. A resource function is a special type of function in Memento that represents a diff --git a/twosigma/memento/runner.py b/twosigma/memento/runner.py index 47a301c..621861d 100644 --- a/twosigma/memento/runner.py +++ b/twosigma/memento/runner.py @@ -36,13 +36,14 @@ # The set of known runners. Register new runners using RunnerBackend.register. _registered_runner_backends = {} -ExistingMementoResult = NamedTuple('ExistingMementoResult', - [('result', Any), ('valid_result', bool)]) +ExistingMementoResult = NamedTuple( + "ExistingMementoResult", [("result", Any), ("valid_result", bool)] +) -def process_existing_memento(storage_backend: StorageBackend, - existing_memento: Memento, - ignore_result: bool) -> ExistingMementoResult: +def process_existing_memento( + storage_backend: StorageBackend, existing_memento: Memento, ignore_result: bool +) -> ExistingMementoResult: """ This logic is reused several times by runners when processing an existing memento for a function invocation. @@ -66,8 +67,11 @@ def process_existing_memento(storage_backend: StorageBackend, try: # If result already exists, deserialize and return if ignore_result: - log.debug("Result of {} was already memoized and is ignored".format( - str(fn_reference_with_args))) + log.debug( + "Result of {} was already memoized and is ignored".format( + str(fn_reference_with_args) + ) + ) return ExistingMementoResult(result=None, valid_result=True) # Unwrap exception from MementoException @@ -77,12 +81,17 @@ def process_existing_memento(storage_backend: StorageBackend, log.warning("Processing a MemoizedException") result = e.to_exception() - log.info("Previous result for {} was memoized and is of type {}.".format( - str(fn_reference_with_args), - existing_memento.invocation_metadata.result_type.name)) + log.info( + "Previous result for {} was memoized and is of type {}.".format( + str(fn_reference_with_args), + existing_memento.invocation_metadata.result_type.name, + ) + ) return ExistingMementoResult(result=result, valid_result=True) except IOError: - log.warning("IO Error while reading memoized result. Recomputing.", exc_info=True) + log.warning( + "IO Error while reading memoized result. Recomputing.", exc_info=True + ) return ExistingMementoResult(result=None, valid_result=False) @@ -111,9 +120,10 @@ def __init__(self, runner_type: str, config: dict = None): self.config = config @staticmethod - def ensure_correlation_id(context: InvocationContext, - fn_reference_with_args: List[FunctionReferenceWithArguments]) ->\ - InvocationContext: + def ensure_correlation_id( + context: InvocationContext, + fn_reference_with_args: List[FunctionReferenceWithArguments], + ) -> InvocationContext: """ Utility method for subclasses to generate a new correlation id, if one is not already present. @@ -121,16 +131,23 @@ def ensure_correlation_id(context: InvocationContext, """ if not context.recursive.correlation_id: correlation_id = "cid_" + uuid.uuid4().hex[0:12] - log.debug("{}: Generating new correlation id for {}".format(correlation_id, - fn_reference_with_args)) + log.debug( + "{}: Generating new correlation id for {}".format( + correlation_id, fn_reference_with_args + ) + ) return context.update_recursive("correlation_id", correlation_id) return context @abstractmethod - def batch_run(self, context: InvocationContext, storage_backend: StorageBackend, - fn_reference_with_args: List[FunctionReferenceWithArguments], - log_runner_backend: 'RunnerBackend', - caller_memento: Optional[Memento]) -> List[Any]: + def batch_run( + self, + context: InvocationContext, + storage_backend: StorageBackend, + fn_reference_with_args: List[FunctionReferenceWithArguments], + log_runner_backend: "RunnerBackend", + caller_memento: Optional[Memento], + ) -> List[Any]: """ Run a series of memento functions using this runner. If the runner is capable, these may be run in parallel. diff --git a/twosigma/memento/runner_local.py b/twosigma/memento/runner_local.py index 8dbe13c..7729485 100644 --- a/twosigma/memento/runner_local.py +++ b/twosigma/memento/runner_local.py @@ -34,12 +34,16 @@ from .runner import RunnerBackend, process_existing_memento, ExistingMementoResult from .storage import StorageBackend -_memento_fn_mutex_lock = RLock() # Lock to protect the defaultdict since it is not ThreadSafe +_memento_fn_mutex_lock = ( + RLock() +) # Lock to protect the defaultdict since it is not ThreadSafe # with the lambda _memento_fn_mutex = defaultdict(lambda: RLock()) # type: Dict[Tuple[str, str], RLock] -def _mutex_for_invocation(fn_reference_with_args: FunctionReferenceWithArguments) -> RLock: +def _mutex_for_invocation( + fn_reference_with_args: FunctionReferenceWithArguments, +) -> RLock: """ Check if any other callers in this process are calling this function at the same time and, if so, wait for them to complete. @@ -49,8 +53,12 @@ def _mutex_for_invocation(fn_reference_with_args: FunctionReferenceWithArguments """ with _memento_fn_mutex_lock: - return _memento_fn_mutex[(fn_reference_with_args.fn_reference.qualified_name, - fn_reference_with_args.arg_hash)] + return _memento_fn_mutex[ + ( + fn_reference_with_args.fn_reference.qualified_name, + fn_reference_with_args.arg_hash, + ) + ] class LocalRunnerBackend(RunnerBackend): @@ -63,10 +71,14 @@ class LocalRunnerBackend(RunnerBackend): def __init__(self, config: dict = None): super().__init__("local", config=config) - def batch_run(self, context: InvocationContext, storage_backend: StorageBackend, - fn_reference_with_args: List[FunctionReferenceWithArguments], - log_runner_backend: RunnerBackend, - caller_memento: Optional[Memento]) -> List[Any]: + def batch_run( + self, + context: InvocationContext, + storage_backend: StorageBackend, + fn_reference_with_args: List[FunctionReferenceWithArguments], + log_runner_backend: RunnerBackend, + caller_memento: Optional[Memento], + ) -> List[Any]: context = self.ensure_correlation_id(context, fn_reference_with_args) arg_list = fn_reference_with_args @@ -74,21 +86,27 @@ def batch_run(self, context: InvocationContext, storage_backend: StorageBackend, # Bulk query for existing mementos existing_mementos = storage_backend.get_mementos( - [f.fn_reference_with_arg_hash() for f in fn_reference_with_args]) + [f.fn_reference_with_arg_hash() for f in fn_reference_with_args] + ) - if context.local.monitor_progress and CallStack.get().get_calling_frame() is None: + if ( + context.local.monitor_progress + and CallStack.get().get_calling_frame() is None + ): # Only show progress bar if monitor_progress is set and this is the root call arg_list = tqdm(arg_list) # Since memento_run_local handles updating the invocation list of the caller already, the # local runner can ignore caller_memento. for idx, f in enumerate(arg_list): - existing_memento_result = ExistingMementoResult(result=None, valid_result=False) + existing_memento_result = ExistingMementoResult( + result=None, valid_result=False + ) existing_memento = existing_mementos[idx] if existing_memento: - existing_memento_result = process_existing_memento(storage_backend, - existing_memento, - context.local.ignore_result) + existing_memento_result = process_existing_memento( + storage_backend, existing_memento, context.local.ignore_result + ) if existing_memento_result.valid_result: results.append(existing_memento_result.result) @@ -97,34 +115,40 @@ def batch_run(self, context: InvocationContext, storage_backend: StorageBackend, call_stack = CallStack.get() calling_frame = call_stack.get_calling_frame() if calling_frame: - propagate_dependencies(caller_memento=calling_frame.memento, - result_memento=existing_memento) + propagate_dependencies( + caller_memento=calling_frame.memento, + result_memento=existing_memento, + ) else: try: - results.append(memento_run_local(context=context, - fn_reference_with_args=f, - storage_backend=storage_backend, - log_runner_backend=log_runner_backend)) + results.append( + memento_run_local( + context=context, + fn_reference_with_args=f, + storage_backend=storage_backend, + log_runner_backend=log_runner_backend, + ) + ) except Exception as e: results.append(e) return results def to_dict(self): - config = { - "type": "local" - } + config = {"type": "local"} return config RunnerBackend.register("local", LocalRunnerBackend) -def memento_run_batch(context: InvocationContext, - fn_reference_with_args: List[FunctionReferenceWithArguments], - storage_backend: StorageBackend, - runner_backend: RunnerBackend, - log_runner_backend: RunnerBackend) -> List[Any]: +def memento_run_batch( + context: InvocationContext, + fn_reference_with_args: List[FunctionReferenceWithArguments], + storage_backend: StorageBackend, + runner_backend: RunnerBackend, + log_runner_backend: RunnerBackend, +) -> List[Any]: """ Run a batch of Memento functions, using the given runner, with arguments. All calls to MementoFunctionTypes pass through this function, on the client side. @@ -156,27 +180,37 @@ def memento_run_batch(context: InvocationContext, if caller_memento: # If the caller memento exists, pass down its correlation id and other fields - context = context.update_recursive("correlation_id", caller_memento.correlation_id) - context = context.update_recursive("retry_on_remote_call", - calling_frame.recursive_context.retry_on_remote_call) + context = context.update_recursive( + "correlation_id", caller_memento.correlation_id + ) + context = context.update_recursive( + "retry_on_remote_call", calling_frame.recursive_context.retry_on_remote_call + ) if context.recursive.context_args is None: # Only update the context args from the call stack if not overridden in this call - context = context.update_recursive("context_args", - calling_frame.recursive_context.context_args) + context = context.update_recursive( + "context_args", calling_frame.recursive_context.context_args + ) # Update the fn_reference_with_args since the context args could change the # argument hash fn_reference_with_args = [ - FunctionReferenceWithArguments(ref.fn_reference, ref.args, ref.kwargs, - context.recursive.context_args) + FunctionReferenceWithArguments( + ref.fn_reference, + ref.args, + ref.kwargs, + context.recursive.context_args, + ) for ref in fn_reference_with_args ] - return runner.batch_run(context=context, - storage_backend=storage_backend, - fn_reference_with_args=fn_reference_with_args, - log_runner_backend=log_runner_backend, - caller_memento=caller_memento) + return runner.batch_run( + context=context, + storage_backend=storage_backend, + fn_reference_with_args=fn_reference_with_args, + log_runner_backend=log_runner_backend, + caller_memento=caller_memento, + ) def propagate_dependencies(caller_memento: Memento, result_memento: Memento): @@ -197,10 +231,12 @@ def propagate_dependencies(caller_memento: Memento, result_memento: Memento): parent_dependencies |= result_memento.function_dependencies -def memento_run_local(context: InvocationContext, - fn_reference_with_args: FunctionReferenceWithArguments, - storage_backend: StorageBackend, - log_runner_backend: RunnerBackend) -> Any: +def memento_run_local( + context: InvocationContext, + fn_reference_with_args: FunctionReferenceWithArguments, + storage_backend: StorageBackend, + log_runner_backend: RunnerBackend, +) -> Any: """ Run a single Memento function in the local process, with arguments. This is typically not called directly, but by a runner implementation. @@ -227,11 +263,16 @@ def memento_run_local(context: InvocationContext, correlation_id = context.recursive.correlation_id # Log what we're calling - log.debug("{}: Calling {} with context {}".format(correlation_id, fn_reference_with_args, - context)) + log.debug( + "{}: Calling {} with context {}".format( + correlation_id, fn_reference_with_args, context + ) + ) # Create a stack frame for the invocation - log_runner = LocalRunnerBackend() if context.local.force_local else log_runner_backend + log_runner = ( + LocalRunnerBackend() if context.local.force_local else log_runner_backend + ) call_stack = CallStack.get() stack_frame = StackFrame(fn_reference_with_args, log_runner, context.recursive) @@ -241,11 +282,12 @@ def memento_run_local(context: InvocationContext, call_stack.push_frame(stack_frame) existing_memento = storage_backend.get_memento( - fn_reference_with_args.fn_reference_with_arg_hash()) + fn_reference_with_args.fn_reference_with_arg_hash() + ) if existing_memento: - existing_memento_result = process_existing_memento(storage_backend, - existing_memento, - context.local.ignore_result) + existing_memento_result = process_existing_memento( + storage_backend, existing_memento, context.local.ignore_result + ) if existing_memento_result.valid_result: stack_frame.memento = existing_memento return existing_memento_result.result @@ -258,12 +300,15 @@ def memento_run_local(context: InvocationContext, log.debug("{}: Function call begins".format(correlation_id)) # noinspection PyProtectedMember result = fn_reference_with_args.fn_reference.memento_fn._filter_call( - **fn_reference_with_args.effective_kwargs) + **fn_reference_with_args.effective_kwargs + ) except RemoteCallException: # Special exception thrown if a remote call is made while processing # the function. This should immediately stop processing and not memoize # the result. The function will be retried by the framework later. - log.debug("Remote call detected during function execution. Retrying later.") + log.debug( + "Remote call detected during function execution. Retrying later." + ) raise except NonMemoizedException: # If the exception is marked not to be memoized, just raise it @@ -283,7 +328,9 @@ def memento_run_local(context: InvocationContext, key_override = result.key_override result = result.result - stack_frame.memento.invocation_metadata.result_type = ResultType.from_object(result) + stack_frame.memento.invocation_metadata.result_type = ( + ResultType.from_object(result) + ) # Memoize the result if it is not already present. # There are two primary reasons a memoized result could have appeared while @@ -291,39 +338,60 @@ def memento_run_local(context: InvocationContext, # 1. This function was invoked in another thread or process and we lost the race # 2. A runner is being used that memoized the result in another process (possibly # on another machine in a compute cluster) and the result is already memoized. - if not storage_backend.is_memoized(fn_reference_with_args.fn_reference, - fn_reference_with_args.arg_hash): + if not storage_backend.is_memoized( + fn_reference_with_args.fn_reference, fn_reference_with_args.arg_hash + ): try: storage_backend.memoize(key_override, stack_frame.memento, result) memoization_status = "successfully memoized" except IOError: - log.warning("IO Error while writing memoized result.", exc_info=True) + log.warning( + "IO Error while writing memoized result.", exc_info=True + ) memoization_status = "memoization failed to write result" else: - memoization_status = "memoized elsewhere while we were computing the result" - - if context.local.ignore_result and stack_frame.memento.invocation_metadata.result_type\ - != ResultType.exception: + memoization_status = ( + "memoized elsewhere while we were computing the result" + ) + + if ( + context.local.ignore_result + and stack_frame.memento.invocation_metadata.result_type + != ResultType.exception + ): log.debug( - "{}: Result was computed, {} and is ignored".format(correlation_id, - memoization_status)) + "{}: Result was computed, {} and is ignored".format( + correlation_id, memoization_status + ) + ) return # If an exception occurred, raise it instead of returning the result if exception_result is not None: - log.debug("{}: Result was computed, {} and is an exception: {}: {}".format( - correlation_id, memoization_status, type(exception_result).__name__, - exception_result)) + log.debug( + "{}: Result was computed, {} and is an exception: {}: {}".format( + correlation_id, + memoization_status, + type(exception_result).__name__, + exception_result, + ) + ) return exception_result - log.debug("{}: Result was computed, {} and is of type {}".format( - correlation_id, memoization_status, - stack_frame.memento.invocation_metadata.result_type.name)) + log.debug( + "{}: Result was computed, {} and is of type {}".format( + correlation_id, + memoization_status, + stack_frame.memento.invocation_metadata.result_type.name, + ) + ) return result finally: log.debug("{}: Function call ends".format(correlation_id)) call_stack.pop_frame() calling_frame = call_stack.get_calling_frame() if calling_frame: - propagate_dependencies(caller_memento=calling_frame.memento, - result_memento=stack_frame.memento) + propagate_dependencies( + caller_memento=calling_frame.memento, + result_memento=stack_frame.memento, + ) diff --git a/twosigma/memento/runner_null.py b/twosigma/memento/runner_null.py index fed418b..0b5cac9 100644 --- a/twosigma/memento/runner_null.py +++ b/twosigma/memento/runner_null.py @@ -30,16 +30,18 @@ class NullRunnerBackend(RunnerBackend): def __init__(self, config: dict = None): super().__init__("null", config=config) - def batch_run(self, context: InvocationContext, storage_backend: StorageBackend, - fn_reference_with_args: List[FunctionReferenceWithArguments], - log_runner_backend: RunnerBackend, - caller_memento: Optional[Memento]) -> List[Any]: + def batch_run( + self, + context: InvocationContext, + storage_backend: StorageBackend, + fn_reference_with_args: List[FunctionReferenceWithArguments], + log_runner_backend: RunnerBackend, + caller_memento: Optional[Memento], + ) -> List[Any]: raise RuntimeError("Null runner refusing to run functions") def to_dict(self): - config = { - "type": "null" - } + config = {"type": "null"} return config diff --git a/twosigma/memento/runner_test.py b/twosigma/memento/runner_test.py index 9ed7959..5f1b1b8 100644 --- a/twosigma/memento/runner_test.py +++ b/twosigma/memento/runner_test.py @@ -49,7 +49,7 @@ def runner_fn_test_1(a, b, c=None, d=None, e=None, f=None): "c": c, "d": d, "e": e == runner_fn_test_1, - "f": (f[0]["a"] == runner_fn_test_1) if f is not None else None + "f": (f[0]["a"] == runner_fn_test_1) if f is not None else None, } diff --git a/twosigma/memento/serialization.py b/twosigma/memento/serialization.py index 82ce493..cc647e2 100644 --- a/twosigma/memento/serialization.py +++ b/twosigma/memento/serialization.py @@ -26,11 +26,17 @@ from twosigma.memento.context import RecursiveContext from twosigma.memento.metadata import ResultType, InvocationMetadata, Memento -from twosigma.memento.reference import FunctionReferenceWithArguments, FunctionReference, \ - FunctionReferenceWithArgHash +from twosigma.memento.reference import ( + FunctionReferenceWithArguments, + FunctionReference, + FunctionReferenceWithArgHash, +) from twosigma.memento.resource import ResourceHandle -from twosigma.memento.types import FunctionNotFoundError, MementoFunctionType, \ - VersionedDataSourceKey +from twosigma.memento.types import ( + FunctionNotFoundError, + MementoFunctionType, + VersionedDataSourceKey, +) class MementoCodec: @@ -54,88 +60,134 @@ def decode_datetime(cls, state: str) -> Union[datetime.date, datetime.datetime]: def encode_memento(cls, memento: Memento) -> Dict: return { "time": cls.encode_datetime(memento.time), - "invocationMetadata": cls.encode_invocation_metadata(memento.invocation_metadata), - "functionDependencies": [cls.encode_fn_reference(x) for x - in memento.function_dependencies] - if memento is not None else [], + "invocationMetadata": cls.encode_invocation_metadata( + memento.invocation_metadata + ), + "functionDependencies": ( + [cls.encode_fn_reference(x) for x in memento.function_dependencies] + if memento is not None + else [] + ), "runner": memento.runner, "correlationId": memento.correlation_id, - "contentKey": cls.encode_versioned_data_source_key(memento.content_key) + "contentKey": cls.encode_versioned_data_source_key(memento.content_key), } @classmethod def decode_memento(cls, state: Dict) -> Memento: return Memento( time=cls.decode_datetime(state["time"]), - invocation_metadata=cls.decode_invocation_metadata(state["invocationMetadata"]), - function_dependencies=set(cls.decode_fn_reference(x) for x - in state["functionDependencies"]) - if state["functionDependencies"] is not None else {}, + invocation_metadata=cls.decode_invocation_metadata( + state["invocationMetadata"] + ), + function_dependencies=( + set(cls.decode_fn_reference(x) for x in state["functionDependencies"]) + if state["functionDependencies"] is not None + else {} + ), runner=state["runner"], correlation_id=state["correlationId"], - content_key=cls.decode_versioned_data_source_key(state["contentKey"]) + content_key=cls.decode_versioned_data_source_key(state["contentKey"]), ) @classmethod def encode_invocation_metadata(cls, obj: InvocationMetadata) -> Dict: return { - "fnReferenceWithArgs": cls.encode_fn_reference_with_args(obj.fn_reference_with_args), - "invocations": [cls.encode_fn_reference_with_args(x) for x in obj.invocations] - if obj.invocations is not None else None, - "resources": [cls.encode_resource_handle(x) for x in obj.resources] - if obj.resources is not None else None, + "fnReferenceWithArgs": cls.encode_fn_reference_with_args( + obj.fn_reference_with_args + ), + "invocations": ( + [cls.encode_fn_reference_with_args(x) for x in obj.invocations] + if obj.invocations is not None + else None + ), + "resources": ( + [cls.encode_resource_handle(x) for x in obj.resources] + if obj.resources is not None + else None + ), "runtimeSeconds": obj.runtime.total_seconds(), - "resultType": obj.result_type.name + "resultType": obj.result_type.name, } @classmethod def decode_invocation_metadata(cls, state: Dict) -> InvocationMetadata: return InvocationMetadata( - fn_reference_with_args=cls.decode_fn_reference_with_args(state["fnReferenceWithArgs"]), - invocations=[cls.decode_fn_reference_with_args(x) for x in state["invocations"]] - if state["invocations"] is not None else None, - resources=[cls.decode_resource_handle(x) for x in state["resources"]] - if state["resources"] is not None else None, + fn_reference_with_args=cls.decode_fn_reference_with_args( + state["fnReferenceWithArgs"] + ), + invocations=( + [cls.decode_fn_reference_with_args(x) for x in state["invocations"]] + if state["invocations"] is not None + else None + ), + resources=( + [cls.decode_resource_handle(x) for x in state["resources"]] + if state["resources"] is not None + else None + ), runtime=datetime.timedelta(seconds=state["runtimeSeconds"]), - result_type=ResultType[state["resultType"]] + result_type=ResultType[state["resultType"]], ) @classmethod def encode_fn_reference_with_args(cls, obj: FunctionReferenceWithArguments) -> Dict: return { "fnReference": cls.encode_fn_reference(obj.fn_reference), - "args": [cls.encode_arg(x) for x in obj.args] - if obj.args is not None else None, - "kwargs": {k: cls.encode_arg(v) for (k, v) in obj.kwargs.items()} - if obj.kwargs is not None else None, - "contextArgs": {k: cls.encode_arg(v) for (k, v) in obj.context_args.items()} - if obj.context_args is not None else None + "args": ( + [cls.encode_arg(x) for x in obj.args] if obj.args is not None else None + ), + "kwargs": ( + {k: cls.encode_arg(v) for (k, v) in obj.kwargs.items()} + if obj.kwargs is not None + else None + ), + "contextArgs": ( + {k: cls.encode_arg(v) for (k, v) in obj.context_args.items()} + if obj.context_args is not None + else None + ), } @classmethod - def decode_fn_reference_with_args(cls, state: Dict) -> FunctionReferenceWithArguments: + def decode_fn_reference_with_args( + cls, state: Dict + ) -> FunctionReferenceWithArguments: return FunctionReferenceWithArguments( fn_reference=cls.decode_fn_reference(state["fnReference"]), - args=tuple([cls.decode_arg(x) for x in state["args"]]) - if state["args"] is not None else None, - kwargs={k: cls.decode_arg(v) for (k, v) in state["kwargs"].items()} - if state["kwargs"] is not None else None, - context_args={k: cls.decode_arg(v) for (k, v) in state["contextArgs"].items()} - if state["contextArgs"] is not None else None + args=( + tuple([cls.decode_arg(x) for x in state["args"]]) + if state["args"] is not None + else None + ), + kwargs=( + {k: cls.decode_arg(v) for (k, v) in state["kwargs"].items()} + if state["kwargs"] is not None + else None + ), + context_args=( + {k: cls.decode_arg(v) for (k, v) in state["contextArgs"].items()} + if state["contextArgs"] is not None + else None + ), ) @classmethod - def encode_fn_reference_with_arg_hash(cls, obj: FunctionReferenceWithArgHash) -> Dict: + def encode_fn_reference_with_arg_hash( + cls, obj: FunctionReferenceWithArgHash + ) -> Dict: return { "fnReference": cls.encode_fn_reference(obj.fn_reference), - "argHash": obj.arg_hash + "argHash": obj.arg_hash, } @classmethod - def decode_fn_reference_with_arg_hash(cls, state: Dict) -> FunctionReferenceWithArgHash: + def decode_fn_reference_with_arg_hash( + cls, state: Dict + ) -> FunctionReferenceWithArgHash: return FunctionReferenceWithArgHash( fn_reference=cls.decode_fn_reference(state["fnReference"]), - arg_hash=state["argHash"] + arg_hash=state["argHash"], ) @classmethod @@ -143,7 +195,7 @@ def encode_resource_handle(cls, obj: ResourceHandle) -> Dict: return { "resourceType": obj.resource_type, "url": obj.url, - "version": obj.version + "version": obj.version, } @classmethod @@ -151,29 +203,41 @@ def decode_resource_handle(cls, state: Dict) -> ResourceHandle: return ResourceHandle( resource_type=state["resourceType"], url=state["url"], - version=state["version"] + version=state["version"], ) @classmethod def encode_fn_reference(cls, obj: FunctionReference) -> Dict: return { "qualifiedName": obj.qualified_name, - "partialArgs": [cls.encode_arg(x) for x in obj.partial_args] - if obj.partial_args is not None else None, - "partialKwargs": {k: cls.encode_arg(v) for (k, v) in obj.partial_kwargs.items()} - if obj.partial_kwargs is not None else None, - "parameterNames": obj.parameter_names + "partialArgs": ( + [cls.encode_arg(x) for x in obj.partial_args] + if obj.partial_args is not None + else None + ), + "partialKwargs": ( + {k: cls.encode_arg(v) for (k, v) in obj.partial_kwargs.items()} + if obj.partial_kwargs is not None + else None + ), + "parameterNames": obj.parameter_names, } @classmethod def decode_fn_reference(cls, state: Dict) -> FunctionReference: return FunctionReference.from_qualified_name( qualified_name=state["qualifiedName"], - partial_args=tuple([cls.decode_arg(x) for x in state["partialArgs"]]) - if state["partialArgs"] is not None else None, - partial_kwargs={k: cls.decode_arg(v) for (k, v) in state["partialKwargs"].items()} - if state["partialKwargs"] is not None else None, - parameter_names=state["parameterNames"] + partial_args=( + tuple([cls.decode_arg(x) for x in state["partialArgs"]]) + if state["partialArgs"] is not None + else None + ), + partial_kwargs=( + {k: cls.decode_arg(v) for (k, v) in state["partialKwargs"].items()} + if state["partialKwargs"] is not None + else None + ), + parameter_names=state["parameterNames"], ) @classmethod @@ -182,7 +246,7 @@ def encode_recursive_context(cls, obj: RecursiveContext) -> Dict: "correlationId": obj.correlation_id, "retryOnRemoteCall": obj.retry_on_remote_call, "preventFurtherCalls": obj.prevent_further_calls, - "contextArgs": cls.encode_arg(obj.context_args) + "contextArgs": cls.encode_arg(obj.context_args), } @classmethod @@ -191,23 +255,27 @@ def decode_recursive_context(cls, state: Dict) -> RecursiveContext: correlation_id=state["correlationId"], retry_on_remote_call=state["retryOnRemoteCall"], prevent_further_calls=state["preventFurtherCalls"], - context_args=cls.decode_arg(state["contextArgs"]) + context_args=cls.decode_arg(state["contextArgs"]), ) @classmethod - def encode_versioned_data_source_key(cls, - content_key: VersionedDataSourceKey) -> Optional[str]: + def encode_versioned_data_source_key( + cls, content_key: VersionedDataSourceKey + ) -> Optional[str]: if content_key is None: return None return "{}#{}".format(content_key.key, content_key.version) @classmethod - def decode_versioned_data_source_key(cls, - state: Optional[str]) -> Optional[VersionedDataSourceKey]: + def decode_versioned_data_source_key( + cls, state: Optional[str] + ) -> Optional[VersionedDataSourceKey]: if state is None: return None hash_index = state.rfind("#") - return VersionedDataSourceKey(key=state[0:hash_index], version=state[hash_index + 1:]) + return VersionedDataSourceKey( + key=state[0:hash_index], version=state[hash_index + 1 :] + ) @classmethod def encode_arg(cls, obj: Any) -> Dict: @@ -217,33 +285,21 @@ def encode_arg(cls, obj: Any) -> Dict: """ if obj is None: - return { - "type": ResultType.null.name - } + return {"type": ResultType.null.name} elif isinstance(obj, bool): - return { - "type": ResultType.boolean.name, - "value": obj - } + return {"type": ResultType.boolean.name, "value": obj} elif isinstance(obj, str): - return { - "type": ResultType.string.name, - "value": obj - } + return {"type": ResultType.string.name, "value": obj} elif isinstance(obj, bytes): - return { - "type": ResultType.binary.name, - "value": base64.b64encode(obj) - } + return {"type": ResultType.binary.name, "value": base64.b64encode(obj)} elif isinstance(obj, int) or isinstance(obj, float): - return { - "type": ResultType.number.name, - "value": obj - } + return {"type": ResultType.number.name, "value": obj} elif isinstance(obj, np.ndarray): if obj.ndim > 1: - raise ValueError("Memento does not support serializing array arguments of " - "more than 1 dimension") + raise ValueError( + "Memento does not support serializing array arguments of " + "more than 1 dimension" + ) if obj.dtype == np.bool: type_str = ResultType.array_boolean.name @@ -262,37 +318,33 @@ def encode_arg(cls, obj: Any) -> Dict: else: raise ValueError("Unknown numpy array type: {}".format(obj.dtype)) - return { - "type": type_str, - "value": [x for x in obj] - } + return {"type": type_str, "value": [x for x in obj]} elif isinstance(obj, Callable): if not isinstance(obj, MementoFunctionType): - raise ValueError("{} is callable but not a MementoFunctionType".format(obj)) + raise ValueError( + "{} is callable but not a MementoFunctionType".format(obj) + ) return { "type": "twosigma.memento.FunctionReference", - "value": cls.encode_fn_reference(obj.fn_reference()) + "value": cls.encode_fn_reference(obj.fn_reference()), } elif isinstance(obj, list) or isinstance(obj, tuple): return { "type": ResultType.list_result.name, - "value": [cls.encode_arg(x) for x in obj] + "value": [cls.encode_arg(x) for x in obj], } elif isinstance(obj, dict): return { "type": ResultType.dictionary.name, - "value": {k: cls.encode_arg(v) for (k, v) in obj.items()} + "value": {k: cls.encode_arg(v) for (k, v) in obj.items()}, } elif isinstance(obj, datetime.datetime): return { "type": ResultType.timestamp.name, - "value": cls.encode_datetime(obj) + "value": cls.encode_datetime(obj), } elif isinstance(obj, datetime.date): - return { - "type": ResultType.date.name, - "value": cls.encode_datetime(obj) - } + return {"type": ResultType.date.name, "value": cls.encode_datetime(obj)} else: raise ValueError("Cannot encode argument of type {}".format(type(obj))) @@ -305,21 +357,26 @@ def decode_arg(cls, state: Dict) -> Any: """ if type(state) is not dict: - raise ValueError("During deserialization of argument, state was not a dict.") + raise ValueError( + "During deserialization of argument, state was not a dict." + ) obj_type = state["type"] if not obj_type: raise ValueError( - "During deserialization of argument, state did not have 'type' attribute.") + "During deserialization of argument, state did not have 'type' attribute." + ) if obj_type == ResultType.null.name: return None value = state["value"] - if obj_type == ResultType.boolean.name or \ - obj_type == ResultType.string.name or \ - obj_type == ResultType.number.name: + if ( + obj_type == ResultType.boolean.name + or obj_type == ResultType.string.name + or obj_type == ResultType.number.name + ): return value elif obj_type == ResultType.binary.name: return base64.b64decode(value) @@ -343,7 +400,8 @@ def decode_arg(cls, state: Dict) -> Any: raise FunctionNotFoundError( "Could not deserialize: Could not replace fn_reference with fn for {}. " "This could be because the function is not in the path or because of a " - "version mismatch.".format(fn_reference.qualified_name)) + "version mismatch.".format(fn_reference.qualified_name) + ) return fn_reference.memento_fn elif obj_type == ResultType.list_result.name: return [cls.decode_arg(x) for x in value] diff --git a/twosigma/memento/storage.py b/twosigma/memento/storage.py index ffd4360..0e88f47 100644 --- a/twosigma/memento/storage.py +++ b/twosigma/memento/storage.py @@ -26,8 +26,11 @@ from abc import ABC, abstractmethod from typing import List, Iterable, Optional -from .reference import FunctionReference, FunctionReferenceWithArguments,\ - FunctionReferenceWithArgHash +from .reference import ( + FunctionReference, + FunctionReferenceWithArguments, + FunctionReferenceWithArgHash, +) from .metadata import Memento # The set of known backends. Register new backends using StorageBackend.register. @@ -78,7 +81,9 @@ def get_memento(self, fn: FunctionReferenceWithArgHash) -> Memento: return self.get_mementos([fn])[0] @abstractmethod - def get_mementos(self, fns: List[FunctionReferenceWithArgHash]) -> List[Optional[Memento]]: + def get_mementos( + self, fns: List[FunctionReferenceWithArgHash] + ) -> List[Optional[Memento]]: """ For each invocation in the list, returns the call's Memento (memoization metadata) if the function is memoized for the given arg hash. @@ -112,8 +117,12 @@ def make_url_for_result(self, memento: Memento) -> Optional[str]: pass @abstractmethod - def read_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, key: str, - retry_on_none=False) -> Optional[bytes]: + def read_metadata( + self, + fn_with_arg_hash: FunctionReferenceWithArgHash, + key: str, + retry_on_none=False, + ) -> Optional[bytes]: """ Read custom metadata for the given arguments from the given key. This is useful, for example, for reading logs @@ -130,9 +139,13 @@ def read_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, key: str pass @abstractmethod - def write_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, key: str, - value: bytes, - store_with_content_key: Optional[VersionedDataSourceKey] = None): + def write_metadata( + self, + fn_with_arg_hash: FunctionReferenceWithArgHash, + key: str, + value: bytes, + store_with_content_key: Optional[VersionedDataSourceKey] = None, + ): """ Write custom metadata for the given arguments for the given key. This is useful, for example, for writing logs @@ -198,7 +211,9 @@ def list_mementos(self, fn: FunctionReference, limit: int = None) -> List[Mement pass @abstractmethod - def memoize(self, key_override: Optional[str], memento: Memento, result: object) -> None: + def memoize( + self, key_override: Optional[str], memento: Memento, result: object + ) -> None: """ Remember the result. As a side effect of this call, the content_hash is computed and set in the provided memento object. @@ -260,7 +275,9 @@ def create(cls, storage_type, config): global _registered_storage_backends if storage_type not in _registered_storage_backends: - raise ValueError("Unrecognized storage backend type {}".format(storage_type)) + raise ValueError( + "Unrecognized storage backend type {}".format(storage_type) + ) return _registered_storage_backends.get(storage_type)(config) @classmethod diff --git a/twosigma/memento/storage_base.py b/twosigma/memento/storage_base.py index 06327fc..2d0eaab 100644 --- a/twosigma/memento/storage_base.py +++ b/twosigma/memento/storage_base.py @@ -34,11 +34,18 @@ from .logging import log from .metadata import Memento, ResultType from .partition import Partition, InMemoryPartition -from .reference import FunctionReference, FunctionReferenceWithArgHash, \ - FunctionReferenceWithArguments +from .reference import ( + FunctionReference, + FunctionReferenceWithArgHash, + FunctionReferenceWithArguments, +) from .serialization import MementoCodec -from .types import FunctionNotFoundError, VersionedDataSourceKey, DataSourceKey,\ - ContentAddressableHash +from .types import ( + FunctionNotFoundError, + VersionedDataSourceKey, + DataSourceKey, + ContentAddressableHash, +) # The set of known codecs. Register new codecs using Codec.register. _registered_codecs = {} @@ -65,12 +72,13 @@ ResultType.index: "pickle", ResultType.series: "pickle", ResultType.data_frame: "pickle", - ResultType.partition: "partition" + ResultType.partition: "partition", } # type: Dict[ResultType, str] -_ResultTypeAndContentKey = namedtuple("_ResultTypeAndContentKey", - ["result_type", "content_key", "from_parent"]) +_ResultTypeAndContentKey = namedtuple( + "_ResultTypeAndContentKey", ["result_type", "content_key", "from_parent"] +) class DataSource(ABC): @@ -123,7 +131,9 @@ def input_versioned(self, key: VersionedDataSourceKey) -> BytesIO: pass @abstractmethod - def input_metadata(self, content_key: VersionedDataSourceKey, metadata_key: str) -> BytesIO: + def input_metadata( + self, content_key: VersionedDataSourceKey, metadata_key: str + ) -> BytesIO: """ Return a binary file-like object that reads data for the given metadata key for a given content key. This is for metadata that is stored in the object store @@ -150,9 +160,12 @@ def output(self, key: DataSourceKey, data: BytesIO) -> VersionedDataSourceKey: pass @abstractmethod - def reference(self, src_data_source: "DataSource", - src_key: VersionedDataSourceKey, - target_key: VersionedDataSourceKey) -> VersionedDataSourceKey: + def reference( + self, + src_data_source: "DataSource", + src_key: VersionedDataSourceKey, + target_key: VersionedDataSourceKey, + ) -> VersionedDataSourceKey: """ Mark a reference to the given data source key coming from the given data source. This is a NOP for data sources that do not perform reference counting for garbage @@ -161,8 +174,9 @@ def reference(self, src_data_source: "DataSource", pass @abstractmethod - def output_metadata(self, content_key: VersionedDataSourceKey, metadata_key: str, - value: bytes): + def output_metadata( + self, content_key: VersionedDataSourceKey, metadata_key: str, value: bytes + ): """ Store data for the given metadata key for a given content key. This is for metadata that is stored in the object store alongside the object. @@ -233,9 +247,14 @@ def all_exist_nonversioned(self, keys: List[DataSourceKey]) -> List[bool]: pass @abstractmethod - def list_keys_nonversioned(self, directory: DataSourceKey, file_prefix: str = "", - recursive: bool = False, limit: int = None, - endswith: str = None) -> Iterable[DataSourceKey]: + def list_keys_nonversioned( + self, + directory: DataSourceKey, + file_prefix: str = "", + recursive: bool = False, + limit: int = None, + endswith: str = None, + ) -> Iterable[DataSourceKey]: """ List the keys in the path provided by the 'prefix' parameter. @@ -281,18 +300,22 @@ def load(self, data_source: DataSource, key: VersionedDataSourceKey) -> object: pass @staticmethod - def make_url_for_key(data_source: DataSource, key: VersionedDataSourceKey) -> str: + def make_url_for_key( + data_source: DataSource, key: VersionedDataSourceKey + ) -> str: return data_source.make_url_for_key(key) @abstractmethod - def store(self, data_source: DataSource, key_override: str, obj: object) -> \ - VersionedDataSourceKey: + def store( + self, data_source: DataSource, key_override: str, obj: object + ) -> VersionedDataSourceKey: """Store the object to the store and return the key under which the data was stored""" pass @staticmethod - def output_key_for_content_key(content_key: ContentAddressableHash) -> \ - DataSourceKey: + def output_key_for_content_key( + content_key: ContentAddressableHash, + ) -> DataSourceKey: """Converts a content key to the key at which it should be stored""" return DataSourceKey("c/{}".format(content_key.key)) @@ -315,8 +338,9 @@ class NullStrategy(Strategy): def __init__(self): super().__init__() - def store(self, data_source: DataSource, key_override: str, obj: object) -> \ - Optional[VersionedDataSourceKey]: + def store( + self, data_source: DataSource, key_override: str, obj: object + ) -> Optional[VersionedDataSourceKey]: if key_override: key = self.output_key_for_override_key(key_override) data_source.delete_nonversioned_key(key) @@ -339,8 +363,9 @@ class BlobStrategy(Strategy): def __init__(self): super(Codec.BlobStrategy, self).__init__() - def store(self, data_source: DataSource, key_override: str, obj: object) -> \ - VersionedDataSourceKey: + def store( + self, data_source: DataSource, key_override: str, obj: object + ) -> VersionedDataSourceKey: data = self.encode(obj) content_hash = hashlib.sha256(data).hexdigest() key = self.output_key_for_content_key(ContentAddressableHash(content_hash)) @@ -364,8 +389,12 @@ def __init__(self, config: dict, strategy: Dict[ResultType, Type[Strategy]]): self.config = config self._strategy = strategy - def load(self, result_type: ResultType, data_source: DataSource, - key: VersionedDataSourceKey) -> object: + def load( + self, + result_type: ResultType, + data_source: DataSource, + key: VersionedDataSourceKey, + ) -> object: """ Loads a memento function result from the given data source and returns the result as a file-like object. @@ -373,8 +402,12 @@ def load(self, result_type: ResultType, data_source: DataSource, """ return self._strategy[result_type].load(data_source, key) - def make_url_for_result(self, result_type: ResultType, data_source: DataSource, - key: VersionedDataSourceKey) -> str: + def make_url_for_result( + self, + result_type: ResultType, + data_source: DataSource, + key: VersionedDataSourceKey, + ) -> str: """ Returns the url from which the memento function result from the given data source can be retrieved. @@ -382,8 +415,13 @@ def make_url_for_result(self, result_type: ResultType, data_source: DataSource, """ return self._strategy[result_type].make_url_for_key(data_source, key) - def store(self, result_type: ResultType, data_source: DataSource, key_override: Optional[str], - obj: object) -> VersionedDataSourceKey: + def store( + self, + result_type: ResultType, + data_source: DataSource, + key_override: Optional[str], + obj: object, + ) -> VersionedDataSourceKey: """ Encode and store the memento function result in the provided object to the datastore, returning the key under which the data was stored. @@ -433,29 +471,32 @@ class DefaultCodec(Codec): """ def __init__(self, config): - super().__init__(config, { - ResultType.exception: self.JsonExceptionStrategy(), - ResultType.null: self.NullStrategy(), - ResultType.boolean: self.ValuePickleStrategy(), - ResultType.string: self.ValuePickleStrategy(), - ResultType.binary: self.ValuePickleStrategy(), - ResultType.number: self.ValuePickleStrategy(), - ResultType.date: self.ValuePickleStrategy(), - ResultType.timestamp: self.ValuePickleStrategy(), - ResultType.list_result: self.ValuePickleStrategy(), - ResultType.dictionary: self.ValuePickleStrategy(), - ResultType.array_boolean: self.ValuePickleStrategy(), - ResultType.array_int8: self.ValuePickleStrategy(), - ResultType.array_int16: self.ValuePickleStrategy(), - ResultType.array_int32: self.ValuePickleStrategy(), - ResultType.array_int64: self.ValuePickleStrategy(), - ResultType.array_float32: self.ValuePickleStrategy(), - ResultType.array_float64: self.ValuePickleStrategy(), - ResultType.index: self.ValuePickleStrategy(), - ResultType.series: self.ValuePickleStrategy(), - ResultType.data_frame: self.ValuePickleStrategy(), - ResultType.partition: self.PicklePartitionStrategy(self) - }) + super().__init__( + config, + { + ResultType.exception: self.JsonExceptionStrategy(), + ResultType.null: self.NullStrategy(), + ResultType.boolean: self.ValuePickleStrategy(), + ResultType.string: self.ValuePickleStrategy(), + ResultType.binary: self.ValuePickleStrategy(), + ResultType.number: self.ValuePickleStrategy(), + ResultType.date: self.ValuePickleStrategy(), + ResultType.timestamp: self.ValuePickleStrategy(), + ResultType.list_result: self.ValuePickleStrategy(), + ResultType.dictionary: self.ValuePickleStrategy(), + ResultType.array_boolean: self.ValuePickleStrategy(), + ResultType.array_int8: self.ValuePickleStrategy(), + ResultType.array_int16: self.ValuePickleStrategy(), + ResultType.array_int32: self.ValuePickleStrategy(), + ResultType.array_int64: self.ValuePickleStrategy(), + ResultType.array_float32: self.ValuePickleStrategy(), + ResultType.array_float64: self.ValuePickleStrategy(), + ResultType.index: self.ValuePickleStrategy(), + ResultType.series: self.ValuePickleStrategy(), + ResultType.data_frame: self.ValuePickleStrategy(), + ResultType.partition: self.PicklePartitionStrategy(self), + }, + ) class JsonExceptionStrategy(Codec.BlobStrategy): """JSON encoding, returning as a MementoException""" @@ -463,19 +504,25 @@ class JsonExceptionStrategy(Codec.BlobStrategy): def __init__(self): super().__init__() - def load(self, data_source: DataSource, key: VersionedDataSourceKey) -> MementoException: + def load( + self, data_source: DataSource, key: VersionedDataSourceKey + ) -> MementoException: with data_source.input_versioned(key) as f: with io.TextIOWrapper(f, encoding="utf-8") as t: result_dict = json.load(t) - return MementoException(result_dict["exception_name"], result_dict["message"], - result_dict["stack_trace"]) + return MementoException( + result_dict["exception_name"], + result_dict["message"], + result_dict["stack_trace"], + ) def encode(self, obj: MementoException) -> bytes: store_dict = { "exception_name": obj.exception_name, "message": obj.message, "stack_trace": "".join( - traceback.format_exception(type(obj), obj, obj.__traceback__)) + traceback.format_exception(type(obj), obj, obj.__traceback__) + ), } return json.dumps(store_dict).encode("utf-8") @@ -494,8 +541,12 @@ class PicklePartition(Partition): _index = None # type: Dict[str, _ResultTypeAndContentKey] """""" - def __init__(self, - codec: Codec, data_source: DataSource, base_key: VersionedDataSourceKey): + def __init__( + self, + codec: Codec, + data_source: DataSource, + base_key: VersionedDataSourceKey, + ): super().__init__() self._codec = codec self._data_source = data_source @@ -509,8 +560,10 @@ def _deserialize_index(data: bytes) -> Dict[str, _ResultTypeAndContentKey]: return { k: _ResultTypeAndContentKey( result_type=ResultType[v["result_type"]], - content_key=MementoCodec.decode_versioned_data_source_key(v["content_key"]), - from_parent=v.get("from_parent", False) + content_key=MementoCodec.decode_versioned_data_source_key( + v["content_key"] + ), + from_parent=v.get("from_parent", False), ) for (k, v) in d.items() } @@ -520,8 +573,10 @@ def _serialize_index(index: Dict[str, _ResultTypeAndContentKey]) -> bytes: d = { k: { "result_type": v.result_type.name, - "content_key": MementoCodec.encode_versioned_data_source_key(v.content_key), - "from_parent": v.from_parent + "content_key": MementoCodec.encode_versioned_data_source_key( + v.content_key + ), + "from_parent": v.from_parent, } for (k, v) in index.items() } @@ -529,16 +584,23 @@ def _serialize_index(index: Dict[str, _ResultTypeAndContentKey]) -> bytes: def get(self, key: str) -> object: if key not in self._index: - raise ValueError("Key '{}' is not in key list for partition".format(key)) + raise ValueError( + "Key '{}' is not in key list for partition".format(key) + ) index_entry = self._index[key] - return self._codec.load(index_entry.result_type, self._data_source, - index_entry.content_key) + return self._codec.load( + index_entry.result_type, self._data_source, index_entry.content_key + ) def list_keys(self, _include_merge_parent: bool = True) -> Iterable[str]: if _include_merge_parent: keys = self._index.keys() else: - keys = [key for key in self._index.keys() if not self._index.get(key).from_parent] + keys = [ + key + for key in self._index.keys() + if not self._index.get(key).from_parent + ] return sorted(keys) class PicklePartitionStrategy(Codec.BlobStrategy): @@ -556,11 +618,14 @@ def __init__(self, codec: Codec): super().__init__() self._codec = codec - def load(self, data_source: DataSource, key: VersionedDataSourceKey) -> Partition: + def load( + self, data_source: DataSource, key: VersionedDataSourceKey + ) -> Partition: return DefaultCodec.PicklePartition(self._codec, data_source, key) - def store(self, data_source: DataSource, key_override: str, obj: Partition) -> \ - VersionedDataSourceKey: + def store( + self, data_source: DataSource, key_override: str, obj: Partition + ) -> VersionedDataSourceKey: # build a dict of key to result type index = dict() # type: Dict[str, _ResultTypeAndContentKey] @@ -569,28 +634,37 @@ def store(self, data_source: DataSource, key_override: str, obj: Partition) -> \ merge_parent = obj._merge_parent if merge_parent: if isinstance(merge_parent, DefaultCodec.PicklePartition): - pickle_partition_parent = cast(DefaultCodec.PicklePartition, merge_parent) + pickle_partition_parent = cast( + DefaultCodec.PicklePartition, merge_parent + ) # noinspection PyProtectedMember parent_index = pickle_partition_parent._index # noinspection PyProtectedMember parent_data_source = pickle_partition_parent._data_source - elif hasattr(merge_parent, "_output_keys") \ - and hasattr(merge_parent, "_data_source"): + elif hasattr(merge_parent, "_output_keys") and hasattr( + merge_parent, "_data_source" + ): # noinspection PyProtectedMember parent_index = merge_parent._output_keys # noinspection PyProtectedMember parent_data_source = merge_parent._data_source else: - raise IOError("Could not merge partitions: parent is not " - "a PicklePartition or has never been serialized") + raise IOError( + "Could not merge partitions: parent is not " + "a PicklePartition or has never been serialized" + ) for k, v in parent_index.items(): # Mark a reference to all values that come from parents. This is for storage # backends that do reference counting. - data_source.reference(parent_data_source, v.content_key, v.content_key) + data_source.reference( + parent_data_source, v.content_key, v.content_key + ) - index[k] = _ResultTypeAndContentKey(result_type=v.result_type, - content_key=v.content_key, - from_parent=True) + index[k] = _ResultTypeAndContentKey( + result_type=v.result_type, + content_key=v.content_key, + from_parent=True, + ) # Layer current keys on top of parent's keys output_keys = dict() @@ -600,15 +674,21 @@ def store(self, data_source: DataSource, key_override: str, obj: Partition) -> \ result_type = ResultType.from_object(result) # Store - content_key_override = "{}/{}".format(key_override, k)\ - if key_override is not None else None - partition_content_key = self._codec.store(result_type, data_source, - content_key_override, result) + content_key_override = ( + "{}/{}".format(key_override, k) + if key_override is not None + else None + ) + partition_content_key = self._codec.store( + result_type, data_source, content_key_override, result + ) # Update map of key to result info - index_entry = _ResultTypeAndContentKey(result_type=result_type, - content_key=partition_content_key, - from_parent=False) + index_entry = _ResultTypeAndContentKey( + result_type=result_type, + content_key=partition_content_key, + from_parent=False, + ) output_keys[k] = index_entry index[k] = index_entry @@ -621,10 +701,14 @@ def store(self, data_source: DataSource, key_override: str, obj: Partition) -> \ # noinspection PyProtectedMember obj._index_bytes = DefaultCodec.PicklePartition._serialize_index(index) - index_key_override = "{}/index.json".format(key_override) \ - if key_override is not None else None - return super().store(data_source=data_source, - key_override=index_key_override, obj=obj) + index_key_override = ( + "{}/index.json".format(key_override) + if key_override is not None + else None + ) + return super().store( + data_source=data_source, key_override=index_key_override, obj=obj + ) def encode(self, obj: Partition) -> bytes: # noinspection PyProtectedMember @@ -651,6 +735,7 @@ class ResultIsWithData: """ Marker class to indicate that metadata is stored with the data. """ + pass @@ -665,7 +750,9 @@ def __init__(self): pass @abstractmethod - def get_mementos(self, fns: List[FunctionReferenceWithArgHash]) -> List[Optional[Memento]]: + def get_mementos( + self, fns: List[FunctionReferenceWithArgHash] + ) -> List[Optional[Memento]]: """ See StorageBackend.get_mementos() @@ -690,7 +777,9 @@ def list_functions(self) -> List[FunctionReference]: pass @abstractmethod - def list_mementos(self, fn: FunctionReference, limit: Optional[int]) -> List[Memento]: + def list_mementos( + self, fn: FunctionReference, limit: Optional[int] + ) -> List[Memento]: """ List all mementos for the given function @@ -709,8 +798,12 @@ def put_memento(self, memento: Memento): pass @abstractmethod - def read_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, - key: str, retry_on_none=False) -> Optional[Union[bytes, ResultIsWithData]]: + def read_metadata( + self, + fn_with_arg_hash: FunctionReferenceWithArgHash, + key: str, + retry_on_none=False, + ) -> Optional[Union[bytes, ResultIsWithData]]: """ Read metadata (e.g. logs) associated with the Memento for the provided `fn_with_arg_hash`. @@ -729,8 +822,13 @@ def read_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, pass @abstractmethod - def write_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, - key: str, value: bytes, stored_with_data: bool): + def write_metadata( + self, + fn_with_arg_hash: FunctionReferenceWithArgHash, + key: str, + value: bytes, + stored_with_data: bool, + ): """ Write metadata (e.g. logs) associated with the Memento for the given `fn_with_arg_hash` @@ -779,41 +877,65 @@ def __init__(self, data_source: DataSource): @staticmethod def _get_function_path(fn_reference: FunctionReference) -> DataSourceKey: - return DataSourceKey("{}/{}".format(DataSourceMetadataSource._function_path_prefix.key, - fn_reference.qualified_name)) + return DataSourceKey( + "{}/{}".format( + DataSourceMetadataSource._function_path_prefix.key, + fn_reference.qualified_name, + ) + ) @staticmethod def _get_path(fn_reference: FunctionReference, arg_hash: str) -> str: - return "{}/{}".format(DataSourceMetadataSource._get_function_path(fn_reference).key, - arg_hash) + return "{}/{}".format( + DataSourceMetadataSource._get_function_path(fn_reference).key, arg_hash + ) @staticmethod - def _get_metadata_path(fn_with_arg_hash: FunctionReferenceWithArgHash) -> DataSourceKey: - return DataSourceKey("{}.memento.json".format(DataSourceMetadataSource._get_path( - fn_with_arg_hash.fn_reference, fn_with_arg_hash.arg_hash))) + def _get_metadata_path( + fn_with_arg_hash: FunctionReferenceWithArgHash, + ) -> DataSourceKey: + return DataSourceKey( + "{}.memento.json".format( + DataSourceMetadataSource._get_path( + fn_with_arg_hash.fn_reference, fn_with_arg_hash.arg_hash + ) + ) + ) @staticmethod - def _get_metadata_key(fn_with_arg_hash: FunctionReferenceWithArgHash, key: str, - stored_with_data: bool) -> DataSourceKey: + def _get_metadata_key( + fn_with_arg_hash: FunctionReferenceWithArgHash, key: str, stored_with_data: bool + ) -> DataSourceKey: return DataSourceKey( - "{}.metadata.{}{}".format(DataSourceMetadataSource._get_path( - fn_with_arg_hash.fn_reference, fn_with_arg_hash.arg_hash), key, - ".with_data" if stored_with_data else "")) + "{}.metadata.{}{}".format( + DataSourceMetadataSource._get_path( + fn_with_arg_hash.fn_reference, fn_with_arg_hash.arg_hash + ), + key, + ".with_data" if stored_with_data else "", + ) + ) def _read_memento(self, path: DataSourceKey) -> Memento: with self.data_source.input_nonversioned(path) as f: with TextIOWrapper(f, encoding="utf-8") as t: return MementoCodec.decode_memento(json.load(t)) - def get_mementos(self, fns: List[FunctionReferenceWithArgHash]) -> List[Optional[Memento]]: + def get_mementos( + self, fns: List[FunctionReferenceWithArgHash] + ) -> List[Optional[Memento]]: results = [] for fn_ref_with_arg_hash in fns: - metadata_path = DataSourceMetadataSource._get_metadata_path(fn_ref_with_arg_hash) + metadata_path = DataSourceMetadataSource._get_metadata_path( + fn_ref_with_arg_hash + ) try: memento = self._read_memento(metadata_path) except FunctionNotFoundError as e: - log.debug("Ignoring memoized result: while decoding Memento for {}, " - "could not find function: {}".format(fn_ref_with_arg_hash, e)) + log.debug( + "Ignoring memoized result: while decoding Memento for {}, " + "could not find function: {}".format(fn_ref_with_arg_hash, e) + ) memento = None except IOError: memento = None @@ -828,32 +950,48 @@ def list_functions(self) -> List[FunctionReference]: fpp = DataSourceMetadataSource._function_path_prefix return [ FunctionReference.from_qualified_name( - x.key[len(DataSourceMetadataSource._function_path_prefix) + 1:]) for - x in self.data_source.list_keys_nonversioned(directory=fpp, file_prefix="", - recursive=False) + x.key[len(DataSourceMetadataSource._function_path_prefix) + 1 :] + ) + for x in self.data_source.list_keys_nonversioned( + directory=fpp, file_prefix="", recursive=False + ) ] - def list_mementos(self, fn: FunctionReference, limit: Optional[int]) -> List[Memento]: + def list_mementos( + self, fn: FunctionReference, limit: Optional[int] + ) -> List[Memento]: path = DataSourceMetadataSource._get_function_path(fn) result = [] - for key in self.data_source.list_keys_nonversioned(directory=path, file_prefix="", - recursive=False, limit=limit, - endswith=".memento.json"): + for key in self.data_source.list_keys_nonversioned( + directory=path, + file_prefix="", + recursive=False, + limit=limit, + endswith=".memento.json", + ): result.append(self._read_memento(key)) return result def put_memento(self, memento: Memento): metadata_path = DataSourceMetadataSource._get_metadata_path( - memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash()) + memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash() + ) log.debug("Writing metadata to {}...".format(metadata_path)) memento_json = json.dumps(MementoCodec.encode_memento(memento)) self.data_source.output(metadata_path, io.BytesIO(memento_json.encode("utf-8"))) - def read_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, - key: str, retry_on_none=False) -> Optional[Union[bytes, ResultIsWithData]]: - metadata_key = DataSourceMetadataSource._get_metadata_key(fn_with_arg_hash, key, False) - metadata_key_with_data =\ - DataSourceMetadataSource._get_metadata_key(fn_with_arg_hash, key, True) + def read_metadata( + self, + fn_with_arg_hash: FunctionReferenceWithArgHash, + key: str, + retry_on_none=False, + ) -> Optional[Union[bytes, ResultIsWithData]]: + metadata_key = DataSourceMetadataSource._get_metadata_key( + fn_with_arg_hash, key, False + ) + metadata_key_with_data = DataSourceMetadataSource._get_metadata_key( + fn_with_arg_hash, key, True + ) def data_source_exists(): retries = 3 if retry_on_none else 1 @@ -876,20 +1014,30 @@ def data_source_exists(): else: return None - def write_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, - key: str, value: bytes, stored_with_data: bool): - metadata_key = DataSourceMetadataSource._get_metadata_key(fn_with_arg_hash, key, - stored_with_data) + def write_metadata( + self, + fn_with_arg_hash: FunctionReferenceWithArgHash, + key: str, + value: bytes, + stored_with_data: bool, + ): + metadata_key = DataSourceMetadataSource._get_metadata_key( + fn_with_arg_hash, key, stored_with_data + ) log.debug("Writing metadata to key {}...".format(metadata_key)) - self.data_source.output(metadata_key, io.BytesIO(bytes() if stored_with_data else value)) + self.data_source.output( + metadata_key, io.BytesIO(bytes() if stored_with_data else value) + ) def forget_call(self, fn_with_arg_hash: FunctionReferenceWithArgHash): - call_path_prefix = DataSourceMetadataSource._get_path(fn_with_arg_hash.fn_reference, - fn_with_arg_hash.arg_hash) + call_path_prefix = DataSourceMetadataSource._get_path( + fn_with_arg_hash.fn_reference, fn_with_arg_hash.arg_hash + ) for key in self.data_source.list_keys_nonversioned( - directory=DataSourceKey(os.path.dirname(call_path_prefix)), - file_prefix=os.path.basename(call_path_prefix), - recursive=False): + directory=DataSourceKey(os.path.dirname(call_path_prefix)), + file_prefix=os.path.basename(call_path_prefix), + recursive=False, + ): self.data_source.delete_all_versions(key, False) def forget_everything(self): @@ -904,6 +1052,7 @@ class _CacheEntry: """ Each entry stores a memento and, optionally, a value. Note that `None` is a valid value. """ + obj_size = None # type: int memento = None # type: Memento value = None # type: object @@ -921,6 +1070,7 @@ class MemoryCache: Write-through memory cache for memoized data """ + memory_cache_bytes = None # type: int memory_usage = None # type: int lru_deque = None # type: deque @@ -943,7 +1093,9 @@ def _pd_mem_usage(obj: Union[pd.DataFrame, pd.Series]) -> int: raise TypeError("Can't compute the memory usage for type {}".format(type(obj))) @staticmethod - def _pd_linreg_mem_usage(obj: Union[pd.DataFrame, pd.Series], sample_size: int = 100) -> int: + def _pd_linreg_mem_usage( + obj: Union[pd.DataFrame, pd.Series], sample_size: int = 100 + ) -> int: """ A good estimate of memory usage for DataFrames and Series without taking too much time. Methodology splits the sample unevenly, deeply computes the memory usage of each split, @@ -994,8 +1146,11 @@ def _estimate_object_size(obj: object) -> int: result = 0 if hasattr(obj, "__len__"): length = len(obj) - result += length * MemoryCache._estimate_object_size( - next(iter(obj))) if length > 0 else 0 + result += ( + length * MemoryCache._estimate_object_size(next(iter(obj))) + if length > 0 + else 0 + ) else: i = obj # type: Iterable result += sum([MemoryCache._estimate_object_size(x) for x in i]) @@ -1006,7 +1161,8 @@ def _estimate_object_size(obj: object) -> int: def _cache_key_for_memento(memento: Memento) -> str: return MemoryCache._cache_key_for_fn( memento.invocation_metadata.fn_reference_with_args.fn_reference, - memento.invocation_metadata.fn_reference_with_args.arg_hash) + memento.invocation_metadata.fn_reference_with_args.arg_hash, + ) @staticmethod def _cache_key_for_fn(fn_ref: FunctionReference, arg_hash: str) -> str: @@ -1028,7 +1184,9 @@ def _evict(self, cache_key: str): if cache_key in self.lru_deque: self.lru_deque.remove(cache_key) - def get_mementos(self, fns: List[FunctionReferenceWithArgHash]) -> List[Optional[Memento]]: + def get_mementos( + self, fns: List[FunctionReferenceWithArgHash] + ) -> List[Optional[Memento]]: result = [] for fn in fns: cache_key = self._cache_key_for_fn(fn.fn_reference, fn.arg_hash) @@ -1091,7 +1249,10 @@ def put(self, memento: Memento, result: object, has_result: bool): self._evict(cache_key) # Free up memory in the cache (if needed) by discarding LRU - while len(self.lru_deque) > 0 and self.memory_usage + obj_size > self.memory_cache_bytes: + while ( + len(self.lru_deque) > 0 + and self.memory_usage + obj_size > self.memory_cache_bytes + ): self._evict(self.lru_deque.popleft()) # Add to cache @@ -1101,8 +1262,9 @@ def put(self, memento: Memento, result: object, has_result: bool): self.memory_usage += obj_size def forget_call(self, fn_with_arg_hash: FunctionReferenceWithArgHash): - cache_key = self._cache_key_for_fn(fn_with_arg_hash.fn_reference, - fn_with_arg_hash.arg_hash) + cache_key = self._cache_key_for_fn( + fn_with_arg_hash.fn_reference, fn_with_arg_hash.arg_hash + ) self.refs.pop(cache_key, None) self._evict(cache_key) @@ -1117,10 +1279,14 @@ def forget_function(self, fn_reference: FunctionReference): qualified_name_slash = qualified_name + "/" # This is O(n) but should be a rare operation and spares us the complexity of maintaining # a second map - ref_list = [key for key in self.refs.keys() if key.startswith(qualified_name_slash)] + ref_list = [ + key for key in self.refs.keys() if key.startswith(qualified_name_slash) + ] for key in ref_list: del self.refs[key] - evict_list = [key for key in self.cache.keys() if key.startswith(qualified_name_slash)] + evict_list = [ + key for key in self.cache.keys() if key.startswith(qualified_name_slash) + ] for key in evict_list: self._evict(key) @@ -1140,8 +1306,15 @@ class StorageBackendBase(StorageBackend, ABC): _memory_cache = None # type: MemoryCache codec = None # type: Codec - def __init__(self, storage_type: str, data_source: DataSource, metadata_source: MetadataSource, - memory_cache_mb: int = None, config: dict = None, read_only: bool = None): + def __init__( + self, + storage_type: str, + data_source: DataSource, + metadata_source: MetadataSource, + memory_cache_mb: int = None, + config: dict = None, + read_only: bool = None, + ): super().__init__(storage_type, config=config, read_only=read_only) self._data_source = data_source self._metadata_source = metadata_source @@ -1157,10 +1330,15 @@ def _get_function_path(fn_reference: FunctionReference) -> str: @staticmethod def _get_path(fn_reference: FunctionReference, arg_hash: str) -> DataSourceKey: - return DataSourceKey("{}/{}".format(StorageBackendBase._get_function_path( - fn_reference), arg_hash)) - - def get_mementos(self, fns: List[FunctionReferenceWithArgHash]) -> List[Optional[Memento]]: + return DataSourceKey( + "{}/{}".format( + StorageBackendBase._get_function_path(fn_reference), arg_hash + ) + ) + + def get_mementos( + self, fns: List[FunctionReferenceWithArgHash] + ) -> List[Optional[Memento]]: # First, consult cache if self._memory_cache: cache_result = self._memory_cache.get_mementos(fns) @@ -1196,16 +1374,22 @@ def read_result(self, memento: Memento) -> object: # Fall back to storage pass - result = self.codec.load(memento.invocation_metadata.result_type, self._data_source, - memento.content_key) + result = self.codec.load( + memento.invocation_metadata.result_type, + self._data_source, + memento.content_key, + ) if self._memory_cache: self._memory_cache.put(memento, result, has_result=True) return result def make_url_for_result(self, memento: Memento) -> str: - return self.codec.make_url_for_result(memento.invocation_metadata.result_type, - self._data_source, memento.content_key) + return self.codec.make_url_for_result( + memento.invocation_metadata.result_type, + self._data_source, + memento.content_key, + ) def is_memoized(self, fn_reference: FunctionReference, arg_hash: str) -> bool: if self._memory_cache: @@ -1213,7 +1397,8 @@ def is_memoized(self, fn_reference: FunctionReference, arg_hash: str) -> bool: return True # if not in memory cache, fall back to storage return self._metadata_source.all_mementos_exist( - [FunctionReferenceWithArgHash(fn_reference, arg_hash)]) + [FunctionReferenceWithArgHash(fn_reference, arg_hash)] + ) def is_all_memoized(self, fns: Iterable[FunctionReferenceWithArguments]) -> bool: if self._memory_cache: @@ -1221,7 +1406,8 @@ def is_all_memoized(self, fns: Iterable[FunctionReferenceWithArguments]) -> bool return True # if not in memory cache, fall back to storage return self._metadata_source.all_mementos_exist( - [fn.fn_reference_with_arg_hash() for fn in fns]) + [fn.fn_reference_with_arg_hash() for fn in fns] + ) def list_functions(self) -> List[FunctionReference]: # Do not consult cache since it would not give us the full picture @@ -1241,7 +1427,9 @@ def memoize(self, key_override: str, memento: Memento, result: object) -> None: # Write data result_type = memento.invocation_metadata.result_type - content_key = self.codec.store(result_type, self._data_source, key_override, result) + content_key = self.codec.store( + result_type, self._data_source, key_override, result + ) log.debug("Wrote data to {}".format(content_key)) assert (result_type == ResultType.null) or (content_key is not None) memento.content_key = content_key @@ -1249,33 +1437,46 @@ def memoize(self, key_override: str, memento: Memento, result: object) -> None: # Write metadata self._metadata_source.put_memento(memento) - def read_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, key: str, - retry_on_none=False) -> bytes: + def read_metadata( + self, + fn_with_arg_hash: FunctionReferenceWithArgHash, + key: str, + retry_on_none=False, + ) -> bytes: # Metadata is not currently cached - result = self._metadata_source.read_metadata(fn_with_arg_hash, key, - retry_on_none=retry_on_none) + result = self._metadata_source.read_metadata( + fn_with_arg_hash, key, retry_on_none=retry_on_none + ) if isinstance(result, ResultIsWithData): # Metadata is stored alongside the data object memento = self.get_mementos([fn_with_arg_hash])[0] if memento is None: - raise IOError("Metadata shows metadata should exist with data, but could " - "not retrieve Memento: {}".format(fn_with_arg_hash)) + raise IOError( + "Metadata shows metadata should exist with data, but could " + "not retrieve Memento: {}".format(fn_with_arg_hash) + ) result = self._data_source.input_metadata(memento.content_key, key) return result - def write_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, - key: str, value: bytes, - store_with_content_key: Optional[VersionedDataSourceKey] = None): + def write_metadata( + self, + fn_with_arg_hash: FunctionReferenceWithArgHash, + key: str, + value: bytes, + store_with_content_key: Optional[VersionedDataSourceKey] = None, + ): assert fn_with_arg_hash is not None if not self.read_only: if store_with_content_key: - self._metadata_source.write_metadata(fn_with_arg_hash, key, bytes(), - stored_with_data=True) + self._metadata_source.write_metadata( + fn_with_arg_hash, key, bytes(), stored_with_data=True + ) self._data_source.output_metadata(store_with_content_key, key, value) else: - self._metadata_source.write_metadata(fn_with_arg_hash, key, value, - stored_with_data=False) + self._metadata_source.write_metadata( + fn_with_arg_hash, key, value, stored_with_data=False + ) else: raise ValueError("Cannot write metadata to a read-only storage backend") diff --git a/twosigma/memento/storage_filesystem.py b/twosigma/memento/storage_filesystem.py index ccfece9..4a03cff 100644 --- a/twosigma/memento/storage_filesystem.py +++ b/twosigma/memento/storage_filesystem.py @@ -40,8 +40,13 @@ from .metadata import ResultType from .partition import Partition from .storage import StorageBackend -from .storage_base import DataSource, DataSourceMetadataSource, StorageBackendBase, DefaultCodec, \ - Codec +from .storage_base import ( + DataSource, + DataSourceMetadataSource, + StorageBackendBase, + DefaultCodec, + Codec, +) from .types import DataSourceKey, VersionedDataSourceKey @@ -78,12 +83,16 @@ def _write_non_versioned_link(self, versioned_key: VersionedDataSourceKey): f.write(str(versioned_path)) def _delete_non_versioned_link(self, key: DataSourceKey): - non_versioned_path = self._get_non_versioned_link_path(self._escape_key(key.key)) + non_versioned_path = self._get_non_versioned_link_path( + self._escape_key(key.key) + ) if os.path.isfile(non_versioned_path): os.unlink(non_versioned_path) def _read_non_versioned_link(self, key: DataSourceKey) -> Path: - non_versioned_path = self._get_non_versioned_link_path(self._escape_key(key.key)) + non_versioned_path = self._get_non_versioned_link_path( + self._escape_key(key.key) + ) with open(str(non_versioned_path), "r") as f: versioned_path = Path(f.read()) return versioned_path @@ -93,8 +102,9 @@ def _get_versions_directory(self, key: DataSourceKey) -> Path: dirname = path.parent return dirname.joinpath(".versions") - def _get_path_versioned(self, key: VersionedDataSourceKey, - metadata_key: Optional[str] = None) -> Path: + def _get_path_versioned( + self, key: VersionedDataSourceKey, metadata_key: Optional[str] = None + ) -> Path: escaped_key = self._escape_key(key.key) dirname = os.path.dirname(escaped_key) basename = os.path.basename(escaped_key) @@ -116,7 +126,9 @@ def input_nonversioned(self, key: DataSourceKey) -> IO: def input_versioned(self, key: VersionedDataSourceKey) -> IO: return self._do_input(self._get_path_versioned(key)) - def input_metadata(self, content_key: VersionedDataSourceKey, metadata_key: str) -> bytes: + def input_metadata( + self, content_key: VersionedDataSourceKey, metadata_key: str + ) -> bytes: path = self._get_path_versioned(content_key, metadata_key=metadata_key) with self._do_input(path) as f: return f.read() @@ -131,7 +143,9 @@ def exists_versioned(self, key: VersionedDataSourceKey) -> bool: return result def exists_nonversioned(self, key: DataSourceKey) -> bool: - non_versioned_path = self._get_non_versioned_link_path(self._escape_key(key.key)) + non_versioned_path = self._get_non_versioned_link_path( + self._escape_key(key.key) + ) if not os.path.exists(non_versioned_path): result = False else: @@ -157,14 +171,21 @@ def output(self, key: DataSourceKey, data: IO) -> VersionedDataSourceKey: self._write_non_versioned_link(versioned_key) return versioned_key - def reference(self, src_data_source: DataSource, src_key: VersionedDataSourceKey, - target_key: VersionedDataSourceKey): + def reference( + self, + src_data_source: DataSource, + src_key: VersionedDataSourceKey, + target_key: VersionedDataSourceKey, + ): # This data source does not perform reference counting pass - def output_metadata(self, content_key: VersionedDataSourceKey, metadata_key: str, - value: bytes): - versioned_path = self._get_path_versioned(content_key, metadata_key=metadata_key) + def output_metadata( + self, content_key: VersionedDataSourceKey, metadata_key: str, value: bytes + ): + versioned_path = self._get_path_versioned( + content_key, metadata_key=metadata_key + ) log.debug("Writing {}".format(versioned_path)) with versioned_path.open(mode="wb") as f: f.write(value) @@ -210,39 +231,53 @@ def delete_all_versions(self, key: DataSourceKey, recursive: bool): path.rmdir() path = path.parent - def list_keys_nonversioned(self, directory: DataSourceKey, file_prefix: str = "", - recursive: bool = False, limit: int = None, - endswith: str = None) -> Iterable[DataSourceKey]: + def list_keys_nonversioned( + self, + directory: DataSourceKey, + file_prefix: str = "", + recursive: bool = False, + limit: int = None, + endswith: str = None, + ) -> Iterable[DataSourceKey]: dir_path = self._get_non_versioned_path(directory) if not dir_path.is_dir(): return [] escaped_key = self._escape_key(directory.key) - dir_prefix = (escaped_key + "/") if escaped_key and not escaped_key.endswith("/") else "" + dir_prefix = ( + (escaped_key + "/") if escaped_key and not escaped_key.endswith("/") else "" + ) if recursive: + def walk_path_recursive(): count = 0 dir_path_str = str(dir_path) for dirpath, dirname, filenames in os.walk(dir_path_str): - if "{}.versions{}".format(os.sep, os.sep) in dirpath or \ - "{}.tmp{}".format(os.sep, os.sep) in dirpath: + if ( + "{}.versions{}".format(os.sep, os.sep) in dirpath + or "{}.tmp{}".format(os.sep, os.sep) in dirpath + ): continue for filename in filenames: - entry = dir_prefix + os.path.join( - dirpath, filename)[len(dir_path_str) + 1:] + entry = ( + dir_prefix + + os.path.join(dirpath, filename)[len(dir_path_str) + 1 :] + ) if entry.endswith(".link"): entry = entry[0:-5] # strip .link off end of string # Filter down to files that begin with file_prefix if os.path.basename(entry).startswith(file_prefix): - if endswith is not None and not os.path.basename(entry).endswith( - endswith): + if endswith is not None and not os.path.basename( + entry + ).endswith(endswith): continue count += 1 - yield DataSourceKey(entry.replace(os.sep, '/')) + yield DataSourceKey(entry.replace(os.sep, "/")) if count == limit: return entries = list(walk_path_recursive()) else: + def walk_path(): count = 0 for entry in dir_path.iterdir(): @@ -252,7 +287,9 @@ def walk_path(): if entry.name.startswith(file_prefix): entry_name = unquote(entry.name) if entry_name.endswith(".link"): - entry_name = entry_name[0:-5] # strip .link off end of string + entry_name = entry_name[ + 0:-5 + ] # strip .link off end of string if endswith is not None and not entry_name.endswith(endswith): continue count += 1 @@ -273,9 +310,14 @@ class FilesystemStorageBackend(StorageBackendBase): config_path = None # type: str metadata_config_path = None # type: str - def __init__(self, config: dict = None, path: str = None, metadata_path: str = None, - memory_cache_mb: int = None, - read_only: bool = None): + def __init__( + self, + config: dict = None, + path: str = None, + metadata_path: str = None, + memory_cache_mb: int = None, + read_only: bool = None, + ): """ Create a storage backend that reads from the filesystem. See module documentation for parameters. Parameters that follow @@ -299,23 +341,30 @@ def __init__(self, config: dict = None, path: str = None, metadata_path: str = N data_source = _FilesystemDataSource(self.config_path) metadata_source = DataSourceMetadataSource( - _FilesystemDataSource( - self.metadata_config_path) if self.metadata_config_path != self.config_path - else data_source) - - super().__init__("filesystem", data_source=data_source, metadata_source=metadata_source, - memory_cache_mb=memory_cache_mb, config=config, read_only=read_only) + _FilesystemDataSource(self.metadata_config_path) + if self.metadata_config_path != self.config_path + else data_source + ) + + super().__init__( + "filesystem", + data_source=data_source, + metadata_source=metadata_source, + memory_cache_mb=memory_cache_mb, + config=config, + read_only=read_only, + ) def to_dict(self): - config = { - "type": "filesystem" - } + config = {"type": "filesystem"} if self.read_only is not None: config["readonly"] = self.read_only if self.config_path is not None: config["path"] = self.config_path if self._memory_cache is not None: - config["memory_cache_mb"] = self._memory_cache.memory_cache_bytes / 1024 / 1024 + config["memory_cache_mb"] = ( + self._memory_cache.memory_cache_bytes / 1024 / 1024 + ) return config @@ -389,7 +438,9 @@ def __setitem__(self, key: str, value: object): """ result_type = ResultType.from_object(value) self._result_types[key] = result_type - self._result_keys[key] = self._codec.store(result_type, self._data_source, None, value) + self._result_keys[key] = self._codec.store( + result_type, self._data_source, None, value + ) def __getitem__(self, item: str) -> object: """ @@ -402,8 +453,9 @@ def get(self, key: str) -> object: if self._merge_parent: return self._merge_parent.get(key) raise ValueError("Key '{}' not in key list for partition".format(key)) - return self._codec.load(self._result_types[key], self._data_source, - self._result_keys[key]) + return self._codec.load( + self._result_types[key], self._data_source, self._result_keys[key] + ) def list_keys(self, _include_merge_parent: bool = True) -> Iterable[str]: if _include_merge_parent and self._merge_parent: diff --git a/twosigma/memento/storage_memory.py b/twosigma/memento/storage_memory.py index 6d66141..8f9687e 100644 --- a/twosigma/memento/storage_memory.py +++ b/twosigma/memento/storage_memory.py @@ -28,8 +28,11 @@ from typing import Iterable, List, Optional, Dict # noqa: F401 from .metadata import Memento -from .reference import FunctionReference, FunctionReferenceWithArguments, \ - FunctionReferenceWithArgHash +from .reference import ( + FunctionReference, + FunctionReferenceWithArguments, + FunctionReferenceWithArgHash, +) from .types import FunctionNotFoundError, VersionedDataSourceKey from .storage import StorageBackend @@ -59,13 +62,21 @@ def __init__(self, config: dict = None, read_only: bool = None): @staticmethod def _get_memento_key(fn_with_arg_hash: FunctionReferenceWithArgHash) -> str: - return fn_with_arg_hash.fn_reference.qualified_name + "/" + fn_with_arg_hash.arg_hash - - def get_mementos(self, fns: List[FunctionReferenceWithArgHash]) -> List[Optional[Memento]]: + return ( + fn_with_arg_hash.fn_reference.qualified_name + + "/" + + fn_with_arg_hash.arg_hash + ) + + def get_mementos( + self, fns: List[FunctionReferenceWithArgHash] + ) -> List[Optional[Memento]]: results = [] for fn_ref_with_arg_hash in fns: try: - memento_dict = self.mementos[fn_ref_with_arg_hash.fn_reference.qualified_name] + memento_dict = self.mementos[ + fn_ref_with_arg_hash.fn_reference.qualified_name + ] memento = memento_dict.get(fn_ref_with_arg_hash.arg_hash) except FunctionNotFoundError: memento = None @@ -73,22 +84,33 @@ def get_mementos(self, fns: List[FunctionReferenceWithArgHash]) -> List[Optional return results def read_result(self, memento: Memento) -> object: - return self.result[self._get_memento_key( - memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash())] + return self.result[ + self._get_memento_key( + memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash() + ) + ] def make_url_for_result(self, memento: Memento) -> Optional[str]: return memento.content_key - def read_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, key: str, - retry_on_none=False) -> bytes: + def read_metadata( + self, + fn_with_arg_hash: FunctionReferenceWithArgHash, + key: str, + retry_on_none=False, + ) -> bytes: # Ignore retry_on_none since the in-memory metadata store is consistent. memento_key = self._get_memento_key(fn_with_arg_hash) metadata_dict = self.metadata[memento_key] # type: Dict[str, bytes] return metadata_dict.get(key) - def write_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, - key: str, value: bytes, - store_with_content_key: Optional[VersionedDataSourceKey] = None): + def write_metadata( + self, + fn_with_arg_hash: FunctionReferenceWithArgHash, + key: str, + value: bytes, + store_with_content_key: Optional[VersionedDataSourceKey] = None, + ): if self.read_only: raise ValueError("Cannot write metadata to a read-only storage backend") memento_key = self._get_memento_key(fn_with_arg_hash) @@ -100,11 +122,17 @@ def is_memoized(self, fn_reference: FunctionReference, arg_hash: str) -> bool: return memento_dict and arg_hash in memento_dict def is_all_memoized(self, fns: Iterable[FunctionReferenceWithArguments]) -> bool: - return all([self.is_memoized(fn_ref_with_arg.fn_reference, fn_ref_with_arg.arg_hash) - for fn_ref_with_arg in fns]) + return all( + [ + self.is_memoized(fn_ref_with_arg.fn_reference, fn_ref_with_arg.arg_hash) + for fn_ref_with_arg in fns + ] + ) def list_functions(self) -> List[FunctionReference]: - return [FunctionReference.from_qualified_name(key) for key in self.mementos.keys()] + return [ + FunctionReference.from_qualified_name(key) for key in self.mementos.keys() + ] def list_mementos(self, fn: FunctionReference, limit: int = None) -> List[Memento]: return list(self.mementos[fn.qualified_name].values())[0:limit] @@ -112,12 +140,16 @@ def list_mementos(self, fn: FunctionReference, limit: int = None) -> List[Mement def memoize(self, key_override: str, memento: Memento, result: object) -> None: if self.read_only: return - memento_key = self._get_memento_key(memento.invocation_metadata.fn_reference_with_args. - fn_reference_with_arg_hash()) - memento.content_key = VersionedDataSourceKey(key_override, "") if key_override else None + memento_key = self._get_memento_key( + memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash() + ) + memento.content_key = ( + VersionedDataSourceKey(key_override, "") if key_override else None + ) self.result[memento_key] = result - qualified_name = memento.invocation_metadata.fn_reference_with_args.\ - fn_reference.qualified_name + qualified_name = ( + memento.invocation_metadata.fn_reference_with_args.fn_reference.qualified_name + ) arg_hash = memento.invocation_metadata.fn_reference_with_args.arg_hash self.mementos[qualified_name][arg_hash] = memento @@ -147,14 +179,13 @@ def forget_function(self, fn_reference: FunctionReference) -> None: raise ValueError("Cannot forget with a storage backend that is read-only") qualified_name = fn_reference.qualified_name for memento in self.list_mementos(fn_reference): - self.forget_call(memento.invocation_metadata.fn_reference_with_args. - fn_reference_with_arg_hash()) + self.forget_call( + memento.invocation_metadata.fn_reference_with_args.fn_reference_with_arg_hash() + ) self.mementos[qualified_name].clear() def to_dict(self): - config = { - "type": "memory" - } + config = {"type": "memory"} if self.read_only is not None: config["readonly"] = self.read_only return config diff --git a/twosigma/memento/storage_null.py b/twosigma/memento/storage_null.py index 2ad224e..2da65a0 100644 --- a/twosigma/memento/storage_null.py +++ b/twosigma/memento/storage_null.py @@ -18,8 +18,11 @@ """ from typing import List, Iterable, Optional -from .reference import FunctionReference, FunctionReferenceWithArguments, \ - FunctionReferenceWithArgHash +from .reference import ( + FunctionReference, + FunctionReferenceWithArguments, + FunctionReferenceWithArgHash, +) from .metadata import Memento from .storage import StorageBackend from .types import VersionedDataSourceKey @@ -29,7 +32,9 @@ class NullStorageBackend(StorageBackend): def __init__(self, config: dict = None): super().__init__("null", config=config) - def get_mementos(self, fns: List[FunctionReferenceWithArgHash]) -> List[Optional[Memento]]: + def get_mementos( + self, fns: List[FunctionReferenceWithArgHash] + ) -> List[Optional[Memento]]: return [None] * len(fns) def read_result(self, memento: Memento) -> object: @@ -38,13 +43,21 @@ def read_result(self, memento: Memento) -> object: def make_url_for_result(self, memento: Memento) -> Optional[str]: raise ValueError("Null backend has no memoized results for any invocations") - def read_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, key: str, - retry_on_none=False) -> object: + def read_metadata( + self, + fn_with_arg_hash: FunctionReferenceWithArgHash, + key: str, + retry_on_none=False, + ) -> object: return None - def write_metadata(self, fn_with_arg_hash: FunctionReferenceWithArgHash, key: str, - value: bytes, - store_with_content_key: Optional[VersionedDataSourceKey] = None): + def write_metadata( + self, + fn_with_arg_hash: FunctionReferenceWithArgHash, + key: str, + value: bytes, + store_with_content_key: Optional[VersionedDataSourceKey] = None, + ): pass def is_memoized(self, fn_reference: FunctionReference, arg_hash: str) -> bool: @@ -73,9 +86,7 @@ def forget_function(self, fn_reference: FunctionReference) -> None: pass def to_dict(self): - config = { - "type": "null" - } + config = {"type": "null"} if self.read_only is not None: config["readonly"] = self.read_only return config diff --git a/twosigma/memento/types.py b/twosigma/memento/types.py index 5dc44cd..eb356b6 100644 --- a/twosigma/memento/types.py +++ b/twosigma/memento/types.py @@ -52,7 +52,9 @@ def with_verbose(self, verbose: bool) -> "DependencyGraphType": pass @abstractmethod - def with_label_filter(self, label_filter: Callable[[str], str]) -> "DependencyGraphType": + def with_label_filter( + self, label_filter: Callable[[str], str] + ) -> "DependencyGraphType": pass @@ -91,19 +93,20 @@ def fn_reference(self): @abstractmethod def clone_with( - self, - fn: Callable = None, - src_fn: Callable = None, - cluster_name: str = None, - version: str = None, - calculated_version: str = None, - context: InvocationContext = None, - partial_args: Tuple[Any] = None, - partial_kwargs: Dict[str, Any] = None, - auto_dependencies: bool = True, - dependencies: List[Union[str, "MementoFunctionType"]] = None, - version_code_hash: str = None, - version_salt: str = None) -> "MementoFunctionType": + self, + fn: Callable = None, + src_fn: Callable = None, + cluster_name: str = None, + version: str = None, + calculated_version: str = None, + context: InvocationContext = None, + partial_args: Tuple[Any] = None, + partial_kwargs: Dict[str, Any] = None, + auto_dependencies: bool = True, + dependencies: List[Union[str, "MementoFunctionType"]] = None, + version_code_hash: str = None, + version_salt: str = None, + ) -> "MementoFunctionType": pass @abstractmethod @@ -115,8 +118,9 @@ def __call__(self, *args, **kwargs): pass @abstractmethod - def call_batch(self, kwargs_list: List[Dict[str, Any]], - raise_first_exception=True) -> List[Any]: + def call_batch( + self, kwargs_list: List[Dict[str, Any]], raise_first_exception=True + ) -> List[Any]: pass @abstractmethod @@ -177,7 +181,8 @@ def force_local(self, local: bool = True): @abstractmethod def dependencies( - self, verbose=False, label_filter: Callable[[str], str] = None) -> DependencyGraphType: + self, verbose=False, label_filter: Callable[[str], str] = None + ) -> DependencyGraphType: pass @abstractmethod @@ -199,4 +204,5 @@ class FunctionNotFoundError(ValueError): to an actual function. """ + pass