diff --git a/pyhsmm/util/stats.py b/pyhsmm/util/stats.py
index 7fde82815345c8f00b0c074f236cc6519abdc276..7705d855792bc9b03f2da82751d36d9db096b502 100644
--- a/pyhsmm/util/stats.py
+++ b/pyhsmm/util/stats.py
@@ -100,7 +100,7 @@ def whiten(datalist):
     return general.treemap(apply_whitening, datalist)
 
 def diag_whiten(datalist):
-    mu, l = mean(datalist), np.sqrt(diag_of_cov(datalist))
+    mu, l = mean(datalist), np.sqrt(np.diag(cov(datalist)))
     def apply_whitening(x):
         return (x-mu)/l + mu
     return general.treemap(apply_whitening, datalist)