Skip to content

Commit

Permalink
Merge pull request #26 from valeriocaporioniunipi/valerio
Browse files Browse the repository at this point in the history
Valerio
  • Loading branch information
valeriocaporioniunipi authored May 21, 2024
2 parents 48b203e + fca9025 commit 7ef159c
Show file tree
Hide file tree
Showing 13 changed files with 740 additions and 383 deletions.
Binary file modified .DS_Store
Binary file not shown.
Binary file added code/.DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions code/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Binary file modified code/__pycache__/csvreader.cpython-311.pyc
Binary file not shown.
Binary file added code/__pycache__/utils.cpython-311.pyc
Binary file not shown.
26 changes: 0 additions & 26 deletions code/abspath.py

This file was deleted.

148 changes: 0 additions & 148 deletions code/csvreader.py

This file was deleted.

100 changes: 61 additions & 39 deletions code/gaussian_reg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import argparse
import numpy as np
from loguru import logger
import matplotlib.pyplot as plt

Expand All @@ -8,10 +8,9 @@
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler

from abspath import abs_path
from csvreader import get_data
from utils import abs_path, get_data

def gaussian_reg(filename, n_splits, ex_cols=0, plot_flag=False):
def gaussian_reg(features, targets, n_splits, **kwargs):
"""
gaussian_reg performs gaussian regression with k-fold cross-validation on the
given dataset and prints evaluation metrics of the gaussian regression model
Expand All @@ -24,33 +23,36 @@ def gaussian_reg(filename, n_splits, ex_cols=0, plot_flag=False):
:param plot_flag: optional (default = False): Whether to plot the actual vs. predicted values
:type plot_flag: bool
:return: None
"""
# Loading data...
#Importing features excluded first three columns: FILE_ID, AGE_AT_SCAN, SEX
x = get_data(filename)[:, ex_cols:]
y = get_data(filename, "AGE_AT_SCAN")
# Definition of keyword arguments
plot_flag = kwargs.get('plot_flag', False)
group = kwargs.get('group', None)

# Standardize features
# Initialize data standardization (done after the k-folding split to avoid leakage)
scaler = StandardScaler()
x_scaled = scaler.fit_transform(x)

# Initialize k-fold cross-validation
kf = KFold(n_splits=n_splits)
kf = KFold(n_splits=n_splits, shuffle = True, random_state= 42)

# Initialize lists to store evaluation metrics
mae_scores = []
mse_scores = []
r2_scores = []
mae_scores, mse_scores, r2_scores = [], [], []

# Initialize figure for plotting
plt.figure(figsize=(10, 8))
if plot_flag:
if group is not None:
fig, (ax, ax_group) = plt.subplots(1, 2, figsize=(20, 8))
else:
fig, ax = plt.subplots(figsize=(10, 8))

# Perform k-fold cross-validation
for i, (train_index, test_index) in enumerate(kf.split(x_scaled), 1):
for i, (train_index, test_index) in enumerate(kf.split(features), 1):
# Split data into training and testing sets
x_train, x_test = x_scaled[train_index], x_scaled[test_index]
y_train, y_test = y[train_index], y[test_index]
x_train, x_test = features[train_index], features[test_index]
y_train, y_test = targets[train_index], targets[test_index]
if group is not None:
group_test = group[test_index]

x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)

# Initialize and fit linear regression model
model = GaussianProcessRegressor()
Expand All @@ -69,30 +71,49 @@ def gaussian_reg(filename, n_splits, ex_cols=0, plot_flag=False):
r2_scores.append(r2)

# Plot actual vs. predicted values for current fold
plt.scatter(y_test, y_pred, alpha=0.5, label=f'Fold {i} - MAE = {np.round(mae_scores[i-1], 2)}')
if plot_flag:
ax.scatter(y_test, y_pred, alpha=0.5,
label=f'Fold {i} - MAE = {np.round(mae_scores[i-1], 2)}')
if group is not None:
y_test_exp = y_test[group_test == 1]
y_pred_exp = y_pred[group_test == 1]
y_test_control = y_test[group_test == -1]
y_pred_control = y_pred[group_test == -1]
ax_group.scatter(y_test_exp, y_pred_exp,color = 'r', alpha = 0.5)
ax_group.scatter(y_test_control, y_pred_control, color = 'royalblue', alpha = 0.5)

# Print average evaluation metrics over all folds
print("Mean Absolute Error:", np.mean(mae_scores))
print("Mean Squared Error:", np.mean(mse_scores))
print("R-squared:", np.mean(r2_scores))

if plot_flag:

target_range = [targets.min(), targets.max()]
# Plot the ideal line (y=x)
plt.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=2)
ax.plot(target_range, target_range, 'k--', lw=2)

# Set plot labels and title
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('Actual vs. Predicted Brain Age')
ax.set_xlabel('Actual age [y]', fontsize = 20)
ax.set_ylabel('Predicted age [y]', fontsize = 20)
ax.set_title('Actual vs. predicted age', fontsize = 24)

# Add legend and grid to the plot
plt.legend()
plt.grid(True)


ax.legend(fontsize = 16)
ax.grid(False)
if group is not None:
ax_group.plot(target_range, target_range, 'k--', lw=2)
ax_group.set_xlabel('Actual age [y]', fontsize = 20)
ax_group.set_ylabel('Predicted age [y]', fontsize = 20)
ax_group.set_title('Actual vs. predicted age - exp. vs. control', fontsize = 24)
ax_group.grid(False)
exp_legend = ax_group.scatter([], [], marker = 'o', color = 'r', label = 'exp.', alpha = 0.5)
control_legend = ax_group.scatter([], [], marker = 'o', color = 'royalblue', label = 'control', alpha = 0.5)
ax_group.legend(handles = [exp_legend, control_legend], loc='lower right', fontsize = 16)
# Show the plot
#plt.savefig('/Users/valeriocaporioni/Downloads/gaussian_reg.png', transparent = True)
plt.show()
else:
logger.info("Skipping the plot of actual vs predicted age ")

def gaussian_reg_parsing():
"""
Expand Down Expand Up @@ -134,31 +155,32 @@ def gaussian_reg_parsing():
"""
parser = argparse.ArgumentParser(description=
'Gaussian regression predicting the age of patients from magnetic resonance imaging')
'Linear regression predicting the age of patients from magnetic resonance imaging')

parser.add_argument("filename",
help="Name of the file that has to be analized")
help="Name of the file that has to be analized if --location argument is"
" passed. Otherwise pass to filename the absolutepath of the file")
parser.add_argument("--target", default = "AGE_AT_SCAN",
help="Name of the colums holding target values")
help="Name of the column holding target values")
parser.add_argument("--location",
help="Location of the file, i.e. folder containing it")
parser.add_argument("--folds", type = int, default = 5,
help="Number of folds in the k-folding (>4, default 5)")
parser.add_argument("--ex_cols", type = int, default = 3,
parser.add_argument("--ex_cols", type = int, default = 5,
help="Number of columns excluded when importing (default 3)")
parser.add_argument("--plot", action="store_true",
help="Show the plot of actual vs predicted brain age")
parser.add_argument("--group", default = 'DX_GROUP',
help="Name of the column indicating the group (experimental vs control)")

args = parser.parse_args()

if args.folds > 4:
try:
args.filename = abs_path(args.filename,
args.location) if args.location else args.filename
args.filename = abs_path(args.filename,args.location) if args.location else args.filename
logger.info(f"Opening file : {args.filename}")
features, targets = get_data(args.filename, args.target, args.ex_cols)
gaussian_reg(features, targets, args.epochs, args.folds,
args.summary, args.history, args.plot)
features, targets, group = get_data(args.filename, args.target, args.ex_cols, group_col = args.group)
gaussian_reg(features, targets, args.folds, plot_flag = args.plot, group = group)
except FileNotFoundError:
logger.error("File not found.")
else:
Expand Down
Loading

0 comments on commit 7ef159c

Please sign in to comment.