diff --git a/docs/source/saving.rst b/docs/source/saving.rst index c8cf995..115e840 100644 --- a/docs/source/saving.rst +++ b/docs/source/saving.rst @@ -2,4 +2,5 @@ saving =============== **saving** contains altair saving functions +.. autofunction:: nesta_ds_utils.viz.altair.saving.webdriver_context .. autofunction:: nesta_ds_utils.viz.altair.saving.save diff --git a/nesta_ds_utils/viz/altair/saving.py b/nesta_ds_utils/viz/altair/saving.py index bc0e74f..6a932c0 100644 --- a/nesta_ds_utils/viz/altair/saving.py +++ b/nesta_ds_utils/viz/altair/saving.py @@ -2,13 +2,11 @@ Module containing utils for styling and exporting figures using Altair. """ -import altair_saver as alt_saver from altair.vegalite import Chart import altair as alt -from selenium import webdriver from webdriver_manager.chrome import ChromeDriverManager from selenium.webdriver.chrome.webdriver import WebDriver -from selenium.webdriver.chrome.options import Options +from selenium.webdriver import Chrome, ChromeOptions, ChromeService import os from typing import Union, List, Type import warnings @@ -16,18 +14,42 @@ from pathlib import Path from nesta_ds_utils.loading_saving import file_ops import yaml +from contextlib import contextmanager def _google_chrome_driver_setup() -> WebDriver: """Set up the driver to save figures""" - chrome_options = Options() - chrome_options.add_argument("--headless") - driver = webdriver.Chrome( - ChromeDriverManager().install(), chrome_options=chrome_options - ) + service = ChromeService(ChromeDriverManager().install()) + chrome_options = ChromeOptions() + chrome_options.add_argument("--headless=new") + driver = Chrome(service=service, options=chrome_options) return driver +@contextmanager +def webdriver_context(driver: WebDriver = None): + """Context Manager for Selenium WebDrivers. + Optionally pass in user-instantiated Selenium Webdriver. + Defaults to setup and yield a ChromeWebDriver. + + Typical usage: + + with webdriver_context(webdriver or None) as driver: + # Do stuff with driver, driver.quit() is then called automatically + + Args: + driver (WebDriver, optional): Webdriver to use. Defaults to 'webdriver.Chrome'. + + Yields: + WebDriver: The optional user-instantiated Selenium WebDriver or a Selenium ChromeWebDriver. + """ + try: + driver = _google_chrome_driver_setup() if driver is None else driver + yield driver + finally: + driver.quit() + + def _save_png( fig: Chart, path: os.PathLike, name: str, scale_factor: int, driver: WebDriver ): @@ -40,8 +62,7 @@ def _save_png( scale_factor (int): Saving scale factor. driver (WebDriver): webdriver to use for saving. """ - alt_saver.save( - fig, + fig.save( f"{path}/{name}.png", method="selenium", webdriver=driver, @@ -73,8 +94,7 @@ def _save_svg( scale_factor (int): Saving scale factor. driver (WebDriver): webdriver to use for saving. """ - alt_saver.save( - fig, + fig.save( f"{path}/{name}.svg", method="selenium", scale_factor=scale_factor, @@ -104,27 +124,26 @@ def save( save_svg (bool, optional): Option to save figure as 'svg'. Default to False. scale_factor (int, optional): Saving scale factor. Default to 5. """ - path = file_ops._convert_str_to_pathlib_path(path) - if not any([save_png, save_html, save_svg]): raise Exception( "At least one format needs to be selected. Example: save(.., save_png=True)." ) + path = file_ops._convert_str_to_pathlib_path(path) + file_ops.make_path_if_not_exist(path) + if save_png or save_svg: - driver = _google_chrome_driver_setup() if driver is None else driver + with webdriver_context(driver): + # Export figures + if save_png: + _save_png(fig, path, name, scale_factor, driver) - file_ops.make_path_if_not_exist(path) - # Export figures - if save_png: - _save_png(fig, path, name, scale_factor, driver) + if save_svg: + _save_svg(fig, path, name, scale_factor, driver) if save_html: _save_html(fig, path, name, scale_factor) - if save_svg: - _save_svg(fig, path, name, scale_factor, driver) - def _find_averta() -> str: """Search for averta font, otherwise return 'Helvetica' and raise a warning. diff --git a/setup.cfg b/setup.cfg index f6ed202..c6d5c43 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,37 +12,39 @@ packages = ["nesta_ds_utils"] [options] python_requires = >=3.8 install_requires = - numpy==1.23.4 - pandas==1.5.1 + numpy>=1.23.4 + pandas>=1.5.1 pyyaml<5.4.0 - scipy==1.9.3 - pyarrow==10.0.0 + scipy>=1.9.3 + pyarrow>=10.0.0 [options.extras_require] s3 = - boto3==1.24.93 + boto3>=1.24.93 gis = - geopandas==0.13.2 -io_extras = - openpyxl==3.0.9 + geopandas>=0.13.2 +io_extras = + openpyxl>=3.0.9 viz = - altair==4.2.0 - altair-saver==0.5.0 - matplotlib==3.6.2 - selenium==4.2.0 - webdriver_manager==4.0.0 + altair>=4.2.0 + vl-convert-python>=1.2.0 + matplotlib>=3.6.2 + selenium>=4.2.0 + webdriver_manager>=4.0.0 networks = networkx==2.8.8 nlp = - nltk==3.7 -test = - pytest==7.1.3 - moto[s3]==4.0.7 + nltk>=3.7 +all = %(s3)s %(gis)s %(io_extras)s %(viz)s %(networks)s %(nlp)s +test = + pytest==7.1.3 + moto[s3]==4.0.7 + %(all)s dev = Sphinx==5.2.3 sphinxcontrib-applehelp==1.0.2 @@ -51,24 +53,10 @@ dev = sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 - pytest==7.1.3 - moto[s3]==4.0.7 pre-commit==2.20.0 pre-commit-hooks==4.3.0 black==22.10.0 - %(s3)s - %(gis)s - %(io_extras)s - %(viz)s - %(networks)s - %(nlp)s -all = - %(s3)s - %(gis)s - %(io_extras)s - %(viz)s - %(networks)s - %(nlp)s + %(test)s [options.package_data] nesta_ds_utils.viz.themes = diff --git a/tests/viz/altair/test_saving.py b/tests/viz/altair/test_saving.py index 538cc9b..426ffae 100644 --- a/tests/viz/altair/test_saving.py +++ b/tests/viz/altair/test_saving.py @@ -1,6 +1,7 @@ import shutil from pathlib import Path from nesta_ds_utils.viz.altair import saving +from selenium.webdriver.chromium.webdriver import ChromiumDriver import pandas as pd import altair as alt import pytest @@ -41,3 +42,18 @@ def test_save_altair_exception(): saving.save( fig, "test_fig", path, save_png=False, save_html=False, save_svg=False ) + + +def test_webdriver(): + """Test that Chrome WebDriver is created by default is a ChromiumDriver, and context manager stops the webdriver.""" + driver = saving._google_chrome_driver_setup() + assert isinstance(driver, ChromiumDriver) + + with saving.webdriver_context(driver) as some_driver: + # No actions needed here, + # just testing that the context manager calls .quit() on driver to terminate. + pass + + # If subprocess not terminated, .poll() returns None + # https://docs.python.org/3/library/subprocess.html#subprocess.Popen.returncode + assert driver.service.process.poll() is not None