-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tree search refactor #18
Conversation
tests/test_rollouts.py
Outdated
# Then go up the tree | ||
assert tree.get_transition(f"{root_id}:0:0").value == pytest.approx( | ||
1.9, rel=0.001 | ||
) # 1 + 0.9 * avg(0, 1, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is avg(0, 1, 2)
?
Also, instead of 1.9
, why not just put pytest.approx(1 + 0.9)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean average. At this node:
- The current reward is 1
- The discount factor is 0.9
- The expected future return is the mean of 0, 1, 2 (the values checked on the preceding lines).
So the value estimate should be 1 + 0.9 * average(0, 1, 2)
. I didn't want to write this out for every node I check here, since I left a comment here:
Lines 259 to 271 in c85ec23
if children := list(self.tree.successors(step_id)): | |
# V_{t+1}(s') = sum_{a'} p(a'|s') * Q_{t+1}(s', a') | |
# Here we assume p(a'|s') is uniform. | |
# TODO: don't make that assumption where a logprob is available | |
v_tp1 = sum( | |
self.get_transition(child_id).value for child_id in children | |
) / len(children) | |
else: | |
v_tp1 = 0.0 | |
# Q_t(s_t, a_t) = r_{t+1} + gamma * V_{t+1}(s_{t+1}) | |
# (we are assuming the environment is deterministic) | |
step.value = step.reward + discount_factor * v_tp1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay. Imo sometimes it's a bit more understandable in tests to use a formula over a value. You could do: 1 + 0.9 * (0 + 1 + 2) / 3
, but feel free to ignore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, will implement before merging
c85ec23
to
ff1f648
Compare
Our old
TreeSearchRollout
implicitly managed a tree as a list of trajectories, using string IDs to infer edges. That was brittle and hard to work with.This PR adds a
TransitionTree
object (essentially anx.DiGraph
wrapper) to make it easier to manipulate trees.