Skip to content

Commit

Permalink
chore: improve test speed
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj committed Apr 13, 2023
1 parent 1f4fbb9 commit f51e8fe
Showing 1 changed file with 21 additions and 23 deletions.
44 changes: 21 additions & 23 deletions tests/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def test_assertion_error():
env_id = 'Simple-v0'
custom_cfgs = {
'train_cfgs': {
'total_steps': 2048,
'total_steps': 200,
'vector_env_nums': 1,
'torch_threads': 4,
},
'algo_cfgs': {
'update_cycle': 1024,
'update_cycle': 100,
'update_iters': 2,
},
'logger_cfgs': {
Expand Down Expand Up @@ -100,12 +100,12 @@ def test_render():
env_id = 'Simple-v0'
custom_cfgs = {
'train_cfgs': {
'total_steps': 2048,
'total_steps': 200,
'vector_env_nums': 1,
'torch_threads': 4,
},
'algo_cfgs': {
'update_cycle': 1024,
'update_cycle': 100,
'update_iters': 2,
},
'logger_cfgs': {
Expand All @@ -125,13 +125,13 @@ def test_off_policy(algo):
env_id = 'Simple-v0'
custom_cfgs = {
'train_cfgs': {
'total_steps': 2048,
'total_steps': 200,
'vector_env_nums': 1,
'torch_threads': 4,
},
'algo_cfgs': {
'update_cycle': 1024,
'steps_per_sample': 1024,
'update_cycle': 100,
'steps_per_sample': 50,
'update_iters': 2,
'start_learning_steps': 0,
'use_critic_norm': True,
Expand All @@ -156,13 +156,13 @@ def test_sac_policy(auto_alpha):
env_id = 'Simple-v0'
custom_cfgs = {
'train_cfgs': {
'total_steps': 2048,
'total_steps': 200,
'vector_env_nums': 1,
'torch_threads': 4,
},
'algo_cfgs': {
'update_cycle': 1024,
'steps_per_sample': 1024,
'update_cycle': 100,
'steps_per_sample': 50,
'update_iters': 2,
'start_learning_steps': 0,
'auto_alpha': auto_alpha,
Expand All @@ -173,7 +173,6 @@ def test_sac_policy(auto_alpha):
'use_wandb': False,
'save_model_freq': 1,
},
'model_cfgs': model_cfgs,
}
agent = omnisafe.Agent('SAC', env_id, custom_cfgs=custom_cfgs)
agent.learn()
Expand All @@ -188,13 +187,13 @@ def test_sac_lag_policy(auto_alpha):
env_id = 'Simple-v0'
custom_cfgs = {
'train_cfgs': {
'total_steps': 2048,
'total_steps': 200,
'vector_env_nums': 1,
'torch_threads': 4,
},
'algo_cfgs': {
'update_cycle': 1024,
'steps_per_sample': 1024,
'update_cycle': 100,
'steps_per_sample': 50,
'update_iters': 2,
'start_learning_steps': 0,
'auto_alpha': auto_alpha,
Expand Down Expand Up @@ -229,19 +228,18 @@ def test_on_policy(algo):
env_id = 'Simple-v0'
custom_cfgs = {
'train_cfgs': {
'total_steps': 2048,
'total_steps': 200,
'vector_env_nums': 1,
'torch_threads': 4,
},
'algo_cfgs': {
'update_cycle': 1024,
'update_cycle': 100,
'update_iters': 2,
},
'logger_cfgs': {
'use_wandb': False,
'save_model_freq': 1,
},
'model_cfgs': model_cfgs,
}
agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs)
agent.learn()
Expand All @@ -253,12 +251,12 @@ def test_workflow_for_training(algo):
env_id = 'Simple-v0'
custom_cfgs = {
'train_cfgs': {
'total_steps': 2048,
'total_steps': 200,
'vector_env_nums': 1,
'torch_threads': 4,
},
'algo_cfgs': {
'update_cycle': 1024,
'update_cycle': 100,
'update_iters': 2,
},
'logger_cfgs': {
Expand All @@ -280,11 +278,11 @@ def test_std_anealing():
env_id = 'Simple-v0'
custom_cfgs = {
'train_cfgs': {
'total_steps': 2048,
'total_steps': 200,
'vector_env_nums': 1,
},
'algo_cfgs': {
'update_cycle': 1024,
'update_cycle': 100,
'update_iters': 2,
},
'logger_cfgs': {
Expand All @@ -311,13 +309,13 @@ def test_std_anealing():
# env_id = 'Simple-v0'
# custom_cfgs = {
# 'train_cfgs': {
# 'total_steps': 2048,
# 'total_steps': 200,
# 'vector_env_nums': 1,
# 'torch_threads': 4,
# 'device': 'cuda:0',
# },
# 'algo_cfgs': {
# 'update_cycle': 1024,
# 'update_cycle': 100,
# 'update_iters': 2,
# },
# 'logger_cfgs': {
Expand Down

0 comments on commit f51e8fe

Please sign in to comment.