diff --git a/mdp.py b/mdp.py index 738ae130b..b9a6eaea0 100644 --- a/mdp.py +++ b/mdp.py @@ -104,6 +104,19 @@ def check_consistency(self): assert abs(s - 1) < 0.001 +class MDP2(MDP): + + """Inherits from MDP. Handles terminal states, and transitions to and from terminal states better.""" + def __init__(self, init, actlist, terminals, transitions, reward=None, gamma=0.9): + MDP.__init__(self, init, actlist, terminals, transitions, reward, gamma=gamma) + + def T(self, state, action): + if action is None: + return [(0.0, state)] + else: + return self.transitions[state][action] + + class GridMDP(MDP): """A two-dimensional grid MDP, as in [Figure 17.1]. All you have to do is @@ -186,7 +199,7 @@ def value_iteration(mdp, epsilon=0.001): U1[s] = R(s) + gamma * max(sum(p*U[s1] for (p, s1) in T(s, a)) for a in mdp.actions(s)) delta = max(delta, abs(U1[s] - U[s])) - if delta < epsilon*(1 - gamma)/gamma: + if delta <= epsilon*(1 - gamma)/gamma: return U