conditional_mouselab.py
The snippet can be accessed without any authentication.
Authored by
Valkyrie Felso
Edited
conditional_mouselab.py 6.11 KiB
from mouselab.mouselab import MouselabEnv
from mouselab.envs.registry import register, registry
from mouselab.distributions import Categorical
import numpy as np
from toolz import get
class ConditionalMouselabEnv(MouselabEnv):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
possible_ground_truths = [
(0, -1, -5, -5, -5, 1, -5, 100, -100, -1, -5, -5, -5),
(0, -1, -5, -5, -5, 1, -5, -100, 100, -1, -5, -5, -5),
(0, -1, -5, -5, -5, -1, -5, -5, -5, 1, -5, -100, 100),
(0, -1, -5, -5, -5, -1, -5, -5, -5, 1, -5, 100, -100),
(0, 1, -5, 100, -100, -1, -5, -5, -5, -1, -5, -5, -5),
(0, 1, -5, -100, 100, -1, -5, -5, -5, -1, -5, -5, -5),
]
if tuple(self.ground_truth) not in possible_ground_truths:
raise ValueError(
"Ground truth does not fit in with hard coded results method"
)
def expected_term_reward(self, state):
if (-100 in state) or (100 in state):
return 96
elif (1 in state):
return -4
else:
return super().expected_term_reward(state)
def results(self, state, action):
"""Returns a list of possible results of taking action in state.
Each outcome is (probability, next_state, reward).
"""
if action == self.term_action:
if (-100 in state) or (100 in state):
yield (1, self.term_state, 96)
elif (1 in state):
yield (1, self.term_state, -4)
else:
yield (1, self.term_state, self.expected_term_reward(state))
elif self.include_last_action:
# assume you are not using the distance cost or a cost that depends on last action
raise NotImplementedError
else:
# check if branch is discovered
# check if early node that is 1 is revealed, or +100 or -100
found_reward_branch = 0
found_rewarding_states = [
action_index
for action_index in range(len(state))
if state[action_index] in [+1, -100, +100]
]
if len(found_rewarding_states) > 0:
# should only be in one cluster
assert (
len(
np.unique(
[
self.mdp_graph.nodes[rewarding_state]["cluster"]
for rewarding_state in found_rewarding_states
]
)
)
<= 1
)
found_reward_branch = np.unique(
[
self.mdp_graph.nodes[rewarding_state]["cluster"]
for rewarding_state in found_rewarding_states
]
)[0]
# if reward branching not yet found or action is depth = 2, it's original problem
if (found_reward_branch == 0) or (
self.mdp_graph.nodes[action]["depth"] == 2
):
for r, p in state[action]:
s1 = list(state)
s1[action] = r
yield (p, tuple(s1), self.cost(state, action))
# else we're not on the rewarding branch
elif self.mdp_graph.nodes[action]["cluster"] != found_reward_branch:
if self.mdp_graph.nodes[action]["depth"] == 1:
s1 = list(state)
s1[action] = -1
yield (1, tuple(s1), self.cost(state, action))
else:
s1 = list(state)
s1[action] = -5
yield (1, tuple(s1), self.cost(state, action))
# or we are on the rewarding branch
else:
if self.mdp_graph.nodes[action]["depth"] == 1:
s1 = list(state)
s1[action] = 1
yield (1, tuple(s1), self.cost(state, action))
elif self.mdp_graph.nodes[action]["depth"] == 3:
if -100 in state:
s1 = list(state)
s1[action] = 100
yield (1, tuple(s1), self.cost(state, action))
elif 100 in state:
s1 = list(state)
s1[action] = -100
yield (1, tuple(s1), self.cost(state, action))
else:
for r, p in Categorical([100, -100], [1 / 2, 1 / 2]):
s1 = list(state)
s1[action] = r
yield (p, tuple(s1), self.cost(state, action))
else:
raise AssertionError("Did not expect to get here")
@classmethod
def new_symmetric_registered(cls, experiment_setting, seed=None, **kwargs):
branching = registry(experiment_setting).branching
reward = registry(experiment_setting).reward_function
if not callable(reward):
r = reward
reward = lambda depth: r
init = []
tree = []
def expand(d):
my_idx = len(init)
init.append(reward(d))
children = []
tree.append(children)
for _ in range(get(d, branching, 0)):
child_idx = expand(d + 1)
children.append(child_idx)
return my_idx
expand(0)
return cls(tree, init, seed=seed, **kwargs)
if __name__ == "__main__":
register(
name="conditional",
branching=[3, 1, 2],
reward_inputs="depth",
reward_dictionary={
1: Categorical([-1, 1], [2 / 3, 1 / 3]),
2: Categorical([-5], [1]),
3: Categorical([-5, +100, -100], [2 / 3, 1 / 6, 1 / 6]),
},
)
# inherit, using hard coded values and symmetric environment
env = ConditionalMouselabEnv.new_symmetric_registered("conditional")
print(list(env.results(env._state, 1)))
env.step(1)
print(list(env.results(env._state, 3)))
Please register or sign in to comment