@@ -268,7 +268,7 @@ def apply_perturbations(self):
268268 mujoco .mjv_applyPerturbPose (self .model , self .data , self .pert , 0 )
269269 mujoco .mjv_applyPerturbForce (self .model , self .data , self .pert )
270270
271- def read_pixels (self , camid = None ):
271+ def read_pixels (self , camid = None , depth = False ):
272272 if self .render_mode == 'window' :
273273 raise NotImplementedError (
274274 "Use 'render()' in 'window' mode." )
@@ -293,13 +293,20 @@ def read_pixels(self, camid=None):
293293 self .scn )
294294 # render
295295 mujoco .mjr_render (self .viewport , self .scn , self .ctx )
296+ shape = glfw .get_framebuffer_size (self .window )
297+
298+
299+ if depth :
300+ img = np .zeros ((shape [1 ], shape [0 ], 3 ), dtype = np .uint8 )
301+ depth_img = np .zeros ((shape [1 ], shape [0 ], 1 ), dtype = np .float32 )
302+ mujoco .mjr_readPixels (img , None , self .viewport , self .ctx )
303+ mujoco .mjr_readPixels (None , depth_img , self .viewport , self .ctx )
304+ return (np .flipud (img ),np .flipud (depth_img ))
305+ else :
306+ img = np .zeros ((shape [1 ], shape [0 ], 3 ), dtype = np .uint8 )
307+ mujoco .mjr_readPixels (img , None , self .viewport , self .ctx )
308+ return np .flipud (img )
296309
297- img = np .zeros (
298- (glfw .get_framebuffer_size (
299- self .window )[1 ], glfw .get_framebuffer_size (
300- self .window )[0 ], 3 ), dtype = np .uint8 )
301- mujoco .mjr_readPixels (img , None , self .viewport , self .ctx )
302- return np .flipud (img )
303310
304311 def render (self ):
305312 if self .render_mode == 'offscreen' :
0 commit comments