!627 add validation of fmk type for plugin load.

From: @yangyongqiang5033
Reviewed-by: @wqtshg,@xchu42,@wqtshg
Signed-off-by: @wqtshg
pull/627/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 78c8adccc8

@ -184,12 +184,20 @@ void TBEPluginManager::GetCustomOpPath(std::string &customop_path) {
void TBEPluginManager::LoadCustomOpLib() {
LoadPluginSo(options_);
std::string fmk_type = std::to_string(domi::TENSORFLOW);
auto it = options_.find(ge::FRAMEWORK_TYPE);
if (it != options_.end()) {
fmk_type = it->second;
}
GELOGD("frameworkType is %s", fmk_type.c_str());
std::vector<OpRegistrationData> registration_datas = domi::OpRegistry::Instance()->registrationDatas;
GELOGI("The size of registration_datas is: %zu", registration_datas.size());
for (OpRegistrationData reg_data : registration_datas) {
GELOGD("Begin to register optype: %s, imply_type: %s", reg_data.GetOmOptype().c_str(),
TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str());
domi::OpRegistry::Instance()->Register(reg_data);
if (std::to_string(reg_data.GetFrameworkType()) == fmk_type) {
GELOGD("Begin to register optype: %s, imply_type: %s", reg_data.GetOmOptype().c_str(),
TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str());
domi::OpRegistry::Instance()->Register(reg_data);
}
}
}

@ -1 +1 @@
Subproject commit c841458262316866e7bfa7783f7ee3205e12e2c9
Subproject commit befc2aac08de4f1f1c38e476c4d3fd53174653ff
Loading…
Cancel
Save