Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix get_current_graph_task_execution_order accumulate_grads ordering #105353

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3678,6 +3678,18 @@ def hook(t_):
self.assertEqual(predicted[0], grad_fns(*actual))
actual = []

# Accumulate grad node has more than one input
a = torch.tensor(1., requires_grad=True)
b = a.sin()
c = a.cos()
out = b * c
register_logging_hooks(a, b, c, out)
out.register_hook(hook)
with torch.autograd.set_multithreading_enabled(False):
out.backward()
self.assertEqual(predicted[0], grad_fns(*actual))
actual = []

# Multiple roots are also OK
a = torch.tensor(1., requires_grad=True)
b = a * 2
Expand Down
13 changes: 11 additions & 2 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,8 @@ std::vector<Node*> get_current_graph_task_execution_order() {
const bool check_exec_info = !task->exec_info_.empty();
std::vector<Node*> out{};
std::unordered_set<Node*> seen{};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer

With these changes, I think you might be able to delete the seen checking here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point

// Do a copy since we mutate it later
std::unordered_map<Node*, int> dependencies = task->dependencies_;

auto compare_seq_nr = [](Node* n1, Node* n2) {
return n1->sequence_nr() < n2->sequence_nr();
Expand All @@ -427,7 +429,9 @@ std::vector<Node*> get_current_graph_task_execution_order() {
}

// Implementation notes:
// - Don't need to count dependencies because we have sequence_nr
// - We need count dependencies even though we have sequence_nr, because
// in the accumulate_grad case we cannot assume the outputs to have higher
// sequence_nr than the inputs
// - Don't need to check topological_nr because we have exec_info
while (!heap.empty()) {
Node* fn = heap.top();
Expand All @@ -450,7 +454,12 @@ std::vector<Node*> get_current_graph_task_execution_order() {
continue;
}
}
heap.push(next_ptr);
auto it = dependencies.find(edge.function.get());
TORCH_INTERNAL_ASSERT(it != dependencies.end());
if (--it->second == 0) {
dependencies.erase(it);
heap.push(next_ptr);
}
}
}
return out;
Expand Down
Loading