-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【multi precision】multi precision support (fp32 + fp16) #9339
Conversation
@@ -279,6 +279,7 @@ endif() | |||
if (LITE_ON_TINY_PUBLISH) | |||
add_definitions("-DLITE_ON_TINY_PUBLISH") | |||
add_definitions("-DLITE_ON_FLATBUFFERS_DESC_VIEW") | |||
add_definitions("-DLITE_WITH_FLATBUFFERS_DESC") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可能影响模型加载耗时,需要评估下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前测试的几个模型看来没有影响
@@ -194,7 +194,8 @@ void LightPredictor::PrepareFeedFetch() { | |||
} | |||
|
|||
void LightPredictor::BuildRuntimeProgram( | |||
const std::shared_ptr<const cpp::ProgramDesc>& program_desc) { | |||
const std::shared_ptr<const cpp::ProgramDesc>& program_desc, | |||
bool use_precision_low) { | |||
auto* exe_scope = &scope_->NewScope(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名应该是 exec_scope
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前别人命名的,我这先不改了
if (op_type != "feed" && op_type != "fetch") { | ||
if (place.precision == PRECISION(kFloat)) { | ||
place.precision = PRECISION(kFP16); | ||
} else if (place.precision == PRECISION(kAny)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
判断可以合并
if (place.precision == PRECISION(kFloat) || place.precision == PRECISION(kAny)) {
function()...
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
lite/core/op_lite.cc
Outdated
@@ -72,8 +72,8 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels( | |||
auto pick_kernel = [&](const Place &place) { | |||
auto ks = KernelRegistry::Global().Create( | |||
op_type_, place.target, place.precision, place.layout); | |||
VLOG(5) << "pick kernel for " << op_info()->Type() << " " | |||
<< place.DebugString() << " get " << ks.size() << " kernels"; | |||
// VLOG(5) << "pick kernel for " << op_info()->Type() << " " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以清理下注释的代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
lite/core/op_lite.cc
Outdated
@@ -130,6 +130,18 @@ bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) { | |||
return AttachImpl(*op_info(), scope); | |||
} | |||
|
|||
#ifdef LITE_ON_FLATBUFFERS_DESC_VIEW | |||
bool OpLite::Attach(const cpp::OpDescWrite &opdesc, lite::Scope *scope) { | |||
// valid_places_.clear(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同,注释的代码不用的话,可以清理下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
lite/core/program.cc
Outdated
int low_precision = 1; | ||
std::string old_op; | ||
|
||
if (use_precision_low == true) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的逻辑直接改成这样更直观点?
use_precision_low_ = use_precision_low;
low_precision = use_precision_low_ ? 1 : 0;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
lite/core/program.cc
Outdated
#ifdef ENABLE_ARM_FP16 | ||
if (lite::DeviceInfo::Global().has_fp16() && low_precision == 1) { | ||
if (op_type != "feed" && op_type != "fetch") { | ||
if (place.precision == static_cast<PrecisionType>(1)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接用PrecisionType::kFloat这类枚举值,不要用1、5这类hard code,下同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok,这地方忘改了
place.precision = static_cast<PrecisionType>(5); | ||
} | ||
} | ||
// transfer weight to fp16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下面的代码是不是跟WeightFP32ToFP16这个函数重复了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,我看看调用一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
很多地方不一样,应该调用不了
// kernels = op->CreateKernels({place}); | ||
//} | ||
if (kernels.size() == 0 && place.target == TargetType::kARM) { | ||
place.target = TargetType::kHost; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是为什么要改成TargetType::kHost呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为有些fp16的kernal是host下的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
虽然也是arm的代码
lite/api/paddle_api.h
Outdated
@@ -559,6 +559,7 @@ class LITE_API MobileConfig : public ConfigBase { | |||
// whether to load data from memory. Model data will be loaded from memory | |||
// buffer if model_from_memory_ is true. | |||
bool model_from_memory_{false}; | |||
PrecisionMode pre_mode_{LITE_PRECISION_NORMAL}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用简写成pre_mode吧,直接用 precision_mode_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
|
||
template <DataLayoutType DLType> | ||
class CalibComputeFp32ToInt32 | ||
: public KernelLite<TARGET(kARM), PRECISION(kInt32), DLType> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PRECISION(kInt32)->PRECISION(kFloat)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok,这里用不上,之后删掉
|
||
template <DataLayoutType DLType> | ||
class CalibComputeFp32ToInt64 | ||
: public KernelLite<TARGET(kARM), PRECISION(kInt64), DLType> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PRECISION(kInt64)->PRECISION(kFloat)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok,这里用不上,之后删掉
lite/operators/calib_inplace_op.h
Outdated
#ifdef LITE_ON_FLATBUFFERS_DESC_VIEW | ||
bool AttachImpl(const cpp::OpDescWrite &opdesc, lite::Scope *scope) override; | ||
#endif | ||
void *getparam() { return ¶m_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GetParam()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
1.支持运行时混合精度
2.添加calib inplace算子
性能:
库体积影响:
+300KB