From 6c831346dd6fe79132c478ddf14a3d0c16225cc8 Mon Sep 17 00:00:00 2001 From: Scott <scott.linderman@gmail.com> Date: Mon, 29 Aug 2016 20:48:08 -0400 Subject: [PATCH] adding utils for relabeling state sequences --- pyhsmm/util/general.py | 43 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/pyhsmm/util/general.py b/pyhsmm/util/general.py index c9a96f9..46a069f 100644 --- a/pyhsmm/util/general.py +++ b/pyhsmm/util/general.py @@ -314,3 +314,46 @@ def treemap(f,l): else: return f(l) +### relabel by usage +def _get_labelset(labelss): + import operator + if isinstance(labelss,np.ndarray): + labelset = np.unique(labelss) + return set(labelset[~np.isnan(labelset)]) + else: + return reduce(operator.or_,(_get_labelset(l) for l in labelss)) + +def _get_N(labelss): + return int(max(_get_labelset(labelss)))+1 + +def relabel_by_permutation(l, perm): + out = np.empty_like(l) + good = ~np.isnan(l) + out[good] = perm[l[good].astype('int32')] + if np.isnan(l).any(): + out[~good] = np.nan + return out + +def relabel_by_usage(labelss, return_mapping=False, N=None): + if isinstance(labelss, np.ndarray): + backwards_compat = True + labelss = [labelss] + else: + backwards_compat = False + + N = _get_N(labelss) if not N else N + usages = sum(np.bincount(l[~np.isnan(l)].astype('int32'),minlength=N) + for l in labelss) + perm = np.argsort(np.argsort(usages)[::-1]) + outs = [relabel_by_permutation(l,perm) for l in labelss] + + if backwards_compat: + if return_mapping: + return outs[0], perm + else: + return outs[0] + else: + if return_mapping: + return outs, perm + else: + return outs -- GitLab