Skip to content

Commit

Permalink
reward functions are now split out from agents
Browse files Browse the repository at this point in the history
  • Loading branch information
mginoya committed Nov 7, 2023
1 parent 6fcb36e commit cbcdbc5
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 27 deletions.
62 changes: 37 additions & 25 deletions alfredo/agents/A1/alfredo_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from jax import numpy as jp

from alfredo.tools import compose_scene
from alfredo.rewards import rConstant
from alfredo.rewards import rHealthy_simple_z
from alfredo.rewards import rSpeed_X
from alfredo.rewards import rControl_act_ss

class Alfredo(PipelineEnv):
# pyformat: disable
Expand Down Expand Up @@ -42,7 +46,8 @@ def __init__(
del kwargs["agent_xml_path"]

sys = mjcf.loads(xml_scene)


# this is vestigial - get rid of this someday soon
if "scene_xml_path" in kwargs:
path = kwargs["scene_xml_path"]
del kwargs["scene_xml_path"]
Expand Down Expand Up @@ -131,44 +136,51 @@ def step(self, state: State, action: jp.ndarray) -> State:

com_before, *_ = self._com(prev_pipeline_state)
com_after, *_ = self._com(pipeline_state)
a_velocity = (com_after - com_before) / self.dt

reward_vel = math.safe_norm(a_velocity)
forward_reward = self._forward_reward_weight * a_velocity[0] # * reward_vel
ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

min_z, max_z = self._healthy_z_range
is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, x=0.0, y=1.0)
is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, x=0.0, y=is_healthy)

if self._terminate_when_unhealthy:
healthy_reward = self._healthy_reward
else:
healthy_reward = self._healthy_reward * is_healthy

reward = healthy_reward - ctrl_cost + forward_reward

done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
x_speed_reward = rSpeed_X(self.sys,
state.pipeline_state,
CoM_prev=com_before,
CoM_now=com_after,
dt=self.dt,
weight=self._forward_reward_weight)

ctrl_cost = rControl_act_ss(self.sys,
state.pipeline_state,
action,
weight=-self._ctrl_cost_weight)

healthy_reward = rHealthy_simple_z(self.sys,
state.pipeline_state,
self._healthy_z_range,
early_terminate=self._terminate_when_unhealthy,
weight=self._healthy_reward,
focus_idx_range=(0, 2))

reward = healthy_reward[0] + ctrl_cost + x_speed_reward[0]

done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0

state.metrics.update(
reward_ctrl=-ctrl_cost,
reward_alive=healthy_reward,
reward_velocity=forward_reward,
reward_ctrl=ctrl_cost,
reward_alive=healthy_reward[0],
reward_velocity=x_speed_reward[0],
agent_x_position=com_after[0],
agent_y_position=com_after[1],
agent_x_velocity=a_velocity[0],
agent_y_velocity=a_velocity[1],
agent_x_velocity=x_speed_reward[1],
agent_y_velocity=x_speed_reward[2],
)

return state.replace(
pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
)

def _get_obs(self, pipeline_state: base.State, action: jp.ndarray) -> jp.ndarray:
"""Observes humanoid body position, velocities, and angles."""
"""Observes Alfredo's body position, velocities, and angles."""

a_positions = pipeline_state.q
a_velocities = pipeline_state.qd
#print(f"a_positions = {a_positions}")
#print(f"a_velocities = {a_velocities}")

if self._exclude_current_positions_from_observation:
a_positions = a_positions[2:]
Expand Down Expand Up @@ -205,7 +217,7 @@ def _get_obs(self, pipeline_state: base.State, action: jp.ndarray) -> jp.ndarray
)

def _com(self, pipeline_state: base.State) -> jp.ndarray:
"""Computes Center of Mass of the Humanoid"""
"""Computes Center of Mass of Alfredo"""

inertia = self.sys.link.inertia

Expand Down
4 changes: 4 additions & 0 deletions alfredo/rewards/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .rConstant import *
from .rSpeed import *
from .rHealthy import *
from .rControl import *
15 changes: 15 additions & 0 deletions alfredo/rewards/rConstant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Tuple

import jax
from brax import actuator, base, math
from brax.envs import PipelineEnv, State
from brax.io import mjcf
from etils import epath
from jax import numpy as jp

def rConstant(sys: base.System,
pipeline_state: base.State,
weight=1.0,
focus_idx_range=(1, -1)) -> jp.ndarray:

return jp.array([weight])
18 changes: 18 additions & 0 deletions alfredo/rewards/rControl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Tuple

import jax
from brax import actuator, base, math
from brax.envs import PipelineEnv, State
from brax.io import mjcf
from etils import epath
from jax import numpy as jp

def rControl_act_ss(sys: base.System,
pipeline_state: base.State,
action: jp.ndarray,
weight=1.0,
focus_idx_range=(1, -1)) -> jp.ndarray:

ctrl_cost = weight * jp.sum(jp.square(action))

return ctrl_cost
31 changes: 31 additions & 0 deletions alfredo/rewards/rHealthy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Tuple

import jax
from brax import actuator, base, math
from brax.envs import PipelineEnv, State
from brax.io import mjcf
from etils import epath
from jax import numpy as jp

def rHealthy_simple_z(sys: base.System,
pipeline_state: base.State,
z_range: Tuple,
early_terminate: True,
weight=1.0,
focus_idx_range=(1, -1)) -> jp.ndarray:

min_z, max_z = z_range
focus_s = focus_idx_range[0]
focus_e = focus_idx_range[-1]

focus_x_pos = pipeline_state.x.pos[focus_s, focus_e]

is_healthy = jp.where(focus_x_pos < min_z, x=0.0, y=1.0)
is_healthy = jp.where(focus_x_pos > max_z, x=0.0, y=is_healthy)

if early_terminate:
hr = weight
else:
hr = weight * is_healthy

return jp.array([hr, is_healthy])
44 changes: 44 additions & 0 deletions alfredo/rewards/rSpeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Tuple

import jax
from brax import actuator, base, math
from brax.envs import PipelineEnv, State
from brax.io import mjcf
from etils import epath
from jax import numpy as jp

def rSpeed_X(sys: base.System,
pipeline_state: base.State,
CoM_prev: jp.ndarray,
CoM_now: jp.ndarray,
dt,
weight=1.0,
focus_idx_range=(1, -1)) -> jp.ndarray:


velocity = (CoM_now - CoM_prev) / dt

focus_s = focus_idx_range[0]
focus_e = focus_idx_range[-1]

sxr = weight * velocity[0]

return jp.array([sxr, velocity[0], velocity[1]])

def rSpeed_Y(sys: base.System,
pipeline_state: base.State,
CoM_prev: jp.ndarray,
CoM_now: jp.ndarray,
dt,
weight=1.0,
focus_idx_range=(1, -1)) -> jp.ndarray:


velocity = (CoM_now - CoM_prev) / dt

focus_s = focus_idx_range[0]
focus_e = focus_idx_range[-1]

syr = weight * velocity[1]

return jp.array([syr, velocity[0], velocity[1]])
4 changes: 2 additions & 2 deletions experiments/Alfredo-simulate-step/one_physics_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@

state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

print(f"Alfredo brax env dir: {dir(env)}")
print(f"state: {state}")
#print(f"Alfredo brax env dir: {dir(env)}")
#print(f"state: {state}")

com = env._com(state.pipeline_state)
obs = env._get_obs(state.pipeline_state, jp.zeros(env.action_size))
Expand Down

0 comments on commit cbcdbc5

Please sign in to comment.