Skip to content

Commit

Permalink
[RLlib] New API stack: (Multi)RLModule overhaul vol 02 (VPG RLModule,…
Browse files Browse the repository at this point in the history
… Algo, and Learner example classes). (ray-project#47885)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent 390e591 commit 69e6cc6
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 19 deletions.
Empty file.
65 changes: 46 additions & 19 deletions rllib/examples/rl_modules/classes/vpg_using_shared_encoder_rlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,33 @@ def setup(self):
super().setup()

# Incoming feature dim from the shared encoder.
embedding_dim = self.model_config["embedding_dim"]
feature_dim = self.model_config["feature_dim"]
hidden_dim = self.model_config["hidden_dim"]

self._pi_head = nn.Sequential(
nn.Linear(embedding_dim, hidden_dim),
nn.Linear(feature_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, self.action_space.n),
)

@override(RLModule)
def _forward(self, batch, **kwargs):
def _forward_inference(self, batch):
with torch.no_grad():
return self._common_forward(batch)

@override(RLModule)
def _forward_exploration(self, batch):
with torch.no_grad():
return self._common_forward(batch)

@override(RLModule)
def _forward_train(self, batch):
return self._common_forward(batch)

def _common_forward(self, batch):
# Features can be found in the batch under the "encoder_features" key.
embeddings = batch["encoder_embeddings"]
logits = self._pi_head(embeddings)
features = batch["encoder_features"]
logits = self._pi_head(features)
return {Columns.ACTION_DIST_INPUTS: logits}


Expand All @@ -47,7 +60,7 @@ class VPGTorchMultiRLModuleWithSharedEncoder(MultiRLModule):
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
EMBEDDING_DIM = 64 # encoder output (feature) dim
FEATURE_DIM = 64 # encoder output (feature) dim
HIDDEN_DIM = 64 # hidden dim for the policy nets
config.rl_module(
Expand All @@ -56,20 +69,20 @@ class VPGTorchMultiRLModuleWithSharedEncoder(MultiRLModule):
# Central/shared encoder net.
SHARED_ENCODER_ID: RLModuleSpec(
module_class=SharedTorchEncoder,
model_config={"embedding_dim": EMBEDDING_DIM},
model_config_dict={"feature_dim": FEATURE_DIM},
),
# Arbitrary number of policy nets (w/o encoder sub-net).
"p0": RLModuleSpec(
module_class=VPGTorchRLModuleUsingSharedEncoder,
model_config={
"embedding_dim": EMBEDDING_DIM,
model_config_dict={
"feature_dim": FEATURE_DIM,
"hidden_dim": HIDDEN_DIM,
},
),
"p1": RLModuleSpec(
module_class=VPGTorchRLModuleUsingSharedEncoder,
model_config={
"embedding_dim": EMBEDDING_DIM,
model_config_dict={
"feature_dim": FEATURE_DIM,
"hidden_dim": HIDDEN_DIM,
},
),
Expand All @@ -96,7 +109,7 @@ def setup(self):
)

@override(MultiRLModule)
def _forward(self, forward_fn_name, batch, **kwargs):
def _run_forward_pass(self, forward_fn_name, batch, **kwargs):
outputs = {}
encoder_forward_fn = getattr(
self._rl_modules[SHARED_ENCODER_ID], forward_fn_name
Expand All @@ -109,9 +122,9 @@ def _forward(self, forward_fn_name, batch, **kwargs):

# Pass policy's observations through shared encoder to get the features for
# this policy.
embeddings = encoder_forward_fn(batch[policy_id])
features = encoder_forward_fn(batch[policy_id])
# Pass the policy's features through the policy net.
batch[policy_id]["encoder_embeddings"] = embeddings
batch[policy_id]["encoder_features"] = features
outputs[policy_id] = forward_fn(batch[policy_id], **kwargs)

return outputs
Expand All @@ -125,13 +138,27 @@ def setup(self):
super().setup()

input_dim = self.observation_space.shape[0]
embedding_dim = self.model_config["embedding_dim"]
feature_dim = self.model_config["feature_dim"]

self._encoder = nn.Sequential(
nn.Linear(input_dim, embedding_dim),
nn.Linear(input_dim, feature_dim),
)

def _forward(self, batch, **kwargs):
@override(RLModule)
def _forward_inference(self, batch):
with torch.no_grad():
return self._common_forward(batch)

@override(RLModule)
def _forward_exploration(self, batch):
with torch.no_grad():
return self._common_forward(batch)

@override(RLModule)
def _forward_train(self, batch):
return self._common_forward(batch)

def _common_forward(self, batch):
# Pass observations through the encoder and return outputs.
embeddings = self._encoder(batch[Columns.OBS])
return {"encoder_embeddings": embeddings}
features = self._encoder(batch[Columns.OBS])
return {"encoder_features": features}

0 comments on commit 69e6cc6

Please sign in to comment.