From 06f2a54ab8298021bcc639c0636971624ba0a69d Mon Sep 17 00:00:00 2001
From: Scott Linderman <scott.linderman@gmail.com>
Date: Fri, 21 Oct 2016 21:42:18 -0400
Subject: [PATCH] bernoulli music generation!

---
 examples/hmm.py | 27 +++++++++++++++++++++++++--
 1 file changed, 25 insertions(+), 2 deletions(-)

diff --git a/examples/hmm.py b/examples/hmm.py
index e15810e..6ed561c 100644
--- a/examples/hmm.py
+++ b/examples/hmm.py
@@ -2,6 +2,8 @@ from __future__ import division
 from builtins import range
 import numpy as np
 np.seterr(divide='ignore') # these warnings are usually harmless for this code
+np.random.seed(0)
+
 from matplotlib import pyplot as plt
 import matplotlib
 import os
@@ -24,6 +26,7 @@ fit the model). Maybe this demo should use multinomial emissions...
 ###############
 
 data = np.loadtxt(os.path.join(os.path.dirname(__file__),'example-data.txt'))[:2500]
+T = data.shape[0]
 
 #########################
 #  posterior inference  #
@@ -40,9 +43,9 @@ obs_hypparams = {'mu_0':np.zeros(obs_dim),
                 'nu_0':obs_dim+2}
 
 ### HDP-HMM without the sticky bias
-
 obs_distns = [pyhsmm.distributions.Gaussian(**obs_hypparams) for state in range(Nmax)]
-posteriormodel = pyhsmm.models.WeakLimitHDPHMM(alpha=6.,gamma=6.,init_state_concentration=1.,
+posteriormodel = pyhsmm.models.WeakLimitHDPHMM(alpha=6.,gamma=6.,
+                                               init_state_concentration=1.,
                                    obs_distns=obs_distns)
 posteriormodel.add_data(data)
 
@@ -52,6 +55,26 @@ for idx in progprint_xrange(100):
 posteriormodel.plot()
 plt.gcf().suptitle('HDP-HMM sampled model after 100 iterations')
 
+### HDP-HMM with "sticky" initialization
+obs_distns = [pyhsmm.distributions.Gaussian(**obs_hypparams) for state in range(Nmax)]
+posteriormodel = pyhsmm.models.WeakLimitHDPHMM(alpha=6.,gamma=6.,
+                                               init_state_concentration=1.,
+                                   obs_distns=obs_distns)
+
+# Start with a "sticky" state sequence
+z_init = np.random.randint(0, Nmax, size=(T//5)).repeat(5)
+posteriormodel.add_data(data, stateseq=z_init)
+
+# Initialize the parameters of the model, holding the stateseq fixed
+for _ in progprint_xrange(10):
+    posteriormodel.resample_parameters()
+
+for idx in progprint_xrange(100):
+    posteriormodel.resample_model()
+
+posteriormodel.plot()
+plt.gcf().suptitle('HDP-HMM (sticky initialization) sampled model after 100 iterations')
+
 ### Sticky-HDP-HMM
 
 obs_distns = [pyhsmm.distributions.Gaussian(**obs_hypparams) for state in range(Nmax)]
-- 
GitLab