Skip to content

Commit 61274e8

Browse files
authored
Merge pull request #38 from rohanpsingh/topic/add-graph
Introduce 2D plotting feature
2 parents 8f6561d + a32db6b commit 61274e8

File tree

3 files changed

+111
-3
lines changed

3 files changed

+111
-3
lines changed

examples/sample.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,38 @@
77
# create the viewer object
88
viewer = mujoco_viewer.MujocoViewer(model, data)
99

10+
viewer.add_line_to_fig(line_name="root-pos-x", fig_idx = 0)
11+
viewer.add_line_to_fig(line_name="root-pos-z", fig_idx = 0)
12+
viewer.add_line_to_fig(line_name="right_ankle_y", fig_idx = 1)
13+
14+
# user has access to mjvFigure
15+
fig = viewer.figs[0]
16+
fig.title = "Root Position"
17+
fig.flg_legend = True
18+
fig.xlabel = "Timesteps"
19+
fig.figurergba[0] = 0.2
20+
fig.figurergba[3] = 0.2
21+
fig.gridsize[0] = 5
22+
fig.gridsize[1] = 5
23+
24+
fig = viewer.figs[1]
25+
fig.title = "Joint position"
26+
fig.flg_legend = True
27+
fig.figurergba[0] = 0.2
28+
fig.figurergba[3] = 0.2
29+
1030
# simulate and render
1131
for _ in range(100000):
32+
viewer.add_data_to_line(line_name="root-pos-x", line_data=data.qpos[0], fig_idx=0)
33+
viewer.add_data_to_line(line_name="root-pos-z", line_data=data.qpos[2], fig_idx=0)
34+
viewer.add_data_to_line(line_name="right_ankle_y", line_data=data.qpos[
35+
mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "right_ankle_y")], fig_idx=1)
36+
1237
mujoco.mj_step(model, data)
1338
viewer.render()
1439
if not viewer.is_alive:
1540
break
1641

1742
# close
1843
viewer.close()
44+

mujoco_viewer/callbacks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self, hide_menus):
1818
self._last_mouse_x = 0
1919
self._last_mouse_y = 0
2020
self._paused = False
21+
self._hide_graph = False
2122
self._transparent = False
2223
self._contacts = False
2324
self._joints = False
@@ -97,6 +98,9 @@ def _key_callback(self, window, key, scancode, action, mods):
9798
self.model.geom_rgba[:, 3] /= 5.0
9899
else:
99100
self.model.geom_rgba[:, 3] *= 5.0
101+
# Toggle Graph overlay
102+
elif key == glfw.KEY_G:
103+
self._hide_graph = not self._hide_graph
100104
# Display inertia
101105
elif key == glfw.KEY_I:
102106
self._inertias = not self._inertias

mujoco_viewer/mujoco_viewer.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,23 @@ def __init__(
7070
self.cam = mujoco.MjvCamera()
7171
self.scn = mujoco.MjvScene(self.model, maxgeom=10000)
7272
self.pert = mujoco.MjvPerturb()
73+
7374
self.ctx = mujoco.MjrContext(
7475
self.model, mujoco.mjtFontScale.mjFONTSCALE_150.value)
7576

77+
width, height = glfw.get_framebuffer_size(self.window)
78+
79+
# figures for creating 2D plots
80+
max_num_figs = 3
81+
self.figs = []
82+
width_adjustment = width % 4
83+
fig_w, fig_h = int(width / 4), int(height / 4)
84+
for idx in range(max_num_figs):
85+
fig = mujoco.MjvFigure()
86+
mujoco.mjv_defaultFigure(fig)
87+
fig.flg_extend = 1
88+
self.figs.append(fig)
89+
7690
# load camera from configuration (if available)
7791
pathlib.Path(
7892
self.CONFIG_PATH.parent).mkdir(
@@ -118,6 +132,51 @@ def __init__(
118132
self._overlay = {}
119133
self._markers = []
120134

135+
def add_line_to_fig(self, line_name, fig_idx = 0):
136+
assert isinstance(line_name, str), \
137+
"Line name must be a string."
138+
139+
fig = self.figs[fig_idx]
140+
if line_name.encode('utf8') == b'':
141+
raise Exception(
142+
"Line name cannot be empty."
143+
)
144+
if line_name.encode('utf8') in fig.linename:
145+
raise Exception(
146+
"Line name already exists in this plot."
147+
)
148+
149+
# this assumes all lines added by user have a non-empty name
150+
linecount = fig.linename.tolist().index(b'')
151+
152+
# we want to add the line after the last non-empty index
153+
fig.linename[linecount] = line_name
154+
155+
# assign x values
156+
for i in range(mujoco.mjMAXLINEPNT):
157+
fig.linedata[linecount][2*i] = -float(i)
158+
159+
def add_data_to_line(self, line_name, line_data, fig_idx = 0):
160+
fig = self.figs[fig_idx]
161+
162+
try:
163+
_line_name = line_name.encode('utf8')
164+
linenames = fig.linename.tolist()
165+
line_idx = linenames.index(_line_name)
166+
except ValueError:
167+
raise Exception(
168+
"line name is not valid, add it to list before calling update"
169+
)
170+
171+
pnt = min(mujoco.mjMAXLINEPNT, fig.linepnt[line_idx] + 1)
172+
# shift data
173+
for i in range(pnt-1, 0, -1):
174+
fig.linedata[line_idx][2*i + 1] = fig.linedata[line_idx][2*i - 1]
175+
176+
# assign new
177+
fig.linepnt[line_idx] = pnt;
178+
fig.linedata[line_idx][1] = line_data;
179+
121180
def add_marker(self, **marker_params):
122181
self._markers.append(marker_params)
123182

@@ -205,6 +264,10 @@ def add_overlay(gridpos, text1, text2):
205264
topleft,
206265
"[J]oints",
207266
"On" if self._joints else "Off")
267+
add_overlay(
268+
topleft,
269+
"[G]raph Viewer",
270+
"Off" if self._hide_graph else "On")
208271
add_overlay(
209272
topleft,
210273
"[I]nertia",
@@ -295,7 +358,6 @@ def read_pixels(self, camid=None, depth=False):
295358
mujoco.mjr_render(self.viewport, self.scn, self.ctx)
296359
shape = glfw.get_framebuffer_size(self.window)
297360

298-
299361
if depth:
300362
rgb_img = np.zeros((shape[1], shape[0], 3), dtype=np.uint8)
301363
depth_img = np.zeros((shape[1], shape[0], 1), dtype=np.float32)
@@ -324,8 +386,10 @@ def update():
324386
self._create_overlay()
325387

326388
render_start = time.time()
327-
self.viewport.width, self.viewport.height = glfw.get_framebuffer_size(
328-
self.window)
389+
390+
width, height = glfw.get_framebuffer_size(self.window)
391+
self.viewport.width, self.viewport.height = width, height
392+
329393
with self._gui_lock:
330394
# update scene
331395
mujoco.mjv_updateScene(
@@ -355,6 +419,20 @@ def update():
355419
t1,
356420
t2,
357421
self.ctx)
422+
423+
# handle figures
424+
if not self._hide_graph:
425+
for idx, fig in enumerate(self.figs):
426+
width_adjustment = width % 4
427+
x = int(3 * width / 4) + width_adjustment
428+
y = idx * int(height / 4)
429+
viewport = mujoco.MjrRect(
430+
x, y, int(width / 4), int(height / 4))
431+
432+
has_lines = len([i for i in fig.linename if i!=b''])
433+
if has_lines:
434+
mujoco.mjr_figure(viewport, fig, self.ctx)
435+
358436
glfw.swap_buffers(self.window)
359437
glfw.poll_events()
360438
self._time_per_render = 0.9 * self._time_per_render + \

0 commit comments

Comments
 (0)