Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
I
ihmm_behav_states
Manage
Activity
Members
Labels
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Sebastian Bruijns
ihmm_behav_states
Compare revisions
b96f9c76308909edd8b4fc61579877cdac7bd82b to 06f6b5be95295bafa6deb3b08395d89ecacf389f
Compare revisions
Changes are shown as if the
source
revision was being merged into the
target
revision.
Learn more about comparing revisions.
Source
sbruijns/ihmm_behav_states
Select target project
No results found
06f6b5be95295bafa6deb3b08395d89ecacf389f
Select Git revision
Swap
Target
sbruijns/ihmm_behav_states
Select target project
sbruijns/ihmm_behav_states
1 result
b96f9c76308909edd8b4fc61579877cdac7bd82b
Select Git revision
Show changes
Only incoming changes from source
Include changes to target since source was created
Compare
Commits on Source (2)
removed unused calc_state_by_sess
· 74b65abf
SebastianBruijns
authored
2 years ago
74b65abf
code cleanup, removed eg _get_leaves_color_list
· 06f6b5be
SebastianBruijns
authored
2 years ago
06f6b5be
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
dyn_glm_chain_analysis.py
+18
-91
18 additions, 91 deletions
dyn_glm_chain_analysis.py
dyn_glm_chain_analysis_unused_funcs.py
+35
-0
35 additions, 0 deletions
dyn_glm_chain_analysis_unused_funcs.py
with
53 additions
and
91 deletions
dyn_glm_chain_analysis.py
View file @
06f6b5be
...
...
@@ -441,13 +441,9 @@ class MCMC_result:
self
.
n_contrasts
=
11
self
.
cont_ticks
=
all_cont_ticks
self
.
state_to_color
=
{}
self
.
state_to_ls
=
{}
self
.
session_contrasts
=
[
np
.
unique
(
cont_mapping
(
d
[:,
0
]
-
d
[:,
1
]))
for
d
in
self
.
data
]
self
.
count_assigns
()
# self.calc_state_by_sess()
def
count_assigns
(
self
):
self
.
assign_counts
=
np
.
zeros
((
self
.
n_samples
,
self
.
n_all_states
))
...
...
@@ -457,14 +453,6 @@ class MCMC_result:
for
s
in
range
(
self
.
n_all_states
):
self
.
assign_counts
[
i
,
s
]
=
np
.
sum
(
flat_list
==
s
)
def
calc_state_by_sess
(
self
):
self
.
states_by_session
=
np
.
zeros
((
self
.
n_pstates
,
self
.
n_sessions
))
for
m
in
self
.
models
:
for
i
,
seq
in
enumerate
(
m
.
stateseqs
):
for
s
in
self
.
proto_states
:
self
.
states_by_session
[
self
.
state_map
[
s
],
i
]
+=
np
.
sum
(
seq
==
s
)
/
len
(
seq
)
self
.
states_by_session
/=
self
.
n_samples
def
state_appearance_posterior
(
self
):
# posterior over when new states appear
state_starts
=
np
.
zeros
(
self
.
n_datapoints
)
...
...
@@ -626,9 +614,11 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
if
consistencies
is
None
:
active_trials
[
relevant_trials
-
trial_counter
]
=
1
else
:
active_trials
[
relevant_trials
-
trial_counter
]
=
np
.
mean
(
consistencies
[
tuple
(
np
.
meshgrid
(
relevant_trials
,
trials
))],
axis
=
0
)
print
(
"
fix this by taking the whole array, multiply by n, subtract n, divide by n-1
"
)
input
()
active_trials
[
relevant_trials
-
trial_counter
]
=
np
.
sum
(
consistencies
[
tuple
(
np
.
meshgrid
(
relevant_trials
,
trials
))],
axis
=
0
)
active_trials
[
relevant_trials
-
trial_counter
]
-=
1
active_trials
[
relevant_trials
-
trial_counter
]
=
active_trials
[
relevant_trials
-
trial_counter
]
/
(
trials
.
shape
[
0
]
-
1
)
# print("fix this by taking the whole array, multiply by n, subtract n, divide by n-1")
# input()
label
=
"
State {}
"
.
format
(
state
)
if
np
.
sum
(
relevant_trials
)
>
0.02
*
len
(
test
.
results
[
0
].
models
[
0
].
stateseqs
[
seq_num
])
else
None
...
...
@@ -756,7 +746,6 @@ def lapse_sides(test, state_sets, indices):
def
state_development_single_sample
(
test
,
indices
,
save
=
True
,
save_append
=
''
,
show
=
True
,
dpi
=
'
figure
'
,
separate_pmf
=
False
):
session_contrasts
=
[
np
.
unique
(
cont_mapping
(
d
[:,
0
]
-
d
[:,
1
]))
for
d
in
test
.
results
[
0
].
data
]
for
i
,
m
in
enumerate
([
item
for
sublist
in
test
.
results
for
item
in
sublist
.
models
]):
if
i
not
in
indices
:
...
...
@@ -818,7 +807,7 @@ def state_development_single_sample(test, indices, save=True, save_append='', sh
session_max
=
i
defined_points
=
np
.
zeros
(
test
.
results
[
0
].
n_contrasts
,
dtype
=
bool
)
defined_points
[
session_contrasts
[
session_max
]]
=
True
defined_points
[
test
.
results
[
0
].
session_contrasts
[
session_max
]]
=
True
n_points
=
150
points
=
np
.
linspace
(
1
,
test
.
results
[
0
].
n_sessions
,
n_points
)
...
...
@@ -922,7 +911,6 @@ def state_development_single_sample(test, indices, save=True, save_append='', sh
def
state_development
(
test
,
state_sets
,
indices
,
save
=
True
,
save_append
=
''
,
show
=
True
,
dpi
=
'
figure
'
,
separate_pmf
=
False
):
state_sets
=
[
np
.
array
(
s
)
for
s
in
state_sets
]
session_contrasts
=
[
np
.
unique
(
cont_mapping
(
d
[:,
0
]
-
d
[:,
1
]))
for
d
in
test
.
results
[
0
].
data
]
if
test
.
results
[
0
].
name
.
startswith
(
'
GLM_Sim_
'
):
print
(
"
./glm sim mice/truth_{}.p
"
.
format
(
test
.
results
[
0
].
name
))
...
...
@@ -1028,7 +1016,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
trial_counter
+=
len
(
state_seq
)
defined_points
=
np
.
zeros
(
test
.
results
[
0
].
n_contrasts
,
dtype
=
bool
)
defined_points
[
session_contrasts
[
session_max
]]
=
True
defined_points
[
test
.
results
[
0
].
session_contrasts
[
session_max
]]
=
True
if
not
separate_pmf
:
temp
=
np
.
sum
(
pmfs
[:,
defined_points
])
/
(
np
.
sum
(
defined_points
))
state_color
=
colors
[
int
(
temp
*
101
-
1
)]
...
...
@@ -1204,41 +1192,6 @@ def find_good_chains_unsplit(chains1, chains2, chains3, chains4, reduce_to=8, si
return
sol
,
r_hat_min
def
find_good_chains_unsplit_fast
(
chains1
,
chains2
,
chains3
,
chains4
,
reduce_to
=
8
):
delete_n
=
-
reduce_to
+
chains1
.
shape
[
0
]
mins
=
np
.
zeros
(
delete_n
+
1
)
n_chains
=
chains1
.
shape
[
0
]
chains
=
np
.
stack
([
chains1
,
chains2
,
chains3
,
chains4
])
print
(
"
Without removals: {}
"
.
format
(
eval_simple_r_hat
(
chains
)))
r_hat
=
eval_simple_r_hat
(
chains
)
mins
[
0
]
=
r_hat
l
,
m
,
n
=
chains
.
shape
psi_dot_j
=
np
.
mean
(
chains
,
axis
=
2
)
s_j_squared
=
np
.
sum
((
chains
-
psi_dot_j
[:,
:,
None
])
**
2
,
axis
=
2
)
/
(
n
-
1
)
r_hat_min
=
10
sol
=
0
for
x
in
combinations
(
range
(
n_chains
),
n_chains
-
delete_n
):
temp1
=
chains
[:,
x
]
temp2
=
psi_dot_j
[:,
x
]
temp3
=
s_j_squared
[:,
x
]
r_hat
=
eval_amortized_r_hat
(
temp1
,
temp2
,
temp3
,
l
,
m
-
delete_n
,
n
)
if
r_hat
<
r_hat_min
:
sol
=
x
r_hat_min
=
min
(
r_hat
,
r_hat_min
)
print
(
"
Minimum is {} (removed {})
"
.
format
(
r_hat_min
,
delete_n
))
sol
=
[
i
for
i
in
range
(
n_chains
)
if
i
not
in
sol
]
print
(
"
Removed: {}
"
.
format
(
sol
))
r_hat_local
=
eval_r_hat
(
np
.
delete
(
chains1
,
sol
,
axis
=
0
),
np
.
delete
(
chains2
,
sol
,
axis
=
0
),
np
.
delete
(
chains3
,
sol
,
axis
=
0
),
np
.
delete
(
chains4
,
sol
,
axis
=
0
))
print
(
"
Minimum over everything is {} (removed {})
"
.
format
(
r_hat_local
,
delete_n
))
return
sol
,
r_hat_min
def
find_good_chains_unsplit_greedy
(
chains1
,
chains2
,
chains3
,
chains4
,
reduce_to
=
8
,
simple
=
False
):
delete_n
=
-
reduce_to
+
chains1
.
shape
[
0
]
mins
=
np
.
zeros
(
delete_n
+
1
)
...
...
@@ -1329,12 +1282,6 @@ if __name__ == "__main__":
chains3
=
chains3
[:,
160
:]
chains4
=
chains4
[:,
160
:]
# chains1 = chains1[:16]
# chains2 = chains2[:16]
# chains3 = chains3[:16]
# chains4 = chains4[:16]
# mins = find_good_chains(chains[:, :-1].reshape(32, chains.shape[1] // 2))
sol
,
final_r_hat
=
find_good_chains_unsplit_greedy
(
chains1
,
chains2
,
chains3
,
chains4
,
reduce_to
=
chains1
.
shape
[
0
]
//
2
)
r_hats
.
append
((
subject
,
final_r_hat
))
...
...
@@ -1361,22 +1308,6 @@ def dist_helper(dist_matrix, state_hists, inds):
return
dist_matrix
def
_get_leaves_color_list
(
R
):
# copied from latest scipy version
leaves_color_list
=
[
None
]
*
len
(
R
[
'
leaves
'
])
for
link_x
,
link_y
,
link_color
in
zip
(
R
[
'
icoord
'
],
R
[
'
dcoord
'
],
R
[
'
color_list
'
]):
for
(
xi
,
yi
)
in
zip
(
link_x
,
link_y
):
if
yi
==
0.0
:
# if yi is 0.0, the point is a leaf
# xi of leaves are 5, 15, 25, 35, ... (see `iv_ticks`)
# index of leaves are 0, 1, 2, 3, ... as below
leaf_index
=
(
int
(
xi
)
-
5
)
//
10
# each leaf has a same color of its link.
leaves_color_list
[
leaf_index
]
=
link_color
return
leaves_color_list
def
state_size_helper
(
n
=
0
,
mode_specific
=
False
):
if
not
mode_specific
:
def
nth_largest_state_func
(
x
):
...
...
@@ -1505,7 +1436,7 @@ if __name__ == "__main__":
loading_info
=
json
.
load
(
open
(
"
canonical_infos.json
"
,
'
r
'
))
r_hats
=
json
.
load
(
open
(
"
canonical_info_r_hats.json
"
,
'
r
'
))
no_good_pcas
=
[
'
NYU-06
'
,
'
SWC_023
'
]
# no good rhat: 'ibl_witten_13'
subjects
=
list
(
loading_info
.
keys
())
subjects
=
[
'
KS014
'
]
#
list(loading_info.keys())
print
(
subjects
)
fit_variance
=
[
0.03
,
0.002
,
0.0005
,
'
uniform
'
,
0
,
0.008
][
0
]
dur
=
'
yes
'
...
...
@@ -1540,7 +1471,7 @@ if __name__ == "__main__":
'
SWC_023
'
:
(
9
,
'
skip
'
),
# gradual
'
ZM_1897
'
:
(
5
,
'
right
'
),
'
ZM_3003
'
:
(
1
,
'
skip
'
)}
fig
,
ax
=
plt
.
subplots
(
1
,
3
,
sharey
=
True
,
figsize
=
(
16
,
9
))
#
fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9))
thinning
=
25
summary_info
=
{
"
thinning
"
:
thinning
,
"
contains
"
:
[],
"
seeds
"
:
[],
"
fit_nums
"
:
[]}
...
...
@@ -1578,22 +1509,19 @@ if __name__ == "__main__":
test
=
pickle
.
load
(
open
(
"
multi_chain_saves/canonical_result_{}_{}.p
"
.
format
(
subject
,
fit_type
),
'
rb
'
))
print
(
'
loaded canoncial result
'
)
mode_indices
=
pickle
.
load
(
open
(
"
multi_chain_saves/mode_indices_{}.p
"
.
format
(
subject
),
'
rb
'
))
state_sets
=
pickle
.
load
(
open
(
"
multi_chain_saves/state_sets_{}.p
"
.
format
(
subject
),
'
rb
'
))
# state_sets = pickle.load(open("multi_chain_saves/state_sets_{}.p".format(subject), 'rb'))
# lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
# continue
#
#
mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
#
state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
#
states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True)
mode_indices
=
pickle
.
load
(
open
(
"
multi_chain_saves/mode_indices_{}_{}.p
"
.
format
(
subject
,
fit_type
),
'
rb
'
))
state_sets
=
pickle
.
load
(
open
(
"
multi_chain_saves/state_sets_{}_{}.p
"
.
format
(
subject
,
fit_type
),
'
rb
'
))
states
,
pmfs
=
state_development
(
test
,
[
s
for
s
in
state_sets
if
len
(
s
)
>
40
],
mode_indices
,
show
=
True
,
separate_pmf
=
True
)
# state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False)
# quit(
)
#
consistencies =
pickle.load(open("multi_chain_saves/consistencies_{}.p".format(subject), 'rb'))
#
con
sistencies /= consistencies[0, 0]
# contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save=True, show=True, consistencies=consistencies
)
# quit()
consistencies
=
pickle
.
load
(
open
(
"
multi_chain_saves/consistencies_{}_{}.p
"
.
format
(
subject
,
fit_type
),
'
rb
'
)
)
consistencies
/
=
consistencies
[
0
,
0
]
con
trasts_plot
(
test
,
[
s
for
s
in
state_sets
if
len
(
s
)
>
40
],
subject
=
subject
,
save
=
True
,
show
=
True
,
consistencies
=
consistencies
)
quit
(
)
# state_types = np.zeros((3, states.shape[1]))
# for s, pmf in zip(states, pmfs):
...
...
@@ -1811,7 +1739,6 @@ if __name__ == "__main__":
contrasts_plot
(
test
,
[
s
for
s
in
state_sets
if
len
(
s
)
>
40
],
subject
=
subject
,
save_append
=
'
_{}{}
'
.
format
(
string_prefix
,
criterion
),
save
=
True
,
show
=
True
)
# R = hc.dendrogram(linkage, no_plot=True, get_leaves=True, color_threshold=color_threshold)
# R["leaves_color_list"] = _get_leaves_color_list(R)
# leaves_color_list = np.array(R["leaves_color_list"])
# leaves = np.array(R["leaves"])
...
...
This diff is collapsed.
Click to expand it.
dyn_glm_chain_analysis_unused_funcs.py
View file @
06f6b5be
...
...
@@ -64,6 +64,41 @@ class MCMC_result:
return
consistency_mat
def
find_good_chains_unsplit_fast
(
chains1
,
chains2
,
chains3
,
chains4
,
reduce_to
=
8
):
delete_n
=
-
reduce_to
+
chains1
.
shape
[
0
]
mins
=
np
.
zeros
(
delete_n
+
1
)
n_chains
=
chains1
.
shape
[
0
]
chains
=
np
.
stack
([
chains1
,
chains2
,
chains3
,
chains4
])
print
(
"
Without removals: {}
"
.
format
(
eval_simple_r_hat
(
chains
)))
r_hat
=
eval_simple_r_hat
(
chains
)
mins
[
0
]
=
r_hat
l
,
m
,
n
=
chains
.
shape
psi_dot_j
=
np
.
mean
(
chains
,
axis
=
2
)
s_j_squared
=
np
.
sum
((
chains
-
psi_dot_j
[:,
:,
None
])
**
2
,
axis
=
2
)
/
(
n
-
1
)
r_hat_min
=
10
sol
=
0
for
x
in
combinations
(
range
(
n_chains
),
n_chains
-
delete_n
):
temp1
=
chains
[:,
x
]
temp2
=
psi_dot_j
[:,
x
]
temp3
=
s_j_squared
[:,
x
]
r_hat
=
eval_amortized_r_hat
(
temp1
,
temp2
,
temp3
,
l
,
m
-
delete_n
,
n
)
if
r_hat
<
r_hat_min
:
sol
=
x
r_hat_min
=
min
(
r_hat
,
r_hat_min
)
print
(
"
Minimum is {} (removed {})
"
.
format
(
r_hat_min
,
delete_n
))
sol
=
[
i
for
i
in
range
(
n_chains
)
if
i
not
in
sol
]
print
(
"
Removed: {}
"
.
format
(
sol
))
r_hat_local
=
eval_r_hat
(
np
.
delete
(
chains1
,
sol
,
axis
=
0
),
np
.
delete
(
chains2
,
sol
,
axis
=
0
),
np
.
delete
(
chains3
,
sol
,
axis
=
0
),
np
.
delete
(
chains4
,
sol
,
axis
=
0
))
print
(
"
Minimum over everything is {} (removed {})
"
.
format
(
r_hat_local
,
delete_n
))
return
sol
,
r_hat_min
if
__name__
==
"
__main__
"
:
subjects
=
list
(
loading_info
.
keys
())
...
...
This diff is collapsed.
Click to expand it.