Skip to content

Commit

Permalink
commiting so I can branch off from here and revert to joystick contro…
Browse files Browse the repository at this point in the history
…l for further experimentation
  • Loading branch information
mginoya committed Aug 28, 2024
1 parent 2d72763 commit d08e692
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
4 changes: 2 additions & 2 deletions alfredo/agents/aant/aant.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,12 @@ def step(self, state: State, action: jax.Array) -> State:
ctrl_cost = rControl_act_ss(self.sys,
pipeline_state,
action,
weight=0.0)
weight=-self._ctrl_cost_weight)

torque_cost = rTorques(self.sys,
pipeline_state,
action,
weight=0.0)
weight=-0.003)

upright_reward = rUpright(self.sys,
pipeline_state,
Expand Down
9 changes: 6 additions & 3 deletions experiments/AAnt-locomotion/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"backend": "positional",
"seed": 1,
"len_training": 1_500_000,
"num_evals": 200,
"num_evals": 500,
"num_envs": 2048,
"batch_size": 2048,
"num_minibatches": 8,
Expand Down Expand Up @@ -81,7 +81,10 @@ def progress(num_steps, metrics):

scenes_fp = os.path.dirname(scenes.__file__)

env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]
env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
f"{scenes_fp}/flatworld/flatworld_A1_env.xml",
f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]

# make and save initial ppo_network
key = jax.random.PRNGKey(wandb.config.seed)
Expand Down Expand Up @@ -115,7 +118,7 @@ def progress(num_steps, metrics):
# ============================
# Training & Saving Params
# ============================
i = 0
i = 8

for p in env_xml_paths:

Expand Down
2 changes: 1 addition & 1 deletion experiments/AAnt-locomotion/vis_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
#yaw_vel = 0.0 # rad/s
#jcmd = jp.array([x_vel, y_vel, yaw_vel])

wcmd = jp.array([-10.0, 10.0])
wcmd = jp.array([10.0, 10.0])

# generate policy rollout
for _ in range(episode_length):
Expand Down

0 comments on commit d08e692

Please sign in to comment.