From 6c831346dd6fe79132c478ddf14a3d0c16225cc8 Mon Sep 17 00:00:00 2001
From: Scott <scott.linderman@gmail.com>
Date: Mon, 29 Aug 2016 20:48:08 -0400
Subject: [PATCH] adding utils for relabeling state sequences

---
 pyhsmm/util/general.py | 43 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 43 insertions(+)

diff --git a/pyhsmm/util/general.py b/pyhsmm/util/general.py
index c9a96f9..46a069f 100644
--- a/pyhsmm/util/general.py
+++ b/pyhsmm/util/general.py
@@ -314,3 +314,46 @@ def treemap(f,l):
     else:
         return f(l)
 
+### relabel by usage
+def _get_labelset(labelss):
+    import operator
+    if isinstance(labelss,np.ndarray):
+        labelset = np.unique(labelss)
+        return set(labelset[~np.isnan(labelset)])
+    else:
+        return reduce(operator.or_,(_get_labelset(l) for l in labelss))
+
+def _get_N(labelss):
+    return int(max(_get_labelset(labelss)))+1
+
+def relabel_by_permutation(l, perm):
+    out = np.empty_like(l)
+    good = ~np.isnan(l)
+    out[good] = perm[l[good].astype('int32')]
+    if np.isnan(l).any():
+        out[~good] = np.nan
+    return out
+
+def relabel_by_usage(labelss, return_mapping=False, N=None):
+    if isinstance(labelss, np.ndarray):
+        backwards_compat = True
+        labelss = [labelss]
+    else:
+        backwards_compat = False
+
+    N = _get_N(labelss) if not N else N
+    usages = sum(np.bincount(l[~np.isnan(l)].astype('int32'),minlength=N)
+                 for l in labelss)
+    perm = np.argsort(np.argsort(usages)[::-1])
+    outs = [relabel_by_permutation(l,perm) for l in labelss]
+
+    if backwards_compat:
+        if return_mapping:
+            return outs[0], perm
+        else:
+            return outs[0]
+    else:
+        if return_mapping:
+            return outs, perm
+        else:
+            return outs
-- 
GitLab