Skip to content
Snippets Groups Projects
Unverified Commit 1fe312b0 authored by Tao Xu's avatar Tao Xu Committed by GitHub
Browse files

image segmentation from mujoco-rendering (#537)


Co-authored-by: default avatarTao Xu <tao@openai.com>
parent d73ce6e9
Branches
No related tags found
No related merge requests found
from threading import Lock
from mujoco_py.generated import const
import numpy as np
cimport numpy as np
cdef class MjRenderContext(object):
"""
......@@ -127,7 +128,7 @@ cdef class MjRenderContext(object):
mjr_freeContext(&self._con)
self._set_mujoco_buffers()
def render(self, width, height, camera_id=None):
def render(self, width, height, camera_id=None, segmentation=False):
cdef mjrRect rect
rect.left = 0
rect.bottom = 0
......@@ -157,6 +158,10 @@ cdef class MjRenderContext(object):
mjv_updateScene(self._model_ptr, self._data_ptr, &self._vopt,
&self._pert, &self._cam, mjCAT_ALL, &self._scn)
if segmentation:
self._scn.flags[const.RND_SEGMENT] = 1
self._scn.flags[const.RND_IDCOLOR] = 1
for marker_params in self._markers:
self._add_marker_to_scene(marker_params)
......@@ -164,7 +169,11 @@ cdef class MjRenderContext(object):
for gridpos, (text1, text2) in self._overlay.items():
mjr_overlay(const.FONTSCALE_150, gridpos, rect, text1.encode(), text2.encode(), &self._con)
def read_pixels(self, width, height, depth=True):
if segmentation:
self._scn.flags[const.RND_SEGMENT] = 0
self._scn.flags[const.RND_IDCOLOR] = 0
def read_pixels(self, width, height, depth=True, segmentation=False):
cdef mjrRect rect
rect.left = 0
rect.bottom = 0
......@@ -173,15 +182,32 @@ cdef class MjRenderContext(object):
rgb_arr = np.zeros(3 * rect.width * rect.height, dtype=np.uint8)
depth_arr = np.zeros(rect.width * rect.height, dtype=np.float32)
cdef unsigned char[::view.contiguous] rgb_view = rgb_arr
cdef float[::view.contiguous] depth_view = depth_arr
mjr_readPixels(&rgb_view[0], &depth_view[0], rect, &self._con)
rgb_img = rgb_arr.reshape(rect.height, rect.width, 3)
cdef np.ndarray[np.npy_uint32, ndim=2] seg_img
cdef np.ndarray[np.npy_int32, ndim=2] seg_ids
ret_img = rgb_img
if segmentation:
seg_img = (rgb_img[:, :, 0] + rgb_img[:, :, 1] * (2**8) + rgb_img[:, :, 2] * (2 ** 16))
seg_img[seg_img >= (self._scn.ngeom + 1)] = 0
seg_ids = np.full((self._scn.ngeom + 1, 2), fill_value=-1, dtype=np.int32)
for i in range(self._scn.ngeom):
geom = self._scn.geoms[i]
if geom.segid != -1:
seg_ids[geom.segid + 1, 0] = geom.objtype
seg_ids[geom.segid + 1, 1] = geom.objid
ret_img = seg_ids[seg_img]
if depth:
depth_img = depth_arr.reshape(rect.height, rect.width)
return (rgb_img, depth_img)
return (ret_img, depth_img)
else:
return rgb_img
return ret_img
def read_pixels_depth(self, np.ndarray[np.float32_t, mode="c", ndim=2] buffer):
''' Read depth pixels into a preallocated buffer '''
......
......@@ -129,7 +129,7 @@ cdef class MjSim(object):
mj_step(self.model.ptr, self.data.ptr)
def render(self, width=None, height=None, *, camera_name=None, depth=False,
mode='offscreen', device_id=-1):
mode='offscreen', device_id=-1, segmentation=False):
"""
Renders view from a camera and returns image as an `numpy.ndarray`.
......@@ -161,9 +161,9 @@ cdef class MjSim(object):
render_context = self._render_context_offscreen
render_context.render(
width=width, height=height, camera_id=camera_id)
width=width, height=height, camera_id=camera_id, segmentation=segmentation)
return render_context.read_pixels(
width, height, depth=depth)
width, height, depth=depth, segmentation=segmentation)
elif mode == 'window':
if self._render_context_window is None:
from mujoco_py.mjviewer import MjViewer
......
......@@ -97,8 +97,12 @@ cdef extern from "mjvisualize.h" nogil:
mjRND_SHADOW = 0, # shadows
mjRND_WIREFRAME, # wireframe
mjRND_REFLECTION, # reflections
mjRND_FOG, # fog
mjRND_ADDITIVE, # additive
mjRND_SKYBOX, # skybox
mjRND_FOG, # fog
mjRND_HAZE, # haze
mjRND_SEGMENT, # segment
mjRND_IDCOLOR, # color
enum: mjNRNDFLAG # number of rendering flags
......
__all__ = ['__version__', 'get_version']
version_info = (2, 0, 2, 9)
version_info = (2, 0, 2, 10)
# format:
# ('mujoco_major', 'mujoco_minor', 'mujoco_py_major', 'mujoco_py_minor')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment