4848from physics_simulator .utils .data_types import JointTrajectory
4949import time
5050import os
51+ from typing import Dict , List , Tuple , Optional , Any
52+ from dataclasses import dataclass
5153
5254from physics_simulator .utils .state_machine import SimpleStateMachine
5355
56+ @dataclass
57+ class DetectedObject :
58+ """Data class for detected object information"""
59+ class_name : str
60+ position : np .ndarray # [x, y, z] in camera frame
61+ orientation : np .ndarray # [qx, qy, qz, qw] in camera frame
62+ confidence : float
63+ bbox : Optional [np .ndarray ] = None # [x1, y1, x2, y2] if available
64+
65+ class VisionModelInterface :
66+ """Interface for vision model that detects objects and returns their poses"""
67+
68+ def __init__ (self ):
69+ """Initialize the vision model interface"""
70+ pass
71+
72+ def detect_objects (self , rgb_image : np .ndarray , depth_image : Optional [np .ndarray ] = None ) -> List [DetectedObject ]:
73+ """
74+ Detect objects in the image and return their poses in camera frame
75+
76+ Args:
77+ rgb_image: RGB image from camera
78+ depth_image: Depth image from camera (optional)
79+
80+ Returns:
81+ List of detected objects with their poses in camera frame
82+ """
83+ # This is a placeholder implementation
84+ # Replace this with your actual vision model
85+ raise NotImplementedError ("Subclass must implement detect_objects method" )
86+
87+ class DummyYoloSegmentationModel (VisionModelInterface ):
88+ """Dummy YOLO segmentation model that uses ground truth from simulator"""
89+
90+ def __init__ (self , simulator , robot ):
91+ super ().__init__ ()
92+ self .simulator = simulator
93+ self .robot = robot
94+ self .object_classes = ["cube" , "bin" ] # Supported object classes
95+
96+ def detect_objects (self , rgb_image : np .ndarray , depth_image : Optional [np .ndarray ] = None ) -> List [DetectedObject ]:
97+ """
98+ Dummy YOLO segmentation detection using ground truth
99+ """
100+ detected_objects = []
101+
102+ # Get ground truth poses for supported objects
103+ for obj_class in self .object_classes :
104+ # Get object state from simulator
105+ obj_state = self .simulator .get_object_state (f"/World/{ obj_class .capitalize ()} " )
106+ world_position = obj_state ["position" ]
107+ world_orientation = obj_state ["orientation" ]
108+
109+ # Transform from world frame to camera frame
110+ camera_position , camera_orientation = self ._world_to_camera_frame (
111+ world_position , world_orientation
112+ )
113+
114+ # Create detected object
115+ detected_obj = DetectedObject (
116+ class_name = obj_class ,
117+ position = camera_position ,
118+ orientation = camera_orientation ,
119+ confidence = 0.95 , # High confidence for ground truth
120+ bbox = np .array ([100 , 100 , 200 , 200 ]) # Dummy bbox
121+ )
122+ detected_objects .append (detected_obj )
123+
124+ return detected_objects
125+
126+ def _world_to_camera_frame (self , world_position , world_orientation ):
127+ """Transform pose from world frame to camera frame"""
128+ from scipy .spatial .transform import Rotation
129+
130+ # Get camera pose in world frame
131+ camera_prim_path = "/World/Galbot/head_link2/head_end_effector_mount_link/front_head_rgb_camera"
132+ camera_state = self .simulator .get_sensor_state (camera_prim_path )
133+ camera_world_position = camera_state ["transform_to_base_link" ]["position" ]
134+ camera_world_orientation = camera_state ["transform_to_base_link" ]["orientation" ]
135+
136+ # Create transformation matrices
137+ camera_world_rot = Rotation .from_quat (camera_world_orientation )
138+ world_rot = Rotation .from_quat (world_orientation )
139+
140+ # Transform position: subtract camera position and rotate
141+ relative_position = world_position - camera_world_position
142+ camera_position = camera_world_rot .inv ().apply (relative_position )
143+
144+ # Transform orientation: compose rotations
145+ camera_orientation = (camera_world_rot .inv () * world_rot ).as_quat ()
146+
147+ return camera_position , camera_orientation
148+
54149def interpolate_joint_positions (start_positions , end_positions , steps ):
55150 return np .linspace (start_positions , end_positions , steps ).tolist ()
56151
57152class IoaiGraspEnv :
58- def __init__ (self , headless = False ):
153+ def __init__ (self , headless = False , vision_model : Optional [ VisionModelInterface ] = None ):
59154 """
60155 Initialize the Olympic environment.
61156
62157 Args:
63158 headless: Whether to run in headless mode (without visualization)
159+ vision_model: Vision model for object detection (optional)
64160 """
65161 self .simulator = None
66162 self .robot = None
163+
164+ # Initialize vision model
165+ self .vision_model = vision_model if vision_model is not None else None
166+
167+ # Vision-related variables
168+ self .detected_objects = []
169+ self .last_detection_time = 0
170+ self .detection_interval = 0.1 # Detection frequency in seconds
67171
68172 # Setup the simulator
69173 self ._setup_simulator (headless = headless )
174+
175+ # Initialize vision model after simulator setup
176+ if self .vision_model is None :
177+ self .vision_model = DummyYoloSegmentationModel (self .simulator , self .robot )
178+
70179 # Setup the interface
71180 self ._setup_interface ()
72181 self ._init_pose ()
@@ -181,7 +290,7 @@ def _setup_simulator(self, headless=False):
181290
182291 # Add bin
183292 bin_config = MeshConfig (
184- prim_path = "/World/bin " ,
293+ prim_path = "/World/Bin " ,
185294 mjcf_path = Path ()
186295 .joinpath (self .simulator .synthnova_assets_directory )
187296 .joinpath ("synthnova_assets" )
@@ -206,7 +315,7 @@ def _setup_simulator(self, headless=False):
206315 # Initialize the simulator
207316 self .simulator .initialize ()
208317
209- bin_state = self .simulator .get_object_state ("/World/bin " )
318+ bin_state = self .simulator .get_object_state ("/World/Bin " )
210319 self .bin_position = bin_state ["position" ]
211320 self .bin_orientation = bin_state ["orientation" ]
212321
@@ -392,6 +501,135 @@ def robot_to_world_frame(self, robot_position, robot_orientation):
392501
393502 return world_position , world_orientation
394503
504+ def camera_to_world_frame (self , camera_position , camera_orientation ):
505+ """Transform pose from camera frame to world frame.
506+
507+ Args:
508+ camera_position: Position in camera frame [x, y, z]
509+ camera_orientation: Orientation in camera frame [qx, qy, qz, qw]
510+
511+ Returns:
512+ Tuple of (world_position, world_orientation) in world frame
513+ """
514+ from scipy .spatial .transform import Rotation
515+
516+ # Get camera pose in world frame
517+ camera_prim_path = self .front_head_rgb_camera_path
518+ camera_state = self .simulator .get_sensor_state (camera_prim_path )
519+ camera_world_position = camera_state ["transform_to_base_link" ]["position" ]
520+ camera_world_orientation = camera_state ["transform_to_base_link" ]["orientation" ]
521+
522+ # Create transformation matrices
523+ camera_world_rot = Rotation .from_quat (camera_world_orientation )
524+ camera_local_rot = Rotation .from_quat (camera_orientation )
525+
526+ # Transform position: rotate and add camera world position
527+ world_position = camera_world_rot .apply (camera_position ) + camera_world_position
528+
529+ # Transform orientation: compose rotations
530+ world_orientation = (camera_world_rot * camera_local_rot ).as_quat ()
531+
532+ return world_position , world_orientation
533+
534+ def world_to_camera_frame (self , world_position , world_orientation ):
535+ """Transform pose from world frame to camera frame.
536+
537+ Args:
538+ world_position: Position in world frame [x, y, z]
539+ world_orientation: Orientation in world frame [qx, qy, qz, qw]
540+
541+ Returns:
542+ Tuple of (camera_position, camera_orientation) in camera frame
543+ """
544+ from scipy .spatial .transform import Rotation
545+
546+ # Get camera pose in world frame
547+ camera_prim_path = self .front_head_rgb_camera_path
548+ camera_state = self .simulator .get_sensor_state (camera_prim_path )
549+ camera_world_position = camera_state ["position" ]
550+ camera_world_orientation = camera_state ["orientation" ]
551+
552+ # Create transformation matrices
553+ camera_world_rot = Rotation .from_quat (camera_world_orientation )
554+ world_rot = Rotation .from_quat (world_orientation )
555+
556+ # Transform position: subtract camera position and rotate
557+ relative_position = world_position - camera_world_position
558+ camera_position = camera_world_rot .inv ().apply (relative_position )
559+
560+ # Transform orientation: compose rotations
561+ camera_orientation = (camera_world_rot .inv () * world_rot ).as_quat ()
562+
563+ return camera_position , camera_orientation
564+
565+ def get_camera_images (self ):
566+ """Get RGB and depth images from the front head camera.
567+
568+ Returns:
569+ Tuple of (rgb_image, depth_image) or (rgb_image, None) if depth not available
570+ """
571+ try :
572+ # Get RGB image
573+ rgb_image = self .interface .front_head_camera .get_rgb ()
574+
575+ # Get depth image if available
576+ depth_image = None
577+ try :
578+ depth_image = self .interface .front_head_camera .get_depth ()
579+ except :
580+ pass # Depth image not available
581+
582+ return rgb_image , depth_image
583+ except Exception as e :
584+ print (f"Error getting camera images: { e } " )
585+ return None , None
586+
587+ def detect_objects_vision (self ) -> List [DetectedObject ]:
588+ """Detect objects using vision model"""
589+ current_time = time .time ()
590+
591+ # Check detection frequency
592+ if current_time - self .last_detection_time < self .detection_interval :
593+ return self .detected_objects
594+
595+ # Get camera images
596+ rgb_image , depth_image = self .get_camera_images ()
597+
598+ if rgb_image is None :
599+ return self .detected_objects
600+
601+ # Run vision model detection
602+ detected_objects = self .vision_model .detect_objects (rgb_image , depth_image )
603+
604+ # Update detection results
605+ self .detected_objects = detected_objects
606+ self .last_detection_time = current_time
607+
608+ return detected_objects
609+
610+ def get_object_pose_from_vision (self , target_class : str = "cube" ) -> Optional [Tuple [np .ndarray , np .ndarray ]]:
611+ """Get object pose from vision detection"""
612+ # Detect objects using vision
613+ detected_objects = self .detect_objects_vision ()
614+
615+ # Find target object
616+ target_object = None
617+ for obj in detected_objects :
618+ if obj .class_name .lower () == target_class .lower ():
619+ target_object = obj
620+ break
621+
622+ if target_object is None :
623+ print (f"Target object '{ target_class } ' not detected" )
624+ return None
625+
626+ # Transform from camera frame to world frame
627+ world_position , world_orientation = self .camera_to_world_frame (
628+ target_object .position , target_object .orientation
629+ )
630+
631+ return world_position , world_orientation
632+
395633 def compute_simple_ik (self , start_joint , target_pose , arm_id = "left_arm" ):
396634 """Compute inverse kinematics using Mink.
397635
@@ -669,22 +907,42 @@ def init_state():
669907 def move_to_pre_pick_state ():
670908 """Move to pre-pick position"""
671909 if self .state_first_entry :
672- cube_state = self .simulator .get_object_state ("/World/Cube" )
673- self .cube_position = cube_state ["position" ].copy ()
910+ # Use vision model to detect object pose instead of ground truth
911+ vision_result = self .get_object_pose_from_vision ("cube" )
912+ if vision_result is not None :
913+ world_pos , world_ori = vision_result
914+ self .cube_position = world_pos .copy ()
915+ self .cube_orientation = world_ori .copy ()
916+ print (f"Vision detected cube at position: { world_pos } " )
917+ else :
918+ # Fallback to ground truth if vision fails
919+ cube_state = self .simulator .get_object_state ("/World/Cube" )
920+ self .cube_position = cube_state ["position" ].copy ()
921+ self .cube_orientation = cube_state ["orientation" ].copy ()
922+ print ("Using ground truth fallback for cube position" )
674923 self .state_first_entry = False
675924
676925 # Convert world frame pose to robot frame
677926 world_pos = self .cube_position + np .array ([0 , 0 , 0.15 ])
678- world_ori = np .array ([0 , 0.7071 , 0 , 0.7071 ])
927+ world_ori = np .array ([0 , 0.7071 , 0 , 0.7071 ]) # Fixed orientation for grasping
679928 robot_pos , robot_ori = self .world_to_robot_frame (world_pos , world_ori )
680929 return self ._move_left_arm_to_pose (robot_pos , robot_ori )
681930
682931 def move_to_pick_state ():
683932 """Move to pick position"""
933+ # Re-detect object position for more accurate pick
934+ vision_result = self .get_object_pose_from_vision ("cube" )
935+ if vision_result is not None :
936+ world_pos , world_ori = vision_result
937+ # Use detected position for more accurate pick
938+ pick_pos = world_pos + np .array ([0 , 0 , 0.03 ])
939+ else :
940+ # Fallback to stored position
941+ pick_pos = self .cube_position + np .array ([0 , 0 , 0.03 ])
942+
684943 # Convert world frame pose to robot frame
685- world_pos = self .cube_position + np .array ([0 , 0 , 0.03 ])
686- world_ori = np .array ([0 , 0.7071 , 0 , 0.7071 ])
687- robot_pos , robot_ori = self .world_to_robot_frame (world_pos , world_ori )
944+ world_ori = np .array ([0 , 0.7071 , 0 , 0.7071 ]) # Fixed orientation for grasping
945+ robot_pos , robot_ori = self .world_to_robot_frame (pick_pos , world_ori )
688946 return self ._move_left_arm_to_pose (robot_pos , robot_ori )
689947
690948 def grasp_state ():
@@ -710,13 +968,24 @@ def move_to_pre_place_state():
710968 def move_to_place_state ():
711969 """Move to place position"""
712970 if self .state_first_entry :
713- bin_state = self .simulator .get_object_state ("/World/bin" )
714- self .bin_position = bin_state ["position" ].copy ()
971+ # Use vision model to detect bin pose instead of ground truth
972+ vision_result = self .get_object_pose_from_vision ("bin" )
973+ if vision_result is not None :
974+ world_pos , world_ori = vision_result
975+ self .bin_position = world_pos .copy ()
976+ self .bin_orientation = world_ori .copy ()
977+ print (f"Vision detected bin at position: { world_pos } " )
978+ else :
979+ # Fallback to ground truth if vision fails
980+ bin_state = self .simulator .get_object_state ("/World/Bin" )
981+ self .bin_position = bin_state ["position" ].copy ()
982+ self .bin_orientation = bin_state ["orientation" ].copy ()
983+ print ("Using ground truth fallback for bin position" )
715984 self .state_first_entry = False
716985
717986 # Convert world frame pose to robot frame
718987 world_pos = self .bin_position + np .array ([0 , 0 , 0.3 ])
719- world_ori = np .array ([0 , 0.7071 , 0 , 0.7071 ])
988+ world_ori = np .array ([0 , 0.7071 , 0 , 0.7071 ]) # Fixed orientation for placing
720989 robot_pos , robot_ori = self .world_to_robot_frame (world_pos , world_ori )
721990 return self ._move_left_arm_to_pose (robot_pos , robot_ori )
722991
0 commit comments