Skip to content

Commit

Permalink
add additional plots
Browse files Browse the repository at this point in the history
  • Loading branch information
azoz01 committed Aug 20, 2024
1 parent 9265af2 commit ed643d5
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 6 deletions.
13 changes: 12 additions & 1 deletion bin/analysis/analyze_distance_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def main():
)[0]
),
),
(
"Dataset2Vec reconstruction (representations)",
Dataset2VecForLandmarkerReconstruction.load_from_checkpoint(
list(
(
paths_provider.encoders_results_path
/ "d2v_reconstruction"
).rglob("*.ckpt")
)[0]
),
),
]
logger.info("Analyzing")
results = dict()
Expand Down Expand Up @@ -102,7 +113,7 @@ def main():
correlations.append(
spearmanr(landmarkers_distances, datasets_distances).statistic
)
results["Dataset2Vec reconstruction"] = {
results["Dataset2Vec reconstruction (to landmarkers)"] = {
"mean": np.mean(correlations),
"std": np.std(correlations),
}
Expand Down
60 changes: 55 additions & 5 deletions bin/analysis/analyze_warmstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,19 @@ def main():
aggregates.columns = ["dataset", "min_value", "max_value"]

plot_data = df.merge(aggregates, "left", "dataset")
plot_data = plot_data.sort_values(["dataset", "warmstart", "number"])
plot_data = plot_data.sort_values(
["dataset", "warmstart", "number", "value"]
)
plot_data["cumulative_max_value"] = plot_data.groupby(
["dataset", "warmstart"]
)["value"].cummax()
plot_data["neg_value"] = -plot_data["value"]
plot_data = plot_data.sort_values(
["dataset", "warmstart", "number", "neg_value"]
)
plot_data["rank"] = plot_data.groupby(["dataset", "number"])[
"neg_value"
].rank()

plot_data["distance"] = (
plot_data["max_value"] - plot_data["cumulative_max_value"]
Expand All @@ -34,18 +43,59 @@ def main():
plot_data["scaled_value"] = (df["value"] - plot_data["min_value"]) / (
plot_data["max_value"] - plot_data["min_value"]
)
plot_data["number"] = plot_data["number"] + 1

logger.info("Generating plot of ranks over time")
fig, ax = plt.subplots(figsize=(10, 5))
sns.lineplot(plot_data, x="number", y="rank", hue="warmstart", ax=ax)
ax.vlines(
x=5,
ymin=1,
ymax=7,
colors="black",
label="End of warm-start phase",
linestyles="dotted",
)
ax.set_ylabel("Rank")
ax.set_xlabel("Number of iteration")
ax.set_title("Rank of the negative ROC AUC over time")
plt.savefig(paths_provider.results_analysis_path / "rank_over_time.png")
plt.clf()

logger.info("Generating plot for raw values")
fig, ax = plt.subplots(figsize=(10, 10))
fig, ax = plt.subplots(figsize=(10, 5))
sns.lineplot(
plot_data, x="number", y="scaled_value", hue="warmstart", ax=ax
)
ax.vlines(
x=5,
ymin=0,
ymax=1,
colors="black",
label="End of warm-start phase",
linestyles="dotted",
)
ax.set_ylabel("Scaled ROC AUC")
ax.set_xlabel("Number of iteration")
ax.set_title("Scaled value of the ROC AUC over time")
plt.savefig(paths_provider.results_analysis_path / "raw_values.png")
plt.clf()

logger.info("Generating ADTM plot")
fig, ax = plt.subplots(figsize=(10, 5))
sns.lineplot(data=plot_data, x="number", y="distance", hue="warmstart")
sns.lineplot(
data=plot_data, x="number", y="distance", hue="warmstart", ax=ax
)
ax.vlines(
x=5,
ymin=0,
ymax=0.25,
colors="black",
label="End of warm-start phase",
linestyles="dotted",
)
ax.set_xlabel("Number of iteration")
ax.set_ylabel("Scaled distance")
plt.savefig(paths_provider.results_analysis_path / "adtm.png")
plt.clf()

Expand All @@ -55,7 +105,7 @@ def main():
plot_data["dataset_name"] = plot_data["dataset"]
draw_cd_diagram(
df_perf=plot_data[["classifier_name", "dataset_name", "accuracy"]].loc[
plot_data.number == 4
plot_data.number == 5
]
)
plt.savefig(
Expand All @@ -65,7 +115,7 @@ def main():

draw_cd_diagram(
df_perf=plot_data[["classifier_name", "dataset_name", "accuracy"]].loc[
plot_data.number == 19
plot_data.number == 20
]
)
plt.savefig(
Expand Down

0 comments on commit ed643d5

Please sign in to comment.