Skip to content

Commit

Permalink
refactor: use pytest fixtures to reduce repetition in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinr24 committed Oct 11, 2024
1 parent 3bf3ea8 commit efb05be
Showing 1 changed file with 36 additions and 36 deletions.
72 changes: 36 additions & 36 deletions tests/unit/test_parameters/test_process_parameter_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,57 @@
# Tests for the parameter processing functions
#


import os
import numpy as np
import pybamm

import pytest


class TestProcessParameterData:
def test_process_1D_data(self):
name = "lico2_ocv_example"
path = os.path.abspath(os.path.dirname(__file__))
processed = pybamm.parameters.process_1D_data(name, path)
assert processed[0] == name
assert isinstance(processed[1], tuple)
assert isinstance(processed[1][0][0], np.ndarray)
assert isinstance(processed[1][1], np.ndarray)
@pytest.fixture
def path():
return os.path.abspath(os.path.dirname(__file__))

def test_process_2D_data(self):
name = "lico2_diffusivity_Dualfoil1998_2D"
path = os.path.abspath(os.path.dirname(__file__))
processed = pybamm.parameters.process_2D_data(name, path)
assert processed[0] == name
assert isinstance(processed[1], tuple)
assert isinstance(processed[1][0][0], np.ndarray)
assert isinstance(processed[1][0][1], np.ndarray)
assert isinstance(processed[1][1], np.ndarray)
@pytest.fixture(params=[
("lico2_ocv_example", pybamm.parameters.process_1D_data),
("lico2_diffusivity_Dualfoil1998_2D", pybamm.parameters.process_2D_data),
("data_for_testing_2D", pybamm.parameters.process_2D_data_csv),
("data_for_testing_3D", pybamm.parameters.process_3D_data_csv),
])
def parameter_data(request, path):
name, processing_function = request.param
processed = processing_function(name, path)
return name, processed

def test_process_2D_data_csv(self):
name = "data_for_testing_2D"
path = os.path.abspath(os.path.dirname(__file__))
processed = pybamm.parameters.process_2D_data_csv(name, path)

class TestProcessParameterData:
def test_processed_name(self, parameter_data):
name, processed = parameter_data
assert processed[0] == name
assert isinstance(processed[1], tuple)
assert isinstance(processed[1][0][0], np.ndarray)
assert isinstance(processed[1][0][1], np.ndarray)
assert isinstance(processed[1][1], np.ndarray)

def test_process_3D_data_csv(self):
name = "data_for_testing_3D"
path = os.path.abspath(os.path.dirname(__file__))
processed = pybamm.parameters.process_3D_data_csv(name, path)

assert processed[0] == name
def test_processed_structure(self, parameter_data):
"""
Test that the processed data has the correct structure.
Args:
parameter_data: A tuple containing the name and processed data.
Asserts:
- The second element of the processed data is a tuple.
- The first element of the second item in the processed data is a numpy array.
- Additional checks based on the shape of the processed data.
"""
name, processed = parameter_data
assert isinstance(processed[1], tuple)
assert isinstance(processed[1][0][0], np.ndarray)
assert isinstance(processed[1][0][1], np.ndarray)
assert isinstance(processed[1][0][2], np.ndarray)
assert isinstance(processed[1][1], np.ndarray)

if len(processed[1][0]) > 1:
assert isinstance(processed[1][0][1], np.ndarray)

elif len(processed[1]) == 3:
assert isinstance(processed[1][0][1], np.ndarray)
assert isinstance(processed[1][0][2], np.ndarray)

def test_error(self):
with pytest.raises(FileNotFoundError, match="Could not find file"):
pybamm.parameters.process_1D_data("not_a_real_file", "not_a_real_path")

0 comments on commit efb05be

Please sign in to comment.