Skip to content

Commit

Permalink
Add new PLSR plots (#105)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Ramirez <aramirez@aretha.seas.ucla.edu>
  • Loading branch information
andrewram4287 and Andrew Ramirez authored Nov 5, 2024
1 parent 527409e commit e708e36
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 60 deletions.
1 change: 0 additions & 1 deletion pf2/figures/figureA1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def makeFigure():
subplotLabel(ax)

X = anndata.read_h5ad("/opt/northwest_bal/full_fitted.h5ad")
X.uns["Pf2_A"] = correct_conditions(X)
add_obs(X, "patient_category")
add_obs(X, "binary_outcome")
add_obs(X, "episode_etiology")
Expand Down
76 changes: 37 additions & 39 deletions pf2/figures/figureA10.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,72 +13,70 @@

def makeFigure():
"""Get a list of the axis objects and create a figure."""
ax, f = getSetup((8, 8), (2, 2))
ax, f = getSetup((5, 4), (2, 2))
subplotLabel(ax)

X = anndata.read_h5ad("/opt/northwest_bal/full_fitted.h5ad")

meta = import_meta()
conversions = convert_to_patients(X)
meta = import_meta(drop_duplicates=False)
conversions = convert_to_patients(X, sample=True)

patient_factor = pd.DataFrame(
X.uns["Pf2_A"],
index=conversions,
columns=np.arange(X.uns["Pf2_A"].shape[1]) + 1,
)
meta = meta.loc[patient_factor.index, :]
meta.set_index("sample_id", inplace=True)

shared_indices = patient_factor.index.intersection(meta.index)
patient_factor = patient_factor.loc[shared_indices, :]
meta = meta.loc[shared_indices, :]

labels, plsr_results_both = plsr_acc(patient_factor, meta)

plot_plsr_loadings(plsr_results_both, ax[0], ax[1], text=False)
ax[0].set(xlim=[-0.4, 0.4], ylim=[-0.4, 0.4])
ax[1].set(xlim=[-0.4, 0.4], ylim=[-0.4, 0.4])
plot_plsr_loadings(plsr_results_both, ax[0], ax[1])
ax[0].set(xlim=[-0.35, 0.35])
ax[1].set(xlim=[-0.4, 0.4])

# plot_plsr_loadings(plsr_results_both, ax[2], ax[3], text=True)
plot_plsr_scores(plsr_results_both, meta, labels, ax[2], ax[3])
ax[2].set(xlim=[-9, 9], ylim=[-9, 9])
ax[3].set(xlim=[-8, 8], ylim=[-8, 8])
ax[2].set(xlim=[-9.5, 9.5])
ax[3].set(xlim=[-8.5, 8.5])

return f


def plsr_acc(patient_factor_matrix, meta_data):
def plsr_acc(patient_factor_matrix, meta_data, n_components=1):
"""Runs PLSR and obtains average prediction accuracy"""

_, labels, [c19_plsr, nc19_plsr] = predict_mortality(
patient_factor_matrix, meta_data, proba=False
patient_factor_matrix, meta_data, n_components=n_components, proba=False
)

return labels, [c19_plsr, nc19_plsr]


def plot_plsr_loadings(plsr_results, ax1, ax2, text=False):
def plot_plsr_loadings(plsr_results, ax1, ax2):
"""Runs PLSR and plots ROC AUC based on actual and prediction labels"""
ax = [ax1, ax2]
type_of_data = ["C19", "nC19"]

for i in range(2):
ax[i].scatter(
plsr_results[i].y_loadings_[0, 0],
plsr_results[i].y_loadings_[0, 1],
c="tab:red",
df_xload = pd.DataFrame(data=plsr_results[i].x_loadings_[:, 0], columns=["PLSR 1"])
df_yload = pd.DataFrame(data=[[plsr_results[i].y_loadings_[0, 0]]], columns=["PLSR 1"])
sns.swarmplot(
data=df_xload,
x="PLSR 1",
ax=ax[i],
color="k",
)
ax[i].scatter(
plsr_results[i].x_loadings_[:, 0], plsr_results[i].x_loadings_[:, 1], c="k"
sns.swarmplot(
data=df_yload,
x="PLSR 1",
ax=ax[i],
color="r",

)
if text:
for index, component in enumerate(plsr_results[i].coef_.index):
ax[i].text(
plsr_results[i].x_loadings_[index, 0],
plsr_results[i].x_loadings_[index, 1] - 0.001,
ha="center",
ma="center",
va="center",
s=component,
c="w",
)

ax[i].set(xlabel="PLSR 1", ylabel="PLSR 2", title=f"{type_of_data[i]}-loadings")
ax[i].set(xlabel="PLSR 1", ylabel="Pf2 Components", title=f"{type_of_data[i]}-loadings")


def plot_plsr_scores(plsr_results, meta_data, labels, ax1, ax2):
Expand All @@ -105,14 +103,14 @@ def plot_plsr_scores(plsr_results, meta_data, labels, ax1, ax2):
numb1=0; numb2=2
else:
numb1=1; numb2=3

sns.scatterplot(
x=plsr_results[i].x_scores_[:, 0],
y=plsr_results[i].x_scores_[:, 1],

df_xscores = pd.DataFrame(data=plsr_results[i].x_scores_[:, 0], columns=["PLSR 1"])
sns.swarmplot(
data=df_xscores,
x="PLSR 1",
ax=ax[i],
hue=score_labels.to_numpy(),
palette=[pal[numb1], pal[numb2]],
hue_order=[1, 0],
ax=ax[i],
)

ax[i].set(xlabel="PLSR 1", ylabel="PLSR 2", title=f"{type_of_data[i]}-loadings")
ax[i].set(xlabel="PLSR 1", ylabel="Samples", title=f"{type_of_data[i]}-scores")
50 changes: 33 additions & 17 deletions pf2/figures/figureA9.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from sklearn.metrics import accuracy_score
import seaborn as sns
from ..data_import import convert_to_patients, import_meta
from ..predict import predict_mortality
from ..predict import predict_mortality, predict_mortality_all
from .common import subplotLabel, getSetup
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import accuracy_score, roc_auc_score
from pf2.figures.commonFuncs.plotGeneral import bal_combine_bo_covid


def makeFigure():
Expand All @@ -21,38 +22,46 @@ def makeFigure():

X = anndata.read_h5ad("/opt/northwest_bal/full_fitted.h5ad")

meta = import_meta()
conversions = convert_to_patients(X)
meta = import_meta(drop_duplicates=False)
conversions = convert_to_patients(X, sample=True)

patient_factor = pd.DataFrame(
X.uns["Pf2_A"],
index=conversions,
columns=np.arange(X.uns["Pf2_A"].shape[1]) + 1,
)
meta = meta.loc[patient_factor.index, :]
meta.set_index("sample_id", inplace=True)

shared_indices = patient_factor.index.intersection(meta.index)
patient_factor = patient_factor.loc[shared_indices, :]
meta = meta.loc[shared_indices, :]

roc_auc = [False, True]
for i in range(2):
plsr_acc_df = pd.DataFrame([])
for j in range(3):
df = plsr_acc_proba(
patient_factor, meta, n_components=j + 1, roc_auc=roc_auc[i]
)
print(df)
df["Component"] = j + 1
plsr_acc_df = pd.concat([plsr_acc_df, df], axis=0)

plsr_acc_df = plsr_acc_df.melt(
id_vars="Component", var_name="Category", value_name="Accuracy"
)
sns.barplot(
data=plsr_acc_df, x="Component", y="Accuracy", hue="Category", ax=ax[i]
data=plsr_acc_df, x="Component", y="Accuracy", hue="Category", ax=ax[i],
hue_order=["C19", "nC19", "Overall"]
)
if roc_auc[i] is True:
ax[i].set(ylim=[0, 1], ylabel="AUC ROC")
else:
ax[i].set(ylim=[0, 1], ylabel="Prediction Accuracy")

plot_plsr_auc_roc(patient_factor, meta, ax[2])
for i in range(2):
plot_plsr_auc_roc(patient_factor, meta, n_components=i + 1, ax=ax[i + 2])
ax[i + 2].set(title=f"PLSR {i + 1} Components")

return f

Expand All @@ -65,6 +74,10 @@ def plsr_acc_proba(patient_factor_matrix, meta_data, n_components=2, roc_auc=Tru
probabilities, labels = predict_mortality(
patient_factor_matrix, n_components=n_components, meta=meta_data, proba=True
)

probabilities_all, labels_all = predict_mortality_all(
patient_factor_matrix, n_components=n_components, meta=meta_data, proba=True
)

probabilities = probabilities.round().astype(int)
meta_data = meta_data.loc[~meta_data.index.duplicated()].loc[labels.index]
Expand All @@ -73,41 +86,44 @@ def plsr_acc_proba(patient_factor_matrix, meta_data, n_components=2, roc_auc=Tru
score = roc_auc_score
else:
score = accuracy_score

covid_acc = score(
labels.loc[meta_data.loc[:, "patient_category"] == "COVID-19"],
probabilities.loc[meta_data.loc[:, "patient_category"] == "COVID-19"],
labels.loc[meta_data.loc[:, "patient_category"] == "COVID-19"].to_numpy().astype(int),
probabilities.loc[meta_data.loc[:, "patient_category"] == "COVID-19"].to_numpy(),
)
nc_acc = score(
labels.loc[meta_data.loc[:, "patient_category"] != "COVID-19"],
labels.loc[meta_data.loc[:, "patient_category"] != "COVID-19"].to_numpy().astype(int),
probabilities.loc[meta_data.loc[:, "patient_category"] != "COVID-19"],
)
acc = score(labels, probabilities)
acc = score(labels_all.to_numpy().astype(int), probabilities_all.round().astype(int))

acc_df.loc[0, :] = [acc, covid_acc, nc_acc]

return acc_df


def plot_plsr_auc_roc(patient_factor_matrix, meta_data, ax):
def plot_plsr_auc_roc(patient_factor_matrix, meta_data, n_components, ax):
"""Runs PLSR and plots ROC AUC based on actual and prediction labels"""
probabilities, labels = predict_mortality(
patient_factor_matrix, meta_data, proba=True
patient_factor_matrix, meta_data, n_components=n_components, proba=True
)
probabilities_all, labels_all = predict_mortality_all(
patient_factor_matrix, n_components=n_components, meta=meta_data, proba=True
)
meta_data = meta_data.loc[~meta_data.index.duplicated()].loc[labels.index]

RocCurveDisplay.from_predictions(
labels.loc[meta_data.loc[:, "patient_category"] == "COVID-19"],
labels.loc[meta_data.loc[:, "patient_category"] == "COVID-19"].to_numpy().astype(int),
probabilities.loc[meta_data.loc[:, "patient_category"] == "COVID-19"],
ax=ax,
name="C19",
)
RocCurveDisplay.from_predictions(
labels.loc[meta_data.loc[:, "patient_category"] != "COVID-19"],
labels.loc[meta_data.loc[:, "patient_category"] != "COVID-19"].to_numpy().astype(int),
probabilities.loc[meta_data.loc[:, "patient_category"] != "COVID-19"],
ax=ax,
name="nC19",
)
RocCurveDisplay.from_predictions(
labels, probabilities, plot_chance_level=True, ax=ax, name="Overall"
)
labels_all.to_numpy().astype(int), probabilities_all, plot_chance_level=True, ax=ax, name="Overall"
)
49 changes: 46 additions & 3 deletions pf2/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ def predict_mortality(

data = data.loc[meta.loc[:, "patient_category"] != "Non-Pneumonia Control", :]
meta = meta.loc[meta.loc[:, "patient_category"] != "Non-Pneumonia Control", :]
labels = data.index.to_series().replace(meta.loc[:, "binary_outcome"])


meta.loc[:, "binary_outcome"] = meta["binary_outcome"].astype("category")
labels = data.index.to_series().map(meta["binary_outcome"])

covid_data = data.loc[meta.loc[:, "patient_category"] == "COVID-19", :]
covid_labels = meta.loc[
meta.loc[:, "patient_category"] == "COVID-19", "binary_outcome"
Expand All @@ -95,4 +97,45 @@ def predict_mortality(

else:
predicted = predictions.round().astype(int)
return accuracy_score(labels, predicted), labels, (c_plsr, nc_plsr)
return accuracy_score(labels.to_numpy().astype(int), predicted), labels, (c_plsr, nc_plsr)


def predict_mortality_all(
data: pd.DataFrame, meta: pd.DataFrame, proba: bool = False, n_components=2
):
"""
Predicts mortality via cross-validation.
Parameters:
data (pd.DataFrame): data to predict
meta (pd.DataFrame): patient meta-data
proba (bool, default:False): return probability of prediction
Returns:
if proba:
probabilities (pd.Series): predicted probability of mortality for
patients
labels (pd.Series): classification targets
else:
accuracy (float): prediction accuracy
models (tuple[COVID, Non-COVID]): fitted PLSR models
"""
if not isinstance(data, pd.DataFrame):
data = pd.DataFrame(data)

data = data.loc[meta.loc[:, "patient_category"] != "Non-Pneumonia Control", :]
meta = meta.loc[meta.loc[:, "patient_category"] != "Non-Pneumonia Control", :]
labels = data.index.to_series().replace(meta.loc[:, "binary_outcome"])
labels = pd.Series(index=labels.index, data=labels.to_numpy().astype(int))

predictions = pd.Series(index=data.index)
predictions[:], all_plsr = run_plsr(
data, labels, proba=proba, n_components=n_components
)

if proba:
return predictions, labels

else:
predicted = predictions.round().astype(int)
return accuracy_score(labels, predicted), labels, all_plsr

0 comments on commit e708e36

Please sign in to comment.