diff --git a/configs/gym/pour_water/gym_config.json b/configs/gym/pour_water/gym_config.json index 9048f81..b9cd4a4 100644 --- a/configs/gym/pour_water/gym_config.json +++ b/configs/gym/pour_water/gym_config.json @@ -1,6 +1,6 @@ { "id": "PourWater-v3", - "max_episodes": 5, + "max_episodes": 10, "env": { "events": { "random_light": { @@ -258,11 +258,40 @@ } }, "dataset": { - "robot_meta": { - "arm_dofs": 12, - "control_freq": 25, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], - "min_len_steps": 5 + "lerobot": { + "func": "LeRobotRecorder", + "mode": "save", + "params": { + "save_path": "/home/dex/projects/yuanhaonan/embodichain/outputs/data_example", + "id": 0, + "robot_meta": { + "robot_type": "CobotMagic", + "arm_dofs": 12, + "control_freq": 25, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], + "observation": { + "vision": { + "cam_high": ["mask"], + "cam_right_wrist": ["mask"], + "cam_left_wrist": ["mask"] + }, + "states": ["qpos"], + "exteroception": ["cam_high", "cam_right_wrist", "cam_left_wrist"] + }, + "action": "qpos_with_eef_pose", + "min_len_steps": 5 + }, + "instruction": { + "lang": "Pour water from bottle to cup" + }, + "extra": { + "scene_type": "Commercial", + "task_description": "Pour water", + "data_type": "sim" + }, + "use_videos": true, + "export_success_only": false + } } } }, diff --git a/configs/gym/pour_water/gym_config_simple.json b/configs/gym/pour_water/gym_config_simple.json new file mode 100644 index 0000000..1ddbad9 --- /dev/null +++ b/configs/gym/pour_water/gym_config_simple.json @@ -0,0 +1,330 @@ +{ + "id": "PourWater-v3", + "max_episodes": 5, + "env": { + "events": { + "random_light": { + "func": "randomize_light", + "mode": "interval", + "interval_step": 10, + "params": { + "entity_cfg": {"uid": "light_1"}, + "position_range": [[-0.5, -0.5, 2], [0.5, 0.5, 2]], + "color_range": [[0.6, 0.6, 0.6], [1, 1, 1]], + "intensity_range": [50.0, 100.0] + } + }, + "init_bottle_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "bottle"}, + "position_range": [[-0.08, -0.12, 0.0], [0.08, 0.04, 0.0]], + "relative_position": true + } + }, + "init_cup_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "cup"}, + "position_range": [[-0.08, -0.04, 0.0], [0.08, 0.12, 0.0]], + "relative_position": true + } + }, + "prepare_extra_attr": { + "func": "prepare_extra_attr", + "mode": "reset", + "params": { + "attrs": [ + { + "name": "object_lengths", + "mode": "callable", + "entity_uids": "all_objects", + "func_name": "compute_object_length", + "func_kwargs": { + "is_svd_frame": true, + "sample_points": 5000 + } + }, + { + "name": "grasp_pose_object", + "mode": "static", + "entity_cfg": { + "uid": "bottle" + }, + "value": [[ + [0.32243, 0.03245, 0.94604, 0.025], + [0.00706, -0.99947, 0.03188, -0.0 ], + [0.94657, -0.0036 , -0.32249, 0.0 ], + [0.0 , 0.0 , 0.0 , 1.0 ] + ]] + }, + { + "name": "left_arm_base_pose", + "mode": "callable", + "entity_cfg": { + "uid": "CobotMagic" + }, + "func_name": "get_link_pose", + "func_kwargs": { + "link_name": "left_arm_base", + "to_matrix": true + } + }, + { + "name": "right_arm_base_pose", + "mode": "callable", + "entity_cfg": { + "uid": "CobotMagic" + }, + "func_name": "get_link_pose", + "func_kwargs": { + "link_name": "right_arm_base", + "to_matrix": true + } + } + ] + } + }, + "register_info_to_env": { + "func": "register_info_to_env", + "mode": "reset", + "params": { + "registry": [ + { + "entity_cfg": { + "uid": "bottle" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "cup" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["left_arm"] + }, + "attrs": ["left_arm_base_pose"], + "pose_register_params": { + "compute_relative": "cup", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["right_arm"] + }, + "attrs": ["right_arm_base_pose"], + "pose_register_params": { + "compute_relative": "bottle", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + } + ], + "registration": "affordance_datas", + "sim_update": true + } + }, + "random_material": { + "func": "randomize_visual_material", + "mode": "interval", + "interval_step": 10, + "params": { + "entity_cfg": {"uid": "table"}, + "random_texture_prob": 0.0, + "texture_path": "CocoBackground/coco", + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]] + } + }, + "random_cup_material": { + "func": "randomize_visual_material", + "mode": "interval", + "interval_step": 10, + "params": { + "entity_cfg": {"uid": "cup"}, + "random_texture_prob": 0.0, + "texture_path": "CocoBackground/coco", + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]] + } + }, + "random_bottle_material": { + "func": "randomize_visual_material", + "mode": "interval", + "interval_step": 10, + "params": { + "entity_cfg": {"uid": "bottle"}, + "random_texture_prob": 0.0, + "texture_path": "CocoBackground/coco", + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]] + } + }, + "random_robot_init_eef_pose": { + "func": "randomize_robot_eef_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "CobotMagic", "control_parts": ["left_arm", "right_arm"]}, + "position_range": [[-0.01, -0.01, -0.01], [0.01, 0.01, 0]] + } + } + }, + "observations": { + "norm_robot_eef_joint": { + "func": "normalize_robot_joint_data", + "mode": "modify", + "name": "robot/qpos", + "params": { + "joint_ids": [12, 13, 14, 15] + } + } + }, + "dataset": { + "lerobot": { + "func": "LeRobotRecorder", + "mode": "save", + "params": { + "save_path": "/home/dex/projects/yuanhaonan/embodichain/outputs/data_example", + "id": 0, + "robot_meta": { + "robot_type": "CobotMagic", + "arm_dofs": 12, + "control_freq": 25, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], + "observation": { + "vision": { + "cam_high": ["mask"], + "cam_right_wrist": ["mask"], + "cam_left_wrist": ["mask"] + }, + "states": ["qpos"], + "exteroception": ["cam_high", "cam_right_wrist", "cam_left_wrist"] + }, + "action": "qpos_with_eef_pose", + "min_len_steps": 5 + }, + "instruction": { + "lang": "Pour water from bottle to cup" + }, + "extra": { + "scene_type": "Commercial", + "task_description": "Pour water", + "data_type": "sim" + }, + "use_videos": true, + "export_success_only": false + } + } + } + }, + "robot": { + "uid": "CobotMagic", + "robot_type": "CobotMagic", + "init_pos": [0.0, 0.0, 0.7775], + "init_qpos": [-0.3,0.3,1.0,1.0,-1.2,-1.2,0.0,0.0,0.6,0.6,0.0,0.0,0.05,0.05,0.05,0.05] + }, + "sensor": [ + { + "sensor_type": "Camera", + "uid": "cam_high", + "width": 960, + "height": 540, + "intrinsics": [488.1665344238281, 488.1665344238281, 480, 270], + "extrinsics": { + "eye": [0.35368482807598, 0.014695524383058989, 1.4517046071614774], + "target": [0.8586357573287919, 0, 0.5232553674540066], + "up": [0.9306678549330372, -0.0005600064212467153, 0.3658647703553347] + } + } + ], + "light": { + "direct": [ + { + "uid": "light_1", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 50.0, + "init_pos": [2, 0, 2], + "radius": 10.0 + } + ] + }, + "background": [ + { + "uid": "table", + "shape": { + "shape_type": "Mesh", + "fpath": "CircleTableSimple/circle_table_simple.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 10.0, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.01 + }, + "body_scale": [1, 1, 1], + "body_type": "kinematic", + "init_pos": [0.725, 0.0, 0.825], + "init_rot": [0, 90, 0] + } + ], + "rigid_object": [ + { + "uid":"cup", + "shape": { + "shape_type": "Mesh", + "fpath": "PaperCup/paper_cup.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 1e1, + "min_position_iters": 32, + "min_velocity_iters":8 + }, + "init_pos": [0.75, 0.1, 0.9], + "body_scale":[0.75, 0.75, 1.0], + "max_convex_hull_num": 8 + }, + { + "uid":"bottle", + "shape": { + "shape_type": "Mesh", + "fpath": "ScannedBottle/kashijia_processed.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 1e1, + "min_position_iters": 32, + "min_velocity_iters":8 + }, + "init_pos": [0.75, -0.1, 0.932], + "body_scale":[1, 1, 1], + "max_convex_hull_num": 8 + } + ] +} \ No newline at end of file diff --git a/embodichain/data/enum.py b/embodichain/data/enum.py index 3716b84..6163789 100644 --- a/embodichain/data/enum.py +++ b/embodichain/data/enum.py @@ -15,6 +15,8 @@ # ---------------------------------------------------------------------------- from enum import Enum, IntEnum +import torch +import numpy as np class SemanticMask(IntEnum): @@ -59,3 +61,168 @@ class Hints(Enum): EndEffector.DEXTROUSHAND.value, ) ARM = (ControlParts.LEFT_ARM.value, ControlParts.RIGHT_ARM.value) + + +class Modality(Enum): + STATES = "states" + STATE_INDICATOR = "state_indicator" + ACTIONS = "actions" + ACTION_INDICATOR = "action_indicator" + IMAGES = "images" + LANG = "lang" + LANG_INDICATOR = "lang_indicator" + GEOMAP = "geomap" # e.g., depth, point cloud, etc. + VISION_LANGUAGE = "vision_language" # e.g., image + lang + + +class JointType(Enum): + QPOS = "qpos" + + +class EefType(Enum): + POSE = "eef_pose" + + +class ActionMode(Enum): + ABSOLUTE = "" + RELATIVE = "delta_" # This indicates the action is relative change with respect to last state. + + +SUPPORTED_PROPRIO_TYPES = [ + ControlParts.LEFT_ARM.value + EefType.POSE.value, + ControlParts.RIGHT_ARM.value + EefType.POSE.value, + ControlParts.LEFT_ARM.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + JointType.QPOS.value, + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.LEFT_EEF.value + EndEffector.GRIPPER.value, + ControlParts.RIGHT_EEF.value + EndEffector.GRIPPER.value, +] +SUPPORTED_ACTION_TYPES = SUPPORTED_PROPRIO_TYPES + [ + ControlParts.LEFT_ARM.value + ActionMode.RELATIVE.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + ActionMode.RELATIVE.value + JointType.QPOS.value, +] + + +class HandQposNormalizer: + """ + A class for normalizing and denormalizing dexterous hand qpos data. + """ + + def __init__(self): + pass + + @staticmethod + def normalize_hand_qpos( + qpos_data: np.ndarray, + key: str, + agent=None, + robot=None, + ) -> np.ndarray: + """ + Clip and normalize dexterous hand qpos data. + + Args: + qpos_data: Raw qpos data + key: Control part key + agent: LearnableRobot instance (for V2 API) + robot: Robot instance (for V3 API) + + Returns: + Normalized qpos data in range [0, 1] + """ + if isinstance(qpos_data, torch.Tensor): + qpos_data = qpos_data.cpu().numpy() + + if agent is not None: + if key not in [ + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ]: + return qpos_data + indices = agent.get_data_index(key, warning=False) + full_limits = agent.get_joint_limits(agent.uid) + limits = full_limits[indices] # shape: [num_joints, 2] + elif robot is not None: + if key not in [ + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ]: + if key in [ControlParts.LEFT_EEF.value, ControlParts.RIGHT_EEF.value]: + # Note: In V3, robot does not distinguish between GRIPPER EEF and HAND EEF in uid, + # _data_key_to_control_part maps both to EEF. Under current conditions, normalization + # will not be performed. Please confirm if this is intended. + pass + return qpos_data + indices = robot.get_joint_ids(key, remove_mimic=True) + limits = robot.body_data.qpos_limits[0][indices] # shape: [num_joints, 2] + else: + raise ValueError("Either agent or robot must be provided") + + if isinstance(limits, torch.Tensor): + limits = limits.cpu().numpy() + + qpos_min = limits[:, 0] # Lower limits + qpos_max = limits[:, 1] # Upper limits + + # Step 1: Clip to valid range + qpos_clipped = np.clip(qpos_data, qpos_min, qpos_max) + + # Step 2: Normalize to [0, 1] + qpos_normalized = (qpos_clipped - qpos_min) / (qpos_max - qpos_min + 1e-8) + + return qpos_normalized + + @staticmethod + def denormalize_hand_qpos( + normalized_qpos: torch.Tensor, + key: str, # "left" or "right" + agent=None, + robot=None, + ) -> torch.Tensor: + """ + Denormalize normalized dexterous hand qpos back to actual angle values + + Args: + normalized_qpos: Normalized qpos in range [0, 1] + key: Control part key + robot: Robot instance + + Returns: + Denormalized actual qpos values + """ + + if agent is not None: + if key not in [ + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ]: + return normalized_qpos + indices = agent.get_data_index(key, warning=False) + full_limits = agent.get_joint_limits(agent.uid) + limits = full_limits[indices] # shape: [num_joints, 2] + elif robot is not None: + if key not in [ + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ]: + if key in [ControlParts.LEFT_EEF.value, ControlParts.RIGHT_EEF.value]: + # Note: In V3, robot does not distinguish between GRIPPER EEF and HAND EEF in uid, + # _data_key_to_control_part maps both to EEF. Under current conditions, denormalization + # will not be performed. Please confirm if this is intended. + pass + return normalized_qpos + indices = robot.get_joint_ids(key, remove_mimic=True) + limits = robot.body_data.qpos_limits[0][indices] # shape: [num_joints, 2] + else: + raise ValueError("Either agent or robot must be provided") + + qpos_min = limits[:, 0].cpu().numpy() # Lower limits + qpos_max = limits[:, 1].cpu().numpy() # Upper limits + + if isinstance(normalized_qpos, torch.Tensor): + normalized_qpos = normalized_qpos.cpu().numpy() + + denormalized_qpos = normalized_qpos * (qpos_max - qpos_min) + qpos_min + + return denormalized_qpos diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index 1b2e5b5..201d03b 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -380,7 +380,7 @@ def get_info(self, **kwargs) -> Dict[str, Any]: info.update(self.evaluate(**kwargs)) return info - def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> bool: + def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: """Check if the episode is truncated. Args: @@ -388,7 +388,7 @@ def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> bool: info: The info dictionary. Returns: - True if the episode is truncated, False otherwise. + A boolean tensor indicating truncation for each environment in the batch. """ return torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 1fc524a..730178f 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -20,7 +20,7 @@ import gymnasium as gym from dataclasses import MISSING -from typing import Dict, Union, Sequence, Tuple, Any, List +from typing import Dict, Union, Sequence, Tuple, Any, List, Optional from embodichain.lab.sim.cfg import ( RobotCfg, @@ -42,6 +42,7 @@ from embodichain.lab.gym.envs.managers import ( EventManager, ObservationManager, + DatasetManager, ) from embodichain.lab.gym.utils.registration import register_env from embodichain.utils import configclass, logger @@ -90,9 +91,10 @@ class EnvLightCfg: Please refer to the :class:`embodichain.lab.gym.managers.ObservationManager` class for more details. """ - # TODO: This would be changed to a more generic data pipeline configuration. - dataset: Union[Dict[str, Any], None] = None - """Data pipeline configuration. Defaults to None. + dataset: Union[object, None] = None + """Dataset settings. Defaults to None, in which case no dataset collection is performed. + + Please refer to the :class:`embodichain.lab.gym.managers.DatasetManager` class for more details. """ # Some helper attributes @@ -131,6 +133,7 @@ class EmbodiedEnv(BaseEnv): def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): self.affordance_datas = {} self.action_bank = None + self._force_truncated: bool = False extensions = getattr(cfg, "extensions", {}) or {} @@ -165,14 +168,8 @@ def _init_sim_state(self, **kwargs): if self.cfg.observations: self.observation_manager = ObservationManager(self.cfg.observations, self) - # TODO: A workaround for handling dataset saving, which need history data of obs-action pairs. - # We may improve this by implementing a data manager to handle data saving and online streaming. - if self.cfg.dataset is not None: - self.metadata["dataset"] = self.cfg.dataset - self.episode_obs_list = [] - self.episode_action_list = [] - - self.curr_episode = 0 + if self.cfg.dataset: + self.dataset_manager = DatasetManager(self.cfg.dataset, self) def _apply_functor_filter(self) -> None: """Apply functor filters to the environment components based on configuration. @@ -252,29 +249,77 @@ def get_affordance(self, key: str, default: Any = None): """ return self.affordance_datas.get(key, default) + def set_force_truncated(self, value: bool = True): + """ + Set force_truncated flag to trigger episode truncation. + """ + self._force_truncated = value + def reset( self, seed: int | None = None, options: dict | None = None ) -> Tuple[EnvObs, Dict]: obs, info = super().reset(seed=seed, options=options) - - if hasattr(self, "episode_obs_list"): - self.episode_obs_list = [obs] - self.episode_action_list = [] - + self._force_truncated: bool = False return obs, info def step( self, action: EnvAction, **kwargs ) -> Tuple[EnvObs, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]: - # TODO: Maybe add action preprocessing manager and its functors. - obs, reward, done, truncated, info = super().step(action, **kwargs) - - if hasattr(self, "episode_action_list"): + """Step the environment with the given action. + + Extends BaseEnv.step() to integrate with DatasetManager for automatic + data collection and saving. The key is to: + 1. Record obs-action pairs as they happen + 2. Detect episode completion + 3. Auto-save episodes BEFORE reset + 4. Then perform the actual reset + """ + self._elapsed_steps += 1 + + action = self._step_action(action=action) + self.sim.update(self.sim_cfg.physics_dt, self.cfg.sim_steps_per_control) + self._update_sim_state(**kwargs) + + obs = self.get_obs(**kwargs) + info = self.get_info(**kwargs) + rewards = self.get_reward(obs=obs, action=action, info=info) + + # Check termination conditions + terminateds = torch.logical_or( + info.get( + "success", + torch.zeros(self.num_envs, dtype=torch.bool, device=self.device), + ), + info.get( + "fail", torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + ), + ) + truncateds = self.check_truncated(obs=obs, info=info) + if self.cfg.ignore_terminations: + terminateds[:] = False + + # Detect which environments need reset + dones = torch.logical_or(terminateds, truncateds) + reset_env_ids = dones.nonzero(as_tuple=False).squeeze(-1) + + # Call dataset manager with mode="save": it will record and auto-save if dones=True + if self.cfg.dataset: + if "save" in self.dataset_manager.available_modes: + self.dataset_manager.apply( + mode="save", + env_ids=None, + obs=obs, + action=action, + dones=dones, + terminateds=terminateds, + info=info, + ) - self.episode_obs_list.append(obs) - self.episode_action_list.append(action) + # Now perform reset for completed environments + if len(reset_env_ids) > 0: + obs, _ = self.reset(options={"reset_ids": reset_env_ids}) - return obs, reward, done, truncated, info + return obs, rewards, terminateds, truncateds, info def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs: if self.observation_manager: @@ -452,29 +497,38 @@ def create_demo_action_list(self, *args, **kwargs) -> Sequence[EnvAction] | None "The method 'create_demo_action_list' must be implemented in subclasses." ) - def to_dataset(self, id: str, save_path: str = None) -> str | None: - """Convert the recorded episode data to a dataset format. + def is_task_success(self, **kwargs) -> torch.Tensor: + """ + Determine if the task is successfully completed. This is mainly used in the data generation process + of the imitation learning. Args: - id (str): Unique identifier for the dataset. - save_path (str, optional): Path to save the dataset. If None, use config or default. + **kwargs: Additional arguments for task-specific success criteria. Returns: - str | None: The path to the saved dataset, or None if failed. + torch.Tensor: A boolean tensor indicating success for each environment in the batch. """ - raise NotImplementedError( - "The method 'to_dataset' will be implemented in the near future." - ) - def is_task_success(self, **kwargs) -> torch.Tensor: - """Determine if the task is successfully completed. This is mainly used in the data generation process - of the imitation learning. + return torch.ones(self.num_envs, dtype=torch.bool, device=self.device) + + def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: + """Check if the episode is truncated. Args: - **kwargs: Additional arguments for task-specific success criteria. + obs: The observation from the environment. + info: The info dictionary. Returns: - torch.Tensor: A boolean tensor indicating success for each environment in the batch. + A boolean tensor indicating truncation for each environment in the batch. """ + if self._force_truncated: + return torch.ones(self.num_envs, dtype=torch.bool, device=self.device) + return super().check_truncated(obs, info) - return torch.ones(self.num_envs, dtype=torch.bool, device=self.device) + def close(self) -> None: + """Close the environment and release resources.""" + # Finalize dataset if present + if self.cfg.dataset: + self.dataset_manager.finalize() + + self.sim.destroy() diff --git a/embodichain/lab/gym/envs/managers/__init__.py b/embodichain/lab/gym/envs/managers/__init__.py index 946165a..b7825cc 100644 --- a/embodichain/lab/gym/envs/managers/__init__.py +++ b/embodichain/lab/gym/envs/managers/__init__.py @@ -14,7 +14,15 @@ # limitations under the License. # ---------------------------------------------------------------------------- -from .cfg import FunctorCfg, SceneEntityCfg, EventCfg, ObservationCfg +from .cfg import ( + FunctorCfg, + SceneEntityCfg, + EventCfg, + ObservationCfg, + DatasetFunctorCfg, +) from .manager_base import Functor, ManagerBase from .event_manager import EventManager from .observation_manager import ObservationManager +from .dataset_manager import DatasetManager +from .datasets import LeRobotRecorder diff --git a/embodichain/lab/gym/envs/managers/cfg.py b/embodichain/lab/gym/envs/managers/cfg.py index 3f5c8da..07888c9 100644 --- a/embodichain/lab/gym/envs/managers/cfg.py +++ b/embodichain/lab/gym/envs/managers/cfg.py @@ -309,3 +309,15 @@ def _resolve_body_names(self, scene: SimulationManager): if isinstance(self.body_ids, int): self.body_ids = [self.body_ids] self.body_names = [entity.body_names[i] for i in self.body_ids] + + +@configclass +class DatasetFunctorCfg(FunctorCfg): + """Configuration for dataset collection functors. + + Dataset functors are called with mode="save" which handles both: + - Recording observation-action pairs on every step + - Auto-saving episodes when dones=True + """ + + mode: Literal["save"] = "save" diff --git a/embodichain/lab/gym/envs/managers/dataset_manager.py b/embodichain/lab/gym/envs/managers/dataset_manager.py new file mode 100644 index 0000000..0a8e9d4 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/dataset_manager.py @@ -0,0 +1,321 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Dataset manager for orchestrating dataset collection functors.""" + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from collections.abc import Sequence + +import torch +from prettytable import PrettyTable + +from embodichain.utils import logger +from embodichain.lab.sim.types import EnvObs, EnvAction +from .manager_base import ManagerBase +from .cfg import DatasetFunctorCfg + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + + +class DatasetManager(ManagerBase): + """Manager for orchestrating dataset collection and saving using functors. + + The dataset manager supports multiple dataset formats through a functor system: + - LeRobot format (via LeRobotRecorder) + - HDF5 format (via HDF5Recorder) + - Zarr format (via ZarrRecorder) + - Custom formats (via user-defined functors) + + Each functor's step() method is called once per environment step and handles: + - Recording observation-action pairs + - Detecting episode completion (dones=True) + - Auto-saving completed episodes + + Example configuration: + >>> from embodichain.lab.gym.envs.managers.cfg import DatasetFunctorCfg + >>> from embodichain.lab.gym.envs.managers.datasets import LeRobotRecorder + >>> + >>> @configclass + >>> class MyEnvCfg: + >>> dataset: dict = { + >>> "lerobot": DatasetFunctorCfg( + >>> func=LeRobotRecorder, + >>> params={ + >>> "robot_meta": {...}, + >>> "instruction": {"lang": "pick and place"}, + >>> "extra": {"scene_type": "kitchen"}, + >>> "save_path": "/data/datasets", + >>> "export_success_only": True, + >>> } + >>> ) + >>> } + """ + + _env: EmbodiedEnv + """The environment instance.""" + + def __init__(self, cfg: object, env: EmbodiedEnv): + """Initialize the dataset manager. + + Args: + cfg: Configuration object containing dataset functor configurations. + env: The environment instance. + """ + # Store functors by mode (similar to EventManager) + self._mode_functor_names: dict[str, list[str]] = {} + self._mode_functor_cfgs: dict[str, list[DatasetFunctorCfg]] = {} + self._mode_class_functor_cfgs: dict[str, list[DatasetFunctorCfg]] = {} + + # Call base class to parse functors + super().__init__(cfg, env) + + ## TODO: fix configurable_action.py to avoid getting env.metadata['dataset'] + # Extract robot_meta from first functor and add to env.metadata for backward compatibility + # This allows legacy code (like action_bank) to access robot_meta via env.metadata["dataset"]["robot_meta"] + for mode_cfgs in self._mode_functor_cfgs.values(): + for functor_cfg in mode_cfgs: + if "robot_meta" in functor_cfg.params: + if not hasattr(env, "metadata"): + env.metadata = {} + if "dataset" not in env.metadata: + env.metadata["dataset"] = {} + env.metadata["dataset"]["robot_meta"] = functor_cfg.params[ + "robot_meta" + ] + logger.log_info( + "Added robot_meta to env.metadata for backward compatibility" + ) + break + else: + continue + break + + logger.log_info( + f"DatasetManager initialized with {sum(len(v) for v in self._mode_functor_names.values())} functors" + ) + + def __str__(self) -> str: + """Returns: A string representation for dataset manager.""" + msg = f" contains {len(self._functor_names)} active functors.\n" + + table = PrettyTable() + table.title = "Active Dataset Functors" + table.field_names = ["Index", "Name", "Type"] + table.align["Name"] = "l" + + for index, name in enumerate(self._functor_names): + functor_cfg = self._functor_cfgs[index] + functor_type = ( + functor_cfg.func.__class__.__name__ + if hasattr(functor_cfg.func, "__class__") + else str(functor_cfg.func) + ) + table.add_row([index, name, functor_type]) + + msg += table.get_string() + msg += "\n" + + return msg + + """ + Properties. + """ + + @property + def active_functors(self) -> dict[str, list[str]]: + """Name of active dataset functors by mode. + + The keys are the modes and the values are the names of the dataset functors. + """ + return self._mode_functor_names + + @property + def available_modes(self) -> list[str]: + """List of available modes for the dataset manager.""" + return list(self._mode_functor_names.keys()) + + """ + Operations. + """ + + def reset( + self, env_ids: Union[Sequence[int], torch.Tensor, None] = None + ) -> dict[str, float]: + """Reset all dataset functors. + + Args: + env_ids: The environment ids. Defaults to None. + + Returns: + Empty dict (no logging info). + """ + # Call reset on all class functors across all modes + for mode_cfgs in self._mode_class_functor_cfgs.values(): + for functor_cfg in mode_cfgs: + functor_cfg.func.reset(env_ids=env_ids) + + return {} + + def apply( + self, + mode: str, + env_ids: Union[Sequence[int], torch.Tensor, None] = None, + obs: Optional[EnvObs] = None, + action: Optional[EnvAction] = None, + dones: Optional[torch.Tensor] = None, + terminateds: Optional[torch.Tensor] = None, + info: Optional[Dict[str, Any]] = None, + ) -> None: + """Apply dataset functors for the specified mode. + + This method follows the same pattern as EventManager.apply() for consistency. + Currently only supports mode="save" which handles both recording and auto-saving. + + Args: + mode: The mode to apply (currently only "save" is supported). + env_ids: The indices of the environments to apply the functor to. + Defaults to None, in which case the functor is applied to all environments. + obs: Observation from the environment (batched for all envs). + action: Action applied to the environment (batched for all envs). + dones: Boolean tensor indicating which envs completed episodes. + terminateds: Boolean tensor indicating termination (success/fail). + info: Info dict containing success/fail information. + """ + # check if mode is valid + if mode not in self._mode_functor_names: + logger.log_warning( + f"Dataset mode '{mode}' is not defined. Skipping dataset operation." + ) + return + + # iterate over all the dataset functors for this mode + for functor_cfg in self._mode_functor_cfgs[mode]: + functor_cfg.func( + self._env, + env_ids, + obs, + action, + dones, + terminateds, + info, + **functor_cfg.params, + ) + + def finalize(self) -> Optional[str]: + """Finalize all dataset functors. + + Called when the environment is closed. Saves any remaining episodes + and finalizes all datasets. + + Returns: + Path to the first saved dataset, or None if failed. + """ + dataset_paths = [] + + # Call finalize on all class functors across all modes + for mode_cfgs in self._mode_class_functor_cfgs.values(): + for functor_cfg in mode_cfgs: + if hasattr(functor_cfg.func, "finalize"): + try: + path = functor_cfg.func.finalize() + if path: + dataset_paths.append(path) + except Exception as e: + logger.log_error(f"Failed to finalize functor: {e}") + + if dataset_paths: + logger.log_info(f"Finalized {len(dataset_paths)} datasets") + return dataset_paths[0] + + return None + + """ + Operations - Functor settings. + """ + + def get_functor_cfg(self, functor_name: str) -> DatasetFunctorCfg: + """Gets the configuration for the specified functor. + + Args: + functor_name: The name of the dataset functor. + + Returns: + The configuration of the dataset functor. + + Raises: + ValueError: If the functor name is not found. + """ + for mode, functors in self._mode_functor_names.items(): + if functor_name in functors: + return self._mode_functor_cfgs[mode][functors.index(functor_name)] + logger.log_error(f"Dataset functor '{functor_name}' not found.") + + """ + Helper functions. + """ + + def _prepare_functors(self): + """Prepare dataset functors from configuration. + + This method parses the configuration and initializes all dataset functors, + organizing them by mode (similar to EventManager). + """ + # Check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + + # Iterate over all the functors + for functor_name, functor_cfg in cfg_items: + # Check for non config + if functor_cfg is None: + continue + + # Convert dict to DatasetFunctorCfg if needed (for JSON configs) + if isinstance(functor_cfg, dict): + functor_cfg = DatasetFunctorCfg(**functor_cfg) + + # Check for valid config type + if not isinstance(functor_cfg, DatasetFunctorCfg): + raise TypeError( + f"Configuration for '{functor_name}' is not of type DatasetFunctorCfg." + f" Received: '{type(functor_cfg)}'." + ) + + # Resolve common parameters + # min_argc=7 to skip: env, env_ids, obs, action, dones, terminateds, info + # These are runtime positional arguments, not config parameters + self._resolve_common_functor_cfg(functor_name, functor_cfg, min_argc=7) + + # Check if mode is a new mode + if functor_cfg.mode not in self._mode_functor_names: + # add new mode + self._mode_functor_names[functor_cfg.mode] = [] + self._mode_functor_cfgs[functor_cfg.mode] = [] + self._mode_class_functor_cfgs[functor_cfg.mode] = [] + + # Add functor name and parameters + self._mode_functor_names[functor_cfg.mode].append(functor_name) + self._mode_functor_cfgs[functor_cfg.mode].append(functor_cfg) + + # Check if the functor is a class + if inspect.isclass(functor_cfg.func): + self._mode_class_functor_cfgs[functor_cfg.mode].append(functor_cfg) diff --git a/embodichain/lab/gym/envs/managers/datasets.py b/embodichain/lab/gym/envs/managers/datasets.py new file mode 100644 index 0000000..32294ec --- /dev/null +++ b/embodichain/lab/gym/envs/managers/datasets.py @@ -0,0 +1,429 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Dataset functors for collecting and saving episode data.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import numpy as np +import torch + +from embodichain.utils import logger +from embodichain.lab.sim.types import EnvObs, EnvAction +from embodichain.lab.gym.utils.misc import is_stereocam +from embodichain.utils.utility import get_right_name +from embodichain.data.enum import JointType +from .manager_base import Functor +from .cfg import DatasetFunctorCfg + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + +try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset, HF_LEROBOT_HOME + + LEROBOT_AVAILABLE = True +except ImportError: + LEROBOT_AVAILABLE = False + + +class LeRobotRecorder(Functor): + """Functor for recording episodes in LeRobot format. + + This functor handles: + - Recording observation-action pairs during episodes + - Converting data to LeRobot format + - Saving episodes when they complete + """ + + def __init__(self, cfg: DatasetFunctorCfg, env: EmbodiedEnv): + """Initialize the LeRobot dataset recorder. + + Args: + cfg: Functor configuration containing params: + - save_path: Root directory for saving datasets + - id: Dataset identifier (repo_id) + - robot_meta: Robot metadata for dataset + - instruction: Optional task instruction + - extra: Optional extra metadata + - use_videos: Whether to save videos + - image_writer_threads: Number of threads for image writing + - image_writer_processes: Number of processes for image writing + - export_success_only: Whether to export only successful episodes + env: The environment instance + """ + super().__init__(cfg, env) + + # Extract parameters from cfg.params + params = cfg.params + + # Required parameters + self.lerobot_data_root = params.get("save_path", "/tmp/lerobot_data") + self.repo_id = params.get( + "id", 0 + ) # Can be int (version counter) or str (dataset name) + self.robot_meta = params.get("robot_meta", {}) + + # Optional parameters + self.instruction = params.get("instruction", None) + self.extra = params.get("extra", {}) + self.use_videos = params.get("use_videos", False) + self.export_success_only = params.get("export_success_only", False) + + # Episode data buffers + self.episode_obs_list: List[Dict] = [] + self.episode_action_list: List[Any] = [] + + # LeRobot dataset instance + self.dataset: Optional[LeRobotDataset] = None + self.dataset_id: int = 0 # Will be set in _initialize_dataset + + # Tracking + self.total_time: float = 0.0 + self.curr_episode: int = 0 + + # Initialize dataset + self._initialize_dataset() + + logger.log_info(f"LeRobotRecorder initialized at: {self.dataset_path}") + + @property + def dataset_path(self) -> str: + """Path to the dataset directory.""" + return str(Path(self.lerobot_data_root) / self.repo_id) + + def reset(self, env_ids: Optional[torch.Tensor] = None) -> None: + """Reset the recorder buffers. + + Args: + env_ids: Environment IDs to reset (currently clears all data). + """ + self._reset_buffer() + + def __call__( + self, + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + obs: EnvObs, + action: EnvAction, + dones: torch.Tensor, + terminateds: torch.Tensor, + info: Dict[str, Any], + save_path: Optional[str] = None, + id: Optional[str] = None, + robot_meta: Optional[Dict] = None, + instruction: Optional[str] = None, + extra: Optional[Dict] = None, + use_videos: bool = False, + export_success_only: bool = False, + ) -> None: + """Main entry point for the recorder functor. + + This method is called by DatasetManager.apply(mode="save") with runtime arguments + as positional parameters and configuration parameters from cfg.params. + + Args: + env: The environment instance. + env_ids: Environment IDs (for consistency with EventManager pattern). + obs: Observation from the environment. + action: Action applied to the environment. + dones: Boolean tensor indicating which envs completed episodes. + terminateds: Termination flags (success/fail). + info: Info dict containing success/fail information. + save_path: Root directory (already set in __init__). + id: Dataset identifier (already set in __init__). + robot_meta: Robot metadata (already set in __init__). + instruction: Task instruction (already set in __init__). + extra: Extra metadata (already set in __init__). + use_videos: Whether to save videos (already set in __init__). + export_success_only: Whether to export only successful episodes (already set in __init__). + """ + # Always record the step + self._record_step(obs, action) + + # Check if any episodes are done and save them + done_env_ids = dones.nonzero(as_tuple=False).squeeze(-1) + if len(done_env_ids) > 0: + # Save completed episodes + self._save_episodes(done_env_ids, terminateds, info) + + def _record_step(self, obs: EnvObs, action: EnvAction) -> None: + """Record a single step.""" + self.episode_obs_list.append(obs) + self.episode_action_list.append(action) + + def _save_episodes( + self, + env_ids: torch.Tensor, + terminateds: Optional[torch.Tensor] = None, + info: Optional[Dict[str, Any]] = None, + ) -> None: + """Save completed episodes.""" + if len(self.episode_obs_list) == 0: + logger.log_warning("No episode data to save") + return + + obs_list = self.episode_obs_list + action_list = self.episode_action_list + + # Align obs and action + if len(obs_list) > len(action_list): + obs_list = obs_list[:-1] + + task = self.instruction.get("lang", "unknown_task") + + # Update metadata + extra_info = self.extra.copy() if self.extra else {} + fps = self.dataset.meta.info.get("fps", 30) + current_episode_time = (len(obs_list) * len(env_ids)) / fps if fps > 0 else 0 + + episode_extra_info = extra_info.copy() + self.total_time += current_episode_time + episode_extra_info["total_time"] = self.total_time + self._update_dataset_info({"extra": episode_extra_info}) + + # Process each environment + for env_id in env_ids.cpu().tolist(): + is_success = False + if info is not None and "success" in info: + success_tensor = info["success"] + if isinstance(success_tensor, torch.Tensor): + is_success = success_tensor[env_id].item() + else: + is_success = success_tensor + elif terminateds is not None: + is_success = terminateds[env_id].item() + + logger.log_info(f"Episode {env_id} success: {is_success}") + if self.export_success_only and not is_success: + logger.log_info(f"Skipping failed episode for env {env_id}") + continue + + try: + for obs, action in zip(obs_list, action_list): + frame = self._convert_frame_to_lerobot(obs, action, task, env_id) + self.dataset.add_frame(frame) + + self.dataset.save_episode() + logger.log_info( + f"Auto-saved {'successful' if is_success else 'failed'} " + f"episode {self.curr_episode} for env {env_id} with {len(obs_list)} frames" + ) + self.curr_episode += 1 + except Exception as e: + logger.log_error(f"Failed to save episode {env_id}: {e}") + + self._reset_buffer() + + def finalize(self) -> Optional[str]: + """Finalize the dataset.""" + if len(self.episode_obs_list) > 0: + active_env_ids = torch.arange(self.num_envs, device=self.device) + self._save_episodes(active_env_ids) + + try: + if self.dataset is not None: + self.dataset.finalize() + logger.log_info(f"Dataset finalized at: {self.dataset_path}") + return self.dataset_path + except Exception as e: + logger.log_error(f"Failed to finalize dataset: {e}") + + return None + + def _reset_buffer(self) -> None: + """Reset episode buffers.""" + self.episode_obs_list.clear() + self.episode_action_list.clear() + logger.log_info("Reset buffers (cleared all batched data)") + + def _initialize_dataset(self) -> None: + """Initialize the LeRobot dataset.""" + robot_type = self.robot_meta.get("robot_type", "robot") + scene_type = self.extra.get("scene_type", "scene") + task_description = self.extra.get("task_description", "task") + + robot_type = str(robot_type).lower().replace(" ", "_") + task_description = str(task_description).lower().replace(" ", "_") + + # Use lerobot_data_root from __init__ + lerobot_data_root = Path(self.lerobot_data_root) + + # repo_id from config or generate one + if isinstance(self.repo_id, int): + # If repo_id is an integer, generate a name + dataset_id = self.repo_id + while True: + repo_id = f"{robot_type}_{scene_type}_{task_description}_v{dataset_id}" + dataset_dir = lerobot_data_root / repo_id + if not dataset_dir.exists(): + break + dataset_id += 1 + self.repo_id = repo_id + self.dataset_id = dataset_id + else: + # repo_id is already a string, use it directly + self.dataset_id = 0 + + fps = self.robot_meta.get("control_freq", 30) + features = self._build_features() + + try: + self.dataset = LeRobotDataset.create( + repo_id=self.repo_id, + fps=fps, + root=str(lerobot_data_root), + robot_type=robot_type, + features=features, + use_videos=self.use_videos, + ) + logger.log_info( + f"Created LeRobot dataset at: {lerobot_data_root / self.repo_id}" + ) + except FileExistsError: + self.dataset = LeRobotDataset( + repo_id=self.repo_id, root=str(lerobot_data_root) + ) + logger.log_info( + f"Loaded existing LeRobot dataset at: {lerobot_data_root / self.repo_id}" + ) + except Exception as e: + logger.log_error(f"Failed to create/load LeRobot dataset: {e}") + raise + + def _build_features(self) -> Dict: + """Build LeRobot features dict.""" + features = {} + extra_vision_config = self.robot_meta.get("observation", {}).get("vision", {}) + + for camera_name in extra_vision_config.keys(): + sensor = self._env.get_sensor(camera_name) + is_stereo = is_stereocam(sensor) + img_shape = (sensor.cfg.height, sensor.cfg.width, 3) + + features[camera_name] = { + "dtype": "video" if self.use_videos else "image", + "shape": img_shape, + "names": ["height", "width", "channel"], + } + + if is_stereo: + features[get_right_name(camera_name)] = { + "dtype": "video" if self.use_videos else "image", + "shape": img_shape, + "names": ["height", "width", "channel"], + } + + qpos = self._env.robot.get_qpos() + state_dim = qpos.shape[1] + + if state_dim > 0: + features["observation.state"] = { + "dtype": "float32", + "shape": (state_dim,), + "names": ["state"], + } + + action_dim = self.robot_meta.get("arm_dofs", 7) + features["action"] = { + "dtype": "float32", + "shape": (action_dim,), + "names": ["action"], + } + + return features + + def _convert_frame_to_lerobot( + self, obs: Dict[str, Any], action: Any, task: str, env_id: int + ) -> Dict: + """Convert a single frame to LeRobot format.""" + frame = {"task": task} + extra_vision_config = self.robot_meta.get("observation", {}).get("vision", {}) + arm_dofs = self.robot_meta.get("arm_dofs", 7) + + # Add images + for camera_name in extra_vision_config.keys(): + if camera_name in obs.get("sensor", {}): + sensor = self._env.get_sensor(camera_name) + is_stereo = is_stereocam(sensor) + + color_data = obs["sensor"][camera_name]["color"] + if isinstance(color_data, torch.Tensor): + color_img = color_data[env_id][:, :, :3].cpu().numpy() + else: + color_img = np.array(color_data)[env_id][:, :, :3] + + if color_img.dtype in [np.float32, np.float64]: + color_img = (color_img * 255).astype(np.uint8) + + frame[camera_name] = color_img + + if is_stereo: + color_right_data = obs["sensor"][camera_name]["color_right"] + if isinstance(color_right_data, torch.Tensor): + color_right_img = ( + color_right_data[env_id][:, :, :3].cpu().numpy() + ) + else: + color_right_img = np.array(color_right_data)[env_id][:, :, :3] + + if color_right_img.dtype in [np.float32, np.float64]: + color_right_img = (color_right_img * 255).astype(np.uint8) + + frame[get_right_name(camera_name)] = color_right_img + + # Add state + qpos = obs["robot"][JointType.QPOS.value] + if isinstance(qpos, torch.Tensor): + state_data = qpos[env_id].cpu().numpy().astype(np.float32) + else: + state_data = np.array(qpos)[env_id].astype(np.float32) + + frame["observation.state"] = state_data + + # Add action + if isinstance(action, torch.Tensor): + action_data = action[env_id, :arm_dofs].cpu().numpy() + elif isinstance(action, np.ndarray): + action_data = action[env_id, :arm_dofs] + elif isinstance(action, dict): + action_data = action.get("action", action.get("arm_action", action)) + if isinstance(action_data, torch.Tensor): + action_data = action_data[env_id, :arm_dofs].cpu().numpy() + elif isinstance(action_data, np.ndarray): + action_data = action_data[env_id, :arm_dofs] + else: + action_data = np.array(action)[env_id, :arm_dofs] + + frame["action"] = action_data + + return frame + + def _update_dataset_info(self, updates: dict) -> bool: + """Update dataset metadata.""" + if self.dataset is None: + logger.log_error("LeRobotDataset not initialized.") + return False + + try: + self.dataset.meta.info.update(updates) + return True + except Exception as e: + logger.log_error(f"Failed to update dataset info: {e}") + return False diff --git a/embodichain/lab/gym/envs/managers/object/geometry.py b/embodichain/lab/gym/envs/managers/object/geometry.py index db26b18..5ede860 100644 --- a/embodichain/lab/gym/envs/managers/object/geometry.py +++ b/embodichain/lab/gym/envs/managers/object/geometry.py @@ -47,7 +47,7 @@ def get_pcd_svd_frame(pc: torch.Tensor) -> torch.Tensor: pc_centered = pc - pc_center u, s, vt = torch.linalg.svd(pc_centered) rotation = vt.T - pc_pose = torch.eye(4, dtype=torch.float32) + pc_pose = torch.eye(4, dtype=torch.float32, device=pc.device) pc_pose[:3, :3] = rotation pc_pose[:3, 3] = pc_center return pc_pose @@ -90,7 +90,7 @@ def apply_svd_transfer_pcd( standard_verts = [] for object_verts in verts: pc_svd_frame = get_pcd_svd_frame(object_verts) - inv_svd_frame = inv_transform(pc_svd_frame) + inv_svd_frame = torch.linalg.inv(pc_svd_frame) standard_object_verts = ( object_verts @ inv_svd_frame[:3, :3].T + inv_svd_frame[:3, 3] ) diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 0f92abb..ebe06d6 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -364,6 +364,7 @@ def config_to_cfg(config: dict) -> "EmbodiedEnvCfg": SceneEntityCfg, EventCfg, ObservationCfg, + DatasetFunctorCfg, ) from embodichain.utils import configclass from embodichain.data import get_data_path @@ -453,7 +454,32 @@ class ComponentCfg: env_cfg.sim_steps_per_control = config["env"].get("sim_steps_per_control", 4) # load dataset config - env_cfg.dataset = config["env"].get("dataset", None) + env_cfg.dataset = ComponentCfg() + if "dataset" in config["env"]: + # Define modules to search for dataset functions + dataset_modules = [ + "embodichain.lab.gym.envs.managers.datasets", + ] + + for dataset_name, dataset_params in config["env"]["dataset"].items(): + dataset_params_modified = deepcopy(dataset_params) + + # Find the function from multiple modules using the utility function + dataset_func = find_function_from_modules( + dataset_params["func"], + dataset_modules, + raise_if_not_found=True, + ) + + from embodichain.lab.gym.envs.managers import DatasetFunctorCfg + + dataset = DatasetFunctorCfg( + func=dataset_func, + mode=dataset_params_modified["mode"], + params=dataset_params_modified["params"], + ) + + setattr(env_cfg.dataset, dataset_name, dataset) # TODO: support more env events, eg, grasp pose generation, mesh preprocessing, etc. diff --git a/embodichain/lab/gym/utils/misc.py b/embodichain/lab/gym/utils/misc.py index b669e6c..b75b70a 100644 --- a/embodichain/lab/gym/utils/misc.py +++ b/embodichain/lab/gym/utils/misc.py @@ -1367,3 +1367,18 @@ def is_eef_hand(robot, control_parts) -> bool: if "gripper" in data_key and is_eef_hand(robot, control_parts) is False: return "right_eef" return None + + +def is_stereocam(sensor) -> bool: + """ + Check if a sensor is a StereoCamera (binocular camera). + + Args: + sensor: The sensor instance to check. + + Returns: + bool: True if the sensor is a StereoCamera, False otherwise. + """ + from embodichain.lab.sim.sensors import StereoCamera + + return isinstance(sensor, StereoCamera) diff --git a/embodichain/lab/scripts/replay_dataset.py b/embodichain/lab/scripts/replay_dataset.py new file mode 100644 index 0000000..96de834 --- /dev/null +++ b/embodichain/lab/scripts/replay_dataset.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +Script to replay LeRobot dataset trajectories in EmbodiedEnv. + +This script loads a LeRobot dataset and replays the recorded trajectories +in the EmbodiedEnv environment. It focuses on trajectory replay and uses +sensor configurations from the environment config file. + +Usage: + python replay_dataset.py --dataset_path /path/to/dataset --config /path/to/gym_config.json + python replay_dataset.py --dataset_path outputs/commercial_cobotmagic_pour_water_001 --config configs/gym/pour_water/gym_config.json --episode 0 +""" + +import os +import argparse +import gymnasium +import torch +import numpy as np +from pathlib import Path + +from embodichain.utils.logger import log_warning, log_info, log_error +from embodichain.utils.utility import load_json +from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.gym.utils.gym_utils import ( + config_to_cfg, +) + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Replay LeRobot dataset in EmbodiedEnv" + ) + parser.add_argument( + "--dataset_path", + type=str, + required=True, + help="Path to the LeRobot dataset directory", + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to the gym config JSON file (for environment setup)", + ) + parser.add_argument( + "--episode", + type=int, + default=None, + help="Specific episode index to replay (default: replay all episodes)", + ) + parser.add_argument( + "--headless", action="store_true", help="Run in headless mode without rendering" + ) + parser.add_argument( + "--fps", + type=int, + default=None, + help="Frames per second for replay (default: use dataset fps)", + ) + parser.add_argument( + "--save_video", action="store_true", help="Save replay as video" + ) + parser.add_argument( + "--video_path", + type=str, + default="./replay_videos", + help="Path to save replay videos", + ) + return parser.parse_args() + + +def load_lerobot_dataset(dataset_path): + """Load LeRobot dataset from the given path. + + Args: + dataset_path: Path to the LeRobot dataset directory + + Returns: + LeRobotDataset instance + """ + try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset + except ImportError as e: + log_error( + f"Failed to import LeRobot: {e}. " + "Please install lerobot: pip install lerobot" + ) + return None + + dataset_path = Path(dataset_path) + if not dataset_path.exists(): + log_error(f"Dataset path does not exist: {dataset_path}") + return None + + # Get repo_id from the dataset path (last directory name) + repo_id = dataset_path.name + # root = str(dataset_path.parent) + + log_info(f"Loading LeRobot dataset: {repo_id} from {dataset_path}") + + try: + dataset = LeRobotDataset(repo_id=repo_id, root=dataset_path) + log_info(f"Dataset loaded successfully:") + log_info( + f" - Total episodes: {dataset.meta.info.get('total_episodes', 'N/A')}" + ) + log_info(f" - Total frames: {dataset.meta.info.get('total_frames', 'N/A')}") + log_info(f" - FPS: {dataset.meta.info.get('fps', 'N/A')}") + log_info(f" - Robot type: {dataset.meta.info.get('robot_type', 'N/A')}") + return dataset + except Exception as e: + log_error(f"Failed to load dataset: {e}") + return None + + +def create_replay_env(config_path, headless=False): + """Create EmbodiedEnv for replay based on config. + + Args: + config_path: Path to the gym config JSON file + headless: Whether to run in headless mode + + Returns: + Gymnasium environment instance + """ + # Load configuration + gym_config = load_json(config_path) + + # Disable dataset recording during replay + if "dataset" in gym_config.get("env", {}): + gym_config["env"]["dataset"] = None + + # Convert config to dataclass + cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) + + # Set render mode + if not headless: + cfg.render_mode = "human" + else: + cfg.render_mode = None + + # Create environment + log_info(f"Creating environment: {gym_config['id']}") + env = gymnasium.make(id=gym_config["id"], cfg=cfg) + + return env + + +def replay_episode( + env, dataset, episode_idx, fps=None, save_video=False, video_path=None +): + """Replay a single episode from the dataset. + + Args: + env: EmbodiedEnv instance + dataset: LeRobotDataset instance + episode_idx: Episode index to replay + fps: Frames per second for replay + save_video: Whether to save replay as video + video_path: Path to save video + + Returns: + True if replay was successful, False otherwise + """ + # Get episode data + try: + ep_meta = dataset.meta.episodes[episode_idx] + start_idx = ep_meta["dataset_from_index"] + end_idx = ep_meta["dataset_to_index"] + episode_data = [dataset[i] for i in range(start_idx, end_idx)] + log_info(f"Replaying episode {episode_idx} with {len(episode_data)} frames") + except Exception as e: + log_error(f"Failed to load episode {episode_idx}: {e}") + return False + + # Reset environment + obs, info = env.reset() + + # Setup video recording if needed + if save_video and video_path: + os.makedirs(video_path, exist_ok=True) + video_file = os.path.join(video_path, f"episode_{episode_idx:04d}.mp4") + # TODO: Implement video recording + log_warning("Video recording is not yet implemented") + + # Replay trajectory + for frame_idx in range(len(episode_data)): + # Get action from dataset + frame = episode_data[frame_idx] + + # Extract action based on dataset action space + # The action format depends on the dataset's robot configuration + if "action" in frame: + action = frame["action"] + if isinstance(action, torch.Tensor): + action = action.cpu().numpy() + else: + log_warning(f"No action found in frame {frame_idx}, skipping") + continue + + # Step environment with recorded action + obs, reward, done, truncated, info = env.step(action) + + # Optional: Add delay to match FPS + if fps: + import time + + time.sleep(1.0 / fps) + + # Check if episode ended + if done or truncated: + log_info(f"Episode ended at frame {frame_idx}/{len(episode_data)}") + break + + log_info(f"Successfully replayed episode {episode_idx}") + return True + + +def main(): + """Main function to replay LeRobot dataset.""" + args = parse_args() + + # Load dataset + dataset = load_lerobot_dataset(args.dataset_path) + if dataset is None: + return + + # Create replay environment + env = create_replay_env(args.config, headless=args.headless) + + # Determine FPS + fps = args.fps if args.fps else dataset.meta.info.get("fps", 30) + log_info(f"Replay FPS: {fps}") + + # Replay episodes + if args.episode is not None: + # Replay single episode + log_info(f"Replaying single episode: {args.episode}") + success = replay_episode( + env, + dataset, + args.episode, + fps=fps, + save_video=args.save_video, + video_path=args.video_path, + ) + if not success: + log_error(f"Failed to replay episode {args.episode}") + else: + # Replay all episodes + total_episodes = dataset.meta.info.get("total_episodes", 0) + log_info(f"Replaying all {total_episodes} episodes") + + for episode_idx in range(total_episodes): + log_info(f"\n{'='*60}") + log_info(f"Episode {episode_idx + 1}/{total_episodes}") + log_info(f"{'='*60}") + + success = replay_episode( + env, + dataset, + episode_idx, + fps=fps, + save_video=args.save_video, + video_path=args.video_path, + ) + + if not success: + log_warning(f"Skipping episode {episode_idx} due to errors") + continue + + # Cleanup + env.close() + log_info("Replay completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/embodichain/lab/scripts/run_env.py b/embodichain/lab/scripts/run_env.py index 1ad5318..70a4b85 100644 --- a/embodichain/lab/scripts/run_env.py +++ b/embodichain/lab/scripts/run_env.py @@ -40,14 +40,18 @@ def generate_and_execute_action_list(env, idx, debug_mode): log_warning("Action is invalid. Skip to next generation.") return False - for action in tqdm.tqdm( - action_list, desc=f"Executing action list #{idx}", unit="step" + for idx_action, action in enumerate( + tqdm.tqdm(action_list, desc=f"Executing action list #{idx}", unit="step") ): + if idx_action == len(action_list) - 1: + log_info( + f"Setting force_truncated before final step at action index: {idx_action}" + ) + env.set_force_truncated(True) + # Step the environment with the current action obs, reward, terminated, truncated, info = env.step(action) - # TODO: May be add some functions for debug_mode - # TODO: We may assume in export demonstration rollout, there is no truncation from the env. # but truncation is useful to improve the generation efficiency. @@ -84,19 +88,19 @@ def generate_function( valid = True while True: - _, _ = env.reset() + # _, _ = env.reset() + ret = [] for trajectory_idx in range(num_traj): valid = generate_and_execute_action_list(env, trajectory_idx, debug_mode) if not valid: + _, _ = env.reset() break if not debug_mode and env.is_task_success().item(): pass - # TODO: Add data saving and online data streaming logic here. - else: log_warning(f"Task fail, Skip to next generation.") valid = False @@ -188,8 +192,8 @@ def main(args, env, gym_config): args = parser.parse_args() - if args.num_envs != 1: - log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") + # if args.num_envs != 1: + # log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") gym_config = load_json(args.gym_config) cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) diff --git a/pyproject.toml b/pyproject.toml index ad6d6d2..043ee60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "pytorch_kinematics==0.7.6", "polars==1.31.0", "PyYAML>=6.0", - "accelerate==1.2.1", + "accelerate>=1.10.0", "wandb==0.20.1", "tensorboard", "transformers>=4.53.0", @@ -51,6 +51,11 @@ dependencies = [ "h5py", ] +[project.optional-dependencies] +lerobot = [ + "lerobot==0.4.2" +] + [tool.setuptools.dynamic] version = { file = ["VERSION"] } diff --git a/scripts/data_gen.sh b/scripts/data_gen.sh new file mode 100755 index 0000000..598afb8 --- /dev/null +++ b/scripts/data_gen.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +NUM_PROCESSES=3 # Set this to the number of parallel processes you want + +for ((i=0; i