diff --git a/pyhsmm/util/general.py b/pyhsmm/util/general.py index c9a96f96d6949c406f46c9dc662c9759c227a748..46a069f017a0b8cf88eb26dcc1e957e26d905f3a 100644 --- a/pyhsmm/util/general.py +++ b/pyhsmm/util/general.py @@ -314,3 +314,46 @@ def treemap(f,l): else: return f(l) +### relabel by usage +def _get_labelset(labelss): + import operator + if isinstance(labelss,np.ndarray): + labelset = np.unique(labelss) + return set(labelset[~np.isnan(labelset)]) + else: + return reduce(operator.or_,(_get_labelset(l) for l in labelss)) + +def _get_N(labelss): + return int(max(_get_labelset(labelss)))+1 + +def relabel_by_permutation(l, perm): + out = np.empty_like(l) + good = ~np.isnan(l) + out[good] = perm[l[good].astype('int32')] + if np.isnan(l).any(): + out[~good] = np.nan + return out + +def relabel_by_usage(labelss, return_mapping=False, N=None): + if isinstance(labelss, np.ndarray): + backwards_compat = True + labelss = [labelss] + else: + backwards_compat = False + + N = _get_N(labelss) if not N else N + usages = sum(np.bincount(l[~np.isnan(l)].astype('int32'),minlength=N) + for l in labelss) + perm = np.argsort(np.argsort(usages)[::-1]) + outs = [relabel_by_permutation(l,perm) for l in labelss] + + if backwards_compat: + if return_mapping: + return outs[0], perm + else: + return outs[0] + else: + if return_mapping: + return outs, perm + else: + return outs