forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTracerMode.h
134 lines (126 loc) · 5.38 KB
/
TracerMode.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#pragma once
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/macros/Macros.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
// NOTE [Tracing Mode Switches]
//
// Historically, tracing function was controlled by two switches:
//
// - `AutoNonVariableTypeMode` guard
//
// Tracing function used to be script-generated inside `VariableType_*.cpp`
// kernels, sharing the same `Autograd` dispatch key with autograd function.
// Therefore, before tracing function was moved out of VariableType,
// `AutoNonVariableTypeMode` guard can also disable tracing as a side effect
// of disabling `Autograd` dispatching.
//
// - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h`
//
// It stores tracing data in a `TracingState` object in TLS. If the
// `TracingState` object in TLS is `null`, then tracing is paused.
//
// The `TracingState` object is created in `tracer::trace()` - the main
// entrance of tracing function. It's temporarily set to `null` inside
// generated VariableType (now TraceType) to bypass tracing for intermediate
// ops (ops being called by other ops). After the intermediate op call
// finishes it's set back to the original `TracingState` object.
//
// The `TracingState` obect in TLS can also be read/written via its Python
// binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs,
// which are also exposed as `TORCH_API`.
//
// Two new switches were introduced since tracing function was moved out of
// VariableType:
//
// - `tracer::impl::set_dispatch_enabled()` API
//
// Unlike the special `Autograd` dispatch key which is included in dispatch
// key set by default, `Tracer` dispatch key is off by default. The
// dispatching switch can be toggled via this new API.
//
// - `tracer::impl::NoTracerDispatchMode` guard
//
// It's used to cover the old semantics of `AutoNonVariableTypeMode` after
// tracing was moved out of VariableType.
//
// Before tracing function was moved out of VariableType, tracing was enabled
// when the following conditions are satisfied:
//
// 1) `TracingState` object in TLS != null;
// - Either inside the execution scope of `tracer::trace()`, or
// - Eagerly called `setTracingState()` with non-null object.
// 2) Not inside `AutoNonVariableTypeMode` scope;
//
// After:
//
// 1) `TracingState` object in TLS != null;
// 2) Has called `tracer::impl::set_dispatch_enabled(true)`;
// 3) Not inside `tracer::impl::NonDispatchGuard` scope;
//
// [TODOs]
//
// - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()`
//
// Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()`
// to keep the semantics exactly the same as before - it's confusing to keep
// both switches, though. We should consider simplifying/limiting the exposed
// `setTracingState()` Python/C++ APIs (and other APIs calling it) so that
// these two can be unified.
//
// - `AutoNonVariableTypeMode` v.s. `tracer::impl::NoTracerDispatchMode`
//
// We don't need to always set both guards together to keep semantics
// unchanged. For the follow use cases of `AutoNonVariableTypeMode` we don't
// need set the new tracer guard:
//
// * Script-generated VariableType kernels. The guard is not necessary as
// tracing is already disabled explicitly by `setTracingState(null)` in
// generated TraceType kernels - we could keep it as is or use the new guard
// instead.
//
// * Custom ops. Will be handled by fallback kernel for `Tracer`.
//
// * Functions that are not likely to be called in tracing context (no python
// binding / not an operator), e.g.: all mobile forward() wrappers, test
// binaries, and etc.
//
// * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp.
// It's not necessary as tracing is off by default.
//
// For the rest of cases we might need have both:
//
// * Functions that might be reachable from eager mode python (especially
// factory methods), e.g.:
// `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`.
// Without the new guard it will add `aten::empty` to the traced graph.
//
// * Some manually maintained functions, e.g.:
// `torch/csrc/autograd/VariableTypeManual.cpp`.
// Set the new guard if it's not obvious whether `setTracingState(null)`
// has been called before it reaches the `AutoNonVariableTypeMode` guard.
//
// We might need tweak the usage of the new guard to optimize/fix things.
// It should only affect the correctness of tracing function, because the
// guard is essentially no-op when the master `setTracingState()` switch is
// off.
namespace at {
// TODO: move this from `at::` to `jit::torch::` after
// `aten/src/ATen/cpp_custom_type_hack.h` is removed.
namespace tracer {
namespace impl {
static inline bool is_dispatch_enabled() {
return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) &&
!c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer);
}
static inline void set_dispatch_enabled(bool enabled) {
TORCH_INTERNAL_ASSERT(
!c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer),
"Cannot enable tracing within the scope of NoTracerDispatchMode!");
c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled);
}
struct NoTracerDispatchMode {
c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer};
};
} // namespace impl
} // namespace tracer
} // namespace at