from mouselab.mouselab import MouselabEnv from mouselab.envs.registry import register, registry from mouselab.distributions import Categorical import numpy as np from toolz import get class ConditionalMouselabEnv(MouselabEnv): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) possible_ground_truths = [ (0, -1, -5, -5, -5, 1, -5, 100, -100, -1, -5, -5, -5), (0, -1, -5, -5, -5, 1, -5, -100, 100, -1, -5, -5, -5), (0, -1, -5, -5, -5, -1, -5, -5, -5, 1, -5, -100, 100), (0, -1, -5, -5, -5, -1, -5, -5, -5, 1, -5, 100, -100), (0, 1, -5, 100, -100, -1, -5, -5, -5, -1, -5, -5, -5), (0, 1, -5, -100, 100, -1, -5, -5, -5, -1, -5, -5, -5), ] if tuple(self.ground_truth) not in possible_ground_truths: raise ValueError( "Ground truth does not fit in with hard coded results method" ) def expected_term_reward(self, state): if (-100 in state) or (100 in state): return 96 elif (1 in state): return -4 else: return super().expected_term_reward(state) def results(self, state, action): """Returns a list of possible results of taking action in state. Each outcome is (probability, next_state, reward). """ if action == self.term_action: if (-100 in state) or (100 in state): yield (1, self.term_state, 96) elif (1 in state): yield (1, self.term_state, -4) else: yield (1, self.term_state, self.expected_term_reward(state)) elif self.include_last_action: # assume you are not using the distance cost or a cost that depends on last action raise NotImplementedError else: # check if branch is discovered # check if early node that is 1 is revealed, or +100 or -100 found_reward_branch = 0 found_rewarding_states = [ action_index for action_index in range(len(state)) if state[action_index] in [+1, -100, +100] ] if len(found_rewarding_states) > 0: # should only be in one cluster assert ( len( np.unique( [ self.mdp_graph.nodes[rewarding_state]["cluster"] for rewarding_state in found_rewarding_states ] ) ) <= 1 ) found_reward_branch = np.unique( [ self.mdp_graph.nodes[rewarding_state]["cluster"] for rewarding_state in found_rewarding_states ] )[0] # if reward branching not yet found or action is depth = 2, it's original problem if (found_reward_branch == 0) or ( self.mdp_graph.nodes[action]["depth"] == 2 ): for r, p in state[action]: s1 = list(state) s1[action] = r yield (p, tuple(s1), self.cost(state, action)) # else we're not on the rewarding branch elif self.mdp_graph.nodes[action]["cluster"] != found_reward_branch: if self.mdp_graph.nodes[action]["depth"] == 1: s1 = list(state) s1[action] = -1 yield (1, tuple(s1), self.cost(state, action)) else: s1 = list(state) s1[action] = -5 yield (1, tuple(s1), self.cost(state, action)) # or we are on the rewarding branch else: if self.mdp_graph.nodes[action]["depth"] == 1: s1 = list(state) s1[action] = 1 yield (1, tuple(s1), self.cost(state, action)) elif self.mdp_graph.nodes[action]["depth"] == 3: if -100 in state: s1 = list(state) s1[action] = 100 yield (1, tuple(s1), self.cost(state, action)) elif 100 in state: s1 = list(state) s1[action] = -100 yield (1, tuple(s1), self.cost(state, action)) else: for r, p in Categorical([100, -100], [1 / 2, 1 / 2]): s1 = list(state) s1[action] = r yield (p, tuple(s1), self.cost(state, action)) else: raise AssertionError("Did not expect to get here") @classmethod def new_symmetric_registered(cls, experiment_setting, seed=None, **kwargs): branching = registry(experiment_setting).branching reward = registry(experiment_setting).reward_function if not callable(reward): r = reward reward = lambda depth: r init = [] tree = [] def expand(d): my_idx = len(init) init.append(reward(d)) children = [] tree.append(children) for _ in range(get(d, branching, 0)): child_idx = expand(d + 1) children.append(child_idx) return my_idx expand(0) return cls(tree, init, seed=seed, **kwargs) if __name__ == "__main__": register( name="conditional", branching=[3, 1, 2], reward_inputs="depth", reward_dictionary={ 1: Categorical([-1, 1], [2 / 3, 1 / 3]), 2: Categorical([-5], [1]), 3: Categorical([-5, +100, -100], [2 / 3, 1 / 6, 1 / 6]), }, ) # inherit, using hard coded values and symmetric environment env = ConditionalMouselabEnv.new_symmetric_registered("conditional") print(list(env.results(env._state, 1))) env.step(1) print(list(env.results(env._state, 3)))