Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tree search refactor #18

Merged
merged 3 commits into from
Sep 10, 2024
Merged

Tree search refactor #18

merged 3 commits into from
Sep 10, 2024

Conversation

sidnarayanan
Copy link
Collaborator

Our old TreeSearchRollout implicitly managed a tree as a list of trajectories, using string IDs to infer edges. That was brittle and hard to work with.

This PR adds a TransitionTree object (essentially a nx.DiGraph wrapper) to make it easier to manipulate trees.

ldp/data_structures.py Outdated Show resolved Hide resolved
ldp/data_structures.py Outdated Show resolved Hide resolved
ldp/data_structures.py Outdated Show resolved Hide resolved
tests/test_rollouts.py Show resolved Hide resolved
# Then go up the tree
assert tree.get_transition(f"{root_id}:0:0").value == pytest.approx(
1.9, rel=0.001
) # 1 + 0.9 * avg(0, 1, 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is avg(0, 1, 2)?

Also, instead of 1.9, why not just put pytest.approx(1 + 0.9)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean average. At this node:

  • The current reward is 1
  • The discount factor is 0.9
  • The expected future return is the mean of 0, 1, 2 (the values checked on the preceding lines).

So the value estimate should be 1 + 0.9 * average(0, 1, 2). I didn't want to write this out for every node I check here, since I left a comment here:

ldp/ldp/data_structures.py

Lines 259 to 271 in c85ec23

if children := list(self.tree.successors(step_id)):
# V_{t+1}(s') = sum_{a'} p(a'|s') * Q_{t+1}(s', a')
# Here we assume p(a'|s') is uniform.
# TODO: don't make that assumption where a logprob is available
v_tp1 = sum(
self.get_transition(child_id).value for child_id in children
) / len(children)
else:
v_tp1 = 0.0
# Q_t(s_t, a_t) = r_{t+1} + gamma * V_{t+1}(s_{t+1})
# (we are assuming the environment is deterministic)
step.value = step.reward + discount_factor * v_tp1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay. Imo sometimes it's a bit more understandable in tests to use a formula over a value. You could do: 1 + 0.9 * (0 + 1 + 2) / 3, but feel free to ignore

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, will implement before merging

@sidnarayanan sidnarayanan merged commit 2bd6745 into main Sep 10, 2024
5 of 6 checks passed
@sidnarayanan sidnarayanan deleted the tree-search branch September 10, 2024 16:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants