diff --git a/data/.lfs/xarm7.tar.gz b/data/.lfs/xarm7.tar.gz new file mode 100644 index 0000000000..b19d8d919a --- /dev/null +++ b/data/.lfs/xarm7.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc5c96439cc415d7d7b1296363b5684354aaef22c7dbe8e50bce81183401511c +size 6297600 diff --git a/dimos/e2e_tests/test_simulation_module.py b/dimos/e2e_tests/test_simulation_module.py new file mode 100644 index 0000000000..32ac01a183 --- /dev/null +++ b/dimos/e2e_tests/test_simulation_module.py @@ -0,0 +1,86 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for the simulation module.""" + +import os + +import pytest + +from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState + + +def _positions_within_tolerance( + positions: list[float], + target: list[float], + tolerance: float, +) -> bool: + if len(positions) < len(target): + return False + return all(abs(positions[i] - target[i]) <= tolerance for i in range(len(target))) + + +@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM doesn't work in CI.") +@pytest.mark.e2e +class TestSimulationModuleE2E: + def test_xarm7_joint_state_published(self, lcm_spy, start_blueprint) -> None: + joint_state_topic = "/xarm/joint_states#sensor_msgs.JointState" + lcm_spy.save_topic(joint_state_topic) + + start_blueprint("simulation-xarm7") + lcm_spy.wait_for_saved_topic(joint_state_topic, timeout=15.0) + + with lcm_spy._messages_lock: + raw_joint_state = lcm_spy.messages[joint_state_topic][0] + + joint_state = JointState.lcm_decode(raw_joint_state) + assert len(joint_state.name) == 8 + assert len(joint_state.position) == 8 + + def test_xarm7_robot_state_published(self, lcm_spy, start_blueprint) -> None: + robot_state_topic = "/xarm/robot_state#sensor_msgs.RobotState" + lcm_spy.save_topic(robot_state_topic) + + start_blueprint("simulation-xarm7") + lcm_spy.wait_for_saved_topic(robot_state_topic, timeout=15.0) + + with lcm_spy._messages_lock: + raw_robot_state = lcm_spy.messages[robot_state_topic][0] + + robot_state = RobotState.lcm_decode(raw_robot_state) + assert robot_state.mt_able in (0, 1) + + def test_xarm7_joint_command_updates_joint_state(self, lcm_spy, start_blueprint) -> None: + joint_state_topic = "/xarm/joint_states#sensor_msgs.JointState" + joint_command_topic = "/xarm/joint_position_command#sensor_msgs.JointCommand" + lcm_spy.save_topic(joint_state_topic) + + start_blueprint("simulation-xarm7") + lcm_spy.wait_for_saved_topic(joint_state_topic, timeout=15.0) + + target_positions = [0.2, -0.2, 0.1, -0.1, 0.15, -0.15, 0.05] + lcm_spy.publish(joint_command_topic, JointCommand(positions=target_positions)) + + tolerance = 0.03 + lcm_spy.wait_for_message_result( + joint_state_topic, + JointState, + predicate=lambda msg: _positions_within_tolerance( + list(msg.position), + target_positions, + tolerance, + ), + fail_message=("joint_state did not reach commanded positions within tolerance"), + timeout=10.0, + ) diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index a4a953062c..c33e56116b 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -64,6 +64,7 @@ "xarm-perception": "dimos.manipulation.manipulation_blueprints:xarm_perception", "xarm6-planner-only": "dimos.manipulation.manipulation_blueprints:xarm6_planner_only", "xarm7-planner-coordinator": "dimos.manipulation.manipulation_blueprints:xarm7_planner_coordinator", + "xarm7-trajectory-sim": "dimos.simulation.sim_blueprints:xarm7_trajectory_sim", } @@ -103,6 +104,7 @@ "replanning_a_star_planner": "dimos.navigation.replanning_a_star.module", "rerun_bridge": "dimos.visualization.rerun.bridge", "ros_nav": "dimos.navigation.rosnav", + "simulation": "dimos.simulation.manipulators.sim_module", "spatial_memory": "dimos.perception.spatial_perception", "speak_skill": "dimos.agents.skills.speak_skill", "temporal_memory": "dimos.perception.experimental.temporal_memory.temporal_memory", diff --git a/dimos/simulation/engines/__init__.py b/dimos/simulation/engines/__init__.py new file mode 100644 index 0000000000..d437f9a7cd --- /dev/null +++ b/dimos/simulation/engines/__init__.py @@ -0,0 +1,25 @@ +"""Simulation engines for manipulator backends.""" + +from __future__ import annotations + +from typing import Literal + +from dimos.simulation.engines.base import SimulationEngine +from dimos.simulation.engines.mujoco_engine import MujocoEngine + +EngineType = Literal["mujoco"] + +_ENGINES: dict[EngineType, type[SimulationEngine]] = { + "mujoco": MujocoEngine, +} + + +def get_engine(engine_name: EngineType) -> type[SimulationEngine]: + return _ENGINES[engine_name] + + +__all__ = [ + "EngineType", + "SimulationEngine", + "get_engine", +] diff --git a/dimos/simulation/engines/base.py b/dimos/simulation/engines/base.py new file mode 100644 index 0000000000..d450614c62 --- /dev/null +++ b/dimos/simulation/engines/base.py @@ -0,0 +1,84 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base interfaces for simulator engines.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + + from dimos.msgs.sensor_msgs import JointState + + +class SimulationEngine(ABC): + """Abstract base class for a simulator engine instance.""" + + def __init__(self, config_path: Path, headless: bool) -> None: + self._config_path = config_path + self._headless = headless + + @property + def config_path(self) -> Path: + return self._config_path + + @property + def headless(self) -> bool: + return self._headless + + @abstractmethod + def connect(self) -> bool: + """Connect to simulation and start the engine.""" + + @abstractmethod + def disconnect(self) -> bool: + """Disconnect from simulation and stop the engine.""" + + @property + @abstractmethod + def connected(self) -> bool: + """Whether the engine is connected.""" + + @property + @abstractmethod + def num_joints(self) -> int: + """Number of joints for the loaded robot.""" + + @property + @abstractmethod + def joint_names(self) -> list[str]: + """Joint names for the loaded robot.""" + + @abstractmethod + def read_joint_positions(self) -> list[float]: + """Read joint positions in radians.""" + + @abstractmethod + def read_joint_velocities(self) -> list[float]: + """Read joint velocities in rad/s.""" + + @abstractmethod + def read_joint_efforts(self) -> list[float]: + """Read joint efforts in Nm.""" + + @abstractmethod + def write_joint_command(self, command: JointState) -> None: + """Command joints using a JointState message.""" + + @abstractmethod + def hold_current_position(self) -> None: + """Hold current joint positions.""" diff --git a/dimos/simulation/engines/mujoco_engine.py b/dimos/simulation/engines/mujoco_engine.py new file mode 100644 index 0000000000..6dc8e46595 --- /dev/null +++ b/dimos/simulation/engines/mujoco_engine.py @@ -0,0 +1,300 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MuJoCo simulation engine implementation.""" + +from __future__ import annotations + +import threading +import time +from typing import TYPE_CHECKING + +import mujoco +import mujoco.viewer as viewer # type: ignore[import-untyped] + +from dimos.simulation.engines.base import SimulationEngine +from dimos.simulation.utils.xml_parser import JointMapping, build_joint_mappings +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from pathlib import Path + + from dimos.msgs.sensor_msgs import JointState + +logger = setup_logger() + + +class MujocoEngine(SimulationEngine): + """ + MuJoCo simulation engine. + + - starts MuJoCo simulation engine + - loads robot/environment into simulation + - applies control commands + """ + + def __init__(self, config_path: Path, headless: bool) -> None: + super().__init__(config_path=config_path, headless=headless) + + xml_path = self._resolve_xml_path(config_path) + self._model = mujoco.MjModel.from_xml_path(str(xml_path)) + self._xml_path = xml_path + + self._data = mujoco.MjData(self._model) + self._joint_mappings = build_joint_mappings(self._xml_path, self._model) + self._joint_names = [mapping.name for mapping in self._joint_mappings] + self._num_joints = len(self._joint_names) + timestep = float(self._model.opt.timestep) + self._control_frequency = 1.0 / timestep if timestep > 0.0 else 100.0 + + self._connected = False + self._lock = threading.Lock() + self._stop_event = threading.Event() + self._sim_thread: threading.Thread | None = None + + self._joint_positions = [0.0] * self._num_joints + self._joint_velocities = [0.0] * self._num_joints + self._joint_efforts = [0.0] * self._num_joints + + self._joint_position_targets = [0.0] * self._num_joints + self._joint_velocity_targets = [0.0] * self._num_joints + self._joint_effort_targets = [0.0] * self._num_joints + self._command_mode = "position" + for i, mapping in enumerate(self._joint_mappings): + current_pos = self._current_position(mapping) + self._joint_position_targets[i] = current_pos + self._joint_positions[i] = current_pos + + def _resolve_xml_path(self, config_path: Path) -> Path: + if config_path is None: + raise ValueError("config_path is required for MuJoCo simulation loading") + resolved = config_path.expanduser() + xml_path = resolved / "scene.xml" if resolved.is_dir() else resolved + if not xml_path.exists(): + raise FileNotFoundError(f"MuJoCo XML not found: {xml_path}") + return xml_path + + def _current_position(self, mapping: JointMapping) -> float: + if mapping.joint_id is not None and mapping.qpos_adr is not None: + return float(self._data.qpos[mapping.qpos_adr]) + if mapping.tendon_qpos_adrs: + return float( + sum(self._data.qpos[adr] for adr in mapping.tendon_qpos_adrs) + / len(mapping.tendon_qpos_adrs) + ) + if mapping.actuator_id is not None: + return float(self._data.actuator_length[mapping.actuator_id]) + return 0.0 + + def _apply_control(self) -> None: + with self._lock: + if self._command_mode == "effort": + targets = list(self._joint_effort_targets) + elif self._command_mode == "velocity": + targets = list(self._joint_velocity_targets) + elif self._command_mode == "position": + targets = list(self._joint_position_targets) + for i, mapping in enumerate(self._joint_mappings): + if mapping.actuator_id is None: + continue + if i < len(targets): + self._data.ctrl[mapping.actuator_id] = targets[i] + + def _update_joint_state(self) -> None: + with self._lock: + for i, mapping in enumerate(self._joint_mappings): + if mapping.joint_id is not None: + if mapping.qpos_adr is not None: + self._joint_positions[i] = float(self._data.qpos[mapping.qpos_adr]) + if mapping.dof_adr is not None: + self._joint_velocities[i] = float(self._data.qvel[mapping.dof_adr]) + self._joint_efforts[i] = float(self._data.qfrc_actuator[mapping.dof_adr]) + continue + + if mapping.tendon_qpos_adrs: + pos_sum = sum(self._data.qpos[adr] for adr in mapping.tendon_qpos_adrs) + count = len(mapping.tendon_qpos_adrs) + self._joint_positions[i] = float(pos_sum / count) + if mapping.tendon_dof_adrs: + vel_sum = sum(self._data.qvel[adr] for adr in mapping.tendon_dof_adrs) + self._joint_velocities[i] = float(vel_sum / len(mapping.tendon_dof_adrs)) + else: + self._joint_velocities[i] = 0.0 + elif mapping.actuator_id is not None: + self._joint_positions[i] = float( + self._data.actuator_length[mapping.actuator_id] + ) + self._joint_velocities[i] = 0.0 + + if mapping.actuator_id is not None: + self._joint_efforts[i] = float(self._data.actuator_force[mapping.actuator_id]) + + def connect(self) -> bool: + try: + logger.info(f"{self.__class__.__name__}: connect()") + with self._lock: + self._connected = True + self._stop_event.clear() + + if self._sim_thread is None or not self._sim_thread.is_alive(): + self._sim_thread = threading.Thread( + target=self._sim_loop, + name=f"{self.__class__.__name__}Sim", + daemon=True, + ) + self._sim_thread.start() + return True + except Exception as e: + logger.error(f"{self.__class__.__name__}: connect() failed: {e}") + return False + + def disconnect(self) -> bool: + try: + logger.info(f"{self.__class__.__name__}: disconnect()") + with self._lock: + self._connected = False + self._stop_event.set() + if self._sim_thread and self._sim_thread.is_alive(): + self._sim_thread.join(timeout=2.0) + self._sim_thread = None + return True + except Exception as e: + logger.error(f"{self.__class__.__name__}: disconnect() failed: {e}") + return False + + def _sim_loop(self) -> None: + logger.info(f"{self.__class__.__name__}: sim loop started") + dt = 1.0 / self._control_frequency + + def _step_once(sync_viewer: bool) -> None: + loop_start = time.time() + self._apply_control() + mujoco.mj_step(self._model, self._data) + if sync_viewer: + m_viewer.sync() + self._update_joint_state() + + elapsed = time.time() - loop_start + sleep_time = dt - elapsed + if sleep_time > 0: + time.sleep(sleep_time) + + if self._headless: + while not self._stop_event.is_set(): + _step_once(sync_viewer=False) + else: + with viewer.launch_passive( + self._model, self._data, show_left_ui=False, show_right_ui=False + ) as m_viewer: + while m_viewer.is_running() and not self._stop_event.is_set(): + _step_once(sync_viewer=True) + + logger.info(f"{self.__class__.__name__}: sim loop stopped") + + @property + def connected(self) -> bool: + with self._lock: + return self._connected + + @property + def num_joints(self) -> int: + return self._num_joints + + @property + def joint_names(self) -> list[str]: + return list(self._joint_names) + + @property + def model(self) -> mujoco.MjModel: + return self._model + + @property + def joint_positions(self) -> list[float]: + with self._lock: + return list(self._joint_positions) + + @property + def joint_velocities(self) -> list[float]: + with self._lock: + return list(self._joint_velocities) + + @property + def joint_efforts(self) -> list[float]: + with self._lock: + return list(self._joint_efforts) + + @property + def control_frequency(self) -> float: + return self._control_frequency + + def read_joint_positions(self) -> list[float]: + return self.joint_positions + + def read_joint_velocities(self) -> list[float]: + return self.joint_velocities + + def read_joint_efforts(self) -> list[float]: + return self.joint_efforts + + def write_joint_command(self, command: JointState) -> None: + if command.position: + self._command_mode = "position" + self._set_position_targets(command.position) + return + if command.velocity: + self._command_mode = "velocity" + self._set_velocity_targets(command.velocity) + return + if command.effort: + self._command_mode = "effort" + self._set_effort_targets(command.effort) + return + + def _set_position_targets(self, positions: list[float]) -> None: + if len(positions) > self._num_joints: + raise ValueError( + f"Position command has {len(positions)} joints, expected at most {self._num_joints}" + ) + with self._lock: + for i in range(len(positions)): + self._joint_position_targets[i] = float(positions[i]) + + def _set_velocity_targets(self, velocities: list[float]) -> None: + if len(velocities) > self._num_joints: + raise ValueError( + f"Velocity command has {len(velocities)} joints, expected at most {self._num_joints}" + ) + with self._lock: + for i in range(len(velocities)): + self._joint_velocity_targets[i] = float(velocities[i]) + + def _set_effort_targets(self, efforts: list[float]) -> None: + if len(efforts) > self._num_joints: + raise ValueError( + f"Effort command has {len(efforts)} joints, expected at most {self._num_joints}" + ) + with self._lock: + for i in range(len(efforts)): + self._joint_effort_targets[i] = float(efforts[i]) + + def hold_current_position(self) -> None: + with self._lock: + self._command_mode = "position" + for i, mapping in enumerate(self._joint_mappings): + self._joint_position_targets[i] = self._current_position(mapping) + + +__all__ = [ + "MujocoEngine", +] diff --git a/dimos/simulation/manipulators/__init__.py b/dimos/simulation/manipulators/__init__.py new file mode 100644 index 0000000000..816de0a18d --- /dev/null +++ b/dimos/simulation/manipulators/__init__.py @@ -0,0 +1,54 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulation manipulator utilities.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface + from dimos.simulation.manipulators.sim_module import ( + SimulationModule, + SimulationModuleConfig, + simulation, + ) + +__all__ = [ + "SimManipInterface", + "SimulationModule", + "SimulationModuleConfig", + "simulation", +] + + +def __getattr__(name: str): # type: ignore[no-untyped-def] + if name == "SimManipInterface": + from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface + + return SimManipInterface + if name in {"SimulationModule", "SimulationModuleConfig", "simulation"}: + from dimos.simulation.manipulators.sim_module import ( + SimulationModule, + SimulationModuleConfig, + simulation, + ) + + return { + "SimulationModule": SimulationModule, + "SimulationModuleConfig": SimulationModuleConfig, + "simulation": simulation, + }[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/dimos/simulation/manipulators/sim_manip_interface.py b/dimos/simulation/manipulators/sim_manip_interface.py new file mode 100644 index 0000000000..c829f0c864 --- /dev/null +++ b/dimos/simulation/manipulators/sim_manip_interface.py @@ -0,0 +1,200 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulation-agnostic manipulator interface.""" + +from __future__ import annotations + +import logging +import math +from typing import TYPE_CHECKING + +from dimos.hardware.manipulators.spec import ControlMode, JointLimits, ManipulatorInfo +from dimos.msgs.sensor_msgs import JointState + +if TYPE_CHECKING: + from dimos.simulation.engines.base import SimulationEngine + + +class SimManipInterface: + """Adapter wrapper around a simulation engine to provide a uniform manipulator API.""" + + def __init__(self, engine: SimulationEngine) -> None: + self.logger = logging.getLogger(self.__class__.__name__) + self._engine = engine + self._joint_names = list(engine.joint_names) + self._dof = len(self._joint_names) + self._connected = False + self._servos_enabled = False + self._control_mode = ControlMode.POSITION + self._error_code = 0 + self._error_message = "" + + def connect(self) -> bool: + """Connect to the simulation engine.""" + try: + self.logger.info("Connecting to simulation engine...") + if not self._engine.connect(): + self.logger.error("Failed to connect to simulation engine") + return False + if self._engine.connected: + self._connected = True + self._servos_enabled = True + self._joint_names = list(self._engine.joint_names) + self._dof = len(self._joint_names) + self.logger.info( + "Successfully connected to simulation", + extra={"dof": self._dof}, + ) + return True + self.logger.error("Failed to connect to simulation engine") + return False + except Exception as exc: + self.logger.error(f"Sim connection failed: {exc}") + return False + + def disconnect(self) -> bool: + """Disconnect from simulation.""" + try: + return self._engine.disconnect() + except Exception as exc: + self._connected = False + self.logger.error(f"Sim disconnection failed: {exc}") + return False + + def is_connected(self) -> bool: + return bool(self._connected and self._engine.connected) + + def get_info(self) -> ManipulatorInfo: + vendor = "Simulation" + model = "Simulation" + dof = self._dof + return ManipulatorInfo( + vendor=vendor, + model=model, + dof=dof, + firmware_version=None, + serial_number=None, + ) + + def get_dof(self) -> int: + return self._dof + + def get_joint_names(self) -> list[str]: + return list(self._joint_names) + + def get_limits(self) -> JointLimits: + lower = [-math.pi] * self._dof + upper = [math.pi] * self._dof + max_vel_rad = math.radians(180.0) + return JointLimits( + position_lower=lower, + position_upper=upper, + velocity_max=[max_vel_rad] * self._dof, + ) + + def set_control_mode(self, mode: ControlMode) -> bool: + self._control_mode = mode + return True + + def get_control_mode(self) -> ControlMode: + return self._control_mode + + def read_joint_positions(self) -> list[float]: + positions = self._engine.read_joint_positions() + return positions[: self._dof] + + def read_joint_velocities(self) -> list[float]: + velocities = self._engine.read_joint_velocities() + return velocities[: self._dof] + + def read_joint_efforts(self) -> list[float]: + efforts = self._engine.read_joint_efforts() + return efforts[: self._dof] + + def read_state(self) -> dict[str, int]: + velocities = self.read_joint_velocities() + is_moving = any(abs(v) > 1e-4 for v in velocities) + mode_int = list(ControlMode).index(self._control_mode) + return { + "state": 1 if is_moving else 0, + "mode": mode_int, + } + + def read_error(self) -> tuple[int, str]: + return self._error_code, self._error_message + + def write_joint_positions(self, positions: list[float]) -> bool: + if not self._servos_enabled: + return False + self._control_mode = ControlMode.POSITION + self._engine.write_joint_command(JointState(position=positions[: self._dof])) + return True + + def write_joint_velocities(self, velocities: list[float]) -> bool: + if not self._servos_enabled: + return False + self._control_mode = ControlMode.VELOCITY + self._engine.write_joint_command(JointState(velocity=velocities[: self._dof])) + return True + + def write_joint_efforts(self, efforts: list[float]) -> bool: + if not self._servos_enabled: + return False + self._control_mode = ControlMode.TORQUE + self._engine.write_joint_command(JointState(effort=efforts[: self._dof])) + return True + + def write_stop(self) -> bool: + self._engine.hold_current_position() + return True + + def write_enable(self, enable: bool) -> bool: + self._servos_enabled = enable + return True + + def read_enabled(self) -> bool: + return self._servos_enabled + + def write_clear_errors(self) -> bool: + self._error_code = 0 + self._error_message = "" + return True + + def read_cartesian_position(self) -> dict[str, float] | None: + return None + + def write_cartesian_position( + self, + pose: dict[str, float], + velocity: float = 1.0, + ) -> bool: + _pose = pose + _velocity = velocity + return False + + def read_gripper_position(self) -> float | None: + return None + + def write_gripper_position(self, position: float) -> bool: + _ = position + return False + + def read_force_torque(self) -> list[float] | None: + return None + + +__all__ = [ + "SimManipInterface", +] diff --git a/dimos/simulation/manipulators/sim_module.py b/dimos/simulation/manipulators/sim_module.py new file mode 100644 index 0000000000..4f1bb986d3 --- /dev/null +++ b/dimos/simulation/manipulators/sim_module.py @@ -0,0 +1,247 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulator-agnostic manipulator simulation module.""" + +from __future__ import annotations + +from dataclasses import dataclass +import threading +import time +from typing import TYPE_CHECKING, Any + +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.core.module import ModuleConfig +from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState +from dimos.simulation.engines import EngineType, get_engine +from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface + +if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + + +@dataclass(kw_only=True) +class SimulationModuleConfig(ModuleConfig): + engine: EngineType + config_path: Path | Callable[[], Path] + headless: bool = False + + +class SimulationModule(Module[SimulationModuleConfig]): + """Module wrapper for manipulator simulation across engines.""" + + default_config = SimulationModuleConfig + config: SimulationModuleConfig + + joint_state: Out[JointState] + robot_state: Out[RobotState] + joint_position_command: In[JointCommand] + joint_velocity_command: In[JointCommand] + + MIN_CONTROL_RATE = 1.0 + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._backend: SimManipInterface | None = None + self._control_rate = 100.0 + self._monitor_rate = 100.0 + self._joint_prefix = "joint" + self._stop_event = threading.Event() + self._control_thread: threading.Thread | None = None + self._monitor_thread: threading.Thread | None = None + self._command_lock = threading.Lock() + self._pending_positions: list[float] | None = None + self._pending_velocities: list[float] | None = None + + def _create_backend(self) -> SimManipInterface: + engine_cls = get_engine(self.config.engine) + config_path = ( + self.config.config_path() + if callable(self.config.config_path) + else self.config.config_path + ) + engine = engine_cls( + config_path=config_path, + headless=self.config.headless, + ) + return SimManipInterface(engine=engine) + + @rpc + def start(self) -> None: + super().start() + if self._backend is None: + self._backend = self._create_backend() + if not self._backend.connect(): + raise RuntimeError("Failed to connect to simulation backend") + self._backend.write_enable(True) + + self._disposables.add( + Disposable(self.joint_position_command.subscribe(self._on_joint_position_command)) + ) + self._disposables.add( + Disposable(self.joint_velocity_command.subscribe(self._on_joint_velocity_command)) + ) + + self._stop_event.clear() + self._control_thread = threading.Thread( + target=self._control_loop, + daemon=True, + name=f"{self.__class__.__name__}-control", + ) + self._monitor_thread = threading.Thread( + target=self._monitor_loop, + daemon=True, + name=f"{self.__class__.__name__}-monitor", + ) + self._control_thread.start() + self._monitor_thread.start() + + @rpc + def stop(self) -> None: + self._stop_event.set() + if self._control_thread and self._control_thread.is_alive(): + self._control_thread.join(timeout=2.0) + if self._monitor_thread and self._monitor_thread.is_alive(): + self._monitor_thread.join(timeout=2.0) + if self._backend: + self._backend.disconnect() + super().stop() + + @rpc + def enable_servos(self) -> bool: + if not self._backend: + return False + return self._backend.write_enable(True) + + @rpc + def disable_servos(self) -> bool: + if not self._backend: + return False + return self._backend.write_enable(False) + + @rpc + def clear_errors(self) -> bool: + if not self._backend: + return False + return self._backend.write_clear_errors() + + @rpc + def emergency_stop(self) -> bool: + if not self._backend: + return False + return self._backend.write_stop() + + def _on_joint_position_command(self, msg: JointCommand) -> None: + with self._command_lock: + self._pending_positions = list(msg.positions) + self._pending_velocities = None + + def _on_joint_velocity_command(self, msg: JointCommand) -> None: + with self._command_lock: + self._pending_velocities = list(msg.positions) + self._pending_positions = None + + def _control_loop(self) -> None: + period = 1.0 / max(self._control_rate, self.MIN_CONTROL_RATE) + next_tick = time.monotonic() # monotonic time used to avoid time drift + while not self._stop_event.is_set(): + with self._command_lock: + positions = ( + None if self._pending_positions is None else list(self._pending_positions) + ) + velocities = ( + None if self._pending_velocities is None else list(self._pending_velocities) + ) + + if self._backend: + if positions is not None: + self._backend.write_joint_positions(positions) + elif velocities is not None: + self._backend.write_joint_velocities(velocities) + dof = self._backend.get_dof() + names = self._resolve_joint_names(dof) + positions = self._backend.read_joint_positions() + velocities = self._backend.read_joint_velocities() + efforts = self._backend.read_joint_efforts() + self.joint_state.publish( + JointState( + frame_id=self.frame_id, + name=names, + position=positions, + velocity=velocities, + effort=efforts, + ) + ) + next_tick += period + sleep_for = next_tick - time.monotonic() + if sleep_for > 0: + if self._stop_event.wait(sleep_for): + break + else: + next_tick = time.monotonic() + + def _monitor_loop(self) -> None: + period = 1.0 / max(self._monitor_rate, self.MIN_CONTROL_RATE) + next_tick = time.monotonic() # monotonic time used to avoid time drift + while not self._stop_event.is_set(): + if not self._backend: + pass + else: + dof = self._backend.get_dof() + self._resolve_joint_names(dof) + positions = self._backend.read_joint_positions() + self._backend.read_joint_velocities() + self._backend.read_joint_efforts() + state = self._backend.read_state() + error_code, _ = self._backend.read_error() + self.robot_state.publish( + RobotState( + state=state.get("state", 0), + mode=state.get("mode", 0), + error_code=error_code, + warn_code=0, + cmdnum=0, + mt_brake=0, + mt_able=1 if self._backend.read_enabled() else 0, + tcp_pose=[], + tcp_offset=[], + joints=[float(p) for p in positions], + ) + ) + next_tick += period + sleep_for = next_tick - time.monotonic() + if sleep_for > 0: + if self._stop_event.wait(sleep_for): + break + else: + next_tick = time.monotonic() + + def _resolve_joint_names(self, dof: int) -> list[str]: + if self._backend: + names = self._backend.get_joint_names() + if len(names) >= dof: + return list(names[:dof]) + return [f"{self._joint_prefix}{i + 1}" for i in range(dof)] + + +simulation = SimulationModule.blueprint + +__all__ = [ + "SimulationModule", + "SimulationModuleConfig", + "simulation", +] diff --git a/dimos/simulation/manipulators/test_sim_module.py b/dimos/simulation/manipulators/test_sim_module.py new file mode 100644 index 0000000000..334e2ce85f --- /dev/null +++ b/dimos/simulation/manipulators/test_sim_module.py @@ -0,0 +1,123 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import threading + +import pytest + +from dimos.simulation.manipulators.sim_module import SimulationModule + + +class _DummyRPC: + def serve_module_rpc(self, _module) -> None: # type: ignore[no-untyped-def] + return None + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + +class _FakeBackend: + def __init__(self) -> None: + self._names = ["joint1", "joint2", "joint3"] + + def get_dof(self) -> int: + return len(self._names) + + def get_joint_names(self) -> list[str]: + return list(self._names) + + def read_joint_positions(self) -> list[float]: + return [0.1, 0.2, 0.3] + + def read_joint_velocities(self) -> list[float]: + return [0.0, 0.0, 0.0] + + def read_joint_efforts(self) -> list[float]: + return [0.0, 0.0, 0.0] + + def read_state(self) -> dict[str, int]: + return {"state": 1, "mode": 2} + + def read_error(self) -> tuple[int, str]: + return 0, "" + + def read_enabled(self) -> bool: + return True + + def disconnect(self) -> None: + return None + + +def _run_single_monitor_iteration(module: SimulationModule, monkeypatch) -> None: # type: ignore[no-untyped-def] + def _wait_once(_: float) -> bool: + module._stop_event.set() + raise StopIteration + + monkeypatch.setattr(module._stop_event, "wait", _wait_once) + with pytest.raises(StopIteration): + module._monitor_loop() + + +def _run_single_control_iteration(module: SimulationModule, monkeypatch) -> None: # type: ignore[no-untyped-def] + def _wait_once(_: float) -> bool: + module._stop_event.set() + raise StopIteration + + monkeypatch.setattr(module._stop_event, "wait", _wait_once) + with pytest.raises(StopIteration): + module._control_loop() + + +def test_simulation_module_publishes_joint_state(monkeypatch) -> None: + module = SimulationModule( + engine="mujoco", + config_path=Path("."), + rpc_transport=_DummyRPC, + ) + module._backend = _FakeBackend() # type: ignore[assignment] + module._stop_event = threading.Event() + + joint_states: list[object] = [] + module.joint_state.subscribe(joint_states.append) + try: + _run_single_control_iteration(module, monkeypatch) + finally: + module.stop() + + assert len(joint_states) == 1 + assert joint_states[0].name == ["joint1", "joint2", "joint3"] + + +def test_simulation_module_publishes_robot_state(monkeypatch) -> None: + module = SimulationModule( + engine="mujoco", + config_path=Path("."), + rpc_transport=_DummyRPC, + ) + module._backend = _FakeBackend() # type: ignore[assignment] + module._stop_event = threading.Event() + + robot_states: list[object] = [] + module.robot_state.subscribe(robot_states.append) + try: + _run_single_monitor_iteration(module, monkeypatch) + finally: + module.stop() + + assert len(robot_states) == 1 + assert robot_states[0].state == 1 diff --git a/dimos/simulation/sim_blueprints.py b/dimos/simulation/sim_blueprints.py new file mode 100644 index 0000000000..5733b19ef2 --- /dev/null +++ b/dimos/simulation/sim_blueprints.py @@ -0,0 +1,49 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dimos.core.transport import LCMTransport +from dimos.msgs.sensor_msgs import ( # type: ignore[attr-defined] + JointCommand, + JointState, + RobotState, +) +from dimos.msgs.trajectory_msgs import JointTrajectory +from dimos.simulation.manipulators.sim_module import simulation +from dimos.utils.data import get_data + +xarm7_trajectory_sim = simulation( + engine="mujoco", + config_path=lambda: get_data("xarm7") + / "scene.xml", # avoid triggering LFS downloads during tests + headless=True, +).transports( + { + ("joint_state", JointState): LCMTransport("/xarm/joint_states", JointState), + ("robot_state", RobotState): LCMTransport("/xarm/robot_state", RobotState), + ("joint_position_command", JointCommand): LCMTransport( + "/xarm/joint_position_command", JointCommand + ), + ("trajectory", JointTrajectory): LCMTransport("/trajectory", JointTrajectory), + } +) + + +__all__ = [ + "simulation", + "xarm7_trajectory_sim", +] + +if __name__ == "__main__": + xarm7_trajectory_sim.build().loop() diff --git a/dimos/simulation/utils/xml_parser.py b/dimos/simulation/utils/xml_parser.py new file mode 100644 index 0000000000..052657ea95 --- /dev/null +++ b/dimos/simulation/utils/xml_parser.py @@ -0,0 +1,277 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MuJoCo XML parsing helpers for joint/actuator metadata.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING +import xml.etree.ElementTree as ET + +import mujoco + +if TYPE_CHECKING: + from pathlib import Path + + +@dataclass(frozen=True) +class JointMapping: + name: str + joint_id: int | None + actuator_id: int | None + qpos_adr: int | None + dof_adr: int | None + tendon_qpos_adrs: tuple[int, ...] + tendon_dof_adrs: tuple[int, ...] + + +@dataclass(frozen=True) +class _ActuatorSpec: + name: str + joint: str | None + tendon: str | None + + +def build_joint_mappings(xml_path: Path, model: mujoco.MjModel) -> list[JointMapping]: + specs = _parse_actuator_specs(xml_path) + if specs: + return _build_joint_mappings_from_specs(specs, model) + if int(model.nu) > 0: + return _build_joint_mappings_from_actuators(model) + return _build_joint_mappings_from_model(model) + + +def _parse_actuator_specs(xml_path: Path) -> list[_ActuatorSpec]: + return _collect_actuator_specs(xml_path.resolve(), seen=set()) + + +def _collect_actuator_specs(xml_path: Path, seen: set[Path]) -> list[_ActuatorSpec]: + if xml_path in seen: + return [] + seen.add(xml_path) + + root = ET.parse(xml_path).getroot() + base_dir = xml_path.parent + specs: list[_ActuatorSpec] = [] + + def walk(node: ET.Element) -> None: + for child in node: + if child.tag == "include": + include_file = child.attrib.get("file") + if include_file: + include_path = (base_dir / include_file).resolve() + specs.extend(_collect_actuator_specs(include_path, seen)) + continue + if child.tag == "actuator": + specs.extend(_parse_actuator_block(child)) + continue + walk(child) + + walk(root) + return specs + + +def _parse_actuator_block(actuator_elem: ET.Element) -> list[_ActuatorSpec]: + specs: list[_ActuatorSpec] = [] + for child in actuator_elem: + joint = child.attrib.get("joint") + tendon = child.attrib.get("tendon") + if not joint and not tendon: + continue + name = child.attrib.get("name") or joint or tendon or "actuator" + specs.append(_ActuatorSpec(name=name, joint=joint, tendon=tendon)) + return specs + + +def _build_joint_mappings_from_specs( + specs: list[_ActuatorSpec], + model: mujoco.MjModel, +) -> list[JointMapping]: + mappings: list[JointMapping] = [] + for spec in specs: + if spec.joint: + mappings.append(_mapping_for_joint(spec, model)) + elif spec.tendon: + mappings.append(_mapping_for_tendon(spec, model)) + return mappings + + +def _mapping_for_joint(spec: _ActuatorSpec, model: mujoco.MjModel) -> JointMapping: + joint_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, spec.joint) + if joint_id < 0: + raise ValueError(f"Unknown joint '{spec.joint}' in MuJoCo model") + actuator_id = _find_actuator_id_for_joint(model, joint_id, spec.name) + joint_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_JOINT, joint_id) or spec.name + return JointMapping( + name=joint_name, + joint_id=joint_id, + actuator_id=actuator_id, + qpos_adr=int(model.jnt_qposadr[joint_id]), + dof_adr=int(model.jnt_dofadr[joint_id]), + tendon_qpos_adrs=(), + tendon_dof_adrs=(), + ) + + +def _mapping_for_tendon(spec: _ActuatorSpec, model: mujoco.MjModel) -> JointMapping: + name = spec.name or spec.tendon + if not name: + raise ValueError("Tendon actuator is missing a name and tendon reference") + tendon_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_TENDON, spec.tendon) + if tendon_id < 0: + raise ValueError(f"Unknown tendon '{spec.tendon}' in MuJoCo model") + actuator_id = _find_actuator_id_for_tendon(model, tendon_id, spec.name) + joint_ids = _tendon_joint_ids(model, tendon_id) + return JointMapping( + name=name, + joint_id=None, + actuator_id=actuator_id, + qpos_adr=None, + dof_adr=None, + tendon_qpos_adrs=tuple(int(model.jnt_qposadr[joint_id]) for joint_id in joint_ids), + tendon_dof_adrs=tuple(int(model.jnt_dofadr[joint_id]) for joint_id in joint_ids), + ) + + +def _find_actuator_id_for_joint( + model: mujoco.MjModel, + joint_id: int, + actuator_name: str | None, +) -> int | None: + if actuator_name: + act_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, actuator_name) + if act_id >= 0: + return int(act_id) + for act_id in range(int(model.nu)): + trn_type = int(model.actuator_trntype[act_id]) + if trn_type != int(mujoco.mjtTrn.mjTRN_JOINT): + continue + if int(model.actuator_trnid[act_id, 0]) == joint_id: + return act_id + return None + + +def _find_actuator_id_for_tendon( + model: mujoco.MjModel, + tendon_id: int, + actuator_name: str | None, +) -> int | None: + if actuator_name: + act_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, actuator_name) + if act_id >= 0: + return int(act_id) + for act_id in range(int(model.nu)): + trn_type = int(model.actuator_trntype[act_id]) + if trn_type != int(mujoco.mjtTrn.mjTRN_TENDON): + continue + if int(model.actuator_trnid[act_id, 0]) == tendon_id: + return act_id + return None + + +def _tendon_joint_ids(model: mujoco.MjModel, tendon_id: int) -> tuple[int, ...]: + adr = int(model.tendon_adr[tendon_id]) + num = int(model.tendon_num[tendon_id]) + joint_ids: list[int] = [] + for wrap_id in range(adr, adr + num): + wrap_type = int(model.wrap_type[wrap_id]) + if wrap_type == int(mujoco.mjtWrap.mjWRAP_JOINT): + joint_ids.append(int(model.wrap_objid[wrap_id])) + return tuple(joint_ids) + + +def _build_joint_mappings_from_actuators(model: mujoco.MjModel) -> list[JointMapping]: + mappings: list[JointMapping] = [] + for actuator_id in range(int(model.nu)): + actuator_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_ACTUATOR, actuator_id) + name = actuator_name or f"actuator{actuator_id}" + trn_type = int(model.actuator_trntype[actuator_id]) + if trn_type == int(mujoco.mjtTrn.mjTRN_JOINT): + joint_id = int(model.actuator_trnid[actuator_id, 0]) + joint_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_JOINT, joint_id) + mappings.append( + JointMapping( + name=joint_name or name, + joint_id=joint_id, + actuator_id=actuator_id, + qpos_adr=int(model.jnt_qposadr[joint_id]), + dof_adr=int(model.jnt_dofadr[joint_id]), + tendon_qpos_adrs=(), + tendon_dof_adrs=(), + ) + ) + continue + + if trn_type == int(mujoco.mjtTrn.mjTRN_TENDON): + tendon_id = int(model.actuator_trnid[actuator_id, 0]) + tendon_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_TENDON, tendon_id) + if not actuator_name and tendon_name: + name = tendon_name + joint_ids = _tendon_joint_ids(model, tendon_id) + mappings.append( + JointMapping( + name=name, + joint_id=None, + actuator_id=actuator_id, + qpos_adr=None, + dof_adr=None, + tendon_qpos_adrs=tuple( + int(model.jnt_qposadr[joint_id]) for joint_id in joint_ids + ), + tendon_dof_adrs=tuple( + int(model.jnt_dofadr[joint_id]) for joint_id in joint_ids + ), + ) + ) + continue + + mappings.append( + JointMapping( + name=name, + joint_id=None, + actuator_id=actuator_id, + qpos_adr=None, + dof_adr=None, + tendon_qpos_adrs=(), + tendon_dof_adrs=(), + ) + ) + + return mappings + + +def _build_joint_mappings_from_model(model: mujoco.MjModel) -> list[JointMapping]: + mappings: list[JointMapping] = [] + for joint_id in range(int(model.njnt)): + jnt_type = int(model.jnt_type[joint_id]) + if jnt_type not in ( + int(mujoco.mjtJoint.mjJNT_HINGE), + int(mujoco.mjtJoint.mjJNT_SLIDE), + ): + continue + joint_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_JOINT, joint_id) + name = joint_name or f"joint{joint_id}" + mappings.append( + JointMapping( + name=name, + joint_id=joint_id, + actuator_id=None, + qpos_adr=int(model.jnt_qposadr[joint_id]), + dof_adr=int(model.jnt_dofadr[joint_id]), + tendon_qpos_adrs=(), + tendon_dof_adrs=(), + ) + ) + return mappings