Skip to content

Commit

Permalink
Merge pull request #71 from kakaoenterprise/config/rnd_montezuma
Browse files Browse the repository at this point in the history
fix error in rnd
  • Loading branch information
ramanuzan authored Dec 1, 2021
2 parents d56f689 + a00ba74 commit 251a0d7
Showing 1 changed file with 48 additions and 48 deletions.
96 changes: 48 additions & 48 deletions jorldy/core/network/rnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,34 @@ def normalize_obs(obs, m, v):


def define_mlp_head_weight(instance, D_in, D_hidden, feature_size):
instance.fc1_p_mlp = torch.nn.Linear(D_in, D_hidden)
instance.fc2_p_mlp = torch.nn.Linear(D_hidden, feature_size)
instance.fc1_predict_mlp = torch.nn.Linear(D_in, D_hidden)
instance.fc2_predict_mlp = torch.nn.Linear(D_hidden, feature_size)

instance.fc1_t_mlp = torch.nn.Linear(D_in, D_hidden)
instance.fc2_t_mlp = torch.nn.Linear(D_hidden, feature_size)
instance.fc1_target_mlp = torch.nn.Linear(D_in, D_hidden)
instance.fc2_target_mlp = torch.nn.Linear(D_hidden, feature_size)


def define_mlp_batch_norm(instance, D_hidden, feature_size):
instance.bn1_p_mlp = torch.nn.BatchNorm1d(D_hidden)
instance.bn2_p_mlp = torch.nn.BatchNorm1d(feature_size)
instance.bn1_predict_mlp = torch.nn.BatchNorm1d(D_hidden)
instance.bn2_predict_mlp = torch.nn.BatchNorm1d(feature_size)

instance.bn1_t_mlp = torch.nn.BatchNorm1d(D_hidden)
instance.bn2_t_mlp = torch.nn.BatchNorm1d(feature_size)
instance.bn1_target_mlp = torch.nn.BatchNorm1d(D_hidden)
instance.bn2_target_mlp = torch.nn.BatchNorm1d(feature_size)


def mlp_head(instance, s_next):
if instance.batch_norm:
p = F.relu(instance.bn1_p_mlp(instance.fc1_p_mlp(s_next)))
p = F.relu(instance.bn2_p_mlp(instance.fc2_p_mlp(p)))
p = F.relu(instance.bn1_predict_mlp(instance.fc1_predict_mlp(s_next)))
p = F.relu(instance.bn2_predict_mlp(instance.fc2_predict_mlp(p)))

t = F.relu(instance.bn1_t_mlp(instance.fc1_t_mlp(s_next)))
t = F.relu(instance.bn2_t_mlp(instance.fc2_t_mlp(t)))
t = F.relu(instance.bn1_target_mlp(instance.fc1_target_mlp(s_next)))
t = F.relu(instance.bn2_target_mlp(instance.fc2_target_mlp(t)))
else:
p = F.relu(instance.fc1_p_mlp(s_next))
p = F.relu(instance.fc2_p_mlp(p))
p = F.relu(instance.fc1_predict_mlp(s_next))
p = F.relu(instance.fc2_predict_mlp(p))

t = F.relu(instance.fc1_t_mlp(s_next))
t = F.relu(instance.fc2_t_mlp(t))
t = F.relu(instance.fc1_target_mlp(s_next))
t = F.relu(instance.fc2_target_mlp(t))

return p, t

Expand All @@ -50,57 +50,57 @@ def define_conv_head_weight(instance, D_in):
feature_size = 64 * dim3[0] * dim3[1]

# Predictor Networks
instance.conv1_p = torch.nn.Conv2d(
instance.conv1_predict = torch.nn.Conv2d(
in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4
)
instance.conv2_p = torch.nn.Conv2d(
instance.conv2_predict = torch.nn.Conv2d(
in_channels=32, out_channels=64, kernel_size=4, stride=2
)
instance.conv3_p = torch.nn.Conv2d(
instance.conv3_predict = torch.nn.Conv2d(
in_channels=64, out_channels=64, kernel_size=3, stride=1
)

# Target Networks
instance.conv1_t = torch.nn.Conv2d(
instance.conv1_target = torch.nn.Conv2d(
in_channels=D_in[0], out_channels=32, kernel_size=8, stride=4
)
instance.conv2_t = torch.nn.Conv2d(
instance.conv2_target = torch.nn.Conv2d(
in_channels=32, out_channels=64, kernel_size=4, stride=2
)
instance.conv3_t = torch.nn.Conv2d(
instance.conv3_target = torch.nn.Conv2d(
in_channels=64, out_channels=64, kernel_size=3, stride=1
)

return feature_size


def define_conv_batch_norm(instance):
instance.bn1_p_conv = torch.nn.BatchNorm2d(32)
instance.bn2_p_conv = torch.nn.BatchNorm2d(64)
instance.bn3_p_conv = torch.nn.BatchNorm2d(64)
instance.bn1_predict_conv = torch.nn.BatchNorm2d(32)
instance.bn2_predict_conv = torch.nn.BatchNorm2d(64)
instance.bn3_predict_conv = torch.nn.BatchNorm2d(64)

instance.bn1_t_conv = torch.nn.BatchNorm2d(32)
instance.bn2_t_conv = torch.nn.BatchNorm2d(64)
instance.bn3_t_conv = torch.nn.BatchNorm2d(64)
instance.bn1_target_conv = torch.nn.BatchNorm2d(32)
instance.bn2_target_conv = torch.nn.BatchNorm2d(64)
instance.bn3_target_conv = torch.nn.BatchNorm2d(64)


def conv_head(instance, s_next):
if instance.batch_norm:
p = F.relu(instance.bn1_p_conv(instance.conv1_p(s_next)))
p = F.relu(instance.bn2_p_conv(instance.conv2_p(p)))
p = F.relu(instance.bn3_p_conv(instance.conv3_p(p)))
p = F.relu(instance.bn1_predict_conv(instance.conv1_predict(s_next)))
p = F.relu(instance.bn2_predict_conv(instance.conv2_predict(p)))
p = F.relu(instance.bn3_predict_conv(instance.conv3_predict(p)))

t = F.relu(instance.bn1_t_conv(instance.conv1_t(s_next)))
t = F.relu(instance.bn2_t_conv(instance.conv2_t(t)))
t = F.relu(instance.bn3_t_conv(instance.conv3_t(t)))
t = F.relu(instance.bn1_target_conv(instance.conv1_target(s_next)))
t = F.relu(instance.bn2_target_conv(instance.conv2_target(t)))
t = F.relu(instance.bn3_target_conv(instance.conv3_target(t)))
else:
p = F.relu(instance.conv1_p(s_next))
p = F.relu(instance.conv2_p(p))
p = F.relu(instance.conv3_p(p))
p = F.relu(instance.conv1_predict(s_next))
p = F.relu(instance.conv2_predict(p))
p = F.relu(instance.conv3_predict(p))

t = F.relu(instance.conv1_t(s_next))
t = F.relu(instance.conv2_t(t))
t = F.relu(instance.conv3_t(t))
t = F.relu(instance.conv1_target(s_next))
t = F.relu(instance.conv2_target(t))
t = F.relu(instance.conv3_target(t))

p = p.view(p.size(0), -1)
t = t.view(t.size(0), -1)
Expand All @@ -109,19 +109,19 @@ def conv_head(instance, s_next):


def define_fc_layers_weight(instance, feature_size, D_hidden):
instance.fc1_p = torch.nn.Linear(feature_size, D_hidden)
instance.fc2_p = torch.nn.Linear(D_hidden, D_hidden)
instance.fc3_p = torch.nn.Linear(D_hidden, D_hidden)
instance.fc1_predict = torch.nn.Linear(feature_size, D_hidden)
instance.fc2_predict = torch.nn.Linear(D_hidden, D_hidden)
instance.fc3_predict = torch.nn.Linear(D_hidden, D_hidden)

instance.fc1_t = torch.nn.Linear(feature_size, D_hidden)
instance.fc1_target = torch.nn.Linear(feature_size, D_hidden)


def fc_layers(instance, p, t):
p = F.relu(instance.fc1_p(p))
p = F.relu(instance.fc2_p(p))
p = instance.fc3_p(p)
p = F.relu(instance.fc1_predict(p))
p = F.relu(instance.fc2_predict(p))
p = instance.fc3_predict(p)

t = instance.fc1_t(t)
t = instance.fc1_target(t)

return p, t

Expand Down

0 comments on commit 251a0d7

Please sign in to comment.