diff --git a/kedro/extras/extensions/__init__.py b/kedro/extras/extensions/__init__.py new file mode 100644 index 0000000000..128cff94e7 --- /dev/null +++ b/kedro/extras/extensions/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains an IPython extension. +""" diff --git a/kedro/extras/extensions/ipython.py b/kedro/extras/extensions/ipython.py new file mode 100644 index 0000000000..e72a4d9903 --- /dev/null +++ b/kedro/extras/extensions/ipython.py @@ -0,0 +1,133 @@ +# Copyright 2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=import-outside-toplevel,global-statement,invalid-name +""" +This script creates an IPython extension to load Kedro-related variables in +local scope. +""" +import logging.config +import sys +from pathlib import Path + +from IPython import get_ipython +from IPython.core.magic import needs_local_scope, register_line_magic + +project_path = Path.cwd() +catalog = None +context = None +session = None + + +def _remove_cached_modules(package_name): + to_remove = [mod for mod in sys.modules if mod.startswith(package_name)] + # `del` is used instead of `reload()` because: If the new version of a module does not + # define a name that was defined by the old version, the old definition remains. + for module in to_remove: + del sys.modules[module] # pragma: no cover + + +def _clear_hook_manager(): + from kedro.framework.hooks import get_hook_manager + + hook_manager = get_hook_manager() + name_plugin_pairs = hook_manager.list_name_plugin() + for name, plugin in name_plugin_pairs: + hook_manager.unregister(name=name, plugin=plugin) # pragma: no cover + + +def load_kedro_objects(path, line=None): # pylint: disable=unused-argument + """Line magic which reloads all Kedro default variables.""" + + import kedro.config.default_logger # noqa: F401 # pylint: disable=unused-import + from kedro.framework.cli import load_entry_points + from kedro.framework.context.context import _add_src_to_path + from kedro.framework.project.metadata import _get_project_metadata + from kedro.framework.session import KedroSession + from kedro.framework.session.session import _activate_session + + global context + global catalog + global session + + path = path or project_path + project_metadata = _get_project_metadata(path) + _add_src_to_path(project_metadata.source_dir, path) + + session = KedroSession.create(path) + _activate_session(session) + + _remove_cached_modules(project_metadata.package_name) + + # clear hook manager; hook implementations will be re-registered when the + # context is instantiated again in `session.context` below + _clear_hook_manager() + + logging.debug("Loading the context from %s", str(path)) + # Reload context to fix `pickle` related error (it is unable to serialize reloaded objects) + # Some details can be found here: + # https://modwsgi.readthedocs.io/en/develop/user-guides/issues-with-pickle-module.html#packing-and-script-reloading + context = session.load_context() + catalog = context.catalog + get_ipython().push( + variables={"context": context, "catalog": catalog, "session": session} + ) + + logging.info("** Kedro project %s", str(project_metadata.project_name)) + logging.info("Defined global variable `context`, `session` and `catalog`") + + for line_magic in load_entry_points("line_magic"): + register_line_magic(needs_local_scope(line_magic)) + logging.info("Registered line magic `%s`", line_magic.__name__) + + +def init_kedro(path=""): + """Line magic to set path to Kedro project. + `%reload_kedro` will default to this location. + """ + global project_path + if path: + project_path = Path(path).expanduser().resolve() + logging.info("Updated path to Kedro project: %s", str(project_path)) + else: + logging.info("No path argument was provided. Using: %s", str(project_path)) + + +def load_ipython_extension(ipython): + """Main entry point when %load_ext is executed""" + ipython.register_magic_function(init_kedro, "line") + ipython.register_magic_function(load_kedro_objects, "line", "reload_kedro") + + try: + load_kedro_objects(project_path) + except (ImportError, ModuleNotFoundError): + logging.error("Kedro appears not to be installed in your current environment.") + except Exception: # pylint: disable=broad-except + logging.error( + "Could not register Kedro extension. Make sure you're in a valid Kedro project.", + exc_info=True, + ) diff --git a/requirements.txt b/requirements.txt index ca9b904898..e040f1eb15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ pluggy~=0.13.0 python-json-logger~=0.1.9 PyYAML>=4.2, <6.0 setuptools>=38.0 +toml~=0.10 toposort~=1.5 # Needs to be at least 1.5 to be able to raise CircularDependencyError diff --git a/setup.cfg b/setup.cfg index f9890ff35c..f3860c722a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,7 +29,7 @@ layers = framework.context framework.project runner - extras + extras.datasets io pipeline config @@ -53,7 +53,7 @@ forbidden_modules = kedro.runner kedro.io kedro.pipeline - kedro.extras + kedro.extras.datasets [importlinter:contract:4] name = Runner et al cannot import Config @@ -62,7 +62,7 @@ source_modules = kedro.runner kedro.io kedro.pipeline - kedro.extras + kedro.extras.datasets forbidden_modules = kedro.config ignore_imports= diff --git a/setup.py b/setup.py index 932679ee99..3a7126f695 100644 --- a/setup.py +++ b/setup.py @@ -136,6 +136,7 @@ def _collect_requirements(requires): "ipykernel>=4.8.1, <5.0", ], "geopandas": _collect_requirements(geopandas_require), + "ipython": ["ipython~=7.0"], "matplotlib": _collect_requirements(matplotlib_require), "holoviews": _collect_requirements(holoviews_require), "networkx": _collect_requirements(networkx_require), diff --git a/tests/extras/extensions/__init__.py b/tests/extras/extensions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/extras/extensions/test_ipython.py b/tests/extras/extensions/test_ipython.py new file mode 100644 index 0000000000..616d8b037d --- /dev/null +++ b/tests/extras/extensions/test_ipython.py @@ -0,0 +1,173 @@ +# Copyright 2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=import-outside-toplevel,reimported +import pytest + +from kedro.extras.extensions.ipython import ( + init_kedro, + load_ipython_extension, + load_kedro_objects, +) +from kedro.framework.project import ProjectMetadata +from kedro.framework.session.session import _deactivate_session + + +@pytest.fixture(autouse=True) +def project_path(mocker, tmp_path): + path = tmp_path + mocker.patch("kedro.extras.extensions.ipython.project_path", path) + + +@pytest.fixture(autouse=True) +def cleanup_session(): + yield + _deactivate_session() + + +class TestInitKedro: + def test_init_kedro(self, tmp_path, caplog): + from kedro.extras.extensions.ipython import project_path + + assert project_path == tmp_path + + kedro_path = tmp_path / "here" + init_kedro(str(kedro_path)) + expected_path = kedro_path.expanduser().resolve() + expected_message = f"Updated path to Kedro project: {expected_path}" + + log_messages = [record.getMessage() for record in caplog.records] + assert expected_message in log_messages + from kedro.extras.extensions.ipython import project_path + + # make sure global variable updated + assert project_path == expected_path + + def test_init_kedro_no_path(self, tmp_path, caplog): + from kedro.extras.extensions.ipython import project_path + + assert project_path == tmp_path + + init_kedro() + expected_message = f"No path argument was provided. Using: {tmp_path}" + + log_messages = [record.getMessage() for record in caplog.records] + assert expected_message in log_messages + from kedro.extras.extensions.ipython import project_path + + # make sure global variable stayed the same + assert project_path == tmp_path + + +class TestLoadKedroObjects: + def test_load_kedro_objects(self, tmp_path, mocker): + fake_metadata = ProjectMetadata( + source_dir=tmp_path / "src", # default + config_file=tmp_path / "pyproject.toml", + package_name="fake_package_name", + project_name="fake_project_name", + project_version="0.1", + context_path="hello.there", + ) + mocker.patch( + "kedro.framework.project.metadata._get_project_metadata", + return_value=fake_metadata, + ) + mocker.patch( + "kedro.framework.session.session._get_project_metadata", + return_value=fake_metadata, + ) + mocker.patch("kedro.framework.context.context._add_src_to_path") + mock_line_magic = mocker.MagicMock() + mock_line_magic.__name__ = "abc" + mocker.patch( + "kedro.framework.cli.load_entry_points", return_value=[mock_line_magic] + ) + mock_register_line_magic = mocker.patch( + "kedro.extras.extensions.ipython.register_line_magic" + ) + mock_context = mocker.patch("kedro.framework.session.KedroSession.load_context") + mock_ipython = mocker.patch("kedro.extras.extensions.ipython.get_ipython") + + load_kedro_objects(tmp_path) + + mock_ipython().push.assert_called_once_with( + variables={ + "context": mock_context(), + "catalog": mock_context().catalog, + "session": mocker.ANY, + } + ) + assert mock_register_line_magic.call_count == 1 + + def test_load_kedro_objects_not_in_kedro_project(self, tmp_path, mocker): + mocker.patch( + "kedro.framework.project.metadata._get_project_metadata", + side_effect=[RuntimeError], + ) + mock_ipython = mocker.patch("kedro.extras.extensions.ipython.get_ipython") + + with pytest.raises(RuntimeError): + load_kedro_objects(tmp_path) + assert not mock_ipython().called + assert not mock_ipython().push.called + + +class TestLoadIPythonExtension: + @pytest.mark.parametrize( + "error,expected_log_message", + [ + ( + ImportError, + "Kedro appears not to be installed in your current environment.", + ), + ( + RuntimeError, + "Could not register Kedro extension. Make sure you're in a valid Kedro project.", + ), + ], + ) + def test_load_extension_not_in_kedro_env_or_project( + self, error, expected_log_message, mocker, caplog + ): + mocker.patch( + "kedro.framework.project.metadata._get_project_metadata", + side_effect=[error], + ) + mock_ipython = mocker.patch("kedro.extras.extensions.ipython.get_ipython") + + load_ipython_extension(mocker.MagicMock()) + + assert not mock_ipython().called + assert not mock_ipython().push.called + + log_messages = [ + record.getMessage() + for record in caplog.records + if record.levelname == "ERROR" + ] + assert log_messages == [expected_log_message] diff --git a/tests/framework/session/test_store.py b/tests/framework/session/test_store.py index acda2f47ae..62dc615487 100644 --- a/tests/framework/session/test_store.py +++ b/tests/framework/session/test_store.py @@ -129,7 +129,7 @@ def test_from_config_uncaught_error(self, mocker): with pytest.raises(ValueError, match=re.escape(pattern)): BaseSessionStore.from_config(config) - assert mocked_init.called_once_with(**config) + mocked_init.assert_called_once_with(**config) @pytest.fixture