forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
net_async_tracing.h
182 lines (149 loc) · 4.97 KB
/
net_async_tracing.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CAFFE2_CORE_NET_ASYNC_TRACING_H_
#define CAFFE2_CORE_NET_ASYNC_TRACING_H_
#include "caffe2/core/common.h"
#include "caffe2/core/net_async_base.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/timer.h"
C10_DECLARE_string(caffe2_net_async_tracing_filepath);
C10_DECLARE_string(caffe2_net_async_names_to_trace);
C10_DECLARE_int(caffe2_net_async_tracing_nth);
namespace caffe2 {
namespace tracing {
struct TORCH_API TracerEvent {
int op_id_ = -1;
int task_id_ = -1;
int stream_id_ = -1;
const char* name_ = nullptr;
const char* category_ = nullptr;
long timestamp_ = -1.0;
bool is_beginning_ = false;
long thread_label_ = -1;
std::thread::id tid_;
int iter_ = -1;
};
enum TracingField {
TRACE_OP,
TRACE_TASK,
TRACE_STREAM,
TRACE_THREAD,
TRACE_NAME,
TRACE_CATEGORY,
TRACE_ITER,
};
enum class TracingMode {
EVERY_K_ITERATIONS,
GLOBAL_TIMESLICE,
};
struct TracingConfig {
TracingMode mode{TracingMode::EVERY_K_ITERATIONS};
std::string filepath{"/tmp"};
// for TracingMode::EVERY_K_ITERATIONS
int64_t trace_every_nth_batch = 100;
int64_t dump_every_nth_batch = 10000;
// for TracingMode::GLOBAL_TIMESLICE
int64_t trace_every_n_ms = 2 * 60 * 1000; // 2min
int64_t trace_for_n_ms = 1000; // 1sec
};
class TORCH_API Tracer {
public:
Tracer(
const NetBase* net,
const std::string& net_name,
TracingConfig = TracingConfig{});
void recordEvent(const TracerEvent& event);
std::string opTraceName(const OperatorBase* op);
std::string opBlobsInfo(const OperatorBase& op);
std::string serializeEvent(const TracerEvent& event);
void linearizeEvents();
void renameThreads();
void setEnabled(bool enabled);
bool isEnabled() const;
const TracingConfig& config() {
return config_;
}
int bumpIter();
int getIter();
int bumpDumpingIter();
// Dump the tracing result to file with given suffix, and then
// clear current events.
void dumpTracingResultAndClearEvents(const std::string& file_suffix);
virtual ~Tracer();
private:
const NetBase* net_ = nullptr;
std::string filename_;
std::vector<TracerEvent> events_;
std::mutex tracer_mutex_;
bool enabled_ = false;
Timer timer_;
int iter_;
int dumping_iter_;
TracingConfig config_;
friend class TracerGuard;
};
class TORCH_API TracerGuard {
public:
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.UninitializedObject)
TracerGuard() {}
void init(Tracer* tracer);
void addArgument();
void addArgument(TracingField field, const char* value);
void addArgument(TracingField field, int value);
template <typename T, typename... Args>
void addArgument(TracingField field, const T& value, const Args&... args) {
addArgument(field, value);
addArgument(args...);
}
void recordEventStart();
virtual ~TracerGuard();
static TracerGuard* getCurrentTracerGuard();
void disable();
private:
bool enabled_ = false;
TracerEvent event_;
Tracer* tracer_;
};
// Extract the shard id from name of the form "...shard:123..."
// Return -1 if there is no shard found
TORCH_API int extractShardId(const std::string& name);
// Check if the net name is white-listed for tracing (specified via a command
// line flag)
TORCH_API bool isTraceableNetName(const std::string& net_name);
TORCH_API std::shared_ptr<Tracer> create(
const NetBase* net,
const std::string& net_name);
TORCH_API bool startIter(const std::shared_ptr<Tracer>& tracer);
} // namespace tracing
#define TRACE_NAME_CONCATENATE(s1, s2) s1##s2
#define TRACE_ANONYMOUS_NAME(str) TRACE_NAME_CONCATENATE(str, __LINE__)
#define TRACE_EVENT_INIT(...) \
TRACE_ANONYMOUS_NAME(trace_guard).init(tracer_.get()); \
TRACE_ANONYMOUS_NAME(trace_guard).addArgument(__VA_ARGS__); \
TRACE_ANONYMOUS_NAME(trace_guard).recordEventStart();
// Supposed to be used only once per scope in AsyncNetBase-derived nets
#define TRACE_EVENT(...) \
tracing::TracerGuard TRACE_ANONYMOUS_NAME(trace_guard); \
if (tracer_ && tracer_->isEnabled()) { \
TRACE_EVENT_INIT(__VA_ARGS__) \
}
#define TRACE_EVENT_IF(cond, ...) \
tracing::TracerGuard TRACE_ANONYMOUS_NAME(trace_guard); \
if (tracer_ && tracer_->isEnabled() && (cond)) { \
TRACE_EVENT_INIT(__VA_ARGS__) \
}
} // namespace caffe2
#endif // CAFFE2_CORE_NET_ASYNC_TRACING_H_