|
|
|
@ -25,6 +25,7 @@
|
|
|
|
|
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
|
|
|
|
|
#include "backend/kernel_compiler/tbe/tbe_utils.h"
|
|
|
|
|
#include "utils/ms_context.h"
|
|
|
|
|
#include "runtime/dev.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace kernel {
|
|
|
|
@ -86,6 +87,8 @@ constexpr auto kJPyModulePath = "py_module_path";
|
|
|
|
|
constexpr auto kJPreBuildOutsAttrs = "prebuild_outs_attrs";
|
|
|
|
|
constexpr auto kJKwdArgs = "kwds_args";
|
|
|
|
|
constexpr auto kJListArgs = "list_args";
|
|
|
|
|
constexpr auto kJSocVersion = "socVersion";
|
|
|
|
|
constexpr auto kSOC_VERSION = "SOC_VERSION";
|
|
|
|
|
|
|
|
|
|
bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node,
|
|
|
|
|
nlohmann::json *kernel_json) {
|
|
|
|
@ -122,6 +125,8 @@ bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspor
|
|
|
|
|
nlohmann::json attrs_json;
|
|
|
|
|
(void)GenTbeAttrJson(anf_node, op_info_ptr, &attrs_json);
|
|
|
|
|
op_info_json[kJAttrs] = attrs_json;
|
|
|
|
|
auto soc_version = TbeKernelJsonCreator::GetSocVersion();
|
|
|
|
|
op_info_json[kJSocVersion] = soc_version;
|
|
|
|
|
std::string json_str = op_info_json.dump();
|
|
|
|
|
size_t hash_id = std::hash<std::string>()(json_str);
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
@ -414,6 +419,30 @@ bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_no
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
string TbeKernelJsonCreator::GetSocVersion() {
|
|
|
|
|
// Get default soc version.
|
|
|
|
|
const int kSocVersionLen = 50;
|
|
|
|
|
char soc_version[kSocVersionLen] = {0};
|
|
|
|
|
auto ret = rtGetSocVersion(soc_version, kSocVersionLen);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "GetSocVersion failed.";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Default SocVersion is " << soc_version;
|
|
|
|
|
// Get soc version from env value.
|
|
|
|
|
const char *soc_version_env = getenv(kSOC_VERSION);
|
|
|
|
|
if (soc_version_env != nullptr) {
|
|
|
|
|
if (std::strcmp(soc_version, soc_version_env) != 0) {
|
|
|
|
|
MS_LOG(WARNING) << "SocVerison change to " << soc_version_env;
|
|
|
|
|
ret = rtSetSocVersion(soc_version_env);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "SetSocVersion to " << soc_version_env << " failed, errorno: " << ret;
|
|
|
|
|
}
|
|
|
|
|
return soc_version_env;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return soc_version;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value,
|
|
|
|
|
nlohmann::json *attr_obj) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value);
|
|
|
|
@ -630,6 +659,8 @@ bool TbeKernelBuild::GenFusionScopeJson(const std::vector<mindspore::AnfNodePtr>
|
|
|
|
|
index = 0;
|
|
|
|
|
data_list.insert(data_list.end(), compute_list.begin(), compute_list.end());
|
|
|
|
|
(*fusion_json)[kFusionOpList] = data_list;
|
|
|
|
|
auto soc_version = TbeKernelJsonCreator::GetSocVersion();
|
|
|
|
|
(*fusion_json)[kJSocVersion] = soc_version;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -859,6 +890,7 @@ bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptr<mindspore::Anf
|
|
|
|
|
(*data_str)[kJName] = name;
|
|
|
|
|
nlohmann::json output_desc;
|
|
|
|
|
output_desc[kJName] = name;
|
|
|
|
|
output_desc[kJDataType] = 0;
|
|
|
|
|
output_desc[kJShape] = "NULL";
|
|
|
|
|
output_desc_list.push_back(output_desc);
|
|
|
|
|
(*index)++;
|
|
|
|
@ -991,6 +1023,7 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
|
|
|
|
|
for (size_t i = 0; i < optional_num; ++i) {
|
|
|
|
|
nlohmann::json optional_input_desc;
|
|
|
|
|
optional_input_desc[kJName] = std::string(kOptional) + std::to_string(*index);
|
|
|
|
|
optional_input_desc[kJShape] = "NULL";
|
|
|
|
|
(*index)++;
|
|
|
|
|
(*layer_iter)->emplace_back(nullptr);
|
|
|
|
|
input_desc_list_tmp.emplace_back(optional_input_desc);
|
|
|
|
|