forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
record_function.h
182 lines (146 loc) · 5.15 KB
/
record_function.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#pragma once
#include <ATen/core/ivalue.h>
#include <c10/util/SmallVector.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch { namespace autograd {
struct Node;
namespace profiler {
struct TORCH_API StringView {
StringView() : StringView(nullptr) {}
explicit StringView(const char* str_ptr)
: owned_str_ptr_(nullptr), str_ptr_(str_ptr) {}
explicit StringView(std::string str)
: owned_str_ptr_(std::make_shared<std::string>(std::move(str))),
str_ptr_(owned_str_ptr_->c_str()) {}
inline const char* str() const {
return str_ptr_;
}
friend std::ostream& operator<<(std::ostream& os, const StringView& dt) {
os << dt.str();
return os;
}
friend bool operator==(const StringView& lhs, const StringView& rhs) {
return strcmp(lhs.str(), rhs.str()) == 0;
}
friend bool operator!=(const StringView& lhs, const StringView& rhs) {
return !(lhs == rhs);
}
private:
std::shared_ptr<std::string> owned_str_ptr_;
const char* str_ptr_;
};
struct TORCH_API RecordFunction {
// Default constructor is used with before function called afterwards
RecordFunction() {}
RecordFunction(const RecordFunction&) = delete;
RecordFunction& operator=(const RecordFunction&) = delete;
// current returns the currently active RecordFunction in this thread.
static RecordFunction* current();
// before function initializes RecordFunction members and calls
// start callbacks
void before(const char* name, int64_t sequence_nr = -1);
void before(std::string name, int64_t sequence_nr = -1);
void before(Node* fn, int64_t sequence_nr = -1);
template<typename F>
void before(
F fn,
c10::ArrayRef<c10::IValue> args,
int64_t current_sequence_nr = -1) {
inputs_ = args.vec();
before(fn, current_sequence_nr);
}
template<typename F>
void before(
F fn,
std::vector<c10::IValue>&& args,
int64_t current_sequence_nr = -1) {
inputs_ = std::move(args);
before(fn, current_sequence_nr);
}
// Destructor calls end callbacks
virtual ~RecordFunction();
inline Node* func() const {
return fn_;
}
inline const StringView& name() const {
return name_;
}
inline int64_t seqNr() const {
return sequence_nr_;
}
const std::vector<c10::IValue>& inputs() const {
return inputs_;
}
inline const RecordFunction* parent() const {
return parent_;
}
bool active() const {
return initialized_;
}
void setRunSampled(bool run_sampled) {
run_sampled_ = run_sampled;
}
void end();
// Saves the thread_id that this RecordFunction was created with. This is
// needed so that we can access Events created by the original thread in a
// different thread, since they are thread-local. This should be used to call
// RecordFunction::end() in a different thread.
void setThreadId();
// Retrieves the thread_id that this RecordFunction was created with. Useful
// if we need to access Events created by the original thread in a different
// thread. The threadId_ should only be set (via setThreadId) in cases where
// RecordFunction::end is called in a different thread.
inline uint16_t getThreadId() const {
return threadId_;
}
private:
void processCallbacks();
Node* fn_ = nullptr;
StringView name_;
int64_t sequence_nr_ = -1;
std::vector<c10::IValue> inputs_;
// parent_ points to the parent RecordFunction and must out live this.
RecordFunction* parent_ = nullptr;
bool initialized_ = false;
bool run_sampled_ = false;
// The thread_id that this RecordFunction was created with. If 0, this means
// that it was not set with setThreadId() and this RecordFunction's callbacks
// cannot be invoked from a separate thread.
uint16_t threadId_ = 0;
};
TORCH_API bool hasCallbacks();
TORCH_API bool needsInputs();
TORCH_API bool hasNonSampledCallbacks();
TORCH_API void setSamplingProbability(double);
TORCH_API double getSamplingProbability();
TORCH_API bool shouldRunSampledCallbacks();
// Given a record function, run the (possibly sampled) start callbacks that have
// been pushed via pushCallback().
TORCH_API void runBeforeCallbacks(
RecordFunction* rf,
const std::string& funcName);
// optional argument - function's seq_no
#define RECORD_FUNCTION(fn, inputs, ...) \
torch::autograd::profiler::RecordFunction guard; \
if (torch::autograd::profiler::hasCallbacks()) { \
auto run_sampled = torch::autograd::profiler::shouldRunSampledCallbacks(); \
if (run_sampled || torch::autograd::profiler::hasNonSampledCallbacks()) { \
guard.setRunSampled(run_sampled); \
if (torch::autograd::profiler::needsInputs()) { \
guard.before(fn, inputs, ##__VA_ARGS__); \
} else { \
guard.before(fn, ##__VA_ARGS__); \
} \
} \
}
// WARNING: all calls to pushCallback/popCallback are not thread safe and
// must not overlap with other code execution
using RecordFunctionCallback = std::function<void(const RecordFunction&)>;
TORCH_API void pushCallback(
RecordFunctionCallback start,
RecordFunctionCallback end = [](const RecordFunction&){},
bool needs_inputs = false,
bool sampled = false);
TORCH_API void popCallback();
} // namespace profiler
}} // namespace torch::autograd