Skip to content
Snippets Groups Projects
Unverified Commit f66788ac authored by Million Integrals's avatar Million Integrals Committed by GitHub
Browse files

Add option to override the mujoco error callback. (#434)

Mujoco 2.0.2.3 
parent 320f195e
No related branches found
No related tags found
No related merge requests found
......@@ -53,6 +53,7 @@ cdef extern from "gl/glshim.h":
# This is the python callback function. We save it in the global() context
# so we can access it from a C wrapper function (c_warning_callback)
cdef object py_warning_callback
cdef object py_error_callback
# This is the saved exception. Because the C callback can not propagate
# exceptions, this must be set to None before calling into MuJoCo, and then
# inspected afterwards.
......@@ -99,6 +100,36 @@ def get_warning_callback():
return py_warning_callback
cdef void c_error_callback(const char *msg) with gil:
'''
Wraps the error callback so that we can pass a python function to the callback.
MuJoCo error handlers are expected to terminate the program and never return.
'''
global py_error_callback
(<object> py_error_callback)(msg)
def set_error_callback(err_callback):
'''
Set a user-defined error callback. It should take in a string message
(the warning string) and terminate the program.
See c_warning_callback, which is the C wrapper to the user defined function
'''
global py_error_callback
global mju_user_error
py_error_callback = err_callback
mju_user_error = c_error_callback
def get_error_callback():
'''
Returns the user-defined warning callback, for use in e.g. a context
manager.
'''
global py_error_callback
return py_error_callback
class wrap_mujoco_warning(object):
'''
Class to wrap capturing exceptions raised during warning callbacks.
......
......@@ -14,7 +14,7 @@ from mujoco_py import (MjSim, load_model_from_xml,
load_model_from_path, MjSimState,
ignore_mujoco_warnings,
load_model_from_mjb)
from mujoco_py import const, cymj
from mujoco_py import const, cymj, functions
from mujoco_py.tests.utils import compare_imgs
......@@ -334,6 +334,20 @@ def test_mj_warning_raises():
sim.step()
def test_mj_error_called():
error_message = None
def error_callback(msg):
nonlocal error_message
error_message = msg.decode()
cymj.set_error_callback(error_callback)
functions.mju_error("error")
assert error_message == "error"
def test_ignore_mujoco_warnings():
# Two boxes on a plane need more than 1 contact (nconmax)
xml = '''
......
__all__ = ['__version__', 'get_version']
version_info = (2, 0, 2, 2)
version_info = (2, 0, 2, 3)
# format:
# ('mujoco_major', 'mujoco_minor', 'mujoco_py_major', 'mujoco_py_minor')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment