Skip to content

Commit

Permalink
Minor bugfix in rl_component_bundle
Browse files Browse the repository at this point in the history
  • Loading branch information
lihuoran committed Jun 14, 2022
1 parent b53547a commit d32d8e0
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions maro/rl/rl_component/rl_component_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from typing import Any, Dict, List

from maro.rl.policy import AbsPolicy
from maro.rl.policy import AbsPolicy, RLPolicy
from maro.rl.rollout import AbsEnvSampler
from maro.rl.training import AbsTrainer

Expand Down Expand Up @@ -45,7 +45,7 @@ def __post_init__(self) -> None:
kept_policies = []
for policy in self.policies:
if policy.name not in self.agent2policy.values():
raise Warning(f"Policy {policy.name} if removed since it is not used by any agent.")
raise Warning(f"Policy {policy.name} is removed since it is not used by any agent.")
else:
kept_policies.append(policy)
self.policies = kept_policies
Expand Down Expand Up @@ -91,5 +91,10 @@ def trainable_agent2policy(self) -> Dict[Any, str]:
}

@property
def trainable_policies(self) -> List[AbsPolicy]: # TODO: Abs or RL?
return [policy for policy in self.policies if policy.name in self.policy_trainer_mapping]
def trainable_policies(self) -> List[RLPolicy]:
policies = []
for policy in self.policies:
if policy.name in self.policy_trainer_mapping:
assert isinstance(policy, RLPolicy)
policies.append(policy)
return policies

0 comments on commit d32d8e0

Please sign in to comment.