diff --git a/tests/unit/test_parameters/test_process_parameter_data.py b/tests/unit/test_parameters/test_process_parameter_data.py index 9352894c5c..6767ee8358 100644 --- a/tests/unit/test_parameters/test_process_parameter_data.py +++ b/tests/unit/test_parameters/test_process_parameter_data.py @@ -2,57 +2,57 @@ # Tests for the parameter processing functions # - import os import numpy as np import pybamm - import pytest -class TestProcessParameterData: - def test_process_1D_data(self): - name = "lico2_ocv_example" - path = os.path.abspath(os.path.dirname(__file__)) - processed = pybamm.parameters.process_1D_data(name, path) - assert processed[0] == name - assert isinstance(processed[1], tuple) - assert isinstance(processed[1][0][0], np.ndarray) - assert isinstance(processed[1][1], np.ndarray) +@pytest.fixture +def path(): + return os.path.abspath(os.path.dirname(__file__)) - def test_process_2D_data(self): - name = "lico2_diffusivity_Dualfoil1998_2D" - path = os.path.abspath(os.path.dirname(__file__)) - processed = pybamm.parameters.process_2D_data(name, path) - assert processed[0] == name - assert isinstance(processed[1], tuple) - assert isinstance(processed[1][0][0], np.ndarray) - assert isinstance(processed[1][0][1], np.ndarray) - assert isinstance(processed[1][1], np.ndarray) +@pytest.fixture(params=[ + ("lico2_ocv_example", pybamm.parameters.process_1D_data), + ("lico2_diffusivity_Dualfoil1998_2D", pybamm.parameters.process_2D_data), + ("data_for_testing_2D", pybamm.parameters.process_2D_data_csv), + ("data_for_testing_3D", pybamm.parameters.process_3D_data_csv), +]) +def parameter_data(request, path): + name, processing_function = request.param + processed = processing_function(name, path) + return name, processed - def test_process_2D_data_csv(self): - name = "data_for_testing_2D" - path = os.path.abspath(os.path.dirname(__file__)) - processed = pybamm.parameters.process_2D_data_csv(name, path) +class TestProcessParameterData: + def test_processed_name(self, parameter_data): + name, processed = parameter_data assert processed[0] == name - assert isinstance(processed[1], tuple) - assert isinstance(processed[1][0][0], np.ndarray) - assert isinstance(processed[1][0][1], np.ndarray) - assert isinstance(processed[1][1], np.ndarray) - - def test_process_3D_data_csv(self): - name = "data_for_testing_3D" - path = os.path.abspath(os.path.dirname(__file__)) - processed = pybamm.parameters.process_3D_data_csv(name, path) - assert processed[0] == name + def test_processed_structure(self, parameter_data): + """ + Test that the processed data has the correct structure. + + Args: + parameter_data: A tuple containing the name and processed data. + + Asserts: + - The second element of the processed data is a tuple. + - The first element of the second item in the processed data is a numpy array. + - Additional checks based on the shape of the processed data. + """ + name, processed = parameter_data assert isinstance(processed[1], tuple) assert isinstance(processed[1][0][0], np.ndarray) - assert isinstance(processed[1][0][1], np.ndarray) - assert isinstance(processed[1][0][2], np.ndarray) assert isinstance(processed[1][1], np.ndarray) + if len(processed[1][0]) > 1: + assert isinstance(processed[1][0][1], np.ndarray) + + elif len(processed[1]) == 3: + assert isinstance(processed[1][0][1], np.ndarray) + assert isinstance(processed[1][0][2], np.ndarray) + def test_error(self): with pytest.raises(FileNotFoundError, match="Could not find file"): pybamm.parameters.process_1D_data("not_a_real_file", "not_a_real_path")