-
Notifications
You must be signed in to change notification settings - Fork 0
/
analyse_sweep.py
53 lines (40 loc) · 1.41 KB
/
analyse_sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
import glob
import os
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
parser = argparse.ArgumentParser()
parser.add_argument("--csv-dir", type=str, default="outputs/")
parser.add_argument("--fig-save-dir", type=str, default="assets/")
DATASET_SIZE_GIB = 9.2
def main(args):
files = glob.glob(os.path.join(args.csv_dir, "*.csv"))
df = pd.concat((pd.read_csv(f) for f in files), ignore_index=True)
for col in df.columns:
if col not in ["cache_limit_gib", "epoch_0_time", "epoch_1_time"]:
df = df.drop(col, axis=1)
df = df.melt(
id_vars=["cache_limit_gib"],
value_vars=["epoch_0_time", "epoch_1_time"],
var_name="epoch",
value_name="time",
)
df["cache_limit_gib"] = df["cache_limit_gib"] / DATASET_SIZE_GIB * 100
df["cache_limit_gib"] = df["cache_limit_gib"].clip(upper=100)
df = df.rename(columns={"cache_limit_gib": "Cache Limit (%)"})
df = df.rename(columns={"time": "Time (s)"})
df = df.rename(columns={"epoch": "Epoch"})
df["Epoch"] = df["Epoch"].replace({"epoch_0_time": "0", "epoch_1_time": "1"})
sns.lmplot(
data=df,
x="Cache Limit (%)",
y="Time (s)",
hue="Epoch",
x_jitter=2.0,
)
plt.ylim(bottom=0)
plt.savefig(os.path.join(args.fig_save_dir, "sweep.png"))
if __name__ == "__main__":
args = parser.parse_args()
main(args)