From 9ca102f0a3cd744014273191d7f79d0e9fff24ee Mon Sep 17 00:00:00 2001
From: Million Integrals <jerry@millionintegrals.com>
Date: Tue, 9 Jul 2019 09:35:56 -0700
Subject: [PATCH] Reraising mujoco error exceptions (#435)

* Reraising mujoco error exceptions
---
 mujoco_py/cymj.pyx           | 15 ++++++++++++++-
 mujoco_py/tests/test_cymj.py | 18 ++++++++++++++++++
 mujoco_py/version.py         |  2 +-
 3 files changed, 33 insertions(+), 2 deletions(-)

diff --git a/mujoco_py/cymj.pyx b/mujoco_py/cymj.pyx
index 63d9234..2cc833f 100644
--- a/mujoco_py/cymj.pyx
+++ b/mujoco_py/cymj.pyx
@@ -61,6 +61,7 @@ cdef object py_error_callback
 #   with wrap_mujoco_warning():
 #       mj_somefunc()
 cdef object py_warning_exception = None
+cdef object py_error_exception = None
 
 
 cdef void c_warning_callback(const char *msg) with gil:
@@ -106,7 +107,12 @@ cdef void c_error_callback(const char *msg) with gil:
     MuJoCo error handlers are expected to terminate the program and never return.
     '''
     global py_error_callback
-    (<object> py_error_callback)(msg)
+
+    try:
+        (<object> py_error_callback)(msg)
+    except Exception as e:
+        global py_error_exception
+        py_error_exception = e
 
 
 def set_error_callback(err_callback):
@@ -140,11 +146,18 @@ class wrap_mujoco_warning(object):
     def __enter__(self):
         global py_warning_exception
         py_warning_exception = None
+        global py_error_exception
+        py_error_exception = None
     def __exit__(self, type, value, traceback):
         global py_warning_exception
+        global py_error_exception
+
         if py_warning_exception is not None:
             raise py_warning_exception
 
+        if py_error_exception is not None:
+            raise py_error_exception
+
 
 def load_model_from_path(str path):
     """Loads model from path."""
diff --git a/mujoco_py/tests/test_cymj.py b/mujoco_py/tests/test_cymj.py
index 7a122d7..f8f6489 100644
--- a/mujoco_py/tests/test_cymj.py
+++ b/mujoco_py/tests/test_cymj.py
@@ -348,6 +348,24 @@ def test_mj_error_called():
     assert error_message == "error"
 
 
+def test_mj_error_raises():
+    def error_callback(msg):
+        raise RuntimeError(msg.decode())
+
+    cymj.set_error_callback(error_callback)
+
+    called = False
+
+    try:
+        with cymj.wrap_mujoco_warning():
+            functions.mju_error("error")
+    except RuntimeError as e:
+        assert e.args[0] == "error"
+        called = True
+
+    assert called
+
+
 def test_ignore_mujoco_warnings():
     # Two boxes on a plane need more than 1 contact (nconmax)
     xml = '''
diff --git a/mujoco_py/version.py b/mujoco_py/version.py
index 2a56a45..e29b641 100644
--- a/mujoco_py/version.py
+++ b/mujoco_py/version.py
@@ -1,6 +1,6 @@
 __all__ = ['__version__', 'get_version']
 
-version_info = (2, 0, 2, 3)
+version_info = (2, 0, 2, 4)
 # format:
 # ('mujoco_major', 'mujoco_minor', 'mujoco_py_major', 'mujoco_py_minor')
 
-- 
GitLab