-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RUNTIME] Add min_repeat_ms to time_evaluator #2200
Conversation
5f99d62
to
f588537
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And also, is it possible to add some tests checking that this automatically adjusting time measurement algorithm works as expected?
python/tvm/module.py
Outdated
@@ -139,26 +139,38 @@ def time_evaluator(self, func_name, ctx, number, repeat=1): | |||
The context we should run this function on. | |||
|
|||
number: int | |||
The number of steps used in measuring each time interval | |||
The number of times to run this function for taking average. | |||
We call this as one `repeat` of measurement. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not very clear from the description what we call a repeat of measurement. (The description for min_repeat_ms makes everything clearer though)
src/runtime/rpc/rpc_session.cc
Outdated
int number, | ||
int repeat, | ||
int min_repeat_ms) { | ||
auto ftimer = [pf, ctx, &number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the local variable number
captured by reference here? It will escape the local scope, might be a bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turns out tvm packed function does not support capturing reference. I updated to capture by value.
src/runtime/rpc/rpc_session.cc
Outdated
|
||
if (duration_ms < min_repeat_ms) { | ||
number = static_cast<int>(std::max((min_repeat_ms / (duration_ms / number) + 1), | ||
number * 1.618)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is 1.618
?
Also, using ceil here might be better than adding 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, do we need this branch here if the loop will exit if the condition is met?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is
1.618
?
Also, using ceil here might be better than adding 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Precision is not very important here as I want to encourage it to set a higher number
.
src/runtime/rpc/rpc_session.cc
Outdated
|
||
duration_ms = std::chrono::duration_cast<std::chrono::duration<double> > | ||
(tend - tbegin).count() * 100; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, here we rerun the whole process until we find the right number of iterations. An alternative would be to rerun only the number of iterations equal to the difference between the necessary number of iterations and the number of iterations already run. And then add its duration to the total duration. This approach may have a slightly different behavior, it may be a bit faster, but a bit less precise, I'm not sure, so I would like to see some more comments in the code describing the algorithm, and why this particular algorithm was chosen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sgrechanik-h We cannot use the accumulation mode due to the reason explained by eqy
are not precise enough to capture short-running tasks. This parameter is | ||
also critical when devices need a certain minimum running time to "warm | ||
up," such as GPUs that need time to reach a performance power state. | ||
where the first one is warm up and will be discarded. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe change this to "plus an additional warm up run that will be discarded." It currently sounds like it means (number - 1) x repeat
int number, | ||
int repeat, | ||
int min_repeat_ms) { | ||
auto ftimer = [pf, ctx, &number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) { | ||
TVMRetValue temp; | ||
std::ostringstream os; | ||
// skip first time call, to activate lazy compilation components. | ||
pf.CallPacked(args, &temp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this definition (1 + number * repeat) is the correct formulation after we have introduced min_repeat_ms
. The goal is to start measurement in the correct power state, which we will likely do if we bump up number
over and over again for the same time_evaluator
call. However, let's say that number
is now sufficient and we get to a fresh time_evaluator
call. In this case I am not sure 1+
will be enough to get the hardware into the right state if necessary. Should we consider number*(1+repeat)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the definition is not correct. I will add a note to the doc string of min_repeat_ms
but keep this definition here.
src/runtime/rpc/rpc_session.cc
Outdated
|
||
if (duration_ms < min_repeat_ms) { | ||
number = static_cast<int>(std::max((min_repeat_ms / (duration_ms / number) + 1), | ||
number * 1.618)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, do we need this branch here if the loop will exit if the condition is met?
@@ -124,7 +124,8 @@ class RPCModuleNode final : public ModuleNode { | |||
PackedFunc GetTimeEvaluator(const std::string& name, | |||
TVMContext ctx, | |||
int number, | |||
int repeat) { | |||
int repeat, | |||
int min_repeat_ms) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this break some current tests if we do not give a default value for min_repeat_ms
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a default argument to python side
@merrymercy what is the status of this PR? |
684edd4
to
e58d6c3
Compare
e58d6c3
to
3587c6b
Compare
src/runtime/rpc/rpc_session.cc
Outdated
TVMRetValue temp; | ||
std::ostringstream os; | ||
// skip first time call, to activate lazy compilation components. | ||
pf.CallPacked(args, &temp); | ||
DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); | ||
int dynamic_number = number; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You used to modify number
directly which had a nice property of remembering the suitable value of number
between runs. I think you can still achieve this effect by declaring the lambda as mutable
(won't be thread-safe though, so I'm not sure).
src/runtime/rpc/rpc_session.cc
Outdated
|
||
dynamic_number = static_cast<int>( | ||
std::max((min_repeat_ms / (duration_ms / dynamic_number) + 1), | ||
dynamic_number * 1.618)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The choice of the constant needs an explanation inside the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Halide uses 2 but I think there is no "correct" number, so it is a random number.
@eqy please review again |
ping @eqy @sgrechanik-h please take another look, if there is no further comments in 24 hours, we can go ahead and merge this PR in |
Thanks, @merrymercy @eqy @sgrechanik-h , this is merged |
min_repeat_ms
sets the minimum duration of a measurement and has been used in autotvm for measurement.As it is a useful feature to make measurement accurate and smart, we'd better move it to general API
time_evaluator
and encourage people to use it.cc @eqy @tqchen @sgrechanik-h