-
Notifications
You must be signed in to change notification settings - Fork 0
/
mem_sbatch_athena.py
226 lines (192 loc) · 9.89 KB
/
mem_sbatch_athena.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import argparse
import os
from pathlib import Path
from experiments_config import all_experiments
sbatch_script_template = """#!/bin/bash
#
#SBATCH --partition={partition}
#SBATCH --gres={gres}
#SBATCH --output=./outputs/{outputs_name}
source $SCRATCH/miniconda3/bin/activate base
conda init bash
source ~/.bashrc
export PYTHONPATH=$(pwd)
touch {csv_filename}
echo -e "algo,universe,env,mem" >> {csv_filename}
{commands_to_run}
"""
ALGOS_TO_KWARGS = {
'sac': {'main': 'src/pytorch_sac/train.py',
'hydra': 1,
'run_args_str': 'num_train_steps=15500 eval_frequency=1000000 num_seed_steps=5000 eval_0_step=false csv_writing_enabled=false track_GPU_memory=true'},
'td3': {'main': 'src/td3/main.py',
'hydra': 0,
'run_args_str': '--max_timesteps 15500 --eval_freq 1000000 --start_timesteps 5000 --eval_0_step --csv_writing_enabled --track_GPU_memory'},
'crossq': {'main': 'src/crossq/train.py',
'hydra': 0,
'run_args_str': '--total_timesteps 15500 --log_freq 1000000 --eval_0_step --csv_writing_enabled --wandb_mode online --track_GPU_memory'},
'tdmpc2': {'main': 'src/tdmpc2/tdmpc2/train.py',
'hydra': 1,
'task_syntax': 1,
'run_args_str': 'steps=15000 disable_wandb=false track_GPU_memory=true eval_freq=1000000 is_time_benchmark=true'},
'bbf_dac': {'main': 'src/bbf_dac/train_parallel.py',
'hydra': 0,
'task_syntax': 2,
'run_args_str': 'track_GPU_memory=true max_steps=15500'},
'sr_sac': {'main': 'src/sr_sac/train_parallel.py',
'hydra': 0,
'task_syntax': 2,
'run_args_str': 'track_GPU_memory=true max_steps=15500'},
}
ALGO_UNIVERSE_TO_CONDA = {
('sac', 'gym'): 'sac_td3_dm_control',
('td3', 'gym'): 'sac_td3_dm_control',
('sac', 'dm_control'): 'sac_td3_dm_control',
('td3', 'dm_control'): 'sac_td3_dm_control',
('crossq', 'dm_control'): 'crossq', # same name but incompatible with gym
('tdmpc2', 'dm_control'): 'tdmpc2',
('bbf_dac', 'dm_control'): 'bbf_sac',
('sr_sac', 'dm_control'): 'bbf_sac',
('sac', 'metaworld'): 'sac_td3_metaworld',
('td3', 'metaworld'): 'sac_td3_metaworld',
('crossq', 'metaworld'): 'crossq_metaworld',
('tdmpc2', 'metaworld'): 'tdmpc2_metaworld',
('bbf_dac', 'metaworld'): 'bbf_dac_mw',
('sr_sac', 'metaworld'): 'bbf_dac_mw',
('sac', 'myo'): 'sac_td3_myo',
('td3', 'myo'): 'sac_td3_myo',
('crossq', 'myo'): 'crossq_myo',
('tdmpc2', 'myo'): 'tdmpc2_myo',
('bbf_dac', 'myo'): 'bbf_dac_myo',
('sr_sac', 'myo'): 'bbf_dac_myo',
('crossq', 'shimmy_dm_control'): 'crossq',
}
def parse_arguments():
parser = argparse.ArgumentParser(description="Parameters for running sbatch scripts.")
# Add arguments
parser.add_argument("--universes", nargs='+', type=str, default=['dm_control'],
choices=['gym', 'dm_control', 'metaworld', 'myo', 'shimmy_dm_control'],
help="Universe the task is in")
parser.add_argument("--algos", nargs='+', type=str, default=["sac"],
choices=['sac', 'td3', 'crossq', 'tdmpc2', 'bbf_dac', 'sr_sac'],
help="Universe the task is in")
parser.add_argument("--envs", nargs='+', type=str, default=None, help="Task to run. None means all.")
parser.add_argument("--time", type=str, default="2-23:59:59", help="Max job duration.")
parser.add_argument("--t", type=int, default=720, help="Max job duration in minutes..")
parser.add_argument("--gres", type=str, default="gpu:1", help="Gres for the job.")
parser.add_argument("--mem", type=str, default=8000, help="Mam memory for the job in MB.")
parser.add_argument("--seed", type=int, default=0, help="Start seed for the job.")
parser.add_argument('--run_args', type=str, default='', help='Additional arguments for the run command.')
parser.add_argument('--partition', type=str, default='plgrid-gpu-a100', help='Partition for the job.')
return parser.parse_args()
def env_name_to_tdmpc2_env_name(env_name: str, universe: str) -> str:
if universe == 'dm_control':
if env_name == "finger_turn_hard":
return "finger-turn_hard"
return env_name.replace('_', '-')
if universe == 'metaworld':
# removes -v2 characters
return 'mw-' + env_name[:-3]
if universe == 'myo':
return env_name
raise ValueError(f"Universe {universe} not supported by tdmpc2 benchmark.")
def env_name_to_bbf_dac_name(env_name: str, universe: str) -> str:
if universe == 'dm_control':
if env_name == "finger_turn_hard":
return "finger-turn_hard"
return env_name.replace('_', '-')
if universe == 'metaworld':
return env_name + '-goal-observable'
if universe == 'myo':
return env_name
raise ValueError(f"Universe {universe} not supported by bbf_sac benchmark.")
def bbf_dac_main_from_universe(universe: str, algo_name: str) -> str:
"""
Returns the main file for bbf_dac algorithm depending on the universe.
Works also for sr_sac as is has a very similar structure.
:param universe: str: Universe the task is in.
:param algo_name: str: Name of the algorithm.
:return: str: Path to the main file for the algorithm.
"""
assert algo_name in ['bbf_dac', 'sr_sac'], f"Function works only for bbf_dac and sr_sac, not {algo_name}."
if universe == 'dm_control':
return f'src/{algo_name}/train_parallel.py'
if universe == 'metaworld':
return f'src/{algo_name}/train_parallel_mw.py'
if universe == 'myo':
return f'src/{algo_name}/train_parallel_myo.py'
raise ValueError(f"Universe {universe} not supported by bbf_sac benchmark.")
if __name__ == '__main__':
default_params_dict = vars(parse_arguments())
scripts_dir = "./sbatch_scripts"
os.makedirs(scripts_dir, exist_ok=True)
os.makedirs("./time", exist_ok=True) # just in case to not get any errors in sbatch
print(f"Script will run with parameters {default_params_dict}\n")
for universe in default_params_dict['universes']:
assert universe in all_experiments.keys(), \
f"Universe {universe} not found in experiments_config"
script_name = ('mem_' +
'_'.join(default_params_dict['algos'])
+ '_'
+ '_'.join(default_params_dict['universes'])
+ f"_seed_{default_params_dict['seed']}"
+ default_params_dict['run_args'].replace(' ', '').replace('--', '_')
+ ".sh")
output_name = script_name.replace(".sh", ".out")
csv_name = str(Path('./mem') / f"{default_params_dict['algos'][0]}_{default_params_dict['universes'][0]}_memory.csv")
partial_commands_to_run = []
all_commands_to_run = []
for universe in default_params_dict['universes']:
# prepare names of the tasks to run on
envs = [env['env_name'] for env in all_experiments[universe]]
if default_params_dict['envs'] is not None:
envs = default_params_dict['envs']
assert len(envs) > 0, f"No experiments found for {default_params_dict['envs']}."
print(f'Running experiments only for envs: {default_params_dict["envs"]}.')
# testing all algos
for algo in default_params_dict['algos']:
algo_kwargs = ALGOS_TO_KWARGS[algo]
use_hydra_syntax = (algo_kwargs['hydra'] == 1)
seed = default_params_dict['seed']
# some algos use 'task' instead of 'env'
task_syntax = algo_kwargs.get('task_syntax', 0)
if task_syntax == 0:
env_or_task_str = 'env'
elif task_syntax == 1:
env_or_task_str = 'task'
else:
env_or_task_str = 'env_name'
for env in envs:
# convert the name of the env to syntax use in tdmpc2
if algo == 'tdmpc2':
env = env_name_to_tdmpc2_env_name(env, universe)
if algo in ['bbf_dac', 'sr_sac']:
env = env_name_to_bbf_dac_name(env, universe)
algo_kwargs['main'] = bbf_dac_main_from_universe(universe, algo)
# prepare command to run with different syntax depending whether hydra is used ot not
if use_hydra_syntax:
command_to_run = f"python {algo_kwargs['main']} {env_or_task_str}={env} universe={universe} seed={seed}"
else:
command_to_run = f"python {algo_kwargs['main']} --{env_or_task_str} {env} --universe {universe} --seed {seed}"
command_to_run = f"{command_to_run} {algo_kwargs['run_args_str'] + default_params_dict['run_args']}\n"
partial_commands_to_run.append(command_to_run)
# Handle activating and deactivating conda environments
conda_env = ALGO_UNIVERSE_TO_CONDA[(algo, universe)]
big_command_to_run = (f'conda activate {conda_env}\n' +
f'python -m wandb online\n\n' +
'\n'.join(partial_commands_to_run) + f'\nconda deactivate\n')
all_commands_to_run.append(big_command_to_run)
partial_commands_to_run = []
print(f"{len(all_commands_to_run)} big commands to run.")
# save script
sbatch_dict = {**default_params_dict,
'commands_to_run': '\n'.join(all_commands_to_run),
'outputs_name': output_name,
'csv_filename': csv_name} # for the csv file
with open(f"{scripts_dir}/{script_name}", "w") as f:
f.write(sbatch_script_template.format(**sbatch_dict))
print(f"Running sbatch with script {script_name}\n")
exit_code = os.system(
f"sbatch -A plgplgplasticityrl-gpu-a100 -c 8 -t {default_params_dict['t']} --mem 10G {scripts_dir}/{script_name}")
if exit_code != 0:
print(f"Error in sbatch, Exit code: {exit_code}")