-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
delay tensorrt registry #45824
delay tensorrt registry #45824
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,13 @@ namespace inference { | |
namespace tensorrt { | ||
namespace plugin { | ||
|
||
#if defined(_WIN32) | ||
#define UNUSED | ||
#define __builtin_expect(EXP, C) (EXP) | ||
#else | ||
#define UNUSED __attribute__((unused)) | ||
#endif | ||
|
||
class PluginTensorRT; | ||
|
||
typedef std::function<PluginTensorRT*(const void*, size_t)> | ||
|
@@ -372,6 +379,26 @@ class TensorRTPluginCreator : public nvinfer1::IPluginCreator { | |
std::vector<nvinfer1::PluginField> plugin_attributes_; | ||
}; | ||
|
||
class TrtPluginRegistry { | ||
public: | ||
static TrtPluginRegistry* Global() { | ||
static TrtPluginRegistry registry; | ||
return ®istry; | ||
} | ||
bool Regist(const std::string& name, std::function<void()> func) { | ||
map.emplace(name, func); | ||
return true; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 返回值的语义是什么?什么时候会返回 false? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为了在 namespace 里调用该函数,创建了一个 UNUSED 变量 |
||
} | ||
void RegistToTrt() { | ||
for (auto it : map) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. by reference There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DONE |
||
it.second(); | ||
} | ||
} | ||
|
||
private: | ||
std::unordered_map<std::string, std::function<void()>> map; | ||
}; | ||
|
||
template <typename T> | ||
class TrtPluginRegistrarV2 { | ||
public: | ||
|
@@ -386,9 +413,14 @@ class TrtPluginRegistrarV2 { | |
T creator; | ||
}; | ||
|
||
#define REGISTER_TRT_PLUGIN_V2(name) \ | ||
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \ | ||
plugin_registrar_##name {} | ||
#define REGISTER_TRT_PLUGIN_V2(name) REGISTER_TRT_PLUGIN_V2_HELPER(name) | ||
|
||
#define REGISTER_TRT_PLUGIN_V2_HELPER(name) \ | ||
UNUSED static bool REGISTER_TRT_PLUGIN_V2_HELPER##name = \ | ||
TrtPluginRegistry::Global()->Regist(#name, []() -> void { \ | ||
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \ | ||
plugin_registrar_##name{}; \ | ||
}); | ||
|
||
} // namespace plugin | ||
} // namespace tensorrt | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
func 为什么不传引用?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE