diff --git a/CMakeLists.txt b/CMakeLists.txt index fb526f1b..56f6ad63 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,19 @@ if (NOT BUILD_PATH) set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") endif() +if(DEFINED ENV{ASCEND_CUSTOM_PATH}) + set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) +else() + set(ASCEND_DIR /usr/local/Ascend) +endif() +set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) +set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) +set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) +set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) +set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) +set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) +set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) + option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) if (ENABLE_OPEN_SRC) @@ -41,7 +54,7 @@ if (ENABLE_OPEN_SRC) message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") endif() set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) - set(STATIC_ACL_LIB ${GE_LIB_PATH}) + set(STATIC_ACL_LIB ${GE_LIB_PATH}) find_module(slog libslog.so ${GE_LIB_PATH}) find_module(mmpa libmmpa.so ${GE_LIB_PATH}) find_module(msprof libmsprof.so ${GE_LIB_PATH}) @@ -56,18 +69,6 @@ if (ENABLE_OPEN_SRC) find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) else() - if(DEFINED ENV{ASCEND_CUSTOM_PATH}) - set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) - else() - set(ASCEND_DIR /usr/local/Ascend) - endif() - set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) - set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) - set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) - set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) - set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) - set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) - set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) find_module(slog libslog.so ${ASCEND_ATC_DIR}) find_module(mmpa libmmpa.so ${ASCEND_ATC_DIR}) if(PLATFORM STREQUAL "train") @@ -127,6 +128,36 @@ if (ENABLE_OPEN_SRC) add_subdirectory(parser) #add_subdirectory(metadef/graph) #add_subdirectory(metadef/register) +elseif (ENABLE_D OR ENABLE_ACL) + # compiling with MindSpore + include(cmake/external_libs/protobuf_static.cmake) + include(cmake/external_libs/protoc.cmake) + include(cmake/external_libs/securec.cmake) + include(cmake/external_libs/json.cmake) + include(cmake/FindModule.cmake) + include(cmake/intf_pub_linux.cmake) + + # common modules + find_module(slog libslog.so ${ASCEND_DRIVER_COMMON_DIR}) + find_module(mmpa libmmpa.so ${ASCEND_DRIVER_COMMON_DIR}) + + if (ENABLE_D) + # training + find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) + find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) + find_module(register libregister.so ${ASCEND_RUNTIME_DIR}) + find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) + elseif(ENABLE_ACL) + # inference + find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) + find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) + find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) + find_module(resource libresource.so ${ASCEND_ATC_DIR}) + find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) + endif () + + set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) + add_subdirectory(metadef) else() set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) diff --git a/cmake/external_libs/protobuf_static.cmake b/cmake/external_libs/protobuf_static.cmake index 25570cef..51f4ed25 100755 --- a/cmake/external_libs/protobuf_static.cmake +++ b/cmake/external_libs/protobuf_static.cmake @@ -48,5 +48,8 @@ set_target_properties(ascend_protobuf_static_lib PROPERTIES add_library(ascend_protobuf_static INTERFACE) target_include_directories(ascend_protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include) target_link_libraries(ascend_protobuf_static INTERFACE ascend_protobuf_static_lib) +if (ENABLE_D OR ENABLE_ACL) +include_directories(${PROTOBUF_STATIC_PKG_DIR}/include) +endif () add_dependencies(ascend_protobuf_static protobuf_static_build) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 896c767d..b1dd11b9 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -1,10 +1,15 @@ -add_subdirectory(common) -add_subdirectory(plugin/engine) -add_subdirectory(graph/build/memory) -add_subdirectory(ge_local_engine) -add_subdirectory(host_cpu_engine) -add_subdirectory(executor) -add_subdirectory(offline) +if (NOT ENABLE_D AND NOT ENABLE_ACL) + add_subdirectory(common) + add_subdirectory(plugin/engine) + add_subdirectory(graph/build/memory) + add_subdirectory(ge_local_engine) + add_subdirectory(host_cpu_engine) + add_subdirectory(executor) + add_subdirectory(offline) +else() + add_subdirectory(common) + add_subdirectory(ge_runtime) +endif () set(PROTO_LIST "${METADEF_DIR}/proto/fusion_model.proto" @@ -28,7 +33,6 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) -############ libge_runner.so ############ set(TRAIN_SRC_LIST "common/formats/format_transfers/datatype_transfer.cc" "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" @@ -333,72 +337,6 @@ set(TRAIN_SRC_LIST "ir_build/atc_ir_common.cc" ) -add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) - -target_compile_definitions(ge_runner PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - DAVINCI_SUPPORT_PROFILING - REUSE_MEMORY=1 - FMK_SUPPORT_DUMP - DAVINCI_CLOUD - google=ascend_private -) - -target_compile_options(ge_runner PRIVATE - -O2 -) - -target_include_directories(ge_runner PRIVATE - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/analyzer - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${GE_CODE_DIR}/inc/framework/common - ${METADEF_DIR} - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - ${GE_CODE_DIR}/../inc/external - ${GE_CODE_DIR}/../inc/cce - ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external - #### blue zone - ${ASCEND_DIR}/driver/include - ${ASCEND_DIR}/fwkacllib/include - ${GE_CODE_DIR}/third_party/fwkacllib/inc - ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain -) - -target_link_libraries(ge_runner - $ - ge_memory - adump_server - msprofiler - -Wl,--no-as-needed - graph - ge_common - ascend_protobuf - register - c_sec - slog - mmpa - msprof - runtime - resource - error_manager - ascend_hal_stub - -Wl,--as-needed - json - -lrt - -ldl -) - -############ libge_compiler.so ############ set(INFER_SRC_LIST "graph/manager/trans_var_data_utils.cc" "omm/csa_interact.cc" @@ -662,6 +600,74 @@ set(INFER_SRC_LIST "analyzer/analyzer.cc" ) +if (NOT ENABLE_D AND NOT ENABLE_ACL) +############ libge_runner.so ############ +add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) + +target_compile_definitions(ge_runner PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + DAVINCI_SUPPORT_PROFILING + REUSE_MEMORY=1 + FMK_SUPPORT_DUMP + DAVINCI_CLOUD + google=ascend_private +) + +target_compile_options(ge_runner PRIVATE + -O2 +) + +target_include_directories(ge_runner PRIVATE + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/ge/analyzer + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${GE_CODE_DIR}/inc/framework/common + ${METADEF_DIR} + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + ${GE_CODE_DIR}/../inc/external + ${GE_CODE_DIR}/../inc/cce + ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external + #### blue zone + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include + ${GE_CODE_DIR}/third_party/fwkacllib/inc + ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain +) + +target_link_libraries(ge_runner + $ + ge_memory + adump_server + msprofiler + -Wl,--no-as-needed + graph + ge_common + ascend_protobuf + register + c_sec + slog + mmpa + msprof + runtime + resource + error_manager + ascend_hal_stub + -Wl,--as-needed + json + -lrt + -ldl +) + +############ libge_compiler.so ############ add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) target_compile_definitions(ge_compiler PRIVATE @@ -919,3 +925,70 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt OPTIONAL DESTINATION ${INSTALL_LIBRARY_DIR} ) + +elseif (ENABLE_ACL) + +############ libge_compiler.so w/static protobuf ############ +add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) + +target_compile_definitions(ge_compiler PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + REUSE_MEMORY=1 + FMK_SUPPORT_DUMP + FMK_HOST_INFER + COMPILE_OMG_PACKAGE + google=ascend_private +) + +target_compile_options(ge_compiler PRIVATE + -O2 +) + +target_include_directories(ge_compiler PRIVATE + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/ge/analyzer + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${GE_CODE_DIR}/inc/framework/common + ${METADEF_DIR} + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include + ${GE_CODE_DIR}/third_party/fwkacllib/inc + ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain +) + +target_link_libraries(ge_compiler + $ + ge_memory + -Wl,--no-as-needed + graph + ge_common + static_ascend_protobuf + register + c_sec + error_manager + slog + mmpa + runtime_compile + resource + -Wl,--as-needed + json + -lrt + -ldl +) + +############ install libge_compiler for MindSpore############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS ge_compiler OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) +endif() diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt index 4537f8a5..14e9b6af 100755 --- a/ge/common/CMakeLists.txt +++ b/ge/common/CMakeLists.txt @@ -63,6 +63,7 @@ set(SRC_LIST "ge/tbe_plugin_manager.cc" ) +if (NOT ENABLE_D AND NOT ENABLE_ACL) ############ libge_common.so ############ add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) target_compile_definitions(ge_common PRIVATE @@ -164,6 +165,57 @@ target_link_libraries(ge_common_static PRIVATE -ldl ) +else () +############ libge_common.so w/static protobuf ############ +add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) +target_compile_definitions(ge_common PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + HOST_VISIBILITY + FMK_SUPPORT_DUMP + OS_CENTOS + google=ascend_private +) + +target_compile_options(ge_common PRIVATE + -fvisibility=hidden + -O2 + -Werror +) + +target_include_directories(ge_common PRIVATE + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/ge/common + ${GE_CODE_DIR}/ge/common/op + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + ${GE_CODE_DIR}/third_party/fwkacllib/inc + ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain +) + +target_link_libraries(ge_common PRIVATE + $ + ascend_protobuf_static + -Wl,--no-as-needed + graph + register + c_sec + error_manager + slog + mmpa + -Wl,--as-needed + json + -lrt + -ldl +) +endif () + ############ install ############ set(INSTALL_BASE_DIR "") set(INSTALL_LIBRARY_DIR lib) diff --git a/ge/ge_runtime/CMakeLists.txt b/ge/ge_runtime/CMakeLists.txt index f46e2cdc..42d3b344 100644 --- a/ge/ge_runtime/CMakeLists.txt +++ b/ge/ge_runtime/CMakeLists.txt @@ -27,14 +27,22 @@ target_compile_definitions(ge_runtime PRIVATE ) target_include_directories(ge_runtime PRIVATE - ${TOP_DIR} - ${TOP_DIR}/inc - ${TOP_DIR}/inc/graph - ${TOP_DIR}/inc/external - ${TOP_DIR}/inc/framework - ${TOP_DIR}/inc/framework/common - ${TOP_DIR}/inc/framework/ge_runtime - ${TOP_DIR}/inc/cce + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/graph + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${GE_CODE_DIR}/inc/framework/common + ${GE_CODE_DIR}/inc/framework/ge_runtime + ${GE_CODE_DIR}/inc/cce + ${GE_CODE_DIR}/third_party/fwkacllib/inc + ${METADEF_DIR} + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/ge ) @@ -45,6 +53,7 @@ target_link_libraries(ge_runtime PRIVATE slog runtime c_sec + graph -Wl,--as-needed -lrt -ldl diff --git a/ge/ge_runtime/model_context.h b/ge/ge_runtime/model_context.h old mode 100755 new mode 100644 index 8860f0da..259ff91f --- a/ge/ge_runtime/model_context.h +++ b/ge/ge_runtime/model_context.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,8 +27,13 @@ class ModelContext { ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle, rtStream_t rt_model_stream, const std::vector &stream_list, const std::vector &label_list, const std::vector &event_list) - : device_id_(device_id), session_id_(session_id), priority_(priority), rt_model_handle_(rt_model_handle), - rt_model_stream_(rt_model_stream), stream_list_(stream_list), label_list_(label_list), + : device_id_(device_id), + session_id_(session_id), + priority_(priority), + rt_model_handle_(rt_model_handle), + rt_model_stream_(rt_model_stream), + stream_list_(stream_list), + label_list_(label_list), event_list_(event_list) {} ~ModelContext() {} diff --git a/ge/ge_runtime/model_runner.cc b/ge/ge_runtime/model_runner.cc index 2c2efde4..9961ab4e 100644 --- a/ge/ge_runtime/model_runner.cc +++ b/ge/ge_runtime/model_runner.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ namespace ge { namespace model_runner { + using RuntimeModelPtr = std::shared_ptr; using DavinciModelPtr = std::shared_ptr; diff --git a/ge/ge_runtime/output.cc b/ge/ge_runtime/output.cc index eec8d170..90c33bb4 100644 --- a/ge/ge_runtime/output.cc +++ b/ge/ge_runtime/output.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -89,5 +89,6 @@ bool Output::SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &dat bool support_mem_share) { return true; } + } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/output.h b/ge/ge_runtime/output.h old mode 100755 new mode 100644 index 13ea956d..1f7f91ee --- a/ge/ge_runtime/output.h +++ b/ge/ge_runtime/output.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ namespace ge { namespace model_runner { + class Output { public: Output(const OpInfoPtr &op_info, const std::shared_ptr &model); @@ -32,8 +33,7 @@ class Output { bool CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); - bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, - bool support_mem_share); + bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, bool support_mem_share); // Copy assignment operator and copy constructor are deleted Output &operator=(const Output &output) = delete; diff --git a/ge/ge_runtime/runtime_model.cc b/ge/ge_runtime/runtime_model.cc index 0b76cbaf..0ff56ef1 100644 --- a/ge/ge_runtime/runtime_model.cc +++ b/ge/ge_runtime/runtime_model.cc @@ -74,8 +74,8 @@ bool RuntimeModel::InitStream(std::shared_ptr &davinci_model) { for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) { rtStream_t stream = nullptr; uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end()) - ? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) - : (RT_STREAM_PERSISTENT); + ? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) + : (RT_STREAM_PERSISTENT); rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag); if (rt_ret != RT_ERROR_NONE) { @@ -115,23 +115,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { return true; } -bool RuntimeModel::InitLabel(uint32_t batch_num) { - GELOGI("batch number:%u.", batch_num); - for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { - rtLabel_t rt_lLabel = nullptr; - rtError_t rt_ret = rtLabelCreate(&rt_lLabel); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); - return false; +bool RuntimeModel::InitLabel(std::shared_ptr &davinci_model) { + GELOGI("batch number:%u.", davinci_model->GetBatchNum()); + label_list_.resize(davinci_model->GetBatchNum()); + for (auto &task_info : davinci_model->GetTaskInfoList()) { + if (task_info == nullptr) { + GELOGE(PARAM_INVALID, "task_info is null."); + continue; + } + + if (task_info->type() != TaskInfoType::LABEL_SET) { + continue; } + auto label_set_task_info = std::static_pointer_cast(task_info); - if (rt_lLabel == nullptr) { - GELOGE(RT_FAILED, "rtLabel is nullptr!"); + if (label_set_task_info->stream_id() >= stream_list_.size()) { + GELOGE(PARAM_INVALID, "Invalid stream id."); return false; } - label_list_.emplace_back(rt_lLabel); + rtLabel_t rt_label = nullptr; + rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); + return false; + } + label_list_[label_set_task_info->label_id()] = rt_label; } + return true; } @@ -163,7 +174,7 @@ bool RuntimeModel::InitResource(std::shared_ptr &davinci_model) { return false; } - if (!InitLabel(davinci_model->GetBatchNum())) { + if (!InitLabel(davinci_model)) { return false; } @@ -281,7 +292,6 @@ bool RuntimeModel::DistributeTask() { GELOGE(FAILED, "DistributeTask failed"); return false; } - return true; } @@ -293,10 +303,14 @@ bool RuntimeModel::Run() { return false; } - GELOGI("Run rtModelExecute success"); + GELOGI("Run rtModelExecute success, ret = 0x%X", ret); ret = rtStreamSynchronize(rt_model_stream_); if (ret != RT_ERROR_NONE) { + if (ret == RT_ERROR_END_OF_SEQUENCE) { + GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret); + return true; + } GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); return false; } @@ -330,6 +344,9 @@ void RuntimeModel::RtStreamDestory() noexcept { void RuntimeModel::RtLabelDestory() noexcept { for (size_t i = 0; i < label_list_.size(); i++) { + if (label_list_[i] == nullptr) { + continue; + } if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Destroy label failed! Index: %zu.", i); return; @@ -471,11 +488,8 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr &davinci_model /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero /// and that of unknown shape is zero too. /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. - int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); - if (elem_num == 0 && constant->weight_tensors[0].size == 0) { - elem_num = 1; - } - + int64_t elem_num = + (constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); if (constant->weight_data.size() < sizeof(uint64_t)) { GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); return false; diff --git a/ge/ge_runtime/runtime_model.h b/ge/ge_runtime/runtime_model.h index 6109915f..d0c466d4 100644 --- a/ge/ge_runtime/runtime_model.h +++ b/ge/ge_runtime/runtime_model.h @@ -40,13 +40,11 @@ class RuntimeModel { const std::vector &GetTaskIdList() const; const std::vector &GetStreamIdList() const; const std::map> &GetRuntimeInfoMap() const { return runtime_info_map_; } - const rtModel_t GetModelHandle() const { return rt_model_handle_; } + rtModel_t GetModelHandle() const { return rt_model_handle_; } bool Run(); bool CopyInputData(const InputData &input_data); - bool GetInputOutputDescInfo(bool zero_copy, - std::vector *input_desc, - std::vector *output_desc, - std::vector *input_format, + bool GetInputOutputDescInfo(bool zero_copy, std::vector *input_desc, + std::vector *output_desc, std::vector *input_format, std::vector *output_format); private: @@ -55,7 +53,7 @@ class RuntimeModel { bool LoadTask(); bool InitStream(std::shared_ptr &davinci_model); bool InitEvent(uint32_t event_num); - bool InitLabel(uint32_t batch_num); + bool InitLabel(std::shared_ptr &davinci_model); bool InitDataInfo(std::shared_ptr &davinci_model); bool InitOutputInfo(std::shared_ptr &davinci_model); bool InitConstantInfo(std::shared_ptr &davinci_model); @@ -87,6 +85,7 @@ class RuntimeModel { std::vector stream_id_list_{}; std::map> runtime_info_map_; }; + } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/task/aicpu_task.cc b/ge/ge_runtime/task/aicpu_task.cc old mode 100755 new mode 100644 index 61ef7a3c..5b3d8e82 --- a/ge/ge_runtime/task/aicpu_task.cc +++ b/ge/ge_runtime/task/aicpu_task.cc @@ -26,6 +26,7 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr(io_addrs.size()); auto io_addrs_size = static_cast(io_addrs_num * sizeof(void *)); constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); - uint32_t node_def_addr_offset = io_addr_offset + io_addrs_size; - uint32_t args_size = - sizeof(aicpu::AicpuParamHead) + io_addrs_size + static_cast(task_info_->node_def().size()); - aicpu::AicpuParamHead aicpu_param_head = {args_size, io_addrs_num}; + uint32_t node_def_len_offset = io_addr_offset + io_addrs_size; + uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t); + uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size + + static_cast(task_info_->node_def().size()) + sizeof(uint32_t); + + aicpu::AicpuParamHead aicpu_param_head; + aicpu_param_head.length = args_size; + aicpu_param_head.ioAddrNum = io_addrs_num; + auto ext_info = task_info_->ext_info(); + uint32_t ext_size = ext_info.size(); + if (ext_info.empty()) { + aicpu_param_head.extInfoLength = 0; + aicpu_param_head.extInfoAddr = 0; + } else { + rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM); + if (flag != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", flag); + return false; + } + + flag = rtMemcpy(ext_info_, ext_size, const_cast(reinterpret_cast(ext_info.data())), ext_size, + RT_MEMCPY_HOST_TO_DEVICE); + if (flag != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag); + return false; + } + + GELOGI("ext info size:", ext_size); + aicpu_param_head.extInfoLength = ext_size; + aicpu_param_head.extInfoAddr = reinterpret_cast(ext_info_); + } // Malloc device memory for args rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); @@ -80,6 +111,17 @@ bool AicpuTask::Distribute() { return false; } } + + // Memcpy node def + auto size = task_info_->node_def().size(); + rt_ret = + rtMemcpy(reinterpret_cast(reinterpret_cast(args_) + node_def_len_offset), sizeof(uint32_t), + reinterpret_cast(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); + return false; + } + // Memcpy node def rt_ret = rtMemcpy(reinterpret_cast(reinterpret_cast(args_) + node_def_addr_offset), task_info_->node_def().size(), reinterpret_cast(task_info_->node_def().data()), diff --git a/ge/ge_runtime/task/aicpu_task.h b/ge/ge_runtime/task/aicpu_task.h old mode 100755 new mode 100644 index cc21af8a..2d3c5040 --- a/ge/ge_runtime/task/aicpu_task.h +++ b/ge/ge_runtime/task/aicpu_task.h @@ -41,6 +41,7 @@ class AicpuTask : public TaskRepeater { std::shared_ptr task_info_; void *stream_; void *args_; + void *ext_info_; void *input_output_addr_; }; } // namespace model_runner diff --git a/ge/ge_runtime/task/cce_task.cc b/ge/ge_runtime/task/cce_task.cc old mode 100755 new mode 100644 index 1c1807b5..04fd5610 --- a/ge/ge_runtime/task/cce_task.cc +++ b/ge/ge_runtime/task/cce_task.cc @@ -103,9 +103,9 @@ bool CceTask::Distribute() { // Modify flowtable addr in args auto args = const_cast(task_info_->args().data()); auto task_offset = reinterpret_cast(const_cast(task_info_->args_offset().data())); + if (task_info_->args().size() < (task_offset[0] + sizeof(uint64_t))) { - GELOGE(FAILED, - "(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu", + GELOGE(FAILED, "(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu", static_cast(task_offset[0]), sizeof(uint64_t), task_info_->args().size()); return false; } @@ -136,8 +136,7 @@ bool CceTask::Distribute() { return false; } - rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), - task_info_->sm_desc().data(), + rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), task_info_->sm_desc().data(), task_info_->sm_desc().size(), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); @@ -146,12 +145,8 @@ bool CceTask::Distribute() { } // Kernel launch - rt_ret = rtKernelLaunch(stub_func_, - task_info_->block_dim(), - args_, - task_info_->args_size(), - static_cast(sm_desc_), - stream_); + rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, task_info_->args_size(), + static_cast(sm_desc_), stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return false; diff --git a/ge/ge_runtime/task/cce_task.h b/ge/ge_runtime/task/cce_task.h old mode 100755 new mode 100644 diff --git a/ge/ge_runtime/task/event_record_task.h b/ge/ge_runtime/task/event_record_task.h old mode 100755 new mode 100644 index b9ae5dba..7c1d4f80 --- a/ge/ge_runtime/task/event_record_task.h +++ b/ge/ge_runtime/task/event_record_task.h @@ -33,7 +33,7 @@ class EventRecordTask : public TaskRepeater { private: std::shared_ptr task_info_; rtStream_t stream_; - rtEvent_t event_; + rtEvent_t event_; }; } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/task/event_wait_task.cc b/ge/ge_runtime/task/event_wait_task.cc index 5f1ffaad..558c2a59 100644 --- a/ge/ge_runtime/task/event_wait_task.cc +++ b/ge/ge_runtime/task/event_wait_task.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/ge/ge_runtime/task/event_wait_task.h b/ge/ge_runtime/task/event_wait_task.h old mode 100755 new mode 100644 index 685be897..9104bbf8 --- a/ge/ge_runtime/task/event_wait_task.h +++ b/ge/ge_runtime/task/event_wait_task.h @@ -33,7 +33,7 @@ class EventWaitTask : public TaskRepeater { private: std::shared_ptr task_info_; rtStream_t stream_; - rtEvent_t event_; + rtEvent_t event_; }; } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/task/hccl_task.cc b/ge/ge_runtime/task/hccl_task.cc index 771341c1..3d5f8504 100644 --- a/ge/ge_runtime/task/hccl_task.cc +++ b/ge/ge_runtime/task/hccl_task.cc @@ -115,7 +115,6 @@ bool HcclTask::Distribute() { rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - (void)rtStreamDestroy(stream); return false; } @@ -129,8 +128,6 @@ bool HcclTask::Distribute() { ge_task.type = static_cast(RT_MODEL_TASK_HCCL); ge_task.stream = stream_; - GETaskKernelHcclInfo kernel_hccl_info; - ge_task.kernelHcclInfo.emplace_back(kernel_hccl_info); ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); diff --git a/ge/ge_runtime/task/hccl_task.h b/ge/ge_runtime/task/hccl_task.h old mode 100755 new mode 100644 diff --git a/ge/ge_runtime/task/label_goto_task.cc b/ge/ge_runtime/task/label_goto_task.cc new file mode 100644 index 00000000..d357accb --- /dev/null +++ b/ge/ge_runtime/task/label_goto_task.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_runtime/task/label_goto_task.h" +#include "ge_runtime/task/task_factory.h" + +namespace ge { +namespace model_runner { +LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + label_(nullptr) { + if (task_info_ == nullptr) { + GELOGW("task_info_ is null!"); + return; + } + auto stream_list = model_context.stream_list(); + auto label_list = model_context.label_list(); + uint32_t stream_id = task_info->stream_id(); + uint32_t label_id = task_info->label_id(); + GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); + GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); + if (stream_id >= stream_list.size() || label_id >= label_list.size()) { + GELOGW("Stream/Label id invalid."); + return; + } + stream_ = stream_list[stream_id]; + label_ = label_list[label_id]; +} + +LabelGotoTask::~LabelGotoTask() {} + +bool LabelGotoTask::Distribute() { + GELOGI("LabelGotoTask Distribute start."); + if (stream_ == nullptr) { + GELOGE(PARAM_INVALID, "stream is null!"); + return false; + } + if (label_ == nullptr) { + GELOGE(PARAM_INVALID, "label is null!"); + return false; + } + rtError_t rt_ret = rtLabelGotoEx(label_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + GELOGI("DistributeTask end."); + return true; +} + +REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); + +} // namespace model_runner +} // namespace ge diff --git a/ge/ge_runtime/task/label_goto_task.h b/ge/ge_runtime/task/label_goto_task.h new file mode 100644 index 00000000..4fd6d1bc --- /dev/null +++ b/ge/ge_runtime/task/label_goto_task.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ +#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ + +#include +#include "ge_runtime/task/task.h" + +namespace ge { +namespace model_runner { +class LabelGotoTask : public TaskRepeater { + public: + LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~LabelGotoTask() override; + + bool Distribute() override; + + private: + std::shared_ptr task_info_; + void *stream_; + void *label_; +}; +} // namespace model_runner +} // namespace ge + +#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ diff --git a/ge/ge_runtime/task/label_set_task.cc b/ge/ge_runtime/task/label_set_task.cc new file mode 100644 index 00000000..3ab5802c --- /dev/null +++ b/ge/ge_runtime/task/label_set_task.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_runtime/task/label_set_task.h" +#include "ge_runtime/task/task_factory.h" + +namespace ge { +namespace model_runner { +LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + label_(nullptr) { + if (task_info_ == nullptr) { + GELOGW("task_info_ is null!"); + return; + } + auto stream_list = model_context.stream_list(); + auto label_list = model_context.label_list(); + uint32_t stream_id = task_info->stream_id(); + uint32_t label_id = task_info->label_id(); + GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); + GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); + if (stream_id >= stream_list.size() || label_id >= label_list.size()) { + GELOGW("Stream/Label id invalid."); + return; + } + stream_ = stream_list[stream_id]; + label_ = label_list[label_id]; +} + +LabelSetTask::~LabelSetTask() {} + +bool LabelSetTask::Distribute() { + GELOGI("LabelSetTask Distribute start."); + if (stream_ == nullptr) { + GELOGE(PARAM_INVALID, "stream is null!"); + return false; + } + if (label_ == nullptr) { + GELOGE(PARAM_INVALID, "label is null!"); + return false; + } + rtError_t rt_ret = rtLabelSet(label_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + GELOGI("DistributeTask end."); + return true; +} + +REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); + +} // namespace model_runner +} // namespace ge diff --git a/ge/ge_runtime/task/label_set_task.h b/ge/ge_runtime/task/label_set_task.h new file mode 100644 index 00000000..70bf1584 --- /dev/null +++ b/ge/ge_runtime/task/label_set_task.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ +#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ + +#include +#include "ge_runtime/task/task.h" + +namespace ge { +namespace model_runner { +class LabelSetTask : public TaskRepeater { + public: + LabelSetTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~LabelSetTask() override; + + bool Distribute() override; + + private: + std::shared_ptr task_info_; + void *stream_; + void *label_; +}; +} // namespace model_runner +} // namespace ge + +#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ diff --git a/ge/ge_runtime/task/label_switch_task.cc b/ge/ge_runtime/task/label_switch_task.cc new file mode 100644 index 00000000..a3c2d41a --- /dev/null +++ b/ge/ge_runtime/task/label_switch_task.cc @@ -0,0 +1,131 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_runtime/task/label_switch_task.h" +#include "ge_runtime/task/task_factory.h" + +namespace ge { +namespace model_runner { +LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, + const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + all_label_resource_(), + label_info_(nullptr) { + if (task_info_ == nullptr) { + GELOGW("task_info_ is null!"); + return; + } + + all_label_resource_ = model_context.label_list(); + auto stream_list = model_context.stream_list(); + uint32_t stream_id = task_info->stream_id(); + GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); + if (stream_id >= stream_list.size()) { + GELOGW("Stream id invalid."); + return; + } + stream_ = stream_list[stream_id]; +} + +LabelSwitchTask::~LabelSwitchTask() { + if (label_info_ != nullptr) { + rtError_t rt_ret = rtFree(label_info_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); + } + label_info_ = nullptr; + } +} + +bool LabelSwitchTask::Distribute() { + GELOGI("LabelSwitchTask Distribute start."); + if (!CheckParamValid()) { + return false; + } + + const std::vector &label_index_list = task_info_->label_list(); + std::vector label_list(task_info_->label_size(), nullptr); + + for (size_t i = 0; i < task_info_->label_size(); ++i) { + uint32_t label_index = label_index_list[i]; + if (label_index >= all_label_resource_.size()) { + GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, + all_label_resource_.size()); + return false; + } + label_list[i] = all_label_resource_[label_index]; + GELOGI("Case %zu: label id %zu.", i, label_index); + } + + uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); + rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + GELOGI("DistributeTask end."); + return true; +} + +bool LabelSwitchTask::CheckParamValid() { + if (stream_ == nullptr) { + GELOGE(PARAM_INVALID, "stream is null!"); + return false; + } + + if (task_info_->label_list().empty()) { + GELOGE(PARAM_INVALID, "label_list is empty."); + return false; + } + + if (task_info_->label_size() != task_info_->label_list().size()) { + GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), + task_info_->label_size()); + return false; + } + + if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { + GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); + return false; + } + + if (label_info_ != nullptr) { + GELOGE(PARAM_INVALID, "label_info_ has dirty data."); + return false; + } + + return true; +} + +REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); + +} // namespace model_runner +} // namespace ge diff --git a/ge/ge_runtime/task/label_switch_task.h b/ge/ge_runtime/task/label_switch_task.h new file mode 100644 index 00000000..463faa31 --- /dev/null +++ b/ge/ge_runtime/task/label_switch_task.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ +#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ + +#include +#include "ge_runtime/task/task.h" + +namespace ge { +namespace model_runner { +class LabelSwitchTask : public TaskRepeater { + public: + LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~LabelSwitchTask() override; + + bool Distribute() override; + + private: + bool CheckParamValid(); + + std::shared_ptr task_info_; + void *stream_; + std::vector all_label_resource_; + void *label_info_; +}; +} // namespace model_runner +} // namespace ge + +#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ diff --git a/ge/ge_runtime/task/memcpy_async_task.h b/ge/ge_runtime/task/memcpy_async_task.h old mode 100755 new mode 100644 diff --git a/ge/ge_runtime/task/profiler_task.h b/ge/ge_runtime/task/profiler_task.h old mode 100755 new mode 100644 diff --git a/ge/ge_runtime/task/stream_active_task.h b/ge/ge_runtime/task/stream_active_task.h old mode 100755 new mode 100644 diff --git a/ge/ge_runtime/task/stream_switch_task.h b/ge/ge_runtime/task/stream_switch_task.h old mode 100755 new mode 100644 index 2caad200..81c12507 --- a/ge/ge_runtime/task/stream_switch_task.h +++ b/ge/ge_runtime/task/stream_switch_task.h @@ -37,6 +37,7 @@ class StreamSwitchTask : public TaskRepeater { void *stream_; std::vector stream_list_; }; + } // namespace model_runner } // namespace ge #endif // GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ diff --git a/ge/ge_runtime/task/task.h b/ge/ge_runtime/task/task.h old mode 100755 new mode 100644 index b8a937b7..6c4df248 --- a/ge/ge_runtime/task/task.h +++ b/ge/ge_runtime/task/task.h @@ -42,7 +42,7 @@ class Task { template class TaskRepeater : public Task { - static_assert(std::is_base_of(), "Wrong TaskInfo Type!"); /*lint !e30*/ + static_assert(std::is_base_of(), "Wrong TaskInfo Type!"); public: TaskRepeater(const ModelContext &model_context, std::shared_ptr task_info) {} diff --git a/ge/ge_runtime/task/task_factory.h b/ge/ge_runtime/task/task_factory.h index 29da1388..670d1fef 100644 --- a/ge/ge_runtime/task/task_factory.h +++ b/ge/ge_runtime/task/task_factory.h @@ -81,6 +81,7 @@ class TaskFactory { std::shared_ptr concrete_task_info = std::static_pointer_cast(task_info); \ return std::make_shared(model_context, concrete_task_info); \ }); + } // namespace model_runner } // namespace ge #endif // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ diff --git a/ge/ge_runtime/task/tbe_task.cc b/ge/ge_runtime/task/tbe_task.cc old mode 100755 new mode 100644 diff --git a/ge/ge_runtime/task/tbe_task.h b/ge/ge_runtime/task/tbe_task.h old mode 100755 new mode 100644 diff --git a/inc/framework/ge_runtime/davinci_model.h b/inc/framework/ge_runtime/davinci_model.h index 8b6ca978..91e70159 100644 --- a/inc/framework/ge_runtime/davinci_model.h +++ b/inc/framework/ge_runtime/davinci_model.h @@ -27,10 +27,10 @@ namespace ge { namespace model_runner { class DavinciModel { public: - DavinciModel(const std::vector> &task_info_list, /*lint !e151*/ + DavinciModel(const std::vector> &task_info_list, const std::vector> &data_info_list, - const std::vector> &output_info_list, /*lint !e151*/ - const std::vector> &constant_info_list, /*lint !e1049*/ + const std::vector> &output_info_list, + const std::vector> &constant_info_list, const std::vector &variable_info_list, const std::vector &wait_active_stream_list, const std::vector &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, @@ -68,12 +68,12 @@ class DavinciModel { uint32_t GetBatchNum() const { return batch_num_; } uint32_t GetEventNum() const { return event_num_; } - const std::vector &GetWaitActiveStreams() const { return wait_active_stream_list_; } /*lint !e1413*/ - const std::vector &GetForceCopyStreams() const { return force_copy_stream_list_; } /*lint !e1413*/ + const std::vector &GetWaitActiveStreams() const { return wait_active_stream_list_; } + const std::vector &GetForceCopyStreams() const { return force_copy_stream_list_; } int32_t GetPriority() const { return priority_; } - const std::vector> &GetTaskInfoList() const { return task_info_list_; } /*lint !e151*/ + const std::vector> &GetTaskInfoList() const { return task_info_list_; } const std::vector> &GetDataInfoList() const { return data_info_list_; } const std::vector> &GetOutputInfoList() const { return output_info_list_; } const std::vector> &GetConstantInfoList() const { return output_info_list_; } @@ -81,7 +81,7 @@ class DavinciModel { private: std::vector> task_info_list_; - std::vector> data_info_list_; /*lint !e151*/ + std::vector> data_info_list_; std::vector> output_info_list_; std::vector> constant_info_list_; std::vector variable_info_list_; diff --git a/inc/framework/ge_runtime/model_runner.h b/inc/framework/ge_runtime/model_runner.h index a5256af7..e495dfdf 100644 --- a/inc/framework/ge_runtime/model_runner.h +++ b/inc/framework/ge_runtime/model_runner.h @@ -52,11 +52,8 @@ class ModelRunner { bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); - bool GetInputOutputDescInfo(uint32_t model_id, - bool zero_copy, - std::vector *input_desc, - std::vector *output_desc, - std::vector *input_format, + bool GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, std::vector *input_desc, + std::vector *output_desc, std::vector *input_format, std::vector *output_format); private: diff --git a/inc/framework/ge_runtime/task_info.h b/inc/framework/ge_runtime/task_info.h index 86119219..e36c4333 100644 --- a/inc/framework/ge_runtime/task_info.h +++ b/inc/framework/ge_runtime/task_info.h @@ -161,12 +161,13 @@ class TbeTaskInfo : public TaskInfo { class AicpuTaskInfo : public TaskInfo { public: AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, - const std::string &node_def, const std::vector &input_data_addrs, + const std::string &node_def, const std::string &ext_info, const std::vector &input_data_addrs, const std::vector &output_data_addrs, bool dump_flag) : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), so_name_(so_name), kernel_name_(kernel_name), node_def_(node_def), + ext_info_(ext_info), input_data_addrs_(input_data_addrs), output_data_addrs_(output_data_addrs) {} ~AicpuTaskInfo() override {} @@ -176,11 +177,13 @@ class AicpuTaskInfo : public TaskInfo { const std::string &node_def() const { return node_def_; } const std::vector &input_data_addrs() const { return input_data_addrs_; } const std::vector &output_data_addrs() const { return output_data_addrs_; } + const std::string &ext_info() const { return ext_info_; } private: std::string so_name_; std::string kernel_name_; std::string node_def_; + std::string ext_info_; std::vector input_data_addrs_; std::vector output_data_addrs_; }; @@ -293,19 +296,19 @@ class HcclTaskInfo : public TaskInfo { hcom_distribute_task_(hcom_distribute_task) {} ~HcclTaskInfo() override {} - const std::string &hccl_type() const { return hccl_type_; } /*lint !e1413*/ + const std::string &hccl_type() const { return hccl_type_; } void *input_data_addr() const { return input_data_addr_; } void *output_data_addr() const { return output_data_addr_; } void *workspace_addr() const { return workspace_addr_; } int64_t workspace_size() const { return workspace_size_; } int64_t hccl_stream_num() const { return hccl_stream_num_; } - const std::vector &private_def() const { return private_def_; } /*lint !e1413*/ + const std::vector &private_def() const { return private_def_; } void *ops_kernel_store() const { return ops_kernel_store_; } int32_t count() const { return count_; } int64_t root_id() const { return root_id_; } int64_t op_type() const { return op_type_; } int64_t data_type() const { return data_type_; } - const std::string group() const { return group_; } + const std::string &group() const { return group_; } std::function hcom_bind_model() const { return hcom_bind_model_; } std::function hcom_unbind_model() const { return hcom_unbind_model_; } std::function, void *)> hcom_distribute_task() const {