Skip to content

Commit

Permalink
[gym_jiminy/common] Add PD controller mode to force instantaneous sta…
Browse files Browse the repository at this point in the history
…te update. (#847)

* [gym_jiminy/common] Add aggregated 'terminated', 'truncated' info key for composite env
* [gym_jiminy/common] Add PD controller mode to force instantaneous state update.
* [gym_jiminy/common] More robust projected support polygon computation.
* [gym_jiminy/rllib] Log num steps instead of num iterations in abscisse. Robust eval to unavailable log file.
* [gym_jiminy/rllib] Update acrobot example.
  • Loading branch information
duburcqa authored Dec 8, 2024
1 parent 2298708 commit 58e9b39
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 62 deletions.
11 changes: 8 additions & 3 deletions python/gym_jiminy/common/gym_jiminy/common/bases/compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ class EpisodeState(IntEnum):
"""


# Define proxies for fast lookup
_CONTINUED, _TERMINATED, _TRUNCATED = ( # pylint: disable=invalid-name
EpisodeState)


class AbstractTerminationCondition(metaclass=ABCMeta):
"""Abstract class from which all termination conditions must derived.
Expand Down Expand Up @@ -522,11 +527,11 @@ def __call__(self, info: InfoType) -> Tuple[bool, bool]:
info[self.name] = termination_info
else:
if is_terminated:
episode_state = EpisodeState.TERMINATED
episode_state = _TERMINATED
elif is_truncated:
episode_state = EpisodeState.TRUNCATED
episode_state = _TRUNCATED
else:
episode_state = EpisodeState.CONTINUED
episode_state = _CONTINUED
info[self.name] = episode_state

# Returning terminated and truncated flags
Expand Down
29 changes: 23 additions & 6 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,13 +721,30 @@ def has_terminated(self, info: InfoType) -> Tuple[bool, bool]:
"""
# Call unwrapped environment implementation
terminated, truncated = self.env.has_terminated(info)

# Evaluate conditions one-by-one as long as none has been triggered
for termination in self.terminations:
if terminated or truncated:
break
if terminated or truncated:
if terminated:
assert "terminated" not in info
info["terminated"] = -1
else:
assert "truncated" not in info
info["terminated"] = -1
return terminated, truncated

# Evaluate conditions one-by-one as long as none has been triggered.
# Termination condition information are aggregated under a single key.
# Termination conditions are evaluated in order, matching constructor
# arguments, with short-circuit mechanism to skip subsequent evaluation
# as soon as one condition is triggered.
for i, termination in enumerate(self.terminations):
terminated, truncated = termination(info)

if terminated:
assert "terminated" not in info
info["terminated"] = i
break
if truncated:
assert "truncated" not in info
info["truncated"] = i
break
return terminated, truncated

def compute_command(self, action: Act, command: np.ndarray) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ def mahony_filter(q: np.ndarray,
v_x, v_y, v_z = compute_tilt_from_quat(q)

# Compute the angular velocity using Explicit Complementary Filter:
# omega_mes = v_a_hat x v_a, where x is the cross product.
# omega_mes = (- v_a) x v_a_hat, where x is the cross product.
v_x_hat, v_y_hat, v_z_hat = acc / EARTH_SURFACE_GRAVITY
omega_mes = np.stack((
v_y_hat * v_z - v_z_hat * v_y,
v_z_hat * v_x - v_x_hat * v_z,
v_x_hat * v_y - v_y_hat * v_x), 0)
omega[:] = gyro - bias_hat + kp * omega_mes
v_x_hat * v_y - v_y_hat * v_x), 0) # eq. 32c
omega[:] = gyro - bias_hat + kp * omega_mes # eq. 32a (right hand)

# Early return if there is no IMU motion
if (np.abs(omega) < 1e-6).all():
Expand All @@ -86,13 +86,13 @@ def mahony_filter(q: np.ndarray,
q_y * p_w + q_z * p_x + q_w * p_y - q_x * p_z,
q_z * p_w - q_y * p_x + q_x * p_y + q_w * p_z,
q_w * p_w - q_x * p_x - q_y * p_y - q_z * p_z,
)
) # eq. 32a (left hand)

# First order quaternion normalization to prevent compounding of errors
q *= (3.0 - np.sum(np.square(q), 0)) / 2

# Update Gyro bias
bias_hat -= dt * ki * omega_mes
bias_hat -= ki * dt * omega_mes # eq. 32b


@nb.jit(nopython=True, cache=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def pd_adapter(action: np.ndarray,
command_state_lower: np.ndarray,
command_state_upper: np.ndarray,
dt: float,
is_instantaneous: bool,
out: np.ndarray) -> None:
"""Compute the target motor accelerations that must be held constant for a
given time interval in order to reach the desired value of some derivative
Expand All @@ -191,6 +192,8 @@ def pd_adapter(action: np.ndarray,
:param order: Derivative order of the position associated with the action.
:param command_state: Current command state, namely, all the derivatives of
the target motors positions up to acceleration order.
If 'is_instantaneous=True', then it will be updated
in-place.
:param command_state_lower: Lower bound of the command state that must be
satisfied at all cost.
:param command_state_upper: Upper bound of the command state that must be
Expand All @@ -200,21 +203,26 @@ def pd_adapter(action: np.ndarray,
:param out: Pre-allocated memory to store the target motor accelerations.
"""
# Update command accelerations based on the action and its derivative order
if order == 2:
# The action corresponds to the command motor accelerations
out[:] = action
if is_instantaneous:
# Update the command state directly
if order == 0:
command_state[0] = action
command_state[1] = 0.0
else:
command_state[1] = action
out[:] = 0.0
else:
if order == 0:
# Compute command velocity
velocity = (action - command_state[0]) / dt

# Clip command velocity
velocity = np.minimum(np.maximum(
velocity, command_state_lower[1]), command_state_upper[1])
else:
# The action corresponds to the command motor velocities
velocity = action

# Clip command velocity
velocity = np.minimum(np.maximum(
velocity, command_state_lower[1]), command_state_upper[1])

# Compute command acceleration
out[:] = (velocity - command_state[1]) / dt

Expand Down Expand Up @@ -520,7 +528,8 @@ def __init__(self,
env: InterfaceJiminyEnv[BaseObs, np.ndarray],
*,
update_ratio: int = -1,
order: int = 1) -> None:
order: int = 1,
is_instantaneous: bool = False) -> None:
"""
:param update_ratio: Ratio between the update period of the controller
and the one of the subsequent controller. -1 to
Expand All @@ -529,6 +538,11 @@ def __init__(self,
:param order: Derivative order of the action. It accepts position or
velocity (respectively 0 or 1).
Optional: 1 by default.
:param is_instantaneous: Whether to consider that the command state
must be updated instantaneously, breaking
continuity of higher-order derivatives, or
continuously by updating the target
acceleration instead.
"""
# Make sure that the specified derivative order is valid
assert order in (0, 1), "Derivative order out-of-bounds"
Expand All @@ -542,6 +556,7 @@ def __init__(self,

# Backup some user argument(s)
self.order = order
self.is_instantaneous = is_instantaneous

# Define some proxies for convenience
self._pd_controller = controller
Expand Down Expand Up @@ -579,4 +594,5 @@ def compute_command(self, action: np.ndarray, command: np.ndarray) -> None:
self._pd_controller._command_state_lower,
self._pd_controller._command_state_upper,
self.control_dt,
self.is_instantaneous,
command)
2 changes: 1 addition & 1 deletion python/gym_jiminy/common/gym_jiminy/common/envs/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,7 @@ def has_terminated(self, info: InfoType) -> Tuple[bool, bool]:
"No simulation running. Please start one before calling this "
"method.")

# Check if the observation is out-of-bounds in debug mode only
# Check if the observation is out-of-bounds
truncated = not self._contains_observation()

return False, truncated
Expand Down
34 changes: 12 additions & 22 deletions python/gym_jiminy/examples/rllib/acrobot_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ray
from ray.tune.registry import register_env
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig

from gym_jiminy.rllib.ppo import PPOConfig
from gym_jiminy.rllib.utilities import (initialize,
Expand Down Expand Up @@ -218,30 +219,19 @@

# ================== Configure policy and value networks ==================

# Default model configuration
model_config = deepcopy(MODEL_DEFAULTS)

# Fully-connected network settings
model_config.update(
# Nonlinearity for built-in fully connected net
fcnet_activation="tanh",
# Number of hidden layers for fully connected net
fcnet_hiddens=[64, 64],
# The last half of the output layer does not dependent on the input
free_log_std=True,
# Whether to share layers between the policy and value function
vf_share_layers=False
)

# Number of preceeding steps incl. current involved in policy computations
algo_config.env_runners(
episode_lookback_horizon = 1,
)

# Model settings
algo_config.training(
algo_config.rl_module(
# Policy model configuration
model=model_config
model_config=DefaultModelConfig(
# Number of hidden layers for fully connected net
fcnet_hiddens=[64, 64],
# Nonlinearity for built-in fully connected net
fcnet_activation="tanh",
# Whether to share layers between the policy and value function
vf_share_layers=False,
# The last half of the output layer does not dependent on the input
free_log_std=True,
)
)

# Exploration settings.
Expand Down
29 changes: 18 additions & 11 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def train(algo_config: AlgorithmConfig,
result = algo.train()

# Log results
iter_num = result[TRAINING_ITERATION]
step_num = result[NUM_ENV_STEPS_SAMPLED_LIFETIME]
if file_writer is not None:
# Flatten result dict after excluding irrelevant special keys
masked_fields = (
Expand All @@ -706,7 +706,7 @@ def train(algo_config: AlgorithmConfig,
full_attr = "/".join(("ray", "tune", attr))
try:
file_writer.add_scalar(
full_attr, value, global_step=iter_num)
full_attr, value, global_step=step_num)
scalar_tags.append(full_attr)
continue
except (TypeError, AssertionError, NotImplementedError):
Expand All @@ -716,19 +716,19 @@ def train(algo_config: AlgorithmConfig,
# Assuming single image
if value.ndim == 3:
file_writer.add_image(
full_attr, value, global_step=iter_num)
full_attr, value, global_step=step_num)
continue

# Assuming batch of images
if value.ndim == 4:
file_writer.add_images(
full_attr, value, global_step=iter_num)
full_attr, value, global_step=step_num)
continue

# Assuming video with arbitrary FPS
if value.ndim == 5:
file_writer.add_video(
full_attr, value, global_step=iter_num, fps=20)
full_attr, value, fps=20, global_step=step_num)
continue

# In last resort, try to log the variable as an histogram
Expand All @@ -737,7 +737,7 @@ def train(algo_config: AlgorithmConfig,
continue
try:
file_writer.add_histogram(
full_attr, value, global_step=iter_num)
full_attr, value, global_step=step_num)
continue
except (ValueError, TypeError):
pass
Expand Down Expand Up @@ -801,6 +801,7 @@ def train(algo_config: AlgorithmConfig,
print(" - ".join(msg_data))

# Backup the policy
iter_num = result[TRAINING_ITERATION]
if checkpoint_interval > 0 and iter_num % checkpoint_interval == 0:
algo.save(os.path.join(logdir, f"checkpoint_{iter_num:06d}"))

Expand Down Expand Up @@ -1220,12 +1221,18 @@ def evaluate_from_algo(algo: Algorithm,
all_returns = np.array([
episode.get_return() for episode in all_episodes])
idx_worst, idx_best = np.argsort(all_returns)[[0, -1]]
log_labels, log_paths = ("best", "worst")[:num_episodes], []
for suffix, idx in zip(log_labels, (idx_best, idx_worst)):
log_labels, log_paths = [], []
for label, idx in (
("best", idx_best), ("worst", idx_worst))[:num_episodes]:
ext = Path(all_log_paths[idx]).suffix
log_path = f"{algo.logdir}/iter_{algo.iteration}-{suffix}{ext}"
shutil.move(all_log_paths[idx], log_path)
log_paths.append(log_path)
log_path = f"{algo.logdir}/iter_{algo.iteration}-{label}{ext}"
try:
shutil.move(all_log_paths[idx], log_path)
except FileNotFoundError:
LOGGER.warning("Failed to save log file during evaluation.")
else:
log_paths.append(log_path)
log_labels.append(label)

# Replay and/or record a video of the best and worst trials if requested.
# Async to enable replaying and recording while training keeps going.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass

import numpy as np
from scipy.spatial import ConvexHull
from scipy.spatial import ConvexHull, QhullError

from jiminy_py.core import ( # pylint: disable=no-name-in-module
multi_array_copyto)
Expand Down Expand Up @@ -124,11 +124,15 @@ def initialize(self) -> None:
# separately rather than all at once.
candidate_xy_refs: List[np.ndarray] = []
for positions in contact_positions:
convhull = ConvexHull(np.stack(positions, axis=0))
candidate_indices = set(
range(len(positions))).intersection(convhull.vertices)
candidate_xy_refs += (
positions[j][:2] for j in candidate_indices)
try:
convhull = ConvexHull(np.stack(positions, axis=0))
candidate_indices = set(
range(len(positions))).intersection(convhull.vertices)
candidate_xy_refs += (
positions[j][:2] for j in candidate_indices)
except QhullError:
# Assuming all the candidate points are part of the convex hull
candidate_xy_refs += (position[:2] for position in positions)
self._candidate_xy_refs = tuple(candidate_xy_refs)

# Allocate memory for stacked position of candidate contact points.
Expand Down

0 comments on commit 58e9b39

Please sign in to comment.