fusion op insert cache

pull/9385/head
jjfeing 4 years ago
parent b5aec1fd45
commit a8366502e2

@ -25,9 +25,30 @@
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace kernel {
namespace mindspore::kernel {
using mindspore::kernel::tbe::TbeUtils;
static size_t GenFusionJsonHash(const nlohmann::json &fusion_json) {
// get an copy
nlohmann::json fusion_json_copy = fusion_json;
auto &op_lists = fusion_json_copy["op_list"];
for (auto &op : op_lists) {
op.erase("name");
for (auto &output_desc : op["output_desc"]) {
output_desc.erase("name");
}
if (op["type"] != "Data") {
for (auto &input_desc : op["input_desc"]) {
input_desc.erase("name");
}
for (auto &list_arg : op["prebuild_output_attrs"]["list_args"]) {
list_arg.erase("name");
}
}
}
return std::hash<std::string>()(fusion_json_copy.dump());
}
std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo> &fusion_scopes) {
MS_LOG(INFO) << "kernel fusion build start, scope size:" << fusion_scopes.size();
std::map<int64_t, KernelModPtr> kernel_mod_ret;
@ -41,8 +62,8 @@ std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo>
continue;
}
// gen kernel_name & check cache
std::string json_str = fusion_op.dump();
size_t hash_id = std::hash<std::string>()(json_str);
size_t hash_id = GenFusionJsonHash(fusion_op);
MS_LOG(INFO) << "Fusion op hash id: " << hash_id;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
@ -102,5 +123,4 @@ std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo>
MS_LOG(INFO) << "Build Fusion Kernel Failed Num: " << build_failed_num;
return kernel_mod_ret;
}
} // namespace kernel
} // namespace mindspore
} // namespace mindspore::kernel

Loading…
Cancel
Save