Skip to content
Snippets Groups Projects

conditional_mouselab.py

  • Clone with SSH
  • Clone with HTTPS
  • Embed
  • Share
    The snippet can be accessed without any authentication.
    Authored by Valkyrie Felso
    Edited
    conditional_mouselab.py 6.11 KiB
    from mouselab.mouselab import MouselabEnv
    from mouselab.envs.registry import register, registry
    from mouselab.distributions import Categorical
    import numpy as np
    from toolz import get
    
    
    class ConditionalMouselabEnv(MouselabEnv):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
    
            possible_ground_truths = [
                (0, -1, -5, -5, -5, 1, -5, 100, -100, -1, -5, -5, -5),
                (0, -1, -5, -5, -5, 1, -5, -100, 100, -1, -5, -5, -5),
                (0, -1, -5, -5, -5, -1, -5, -5, -5, 1, -5, -100, 100),
                (0, -1, -5, -5, -5, -1, -5, -5, -5, 1, -5, 100, -100),
                (0, 1, -5, 100, -100, -1, -5, -5, -5, -1, -5, -5, -5),
                (0, 1, -5, -100, 100, -1, -5, -5, -5, -1, -5, -5, -5),
            ]
    
            if tuple(self.ground_truth) not in possible_ground_truths:
                raise ValueError(
                    "Ground truth does not fit in with hard coded results method"
                )
    
        def expected_term_reward(self, state):
            if (-100 in state) or (100 in state):
                return 96
            elif (1 in state):
                return -4
            else:
                return super().expected_term_reward(state)
            
            
        def results(self, state, action):
            """Returns a list of possible results of taking action in state.
    
            Each outcome is (probability, next_state, reward).
            """
            if action == self.term_action:
                if (-100 in state) or (100 in state):
                    yield (1, self.term_state, 96)
                elif (1 in state):
                    yield (1, self.term_state, -4)
                else:
                    yield (1, self.term_state, self.expected_term_reward(state))
            elif self.include_last_action:
                # assume you are not using the distance cost or a cost that depends on last action
                raise NotImplementedError
            else:
                # check if branch is discovered
                # check if early node that is 1 is revealed, or +100 or -100
                found_reward_branch = 0
                found_rewarding_states = [
                    action_index
                    for action_index in range(len(state))
                    if state[action_index] in [+1, -100, +100]
                ]
                if len(found_rewarding_states) > 0:
                    # should only be in one cluster
                    assert (
                        len(
                            np.unique(
                                [
                                    self.mdp_graph.nodes[rewarding_state]["cluster"]
                                    for rewarding_state in found_rewarding_states
                                ]
                            )
                        )
                        <= 1
                    )
                    found_reward_branch = np.unique(
                        [
                            self.mdp_graph.nodes[rewarding_state]["cluster"]
                            for rewarding_state in found_rewarding_states
                        ]
                    )[0]
    
                # if reward branching not yet found or action is depth = 2, it's original problem
                if (found_reward_branch == 0) or (
                    self.mdp_graph.nodes[action]["depth"] == 2
                ):
                    for r, p in state[action]:
                        s1 = list(state)
                        s1[action] = r
                        yield (p, tuple(s1), self.cost(state, action))
                # else we're not on the rewarding branch
                elif self.mdp_graph.nodes[action]["cluster"] != found_reward_branch:
                    if self.mdp_graph.nodes[action]["depth"] == 1:
                        s1 = list(state)
                        s1[action] = -1
                        yield (1, tuple(s1), self.cost(state, action))
                    else:
                        s1 = list(state)
                        s1[action] = -5
                        yield (1, tuple(s1), self.cost(state, action))
                # or we are on the rewarding branch
                else:
                    if self.mdp_graph.nodes[action]["depth"] == 1:
                        s1 = list(state)
                        s1[action] = 1
                        yield (1, tuple(s1), self.cost(state, action))
                    elif self.mdp_graph.nodes[action]["depth"] == 3:
                        if -100 in state:
                            s1 = list(state)
                            s1[action] = 100
                            yield (1, tuple(s1), self.cost(state, action))
                        elif 100 in state:
                            s1 = list(state)
                            s1[action] = -100
                            yield (1, tuple(s1), self.cost(state, action))
                        else:
                            for r, p in Categorical([100, -100], [1 / 2, 1 / 2]):
                                s1 = list(state)
                                s1[action] = r
                                yield (p, tuple(s1), self.cost(state, action))
                    else:
                        raise AssertionError("Did not expect to get here")
    
        @classmethod
        def new_symmetric_registered(cls, experiment_setting, seed=None, **kwargs):
            branching = registry(experiment_setting).branching
            reward = registry(experiment_setting).reward_function
    
            if not callable(reward):
                r = reward
                reward = lambda depth: r
    
            init = []
            tree = []
    
            def expand(d):
                my_idx = len(init)
                init.append(reward(d))
                children = []
                tree.append(children)
                for _ in range(get(d, branching, 0)):
                    child_idx = expand(d + 1)
                    children.append(child_idx)
                return my_idx
    
            expand(0)
            return cls(tree, init, seed=seed, **kwargs)
    
    
    if __name__ == "__main__":
        register(
            name="conditional",
            branching=[3, 1, 2],
            reward_inputs="depth",
            reward_dictionary={
                1: Categorical([-1, 1], [2 / 3, 1 / 3]),
                2: Categorical([-5], [1]),
                3: Categorical([-5, +100, -100], [2 / 3, 1 / 6, 1 / 6]),
            },
        )
    
        # inherit, using hard coded values and symmetric environment
        env = ConditionalMouselabEnv.new_symmetric_registered("conditional")
    
        print(list(env.results(env._state, 1)))
        env.step(1)
        print(list(env.results(env._state, 3)))
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Finish editing this message first!
    Please register or to comment