-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core][cgraph] Rework DagRef Destruction #49818
Conversation
Signed-off-by: dayshah <dhyey2019@gmail.com>
Signed-off-by: dayshah <dhyey2019@gmail.com>
Signed-off-by: dayshah <dhyey2019@gmail.com>
Signed-off-by: dayshah <dhyey2019@gmail.com>
Signed-off-by: dayshah <dhyey2019@gmail.com>
Signed-off-by: dayshah <dhyey2019@gmail.com>
6f5ca37
to
7dd356e
Compare
Signed-off-by: dayshah <dhyey2019@gmail.com>
7dd356e
to
7d10d79
Compare
python/ray/dag/compiled_dag_node.py
Outdated
and buffer all results up to that index. If the DAG has already | ||
been executed up to the given index, just return the result | ||
corresponding to the given index and channel. | ||
and buffer all results (except for refs that have been destructed). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"(except for refs that have been destructed)" is hard for the user to understand, elaborate a bit more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated a bit lmk if it's clearer now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will have another pass to see if this is comprehensive enough
python/ray/dag/compiled_dag_node.py
Outdated
max_finished_execution_index + 1 is in the set of destructed indices. | ||
""" | ||
timeout = self._get_timeout | ||
# Keep releasing buffers while the next execution idx is in the destructed set |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
explain why we are doing this rather than what the code does, sth like "check if native buffers corresponding to destructed CompiledDAGRefs are ready to be released, and release as many as possible"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main doc of the function describes what this does so removed this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please also write down in the description the pros and cons for the 3 approaches and why this is preferred so that we don't lose context?
Signed-off-by: dayshah <dhyey2019@gmail.com>
Signed-off-by: dayshah <dhyey2019@gmail.com>
self._cache_execution_results( | ||
self._max_finished_execution_index + 1, | ||
result, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a prior execute() returns two refs, ref1 and ref2, and ref1 has gone out of scope, and ray.get(ref2) is called, this will cache the value for both ref1 and ref2, and then pop only value-for-ref2 but leaves value-for-ref1, which is leaked?
For this to work, I think you will need to make _destructed_execution_idxs a map from int to set of int. i.e., the value is the set of channel indexes.
Can you add a test case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added test case and logic to handle inside cache, so we never cache if the ref for that channel idx has been destructed
Otherwise LGTM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reviewing
|
||
ctx = DAGContext.get_current() | ||
timeout = ctx.get_timeout | ||
def _try_release_buffers(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe in this function we should not only call release_channel_buffers
but also check the cached results and clean it up if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they don't have to always be called at the same time.
We should never cache if already destructed because we know ahead of time. So the only times we should ever remove from the cache is during __del__
or get
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s just a safeguard. It’s fine if you prefer not to add it.
Signed-off-by: dayshah <dhyey2019@gmail.com>
# Test that ray.get() on ref still works properly even if | ||
# ref2 (corresponding to a later execution) is destructed first |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not consistent with test name? or maybe I misunderstood the test name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i had trouble coming up with the names for these tests lol, but the idea is that we're getting the ref that was executed first, and destructing the second ref. so the ref we're getting was made before the ref we destructed. Agree the test_get_ref_before_destructed_ref
name isn't easy to understand, but can't think of a good one for these
Signed-off-by: dayshah <dhyey2019@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Others LGTM
a = Actor.remote(0) | ||
with InputNode() as i: | ||
dag = a.echo.bind(i) | ||
class TestDAGRefDestruction: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the tests, could you use different inputs for different execute calls so that we can ensure get retrieves the expected execution index?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it shouldn't matter because each call is incrementing the counter so will still be different result for each execution idx.
For first example, ref is 1 and ref2 is 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_basic_destruction
uses a.echo.bind(i)
. Maybe it's the only one test that doesn't use inc
.
Signed-off-by: dayshah <dhyey2019@gmail.com>
Signed-off-by: dayshah <dhyey2019@gmail.com> Signed-off-by: Anson Qian <anson627@gmail.com>
Why are these changes needed?
There are two different approaches we can take when deserializing dagrefs.
Approach #1 is taken here #49781 and is much simpler.
The downside however is that when a dagref is destructed it requires the execution and caching of all previous executions. We have no guarantee Python desructs in order, so if there are multiple dagrefs being destructed it becomes possible that we're still deserializing and caching them even if we don't need to. We are also forced to execute up to the dagref that is being destructed even if the user hasn't called get on previous dagrefs yet.
Approach #2 is taken in this pr, and here we hold a destructed_ref_idxs dict which is the execution_idxs -> set of channel_idxs of destructed CompiledDagRefs. Then, we release the buffer whenever the max_finished_execution_index is 1 less than any destructed execution index with a complete set of channel_idxs. There's three places we check for this, during destruction, during execute, and during get (in the loop of execute_until). The upside is that the destruction of the dagrefs will never require the executing any previous dagrefs. Execution for the destructed dagrefs will only happen if the max_finished_execution_index has reached the index before the destructed dagref. The downside here is the complexity of having the logic of checking if buffers can be released at 3 separate places.
Related issue number
Closes #49782
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.