diff --git a/test/test_dynamic_shapes_detector.py b/test/test_dynamic_shapes_detector.py new file mode 100644 index 00000000000..be2a6b4042d --- /dev/null +++ b/test/test_dynamic_shapes_detector.py @@ -0,0 +1,215 @@ +import torch +import torch_xla +import test_utils +import unittest + + +class TestDynamicShapeDetector(test_utils.XlaTestCase): + + def _run_and_compare(self, f, args=None, allowed_traces=None): + """Run f and its torch_xla.compile wrapped version, comparing the equality + of their results. + + If no optf is provided, we create a new one by wrapping it with + torch_xla.compile ourselves. + """ + optf = torch_xla.compile(f, allowed_traces=allowed_traces) + args = args or [] + + out = f(*args) + optout = optf(*args) + + self.assertEqual(out, optout) + + def test_single(self): + # Test: trace a function once, when only one trace is allowed. + + def foo(x): + return x + x + + inp = torch.rand(10, device=torch_xla.device()) + self._run_and_compare(foo, args=(inp,), allowed_traces=1) + + def test_many_traces(self): + # Test: multiple traces of a function. + # + # Steps 0~2 and 5: create new traces. + # Steps 3 and 4: ensure we have already traced these paths. + + def foo(x, step): + r0 = x + x + x + r = r0 + x + if step in (0, 3): + return r + x + if step == (1, 4): + return r * 2 + if step == 2: + return r * 4 + return r0 + + inp = torch.rand(10, device=torch_xla.device()) + + for i in range(6): + self._run_and_compare(foo, args=(inp, i), allowed_traces=4) + + def test_trace_limit_exceeded_different_input_shape(self): + # Test: catch trace limit exceeded error when running the function with a + # function with different shape. + + allowed_traces = 1 + + def foo(x): + return x + x + + inp1 = torch.rand(10, device=torch_xla.device()) + self._run_and_compare(foo, args=(inp1,), allowed_traces=allowed_traces) + + msg = """\ +.* Maximum number of different traces allowed per function exceeded: 1 +Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10) +Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()""" + + with self.assertRaises(RuntimeError, msg=msg): + inp2 = torch.rand(5, device=torch_xla.device()) + self._run_and_compare(foo, args=(inp2,), allowed_traces=allowed_traces) + + def test_trace_limit_exceeded_common_sequence_mismatch(self): + # Test: catch trace limit exceeded error when the common sequence (i.e. compressed + # path) of the trie node mismatches. + # + # Step 0: creates a trace with one node containing the add operation + # + # Step 1: tries to create 2 child nodes with: + # (i) add operation (previous trace); and + # (ii) mul operation. + # However, it fails since we have reached the limit. + + allowed_traces = 1 + + def foo(x, step): + if step == 0: + return x + x + else: + return x * 5 + + inp = torch.rand(10, device=torch_xla.device()) + self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces) + + msg = """\ +.* Maximum number of different traces allowed per function exceeded: 1 +Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: () +Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()""" + + with self.assertRaises(RuntimeError, msg=msg): + self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces) + + def test_trace_limit_exceeded_children_mismatch(self): + # Test: catch trace limit exceeded error when the expected child of the trie + # node mismatches. + # + # Step 0: creates a trace with one node containing 3 operations, the last + # being a mul operation. + # + # Step 1: creates another trace by splitting the node, creating 2 other child + # nodes containing the different operations in the end: + # (i) mul operation; and + # (ii) add operation. + # + # Step 2: tries to create a 3rd child node: div operation. However, we can't + # do it, since we have reached the limit. + + allowed_traces = 2 + + def foo(x, step): + r = x + x + if step == 0: + return r * 2 + if step == 1: + return r + x + return r / 3 + + inp = torch.rand(10, device=torch_xla.device()) + self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces) + self._run_and_compare(foo, args=(inp, 1), allowed_traces=allowed_traces) + + msg = """\ +.* Maximum number of different traces allowed per function exceeded: 2 +Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10) +Expected either of: + - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: () + - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()""" + + with self.assertRaises(RuntimeError, msg=msg): + self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces) + + def test_trace_limit_exceeded_common_sequence_early_stop(self): + # Test: catch trace limit exceeded error when the trace ends unexpectedly in + # the common sequence. + # + # Step 0: creates a trace with one node containing 3 operations. + # + # Step 1: at the end of this trace, it tries to create a new node containing + # the remaining operations of the previous trace, i.e. mul operation. However, + # it fails because we have reached the limit. + + allowed_traces = 1 + + def foo(x, mul=False): + r = x + x + if mul: + return r * 10 + else: + return r + + inp = torch.rand(10, device=torch_xla.device()) + self._run_and_compare(foo, args=(inp, True), allowed_traces=allowed_traces) + + msg = """\ +.* Maximum number of different traces allowed per function exceeded: 1 +Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: () +Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()""" + + with self.assertRaises(RuntimeError, msg=msg): + self._run_and_compare( + foo, args=(inp, False), allowed_traces=allowed_traces) + + def test_trace_limit_exceeded_children_early_stop(self): + # Test: catch trace limit exceeded error when the trace ends unexpectedly at + # a fork point (i.e. next operation would jump to anothe trie node). + # + # Step 0: creates a trace with one node containing 3 operations. + # + # Step 1: splits the node, creating 2 child nodes containing: + # (i) the differring operations from the last trace, i.e. mul operation + # (ii) the current last operation, i.e. add operation + # + # Step 3: at the end of this trace, it tries to turn the current trie node + # into a new trace. However, it fails since we have reached the limit. + + allowed_traces = 2 + + def foo(x, step): + r = x + x + if step == 0: + return r * 2 + if step == 1: + return r + x + return r + + inp = torch.rand(10, device=torch_xla.device()) + self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces) + self._run_and_compare(foo, args=(inp, 1), allowed_traces=allowed_traces) + + msg = """\ +.* Maximum number of different traces allowed per function exceeded: 2 +Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: () +Expected either of: + - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: () + - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()""" + + with self.assertRaises(RuntimeError, msg=msg): + self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index b9c1ff1c397..89fefda457f 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -287,11 +287,13 @@ ptxla_cc_library( ptxla_cc_library( name = "ir", srcs = [ + "dynamic_shape_detector.cpp", "ir.cpp", "lowering_context.cpp", "stack_frame_index_builder.cpp", ], hdrs = [ + "dynamic_shape_detector.h", "ir.h", "lowering_context.h", "stack_frame_index_builder.h", diff --git a/torch_xla/csrc/dynamic_shape_detector.cpp b/torch_xla/csrc/dynamic_shape_detector.cpp new file mode 100644 index 00000000000..1a7518d7a5e --- /dev/null +++ b/torch_xla/csrc/dynamic_shape_detector.cpp @@ -0,0 +1,261 @@ +#include "torch_xla/csrc/dynamic_shape_detector.h" + +#include + +#include "torch_xla/csrc/runtime/debug_macros.h" + +namespace torch_xla { + +// Maximum number of allowed traces per function (i.e. session). +static std::size_t max_allowed_traces_per_function = 1; + +TrieNode::TrieNode(absl::Span common_sequence, + bool is_trace_boundary) + : common_sequence_(common_sequence.begin(), common_sequence.end()), + is_trace_boundary_(is_trace_boundary) {} + +bool TrieNode::IsLeaf() const { return children_.empty(); } + +void TrieNode::NewTraceNotAllowedError(std::optional value, + std::size_t matched) { + std::ostringstream ostr; + ostr << "Maximum number of different traces allowed per function exceeded: " + << max_allowed_traces_per_function << std::endl; + + if (value.has_value()) { + ostr << "Got: " << value->str << std::endl; + } else { + ostr << "Reached the end of the function at: " + << common_sequence_[matched - 1].str << std::endl; + } + + if (common_sequence_.size() > matched) { + ostr << "Expected: " << common_sequence_[matched].str << std::endl; + } else { + ostr << "Expected either of:" << std::endl; + for (auto& pair : children_) { + ostr << " - " << pair.second->common_sequence_.front().str << std::endl; + } + } + + XLA_ERROR() << ostr.str(); +} + +bool TrieNode::MarkTraceBoundary(std::size_t matched, bool allow_new_trace) { + // No need to do anything here, iff: + // + // 1. nothing was matched, yet + // + // 2. we matched everything in this node, and this node is already marked as + // a trace boundary. + if (matched == 0 || + (common_sequence_.size() == matched && is_trace_boundary_)) { + return false; + } + + // From this point, we will create a new trace. + if (!allow_new_trace) { + // Raise an error if we have reached the maximum number of traces. + NewTraceNotAllowedError(std::nullopt, matched); + } + + // If we haven't matched everything in this node, we will have to split this + // node. The newly created node will contain the suffix (common_sequence_ + // after matched), and this node (i.e. the existing one) will contain the + // prefix. + if (common_sequence_.size() != matched) { + MaybeSplitAt(matched); + } + + // Finally, mark this node as a trace boundary. + is_trace_boundary_ = true; + + return true; +} + +TrieBuilder TrieNode::AddValue(TrieValue value, std::size_t matched, + bool allow_new_trace) { + TF_VLOG(5) << "Adding value: " << value.str << " (" << value.hash << ")"; + + // If this node has no children and is not marked as a trace boundary, it + // means that TrieBuilder created this node and is incrementally adding + // TrieValue to it. Therefore, we just need to keep doing it. + if (IsLeaf() && !is_trace_boundary_) { + common_sequence_.push_back(value); + return {this, matched + 1}; + } + + // If common_sequence_ still has more elements to be matched, try to match + // it with value. If we succeed, we simply increment the number of matched + // elements. + if (common_sequence_.size() > matched && + common_sequence_[matched].hash == value.hash) { + return {this, matched + 1}; + } + + // If we have matched every element in this node, try to find a child that + // corresponds to the given value. If we find it, return a TrieBuilder with + // the node found, and set matched_ to 1. + if (common_sequence_.size() == matched && + children_.find(value.hash) != children_.end()) { + return {children_[value.hash].get(), 1}; + } + + // Otherwise, we will have to create a new trace. So, first, check whether we + // are allowed to do so. + if (!allow_new_trace) { + NewTraceNotAllowedError(value, matched); + } + + // Maybe split the current node into: prefix (before matched) and suffix + // (after matched). + bool did_split = MaybeSplitAt(matched); + + // Create a new node that contains only the given value. + std::unique_ptr node = + std::make_unique(absl::Span{value}); + + // Associate the given value with the created node in the children's map. + children_[value.hash] = std::move(node); + + TF_VLOG(5) << "Created new node " << children_[value.hash].get() + << " for value: " << value.str << " (" << value.hash << ")"; + + // Unmark this node as trace boundary iff we actually split this node (i.e. + // suffix actually had something). Otherwise, this should still be a trace + // boundary. + if (did_split) { + is_trace_boundary_ = false; + } + + return {children_[value.hash].get(), 1}; +} + +bool TrieNode::MaybeSplitAt(std::size_t matched) { + // Split common_sequence_ into prefix (before matched) and suffix (after + // matched). Note that these variables are spans, i.e. they don't own their + // contents. + absl::Span common_sequence(common_sequence_); + absl::Span prefix = + common_sequence.subspan(0, /*len=*/matched); + absl::Span suffix = common_sequence.subspan(matched); + + // A split only occurs if suffix is not empty. + if (!suffix.empty()) { + std::unique_ptr suffix_node = + std::make_unique(suffix, is_trace_boundary_); + + // The suffix node's children should be what this node's children was before + // the split. Therefore, we swap those. + std::swap(children_, suffix_node->children_); + + // Create the children_ map entry for the newly created suffix node. + children_[suffix.front().hash] = std::move(suffix_node); + + TF_VLOG(5) << "Split node " << children_[suffix.front().hash].get() + << " at position " << matched << ": " << suffix.front().str + << " (" << suffix.front().hash << ")"; + } + + // This node's common_sequence_ will be whatever the prefix was. + common_sequence_ = std::vector{prefix.begin(), prefix.end()}; +} + +DynamicShapeDetector* DynamicShapeDetector::Get() { + static DynamicShapeDetector ds_detector = DynamicShapeDetector(); + return &ds_detector; +} + +void DynamicShapeDetector::StartSession(const std::string& name) { + if (session_infos_.find(name) == session_infos_.end()) { + // Create a new session, with a fresh TrieNode. + session_infos_[name] = {name, std::make_unique(), 0}; + TF_VLOG(5) << "Created new session: " << name; + } + current_session_ = &session_infos_[name]; + TF_VLOG(5) << "Started session: " << name; + RootBuilder(); +} + +void DynamicShapeDetector::SetMaxAllowedTraces(std::size_t value) { + max_allowed_traces_per_function = value; +} + +std::size_t DynamicShapeDetector::GetMaxAllowedTraces() { + return max_allowed_traces_per_function; +} + +bool DynamicShapeDetector::IsSessionActive() { + return current_session_ != nullptr; +} + +bool DynamicShapeDetector::AllowNewTrace() { + XLA_CHECK(IsSessionActive()); + return current_session_->traces_ < max_allowed_traces_per_function; +} + +void DynamicShapeDetector::EndSession() { + XLA_CHECK(IsSessionActive()); + + try { + // Mark the current builder_ node as trace boundary. + // If we did create a new trace, increment the session's trace number. + if (builder_.MarkTraceBoundary(AllowNewTrace())) { + current_session_->traces_++; + TF_VLOG(5) << "Created new trace."; + } + + ResetSession(); + TF_VLOG(5) << "Ended session: " << current_session_->name_; + } catch (const std::exception& e) { + // MarkTraceBoundary might raise an exception if AllowNewTrace() is false. + // Catch it here, so that we can correctly end the session. + ResetSession(); + throw; + } +} + +void DynamicShapeDetector::ResetSession() { + current_session_ = nullptr; + builder_ = {}; +} + +void DynamicShapeDetector::RootBuilder() { + builder_ = current_session_->NewBuilder(); +} + +void DynamicShapeDetector::AddNodeInfo(torch::lazy::hash_t hash, + const std::string& str) { + XLA_CHECK(current_session_ != nullptr); + + try { + builder_.AddValue({hash, str}, AllowNewTrace()); + } catch (const std::exception& e) { + // AddValue might raise an exception if AllowNewTrace() is false. Catch it + // here, so that we can correctly return the builder to the root of the + // trie. + // + // TODO(ysiraichi): we should actually rollback this trace. + RootBuilder(); + throw; + } +} + +void DynamicShapeDetector::RemoveSessionIfExists(const std::string& name) { + std::size_t removed = session_infos_.erase(name); + if (removed == 1) { + TF_VLOG(5) << "Removed session: " << name; + } +} + +TrieBuilder SessionInfo::NewBuilder() { return {root_.get(), 0}; } + +void TrieBuilder::AddValue(TrieValue value, bool allow_new_trace) { + *this = node_->AddValue(value, matched_, allow_new_trace); +} + +bool TrieBuilder::MarkTraceBoundary(bool allow_new_trace) { + return node_->MarkTraceBoundary(matched_, allow_new_trace); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/dynamic_shape_detector.h b/torch_xla/csrc/dynamic_shape_detector.h new file mode 100644 index 00000000000..623814fb33c --- /dev/null +++ b/torch_xla/csrc/dynamic_shape_detector.h @@ -0,0 +1,246 @@ +#ifndef XLA_TORCH_XLA_CSRC_DYNAMIC_SHAPE_DETECTOR_H_ +#define XLA_TORCH_XLA_CSRC_DYNAMIC_SHAPE_DETECTOR_H_ + +#include + +#include + +#include "absl/types/span.h" + +namespace torch_xla { + +struct TrieNode; + +// Unit of information stored inside the trie. +struct TrieValue { + // Unique value that identifies an IR node. + torch::lazy::hash_t hash; + + // String representation of the IR node. + std::string str; +}; + +// Helper struct for iteratively building the trie. +// +// This structure keeps track of what's the current state (i.e. TrieNode) and +// how much of its common_sequence_ have we matched up to this point (i.e. +// matched). +// +// Upon calling AddValue, we might change the current state of this builder by +// increasing the number of matched elements, or jumping to another TrieNode. +// +// You can think of it as if we were jumping around in a DFA, and this struct +// allow us to walk through and modify said DFA. +// +// Main assumption for every TrieBuilder +// ===================================== +// +// 1. matched_ will be up to the size of the node's common_sequence_. +// 2. matched_ will be 0 only in the beginning for the root node. +struct TrieBuilder { + // Wrappers to the currently pointed to TrieNode methods. + void AddValue(TrieValue value, bool allow_new_trace); + bool MarkTraceBoundary(bool allow_new_trace); + + // Current TrieNode. + TrieNode* node_; + + // Number of matched elements in the current node. + std::size_t matched_; +}; + +// Implementation of a compressed trie. +// +// Main idea +// ========= +// +// The main interface to interact with TrieNode is TrieBuilder. We start from +// the root, incrementally accepting new TrieValue by calling AddValue. Said +// function will incrementally build the trie. Finally, MarkTraceBoundary will +// set is_trace_boundary_ and maybe split the current node (if we haven't +// matched everything in this node's common_sequence_). +// +// Main assumption for every TrieNode +// ================================== +// +// Except for the root, common_sequence_ will always have size, at least, 1.The +// first element always corresponds to the TrieValue that was used to go from +// the parent node to this one. +// +// Examples +// ======== +// +// 1. In an empty trie, we started with TrieBuilder {root, 0}, where root +// corresponds to the empty trie's only node. Upon calling AddValue, the given +// value will be appended to common_sequence_. Finally, we return a new +// TrieBuilder {root, 1} (we have one match!). +// +// 2. Consider the TrieBuilder {root, 20}. If AddValue is called, and +// common_sequence_'s size is greater-than 20, we check if the given value is +// the same as common_sequence_[20]. If so, we do nothing but return an updated +// TrieBuilder {root, 21} (we have matched one extra value). +// +// 3. Consider the TrieBuilder {root, 20}. If AddValue is called, and +// common_sequence_'s size is exactly 20, we check if the given value is one of +// this node's children. If not, we create a new TrieNode, and add it to the +// children's map. Finally, we return a new TrieBuilder {newnode, 1} (the 1st +// element of newnode is the value that was responsible for creating it). +// +// 4. Consider the TrieBuilder {root, 20}. If AddValue is called, but the node's +// common_sequence_ has size 35, root will be split. As a result, 2 nodes will +// be created: (i) a node containing the remaining unmatched 15 elements; and +// (ii) a node containing the given TrieValue. The returned TrieBuilder will be +// {node (ii), 1}. +// +// 5. Consider the TrieBuilder {root, 20}. If MarkTraceBoundary is called, and +// root is a leaf (i.e. no children), then root.is_trace_boundary_ is set to +// true. +struct TrieNode { + using ChildrenMap = std::map>; + + TrieNode(absl::Span common_sequence = {}, + bool is_trace_boundary = false); + + // May add TrieValue to this TrieNode. + // + // This function is used to iteratively construct the trace. It does 2 things. + // + // First, it checks whether the given value actually matches the values + // already inside this node, i.e. this trace was seen before. For example, the + // given value may match the value inside common_sequence_ (after `matched` + // elements) or one of children (if `matched` equals the size of + // common_sequence_). + // + // Then, if the given value is not inside this node, we have to add it by + // either: + // 1. adding it to the common_sequence_ + // 2. adding it to the children_ + // 3. splitting this node, creating 2 new nodes containing: (i) rest of the + // unmatched common_sequence_; and (ii) the given value. + TrieBuilder AddValue(TrieValue value, std::size_t matched, + bool allow_new_trace); + + // Marks this node as trace boundary. + // + // Given the number of `matched` elements in the common_sequence_, this + // function sets `is_trace_boundary_` and possibly moves the rest of the + // unmatched common_sequence_ to a new node. + // + // Returns whether a new trace was created. + bool MarkTraceBoundary(std::size_t matched, bool allow_new_trace); + + // Issue an error indicating a new trace is not allowed. + // + // This function will correctly inspect the TrieNode, building an informative + // error message. + void NewTraceNotAllowedError(std::optional value, + std::size_t matched); + + // Maybe split this node into 2, containing, respectively: (i) + // common_sequence_ before `matched`; and (ii) common_sequence_ after + // `matched`. + // + // If `matched` is 0, it means that we have a divergence from the start. The + // created suffix node will be a copy of this node, while this node will have + // a 0-sized common_sequence_ with populated children. + // + // If `matched` is the size of common_sequence_, it means that we don't need a + // suffix node (since there's no suffix). + // + // Return whether we did split it or not. + bool MaybeSplitAt(std::size_t matched); + + // Returns true if this node is a leaf, i.e. no children. + bool IsLeaf() const; + + // Sequence of values all children_ in this node share. + std::vector common_sequence_; + + // Flag indicating whether the current node is a trace boundary. i.e. + // whether there is a trace that ends with common_sequence_. + bool is_trace_boundary_; + + // Children, i.e. forking points, of this node. + ChildrenMap children_; +}; + +struct SessionInfo { + // Instantiates a new TrieBuilder located at its root. + TrieBuilder NewBuilder(); + + // Name of this session. + std::string name_; + + // Root of the trie that stores trace information for this session. + std::unique_ptr root_; + + // Number of recorded traces for this session. + std::size_t traces_; +}; + +// Surface class for detecting dynamic shapes. +// +// Manages the information related to each session as well as the active +// session, i.e. the one that we are recording traces for. +class DynamicShapeDetector { + public: + static DynamicShapeDetector* Get(); + + // Starts recording the created IR nodes into the trie whose root is + // associated with the session named: `name`. + void StartSession(const std::string& name); + + // Stops recording the created IR nodes for the active session. + // + // Before doing that, we commit the current trace, turning the current + // TrieNode being visited into a trace boundary. + // + // This function may raise an exception if we aren't allowed to create + // more traces. + void EndSession(); + + // Records a newly created IR node (its metadata). + // + // This function may raise an exception if: + // 1. we aren't allowed to create more traces; and + // 2. we have to create a new TrieNode because this IR node wasn't expected + // in the trie. + void AddNodeInfo(torch::lazy::hash_t hash, const std::string& str); + + // Checks whether there's any session active. + bool IsSessionActive(); + + // Maybe removes the session entry. + void RemoveSessionIfExists(const std::string& name); + + // API for setting the maximum number of traces allowed to be recorded. + static void SetMaxAllowedTraces(std::size_t value); + static std::size_t GetMaxAllowedTraces(); + + private: + // Whether the current session allows new traces, i.e. new graph compilations. + bool AllowNewTrace(); + + // Move the TrieBuilder to the root node of this session. + void RootBuilder(); + + // Resets the data related to the current session. + // + // Specifically, this function: + // 1. resets the builder + // 2. assigns current_session_ to nullptr + void ResetSession(); + + // Stores the information related to each session. + std::unordered_map session_infos_; + + // Pointer to the current active session. + SessionInfo* current_session_; + + // Iterative builder for the current active session. + TrieBuilder builder_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_DYNAMIC_SHAPE_DETECTOR_H_ diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 86cb27a22ef..3f478dce87e 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -40,6 +40,7 @@ #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/dl_convertor.h" #include "torch_xla/csrc/dtype.h" +#include "torch_xla/csrc/dynamic_shape_detector.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" @@ -2520,6 +2521,24 @@ void InitXlaModuleBindings(py::module m) { }); m.def("_get_current_graph_name", []() { return XLAGraphExecutor::Get()->CurrentGraphName(); }); + m.def("_dynamic_shape_detector_start_session", + [](const std::string& session) { + DynamicShapeDetector::Get()->StartSession(session); + }); + m.def("_dynamic_shape_detector_end_session", + []() { return DynamicShapeDetector::Get()->EndSession(); }); + m.def("_dynamic_shape_detector_remove_session", + [](const std::string& session) { + DynamicShapeDetector::Get()->RemoveSessionIfExists(session); + }); + m.def("_dynamic_shape_detector_set_max_allowed_traces", + [](int64_t max_allowed_traces) { + DynamicShapeDetector::SetMaxAllowedTraces(max_allowed_traces); + }); + m.def("_dynamic_shape_detector_get_max_allowed_traces", + [](int64_t max_allowed_traces) { + return DynamicShapeDetector::GetMaxAllowedTraces(); + }); m.def("_replace_xla_tensor", [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { return XLANativeFunctions::set_(self, source); diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 7a0a226c27d..2e4338f50b7 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -43,6 +43,22 @@ torch::lazy::hash_t GetOperandHashes(const torch::lazy::OpList& operands, } // namespace +void DetectDynamicShape(torch::lazy::NodePtr node) { + DynamicShapeDetector* detector = DynamicShapeDetector::Get(); + + if (!detector->IsSessionActive()) { + return; + } + + // don't add leaf nodes + std::unordered_set unwanted_nodes = {"xla::device_data", + "prim::Constant"}; + + if (unwanted_nodes.find(node->op().ToString()) == unwanted_nodes.end()) { + detector->AddNodeInfo(node->hash(), node->ToString()); + } +} + XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::OpList operands, std::vector&& shapes, xla::Shape xla_shape, size_t num_outputs, torch::lazy::hash_t hash_seed) diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index 10f3a8eeb33..1e8ede42006 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -19,6 +19,7 @@ #include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" #include "absl/types/span.h" +#include "torch_xla/csrc/dynamic_shape_detector.h" #include "torch_xla/csrc/runtime/types.h" #include "xla/client/xla_builder.h" @@ -35,9 +36,12 @@ template using OutputMap = std::unordered_map; +void DetectDynamicShape(torch::lazy::NodePtr node); + template torch::lazy::NodePtr MakeNode(Args&&... args) { torch::lazy::NodePtr res = std::make_shared(std::forward(args)...); + DetectDynamicShape(res); return res; } diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index 58958979f04..2d692c7e7a7 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -2,7 +2,9 @@ import collections import contextlib import functools +import uuid from typing import Any, Callable, List, Optional, Tuple +import weakref import torch import torch.distributed as dist @@ -59,7 +61,7 @@ def device_count() -> int: def sync(wait: bool = False): """Launches all pending graph operations. - + Args: wait (bool): whether to block the current process until the execution finished. @@ -83,9 +85,17 @@ def step(): return compile() -def compile(f: Optional[Callable] = None, - full_graph: Optional[bool] = False, - name: Optional[str] = None): +# Keeps track of the alive functions. This allow us to remove session entries in the +# C++ side for functions that are no longer alive. +_compiled_id_to_functions_ref = weakref.WeakValueDictionary() + + +def compile( + f: Optional[Callable] = None, + full_graph: Optional[bool] = False, + name: Optional[str] = None, + num_different_graphs_allowed: Optional[int] = None, +): """ Optimizes given model/function using torch_xla's LazyTensor tracing mode. PyTorch/XLA will trace the given function with given inputs and then generate @@ -102,6 +112,9 @@ def compile(f: Optional[Callable] = None, name (Optional[name]): Name of the compiled program. The name of the function `f` will be used if not specified. This name will be used in the `PT_XLA_DEBUG` messages as well as HLO/IR dump file. + num_different_graphs_allowed (Optional[int]): number of different traced graphs of the given + model/function that we are allowed to have. An error will be raised in case this limit + is exceeded. Example:: @@ -119,12 +132,31 @@ def foo2(x): with torch_xla.compile(): res = foo2(x) """ - if name == None and f: + if name is None and f is not None: if hasattr(f, '__name__'): name = f.__name__ elif hasattr(f, '__str__'): name = f.__str__() + if f is not None: + current_id = f"{name}_{id(f)}" + else: + current_id = str(uuid.uuid4()) + + # Check whether the function/module that corresponds with current_id is still alive. If it's not, + # we can remove it from the session's map in the C++ side, so we can start a fresh session. + # + # This solves the issue where there are 2 different local-scoped functions with the same name. + # Since they are local-scoped, they might end-up with the same id. And, since they have the same + # name, their current_id will be the same, even though they are different functions. + # + # This issue was observed when running test_dynamic_shape_detector.py. + if current_id not in _compiled_id_to_functions_ref: + torch_xla._XLAC._dynamic_shape_detector_remove_session(current_id) + + if f is not None: + _compiled_id_to_functions_ref[current_id] = f + def _clear_pending_ops_before_compile(): sync() @@ -134,28 +166,35 @@ def _compile(): saved_allow_execution = torch_xla._XLAC._get_allow_execution() saved_current_graph_name = torch_xla._XLAC._get_current_graph_name() torch_xla._XLAC._set_use_eager_mode(False) - if name != None: + if name is not None: torch_xla._XLAC._set_current_graph_name(name + '_clear_pending') # Clear pending operations _clear_pending_ops_before_compile() - if name != None: + if name is not None: torch_xla._XLAC._set_current_graph_name(name) # if full_graph sets to true execution can not happen before the sync below torch_xla._XLAC._set_allow_execution(not full_graph) + if num_different_graphs_allowed is not None: + torch_xla._XLAC._dynamic_shape_detector_set_max_num_different_graphs_allowed( + num_different_graphs_allowed) + torch_xla._XLAC._dynamic_shape_detector_start_session(current_id) + try: yield finally: torch_xla._XLAC._set_allow_execution(saved_allow_execution) + if num_different_graphs_allowed is not None: + torch_xla._XLAC._dynamic_shape_detector_end_session() # Collect the traced graph after running the target function and # execute the graph. sync() torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status) torch_xla._XLAC._set_current_graph_name(saved_current_graph_name) - return _compile() if not f else _compile()(f) + return _compile() if f is None else _compile()(f) def manual_seed(seed, device=None):