Skip to content

Commit

Permalink
Added skip connection
Browse files Browse the repository at this point in the history
  • Loading branch information
saidinesh_pola committed Dec 10, 2023
1 parent 66c6b34 commit 8651cdf
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
2 changes: 1 addition & 1 deletion reinforcement_learning/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Config:
eval_num_steps = 1_000_000 # 1_000_000 # Number of steps to evaluate
checkpoint_interval = 5_000_000 # Interval to save models
# f"nmmo_{time.strftime('%Y%m%d_%H%M%S')}_{seed}" # Run name
run_name = f"nmmo_rp_cr_attn_skill_bonus_seed{seed}_exp17"
run_name = f"nmmo_rp_cr_attn_lstm_seed{seed}_exp17"
runs_dir = "./runs" # Directory for runs
policy_store_dir = None # Policy store directory
use_serial_vecenv = False # Use serial vecenv implementation
Expand Down
22 changes: 16 additions & 6 deletions reinforcement_learning/rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,11 @@ def __init__(self, input_size):
self.multihead_attn = MultiheadAttention(
16, 4) # Added MultiheadAttention layer

self.tile_conv_1 = torch.nn.Conv2d(96, 64, 3)
self.tile_conv_2 = torch.nn.Conv2d(64, 32, 3)
self.tile_conv_3 = torch.nn.Conv2d(32, 16, 3)
self.tile_fc = torch.nn.Linear(16 * 9 * 9, input_size)
self.tile_conv_1 = torch.nn.Conv2d(96, 64, 3, padding=1)
self.tile_conv_2 = torch.nn.Conv2d(64, 64, 3, padding=1)
self.tile_conv_3 = torch.nn.Conv2d(64, 32, 3, padding=1)
self.tile_conv_4 = torch.nn.Conv2d(32, 16, 3, padding=1)
self.tile_fc = torch.nn.Linear(16 * 15 * 15, input_size)
self.activation = torch.nn.ReLU()

def forward(self, tile):
Expand All @@ -153,13 +154,22 @@ def forward(self, tile):
)

tile = self.activation(self.tile_conv_1(tile))
tile_skip_1 = tile.clone() # Save for skip connection

tile = self.activation(self.tile_conv_2(tile))
tile = self.activation(self.tile_conv_3(tile)) # Additional layer
tile += tile_skip_1 # Add skip connection

tile = self.activation(self.tile_conv_3(tile))
tile_skip_2 = tile.clone() # Save for skip connection

tile = self.activation(self.tile_conv_4(tile))
tile += tile_skip_2 # Add skip connection

# Reshape for MultiheadAttention
tile = tile.view(tile.size(0), tile.size(1), -1).permute(2, 0, 1)
tile, _ = self.multihead_attn(
tile, tile, tile) # Apply MultiheadAttention
tile = tile.permute(1, 2, 0).view(agents, 16, 9, 9) # Reshape back
tile = tile.permute(1, 2, 0).view(agents, 16, 15, 15) # Reshape back
tile = tile.contiguous().view(agents, -1)
tile = self.activation(self.tile_fc(tile))

Expand Down

0 comments on commit 8651cdf

Please sign in to comment.