Skip to content

Commit

Permalink
[gym/common] Plot state, action and features of all pipeline blocks.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Apr 2, 2024
1 parent a4485e7 commit 8a4e450
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def apply_safety_limits(command: np.ndarray,

return out


class MotorSafetyLimit(
BaseControllerBlock[np.ndarray, np.ndarray, BaseObsT, np.ndarray]):
"""Safety mechanism primarily designed to prevent hardware damage and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def integrate_zoh(state: np.ndarray,
state[0] += dt * state[1]



@nb.jit(nopython=True, cache=True, fastmath=True)
def pd_controller(q_measured: np.ndarray,
v_measured: np.ndarray,
Expand Down Expand Up @@ -423,7 +422,7 @@ def compute_command(self, action: np.ndarray) -> np.ndarray:
# Update the target motor accelerations based on the provided action
self._command_accel[:] = (
(action - self._command_state[1]) / self.control_dt
if self.order == 2 else action)
if self.order == 1 else action)

# Dead band to avoid slow drift of target at rest for evaluation only
if not self.env.is_training:
Expand Down
65 changes: 36 additions & 29 deletions python/gym_jiminy/common/gym_jiminy/common/envs/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,12 +1009,18 @@ def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
return self.simulator.render( # type: ignore[return-value]
return_rgb_array=self.render_mode == 'rgb_array')

def plot(self, **kwargs: Any) -> TabbedFigure:
def plot(self,
enable_block_states: bool = False,
**kwargs: Any) -> TabbedFigure:
"""Display common simulation data and action over time.
.. Note:
It adds "Action" tab on top of original `Simulator.plot`.
It adds tabs for the base environment action plus all blocks
((state, action) for controllers and (state, features) for
observers) on top of original `Simulator.plot`.
:param enable_block_states: Whether to display the internal state of
all blocks.
:param kwargs: Extra keyword arguments to forward to `simulator.plot`.
"""
# Call base implementation
Expand All @@ -1027,35 +1033,36 @@ def plot(self, **kwargs: Any) -> TabbedFigure:
"Nothing to plot. Please run a simulation before calling "
"`plot` method.")

# Extract action.
# If telemetry action fieldnames is a dictionary, it cannot be nested.
# In such a case, keys corresponds to subplots, and values are
# individual scalar data over time to be displayed to the same subplot.
t = log_vars["Global.Time"]
tab_data: Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]] = {}
action_fieldnames = self.log_fieldnames.get("action")
if action_fieldnames is None:
# It was impossible to register the action to the telemetry, likely
# because of incompatible dtype. Early return without adding tab.
return figure
if isinstance(action_fieldnames, dict):
for group, fieldnames in action_fieldnames.items():
if not isinstance(fieldnames, list):
LOGGER.error(
"Action space not supported by this method.")
return figure
tab_data[group] = {
# Plot all registered variables
for key, fielnames in self.log_fieldnames.items():
# Filter state if requested
if not enable_block_states and key.endswith(".state"):
continue

# Extract action hierarchical time series.
# Fieldnames stored in a dictionary cannot be nested. In such a
# case, keys corresponds to subplots, and values are individual
# scalar data over time to be displayed to the same subplot.
t = log_vars["Global.Time"]
tab_data: Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]] = {}
if isinstance(fielnames, dict):
for group, fieldnames in fielnames.items():
if not isinstance(fieldnames, list):
LOGGER.error(
"Action space not supported by this method.")
return figure
tab_data[group] = {
key.split(".", 2)[2]: value
for key, value in extract_variables_from_log(
log_vars, fieldnames, as_dict=True).items()}
elif isinstance(fielnames, list):
tab_data.update({
key.split(".", 2)[2]: value
for key, value in extract_variables_from_log(
log_vars, fieldnames, as_dict=True).items()}
elif isinstance(action_fieldnames, list):
tab_data.update({
key.split(".", 2)[2]: value
for key, value in extract_variables_from_log(
log_vars, action_fieldnames, as_dict=True).items()})

# Add action tab
figure.add_tab(" ".join(("Env", "Action")), t, tab_data)
log_vars, fielnames, as_dict=True).items()})

# Add action tab
figure.add_tab(key.replace(".", " "), t, tab_data)

# Return figure for convenience and consistency with Matplotlib
return figure
Expand Down
2 changes: 1 addition & 1 deletion python/gym_jiminy/common/gym_jiminy/common/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_fieldnames(structure: Union[gym.Space[DataNested], DataNested],
else:
# Tensor: basic numbering
fieldname = np.array([
".".join(map(str, filter(None, (*fieldname_path, i))))
".".join(map(str, (*filter(None, fieldname_path), i)))
for i in range(data.size)]).reshape(data.shape).tolist()
fieldnames.append(fieldname)

Expand Down

0 comments on commit 8a4e450

Please sign in to comment.