Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TD3应用混合动作空间报错,AssertionError #789

Closed
dajianer opened this issue Apr 11, 2024 · 2 comments
Closed

TD3应用混合动作空间报错,AssertionError #789

dajianer opened this issue Apr 11, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@dajianer
Copy link

dajianer commented Apr 11, 2024

在使用TD3训练混合动作空间环境时,运行会报错assert isinstance(action, torch.Tensor),我查看源码发现HybridArgmaxSampleWrapper的forward返回值确实可能会引起错误,请问我应该怎样解决呢
代码如下:

    logging.getLogger().setLevel(logging.INFO)
    cfg = compile_config(main_config, create_cfg=create_config, auto=True)
    ding_init(cfg)
    ctx = OnlineRLContext(collect_kwargs={'eps': 0.01})
    with task.start(async_mode=False, ctx=ctx):
        collector_env = BaseEnvManagerV2(
            env_fn=[lambda: DI_UAV_AoI(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
        )
        evaluator_env = BaseEnvManagerV2(
            env_fn=[lambda: DI_UAV_AoI(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
        )

        set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

        model = ContinuousQAC(**cfg.policy.model)
        buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
        policy = TD3Policy(cfg.policy, model=model)

        task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
        task.use(
            StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
        )
        task.use(data_pusher(cfg, buffer_))
        task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
        task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
        task.use(termination_checker(max_train_iter=int(1e5)))
        task.use(online_logger())
        task.run(max_step=int(1e5))
@PaParaZz1 PaParaZz1 added the bug Something isn't working label Apr 12, 2024
@PaParaZz1
Copy link
Member

请问你使用混合动作空间时,有没有指定 TD3Policy 中的 action_space='hybrid'

具体的相关实现在这里(链接

完整的 DDPG/TD3 类型的混合动作空间配置文件可以参考这个示例

@MarkHolmstrom
Copy link
Contributor

The issue appears to be TD3's target policy smoothing, with the action noise wrapper not supporting the hybrid action space. Setting noise=False in the policy configuration to be the same as the reference DDPG config disables target policy smoothing as a workaround.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants