Skip to content

Commit

Permalink
Merge branch 'issue-1' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
wleoncio committed Jul 31, 2024
2 parents 024ff3e + 8784999 commit dcb2665
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 4 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@ local-install:
$(ENV_PATH)pip install .

test:
$(ENV_PATH)pytest -m "not slow"

test-full:
$(ENV_PATH)pytest
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
markers = slow: marks tests as slow (deselect with '-m "not slow"')
40 changes: 36 additions & 4 deletions src/pCRscore/discovery_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import numpy
from sklearn import preprocessing
from sklearn.metrics import make_scorer, f1_score, accuracy_score
from sklearn.model_selection import GridSearchCV, train_test_split, StratifiedKFold
from sklearn.model_selection import \
GridSearchCV, train_test_split, StratifiedKFold, KFold, cross_val_score
from sklearn.svm import SVC
from sklearn.datasets import make_classification

# TODO: add code from
# https://github.com/YounessAzimzade/XML-TME-NAC-BC/blob/main/Discovery%20SVM.ipynb
Expand Down Expand Up @@ -51,8 +53,7 @@ def extract_features(data):

return X, y

def grid_search(X, y, n_cores = 1):
# TODO: profile (with cProfile?) to possibly add progress bar
def grid_search(X, y, n_cores = 1, verbose = 0):
# Defining the parameter range for the hyperparameter grid search
param_grid = {
'C': numpy.exp(numpy.linspace(-12, 3, num = 50)),
Expand All @@ -77,10 +78,41 @@ def grid_search(X, y, n_cores = 1):
# no verbosity
grid = GridSearchCV(
SVC(class_weight='balanced'),
param_grid, scoring = scoring, refit = 'F1', cv = 10, n_jobs = n_cores
param_grid, scoring = scoring, refit = 'F1', cv = 10, n_jobs = n_cores,
verbose = verbose
)

# Fit the model for grid search using the training data
grid.fit(X_train, y_train)

return grid

def evaluate_model(X, y, verbose = False):
# We normally start with the model that has the best performance and
# fine tune the parameters to find the best model.
# Here, the following model found to have the best performance
# based on combined score

# Create model
model = SVC(
C = 1, gamma = 0.1, kernel = 'rbf', probability = True,
class_weight = 'balanced'
)

# It should be noted that SHAP values calculated using these two models are
# very similar, particularly for features with high correlation to response.

cv = KFold(n_splits=5, random_state=1, shuffle=True)

# evaluate model
Acc_score = cross_val_score(model, X, y, scoring='accuracy', cv=cv, n_jobs=-1)
f1_score = cross_val_score(model, X, y, scoring='f1', cv=cv, n_jobs=-1)
roc_auc = cross_val_score(model, X, y, scoring='roc_auc', cv=cv, n_jobs=-1)

# report performance
if verbose:
print('Accuracy: %.3f (%.3f)' % (numpy.mean(Acc_score)*100, numpy.std(Acc_score)*100))
print('f1 score: %.3f (%.3f)' % (numpy.mean(f1_score), numpy.std(f1_score)))
print('AUC: %.3f (%.3f)' % (numpy.mean(roc_auc), numpy.std(roc_auc)))

return {'Accuracy': Acc_score, 'f1 score': f1_score, 'AUC': roc_auc}
18 changes: 18 additions & 0 deletions tests/test_discovery_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,21 @@ def test_preprocess(mock_read_csv, mock_data):

X, y = discovery_svm.extract_features(data)
assert X.shape == (100, 44)

@pytest.mark.slow
def test_grid_search():
X = np.random.randn(100, 44)
y = np.random.choice([0, 1], 100)
grid = discovery_svm.grid_search(X, y)
assert isinstance(grid, discovery_svm.GridSearchCV)
assert hasattr(grid, 'best_params_')
assert hasattr(grid, 'best_score_')

def test_evaluate_model():
X = np.random.randn(100, 44)
y = np.random.choice([0, 1], 100)
stats = discovery_svm.evaluate_model(X, y)
assert isinstance(stats, tuple)
assert len(stats) == 3
for i in range(3):
assert len(stats[i]) == 5

0 comments on commit dcb2665

Please sign in to comment.