Skip to content

Commit

Permalink
Merge pull request #163 from kywch/isaacgym
Browse files Browse the repository at this point in the history
Kyoung's isaac fixes
  • Loading branch information
jsuarez5341 authored Feb 11, 2025
2 parents 2cdd770 + 34872ea commit df7d9e8
Show file tree
Hide file tree
Showing 13 changed files with 487 additions and 128 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pufferlib/puffernet.c

# Raylib
raylib_wasm/
raylib*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
87 changes: 21 additions & 66 deletions clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import psutil
import time

from threading import Thread
from collections import defaultdict, deque

Expand Down Expand Up @@ -118,10 +119,7 @@ def evaluate(data):
actions = actions.cpu().numpy()
mask = torch.as_tensor(mask)# * policy.mask)
o = o if config.cpu_offload else o_device

state = data.vecenv.state
demo = data.vecenv.demo
experience.store(o, state, demo, value, actions, logprob, r, d, env_id, mask)
experience.store(o, value, actions, logprob, r, d, env_id, mask)

for i in info:
for k, v in pufferlib.utils.unroll_nested_dict(i):
Expand Down Expand Up @@ -157,62 +155,37 @@ def train(data):
losses = data.losses

with profile.train_misc:
update_obs_stats = getattr(data.policy, "update_obs_stats", None)

idxs = experience.sort_training_data()
dones_np = experience.dones_np[idxs]
values_np = experience.values_np[idxs]
rewards_np = experience.rewards_np[idxs]
experience.flatten_batch()
# TODO: bootstrap between segment bounds
advantages_np = compute_gae(dones_np, values_np,
rewards_np, config.gamma, config.gae_lambda)
experience.flatten_batch(advantages_np)

# Optimizing the policy and value network
total_minibatches = experience.num_minibatches * config.update_epochs
mean_pg_loss, mean_v_loss, mean_entropy_loss = 0, 0, 0
mean_old_kl, mean_kl, mean_clipfrac = 0, 0, 0

# Compute adversarial reward. Note: discriminator doesn't get
# updated as often this way, but GAE is more accurate
state = experience.state.view(experience.num_minibatches,
config.minibatch_size, experience.state.shape[-1])
adversarial_reward = torch.zeros(
experience.num_minibatches, config.minibatch_size).to(config.device)
'''
with torch.no_grad():
for mb in range(experience.num_minibatches):
disc_logits = data.policy.policy.discriminate(state[mb]).squeeze()
prob = 1 / (1 + torch.exp(-disc_logits))
adversarial_reward[mb] = -torch.log(torch.maximum(
1 - prob, torch.tensor(0.0001, device=config.device)))
'''

# TODO: Nans in adversarial reward and gae
adversarial_reward_np = adversarial_reward.cpu().numpy().ravel()
advantages_np = compute_gae(dones_np, values_np,
rewards_np + adversarial_reward_np, config.gamma, config.gae_lambda)
advantages = torch.as_tensor(advantages_np).to(config.device)
experience.b_advantages = advantages.reshape(experience.minibatch_rows,
experience.num_minibatches, experience.bptt_horizon).transpose(0, 1).reshape(
experience.num_minibatches, experience.minibatch_size)
experience.returns_np = advantages_np + experience.values_np
experience.b_returns = experience.b_advantages + experience.b_values

# Clamp action to [-1, 1]
experience.b_actions = torch.clamp(experience.b_actions, -1, 1)

for epoch in range(config.update_epochs):
lstm_state = None
for mb in range(experience.num_minibatches):
# TODO: bootstrap between segment bounds
with profile.train_misc:
obs = experience.b_obs[mb]
obs = obs.to(config.device)
state = experience.b_state[mb]
demo = experience.b_demo[mb].to(config.device)
atn = experience.b_actions[mb]
log_probs = experience.b_logprobs[mb]
val = experience.b_values[mb]
adv = experience.b_advantages[mb]
ret = experience.b_returns[mb]

with profile.train_forward:
if update_obs_stats is not None:
update_obs_stats(obs.reshape(-1, *data.vecenv.single_observation_space.shape))

if experience.lstm_h is not None:
_, newlogprob, entropy, newvalue, lstm_state = data.policy(
obs, state=lstm_state, action=atn)
Expand All @@ -223,14 +196,6 @@ def train(data):
action=atn,
)

# Discriminator loss
# BUG: Data shape is wrong for morph. State should have same shape as demo
disc_state = data.policy.policy.discriminate(state)
#disc_demo = data.policy.policy.discriminate(demo)
#disc_loss_agent = torch.nn.BCEWithLogitsLoss()(disc_state, torch.zeros_like(disc_state))
#disc_loss_demo = torch.nn.BCEWithLogitsLoss()(disc_demo, torch.ones_like(disc_demo))
#disc_loss = 0.5 * (disc_loss_agent + disc_loss_demo)

if config.device == 'cuda':
torch.cuda.synchronize()

Expand Down Expand Up @@ -271,7 +236,7 @@ def train(data):
v_loss = 0.5 * ((newvalue - ret) ** 2).mean()

entropy_loss = entropy.mean()
loss = pg_loss - config.ent_coef*entropy_loss + config.vf_coef*v_loss# + config.disc_coef*disc_loss
loss = pg_loss - config.ent_coef * entropy_loss + v_loss * config.vf_coef

with profile.learn:
data.optimizer.zero_grad()
Expand All @@ -288,7 +253,6 @@ def train(data):
losses.old_approx_kl += old_approx_kl.item() / total_minibatches
losses.approx_kl += approx_kl.item() / total_minibatches
losses.clipfrac += clipfrac.item() / total_minibatches
#losses.discriminator += disc_loss.item() / total_minibatches

if config.target_kl is not None:
if approx_kl > config.target_kl:
Expand Down Expand Up @@ -428,7 +392,6 @@ def make_losses():
approx_kl=0,
clipfrac=0,
explained_variance=0,
discriminator=0,
)

class Experience:
Expand All @@ -444,10 +407,6 @@ def __init__(self, batch_size, bptt_horizon, minibatch_size, obs_shape, obs_dtyp
obs_device = device if not pin else 'cpu'
self.obs=torch.zeros(batch_size, *obs_shape, dtype=obs_dtype,
pin_memory=pin, device=device if not pin else 'cpu')
self.demo=torch.zeros(batch_size, 358, dtype=obs_dtype,
pin_memory=pin, device=device if not pin else 'cpu')
self.state=torch.zeros(batch_size, 358, dtype=obs_dtype,
pin_memory=pin, device=device if not pin else 'cpu')
self.actions=torch.zeros(batch_size, *atn_shape, dtype=atn_dtype, pin_memory=pin)
self.logprobs=torch.zeros(batch_size, pin_memory=pin)
self.rewards=torch.zeros(batch_size, pin_memory=pin)
Expand Down Expand Up @@ -492,15 +451,13 @@ def __init__(self, batch_size, bptt_horizon, minibatch_size, obs_shape, obs_dtyp
def full(self):
return self.ptr >= self.batch_size

def store(self, obs, state, demo, value, action, logprob, reward, done, env_id, mask):
def store(self, obs, value, action, logprob, reward, done, env_id, mask):
# Mask learner and Ensure indices do not exceed batch size
ptr = self.ptr
indices = torch.where(mask)[0].numpy()[:self.batch_size - ptr]
end = ptr + len(indices)

self.obs[ptr:end] = obs.to(self.obs.device)[indices]
self.state[ptr:end] = state.to(self.state.device)[indices]
self.demo[ptr:end] = demo.to(self.demo.device)[indices]
self.values_np[ptr:end] = value.cpu().numpy()[indices]
self.actions_np[ptr:end] = action[indices]
self.logprobs_np[ptr:end] = logprob.cpu().numpy()[indices]
Expand All @@ -516,31 +473,29 @@ def sort_training_data(self):
self.b_idxs_obs = torch.as_tensor(idxs.reshape(
self.minibatch_rows, self.num_minibatches, self.bptt_horizon
).transpose(1,0,-1)).to(self.obs.device).long()
self.b_idxs_state = torch.as_tensor(idxs.reshape(
self.minibatch_rows, self.num_minibatches, self.bptt_horizon
).transpose(1,0,-1)).to(self.state.device).long()
self.b_idxs_demo = torch.as_tensor(idxs.reshape(
self.minibatch_rows, self.num_minibatches, self.bptt_horizon
).transpose(1,0,-1)).to(self.demo.device).long()
self.b_idxs = self.b_idxs_obs.to(self.device)
self.b_idxs_flat = self.b_idxs.reshape(
self.num_minibatches, self.minibatch_size)
self.sort_keys = []
return idxs

def flatten_batch(self):
def flatten_batch(self, advantages_np):
advantages = torch.as_tensor(advantages_np).to(self.device)
b_idxs, b_flat = self.b_idxs, self.b_idxs_flat
self.b_actions = self.actions.to(self.device, non_blocking=True)
self.b_logprobs = self.logprobs.to(self.device, non_blocking=True)
self.b_dones = self.dones.to(self.device, non_blocking=True)
self.b_values = self.values.to(self.device, non_blocking=True)
self.b_advantages = advantages.reshape(self.minibatch_rows,
self.num_minibatches, self.bptt_horizon).transpose(0, 1).reshape(
self.num_minibatches, self.minibatch_size)
self.returns_np = advantages_np + self.values_np
self.b_obs = self.obs[self.b_idxs_obs]
self.b_state = self.state[self.b_idxs_state]
self.b_demo = self.demo[self.b_idxs_demo]
self.b_actions = self.b_actions[b_idxs].contiguous()
self.b_logprobs = self.b_logprobs[b_idxs]
self.b_dones = self.b_dones[b_idxs]
self.b_values = self.b_values[b_flat]
self.b_returns = self.b_advantages + self.b_values

class Utilization(Thread):
def __init__(self, delay=1, maxlen=20):
Expand Down Expand Up @@ -816,4 +771,4 @@ def print_dashboard(env_name, utilization, global_step, epoch,
with console.capture() as capture:
console.print(dashboard)

print('\033[0;0H' + capture.get())
print('\033[0;0H' + capture.get())
16 changes: 9 additions & 7 deletions config/morph.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ policy_name = Policy
# rnn_name = Recurrent

[policy]
input_dim = 934
action_dim = 69
demo_dim = 358
hidden = 512
; input_dim = 934
; action_dim = 69
; demo_dim = 358
hidden_size = 2048

[env]
motion_file = "resources/morph/amass_train_take6_upright.pkl"
motion_file = "resources/morph/totalcapture_acting_poses.pkl"
has_self_collision = True
num_envs = 2048
#num_envs = 32
Expand All @@ -28,13 +28,15 @@ compile = False
norm_adv = True
target_kl = None

total_timesteps = 1_000_000_000
total_timesteps = 5_000_000_000
eval_timesteps = 100_000

num_workers = 1
num_envs = 1
batch_size = 65536
minibatch_size = 16384
; batch_size = 1024
; minibatch_size = 256

disc_coef = 5.0

Expand All @@ -50,4 +52,4 @@ vf_clip_coef = 0.2
max_grad_norm = 1.0
ent_coef = 0.0
learning_rate = 2e-5
checkpoint_interval = 1000
checkpoint_interval = 10000
5 changes: 3 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def make_policy(env, policy_cls, rnn_cls, args):
policy = rnn_cls(env, policy, **args['rnn'])
policy = pufferlib.cleanrl.RecurrentPolicy(policy)
else:
policy = pufferlib.cleanrl.Policy(policy)
if not isinstance(policy, pufferlib.cleanrl.Policy):
policy = pufferlib.cleanrl.Policy(policy)

return policy.to(args['train']['device'])

Expand Down Expand Up @@ -361,7 +362,7 @@ def train(args, make_env, policy_cls, rnn_cls, wandb,
' demo options. Shows valid args for your env and policy',
formatter_class=RichHelpFormatter, add_help=False)
parser.add_argument('--env', '--environment', type=str,
default='puffer_squared', help='Name of specific environment to run')
default='morph', help='Name of specific environment to run')
parser.add_argument('--mode', type=str, default='train',
choices='train eval evaluate sweep sweep-carbs autotune profile'.split())
parser.add_argument('--vec-overwork', action='store_true',
Expand Down
19 changes: 15 additions & 4 deletions pufferlib/environments/morph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
from .environment import env_creator

# try:
# import torch
# except ImportError:
# pass
# else:
# from .torch import Policy
# try:
# from .torch import Recurrent
# except:
# Recurrent = None

try:
import torch
import pufferlib.environments.morph.policy as torch
except ImportError:
pass
else:
from .torch import Policy
from .policy import Policy
try:
from .torch import Recurrent
from .policy import Recurrent
except:
Recurrent = None
Recurrent = None
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def __init__(self, env, input_dim, action_dim, demo_dim, hidden):
self.actor_mlp = nn.Sequential(
layer_init(nn.Linear(input_dim, hidden)),
nn.ReLU(),
layer_init(nn.Linear(hidden, hidden)),
nn.ReLU(),
)

'''
Expand Down Expand Up @@ -50,6 +52,9 @@ def __init__(self, env, input_dim, action_dim, demo_dim, hidden):
self.critic_mlp = nn.Sequential(
layer_init(nn.Linear(input_dim, hidden)),
nn.ReLU(),
layer_init(nn.Linear(hidden, hidden)),
nn.ReLU(),
layer_init(nn.Linear(hidden, 1)),
)

'''
Expand All @@ -68,14 +73,14 @@ def __init__(self, env, input_dim, action_dim, demo_dim, hidden):
nn.SiLU(),
)
'''
self.value = nn.Linear(hidden, 1)
# self.value = nn.Linear(hidden, 1)

### Discriminator
self._disc_mlp = nn.Sequential(
layer_init(nn.Linear(demo_dim, 1024)),
nn.ReLU(),
layer_init(nn.Linear(1024, hidden)),
layer_init(nn.Linear(demo_dim, hidden)),
nn.ReLU(),
# layer_init(nn.Linear(1024, hidden)),
# nn.ReLU(),
)
self._disc_logits = layer_init(torch.nn.Linear(hidden, 1))
self.obs_mean = None
Expand All @@ -90,7 +95,8 @@ def forward(self, observations):
-10.0, 10.0
)
hidden, lookup = self.encode_observations(observations)
actions, value = self.decode_actions(hidden, lookup)
actions, _ = self.decode_actions(hidden, lookup)
value = self.critic_mlp(observations)
return actions, value

def encode_observations(self, obs):
Expand All @@ -100,8 +106,10 @@ def decode_actions(self, hidden, lookup=None):
mu = self.mu(hidden)
std = torch.exp(self.sigma).expand_as(mu)
probs = torch.distributions.Normal(mu, std)
value = self.value(hidden)
return probs, value

# NOTE: Separate critic network takes input directly
# value = self.critic_mlp(hidden)
return probs, 0

def discriminate(self, amp_obs):
disc_mlp_out = self._disc_mlp(amp_obs)
Expand Down
Loading

0 comments on commit df7d9e8

Please sign in to comment.