Save atomic kernel bin to model.

pull/1442/head
zhaozhixuan 4 years ago
parent 4ee06ea0a6
commit 6c992375aa

@ -35,11 +35,11 @@ void TBEKernelStore::LoadTBEKernelBinToOpDesc(const std::shared_ptr<ge::OpDesc>
GELOGI("Load tbe kernel:%s, %zu", kernel_bin->GetName().c_str(), kernel_bin->GetBinDataSize()); GELOGI("Load tbe kernel:%s, %zu", kernel_bin->GetName().c_str(), kernel_bin->GetBinDataSize());
std::string atomic_kernel_name; std::string atomic_kernel_name;
(void) AttrUtils::GetStr(op_desc, "ATOMIC_ATTR_TBE_KERNEL_NAME", atomic_kernel_name); (void) AttrUtils::GetStr(op_desc, ATOMIC_ATTR_TBE_KERNEL_NAME, atomic_kernel_name);
if (!atomic_kernel_name.empty()) { if (!atomic_kernel_name.empty()) {
GELOGI("Get atomic kernel name is %s", atomic_kernel_name.c_str()); GELOGI("Get atomic kernel name is %s.", atomic_kernel_name.c_str());
auto atomic_kernel_bin = FindKernel(atomic_kernel_name); auto atomic_kernel_bin = FindKernel(atomic_kernel_name);
GE_IF_BOOL_EXEC(!op_desc->SetExtAttr("EXT_ATTR_ATOMIC_TBE_KERNEL", atomic_kernel_bin), GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(EXT_ATTR_ATOMIC_TBE_KERNEL, atomic_kernel_bin),
GELOGW("LoadKernelTBEBinToOpDesc: SetExtAttr for atomic kernel_bin failed");) GELOGW("LoadKernelTBEBinToOpDesc: SetExtAttr for atomic kernel_bin failed");)
} }
} }

@ -654,7 +654,7 @@ Status ModelBuilder::SaveAtomicTBEKernel(const OpDescPtr &op_desc) {
} }
tbe_kernel_store_.AddTBEKernel(tbe_kernel); tbe_kernel_store_.AddTBEKernel(tbe_kernel);
GELOGD("Atomic_clean_node tbe_kernel_name %s!", tbe_kernel->GetName().c_str()); GELOGD("Atomic_clean_node tbe_kernel_name %s!", tbe_kernel->GetName().c_str());
(void) AttrUtils::SetStr(op_desc, "ATOMIC_ATTR_TBE_KERNEL_NAME", tbe_kernel->GetName()); (void) AttrUtils::SetStr(op_desc, ATOMIC_ATTR_TBE_KERNEL_NAME, tbe_kernel->GetName());
std::string kernel_name; std::string kernel_name;
(void) AttrUtils::GetStr(atomic_op_desc, atomic_op_desc->GetName() + "_kernelname", kernel_name); (void) AttrUtils::GetStr(atomic_op_desc, atomic_op_desc->GetName() + "_kernelname", kernel_name);
@ -662,11 +662,11 @@ Status ModelBuilder::SaveAtomicTBEKernel(const OpDescPtr &op_desc) {
std::string meta_data; std::string meta_data;
(void) AttrUtils::GetStr(atomic_op_desc, TVM_ATTR_NAME_METADATA, meta_data); (void) AttrUtils::GetStr(atomic_op_desc, TVM_ATTR_NAME_METADATA, meta_data);
(void) AttrUtils::SetStr(op_desc, "ATOMIC_ATTR_TVM_METADATA", meta_data); (void) AttrUtils::SetStr(op_desc, ATOMIC_ATTR_TVM_METADATA, meta_data);
std::string json_string; std::string json_string;
(void) AttrUtils::GetStr(atomic_op_desc, TVM_ATTR_NAME_MAGIC, json_string); (void) AttrUtils::GetStr(atomic_op_desc, TVM_ATTR_NAME_MAGIC, json_string);
(void) AttrUtils::SetStr(op_desc, "ATOMIC_ATTR_TVM_MAGIC", json_string); (void) AttrUtils::SetStr(op_desc, ATOMIC_ATTR_TVM_MAGIC, json_string);
return SUCCESS; return SUCCESS;
} }

@ -538,15 +538,15 @@ std::string AtomicAddrCleanOpTask::GetKeyForOpParamSize() const {
} }
std::string AtomicAddrCleanOpTask::GetKeyForTbeKernel() const { std::string AtomicAddrCleanOpTask::GetKeyForTbeKernel() const {
return "EXT_ATTR_ATOMIC_TBE_KERNEL"; return EXT_ATTR_ATOMIC_TBE_KERNEL;
} }
std::string AtomicAddrCleanOpTask::GetKeyForTvmMagic() const { std::string AtomicAddrCleanOpTask::GetKeyForTvmMagic() const {
return "ATOMIC_ATTR_TVM_MAGIC"; return ATOMIC_ATTR_TVM_MAGIC;
} }
std::string AtomicAddrCleanOpTask::GetKeyForTvmMetaData() const { std::string AtomicAddrCleanOpTask::GetKeyForTvmMetaData() const {
return "ATOMIC_ATTR_TVM_METADATA"; return ATOMIC_ATTR_TVM_METADATA;
} }
std::string AtomicAddrCleanOpTask::GetKeyForKernelName(const OpDesc &op_desc) const { std::string AtomicAddrCleanOpTask::GetKeyForKernelName(const OpDesc &op_desc) const {

@ -1 +1 @@
Subproject commit 4ff5e3987f2e5d2980019defacaf0891861c84fc Subproject commit 366b15574218befa11454311879a4f436eeb67a9

@ -1 +1 @@
Subproject commit 51fb6c4850906e8342598d47eccfca0b87ffea59 Subproject commit d744541c6ca7f6966c1befacc9f83f53b0829e0a

@ -476,8 +476,8 @@ TEST_F(UtestGeHybrid, test_key_for_kernel_bin) {
EXPECT_EQ(aicore_task->GetKeyForKernelName(op_desc), "Sum_kernelname"); EXPECT_EQ(aicore_task->GetKeyForKernelName(op_desc), "Sum_kernelname");
auto atomic_task = std::unique_ptr<hybrid::AtomicAddrCleanOpTask>(new(std::nothrow)hybrid::AtomicAddrCleanOpTask()); auto atomic_task = std::unique_ptr<hybrid::AtomicAddrCleanOpTask>(new(std::nothrow)hybrid::AtomicAddrCleanOpTask());
EXPECT_EQ(atomic_task->GetKeyForTbeKernel(), "EXT_ATTR_ATOMIC_TBE_KERNEL"); EXPECT_EQ(atomic_task->GetKeyForTbeKernel(), EXT_ATTR_ATOMIC_TBE_KERNEL);
EXPECT_EQ(atomic_task->GetKeyForTvmMagic(), "ATOMIC_ATTR_TVM_MAGIC"); EXPECT_EQ(atomic_task->GetKeyForTvmMagic(), ATOMIC_ATTR_TVM_MAGIC);
EXPECT_EQ(atomic_task->GetKeyForTvmMetaData(), "ATOMIC_ATTR_TVM_METADATA"); EXPECT_EQ(atomic_task->GetKeyForTvmMetaData(), ATOMIC_ATTR_TVM_METADATA);
EXPECT_EQ(atomic_task->GetKeyForKernelName(op_desc), "Sum_atomic_kernelname"); EXPECT_EQ(atomic_task->GetKeyForKernelName(op_desc), "Sum_atomic_kernelname");
} }

Loading…
Cancel
Save