Skip to content

Commit

Permalink
Add curriculum changes
Browse files Browse the repository at this point in the history
  • Loading branch information
saidinesh_pola committed Dec 10, 2023
1 parent 8651cdf commit 8a1cf2f
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 143 deletions.
1 change: 1 addition & 0 deletions manually-create-your-curriculum.ipynb

Large diffs are not rendered by default.

59 changes: 41 additions & 18 deletions reinforcement_learning/rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,10 @@ def __init__(self, input_size):
self.multihead_attn = MultiheadAttention(
16, 4) # Added MultiheadAttention layer

self.tile_conv_1 = torch.nn.Conv2d(96, 64, 3, padding=1)
self.tile_conv_2 = torch.nn.Conv2d(64, 64, 3, padding=1)
self.tile_conv_3 = torch.nn.Conv2d(64, 32, 3, padding=1)
self.tile_conv_4 = torch.nn.Conv2d(32, 16, 3, padding=1)
self.tile_fc = torch.nn.Linear(16 * 15 * 15, input_size)
self.tile_conv_1 = torch.nn.Conv2d(96, 64, 3)
self.tile_conv_2 = torch.nn.Conv2d(64, 32, 3)
self.tile_conv_3 = torch.nn.Conv2d(32, 16, 3)
self.tile_fc = torch.nn.Linear(16 * 9 * 9, input_size)
self.activation = torch.nn.ReLU()

def forward(self, tile):
Expand All @@ -154,22 +153,13 @@ def forward(self, tile):
)

tile = self.activation(self.tile_conv_1(tile))
tile_skip_1 = tile.clone() # Save for skip connection

tile = self.activation(self.tile_conv_2(tile))
tile += tile_skip_1 # Add skip connection

tile = self.activation(self.tile_conv_3(tile))
tile_skip_2 = tile.clone() # Save for skip connection

tile = self.activation(self.tile_conv_4(tile))
tile += tile_skip_2 # Add skip connection

tile = self.activation(self.tile_conv_3(tile)) # Additional layer
# Reshape for MultiheadAttention
tile = tile.view(tile.size(0), tile.size(1), -1).permute(2, 0, 1)
tile, _ = self.multihead_attn(
tile, tile, tile) # Apply MultiheadAttention
tile = tile.permute(1, 2, 0).view(agents, 16, 15, 15) # Reshape back
tile = tile.permute(1, 2, 0).view(agents, 16, 9, 9) # Reshape back
tile = tile.contiguous().view(agents, -1)
tile = self.activation(self.tile_fc(tile))

Expand Down Expand Up @@ -312,6 +302,15 @@ def __init__(self, input_size, hidden_size):
)
self.attn = SelfAttention(hidden_size, 4, hidden_size//4)
self.fc = torch.nn.Linear(hidden_size * 2, hidden_size)
self.rnn = torch.nn.LSTM(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=5,
batch_first=True,
)
self.prev_player_states = None
self.prev_inventory_states = None
self.prev_hidden_states = None

def apply_layer(self, layer, embeddings, mask, hidden):
hidden = layer(hidden)
Expand All @@ -338,11 +337,35 @@ def forward(self, hidden, lookup):
player_embeddings_before = player_embeddings.clone()
inventory_embeddings_before = inventory_embeddings.clone()
hidden_before = hidden.clone()
hidden = hidden.unsqueeze(1) # make it 3d shape for self attn
hidden = hidden.unsqueeze(1)
# Check for the prev states shape
if self.prev_player_states is None or self.prev_player_states[0].shape != (self.rnn.num_layers, player_embeddings.shape[0], self.rnn.hidden_size):
h_0 = torch.zeros(
self.rnn.num_layers, player_embeddings.shape[0], self.rnn.hidden_size, device=player_embeddings.device)
c_0 = torch.zeros(
self.rnn.num_layers, player_embeddings.shape[0], self.rnn.hidden_size, device=player_embeddings.device)
self.prev_player_states = (h_0, c_0)
if self.prev_inventory_states is None or self.prev_inventory_states[0].shape != (self.rnn.num_layers, inventory_embeddings.shape[0], self.rnn.hidden_size):
h_0 = torch.zeros(
self.rnn.num_layers, inventory_embeddings.shape[0], self.rnn.hidden_size, device=inventory_embeddings.device)
c_0 = torch.zeros(
self.rnn.num_layers, inventory_embeddings.shape[0], self.rnn.hidden_size, device=inventory_embeddings.device)
self.prev_inventory_states = (h_0, c_0)
if self.prev_hidden_states is None or self.prev_hidden_states[0].shape != (self.rnn.num_layers, hidden.shape[0], self.rnn.hidden_size):
h_0 = torch.zeros(
self.rnn.num_layers, hidden.shape[0], self.rnn.hidden_size, device=hidden.device)
c_0 = torch.zeros(
self.rnn.num_layers, hidden.shape[0], self.rnn.hidden_size, device=hidden.device)
self.prev_hidden_states = (h_0, c_0)
player_embeddings, self.prev_player_states = self.rnn(
player_embeddings, self.prev_player_states)
inventory_embeddings, self.prev_inventory_states = self.rnn(
inventory_embeddings, self.prev_inventory_states)

player_embeddings = self.attn(player_embeddings)
inventory_embeddings = self.attn(inventory_embeddings)

hidden, self.prev_hidden_states = self.rnn(
hidden, self.prev_hidden_states)
hidden = self.attn(hidden)
hidden = hidden.squeeze(1)
# Concat before and after attention and use MLP (advise communicaion) to get same shape as before
Expand Down
11 changes: 11 additions & 0 deletions results/task.csv
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,14 @@ Task_AttainSkill_(skill:Mage_level:10_num_agent:1)_reward_to:agent,0,0.033854166
Task_AttainSkill_(skill:Range_level:10_num_agent:1)_reward_to:agent,0,0.03732638888888889,
Task_AttainSkill_(skill:Fishing_level:10_num_agent:1)_reward_to:agent,0,0.044270833333333336,
Task_AttainSkill_(skill:Herbalism_level:10_num_agent:1)_reward_to:agent,0,0.03298611111111111,
Task_CountEvent_(event:EARN_GOLD_N:10)_reward_to:agent,1,0.26171875,91.0
Task_CountEvent_(event:PLAYER_KILL_N:10)_reward_to:agent,2,0.3265625,81.0
Task_CountEvent_(event:PLAYER_KILL_N:10)_reward_to:agent,28,0.459375,186.28571428571428
Task_CountEvent_(event:PLAYER_KILL_N:10)_reward_to:agent,2,0.3265625,81.0
Task_CountEvent_(event:PLAYER_KILL_N:10)_reward_to:agent,2,0.3265625,81.0
Task_CountEvent_(event:PLAYER_KILL_N:10)_reward_to:agent,28,0.459375,186.28571428571428
Task_CountEvent_(event:PLAYER_KILL_N:10)_reward_to:agent,2,0.3265625,81.0
Task_CountEvent_(event:PLAYER_KILL_N:10)_reward_to:agent,2,0.3265625,81.0
Task_CountEvent_(event:PLAYER_KILL_N:10)_reward_to:agent,2,0.3265625,81.0
Task_CountEvent_(event:PLAYER_KILL_N:10)_reward_to:agent,2,0.3265625,81.0
Task_CountEvent_(event:PLAYER_KILL_N:10)_reward_to:agent,1,0.27265625000000004,54.0
Loading

0 comments on commit 8a1cf2f

Please sign in to comment.