diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/common.py b/mindspore/_extends/parallel_compile/tbe_compiler/common.py index 7287bace95..5b259a7da1 100644 --- a/mindspore/_extends/parallel_compile/tbe_compiler/common.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/common.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================ """tbe common""" -import json import os class TBEException(Exception): @@ -27,23 +26,6 @@ class TBEException(Exception): return self.__error_msg -def get_ddk_version(): - """get ddk version""" - ddk_version = os.environ.get("DDK_VERSION") - if ddk_version is None: - default_ddk_info_file = '/usr/local/HiAI/runtime/ddk_info' - backup_ddk_info_file = '/usr/local/Ascend/fwkacllib/ddk_info' - if os.path.exists(default_ddk_info_file): - with open(default_ddk_info_file, "r") as fp: - ddk_version = json.load(fp)["VERSION"] - elif os.path.exists(backup_ddk_info_file): - with open(backup_ddk_info_file, "r") as fp: - ddk_version = json.load(fp)["VERSION"] - else: - ddk_version = "Ascend910" - return ddk_version - - def get_build_in_impl_path(): """get build-in tbe implement path""" tbe_impl_path = os.environ.get("TBE_IMPL_PATH") diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py b/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py index e1f23a7549..c6e39a41a8 100755 --- a/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py @@ -18,9 +18,8 @@ import os import sys from te.platform.cce_conf import te_set_version from te.platform.fusion_util import fusion_op -from common import check_kernel_info, get_args, get_build_in_impl_path, get_ddk_version +from common import check_kernel_info, get_args, get_build_in_impl_path -ddk_version = get_ddk_version() build_in_impl_path = get_build_in_impl_path() # op function list @@ -30,7 +29,6 @@ fusion_pattern_end_flag = "fusion_pattern_end" def _initialize(impl_path): """Initialize""" - te_set_version(ddk_version) if impl_path == "": op_module_name = build_in_impl_path else: @@ -53,7 +51,7 @@ def build_op(build_type, json_str): """ kernel_info = json.loads(json_str) check_kernel_info(kernel_info) - + te_set_version(kernel_info["op_info"]["socVersion"]) op_name = kernel_info['op_info']['name'] try: @@ -111,7 +109,7 @@ def compile_fusion_op(json_str): Exception: If specific keyword is not found. """ args = json.loads(json_str) - te_set_version(ddk_version) + te_set_version(args['fusion_op']["socVersion"]) if 'fusion_op' not in args or not args['fusion_op']: raise ValueError("Json string Errors, key:fusion_op not found.") fusion_op_arg = args['fusion_op'] diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc index c7754156ab..84d36c4de4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc @@ -25,6 +25,7 @@ #include "backend/kernel_compiler/tbe/tbe_convert_utils.h" #include "backend/kernel_compiler/tbe/tbe_utils.h" #include "utils/ms_context.h" +#include "runtime/dev.h" namespace mindspore { namespace kernel { @@ -86,6 +87,8 @@ constexpr auto kJPyModulePath = "py_module_path"; constexpr auto kJPreBuildOutsAttrs = "prebuild_outs_attrs"; constexpr auto kJKwdArgs = "kwds_args"; constexpr auto kJListArgs = "list_args"; +constexpr auto kJSocVersion = "socVersion"; +constexpr auto kSOC_VERSION = "SOC_VERSION"; bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr &anf_node, nlohmann::json *kernel_json) { @@ -122,6 +125,8 @@ bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr()(json_str); auto context_ptr = MsContext::GetInstance(); @@ -414,6 +419,30 @@ bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr &anf_no return true; } +string TbeKernelJsonCreator::GetSocVersion() { + // Get default soc version. + const int kSocVersionLen = 50; + char soc_version[kSocVersionLen] = {0}; + auto ret = rtGetSocVersion(soc_version, kSocVersionLen); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "GetSocVersion failed."; + } + MS_LOG(INFO) << "Default SocVersion is " << soc_version; + // Get soc version from env value. + const char *soc_version_env = getenv(kSOC_VERSION); + if (soc_version_env != nullptr) { + if (std::strcmp(soc_version, soc_version_env) != 0) { + MS_LOG(WARNING) << "SocVerison change to " << soc_version_env; + ret = rtSetSocVersion(soc_version_env); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "SetSocVersion to " << soc_version_env << " failed, errorno: " << ret; + } + return soc_version_env; + } + } + return soc_version; +} + void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, nlohmann::json *attr_obj) { MS_EXCEPTION_IF_NULL(value); @@ -630,6 +659,8 @@ bool TbeKernelBuild::GenFusionScopeJson(const std::vector index = 0; data_list.insert(data_list.end(), compute_list.begin(), compute_list.end()); (*fusion_json)[kFusionOpList] = data_list; + auto soc_version = TbeKernelJsonCreator::GetSocVersion(); + (*fusion_json)[kJSocVersion] = soc_version; return true; } @@ -859,6 +890,7 @@ bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptremplace_back(nullptr); input_desc_list_tmp.emplace_back(optional_input_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h index d4cfe7866d..9c760bdcbb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h @@ -92,6 +92,7 @@ class TbeKernelJsonCreator { std::string json_name() { return json_name_; } bool GenTbeAttrJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, nlohmann::json *attrs_json); + static string GetSocVersion(); private: bool GenTbeInputsJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, diff --git a/tests/ut/cpp/stub/runtime/runtime_stub.cc b/tests/ut/cpp/stub/runtime/runtime_stub.cc index f5ca261cdc..400153aa62 100644 --- a/tests/ut/cpp/stub/runtime/runtime_stub.cc +++ b/tests/ut/cpp/stub/runtime/runtime_stub.cc @@ -53,6 +53,10 @@ rtError_t rtFunctionRegister(void *binHandle, const void *stubFunc, const char * return RT_ERROR_NONE; } +RTS_API rtError_t rtSetSocVersion(const char *version) { return RT_ERROR_NONE; } + +rtError_t rtGetSocVersion(char *version, const uint32_t maxLen) { return RT_ERROR_NONE; } + rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream) { return RT_ERROR_NONE;