diff --git a/rllib/algorithms/bc/torch/bc_torch_rl_module.py b/rllib/algorithms/bc/torch/bc_torch_rl_module.py index a547047d7f417..d06c323b124ef 100644 --- a/rllib/algorithms/bc/torch/bc_torch_rl_module.py +++ b/rllib/algorithms/bc/torch/bc_torch_rl_module.py @@ -11,7 +11,7 @@ class BCTorchRLModule(TorchRLModule): @override(RLModule) def setup(self): # __sphinx_doc_begin__ - # Build models from catalog + # Build models from catalog. self.encoder = self.catalog.build_encoder(framework=self.framework) self.pi = self.catalog.build_pi_head(framework=self.framework) diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index fb3e34f4339dd..43eddb909dea0 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -1,5 +1,5 @@ import copy -from dataclasses import dataclass, field +import dataclasses import logging import pprint from typing import ( @@ -553,7 +553,7 @@ def _check_module_exists(self, module_id: ModuleID) -> None: @PublicAPI(stability="alpha") -@dataclass +@dataclasses.dataclass class MultiRLModuleSpec: """A utility spec class to make it constructing MultiRLModules easier. @@ -666,7 +666,11 @@ def build(self, module_id: Optional[ModuleID] = None) -> RLModule: observation_space=self.observation_space, action_space=self.action_space, inference_only=self.inference_only, - model_config=self.model_config, + model_config=( + dataclasses.asdict(self.model_config) + if dataclasses.is_dataclass(self.model_config) + else self.model_config + ), rl_module_specs=self.rl_module_specs, ) # Older custom model might still require the old `MultiRLModuleConfig` under @@ -859,10 +863,10 @@ def get_rl_module_config(self): "module2: [RLModuleSpec], ..}, inference_only=..)", error=False, ) -@dataclass +@dataclasses.dataclass class MultiRLModuleConfig: inference_only: bool = False - modules: Dict[ModuleID, RLModuleSpec] = field(default_factory=dict) + modules: Dict[ModuleID, RLModuleSpec] = dataclasses.field(default_factory=dict) def to_dict(self): return { diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index f1fb5b337cc54..42aa0a780ed45 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -98,7 +98,7 @@ def build(self) -> "RLModule": observation_space=self.observation_space, action_space=self.action_space, inference_only=self.inference_only, - model_config=self.model_config, + model_config=self._get_model_config(), catalog_class=self.catalog_class, ) # Older custom model might still require the old `RLModuleConfig` under