diff --git a/embodichain/data/enum.py b/embodichain/data/enum.py index c4643a82..2902045c 100644 --- a/embodichain/data/enum.py +++ b/embodichain/data/enum.py @@ -24,12 +24,14 @@ class SemanticMask(IntEnum): Attributes: BACKGROUND (int): Represents the background region (value: 0). FOREGROUND (int): Represents the foreground objects (value: 1). - ROBOT (int): Represents the robot region (value: 2). + ROBOT_LEFT (int): Represents the left robot region (value: 2). + ROBOT_RIGHT (int): Represents the right robot region (value: 3). """ BACKGROUND = 0 FOREGROUND = 1 - ROBOT = 2 + ROBOT_LEFT = 2 + ROBOT_RIGHT = 3 class EndEffector(Enum): diff --git a/embodichain/lab/gym/envs/managers/observations.py b/embodichain/lab/gym/envs/managers/observations.py index 1eb83fa4..24be9d99 100644 --- a/embodichain/lab/gym/envs/managers/observations.py +++ b/embodichain/lab/gym/envs/managers/observations.py @@ -180,10 +180,11 @@ def compute_semantic_mask( """Compute the semantic mask for the specified scene entity. Note: - The semantic mask is defined as (B, H, W, 3) where the three channels represents: - - robot channel: the instance id of the robot is set to 1 (0 if not robot) + The semantic mask is defined as (B, H, W, len(SemanticMask)) where these channels represents: - background channel: the instance id of the background is set to 1 (0 if not background) - foreground channel: the instance id of the foreground objects is set to 1 (0 if not foreground) + - robot left-side channel: the instance id of the robot left-side is set to 1 + - robot right-side channel: the instance id of the robot right-side is set to 1 Args: env: The environment instance. @@ -209,13 +210,30 @@ def compute_semantic_mask( else: mask = obs["sensor"][entity_cfg.uid]["mask"] - robot_uids = env.robot.get_user_ids() + left_robot_uids = torch.cat( + [ + env.robot.get_user_ids(link_name) + for link_name in env.robot.link_names + if link_name.startswith("left_") + ], + -1, + ) + right_robot_uids = torch.cat( + [ + env.robot.get_user_ids(link_name) + for link_name in env.robot.link_names + if link_name.startswith("right_") + ], + -1, + ) mask_exp = mask.unsqueeze(-1) - robot_uids_exp = robot_uids.unsqueeze_(1).unsqueeze_(1) + left_robot_uids_exp = left_robot_uids.unsqueeze_(1).unsqueeze_(1) + right_robot_uids_exp = right_robot_uids.unsqueeze_(1).unsqueeze_(1) - robot_mask = (mask_exp == robot_uids_exp).any(-1).squeeze_(-1) + left_robot_mask = (mask_exp == left_robot_uids_exp).any(-1).squeeze_(-1) + right_robot_mask = (mask_exp == right_robot_uids_exp).any(-1).squeeze_(-1) asset_uids = env.sim.asset_uids foreground_assets = [ @@ -239,9 +257,11 @@ def compute_semantic_mask( foreground_mask = (mask_exp == foreground_uids_exp).any(-1).squeeze_(-1) - background_mask = ~(robot_mask | foreground_mask).squeeze_(-1) + background_mask = ~(left_robot_mask | right_robot_mask | foreground_mask).squeeze_( + -1 + ) - masks = [None, None, None] + masks = [None, None, None, None] masks_ids = [member.value for member in SemanticMask] assert len(masks) == len( masks_ids @@ -249,7 +269,8 @@ def compute_semantic_mask( mask_id_to_label = { SemanticMask.BACKGROUND.value: background_mask, SemanticMask.FOREGROUND.value: foreground_mask, - SemanticMask.ROBOT.value: robot_mask, + SemanticMask.ROBOT_LEFT.value: left_robot_mask, + SemanticMask.ROBOT_RIGHT.value: right_robot_mask, } for mask_id in masks_ids: masks[mask_id] = mask_id_to_label[mask_id]