Skip to content

Commit

Permalink
Added unit tests for SVM (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
wleoncio committed Jul 31, 2024
1 parent 28bd854 commit 8784999
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@ local-install:
$(ENV_PATH)pip install .

test:
$(ENV_PATH)pytest -m "not slow"

test-full:
$(ENV_PATH)pytest
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
markers = slow: marks tests as slow (deselect with '-m "not slow"')
18 changes: 17 additions & 1 deletion tests/test_discovery_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,20 @@ def test_preprocess(mock_read_csv, mock_data):
X, y = discovery_svm.extract_features(data)
assert X.shape == (100, 44)

# TODO: add tests for SVM grid_search and evaluate_model
@pytest.mark.slow
def test_grid_search():
X = np.random.randn(100, 44)
y = np.random.choice([0, 1], 100)
grid = discovery_svm.grid_search(X, y)
assert isinstance(grid, discovery_svm.GridSearchCV)
assert hasattr(grid, 'best_params_')
assert hasattr(grid, 'best_score_')

def test_evaluate_model():
X = np.random.randn(100, 44)
y = np.random.choice([0, 1], 100)
stats = discovery_svm.evaluate_model(X, y)
assert isinstance(stats, tuple)
assert len(stats) == 3
for i in range(3):
assert len(stats[i]) == 5

0 comments on commit 8784999

Please sign in to comment.