Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: sven1977 <svenmika1977@gmail.com>
  • Loading branch information
sven1977 committed Sep 25, 2024
1 parent 87acd25 commit 50a43e6
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 123 deletions.
13 changes: 6 additions & 7 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1020,25 +1020,24 @@ py_test(
)

# IMPALA
# @OldAPIStack
py_test(
name = "test_impala_old_api_stack",
name = "test_impala",
tags = ["team:rllib", "algorithms_dir"],
size = "large",
srcs = ["algorithms/impala/tests/test_impala.py"]
)
# @OldAPIStack
py_test(
name = "test_vtrace_old_api_stack",
name = "test_vtrace_v2",
tags = ["team:rllib", "algorithms_dir"],
size = "small",
srcs = ["algorithms/impala/tests/test_vtrace_old_api_stack.py"]
srcs = ["algorithms/impala/tests/test_vtrace_v2.py"]
)
# @OldAPIStack
py_test(
name = "test_vtrace_v2",
name = "test_vtrace_old_api_stack",
tags = ["team:rllib", "algorithms_dir"],
size = "small",
srcs = ["algorithms/impala/tests/test_vtrace_v2.py"]
srcs = ["algorithms/impala/tests/test_vtrace_old_api_stack.py"]
)

# MARWIL
Expand Down
64 changes: 12 additions & 52 deletions rllib/algorithms/impala/tests/test_impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@
import ray
import ray.rllib.algorithms.impala as impala
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
from ray.rllib.utils.test_utils import (
check,
check_compute_single_action,
check_train_results,
)
from ray.rllib.utils.metrics import LEARNER_RESULTS
from ray.rllib.utils.test_utils import check


class TestIMPALA(unittest.TestCase):
Expand All @@ -20,54 +16,18 @@ def setUpClass(cls) -> None:
def tearDownClass(cls) -> None:
ray.shutdown()

def test_impala_compilation(self):
"""Test whether IMPALA can be built with both frameworks."""
config = (
impala.IMPALAConfig()
.environment("CartPole-v1")
.resources(num_gpus=0)
.env_runners(num_env_runners=2)
.training(
model={
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
},
)
)
num_iterations = 2

for lstm in [False, True]:
config.num_aggregation_workers = 0 if not lstm else 1
config.model["use_lstm"] = lstm
print(
"lstm={} aggregation-workers={}".format(
lstm, config.num_aggregation_workers
)
)
# Test with and w/o aggregation workers (this has nothing
# to do with LSTMs, though).
algo = config.build()
for i in range(num_iterations):
results = algo.train()
print(results)
check_train_results(results)

check_compute_single_action(
algo,
include_state=lstm,
include_prev_action_reward=lstm,
)
algo.stop()

def test_impala_lr_schedule(self):
# Test whether we correctly ignore the "lr" setting.
# The first lr should be 0.05.
config = (
impala.IMPALAConfig()
.resources(num_gpus=0)
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.learners(num_learners=0)
.training(
lr=0.1,
lr_schedule=[
lr=[
[0, 0.05],
[100000, 0.000001],
],
Expand All @@ -78,15 +38,15 @@ def test_impala_lr_schedule(self):
)

def get_lr(result):
return result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
"cur_lr"
return result[LEARNER_RESULTS][DEFAULT_POLICY_ID][
"default_optimizer_learning_rate"
]

algo = config.build()
policy = algo.get_policy()
optim = algo.learner_group._learner.get_optimizer()

try:
check(policy.cur_lr, 0.05)
check(optim.param_groups[0]["lr"], 0.05)
for _ in range(1):
r1 = algo.train()
for _ in range(2):
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/tests/test_vtrace_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
make_time_major,
)
from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import vtrace_torch
from ray.rllib.algorithms.impala.tests.test_vtrace import (
from ray.rllib.algorithms.impala.tests.test_vtrace_old_api_stack import (
_ground_truth_vtrace_calculation,
)
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
Expand Down
3 changes: 2 additions & 1 deletion rllib/algorithms/tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def test_evaluation_option(self):
)

algo = config.build()
# Given evaluation_interval=2, r0, r2, r4 should not contain
# Given evaluation_interval=2, r0, r2 should not contain
# evaluation metrics, while r1, r3 should.
r0 = algo.train()
print(r0)
Expand All @@ -437,6 +437,7 @@ def test_evaluation_option(self):
print(r3)
algo.stop()

# No eval results yet in first iteration (eval has not run yet).
self.assertFalse(EVALUATION_RESULTS in r0)
self.assertTrue(EVALUATION_RESULTS in r1)
self.assertFalse(EVALUATION_RESULTS in r2)
Expand Down
63 changes: 1 addition & 62 deletions rllib/policy/tests/test_compute_log_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
from scipy.stats import norm

import ray
import ray.rllib.algorithms.dqn as dqn
import ray.rllib.algorithms.ppo as ppo
import ray.rllib.algorithms.sac as sac
from ray.rllib.utils.numpy import MAX_LOG_NN_OUTPUT, MIN_LOG_NN_OUTPUT, fc, one_hot
from ray.rllib.utils.numpy import fc, one_hot
from ray.rllib.utils.test_utils import check


Expand Down Expand Up @@ -135,14 +133,6 @@ def setUpClass(cls) -> None:
def tearDownClass(cls) -> None:
ray.shutdown()

def test_dqn(self):
"""Tests, whether DQN correctly computes logp in soft-q mode."""
config = dqn.DQNConfig()
# Soft-Q for DQN.
config.env_runners(exploration_config={"type": "SoftQ", "temperature": 0.5})
config.debugging(seed=42)
do_test_log_likelihood(dqn.DQN, config)

def test_ppo_cont(self):
"""Tests PPO's (cont. actions) compute_log_likelihoods method."""
config = ppo.PPOConfig()
Expand All @@ -163,57 +153,6 @@ def test_ppo_discr(self):
prev_a = np.array(0)
do_test_log_likelihood(ppo.PPO, config, prev_a)

def test_sac_cont(self):
"""Tests SAC's (cont. actions) compute_log_likelihoods method."""
config = sac.SACConfig()
config.training(
policy_model_config={
"fcnet_hiddens": [10],
"fcnet_activation": "linear",
}
)
config.debugging(seed=42)
prev_a = np.array([0.0])

# SAC cont uses a squashed normal distribution. Implement it's logp
# logic here in numpy for comparing results.
def logp_func(means, log_stds, values, low=-1.0, high=1.0):
stds = np.exp(np.clip(log_stds, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT))
unsquashed_values = np.arctanh((values - low) / (high - low) * 2.0 - 1.0)
log_prob_unsquashed = np.sum(
np.log(norm.pdf(unsquashed_values, means, stds)), -1
)
return log_prob_unsquashed - np.sum(
np.log(1 - np.tanh(unsquashed_values) ** 2), axis=-1
)

do_test_log_likelihood(
sac.SAC,
config,
prev_a,
continuous=True,
layer_key=(
"fc",
(0, 2),
("action_model._hidden_layers.0.", "action_model._logits."),
),
logp_func=logp_func,
)

def test_sac_discr(self):
"""Tests SAC's (discrete actions) compute_log_likelihoods method."""
config = sac.SACConfig()
config.training(
policy_model_config={
"fcnet_hiddens": [10],
"fcnet_activation": "linear",
}
)
config.debugging(seed=42)
prev_a = np.array(0)

do_test_log_likelihood(sac.SAC, config, prev_a)


if __name__ == "__main__":
import sys
Expand Down

0 comments on commit 50a43e6

Please sign in to comment.