From 5d64ba5da1f55af815944ff58d477048b2e27236 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Thu, 24 Oct 2024 10:26:25 +0200 Subject: [PATCH] fix unit tests --- .../robots/feetech_calibration.py | 9 +++++ .../robot_devices/robots/manipulator.py | 23 +++-------- tests/mock_dynamixel_sdk.py | 29 ++++++++------ tests/mock_scservo_sdk.py | 38 +++++++++++++------ 4 files changed, 59 insertions(+), 40 deletions(-) diff --git a/lerobot/common/robot_devices/robots/feetech_calibration.py b/lerobot/common/robot_devices/robots/feetech_calibration.py index 4c7fa5e06..1fca07f43 100644 --- a/lerobot/common/robot_devices/robots/feetech_calibration.py +++ b/lerobot/common/robot_devices/robots/feetech_calibration.py @@ -126,6 +126,15 @@ def apply_offset(calib, offset): return calib +def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): + if robot_type == "so100": + return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type) + elif robot_type == "moss": + return run_arm_auto_calibration_moss(arm, robot_type, arm_name, arm_type) + else: + raise ValueError(robot_type) + + def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 0ad219f9e..fc21e64a3 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -335,29 +335,17 @@ def load_or_run_calibration_(name, arm, arm_type): calibration = run_arm_calibration(arm, self.robot_type, name, arm_type) - elif self.robot_type in ["so100"]: + elif self.robot_type in ["so100", "moss"]: from lerobot.common.robot_devices.robots.feetech_calibration import ( - run_arm_auto_calibration_so100, + run_arm_auto_calibration, run_arm_manual_calibration, ) - if arm_type == "leader": + # TODO(rcadene): better way to handle mocking + test run_arm_auto_calibration + if arm_type == "leader" or arm.mock: calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) elif arm_type == "follower": - calibration = run_arm_auto_calibration_so100(arm, self.robot_type, name, arm_type) - else: - raise ValueError(arm_type) - - elif self.robot_type in ["moss"]: - from lerobot.common.robot_devices.robots.feetech_calibration import ( - run_arm_auto_calibration_moss, - run_arm_manual_calibration, - ) - - if arm_type == "leader": - calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) - elif arm_type == "follower": - calibration = run_arm_auto_calibration_moss(arm, self.robot_type, name, arm_type) + calibration = run_arm_auto_calibration(arm, self.robot_type, name, arm_type) else: raise ValueError(arm_type) @@ -482,7 +470,6 @@ def set_so100_robot_preset(self): # the motors. Note: this configuration is not in the official STS3215 Memory Table self.follower_arms[name].write("Maximum_Acceleration", 254) self.follower_arms[name].write("Acceleration", 254) - time.sleep(1) def teleop_step( self, record_data=False diff --git a/tests/mock_dynamixel_sdk.py b/tests/mock_dynamixel_sdk.py index 6d0ed20e5..a790dff05 100644 --- a/tests/mock_dynamixel_sdk.py +++ b/tests/mock_dynamixel_sdk.py @@ -18,6 +18,19 @@ def convert_to_bytes(value, bytes): return value +def get_default_motor_values(motor_index): + return { + # Key (int) are from X_SERIES_CONTROL_TABLE + 7: motor_index, # ID + 8: DEFAULT_BAUDRATE, # Baud_rate + 10: 0, # Drive_Mode + 64: 0, # Torque_Enable + # Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144 + # For other joints, 2560 will be autocorrected to be in calibration range + 132: 2560, # Present_Position + } + + class PortHandler: def __init__(self, port): self.port = port @@ -52,18 +65,9 @@ def __init__(self, port_handler, packet_handler, address, bytes): self.packet_handler = packet_handler def addParam(self, motor_index): # noqa: N802 + # Initialize motor default values if motor_index not in self.packet_handler.data: - # Initialize motor default values - self.packet_handler.data[motor_index] = { - # Key (int) are from X_SERIES_CONTROL_TABLE - 7: motor_index, # ID - 8: DEFAULT_BAUDRATE, # Baud_rate - 10: 0, # Drive_Mode - 64: 0, # Torque_Enable - # Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144 - # For other joints, 2560 will be autocorrected to be in calibration range - 132: 2560, # Present_Position - } + self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) def txRxPacket(self): # noqa: N802 return COMM_SUCCESS @@ -78,6 +82,9 @@ def __init__(self, port_handler, packet_handler, address, bytes): self.address = address def addParam(self, index, data): # noqa: N802 + # Initialize motor default values + if index not in self.packet_handler.data: + self.packet_handler.data[index] = get_default_motor_values(index) self.changeParam(index, data) def txPacket(self): # noqa: N802 diff --git a/tests/mock_scservo_sdk.py b/tests/mock_scservo_sdk.py index 06c4283a6..596978c00 100644 --- a/tests/mock_scservo_sdk.py +++ b/tests/mock_scservo_sdk.py @@ -18,6 +18,29 @@ def convert_to_bytes(value, bytes): return value +def get_default_motor_values(motor_index): + return { + # Key (int) are from SCS_SERIES_CONTROL_TABLE + 5: motor_index, # ID + 6: DEFAULT_BAUDRATE, # Baud_rate + 10: 0, # Drive_Mode + 21: 32, # P_Coefficient + 22: 32, # D_Coefficient + 23: 0, # I_Coefficient + 40: 0, # Torque_Enable + 41: 254, # Acceleration + 31: -2047, # Offset + 33: 0, # Mode + 55: 1, # Lock + # Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144 + # For other joints, 2560 will be autocorrected to be in calibration range + 56: 2560, # Present_Position + 58: 0, # Present_Speed + 69: 0, # Present_Current + 85: 150, # Maximum_Acceleration + } + + class PortHandler: def __init__(self, port): self.port = port @@ -52,18 +75,9 @@ def __init__(self, port_handler, packet_handler, address, bytes): self.packet_handler = packet_handler def addParam(self, motor_index): # noqa: N802 + # Initialize motor default values if motor_index not in self.packet_handler.data: - # Initialize motor default values - self.packet_handler.data[motor_index] = { - # Key (int) are from X_SERIES_CONTROL_TABLE - 7: motor_index, # ID - 8: DEFAULT_BAUDRATE, # Baud_rate - 10: 0, # Drive_Mode - 64: 0, # Torque_Enable - # Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144 - # For other joints, 2560 will be autocorrected to be in calibration range - 132: 2560, # Present_Position - } + self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) def txRxPacket(self): # noqa: N802 return COMM_SUCCESS @@ -78,6 +92,8 @@ def __init__(self, port_handler, packet_handler, address, bytes): self.address = address def addParam(self, index, data): # noqa: N802 + if index not in self.packet_handler.data: + self.packet_handler.data[index] = get_default_motor_values(index) self.changeParam(index, data) def txPacket(self): # noqa: N802