diff --git a/pyscal/plotting.py b/pyscal/plotting.py index 0b44373b..f29020e5 100644 --- a/pyscal/plotting.py +++ b/pyscal/plotting.py @@ -127,13 +127,14 @@ def get_satnum_from_tag(string: str) -> int: def get_plot_config_options(curve_type: str, **kwargs) -> dict: """ - Get config data from plot config dictionary based on the curve type + Get config data from plot config dictionary based on the curve (model) type. Args: - curve_type (str): _description_ + curve_type (str): Name of the curve type. Allowed types are given in + the PLOT_CONFIG_OPTIONS dictionary Returns: - dict: _description_ + dict: Config parameters for the chosen model type """ config = PLOT_CONFIG_OPTIONS[curve_type].copy() @@ -306,13 +307,16 @@ def save_figure( plot_type: str, outdir: str, ) -> None: - """_summary_ + """ + + Save the provided figure. Args: fig (plt.Figure): Figure to be saved satnum (int): SATNUM number config (dict): Plot config plot_type (str): Figure type. Allowed types are 'relperm' and 'pc' + outdir (str): Directory where the figure will be saved """ # Get curve name @@ -333,18 +337,23 @@ def save_figure( bbox_inches="tight", ) + print(f"Figure saved to {fout}.png") + # Clear figure so that it is empty for the next SATNUM's plot fig.clear() def wog_plotter(model: WaterOilGas, **kwargs) -> None: - """_summary_ + """ + + Plot a WaterOilGas (WaterOil and GasOil) model. + For a WaterOilGas instance, the WaterOil and GasOil instances can be accessed, then the "table" instance variable. Args: - model (WaterOilGas): _description_ + model (WaterOilGas): WaterOilGas instance """ outdir = kwargs["outdir"] @@ -390,11 +399,13 @@ def wog_plotter(model: WaterOilGas, **kwargs) -> None: def wo_plotter(model: WaterOil, **kwargs) -> None: """ + Plot a WaterOil model. + For a WaterOil instance, the saturation table can be accessed using the "table" instance variable. Args: - model (WaterOil): _description_ + model (WaterOil): WaterOil instance """ config = get_plot_config_options("WaterOil", **kwargs) satnum = get_satnum_from_tag(model.tag) @@ -419,11 +430,13 @@ def wo_plotter(model: WaterOil, **kwargs) -> None: def go_plotter(model: GasOil, **kwargs) -> None: """ + Plot a GasOil model. + For a GasOil instance, the saturation table can be accessed using the "table" instance variable. Args: - model (GasOil): _description_ + model (GasOil): GasOil instance """ config = get_plot_config_options("GasOil", **kwargs) @@ -450,9 +463,16 @@ def go_plotter(model: GasOil, **kwargs) -> None: def gw_plotter(model: GasWater, **kwargs) -> None: - # For GasWater, the format is different, and an additional formatting step is - # required. Use the formatted table as an argument to the plotter function, - # instead of the "table" instance variable + """ + + For GasWater, the format is different, and an additional formatting step is + required. Use the formatted table as an argument to the plotter function, + instead of the "table" instance variable + + Args: + model (GasWater): GasWater instance + """ + table = format_gaswater_table(model) config = get_plot_config_options("GasWater", **kwargs) satnum = get_satnum_from_tag(model.tag) diff --git a/pyscal/pyscalcli.py b/pyscal/pyscalcli.py index 159cf93e..50abc275 100644 --- a/pyscal/pyscalcli.py +++ b/pyscal/pyscalcli.py @@ -320,4 +320,3 @@ def pyscal_main( if plot: plotting.plotter(wog_list, plot_pc, plot_semilog, plot_outdir) - print(f"Plots saved in {plot_outdir}") diff --git a/tests/test_plotting.py b/tests/test_plotting.py index ea52f425..2500b5da 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -1,3 +1,5 @@ +"""Test the plotting module""" + from pathlib import Path import matplotlib.pyplot as plt @@ -7,6 +9,7 @@ def test_get_satnum_from_tag(): + """Check that the SATNUM number can be retrieved from the model tag""" # Several PyscalLists of different model types to be checked pyscal_lists = [ PyscalList( @@ -48,19 +51,23 @@ def test_get_satnum_from_tag(): def test_plotter(): - # Check if Exception is raised if a model type is not included. This is done - # to check that all models have been implemented in the plotting module. + """Check if an Exception is raised if a model type is not included. This is done + to check that all models have been implemented in the plotting module.""" class DummyPyscalList: - # Can't use the actual PyscalList, as this will raise its own exception - # (DummyModel is not a pyscal object), so a dummy PyscalList is used + """ + Can't use the actual PyscalList, as this will raise its own exception + (DummyModel is not a pyscal object), so a dummy PyscalList is used + + #If the PyscalList.pyscal_list instance variable name changes, this + will still pass...""" - # If the PyscalList.pyscal_list instance variable name changes, this - # will still pass... def __init__(self, models: list) -> None: self.pyscal_list = models class DummyModel: + """Dummy model""" + def __init__(self, tag: str) -> None: self.tag = tag @@ -77,8 +84,8 @@ def __init__(self, tag: str) -> None: def test_pyscal_list_attr(): - # Check that the PyscalList class has an pyscal_list instance variable. - # This is access by the plotting module to loop through models to plot. + """Check that the PyscalList class has an pyscal_list instance variable. + This is accessed by the plotting module to loop through models to plot.""" assert ( hasattr(PyscalList(), "pyscal_list") is True ), "The PyscalList object should have a pyscal_list instance variable.\ @@ -86,7 +93,7 @@ def test_pyscal_list_attr(): def test_plot_relperm(): - # Test that a matplotlib.pyplot Figure instance is returned + """Test that a matplotlib.pyplot Figure instance is returned""" wateroil = WaterOil(swl=0.1, h=0.1) wateroil.add_corey_water() wateroil.add_corey_oil() @@ -101,7 +108,7 @@ def test_plot_relperm(): def test_plot_pc(): - # Test that a matplotlib.pyplot Figure instance is returned + """Test that a matplotlib.pyplot Figure instance is returned""" wateroil = WaterOil(swl=0.1, h=0.1) wateroil.add_corey_water() wateroil.add_corey_oil() @@ -117,7 +124,7 @@ def test_plot_pc(): def test_wog_plotter(tmpdir): - # Test if relative permeability figures are created by the plotter function + """Test that relative permeability figures are created by the plotter function""" wateroil = WaterOil(swl=0.1, h=0.1, tag="SATNUM 1") wateroil.add_corey_water() wateroil.add_corey_oil() @@ -145,7 +152,7 @@ def test_wog_plotter(tmpdir): def test_wo_plotter(tmpdir): - # Test if relative permeability figures are created by the plotter function + """Test that relative permeability figures are created by the plotter function""" wateroil = WaterOil(swl=0.1, h=0.1, tag="SATNUM 1") wateroil.add_corey_water() wateroil.add_corey_oil() @@ -163,7 +170,7 @@ def test_wo_plotter(tmpdir): def test_wo_plotter_relperm_only(tmpdir): - # Test if relative permeability figures are created by the plotter function + """Test that relative permeability figures are created by the plotter function""" wateroil = WaterOil(swl=0.1, h=0.1, tag="SATNUM 1") wateroil.add_corey_water() wateroil.add_corey_oil() @@ -183,7 +190,7 @@ def test_wo_plotter_relperm_only(tmpdir): def test_go_plotter(tmpdir): - # Test if relative permeability figures are created by the plotter function + """Test that relative permeability figures are created by the plotter function""" gasoil = GasOil(swl=0.1, h=0.1, tag="SATNUM 1") gasoil.add_corey_gas() gasoil.add_corey_oil() @@ -204,7 +211,7 @@ def test_go_plotter(tmpdir): def test_gw_plotter(tmpdir): - # Test if relative permeability figures are created by the plotter function + """Test that relative permeability figures are created by the plotter function""" gaswater = GasWater(swl=0.1, h=0.1, tag="SATNUM 1") gaswater.add_corey_water() gaswater.add_corey_gas() @@ -222,7 +229,7 @@ def test_gw_plotter(tmpdir): def test_save_figure(tmpdir): - # Test that figure is saved + """Test that figure is saved""" fig = plt.Figure() config = {"curves": "dummy", "suffix": ""} diff --git a/tests/test_pyscalcli.py b/tests/test_pyscalcli.py index aabb459d..543aea78 100644 --- a/tests/test_pyscalcli.py +++ b/tests/test_pyscalcli.py @@ -586,3 +586,44 @@ def test_pyscal_main(): with pytest.raises(ValueError, match="Interpolation parameter provided"): pyscalcli.pyscal_main(relperm_file, int_param_wo=-1, output=os.devnull) + + +def test_pyscalcli_plot(capsys, mocker, tmpdir): + """Test that plots are created through the CLI. This is done by testing to + see if the print statements in the save_figure function are present in stdout""" + scalrec_file = Path(__file__).absolute().parent / "data/scal-pc-input-example.xlsx" + + mocker.patch( + "sys.argv", + [ + "pyscal", + str(scalrec_file), + "--int_param_wo", + "0", + "--output", + "-", + "--plot", + "--plot_pc", + "--plot_outdir", + str(tmpdir), + ], + ) + + pyscalcli.main() + + expected_plots = [ + "krw_krow_SATNUM_1.png", + "krg_krog_SATNUM_1.png", + "krw_krow_SATNUM_2.png", + "krg_krog_SATNUM_2.png", + "krw_krow_SATNUM_3.png", + "krg_krog_SATNUM_3.png", + "pcow_SATNUM_1.png", + "pcow_SATNUM_2.png", + "pcow_SATNUM_3.png", + ] + + captured = capsys.readouterr() + + for plot in expected_plots: + assert plot in captured.out