Skip to content

Commit

Permalink
Remove checking for specific pytorch warnings that don't seem to appe…
Browse files Browse the repository at this point in the history
…ar anymore during DAgger tests.
  • Loading branch information
ernestum committed Dec 18, 2023
1 parent a55ff9e commit b74ad0f
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions tests/scripts/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,22 +266,14 @@ def test_train_preference_comparisons_reward_named_config(tmpdir, named_configs)


def test_train_dagger_main(tmpdir):
with pytest.warns(None) as record:
run = train_imitation.train_imitation_ex.run(
command_name="dagger",
named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"],
config_updates=dict(
logging=dict(log_root=tmpdir),
demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH),
),
)
for warning in record:
# PyTorch wants writeable arrays.
# See https://github.com/HumanCompatibleAI/imitation/issues/219
assert not (
warning.category == UserWarning
and "NumPy array is not writeable" in warning.message.args[0]
)
run = train_imitation.train_imitation_ex.run(
command_name="dagger",
named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"],
config_updates=dict(
logging=dict(log_root=tmpdir),
demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH),
),
)
assert run.status == "COMPLETED"
assert isinstance(run.result, dict)

Expand Down

0 comments on commit b74ad0f

Please sign in to comment.