Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gym_jiminy/common] Remove 'post_fn' argument from drift term cond. #873

Merged
merged 7 commits into from
Jan 28, 2025
6 changes: 2 additions & 4 deletions .github/workflows/linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,13 @@ jobs:
if: matrix.BUILD_TYPE != 'Debug' && matrix.PYTHON_VERSION == '3.10'
run: |
# Generate stubs
# FIXME: stubgen does not work with Numpy 2.X
# FIXME: stubgen does not work with Numpy 2.X without option `--no-analysis`
# (see https://github.com/python/mypy/issues/17396)
"${PYTHON_EXECUTABLE}" -m pip install "numpy<2.0"
stubgen -p jiminy_py -o ${RootDir}/build/pypi/jiminy_py/src
stubgen -p jiminy_py -o ${RootDir}/build/pypi/jiminy_py/src --no-analysis
"${PYTHON_EXECUTABLE}" "${RootDir}/build_tools/stubgen.py" \
-o ${RootDir}/build/stubs --ignore-invalid=all jiminy_py
cp ${RootDir}/build/stubs/jiminy_py-stubs/core/__init__.pyi \
${RootDir}/build/pypi/jiminy_py/src/jiminy_py/core/core.pyi
"${PYTHON_EXECUTABLE}" -m pip install --upgrade "numpy>=1.24" numba torch

# Re-install jiminy with stubs
cd "${RootDir}/build"
Expand Down
6 changes: 2 additions & 4 deletions .github/workflows/macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,9 @@ jobs:

# Generate stubs
if [[ "${{ matrix.BUILD_TYPE }}" != 'Debug' && "${{ matrix.OS }}" != 'macos-13' ]] ; then
# FIXME: stubgen does not work with Numpy 2.X
# FIXME: stubgen does not work with Numpy 2.X without option `--no-analysis`
# (see https://github.com/python/mypy/issues/17396)
"${PYTHON_EXECUTABLE}" -m pip install "numpy<2.0"
stubgen -p jiminy_py -o ${RootDir}/build/pypi/jiminy_py/src
stubgen -p jiminy_py -o ${RootDir}/build/pypi/jiminy_py/src --no-analysis
# FIXME: Python 3.10 and Python 3.11 crashes when generating stubs without any backtrace...
if [[ "${{ matrix.PYTHON_VERSION }}" != '3.10' && "${{ matrix.PYTHON_VERSION }}" != '3.11' ]] ; then
# lldb --batch -o "settings set target.process.stop-on-exec false" \
Expand All @@ -155,7 +154,6 @@ jobs:
cp ${RootDir}/build/stubs/jiminy_py-stubs/core/__init__.pyi \
${RootDir}/build/pypi/jiminy_py/src/jiminy_py/core/core.pyi
fi
"${PYTHON_EXECUTABLE}" -m pip install --upgrade "numpy>=1.24" numba torch
fi

# Generate wheels
Expand Down
6 changes: 2 additions & 4 deletions .github/workflows/manylinux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,13 @@ jobs:
export LD_LIBRARY_PATH="$InstallDir/lib:$InstallDir/lib64:/usr/local/lib"

# Generate stubs
# FIXME: stubgen does not work with Numpy 2.X
# FIXME: stubgen does not work with Numpy 2.X without option `--no-analysis`
# (see https://github.com/python/mypy/issues/17396)
"${PYTHON_EXECUTABLE}" -m pip install "numpy<2.0"
stubgen -p jiminy_py -o $RootDir/build/pypi/jiminy_py/src
stubgen -p jiminy_py -o $RootDir/build/pypi/jiminy_py/src --no-analysis
"${PYTHON_EXECUTABLE}" "$RootDir/build_tools/stubgen.py" \
-o $RootDir/build/stubs --ignore-invalid=all jiminy_py
\cp $RootDir/build/stubs/jiminy_py-stubs/core/__init__.pyi \
$RootDir/build/pypi/jiminy_py/src/jiminy_py/core/core.pyi
"${PYTHON_EXECUTABLE}" -m pip install --upgrade "numpy>=1.24" numba

# Generate wheels
cd "$RootDir/build"
Expand Down
18 changes: 7 additions & 11 deletions .github/workflows/win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,13 @@ jobs:
${env:Path} += ";$InstallDir/lib"

# Generate stubs
if ("${{ matrix.PYTHON_VERSION }}" -ne "3.13") {
# FIXME: stubgen does not work with Numpy 2.X
# (see https://github.com/python/mypy/issues/17396)
python -m pip install "numpy<2.0"
stubgen -p jiminy_py -o $RootDir/build/pypi/jiminy_py/src
python "$RootDir/build_tools/stubgen.py" `
-o $RootDir/build/stubs --ignore-invalid=all jiminy_py
Copy-Item -Force -Path "$RootDir/build/stubs/jiminy_py-stubs/core/__init__.pyi" `
-Destination "$RootDir/build/pypi/jiminy_py/src/jiminy_py/core/core.pyi"
python -m pip install --upgrade "numpy>=1.24" numba torch
}
# FIXME: stubgen does not work with Numpy 2.X without option `--no-analysis`
# (see https://github.com/python/mypy/issues/17396)
stubgen -p jiminy_py -o $RootDir/build/pypi/jiminy_py/src --no-analysis
python "$RootDir/build_tools/stubgen.py" `
-o $RootDir/build/stubs --ignore-invalid=all jiminy_py
Copy-Item -Force -Path "$RootDir/build/stubs/jiminy_py-stubs/core/__init__.pyi" `
-Destination "$RootDir/build/pypi/jiminy_py/src/jiminy_py/core/core.pyi"

# Generate wheels
Set-Location -Path "$RootDir/build"
Expand Down
6 changes: 4 additions & 2 deletions python/gym_jiminy/common/gym_jiminy/common/bases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
MixtureReward,
AbstractTerminationCondition,
QuantityTermination,
EpisodeState)
EpisodeState,
partial_hashable)
from .blocks import (BlockState,
InterfaceBlock,
BaseObserverBlock,
Expand Down Expand Up @@ -74,5 +75,6 @@
'QuantityCreator',
'EpisodeState',
'StateQuantity',
'DatasetTrajectoryQuantity'
'DatasetTrajectoryQuantity',
'partial_hashable'
]
55 changes: 53 additions & 2 deletions python/gym_jiminy/common/gym_jiminy/common/bases/compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
This modular approach allows for standardization of usual metrics. Overall, it
greatly reduces code duplication and bugs.
"""
import inspect
from functools import partial
from abc import abstractmethod, ABCMeta
from enum import IntEnum
from typing import Tuple, Sequence, Callable, Union, Optional, Generic, TypeVar
from typing import (
Tuple, Sequence, Callable, Union, Optional, Generic, Any, TypeVar,
TYPE_CHECKING)

import numpy as np

Expand All @@ -25,6 +29,53 @@
ArrayLikeOrScalar = Union[ArrayOrScalar, Sequence[Union[Number, np.number]]]


class partial_hashable(partial): # pylint: disable=invalid-name
"""Extends standard `functools.Partial` class with hash and equality
operator.

Two partial instances are equal if they are wrapping the exact same
function (i.e. pointing to the same memory address as per `id` build-in
function), and bindings the same arguments (i.e. all arguments are equal
as per `==` operator). Note that it does not matter if the constructor
arguments of `Partial` itself are positional or keyword-based. Internally,
they will be stored in an ordered list of keyword-only arguments for
equality check.

.. warning::
Try to instantiate this class with invalid arguments for the method
being wrapped (e.g. specifying multiple values for the same argument)
would raise a `TypeError` exception, unlike `functools.partial` that
would only fail when calling the resulting callable object.
"""

if TYPE_CHECKING:
_normalized_args: Tuple[Any, ...]

def __new__(cls,
func: Callable, /,
*args: Any,
**kwargs: Any) -> "partial_hashable":
# Call base implementation
self = super(partial_hashable, cls).__new__(cls, func, *args, **kwargs)

# Pre-compute normalized arguments once and for all
sig = inspect.signature(self.func)
bound = sig.bind_partial(*self.args, **(self.keywords or {}))
bound.apply_defaults()
self._normalized_args = tuple(bound.arguments.values())

return self

def __eq__(self, other: Any) -> bool:
if not isinstance(other, partial_hashable):
return False
return self.func == other.func and (
self._normalized_args == other._normalized_args)

def __hash__(self) -> int:
return hash((self.func, self._normalized_args))


class AbstractReward(metaclass=ABCMeta):
"""Abstract class from which all reward component must derived.

Expand Down Expand Up @@ -514,7 +565,7 @@ def __call__(self, info: InfoType) -> Tuple[bool, bool]:
return is_terminated, is_truncated


class QuantityTermination(AbstractTerminationCondition, Generic[ValueT]):
class QuantityTermination(AbstractTerminationCondition):
"""Convenience class making it easy to derive termination conditions from
generic quantities.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .generic import (SurviveReward,
TrackingQuantityReward,
TrackingActuatedJointPositionsReward,
MinimizeMechanicalPowerConsumption,
DriftTrackingQuantityTermination,
ShiftTrackingQuantityTermination,
MechanicalSafetyTermination,
Expand Down Expand Up @@ -38,6 +39,7 @@
"SurviveReward",
"MinimizeFrictionReward",
"MinimizeAngularMomentumReward",
"MinimizeMechanicalPowerConsumption",
"TrackingQuantityReward",
"TrackingActuatedJointPositionsReward",
"TrackingBaseHeightReward",
Expand Down
Loading
Loading