diff --git a/inc/common/blocking_queue.h b/inc/common/blocking_queue.h index 7a5e98cf..12b02773 100644 --- a/inc/common/blocking_queue.h +++ b/inc/common/blocking_queue.h @@ -42,7 +42,7 @@ class BlockingQueue { return false; } - item = queue_.front(); + item = std::move(queue_.front()); queue_.pop_front(); full_cond_.notify_one(); @@ -71,6 +71,27 @@ class BlockingQueue { return true; } + bool Push(T &&item, bool is_wait = true) { + std::unique_lock lock(mutex_); + + while (queue_.size() >= max_size_ && !is_stoped_) { + if (!is_wait) { + return false; + } + full_cond_.wait(lock); + } + + if (is_stoped_) { + return false; + } + + queue_.emplace_back(std::move(item)); + + empty_cond_.notify_one(); + + return true; + } + void Stop() { { std::unique_lock lock(mutex_); diff --git a/inc/common/opskernel/ge_task_info.h b/inc/common/opskernel/ge_task_info.h index 1b8c7584..360f8a5d 100644 --- a/inc/common/opskernel/ge_task_info.h +++ b/inc/common/opskernel/ge_task_info.h @@ -26,6 +26,7 @@ using std::string; namespace ge { // when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD struct GETaskKernelHcclInfo { + string input_name; string hccl_type; void *inputDataAddr; void *outputDataAddr; @@ -35,6 +36,7 @@ struct GETaskKernelHcclInfo { int32_t opType; int64_t rootId; uint64_t workSpaceMemSize; + std::vector dims; std::vector hcclStreamList; }; @@ -48,7 +50,7 @@ struct GETaskInfo { uint32_t privateDefLen; void *opsKernelStorePtr; - GETaskKernelHcclInfo kernelHcclInfo; + std::vector kernelHcclInfo; }; } // namespace ge #endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ diff --git a/inc/common/opskernel/ops_kernel_info_store.h b/inc/common/opskernel/ops_kernel_info_store.h index 52ceda91..46338e45 100644 --- a/inc/common/opskernel/ops_kernel_info_store.h +++ b/inc/common/opskernel/ops_kernel_info_store.h @@ -73,7 +73,7 @@ class OpsKernelInfoStore { // only call fe engine interface to compile single op virtual Status CompileOp(vector &node_vec) { return SUCCESS; } - + virtual Status CompileOpRun(vector &node_vec) { return SUCCESS; } // load task for op virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index 6fa269ce..09561212 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -33,6 +33,7 @@ const char *const OPTION_EXEC_SESSION_ID = "ge.exec.sessionId"; const char *const OPTION_EXEC_DEVICE_ID = "ge.exec.deviceId"; const char *const OPTION_EXEC_JOB_ID = "ge.exec.jobId"; const char *const OPTION_EXEC_IS_USEHCOM = "ge.exec.isUseHcom"; +const char *const OPTION_EXEC_IS_USEHVD = "ge.exec.isUseHvd"; const char *const OPTION_EXEC_RANK_ID = "ge.exec.rankId"; const char *const OPTION_EXEC_POD_NAME = "ge.exec.podName"; const char *const OPTION_EXEC_DEPLOY_MODE = "ge.exec.deployMode"; @@ -52,6 +53,7 @@ const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; const char *const OPTION_EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory"; +const char *const OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION = "ge.exec.isTailingOptimization"; // Option key: memory init const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; @@ -153,7 +155,7 @@ const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; // congigure opSelectImplmode to setting op select implmode -const std::string kOpSelectImplmode = "ge.opSelectImplmode"; +const std::string OP_SELECT_IMPL_MODE = "ge.opSelectImplmode"; // configure whether to enable hcom parallel by session constructor options param, // its value should be "0" or "1", default value is "0" @@ -214,6 +216,9 @@ const char *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; // Its value should be "true" or "false", default value is "false" const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; +// Configure input fp16 nodes +const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; + // Graph run mode enum GraphRunMode { PREDICTION = 0, TRAIN }; @@ -263,14 +268,37 @@ static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); static const char *const ENABLE_SINGLE_STREAM = ge::ENABLE_SINGLE_STREAM; +static const char *const AICORE_NUM = ge::AICORE_NUM.c_str(); +static const char *const FUSION_SWITCH_FILE = ge::FUSION_SWITCH_FILE.c_str(); +static const char *const ENABLE_SMALL_CHANNEL = ge::ENABLE_SMALL_CHANNEL.c_str(); +static const char *const QUANT_OPTIMIZE = ge::QUANT_OPTIMIZE.c_str(); +static const char *const OP_SELECT_IMPL_MODE = ge::OP_SELECT_IMPL_MODE.c_str(); +static const char *const OUTPUT_TYPE = ge::OUTPUT_DATATYPE.c_str(); +static const char *const BUFFER_OPTIMIZE = ge::BUFFER_OPTIMIZE.c_str(); +static const char *const ENABLE_COMPRESS_WEIGHT = ge::ENABLE_COMPRESS_WEIGHT.c_str(); +static const char *const COMPRESS_WEIGHT_CONF = "compress_weight_conf"; +static const char *const OUT_NODES = ge::OUTPUT_NODE_NAME.c_str(); +static const char *const INPUT_FP16_NODES = ge::INPUT_FP16_NODES.c_str(); +static const char *const LOG_LEVEL = "log"; // for interface: aclgrphBuildModel -const std::set ir_builder_suppported_options = {INPUT_FORMAT, INPUT_SHAPE, DYNAMIC_BATCH_SIZE, - DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE}; +const std::set ir_builder_suppported_options = { + INPUT_FORMAT, INPUT_SHAPE, DYNAMIC_BATCH_SIZE, DYNAMIC_IMAGE_SIZE, + INSERT_OP_FILE, OUTPUT_TYPE, BUFFER_OPTIMIZE, ENABLE_COMPRESS_WEIGHT, + COMPRESS_WEIGHT_CONF, OUT_NODES, INPUT_FP16_NODES, LOG_LEVEL}; // for interface: aclgrphBuildInitialize -const std::set global_options = { - HEAD_STREAM, CORE_TYPE, SOC_VERSION, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, - AUTO_TUNE_MODE, ENABLE_SINGLE_STREAM}; +const std::set global_options = {HEAD_STREAM, + CORE_TYPE, + SOC_VERSION, + PRECISION_MODE, + EXEC_DISABLE_REUSED_MEMORY, + AUTO_TUNE_MODE, + ENABLE_SINGLE_STREAM, + AICORE_NUM, + FUSION_SWITCH_FILE, + ENABLE_SMALL_CHANNEL, + QUANT_OPTIMIZE, + OP_SELECT_IMPL_MODE}; } // namespace ir_option } // namespace ge diff --git a/inc/external/graph/operator.h b/inc/external/graph/operator.h index be7f10db..1deae7d9 100644 --- a/inc/external/graph/operator.h +++ b/inc/external/graph/operator.h @@ -48,12 +48,9 @@ class NamedAttrs; class Graph; class AttrValue; -using SubgraphBuilder = std::function; +using SubgraphBuilder = std::function; using OperatorImplPtr = std::shared_ptr; -class Graph; -using GraphBuilderCallback = std::function; - class OpIO; using OutHandler = std::shared_ptr; using InHandler = std::shared_ptr; @@ -139,12 +136,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { void SetInferenceContext(const InferenceContextPtr &inference_context); InferenceContextPtr GetInferenceContext() const; - void SetGraphBuilder(const GraphBuilderCallback &builder); - graphStatus GetGraphBuilder(GraphBuilderCallback &builder) const; - - void AddSubgraphName(const string &name); - string GetSubgraphName(int index) const; - graphStatus VerifyAllAttr(bool disable_common_verifier = false); size_t GetInputsSize() const; @@ -265,9 +256,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); - void SubgraphRegister(const std::string &name, bool dynamic); - void SubgraphCountRegister(const std::string &name, uint32_t count); - void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder); + void SubgraphRegister(const string &ir_name, bool dynamic); + void SubgraphCountRegister(const string &ir_name, uint32_t count); + void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); private: Operator &SetInput(const string &dst_name, const OutHandler &out_handler); diff --git a/inc/external/graph/operator_reg.h b/inc/external/graph/operator_reg.h index 57b1f8fe..dfa21acf 100644 --- a/inc/external/graph/operator_reg.h +++ b/inc/external/graph/operator_reg.h @@ -186,56 +186,54 @@ class OpReg { Operator::OutputRegister(#x); \ (void)OpReg() -#define DYNAMIC_INPUT(x, t) \ - N(); \ - __dy_input_##x(); \ - } \ - \ - public: \ - _THIS_TYPE &create_dynamic_input_##x(unsigned int num, bool isPushBack = true) { \ - Operator::DynamicInputRegister(#x, num, isPushBack); \ - return *this; \ - } \ - _THIS_TYPE &create_dynamic_input_byindex_##x(unsigned int num, size_t index) { \ - Operator::DynamicInputRegisterByIndex(#x, num, index); \ - return *this; \ - } \ - TensorDesc get_dynamic_input_desc_##x(unsigned int index) const { return Operator::GetDynamicInputDesc(#x, index); } \ - graphStatus update_dynamic_input_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \ - return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ - } \ - _THIS_TYPE &set_dynamic_input_##x(unsigned int dstIndex, Operator &v) { \ - Operator::SetInput(#x, dstIndex, v); \ - return *this; \ - } \ - _THIS_TYPE &set_dynamic_input_##x(unsigned int dstIndex, Operator &v, const string &srcName) { \ - Operator::SetInput(#x, dstIndex, v, srcName); \ - return *this; \ - } \ - \ - private: \ - void __dy_input_##x() { \ +#define DYNAMIC_INPUT(x, t) \ + N(); \ + __dy_input_##x(); \ + } \ + \ + public: \ + _THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \ + Operator::DynamicInputRegister(#x, num, isPushBack); \ + return *this; \ + } \ + _THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \ + Operator::DynamicInputRegisterByIndex(#x, num, index); \ + return *this; \ + } \ + TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { return Operator::GetDynamicInputDesc(#x, index); } \ + graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ + return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ + } \ + _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \ + Operator::SetInput(#x, dstIndex, v); \ + return *this; \ + } \ + _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const string &srcName) { \ + Operator::SetInput(#x, dstIndex, v, srcName); \ + return *this; \ + } \ + \ + private: \ + void __dy_input_##x() { \ (void)OpReg() -#define DYNAMIC_OUTPUT(x, t) \ - N(); \ - __dy_output_##x(); \ - } \ - \ - public: \ - _THIS_TYPE &create_dynamic_output_##x(unsigned int num, bool isPushBack = true) { \ - Operator::DynamicOutputRegister(#x, num, isPushBack); \ - return *this; \ - } \ - TensorDesc get_dynamic_output_desc_##x(unsigned int index) const { \ - return Operator::GetDynamicOutputDesc(#x, index); \ - } \ - graphStatus update_dynamic_output_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \ - return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \ - } \ - \ - private: \ - void __dy_output_##x() { \ +#define DYNAMIC_OUTPUT(x, t) \ + N(); \ + __dy_output_##x(); \ + } \ + \ + public: \ + _THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \ + Operator::DynamicOutputRegister(#x, num, isPushBack); \ + return *this; \ + } \ + TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { return Operator::GetDynamicOutputDesc(#x, index); } \ + graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ + return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \ + } \ + \ + private: \ + void __dy_output_##x() { \ (void)OpReg() #define GRAPH(x) \ @@ -258,29 +256,29 @@ class OpReg { Operator::SubgraphCountRegister(#x, 1); \ (void)OpReg() -#define DYNAMIC_GRAPH(x) \ - N(); \ - __graph_##x(); \ - } \ - \ - public: \ - static const string name_graph_##x() { return #x; } \ - _THIS_TYPE &create_dynamic_subgraph_##x(unsigned int num) { \ - Operator::SubgraphCountRegister(#x, num); \ - return *this; \ - } \ - SubgraphBuilder get_dynamic_subgraph_builder_##x(unsigned int index) const { \ - return Operator::GetDynamicSubgraphBuilder(#x, index); \ - } \ - Graph get_dynamic_subgraph_##x(unsigned int index) const { return Operator::GetDynamicSubgraph(#x, index); } \ - _THIS_TYPE &set_dynamic_subgraph_builder_##x(unsigned int index, const SubgraphBuilder &v) { \ - Operator::SetSubgraphBuilder(#x, index, v); \ - return *this; \ - } \ - \ - private: \ - void __graph_##x() { \ - Operator::SubgraphRegister(#x, true); \ +#define DYNAMIC_GRAPH(x) \ + N(); \ + __graph_##x(); \ + } \ + \ + public: \ + static const string name_graph_##x() { return #x; } \ + _THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \ + Operator::SubgraphCountRegister(#x, num); \ + return *this; \ + } \ + SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \ + return Operator::GetDynamicSubgraphBuilder(#x, index); \ + } \ + Graph get_dynamic_subgraph_##x(uint32_t index) const { return Operator::GetDynamicSubgraph(#x, index); } \ + _THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index, const SubgraphBuilder &v) { \ + Operator::SetSubgraphBuilder(#x, index, v); \ + return *this; \ + } \ + \ + private: \ + void __graph_##x() { \ + Operator::SubgraphRegister(#x, true); \ (void)OpReg() #define PASTE(g_register, y) g_register##y diff --git a/inc/external/graph/types.h b/inc/external/graph/types.h index 46cb34b9..6a8362ba 100644 --- a/inc/external/graph/types.h +++ b/inc/external/graph/types.h @@ -24,7 +24,7 @@ namespace ge { static const int64_t UNKNOWN_DIM = -1; static const int64_t UNKNOWN_DIM_NUM = -2; -static const std::vector UNKNOWN_SHAPE = {0}; +static const std::vector UNKNOWN_SHAPE = {-1}; static const std::vector UNKNOWN_RANK = {-2}; #ifdef HOST_VISIBILITY diff --git a/inc/framework/common/ge_types.h b/inc/framework/common/ge_types.h index 6c70aa4c..ae83e40d 100644 --- a/inc/framework/common/ge_types.h +++ b/inc/framework/common/ge_types.h @@ -40,6 +40,14 @@ enum FrameworkType { FMK_TYPE_RESERVED, }; +enum OpEngineType { + ENGINE_SYS = 0, // default engine + ENGINE_AICORE = 1, + ENGINE_VECTOR = 2, + ENGINE_AICUBE = 3, // not support + ENGINE_AIVECTOR = 4 // not support +}; + const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM"; // Data cache, including data address and length @@ -141,6 +149,7 @@ struct Options { int32_t device_id; std::string job_id; bool isUseHcom; + bool isUseHvd; bool deployMode; bool isAICPUMode; bool enable_atomic; diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index 7bb8d5e7..fe5cca62 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -442,6 +442,7 @@ REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); REGISTER_OPTYPE_DECLARE(SEND, "Send"); REGISTER_OPTYPE_DECLARE(RECV, "Recv"); +REGISTER_OPTYPE_DECLARE(ENDOFSEQUENCE, "EndOfSequence"); REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); @@ -508,6 +509,12 @@ REGISTER_OPTYPE_DECLARE(DEPTHWISEWEIGHT6D24D, "depthwise_weight_6d_2_4d"); REGISTER_OPTYPE_DECLARE(SQRTGRAD, "SqrtGrad"); REGISTER_OPTYPE_DECLARE(SIGMOIDGRAD, "SigmoidGrad"); +// Horovod operator +REGISTER_OPTYPE_DECLARE(HVDCALLBACKALLREDUCE, "HorovodAllreduce"); +REGISTER_OPTYPE_DECLARE(HVDCALLBACKALLGATHER, "HorovodAllgather"); +REGISTER_OPTYPE_DECLARE(HVDCALLBACKBROADCAST, "HorovodBroadcast"); +REGISTER_OPTYPE_DECLARE(HVDWAIT, "HorovodWait"); + enum InputMode { INPUT = 0, CONST }; // Definition of the processing status enum of the process module diff --git a/inc/framework/dlog/log.h b/inc/framework/dlog/log.h deleted file mode 100644 index 8126720c..00000000 --- a/inc/framework/dlog/log.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_FRAMEWORK_DLOG_LOG_H_ -#define INC_FRAMEWORK_DLOG_LOG_H_ - -#include -#if !defined(__ANDROID__) && !defined(ANDROID) -#include "toolchain/slog.h" -#else -#include -#endif - -#ifdef _MSC_VER -#define FUNC_NAME __FUNCTION__ -#else -#define FUNC_NAME __PRETTY_FUNCTION__ -#endif - -#if !defined(__ANDROID__) && !defined(ANDROID) -#define DAV_LOGI(MOD_NAME, fmt, ...) dlog_info(static_cast(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) -#define DAV_LOGW(MOD_NAME, fmt, ...) dlog_warn(static_cast(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) -#define DAV_LOGE(MOD_NAME, fmt, ...) dlog_error(static_cast(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) -#define DAV_LOGD(MOD_NAME, fmt, ...) dlog_debug(static_cast(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) -#define DAV_EVENT(MOD_NAME, fmt, ...) dlog_event(static_cast(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) -#else -#define DAV_LOGI(MOD_NAME, fmt, ...) \ - __android_log_print(ANDROID_LOG_INFO, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) -#define DAV_LOGW(MOD_NAME, fmt, ...) \ - __android_log_print(ANDROID_LOG_WARN, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) -#define DAV_LOGE(MOD_NAME, fmt, ...) \ - __android_log_print(ANDROID_LOG_ERROR, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) -#define DAV_LOGD(MOD_NAME, fmt, ...) \ - __android_log_print(ANDROID_LOG_DEBUG, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) -#define DAV_EVENT(MOD_NAME, fmt, ...) \ - __android_log_print(ANDROID_LOG_DEBUG, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) -#endif - -#define DLOG_DECLARE(level) \ - void Log_##level(const char *mod_name, const char *func, const char *file, int line, const char *format, ...) - -namespace domi { -DLOG_DECLARE(INFO); -DLOG_DECLARE(WARNING); -DLOG_DECLARE(ERROR); -} // namespace domi - -#endif // INC_FRAMEWORK_DLOG_LOG_H_ diff --git a/inc/framework/ge_runtime/task_info.h b/inc/framework/ge_runtime/task_info.h old mode 100755 new mode 100644 diff --git a/inc/framework/ge_runtime_dummy/davinci_model.h b/inc/framework/ge_runtime_dummy/davinci_model.h new file mode 100644 index 00000000..91e70159 --- /dev/null +++ b/inc/framework/ge_runtime_dummy/davinci_model.h @@ -0,0 +1,113 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_ +#define INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_ + +#include +#include + +#include "ge_runtime/op_info.h" +#include "ge_runtime/task_info.h" + +namespace ge { +namespace model_runner { +class DavinciModel { + public: + DavinciModel(const std::vector> &task_info_list, + const std::vector> &data_info_list, + const std::vector> &output_info_list, + const std::vector> &constant_info_list, + const std::vector &variable_info_list, + const std::vector &wait_active_stream_list, + const std::vector &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, + uint64_t var_size = 0, uintptr_t logic_mem_base = 0, uintptr_t logic_weight_base = 0, + uintptr_t logic_var_base = 0, uint32_t stream_num = 0, uint32_t batch_num = 0, uint32_t event_num = 0, + int32_t priority = 0) + : task_info_list_(task_info_list), + data_info_list_(data_info_list), + output_info_list_(output_info_list), + constant_info_list_(constant_info_list), + variable_info_list_(variable_info_list), + wait_active_stream_list_(wait_active_stream_list), + force_copy_stream_list_(force_copy_stream_list), + mem_size_(mem_size), + weight_size_(weight_size), + var_size_(var_size), + logic_mem_base_(logic_mem_base), + logic_weight_base_(logic_weight_base), + logic_var_base_(logic_var_base), + stream_num_(stream_num), + batch_num_(batch_num), + event_num_(event_num), + priority_(priority) {} + ~DavinciModel() {} + + uint64_t GetMemSize() const { return mem_size_; } + uint64_t GetWeightSize() const { return weight_size_; } + uint64_t GetVarSize() const { return var_size_; } + + uintptr_t GetLogicMemBase() const { return logic_mem_base_; } + uintptr_t GetLogicWeightBase() const { return logic_weight_base_; } + uintptr_t GetLogicVarBase() const { return logic_var_base_; } + + uint32_t GetStreamNum() const { return stream_num_; } + uint32_t GetBatchNum() const { return batch_num_; } + uint32_t GetEventNum() const { return event_num_; } + + const std::vector &GetWaitActiveStreams() const { return wait_active_stream_list_; } + const std::vector &GetForceCopyStreams() const { return force_copy_stream_list_; } + + int32_t GetPriority() const { return priority_; } + + const std::vector> &GetTaskInfoList() const { return task_info_list_; } + const std::vector> &GetDataInfoList() const { return data_info_list_; } + const std::vector> &GetOutputInfoList() const { return output_info_list_; } + const std::vector> &GetConstantInfoList() const { return output_info_list_; } + const std::vector &GetVariableInfoList() const { return variable_info_list_; } + + private: + std::vector> task_info_list_; + std::vector> data_info_list_; + std::vector> output_info_list_; + std::vector> constant_info_list_; + std::vector variable_info_list_; + + std::vector wait_active_stream_list_; + std::vector force_copy_stream_list_; + + uint64_t mem_size_; + uint64_t weight_size_; + uint64_t var_size_; + + uintptr_t logic_mem_base_; + uintptr_t logic_weight_base_; + uintptr_t logic_var_base_; + + uint32_t stream_num_; + uint32_t batch_num_; + uint32_t event_num_; + + int32_t priority_; + + // Disable to copy constructor and assignment operator + DavinciModel &operator=(const DavinciModel &) = delete; + DavinciModel(const DavinciModel &) = delete; +}; +} // namespace model_runner +} // namespace ge + +#endif // INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_ diff --git a/inc/framework/ge_runtime_dummy/model_runner.h b/inc/framework/ge_runtime_dummy/model_runner.h new file mode 100644 index 00000000..6e7abcb9 --- /dev/null +++ b/inc/framework/ge_runtime_dummy/model_runner.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_ +#define INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_ + +#include +#include +#include + +#include "common/ge_inner_error_codes.h" +#include "common/ge_types.h" +#include "ge_runtime/davinci_model.h" + +namespace ge { +namespace model_runner { +class RuntimeModel; + +class ModelRunner { + public: + static ModelRunner &Instance(); + + bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, + std::shared_ptr davinci_model, std::shared_ptr listener); + + const std::vector &GetTaskIdList(uint32_t model_id) const; + + bool UnloadModel(uint32_t model_id); + + bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); + + bool GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, std::vector *input_desc, + std::vector *output_desc, std::vector *input_format, + std::vector *output_format); + + private: + ModelRunner() = default; + ~ModelRunner() = default; + + std::unordered_map> runtime_models_; +}; +} // namespace model_runner +} // namespace ge + +#endif // INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_ diff --git a/inc/framework/ge_runtime_dummy/op_info.h b/inc/framework/ge_runtime_dummy/op_info.h new file mode 100644 index 00000000..22c16ed6 --- /dev/null +++ b/inc/framework/ge_runtime_dummy/op_info.h @@ -0,0 +1,72 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_ +#define INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_ + +#include +#include +#include + +namespace ge { +namespace model_runner { +struct TensorInfo { + int64_t GetShapeSize() const { + int64_t res = 1; + if (dims.empty()) { + return 0; + } + for (auto dim : dims) { + res *= dim; + } + return res; + } + + int64_t GetDim(uint32_t index) { + if (index >= dims.size()) { + return 0; + } + return dims[index]; + } + + std::vector dims; + uint32_t datatype; + uint32_t format; + uint32_t real_dim_cnt; + uint32_t size; + bool is_output; +}; + +struct OpInfo { + uint32_t index; + std::string name; + std::string type; + bool var_is_broadcast; + std::vector input_addrs; + std::vector output_addrs; + std::vector input_tensors; + std::vector output_tensors; + std::vector weight_tensors; + std::vector src_name; + std::vector src_index; + std::string weight_data; +}; + +using TensorInfoPtr = std::shared_ptr; +using OpInfoPtr = std::shared_ptr; +} // namespace model_runner +} // namespace ge +#endif // INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_ diff --git a/inc/framework/ge_runtime_dummy/task_info.h b/inc/framework/ge_runtime_dummy/task_info.h new file mode 100644 index 00000000..a48ed68b --- /dev/null +++ b/inc/framework/ge_runtime_dummy/task_info.h @@ -0,0 +1,394 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ +#define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ + +#include +#include +#include +#include +#include + +#include "cce/taskdown_api.h" + +namespace ge { +namespace model_runner { +enum TaskInfoType { + CCE = 0, + TBE, + AICPU, + LABEL_SET, + LABEL_SWITCH, + LABEL_GOTO, + EVENT_RECORD, + EVENT_WAIT, + FUSION_START, + FUSION_END, + HCCL, + PROFILER_TRACE, + MEMCPY_ASYNC, + STREAM_SWITCH, + STREAM_ACTIVE, + // Insert new task type here + REVSERVED = 23 +}; + +class TaskInfo { + public: + virtual ~TaskInfo() {} + uint32_t stream_id() const { return stream_id_; } + TaskInfoType type() const { return type_; } + + protected: + TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {} + + private: + uint32_t stream_id_; + TaskInfoType type_; +}; + +class CceTaskInfo : public TaskInfo { + public: + CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim, + const std::vector &args, uint32_t args_size, const std::vector &sm_desc, + const std::vector &flow_table, const std::vector &args_offset, bool is_flowtable) + : TaskInfo(stream_id, TaskInfoType::CCE), + ctx_(ctx), + stub_func_(stub_func), + block_dim_(block_dim), + args_(args), + args_size_(args_size), + sm_desc_(sm_desc), + flow_table_(flow_table), + args_offset_(args_offset), + is_flowtable_(is_flowtable) {} + ~CceTaskInfo() override {} + + cce::ccOpContext cc_context() const { return ctx_; } + std::string stub_func() const { return stub_func_; } + uint32_t block_dim() const { return block_dim_; } + const std::vector &args() const { return args_; } + uint32_t args_size() const { return args_size_; } + const std::vector &sm_desc() const { return sm_desc_; } + const std::vector &flow_table() const { return flow_table_; } + const std::vector &args_offset() const { return args_offset_; } + bool is_flowtable() const { return is_flowtable_; } + + private: + cce::ccOpContext ctx_; + std::string stub_func_; + uint32_t block_dim_; + std::vector args_; + uint32_t args_size_; + std::vector sm_desc_; + std::vector flow_table_; + std::vector args_offset_; + bool is_flowtable_; +}; + +class TbeTaskInfo : public TaskInfo { + public: + TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector &args, + uint32_t args_size, const std::vector &sm_desc, void *binary, uint32_t binary_size, + const std::vector &meta_data, const std::vector &input_data_addrs, + const std::vector &output_data_addrs, const std::vector &workspace_addrs) + : TaskInfo(stream_id, TaskInfoType::TBE), + stub_func_(stub_func), + block_dim_(block_dim), + args_(args), + args_size_(args_size), + sm_desc_(sm_desc), + binary_(binary), + binary_size_(binary_size), + meta_data_(meta_data), + input_data_addrs_(input_data_addrs), + output_data_addrs_(output_data_addrs), + workspace_addrs_(workspace_addrs) {} + ~TbeTaskInfo() override {} + + const std::string &stub_func() const { return stub_func_; } + uint32_t block_dim() const { return block_dim_; } + const std::vector &args() const { return args_; } + uint32_t args_size() const { return args_size_; } + const std::vector &sm_desc() const { return sm_desc_; } + void *binary() const { return binary_; } + uint32_t binary_size() const { return binary_size_; } + const std::vector &meta_data() const { return meta_data_; } + const std::vector &input_data_addrs() const { return input_data_addrs_; } + const std::vector &output_data_addrs() const { return output_data_addrs_; } + const std::vector &workspace_addrs() const { return workspace_addrs_; } + + void SetBinary(void *binary, uint32_t binary_size) { + binary_ = binary; + binary_size_ = binary_size; + } + + private: + std::string stub_func_; + uint32_t block_dim_; + std::vector args_; + uint32_t args_size_; + std::vector sm_desc_; + void *binary_; + uint32_t binary_size_; + std::vector meta_data_; + std::vector input_data_addrs_; + std::vector output_data_addrs_; + std::vector workspace_addrs_; +}; + +class AicpuTaskInfo : public TaskInfo { + public: + AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def, + const std::vector &input_data_addrs, const std::vector &output_data_addrs) + : TaskInfo(stream_id, TaskInfoType::AICPU), + so_name_(so_name), + kernel_name_(kernel_name), + node_def_(node_def), + input_data_addrs_(input_data_addrs), + output_data_addrs_(output_data_addrs) {} + ~AicpuTaskInfo() override {} + + const std::string &so_name() const { return so_name_; } + const std::string &kernel_name() const { return kernel_name_; } + const std::string &node_def() const { return node_def_; } + const std::vector &input_data_addrs() const { return input_data_addrs_; } + const std::vector &output_data_addrs() const { return output_data_addrs_; } + + private: + std::string so_name_; + std::string kernel_name_; + std::string node_def_; + std::vector input_data_addrs_; + std::vector output_data_addrs_; +}; + +class LabelTaskInfo : public TaskInfo { + public: + uint32_t label_id() const { return label_id_; } + + protected: + LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id) + : TaskInfo(stream_id, type), label_id_(label_id) {} + virtual ~LabelTaskInfo() override {} + + uint32_t label_id_; +}; + +class LabelSetTaskInfo : public LabelTaskInfo { + public: + LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) + : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} + ~LabelSetTaskInfo() override {} +}; + +class LabelSwitchTaskInfo : public LabelTaskInfo { + public: + LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) + : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} + ~LabelSwitchTaskInfo() override {} +}; + +class LabelGotoTaskInfo : public LabelTaskInfo { + public: + LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) + : LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} + ~LabelGotoTaskInfo() override {} +}; + +class EventTaskInfo : public TaskInfo { + public: + uint32_t event_id() const { return event_id_; } + + protected: + EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id) + : TaskInfo(stream_id, type), event_id_(event_id) {} + virtual ~EventTaskInfo() override {} + + uint32_t event_id_; +}; + +class EventRecordTaskInfo : public EventTaskInfo { + public: + EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id) + : EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {} + ~EventRecordTaskInfo() override {} +}; + +class EventWaitTaskInfo : public EventTaskInfo { + public: + EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id) + : EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {} + ~EventWaitTaskInfo() override {} +}; + +class FusionStartTaskInfo : public TaskInfo { + public: + explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {} + ~FusionStartTaskInfo() override {} +}; + +class FusionEndTaskInfo : public TaskInfo { + public: + explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {} + ~FusionEndTaskInfo() override {} +}; + +class HcclTaskInfo : public TaskInfo { + public: + HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr, + void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, + const std::vector &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, + int64_t op_type, int64_t data_type, std::function hcom_bind_model, + std::function hcom_unbind_model, + std::function, void *)> hcom_distribute_task) + : TaskInfo(stream_id, TaskInfoType::HCCL), + hccl_type_(hccl_type), + input_data_addr_(input_data_addr), + output_data_addr_(output_data_addr), + workspace_addr_(workspace_addr), + workspace_size_(workspace_size), + hccl_stream_num_(hccl_stream_num), + private_def_(private_def), + ops_kernel_store_(ops_kernel_store), + count_(count), + root_id_(root_id), + op_type_(op_type), + data_type_(data_type), + hcom_bind_model_(hcom_bind_model), + hcom_unbind_model_(hcom_unbind_model), + hcom_distribute_task_(hcom_distribute_task) {} + ~HcclTaskInfo() override {} + + const std::string &hccl_type() const { return hccl_type_; } + void *input_data_addr() const { return input_data_addr_; } + void *output_data_addr() const { return output_data_addr_; } + void *workspace_addr() const { return workspace_addr_; } + int64_t workspace_size() const { return workspace_size_; } + int64_t hccl_stream_num() const { return hccl_stream_num_; } + const std::vector &private_def() const { return private_def_; } + void *ops_kernel_store() const { return ops_kernel_store_; } + int32_t count() const { return count_; } + int64_t root_id() const { return root_id_; } + int64_t op_type() const { return op_type_; } + int64_t data_type() const { return data_type_; } + std::function hcom_bind_model() const { return hcom_bind_model_; } + std::function hcom_unbind_model() const { return hcom_unbind_model_; } + std::function, void *)> hcom_distribute_task() const { + return hcom_distribute_task_; + } + + private: + std::string hccl_type_; + void *input_data_addr_; + void *output_data_addr_; + void *workspace_addr_; + int64_t workspace_size_; + int64_t hccl_stream_num_; + std::vector private_def_; + void *ops_kernel_store_; + int32_t count_; + int64_t root_id_; + int64_t op_type_; + int64_t data_type_; + std::function hcom_bind_model_; + std::function hcom_unbind_model_; + std::function, void *)> hcom_distribute_task_; +}; + +class ProfilerTraceTaskInfo : public TaskInfo { + public: + ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) + : TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {} + ~ProfilerTraceTaskInfo() override {} + + uint64_t log_id() const { return log_id_; } + bool notify() const { return notify_; } + uint32_t flat() const { return flat_; } + + private: + uint64_t log_id_; + bool notify_; + uint32_t flat_; +}; + +class MemcpyAsyncTaskInfo : public TaskInfo { + public: + MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind) + : TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC), + dst_(dst), + dst_max_(dst_max), + src_(src), + count_(count), + kind_(kind) {} + ~MemcpyAsyncTaskInfo() override {} + + void *dst() const { return dst_; } + uint64_t dst_max() const { return dst_max_; } + void *src() const { return src_; } + uint64_t count() const { return count_; } + uint32_t kind() const { return kind_; } + + private: + void *dst_; + uint64_t dst_max_; + void *src_; + uint64_t count_; + int32_t kind_; +}; + +class StreamSwitchTaskInfo : public TaskInfo { + public: + StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond, + int64_t data_type) + : TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH), + true_stream_id_(true_stream_id), + input_addr_(input_addr), + value_addr_(value_addr), + cond_(cond), + data_type_(data_type) {} + ~StreamSwitchTaskInfo() override {} + + int64_t true_stream_id() const { return true_stream_id_; } + void *input_addr() const { return input_addr_; } + void *value_addr() const { return value_addr_; } + int64_t cond() const { return cond_; } + int64_t data_type() const { return data_type_; } + + private: + int64_t true_stream_id_; + void *input_addr_; + void *value_addr_; + int64_t cond_; + int64_t data_type_; +}; + +class StreamActiveTaskInfo : public TaskInfo { + public: + StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id) + : TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {} + ~StreamActiveTaskInfo() override {} + + uint32_t active_stream_id() const { return active_stream_id_; } + + private: + uint32_t active_stream_id_; +}; +} // namespace model_runner +} // namespace ge + +#endif // INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index a18e730d..f0707c67 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -23,6 +23,7 @@ #include #include "ge/ge_ir_build.h" #include "common/ge_inner_error_codes.h" +#include "common/ge_types.h" #include "graph/ge_tensor.h" #include "graph/graph.h" #include "graph/op_desc.h" @@ -30,9 +31,13 @@ namespace ge { class GeGenerator { public: + static GeGenerator &GetInstance() { + static GeGenerator Instance; + return Instance; + } GeGenerator() = default; - ~GeGenerator() = default; + ~GeGenerator() { (void)Finalize(); } GeGenerator(const GeGenerator &) = delete; @@ -60,10 +65,25 @@ class GeGenerator { /// Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector &inputs, const std::vector &outputs, const std::string &model_file_name); + /// + /// @ingroup ge + /// @brief: Build single Op into model buff. + /// @param [in] op_desc: the OP description. + /// @param [in] inputs: input tensors. + /// @param [in] outputs: output tensors. + /// @param [in] engine_type: specific engine. + /// @param [out] model_buff: model buff of single op. + /// @return SUCCESS or FAILED + Status BuildSingleOpModel(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, + OpEngineType engine_type, ModelBufferData &model_buff); private: Status GenerateModel(const Graph &graph, const string &file_name_prefix, const vector &inputs, ge::ModelBufferData &model, bool is_offline = true); + Status BuildSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, + const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, + bool is_offline = true); + class Impl; std::shared_ptr impl_; diff --git a/inc/framework/omg/omg.h b/inc/framework/omg/omg.h new file mode 100644 index 00000000..11d94817 --- /dev/null +++ b/inc/framework/omg/omg.h @@ -0,0 +1,113 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_OMG_H_ +#define INC_FRAMEWORK_OMG_OMG_H_ + +#include +#include +#include +#include +#include "framework/common/types.h" +#include "framework/omg/omg_inner_types.h" +#include "proto/ge_ir.pb.h" +#include "proto/om.pb.h" + +#include "graph/compute_graph.h" +#include "graph/graph.h" +#include "graph/model.h" +#include "runtime/kernel.h" + +using domi::Status; +using std::pair; +using std::string; +using std::unordered_map; +using std::vector; + +namespace ge { +/** + * @ingroup domi_omg + * @brief init omg context + * @return void + */ +Status InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, + bool is_dynamic_input); + +/** + * @ingroup domi_omg + * @brief generate graph based on the input model file and weight file + * @param [out] graph graph + * @param [in] model_file path of model file + * @param [in] weights_file path of weight file + * @param [in] type type of the input model + * @param [in] op_conf op mapping configuration + * @param [in] target type of platform. If a tiny model is generated, set target to tiny + * @param [in] run_mode run model + * @param [in] enable_l2dynamic enable l2dynamic + * @param [in] is_dynamic_input dynamic input, true of false + * @param [in] atc_params multiply atc params + * @return Status result code + */ +Status ParseGraph(ge::Graph &graph, const std::map &atc_params, const char *model_file, + const char *weights_file, domi::FrameworkType type, const char *op_conf = nullptr, + const char *target = nullptr, RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false); + +/** + * @ingroup domi_omg + * @brief generates a simplified JSON file based on the key value of the offline model file in protobuf format + * @param [in] model_file path of offline model file + * @param [out] json_file path of json file + * @param [key] encrypted key + * @return Status result code + */ +Status ConvertOmModelToJson(const char *model_file, const char *json_file); + +Status ConvertPbtxtToJson(const char *model_file, const char *json_file); +/** + * @ingroup domi_omg + * @brief convert the model file in protobuf format into a JSON file. + * @param [in] framework type of model + * @param [in] om model_file path of offline model file + * @param [out] json_file path of json file + * @param [key] encrypted key + * @return Status result code + */ +Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, const char *json_file); + +void GetGroupName(ge::proto::ModelDef &model); + +void FindParserSo(const string &path, vector &fileList, string &caffe_parser_path); + +Status CheckCustomAiCpuOpLib(); + +Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); + +Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); + +Status GetOutputLeaf(ge::NodePtr node, std::vector> &output_nodes_info, + std::vector &output_nodes_name); +} // namespace ge + +namespace domi { +/** + * @ingroup domi_omg + * @brief get omg context + * @return reference of OmgContext + */ +ge::OmgContext &GetContext(); +} // namespace domi + +#endif // INC_FRAMEWORK_OMG_OMG_H_ diff --git a/inc/framework/omg/omg_inner_types.h b/inc/framework/omg/omg_inner_types.h index 547fbe2f..118477b1 100644 --- a/inc/framework/omg/omg_inner_types.h +++ b/inc/framework/omg/omg_inner_types.h @@ -44,11 +44,10 @@ namespace ge { * @brief run model */ enum RunMode { - GEN_OM_MODEL = 0, // generate offline model file - MODEL_TO_JSON = 1, // convert to JSON file - MODEL_TO_JSON_WITH_SHAPE = 2, // convert to json file with shape - ONLY_PRE_CHECK = 3, // only for pre-check - PBTXT_TO_JSON = 5 // pbtxt to json + GEN_OM_MODEL = 0, // generate offline model file + MODEL_TO_JSON = 1, // convert to JSON file + ONLY_PRE_CHECK = 3, // only for pre-check + PBTXT_TO_JSON = 5 // pbtxt to json }; /// @@ -93,6 +92,8 @@ struct OmgContext { std::map> out_nodes_map; // user-designate out nodes (this is used for determing the orders) std::vector> user_out_nodes; + // net out nodes (where user_out_nodes or leaf nodes) + std::vector net_out_nodes; // path for the aicpu custom operator so_file std::vector aicpu_op_run_paths; // ddk version diff --git a/inc/graph/compute_graph.h b/inc/graph/compute_graph.h index dbde46f5..c18b7b5b 100644 --- a/inc/graph/compute_graph.h +++ b/inc/graph/compute_graph.h @@ -235,6 +235,8 @@ class ComputeGraph : public std::enable_shared_from_this, public A std::vector &stack); graphStatus BFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, std::deque &stack); + graphStatus BFSTopologicalSortingWithGroup(std::vector &node_vec, + std::map &map_in_edge_num, std::deque &stack); graphStatus CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, std::map &breadth_node_map); graphStatus TopologicalSortingGraph(); diff --git a/inc/graph/debug/ge_attr_define.h b/inc/graph/debug/ge_attr_define.h index 23bb114a..99dd7774 100644 --- a/inc/graph/debug/ge_attr_define.h +++ b/inc/graph/debug/ge_attr_define.h @@ -94,6 +94,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_SHAPE; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER_FORMAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_K; @@ -133,6 +137,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_GRAPH_NAME; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; @@ -692,6 +697,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OUT_NODES_NAME; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; @@ -920,6 +927,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG; @@ -999,14 +1007,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; // functional ops attr +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_ELSE_BRANCH; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; // used for label switch GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_END_NODE; -// Varible +// Variable GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; @@ -1032,6 +1043,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM // Dynamic stitch GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; + +// Used for support Horovod +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INTER_EVENT_IDENTIFY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE; +// for gradient group +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG; } // namespace ge #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ diff --git a/inc/graph/op_desc.h b/inc/graph/op_desc.h index 1827e6be..faca2d99 100644 --- a/inc/graph/op_desc.h +++ b/inc/graph/op_desc.h @@ -264,6 +264,8 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); void RemoveSubgraphInstanceName(const std::string &name); + graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const; + protected: ProtoAttrMapHelper MutableAttrMap() override; ConstProtoAttrMapHelper GetAttrMap() const override; @@ -288,7 +290,7 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { // subgraph ir names to type, for a `if` operator: // then_branch: static - // else_branch: dynamic + // else_branch: static // or for a `case` op: // branches: dynamic std::map subgraph_ir_names_to_type_; diff --git a/inc/graph/runtime_inference_context.h b/inc/graph/runtime_inference_context.h new file mode 100644 index 00000000..6c6c82e7 --- /dev/null +++ b/inc/graph/runtime_inference_context.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ +#define INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ + +#include +#include +#include +#include +#include "external/graph/ge_error_codes.h" +#include "external/graph/tensor.h" + +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY RuntimeInferenceContext { + public: + static graphStatus GetContext(const std::string &context_id, RuntimeInferenceContext **ctx); + static graphStatus CreateContext(const std::string &context_id); + static void DestroyContext(const std::string &context_id); + + graphStatus SetTensor(int64_t node_id, int output_id, Tensor &&tensor); + graphStatus GetTensor(int64_t node_id, int output_id, Tensor &tensor); + + private: + std::map> tensors_; + std::mutex mu_; + + static std::map> contexts_; + static std::mutex ctx_mu_; +}; +} // namespace ge + +#endif // INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ diff --git a/inc/graph/utils/graph_utils.h b/inc/graph/utils/graph_utils.h index 904684e7..15d25251 100644 --- a/inc/graph/utils/graph_utils.h +++ b/inc/graph/utils/graph_utils.h @@ -29,6 +29,18 @@ #include "graph/graph.h" #include "graph/model.h" +#define GE_DUMP(compute_graph, name) \ + do { \ + GraphUtils::DumpGEGraph(compute_graph, name); \ + GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \ + for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \ + static int8_t i = 0; \ + auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \ + GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \ + GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \ + } \ + } while (0) + #define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ do { \ DataType ret; \ @@ -155,6 +167,8 @@ class GraphUtils { static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, const NodePtr &new_node); + static graphStatus RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node); + static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node); static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, @@ -299,6 +313,24 @@ class GraphUtils { std::map> &symbol_to_anchors, std::map &anchor_to_symbol); + /// + /// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs + /// of the graph have UNKNOWN_SHAPE operators or not. + /// Note: This function will only look 'down' from the graph, not 'up'. For example, the following + /// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE + /// ROOT graph: A -----> B -----> C + /// K subgraph U + /// | + /// V + /// SUB graph: D --> E --> F + /// K K K + /// @param [in] graph + /// @return bool + /// + static bool IsUnknownShapeGraph(const ComputeGraphPtr &graph); + + static NodePtr FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name); + private: /// /// Get reference-mapping for in_data_anchors of node @@ -438,6 +470,11 @@ class ComputeGraphBuilder { /// NodePtr GetNode(const std::string &name); + /// @brief Get all nodes + /// @return std::vector + /// + std::vector GetAllNodes(); + protected: /// /// @brief Build nodes @@ -535,6 +572,13 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { /// CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); + /// + /// @brief Add target for graph + /// @param [in] target_name + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &AddTarget(const std::string &target_name); + /// /// @brief Set parent-node of graph /// @param [in] parent_node @@ -590,10 +634,19 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { /// void AddRetValNodes(graphStatus &error_code, std::string &error_msg); + /// + /// @brief Build target-nodes for graph + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void BuildGraphTargets(graphStatus &error_code, std::string &error_msg); + std::string name_; NodePtr parent_node_; std::map, std::vector>> graph_inputs_; std::vector> graph_outputs_; + std::vector graph_targets_; // index_of_graph_input -> in_anchor_index_of_parent_node std::map input_mapping_; diff --git a/inc/graph/utils/node_utils.h b/inc/graph/utils/node_utils.h index e4c18d51..6e0e655d 100644 --- a/inc/graph/utils/node_utils.h +++ b/inc/graph/utils/node_utils.h @@ -17,10 +17,23 @@ #ifndef INC_GRAPH_UTILS_NODE_UTILS_H_ #define INC_GRAPH_UTILS_NODE_UTILS_H_ +#include #include #include #include "graph/node.h" + namespace ge { +// Op types of Const like Opps. +extern const std::set kConstOpTypes; +// Op types of If like Opps. +extern const std::set kIfOpTypes; +// Op types of While like Opps. +extern const std::set kWhileOpTypes; +// Op types of Case like Opps. +extern const std::set kCaseOpTypes; +// Op types of For like Opps. +extern const std::set kForOpTypes; + class NodeUtils { public: static graphStatus AddSendEventId(const NodePtr &node, const uint32_t &event_id); @@ -94,6 +107,13 @@ class NodeUtils { /// static bool GetConstOpType(const NodePtr &in_node, std::string &op_type); + /// + /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. + /// @param [in] node + /// @return return GRAPH_SUCCESS if remove successfully, other for failed. + /// + static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); + private: static std::map> map_send_info_; static std::map> map_recv_info_; diff --git a/inc/graph/utils/type_utils.h b/inc/graph/utils/type_utils.h index f5f8234d..35b8cf22 100644 --- a/inc/graph/utils/type_utils.h +++ b/inc/graph/utils/type_utils.h @@ -24,6 +24,7 @@ #include "graph/ge_error_codes.h" #include "graph/types.h" #include "graph/usr_types.h" +#include "register/register_types.h" namespace ge { class TypeUtils { @@ -37,6 +38,7 @@ class TypeUtils { static std::string FormatToSerialString(Format format); static Format SerialStringToFormat(const std::string &str); static Format DataFormatToFormat(const std::string &str); + static Format DomiFormatToFormat(domi::domiTensorFormat_t domi_format); static graphStatus Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def); static graphStatus Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr); diff --git a/src/common/graph/compute_graph.cc b/src/common/graph/compute_graph.cc index 591ff0b5..0729e64e 100644 --- a/src/common/graph/compute_graph.cc +++ b/src/common/graph/compute_graph.cc @@ -36,6 +36,75 @@ namespace ge { namespace { const size_t OUTPUT_PARAM_SIZE = 2; +bool IsUseBFS() { + string run_mode; + const int base = 10; + if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) { + if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) >= TRAIN) { + return true; + } + } else { + GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); + } + return false; +} +bool IsTailingOptimization() { + string is_tailing_optimization_option; + auto ret = GetContext().GetOption(ge::OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, is_tailing_optimization_option); + if (ret == GRAPH_SUCCESS) { + GELOGI("Option ge.exec.isTailingOptimization is %s", is_tailing_optimization_option.c_str()); + // "1" means it's True from frontend option + return is_tailing_optimization_option == "1"; + } + GELOGW("OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION not set, use BFSTopologicalSorting by default."); + return false; +} +bool IsFusedNode(const NodePtr &node) { + bool is_fused_node = false; + AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_FLAG, is_fused_node); + return is_fused_node; +} +string GetGroupId(const NodePtr &node) { + string group_id; + AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_id); + return group_id; +} +bool IsGroupEnd(const NodePtr &node) { + if (GetGroupId(node).empty()) { + return false; + } + if (node->GetOutDataNodesSize() == 0) { + return true; + } + for (const auto &out_data_node : node->GetOutDataNodes()) { + if (IsFusedNode(out_data_node)) { + return true; + } + } + return false; +} +void SplitNodeToStack(const std::map &breadth_node_map, string current_group_id, + std::vector &stack_input, std::deque &group_stack, std::deque &stack) { + for (const auto &name_node : breadth_node_map) { + // group first + string group_id; + if (AttrUtils::GetStr(name_node.second->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_id)) { + GELOGI("current node %s, group id: %s , current group id %s", name_node.second->GetName().c_str(), + group_id.c_str(), current_group_id.c_str()); + if (!current_group_id.empty() && group_id != current_group_id) { + GELOGI("node go to input_stack back: %s", name_node.second->GetName().c_str()); + (void)stack_input.insert(stack_input.begin(), name_node.second); + } else { + current_group_id = group_id; + GELOGI("node go to group_stack: %s", name_node.second->GetName().c_str()); + (void)group_stack.push_front(name_node.second); + } + continue; + } + GELOGI("node go to stack: %s ", name_node.second->GetName().c_str()); + (void)stack.push_front(name_node.second); + } +} } // namespace GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) @@ -546,24 +615,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode( /// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::UpdateInputMapping(const std::map &input_mapping) { - size_t update_num = 0; for (auto &input : nodes_) { - if (update_num >= input_mapping.size()) { - break; - } - uint32_t cur_index = 0; - if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { - continue; - } - auto iter = input_mapping.find(cur_index); - if (iter == input_mapping.end()) { - continue; - } - if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { - GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); - return GRAPH_FAILED; + if (input->GetType() == DATA) { + uint32_t cur_index = 0; + if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { + continue; + } + auto iter = input_mapping.find(cur_index); + if (iter == input_mapping.end()) { + continue; + } + if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { + GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); + return GRAPH_FAILED; + } } - update_num++; } return GRAPH_SUCCESS; @@ -719,10 +785,10 @@ graphStatus ComputeGraph::BFSTopologicalSorting(std::vector &node_vec, node = stack_input.back(); stack_input.pop_back(); } + node_vec.push_back(node); GE_CHECK_NOTNULL(node->GetOpDesc()); GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); - CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map); for (const auto &name_node : breadth_node_map) { @@ -730,7 +796,65 @@ graphStatus ComputeGraph::BFSTopologicalSorting(std::vector &node_vec, } breadth_node_map.clear(); } + return GRAPH_SUCCESS; +} +graphStatus ComputeGraph::BFSTopologicalSortingWithGroup(std::vector &node_vec, + std::map &map_in_edge_num, + std::deque &stack) { + GELOGI("Runing_Bfs_Sort_With_Group"); + std::string current_group_id; + std::vector stack_input; + std::deque group_stack; + std::deque fused_node_stack; + std::map breadth_node_map; + // Record the number of non data nodes but no input nodes + GE_CHK_BOOL_EXEC(SortNodes(stack_input, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); + + // Only data nodes here + while (!stack_input.empty() || !stack.empty() || !group_stack.empty()) { + NodePtr node = nullptr; + if (!group_stack.empty()) { + // Traversal node in group has priority + node = group_stack.back(); + group_stack.pop_back(); + } else if (!stack.empty()) { + node = stack.back(); + stack.pop_back(); + } else { + node = stack_input.back(); + stack_input.pop_back(); + } + + if (IsFusedNode(node) && current_group_id.empty()) { + current_group_id = node->GetName(); + } + if (GetGroupId(node).empty() || GetGroupId(node) == current_group_id) { + node_vec.push_back(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + GELOGI("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); + } else { + if (current_group_id.empty()) { + current_group_id = GetGroupId(node); + node_vec.push_back(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + GELOGI("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); + } else { + GELOGI("current group id is %s ,node go to input_stack back: %s", current_group_id.c_str(), + node->GetName().c_str()); + (void)stack_input.insert(stack_input.begin(), node); + continue; + } + } + CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map); + SplitNodeToStack(breadth_node_map, current_group_id, stack_input, group_stack, stack); + breadth_node_map.clear(); + // check the end of group + if (IsGroupEnd(node)) { + GELOGI("Current node %s is end of group %s.", node->GetName().c_str(), current_group_id.c_str()); + current_group_id = ""; + } + } return GRAPH_SUCCESS; } @@ -751,15 +875,14 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::mapGetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor - : node->GetOutControlAnchor()->GetPeerAnchors()) { + if (node->GetOutControlAnchor() != nullptr) { + for (AnchorPtr peer_in_anchor : node->GetOutControlAnchor()->GetPeerAnchors()) { auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); if (iter != map_in_edge_num.end() && 0 == --iter->second) { (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); } - }) + } + } return GRAPH_SUCCESS; } @@ -796,21 +919,18 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog graphStatus ComputeGraph::TopologicalSortingGraph() { std::vector node_vec; std::map map_in_edge_num; - bool use_BFS = false; - string run_mode; - const int base = 10; - if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) { - if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) >= TRAIN) { - use_BFS = true; - } - } else { - GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); - } - + bool use_BFS = IsUseBFS(); + bool is_tailing_optimization = IsTailingOptimization(); if (use_BFS) { std::deque stack; - if (BFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { - return GRAPH_FAILED; + if (is_tailing_optimization) { + if (BFSTopologicalSortingWithGroup(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } else { + if (BFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } } } else { std::vector stack; diff --git a/src/common/graph/debug/ge_op_types.h b/src/common/graph/debug/ge_op_types.h index 3c511bdd..da36f72c 100644 --- a/src/common/graph/debug/ge_op_types.h +++ b/src/common/graph/debug/ge_op_types.h @@ -48,6 +48,12 @@ GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2"); GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); +// Horovod operator +GE_REGISTER_OPTYPE(HVDCALLBACKALLREDUCE, "hvdCallbackAllreduce"); +GE_REGISTER_OPTYPE(HVDCALLBACKALLGATHER, "hvdCallbackAllgather"); +GE_REGISTER_OPTYPE(HVDCALLBACKBROADCAST, "hvdCallbackBroadcast"); +GE_REGISTER_OPTYPE(HVDWAIT, "hvdWait"); + GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); GE_REGISTER_OPTYPE(RECV, "Recv"); diff --git a/src/common/graph/ge_attr_define.cc b/src/common/graph/ge_attr_define.cc index 23c1cff0..1f427cf3 100644 --- a/src/common/graph/ge_attr_define.cc +++ b/src/common/graph/ge_attr_define.cc @@ -76,6 +76,10 @@ const std::string ATTR_NAME_ALGO = "algo"; const std::string ATTR_NAME_FORMAT = "format"; +const std::string ATTR_NAME_STORAGE_FORMAT = "storage_format"; + +const std::string ATTR_NAME_STORAGE_SHAPE = "storage_shape"; + const std::string ATTR_NAME_FILTER_FORMAT = "filter_format"; const std::string ATTR_NAME_LRN_K = "lrn_k"; @@ -115,6 +119,7 @@ const std::string ATTR_NAME_AIPP = "aipp"; const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; +const std::string ATTR_NAME_PARENT_GRAPH_NAME = "_parent_graph_name"; const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; @@ -697,6 +702,8 @@ const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; +const std::string ATTR_MODEL_OUT_NODES_NAME = "attr_model_out_nodes_name"; + const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; @@ -895,6 +902,7 @@ const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; +const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE = "subgraph_first_active"; const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; @@ -973,12 +981,15 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; // functional ops attr +const std::string ATTR_NAME_IF_THEN_BRANCH = "then_branch"; +const std::string ATTR_NAME_IF_ELSE_BRANCH = "else_branch"; const std::string ATTR_NAME_WHILE_COND = "cond"; const std::string ATTR_NAME_WHILE_BODY = "body"; // used for label switch const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; +const std::string ATTR_NAME_SUBGRAPH_END_NODE = "_subgraph_end_node"; const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; @@ -990,4 +1001,11 @@ const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST = "_valid_input_shape_li const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_list_list"; const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; + +// used for Horovod +const std::string ATTR_INTER_EVENT_IDENTIFY = "event_id"; +const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; +// used for allreduce tailing optimization +const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; +const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; } // namespace ge diff --git a/src/common/graph/op_desc.cc b/src/common/graph/op_desc.cc index 582cfa9a..ba3c9b33 100644 --- a/src/common/graph/op_desc.cc +++ b/src/common/graph/op_desc.cc @@ -878,7 +878,7 @@ graphStatus OpDesc::CommonVerify() const { // Checking shape of all inputs vector ishape = GetInputDescPtr(iname)->GetShape().GetDims(); for (int64_t dim : ishape) { - GE_CHK_BOOL_RET_STATUS(dim >= -1, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", + GE_CHK_BOOL_RET_STATUS(dim >= -2, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", iname.c_str()); } } @@ -1310,4 +1310,25 @@ OpDesc::GetSubgraphTypeByIrName(const std::string &name) const { } return iter->second; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDesc::GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const { + for (size_t idx = 0; idx < subgraph_instance_names_.size(); ++idx) { + if (subgraph_instance_names_[idx] != instance_name) { // find subgraph index. + continue; + } + + for (auto name_to_index : subgraph_names_to_index_) { + if (name_to_index.second != idx) { // find subgraph name. + continue; + } + + subgraph_name = name_to_index.first; + return GRAPH_SUCCESS; + } + } + + return GRAPH_PARAM_INVALID; +} + } // namespace ge diff --git a/src/common/graph/operator.cc b/src/common/graph/operator.cc index c4ff7ac5..8adf56c1 100644 --- a/src/common/graph/operator.cc +++ b/src/common/graph/operator.cc @@ -30,9 +30,11 @@ #include "framework/common/debug/ge_log.h" #include "graph/compute_graph.h" #include "graph/ge_attr_value.h" +#include "graph/ge_context.h" #include "graph/ge_tensor.h" #include "graph/node.h" #include "graph/op_desc.h" +#include "graph/runtime_inference_context.h" #include "graph/usr_types.h" #include "utils/graph_utils.h" #include "utils/op_desc_utils.h" @@ -349,48 +351,54 @@ class OperatorImpl : public std::enable_shared_from_this { InferenceContextPtr GetInferenceContext() const { return inference_context_; } - void SubgraphRegister(const std::string &name, bool dynamic) { - op_desc_->RegisterSubgraphIrName(name, dynamic ? kDynamic : kStatic); + void SubgraphRegister(const std::string &ir_name, bool dynamic) { + op_desc_->RegisterSubgraphIrName(ir_name, dynamic ? kDynamic : kStatic); } - void SubgraphCountRegister(const std::string &name, uint32_t count) { - if (op_desc_->GetSubgraphTypeByIrName(name) == kStatic) { - op_desc_->AddSubgraphName(name); + void SubgraphCountRegister(const std::string &ir_name, uint32_t count) { + if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kStatic) { + op_desc_->AddSubgraphName(ir_name); + subgraph_names_to_builders_[ir_name] = nullptr; } else { for (uint32_t i = 0; i < count; ++i) { - op_desc_->AddSubgraphName(name + std::to_string(i)); + string key_name = ir_name + std::to_string(i); + op_desc_->AddSubgraphName(key_name); + subgraph_names_to_builders_[key_name] = nullptr; } } - - subgraph_names_to_builders_[name].resize(count, nullptr); } - void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder) { - auto iter = subgraph_names_to_builders_.find(name); - if (iter == subgraph_names_to_builders_.end()) { - GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u, invalid name", name.c_str(), index); - return; + void SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { + string key_name = ir_name; + if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { + key_name += std::to_string(index); } - if (iter->second.size() <= index) { - GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u, excceds the max size %zu", - name.c_str(), index, iter->second.size()); + + auto it = subgraph_names_to_builders_.find(key_name); + if (it == subgraph_names_to_builders_.end()) { + GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u.", ir_name.c_str(), index); return; } - iter->second[index] = builder; + it->second = builder; } - SubgraphBuilder GetSubgraphBuilder(const std::string &name, uint32_t index) const { + SubgraphBuilder GetSubgraphBuilder(const std::string &ir_name, uint32_t index) const { + string key_name = ir_name; + if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { + key_name += std::to_string(index); + } + + return GetSubgraphBuilder(key_name); + } + + SubgraphBuilder GetSubgraphBuilder(const std::string &name) const { auto iter = subgraph_names_to_builders_.find(name); if (iter == subgraph_names_to_builders_.end()) { - GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s index %u, invalid name", name.c_str(), index); - return nullptr; - } - if (iter->second.size() <= index) { - GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s index %u, excceds the max size %zu", - name.c_str(), index, iter->second.size()); + GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s", name.c_str()); return nullptr; } - return iter->second[index]; + + return iter->second; } std::vector GetSubgraphNames() const { @@ -408,12 +416,11 @@ class OperatorImpl : public std::enable_shared_from_this { private: ge::ConstNodePtr node_{nullptr}; ge::InferenceContextPtr inference_context_; - GraphBuilderCallback graph_builder_callback_; std::map> output_links_{}; std::map input_link_{}; std::vector> control_input_link_{}; std::vector> control_output_link_{}; - std::map> subgraph_names_to_builders_; + std::map subgraph_names_to_builders_; }; // Used to manage OperatorImpl instances created by ge api. @@ -582,6 +589,17 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co return const_op.GetAttr(op::Const::name_attr_value(), data); } } + + // Try get from runtime inference context + auto session_id = std::to_string(GetContext().SessionId()); + RuntimeInferenceContext *runtime_infer_ctx = nullptr; + if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { + GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); + auto ret = runtime_infer_ctx->GetTensor(peer_node_ptr->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); + if (ret == GRAPH_SUCCESS) { + return GRAPH_SUCCESS; + } + } } else { // For outer graph return GetInputConstDataOut(dst_name, data); @@ -1204,25 +1222,27 @@ void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) { operator_impl_->SubgraphCountRegister(name, count); } -void Operator::SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder) { +void Operator::SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", ir_name.c_str()); return; } - operator_impl_->SetSubgraphBuilder(name, index, builder); + operator_impl_->SetSubgraphBuilder(ir_name, index, builder); } std::vector Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); } -SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &name, uint32_t index) const { +SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &ir_name, uint32_t index) const { if (operator_impl_ == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr."); return nullptr; } - return operator_impl_->GetSubgraphBuilder(name, index); + return operator_impl_->GetSubgraphBuilder(ir_name, index); } -SubgraphBuilder Operator::GetSubgraphBuilder(const string &name) const { return GetDynamicSubgraphBuilder(name, 0); } +SubgraphBuilder Operator::GetSubgraphBuilder(const string &ir_name) const { + return GetDynamicSubgraphBuilder(ir_name, 0); +} Graph Operator::GetSubgraph(const string &name) const { if (operator_impl_ == nullptr) { @@ -1307,8 +1327,8 @@ class GraphBuilderImpl { } } GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, - "User Input do not include operator such as \ - Data, Variable operator or operator that has output but no input."); + "User Input do not include operator such as " + "Data, Variable operator or operator that has output but no input."); auto ret = WalkAllOperators(vec_inputs); GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); @@ -1361,8 +1381,67 @@ class GraphBuilderImpl { vec_op_back_forward.push_back(in_link.lock()); } que.push(vec_op_back_forward); + + if (WalkAllSubgraphs(node_ptr, op_impl) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } } } + return MoveSubgraphToRoot(graph_); + } + + graphStatus WalkAllSubgraphs(const NodePtr &node, const OperatorImplPtr &op_impl) { + const string name = node->GetName(); + for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) { + const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first); + GE_CHK_BOOL_EXEC(builder != nullptr, return GRAPH_FAILED, "Node: %s, Get builder failed.", name.c_str()); + + Graph graph = builder(); // Build subgraph from user define builder. + const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph); + GE_CHK_BOOL_EXEC(subgraph != nullptr, return GRAPH_FAILED, "Node: %s, Build graph failed.", name.c_str()); + + subgraph->SetParentNode(node); + subgraph->SetParentGraph(graph_); + if (graph_->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + + if (op_impl->op_desc_->SetSubgraphInstanceName(name_idx.second, subgraph->GetName()) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to set subgraph %s index %u", subgraph->GetName().c_str(), name_idx.second); + return GRAPH_FAILED; + } + } + + return GRAPH_SUCCESS; + } + + graphStatus MoveSubgraphToRoot(const ComputeGraphPtr &graph) { + const ComputeGraphPtr &root_graph = GraphUtils::FindRootGraph(graph); + if (root_graph == nullptr) { + GELOGE(GRAPH_FAILED, "Graph: %s, Find root graph failed.", graph->GetName().c_str()); + return GRAPH_FAILED; + } + + if (root_graph == graph) { + auto subgraphs = graph->GetAllSubgraphs(); + for (auto &subgraph : subgraphs) { + if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } + } else { + auto subgraphs = graph->GetAllSubgraphs(); + for (auto &subgraph : subgraphs) { + if (root_graph->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + graph->RemoveSubgraph(subgraph->GetName()); + if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } + } + return GRAPH_SUCCESS; } @@ -1423,11 +1502,22 @@ class GraphBuilderImpl { }; inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { + for (const auto &graph : compute_graph->GetAllSubgraphs()) { + std::set node_names; + for (auto const &node : graph->GetDirectNode()) { + node_names.insert(node->GetName()); + } + + if (node_names.size() != graph->GetDirectNodesSize()) { + return true; + } + } + std::set node_names; - for (auto const &node : compute_graph->GetAllNodes()) { + for (auto const &node : compute_graph->GetDirectNode()) { node_names.insert(node->GetName()); } - return node_names.size() != compute_graph->GetAllNodes().size(); + return node_names.size() != compute_graph->GetDirectNodesSize(); } ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector &inputs) { diff --git a/src/common/graph/ref_relation.cc b/src/common/graph/ref_relation.cc index cacf213f..b3cf37af 100644 --- a/src/common/graph/ref_relation.cc +++ b/src/common/graph/ref_relation.cc @@ -136,17 +136,11 @@ graphStatus RefRelations::Impl::BuildRefRelationsForBranch( out_ref_i_all_refs.emplace_back(cell_root); for (const auto &ele : ref_o_net_nodes) { RefCell cell_netoutput_in; - RefCell cell_netoutput_out; cell_netoutput_in.node_name = (ele.first)->GetName(); cell_netoutput_in.node = ele.first; cell_netoutput_in.in_out = NODE_IN; cell_netoutput_in.in_out_idx = ele.second; - cell_netoutput_out.node_name = (ele.first)->GetName(); - cell_netoutput_out.node = ele.first; - cell_netoutput_out.in_out = NODE_OUT; - cell_netoutput_out.in_out_idx = ele.second; out_ref_i_all_refs.emplace_back(cell_netoutput_in); - out_ref_i_all_refs.emplace_back(cell_netoutput_out); } node_refs.emplace_back(out_ref_i_all_refs); ref_o++; @@ -155,6 +149,7 @@ graphStatus RefRelations::Impl::BuildRefRelationsForBranch( } graphStatus RefRelations::Impl::BuildLookUpTables() { + GELOGD("start to build look up table!"); for (size_t i = 0; i < values_.size(); i++) { vector> &val = values_[i]; for (const auto &ele : val) { @@ -216,12 +211,7 @@ graphStatus RefRelations::Impl::BuildRefRelationsForWhile( cell_netoutput_in.node = ele.first; cell_netoutput_in.in_out = NODE_IN; cell_netoutput_in.in_out_idx = ele.second; - cell_netoutput_out.node_name = (ele.first)->GetName(); - cell_netoutput_out.node = ele.first; - cell_netoutput_out.in_out = NODE_OUT; - cell_netoutput_out.in_out_idx = ele.second; ref_i_all_refs.emplace_back(cell_netoutput_in); - ref_i_all_refs.emplace_back(cell_netoutput_out); } node_refs.emplace_back(ref_i_all_refs); ref_i++; @@ -237,13 +227,10 @@ graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType( auto node_type = root_node->GetType(); auto status = GRAPH_SUCCESS; - if (node_type == kIf || node_type == kCase) { + if (node_type != kWhile) { status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); - } else if (node_type == kWhile) { - status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); } else { - GELOGE(GRAPH_PARAM_INVALID, "Node type [%s] is not supported for build ref relations!", node_type.c_str()); - status = GRAPH_PARAM_INVALID; + status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); } return status; } @@ -291,6 +278,7 @@ graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::Comput graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector &data_nodes, vector> &classed_data_nodes) { + GELOGD("start to process subgraph data nodes!"); int max_ref_idx = 0; for (const auto &e : data_nodes) { int i; @@ -315,6 +303,7 @@ graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector &data_n graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( const vector &netoutput_nodes, vector>> &classed_netoutput_nodes) { + GELOGD("[RefRelations]Start to process subgraph netoutput!"); for (const auto &sub_netoutput_node : netoutput_nodes) { auto op_desc = sub_netoutput_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -340,6 +329,7 @@ graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( } graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { + GELOGD("Start to build ref relations!"); /* First Step: Get root graph */ ge::ComputeGraph &root_graph = graph; auto status = GetRootGraph(graph, root_graph); @@ -349,12 +339,12 @@ graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { for (const auto &node : graph.GetAllNodes()) { auto node_type = node->GetType(); - if (function_op.find(node_type) == function_op.end()) { - continue; - } std::vector ref_nodes; auto op_desc = node->GetOpDesc(); auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + continue; + } vector data_nodes; vector netoutput_nodes; // Get data and netoutput of sub_graph diff --git a/src/common/graph/runtime_inference_context.cc b/src/common/graph/runtime_inference_context.cc new file mode 100644 index 00000000..916da564 --- /dev/null +++ b/src/common/graph/runtime_inference_context.cc @@ -0,0 +1,95 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/runtime_inference_context.h" +#include +#include "framework/common/debug/ge_log.h" + +namespace ge { +std::map> RuntimeInferenceContext::contexts_; +std::mutex RuntimeInferenceContext::ctx_mu_; + +graphStatus RuntimeInferenceContext::CreateContext(const std::string &context_id) { + GELOGI("To create context. session id = %s", context_id.c_str()); + auto ctx = std::unique_ptr(new (std::nothrow) RuntimeInferenceContext()); + if (ctx == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to create instance of RuntimeInferenceContext. context_id = %s", context_id.c_str()); + return GRAPH_FAILED; + } + + auto emplace_ret = contexts_.emplace(context_id, std::move(ctx)); + if (!emplace_ret.second) { + GELOGE(GRAPH_FAILED, "Old context not destroyed"); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +void RuntimeInferenceContext::DestroyContext(const std::string &context_id) { + GELOGI("To destroy context. session id = %s", context_id.c_str()); + std::lock_guard lk(ctx_mu_); + contexts_.erase(context_id); +} + +graphStatus RuntimeInferenceContext::GetContext(const std::string &context_id, RuntimeInferenceContext **ctx) { + std::lock_guard lk(ctx_mu_); + auto it = contexts_.find(context_id); + if (it != contexts_.end()) { + *ctx = it->second.get(); + return GRAPH_SUCCESS; + } + + GELOGD("Runtime inference context not created. session id = %s", context_id.c_str()); + return GRAPH_FAILED; +} + +graphStatus RuntimeInferenceContext::SetTensor(int64_t node_id, int output_id, Tensor &&tensor) { + std::lock_guard lk(mu_); + auto &output_tensors = tensors_[node_id]; + if (static_cast(output_id) >= output_tensors.size()) { + output_tensors.resize(output_id + 1); + } + + GELOGD("Set tensor for node_id = %ld, output_id = %d", node_id, output_id); + output_tensors[output_id] = std::move(tensor); + return GRAPH_SUCCESS; +} + +graphStatus RuntimeInferenceContext::GetTensor(int64_t node_id, int output_id, Tensor &tensor) { + if (output_id < 0) { + GELOGE(GRAPH_PARAM_INVALID, "Invalid output index: %d", output_id); + return GRAPH_PARAM_INVALID; + } + + std::lock_guard lk(mu_); + auto iter = tensors_.find(node_id); + if (iter == tensors_.end()) { + GELOGE(INTERNAL_ERROR, "Node not register. Id = %ld", node_id); + return INTERNAL_ERROR; + } + + auto &output_tensors = iter->second; + if (static_cast(output_id) >= output_tensors.size()) { + GELOGE(GRAPH_FAILED, "Node output is not registered. node_id = %ld, output index = %d", node_id, output_id); + return GRAPH_FAILED; + } + + GELOGD("Get tensor for node_id = %ld, output_id = %d", node_id, output_id); + tensor = output_tensors[output_id]; + return GRAPH_SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/src/common/graph/utils/ge_ir_utils.cc b/src/common/graph/utils/ge_ir_utils.cc index b6367011..c08ea9ab 100644 --- a/src/common/graph/utils/ge_ir_utils.cc +++ b/src/common/graph/utils/ge_ir_utils.cc @@ -273,6 +273,9 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const auto data_type = TypeUtils::DataTypeToSerialString(input_desc->GetDataType()); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_dtype:" + std::to_string(i), &data_type); + auto data_type_origin = TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "input_desc_origin_dtype:" + std::to_string(i), &data_type_origin); auto dims = input_desc->GetShape().GetDims(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_desc_shape:" + std::to_string(i), &dims); @@ -346,6 +349,9 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const auto data_type = TypeUtils::DataTypeToSerialString(output_desc->GetDataType()); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_dtype:" + std::to_string(i), &data_type); + auto origin_data_type = TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "output_desc_origin_dtype:" + std::to_string(i), &origin_data_type); auto dims = output_desc->GetShape().GetDims(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_desc_shape:" + std::to_string(i), &dims); diff --git a/src/common/graph/utils/graph_utils.cc b/src/common/graph/utils/graph_utils.cc index c495ffc9..c4057c95 100644 --- a/src/common/graph/utils/graph_utils.cc +++ b/src/common/graph/utils/graph_utils.cc @@ -61,6 +61,7 @@ const char *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; const char *const kDumpStrBuild = "Build"; const char *const kDumpStrPartition = "partition"; const char *const kDumpStrOptimizeSubgraph = "OptimizeSubGraph"; +const char *const kDumpStrSubgraphFunc = "sub_graph"; const char *const kDumpStrAicpu = "Aicpu"; }; // namespace @@ -202,6 +203,58 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertNod return GRAPH_SUCCESS; } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node) { + GE_CHECK_NOTNULL(compute_graph); + if (remove_node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return GRAPH_FAILED; + } + + // Check if this node is belong to this compute graph, maybe a little slow + const auto &all_nodes_in_graph = compute_graph->GetDirectNode(); + if (std::find(all_nodes_in_graph.begin(), all_nodes_in_graph.end(), remove_node) == all_nodes_in_graph.end()) { + GELOGE(GRAPH_FAILED, "Can not find node %s in graph %s.", remove_node->GetName().c_str(), + compute_graph->GetName().c_str()); + return GRAPH_FAILED; + } + // Find all subgraph of this node + const auto &root_graph = GraphUtils::FindRootGraph(compute_graph); + std::vector subgraphs; + std::vector all_nodes; + std::deque candidates; + NodePtr remove_node_new = remove_node; + candidates.emplace_back(remove_node_new); + while (!candidates.empty()) { + const NodePtr node = candidates.front(); + all_nodes.emplace_back(node); + candidates.pop_front(); + + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + + const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); + for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { + auto subgraph = root_graph->GetSubgraph(*name_iter); + if (subgraph != nullptr) { + subgraphs.emplace_back(subgraph); + candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); + } + } + } + // Remove all subgraph + for (const auto &remove_graph : subgraphs) { + if (root_graph->RemoveSubGraph(remove_graph) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Remove subgraph failed, sub graph name is %s, compute graph is %s.", + remove_node->GetName().c_str(), compute_graph->GetName().c_str()); + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node) { GE_CHECK_NOTNULL(compute_graph); @@ -217,12 +270,10 @@ GraphUtils::RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const (void)compute_graph->RemoveOutputNode(node); // If the node has sub-graphs, delete them - auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); - if (!sub_graph_names.empty()) { - auto root_graph = FindRootGraph(compute_graph); - for (const auto &name : sub_graph_names) { - root_graph->RemoveSubgraph(name); - } + auto ret = RemoveSubgraphRecursively(compute_graph, node); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Remove subgraph recursively failed."); + return GRAPH_FAILED; } auto iter = find(compute_graph->nodes_.begin(), compute_graph->nodes_.end(), node); @@ -484,9 +535,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::MatchDumpStr(con return false; } - if (dump_graph_level == kDumpLevel2 && ((suffix.find(kDumpStrPartition) != std::string::npos) || - (suffix.find(kDumpStrOptimizeSubgraph) != std::string::npos) || - (suffix.find(kDumpStrAicpu) != std::string::npos))) { + if (dump_graph_level == kDumpLevel2 && + ((suffix.find(kDumpStrPartition) != std::string::npos) || + (suffix.find(kDumpStrOptimizeSubgraph) != std::string::npos) || + (suffix.find(kDumpStrAicpu) != std::string::npos) || (suffix.find(kDumpStrSubgraphFunc) != std::string::npos))) { return true; } @@ -1026,9 +1078,9 @@ graphStatus ReplaceControlAnchors(const NodePtr &new_node, const NodePtr &old_no GE_CHECK_NOTNULL(old_out_control_anchor); auto peer_in_anchors = old_out_control_anchor->GetPeerAnchors(); auto new_out_control_anchor = new_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(new_out_control_anchor); auto exists_in_anchors = new_out_control_anchor->GetPeerAnchors(); auto exists_in_anchors_set = std::set(exists_in_anchors.begin(), exists_in_anchors.end()); - GE_CHECK_NOTNULL(new_out_control_anchor); for (const auto &peer_in_anchor : peer_in_anchors) { if (peer_in_anchor != nullptr) { if (exists_in_anchors_set.count(peer_in_anchor) > 0) { @@ -1304,6 +1356,26 @@ graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, return GRAPH_SUCCESS; } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr GraphUtils::FindNodeFromAllNodes(ComputeGraphPtr &graph, + const std::string &name) { + auto root_graph = FindRootGraph(graph); + if (root_graph == nullptr) { + GE_LOGE("Failed find node %s, null root graph", name.c_str()); + return nullptr; + } + + for (const auto &node : root_graph->GetAllNodes()) { + if (node == nullptr) { + continue; + } + if (node->GetName() == name) { + return node; + } + } + + return nullptr; +} + /// /// Get reference-mapping for in_data_anchors of node /// @param [in] node @@ -1668,7 +1740,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t for (const auto &input_name : op_desc->GetAllInputNames()) { if (!input_name.empty() && (output_name == input_name)) { reuse_in_index = op_desc->GetInputIndexByName(input_name); - GELOGI("Reference name[%s] output[%s][%u] ref to input[%s][%d].", op_desc->GetName().c_str(), + GELOGI("Reference name[%s] output[%s][%d] ref to input[%s][%d].", op_desc->GetName().c_str(), output_name.c_str(), output_index, input_name.c_str(), reuse_in_index); return true; } @@ -1693,6 +1765,43 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t return false; } +/// +/// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs +/// of the graph have UNKNOWN_SHAPE operators or not. +/// Note: This function will only look 'down' from the graph, not 'up'. For example, the following +/// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE +/// ROOT graph: A -----> B -----> C +/// K subgraph U +/// | +/// V +/// SUB graph: D --> E --> F +/// K K K +/// @param [in] graph +/// @return bool +/// +bool GraphUtils::IsUnknownShapeGraph(const ComputeGraphPtr &graph) { + if (graph == nullptr) { + GELOGW("Input graph is nullptr."); + return false; + } + for (const auto &node : graph->GetDirectNode()) { + bool is_unknown = false; + auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); + if (ret != GRAPH_SUCCESS) { + GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), + node->GetType().c_str()); + continue; + } + if (is_unknown) { + GELOGD("Node %s, type %s is unknown shape in graph %s.", node->GetName().c_str(), node->GetType().c_str(), + graph->GetName().c_str()); + return true; + } + } + GELOGD("Graph %s does not have unknown shape node.", graph->GetName().c_str()); + return false; +} + /// /// @brief Add node to graph /// @param [in] op_desc @@ -1868,6 +1977,17 @@ NodePtr ComputeGraphBuilder::GetNode(const std::string &name) { return iter->second; } +/// @brief Get all nodes +/// @return std::vector +/// +std::vector ComputeGraphBuilder::GetAllNodes() { + std::vector nodes; + for (const auto &iter : node_names_) { + nodes.emplace_back(iter.second); + } + return nodes; +} + /// /// @brief Add node to graph /// @param [in] op_desc @@ -1937,6 +2057,16 @@ CompleteGraphBuilder &CompleteGraphBuilder::AddOutput(const std::string &owner_n return *this; } +/// +/// @brief Add target for graph +/// @param [in] target_name +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder &CompleteGraphBuilder::AddTarget(const std::string &target_name) { + graph_targets_.emplace_back(target_name); + return *this; +} + /// /// @brief Set parent-node of graph /// @param [in] parent_node @@ -2013,6 +2143,11 @@ ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string return nullptr; } + BuildGraphTargets(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + // ATTR_NAME_SESSION_GRAPH_ID std::string graph_id; if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { @@ -2210,6 +2345,27 @@ void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string & GELOGD("AddRetValNodes succ."); } +/// +/// @brief Build target-nodes for graph +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void CompleteGraphBuilder::BuildGraphTargets(graphStatus &error_code, std::string &error_msg) { + std::vector target_nodes; + for (const std::string &target_name : graph_targets_) { + auto target_iter = node_names_.find(target_name); + if ((target_iter == node_names_.end()) || (target_iter->second == nullptr)) { + error_code = GRAPH_FAILED; + error_msg = "BuildGraphTargets failed: target_node " + target_name + " not exist in graph."; + return; + } + target_nodes.emplace_back(target_iter->second); + } + owner_graph_->SetGraphTargetNodesInfo(target_nodes); + return; +} + /// /// @brief Add node to graph /// @param [in] op_desc diff --git a/src/common/graph/utils/node_utils.cc b/src/common/graph/utils/node_utils.cc index 8c2ff244..e4fb8b82 100644 --- a/src/common/graph/utils/node_utils.cc +++ b/src/common/graph/utils/node_utils.cc @@ -29,6 +29,13 @@ namespace ge { std::map> NodeUtils::map_send_info_{}; std::map> NodeUtils::map_recv_info_{}; +const std::set kConstOpTypes = {"Const", "Constant"}; + +const std::set kIfOpTypes = {"If", "_If", "StatelessIf"}; +const std::set kWhileOpTypes = {"While", "_While", "StatelessWhile"}; +const std::set kCaseOpTypes = {"Case"}; +const std::set kForOpTypes = {"For"}; + bool OpShapeIsUnknown(const OpDescPtr &desc) { for (const auto &ptr : desc->GetAllInputsDescPtr()) { auto ge_shape = ptr->GetShape(); @@ -315,6 +322,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); peer_input_desc->SetDataType(output_tensor.GetDataType()); peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); + std::vector> shape_range; + (void)output_tensor.GetShapeRange(shape_range); + peer_input_desc->SetShapeRange(shape_range); ge::TensorUtils::SetRealDimCnt(*peer_input_desc, static_cast(output_tensor.GetShape().GetDims().size())); GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", @@ -477,6 +487,14 @@ bool NodeUtils::IsSubgraphInput(const NodePtr &node) { return false; } + auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); + if (parent_op_desc == nullptr) { + return false; + } + if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { + return false; + } + return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); } @@ -491,6 +509,14 @@ bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { return false; } + auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); + if (parent_op_desc == nullptr) { + return false; + } + if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { + return false; + } + for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { return true; @@ -557,4 +583,58 @@ bool NodeUtils::GetConstOpType(const NodePtr &in_node, std::string &op_type) { return false; } + +/// +/// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. +/// @param [in] node +/// @return return GRAPH_SUCCESS if remove successfully, other for failed. +/// +Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + return GRAPH_SUCCESS; + } else { + auto owner_graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(owner_graph); + auto root_graph = GraphUtils::FindRootGraph(owner_graph); + GE_CHECK_NOTNULL(root_graph); + + std::unordered_set subgraph_to_remove; + for (auto &subgraph_name : subgraph_names) { + std::deque queue; + queue.push_back(subgraph_name); + subgraph_to_remove.insert(subgraph_name); + op_desc->RemoveSubgraphInstanceName(subgraph_name); + while (!queue.empty()) { + auto graph_name = queue.front(); + queue.pop_front(); + + auto subgraph = root_graph->GetSubgraph(graph_name); + GE_CHECK_NOTNULL(subgraph); + for (const auto &sub_node : subgraph->GetDirectNode()) { + auto sub_op_desc = sub_node->GetOpDesc(); + GE_CHECK_NOTNULL(sub_op_desc); + auto sub_names = sub_op_desc->GetSubgraphInstanceNames(); + // Subgraph and all nodes in it will be removed later, + // no need to remove 'SubgraphInstanceName' in op desc here. + for (auto &name : sub_names) { + if (subgraph_to_remove.insert(name).second) { + queue.push_back(name); + } + } + } + } + } + // Remove subgraph from root_graph + for (const auto &name : subgraph_to_remove) { + GELOGI("Remove subgraph:%s.", name.c_str()); + root_graph->RemoveSubgraph(name); + } + } + + return GRAPH_SUCCESS; +} } // namespace ge diff --git a/src/common/graph/utils/op_desc_utils.cc b/src/common/graph/utils/op_desc_utils.cc index 32ae00cf..886a2952 100644 --- a/src/common/graph/utils/op_desc_utils.cc +++ b/src/common/graph/utils/op_desc_utils.cc @@ -199,6 +199,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils:: auto in_node = out_anchor->GetOwnerNode(); if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { ret.push_back(in_node); + } else if (in_node->GetType() == DATA) { + const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); + GE_CHK_BOOL_EXEC(graph != nullptr, continue, "Owner graph is null"); + + const NodePtr &parent_node = graph->GetParentNode(); + if (parent_node == nullptr) { + continue; // Root graph. + } + + if (kWhileOpTypes.count(parent_node->GetType()) > 0) { + continue; // Subgraph of While cond or body. + } + + NodePtr input_node = NodeUtils::GetParentInput(in_node); + if ((input_node != nullptr) && ((input_node->GetType() == CONSTANT) || (input_node->GetType() == CONSTANTOP))) { + ret.push_back(input_node); + } } } return ret; diff --git a/src/common/graph/utils/type_utils.cc b/src/common/graph/utils/type_utils.cc index e8ad9ed0..7d78db5f 100644 --- a/src/common/graph/utils/type_utils.cc +++ b/src/common/graph/utils/type_utils.cc @@ -17,6 +17,8 @@ #include "graph/utils/type_utils.h" #include "debug/ge_util.h" +using domi::domiTensorFormat_t; + namespace ge { static const std::map kFormatToStringMap = { {FORMAT_NCHW, "NCHW"}, @@ -60,6 +62,25 @@ static const std::map kFormatToStringMap = { {FORMAT_RESERVED, "FORMAT_RESERVED"}, {FORMAT_ALL, "ALL"}}; +static const std::map kDomiFormatToGeFormat = { + {domi::DOMI_TENSOR_NCHW, FORMAT_NCHW}, + {domi::DOMI_TENSOR_NHWC, FORMAT_NHWC}, + {domi::DOMI_TENSOR_ND, FORMAT_ND}, + {domi::DOMI_TENSOR_NC1HWC0, FORMAT_NC1HWC0}, + {domi::DOMI_TENSOR_FRACTAL_Z, FORMAT_FRACTAL_Z}, + {domi::DOMI_TENSOR_NC1C0HWPAD, FORMAT_NC1C0HWPAD}, + {domi::DOMI_TENSOR_NHWC1C0, FORMAT_NHWC1C0}, + {domi::DOMI_TENSOR_FSR_NCHW, FORMAT_FSR_NCHW}, + {domi::DOMI_TENSOR_FRACTAL_DECONV, FORMAT_FRACTAL_DECONV}, + {domi::DOMI_TENSOR_BN_WEIGHT, FORMAT_BN_WEIGHT}, + {domi::DOMI_TENSOR_CHWN, FORMAT_CHWN}, + {domi::DOMI_TENSOR_FILTER_HWCK, FORMAT_FILTER_HWCK}, + {domi::DOMI_TENSOR_NDHWC, FORMAT_NDHWC}, + {domi::DOMI_TENSOR_NCDHW, FORMAT_NCDHW}, + {domi::DOMI_TENSOR_DHWCN, FORMAT_DHWCN}, + {domi::DOMI_TENSOR_DHWNC, FORMAT_DHWNC}, + {domi::DOMI_TENSOR_RESERVED, FORMAT_RESERVED}}; + static const std::unordered_set kInternalFormat = {"NC1HWC0", "FRACTAL_Z", "NC1C0HWPAD", @@ -282,6 +303,15 @@ Format TypeUtils::DataFormatToFormat(const std::string &str) { } } +Format TypeUtils::DomiFormatToFormat(domi::domiTensorFormat_t domi_format) { + auto it = kDomiFormatToGeFormat.find(domi_format); + if (it != kDomiFormatToGeFormat.end()) { + return it->second; + } + GELOGE(GRAPH_FAILED, "do not find domi Format %d from map", domi_format); + return FORMAT_RESERVED; +} + static inline void CopyDataFromBuffer(vector &data, const Buffer &buffer) { data.clear(); if (buffer.GetData() != nullptr && buffer.GetSize() != 0) { diff --git a/src/ge/CMakeLists.txt b/src/ge/CMakeLists.txt index ba6f1d73..97d349c0 100755 --- a/src/ge/CMakeLists.txt +++ b/src/ge/CMakeLists.txt @@ -64,6 +64,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" "engine_manager/dnnengine_manager.cc" + "executor/ge_executor.cc" "ge_local_engine/engine/host_cpu_engine.cc" "generator/ge_generator.cc" "generator/generator_api.cc" @@ -107,47 +108,61 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/partition/engine_place.cc" "graph/partition/graph_partition.cc" "graph/passes/*.cc" - "graph/passes/folding_kernel/add_kernel.cc" - "graph/passes/folding_kernel/broadcast_args_kernel.cc" - "graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" - "graph/passes/folding_kernel/cast_kernel.cc" - "graph/passes/folding_kernel/concat_offset_kernel.cc" - "graph/passes/folding_kernel/concat_v2_kernel.cc" - "graph/passes/folding_kernel/dynamic_stitch_kernel.cc" - "graph/passes/folding_kernel/empty_kernel.cc" - "graph/passes/folding_kernel/expanddims_kernel.cc" - "graph/passes/folding_kernel/fill_kernel.cc" - "graph/passes/folding_kernel/floordiv_kernel.cc" - "graph/passes/folding_kernel/floormod_kernel.cc" - "graph/passes/folding_kernel/gather_v2_kernel.cc" - "graph/passes/folding_kernel/greater_kernel.cc" - "graph/passes/folding_kernel/kernel_utils.cc" - "graph/passes/folding_kernel/maximum_kernel.cc" - "graph/passes/folding_kernel/mul_kernel.cc" - "graph/passes/folding_kernel/pack_kernel.cc" - "graph/passes/folding_kernel/permute_kernel.cc" - "graph/passes/folding_kernel/range_kernel.cc" - "graph/passes/folding_kernel/rank_kernel.cc" - "graph/passes/folding_kernel/reduce_prod_kernel.cc" - "graph/passes/folding_kernel/reshape_kernel.cc" - "graph/passes/folding_kernel/rsqrt_kernel.cc" - "graph/passes/folding_kernel/shape_kernel.cc" - "graph/passes/folding_kernel/shape_n_kernel.cc" - "graph/passes/folding_kernel/size_kernel.cc" - "graph/passes/folding_kernel/slice_d_kernel.cc" - "graph/passes/folding_kernel/slice_kernel.cc" - "graph/passes/folding_kernel/squeeze_kernel.cc" - "graph/passes/folding_kernel/ssd_prior_box_kernel.cc" - "graph/passes/folding_kernel/strided_slice_kernel.cc" - "graph/passes/folding_kernel/sub_kernel.cc" - "graph/passes/folding_kernel/transdata_kernel.cc" - "graph/passes/folding_kernel/unpack_kernel.cc" + "host_kernels/add_kernel.cc" + "host_kernels/broadcast_args_kernel.cc" + "host_kernels/broadcast_gradient_args_kernel.cc" + "host_kernels/cast_kernel.cc" + "host_kernels/concat_offset_kernel.cc" + "host_kernels/concat_v2_kernel.cc" + "host_kernels/dynamic_stitch_kernel.cc" + "host_kernels/empty_kernel.cc" + "host_kernels/expanddims_kernel.cc" + "host_kernels/fill_kernel.cc" + "host_kernels/floordiv_kernel.cc" + "host_kernels/floormod_kernel.cc" + "host_kernels/gather_v2_kernel.cc" + "host_kernels/greater_kernel.cc" + "host_kernels/kernel_utils.cc" + "host_kernels/maximum_kernel.cc" + "host_kernels/mul_kernel.cc" + "host_kernels/pack_kernel.cc" + "host_kernels/permute_kernel.cc" + "host_kernels/range_kernel.cc" + "host_kernels/rank_kernel.cc" + "host_kernels/reduce_prod_kernel.cc" + "host_kernels/reshape_kernel.cc" + "host_kernels/rsqrt_kernel.cc" + "host_kernels/shape_kernel.cc" + "host_kernels/shape_n_kernel.cc" + "host_kernels/size_kernel.cc" + "host_kernels/slice_d_kernel.cc" + "host_kernels/slice_kernel.cc" + "host_kernels/squeeze_kernel.cc" + "host_kernels/ssd_prior_box_kernel.cc" + "host_kernels/strided_slice_kernel.cc" + "host_kernels/sub_kernel.cc" + "host_kernels/transdata_kernel.cc" + "host_kernels/transpose_kernel.cc" + "host_kernels/unpack_kernel.cc" "graph/preprocess/graph_preprocess.cc" "graph/preprocess/insert_op/ge_aipp_op.cc" "graph/preprocess/insert_op/util_insert_aipp_op.cc" "graph/preprocess/multi_batch_copy_graph.cc" + "hybrid/common/npu_memory_allocator.cc" + "hybrid/common/tensor_value.cc" + "hybrid/executor/*.cc" + "hybrid/executor/worker/*.cc" + "hybrid/hybrid_davinci_model.cc" + "hybrid/model/*.cc" + "hybrid/node_executor/aicore/*.cc" + "hybrid/node_executor/aicpu/aicpu_node_executor.cc" + "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" + "hybrid/node_executor/hostcpu/ge_local_node_executor.cc" + "hybrid/node_executor/node_executor.cc" + "hybrid/node_executor/task_context.cc" "init/gelib.cc" "model/ge_model.cc" + "model/ge_root_model.cc" "omm/csa_interact.cc" "opskernel_manager/ops_kernel_manager.cc" "session/inner_session.cc" @@ -231,42 +246,43 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/partition/engine_place.cc" "graph/partition/graph_partition.cc" "graph/passes/*.cc" - "graph/passes/folding_kernel/add_kernel.cc" - "graph/passes/folding_kernel/broadcast_args_kernel.cc" - "graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" - "graph/passes/folding_kernel/cast_kernel.cc" - "graph/passes/folding_kernel/concat_offset_kernel.cc" - "graph/passes/folding_kernel/concat_v2_kernel.cc" - "graph/passes/folding_kernel/dynamic_stitch_kernel.cc" - "graph/passes/folding_kernel/empty_kernel.cc" - "graph/passes/folding_kernel/expanddims_kernel.cc" - "graph/passes/folding_kernel/fill_kernel.cc" - "graph/passes/folding_kernel/floordiv_kernel.cc" - "graph/passes/folding_kernel/floormod_kernel.cc" - "graph/passes/folding_kernel/gather_v2_kernel.cc" - "graph/passes/folding_kernel/greater_kernel.cc" - "graph/passes/folding_kernel/kernel_utils.cc" - "graph/passes/folding_kernel/maximum_kernel.cc" - "graph/passes/folding_kernel/mul_kernel.cc" - "graph/passes/folding_kernel/pack_kernel.cc" - "graph/passes/folding_kernel/permute_kernel.cc" - "graph/passes/folding_kernel/range_kernel.cc" - "graph/passes/folding_kernel/rank_kernel.cc" - "graph/passes/folding_kernel/reduce_prod_kernel.cc" - "graph/passes/folding_kernel/reshape_kernel.cc" - "graph/passes/folding_kernel/rsqrt_kernel.cc" - "graph/passes/folding_kernel/shape_kernel.cc" - "graph/passes/folding_kernel/shape_n_kernel.cc" - "graph/passes/folding_kernel/size_kernel.cc" - "graph/passes/folding_kernel/slice_d_kernel.cc" - "graph/passes/folding_kernel/slice_kernel.cc" - "graph/passes/folding_kernel/squeeze_kernel.cc" - "graph/passes/folding_kernel/ssd_prior_box_kernel.cc" - "graph/passes/folding_kernel/strided_slice_kernel.cc" - "graph/passes/folding_kernel/sub_kernel.cc" - "graph/passes/folding_kernel/transdata_kernel.cc" - "graph/passes/folding_kernel/transpose_kernel.cc" - "graph/passes/folding_kernel/unpack_kernel.cc" + "host_kernels/add_kernel.cc" + "host_kernels/broadcast_args_kernel.cc" + "host_kernels/broadcast_gradient_args_kernel.cc" + "host_kernels/cast_kernel.cc" + "host_kernels/concat_offset_kernel.cc" + "host_kernels/concat_v2_kernel.cc" + "host_kernels/dynamic_stitch_kernel.cc" + "host_kernels/empty_kernel.cc" + "host_kernels/expanddims_kernel.cc" + "host_kernels/fill_kernel.cc" + "host_kernels/floordiv_kernel.cc" + "host_kernels/floormod_kernel.cc" + "host_kernels/gather_v2_kernel.cc" + "host_kernels/greater_kernel.cc" + "host_kernels/kernel_utils.cc" + "host_kernels/maximum_kernel.cc" + "host_kernels/mul_kernel.cc" + "host_kernels/pack_kernel.cc" + "host_kernels/permute_kernel.cc" + "host_kernels/range_kernel.cc" + "host_kernels/rank_kernel.cc" + "host_kernels/reduce_prod_kernel.cc" + "host_kernels/reshape_kernel.cc" + "host_kernels/rsqrt_kernel.cc" + "host_kernels/shape_kernel.cc" + "host_kernels/shape_n_kernel.cc" + "host_kernels/size_kernel.cc" + "host_kernels/slice_d_kernel.cc" + "host_kernels/slice_kernel.cc" + "host_kernels/squeeze_kernel.cc" + "host_kernels/ssd_prior_box_kernel.cc" + "host_kernels/strided_slice_kernel.cc" + "host_kernels/sub_kernel.cc" + "host_kernels/transdata_kernel.cc" + "host_kernels/transpose_kernel.cc" + "host_kernels/unpack_kernel.cc" + "hybrid/hybrid_davinci_model_stub.cc" "graph/preprocess/graph_preprocess.cc" "graph/preprocess/insert_op/ge_aipp_op.cc" "graph/preprocess/insert_op/util_insert_aipp_op.cc" @@ -275,6 +291,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "ir_build/atc_ir_common.cc" "ir_build/ge_ir_build.cc" "model/ge_model.cc" + "model/ge_root_model.cc" "omm/csa_interact.cc" "opskernel_manager/ops_kernel_manager.cc" "session/inner_session.cc" diff --git a/src/ge/client/ge_api.cc b/src/ge/client/ge_api.cc index 24126de1..51a0accd 100644 --- a/src/ge/client/ge_api.cc +++ b/src/ge/client/ge_api.cc @@ -49,7 +49,7 @@ void GetOpsProtoPath(std::string &opsproto_path) { const char *path_env = std::getenv("ASCEND_OPP_PATH"); if (path_env != nullptr) { std::string path = path_env; - opsproto_path = (path + "/op_proto/built-in/" + ":") + (path + "/op_proto/custom/"); + opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/"); GELOGI("Get opsproto so path from env: %s", path.c_str()); return; } @@ -57,7 +57,7 @@ void GetOpsProtoPath(std::string &opsproto_path) { GELOGI("path_base is %s", path_base.c_str()); path_base = path_base.substr(0, path_base.rfind('/')); path_base = path_base.substr(0, path_base.rfind('/') + 1); - opsproto_path = (path_base + "ops/op_proto/built-in/" + ":") + (path_base + "ops/op_proto/custom/"); + opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); } Status CheckDumpAndReuseMemory(const std::map &options) { @@ -103,20 +103,6 @@ Status CheckOptionsValid(const std::map &options) { return SUCCESS; } -void SaveDdkVersion(const std::map &options) { - auto ddk_option = options.find(DDK_VERSION_FLAG); - if (ddk_option != options.end()) { - auto ddk_version = ddk_option->second; - if (!ddk_version.empty()) { - GELOGI("Input ddk version : %s.", ddk_version.c_str()); - domi::GetContext().ddk_version = ddk_version; - } - } else { - GELOGW("No ddkVersion!"); - return; - } -} - // Initialize GE, prepare for execution, call GELib::Initialize Status GEInitialize(const std::map &options) { GELOGT(TRACE_INIT, "GEInitialize start"); @@ -146,9 +132,6 @@ Status GEInitialize(const std::map &options) { } GE_TIMESTAMP_END(CheckOptionsValid, "GEInitialize::CheckOptionsValid"); - GE_TIMESTAMP_START(InitPreparation); - SaveDdkVersion(options); - GE_TIMESTAMP_END(InitPreparation, "GEInitialize::InitPreparation"); // call Initialize GELOGT(TRACE_RUNNING, "Initializing environment"); GE_TIMESTAMP_START(GELibInitialize); diff --git a/src/ge/common/convert/pb2json.cc b/src/ge/common/convert/pb2json.cc index 88b2a332..832a8278 100644 --- a/src/ge/common/convert/pb2json.cc +++ b/src/ge/common/convert/pb2json.cc @@ -22,6 +22,7 @@ #include #include "securec.h" #include "framework/common/fmk_types.h" +#include "framework/common/debug/ge_log.h" using std::set; using std::string; @@ -146,7 +147,10 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { uint8_t *value = 0; value = reinterpret_cast(&temp_value); char str[kSignificantDigits]; - sprintf_s(str, kSignificantDigits, "%d", *value); + if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1) { + GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str()); + continue; + } result += str; } return result; diff --git a/src/ge/common/formats/format_transfers/format_transfer_transpose.cc b/src/ge/common/formats/format_transfers/format_transfer_transpose.cc index ec309543..3be4d67d 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_transpose.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_transpose.cc @@ -21,6 +21,7 @@ #include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" #include "graph/utils/type_utils.h" namespace ge { @@ -199,6 +200,23 @@ Status TransposeWithShapeCheck(const uint8_t *data, const std::vector & return Transpose(data, src_shape, src_data_type, perm_arg, result); } +Status GetPermByForamt(Format src_format, Format dst_format, std::vector &perm) { + auto dst_iter = perm_args.find(src_format); + if (dst_iter == perm_args.end()) { + GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s", + TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); + return UNSUPPORTED; + } + auto iter = dst_iter->second.find(dst_format); + if (iter == dst_iter->second.end()) { + GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s", + TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); + return UNSUPPORTED; + } + perm = iter->second; + return SUCCESS; +} + Status FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult &result) { std::vector expected_shape; auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expected_shape); @@ -218,23 +236,12 @@ Status FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult & Status FormatTransferTranspose::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, std::vector &dst_shape) { - auto dst_iter = perm_args.find(src_format); - if (dst_iter == perm_args.end()) { - GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s", - TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); - return UNSUPPORTED; - } - auto iter = dst_iter->second.find(dst_format); - if (iter == dst_iter->second.end()) { - GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s", - TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); - return UNSUPPORTED; - } - - if (!IsShapeArgValid(src_shape, iter->second)) { + std::vector perm_arg; + GE_CHK_STATUS_RET_NOLOG(GetPermByForamt(src_format, dst_format, perm_arg)); + if (!IsShapeArgValid(src_shape, perm_arg)) { return PARAM_INVALID; } - dst_shape = TransShapeByPerm(src_shape, iter->second); + dst_shape = TransShapeByPerm(src_shape, perm_arg); return SUCCESS; } diff --git a/src/ge/common/formats/format_transfers/format_transfer_transpose.h b/src/ge/common/formats/format_transfers/format_transfer_transpose.h index 476ef024..0e84ef8c 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_transpose.h +++ b/src/ge/common/formats/format_transfers/format_transfer_transpose.h @@ -31,6 +31,8 @@ Status TransposeWithShapeCheck(const uint8_t *src, const std::vector &s const std::vector &dst_shape, DataType src_data_type, const std::vector &perm_arg, TransResult &result); +Status GetPermByForamt(Format src_format, Format dst_format, std::vector &perm); + class FormatTransferTranspose : public FormatTransfer { public: Status TransFormat(const TransArgs &args, TransResult &result) override; diff --git a/src/ge/common/helper/model_helper.cc b/src/ge/common/helper/model_helper.cc index 194ea59f..facaabdf 100644 --- a/src/ge/common/helper/model_helper.cc +++ b/src/ge/common/helper/model_helper.cc @@ -180,8 +180,7 @@ ModelHelper::SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::strin GELOGE(FAILED, "SaveModel fail for compute_graph null"); return FAILED; } - ge::GraphUtils::DumpGEGraph(compute_graph, "OriginalGraph"); - ge::GraphUtils::DumpGEGraphToOnnx(*compute_graph, "OriginalGraph"); + GE_DUMP(compute_graph, "OriginalGraph"); // Model ModelPtr model_ptr = ge::MakeShared(); GE_CHECK_NOTNULL_EXEC(model_ptr, return MEMALLOC_FAILED); diff --git a/src/ge/common/profiling/profiling_manager.cc b/src/ge/common/profiling/profiling_manager.cc index 8422ebf6..8a29f0b4 100644 --- a/src/ge/common/profiling/profiling_manager.cc +++ b/src/ge/common/profiling/profiling_manager.cc @@ -74,14 +74,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In return FAILED; } // profiling startup first time + GELOGI("Begin to init profiling, device num %zu", device_id_.size()); for (size_t i = 0; i < device_id_.size(); ++i) { ret = StartProfiling(0, device_id_[i]); if (ret != SUCCESS) { - GELOGE(ret, "Profiling start failed."); + GELOGE(ret, "Profiling start failed on device %d.", device_id_[i]); return FAILED; } - GELOGI("Profiling init succ."); + GELOGI("Profiling init succ on device %d.", device_id_[i]); } + } else { + GELOGI("The profiling is off, skip the initialization"); } #endif return SUCCESS; @@ -164,7 +167,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In } is_profiling_ = true; - } catch (Json::parse_error &e) { + } catch (...) { GELOGE(FAILED, "Json conf is not invalid !"); return ge::PARAM_INVALID; } @@ -274,7 +277,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::St ss << start_cfg; send_profiling_config_ = ss.str(); GELOGI("Profiling config %s\n", send_profiling_config_.c_str()); - } catch (Json::parse_error &e) { + } catch (...) { GELOGE(FAILED, "Op trace json conf is not invalid !"); return FAILED; } diff --git a/src/ge/common/types.cc b/src/ge/common/types.cc index 26668c70..751e36b7 100644 --- a/src/ge/common/types.cc +++ b/src/ge/common/types.cc @@ -389,6 +389,7 @@ REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); REGISTER_OPTYPE_DEFINE(SEND, "Send"); REGISTER_OPTYPE_DEFINE(RECV, "Recv"); +REGISTER_OPTYPE_DEFINE(ENDOFSEQUENCE, "EndOfSequence"); REGISTER_OPTYPE_DEFINE(LABELSET, "LabelSet"); REGISTER_OPTYPE_DEFINE(LABELGOTO, "LabelGoto"); @@ -456,6 +457,12 @@ REGISTER_OPTYPE_DEFINE(SIGMOIDGRAD, "SigmoidGrad"); REGISTER_OPTYPE_DEFINE(TRANSSHAPE, "TransShape"); +// Horovod operator +REGISTER_OPTYPE_DEFINE(HVDCALLBACKALLREDUCE, "HorovodAllreduce"); +REGISTER_OPTYPE_DEFINE(HVDCALLBACKALLGATHER, "HorovodAllgather"); +REGISTER_OPTYPE_DEFINE(HVDCALLBACKBROADCAST, "HorovodBroadcast"); +REGISTER_OPTYPE_DEFINE(HVDWAIT, "HorovodWait"); + const std::string MODEL_ATTR_TASKS = "tasks"; const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; diff --git a/src/ge/common/util.cc b/src/ge/common/util.cc index b53a1c43..0a6561c4 100644 --- a/src/ge/common/util.cc +++ b/src/ge/common/util.cc @@ -67,8 +67,9 @@ static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Messag } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr), return false, - "incorrect parameter. nullptr == file || nullptr == proto"); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr), + ErrorManager::GetInstance().ATCReportErrMessage("E19001"); + return false, "Input parameter file or proto is nullptr!"); std::string real_path = RealPath(file); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "pb file path '%s' not valid", file); @@ -77,7 +78,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); if (!fs.is_open()) { - GELOGE(ge::FAILED, "Open %s failed.", file); + ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"realpath"}, {file}); + GELOGE(ge::FAILED, "Open real path[%s] failed.", file); return false; } @@ -89,7 +91,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co fs.close(); if (!ret) { - GELOGE(ge::FAILED, "Parse %s failed.", file); + ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"filepath"}, {file}); + GELOGE(ge::FAILED, "Parse file[%s] failed.", file); return ret; } @@ -113,17 +116,17 @@ long GetFileLength(const std::string &input_file) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); unsigned long long file_length = 0; GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, - ErrorManager::GetInstance().ATCReportErrMessage("E10037"); - return -1, "open file failed."); + ErrorManager::GetInstance().ATCReportErrMessage("E10037", {"filepath"}, {input_file}); + return -1, "Open file[%s] failed", input_file.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), ErrorManager::GetInstance().ATCReportErrMessage("E10038"); - return -1, "file length is 0, not valid."); + return -1, "File[%s] length is 0, not valid.", input_file.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - file_length > kMaxFileSizeLimit, - ErrorManager::GetInstance().ATCReportErrMessage("E10039", {"filesize", "maxlen"}, - {std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); - return -1, "file size %lld is out of limit: %d.", file_length, kMaxFileSizeLimit); + file_length > kMaxFileSizeLimit, ErrorManager::GetInstance().ATCReportErrMessage( + "E10039", {"filepath", "filesize", "maxlen"}, + {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); + return -1, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length, kMaxFileSizeLimit); return static_cast(file_length); } @@ -202,7 +205,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); auto dir_path_len = directory_path.length(); if (dir_path_len >= PATH_MAX) { - GELOGW("Directory path is too long."); + ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, + {directory_path, std::to_string(PATH_MAX)}); + GELOGW("Path[%s] len is too long, it must smaller than %d", directory_path.c_str(), PATH_MAX); return -1; } char tmp_dir_path[PATH_MAX] = {0}; @@ -213,8 +218,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 if (ret != 0) { if (errno != EEXIST) { - GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", - directory_path.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); + GELOGW("Cannot create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); return ret; } } @@ -224,7 +229,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: int32_t ret = mmMkdir(const_cast(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 if (ret != 0) { if (errno != EEXIST) { - GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", directory_path.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); + GELOGW("Cannot create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); return ret; } } @@ -253,16 +259,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch std::string real_path = RealPath(file); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), - ErrorManager::GetInstance().ATCReportErrMessage("E10036", {"realpath"}, {file}); - return false, "proto file real path '%s' not valid", file); + ErrorManager::GetInstance().ATCReportErrMessage("E10036", {"filepath"}, {file}); + return false, "Get path[%s]'s real path failed", file); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); std::ifstream fs(real_path.c_str(), std::ifstream::in); if (!fs.is_open()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10040", {"protofile"}, {file}); - GELOGE(ge::FAILED, "Fail to open proto file '%s'.", file); + ErrorManager::GetInstance().ATCReportErrMessage("E10040", {"realpth", "protofile"}, {real_path, file}); + GELOGE(ge::FAILED, "Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), + file); return false; } @@ -328,18 +335,21 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInt64MulOverflow(int6 FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL."); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= PATH_MAX, return "", "path is invalid"); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + strlen(path) >= PATH_MAX, + ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(PATH_MAX)}); + return "", "Path[%s] len is too long, it must smaller than %d", path, PATH_MAX); // PATH_MAX is the system's own macro, indicating the maximum file path length supported std::shared_ptr resolved_path(new (std::nothrow) char[PATH_MAX](), std::default_delete()); - if (resolved_path == nullptr) { - GELOGW("new an PATH_MAX string object failed."); - return ""; - } - std::string res; + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + resolved_path == nullptr, + ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"filepath", "size"}, {path, std::to_string(PATH_MAX)}); + return "", "Path[%s] new string object len[%d] failed.", path, PATH_MAX); // Nullptr is returned when the path does not exist or there is no permission // Return absolute path when path is accessible + std::string res; if (realpath(path, resolved_path.get()) != nullptr) { res = resolved_path.get(); } @@ -360,7 +370,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const // Unable to get absolute path (does not exist or does not have permission to access) if (real_path.empty()) { ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); - GELOGW("Can not get real path for %s, %s", file_path.c_str(), strerror(errno)); + GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno)); return false; } @@ -381,7 +391,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const // The absolute path points to a file that is not readable if (access(real_path.c_str(), R_OK) != 0) { ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); - GELOGW("Read path[%s] failed, %s", file_path.c_str(), strerror(errno)); + GELOGW("Read path[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); return false; } @@ -416,9 +426,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const // File is not readable or writable if (access(real_path.c_str(), W_OK | F_OK) != 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"path", "errmsg"}, - {real_path.c_str(), strerror(errno)}); - GELOGW("Write file failed, path[%s], %s", real_path.c_str(), strerror(errno)); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"realpath", "path", "errmsg"}, + {real_path, file_path, strerror(errno)}); + GELOGW("Write file[%s] failed, input path is %s, errmsg[%s]", real_path.c_str(), file_path.c_str(), + strerror(errno)); return false; } } else { diff --git a/src/ge/executor/CMakeLists.txt b/src/ge/executor/CMakeLists.txt index 8512904c..90b091d2 100755 --- a/src/ge/executor/CMakeLists.txt +++ b/src/ge/executor/CMakeLists.txt @@ -59,12 +59,15 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../graph/load/new_model_manager/tbe_handle_store.cc" "../graph/load/new_model_manager/zero_copy_task.cc" "../graph/load/output/output.cc" + "../graph/manager/graph_caching_allocator.cc" "../graph/manager/graph_manager_utils.cc" "../graph/manager/graph_mem_allocator.cc" "../graph/manager/graph_var_manager.cc" "../graph/manager/trans_var_data_utils.cc" "../graph/manager/util/debug.cc" + "../hybrid/hybrid_davinci_model_stub.cc" "../model/ge_model.cc" + "../model/ge_root_model.cc" "../omm/csa_interact.cc" "../single_op/single_op.cc" "../single_op/single_op_manager.cc" diff --git a/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc b/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc index cde6640f..857cee6b 100644 --- a/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc +++ b/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc @@ -20,6 +20,7 @@ #include "common/ge/ge_util.h" #include "common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" +#include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "op/op_factory.h" @@ -68,6 +69,15 @@ Status GeLocalOpsKernelInfoStore::CalcOpRunningParam(Node &ge_node) { GELOGE(FAILED, "CalcOpRunningParam failed, as op desc is null"); return FAILED; } + + bool is_shape_unknown = false; + if (NodeUtils::GetNodeUnknownShapeStatus(ge_node, is_shape_unknown) == GRAPH_SUCCESS) { + if (is_shape_unknown) { + GELOGI("op:%s is unknown shape, does not need to calc output size.", ge_node.GetName().c_str()); + return SUCCESS; + } + } + const string node_name = ge_node.GetName(); const string node_type = ge_node.GetType(); size_t output_size = op_desc->GetOutputsSize(); @@ -157,6 +167,13 @@ Status GeLocalOpsKernelInfoStore::CalcConstantStrMemSize(const OpDescPtr &op_des void GeLocalOpsKernelInfoStore::GetAllOpsKernelInfo(map &infos) const { infos = op_info_map_; } Status GeLocalOpsKernelInfoStore::GenerateTask(const Node &node, RunContext &context, vector &tasks) { + bool is_shape_unknown = false; + if (NodeUtils::GetNodeUnknownShapeStatus(node, is_shape_unknown) == GRAPH_SUCCESS) { + if (is_shape_unknown) { + GELOGI("op:%s is unknown shape, does not need to generate task", node.GetName().c_str()); + return SUCCESS; + } + } string name = node.GetName(); string type = node.GetType(); GELOGD("Ge local generate task for node:%s(%s) begin, tasks.size()=%zu.", name.c_str(), type.c_str(), tasks.size()); diff --git a/src/ge/ge_runtime/task/hccl_task.cc b/src/ge/ge_runtime/task/hccl_task.cc index 7a513597..3d5f8504 100644 --- a/src/ge/ge_runtime/task/hccl_task.cc +++ b/src/ge/ge_runtime/task/hccl_task.cc @@ -128,17 +128,17 @@ bool HcclTask::Distribute() { ge_task.type = static_cast(RT_MODEL_TASK_HCCL); ge_task.stream = stream_; - ge_task.kernelHcclInfo.hccl_type = task_info_->hccl_type(); - ge_task.kernelHcclInfo.inputDataAddr = task_info_->input_data_addr(); - ge_task.kernelHcclInfo.outputDataAddr = task_info_->output_data_addr(); - ge_task.kernelHcclInfo.workSpaceAddr = task_info_->workspace_addr(); - ge_task.kernelHcclInfo.workSpaceMemSize = task_info_->workspace_size(); - ge_task.kernelHcclInfo.count = task_info_->count(); - ge_task.kernelHcclInfo.dataType = static_cast(task_info_->data_type()); - ge_task.kernelHcclInfo.opType = static_cast(task_info_->op_type()); - ge_task.kernelHcclInfo.rootId = task_info_->root_id(); - - ge_task.kernelHcclInfo.hcclStreamList = slave_stream_list_; + ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); + ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); + ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); + ge_task.kernelHcclInfo[0].workSpaceAddr = task_info_->workspace_addr(); + ge_task.kernelHcclInfo[0].workSpaceMemSize = task_info_->workspace_size(); + ge_task.kernelHcclInfo[0].count = task_info_->count(); + ge_task.kernelHcclInfo[0].dataType = static_cast(task_info_->data_type()); + ge_task.kernelHcclInfo[0].opType = static_cast(task_info_->op_type()); + ge_task.kernelHcclInfo[0].rootId = task_info_->root_id(); + + ge_task.kernelHcclInfo[0].hcclStreamList = slave_stream_list_; ge_task.privateDef = private_def; ge_task.privateDefLen = private_def_len; diff --git a/src/ge/generator/ge_generator.cc b/src/ge/generator/ge_generator.cc index d4a33eec..f25a67cd 100644 --- a/src/ge/generator/ge_generator.cc +++ b/src/ge/generator/ge_generator.cc @@ -27,6 +27,7 @@ #include "graph/opsproto_manager.h" #include "graph/utils/graph_utils.h" #include "model/ge_model.h" +#include "init/gelib.h" using std::map; using std::string; @@ -34,9 +35,79 @@ using std::vector; namespace { const char *const kAttrOpType = "op_type"; -} +const char *const kEngineNameDefault = "default"; +const char *const kVectorEngine = "VectorEngine"; +const char *const kAIcoreEngine = "AIcoreEngine"; +const char *const kFileNameSuffix = "online"; + +std::map engine_type_map{ + {ge::ENGINE_SYS, kEngineNameDefault}, {ge::ENGINE_AICORE, kAIcoreEngine}, {ge::ENGINE_VECTOR, kVectorEngine}}; +} // namespace namespace ge { +static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engine_type) { + GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); + if (engine_type == ENGINE_SYS) { + GELOGI("CheckEngineType: use default engine."); + return SUCCESS; + } + // get op engine name + string op_engine_name; + auto iter = engine_type_map.find(engine_type); + if (iter != engine_type_map.end()) { + op_engine_name = iter->second; + GELOGI("CheckEngineType: engine type: %d", static_cast(engine_type)); + } else { + GELOGE(FAILED, "CheckEngineType: engine type: %d not support", static_cast(engine_type)); + return FAILED; + } + // set op engine name and opkernelLib. when engine support + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "CheckEngineType failed."); + return FAILED; + } + OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); + std::vector op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); + if (op_infos.empty()) { + GELOGE(FAILED, "CheckEngineType: Can not get op info by op type %s", op_desc->GetType().c_str()); + return FAILED; + } + string kernel_name; + for (const auto &it : op_infos) { + if (it.engine == op_engine_name) { + kernel_name = it.opKernelLib; + break; + } + } + if (kernel_name.empty()) { + GELOGE(FAILED, "CheckEngineType:Can not find ops kernel,engine name: %s.", op_engine_name.c_str()); + return FAILED; + } + auto &kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores(); + auto kernel_info_store = kernel_map.find(kernel_name); + if (kernel_info_store != kernel_map.end()) { + std::string unsupported_reason; + if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { + op_desc->SetOpEngineName(op_engine_name); + op_desc->SetOpKernelLibName(kernel_name); + GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), + op_engine_name.c_str(), op_desc->GetName().c_str()); + return SUCCESS; + } else { + GELOGE(FAILED, "CheckEngineType: check support failed, Op type %s of ops kernel %s is unsupported, reason:%s", + op_desc->GetType().c_str(), kernel_name.c_str(), unsupported_reason.c_str()); + return FAILED; + } + } else { + GELOGE(FAILED, + "CheckEngineType:Can not find any supported ops kernel info store by kernel_name %s," + "op type is %s, op name is %s", + kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str()); + } + return FAILED; +} + static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index, bool attr) { GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID); @@ -96,7 +167,7 @@ static Status AddOutputs(const ComputeGraphPtr &graph, const NodePtr &node, cons } static void GetOpsProtoPath(string &opsproto_path) { - GELOGI("Start to get ops proto path schedule"); + GELOGI("Start to get ops proto path schedule."); const char *path_env = std::getenv("ASCEND_OPP_PATH"); if (path_env != nullptr) { string path = path_env; @@ -105,7 +176,7 @@ static void GetOpsProtoPath(string &opsproto_path) { GELOGE(FAILED, "File path %s is invalid.", path.c_str()); return; } - opsproto_path = (path + "/op_proto/built-in/" + ":") + (path + "/op_proto/custom/"); + opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/"); GELOGI("Get opsproto so path from env : %s", path.c_str()); return; } @@ -113,15 +184,14 @@ static void GetOpsProtoPath(string &opsproto_path) { GELOGI("path_base is %s", path_base.c_str()); path_base = path_base.substr(0, path_base.rfind('/')); path_base = path_base.substr(0, path_base.rfind('/') + 1); - opsproto_path = (path_base + "ops/op_proto/built-in/" + ":") + (path_base + "ops/op_proto/custom/"); + opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); } class GeGenerator::Impl { public: - Status BuildModel(const Graph &graph, const vector &inputs, GraphId &graph_id, - vector &ge_models); + Status BuildModel(const Graph &graph, const vector &inputs, GraphId &graph_id, GeRootModelPtr &ge_models); - Status SaveModel(const string &file_name_prefix, vector &models, ModelBufferData &model); + Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model); Status SaveParams(GeModelPtr &ge_model, const string &type, const map &attrs, const vector &inputs, const vector &outputs); @@ -141,7 +211,7 @@ Status GeGenerator::Initialize(const map &options) { } string opsproto_path; GetOpsProtoPath(opsproto_path); - GELOGI("opsproto_path is %s", opsproto_path.c_str()); + GELOGI("Get opsproto path is %s", opsproto_path.c_str()); OpsProtoManager *manager = OpsProtoManager::Instance(); map option_tmp; option_tmp.emplace(std::pair(string("ge.opsProtoLibPath"), opsproto_path)); @@ -149,7 +219,7 @@ Status GeGenerator::Initialize(const map &options) { Status ret = impl_->graph_manager_.Initialize(options); if (ret != SUCCESS) { - GELOGE(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, "Graph manager initialize failed"); + GELOGE(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, "Graph manager initialize failed."); return GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED; } // get ek file @@ -179,7 +249,7 @@ Status GeGenerator::Finalize() { GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); Status ret = impl_->graph_manager_.Finalize(); if (ret != SUCCESS) { - GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "Graph manager finalize failed"); + GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "Graph manager finalize failed."); return GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED; } return SUCCESS; @@ -187,7 +257,7 @@ Status GeGenerator::Finalize() { Status GeGenerator::GenerateOfflineModel(const Graph &graph, const string &file_name_prefix, const vector &inputs) { - GELOGI("Start to GenerateOfflineModel."); + GELOGI("Start to generate offline model."); ModelBufferData model; return GenerateModel(graph, file_name_prefix, inputs, model, true); } @@ -208,25 +278,22 @@ Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) { } return ret; } - GELOGI("GenerateInfershapeJson success."); + GELOGI("GenerateInfershapeGraph success."); return SUCCESS; } Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector &inputs, ModelBufferData &model, bool is_offline) { GraphId graph_id; - vector ge_models; + GeRootModelPtr ge_root_model = nullptr; GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); - - string model_name; - auto compute_graph = GraphUtils::GetComputeGraph(graph); - if (compute_graph == nullptr) { - GELOGW("Get compute graph fail."); - } else { - model_name = compute_graph->GetName(); - } + // using output as model_name (ignore ".om") + int start_position = file_name_prefix.find_last_of('/') + 1; + int end_position = file_name_prefix.length() - 3; + const string model_name = file_name_prefix.substr(start_position, end_position - start_position); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(model_name.empty(), return PARAM_INVALID, "om name is not valid!"); impl_->is_offline_ = is_offline; - Status ret = impl_->BuildModel(graph, inputs, graph_id, ge_models); + Status ret = impl_->BuildModel(graph, inputs, graph_id, ge_root_model); if (ret != SUCCESS) { GELOGE(ret, "Build model failed"); if (impl_->graph_manager_.Finalize() != SUCCESS) { @@ -234,11 +301,14 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr } return ret; } - - if (!model_name.empty() && !ge_models.empty()) { - ge_models[0]->SetName(model_name); - } - ret = impl_->SaveModel(file_name_prefix, ge_models, model); + GE_CHECK_NOTNULL(ge_root_model); + GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); + map name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); + GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; + + GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model can not be null"); + ge_model->SetName(model_name); + ret = impl_->SaveModel(file_name_prefix, ge_model, model); if (ret != SUCCESS) { GELOGE(ret, "Save model failed"); if (impl_->graph_manager_.Finalize() != SUCCESS) { @@ -250,17 +320,9 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr return SUCCESS; } -/** - * @ingroup ge - * @brief Compiling a single operator into an offline model - * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file - * @param [in] vector &inputs: Operator input data description information. - * @param [in] vector &outputs: Operator output data description information. - * @param [in] const string &model_file_name: Offline model filename. - * @return SUCCESS handle successfully / others handle failed - */ -Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector &inputs, - const vector &outputs, const string &model_file_name) { +Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, + const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, + bool is_offline) { GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); if (!inputs.empty() && (inputs.size() != op_desc->GetInputsSize())) { GELOGE(PARAM_INVALID, "Tensor size: %zu, Inputs size:%zu", inputs.size(), op_desc->GetInputsSize()); @@ -275,7 +337,16 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector(name); if (compute_graph == nullptr) { @@ -283,9 +354,11 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vectorAddNode(op_desc); GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR); + + // 4. Create InputData node. int32_t arg_index = 0; if (inputs.empty()) { for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { @@ -301,7 +374,7 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector ge_models; + GeRootModelPtr ge_root_model = nullptr; GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); - GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, inputs, graph_id, ge_models)); + impl_->is_offline_ = is_offline; + GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, inputs, graph_id, ge_root_model)); + map op_attrs = op_desc_tmp->GetAllAttrs(); + GE_CHECK_NOTNULL(ge_root_model); + GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); + map name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); + GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; + GELOGD("The opType in op_desc_tmp is: %s", op_desc_tmp->GetType().c_str()); + GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs)); + GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff)); + return SUCCESS; +} - if (!ge_models.empty()) { - map op_attrs = op_desc_tmp->GetAllAttrs(); - GELOGI("The opType in op_desc_tmp is: %s", op_desc_tmp->GetType().c_str()); - GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_models[0], op_desc_tmp->GetType(), op_attrs, inputs, outputs)); - } +/** + * @ingroup ge + * @brief Compiling a single operator into an offline model + * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file + * @param [in] vector &inputs: Operator input data description information. + * @param [in] vector &outputs: Operator output data description information. + * @param [in] const string &model_file_name: Offline model filename. + * @return SUCCESS handle successfully / others handle failed + */ +Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector &inputs, + const vector &outputs, const string &model_file_name) { + GELOGI("Start to Build Single Op Offline Model."); ModelBufferData model_buff; - GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_models, model_buff)); - return SUCCESS; + OpEngineType engine_type = ENGINE_SYS; + return BuildSingleOp(op_desc, inputs, outputs, model_file_name, engine_type, model_buff, true); +} + +/** + * @ingroup ge + * @brief Compiling a single operator into online buffer + * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file + * @param [in] vector &inputs: Operator input data description information. + * @param [in] vector &outputs: Operator output data description information. + * @param [in] engine_type: specific engine. + * @param [out] ModelBufferData &Model_buff: Model_buff: model buffer of the op. + * @return SUCCESS handle successfully / others handle failed + */ +Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector &inputs, + const vector &outputs, OpEngineType engine_type, + ModelBufferData &model_buff) { + GELOGI("Start to Build Single Op Online"); + return BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false); } Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map &attrs, const vector &inputs, const vector &outputs) { GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID); GE_CHK_BOOL_EXEC_NOLOG(graph_manager_.SaveParams(*ge_model, type, attrs, inputs, outputs) == SUCCESS, - graph_manager_.Finalize(); + (void)graph_manager_.Finalize(); return FAILED); return SUCCESS; } -Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, vector &models, - ModelBufferData &model_buff) { - // to be change to ModelHelper interface - if (models.empty()) { - GELOGE(FAILED, "models are empty."); - return FAILED; - } - +Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) { ModelHelper model_helper; model_helper.SetSaveMode(is_offline_); - Status ret = model_helper.SaveToOmModel(models[0], save_param_, file_name_prefix, model_buff); + Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff); if (ret != SUCCESS) { GELOGE(ret, "Save to Om model failed"); return ret; @@ -355,20 +456,21 @@ Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, vector &inputs, GraphId &graph_id, - vector &ge_models) { + GeRootModelPtr &ge_root_model) { static GraphId id = 0; const std::map options; Status ret = graph_manager_.AddGraph(id, graph, options); if (ret != SUCCESS) { - GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "graphManager AddGraph failed, id: %u", id); - graph_manager_.Finalize(); + GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph failed, id: %u", id); + (void)graph_manager_.Finalize(); return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED; } - GELOGI("models' inputs.size()=%zu", inputs.size()); - ret = graph_manager_.BuildGraph(id, inputs, ge_models); + GELOGI("models inputs.size()=%zu", inputs.size()); + graph_manager_.SetOptionsRunGraphFlag(false); + ret = graph_manager_.BuildGraph(id, inputs, ge_root_model); if (ret != SUCCESS) { - GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "graphManager BuildGraph failed, id: %u", id); + GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager build graph failed, id: %u", id); return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; } @@ -383,14 +485,14 @@ Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph, GraphId &g const std::map options; Status ret = graph_manager_.AddGraph(id, graph, options); if (ret != SUCCESS) { - GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "graphManager AddGraph failed, id: %u", id); - graph_manager_.Finalize(); + GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "graphManager add graph failed, id: %u", id); + (void)graph_manager_.Finalize(); return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED; } ret = graph_manager_.GenerateInfershapeGraph(id); if (ret != SUCCESS) { - GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "graphManager BuildGraph failed, id: %u", id); + GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager BuildGraph failed, id: %u", id); return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; } diff --git a/src/ge/graph/build/graph_builder.cc b/src/ge/graph/build/graph_builder.cc index 8e556ff2..f2fa4ada 100644 --- a/src/ge/graph/build/graph_builder.cc +++ b/src/ge/graph/build/graph_builder.cc @@ -53,6 +53,7 @@ Status GraphBuilder::CalcOpParam(const ge::ComputeGraphPtr &graph) { GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GraphBuilder: GE is not initialized"); return GE_CLI_GE_NOT_INITIALIZED; } + for (const auto &node_ptr : graph->GetAllNodes()) { GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); std::string kernel_lib_name = node_ptr->GetOpDesc()->GetOpKernelLibName(); @@ -84,76 +85,229 @@ Status GraphBuilder::CalcOpParam(const ge::ComputeGraphPtr &graph) { return INTERNAL_ERROR; } } + + auto parent_node = graph->GetParentNode(); + if (parent_node == nullptr) { + GELOGI("Graph[%s] do not have parent node, no need update parent node output size.", graph->GetName().c_str()); + return SUCCESS; + } + + GE_CHK_STATUS_RET(UpdateParentNodeOutputSize(graph, parent_node)); GELOGI("Success to calculate op running param."); return SUCCESS; } +Status GraphBuilder::UpdateParentNodeOutputSize(const ge::ComputeGraphPtr &graph, ge::NodePtr &parent_node_ptr) { + GELOGI("Begin to update parent node[%s] of graph[%s] output size.", parent_node_ptr->GetName().c_str(), + graph->GetName().c_str()); + auto parent_op_desc = parent_node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(parent_op_desc); + bool is_unknown_shape = false; + if (!AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape)) { + GELOGE(PARAM_INVALID, "Get op %s unknown shape attr failed.", parent_op_desc->GetName().c_str()); + return PARAM_INVALID; + } + if (is_unknown_shape) { + GELOGI("Current graph[%s] is unknown, no need to update parent node[%s] output size.", graph->GetName().c_str(), + parent_node_ptr->GetName().c_str()); + return SUCCESS; + } + for (const auto &node_ptr : graph->GetDirectNode()) { + if (node_ptr->GetType() != NETOUTPUT) { + continue; + } + auto op_desc = node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (const auto &in_data_anchor : node_ptr->GetAllInDataAnchors()) { + auto index = in_data_anchor->GetIdx(); + ge::GeTensorDesc desc_temp = op_desc->GetInputDesc(index); + int64_t size = 0; + GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc_temp, size) != SUCCESS, GELOGI("Get size failed!")); + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(desc_temp, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(INTERNAL_ERROR, "NetOutput input tensor %d, attr %s not found.", index, + ATTR_NAME_PARENT_NODE_INDEX.c_str()); + return INTERNAL_ERROR; + } + ge::GeTensorDesc parent_desc_temp = parent_op_desc->GetOutputDesc(parent_index); + ge::TensorUtils::SetSize(parent_desc_temp, size); + GE_CHK_STATUS_RET(parent_op_desc->UpdateOutputDesc(parent_index, parent_desc_temp)); + GELOGI("Update parent node[%s] output index[%u] to size[%ld].", parent_node_ptr->GetName().c_str(), parent_index, + size); + } + } + return SUCCESS; +} + Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, - GeModelPtr &ge_model_ptr, uint64_t session_id) { + GeRootModelPtr &ge_root_model_ptr, uint64_t session_id) { GELOGI("Start to build model."); if (comp_graph == nullptr) { GELOGE(GE_GRAPH_PARAM_NULLPTR, "Graph build comp_graph is null."); return GE_GRAPH_PARAM_NULLPTR; } + ge_root_model_ptr = MakeShared(comp_graph); + if (ge_root_model_ptr == nullptr) { + return MEMALLOC_FAILED; + } + GeModelPtr ge_model_ptr = nullptr; + bool is_dynamic_shape = false; + // To be compatible with the old process, do not verify the return value temporarily. + (void)AttrUtils::GetBool(comp_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); + if (is_dynamic_shape) { + GE_CHK_STATUS_RET( + BuildForDynamicShapeGraph(comp_graph, subgraph_ptr_list, ge_root_model_ptr, ge_model_ptr, session_id), + "Build for dynamic shape graph failed."); + return SUCCESS; + } + + GE_CHK_STATUS_RET(BuildForKnownShapeGraph(comp_graph, subgraph_ptr_list, ge_model_ptr, session_id), + "Build for known shape graph failed."); + ge_root_model_ptr->SetSubgraphInstanceNameToModel(comp_graph->GetName(), ge_model_ptr); + return SUCCESS; +} +Status GraphBuilder::BuildForKnownShapeGraph(ComputeGraphPtr &comp_graph, + std::vector &subgraph_ptr_list, GeModelPtr &ge_model_ptr, + uint64_t session_id) { + GELOGI("Begin to build known shape graph[%s].", comp_graph->GetName().c_str()); Status ret = SecondPartition(comp_graph, subgraph_ptr_list); - GE_CHK_STATUS_RET(ret, "Graph second partition Failed."); + GE_CHK_STATUS_RET(ret, "Graph[%s] second partition Failed.", comp_graph->GetName().c_str()); auto subgraph_map = graph_partitioner_.GetSubGraphMap(); GE_TIMESTAMP_START(BuildSubgraph); ge::ModelBuilder builder(comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); - - GELOGI("[Build] invoke the other opskernel to generate task."); - - GraphUtils::DumpGEGraph(comp_graph, "BeforePreBuildModel"); - GraphUtils::DumpGEGraphToOnnx(*comp_graph, "BeforePreBuildModel"); - + GE_DUMP(comp_graph, "BeforePreBuildModel"); GE_TIMESTAMP_START(PreBuildModel); - GE_CHK_STATUS_RET(builder.PreBuildModel(), "Builder PreBuildModel() return fail."); + GE_CHK_STATUS_RET(builder.PreBuildModel(), "Graph[%s] builder PreBuildModel() return fail.", + comp_graph->GetName().c_str()); GE_TIMESTAMP_END(PreBuildModel, "GraphBuilder::PreBuildModel"); - GraphUtils::DumpGEGraph(comp_graph, "AfterPrebuildmodel"); - GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterPrebuildmodel"); - + GE_DUMP(comp_graph, "AfterPreBuildModel"); GE_TIMESTAMP_START(CalcOpParam); - GE_CHK_STATUS_RET(CalcOpParam(comp_graph), "Builder CalcOpParam() return fail."); + GE_CHK_STATUS_RET(CalcOpParam(comp_graph), "Graph[%s] builder CalcOpParam() return fail.", + comp_graph->GetName().c_str()); GE_TIMESTAMP_END(CalcOpParam, "GraphBuilder::CalcOpParam"); - GraphUtils::DumpGEGraph(comp_graph, "AfterCalcOpParam"); - GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterCalcOpParam"); + GE_DUMP(comp_graph, "AfterCalcOpParam"); ModelPtr model_ptr = MakeShared(); if (model_ptr == nullptr) { return MEMALLOC_FAILED; } GE_TIMESTAMP_START(BuildModelForGetTask); - GE_CHK_STATUS_RET(builder.BuildModelForGetTask(*model_ptr), "Builder BuildModelForGetTask() return fail."); + GE_CHK_STATUS_RET(builder.BuildModelForGetTask(*model_ptr), "Graph[%s] builder BuildModelForGetTask() return fail.", + comp_graph->GetName().c_str()); GE_TIMESTAMP_END(BuildModelForGetTask, "GraphBuilder::BuildModelForGetTask"); - - GraphUtils::DumpGEGraph(comp_graph, "AfterBuildModel"); - GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterBuildModel"); + GE_DUMP(comp_graph, "AfterBuildModel"); GE_TIMESTAMP_START(GetTaskInfo); ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_map, session_id); GE_TIMESTAMP_END(GetTaskInfo, "GraphBuilder::GetTaskInfo"); - - GraphUtils::DumpGEGraph(comp_graph, "AfterGetTask"); - GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterGetTask"); + GE_DUMP(comp_graph, "AfterGetTask"); if (ret != SUCCESS) { - GELOGE(ret, "Builder GetTaskInfo() return fail."); + GELOGE(ret, "Graph[%s] builder GetTaskInfo() return fail.", comp_graph->GetName().c_str()); return ret; } - for (auto graph : comp_graph->GetAllSubgraphs()) { - GraphUtils::DumpGEGraphToOnnx(*graph, "SubgraphGetTask"); + ge_model_ptr = MakeShared(); + if (ge_model_ptr == nullptr) { + return MEMALLOC_FAILED; } + GE_CHK_STATUS_RET(builder.SaveDataToModel(*model_ptr, *ge_model_ptr), + "Graph[%s] builder SaveDataToModel() return fail.", comp_graph->GetName().c_str()); + GELOGI("Success to build graph[%s] model.", comp_graph->GetName().c_str()); + GE_TIMESTAMP_END(BuildSubgraph, "GraphBuilder::Build"); + return SUCCESS; +} +Status GraphBuilder::BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, + uint64_t session_id) { + GELOGI("Begin to build unknown shape graph[%s].", comp_graph->GetName().c_str()); + GE_TIMESTAMP_START(CalcOpParam); + GE_CHK_STATUS_RET(CalcOpParam(comp_graph), "Graph[%s] builder CalcOpParam() return fail.", + comp_graph->GetName().c_str()); + GE_TIMESTAMP_END(CalcOpParam, "GraphBuilder::CalcOpParam"); + GE_DUMP(comp_graph, "AfterCalcOpParam"); + Graph2SubGraphInfoList subgraph_map; + ge::ModelBuilder builder(comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); + ModelPtr model_ptr = MakeShared(); + if (model_ptr == nullptr) { + return MEMALLOC_FAILED; + } + GE_TIMESTAMP_START(BuildModelForGetDynShapeTask); + GE_CHK_STATUS_RET(builder.BuildModelForGetDynShapeTask(*model_ptr), + "Graph[%s] builder BuildModelForGetDynShapeTask() return fail.", comp_graph->GetName().c_str()); + GE_TIMESTAMP_END(BuildModelForGetDynShapeTask, "GraphBuilder::BuildModelForGetDynShapeTask"); + GE_TIMESTAMP_START(GetTaskInfo); + Status ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_map, session_id); + GE_TIMESTAMP_END(GetTaskInfo, "GraphBuilder::GetTaskInfo"); + + GraphUtils::DumpGEGraph(comp_graph, "AfterGetTask"); + GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterGetTask"); + if (ret != SUCCESS) { + GELOGE(ret, "Graph[%s] builder GetTaskInfo() return fail.", comp_graph->GetName().c_str()); + return ret; + } ge_model_ptr = MakeShared(); if (ge_model_ptr == nullptr) { return MEMALLOC_FAILED; } - GE_CHK_STATUS_RET(builder.SaveDataToModel(*model_ptr, *ge_model_ptr), "model builder SaveDataToModel() return fail."); - GELOGI("Success to build model."); - GE_TIMESTAMP_END(BuildSubgraph, "GraphBuilder::Build"); + GE_CHK_STATUS_RET(builder.SaveDataToModel(*model_ptr, *ge_model_ptr), + "Graph[%s] builder SaveDataToModel() return fail.", comp_graph->GetName().c_str()); + GELOGI("Success to build graph[%s] model.", comp_graph->GetName().c_str()); + return SUCCESS; +} + +Status GraphBuilder::BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph, + std::vector &subgraph_ptr_list, + GeRootModelPtr &ge_root_model_ptr, GeModelPtr &ge_model_ptr, + uint64_t session_id) { + GELOGI("Start to build BuildForDynamicShape for dynamic shape."); + for (const auto &node : comp_graph->GetDirectNode()) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (node->GetType() == DATA) { + GE_CHK_STATUS_RET(CalcDynShapeRootGraphDataSize(op_desc), "Calc dynamic shape root graph data[%s] size failed.", + op_desc->GetName().c_str()); + } + + // ATTR_NAME_IS_UNKNOWN_SHAPE is set on "graph partion" stage, but afer fusion , the graph may + // be changed so here need to renew. For example , the scene followed: + // (known)partioncall(known) (known)partioncall(known) + // After fusion + // | --> + // (known)Unique(unknown)--->(unknow)Shape(unknown) (known)FuncDef(known) + // if scene like this , it should be process as known shape graph + bool is_unknown_shape = false; + GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), + "Get node[%s] shape status failed!", node->GetName().c_str()); + if (!is_unknown_shape) { + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape), return FAILED, + "Renew node [%s] attr[%s] failed!", node->GetName().c_str(), ATTR_NAME_IS_UNKNOWN_SHAPE.c_str()); + GELOGD("renew node [%s] attr[%s] success! value is %d", node->GetName().c_str(), + ATTR_NAME_IS_UNKNOWN_SHAPE.c_str(), is_unknown_shape); + } + + vector subgraph_names = op_desc->GetSubgraphInstanceNames(); + for (auto subgraph_name : subgraph_names) { + ComputeGraphPtr subgraph = comp_graph->GetSubgraph(subgraph_name); + bool is_unknown_shape = false; + if (!AttrUtils::GetBool(op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape)) { + GELOGE(PARAM_INVALID, "Get op %s unknown shape attr failed.", op_desc->GetName().c_str()); + return PARAM_INVALID; + } + if (is_unknown_shape) { + // unknown shape build flow + GE_CHK_STATUS_RET(BuildForUnknownShapeGraph(subgraph, ge_model_ptr, session_id), + "Build for unknown shape graph failed."); + } else { + // known shape build flow + GE_CHK_STATUS_RET(BuildForKnownShapeGraph(subgraph, subgraph_ptr_list, ge_model_ptr, session_id), + "Build for known shape graph failed."); + } + ge_root_model_ptr->SetSubgraphInstanceNameToModel(subgraph_name, ge_model_ptr); + } + } return SUCCESS; } @@ -199,10 +353,7 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr GELOGE(ret, "Optimize streamed subGraph fail."); return ret; } - - GraphUtils::DumpGEGraph(comp_graph, "AfterOptimizeStreamedSubGraph"); - GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterOptimizeStreamedSubGraph"); - + GE_DUMP(comp_graph, "AfterOptimizeStreamedSubGraph"); auto *get_var_mem_base = reinterpret_cast(reinterpret_cast(ge::VarManager::Instance(0)->GetVarMemLogicBase())); uint64_t var_size = (ge::VarManager::Instance(session_id)->GetVarMemSize(RT_MEMORY_HBM) > 0) @@ -289,6 +440,36 @@ Status GraphBuilder::UpdateDataInputSize(const ge::NodePtr &node_ptr) { return SUCCESS; } +Status GraphBuilder::CalcDynShapeRootGraphDataSize(const ge::OpDescPtr &op_desc) { + GELOGI("Begin to calc dynamic shape graph data[%s] size.", op_desc->GetName().c_str()); + // data op only has one output anchor + ge::GeTensorDesc output_desc = op_desc->GetOutputDesc(0); + int64_t output_size = 0; + if (ge::TensorUtils::GetSize(output_desc, output_size) != SUCCESS) { + GELOGW("Get size failed!"); + } + + if (output_size > 0) { + GELOGI("No need to update dynamic shape graph data output size[%ld].", output_size); + return SUCCESS; + } else { + int64_t real_dim_size = 0; + ge::graphStatus graph_status = TensorUtils::GetTensorSizeInBytes(output_desc, real_dim_size); + if (graph_status != GRAPH_SUCCESS) { + GELOGE(FAILED, "Get tensor size in bytes failed."); + return FAILED; + } + + ge::TensorUtils::SetSize(output_desc, real_dim_size); + GELOGI("Update dynamic shape graph data output size to [%ld].", real_dim_size); + if (op_desc->UpdateOutputDesc(0, output_desc) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Update dynamic shape graph data output desc size failed."); + return FAILED; + } + } + return SUCCESS; +} + Status GraphBuilder::SecondPartition(ge::ComputeGraphPtr &comp_graph, vector &subgraph_ptr_list) { GELOGI("[SecondPartition] second partition."); GE_TIMESTAMP_START(GraphPartition2); diff --git a/src/ge/graph/build/graph_builder.h b/src/ge/graph/build/graph_builder.h index d0bf26e6..def3a28b 100644 --- a/src/ge/graph/build/graph_builder.h +++ b/src/ge/graph/build/graph_builder.h @@ -38,6 +38,7 @@ #include "graph/partition/graph_partition.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" +#include "model/ge_root_model.h" namespace ge { class GraphBuilder { @@ -46,8 +47,8 @@ class GraphBuilder { GraphBuilder(const GraphBuilder &in) = delete; GraphBuilder &operator=(const GraphBuilder &in) = delete; virtual ~GraphBuilder() = default; - Status Build(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, GeModelPtr &ge_model_ptr, - uint64_t session_id = INVALID_SESSION_ID); + Status Build(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, + GeRootModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); void SetOptions(const GraphManagerOptions &options); private: @@ -56,8 +57,16 @@ class GraphBuilder { Graph2SubGraphInfoList &subgraph_map, uint64_t session_id = INVALID_SESSION_ID); Status SetInputSize(const ge::NodePtr &node_ptr); Status UpdateDataInputSize(const ge::NodePtr &node_ptr); + Status UpdateParentNodeOutputSize(const ge::ComputeGraphPtr &graph, ge::NodePtr &parent_node_ptr); + Status CalcDynShapeRootGraphDataSize(const ge::OpDescPtr &op_desc); Status SecondPartition(ge::ComputeGraphPtr &comp_graph, vector &subgraph_ptr_list); - + Status BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, + GeRootModelPtr &ge_root_model_ptr, GeModelPtr &ge_model_ptr, + uint64_t session_id = INVALID_SESSION_ID); + Status BuildForKnownShapeGraph(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, + GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); + Status BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, + uint64_t session_id = INVALID_SESSION_ID); int build_mode_; std::map stream_max_parallel_num_; diff --git a/src/ge/graph/build/logical_stream_allocator.cc b/src/ge/graph/build/logical_stream_allocator.cc index ff33e3b7..f57e8086 100644 --- a/src/ge/graph/build/logical_stream_allocator.cc +++ b/src/ge/graph/build/logical_stream_allocator.cc @@ -512,13 +512,14 @@ Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector all_reduce_succs; for (const NodePtr &node : graph->GetDirectNode()) { - if (node->GetType() != HCOMALLREDUCE || node->GetInDataNodes().size() <= 1) { + if ((node->GetType() != HCOMALLREDUCE && node->GetType() != HVDCALLBACKALLREDUCE) || + node->GetInDataNodes().size() <= 1) { continue; } @@ -534,7 +535,10 @@ Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vectorGetOpDesc()); (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); - if (out_stream_label == reduce_stream_label) { + // normally, Allreduce do not have streamLabel. when in horovod scenario Allreduce will have streamLabel + bool isSuccessorParallel = + (out_stream_label == reduce_stream_label) || (!reduce_stream_label.empty() && out_stream_label.empty()); + if (isSuccessorParallel) { all_reduce_succs.emplace(out_node); all_out_data_nodes.emplace(out_node); } diff --git a/src/ge/graph/build/memory/block_mem_assigner.cc b/src/ge/graph/build/memory/block_mem_assigner.cc index 73d3ee98..6d908155 100644 --- a/src/ge/graph/build/memory/block_mem_assigner.cc +++ b/src/ge/graph/build/memory/block_mem_assigner.cc @@ -54,13 +54,42 @@ using std::unordered_map; using std::unordered_set; using std::vector; +void MemoryBlock::SetHeadOffset(size_t offset) { + head_offset_ = offset; + size_t child_offset = head_offset_; + for (auto block : child_blocks_) { + if (block != nullptr) { + block->SetHeadOffset(child_offset); + child_offset += block->Size(); + } + } +} + +void MemoryBlock::SetTailOffset(size_t offset) { + tail_offset_ = offset; + size_t child_offset = head_offset_; + for (auto block : child_blocks_) { + if (block != nullptr) { + child_offset += block->Size(); + block->SetTailOffset(child_offset - 1); + } + } +} + void MemoryBlock::Resize() { + size_t child_block_size = 0; + for (auto block : child_blocks_) { + if (block != nullptr) { + block->Resize(); + child_block_size += block->Size(); + } + } auto iter = std::max_element(real_size_list_.begin(), real_size_list_.end()); if (iter == real_size_list_.end()) { GELOGW("real_size_list_ is empty"); return; } else { - size_t block_size = *iter; + size_t block_size = (child_block_size > *iter) ? child_block_size : *iter; if ((block_size > 0) && (block_size % MEM_ALIGN_SIZE != 0)) { block_size = (block_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; } @@ -102,6 +131,68 @@ bool MemoryBlock::IsSameLabel(std::string &first_batch_label) { return all_same_label; } +bool CanNotLifeReuse(MemoryBlock *block) { + if (block == nullptr || !block->reuse_mem_ || block->deleted_block_ || block->continuous_block_ || + block->GetLifeEnd() == kMaxLifeTime) { + return true; + } + return false; +} + +void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block) { + if (CanNotLifeReuse(this) || CanNotLifeReuse(block)) { + return; + } + MemoryBlock *parent = nullptr; + MemoryBlock *child = nullptr; + // merge small block to large block + if ((block->GetLifeBegin() > GetLifeEnd()) && (block->stream_id_ == stream_id_)) { + if ((child_offset_ + block->block_size_) <= block_size_) { + parent = this; + child = block; + } else if ((block->child_offset_ + block_size_) <= block->block_size_) { + parent = block; + child = this; + } + } + if ((parent != nullptr) && (child != nullptr) && child->child_blocks_.empty()) { + parent->child_blocks_.emplace_back(child); + parent->child_offset_ += child->block_size_; + child->deleted_block_ = true; + GELOGI( + "Add block stream id:%ld [size:%zu, life time[begin:%zu, end:%zu]] to" + " block[size:%zu, life time[begin:%zu, end:%zu]]", + stream_id_, child->block_size_, child->GetLifeBegin(), child->GetLifeEnd(), parent->block_size_, + parent->GetLifeBegin(), parent->GetLifeEnd()); + } +} + +size_t MemoryBlock::GetLifeBegin() { + size_t life_time = 0; + if (!node_type_index_list_.empty()) { + if (node_type_index_list_.front().node != nullptr) { + auto node_op_desc = node_type_index_list_.front().node->GetOpDesc(); + if (node_op_desc != nullptr) { + life_time = node_op_desc->GetId(); + } + } + } + return life_time; +} + +size_t MemoryBlock::GetLifeEnd() { + if (!node_type_index_list_.empty()) { + return node_type_index_list_.back().life_time_end; + } + return kMaxLifeTime; +} + +void MemoryBlock::SetLifeTimeEnd(size_t time) { + if (!node_type_index_list_.empty()) { + node_type_index_list_.back().life_time_end = time; + } +} + void SetLastUsedInputMemAttr(NodePtr &node, int input_index) { if (node == nullptr) { return; @@ -122,6 +213,27 @@ void SetLastUsedInputMemAttr(NodePtr &node, int input_index) { } } +Status GetNoAlignSize(const ge::OpDesc &desc, uint32_t index, size_t &size) { + // calculate tensor real size + auto output_op_desc = desc.GetOutputDescPtr(index); + if (output_op_desc == nullptr) { + GELOGI("GetNoAlignSize failed. OpName: %s, OpType: %s, index: %d", desc.GetName().c_str(), desc.GetType().c_str(), + index); + return FAILED; + } + int64_t tensor_size = 0; + GeShape shape = output_op_desc->GetShape(); + Format format = output_op_desc->GetFormat(); + DataType data_type = output_op_desc->GetDataType(); + graphStatus graph_status = TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size); + if (graph_status != GRAPH_SUCCESS) { + GELOGE(graph_status, "CalcTensorMemSize failed!"); + return FAILED; + } + size = static_cast(tensor_size); + return SUCCESS; +} + string ToString(ge::NodeTypeIndex &x) { stringstream ss; ss << "[" << x.node->GetName() << "(" << x.node->GetType() << "), "; @@ -150,7 +262,7 @@ string MemoryBlock::String() { } BlockMemAssigner::BlockMemAssigner(ge::ComputeGraphPtr compute_graph) - : mem_offset_(0), compute_graph_(std::move(compute_graph)) {} + : mem_offset_(0), compute_graph_(std::move(compute_graph)), life_time_(0) {} BlockMemAssigner::~BlockMemAssigner() { for (MemoryBlock *memory_block : memory_blocks_) { @@ -290,8 +402,9 @@ bool CanReuseBySize(const map &reusable_block_counts, const Me // continuous memory case:only real_size is maximum can be reused and only one continuous memory in one block if (continuous || reusable_block.continuous_block_) { - auto it = std::max_element(std::begin(reusable_block.RealSizeList()), std::end(reusable_block.RealSizeList())); - if (it != std::end(reusable_block.RealSizeList())) { + auto it = + std::max_element(std::begin(reusable_block.NoAlignSizeList()), std::end(reusable_block.NoAlignSizeList())); + if (it != std::end(reusable_block.NoAlignSizeList())) { GE_IF_BOOL_EXEC((continuous && reusable_block.continuous_block_) || (continuous && (real_size < *it)) || (reusable_block.continuous_block_ && (real_size > *it)), GELOGD("Conflict current block size:%zu continuous:%d, reuse block max size:%zu continuous:%d", @@ -498,25 +611,29 @@ void BlockMemAssigner::PrintSymbolMap() { } } -MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, MemoryType mem_type, const NodePtr &n, - uint32_t out_index, const vector &workspace_reuse_flag, - const bool is_op_reuse_mem, const bool continuous) { +MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, size_t no_align_size, + MemoryType mem_type, const NodePtr &n, uint32_t out_index, + const vector &workspace_reuse_flag, const bool is_op_reuse_mem, + const bool continuous) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null."); auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr); + bool is_reuse_memory = false; string ge_disable_reuse_mem_env = "0"; (void)ge::GetContext().GetOption(kDisableReuseMemory, ge_disable_reuse_mem_env); if (ge_disable_reuse_mem_env != "1") { bool reuse_mem_flag = !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); - bool is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && reuse_mem_flag && is_op_reuse_mem && - (IsPreReuse(n, out_index)); + is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && reuse_mem_flag && is_op_reuse_mem && + (IsPreReuse(n, out_index)); auto stream_id = node_op_desc->GetStreamId(); auto map_iter = reusable_streams_map_.find(stream_id); if (is_reuse_memory && map_iter != reusable_streams_map_.end()) { for (auto it = reusable_blocks_.begin(); it != reusable_blocks_.end(); ++it) { MemoryBlock *reusable_block = *it; if (!IsPostReuse(reusable_block)) { + reusable_block->reuse_mem_ = false; + GELOGI("Unreusable block."); continue; } @@ -526,7 +643,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, CanReuseByStream(map_iter->second, *reusable_block)) { GELOGD("Cross stream mem reuse, target stream:%ld, current stream:%ld", reusable_block->stream_id_, stream_id); - reusable_block->AddNodeTypeIndex({n, mem_type, out_index}, real_size); + reusable_block->AddNodeTypeIndex({n, mem_type, out_index}, real_size, no_align_size); if (mem_type == kOutput) { auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); if (iter != anchor_to_symbol_.end()) { @@ -543,7 +660,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, } } - auto block = new (std::nothrow) MemoryBlock(block_size, is_op_reuse_mem); + auto block = new (std::nothrow) MemoryBlock(block_size, is_reuse_memory); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(block == nullptr, return nullptr, "new an object failed."); // Data and netoutput need zero copy block @@ -551,7 +668,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, block->is_zero_copy_ = true; } - block->Init(real_size, mem_type, n, out_index); + block->Init(real_size, mem_type, n, out_index, no_align_size); block->stream_id_ = node_op_desc->GetStreamId(); block->ref_count_++; block->continuous_block_ = continuous; @@ -577,11 +694,14 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, if (output_op_desc != nullptr) { GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS, GELOGI("Get size failed")); } + size_t no_align_size = 0; + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetNoAlignSize(*node_op_desc, index, no_align_size) != SUCCESS, return nullptr, + "Get no align size failed"); if (IsSymbolExist(node_index_io)) { std::string symbol = anchor_to_symbol_[node_index_io.ToString()]; block = symbol_blocks_[symbol]; - block->AddNodeTypeIndex({n, kOutput, index}, size); + block->AddNodeTypeIndex({n, kOutput, index}, size, no_align_size); block->ref_count_++; } else { int64_t max_size = size; @@ -594,7 +714,8 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, } auto block_size = GetBlockSize(max_size, ranges); vector workspace_reuse_flag; - block = ApplyMemory(block_size, size, kOutput, n, index, workspace_reuse_flag, is_op_reuse_mem, continuous); + block = ApplyMemory(block_size, size, no_align_size, kOutput, n, index, workspace_reuse_flag, is_op_reuse_mem, + continuous); } GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(block == nullptr, return nullptr, "Block is nullptr."); int out_count_reuse_input = block->ref_count_; @@ -628,7 +749,7 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, GE_IF_BOOL_EXEC(ge::TensorUtils::GetReuseInputIndex(*owner_node_op_desc, dst_reuse_input_index) != SUCCESS, GELOGI("Get dst_reuse_input_index failed")); if (dst_reuse_input && (dst_reuse_input_index == static_cast(in_anchor->GetIdx()))) { - block->AddNodeTypeIndex({owner_node, kOutput, i}, block->Size()); + block->AddNodeTypeIndex({owner_node, kOutput, i}, block->Size(), block->Size()); out_count_reuse_input += 1; reuse_input = true; } @@ -710,6 +831,7 @@ void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vectorreuse_mem_, return, "doesn't reuse memory"); --to_release->ref_count_; if (to_release->ref_count_ == 0) { + to_release->SetLifeTimeEnd(life_time_); reusable_memory.emplace_back(to_release); AddReusableBlockCount(*to_release, reusable_block_counts_); } @@ -852,12 +974,11 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector zero_memory_list_.emplace_back(node, kOutput, i); continue; } - bool reuse_mem = is_op_reuse_mem_; // atomic can't be reused if (is_op_reuse_mem_ && out_node_set_continuous_input && is_atomic) { - reuse_mem = false; + is_op_reuse_mem_ = false; } - MemoryBlock *mem_block = ApplyOutMemory(node, i, ranges, reuse_mem, out_node_set_continuous_input); + MemoryBlock *mem_block = ApplyOutMemory(node, i, ranges, is_op_reuse_mem_, out_node_set_continuous_input); if (mem_block != nullptr) { node_out_blocks_[node->GetName()].emplace_back(mem_block); if (out_node_set_continuous_input) { @@ -894,6 +1015,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { for (NodePtr &n : compute_graph_->GetAllNodes()) { auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); + life_time_ = node_op_desc->GetId(); int64_t stream_id = node_op_desc->GetStreamId(); if (AssignOutputMemoryWithReuse(n, ranges) != SUCCESS) { return; @@ -930,9 +1052,9 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { zero_memory_list_.emplace_back(n, kWorkspace, static_cast(i)); continue; } - MemoryBlock *mem_block = - ApplyMemory(GetBlockSize(static_cast(temp[i]), ranges), static_cast(temp[i]), kWorkspace, n, - static_cast(i), workspace_reuse_flag, is_op_reuse_mem_, false); + MemoryBlock *mem_block = ApplyMemory(GetBlockSize(static_cast(temp[i]), ranges), + static_cast(temp[i]), static_cast(temp[i]), kWorkspace, n, + static_cast(i), workspace_reuse_flag, is_op_reuse_mem_, false); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mem_block == nullptr, continue, "failed to apply memory block."); CheckWorkspaceReuse(workspace_reuse_flag, i, stream_id, mem_block); } @@ -1001,7 +1123,8 @@ void MergeBlocks(std::vector &dest, std::vector &s dest[i]->AddSymbol(symbol); } for (size_t j = 0; j < src[i]->NodeTypeIndexList().size(); ++j) { - dest[i]->AddNodeTypeIndex(src[i]->NodeTypeIndexList()[j], src[i]->RealSizeList()[j]); + dest[i]->AddNodeTypeIndex(src[i]->NodeTypeIndexList()[j], src[i]->RealSizeList()[j], + src[i]->NoAlignSizeList()[j]); src[i]->deleted_block_ = true; } } @@ -1115,6 +1238,21 @@ void BlockMemAssigner::AssignContinuousBlocks() { } } +void BlockMemAssigner::ReuseBlocksByLifeTime() { + for (size_t i = 0; i < memory_blocks_.size(); ++i) { + auto parent = memory_blocks_[i]; + if (parent == nullptr || parent->deleted_block_) { + continue; + } + if (parent->reuse_mem_ && !IsPostReuse(parent)) { + parent->reuse_mem_ = false; + } + for (size_t j = i + 1; j < memory_blocks_.size(); ++j) { + parent->AddLifeReuseBlock(memory_blocks_[j]); + } + } +} + /// /// @ingroup domi_omg /// @brief traverse memory size, resize, calculate offset @@ -1129,8 +1267,8 @@ void BlockMemAssigner::ResizeMemoryBlocks() { memory_block->SetHeadOffset(mem_offset_); mem_offset_ += memory_block->Size(); memory_block->SetTailOffset(mem_offset_ - 1); - GELOGI("mem_offset_ exclude zero_copy_memory is %zu.", mem_offset_); } + GELOGI("mem_offset_ exclude zero_copy_memory is %zu.", mem_offset_); } /// @@ -1142,15 +1280,18 @@ void BlockMemAssigner::ResizeMemoryBlocks() { /// @param [in] real_size memory size in need /// @return Status result /// -void SetOffsetSize(const NodeTypeIndex &node_type_index, int64_t offset, size_t size, size_t real_size) { - ge::OpDescPtr op_desc = node_type_index.node->GetOpDesc(); +void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, size_t real_size, size_t no_align_size, + bool child_block) { + ge::OpDescPtr op_desc = node_type.node->GetOpDesc(); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc == nullptr, return, "op_desc is null."); - string graph_name = node_type_index.node->GetOwnerComputeGraph()->GetName(); + string graph_name = node_type.node->GetOwnerComputeGraph()->GetName(); vector memorys_type; + int64_t offset = block->HeadOffset(); + size_t end = node_type.life_time_end; bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, memorys_type); - if (node_type_index.mem_type == kOutput) { + if (node_type.mem_type == kOutput) { vector output_list = op_desc->GetOutputOffset(); - for (auto i = static_cast(output_list.size()); i < node_type_index.index + 1; i++) { + for (auto i = static_cast(output_list.size()); i < node_type.index + 1; i++) { output_list.emplace_back(kInvalidOffset); } if (output_list.empty()) { @@ -1160,39 +1301,56 @@ void SetOffsetSize(const NodeTypeIndex &node_type_index, int64_t offset, size_t if ((op_desc->GetType() == DATA) || (op_desc->GetType() == AIPP_DATA_TYPE) || (op_desc->GetType() == MULTISHAPE) || (op_desc->GetType() == NETOUTPUT)) { - if ((output_list[node_type_index.index] == kInvalidOffset) || (output_list[node_type_index.index] < offset)) { - output_list.at(node_type_index.index) = offset; + if ((output_list[node_type.index] == kInvalidOffset) || (output_list[node_type.index] < offset)) { + output_list.at(node_type.index) = offset; } } else { // fusion: keep the original other type offset value from op_desc - bool set_out_offset = (!has_mem_type_attr) || (memorys_type[node_type_index.index] != RT_MEMORY_L1); + bool set_out_offset = (!has_mem_type_attr) || + (memorys_type.size() > node_type.index && memorys_type[node_type.index] != RT_MEMORY_L1); if (set_out_offset) { - output_list.at(node_type_index.index) = offset; + output_list.at(node_type.index) = offset; } } op_desc->SetOutputOffset(output_list); - GELOGI("[IMAS]Set %s name[%s] output[%d] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu].", - graph_name.c_str(), op_desc->GetName().c_str(), node_type_index.index, offset, op_desc->GetStreamId(), size, - real_size); - } else if (node_type_index.mem_type == kWorkspace) { + } else if (node_type.mem_type == kWorkspace) { vector workspace_list; workspace_list = op_desc->GetWorkspace(); - for (auto i = static_cast(workspace_list.size()); i < node_type_index.index + 1; i++) { + for (auto i = static_cast(workspace_list.size()); i < node_type.index + 1; i++) { workspace_list.emplace_back(kInvalidOffset); } - vector workspace_memory_type; - bool has_workspace_mem_type_attr = - ge::AttrUtils::GetListInt(op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, workspace_memory_type); + vector workspace_mem_type; + bool has_workspace_mem_type = ge::AttrUtils::GetListInt(op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, workspace_mem_type); // fusion: keep the original other type offset value from op_desc - bool set_workspace_offset = - (!has_workspace_mem_type_attr) || (workspace_memory_type[node_type_index.index] != RT_MEMORY_L1); + bool set_workspace_offset = (!has_workspace_mem_type) || (workspace_mem_type.size() > node_type.index && + workspace_mem_type[node_type.index] != RT_MEMORY_L1); if (set_workspace_offset) { - workspace_list.at(node_type_index.index) = offset; + workspace_list.at(node_type.index) = offset; } op_desc->SetWorkspace(workspace_list); - GELOGI("[IMAS]Set %s name[%s] workspace[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu].", - graph_name.c_str(), op_desc->GetName().c_str(), node_type_index.index, offset, op_desc->GetStreamId(), size, - real_size); + } + GELOGI( + "[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu]" + " noalignsize[%zu] life time begin[%zu] life time end[%zu] child[%d].", + graph_name.c_str(), op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset, + op_desc->GetStreamId(), block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block); +} + +void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) { + if (block == nullptr) { + return; + } + size_t index = 0; + size_t real_size = 0; + size_t no_align_size = 0; + auto real_size_list_size = block->RealSizeList().size(); + for (const NodeTypeIndex &node_type_index : block->NodeTypeIndexList()) { + if (index < real_size_list_size) { + real_size = block->RealSizeList()[index]; + no_align_size = block->NoAlignSizeList()[index]; + } + SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block); + index++; } } @@ -1206,21 +1364,16 @@ void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) { continue; } - size_t index = 0; - size_t real_size = 0; - auto real_size_list_size = memory_block->RealSizeList().size(); - for (const NodeTypeIndex &node_type_index : memory_block->NodeTypeIndexList()) { - if (index < real_size_list_size) { - real_size = memory_block->RealSizeList()[index]; - } - SetOffsetSize(node_type_index, memory_block->HeadOffset(), memory_block->Size(), real_size); - index++; + SetBlockOpMemOffset(memory_block, false); + for (MemoryBlock *child_block : memory_block->ChildBlockList()) { + SetBlockOpMemOffset(child_block, true); } } if (!is_zero_copy) { for (const NodeTypeIndex &node_type_index : zero_memory_list_) { - SetOffsetSize(node_type_index, 0, 0, 0); + MemoryBlock block(0, 0); + SetOffsetSize(node_type_index, &block, 0, 0, false); } } } @@ -1290,7 +1443,7 @@ void BlockMemAssigner::FindHeadAndTailNodesForStream(mapGetOpDesc()->GetOutputsSize(); i++) { int64_t size = 0; if (ge::TensorUtils::GetSize(*n->GetOpDesc()->GetOutputDescPtr(static_cast(i)), size) != SUCCESS) { - GELOGW("Get output size failed!"); + GELOGW("Get output size failed!"); continue; } stream_mem_map[stream_id] += size; @@ -1375,6 +1528,6 @@ void BlockMemAssigner::FindDependentStreamBetweenGraphs(const NodePtr &pre_node, bool BlockMemAssigner::CheckIsZeroMemNodeType(const string &node_type) const { return (node_type == VARIABLE) || (node_type == CONSTANT) || (node_type == MULTISHAPE) || (node_type == HCOMBROADCAST) || (node_type == HCOMALLREDUCE) || (node_type == CONSTANTOP) || - (node_type == ASSIGNADD) || (node_type == ASSIGNSUB) || (node_type == ASSIGN); + (node_type == ASSIGNADD) || (node_type == ASSIGNSUB) || (node_type == ASSIGN) || (node_type == HVDWAIT); } } // namespace ge diff --git a/src/ge/graph/build/memory/block_mem_assigner.h b/src/ge/graph/build/memory/block_mem_assigner.h index 97e69431..7382fc72 100644 --- a/src/ge/graph/build/memory/block_mem_assigner.h +++ b/src/ge/graph/build/memory/block_mem_assigner.h @@ -31,6 +31,8 @@ #include "graph/utils/graph_utils.h" namespace ge { +const size_t kMaxLifeTime = 0xffffffff; + enum MemoryType { kOutput, kWorkspace }; struct NodeTypeIndex { @@ -40,6 +42,15 @@ struct NodeTypeIndex { ge::NodePtr node = nullptr; MemoryType mem_type = kOutput; uint32_t index = 0; + size_t life_time_end = kMaxLifeTime; + const string GetMemType() const { + if (mem_type == kOutput) { + return "output"; + } else if (mem_type == kWorkspace) { + return "workspace"; + } + return "unknown"; + } }; class MemoryBlock { @@ -55,7 +66,8 @@ class MemoryBlock { is_zero_copy_(false), block_size_(block_size), head_offset_(0), - tail_offset_(0) {} + tail_offset_(0), + child_offset_(0) {} MemoryBlock(const MemoryBlock &) = delete; @@ -66,23 +78,25 @@ class MemoryBlock { symbol_list_.clear(); } - void Init(size_t real_size, MemoryType type, const ge::NodePtr &node, uint32_t out_index) { + void Init(size_t real_size, MemoryType type, const ge::NodePtr &node, uint32_t out_index, size_t no_align_size) { real_size_list_.emplace_back(real_size); + no_align_size_list_.emplace_back(no_align_size); node_type_index_list_.emplace_back(node, type, out_index); } size_t Size() const { return block_size_; } - void SetHeadOffset(size_t offset) { head_offset_ = offset; } + void SetHeadOffset(size_t offset); - void SetTailOffset(size_t offset) { tail_offset_ = offset; } + void SetTailOffset(size_t offset); size_t HeadOffset() const { return head_offset_; } size_t TailOffset() const { return tail_offset_; } - void AddNodeTypeIndex(const NodeTypeIndex &node_type_index, size_t real_size) { + void AddNodeTypeIndex(const NodeTypeIndex &node_type_index, size_t real_size, size_t no_align_size) { node_type_index_list_.emplace_back(node_type_index); real_size_list_.emplace_back(real_size); + no_align_size_list_.emplace_back(no_align_size); } void AddSymbol(const std::string &symbol) { symbol_list_.emplace_back(symbol); } @@ -90,6 +104,8 @@ class MemoryBlock { const std::vector &NodeTypeIndexList() const { return node_type_index_list_; } const std::vector &SymbolList() const { return symbol_list_; } const std::vector &RealSizeList() const { return real_size_list_; } + const std::vector &ChildBlockList() const { return child_blocks_; } + const std::vector &NoAlignSizeList() const { return no_align_size_list_; } void Resize(); @@ -97,6 +113,14 @@ class MemoryBlock { bool IsSameLabel(std::string &first_batch_label); + void AddLifeReuseBlock(MemoryBlock *block); + + void SetLifeTimeEnd(size_t time); + + size_t GetLifeBegin(); + + size_t GetLifeEnd(); + int ref_count_; int64_t stream_id_; bool deleted_block_; @@ -109,10 +133,13 @@ class MemoryBlock { private: size_t block_size_; std::vector real_size_list_; + std::vector no_align_size_list_; size_t head_offset_; size_t tail_offset_; + size_t child_offset_; std::vector node_type_index_list_; std::vector symbol_list_; + std::vector child_blocks_; }; class BlockMemAssigner : public MemAssigner { @@ -292,8 +319,8 @@ class BlockMemAssigner : public MemAssigner { /// @return MemoryBlock* /// @author /// - MemoryBlock *ApplyMemory(size_t block_size, size_t real_size, MemoryType mem_type, const ge::NodePtr &n, - uint32_t out_index, const std::vector &workspace_reuse_flag, + MemoryBlock *ApplyMemory(size_t block_size, size_t real_size, size_t no_align_size, MemoryType mem_type, + const ge::NodePtr &n, uint32_t out_index, const std::vector &workspace_reuse_flag, const bool is_op_reuse_mem, const bool continuous); /// @@ -354,6 +381,17 @@ class BlockMemAssigner : public MemAssigner { bool IsOutNodeSetContinuousInput(const NodePtr &n, uint32_t out_index, std::string &peer_name, uint32_t &peer_input_index); + /// + /// @ingroup GE + /// @|+++++++++block1++++++++| |+++++++++block1++++++++| + /// @|+++++++++block1++++++++||++block2++| |+++++++++block1++++++++||++block2++| + /// @ |++block2++||++block3++| ==> |++block3++| |++block2++| + /// @ |++block3++| |++block3++| + /// @return void + /// @author + /// + void ReuseBlocksByLifeTime(); + std::vector reusable_blocks_; std::map reusable_block_counts_; @@ -380,6 +418,8 @@ class BlockMemAssigner : public MemAssigner { bool is_op_reuse_mem_ = true; + size_t life_time_; + int64_t atomic_addr_clean_id_ = 0; }; } // namespace ge diff --git a/src/ge/graph/build/memory/graph_mem_assigner.cc b/src/ge/graph/build/memory/graph_mem_assigner.cc index c3078bec..931ebba4 100644 --- a/src/ge/graph/build/memory/graph_mem_assigner.cc +++ b/src/ge/graph/build/memory/graph_mem_assigner.cc @@ -245,13 +245,14 @@ Status GraphMemoryAssigner::AssignZeroCopyMemory(size_t &mem_offset, size_t &zer memory_block->SetHeadOffset(mem_offset); mem_offset += memory_block->Size(); memory_block->SetTailOffset(mem_offset - 1); - GELOGI("mem_offset_ include zero_copy_memory is %zu.", mem_offset); } + GELOGI("mem_offset_ include zero_copy_memory is %zu.", mem_offset); // set offset for zero copy nodes priority_assigner->SetOpMemOffset(true); - zero_mem_copy_size = mem_offset - mem_offset_tmp; + memory_offset_[0].mem_offset_ = mem_offset; + GELOGI("max_mem_offset:%zu, mem_offset:%zu, zero_mem_copy_size:%zu.", mem_offset, mem_offset_tmp, zero_mem_copy_size); return SUCCESS; @@ -360,8 +361,11 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, return PARAM_INVALID;); vector output_list = peer_op_desc->GetOutputOffset(); + std::vector offsets_for_fusion = {}; + bool has_offset_attr = + AttrUtils::GetListInt(peer_op_desc, ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION, offsets_for_fusion); if (peer_out_data_anchor->GetIdx() < static_cast(output_list.size())) { - if (continuous_input_alloc) { + if (continuous_input_alloc && !has_offset_attr) { if (in_data_anchor->GetIdx() == 0) { continuous_mem_start = output_list.at(peer_out_data_anchor->GetIdx()); } @@ -391,9 +395,7 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, } peer_op_desc->SetOutputOffset(output_list); size_t pre_mem_offset = memory_offset_[0].mem_offset_; - std::vector offsets_for_fusion = {}; - bool has_offset_attr = - AttrUtils::GetListInt(peer_op_desc, ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION, offsets_for_fusion); + int64_t tensor_desc_size = 0; if (has_offset_attr) { if (peer_out_data_anchor->GetIdx() < static_cast(offsets_for_fusion.size())) { @@ -1232,7 +1234,7 @@ ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node, vector< ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node) const { GE_CHECK_NOTNULL(node->GetOpDesc()); vector input_list; - if (node->GetType() == HCOMBROADCAST) { + if (node->GetType() == HCOMBROADCAST || node->GetType() == HVDCALLBACKBROADCAST) { for (const auto &anchor : node->GetAllInDataAnchors()) { vector output_list; auto peer_out_anchor = anchor->GetPeerOutAnchor(); diff --git a/src/ge/graph/build/memory/var_mem_assign_util.cc b/src/ge/graph/build/memory/var_mem_assign_util.cc index a71e09b2..111adc7a 100644 --- a/src/ge/graph/build/memory/var_mem_assign_util.cc +++ b/src/ge/graph/build/memory/var_mem_assign_util.cc @@ -208,7 +208,7 @@ Status VarMemAssignUtil::DealVariableNode(uint32_t graph_id, const ge::NodePtr & for (const ge::OutDataAnchorPtr &var_out_data_anchor : node->GetAllOutDataAnchors()) { for (const ge::InDataAnchorPtr &dst_in_data_anchor : var_out_data_anchor->GetPeerInDataAnchors()) { ge::NodePtr dst_node = dst_in_data_anchor->GetOwnerNode(); - if (dst_node->GetType() == HCOMBROADCAST) { + if (dst_node->GetType() == HCOMBROADCAST || dst_node->GetType() == HVDCALLBACKBROADCAST) { GE_CHK_STATUS_RET(DealBroadCastNode(graph_id, dst_node, dst_in_data_anchor, node, session_id)); continue; } diff --git a/src/ge/graph/build/model_builder.cc b/src/ge/graph/build/model_builder.cc index 4c3f3ffd..a3ecc63c 100644 --- a/src/ge/graph/build/model_builder.cc +++ b/src/ge/graph/build/model_builder.cc @@ -412,7 +412,9 @@ Status ModelBuilder::BuildModelDef(ge::Model &model) { GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_ZERO_COPY_MEMORY_SIZE, zero_copy_mem_size_), GELOGE(FAILED, "SetInt of ATTR_MODEL_ZERO_COPY_MEMORY_SIZE failed."); return FAILED); - + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(&model, ATTR_MODEL_OUT_NODES_NAME, domi::GetContext().net_out_nodes), + GELOGE(FAILED, "SetListStr of ATTR_MODEL_OUT_NODES_NAME failed."); + return FAILED); GELOGI("For model, max_mem_offset_: %zu, zero_copy_mem_size_: %zu", max_mem_offset_, zero_copy_mem_size_); string ge_core_type; @@ -651,6 +653,14 @@ Status ModelBuilder::BuildModelForGetTask(ge::Model &model) { return SUCCESS; } +Status ModelBuilder::BuildModelForGetDynShapeTask(ge::Model &model_def) { + GE_TIMESTAMP_START(BuildModelDef); + GE_CHK_STATUS_RET(BuildModelDef(model_def), "BuildModelDef failed!"); + GE_TIMESTAMP_END(BuildModelDef, "GraphBuilder::BuildModelDef"); + SetModelVersion(model_def); + return SUCCESS; +} + ge::Buffer ModelBuilder::GetWeightBuffer() const { return weight_buffer_; } Status ModelBuilder::CompileSingleOp() { GELOGD("Begin to compile single op."); diff --git a/src/ge/graph/build/model_builder.h b/src/ge/graph/build/model_builder.h index 8f0d69b4..21e611ee 100644 --- a/src/ge/graph/build/model_builder.h +++ b/src/ge/graph/build/model_builder.h @@ -50,6 +50,7 @@ class ModelBuilder { Status SaveDataToModel(ge::Model &model, ge::GeModel &ge_model); Status PreBuildModel(); Status BuildModelForGetTask(ge::Model &model_def); + ge::Status BuildModelForGetDynShapeTask(ge::Model &model_def); ge::Buffer GetWeightBuffer() const; diff --git a/src/ge/graph/build/stream_allocator.cc b/src/ge/graph/build/stream_allocator.cc index d1efa221..318134bd 100644 --- a/src/ge/graph/build/stream_allocator.cc +++ b/src/ge/graph/build/stream_allocator.cc @@ -20,12 +20,12 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/fmk_error_codes.h" #include "framework/common/types.h" -#include "graph/ge_context.h" +#include "graph/build/logical_stream_allocator.h" #include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" +#include "graph/ge_context.h" #include "graph/utils/graph_utils.h" #include "init/gelib.h" -#include "graph/build/logical_stream_allocator.h" using std::map; using std::set; @@ -72,8 +72,7 @@ StreamAllocator::StreamAllocator(ComputeGraphPtr whole_graph, const Graph2SubGra Status StreamAllocator::AssignLogicalStreams(const std::map &max_parallel_num, bool hcom_parallel) { GELOGI("Assign logical streams start."); GE_CHECK_NOTNULL(whole_graph_); - GraphUtils::DumpGEGraph(whole_graph_, "BeforeAssignedLogicalStreams"); - GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "BeforeAssignedLogicalStreams"); + GE_DUMP(whole_graph_, "BeforeAssignedLogicalStreams"); auto gelib = GELib::GetInstance(); if (gelib == nullptr) { @@ -91,9 +90,7 @@ Status StreamAllocator::AssignLogicalStreams(const std::map &m GELOGE(status, "Assign logical streams failed."); return status; } - - GraphUtils::DumpGEGraph(whole_graph_, "AfterAssignedLogicalStreams"); - GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "AfterAssignedLogicalStreams"); + GE_DUMP(whole_graph_, "AfterAssignedLogicalStreams"); GELOGI("Assign logical streams success."); return SUCCESS; @@ -104,6 +101,7 @@ Status StreamAllocator::AssignLogicalStreams(const std::map &m Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_num) { GELOGI("RefreshRealStream start."); GE_CHECK_NOTNULL(whole_graph_); + GE_DUMP(whole_graph_, "BeforeRefreshRealStream"); Status status = AssignSingleStream(); if (status != SUCCESS) { @@ -117,6 +115,12 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu return status; } + status = SetActiveStreamsForSubgraphs(); + if (status != SUCCESS) { + GELOGE(status, "SetActiveStreamsForSubgraphs failed."); + return status; + } + status = InsertSyncEvents(); if (status != SUCCESS) { GELOGE(status, "InsertSyncEventId failed!"); @@ -161,8 +165,7 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu } DumpEvents(); - GraphUtils::DumpGEGraph(whole_graph_, "RefreshRealStream"); - GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "RefreshRealStream"); + GE_DUMP(whole_graph_, "AfterRefreshRealStream"); for (const NodePtr &node : whole_graph_->GetAllNodes()) { GE_CHECK_NOTNULL(node->GetOpDesc()); @@ -232,657 +235,858 @@ Status StreamAllocator::AssignSingleStream() { return SUCCESS; } -// Split the stream according to the maximum number of nodes in the stream. -Status StreamAllocator::SplitStreams(vector> &split_streams) { - if (enable_single_stream_ || stream_num_ == 0) { - GELOGI("The single stream option is enabled or the number of streams is 0, no need to split streams."); - return SUCCESS; - } - - // stream_node_num_vec records the number of all nodes on each stream - // added_stream_num_vec records the number of streams that each stream needs to increase - // new_stream_id_vec records the new physical stream id for each stream - vector stream_node_num_vec(stream_num_); - vector added_stream_num_vec(stream_num_); - vector new_stream_id_vec(stream_num_); - map stream_continuous_2_node_num_map; - map> stream_continuous_2_nodes_map; - map> stream_2_nodes_map; - vector pre_node_vec(stream_num_); - - int64_t last_stream_id = stream_num_ - 1; - for (auto i = 0; i <= last_stream_id; i++) { - stream_node_num_vec[i] = 0; - added_stream_num_vec[i] = 0; - new_stream_id_vec[i] = i; - pre_node_vec[i] = nullptr; +Status StreamAllocator::SetActiveStreamsByLabel() { + for (const auto &node : whole_graph_->GetAllNodes()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + string stream_label; + if (AttrUtils::GetStr(op_desc, ATTR_NAME_STREAM_LABEL, stream_label) && !stream_label.empty()) { + int64_t stream_id = op_desc->GetStreamId(); + if (stream_id != kInvalidStream) { + labeled_streams_[stream_label].emplace(stream_id); + } + } } - uint32_t max_stream_count = 0; - uint32_t max_task_count = 0; - GE_CHK_STATUS_RET(GetMaxStreamAndTask(false, max_stream_count, max_task_count), - "Get max stream and task count failed."); - - for (const auto &cur_node : whole_graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(cur_node); - auto op_desc = cur_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - int64_t stream_id = op_desc->GetStreamId(); - if (stream_id == kInvalidStream) { + for (const auto &node : whole_graph_->GetAllNodes()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + vector activated_label_list; + if (!AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_ACTIVE_LABEL_LIST, activated_label_list) || + activated_label_list.empty()) { continue; } - if (stream_id > last_stream_id) { - GELOGE(FAILED, "SplitStreams:streamid(%ld) > last_stream_id(%ld)", stream_id, last_stream_id); - return FAILED; - } - stream_node_num_vec[stream_id]++; - stream_2_nodes_map[stream_id].push_back(cur_node); - // The maximum number of tasks per stream. - int64_t max_node_num_one_stream = GetMaxNodeNumPerStream(cur_node, max_task_count); - std::string continuous_stream_label; - if (HasContinuousStreamLabel(op_desc, continuous_stream_label)) { - stream_continuous_2_node_num_map[continuous_stream_label]++; - // return error - if (stream_continuous_2_node_num_map[continuous_stream_label] > max_node_num_one_stream) { - GELOGE(FAILED, "SplitStreams:node[%s] stream_id[%ld] continuous stream label[%s] unsatisfied ", - op_desc->GetName().c_str(), stream_id, continuous_stream_label.c_str()); - return FAILED; + + vector activated_stream_list; + for (string &activated_label : activated_label_list) { + specific_activated_labels_[activated_label].emplace(node); + for (int64_t activated_stream : labeled_streams_[activated_label]) { + activated_stream_list.push_back(static_cast(activated_stream)); + specific_activated_streams_.emplace(activated_stream); + specific_activated_streams_nodes_map_[activated_stream].emplace(node); + GELOGI("Node %s active stream %ld by %s.", node->GetName().c_str(), activated_stream, activated_label.c_str()); } - stream_continuous_2_nodes_map[continuous_stream_label].push_back(cur_node); } - // Split the stream if it exceeds the maximum number of nodes in the stream. - if (stream_node_num_vec[stream_id] > max_node_num_one_stream) { - last_stream_id++; - GELOGI( - "stream_node_num_vec[%ld]= %ld > max_node_num_one_stream : %ld, " - "It's time to split the stream, split newly-added stream id is %ld", - stream_id, stream_node_num_vec[stream_id], max_node_num_one_stream, last_stream_id); - NodePtr pre_node = pre_node_vec[stream_id]; - stream_node_num_vec[stream_id] = 1; - // try spilt a new stream and move same continuous stream label nodes from this stream - bool not_use_cur = false; - NodePtr not_cur = nullptr; - std::string cur_continuous_stream_label; - if (HasContinuousStreamLabel(op_desc, cur_continuous_stream_label)) { - // get stored nodes - auto nodes = stream_continuous_2_nodes_map[cur_continuous_stream_label]; - GE_RETURN_WITH_LOG_IF_FALSE(!nodes.empty(), "split stream with continuous stream label %s failed", - cur_continuous_stream_label.c_str()); - for (const auto &node : nodes) { - auto stored_op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(stored_op_desc); - stored_op_desc->SetStreamId(last_stream_id); - stream_node_num_vec[stream_id]++; - } - not_use_cur = true; - not_cur = nodes.front(); - GE_CHECK_NOTNULL(not_cur); - GELOGI("split from first node %s with continuous stream label %s", not_cur->GetName().c_str(), - cur_continuous_stream_label.c_str()); - auto iter = std::find(stream_2_nodes_map[stream_id].begin(), stream_2_nodes_map[stream_id].end(), not_cur); - GE_RETURN_WITH_LOG_IF_FALSE( - (iter != stream_2_nodes_map[stream_id].end()) && (iter != stream_2_nodes_map[stream_id].begin()), - "split stream with continuous stream label %s failed", cur_continuous_stream_label.c_str()); - iter--; - pre_node = *iter; - } + GE_CHK_BOOL_EXEC(AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, activated_stream_list), + GELOGE(FAILED, "SetListInt failed."); + return FAILED); + } - added_stream_num_vec[stream_id]++; - new_stream_id_vec[stream_id] = last_stream_id; - split_streams[stream_id].emplace(last_stream_id); + return SUCCESS; +} - // Add the send/recv event to the first and last nodes of the split stream. - if (pre_node != nullptr) { - GE_CHK_STATUS_RET(AddEventId(pre_node, not_cur, cur_node, not_use_cur), "AddEventId failed."); +Status StreamAllocator::SetActiveStreamsForSubgraphs() { + for (auto &subgraph : whole_graph_->GetAllSubgraphs()) { + GE_CHECK_NOTNULL(subgraph); + NodePtr first_active_node = nullptr; + + // Get all streams in subgraph. + set subgraph_streams; + for (auto &node : subgraph->GetDirectNode()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + // Skip streams with label + string stream_label; + if (AttrUtils::GetStr(op_desc, ATTR_NAME_STREAM_LABEL, stream_label) && !stream_label.empty()) { + continue; + } + int64_t stream_id = op_desc->GetStreamId(); + if (stream_id != kInvalidStream) { + subgraph_streams.emplace(stream_id); + GELOGI("Add stream %ld to active_stream_list of node %s of graph %s", stream_id, node->GetName().c_str(), + subgraph->GetName().c_str()); + } + bool is_first_active = false; + if (AttrUtils::GetBool(op_desc, ATTR_NAME_SUBGRAPH_FIRST_ACTIVE, is_first_active) && is_first_active) { + first_active_node = node; } } - /// If the split stream num is greater than 1, the node behind the same - /// stream must reset the new stream id. - if (added_stream_num_vec[stream_id] >= 1) { - op_desc->SetStreamId(new_stream_id_vec[stream_id]); + if (first_active_node == nullptr) { + continue; } - pre_node_vec[stream_id] = cur_node; - } + subgraph_first_active_node_map_[subgraph] = first_active_node; - if (last_stream_id >= 0) { - stream_num_ = last_stream_id + 1; - } - return SUCCESS; -} + // Set active streams for StreamActive. + subgraph_streams.erase(first_active_node->GetOpDesc()->GetStreamId()); -Status StreamAllocator::AddEventId(const NodePtr &pre_node, const NodePtr ¬_cur, const NodePtr &cur_node, - bool not_use_cur) { - GELOGI("Add send event %u for node %s", event_num_, pre_node->GetName().c_str()); - AddSendEventId(pre_node, event_num_); - if (not_use_cur) { - GE_CHECK_NOTNULL(not_cur); - GELOGI("Add recv event %u for node %s", event_num_, not_cur->GetName().c_str()); - AddRecvEventId(not_cur, event_num_); - } else { - GELOGI("Add recv event %u for node %s", event_num_, cur_node->GetName().c_str()); - AddRecvEventId(cur_node, event_num_); + vector active_streams; + for (int64_t active_stream : subgraph_streams) { + active_streams.emplace_back(static_cast(active_stream)); + specific_activated_streams_.emplace(active_stream); + } + + if (!AttrUtils::SetListInt(first_active_node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + GELOGE(FAILED, "Set active streams for node %s failed.", first_active_node->GetName().c_str()); + return FAILED; + } } - ++event_num_; return SUCCESS; } -Status StreamAllocator::GetMaxStreamAndTask(bool huge_stream, uint32_t &max_stream_count, uint32_t &max_task_count) { - const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); - if (buffer_optimize_on != nullptr) { - rtError_t ret = rtSetPlatformType(PLATFORM_MINI_V1); - if (ret != RT_ERROR_NONE) { - GELOGE(FAILED, "Get max stream and task count by rts failed."); - return FAILED; +// Insert the send/recv event id to the graph +Status StreamAllocator::InsertSyncEvents() { + for (const auto &cur_node : whole_graph_->GetAllNodes()) { + // Take the adjacent points, then judge whether need to insert the event + for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) { + for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) { + NodePtr next_node = peer_in_anchor->GetOwnerNode(); + Status status = InsertOneEventInTwoNodes(cur_node, next_node); + if (status != SUCCESS) { + GELOGE(status, "InsertOneEventInTwoNodes failed!"); + return status; + } + } } - } - uint32_t stream_type = RT_NORMAL_STREAM; - if (huge_stream) { - stream_type = RT_HUGE_STREAM; + /// If the two nodes of the control side belong to two streams, + /// you also need to add the send/recv event. + if (cur_node->GetOutControlAnchor() != nullptr) { + for (const AnchorPtr &peer_in_anchor : cur_node->GetOutControlAnchor()->GetPeerAnchors()) { + NodePtr next_node = peer_in_anchor->GetOwnerNode(); + Status status = InsertOneEventInTwoNodes(cur_node, next_node); + if (status != SUCCESS) { + GELOGE(status, "InsertOneEventInTwoNodes failed!"); + return status; + } + } + } } - rtError_t ret = rtGetMaxStreamAndTask(stream_type, &max_stream_count, &max_task_count); - if (ret != RT_ERROR_NONE) { - GELOGE(FAILED, "Get max stream and task count by rts failed."); - return FAILED; + + Status status = InsertEventsForSubgraph(); + if (status != SUCCESS) { + GELOGE(status, "InsertEventsBetweenSubAndParentGraphNodes failed!"); + return status; } - GELOGI("Allowed max stream count: %u, max task count per stream: %u.", max_stream_count, max_task_count); return SUCCESS; } -int64_t StreamAllocator::GetMaxNodeNumPerStream(const NodePtr &node, uint32_t max_task_count) { - int64_t max_node_num_one_stream = static_cast(max_task_count); - string op_type = node->GetType(); - if (IsHcclOp(op_type)) { - max_node_num_one_stream /= kTaskNumPerHcclNode; - } else { - max_node_num_one_stream /= kTaskNumPerNormalNode; +// Insert one send/recv event in two nodes +Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const NodePtr &next_node) { + GE_CHECK_NOTNULL(cur_node->GetOpDesc()); + GE_CHECK_NOTNULL(next_node->GetOpDesc()); + + // No need to insert events after node that do not assign streams. + int64_t cur_stream_id = cur_node->GetOpDesc()->GetStreamId(); + if (cur_stream_id == kInvalidStream) { + GELOGD("No need to insert event after node %s.", cur_node->GetName().c_str()); + return SUCCESS; } - if (max_node_num_one_stream == 0) { - max_node_num_one_stream = 1; + + // No need to insert events between nodes in the same stream. + int64_t next_stream_id = next_node->GetOpDesc()->GetStreamId(); + if (cur_stream_id == next_stream_id) { + return SUCCESS; } - return max_node_num_one_stream; -} + if (next_stream_id == kInvalidStream) { + GELOGE(FAILED, "Stream id of next_node %s should not be %ld", next_node->GetName().c_str(), kInvalidStream); + return FAILED; + } -Status StreamAllocator::UpdateActiveStreams(const vector> &split_streams) { - UpdateLabelStreams(split_streams); + // No event needs to be inserted between the active node and the activated stream. + string next_node_label; + if (AttrUtils::GetStr(next_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, next_node_label) && !next_node_label.empty()) { + auto iter = specific_activated_labels_.find(next_node_label); + if (iter != specific_activated_labels_.end()) { + for (const auto &active_node : iter->second) { + OpDescPtr active_op = active_node->GetOpDesc(); + GE_CHECK_NOTNULL(active_op); + if ((cur_stream_id == active_op->GetStreamId()) && (cur_node->GetOpDesc()->GetId() <= active_op->GetId())) { + GELOGI("No need to insert event between node %s and %s.", cur_node->GetName().c_str(), + next_node->GetName().c_str()); + return SUCCESS; + } + } + } + } - for (auto &node : whole_graph_->GetAllNodes()) { - if ((node->GetType() == STREAMSWITCH) || (node->GetType() == STREAMSWITCHN)) { - if (InsertActiveNodesAfterSwitch(node) != SUCCESS) { - GELOGE(FAILED, "Insert active nodes after switch node failed."); - return FAILED; + // Add send and receive events. + AddSendEventId(cur_node, event_num_); + AddRecvEventId(next_node, event_num_); + GELOGD("Insert event %u between node %s(stream %ld) and %s(stream %ld)", event_num_, cur_node->GetName().c_str(), + cur_stream_id, next_node->GetName().c_str(), next_stream_id); + + ++event_num_; + + return SUCCESS; +} + +Status StreamAllocator::InsertEventsForSubgraph() { + for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { + GE_CHECK_NOTNULL(subgraph); + for (const auto &node : subgraph->GetDirectNode()) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + bool is_subgraph_end_node = false; + if (!AttrUtils::GetBool(op_desc, ATTR_NAME_SUBGRAPH_END_NODE, is_subgraph_end_node) || !is_subgraph_end_node) { + continue; } - } else { - vector active_streams; - GE_CHECK_NOTNULL(node->GetOpDesc()); - if (AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { - vector new_active_streams = active_streams; - for (const uint32_t logical_stream : active_streams) { - if (static_cast(logical_stream) >= split_streams.size()) { - GELOGE(FAILED, "logical stream is out of range."); - return FAILED; - } - const set &new_split_streams = split_streams[logical_stream]; - if (!new_split_streams.empty()) { - for (int64_t split_stream : new_split_streams) { - new_active_streams.emplace_back(static_cast(split_stream)); - } - } - } - if (!AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, new_active_streams)) { - GELOGE(FAILED, "Set active streams for node %s failed.", node->GetName().c_str()); - return FAILED; + const auto parent_node = subgraph->GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + + // Insert events between subgraph end node and parent node's out nodes + for (const auto &next_node : parent_node->GetOutAllNodes()) { + Status status = InsertOneEventInTwoNodes(node, next_node); + if (status != SUCCESS) { + GELOGE(status, "InsertOneEventInTwoNodes failed!"); + return status; } } + + break; } } - Status status = SetActiveStreamsForSubgraph(); + return SUCCESS; +} + +// Optimize the event in the graph, delete the redundant sync event according to the stream information +Status StreamAllocator::OptimizeSyncEvents() { + map> stream_nodes; + + for (const auto &node : whole_graph_->GetAllNodes()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + int64_t stream_id = node->GetOpDesc()->GetStreamId(); + stream_nodes[stream_id].emplace_back(node); + } + + Status status = OptimizeBySendEvents(stream_nodes); if (status != SUCCESS) { - GELOGE(status, "SetActiveStreamsForSubgraph failed!"); + GELOGE(status, "OptimizeBySendEvents failed!"); return status; } - status = SetActiveStreamsForLoop(); + status = OptimizeByRecvEvents(stream_nodes); if (status != SUCCESS) { - GELOGE(status, "SetActiveStreamsForLoop failed!"); + GELOGE(status, "OptimizeByRecvEvents failed!"); return status; } - return SUCCESS; -} - -void StreamAllocator::UpdateLabelStreams(const vector> &split_streams) { - for (size_t i = 0; i < split_streams.size(); i++) { - auto &streams = split_streams[i]; - if (streams.empty()) { - continue; - } - if (specific_activated_streams_.count(static_cast(i)) > 0) { - specific_activated_streams_.insert(streams.begin(), streams.end()); - } - for (auto &labeled_stream : labeled_streams_) { - if (labeled_stream.second.count(static_cast(i)) > 0) { - labeled_stream.second.insert(streams.begin(), streams.end()); - break; + status = OptimizeByStreamActivate(); + if (status != SUCCESS) { + GELOGE(status, "OptimizeByStreamActivate failed!"); + return status; + } + for (auto pair : node_to_send_events_) { + if (pair.first->GetType() == STREAMSWITCH) { + for (auto event_id : pair.second) { + GELOGI("Curren switch node is %s, remove send event_id %d.", pair.first->GetName().c_str(), event_id); + RmvSendEventId(pair.first, event_id); + auto recv_node = GetNodeFromRecvEventId(event_id); + GELOGI("Curren recv_node is %s, remove recv event_id %d.", recv_node->GetName().c_str(), event_id); + RmvRecvEventId(recv_node, event_id); } } } + return SUCCESS; } -Status StreamAllocator::SetActiveStreamsForSubgraph() { - for (auto &subgraph : whole_graph_->GetAllSubgraphs()) { - GE_CHECK_NOTNULL(subgraph); - NodePtr first_active_node = nullptr; +/// Optimization scenario: one stream has multiple send events in one node, +/// and multiple nodes for recv events on another stream +/// Example: +/// Stream0 Stream1 +/// N1 - - - event - > N1 +/// \ | +/// \ v +/// - - event - > N2 +Status StreamAllocator::OptimizeBySendEvents(const map> &stream_nodes) { + for (const auto &one_pair : stream_nodes) { + // The nodes on a stream in order + const vector &nodes = one_pair.second; - // Get all streams in subgraph. - set subgraph_streams; - for (auto &node : subgraph->GetDirectNode()) { - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - int64_t stream_id = op_desc->GetStreamId(); - if (stream_id != kInvalidStream) { - subgraph_streams.emplace(stream_id); - } - if (first_active_node == nullptr && node->GetType() == STREAMACTIVE) { - first_active_node = node; - } - } + map send_node_to_event_id; - if (first_active_node == nullptr) { - continue; - } + for (const auto &recv_node_ptr : nodes) { + GE_CHECK_NOTNULL(recv_node_ptr); + // Get all recv events of the current node, then traverse the event + vector recv_events; + GetRecvEventIdList(recv_node_ptr, recv_events); - // Set active streams for StreamActive. - subgraph_streams.erase(first_active_node->GetOpDesc()->GetStreamId()); + for (const auto &event_id : recv_events) { + NodePtr send_node_ptr = GetNodeFromSendEventId(event_id); + GE_CHECK_NOTNULL(send_node_ptr); - vector active_streams; - for (int64_t active_stream : subgraph_streams) { - active_streams.emplace_back(static_cast(active_stream)); - specific_activated_streams_.emplace(active_stream); + /// If the record to the stream is found in the map, + /// and the recv node is the node, then remove sync event + if (send_node_to_event_id.find(send_node_ptr) != send_node_to_event_id.end()) { + RmvSendEventId(send_node_ptr, event_id); + RmvRecvEventId(recv_node_ptr, event_id); + GELOGI("Remove event %u between node %s and node %s", event_id, send_node_ptr->GetName().c_str(), + recv_node_ptr->GetName().c_str()); + } else { + send_node_to_event_id[send_node_ptr] = event_id; + } + } } + } - if (!AttrUtils::SetListInt(first_active_node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { - GELOGE(FAILED, "Set active streams for node %s failed.", first_active_node->GetName().c_str()); - return FAILED; - } + return SUCCESS; +} - // Remove all events after StreamActive. - vector send_events; - GetSendEventIdList(first_active_node, send_events); +/// Scenario: multiple send nodes on a stream sent to a single recv node on the destination stream +/// Example: +/// Stream0 Stream1 +/// N1 - - +/// | | +/// | - - event - - - +/// | | +/// V V +/// N2 - - - event - > N2 +Status StreamAllocator::OptimizeByRecvEvents(const map> &stream_nodes) { + for (const auto &one_pair : stream_nodes) { + // The nodes on a stream in order + const vector &nodes = one_pair.second; - for (const auto &event_id : send_events) { - NodePtr recv_node = GetNodeFromRecvEventId(event_id); - GE_CHECK_NOTNULL(recv_node); + map recv_node_to_event_id; - RmvSendEventId(first_active_node, event_id); - RmvRecvEventId(recv_node, event_id); - GELOGI("Remove event %u between node %s and node %s", event_id, first_active_node->GetName().c_str(), - recv_node->GetName().c_str()); - } - } + for (const auto &send_node_ptr : nodes) { + GE_CHECK_NOTNULL(send_node_ptr); + // Get all send events of the current node, then traverse the event + vector send_id_list; + GetSendEventIdList(send_node_ptr, send_id_list); - return SUCCESS; -} + for (const auto &event_id : send_id_list) { + NodePtr recv_node_ptr = GetNodeFromRecvEventId(event_id); + GE_CHECK_NOTNULL(recv_node_ptr); -Status StreamAllocator::SetActiveStreamsByLabel() { - for (const auto &node : whole_graph_->GetAllNodes()) { - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - string stream_label; - if (AttrUtils::GetStr(op_desc, ATTR_NAME_STREAM_LABEL, stream_label) && !stream_label.empty()) { - int64_t stream_id = op_desc->GetStreamId(); - if (stream_id != kInvalidStream) { - labeled_streams_[stream_label].emplace(stream_id); + /// If the record to the stream is found in the map, + /// and the send node is the node, then remove sync event + auto it = recv_node_to_event_id.find(recv_node_ptr); + if (it != recv_node_to_event_id.end()) { + uint32_t pre_event_id = it->second; + NodePtr pre_send_node_ptr = GetNodeFromSendEventId(pre_event_id); + GE_CHECK_NOTNULL(pre_send_node_ptr); + + RmvSendEventId(pre_send_node_ptr, pre_event_id); + RmvRecvEventId(recv_node_ptr, pre_event_id); + GELOGI("Remove event %u between node %s and node %s.", event_id, pre_send_node_ptr->GetName().c_str(), + recv_node_ptr->GetName().c_str()); + } + recv_node_to_event_id[recv_node_ptr] = event_id; } } } - for (const auto &node : whole_graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - vector activated_label_list; - if (!AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_ACTIVE_LABEL_LIST, activated_label_list) || - activated_label_list.empty()) { - continue; - } + return SUCCESS; +} - vector activated_stream_list; - for (string &activated_label : activated_label_list) { - specific_activated_labels_[activated_label].emplace(node); - for (int64_t activated_stream : labeled_streams_[activated_label]) { - activated_stream_list.push_back(static_cast(activated_stream)); - specific_activated_streams_.emplace(activated_stream); - specific_activated_streams_nodes_map_[activated_stream].emplace(node); - GELOGI("Node %s active stream %ld by %s.", node->GetName().c_str(), activated_stream, activated_label.c_str()); +Status StreamAllocator::OptimizeByStreamActivate() { + auto node_to_send_events_temp = node_to_send_events_; + for (const auto &node_event_id_pair : node_to_send_events_temp) { + const NodePtr &send_node_ptr = node_event_id_pair.first; + for (const auto &event_id : node_event_id_pair.second) { + NodePtr recv_node_ptr = GetNodeFromRecvEventId(event_id); + GE_CHECK_NOTNULL(recv_node_ptr); + if (IsRecvNodeActivatedBySendNode(send_node_ptr, recv_node_ptr)) { + RmvSendEventId(send_node_ptr, event_id); + RmvRecvEventId(recv_node_ptr, event_id); + GELOGI("Remove event %u between node %s and node %s.", event_id, send_node_ptr->GetName().c_str(), + recv_node_ptr->GetName().c_str()); } } - GE_CHK_BOOL_EXEC(AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, activated_stream_list), - GELOGE(FAILED, "SetListInt failed."); - return FAILED); } - return SUCCESS; } -Status StreamAllocator::SetActiveStreamsForLoop() { - vector loop_active_streams; - for (int64_t stream_id = 0; stream_id < stream_num_; stream_id++) { - if (specific_activated_streams_.count(stream_id) == 0) { - loop_active_streams.emplace_back(static_cast(stream_id)); +// In situation : stream(normal) -> stream(streamActivate)-> +// -> stream(streamSwitch) -> stream(streamActivate) -> stream(stream true or false) +// No need to insert an event between node in stream(normal) and node in stream(stream true or false) +bool StreamAllocator::IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const { + GE_CHECK_NOTNULL_EXEC(send_node_ptr->GetOpDesc(), GELOGE(FAILED, "op desc is nullptr"); return false); + GE_CHECK_NOTNULL_EXEC(recv_node_ptr->GetOpDesc(), GELOGE(FAILED, "op desc is nullptr"); return false); + auto cur_stream_id = send_node_ptr->GetOpDesc()->GetStreamId(); + if (AttrUtils::HasAttr(recv_node_ptr->GetOpDesc(), ATTR_NAME_STREAM_LABEL)) { + // find streamActivate node + auto iter = specific_activated_streams_nodes_map_.find(recv_node_ptr->GetOpDesc()->GetStreamId()); + set activate_stream_nodes; + if (iter != specific_activated_streams_nodes_map_.end()) { + activate_stream_nodes = iter->second; } - } - // Set the stream that needs to be activated - for (const auto &node : whole_graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - bool is_loop_active = false; - if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, is_loop_active) && is_loop_active) { - vector activated_label_list; - if (!AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_ACTIVE_LABEL_LIST, activated_label_list) || - activated_label_list.empty()) { - GE_CHK_BOOL_EXEC(AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, loop_active_streams), - GELOGE(FAILED, "SetListInt failed."); - return FAILED); - for (const auto &stream_id : loop_active_streams) { - GELOGI("Active stream %u for node: %s", stream_id, node->GetName().c_str()); + set visited_nodes{recv_node_ptr}; + while (!activate_stream_nodes.empty()) { + set activate_stream_nodes_temp; + for (const auto &activate_stream_node : activate_stream_nodes) { + GE_IF_BOOL_EXEC(activate_stream_node->GetOpDesc() == nullptr, continue); + if (visited_nodes.find(activate_stream_node) != visited_nodes.end() || + AttrUtils::HasAttr(activate_stream_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE)) { + return false; + } + visited_nodes.insert(activate_stream_node); + // nodes in stream link to streamActivate no need to add event before activated node + for (const auto &pre_activate_stream_node : activate_stream_node->GetInNodes()) { + GE_IF_BOOL_EXEC(pre_activate_stream_node->GetOpDesc() == nullptr, continue); + if (pre_activate_stream_node->GetOpDesc()->GetStreamId() == cur_stream_id && + pre_activate_stream_node->GetOpDesc()->GetId() >= send_node_ptr->GetOpDesc()->GetId()) { + return true; + } + auto in_nodes_of_pre = pre_activate_stream_node->GetInNodes(); + if (std::find(in_nodes_of_pre.begin(), in_nodes_of_pre.end(), send_node_ptr) != in_nodes_of_pre.end()) { + return true; + } + } + auto iterator = specific_activated_streams_nodes_map_.find(activate_stream_node->GetOpDesc()->GetStreamId()); + if (iterator != specific_activated_streams_nodes_map_.end()) { + auto active_nodes = iterator->second; + for (const auto &active_node : active_nodes) { + activate_stream_nodes_temp.emplace(active_node); + } } - - break; } + activate_stream_nodes = activate_stream_nodes_temp; } } - - return CheckStreamActived(); + return false; } -Status StreamAllocator::CheckStreamActived() const { - for (const auto &node : whole_graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - vector active_streams; - if (AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { - uint32_t stream_id = static_cast(node->GetOpDesc()->GetStreamId()); - auto iter = find(active_streams.begin(), active_streams.end(), stream_id); - if (iter != active_streams.end()) { - GELOGE(FAILED, "Node %s cannot active its own stream %u.", node->GetName().c_str(), stream_id); - return FAILED; - } - } +// Split the stream according to the maximum number of nodes in the stream. +Status StreamAllocator::SplitStreams(vector> &split_streams) { + if (enable_single_stream_ || stream_num_ == 0) { + GELOGI("The single stream option is enabled or the number of streams is 0, no need to split streams."); + return SUCCESS; } - return SUCCESS; -} + // stream_node_num_vec records the number of all nodes on each stream + // added_stream_num_vec records the number of streams that each stream needs to increase + // new_stream_id_vec records the new physical stream id for each stream + vector stream_node_num_vec(stream_num_); + vector added_stream_num_vec(stream_num_); + vector new_stream_id_vec(stream_num_); + map stream_continuous_2_node_num_map; + map> stream_continuous_2_nodes_map; + map> stream_2_nodes_map; + vector pre_node_vec(stream_num_); + + int64_t last_stream_id = stream_num_ - 1; + for (auto i = 0; i <= last_stream_id; i++) { + stream_node_num_vec[i] = 0; + added_stream_num_vec[i] = 0; + new_stream_id_vec[i] = i; + pre_node_vec[i] = nullptr; + } + + uint32_t max_stream_count = 0; + uint32_t max_task_count = 0; + GE_CHK_STATUS_RET(GetMaxStreamAndTask(false, max_stream_count, max_task_count), + "Get max stream and task count failed."); -// Insert the send/recv event id to the graph -Status StreamAllocator::InsertSyncEvents() { for (const auto &cur_node : whole_graph_->GetAllNodes()) { - // Take the adjacent points, then judge whether need to insert the event - for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) { - for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) { - NodePtr next_node = peer_in_anchor->GetOwnerNode(); - Status status = InsertOneEventInTwoNodes(cur_node, next_node); - if (status != SUCCESS) { - GELOGE(status, "InsertOneEventInTwoNodes failed!"); - return status; - } + GE_CHECK_NOTNULL(cur_node); + auto op_desc = cur_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + int64_t stream_id = op_desc->GetStreamId(); + if (stream_id == kInvalidStream) { + continue; + } + if (stream_id > last_stream_id) { + GELOGE(FAILED, "SplitStreams:streamid(%ld) > last_stream_id(%ld)", stream_id, last_stream_id); + return FAILED; + } + stream_node_num_vec[stream_id]++; + stream_2_nodes_map[stream_id].push_back(cur_node); + // The maximum number of tasks per stream. + int64_t max_node_num_one_stream = GetMaxNodeNumPerStream(cur_node, max_task_count); + std::string continuous_stream_label; + if (HasContinuousStreamLabel(op_desc, continuous_stream_label)) { + stream_continuous_2_node_num_map[continuous_stream_label]++; + // return error + if (stream_continuous_2_node_num_map[continuous_stream_label] > max_node_num_one_stream) { + GELOGE(FAILED, "SplitStreams:node[%s] stream_id[%ld] continuous stream label[%s] unsatisfied ", + op_desc->GetName().c_str(), stream_id, continuous_stream_label.c_str()); + return FAILED; } + stream_continuous_2_nodes_map[continuous_stream_label].push_back(cur_node); } - - /// If the two nodes of the control side belong to two streams, - /// you also need to add the send/recv event. - if (cur_node->GetOutControlAnchor() != nullptr) { - for (const AnchorPtr &peer_in_anchor : cur_node->GetOutControlAnchor()->GetPeerAnchors()) { - NodePtr next_node = peer_in_anchor->GetOwnerNode(); - Status status = InsertOneEventInTwoNodes(cur_node, next_node); - if (status != SUCCESS) { - GELOGE(status, "InsertOneEventInTwoNodes failed!"); - return status; + // Split the stream if it exceeds the maximum number of nodes in the stream. + if (NeedSpiltNewStream(stream_node_num_vec[stream_id], max_node_num_one_stream, op_desc)) { + last_stream_id++; + GELOGI( + "stream_node_num_vec[%ld]= %ld > max_node_num_one_stream : %ld, " + "It's time to split the stream, split newly-added stream id is %ld", + stream_id, stream_node_num_vec[stream_id], max_node_num_one_stream, last_stream_id); + NodePtr pre_node = pre_node_vec[stream_id]; + stream_node_num_vec[stream_id] = 1; + // try spilt a new stream and move same continuous stream label nodes from this stream + bool not_use_cur = false; + NodePtr not_cur = nullptr; + std::string cur_continuous_stream_label; + if (HasContinuousStreamLabel(op_desc, cur_continuous_stream_label)) { + // get stored nodes + auto nodes = stream_continuous_2_nodes_map[cur_continuous_stream_label]; + GE_RETURN_WITH_LOG_IF_FALSE(!nodes.empty(), "split stream with continuous stream label %s failed", + cur_continuous_stream_label.c_str()); + for (const auto &node : nodes) { + auto stored_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(stored_op_desc); + stored_op_desc->SetStreamId(last_stream_id); + stream_node_num_vec[stream_id]++; } + not_use_cur = true; + not_cur = nodes.front(); + GE_CHECK_NOTNULL(not_cur); + GELOGI("split from first node %s with continuous stream label %s", not_cur->GetName().c_str(), + cur_continuous_stream_label.c_str()); + auto iter = std::find(stream_2_nodes_map[stream_id].begin(), stream_2_nodes_map[stream_id].end(), not_cur); + GE_RETURN_WITH_LOG_IF_FALSE( + (iter != stream_2_nodes_map[stream_id].end()) && (iter != stream_2_nodes_map[stream_id].begin()), + "split stream with continuous stream label %s failed", cur_continuous_stream_label.c_str()); + iter--; + pre_node = *iter; + } + + added_stream_num_vec[stream_id]++; + new_stream_id_vec[stream_id] = last_stream_id; + split_streams[stream_id].emplace(last_stream_id); + node_split_stream_map_[cur_node] = last_stream_id; + + // Add the send/recv event to the first and last nodes of the split stream. + if (pre_node != nullptr) { + GE_CHK_STATUS_RET(AddEventId(pre_node, not_cur, cur_node, not_use_cur), "AddEventId failed."); } } + + /// If the split stream num is greater than 1, the node behind the same + /// stream must reset the new stream id. + if (added_stream_num_vec[stream_id] >= 1) { + op_desc->SetStreamId(new_stream_id_vec[stream_id]); + } + + pre_node_vec[stream_id] = cur_node; } + if (last_stream_id >= 0) { + stream_num_ = last_stream_id + 1; + } return SUCCESS; } -// Insert one send/recv event in two nodes -Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const NodePtr &next_node) { - GE_CHECK_NOTNULL(cur_node->GetOpDesc()); - GE_CHECK_NOTNULL(next_node->GetOpDesc()); +bool StreamAllocator::NeedSpiltNewStream(int64_t stream_node_num, int64_t max_node_num_one_stream, + const OpDescPtr &op_desc) const { + const set label_op_types({LABELSET, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX}); + bool is_first_active_node = false; + (void)AttrUtils::GetBool(op_desc, ATTR_NAME_SUBGRAPH_FIRST_ACTIVE, is_first_active_node); + return (stream_node_num > max_node_num_one_stream && op_desc->GetSubgraphInstanceNames().empty() && + !is_first_active_node && label_op_types.count(op_desc->GetType()) == 0); +} - // No need to insert events after node that do not assign streams. - int64_t cur_stream_id = cur_node->GetOpDesc()->GetStreamId(); - if (cur_stream_id == kInvalidStream) { - GELOGD("No need to insert event after node %s.", cur_node->GetName().c_str()); - return SUCCESS; +Status StreamAllocator::UpdateActiveStreams(const vector> &split_streams) { + UpdateLabelStreams(split_streams); + + for (auto &node : whole_graph_->GetAllNodes()) { + if ((node->GetType() == STREAMSWITCH) || (node->GetType() == STREAMSWITCHN)) { + if (InsertActiveNodesAfterSwitch(node) != SUCCESS) { + GELOGE(FAILED, "Insert active nodes after switch node failed."); + return FAILED; + } + } else { + vector active_streams; + GE_CHECK_NOTNULL(node->GetOpDesc()); + if (AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + vector new_active_streams = active_streams; + for (const uint32_t logical_stream : active_streams) { + if (static_cast(logical_stream) >= split_streams.size()) { + GELOGE(FAILED, "logical stream is out of range."); + return FAILED; + } + const set &new_split_streams = split_streams[logical_stream]; + if (!new_split_streams.empty()) { + for (int64_t split_stream : new_split_streams) { + new_active_streams.emplace_back(static_cast(split_stream)); + GELOGI("Add stream %ld to active_stream_list of node %s of graph %s", split_stream, + node->GetName().c_str(), node->GetOwnerComputeGraph()->GetName().c_str()); + } + } + } + if (!AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, new_active_streams)) { + GELOGE(FAILED, "Set active streams for node %s failed.", node->GetName().c_str()); + return FAILED; + } + } + } } - // No need to insert events between nodes in the same stream. - int64_t next_stream_id = next_node->GetOpDesc()->GetStreamId(); - if (cur_stream_id == next_stream_id) { - return SUCCESS; + Status status = UpdateActiveStreamsForSubgraphs(); + if (status != SUCCESS) { + GELOGE(status, "SetActiveStreamsForSubgraph failed!"); + return status; } - if (next_stream_id == kInvalidStream) { - GELOGE(FAILED, "Stream id of next_node %s should not be %ld", next_node->GetName().c_str(), kInvalidStream); - return FAILED; + status = SetActiveStreamsForLoop(); + if (status != SUCCESS) { + GELOGE(status, "SetActiveStreamsForLoop failed!"); + return status; } - // No event needs to be inserted between the active node and the activated stream. - string next_node_label; - if (AttrUtils::GetStr(next_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, next_node_label) && !next_node_label.empty()) { - auto iter = specific_activated_labels_.find(next_node_label); - if (iter != specific_activated_labels_.end()) { - for (const auto &active_node : iter->second) { - OpDescPtr active_op = active_node->GetOpDesc(); - GE_CHECK_NOTNULL(active_op); - if ((cur_stream_id == active_op->GetStreamId()) && (cur_node->GetOpDesc()->GetId() <= active_op->GetId())) { - GELOGI("No need to insert event between node %s and %s.", cur_node->GetName().c_str(), - next_node->GetName().c_str()); - return SUCCESS; - } + return SUCCESS; +} + +void StreamAllocator::UpdateLabelStreams(const vector> &split_streams) { + for (size_t i = 0; i < split_streams.size(); i++) { + auto &streams = split_streams[i]; + if (streams.empty()) { + continue; + } + if (specific_activated_streams_.count(static_cast(i)) > 0) { + specific_activated_streams_.insert(streams.begin(), streams.end()); + } + for (auto &labeled_stream : labeled_streams_) { + if (labeled_stream.second.count(static_cast(i)) > 0) { + labeled_stream.second.insert(streams.begin(), streams.end()); + break; } } } +} - // Add send and receive events. - AddSendEventId(cur_node, event_num_); - AddRecvEventId(next_node, event_num_); - GELOGD("Insert event %u between node %s(stream %ld) and %s(stream %ld)", event_num_, cur_node->GetName().c_str(), - cur_stream_id, next_node->GetName().c_str(), next_stream_id); +Status StreamAllocator::InsertActiveNodesAfterSwitch(NodePtr &switch_node) { + vector active_nodes; + if (InsertActiveNodesAfterSwitch(switch_node, active_nodes) != SUCCESS) { + GELOGE(FAILED, "Insert active nodes after node %s failed.", switch_node->GetName().c_str()); + return FAILED; + } + if (active_nodes.empty()) { + return SUCCESS; + } + vector stream_ids; + for (auto &active_node : active_nodes) { + GE_CHECK_NOTNULL(active_node->GetOpDesc()); + active_node->GetOpDesc()->SetStreamId(stream_num_); + stream_ids.emplace_back(stream_num_); + specific_activated_streams_.emplace(stream_num_); + stream_num_++; + } + auto op_desc = switch_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); - ++event_num_; + if (!AttrUtils::SetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, stream_ids)) { + GELOGE(FAILED, "SetListInt failed."); + return FAILED; + } return SUCCESS; } -// Optimize the event in the graph, delete the redundant sync event according to the stream information -Status StreamAllocator::OptimizeSyncEvents() { - map> stream_nodes; - - for (const auto &node : whole_graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - int64_t stream_id = node->GetOpDesc()->GetStreamId(); - stream_nodes[stream_id].emplace_back(node); +Status StreamAllocator::InsertActiveNodesAfterSwitch(NodePtr &switch_node, vector &active_nodes) { + GE_CHECK_NOTNULL(switch_node); + OpDescPtr switch_desc = switch_node->GetOpDesc(); + GE_CHECK_NOTNULL(switch_desc); + vector ori_active_label_list; + if (!AttrUtils::GetListStr(switch_desc, ATTR_NAME_ACTIVE_LABEL_LIST, ori_active_label_list) || + ori_active_label_list.empty()) { + GELOGE(INTERNAL_ERROR, "Get active label list of switch %s failed.", switch_node->GetName().c_str()); + return INTERNAL_ERROR; } - Status status = OptimizeBySendEvents(stream_nodes); - if (status != SUCCESS) { - GELOGE(status, "OptimizeBySendEvents failed!"); - return status; + vector active_label_list; + vector added_active_nodes; + if (AddActiveNodes(switch_node, ori_active_label_list, active_label_list, added_active_nodes) != SUCCESS) { + GELOGE(FAILED, "Add active nodes after node %s failed.", switch_node->GetName().c_str()); + return FAILED; } - status = OptimizeByRecvEvents(stream_nodes); - if (status != SUCCESS) { - GELOGE(status, "OptimizeByRecvEvents failed!"); - return status; + if (SetActiveLabelList(switch_node, active_label_list) != SUCCESS) { + GELOGE(FAILED, "set active label list failed"); + return FAILED; } - status = OptimizeByStreamActivate(); - if (status != SUCCESS) { - GELOGE(status, "OptimizeByStreamActivate failed!"); - return status; + if (added_active_nodes.empty()) { + return SUCCESS; } + for (auto &active_node : added_active_nodes) { + GE_CHECK_NOTNULL(switch_node->GetOutControlAnchor()); + if (switch_node->GetOutControlAnchor()->LinkTo(active_node->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Link %s to %s failed.", switch_node->GetName().c_str(), active_node->GetName().c_str()); + return FAILED; + } + active_nodes.emplace_back(active_node); + } return SUCCESS; } -/// Optimization scenario: one stream has multiple send events in one node, -/// and multiple nodes for recv events on another stream -/// Example: -/// Stream0 Stream1 -/// N1 - - - event - > N1 -/// \ | -/// \ v -/// - - event - > N2 -Status StreamAllocator::OptimizeBySendEvents(const map> &stream_nodes) { - for (const auto &one_pair : stream_nodes) { - // The nodes on a stream in order - const vector &nodes = one_pair.second; - - map send_node_to_event_id; - - for (const auto &recv_node_ptr : nodes) { - GE_CHECK_NOTNULL(recv_node_ptr); - // Get all recv events of the current node, then traverse the event - vector recv_events; - GetRecvEventIdList(recv_node_ptr, recv_events); +Status StreamAllocator::UpdateActiveStreamsForSubgraphs() const { + // Update active stream list for active nodes + for (auto &node_stream_pair : node_split_stream_map_) { + auto node = node_stream_pair.first; + auto subgraph = node->GetOwnerComputeGraph(); + if (subgraph->GetParentNode() == nullptr) { + continue; + } + // Skip streams with label + GE_CHECK_NOTNULL(node->GetOpDesc()); + string stream_label; + if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label) && !stream_label.empty()) { + continue; + } + auto it = subgraph_first_active_node_map_.find(subgraph); + if (it == subgraph_first_active_node_map_.end()) { + continue; + } + const auto &active_node = it->second; + GE_CHECK_NOTNULL(active_node); + auto op_desc = active_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + vector active_streams; + (void)AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams); + set new_active_streams(active_streams.begin(), active_streams.end()); + new_active_streams.emplace(static_cast(node_stream_pair.second)); + active_streams.assign(new_active_streams.begin(), new_active_streams.end()); + if (!AttrUtils::SetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + GELOGE(FAILED, "Set active streams for node %s failed.", active_node->GetName().c_str()); + return FAILED; + } + } - for (const auto &event_id : recv_events) { - NodePtr send_node_ptr = GetNodeFromSendEventId(event_id); - GE_CHECK_NOTNULL(send_node_ptr); + return SUCCESS; +} - /// If the record to the stream is found in the map, - /// and the recv node is the node, then remove sync event - if (send_node_to_event_id.find(send_node_ptr) != send_node_to_event_id.end()) { - RmvSendEventId(send_node_ptr, event_id); - RmvRecvEventId(recv_node_ptr, event_id); - GELOGI("Remove event %u between node %s and node %s", event_id, send_node_ptr->GetName().c_str(), - recv_node_ptr->GetName().c_str()); - } else { - send_node_to_event_id[send_node_ptr] = event_id; +Status StreamAllocator::SetActiveStreamsForLoop() { + vector loop_active_streams; + for (int64_t stream_id = 0; stream_id < stream_num_; stream_id++) { + if (specific_activated_streams_.count(stream_id) == 0) { + loop_active_streams.emplace_back(static_cast(stream_id)); + } + } + // Set the stream that needs to be activated + for (const auto &node : whole_graph_->GetAllNodes()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + bool is_loop_active = false; + if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, is_loop_active) && is_loop_active) { + vector activated_label_list; + if (!AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_ACTIVE_LABEL_LIST, activated_label_list) || + activated_label_list.empty()) { + GE_CHK_BOOL_EXEC(AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, loop_active_streams), + GELOGE(FAILED, "SetListInt failed."); + return FAILED); + for (const auto &stream_id : loop_active_streams) { + GELOGI("Active stream %u for node: %s", stream_id, node->GetName().c_str()); } + + break; } } } - return SUCCESS; + return CheckStreamActived(); } -/// Scenario: multiple send nodes on a stream sent to a single recv node on the destination stream -/// Example: -/// Stream0 Stream1 -/// N1 - - -/// | | -/// | - - event - - - -/// | | -/// V V -/// N2 - - - event - > N2 -Status StreamAllocator::OptimizeByRecvEvents(const map> &stream_nodes) { - for (const auto &one_pair : stream_nodes) { - // The nodes on a stream in order - const vector &nodes = one_pair.second; - - map recv_node_to_event_id; +Status StreamAllocator::CheckStreamActived() const { + for (const auto &node : whole_graph_->GetAllNodes()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + vector active_streams; + if (AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + uint32_t stream_id = static_cast(node->GetOpDesc()->GetStreamId()); + auto iter = find(active_streams.begin(), active_streams.end(), stream_id); + if (iter != active_streams.end()) { + GELOGE(FAILED, "Node %s cannot active its own stream %u.", node->GetName().c_str(), stream_id); + return FAILED; + } + } + } - for (const auto &send_node_ptr : nodes) { - GE_CHECK_NOTNULL(send_node_ptr); - // Get all send events of the current node, then traverse the event - vector send_id_list; - GetSendEventIdList(send_node_ptr, send_id_list); + return SUCCESS; +} - for (const auto &event_id : send_id_list) { - NodePtr recv_node_ptr = GetNodeFromRecvEventId(event_id); - GE_CHECK_NOTNULL(recv_node_ptr); +// Add active entry stream for special env. +Status StreamAllocator::AddActiveEntryStream() { + auto gelib = GELib::GetInstance(); + bool head_stream = (gelib == nullptr) ? false : gelib->HeadStream(); + GELOGI("Configured head stream: %u", head_stream); + if (!head_stream) { + return SUCCESS; + } - /// If the record to the stream is found in the map, - /// and the send node is the node, then remove sync event - auto it = recv_node_to_event_id.find(recv_node_ptr); - if (it != recv_node_to_event_id.end()) { - uint32_t pre_event_id = it->second; - NodePtr pre_send_node_ptr = GetNodeFromSendEventId(pre_event_id); - GE_CHECK_NOTNULL(pre_send_node_ptr); + // Collect streams active by StreamSwitch/StreamActive node. + std::set deactive_stream; + for (ge::NodePtr &node : whole_graph_->GetAllNodes()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + Status ret = CollectDeactiveStream(node->GetOpDesc(), deactive_stream); + if (ret != SUCCESS) { + return ret; + } + } - RmvSendEventId(pre_send_node_ptr, pre_event_id); - RmvRecvEventId(recv_node_ptr, pre_event_id); - GELOGI("Remove event %u between node %s and node %s.", event_id, pre_send_node_ptr->GetName().c_str(), - recv_node_ptr->GetName().c_str()); - } - recv_node_to_event_id[recv_node_ptr] = event_id; - } + // Collect default active stream, Add to active entry stream. + std::vector active_stream_list; + for (int64_t stream_id = 0; stream_id < stream_num_; ++stream_id) { + if (deactive_stream.count(stream_id) == 0) { + active_stream_list.push_back(stream_id); } } - return SUCCESS; + int64_t new_stream_id = stream_num_; + stream_num_++; + return InsertActiveEntryStream(active_stream_list, new_stream_id); } -// In situation : stream(normal) -> stream(streamActivate)-> -// -> stream(streamSwitch) -> stream(streamActivate) -> stream(stream true or false) -// No need to insert an event between node in stream(normal) and node in stream(stream true or false) -bool StreamAllocator::IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const { - GE_CHECK_NOTNULL_EXEC(send_node_ptr->GetOpDesc(), GELOGE(FAILED, "op desc is nullptr"); return false); - GE_CHECK_NOTNULL_EXEC(recv_node_ptr->GetOpDesc(), GELOGE(FAILED, "op desc is nullptr"); return false); - auto cur_stream_id = send_node_ptr->GetOpDesc()->GetStreamId(); - if (AttrUtils::HasAttr(recv_node_ptr->GetOpDesc(), ATTR_NAME_STREAM_LABEL)) { - // find streamActivate node - auto iter = specific_activated_streams_nodes_map_.find(recv_node_ptr->GetOpDesc()->GetStreamId()); - set activate_stream_nodes; - if (iter != specific_activated_streams_nodes_map_.end()) { - activate_stream_nodes = iter->second; +// Collect deactive stream from flowctrl op. +Status StreamAllocator::CollectDeactiveStream(const OpDescPtr &op_desc, std::set &deactive_streams) const { + GE_CHECK_NOTNULL(op_desc); + std::string op_type = op_desc->GetType(); + if (op_type == STREAMSWITCH) { + std::vector active_stream_list; + // If GetListInt fail, active_stream_list is empty. + (void)ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list); + if (active_stream_list.size() != kMaxSwitchStreamNum) { + GELOGE(INTERNAL_ERROR, "Stream num of switch true branch must be %u.", kMaxSwitchStreamNum); + return INTERNAL_ERROR; } - set visited_nodes{recv_node_ptr}; - while (!activate_stream_nodes.empty()) { - set activate_stream_nodes_temp; - for (const auto &activate_stream_node : activate_stream_nodes) { - GE_IF_BOOL_EXEC(activate_stream_node->GetOpDesc() == nullptr, continue); - if (visited_nodes.find(activate_stream_node) != visited_nodes.end() || - AttrUtils::HasAttr(activate_stream_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE)) { - return false; - } - visited_nodes.insert(activate_stream_node); - // nodes in stream link to streamActivate no need to add event before activated node - for (const auto &pre_activate_stream_node : activate_stream_node->GetInNodes()) { - GE_IF_BOOL_EXEC(pre_activate_stream_node->GetOpDesc() == nullptr, continue); - if (pre_activate_stream_node->GetOpDesc()->GetStreamId() == cur_stream_id && - pre_activate_stream_node->GetOpDesc()->GetId() >= send_node_ptr->GetOpDesc()->GetId()) { - return true; - } - auto in_nodes_of_pre = pre_activate_stream_node->GetInNodes(); - if (std::find(in_nodes_of_pre.begin(), in_nodes_of_pre.end(), send_node_ptr) != in_nodes_of_pre.end()) { - return true; - } - } - auto iterator = specific_activated_streams_nodes_map_.find(activate_stream_node->GetOpDesc()->GetStreamId()); - if (iterator != specific_activated_streams_nodes_map_.end()) { - auto active_nodes = iterator->second; - for (const auto &active_node : active_nodes) { - activate_stream_nodes_temp.emplace(active_node); - } - } + + deactive_streams.insert(active_stream_list[0]); + GELOGI("Flowctrl_op node:%s, flowctrl stream id:%u.", op_desc->GetName().c_str(), active_stream_list[0]); + } else if (op_type == STREAMACTIVE) { + if (op_desc->HasAttr(ATTR_NAME_SWITCH_BRANCH_NODE_LABEL)) { + std::vector active_stream_list; + if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list)) { + GELOGE(INTERNAL_ERROR, "StreamActiveOp get attr ACTIVE_STREAM fail."); + return INTERNAL_ERROR; + } + + for (uint32_t deactive_stream : active_stream_list) { + deactive_streams.insert(deactive_stream); + GELOGI("Flowctrl_op node:%s, flowctrl stream id:%u.", op_desc->GetName().c_str(), deactive_stream); } - activate_stream_nodes = activate_stream_nodes_temp; } } - return false; + + return SUCCESS; } -Status StreamAllocator::OptimizeByStreamActivate() { - auto node_to_send_events_temp = node_to_send_events_; - for (const auto &node_event_id_pair : node_to_send_events_temp) { - const NodePtr &send_node_ptr = node_event_id_pair.first; - for (const auto &event_id : node_event_id_pair.second) { - NodePtr recv_node_ptr = GetNodeFromRecvEventId(event_id); - GE_CHECK_NOTNULL(recv_node_ptr); - if (IsRecvNodeActivatedBySendNode(send_node_ptr, recv_node_ptr)) { - RmvSendEventId(send_node_ptr, event_id); - RmvRecvEventId(recv_node_ptr, event_id); - GELOGI("Remove event %u between node %s and node %s.", event_id, send_node_ptr->GetName().c_str(), - recv_node_ptr->GetName().c_str()); - } - } +// Insert StreamActive Op for Entry Stream. +Status StreamAllocator::InsertActiveEntryStream(const std::vector &active_streams, int64_t stream_id) { + string node_name = "ActiveEntryStream_" + string(STREAMACTIVE); + OpDescPtr op_desc = ge::MakeShared(node_name, STREAMACTIVE); + if (op_desc == nullptr) { + GELOGE(FAILED, "Failed to new opdesc."); + return FAILED; } + GELOGI("Create StreamActive op:%s.", op_desc->GetName().c_str()); + + GE_CHK_BOOL_EXEC( + AttrUtils::SetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, std::move(std::vector())), + GELOGE(FAILED, "SetListStr failed."); + return FAILED); + + NodePtr active_node = whole_graph_->AddNodeFront(op_desc); + GE_IF_BOOL_EXEC(active_node == nullptr, + GELOGE(FAILED, "Create StreamActive op: %s failed.", op_desc->GetName().c_str()); + return INTERNAL_ERROR); + GE_CHECK_NOTNULL(active_node->GetOpDesc()); + // Add one stream for ActiveEntryStream Task. + active_node->GetOpDesc()->SetStreamId(stream_id); + + GE_CHK_BOOL_EXEC(AttrUtils::SetBool(op_desc, "is_aicpu_stream", true), GELOGE(FAILED, "SetBool failed."); + return FAILED); + GE_CHK_BOOL_EXEC(AttrUtils::SetListInt(active_node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams), + GELOGE(FAILED, "SetListInt failed."); + return FAILED); + + std::vector group_names; + GE_CHK_BOOL_EXEC(AttrUtils::SetListStr(active_node->GetOpDesc(), ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, group_names), + GELOGE(FAILED, "SetLisStr failed."); + return FAILED); + return SUCCESS; } @@ -987,38 +1191,104 @@ Status StreamAllocator::InsertSyncEventNodes() { return status; } - GELOGI("Insert send event %u after node: %s", event_id, node->GetName().c_str()); + GELOGI("Insert send event %u after node: %s", event_id, node->GetName().c_str()); + } + } + + Status status = ReorderEventNodes(); + if (status != SUCCESS) { + GELOGE(status, "Graph ReorderEventNodes failed"); + return status; + } + + return SUCCESS; +} + +Status StreamAllocator::ReorderEventNodes() const { + Status status = whole_graph_->InsertEventNodes(); + if (status != SUCCESS) { + GELOGE(status, "Whole graph InsertEventNodes failed"); + return status; + } + for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { + status = subgraph->InsertEventNodes(); + if (status != SUCCESS) { + GELOGE(status, "Subgraph %s InsertEventNodes failed", subgraph->GetName().c_str()); + return status; + } + } + return SUCCESS; +} + +void StreamAllocator::DumpEvents() { + map> after_refresh_stream_nodes; + for (const auto &node : whole_graph_->GetAllNodes()) { + GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); + int64_t stream_id = node->GetOpDesc()->GetStreamId(); + after_refresh_stream_nodes[stream_id].emplace_back(node); + } + + for (const auto &one_pair : after_refresh_stream_nodes) { + int64_t stream_id = one_pair.first; + GELOGI("After RefreshRealStream: stream %ld.", stream_id); + + for (const auto &node : one_pair.second) { + string send_event_str; + for (const auto &send_event_id : node_to_send_events_[node]) { + send_event_str += " " + to_string(send_event_id); + } + if (!send_event_str.empty()) { + GELOGI("node: %s, send events: %s", node->GetName().c_str(), send_event_str.c_str()); + } + + string recv_event_str; + for (const auto &recv_event_id : node_to_recv_events_[node]) { + recv_event_str += " " + to_string(recv_event_id); + } + if (!recv_event_str.empty()) { + GELOGI("node: %s, recv events: %s", node->GetName().c_str(), recv_event_str.c_str()); + } + } + } +} + +Status StreamAllocator::GetMaxStreamAndTask(bool huge_stream, uint32_t &max_stream_count, uint32_t &max_task_count) { + const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); + if (buffer_optimize_on != nullptr) { + rtError_t ret = rtSetPlatformType(PLATFORM_MINI_V1); + if (ret != RT_ERROR_NONE) { + GELOGE(FAILED, "Get max stream and task count by rts failed."); + return FAILED; } } - Status status = ReorderEventNodes(); - if (status != SUCCESS) { - GELOGE(status, "Graph ReorderEventNodes failed"); - return status; + uint32_t stream_type = RT_NORMAL_STREAM; + if (huge_stream) { + stream_type = RT_HUGE_STREAM; + } + rtError_t ret = rtGetMaxStreamAndTask(stream_type, &max_stream_count, &max_task_count); + if (ret != RT_ERROR_NONE) { + GELOGE(FAILED, "Get max stream and task count by rts failed."); + return FAILED; } + GELOGI("Allowed max stream count: %u, max task count per stream: %u.", max_stream_count, max_task_count); return SUCCESS; } -Status StreamAllocator::ReorderEventNodes() const { - Status status = whole_graph_->InsertEventNodes(); - GraphUtils::DumpGEGraph(whole_graph_, "AfterInsertEventNodes", true); - GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "AfterInsertEventNodes"); - if (status != SUCCESS) { - GELOGE(status, "Whole graph InsertEventNodes failed"); - return status; +int64_t StreamAllocator::GetMaxNodeNumPerStream(const NodePtr &node, uint32_t max_task_count) { + int64_t max_node_num_one_stream = static_cast(max_task_count); + string op_type = node->GetType(); + if (IsHcclOp(op_type)) { + max_node_num_one_stream /= kTaskNumPerHcclNode; + } else { + max_node_num_one_stream /= kTaskNumPerNormalNode; } - for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { - status = subgraph->InsertEventNodes(); - GraphUtils::DumpGEGraph(subgraph, "AfterInsertEventNodes_Subgraph"); - GraphUtils::DumpGEGraphToOnnx(*subgraph, "AfterInsertEventNodes_Subgraph"); - if (status != SUCCESS) { - GELOGE(status, "Subgraph %s InsertEventNodes failed", subgraph->GetName().c_str()); - return status; - } + if (max_node_num_one_stream == 0) { + max_node_num_one_stream = 1; } - return SUCCESS; + return max_node_num_one_stream; } // Insert send event id on a node @@ -1109,203 +1379,19 @@ NodePtr StreamAllocator::GetNodeFromRecvEventId(uint32_t recv_event_id) const { return nullptr; } -void StreamAllocator::DumpEvents() { - map> after_refresh_stream_nodes; - for (const auto &node : whole_graph_->GetAllNodes()) { - GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); - int64_t stream_id = node->GetOpDesc()->GetStreamId(); - after_refresh_stream_nodes[stream_id].emplace_back(node); - } - - for (const auto &one_pair : after_refresh_stream_nodes) { - int64_t stream_id = one_pair.first; - GELOGI("After RefreshRealStream: stream %ld.", stream_id); - - for (const auto &node : one_pair.second) { - string send_event_str; - for (const auto &send_event_id : node_to_send_events_[node]) { - send_event_str += " " + to_string(send_event_id); - } - if (!send_event_str.empty()) { - GELOGI("node: %s, send events: %s", node->GetName().c_str(), send_event_str.c_str()); - } - - string recv_event_str; - for (const auto &recv_event_id : node_to_recv_events_[node]) { - recv_event_str += " " + to_string(recv_event_id); - } - if (!recv_event_str.empty()) { - GELOGI("node: %s, recv events: %s", node->GetName().c_str(), recv_event_str.c_str()); - } - } - } -} - -// Add active entry stream for special env. -Status StreamAllocator::AddActiveEntryStream() { - auto gelib = GELib::GetInstance(); - bool head_stream = (gelib == nullptr) ? false : gelib->HeadStream(); - GELOGI("Configured head stream: %u", head_stream); - if (!head_stream) { - return SUCCESS; - } - - // Collect streams active by StreamSwitch/StreamActive node. - std::set deactive_stream; - for (ge::NodePtr &node : whole_graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - Status ret = CollectDeactiveStream(node->GetOpDesc(), deactive_stream); - if (ret != SUCCESS) { - return ret; - } - } - - // Collect default active stream, Add to active entry stream. - std::vector active_stream_list; - for (int64_t stream_id = 0; stream_id < stream_num_; ++stream_id) { - if (deactive_stream.count(stream_id) == 0) { - active_stream_list.push_back(stream_id); - } - } - - int64_t new_stream_id = stream_num_; - stream_num_++; - return InsertActiveEntryStream(active_stream_list, new_stream_id); -} - -// Collect deactive stream from flowctrl op. -Status StreamAllocator::CollectDeactiveStream(const OpDescPtr &op_desc, std::set &deactive_streams) const { - GE_CHECK_NOTNULL(op_desc); - std::string op_type = op_desc->GetType(); - if (op_type == STREAMSWITCH) { - std::vector active_stream_list; - // If GetListInt fail, active_stream_list is empty. - (void)ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list); - if (active_stream_list.size() != kMaxSwitchStreamNum) { - GELOGE(INTERNAL_ERROR, "Stream num of switch true branch must be %u.", kMaxSwitchStreamNum); - return INTERNAL_ERROR; - } - - deactive_streams.insert(active_stream_list[0]); - GELOGI("Flowctrl_op node:%s, flowctrl stream id:%u.", op_desc->GetName().c_str(), active_stream_list[0]); - } else if (op_type == STREAMACTIVE) { - if (op_desc->HasAttr(ATTR_NAME_SWITCH_BRANCH_NODE_LABEL)) { - std::vector active_stream_list; - if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list)) { - GELOGE(INTERNAL_ERROR, "StreamActiveOp get attr ACTIVE_STREAM fail."); - return INTERNAL_ERROR; - } - - for (uint32_t deactive_stream : active_stream_list) { - deactive_streams.insert(deactive_stream); - GELOGI("Flowctrl_op node:%s, flowctrl stream id:%u.", op_desc->GetName().c_str(), deactive_stream); - } - } - } - - return SUCCESS; -} - -// Insert StreamActive Op for Entry Stream. -Status StreamAllocator::InsertActiveEntryStream(const std::vector &active_streams, int64_t stream_id) { - string node_name = "ActiveEntryStream_" + string(STREAMACTIVE); - OpDescPtr op_desc = ge::MakeShared(node_name, STREAMACTIVE); - if (op_desc == nullptr) { - GELOGE(FAILED, "Failed to new opdesc."); - return FAILED; - } - GELOGI("Create StreamActive op:%s.", op_desc->GetName().c_str()); - - GE_CHK_BOOL_EXEC( - AttrUtils::SetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, std::move(std::vector())), - GELOGE(FAILED, "SetListStr failed."); - return FAILED); - - NodePtr active_node = whole_graph_->AddNodeFront(op_desc); - GE_IF_BOOL_EXEC(active_node == nullptr, - GELOGE(FAILED, "Create StreamActive op: %s failed.", op_desc->GetName().c_str()); - return INTERNAL_ERROR); - GE_CHECK_NOTNULL(active_node->GetOpDesc()); - // Add one stream for ActiveEntryStream Task. - active_node->GetOpDesc()->SetStreamId(stream_id); - - GE_CHK_BOOL_EXEC(AttrUtils::SetBool(op_desc, "is_aicpu_stream", true), GELOGE(FAILED, "SetBool failed."); - return FAILED); - GE_CHK_BOOL_EXEC(AttrUtils::SetListInt(active_node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams), - GELOGE(FAILED, "SetListInt failed."); - return FAILED); - - std::vector group_names; - GE_CHK_BOOL_EXEC(AttrUtils::SetListStr(active_node->GetOpDesc(), ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, group_names), - GELOGE(FAILED, "SetLisStr failed."); - return FAILED); - - return SUCCESS; -} - -Status StreamAllocator::InsertActiveNodesAfterSwitch(NodePtr &switch_node) { - vector active_nodes; - if (InsertActiveNodesAfterSwitch(switch_node, active_nodes) != SUCCESS) { - GELOGE(FAILED, "Insert active nodes after node %s failed.", switch_node->GetName().c_str()); - return FAILED; - } - if (active_nodes.empty()) { - return SUCCESS; - } - vector stream_ids; - for (auto &active_node : active_nodes) { - GE_CHECK_NOTNULL(active_node->GetOpDesc()); - active_node->GetOpDesc()->SetStreamId(stream_num_); - stream_ids.emplace_back(stream_num_); - specific_activated_streams_.emplace(stream_num_); - stream_num_++; - } - auto op_desc = switch_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - if (!AttrUtils::SetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, stream_ids)) { - GELOGE(FAILED, "SetListInt failed."); - return FAILED; - } - - return SUCCESS; -} - -Status StreamAllocator::InsertActiveNodesAfterSwitch(NodePtr &switch_node, vector &active_nodes) { - GE_CHECK_NOTNULL(switch_node); - OpDescPtr switch_desc = switch_node->GetOpDesc(); - GE_CHECK_NOTNULL(switch_desc); - vector ori_active_label_list; - if (!AttrUtils::GetListStr(switch_desc, ATTR_NAME_ACTIVE_LABEL_LIST, ori_active_label_list) || - ori_active_label_list.empty()) { - GELOGE(INTERNAL_ERROR, "Get active label list of switch %s failed.", switch_node->GetName().c_str()); - return INTERNAL_ERROR; - } - - vector active_label_list; - vector added_active_nodes; - if (AddActiveNodes(switch_node, ori_active_label_list, active_label_list, added_active_nodes) != SUCCESS) { - GELOGE(FAILED, "Add active nodes after node %s failed.", switch_node->GetName().c_str()); - return FAILED; - } - - if (SetActiveLabelList(switch_node, active_label_list) != SUCCESS) { - GELOGE(FAILED, "set active label list failed"); - return FAILED; - } - - if (added_active_nodes.empty()) { - return SUCCESS; - } - - for (auto &active_node : added_active_nodes) { - GE_CHECK_NOTNULL(switch_node->GetOutControlAnchor()); - if (switch_node->GetOutControlAnchor()->LinkTo(active_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Link %s to %s failed.", switch_node->GetName().c_str(), active_node->GetName().c_str()); - return FAILED; - } - active_nodes.emplace_back(active_node); +Status StreamAllocator::AddEventId(const NodePtr &pre_node, const NodePtr ¬_cur, const NodePtr &cur_node, + bool not_use_cur) { + GELOGI("Add send event %u for node %s", event_num_, pre_node->GetName().c_str()); + AddSendEventId(pre_node, event_num_); + if (not_use_cur) { + GE_CHECK_NOTNULL(not_cur); + GELOGI("Add recv event %u for node %s", event_num_, not_cur->GetName().c_str()); + AddRecvEventId(not_cur, event_num_); + } else { + GELOGI("Add recv event %u for node %s", event_num_, cur_node->GetName().c_str()); + AddRecvEventId(cur_node, event_num_); } + ++event_num_; return SUCCESS; } diff --git a/src/ge/graph/build/stream_allocator.h b/src/ge/graph/build/stream_allocator.h index ea6d08a3..528c22a9 100644 --- a/src/ge/graph/build/stream_allocator.h +++ b/src/ge/graph/build/stream_allocator.h @@ -40,41 +40,47 @@ class StreamAllocator { const vector &GetHugeStreams() const { return huge_streams_; } private: - Status SplitStreams(std::vector> &split_streams); - Status AssignSingleStream(); + Status SetActiveStreamsByLabel(); - Status UpdateActiveStreams(const std::vector> &splited_streams); - void UpdateLabelStreams(const std::vector> &split_streams); - Status SetActiveStreamsForSubgraph(); - Status SetActiveStreamsForLoop(); - Status CheckStreamActived() const; - Status GetMaxStreamAndTask(bool huge_stream, uint32_t &max_stream_count, uint32_t &max_task_count); - int64_t GetMaxNodeNumPerStream(const NodePtr &node, uint32_t max_node_num_one_stream); + Status SetActiveStreamsForSubgraphs(); Status InsertSyncEvents(); Status InsertOneEventInTwoNodes(const NodePtr &cur_node_ptr, const NodePtr &next_node_ptr); + Status InsertEventsForSubgraph(); Status OptimizeSyncEvents(); Status OptimizeBySendEvents(const std::map> &stream_nodes); Status OptimizeByRecvEvents(const std::map> &stream_nodes); Status OptimizeByStreamActivate(); + // Determine if the successor node of RecvNode is directly or indirectly activated by the SendNode precursor node + bool IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const; - Status RefreshContinuousEvents(); - Status InsertSyncEventNodes(); - Status ReorderEventNodes() const; + Status SplitStreams(std::vector> &split_streams); + bool NeedSpiltNewStream(int64_t stream_node_num, int64_t max_node_num_one_stream, const OpDescPtr &op_desc) const; + Status UpdateActiveStreams(const std::vector> &splited_streams); + void UpdateLabelStreams(const std::vector> &split_streams); Status InsertActiveNodesAfterSwitch(NodePtr &switch_node); Status InsertActiveNodesAfterSwitch(NodePtr &switch_nodes, std::vector &switch_active_nodes); - Status SetActiveStreamList(NodePtr &active_node, const std::string &active_label); - Status AddActiveNodes(NodePtr &switch_node, const std::vector &ori_active_label_list, - std::vector &active_label_list, std::vector &added_active_nodes); + Status UpdateActiveStreamsForSubgraphs() const; + Status SetActiveStreamsForLoop(); + Status CheckStreamActived() const; Status AddActiveEntryStream(); Status CollectDeactiveStream(const OpDescPtr &op_desc, std::set &deactive_streams) const; Status InsertActiveEntryStream(const std::vector &active_streams, int64_t stream_id); - Status AddEventId(const NodePtr &pre_node, const NodePtr ¬_cur, const NodePtr &cur_node, bool not_use_cur); + Status RefreshContinuousEvents(); + + Status InsertSyncEventNodes(); + Status ReorderEventNodes() const; + + void DumpEvents(); + + Status GetMaxStreamAndTask(bool huge_stream, uint32_t &max_stream_count, uint32_t &max_task_count); + int64_t GetMaxNodeNumPerStream(const NodePtr &node, uint32_t max_node_num_one_stream); + void AddSendEventId(const NodePtr &node, uint32_t event_id); void AddRecvEventId(const NodePtr &node, uint32_t event_id); void RmvSendEventId(const NodePtr &node, uint32_t event_id); @@ -83,10 +89,11 @@ class StreamAllocator { void GetRecvEventIdList(const NodePtr &node, std::vector &recv_list) const; NodePtr GetNodeFromSendEventId(uint32_t send_event_id) const; NodePtr GetNodeFromRecvEventId(uint32_t recv_event_id) const; + Status AddEventId(const NodePtr &pre_node, const NodePtr ¬_cur, const NodePtr &cur_node, bool not_use_cur); - void DumpEvents(); - // Determine if the successor node of RecvNode is directly or indirectly activated by the SendNode precursor node - bool IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const; + Status AddActiveNodes(NodePtr &switch_node, const std::vector &ori_active_label_list, + std::vector &active_label_list, std::vector &added_active_nodes); + Status SetActiveStreamList(NodePtr &active_node, const std::string &active_label); ComputeGraphPtr whole_graph_; const Graph2SubGraphInfoList &subgraphs_; @@ -102,6 +109,9 @@ class StreamAllocator { std::set specific_activated_streams_; std::map> specific_activated_streams_nodes_map_; + std::map node_split_stream_map_; + std::map subgraph_first_active_node_map_; + // send events corresponding to the node std::map> node_to_send_events_; @@ -109,4 +119,4 @@ class StreamAllocator { std::map> node_to_recv_events_; }; } // namespace ge -#endif // GE_GRAPH_BUILD_STREAM_ALLOCATOR_H_ +#endif // GE_GRAPH_BUILD_STREAM_ALLOCATOR_H_ \ No newline at end of file diff --git a/src/ge/graph/build/stream_graph_optimizer.cc b/src/ge/graph/build/stream_graph_optimizer.cc index 204a98b2..a3e8044d 100644 --- a/src/ge/graph/build/stream_graph_optimizer.cc +++ b/src/ge/graph/build/stream_graph_optimizer.cc @@ -30,7 +30,7 @@ namespace ge { StreamGraphOptimizer::~StreamGraphOptimizer() {} void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map) { - size_t node_size = comp_graph->GetDirectNodesSize(); + size_t node_size = comp_graph->GetAllNodesSize(); GELOGI("Refresh placeholder and end nodeId start from node num: %zu", node_size); for (const auto &subgraph_pair : subgraph_map) { for (const auto &subgraph_info : subgraph_pair.second) { diff --git a/src/ge/graph/build/task_generator.cc b/src/ge/graph/build/task_generator.cc index a6bc6128..ec6bf584 100644 --- a/src/ge/graph/build/task_generator.cc +++ b/src/ge/graph/build/task_generator.cc @@ -17,9 +17,9 @@ #include "graph/build/task_generator.h" #include #include +#include "common/profiling/profiling_manager.h" #include "common/types.h" #include "common/util.h" -#include "common/profiling/profiling_manager.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" @@ -73,12 +73,24 @@ Status TaskGenerator::GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t std::vector task_def_list; std::map op_name_map; + GE_DUMP(graph, "GenerateTaskBefore"); + bool is_unknown_shape = false; + NodePtr parent_node = graph->GetParentNode(); + if (parent_node != nullptr) { + auto op_desc = parent_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + (void)AttrUtils::GetBool(op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); + } + Status ret = SUCCESS; + if (is_unknown_shape) { + GELOGI("Beign to generate unknown shape task."); + ret = GenerateUnknownShapeTask(run_context, graph, task_def_list, op_name_map); + } else { + GELOGI("Beign to generate known shape task."); + ret = GenerateTask(run_context, graph, task_def_list, op_name_map); + } + GE_DUMP(graph, "GenerateTaskAfter"); - GraphUtils::DumpGEGraph(graph, "GenerateTaskBefore"); - GraphUtils::DumpGEGraphToOnnx(*graph, "GenerateTaskBefore"); - Status ret = GenerateTask(run_context, graph, task_def_list, op_name_map); - GraphUtils::DumpGEGraph(graph, "GenerateTaskAfter"); - GraphUtils::DumpGEGraphToOnnx(*graph, "GenerateTaskAfter"); if (ret != SUCCESS) { GELOGE(ret, "GenerateTask failed. session_id=%lu", session_id); return ret; @@ -251,8 +263,9 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GE_TIMESTAMP_CALLNUM_START(GenerateTask); // map store fusion nodes map> fusion_nodes; - const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); - if (buffer_optimize_on != nullptr) { + string buffer_optimize = "off_optimize"; + (void)ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize); + if (buffer_optimize != "off_optimize") { GE_CHK_STATUS_RET(SaveFusionNodes(fusion_nodes, graph)); } std::unordered_set fusion_nodes_seen; @@ -342,10 +355,125 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast(ops_kernel_info_store_ptr)); } - GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task finished, generate %lu task(s).", + GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task finished, generate %zu task(s).", + op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, + task_list_size_after - task_list_size_before); + } + GE_TIMESTAMP_CALLNUM_END(GenerateTask, "GraphBuild::GenerateTask"); + return SUCCESS; +} + +Status TaskGenerator::GenerateUnknownShapeTask(RunContext &run_context, ComputeGraphPtr &graph, + vector &task_def_list, + map &op_name_map) { + std::shared_ptr ge_lib = GELib::GetInstance(); + if ((ge_lib == nullptr) || !ge_lib->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GenerateTask failed."); + return GE_CLI_GE_NOT_INITIALIZED; + } + GE_CHK_STATUS_RET(MarkNodeAndSetIndex(graph), "MarkNodeAndSetIndex failed."); + ProfilingPoint profiling_point; + vector all_reduce_nodes; + GE_CHK_STATUS_RET(FindProfilingTaskIndex(graph, profiling_point, all_reduce_nodes)); + + const OpsKernelManager &ops_kernel_manager = ge_lib->OpsKernelManagerObj(); + + GE_TIMESTAMP_CALLNUM_START(GenerateTask); + // map store fusion nodes + map> fusion_nodes; + string buffer_optimize = "off_optimize"; + (void)ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize); + if (buffer_optimize != "off_optimize") { + GE_CHK_STATUS_RET(SaveFusionNodes(fusion_nodes, graph)); + } + std::unordered_set fusion_nodes_seen; + int64_t group_key; + uint32_t node_index = 0; + rtStream_t stream = nullptr; + GE_CHK_RT_RET(rtStreamCreate(&stream, 0)); + run_context.stream = stream; + GE_CHK_RT_RET(rtModelBindStream(run_context.model, stream, 0)); + for (auto &node : graph->GetAllNodes()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + node_index++; + string name = node->GetName(); + string type = node->GetType(); + bool attr_notask = false; + bool get_attr_notask_flag = ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOTASK, attr_notask); + GE_IF_BOOL_EXEC(get_attr_notask_flag && attr_notask, + GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); + continue); + + GE_CHK_STATUS_RET(UpdateOpIsVarAttr(op_desc, graph->GetSessionID())); + string op_kernel_lib_name = op_desc->GetOpKernelLibName(); + // For fusion ddb pass, task def must be continuous. + // Part2: Call + auto fusion_task_info = + FusionTaskInfo{run_context, graph, node, op_desc, node_index, ge_lib, + ops_kernel_manager, task_def_list, op_name_map, profiling_point, all_reduce_nodes}; + GE_CHK_STATUS_RET(GenerateTaskForFusionNode(fusion_task_info, fusion_nodes, fusion_nodes_seen), + "Call GenerateTaskForFusionNode node:%s(%s) failed", name.c_str(), type.c_str()); + // continue directly + if (ge::AttrUtils::GetInt(op_desc, ATTR_NAME_FUSION_GROUP_KEY, group_key)) { + GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); + continue; + } + if (op_kernel_lib_name.empty()) { + GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); + continue; + } + OpsKernelInfoStorePtr kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); + if (kernel_info_store == nullptr) { + GELOGE(INTERNAL_ERROR, "No ops kernel store found. node:%s(%s), op_kernel_lib_name=%s.", name.c_str(), + type.c_str(), op_kernel_lib_name.c_str()); + return INTERNAL_ERROR; + } + GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "Call UpdateAnchorStatus node:%s(%s) failed", name.c_str(), + type.c_str()); + int64_t op_id = op_desc->GetId(); + int64_t stream_id = op_desc->GetStreamId(); + // Profiling task + size_t task_list_size_before = task_def_list.size(); + GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); + + GELOGI("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task.", op_kernel_lib_name.c_str(), + name.c_str(), type.c_str(), op_id, stream_id); + GE_TIMESTAMP_RESTART(GenerateTask); + auto ret = kernel_info_store->GenerateTask(*node, run_context, task_def_list); + GE_TIMESTAMP_ADD(GenerateTask); + if (ret != SUCCESS) { + GELOGE(ret, "Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task failed.", + op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id); + return ret; + } + // Profiling task + GE_CHK_STATUS_RET(InsertProfilingTaskAfter(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); + size_t task_list_size_after = task_def_list.size(); + // If tasks is reduced + if (task_list_size_after < task_list_size_before) { + GELOGE(FAILED, "Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task. but task num from %zu to %zu.", + op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, task_list_size_before, + task_list_size_after); + return FAILED; + } + + // Reset stream id to ge stream id, as graph load must use ge stream to reassign stream + void *ops_kernel_info_store_ptr = kernel_info_store.get(); + for (size_t idx = task_list_size_before; idx < task_list_size_after; ++idx) { + op_name_map[idx] = name; + // Set opsKernelInfoStorePtr and op_index, the two fields be use in DistributeTask and InitTaskInfo + TaskDef *task_def_ptr = &task_def_list[idx]; + GE_CHECK_NOTNULL(task_def_ptr); + task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast(ops_kernel_info_store_ptr)); + } + + GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task finished, generate %zu task(s).", op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, task_list_size_after - task_list_size_before); } + GE_CHK_RT(rtModelUnbindStream(run_context.model, stream)); + GE_CHK_RT(rtStreamDestroy(stream)); GE_TIMESTAMP_CALLNUM_END(GenerateTask, "GraphBuild::GenerateTask"); return SUCCESS; } @@ -381,6 +509,11 @@ Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info fusion_node_type.c_str()); continue; } + bool attr_notask = false; + GE_IF_BOOL_EXEC(ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOTASK, attr_notask) && attr_notask, + GELOGI("Fusion: fusion_node[name:%s, type:%s] does not need to generate task.", + fusion_node_name.c_str(), fusion_node_type.c_str()); + continue); size_t task_list_size_before = task_def_list.size(); OpsKernelInfoStorePtr kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); @@ -528,6 +661,10 @@ Status TaskGenerator::MarkFirstAndLastOps(const vector &ops, bool is_ vector> continuous_op_lists(1); const set label_op_types({LABELSET, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX}); for (auto &op_desc : ops) { + bool attr_notask = false; + if (ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOTASK, attr_notask) && attr_notask) { + continue; + } string op_type = op_desc->GetType(); if (!is_single_stream && (!op_desc->GetSubgraphInstanceNames().empty() || label_op_types.count(op_type) != 0)) { continuous_op_lists.emplace_back(vector()); @@ -629,7 +766,7 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP continue; } - if (op_desc->GetType() == HCOMALLREDUCE) { + if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HVDCALLBACKALLREDUCE) { bp_node = node; all_reduce_nodes.emplace_back(current_idx); GELOGI("Allreduce name %s, idx %u", op_desc->GetName().c_str(), current_idx); @@ -721,7 +858,7 @@ Status TaskGenerator::FindBpOfEnv(const ComputeGraphPtr &graph, const std::strin iter_end = current_idx; GELOGI("Iter end name %s, idx %u", op_desc->GetName().c_str(), iter_end); } - if (op_desc->GetType() == HCOMALLREDUCE) { + if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HVDCALLBACKALLREDUCE) { all_reduce_nodes.emplace_back(current_idx); GELOGI("Allreduce name %s, idx %u", op_desc->GetName().c_str(), current_idx); } diff --git a/src/ge/graph/build/task_generator.h b/src/ge/graph/build/task_generator.h index c666244b..02721e00 100644 --- a/src/ge/graph/build/task_generator.h +++ b/src/ge/graph/build/task_generator.h @@ -82,7 +82,7 @@ class TaskGenerator { Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id); /// - /// call engine to generate task. + /// call engine to generate known shape task. /// @param run_context run context /// @param graph compute graph /// @param task_def_list task def list generate by engine @@ -93,6 +93,18 @@ class TaskGenerator { Status GenerateTask(RunContext &run_context, ComputeGraphPtr &graph, std::vector &task_def_list, std::map &op_name_map); + /// + /// call engine to generate unknown shape task. + /// @param run_context run context + /// @param graph compute graph + /// @param task_def_list task def list generate by engine + /// @param op_name_map relation of task index and op + /// @return SUCCESS:seccess + /// Other: failed + /// + Status GenerateUnknownShapeTask(RunContext &run_context, ComputeGraphPtr &graph, + std::vector &task_def_list, std::map &op_name_map); + /// /// AddModelTaskToModel /// @param model_task_def model task diff --git a/src/ge/graph/execute/graph_execute.cc b/src/ge/graph/execute/graph_execute.cc index 6ef1f671..4173706a 100644 --- a/src/ge/graph/execute/graph_execute.cc +++ b/src/ge/graph/execute/graph_execute.cc @@ -258,7 +258,7 @@ Status GraphExecutor::SyncExecuteModel(uint32_t model_id, const std::vectorGetResultCode(); - if (result_code != SUCCESS) { + if (result_code != SUCCESS && result_code != END_OF_SEQUENCE) { GELOGE(GE_GRAPH_EXECUTE_FAILED, "[GraphExecutor] execute model failed, ret=%u, modelId=%u.", result_code, model_id); return GE_GRAPH_EXECUTE_FAILED; @@ -319,7 +319,7 @@ Status GraphExecutor::FreeExecuteMemory() { return SUCCESS; } -Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeModelPtr &ge_model, +Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_root_model, const std::vector &input_tensor, std::vector &output_tensor) { if (graph_id != last_graph_id_) { auto ret = FreeExecuteMemory(); @@ -333,8 +333,8 @@ Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeModelPtr &ge_model, GELOGE(GE_GRAPH_EXECUTE_NOT_INIT, "[GraphExecutor] AI Core Engine without calling SetCondition!"); return GE_GRAPH_EXECUTE_NOT_INIT; } - GE_CHECK_NOTNULL_EXEC(ge_model, return FAILED); - Status ret = SyncExecuteModel(ge_model->GetModelId(), input_tensor, output_tensor); + GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); + Status ret = SyncExecuteModel(ge_root_model->GetModelId(), input_tensor, output_tensor); if (ret != SUCCESS) { GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] SyncExecuteModel Error!"); return GE_GRAPH_SYNC_MODEL_FAILED; @@ -343,7 +343,7 @@ Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeModelPtr &ge_model, return SUCCESS; } -Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeModelPtr &ge_model, +Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, const std::vector &input_tensor) { GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id); if (graph_id != last_graph_id_) { @@ -353,8 +353,8 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeModelPtr &ge_m } } last_graph_id_ = graph_id; - GE_CHECK_NOTNULL_EXEC(ge_model, return FAILED); - Status ret = AsyncExecuteModel(ge_model->GetModelId(), input_tensor); + GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); + Status ret = AsyncExecuteModel(ge_root_model->GetModelId(), input_tensor); if (ret != SUCCESS) { GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!"); return GE_GRAPH_SYNC_MODEL_FAILED; diff --git a/src/ge/graph/execute/graph_execute.h b/src/ge/graph/execute/graph_execute.h index e7fd2084..9d4ecc24 100644 --- a/src/ge/graph/execute/graph_execute.h +++ b/src/ge/graph/execute/graph_execute.h @@ -46,11 +46,11 @@ class GraphExecutor { virtual ~GraphExecutor(); - Status ExecuteGraph(GraphId graph_id, const GeModelPtr &ge_model, const std::vector &input_tensor, + Status ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_root_model, const std::vector &input_tensor, std::vector &output_tensor); - Status ExecuteGraphAsync(GraphId graph_id, const GeModelPtr &ge_model, - const std::vector &input_tensor); + ge::Status ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, + const std::vector &input_tensor); Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr listener); diff --git a/src/ge/graph/label/label_maker.cc b/src/ge/graph/label/label_maker.cc index 0c3e0adf..88b90199 100644 --- a/src/ge/graph/label/label_maker.cc +++ b/src/ge/graph/label/label_maker.cc @@ -94,42 +94,6 @@ void LabelMaker::SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr op_desc->SetStreamId(stream_id); } -/** - * @ingroup ge - * @brief Link Node to Graph head. - * @param [in] graph: graph for add node. - * @param [in] lb_node: Node for set link to head. - * @return: SUCCESS / FAILED - */ -Status LabelMaker::AddCtrlLink2Data(const ComputeGraphPtr &graph, const NodePtr &node) { - GE_CHECK_NOTNULL(graph); - GE_CHECK_NOTNULL(node); - - std::set linked_nodes; - for (const NodePtr &n : graph->GetDirectNode()) { - GE_CHECK_NOTNULL(n); - if (n->GetType() != DATA) { - continue; - } - - // Link control edge to graph head. - for (const NodePtr &out_node : n->GetOutAllNodes()) { - if (linked_nodes.count(out_node) > 0) { - continue; - } - - (void)linked_nodes.insert(out_node); - if (GraphUtils::AddEdge(node->GetOutControlAnchor(), out_node->GetInControlAnchor()) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add ctrl edge from %s to %s failed.", node->GetName().c_str(), - out_node->GetName().c_str()); - return FAILED; - } - } - } - - return SUCCESS; -} - /** * @ingroup ge * @brief Add StreamActive node at graph front. @@ -154,15 +118,10 @@ NodePtr LabelMaker::AddStreamActive(const ComputeGraphPtr &graph, const std::str vector active_streams; (void)AttrUtils::SetStr(op_desc, ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, op_desc->GetName()); (void)AttrUtils::SetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams); + (void)AttrUtils::SetBool(op_desc, ATTR_NAME_SUBGRAPH_FIRST_ACTIVE, true); NodePtr stream_active = graph->AddNodeFront(op_desc); GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr); - // Link control edge to graph head. - if (AddCtrlLink2Data(graph, stream_active) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add ctrl edge for graph %s failed.", graph->GetName().c_str()); - return nullptr; - } - return stream_active; } @@ -230,6 +189,7 @@ NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::st GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); + (void)AttrUtils::SetBool(op_desc, ATTR_NAME_SUBGRAPH_END_NODE, true); NodePtr label_set = graph->AddNode(op_desc); GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); diff --git a/src/ge/graph/label/label_maker.h b/src/ge/graph/label/label_maker.h index f77c3dc9..759bf5cf 100644 --- a/src/ge/graph/label/label_maker.h +++ b/src/ge/graph/label/label_maker.h @@ -60,7 +60,6 @@ class LabelMaker { ComputeGraphPtr parent_graph_; private: - Status AddCtrlLink2Data(const ComputeGraphPtr &graph, const NodePtr &node); void SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); void SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); void SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); diff --git a/src/ge/graph/label/partitioned_call_label_maker.cc b/src/ge/graph/label/partitioned_call_label_maker.cc index 39c88717..64db223b 100644 --- a/src/ge/graph/label/partitioned_call_label_maker.cc +++ b/src/ge/graph/label/partitioned_call_label_maker.cc @@ -50,6 +50,21 @@ Status PartitionedCallLabelMaker::Run(uint32_t &label_index) { return FAILED; } + const std::string stream_active_name = parent_node_->GetName() + "/StreamActive"; // rtStreamActive + NodePtr stream_active = AddStreamActive(sub_graph, stream_active_name); + if (stream_active == nullptr) { + GELOGE(INTERNAL_ERROR, "Subgraph: %s add stream active node failed.", sub_graph->GetName().c_str()); + return FAILED; + } + + for (auto &node : sub_graph->GetDirectNode()) { + if (node->GetType() == NETOUTPUT) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + (void)AttrUtils::SetBool(op_desc, ATTR_NAME_SUBGRAPH_END_NODE, true); + } + } + return SUCCESS; } diff --git a/src/ge/graph/label/partitioned_call_label_maker.h b/src/ge/graph/label/partitioned_call_label_maker.h index c78a06fc..1c0f0890 100644 --- a/src/ge/graph/label/partitioned_call_label_maker.h +++ b/src/ge/graph/label/partitioned_call_label_maker.h @@ -29,6 +29,8 @@ +---------------+ +---------------+ | Node | +---------------+ + +---------------+ | StreamActive | + | Node | +---------------+ +---------------+ | f | | Node | +---------------+ +---------------+ | u | diff --git a/src/ge/graph/load/graph_loader.cc b/src/ge/graph/load/graph_loader.cc index 87db7f3d..1f4cbcf9 100644 --- a/src/ge/graph/load/graph_loader.cc +++ b/src/ge/graph/load/graph_loader.cc @@ -53,7 +53,7 @@ Status GraphLoader::UnloadModel(uint32_t model_id) { return SUCCESS; } -Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr &ge_model_ptr, +Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr &ge_root_model_ptr, const std::shared_ptr &listener) { GELOGI("Load model online begin."); rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); @@ -62,15 +62,15 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptrGetModelId(); + model_id = ge_root_model_ptr->GetModelId(); auto model_manager = ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); - Status ret = model_manager->LoadModelOnline(model_id, ge_model_ptr, listener); + Status ret = model_manager->LoadModelOnline(model_id, ge_root_model_ptr, listener); if (ret != SUCCESS) { GELOGE(ret, "LoadModel: Load failed. ret = %u", ret); CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_LOAD); diff --git a/src/ge/graph/load/graph_loader.h b/src/ge/graph/load/graph_loader.h index 5fe37a36..c887c06b 100644 --- a/src/ge/graph/load/graph_loader.h +++ b/src/ge/graph/load/graph_loader.h @@ -71,7 +71,7 @@ class GraphLoader { static Status DestroyAicpuSessionForInfer(uint32_t model_id); - static Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr &model, + static Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr &ge_root_model, const std::shared_ptr &listener); }; } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/data_dumper.cc b/src/ge/graph/load/new_model_manager/data_dumper.cc index db675132..db9318ec 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.cc +++ b/src/ge/graph/load/new_model_manager/data_dumper.cc @@ -15,9 +15,12 @@ */ #include "graph/load/new_model_manager/data_dumper.h" + +#include #include #include #include + #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" @@ -32,6 +35,7 @@ namespace { const uint32_t kAicpuLoadFlag = 1; const uint32_t kAicpuUnloadFlag = 0; +const uint32_t kTimeBufferLen = 80; const char *const kDumpOutput = "output"; const char *const kDumpInput = "input"; const char *const kDumpAll = "all"; @@ -156,10 +160,8 @@ void DataDumper::SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::s return; } - uintptr_t data_addr = args - sizeof(void *) * op_desc->GetInputOffset().size() + - sizeof(void *) * static_cast(inner_input_mapping.input_anchor_index); GELOGI("Save input dump task %s, id: %u.", data_op->GetName().c_str(), task_id); - op_list_.push_back({task_id, stream_id, data_op, data_addr, false, inner_input_mapping.input_anchor_index, + op_list_.push_back({task_id, stream_id, data_op, args, false, inner_input_mapping.input_anchor_index, inner_input_mapping.output_anchor_index, input_tensor->GetShape().GetDims()}); } } @@ -188,11 +190,24 @@ static void SetOpMappingLoopAddr(uintptr_t step_id, uintptr_t loop_per_iter, uin } } +static std::string GetCurrentTime() { + std::time_t now = std::time(nullptr); + std::tm *ptm = std::localtime(&now); + if (ptm == nullptr) { + return ""; + } + char buffer[kTimeBufferLen] = {0}; + // format: 20171122042550 + std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm); + return std::string(buffer); +} + Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { GELOGI("Start dump output"); if (inner_dump_info.is_task) { // tbe or aicpu op const auto &output_descs = inner_dump_info.op->GetAllOutputsDesc(); + const auto input_size = inner_dump_info.op->GetAllInputsDesc().size(); const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op, false); if (output_descs.size() != output_addrs.size()) { GELOGE(PARAM_INVALID, "Invalid output desc addrs size %zu, op %s has %zu output desc.", output_addrs.size(), @@ -217,8 +232,7 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: output.set_original_output_index(origin_output_index); output.set_original_output_format(static_cast(output_descs.at(i).GetOriginFormat())); output.set_original_output_data_type(static_cast(output_descs.at(i).GetOriginDataType())); - // due to lhisi virtual addr bug, cannot use args now - output.set_address(static_cast(reinterpret_cast(output_addrs[i]))); + output.set_address(static_cast(inner_dump_info.args + (i + input_size) * sizeof(void *))); task.mutable_output()->Add(std::move(output)); } @@ -255,8 +269,8 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: GELOGE(FAILED, "Index is out of range."); return FAILED; } - output.set_address( - static_cast(reinterpret_cast(output_addrs[inner_dump_info.output_anchor_index]))); + auto data_addr = inner_dump_info.args + sizeof(void *) * static_cast(inner_dump_info.input_anchor_index); + output.set_address(static_cast(data_addr)); task.mutable_output()->Add(std::move(output)); @@ -282,7 +296,7 @@ Status DataDumper::DumpInput(const InnerDumpInfo &inner_dump_info, aicpu::dump:: input.mutable_shape()->add_dim(dim); } - input.set_address(static_cast(reinterpret_cast(input_addrs[i]))); + input.set_address(static_cast(inner_dump_info.args + sizeof(void *) * i)); task.mutable_input()->Add(std::move(input)); } return SUCCESS; @@ -370,7 +384,10 @@ Status DataDumper::LoadDumpInfo() { } aicpu::dump::OpMappingInfo op_mapping_info; - op_mapping_info.set_dump_path(PropertiesManager::Instance().GetDumpOutputPath() + std::to_string(device_id_) + "/"); + std::string time_now = GetCurrentTime(); + GELOGI("Time is %s now", time_now.c_str()); + op_mapping_info.set_dump_path(PropertiesManager::Instance().GetDumpOutputPath() + time_now + "/" + + std::to_string(device_id_) + "/"); op_mapping_info.set_model_name(model_name_); op_mapping_info.set_model_id(model_id_); op_mapping_info.set_flag(kAicpuLoadFlag); diff --git a/src/ge/graph/load/new_model_manager/davinci_model.cc b/src/ge/graph/load/new_model_manager/davinci_model.cc index 33a4fcf4..a0e88f3c 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.cc +++ b/src/ge/graph/load/new_model_manager/davinci_model.cc @@ -45,6 +45,7 @@ #include "graph/load/output/output.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_var_manager.h" +#include "graph/manager/trans_var_data_utils.h" #include "graph/manager/util/debug.h" #include "graph/model_serialize.h" #include "graph/node.h" @@ -75,6 +76,7 @@ namespace ge { namespace { const uint32_t kDataIndex = 0; +const uint32_t kOutputNum = 1; const uint32_t kTrueBranchStreamNum = 1; const uint32_t kThreadNum = 16; const uint32_t kAddrLen = sizeof(void *); @@ -83,275 +85,6 @@ const int kBytes = 8; const uint32_t kDataMemAlignSizeCompare = 64; const char *const kDefaultBatchLable = "Batch_default"; -class RtContextSwitchGuard { - public: - RtContextSwitchGuard(rtCtxMode_t mode, uint32_t device_id) : last_(nullptr), current_(nullptr) { - auto ret = rtCtxGetCurrent(&last_); - if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Failed to get current context from rt, error-code %d", ret); - return; - } - - ret = rtCtxCreate(¤t_, mode, static_cast(device_id)); - if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Failed to create new context for device %u, error-code %d", device_id, ret); - return; - } - - ret = rtCtxSetCurrent(current_); - if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Failed to switch context to normal, context %p, device %u", current_, device_id); - return; - } - GELOGD("Create and switch rt context %p type %d for device %u, backup last %p.", current_, mode, device_id, last_); - } - - ~RtContextSwitchGuard() { - if (current_ != nullptr) { - auto ret = rtCtxDestroy(current_); - GELOGD("Destory current context %p result %d", current_, ret); - } - if (last_ != nullptr) { - auto ret = rtCtxSetCurrent(last_); - GELOGD("Recovery last context %p result %d.", last_, ret); - } - } - - private: - rtContext_t last_; - rtContext_t current_; -}; - -int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { - int64_t var_size = GetSizeByDataType(desc.GetDataType()); - if (var_size <= 0) { - GELOGE(PARAM_INVALID, "Failed to calc var data size from data type %s", - TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str()); - return -1; - } - auto shape = desc.GetShape(); - auto dim_num = shape.GetDimNum(); - for (size_t dim_index = 0; dim_index < dim_num; ++dim_index) { - var_size *= shape.GetDim(dim_index); - } - return var_size; -} - -Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_ptr &var_data, - const GeTensorDesc &input_desc) { - uint8_t *var_logic = nullptr; - GE_CHECK_NOTNULL(var); - auto ret = VarManager::Instance(session_id)->GetVarAddr(var->GetName(), input_desc, &var_logic); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, - "Failed to copy var %s from device, can not find it" - " from var manager %u", - var->GetName().c_str(), ret); - return INTERNAL_ERROR; - } - - uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM); - if (var_addr == nullptr) { - GELOGE(INTERNAL_ERROR, - "Failed to copy var %s from device, cant not get " - "var addr from logic addr %p", - var->GetName().c_str(), var_logic); - return INTERNAL_ERROR; - } - - int64_t var_size_bytes = CalcVarSizeInBytes(input_desc); - if (var_size_bytes <= 0) { - return INTERNAL_ERROR; - } - - std::unique_ptr var_host(new (std::nothrow) uint8_t[var_size_bytes]); - if (var_host == nullptr) { - GELOGE(OUT_OF_MEMORY, "Failed to malloc rt-host memory, size %ld", var_size_bytes); - return OUT_OF_MEMORY; - } - - ret = rtMemcpy(reinterpret_cast(var_host.get()), var_size_bytes, reinterpret_cast(var_addr), - var_size_bytes, RT_MEMCPY_DEVICE_TO_HOST); - if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, - "Failed to copy var memory from device, var %s, size %ld," - " rt-error-code %u", - var->GetName().c_str(), var_size_bytes, ret); - return RT_FAILED; - } - - GELOGD("Copy var %s from device to host, size %ld", var->GetName().c_str(), var_size_bytes); - var_data.swap(var_host); - - GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr); - - return SUCCESS; -} - -Status CopyVarToDevice(const NodePtr &var, const formats::TransResult &trans_result, void *var_addr) { - GELOGD("Copy var %s from host to device, size %zu", var->GetName().c_str(), trans_result.length); - auto ret = rtMemcpy(var_addr, trans_result.length, reinterpret_cast(trans_result.data.get()), - trans_result.length, RT_MEMCPY_HOST_TO_DEVICE); - if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Failed to copy memory to device, size %zu", trans_result.length); - return RT_FAILED; - } - return SUCCESS; -} - -Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats::TransResult &result) { - formats::TransResult result_last_time{}; - bool use_init_data = true; - for (const auto &trans_info : trans_road) { - if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) { - GELOGD("Skip to trans variable data on the reshape/reformat node"); - continue; - } - uint8_t *src_data = nullptr; - if (use_init_data) { - src_data = var_data; - use_init_data = false; - } else { - src_data = result_last_time.data.get(); - } - - formats::TransResult tmp_result{}; - if (trans_info.node_type == TRANSDATA) { - auto src_format = trans_info.input.GetFormat(); - auto src_shape = trans_info.input.GetShape().GetDims(); - auto dst_format = trans_info.output.GetFormat(); - auto dst_shape = trans_info.output.GetShape().GetDims(); - auto data_type = trans_info.input.GetDataType(); - GELOGD("Trans format from %s to %s, shape %s to %s, data-type %s", - TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), - formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(), - TypeUtils::DataTypeToSerialString(data_type).c_str()); - auto ret = formats::TransFormat({src_data, src_format, dst_format, src_shape, dst_shape, data_type}, tmp_result); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, - "Failed to trans format from %s to %s, shape %s to %s, " - "data type %s error code %u", - TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), - formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(), - TypeUtils::DataTypeToSerialString(data_type).c_str(), ret); - return ret; - } - } else if (trans_info.node_type == CAST) { - auto input_shape = trans_info.input.GetShape(); - auto src_data_size = input_shape.GetShapeSize() == 0 ? 1 : input_shape.GetShapeSize(); - auto src_data_type = trans_info.input.GetDataType(); - auto dst_data_type = trans_info.output.GetDataType(); - GELOGD("Trans data type from %s to %s, input shape %s, data size %ld", - TypeUtils::DataTypeToSerialString(src_data_type).c_str(), - TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), - src_data_size); - auto ret = formats::TransDataType({src_data, static_cast(src_data_size), src_data_type, dst_data_type}, - tmp_result); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to trans data type from %s to %s, input shape %s, data size %ld, error code %u", - TypeUtils::DataTypeToSerialString(src_data_type).c_str(), - TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), - src_data_size, ret); - return ret; - } - } else { - GELOGE(UNSUPPORTED, "Failed to trans var data, the trans type %s does not supported", - trans_info.node_type.c_str()); - return UNSUPPORTED; - } - result_last_time = tmp_result; - } - - result = result_last_time; - return SUCCESS; -} - -/// re-alloc var memory on device using var-manager -/// free origin var memory(var manager does not support now) -/// @param session_id -/// @param var -/// @param var_size_bytes -/// @param var_device -/// @return -Status ReAssignVarAddr(uint64_t session_id, const std::string &var_name, const GeTensorDesc &tensor_desc, - void **var_device) { - uint8_t *var_logic = nullptr; - Status ret = VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &var_logic); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, - "Failed to get var %s device addr, can not find it" - " from var manager %u", - var_name.c_str(), ret); - return INTERNAL_ERROR; - } - - uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM); - if (var_addr == nullptr) { - GELOGE(INTERNAL_ERROR, "Failed to convert var %s logic addr to real addr", var_name.c_str()); - return INTERNAL_ERROR; - } - *var_device = var_addr; - - GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr); - - return SUCCESS; -} - -Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t session_id) { - // do not need to do anything if only all reshape/reformat node on the trans_road - GE_CHECK_NOTNULL(var); - bool need_trans = false; - for (auto &road : trans_road) { - if (road.node_type != RESHAPE && road.node_type != REFORMAT) { - need_trans = true; - break; - } - } - if (!need_trans) { - return SUCCESS; - } - - // Sync var data from device - std::unique_ptr var_data; - if (trans_road.size() == 0) { - GELOGE(INTERNAL_ERROR, "Failed to get trans_road, trans_road is empty."); - return INTERNAL_ERROR; - } - const GeTensorDesc &input_desc = trans_road.begin()->input; - auto ret = CopyVarFromDevice(session_id, var, var_data, input_desc); - if (ret != SUCCESS) { - return ret; - } - - formats::TransResult trans_result{}; - ret = TransVarOnHost(var_data.get(), trans_road, trans_result); - if (ret != SUCCESS) { - GELOGE(ret, "Failed to trans var data on host, error code %u", ret); - return ret; - } - - void *var_device = nullptr; - - /// It is a temporary solution to use the last GeTensorDesc to assign variable memory because the variable manager - /// depends on TensorDesc and it is difficult to be modified. The correct solution is to assign memory based on the - /// size of the converted variable. To complete the final solution, the dependency of the variable manager on - /// TensorDesc needs to be removed. This change is large and needs to be performed step by step. - ret = ReAssignVarAddr(session_id, var->GetName(), trans_road.rbegin()->output, &var_device); - if (ret != SUCCESS) { - GELOGE(ret, "Failed to re-assign memory on device, size %zu", trans_result.length); - return ret; - } - - // sync new data to device - ret = CopyVarToDevice(var, trans_result, var_device); - if (ret != SUCCESS) { - GELOGE(ret, "Failed to send var data to device"); - return ret; - } - - return SUCCESS; -} - inline bool IsDataOp(const std::string &node_type) { return node_type == DATA_TYPE || node_type == AIPP_DATA_TYPE || node_type == ANN_DATA_TYPE; } @@ -474,6 +207,14 @@ DavinciModel::~DavinciModel() { CleanTbeHandle(); var_mem_base_ = nullptr; + if (known_node_) { + if (args_ != nullptr) { + GE_CHK_RT(rtFree(args_)); + } + if (args_host_ != nullptr) { + GE_CHK_RT(rtFreeHost(args_host_)); + } + } } catch (...) { GELOGW("DavinciModel::~DavinciModel: clear op_list catch exception."); } @@ -574,6 +315,14 @@ Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_p GELOGI("copy weights data to device"); } + GE_CHK_STATUS_RET(InitVariableMem(), "init variable mem failed."); + runtime_param_.mem_base = mem_base_; + runtime_param_.weight_base = weights_mem_base_; + return SUCCESS; +} + +Status DavinciModel::InitVariableMem() { + // malloc variable memory base var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); if (TotalVarMemSize() && var_mem_base_ == nullptr) { Status ret = VarManager::Instance(session_id_)->MallocVarMemory(TotalVarMemSize()); @@ -582,12 +331,9 @@ Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_p return ret; } var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); - GELOGI("[IMAS]InitModelMem graph_%u MallocMemory type[V] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, + GELOGI("[IMAS]InitVariableMem graph_%u MallocMemory type[V] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, var_mem_base_, TotalVarMemSize()); } - - runtime_param_.mem_base = mem_base_; - runtime_param_.weight_base = weights_mem_base_; runtime_param_.var_base = var_mem_base_; return SUCCESS; } @@ -618,11 +364,15 @@ void DavinciModel::InitRuntimeParams() { ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_VAR_SIZE, value); runtime_param_.var_size = ret ? (uint64_t)value : 0; session_id_ = runtime_param_.session_id; - GELOGI("InitRuntimeParams(), memory_size:%lu, weight_size:%lu, stream_num:%u, session_id:%u, var_size:%lu.", - runtime_param_.mem_size, runtime_param_.weight_size, runtime_param_.stream_num, runtime_param_.session_id, - runtime_param_.var_size); - GELOGI("InitRuntimeParams(), event_num:%u, label_num:%u", runtime_param_.event_num, runtime_param_.label_num); + GELOGI( + "InitRuntimeParams(), memory_size:%lu, weight_size:%lu, session_id:%u, var_size:%lu, logic_var_base:%lu, " + "logic_mem_base:%lu.", + runtime_param_.mem_size, runtime_param_.weight_size, runtime_param_.session_id, runtime_param_.var_size, + runtime_param_.logic_var_base, runtime_param_.logic_mem_base); + + GELOGI("InitRuntimeParams(), stream_num:%lu, event_num:%u, label_num:%u", runtime_param_.stream_num, + runtime_param_.event_num, runtime_param_.label_num); } void DavinciModel::CheckHasHcomOp() { @@ -639,7 +389,9 @@ void DavinciModel::CheckHasHcomOp() { GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGW("Node OpDesc is nullptr"); continue); GE_IF_BOOL_EXEC(((op_desc->GetType() == HCOMBROADCAST) || (op_desc->GetType() == HCOMALLGATHER) || (op_desc->GetType() == HCOMALLREDUCE) || (op_desc->GetType() == HCOMSEND) || - (op_desc->GetType() == HCOMRECEIVE) || (op_desc->GetType() == HCOMREDUCESCATTER)), + (op_desc->GetType() == HCOMRECEIVE) || (op_desc->GetType() == HCOMREDUCESCATTER) || + (op_desc->GetType() == HVDCALLBACKALLREDUCE) || (op_desc->GetType() == HVDCALLBACKALLGATHER) || + (op_desc->GetType() == HVDCALLBACKBROADCAST) || (op_desc->GetType() == HVDWAIT)), uint32_t stream_id = static_cast(op_desc->GetStreamId()); (void)hcom_streams_.emplace(stream_id); GELOGD("hcom stream: %u.", stream_id); continue); @@ -692,6 +444,10 @@ Status DavinciModel::DoTaskSink() { GELOGI("do task_sink."); GE_CHK_STATUS_RET(BindModelStream(), "Bind model stream failed."); + if (known_node_) { + GE_CHK_STATUS_RET(MallocKnownArgs(), "Mallloc known node args failed."); + } + GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def_.get()), "InitTaskInfo failed."); GE_CHK_STATUS_RET(LoadWithQueue(), "LoadWithQueue failed."); @@ -787,12 +543,14 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size GE_CHK_STATUS_RET(CopyVarData(compute_graph_), "copy var data failed."); GE_TIMESTAMP_START(InitModelMem); - GE_CHK_STATUS_RET_NOLOG(InitModelMem(dev_ptr, mem_size, weight_ptr, weight_size)); + GELOGI("known_node is %d", known_node_); + if (!known_node_) { + GE_CHK_STATUS_RET_NOLOG(InitModelMem(dev_ptr, mem_size, weight_ptr, weight_size)); + data_inputer_ = new (std::nothrow) DataInputer(); + GE_CHK_BOOL_RET_STATUS(data_inputer_ != nullptr, INTERNAL_ERROR, "data_inputer_ is nullptr."); + } GE_TIMESTAMP_END(InitModelMem, "GraphLoader::InitModelMem"); - data_inputer_ = new (std::nothrow) DataInputer(); - GE_CHK_BOOL_RET_STATUS(data_inputer_ != nullptr, INTERNAL_ERROR, "data_inputer_ is nullptr."); - for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); @@ -817,7 +575,6 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size } SetDataDumperArgs(); - GE_TIMESTAMP_START(DoTaskSink); auto ret = DoTaskSink(); GE_TIMESTAMP_END(DoTaskSink, "GraphLoader::DoTaskSink"); @@ -832,6 +589,7 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size } ProfilingManager::Instance().ReportProfilingData(GetTaskDescInfo(), compute_graph_desc_info); } + GELOGI("davinci model init success."); return ret; } @@ -935,6 +693,10 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { // op_desc Checked by Init: Data, valid. auto op_desc = node->GetOpDesc(); + if (known_node_) { + data_op_list_.push_back(op_desc); + return SUCCESS; + } uint32_t parent_index = 0; // Ignore subgraph Data Node. if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { GELOGI("Skip subgraph Data node: %s.", op_desc->GetName().c_str()); @@ -1015,6 +777,10 @@ Status DavinciModel::InitInputZeroCopyNodes(const NodePtr &node) { Status DavinciModel::InitNetOutput(const NodePtr &node) { // node->GetOpDesc Checked by Init: NetOutput, valid. auto op_desc = node->GetOpDesc(); + if (known_node_) { + output_op_list_.push_back(op_desc); + return SUCCESS; + } ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL(owner_graph); if (owner_graph->GetParentGraph() != nullptr) { @@ -1024,7 +790,6 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { } output_op_list_.push_back(op_desc); - // Make information for copy output data. const vector input_size_list = ModelUtils::GetInputSize(op_desc); const vector virtual_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc, false); @@ -1048,6 +813,7 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { GELOGE(PARAM_INVALID, "Output zero copy nodes init failed!"); return PARAM_INVALID; } + GELOGI("DavinciModel::InitNetoutput success."); return SUCCESS; } @@ -1605,7 +1371,9 @@ Status DavinciModel::GetOutputDescInfo(vector &output_desc, for (size_t i = 0; i < output_op_list_.size(); i++) { auto &op_desc = output_op_list_[i]; uint32_t out_size = static_cast(op_desc->GetInputsSize()); - + // get real out nodes from model + vector out_node_name; + (void)ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_OUT_NODES_NAME, out_node_name); for (uint32_t index = 0; index < out_size; index++) { string output_name; InputOutputDescInfo output; @@ -1616,10 +1384,14 @@ Status DavinciModel::GetOutputDescInfo(vector &output_desc, std::vector src_index = op_desc->GetSrcIndex(); GE_CHK_BOOL_RET_STATUS(src_name.size() > index && src_index.size() > index, INTERNAL_ERROR, "construct output_name failed."); - output_name = - std::string("output_") + std::to_string(index) + "_" + src_name[index] + "_" + std::to_string(src_index[index]); + // forward compatbility, if old om has no out_node_name, need to return output follow origin way + if (out_size == out_node_name.size()) { + output_name = out_node_name[index] + ":" + std::to_string(src_index[index]); + } else { + output_name = std::string("output_") + std::to_string(index) + "_" + src_name[index] + "_" + + std::to_string(src_index[index]); + } output.name = output_name; - output_desc.push_back(output); formats.push_back(format_result); } @@ -1653,8 +1425,8 @@ Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data "input data size(%u) does not match model required size(%u), ret failed.", data_buf.length, mem_size); - GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] memaddr[%p] mem_size[%u] datasize[%u]", - runtime_param_.graph_id, data.first, mem_addr, mem_size, data_buf.length); + GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] dst[%p] src[%p] mem_size[%u] datasize[%u]", + runtime_param_.graph_id, data.first, mem_addr, data_buf.data, mem_size, data_buf.length); if (data_buf.length == 0) { GELOGW("No data need to memcpy!"); return SUCCESS; @@ -2000,7 +1772,7 @@ Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data) { uint32_t output_data_index = 0; for (auto &op_desc : output_op_list_) { ret = CopyOutputDataToUser(op_desc, output_data.blobs, output_data_index); - GE_CHK_BOOL_EXEC(ret == SUCCESS, break, "Copy input data to model ret failed, index:%u, model id:%u", + GE_CHK_BOOL_EXEC(ret == SUCCESS, break, "Copy output data to model ret failed, index:%u, model id:%u", output_data.index, output_data.model_id); } } @@ -2032,8 +1804,10 @@ Status DavinciModel::CopyOutputDataToUser(OpDescPtr &op_desc, std::vectorGetName().c_str(), i, data_buf.data, data_buf.length, v_output_size[i]); + GELOGI( + "CopyOutputDataToUser memcpy graph_%u type[F] name[%s] output[%lu] dst[%p] src[%p] mem_size[%u] datasize[%u]", + runtime_param_.graph_id, op_desc->GetName().c_str(), i, data_buf.data, v_output_data_addr[i], data_buf.length, + v_output_size[i]); GE_CHK_RT_RET(rtMemcpy(data_buf.data, size, v_output_data_addr[i], size, RT_MEMCPY_DEVICE_TO_DEVICE)); } @@ -2104,14 +1878,9 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b OutputData *output_data) { GE_CHK_BOOL_EXEC(listener_ != nullptr, return PARAM_INVALID, "listener_ is null."); std::vector outputs; - if (seq_end_flag) { - GELOGW("End of sequence, model id: %u", model_id_); - GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, END_OF_SEQUENCE, outputs), "OnComputeDone failed"); - return END_OF_SEQUENCE; - } // return result is not required - if (!rslt_flg) { + if (!rslt_flg && !seq_end_flag) { GELOGW("Compute failed, model id: %u", model_id_); GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed."); return INTERNAL_ERROR; @@ -2146,7 +1915,11 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b } GE_IF_BOOL_EXEC((DumpOpInputOutput() != SUCCESS), GELOGW("dump op failed, model_id: %u", model_id_);); - + if (seq_end_flag) { + GELOGW("End of sequence, model id: %u", model_id_); + GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, END_OF_SEQUENCE, outputs), "OnCompute Done failed."); + return END_OF_SEQUENCE; + } GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS, outputs), "OnComputeDone failed"); return SUCCESS; } @@ -2493,6 +2266,87 @@ void DavinciModel::UnbindTaskSinkStream() { return; } +Status DavinciModel::CreateKnownZeroCopyMap(const vector &inputs, const vector &outputs) { + GELOGI("DavinciModel::CreateKnownZeroCopyMap in."); + if (inputs.size() != data_op_list_.size()) { + GELOGE(FAILED, "input data addr %u is not equal to input op number %u.", inputs.size(), data_op_list_.size()); + return FAILED; + } + for (size_t i = 0; i < data_op_list_.size(); ++i) { + const vector addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, data_op_list_[i]); + knonw_input_data_info_[addr_list[kDataIndex]] = inputs[i]; + GELOGI("DavinciModel::CreateKnownZeroCopyMap input %d,v addr %p,p addr %p .", i, addr_list[kDataIndex], inputs[i]); + } + if (output_op_list_.size() != kOutputNum) { + GELOGE(FAILED, "output op num is %u, not equal %u.", outputs.size(), kOutputNum); + return FAILED; + } + const vector addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, output_op_list_[kDataIndex]); + if (outputs.size() != addr_list.size()) { + GELOGE(FAILED, "output data addr %u is not equal to output op number %u.", outputs.size(), addr_list.size()); + return FAILED; + } + for (size_t i = 0; i < addr_list.size(); ++i) { + knonw_output_data_info_[addr_list[i]] = outputs[i]; + GELOGI("DavinciModel::CreateKnownZeroCopyMap output %d,v addr %p,p addr %p .", i, addr_list[i], outputs[i]); + } + GELOGI("DavinciModel::CreateKnownZeroCopyMap success."); + return SUCCESS; +} + +Status DavinciModel::UpdateKnownZeroCopyAddr(vector &io_addrs, uint32_t args_offset) { + for (size_t i = 0; i < io_addrs.size(); ++i) { + auto it_in = knonw_input_data_info_.find(io_addrs[i]); + if (it_in != knonw_input_data_info_.end()) { + GELOGI("DavinciModel::UpdateKnownZeroCopyAddr input %d,v addr %p,p addr %p .", i, io_addrs[i], + knonw_input_data_info_.at(io_addrs[i])); + io_addrs[i] = knonw_input_data_info_.at(io_addrs[i]); + } + auto it_out = knonw_output_data_info_.find(io_addrs[i]); + if (it_out != knonw_output_data_info_.end()) { + GELOGI("DavinciModel::UpdateKnownZeroCopyAddr output %d,v addr %p,p addr %p .", i, io_addrs[i], + knonw_output_data_info_.at(io_addrs[i])); + io_addrs[i] = knonw_output_data_info_.at(io_addrs[i]); + } + } + // may args_size is equal to src_args_size? + uint32_t src_args_size = io_addrs.size() * sizeof(uint64_t); + GELOGI("DavinciModel::UpdateKnownZeroCopyAddr args host %p, src_args_size %u, args_offset %u", args_host_, + src_args_size, args_offset); + errno_t sec_ret = + memcpy_s(static_cast(args_host_) + args_offset, src_args_size, io_addrs.data(), src_args_size); + if (sec_ret != EOK) { + GELOGE(FAILED, "Call memcpy_s failed, ret: %d", sec_ret); + return FAILED; + } + GELOGI("DavinciModel::UpdateKnownZeroCopyAddr success."); + return SUCCESS; +} + +Status DavinciModel::UpdateKnownNodeArgs(const vector &inputs, const vector &outputs) { + GELOGI("DavinciModel::UpdateKnownNodeArgs in"); + GE_CHK_STATUS_RET(CreateKnownZeroCopyMap(inputs, outputs), + "DavinciModel::UpdateKnownNodeArgs create map for input/output zero copy."); + for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) { + auto &task = task_list_[task_index]; + if (task != nullptr) { + Status ret = task->UpdateArgs(); + if (ret != SUCCESS) { + GELOGE(FAILED, "task %d created by davinci model is nullptr.", task_index); + return FAILED; + } + } + } + GELOGI("DavinciModel::UpdateKnownNodeArgs device args %p, size %u, host args %p, size %u", args_, total_args_size_, + args_host_, total_args_size_); + // copy continuous args from host to device + Status rt_ret = rtMemcpy(args_, total_args_size_, args_host_, total_args_size_, RT_MEMCPY_HOST_TO_DEVICE); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) + + GELOGI("DavinciModel::UpdateKnownNodeArgs success"); + return SUCCESS; +} + Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { GELOGI("InitTaskInfo in,task size %zu", model_task_def.task().size()); task_list_.resize(model_task_def.task_size()); @@ -2513,13 +2367,13 @@ Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { GELOGE(RT_FAILED, "Failed to set context from rt, error-code 0x%X.", rt_ret); return RT_FAILED; } - - model->task_list_[idx] = TaskInfoFactory::Instance().Create(static_cast(task.type())); - Status ret = FAILED; - if (model->task_list_[idx] != nullptr) { - ret = model->task_list_[idx]->Init(task, model); + // dynamic shape will create task_list_ before + if (model->task_list_[idx] == nullptr) { + model->task_list_[idx] = TaskInfoFactory::Instance().Create(static_cast(task.type())); + GE_CHECK_NOTNULL(model->task_list_[idx]); } + ret = model->task_list_[idx]->Init(task, model); return ret; }, model_task_def.task(i), this, ctx, i); @@ -2543,6 +2397,39 @@ Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { return SUCCESS; } +Status DavinciModel::MallocKnownArgs() { + GELOGI("DavinciModel::MallocKnownArgs in"); + if (model_task_def_->task_size() == 0) { + GELOGW("DavinciModel::MallocKnownArgs davincimodel has no task info."); + return SUCCESS; + } + task_list_.resize(model_task_def_->task_size()); + for (int32_t i = 0; i < model_task_def_->task_size(); ++i) { + const domi::TaskDef &taskdef = model_task_def_->task(i); + task_list_[i] = TaskInfoFactory::Instance().Create(static_cast(taskdef.type())); + GE_CHECK_NOTNULL(task_list_[i]); + Status ret = task_list_[i]->CalculateArgs(taskdef, this); + if (ret != SUCCESS) { + GELOGE(ret, "TaskInfo CalculateArgs failed."); + return ret; + } + } + // malloc args memory + rtError_t rt_ret = rtMalloc(&args_, total_args_size_, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + // malloc args host memory + rt_ret = rtMallocHost(&args_host_, total_args_size_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtMallocHost failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + GELOGI("DavinciModel::MallocKnownArgs success, total args size %u.", total_args_size_); + return SUCCESS; +} + Status DavinciModel::DistributeTask() { GELOGI("do Distribute."); for (auto &task : cpu_task_list_) { @@ -3117,7 +3004,7 @@ bool DavinciModel::IsBroadCastOpData(const ge::NodePtr &var_node) { GE_RT_FALSE_CHECK_NOTNULL(in_anchor); ge::NodePtr dst_node = in_anchor->GetOwnerNode(); GE_RT_FALSE_CHECK_NOTNULL(dst_node); - if (dst_node->GetType() == HCOMBROADCAST) { + if (dst_node->GetType() == HCOMBROADCAST || dst_node->GetType() == HVDCALLBACKBROADCAST) { return true; } } @@ -3126,32 +3013,15 @@ bool DavinciModel::IsBroadCastOpData(const ge::NodePtr &var_node) { } void DavinciModel::InitZeroCopyUtil(bool is_dynamic_batch, bool &input_zero_copy, bool &output_zero_copy) { - auto dump_path = PropertiesManager::Instance().GetDumpOutputPath(); - auto enable_dump = !dump_path.empty(); - - auto dump_op_env = std::getenv("DUMP_OP"); - if (dump_op_env != nullptr) { - string dump_op_flag(dump_op_env); - if (dump_op_flag == "1") { - enable_dump = true; - } - } - - GELOGI("dump path: %s, dump_op_env: %s", dump_path.c_str(), dump_op_env); if (!is_dynamic_batch) { zero_copy_batch_label_addrs_.clear(); } - if (enable_dump) { - input_zero_copy = false; - output_zero_copy = false; - } else { - for (const auto &addrs : output_outside_addrs_) { - const auto &used_list = addrs.second; - if (used_list.empty()) { - output_zero_copy = false; - break; - } + for (const auto &addrs : output_outside_addrs_) { + const auto &used_list = addrs.second; + if (used_list.empty()) { + output_zero_copy = false; + break; } } } @@ -3244,11 +3114,11 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa return SUCCESS; } -uint8_t *DavinciModel::MallocFeatureMapMem(uint64_t data_size) { +uint8_t *DavinciModel::MallocFeatureMapMem(size_t data_size) { uint8_t *mem_base = nullptr; const string purpose("feature map,used for op input and output."); if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { - data_size = static_cast(VarManager::Instance(0)->GetGraphMemoryMaxSize()); + data_size = static_cast(VarManager::Instance(0)->GetGraphMemoryMaxSize()); string memory_key = std::to_string(0) + "_f"; mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, data_size, GetDeviceId()); } else { @@ -3261,7 +3131,7 @@ uint8_t *DavinciModel::MallocFeatureMapMem(uint64_t data_size) { return mem_base; } -uint8_t *DavinciModel::MallocWeightsMem(uint32_t weights_size) { +uint8_t *DavinciModel::MallocWeightsMem(size_t weights_size) { uint8_t *weights_mem_base = nullptr; const string purpose("weights memory in inference network."); if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { @@ -3319,10 +3189,6 @@ uint32_t DavinciModel::GetGraphID(const std::string &session_graph_id) { Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id) { GELOGI("TransAllVarData start: session_id:%lu, graph_id: %u.", session_id_, graph_id); - - ThreadPool executor(kThreadNum); - std::vector> vector_future; - rtContext_t ctx = nullptr; rtError_t rt_ret = rtCtxGetCurrent(&ctx); if (rt_ret != RT_ERROR_NONE) { @@ -3330,6 +3196,7 @@ Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id) return RT_FAILED; } + std::vector variable_node_list; for (ge::NodePtr &node : graph->GetDirectNode()) { if (node == nullptr) { continue; @@ -3337,63 +3204,13 @@ Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id) if (node->GetType() != VARIABLE) { continue; } - std::future f = executor.commit( - [](ge::NodePtr &node, DavinciModel *model, rtContext_t ctx, uint32_t graph_id) -> Status { - if (model == nullptr) { - GELOGE(FAILED, "DavinciModel is NULL!"); - return FAILED; - } - rtError_t rt_ret = rtCtxSetCurrent(ctx); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Failed to set context, error_code is: 0x%X.", rt_ret); - return RT_FAILED; - } - uint32_t allocated_graph_id = 0; - Status ret = VarManager::Instance(model->session_id_)->GetAllocatedGraphId(node->GetName(), allocated_graph_id); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "var has not been allocated, node:%s, graph_id:%u.", node->GetName().c_str(), - graph_id); - return INTERNAL_ERROR; - } - uint32_t changed_graph_id = 0; - ret = VarManager::Instance(model->session_id_)->GetChangedGraphId(node->GetName(), changed_graph_id); - bool call_trans_var = - (ret == SUCCESS && changed_graph_id == graph_id && changed_graph_id != allocated_graph_id); - if (call_trans_var) { - GELOGI("VarManager::GetChangedGraphId() success, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id); - VarTransRoad *trans_road = VarManager::Instance(model->session_id_)->GetTransRoad(node->GetName()); - if (trans_road == nullptr) { - GELOGI("The variable %s does not have any trans road", node->GetName().c_str()); - return SUCCESS; - } - ret = TransVarData(node, *trans_road, model->session_id_); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "TransVarData failed, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id); - return INTERNAL_ERROR; - } - VarManager::Instance(model->session_id_)->RemoveChangedGraphId(node->GetName()); - } - return SUCCESS; - }, - node, this, ctx, graph_id); - if (!f.valid()) { - GELOGE(FAILED, "Future is invalid"); - return FAILED; - } - vector_future.push_back(std::move(f)); + variable_node_list.emplace_back(node); } - Status ret_status; - for (size_t i = 0; i < vector_future.size(); ++i) { - ret_status = vector_future[i].get(); - if (ret_status != SUCCESS) { - GELOGE(ret_status, "TransAllVarData:: trans %zu vardata failed", i); - return ret_status; - } - } + GE_CHK_STATUS_RET_NOLOG( + TransVarDataUtils::TransAllVarData(variable_node_list, session_id_, ctx, graph_id, kThreadNum)); GELOGI("TransAllVarData success."); - return SUCCESS; } @@ -3457,96 +3274,8 @@ void DavinciModel::ReuseHcclFollowStream(int64_t remain_cap, int64_t &index) { } } -Status TransTensor(uint8_t *var_data, const NodePtr &var_src, const NodePtr &var_dst, formats::TransResult &result) { - GE_CHECK_NOTNULL(var_src); - GE_CHECK_NOTNULL(var_src->GetOpDesc()); - GE_CHECK_NOTNULL(var_dst); - GE_CHECK_NOTNULL(var_dst->GetOpDesc()); - auto src_data_shape_size = var_src->GetOpDesc()->GetOutputDesc(0).GetShape().GetShapeSize(); - auto src_data_datatype = var_src->GetOpDesc()->GetOutputDesc(0).GetDataType(); - auto dst_data_datatype = var_dst->GetOpDesc()->GetOutputDesc(0).GetDataType(); - GE_IF_BOOL_EXEC( - src_data_datatype != dst_data_datatype, - auto ret = formats::TransDataType( - {var_data, static_cast(src_data_shape_size), src_data_datatype, dst_data_datatype}, result); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "trans var data on host failed"); - return ret; - }); - return SUCCESS; -} - -Status DavinciModel::CopyTensorFromSrcVarNode(const NodePtr &var_src, const NodePtr &var_dst) { - /// after FE fusion pass, input num of applymomentum op was changed, 0th input is var_fp32, 6th input is - /// var_fp16(new). - /// unlink edges between var_fp32 and "dst_node" (need fp16) of var_fp32, add edge between var_fp16 and dst_node. - /// need copy value from var_fp32 to var_fp16. - /// [opdesc of var_src and var_dst are checked before passed in, no need to check if they are nullptr] - GE_IF_BOOL_EXEC(var_src == nullptr || var_dst == nullptr, GELOGE(FAILED, "node var is nullptr"); return FAILED); - // src_node output_desc (fp32) - GeTensorDesc output_desc = var_src->GetOpDesc()->GetOutputDesc(0); - auto src_data_type = output_desc.GetDataType(); - auto src_shape = output_desc.GetShape(); - auto src_format = output_desc.GetFormat(); - GELOGI("src_node %s, src_format %s, src_shape %s, src_type %s", var_src->GetName().c_str(), - TypeUtils::FormatToSerialString(src_format).c_str(), formats::ShapeToString(src_shape).c_str(), - TypeUtils::DataTypeToSerialString(src_data_type).c_str()); - // dst_node output_desc (fp16) - GeTensorDesc dst_tensor_desc = var_dst->GetOpDesc()->GetOutputDesc(0); - auto data_type = dst_tensor_desc.GetDataType(); - auto data_shape = dst_tensor_desc.GetShape(); - auto data_format = dst_tensor_desc.GetFormat(); - GELOGI("dst_node %s, src_format %s, src_shape %s, src_type %s", var_dst->GetName().c_str(), - TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(), - TypeUtils::DataTypeToSerialString(data_type).c_str()); - // Sync var data from device - std::unique_ptr var_src_data; - RtContextSwitchGuard switch_context(RT_CTX_NORMAL_MODE, device_id_); - // copy from src_node - auto ret = CopyVarFromDevice(session_id_, var_src, var_src_data, output_desc); - GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "Copy Var From Device failed"); return ret); - // trans dtype - formats::TransResult trans_result; - ret = TransTensor(var_src_data.get(), var_src, var_dst, trans_result); - GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "trans var data on host failed"); return ret); - // reset src value. - void *var_device = nullptr; - ret = ReAssignVarAddr(session_id_, var_dst->GetName(), dst_tensor_desc, &var_device); - GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "assign mem failed"); return ret); - // copy to device - ret = CopyVarToDevice(var_dst, trans_result, var_device); - GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Failed to send var data to device"); return ret); - return SUCCESS; -} - Status DavinciModel::CopyVarData(ComputeGraphPtr &compute_graph) { - GELOGI("CopyVarData start: session_id:%lu.", session_id_); - if (compute_graph == nullptr) { - GELOGE(FAILED, "compute_graph is nullptr"); - return FAILED; - } - - string cp_from_node; - bool copy_value = false; - for (auto &node : compute_graph->GetAllNodes()) { - GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() != VARIABLE, continue); - GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), "_copy_from_var_node", cp_from_node), - GELOGI("Get original type of cp_from_node")); - if (cp_from_node.length() != 0) { - (void)ge::AttrUtils::GetBool(node->GetOpDesc(), "_copy_value", copy_value); // no need to check value - if (!copy_value) { - auto src_node = compute_graph->FindNode(cp_from_node); - GE_CHECK_NOTNULL(src_node); - GELOGI("current_var_node__: [%s] copy_from_var_node__: [%s].", node->GetName().c_str(), - src_node->GetName().c_str()); - auto ret = CopyTensorFromSrcVarNode(src_node, node); - GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "copy tensor failed!"); return FAILED); - // only copy once - (void)ge::AttrUtils::SetBool(node->GetOpDesc(), "_copy_value", true); // no need to check value - } - } - } - return SUCCESS; + return TransVarDataUtils::CopyVarData(compute_graph, session_id_, device_id_); } Status DavinciModel::GetComputeGraphInfo(std::vector &compute_graph_desc_info) { diff --git a/src/ge/graph/load/new_model_manager/davinci_model.h b/src/ge/graph/load/new_model_manager/davinci_model.h index 25cb0a3a..cd532923 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.h +++ b/src/ge/graph/load/new_model_manager/davinci_model.h @@ -250,6 +250,11 @@ class DavinciModel { return res; } + rtStream_t GetRtModelStream() { + rtModel_t res = rt_model_stream_; + return res; + } + uint64_t GetRtBaseAddr() const { return runtime_param_.logic_mem_base; } uint64_t GetRtWeightAddr() const { return runtime_param_.logic_weight_base; } @@ -427,6 +432,26 @@ class DavinciModel { void CreateHcclFollowStream(rtStream_t stream, int64_t remain_cap); void ReuseHcclFollowStream(int64_t remain_cap, int64_t &index); + void InitRuntimeParams(); + Status InitVariableMem(); + + void UpdateMemBase(uint8_t *mem_base) { + runtime_param_.mem_base = mem_base; + mem_base_ = mem_base; + } + void SetTotalArgsSize(uint32_t args_size) { total_args_size_ += args_size; } + uint32_t GetTotalArgsSize() { return total_args_size_; } + void *GetCurrentArgsAddr(uint32_t offset) { + void *cur_args = static_cast(args_) + offset; + return cur_args; + } + void SetKnownNode(bool known_node) { known_node_ = known_node; } + bool IsKnownNode() { return known_node_; } + Status MallocKnownArgs(); + Status UpdateKnownNodeArgs(const vector &inputs, const vector &outputs); + Status CreateKnownZeroCopyMap(const vector &inputs, const vector &outputs); + Status UpdateKnownZeroCopyAddr(vector &io_addrs, uint32_t args_offset); + private: // memory address of weights uint8_t *weights_mem_base_; @@ -523,9 +548,9 @@ class DavinciModel { Status DistributeTask(); - uint8_t *MallocFeatureMapMem(uint64_t data_size); + uint8_t *MallocFeatureMapMem(size_t data_size); - uint8_t *MallocWeightsMem(uint32_t weights_size); + uint8_t *MallocWeightsMem(size_t weights_size); void FreeFeatureMapMem(); @@ -690,8 +715,6 @@ class DavinciModel { /// Status CpuModelRepeat(); - void InitRuntimeParams(); - /// /// @ingroup ge /// @brief set ts device. @@ -709,7 +732,6 @@ class DavinciModel { Status TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id); Status CopyVarData(ComputeGraphPtr &graph); - Status CopyTensorFromSrcVarNode(const NodePtr &var_src, const NodePtr &var_dst); // get desc info of graph for profiling Status GetComputeGraphInfo(vector &compute_graph_desc_info); @@ -827,6 +849,13 @@ class DavinciModel { DataDumper data_dumper_; uint64_t iterator_count_; bool is_l1_fusion_enable_; + + bool known_node_ = false; + uint32_t total_args_size_ = 0; + void *args_ = nullptr; + void *args_host_ = nullptr; + std::map knonw_input_data_info_; + std::map knonw_output_data_info_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ diff --git a/src/ge/graph/load/new_model_manager/model_manager.cc b/src/ge/graph/load/new_model_manager/model_manager.cc index b7bd3deb..8171498a 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.cc +++ b/src/ge/graph/load/new_model_manager/model_manager.cc @@ -24,6 +24,7 @@ #include "framework/common/debug/ge_log.h" #include "graph/load/new_model_manager/davinci_model.h" #include "graph/load/new_model_manager/davinci_model_parser.h" +#include "model/ge_root_model.h" namespace ge { thread_local uint32_t device_count = 0; @@ -68,8 +69,6 @@ Status ModelManager::KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, u GE_CHK_RT(rtFree(aicpu_kernel_addr)); return FAILED;) uint64_t kernel_id_addr = static_cast(reinterpret_cast(aicpu_kernel_addr)); param_base.fwkKernelBase.fwk_kernel.kernelID = kernel_id_addr; - // Remove model key from map - model_aicpu_kernel_.erase(iter); } } @@ -214,18 +213,38 @@ Status ModelManager::SetDevice(int32_t deviceId) const { return SUCCESS; } +ge::Status ModelManager::DoLoadHybridModelOnline(uint32_t model_id, const shared_ptr &ge_root_model, + const shared_ptr &listener) { + auto hybrid_model = hybrid::HybridDavinciModel::Create(ge_root_model); + GE_CHECK_NOTNULL(hybrid_model); + hybrid_model->SetListener(listener); + hybrid_model->SetModelId(model_id); + hybrid_model->SetDeviceId(GetContext().DeviceId()); + GE_CHK_STATUS_RET(hybrid_model->Init(), "Failed to init hybrid model. model_id = %u", model_id); + auto shared_model = std::shared_ptr(hybrid_model.release()); + InsertModel(model_id, shared_model); + return SUCCESS; +} + /// /// @ingroup domi_ome /// @brief load model online /// @return Status run result /// -Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr &ge_model, +Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr &ge_root_model, std::shared_ptr listener) { GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "Param incorrect, listener is null"); if (model_id == INVALID_MODEL_ID) { GenModelId(&model_id); } + bool is_shape_unknown = false; + GE_CHK_STATUS_RET(ge_root_model->CheckIsUnknownShape(is_shape_unknown), "CheckIsUnknownShape failed, model id:%u", + model_id); + if (is_shape_unknown) { + return DoLoadHybridModelOnline(model_id, ge_root_model, listener); + } + GE_CHK_STATUS_RET(SetDevice(static_cast(GetContext().DeviceId())), "Set device failed, model id:%u.", model_id); mmTimespec timespec = mmGetTickCount(); @@ -238,6 +257,11 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrSetId(model_id); davinci_model->SetDeviceId(GetContext().DeviceId()); + auto root_graph = ge_root_model->GetRootGraph(); + GE_CHECK_NOTNULL(root_graph); + string root_model_name = root_graph->GetName(); + auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); + GeModelPtr ge_model = name_to_model[root_model_name]; Status ret = SUCCESS; do { GE_TIMESTAMP_START(Assign); @@ -274,16 +298,26 @@ void ModelManager::InsertModel(uint32_t id, std::shared_ptr &davin model_map_[id] = davinci_model; } +void ModelManager::InsertModel(uint32_t id, shared_ptr &hybrid_model) { + GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", id); + std::lock_guard lock(map_mutex_); + hybrid_model_map_[id] = hybrid_model; +} + Status ModelManager::DeleteModel(uint32_t id) { std::lock_guard lock(map_mutex_); auto it = model_map_.find(id); - if (it == model_map_.end()) { + auto hybrid_model_it = hybrid_model_map_.find(id); + if (it != model_map_.end()) { + (void)model_map_.erase(it); + } else if (hybrid_model_it != hybrid_model_map_.end()) { + (void)hybrid_model_map_.erase(hybrid_model_it); + } else { GELOGE(PARAM_INVALID, "model id %u does not exists.", id); return PARAM_INVALID; } - (void)model_map_.erase(it); return SUCCESS; } @@ -294,6 +328,13 @@ std::shared_ptr ModelManager::GetModel(uint32_t id) { return (it == model_map_.end()) ? nullptr : it->second; } +std::shared_ptr ModelManager::GetHybridModel(uint32_t id) { + std::lock_guard lock(map_mutex_); + + auto it = hybrid_model_map_.find(id); + return (it == hybrid_model_map_.end()) ? nullptr : it->second; +} + Status ModelManager::Unload(uint32_t model_id) { GE_CHK_STATUS_RET(DeleteModel(model_id), "failed to unload model id: %u", model_id); if (device_count > 0) { @@ -349,7 +390,10 @@ Status ModelManager::DataInput(const InputData &input_data, OutputData &output_d /// Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector &inputs) { std::shared_ptr model = GetModel(model_id); - GE_CHECK_NOTNULL(model); + auto hybrid_model = GetHybridModel(model_id); + if (hybrid_model == nullptr) { + GE_CHECK_NOTNULL(model); + } InputData input_data; input_data.model_id = model_id; @@ -374,6 +418,12 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vectorInit(input_data, output_data), return domi::PUSH_DATA_FAILED, "Init InputDataWrapper failed,input data model_id is : %u.", model_id); + if (hybrid_model != nullptr) { + GE_CHK_STATUS_RET(hybrid_model->EnqueueData(data_wrap), "Data queue is full, please call again later, model_id %u ", + model_id); + return SUCCESS; + } + GE_CHK_BOOL_RET_STATUS(model != nullptr, PARAM_INVALID, "Invalid Model ID %u in InputData! ", model_id); DataInputer *inputer = model->GetDataInputer(); @@ -395,6 +445,13 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vectorModelRunStart()); + GELOGI("Start hybrid model %u success.", model_id); + return SUCCESS; + } + std::shared_ptr davinci_model = GetModel(model_id); GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid Model ID %u to start! ", model_id); @@ -416,6 +473,13 @@ Status ModelManager::Start(uint32_t model_id) { /// @author /// Status ModelManager::Stop(uint32_t model_id) { + auto hybrid_model = GetHybridModel(model_id); + if (hybrid_model != nullptr) { + GE_CHK_STATUS_RET_NOLOG(hybrid_model->ModelRunStop()); + GELOGI("Stop hybrid model %u success.", model_id); + return SUCCESS; + } + std::shared_ptr davinci_model = GetModel(model_id); GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid Model ID %u to stop!", model_id); @@ -581,6 +645,13 @@ Status ModelManager::HandleDumpCommand(const Command &command) { } Status ModelManager::GetMaxUsedMemory(const uint32_t model_id, uint64_t &max_size) { + auto hybrid_model = GetHybridModel(model_id); + if (hybrid_model != nullptr) { + // TODO hybrid use dynamic memory allocation + max_size = 0; + return SUCCESS; + } + std::shared_ptr davinci_model = GetModel(model_id); GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "GetMaxUsedMemory Failed, Invalid Model ID %u !", model_id); diff --git a/src/ge/graph/load/new_model_manager/model_manager.h b/src/ge/graph/load/new_model_manager/model_manager.h index ae73c1ce..b79f388a 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.h +++ b/src/ge/graph/load/new_model_manager/model_manager.h @@ -25,6 +25,7 @@ #include #include #include +#include #include "cce/aicpu_engine_struct.h" #include "common/ge_inner_error_codes.h" #include "common/ge_types.h" @@ -34,10 +35,10 @@ #include "ge/ge_api_types.h" #include "graph/ge_context.h" #include "graph/model.h" +#include "hybrid/hybrid_davinci_model.h" #include "runtime/base.h" namespace ge { - class DavinciModel; class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { @@ -69,9 +70,12 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// @return Status run result /// @author @ /// - ge::Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr &model, + ge::Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr &ge_root_model, std::shared_ptr listener); + ge::Status DoLoadHybridModelOnline(uint32_t model_id, const shared_ptr &ge_root_model, + const std::shared_ptr &listener); + /// /// @ingroup ge /// @brief ACL case, Load task list with queue. @@ -206,6 +210,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// std::shared_ptr GetModel(uint32_t id); + std::shared_ptr GetHybridModel(uint32_t id); + ge::Status KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, uint64_t session_id, uint32_t model_id); ge::Status CreateAicpuSession(uint64_t session_id); @@ -238,6 +244,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// @brief insert new model into model manager set /// void InsertModel(uint32_t id, std::shared_ptr &davinci_model); + void InsertModel(uint32_t id, std::shared_ptr &hybrid_model); /// /// @ingroup domi_ome @@ -248,6 +255,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { void GenModelId(uint32_t *id); std::map> model_map_; + std::map> hybrid_model_map_; std::map> model_aicpu_kernel_; uint32_t max_model_id_; std::mutex map_mutex_; diff --git a/src/ge/graph/load/new_model_manager/model_utils.cc b/src/ge/graph/load/new_model_manager/model_utils.cc index c372f528..a807f2a3 100644 --- a/src/ge/graph/load/new_model_manager/model_utils.cc +++ b/src/ge/graph/load/new_model_manager/model_utils.cc @@ -474,7 +474,7 @@ vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param int64_t workspace_bytes = v_workspace_bytes[i]; uint8_t *mem_addr = workspace_bytes == 0 ? nullptr : mem_base + workspace_offset; v_workspace_data_addr.push_back(mem_addr); - GELOGI("[IMAS]GetWorkspaceDataAddrs graph_%u type[F] name[%s] output[%zu] offset[%ld] bytes[%ld] memaddr[%p]", + GELOGI("[IMAS]GetWorkspaceDataAddrs graph_%u type[F] name[%s] workspace[%zu] offset[%ld] bytes[%ld] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, workspace_offset, workspace_bytes, mem_addr); } } diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc index 77825991..8529b90f 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc @@ -37,17 +37,12 @@ HcclTaskInfo::~HcclTaskInfo() { if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtFree Fail, ret = 0x%X.", ret); } - private_def_ = nullptr; } - input_data_addr_ = nullptr; davinci_model_ = nullptr; ops_kernel_store_ = nullptr; - output_data_addr_ = nullptr; - workspace_addr_ = nullptr; max_node_of_hccl_stream_ = 0; } - Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { GELOGI("HcclTaskInfo Init Start."); if (davinci_model == nullptr) { @@ -55,63 +50,75 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m return PARAM_INVALID; } davinci_model_ = davinci_model; - Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); if (ret != SUCCESS) { return ret; } - GetPrivateDefByTaskDef(task_def); - auto hccl_def = task_def.kernel_hccl(); - hcclDataType_t data_type; - int32_t count; uint32_t op_index = hccl_def.op_index(); GELOGI("HcclTaskInfo Init, op_index is: %u", op_index); - std::string hccl_type = hccl_def.hccl_type(); // Get HCCL op OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); GE_CHECK_NOTNULL(op_desc); - Status dmrt = HcomOmeUtil::GetHcomDataType(op_desc, data_type); + // Create the kernel hccl infos + CreateKernelHcclInfo(op_desc); + + // Initialize the hccl_type of all kernel hccl info + HcomOmeUtil::GetHcclType(task_def, kernel_hccl_infos_); + + // Only in Horovod scenario should get the inputName and GeShape + ret = HcomOmeUtil::GetHorovodInputs(op_desc, kernel_hccl_infos_); + if (ret != SUCCESS) { + GELOGE(FAILED, "davinci_model: GetHorovodInputs fail! domi error: %u", ret); + return FAILED; + } + Status dmrt = HcomOmeUtil::GetHcclDataType(op_desc, kernel_hccl_infos_); if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomDataType fail! domi error: %u", dmrt); return FAILED; } - - dmrt = HcomOmeUtil::GetHcomCount(op_desc, data_type, (hccl_type == HCOMALLGATHER), count); + dmrt = HcomOmeUtil::GetHcclCount(op_desc, kernel_hccl_infos_); if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomCount fail! domi error: %u", dmrt); return FAILED; } - - ret = SetAddrs(hccl_type, op_desc); + // Only HCOMBROADCAST and HVDCALLBACKBROADCAST need to get the rootId + dmrt = HcomOmeUtil::GetAllRootId(op_desc, kernel_hccl_infos_); + if (dmrt != SUCCESS) { + GELOGE(FAILED, "davinci_model: Get rootId fail! domi error: %u", dmrt); + return FAILED; + } + ret = SetAddrs(op_desc, kernel_hccl_infos_); if (ret != SUCCESS) { GELOGE(ret, "Setaddrs Fail."); return ret; } - - count_ = count; - hccl_type_ = hccl_type; - data_type_ = data_type; - // GE's new process: hccl declares the need for Workspace size, and GE allocates Workspace - auto workspace_bytes = op_desc->GetWorkspaceBytes(); - if (!workspace_bytes.empty()) { - uint64_t workspace_mem_size_tmp = workspace_bytes[0]; - GELOGI("hccl need workSpaceMemSize=%lu", workspace_mem_size_tmp); - if (workspace_mem_size_tmp != 0) { - workspace_mem_size_ = workspace_mem_size_tmp; - vector workspace_data_addrs = - ModelUtils::GetWorkspaceDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - if (!workspace_data_addrs.empty()) { - GELOGI("Get workSpaceAddr"); - workspace_addr_ = workspace_data_addrs[0]; - } - } + ret = SetWorkspace(op_desc, kernel_hccl_infos_); + if (ret != SUCCESS) { + GELOGE(ret, "SetWorkspace Fail."); + return ret; } // GE's new process: hccl declares the number of streams required, creates a stream by GE, and sends it to hccl + ret = SetFollowStream(op_desc, davinci_model); + if (ret != SUCCESS) { + GELOGE(ret, "SetStream Fail."); + return ret; + } + + GELOGI("HcclTaskInfo Init Success"); + return SUCCESS; +} + +Status HcclTaskInfo::SetFollowStream(const ge::ConstOpDescPtr &op_desc, DavinciModel *davinci_model) { + if (!HcomOmeUtil::IsHCOMOp(op_desc->GetType())) { + GELOGI("Node %s Optye %s no need to create slave streams.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); + return SUCCESS; + } + Status ret; int64_t hccl_stream_num = 0; if (!ge::AttrUtils::GetInt(op_desc, "used_stream_num", hccl_stream_num)) { GELOGI("op_desc has no attr used_stream_num!"); @@ -142,8 +149,7 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m return FAILED; } } - - GELOGI("HcclTaskInfo Init Success, hcclStreamNum =%ld", hccl_stream_num); + GELOGI("Initialize hccl slave stream success, hcclStreamNum =%ld", hccl_stream_num); return SUCCESS; } @@ -167,14 +173,12 @@ Status HcclTaskInfo::CreateStream(int64_t stream_num, DavinciModel *davinci_mode GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_FAILED; } - // Create slave stream, inactive by default, activated by hccl rt_ret = rtModelBindStream(davinci_model->GetRtModelHandle(), stream, RT_MODEL_WAIT_ACTIVE_STREAM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_FAILED; } - GELOGD("hccl_stream addr is=%p", stream); int64_t remain_cap = max_node_of_hccl_stream_ - 1; davinci_model->CreateHcclFollowStream(stream, remain_cap); @@ -192,7 +196,6 @@ Status HcclTaskInfo::Distribute() { GELOGE(INTERNAL_ERROR, "ops kernel store is null."); return INTERNAL_ERROR; } - OpsKernelInfoStore *ops_kernel_info_store = reinterpret_cast(ops_kernel_store_); GE_CHECK_NOTNULL(ops_kernel_info_store); GETaskInfo ge_task; @@ -202,81 +205,62 @@ Status HcclTaskInfo::Distribute() { GELOGE(INTERNAL_ERROR, "davinci_model : load task fail, return ret: %u", result); return INTERNAL_ERROR; } - GELOGI("HcclTaskInfo Distribute Success."); return SUCCESS; } - -Status HcclTaskInfo::SetAddrs(const std::string &hccl_type, const std::shared_ptr &op_desc) { +Status HcclTaskInfo::SetAddrs(const std::shared_ptr &op_desc, + std::vector &kernel_hccl_infos) { + GE_CHECK_NOTNULL(op_desc); + if (HcomOmeUtil::CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { + GELOGE(PARAM_INVALID, "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); + return PARAM_INVALID; + } + GELOGI("Set hccl task input output address, node[%s}, type[%s] kernel_hccl_infos.size[%zu].", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), kernel_hccl_infos.size()); + if (op_desc->GetType() == HVDWAIT) { + return SUCCESS; + } domi::Status dmrt; - hcclRedOp_t op_type; + hcclRedOp_t op_type = HCCL_REP_OP_SUM; GE_CHECK_NOTNULL(davinci_model_); + GELOGI("Calc opType[%s] input address before. Node name[%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); auto input_data_addr_list = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); - if (!input_data_addr_list.empty()) { - input_data_addr_ = input_data_addr_list[0]; - } - void *output_data_addr = nullptr; auto output_data_addr_list = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); - if (!output_data_addr_list.empty()) { - output_data_addr = output_data_addr_list[0]; - } - - if (hccl_type == HCOMBROADCAST) { - int64_t root_id; - dmrt = HcomOmeUtil::GetHcomRootId(op_desc, root_id); - if (dmrt != SUCCESS) { - GELOGE(FAILED, "davinci_model: GetHcomRootId fail! domi error: %u", dmrt); - return FAILED; - } - root_id_ = root_id; - } else if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE) { - output_data_addr_ = output_data_addr; - } else if (hccl_type == HCOMALLREDUCE) { - dmrt = HcomOmeUtil::GetHcomOperationType(op_desc, op_type); - if (dmrt != SUCCESS) { - GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); - return FAILED; - } - - output_data_addr_ = output_data_addr; - op_type_ = op_type; - } else if (hccl_type == HCOMREDUCESCATTER) { - dmrt = HcomOmeUtil::GetHcomOperationType(op_desc, op_type); - if (dmrt != SUCCESS) { - GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); - return FAILED; + // initialize every kernel_hccl_info inputDataAddr + for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { + std::string hccl_type = kernel_hccl_infos[i].hccl_type; + void *input_data_addr = input_data_addr_list.empty() ? nullptr : input_data_addr_list[i]; + kernel_hccl_infos[i].inputDataAddr = input_data_addr; + + void *output_data_addr = output_data_addr_list.empty() ? nullptr : output_data_addr_list[i]; + if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) { + kernel_hccl_infos[i].outputDataAddr = output_data_addr; + } else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE) { + dmrt = HcomOmeUtil::GetHcclOperationType(op_desc, op_type); + if (dmrt != SUCCESS) { + GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); + return FAILED; + } + kernel_hccl_infos[i].outputDataAddr = output_data_addr; + kernel_hccl_infos[i].opType = op_type; } - - output_data_addr_ = output_data_addr; - op_type_ = op_type; + davinci_model_->DisableZeroCopy(input_data_addr); } - - davinci_model_->DisableZeroCopy(input_data_addr_); return SUCCESS; } - void HcclTaskInfo::TransToGETaskInfo(GETaskInfo &ge_task) { ge_task.id = id_; ge_task.type = static_cast(RT_MODEL_TASK_HCCL); ge_task.stream = stream_; - - ge_task.kernelHcclInfo.hccl_type = hccl_type_; - ge_task.kernelHcclInfo.inputDataAddr = input_data_addr_; - ge_task.kernelHcclInfo.outputDataAddr = output_data_addr_; - ge_task.kernelHcclInfo.workSpaceAddr = workspace_addr_; - ge_task.kernelHcclInfo.count = count_; - ge_task.kernelHcclInfo.dataType = static_cast(data_type_); - ge_task.kernelHcclInfo.opType = static_cast(op_type_); - ge_task.kernelHcclInfo.rootId = root_id_; - ge_task.kernelHcclInfo.workSpaceMemSize = workspace_mem_size_; - ge_task.kernelHcclInfo.hcclStreamList = hccl_stream_list_; - + ge_task.kernelHcclInfo = kernel_hccl_infos_; ge_task.privateDef = private_def_; ge_task.privateDefLen = private_def_len_; ge_task.opsKernelStorePtr = ops_kernel_store_; + for (size_t i = 0; i < ge_task.kernelHcclInfo.size(); i++) { + ge_task.kernelHcclInfo[i].hcclStreamList = hccl_stream_list_; + } } - void HcclTaskInfo::GetPrivateDefByTaskDef(const domi::TaskDef &task) { // Get privateDef and opsKernelStorePtr from taskDef and save them in taskInfo GELOGI("get custom info in modelTaskDef."); @@ -299,11 +283,54 @@ void HcclTaskInfo::GetPrivateDefByTaskDef(const domi::TaskDef &task) { GELOGE(RT_FAILED, "Call rtMemcpy Fail, ret = 0x%X.", ret); return; } - GELOGI("The first address of the custom info, privateDef=%p.", private_def_); } } } - +void HcclTaskInfo::CreateKernelHcclInfo(const ge::ConstOpDescPtr &op_desc) { + GE_CHECK_NOTNULL_JUST_RETURN(op_desc); + if (HcomOmeUtil::IsHCOMOp(op_desc->GetType())) { + GETaskKernelHcclInfo kernel_hccl_info; + kernel_hccl_infos_.emplace_back(kernel_hccl_info); + } else if (HcomOmeUtil::IsHorovodOp(op_desc->GetType())) { + // Horovod wait do not have any input, but create a GETaskKernelHcclInfo to record hccl_type. + // Other Operator need to check that the number of GETaskKernelHcclInfo must equals to number of inputs + if (op_desc->GetType() == HVDWAIT) { + GETaskKernelHcclInfo kernel_hccl_info; + kernel_hccl_infos_.emplace_back(kernel_hccl_info); + return; + } + for (size_t i = 0; i < op_desc->GetInputsSize(); i++) { + GETaskKernelHcclInfo kernel_hccl_info; + kernel_hccl_infos_.emplace_back(kernel_hccl_info); + } + } +} +Status HcclTaskInfo::SetWorkspace(const std::shared_ptr &op_desc, + std::vector &kernel_hccl_infos) { + GE_CHECK_NOTNULL(op_desc); + GELOGI("SetWorkspace Node[%s] opType[%s] set workspace.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); + uint64_t workspace_mem_size = 0; + void *workspace_addr = nullptr; + auto workspace_bytes = op_desc->GetWorkspaceBytes(); + if (!workspace_bytes.empty()) { + uint64_t workspace_mem_size_tmp = workspace_bytes[0]; + GELOGI("hccl need workSpaceMemSize=%lu", workspace_mem_size_tmp); + if (workspace_mem_size_tmp != 0) { + workspace_mem_size = workspace_mem_size_tmp; + vector workspace_data_addrs = + ModelUtils::GetWorkspaceDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + if (!workspace_data_addrs.empty()) { + GELOGI("Get workSpaceAddr"); + workspace_addr = workspace_data_addrs[0]; + } + } + } + for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { + kernel_hccl_infos[i].workSpaceMemSize = workspace_mem_size; + kernel_hccl_infos[i].workSpaceAddr = workspace_addr; + } + return SUCCESS; +} REGISTER_TASK_INFO(RT_MODEL_TASK_HCCL, HcclTaskInfo); } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h index be033fac..bb0a88de 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h @@ -18,9 +18,9 @@ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_HCCL_TASK_INFO_H_ #include +#include #include #include -#include #include "common/opskernel/ge_task_info.h" #include "graph/load/new_model_manager/task_info/task_info.h" @@ -30,16 +30,7 @@ class HcclTaskInfo : public TaskInfo { public: HcclTaskInfo() : davinci_model_(nullptr), - hccl_type_(""), - input_data_addr_(nullptr), - output_data_addr_(nullptr), - count_(0), - data_type_(HCCL_DATA_TYPE_INT8), - op_type_(HCCL_REP_OP_SUM), - root_id_(0), id_(0), - workspace_addr_(nullptr), - workspace_mem_size_(0), hccl_stream_list_(), ops_kernel_store_(nullptr), private_def_(nullptr), @@ -56,6 +47,8 @@ class HcclTaskInfo : public TaskInfo { private: ge::Status SetAddrs(const std::string &hccl_type, const std::shared_ptr &op); + Status SetAddrs(const std::shared_ptr &op_desc, std::vector &kernel_hccl_infos); + void TransToGETaskInfo(GETaskInfo &ge_task); void GetPrivateDefByTaskDef(const domi::TaskDef &task); @@ -64,23 +57,21 @@ class HcclTaskInfo : public TaskInfo { ge::Status CreateStream(int64_t stream_num, DavinciModel *davinci_model); + Status SetFollowStream(const ge::ConstOpDescPtr &op_desc, DavinciModel *davinci_model); + + void CreateKernelHcclInfo(const ge::ConstOpDescPtr &op_desc); + + Status SetWorkspace(const std::shared_ptr &op_desc, std::vector &kernel_hccl_infos); + DavinciModel *davinci_model_; - string hccl_type_; - void *input_data_addr_; - void *output_data_addr_; - int32_t count_; - hcclDataType_t data_type_; - hcclRedOp_t op_type_; - int64_t root_id_; uint32_t id_; - void *workspace_addr_; - uint64_t workspace_mem_size_; vector hccl_stream_list_; void *ops_kernel_store_; void *private_def_; uint32_t private_def_len_; static std::mutex hccl_follow_stream_mutex_; static uint32_t max_node_of_hccl_stream_; + vector kernel_hccl_infos_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_HCCL_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc index 085b3ab4..635fec5d 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc @@ -51,17 +51,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin GELOGE(INTERNAL_ERROR, "Init aicpu task info error, index is out of range!"); return INTERNAL_ERROR; } - - if (CopyTaskInfo(kernel_ex_def, rts_param, op_desc) != SUCCESS) { - GELOGE(FAILED, "copy task info to workspace failed."); - return FAILED; - } - - const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); - if (workspace_data_addrs.empty()) { - GELOGE(FAILED, "workspace_data_addrs is empty."); - return FAILED; - } + op_desc_ = op_desc; // 2. Reconstruct kernelExDef.args to STR_FWK_OP_KERNEL STR_FWK_OP_KERNEL fwk_op_kernel = {0}; @@ -87,7 +77,52 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin } } + auto session_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.sessionID; + // 2.2 Collect aicpu kernel + uint64_t kernel_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.kernelID; + GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuKernel(session_id, davinci_model->Id(), kernel_id) != SUCCESS, + GELOGE(FAILED, "CreateAicpuKernel error."); + return FAILED;) + + kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); + if (davinci_model_->IsKnownNode()) { + void *input_output_addr = davinci_model_->GetCurrentArgsAddr(args_offset_); + fwk_op_kernel.fwkKernelBase.fwk_kernel.inputOutputAddr = + static_cast(reinterpret_cast(input_output_addr)); + void *workspace_base_addr = nullptr; + rtError_t rt_ret = rtMalloc(&workspace_base_addr, kernel_ex_def.task_info_size(), RT_MEMORY_HBM); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc error, ret: Ox%X", rt_ret); return FAILED;); + rt_ret = rtMemcpy(workspace_base_addr, kernel_ex_def.task_info_size(), kernel_ex_def.task_info().data(), + kernel_ex_def.task_info_size(), RT_MEMCPY_HOST_TO_DEVICE); + fwk_op_kernel.fwkKernelBase.fwk_kernel.workspaceBaseAddr = + static_cast(reinterpret_cast(workspace_base_addr)); + fwk_op_kernel.fwkKernelBase.fwk_kernel.stepIDAddr = step_id_addr; + fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoNum = 0; + fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = 0; + + rt_ret = rtMalloc(&kernel_buf_, kernel_buf_size_, RT_MEMORY_HBM); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc error: 0x%X", rt_ret); return FAILED;) + + rt_ret = rtMemcpy(kernel_buf_, kernel_buf_size_, static_cast(&fwk_op_kernel), kernel_buf_size_, + RT_MEMCPY_HOST_TO_DEVICE); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) + + GELOGI("KernelExTaskInfo knonw node Init Success."); + return SUCCESS; + } + // 3. Set workspaceaddr, inputOutputDataAddr + if (CopyTaskInfo(kernel_ex_def, rts_param, op_desc) != SUCCESS) { + GELOGE(FAILED, "copy task info to workspace failed."); + return FAILED; + } + + const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); + if (workspace_data_addrs.empty()) { + GELOGE(FAILED, "workspace_data_addrs is empty."); + return FAILED; + } + uint64_t workspace_base_addr = reinterpret_cast(reinterpret_cast(workspace_data_addrs[0])); const vector input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); const vector output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); @@ -106,8 +141,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; - dump_args_ = - reinterpret_cast(reinterpret_cast(input_output_addr_) + sizeof(void *) * input_addrs.size()); + dump_args_ = input_output_addr_; } } @@ -119,16 +153,10 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = 0; // 4. Create session - auto session_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.sessionID; GE_CHECK_NOTNULL(ModelManager::GetInstance()); GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuSession(session_id) != SUCCESS, GELOGE(FAILED, "CreateAicpuSession error. session id: %lu", session_id); return FAILED;) - // 4.1 Collect aicpu kernel - uint64_t kernel_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.kernelID; - GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuKernel(session_id, davinci_model->Id(), kernel_id) != SUCCESS, - GELOGE(FAILED, "CreateAicpuKernel error."); - return FAILED;) // 5. Return result rtError_t rt_ret = rtMalloc(&kernel_buf_, sizeof(STR_FWK_OP_KERNEL), RT_MEMORY_HBM); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc error: 0x%X", rt_ret); return FAILED;) @@ -144,12 +172,46 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, io_addrs.data(), input_output_addr_, addrs_size, 0); - kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); - GELOGI("KernelExTaskInfo Init Success. session id: %lu", session_id); return SUCCESS; } +Status KernelExTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + auto kernel_ex_def = task_def.kernel_ex(); + uint32_t op_index = kernel_ex_def.op_index(); + OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); + if (op_desc == nullptr) { + GELOGE(INTERNAL_ERROR, "Init aicpu task info error, index is out of range!"); + return INTERNAL_ERROR; + } + args_offset_ = davinci_model->GetTotalArgsSize(); + const size_t inputs_size = op_desc->GetInputsSize(); + const size_t outputs_size = op_desc->GetOutputsSize(); + // aicpu kernel input/output size + size_t mem_length = inputs_size + outputs_size; + uint32_t mem_size = sizeof(uint64_t) * mem_length; + davinci_model->SetTotalArgsSize(mem_size); + GELOGI("kernel task name %s, args_size %u, args_offset %u", op_desc->GetName().c_str(), mem_size, args_offset_); + return SUCCESS; +} + +Status KernelExTaskInfo::UpdateArgs() { + GELOGI("KernelExTaskInfo::UpdateArgs in."); + const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); + vector io_addrs; + vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); + vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); + + io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); + + GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), + "update known node %s zero copy addr failed.", op_desc_->GetName().c_str()); + + GELOGI("KernelExTaskInfo::UpdateArgs success."); + return SUCCESS; +} + Status KernelExTaskInfo::CopyTaskInfo(const domi::KernelExDef &kernel_def, const RuntimeParam &rts_param, const OpDescPtr &op_desc) { // Userspace copy need virtual address. diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h index a6419f9f..8903a17c 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h @@ -41,6 +41,10 @@ class KernelExTaskInfo : public TaskInfo { Status Release() override; + Status UpdateArgs() override; + + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + uint32_t GetTaskID() override { return task_id_; } uint32_t GetStreamId() override { return stream_id_; } @@ -61,6 +65,8 @@ class KernelExTaskInfo : public TaskInfo { void *kernel_buf_; void *input_output_addr_; void *dump_args_; + OpDescPtr op_desc_ = nullptr; + uint32_t args_offset_ = 0; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_KERNEL_EX_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc index 5ed89cc6..df0ed5fd 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc @@ -343,6 +343,10 @@ Status KernelTaskInfo::SuperKernelDistribute() { Status KernelTaskInfo::Distribute() { GELOGD("KernelTaskInfo Distribute Start."); + if (davinci_model_->IsKnownNode()) { + args_ = davinci_model_->GetCurrentArgsAddr(args_offset_); + GELOGI("Known node %s args addr %p, offset %u.", op_desc_->GetName().c_str(), args_, args_offset_); + } rtError_t rt_ret = RT_ERROR_NONE; char *skt_enable_env = getenv("SKT_ENABLE"); int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; @@ -380,7 +384,29 @@ Status KernelTaskInfo::Distribute() { return SUCCESS; } +Status KernelTaskInfo::UpdateArgs() { + GELOGI("KernelTaskInfo::UpdateArgs in."); + const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); + vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); + vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); + vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc_); + + vector io_addrs; + io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); + io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); + + GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), + "update known node %s zero copy addr failed.", op_desc_->GetName().c_str()); + + GELOGI("KernelTaskInfo::UpdateArgs success."); + return SUCCESS; +} + Status KernelTaskInfo::Release() { + if (davinci_model_ != nullptr && davinci_model_->IsKnownNode()) { + return SUCCESS; + } FreeRtMem(&args_); FreeRtMem(&flowtable_); FreeRtMem(&custom_info_.input_descs); @@ -439,6 +465,15 @@ Status KernelTaskInfo::UpdateL2Data(const domi::KernelDef &kernel_def) { return SUCCESS; } +Status KernelTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + domi::KernelDef kernel_def = task_def.kernel(); + uint32_t args_size = kernel_def.args_size(); + args_offset_ = davinci_model->GetTotalArgsSize(); + davinci_model->SetTotalArgsSize(args_size); + GELOGI("kernel task name , args_size %u, args_offset %u", args_size, args_offset_); + return SUCCESS; +} + Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kernel_def) { GELOGD("Do InitTVMTask."); GE_CHECK_NOTNULL(davinci_model_); @@ -448,6 +483,9 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne GELOGE(INTERNAL_ERROR, "InitTVMTaskInfo error, index:%u out of range!", ctx_.opIndex); return INTERNAL_ERROR; } + if (davinci_model_->IsKnownNode()) { + return SUCCESS; + } // Update Stub // When training, when the the second call to DavinciModel::init() comes here, stub_func_ is already valid, @@ -512,7 +550,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; - dump_args_ = static_cast(args_) + offset + kAddrLen * input_data_addrs.size(); + dump_args_ = static_cast(args_) + offset; } // update origin l2 data @@ -771,7 +809,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; - dump_args_ = static_cast(args_) + sizeof(aicpu::AicpuParamHead) + kAddrLen * input_addrs.size(); + dump_args_ = static_cast(args_) + sizeof(aicpu::AicpuParamHead); } vector virtual_io_addrs; // use virtual address for zero copy key. diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h index 234c25f4..e6753b10 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h @@ -67,6 +67,10 @@ class KernelTaskInfo : public TaskInfo { Status Distribute() override; + Status UpdateArgs() override; + + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + Status Release() override; cce::ccOpContext *GetCtx() override { return &ctx_; } @@ -146,6 +150,7 @@ class KernelTaskInfo : public TaskInfo { void *dump_args_; OpDescPtr op_desc_; DavinciModel *davinci_model_; + uint32_t args_offset_ = 0; // For super kernel uint32_t skt_id_; diff --git a/src/ge/graph/load/new_model_manager/task_info/task_info.h b/src/ge/graph/load/new_model_manager/task_info/task_info.h index 2a0b93c7..5d2c89eb 100644 --- a/src/ge/graph/load/new_model_manager/task_info/task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/task_info.h @@ -62,6 +62,10 @@ class TaskInfo { virtual Status Distribute() = 0; + virtual Status UpdateArgs() { return SUCCESS; } + + virtual Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { return SUCCESS; } + virtual Status Release() { return SUCCESS; } virtual cce::ccOpContext *GetCtx() { return nullptr; } diff --git a/src/ge/graph/manager/graph_caching_allocator.cc b/src/ge/graph/manager/graph_caching_allocator.cc new file mode 100644 index 00000000..5df6769b --- /dev/null +++ b/src/ge/graph/manager/graph_caching_allocator.cc @@ -0,0 +1,343 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/manager/graph_caching_allocator.h" + +#include +#include +#include + +#include "framework/common/debug/ge_log.h" +#include "graph/manager/graph_mem_allocator.h" + +namespace ge { +const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, + 8 * kMByteSize, + 32 * kMByteSize, + 128 * kMByteSize, + kGByteSize, + 4 * kGByteSize, + 16 * kGByteSize, + 26 * kGByteSize}; + +static bool BlockComparator(const Block *left, const Block *right) { + if (left->device_id != right->device_id) { + return left->device_id < right->device_id; + } + if (left->size != right->size) { + return left->size < right->size; + } + return reinterpret_cast(left->ptr) < reinterpret_cast(right->ptr); +} + +bool CanMerge(Block *block) { + if (block == nullptr || block->allocated || !block->IsSplit()) { + return false; + } + return true; +} + +size_t GetBinIndex(size_t size) { + size_t index = 0; + for (auto range : bin_ranges) { + if (size <= range) { + break; + } + ++index; + } + if (index > kNumBins - 1) { + index = kNumBins - 1; + } + return index; +} + +size_t GetAllocationSize(size_t size) { + size_t index = GetBinIndex(size); + return bin_ranges[index]; +} + +/// +/// @ingroup ge_graph +/// @brief block size based on alignment +/// @param [in] original malloc size +/// @return allocation size +/// +size_t GetBlockSize(size_t size) { + if (size == 0) { + return kRoundBlockSize; + } + return kRoundBlockSize * ((size + kRoundBlockSize - 1) / kRoundBlockSize); +} + +bool ShouldSplit(const Block *block, size_t size) { + return static_cast(size) <= (static_cast(block->size) * kSplitThreshold); +} + +CachingAllocator::CachingAllocator(rtMemType_t memory_type) : memory_type_(memory_type), memory_allocator_(nullptr) { + for (uint32_t i = 0; i < kNumBins; ++i) { + free_block_bins_[i] = nullptr; + } +} + +Status CachingAllocator::Initialize(uint32_t device_id) { + GELOGI("Device id %u", device_id); + // when redo Initialize free old memory + FreeBlocks(); + std::lock_guard lock(mutex_); + for (uint32_t i = 0; i < kNumBins; ++i) { + if (free_block_bins_[i] != nullptr) { + continue; + } + auto bin_ptr = new (std::nothrow) BlockBin(BlockComparator); + if (bin_ptr == nullptr) { + GELOGE(ge::FAILED, "Alloc BlockBin failed."); + return ge::FAILED; + } + free_block_bins_[i] = bin_ptr; + } + memory_allocator_ = MemManager::Instance(memory_type_); + if (memory_allocator_ == nullptr) { + return ge::FAILED; + } + return ge::SUCCESS; +} + +void CachingAllocator::Finalize(uint32_t device_id) { + GELOGI("Device id %u", device_id); + FreeBlocks(); + FreeBlockBins(); +} + +uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device_id) { + uint8_t *ptr = nullptr; + size = GetBlockSize(size); + Block *block = FindFreeBlock(size, org_ptr, device_id); + if (block != nullptr) { + ptr = block->ptr; + } else { + if (ge::SUCCESS == TryExtendCache(size, device_id)) { + block = FindFreeBlock(size, org_ptr, device_id); + if (block != nullptr) { + ptr = block->ptr; + } + } + } + if (ptr == nullptr) { + GELOGE(FAILED, "Malloc failed device id = %u, size= %zu", device_id, size); + } else { + std::lock_guard lock(mutex_); + block->allocated = true; + allocated_blocks_[block->ptr] = block; + GELOGI("Malloc device id = %u, size= %zu", device_id, size); + } + return ptr; +} + +Status CachingAllocator::Free(uint8_t *ptr, uint32_t device_id) { + GELOGI("Free device id = %u", device_id); + if (ptr == nullptr) { + GELOGE(PARAM_INVALID, "Invalid memory pointer"); + return ge::PARAM_INVALID; + } + + std::lock_guard lock(mutex_); + auto it = allocated_blocks_.find(ptr); + if (it == allocated_blocks_.end()) { + GELOGE(PARAM_INVALID, "Invalid memory pointer"); + return ge::PARAM_INVALID; + } + Block *block = it->second; + allocated_blocks_.erase(it); + FreeBlock(block); + return ge::SUCCESS; +} + +void CachingAllocator::FreeBlock(Block *block) { + if (block == nullptr || !block->allocated) { + return; + } + GELOGI("Free block size = %zu", block->size); + + std::lock_guard lock(mutex_); + block->allocated = false; + auto &bin = *block->bin; + Block *merge_blocks[] = {block->prev, block->next}; + for (Block *merge_block : merge_blocks) { + MergeBlocks(block, merge_block, bin); + } + bin.insert(block); +} + +void CachingAllocator::MergeBlocks(Block *dst, Block *src, BlockBin &bin) { + if (!CanMerge(dst) || !CanMerge(src)) { + return; + } + + if (dst->prev == src) { + dst->ptr = src->ptr; + dst->prev = src->prev; + if (dst->prev != nullptr) { + dst->prev->next = dst; + } + } else { + dst->next = src->next; + if (dst->next != nullptr) { + dst->next->prev = dst; + } + } + + dst->size += src->size; + bin.erase(src); + delete src; +} + +BlockBin *CachingAllocator::GetBlockBin(size_t size) { + size_t index = GetBinIndex(size); + return free_block_bins_[index]; +} + +Block *CachingAllocator::FindFreeBlock(size_t size, uint8_t *org_ptr, uint32_t device_id) { + // org_ptr - 1, try to find ptr same as org_ptr + Block key(device_id, size, (org_ptr == nullptr ? nullptr : org_ptr - 1)); + BlockBin *bin = GetBlockBin(size); + if (bin == nullptr) { + GELOGE(ge::FAILED, "Get block bin failed size = %zu", size); + return nullptr; + } + std::lock_guard lock(mutex_); + auto it = bin->lower_bound(&key); + if (it != bin->end()) { + Block *block = *it; + bin->erase(it); + if (block != nullptr) { + GELOGI("Find block size = %zu", block->size); + if (ShouldSplit(block, size)) { + return SplitBlock(block, size, *bin, device_id); + } + } + return block; + } + return nullptr; +} + +Block *CachingAllocator::SplitBlock(Block *block, size_t size, BlockBin &bin, uint32_t device_id) { + // block has been checked, should not be nullptr + Block *remaining = block; + Block *new_block = new (std::nothrow) Block(device_id, size, &bin, block->ptr); + if (new_block == nullptr) { + GELOGE(ge::FAILED, "Alloc block failed size = %zu", size); + return block; + } + new_block->prev = remaining->prev; + if (new_block->prev != nullptr) { + new_block->prev->next = new_block; + } + new_block->next = remaining; + remaining->prev = new_block; + remaining->ptr = remaining->ptr + size; + remaining->size -= size; + bin.insert(remaining); + return new_block; +} + +Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { + auto memory_size = GetAllocationSize(size); + const std::string purpose = "Memory for caching."; + auto memory_addr = memory_allocator_->MallocMemory(purpose, memory_size, device_id); + // try to free caches and malloc again when malloc memory failed + if (memory_addr == nullptr) { + FreeCachedBlocks(); + memory_addr = memory_allocator_->MallocMemory(purpose, memory_size, device_id); + if (memory_addr == nullptr) { + GELOGE(ge::FAILED, "TryExtendCache failed, no enough memory for size = %zu, device_id = %u", memory_size, + device_id); + return ge::FAILED; + } + } + if (AddToBlockBin(memory_addr, memory_size) != ge::SUCCESS) { + (void)memory_allocator_->FreeMemory(memory_addr); + return ge::FAILED; + } + return ge::SUCCESS; +} + +Status CachingAllocator::AddToBlockBin(uint8_t *ptr, size_t size) { + BlockBin *bin = GetBlockBin(size); + if (bin == nullptr) { + GELOGE(ge::FAILED, "Get block bin failed size = %zu", size); + return ge::FAILED; + } + Block *block = new (std::nothrow) Block(0, size, bin, nullptr); + if (block == nullptr) { + GELOGE(ge::FAILED, "Alloc block failed size = %zu", size); + return ge::FAILED; + } + + GELOGI("Block size = %zu", size); + block->ptr = ptr; + block->size = size; + + std::lock_guard lock(mutex_); + bin->insert(block); + return ge::SUCCESS; +} + +void CachingAllocator::FreeCachedBlocks() { + GELOGI("Free cached blocks"); + std::lock_guard lock(mutex_); + for (uint32_t i = 0; i < kNumBins; ++i) { + auto pool = free_block_bins_[i]; + if (pool == nullptr) { + continue; + } + for (auto it = pool->begin(); it != pool->end();) { + Block *block = *it; + // free block memory that has not been split + if ((block != nullptr) && (block->ptr != nullptr) && (block->prev == nullptr) && (block->next == nullptr) && + (memory_allocator_->FreeMemory(block->ptr) == ge::SUCCESS)) { + pool->erase(it++); + delete block; + continue; + } + ++it; + } + } +} + +void CachingAllocator::FreeBlocks() { + GELOGI("Free blocks"); + std::lock_guard lock(mutex_); + // free allocated blocks and put to cache + for (auto &it : allocated_blocks_) { + FreeBlock(it.second); + } + allocated_blocks_.clear(); + + FreeCachedBlocks(); +} + +void CachingAllocator::FreeBlockBins() { + GELOGI("Free block bins"); + std::lock_guard lock(mutex_); + for (uint32_t i = 0; i < kNumBins; ++i) { + if (free_block_bins_[i] != nullptr) { + delete free_block_bins_[i]; + free_block_bins_[i] = nullptr; + } + } +} + +} // namespace ge diff --git a/src/ge/graph/manager/graph_caching_allocator.h b/src/ge/graph/manager/graph_caching_allocator.h new file mode 100644 index 00000000..75864ce7 --- /dev/null +++ b/src/ge/graph/manager/graph_caching_allocator.h @@ -0,0 +1,212 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_ +#define GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "framework/common/ge_inner_error_codes.h" +#include "graph/node.h" +#include "runtime/mem.h" + +namespace ge { + +constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes +constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold +constexpr size_t kKByteSize = 1024; +constexpr size_t kMByteSize = 1024 * 1024; +constexpr size_t kGByteSize = 1024 * 1024 * 1024; + +struct Block; +typedef bool (*Comparison)(const Block *, const Block *); +using BlockBin = std::set; +static const uint32_t kNumBins = 8; + +struct Block { + uint32_t device_id; // npu device id + size_t size; // block size in bytes + BlockBin *bin; // owning block bin + uint8_t *ptr; // memory address + bool allocated; // in-use flag + Block *prev; // prev block if split from a larger allocation + Block *next; // next block if split from a larger allocation + + Block(uint32_t device, size_t size, BlockBin *bin, uint8_t *ptr) + : device_id(device), size(size), bin(bin), ptr(ptr), allocated(0), prev(nullptr), next(nullptr) {} + + // constructor for search key + Block(uint32_t device, size_t size, uint8_t *ptr) + : device_id(device), size(size), bin(nullptr), ptr(ptr), allocated(0), prev(nullptr), next(nullptr) {} + + bool IsSplit() const { return (prev != nullptr) || (next != nullptr); } +}; + +class MemoryAllocator; + +class CachingAllocator { + public: + explicit CachingAllocator(rtMemType_t memory_type); + + virtual ~CachingAllocator() = default; + + /// + /// @ingroup ge_graph + /// @brief caching allocator init + /// @param [in] device id + /// @return Status of init + /// + Status Initialize(uint32_t device_id = 0); + + /// + /// @ingroup ge_graph + /// @brief memory allocator finalize, release cached memory + /// @return void + /// + void Finalize(uint32_t device_id = 0); + + /// + /// @ingroup ge_graph + /// @brief malloc memory + /// @param [in] size memory size + /// @param [in] try to reuse the same memory + /// @param [in] device id + /// @return memory address + /// + uint8_t *Malloc(size_t size, uint8_t *org_ptr = nullptr, uint32_t device_id = 0); + + /// + /// @ingroup ge_graph + /// @brief free memory + /// @param [in] device_id device id + /// @param [out] memory_ptr memory address ptr + /// @return Status result of function + /// + Status Free(uint8_t *memory_addr, uint32_t device_id = 0); + + private: + /// + /// @ingroup ge_graph + /// @brief extend cache by size + /// @param [in] memory size + /// @param [in] device id + /// @return Status result of function + /// + Status TryExtendCache(size_t size, uint32_t device_id); + + /// + /// @ingroup ge_graph + /// @brief find free block by size + /// @param [in] memory size + /// @param [in] device_id device id + /// @return block ptr + /// + Block *FindFreeBlock(size_t size, uint8_t *org_ptr, uint32_t device_id); + + /// + /// @ingroup ge_graph + /// @brief get the right bin based on size + /// @param [in] original malloc size + /// @return block bin + /// + BlockBin *GetBlockBin(size_t size); + + /// + /// @ingroup ge_graph + /// @brief add memory to right bin based on size + /// @param [in] memory ptr + /// @param [in] memory size + /// @return Status result of function + /// + Status AddToBlockBin(uint8_t *ptr, size_t size); + + /// + /// @ingroup ge_graph + /// @brief free block to right bin + /// @param [in] block ptr + /// @return void + /// + void FreeBlock(Block *block); + + /// + /// @ingroup ge_graph + /// @brief free all cached blocks to right bin and release the memory when memory is not enough + /// @return void + /// + void FreeCachedBlocks(); + + /// + /// @ingroup ge_graph + /// @brief free allocated and cached blocks and release the memory when process exit + /// @return void + /// + void FreeBlocks(); + + /// + /// @ingroup ge_graph + /// @brief free block bins when process exit + /// @return void + /// + void FreeBlockBins(); + + /// + /// @ingroup ge_graph + /// @brief If a split block is freed, try merging with the original block + /// @param [inout] dest block ptr + /// @param [in] src block ptr + /// @param [out] block bin + /// @return void + /// + void MergeBlocks(Block *dst, Block *src, BlockBin &bin); + + /// + /// @ingroup ge_graph + /// @brief If the allocated memory size is too much smaller than the memory block, try to split the memory block + /// @param [in] original block ptr + /// @param [in] allocated memory size + /// @param [in] block bin + /// @param [in] device id + /// @return splited block ptr + /// + Block *SplitBlock(Block *block, size_t size, BlockBin &bin, uint32_t device_id); + + private: + rtMemType_t memory_type_; + + // device memory allocator + MemoryAllocator *memory_allocator_; + + // lock around all operations + mutable std::recursive_mutex mutex_; + + // allocated blocks by memory pointer + std::unordered_map allocated_blocks_; + + // block bins by different block size + BlockBin *free_block_bins_[kNumBins]; +}; + +}; // namespace ge + +#endif // GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_ diff --git a/src/ge/graph/manager/graph_manager.cc b/src/ge/graph/manager/graph_manager.cc index 514a90ce..9b0ceeb3 100644 --- a/src/ge/graph/manager/graph_manager.cc +++ b/src/ge/graph/manager/graph_manager.cc @@ -48,31 +48,41 @@ #include "graph/passes/compile_nodes_pass.h" #include "graph/passes/constant_folding_pass.h" #include "graph/passes/constant_fuse_same_pass.h" +#include "graph/passes/control_trigger_pass.h" #include "graph/passes/dimension_adjust_pass.h" #include "graph/passes/flow_ctrl_pass.h" +#include "graph/passes/hccl_group_pass.h" #include "graph/passes/hccl_memcpy_pass.h" +#include "graph/passes/identity_pass.h" #include "graph/passes/identify_reference_pass.h" #include "graph/passes/iterator_op_pass.h" #include "graph/passes/link_gen_mask_nodes_pass.h" #include "graph/passes/merge_pass.h" #include "graph/passes/multi_batch_pass.h" +#include "graph/passes/next_iteration_pass.h" #include "graph/passes/permute_pass.h" #include "graph/passes/prune_pass.h" #include "graph/passes/replace_with_empty_const_pass.h" #include "graph/passes/reshape_remove_pass.h" +#include "graph/passes/reshape_recovery_pass.h" #include "graph/passes/same_transdata_breadth_fusion_pass.h" #include "graph/passes/subgraph_pass.h" +#include "graph/passes/switch_dead_branch_elimination.h" #include "graph/passes/switch_logic_remove_pass.h" -#include "graph/passes/switch_pass.h" +#include "graph/passes/switch_op_pass.h" #include "graph/passes/transop_breadth_fusion_pass.h" #include "graph/passes/transop_depth_fusion_pass.h" #include "graph/passes/transop_nearby_allreduce_fusion_pass.h" #include "graph/passes/transop_symmetry_elimination_pass.h" #include "graph/passes/transop_without_reshape_fusion_pass.h" #include "graph/passes/transpose_transdata_pass.h" +#include "graph/passes/dimension_compute_pass.h" #include "graph/passes/variable_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" #include "graph/passes/variable_ref_delete_op_pass.h" +#include "graph/passes/variable_ref_useless_control_out_delete_pass.h" +#include "graph/passes/cond_remove_pass.h" +#include "graph/passes/ctrl_edge_transfer_pass.h" #include "graph/partition/dynamic_shape_partition.h" #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" @@ -85,6 +95,18 @@ const char *const kNetOutput = "NetOutput"; const char *const kVariable = "Variable"; const char *const kSend = "Send"; const char *const kRecv = "Recv"; + +bool IsTailingOptimization() { + string is_tailing_optimization_option; + auto ret = ge::GetContext().GetOption(ge::OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, is_tailing_optimization_option); + if (ret == ge::GRAPH_SUCCESS) { + GELOGI("Option ge.exec.isTailingOptimization is %s", is_tailing_optimization_option.c_str()); + // "1" means it's True from frontend option + return is_tailing_optimization_option == "1"; + } + GELOGW("OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION not set, use BFSTopologicalSorting by default."); + return false; +} } // namespace namespace ge { @@ -173,22 +195,23 @@ Status GraphManager::Finalize() { } // unload model - auto ge_model = graph_node->GetGeModel(); - if (ge_model != nullptr && ge_model->GetModelId() != INVALID_MODEL_ID && graph_node->GetLoadFlag()) { + auto ge_root_model = graph_node->GetGeRootModel(); + if (ge_root_model != nullptr && ge_root_model->GetModelId() != INVALID_MODEL_ID && graph_node->GetLoadFlag()) { rt_ret = rtSetDevice(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { - GELOGW("[GraphManager] rtSetDevice failed, modelId=%u, graphId=%u.", ge_model->GetModelId(), iter->first); + GELOGW("[GraphManager] rtSetDevice failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), iter->first); unload_model_ret = FAILED; continue; } - ret = GraphLoader::UnloadModel(ge_model->GetModelId()); + ret = GraphLoader::UnloadModel(ge_root_model->GetModelId()); if (ret != SUCCESS) { - GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", ge_model->GetModelId(), iter->first); + GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), iter->first); unload_model_ret = ret; } rt_ret = rtDeviceReset(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { - GELOGW("[GraphManager] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_model->GetModelId(), iter->first); + GELOGW("[GraphManager] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), + iter->first); unload_model_ret = FAILED; continue; } @@ -326,12 +349,11 @@ Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_gr #define GM_RUN_AND_DUMP(name, func, ...) \ do { \ GE_RUN(GraphManager, func, __VA_ARGS__); \ - GraphUtils::DumpGEGraph(compute_graph, "PreRunAfter" name); \ - GraphUtils::DumpGEGraphToOnnx(*compute_graph, "PreRunAfter" name); \ + GE_DUMP(compute_graph, "PreRunAfter" name); \ GELOGI("Run %s on graph %s(%u) success.", name, compute_graph->GetName().c_str(), graph_node->GetGraphId()); \ } while (0) -Status GraphManager::PreRunDynShape(const GraphNodePtr &graph_node, const std::vector &inputs, - vector &ge_models, GeModelPtr &ge_model, uint64_t session_id) { +Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector &inputs, + GeRootModelPtr &ge_root_model, uint64_t session_id) { GE_CHECK_NOTNULL(graph_node); GE_CHECK_NOTNULL(graph_node->GetGraph()); auto compute_graph = GraphUtils::GetComputeGraph(*graph_node->GetGraph()); @@ -340,22 +362,34 @@ Status GraphManager::PreRunDynShape(const GraphNodePtr &graph_node, const std::v GEEVENT("PreRun start, graph node size %zu, session id %lu, graph id %u, graph name %s", compute_graph->GetDirectNodesSize(), session_id, compute_graph->GetGraphID(), compute_graph->GetName().c_str()); - GraphUtils::DumpGEGraph(compute_graph, "PreRunBegin"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph, "PreRunBegin"); + GE_DUMP(compute_graph, "PreRunBegin"); GM_RUN_AND_DUMP("OptimizeGraphPrepare", graph_optimize_.OptimizeOriginalGraphForQuantize, compute_graph); GM_RUN_AND_DUMP("HandleSummaryOp", graph_optimize_.HandleSummaryOp, compute_graph); GM_RUN_AND_DUMP("Prepare", graph_preparer_.PrepareDynShape, graph_node->GetGraph(), inputs, compute_graph, session_id); - // original graph optimization and running format inference GM_RUN_AND_DUMP("OptimizeOriginalGraph", graph_optimize_.OptimizeOriginalGraph, compute_graph); + GM_RUN_AND_DUMP("PrepareRunningFormatRefiner", graph_preparer_.PrepareRunningFormatRefiner); + GM_RUN_AND_DUMP("RefineRunningFormat", graph_optimize_.OptimizeOriginalGraphJudgeInsert, compute_graph); + if (IsTailingOptimization()) { + GM_RUN_AND_DUMP("OptimizeSwitchOp", graph_preparer_.SwitchOpOptimize, compute_graph); + } GM_RUN_AND_DUMP("Optimize1", OptimizeStage1, compute_graph); GM_RUN_AND_DUMP("InferShape2", compute_graph->InferShapeInNeed); + // TODO: to be delete + const char *unknown_shape_skip = std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION"); + if (unknown_shape_skip != nullptr) { + PassManager graph_pass; + GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::CtrlEdgeTransferPass", new (std::nothrow) CtrlEdgeTransferPass)) + GE_CHK_STATUS_RET(graph_pass.Run(compute_graph)); + } + GM_RUN_AND_DUMP("OptimizeSubgraph", OptimizeSubgraph, graph_node, compute_graph, session_id); GM_RUN_AND_DUMP("Optimize2", OptimizeStage2, compute_graph); - GM_RUN_AND_DUMP("Build", Build, graph_node, compute_graph, ge_models, ge_model, session_id); + GM_RUN_AND_DUMP("Build", Build, graph_node, compute_graph, ge_root_model, session_id); // when set incre build, save om model and var manager + GeModelPtr ge_model = nullptr; auto save_ret = SaveCacheAfterBuild(graph_node->GetGraphId(), compute_graph, ge_model); if (save_ret != SUCCESS) { GELOGW("Fail to save cache."); @@ -367,91 +401,8 @@ Status GraphManager::PreRunDynShape(const GraphNodePtr &graph_node, const std::v } #undef RUN_AND_DUMP -Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector &inputs, - vector &ge_models, GeModelPtr &ge_model, uint64_t session_id) { - GELOGI("Ready For PreRun Start session_id = %lu.", session_id); - GE_TIMESTAMP_START(PreRun); - GE_CHECK_NOTNULL(graph_node); - // it will not execute graph preprocess, optimize, parition, build if the graph has built successful. - GE_CHECK_NOTNULL(graph_node->GetGraph()); - auto compute_graph = GraphUtils::GetComputeGraph(*graph_node->GetGraph()); - GE_IF_BOOL_EXEC(compute_graph == nullptr, GELOGE(FAILED, "compute graph is NULL."); return FAILED); - GraphUtils::DumpGEGraph(compute_graph, "BeforeSummaryHandle"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph, "BeforeSummaryHandle"); - for (auto graph : compute_graph->GetAllSubgraphs()) { - GraphUtils::DumpGEGraph(graph, "BeforeSummaryHandle"); - GraphUtils::DumpGEGraphToOnnx(*graph, "BeforeSummaryHandleSubgraph"); - } - GEEVENT("PreRun start, graph node size is %zu", compute_graph->GetDirectNodesSize()); - // optimize the summary op in graph: store the summary name and replace the summary ops with net_output op. - GE_TIMESTAMP_START(HandleSummaryOp); - auto ret = graph_optimize_.HandleSummaryOp(compute_graph); - GE_TIMESTAMP_END(HandleSummaryOp, "GraphManager::HandleSummaryOp"); - GE_CHK_BOOL_EXEC(ret == SUCCESS, return ret, "[RunTrainGraph] HandleSummaryOp failed."); - GE_TIMESTAMP_START(GraphPrepare); - ret = graph_preparer_.Prepare(graph_node->GetGraph(), inputs, compute_graph, var_acc_ctrl_, session_id); - if (ret != SUCCESS) { - GELOGE(ret, "ATC RunGraph input compute graph is NULL"); - return ret; - } - GE_TIMESTAMP_END(GraphPrepare, "GraphPrepare::Prepare"); - compute_graph->SetSessionID(session_id); - GraphUtils::DumpGEGraph(compute_graph, "OptimizeOriginalGraphAfter"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph, "OptimizeOriginalGraphAfter"); - for (auto graph : compute_graph->GetAllSubgraphs()) { - GraphUtils::DumpGEGraphToOnnx(*graph, "OptimizeOriginalGraphAfterSubgraph"); - } - - GE_TIMESTAMP_START(InferShape); - // Origin graph infershape - GE_CHK_STATUS_EXEC(compute_graph->InferShapeInNeed(), - GELOGE(GE_GRAPH_INFERSHAPE_FAILED, " OriginGraph infershape failed"); - return GE_GRAPH_INFERSHAPE_FAILED;) - GE_TIMESTAMP_END(InferShape, "ComputeGraph::InferShapeInNeed"); - - ret = OptimizeSubgraph(graph_node, compute_graph, session_id); - if (ret != SUCCESS) { - return ret; - } - std::shared_ptr instance_ge = ge::GELib::GetInstance(); - if (instance_ge != nullptr && instance_ge->InitFlag()) { - // optimize after merge subgraph - GE_TIMESTAMP_START(OptimizeAfterMergeSubgraph); - const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); - if (buffer_optimize_on != nullptr) { - ret = NewOptimizeAfterMergeSubGraph(compute_graph); - } else { - ret = OptimizeAfterMergeSubGraph(compute_graph); - } - if (ret != SUCCESS) { - GELOGE(ret, "Optimize after merge subgraph failed."); - return ret; - } - GE_TIMESTAMP_END(OptimizeAfterMergeSubgraph, "GraphManager::OptimizeAfterMergeSubGraph"); - } - - GraphUtils::DumpGEGraph(compute_graph, "OptimizeMergeSubGraphAfter"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph, "OptimizeMergeSubGraphAfter"); - - ret = Build(graph_node, compute_graph, ge_models, ge_model, session_id); - if (ret != SUCCESS) { - return ret; - } - - // when set incre build, save om model and var manager - auto save_ret = SaveCacheAfterBuild(graph_node->GetGraphId(), compute_graph, ge_model); - if (save_ret != SUCCESS) { - GELOGW("Fail to save cache."); - } - // release rts generate context - RtContextUtil::GetInstance().DestroyrtContexts(); - GE_TIMESTAMP_END(PreRun, "GraphManager::PreRun"); - GEEVENT("[GEPERFTRACE] GE PreRun End"); - return ret; -} - Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std::vector &inputs, - vector &ge_models, uint64_t session_id) { + GeRootModelPtr &ge_root_model, uint64_t session_id) { // it will not execute graph prreprocess, optimize, parition, build if the graph has built successful. Status ret = SUCCESS; if (IsGraphNeedBuild(graph_node)) { @@ -466,13 +417,13 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: // check need incre build. ret = IncreBuild(graph_node, ge_model); if (ret != SUCCESS) { - ret = PreRun(graph_node, inputs, ge_models, ge_model, session_id); + ret = PreRun(graph_node, inputs, ge_root_model, session_id); if (ret != SUCCESS) { GELOGE(ret, "PreRun Failed."); return ret; } } - ret = LoadGraph(ge_model, graph_node); + ret = LoadGraph(ge_root_model, graph_node); if (ret != SUCCESS) { GELOGE(ret, "LoadGraph Failed."); return ret; @@ -480,8 +431,8 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: graph_node->SetBuildFlag(true); var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); } else if (!graph_node->GetLoadFlag()) { - GeModelPtr ge_model = graph_node->GetGeModel(); - ret = LoadGraph(ge_model, graph_node); + GeRootModelPtr ge_root_model_ptr = graph_node->GetGeRootModel(); + ret = LoadGraph(ge_root_model_ptr, graph_node); if (ret != SUCCESS) { GELOGE(ret, "LoadGraph Failed."); return ret; @@ -489,19 +440,27 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: } return ret; } -Status GraphManager::LoadGraph(const GeModelPtr &ge_model, const GraphNodePtr &graph_node) { +Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { GELOGI("[LoadGraph] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); - if (options_.run_graph_flag && ge_model != nullptr) { + if (options_.run_graph_flag && ge_root_model != nullptr) { // synchronization run graph with model std::shared_ptr model_listener = GetModelListener(); ModelIdInfo model_id_info; - if (getenv(kEnvGeuseStaticMemory) != nullptr) { - GELOGI("[LoadGraph] GE_USE_STATIC_MEMORY is seted."); - } else { - GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); + bool is_unknown_shape = false; + GE_CHK_STATUS_RET(ge_root_model->CheckIsUnknownShape(is_unknown_shape)); + if (!is_unknown_shape) { + if (getenv(kEnvGeuseStaticMemory) != nullptr) { + GELOGI("[LoadGraph] GE_USE_STATIC_MEMORY is seted."); + } else { + auto root_graph = ge_root_model->GetRootGraph(); + GE_CHECK_NOTNULL(root_graph); + auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); + GeModelPtr ge_model = name_to_model[root_graph->GetName()]; + GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); + } } GE_TIMESTAMP_START(LoadGraph); - Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_model, model_listener); + Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, model_listener); GE_TIMESTAMP_END(LoadGraph, "GraphManager::LoadGraph"); if (ret != SUCCESS) { GELOGE(ret, "[StartForRunGraph] LoadGraph Failed"); @@ -509,8 +468,8 @@ Status GraphManager::LoadGraph(const GeModelPtr &ge_model, const GraphNodePtr &g return ret; } graph_node->SetLoadFlag(true); - ge_model->SetModelId(model_id_info.model_id); - graph_node->SetGeModel(ge_model); + ge_root_model->SetModelId(model_id_info.model_id); + graph_node->SetGeRootModel(ge_root_model); } return SUCCESS; } @@ -612,7 +571,7 @@ Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &grap GE_CHK_STATUS_RET(graph_executor_.SetGraphContext(GetGraphContext())); graph_executor_.SetTrainFlag(options_.train_graph_flag); } - ret = graph_executor_.ExecuteGraph(graph_id, graph_node->GetGeModel(), inputs, outputs); + ret = graph_executor_.ExecuteGraph(graph_id, graph_node->GetGeRootModel(), inputs, outputs); graph_node->SetRunFlag(false); if (ret != SUCCESS) { @@ -661,21 +620,18 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vector ge_models; - if (options_.local_fmk_op_flag) { graph_optimize_.TranFrameOp(compute_graph_tmp); } - ret = StartForRunGraph(graph_node, inputs, ge_models, session_id); + GeRootModelPtr ge_root_model; + ret = StartForRunGraph(graph_node, inputs, ge_root_model, session_id); if (ret != SUCCESS) { GELOGE(ret, "[RunGraph] StartForRunGraph failed!"); graph_node->SetRunFlag(false); return ret; } - const std::vector &all_sub_graph = graph_node->GetAllSubGraph(); - // excute graph ret = InnerRunGraph(graph_node, graph_id, inputs, outputs); if (ret != SUCCESS) { @@ -690,8 +646,10 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vectorGetSubGraph(); + GeRootModelPtr root_model = graph_node->GetGeRootModel(); + if (root_model != nullptr) { + GELOGI("Start CheckpointHandle."); + auto checkPointGraph = root_model->GetRootGraph(); if (IsCheckpointGraph(checkPointGraph)) { ret = CheckpointHandle(graph_id, checkPointGraph, outputs); if (ret != SUCCESS) { @@ -731,7 +689,7 @@ Status GraphManager::GenerateInfershapeGraph(GraphId &graph_id) { } Status GraphManager::BuildGraph(const GraphId &graph_id, const std::vector &inputs, - std::vector &models) { + GeRootModelPtr &ge_root_model) { GELOGI("[BuildGraph] start to build graph, graph_id=%u.", graph_id); if (inputs.empty()) { GELOGW("[BuildGraph] BuildGraph warning: empty GeTensor inputs"); @@ -763,7 +721,7 @@ Status GraphManager::BuildGraph(const GraphId &graph_id, const std::vector(tv.tv_sec * 1000000 + tv.tv_usec); // 1000000us - ret = StartForRunGraph(graph_node, inputs, models, session_id); + ret = StartForRunGraph(graph_node, inputs, ge_root_model, session_id); graph_node->SetRunFlag(false); if (ret != SUCCESS) { GELOGE(GE_GRAPH_PRERUN_FAILED, "[BuildGraph] StartForRunGraph failed!"); @@ -811,8 +769,8 @@ void GraphManager::RemoveModelCacheHelper(const GraphId &graph_id) { } } -bool GraphManager::CheckModelLoad(const GeModelPtr &ge_model, bool load_flag) { - return ((ge_model != nullptr) && (ge_model->GetModelId() != INVALID_MODEL_ID) && load_flag); +bool GraphManager::CheckModelLoad(const GeRootModelPtr &ge_root_model, bool load_flag) { + return ((ge_root_model != nullptr) && (ge_root_model->GetModelId() != INVALID_MODEL_ID) && load_flag); } Status GraphManager::RemoveGraph(const GraphId &graph_id) { @@ -867,24 +825,24 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { RemoveModelCacheHelper(graph_id); - auto ge_model = graph_node->GetGeModel(); - if (CheckModelLoad(ge_model, graph_node->GetLoadFlag())) { - GELOGI("Unload model %u.", ge_model->GetModelId()); + auto ge_root_model = graph_node->GetGeRootModel(); + if (CheckModelLoad(ge_root_model, graph_node->GetLoadFlag())) { + GELOGI("Unload model %u.", ge_root_model->GetModelId()); rt_ret = rtSetDevice(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", ge_model->GetModelId(), + GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), graph_id); return FAILED; } - middle_ret = GraphLoader::UnloadModel(ge_model->GetModelId()); + middle_ret = GraphLoader::UnloadModel(ge_root_model->GetModelId()); if (middle_ret != SUCCESS) { - GELOGE(middle_ret, "[GraphManager:] unload model failed, modelId=%u, graph_id=%u.", ge_model->GetModelId(), + GELOGE(middle_ret, "[GraphManager:] unload model failed, modelId=%u, graph_id=%u.", ge_root_model->GetModelId(), graph_id); ret = middle_ret; } rt_ret = rtDeviceReset(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_model->GetModelId(), + GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), graph_id); ret = FAILED; } @@ -1265,7 +1223,7 @@ Status GraphManager::CheckpointHandle(const GraphId &graph_id, const ComputeGrap std::map save_results; NodePtr netoutput = nullptr; - for (const auto &node : compute_graph->GetDirectNode()) { + for (const auto &node : compute_graph->GetAllNodes()) { if (node->GetType() == kNetOutput) { netoutput = node; break; @@ -1436,7 +1394,7 @@ bool GraphManager::IsBroadCastOpData(const ge::NodePtr &var_node) { GE_RT_FALSE_CHECK_NOTNULL(in_anchor); ge::NodePtr dst_node = in_anchor->GetOwnerNode(); GE_RT_FALSE_CHECK_NOTNULL(dst_node); - if (dst_node->GetType() == HCOMBROADCAST) { + if (dst_node->GetType() == HCOMBROADCAST || dst_node->GetType() == HVDCALLBACKBROADCAST) { return true; } } @@ -1444,6 +1402,21 @@ bool GraphManager::IsBroadCastOpData(const ge::NodePtr &var_node) { return false; } +void GraphManager::SetAttrForHcomBroadCastOp(ge::ComputeGraphPtr &compute_graph) { + // add variable attr for hccl broadcast,need to be removed after variable pass online + for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { + if (node->GetOpDesc()->GetType() != ge::VARIABLE) { + continue; + } + if (IsBroadCastOpData(node)) { + AdjustBroadCastOpData(node); + } + if (IsAssignOpData(node)) { + AdjustAssignOpData(node); + } + } +} + void GraphManager::AdjustBroadCastOpData(const ge::NodePtr &var_node) { if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore")) { GELOGW("set var_is_restore failed"); @@ -1505,8 +1478,8 @@ bool GraphManager::ConfirmUseOpAndIndexByNode(const ge::NodePtr &var_node, return false; } -Status GraphManager::RemoveIsolatedConst(ge::ComputeGraphPtr &compute_graph) { - for (ge::NodePtr &n : compute_graph->GetAllNodes()) { +Status GraphManager::RemoveIsolatedConstInThisGraph(ge::ComputeGraphPtr &compute_graph) { + for (ge::NodePtr &n : compute_graph->GetDirectNode()) { if (n->GetOpDesc() == nullptr) { continue; } @@ -1525,6 +1498,14 @@ Status GraphManager::RemoveIsolatedConst(ge::ComputeGraphPtr &compute_graph) { return SUCCESS; } +Status GraphManager::RemoveIsolatedConst(ge::ComputeGraphPtr &compute_graph) { + GE_CHK_STATUS_RET(RemoveIsolatedConstInThisGraph(compute_graph)); + for (auto &sub_graph : compute_graph->GetAllSubgraphs()) { + GE_CHK_STATUS_RET(RemoveIsolatedConstInThisGraph(sub_graph)); + } + return SUCCESS; +} + Status GraphManager::NewOptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph) { GELOGD("NewOptimizeAfterMergeSubGraph in"); @@ -1547,9 +1528,9 @@ Status GraphManager::NewOptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_ } PassManager passes; - GE_CHK_STATUS_RET(passes.AddPass(new (std::nothrow) MultiBatchPass)); - GE_CHK_STATUS_RET(passes.AddPass(new (std::nothrow) CompileNodesPass)); - GE_CHK_STATUS_RET(passes.AddPass(new (std::nothrow) AtomicAddrCleanPass)); + GE_CHK_STATUS_RET(passes.AddPass("MultiBatchPass", new (std::nothrow) MultiBatchPass)); + GE_CHK_STATUS_RET(passes.AddPass("CompileNodesPass", new (std::nothrow) CompileNodesPass)); + GE_CHK_STATUS_RET(passes.AddPass("AtomicAddrCleanPass", new (std::nothrow) AtomicAddrCleanPass)); GE_TIMESTAMP_START(passes); ret = passes.Run(compute_graph); @@ -1573,107 +1554,147 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { GELOGI("get ge.exec.variable_acc failed. set default value."); } PassManager after_merge_passes; - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) ConstantFuseSamePass)); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariablePrepareOpPass)); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) IteratorOpPass)); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) CommonSubexpressionEliminationPass)); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) PermutePass)); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariablePrepareOpPass)); + GE_CHK_STATUS_RET( + after_merge_passes.AddPass("OptimizeStage1_1::ConstantFuseSamePass", new (std::nothrow) ConstantFuseSamePass)); + GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CommonSubexpressionEliminationPass", + new (std::nothrow) CommonSubexpressionEliminationPass)); + GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::PermutePass", new (std::nothrow) PermutePass)) + /* + * The SameTransdataBreadthFusionPass should be called before VariableOpPass, because of the scene following: + * node3 + * | + * transdata1 node2 + * | | + * cast1 transdata2 + * \ / + * var + * the node `transdata1` should be moved to the front of the ndoe `cast1`, + * to ensure that `transdata1` and `transdata2` can be fusion with `var`. + * But it is a temp solution, because the `SameTransdataBreadthFusionPass` + * can only move `TransData` but not `Cast` nodes. + * So if we exchange Cast and TransData, the fusion mechanism will fail. + */ + GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::SameTransdataBreadthFusionPass", + new (std::nothrow) SameTransdataBreadthFusionPass)) GE_IF_BOOL_EXEC(options == "default" || options == "1", GELOGI("turn on variable accelerator"); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariableOpPass(&var_acc_ctrl_)))); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) TransOpDepthFusionPass)); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) TransOpBreadthFusionPass)); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) SameTransdataBreadthFusionPass)); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) TransOpWithoutReshapeFusionPass)); + GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::VariableOpPass", + new (std::nothrow) VariableOpPass(&var_acc_ctrl_)))) + GE_CHK_STATUS_RET( + after_merge_passes.AddPass("OptimizeStage1_1::TransOpDepthFusionPass", new (std::nothrow) TransOpDepthFusionPass)) + GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::TransOpWithoutReshapeFusionPass", + new (std::nothrow) TransOpWithoutReshapeFusionPass)) + GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::TransOpBreadthFusionPass", + new (std::nothrow) TransOpBreadthFusionPass)) GE_TIMESTAMP_START(after_merge_passes); auto ret = after_merge_passes.Run(compute_graph); - GE_TIMESTAMP_END(after_merge_passes, "GraphManager::AfterMergePasses"); + GE_TIMESTAMP_END(after_merge_passes, "GraphManager::OptimizeStage1_1"); if (ret != SUCCESS && ret != NOT_CHANGED) { - GELOGE(ret, "Run passes after merge sub graph failed, ret:%d.", ret); + GELOGE(ret, "Run passes when OptimizeStage1_1 failed, ret:%u.", ret); return ret; } - GEPass ge_passes(compute_graph); + GraphUtils::DumpGEGraphToOnnx(*compute_graph, "OptimizeStage1_1"); + NamesToPass names_to_passes; TransOpNearbyAllreduceFusionPass trans_op_nearby_allreduce_fusion_pass; ReshapeRemovePass reshape_remove_pass; ConstantFoldingPass constant_folding_pass; DimensionAdjustPass dimension_adjust_pass; AddNPass addn_pass; - SwitchPass switch_pass; + SwitchDeadBranchElimination switch_dead_branch_elimination; SwitchLogicRemovePass switch_logic_remove_pass; MergePass merge_pass; IdentifyReferencePass identify_reference_pass; CastRemovePass cast_remove_pass; TransposeTransDataPass transpose_transdata_pass; + TransOpSymmetryEliminationPass symmetry_elimination_pass; + DimensionComputePass dimension_compute_pass; names_to_passes.emplace_back("AddNPass", &addn_pass); - names_to_passes.emplace_back("SwitchPass", &switch_pass); + names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); names_to_passes.emplace_back("SwitchLogicRemovePass", &switch_logic_remove_pass); names_to_passes.emplace_back("MergePass", &merge_pass); names_to_passes.emplace_back("IdentifyReferencePass", &identify_reference_pass); names_to_passes.emplace_back("CastRemovePass", &cast_remove_pass); names_to_passes.emplace_back("TransposeTransDataPass", &transpose_transdata_pass); + names_to_passes.emplace_back("TransOpSymmetryEliminationPass", &symmetry_elimination_pass); names_to_passes.emplace_back("TransOpNearbyAllreduceFusionPass", &trans_op_nearby_allreduce_fusion_pass); names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); + names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); GE_TIMESTAMP_START(names_to_passes); - ret = ge_passes.Run(names_to_passes); - GE_TIMESTAMP_END(names_to_passes, "GraphManager::MergedGraphNameToPasses"); + ret = GEPass(compute_graph).Run(names_to_passes); + GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2"); if (ret != SUCCESS) { - GELOGE(ret, "Run ge_passes optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); + GELOGE(ret, "Run passes when OptimizeStage1_2 failed, ret:%u.", ret); return ret; } + GraphUtils::DumpGEGraphToOnnx(*compute_graph, "OptimizeStage1_2"); PassManager graph_pass; - try { - (void)graph_pass.AddPass(new PrunePass); - } catch (std::bad_alloc &e) { - GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); - return INTERNAL_ERROR; + // the prune pass should between SwtichPass and SwitchOpPass + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)) + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::SwitchOpPass", new (std::nothrow) SwitchOpPass)) + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::IteratorOpPass", new (std::nothrow) IteratorOpPass)) + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass", + new (std::nothrow) VariableRefUselessControlOutDeletePass)) + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass)) + GE_TIMESTAMP_START(graph_pass); + ret = graph_pass.Run(compute_graph); + GE_TIMESTAMP_END(graph_pass, "GraphManager::OptimizeStage1_3"); + if (ret != SUCCESS && ret != NOT_CHANGED) { + GELOGE(ret, "Run passes when OptimizeStage1_3 failed, ret:%u.", ret); + return ret; + } + + NamesToPass identity_remove_pass; + GE_TIMESTAMP_START(identity_remove_pass); + IdentityPass identity_force_pass(true); // after SwitchOpPass + identity_remove_pass.emplace_back("IdentityPass", &identity_force_pass); + ret = GEPass(compute_graph).Run(identity_remove_pass); + GE_TIMESTAMP_END(identity_remove_pass, "GraphPrepare::IdentityRemovePass"); + if (ret != SUCCESS) { + GELOGE(ret, "Run identity remove pass for preprocess failed, ret:%u.", ret); + return ret; } return SUCCESS; } + Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { GELOGI("Start optimize after merge sub graph."); PassManager after_merge_passes; - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariableRefDeleteOpPass)); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) AtomicAddrCleanPass)); - GE_CHK_STATUS_RET( - after_merge_passes.AddPass(new (std::nothrow) LinkGenMaskNodesPass(options_.stream_max_parallel_num))); + GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage2::AfterMergePasses::LinkGenMaskNodesPass", + new (std::nothrow) + LinkGenMaskNodesPass(options_.stream_max_parallel_num))); GE_TIMESTAMP_START(after_merge_passes); auto ret = after_merge_passes.Run(compute_graph); - GE_TIMESTAMP_END(after_merge_passes, "GraphManager::AfterMergePasses"); + GE_TIMESTAMP_END(after_merge_passes, "OptimizeStage2::AfterMergePasses"); if (ret != SUCCESS && ret != NOT_CHANGED) { GELOGE(ret, "Run passes after merge sub graph failed, ret:%d.", ret); return ret; } + SetAttrForHcomBroadCastOp(compute_graph); - // add variable attr for hccl broadcast,need to be removed after variable pass online - for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { - if (node->GetOpDesc()->GetType() != VARIABLE) { - continue; - } - - if (IsBroadCastOpData(node)) { - AdjustBroadCastOpData(node); - } - if (IsAssignOpData(node)) { - AdjustAssignOpData(node); - } - } - - GEPass ge_passes(compute_graph); NamesToPass names_to_passes; ConstantFoldingPass constant_folding_pass; + ReshapeRemovePass reshape_remove_pass; + CondRemovePass condition_remove_pass; names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); + names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); + names_to_passes.emplace_back("CondRemovePass", &condition_remove_pass); + HcclGroupPass hccl_group_pass; + if (IsTailingOptimization()) { + names_to_passes.emplace_back("HcclGroupPass", &hccl_group_pass); + } GE_TIMESTAMP_START(names_to_passes); - ret = ge_passes.Run(names_to_passes); - GE_TIMESTAMP_END(names_to_passes, "GraphManager::MergedGraphNameToPasses"); + ret = GEPass(compute_graph).Run(names_to_passes); + GE_TIMESTAMP_END(names_to_passes, "OptimizeStage2::MergedGraphNameToPasses"); if (ret != SUCCESS) { GELOGE(ret, "Run ge_passes optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); return ret; @@ -1686,31 +1707,67 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { } PassManager pass_for_control_attr_optimize; - GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass(new (std::nothrow) HcclMemcpyPass)); if (options_.train_graph_flag) { - GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass(new (std::nothrow) FlowCtrlPass)); - } - GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass(new (std::nothrow) MultiBatchPass)); - GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass(new (std::nothrow) CompileNodesPass)); + // TODO: to be delete + const char *unknown_shape_skip = std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION"); + if (unknown_shape_skip == nullptr) { + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::FlowCtrlPass", + new (std::nothrow) FlowCtrlPass)) + } + } + // TODO: to be delete + const char *unknown_shape_skip = std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION"); + if (unknown_shape_skip == nullptr) { + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::SubgraphPass", + new (std::nothrow) SubgraphPass)); + } + + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::MultiBatchPass", + new (std::nothrow) MultiBatchPass)) + // the value of the attr is the original variable name the ref-variable ref from. + // The attr will be used when allocating memory, + // the node marked attr will be output to a variable instead of new-allocated memory. + // Therefore, ComputeGraph should not delete nodes after `VariableRefDeleteOpPass` + // to prevent unexpected deletion of nodes marked with attr + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::AfterMergePasses::VariableRefDeleteOpPass", + new (std::nothrow) VariableRefDeleteOpPass)) + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::CompileNodesPass", + new (std::nothrow) CompileNodesPass)) + // When the input node to be cleared is after a `Data` node, the atomic-clean-node should not be inserted. + // So The ComputeGraph should not delete nodes after `AtomicAddrCleanPass` + // to prevent unexpected deletion of nodes after a `Data` node + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::AfterMergePasses::AtomicAddrCleanPass", + new (std::nothrow) AtomicAddrCleanPass)) GE_TIMESTAMP_START(pass_for_control_attr_optimize); ret = pass_for_control_attr_optimize.Run(compute_graph); - GE_TIMESTAMP_END(pass_for_control_attr_optimize, "GraphManager::ControlAttrOptimize"); + GE_TIMESTAMP_END(pass_for_control_attr_optimize, "OptimizeStage2::ControlAttrOptimize"); if (ret != SUCCESS && ret != NOT_CHANGED) { GELOGE(ret, "Run passes when optimize stage 2 failed"); return ret; } + ChangeConstTypeWhenTraining(compute_graph); + ret = compute_graph->TopologicalSorting(); if (ret != SUCCESS) { GELOGE(ret, "Graph topological sort failed, ret:%d.", ret); return ret; } - GELOGI("End optimize after merge sub graph."); return SUCCESS; } - +void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph) { + // The constant for train is CONSTANTOP, and is CONSTANT for inference. They will be unified in future. + if (options_.train_graph_flag) { + for (NodePtr &n : compute_graph->GetAllNodes()) { + // This can ensure that n is not a null pointer + if (n->GetOpDesc()->GetType() == CONSTANT) { + n->GetOpDesc()->SetType(CONSTANTOP); + } + } + } +} Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph) { GELOGI("Start optimize after merge sub graph."); @@ -1732,9 +1789,9 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra GELOGI("get ge.exec.variable_acc failed. set default value."); } PassManager after_merge_passes; - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) PermutePass)); - GE_IF_BOOL_EXEC(options == "default" || options == "1", GELOGI("turn on variable accelerator"); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariableOpPass(&var_acc_ctrl_)))); + GE_CHK_STATUS_RET(after_merge_passes.AddPass("PermutePass", new (std::nothrow) PermutePass)); + GE_IF_BOOL_EXEC(options == "default" || options == "1", GELOGI("turn on variable accelerator"); GE_CHK_STATUS_RET( + after_merge_passes.AddPass("VariableOpPass", new (std::nothrow) VariableOpPass(&var_acc_ctrl_)))); ret = after_merge_passes.Run(compute_graph); if (ret != SUCCESS && ret != NOT_CHANGED) { GELOGE(ret, "Run passes after merge sub graph failed, ret:%d.", ret); @@ -1752,13 +1809,17 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra GE_CHK_STATUS_RET(ret, "Run ge_passes optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); PassManager after_merge_fusion_passes; - GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass(new (std::nothrow) TransOpBreadthFusionPass)); - GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass(new (std::nothrow) VariableRefDeleteOpPass)); - GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass(new (std::nothrow) SameTransdataBreadthFusionPass)); - GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass(new (std::nothrow) TransOpWithoutReshapeFusionPass)); - GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass(new (std::nothrow) AtomicAddrCleanPass)); + GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass("TransOpWithoutReshapeFusionPass", + new (std::nothrow) TransOpWithoutReshapeFusionPass)); GE_CHK_STATUS_RET( - after_merge_fusion_passes.AddPass(new (std::nothrow) LinkGenMaskNodesPass(options_.stream_max_parallel_num))); + after_merge_fusion_passes.AddPass("TransOpBreadthFusionPass", new (std::nothrow) TransOpBreadthFusionPass)); + GE_CHK_STATUS_RET( + after_merge_fusion_passes.AddPass("VariableRefDeleteOpPass", new (std::nothrow) VariableRefDeleteOpPass)); + GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass("SameTransdataBreadthFusionPass", + new (std::nothrow) SameTransdataBreadthFusionPass)); + GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass("AtomicAddrCleanPass", new (std::nothrow) AtomicAddrCleanPass)); + GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass( + "LinkGenMaskNodesPass", new (std::nothrow) LinkGenMaskNodesPass(options_.stream_max_parallel_num))); GE_TIMESTAMP_START(after_merge_fusion_passes); ret = after_merge_fusion_passes.Run(compute_graph); GE_TIMESTAMP_END(after_merge_fusion_passes, "GraphManager::AfterMergePasses"); @@ -1790,6 +1851,8 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); DimensionAdjustPass dimension_adjust_pass; names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); + CondRemovePass condition_remove_pass; + names_to_passes.emplace_back("CondRemovePass", &condition_remove_pass); GE_TIMESTAMP_START(names_to_passes); ret = ge_passes.Run(names_to_passes); GE_TIMESTAMP_END(names_to_passes, "GraphManager::MergedGraphNameToPasses"); @@ -1799,9 +1862,9 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra GE_CHK_STATUS_RET(ret, "Remove isolated Constant failed, ret:%d.", ret); PassManager pass_for_optimize; - GE_CHK_STATUS_RET(pass_for_optimize.AddPass(new (std::nothrow) SubgraphPass)); - GE_CHK_STATUS_RET(pass_for_optimize.AddPass(new (std::nothrow) MultiBatchPass)); - GE_CHK_STATUS_RET(pass_for_optimize.AddPass(new (std::nothrow) CompileNodesPass)); + GE_CHK_STATUS_RET(pass_for_optimize.AddPass("SubgraphPass", new (std::nothrow) SubgraphPass)); + GE_CHK_STATUS_RET(pass_for_optimize.AddPass("MultiBatchPass", new (std::nothrow) MultiBatchPass)); + GE_CHK_STATUS_RET(pass_for_optimize.AddPass("CompileNodesPass", new (std::nothrow) CompileNodesPass)); GE_TIMESTAMP_START(pass_for_optimize); ret = pass_for_optimize.Run(compute_graph); GE_TIMESTAMP_END(pass_for_optimize, "GraphManager::OptimizePass"); @@ -1817,27 +1880,36 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra return SUCCESS; } -Status GraphManager::LoadGraphAsync(const GeModelPtr &ge_model, const GraphNodePtr &graph_node) { +Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { GELOGI("[LoadGraphAsync] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); - if (options_.run_graph_flag && ge_model != nullptr) { + if (options_.run_graph_flag && ge_root_model != nullptr) { // synchronization run graph with model ModelIdInfo model_id_info; - if (getenv(kEnvGeuseStaticMemory) != nullptr) { - GELOGI("[LoadGraphAsync] GE_USE_STATIC_MEMORY is seted."); - } else { - GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); + bool is_unknown_shape = false; + GE_CHK_STATUS_RET(ge_root_model->CheckIsUnknownShape(is_unknown_shape)); + if (!is_unknown_shape) { + if (getenv(kEnvGeuseStaticMemory) != nullptr) { + GELOGI("[LoadGraphAsync] GE_USE_STATIC_MEMORY is seted."); + } else { + auto root_graph = ge_root_model->GetRootGraph(); + GE_CHECK_NOTNULL(root_graph); + auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); + GeModelPtr ge_model = name_to_model[root_graph->GetName()]; + GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); + } } GE_TIMESTAMP_START(LoadGraph); GE_CHECK_NOTNULL(graph_node->graph_run_async_listener_); - Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_model, graph_node->graph_run_async_listener_); + Status ret = + GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, graph_node->graph_run_async_listener_); GE_TIMESTAMP_END(LoadGraph, "GraphManager::LoadGraphAsync"); if (ret != SUCCESS) { GELOGE(ret, "[LoadGraphAsync] LoadGraphAsync Failed"); graph_node->SetRunFlag(false); return ret; } - ge_model->SetModelId(model_id_info.model_id); - graph_node->SetGeModel(ge_model); + ge_root_model->SetModelId(model_id_info.model_id); + graph_node->SetGeRootModel(ge_root_model); } return SUCCESS; } @@ -1872,7 +1944,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra rtError_t rt_ret; for (auto &it : graph_map_) { auto graph_id = it.second->GetGraphId(); - auto model = it.second->GetGeModel(); + auto model = it.second->GetGeRootModel(); if (model == nullptr) { continue; } @@ -1925,8 +1997,7 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager GELOGI("ProcessSubGraphWithMultiThreads start, graph name is %s, engine_name is %s, thread id is %lu", compute_graph_tmp != nullptr ? compute_graph_tmp->GetName().c_str() : "", engine_name.c_str(), pthread_self()); - GraphUtils::DumpGEGraph(compute_graph_tmp, "OptimizeSubGraphBefore"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_tmp, "OptimizeSubGraphBefore"); + GE_DUMP(compute_graph_tmp, "OptimizeSubGraphBefore"); GE_CHECK_NOTNULL(compute_graph_tmp); compute_graph_tmp->SetSessionID(session_id); ret = graph_manager->graph_optimize_.OptimizeSubGraph(compute_graph_tmp, engine_name); @@ -1936,8 +2007,7 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager } else { GELOGI("SubGraph optimize success %s", engine_name.c_str()); } - GraphUtils::DumpGEGraph(compute_graph_tmp, "OptimizeSubGraphAfter"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_tmp, "OptimizeSubGraphAfter"); + GE_DUMP(compute_graph_tmp, "OptimizeSubGraphAfter"); sub_graph_info_ptr->SetSubGraph(compute_graph_tmp); GELOGI("ProcessSubGraphWithMultiThreads end, graph name is %s, engine_name is %s, thread id is %lu", compute_graph_tmp != nullptr ? compute_graph_tmp->GetName().c_str() : "", engine_name.c_str(), @@ -2075,7 +2145,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { GELOGI("Start for run graph async."); - GeModelPtr ge_model = nullptr; + GeRootModelPtr ge_root_model = nullptr; if (graph_manager->IsGraphNeedBuild(graph_node)) { if (graph_node->GetBuildFlag()) { ReturnError(graph_manager, args.callback, PARAM_INVALID, @@ -2087,8 +2157,9 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { } // check need incre build. + GeModelPtr ge_model = nullptr; if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { - ret = graph_manager->PreRun(graph_node, ge_inputs, ge_models, ge_model, args.session_id); + ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); if (ret != SUCCESS) { graph_node->SetRunFlag(false); ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); @@ -2099,11 +2170,11 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { graph_node->SetBuildFlag(true); graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); } else { - ge_model = graph_node->GetGeModel(); + ge_root_model = graph_node->GetGeRootModel(); } graph_manager->run_args_q_.Push( - RunArgs({graph_node, args.graph_id, args.input_tensor, ge_model, GetThreadLocalContext(), args.callback})); + RunArgs({graph_node, args.graph_id, args.input_tensor, ge_root_model, GetThreadLocalContext(), args.callback})); GELOGI("Loop end."); } } @@ -2126,8 +2197,8 @@ void GraphManager::RunThread(GraphManager *graph_manager) { Status ret; if (!args.graph_node->GetLoadFlag()) { - ret = graph_manager->LoadGraphAsync(args.ge_model, args.graph_node); - if (ret != SUCCESS) { + ret = graph_manager->LoadGraphAsync(args.ge_root_model, args.graph_node); + if (ret != SUCCESS || args.ge_root_model == nullptr) { StopQueue(graph_manager); ReturnError(graph_manager, args.callback, ret, "LoadGraphAsync failed, thread exit."); args.graph_node->Unlock(); @@ -2135,7 +2206,7 @@ void GraphManager::RunThread(GraphManager *graph_manager) { } args.graph_node->SetLoadFlag(true); GELOGI("LoadGraph[%u], model[%u] success and set LoadFlag to true.", args.graph_node->GetGraphId(), - args.ge_model->GetModelId()); + args.ge_root_model->GetModelId()); } if (graph_manager->GetTrainFlag()) { @@ -2146,8 +2217,8 @@ void GraphManager::RunThread(GraphManager *graph_manager) { graph_manager->graph_executor_.SetTrainFlag(graph_manager->options_.train_graph_flag); } - ret = - graph_manager->graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeModel(), args.input_tensor); + ret = graph_manager->graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeRootModel(), + args.input_tensor); args.graph_node->SetRunFlag(false); args.graph_node->Unlock(); if (ret != SUCCESS) { @@ -2214,24 +2285,40 @@ const map *GraphManager::GetGraphOptions(uint32_t grap } return &(graph_node->GetOptions()); } + +void GraphManager::SetOptionsRunGraphFlag(bool run_graph_flag) { options_.run_graph_flag = run_graph_flag; } + Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id) { // graph partition // all sub graph list of root graph and sub graph + GE_TIMESTAMP_START(GraphPartitionDynamicShape); + DynamicShapePartitioner dynamic_shape_partitioner(compute_graph); + auto ret = dynamic_shape_partitioner.Partition(); + if (ret != SUCCESS) { + GELOGE(ret, "Graph partition by dynamic shape Failed"); + return ret; + } + bool dynamic_shape_partitioned = false; + if (!AttrUtils::GetBool(*compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, dynamic_shape_partitioned)) { + GELOGE(FAILED, "failed get dynamic shape partitioned flag on partitioned graph."); + return FAILED; + } + GE_TIMESTAMP_END(GraphPartitionDynamicShape, "OptimizeSubgraph::GraphPartitionDynamicShape"); GE_TIMESTAMP_START(GraphPartition); - auto ret = graph_partitioner_.Partition(compute_graph, GraphPartitioner::kPartitioning); + ret = graph_partitioner_.Partition(compute_graph, GraphPartitioner::kPartitioning); if (ret != SUCCESS) { GELOGE(ret, "Graph partition Failed"); return ret; } - GE_TIMESTAMP_END(GraphPartition, "GraphPartitioner::Partition1"); + GE_TIMESTAMP_END(GraphPartition, "OptimizeSubgraph::Partition1"); GE_TIMESTAMP_START(SetSubgraph); ret = SetSubgraph(session_id, compute_graph); if (ret != SUCCESS) { GELOGE(ret, "Graph set subgraph Failed"); return ret; } - GE_TIMESTAMP_END(SetSubgraph, "SetSubGraph"); + GE_TIMESTAMP_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); ComputeGraphPtr merged_compute_graph = nullptr; std::vector merged_sub_graph_list; @@ -2245,20 +2332,22 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra GE_CHECK_NOTNULL(merged_compute_graph); merged_compute_graph->SetSessionID(session_id); merged_compute_graph->SetGraphID(graph_node->GetGraphId()); - GraphUtils::DumpGEGraph(merged_compute_graph, "mergedComputeGraph"); - GraphUtils::DumpGEGraphToOnnx(*merged_compute_graph, "mergedComputeGraph"); + merged_compute_graph->SetNeedIteration(compute_graph->GetNeedIteration()); for (auto &sub_graph : merged_compute_graph->GetAllSubgraphs()) { sub_graph->SetSessionID(session_id); sub_graph->SetGraphID(graph_node->GetGraphId()); - GraphUtils::DumpGEGraph(sub_graph, "mergedComputeGraph_subgraph"); - GraphUtils::DumpGEGraphToOnnx(*sub_graph, "mergedComputeGraph_subgraph"); } - GE_TIMESTAMP_END(MergeSubgraph, "GraphManager::MergeSubGraph"); + GE_TIMESTAMP_END(MergeSubgraph, "OptimizeSubgraph::MergeSubGraph"); + GE_DUMP(merged_compute_graph, "mergedComputeGraph"); compute_graph = merged_compute_graph; + if (!AttrUtils::SetBool(*compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, dynamic_shape_partitioned)) { + GELOGE(FAILED, "failed set dynamic shape partitioned flag on partitioned graph."); + return FAILED; + } return SUCCESS; } Status GraphManager::Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, - vector &ge_models, GeModelPtr &ge_model, uint64_t session_id) { + GeRootModelPtr &ge_root_model, uint64_t session_id) { // build if (compute_graph != nullptr) { std::string graph_name = compute_graph->GetName(); @@ -2267,7 +2356,7 @@ Status GraphManager::Build(const GraphNodePtr &graph_node, ComputeGraphPtr &comp compute_graph->SetName(graph_name); } std::vector sub_graph_list; - auto ret = graph_builder_.Build(compute_graph, sub_graph_list, ge_model, session_id); + auto ret = graph_builder_.Build(compute_graph, sub_graph_list, ge_root_model, session_id); if (ret != SUCCESS) { GELOGE(ret, "SubGraph build Failed."); return ret; @@ -2282,18 +2371,7 @@ Status GraphManager::Build(const GraphNodePtr &graph_node, ComputeGraphPtr &comp GraphUtils::DumpGEGraph(compute_graph, "Build", is_always_dump); GraphUtils::DumpGEGraphToOnnx(*compute_graph, "Build"); - // set modelptr to subgraph - for (const auto &sub_graph_info : sub_graph_list) { - sub_graph_info->SetGeModelPtr(ge_model); - } - - ge_models.push_back(ge_model); - - GE_IF_BOOL_EXEC(sub_graph_list.empty(), GELOGE(FAILED, "Input graph must have at least one calculation op Node"); - return FAILED;); - sub_graph_list[0]->SetSubGraph(compute_graph); - // set subgraphlist to graphnode - graph_node->SetSubGraph(sub_graph_list); + graph_node->SetGeRootModel(ge_root_model); return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/manager/graph_manager.h b/src/ge/graph/manager/graph_manager.h index d13a2929..dec88cdc 100644 --- a/src/ge/graph/manager/graph_manager.h +++ b/src/ge/graph/manager/graph_manager.h @@ -99,7 +99,7 @@ class GraphManager { /// @param [out] models build result /// @return Status result of function /// - Status BuildGraph(const GraphId &graph_id, const std::vector &inputs, vector &models); + ge::Status BuildGraph(const GraphId &graph_id, const std::vector &inputs, GeRootModelPtr &models); /// /// @ingroup ge_graph @@ -153,6 +153,8 @@ class GraphManager { const std::map *GetGraphOptions(uint32_t graph_id); + void SetOptionsRunGraphFlag(bool run_graph_flag); + private: struct PreRunArgs { GraphId graph_id; @@ -166,7 +168,7 @@ class GraphManager { GraphNodePtr graph_node; GraphId graph_id; std::vector input_tensor; - GeModelPtr ge_model; + GeRootModelPtr ge_root_model; GEThreadLocalContext context; RunAsyncCallback callback; }; @@ -177,19 +179,16 @@ class GraphManager { static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, const SubGraphInfoPtr &sub_graph_info_ptr, uint64_t session_id, const GEThreadLocalContext &ge_context); - Status PreRun(const GraphNodePtr &graph_node, const std::vector &inputs, vector &ge_models, - GeModelPtr &ge_model, uint64_t session_id = INVALID_SESSION_ID); - - Status PreRunDynShape(const GraphNodePtr &graph_node, const std::vector &inputs, - vector &ge_models, GeModelPtr &ge_model, uint64_t session_id = INVALID_SESSION_ID); + Status PreRun(const GraphNodePtr &graph_node, const std::vector &inputs, GeRootModelPtr &ge_root_model, + uint64_t session_id = INVALID_SESSION_ID); Status OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id); - Status Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, vector &ge_models, - GeModelPtr &ge_model, uint64_t session_id); + Status Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, GeRootModelPtr &ge_root_model, + uint64_t session_id); Status StartForRunGraph(const GraphNodePtr &graph_node, const std::vector &inputs, - vector &ge_models, uint64_t session_id = INVALID_SESSION_ID); + GeRootModelPtr &ge_root_model, uint64_t session_id = INVALID_SESSION_ID); Status InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, const std::vector &inputs, std::vector &outputs); @@ -240,6 +239,8 @@ class GraphManager { Status SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph); + void SetAttrForHcomBroadCastOp(ge::ComputeGraphPtr &compute_graph); + bool IsBroadCastOpData(const ge::NodePtr &var_node); void AdjustBroadCastOpData(const ge::NodePtr &var_node); @@ -258,6 +259,7 @@ class GraphManager { std::shared_ptr GetGraphContext() const { return graph_context_; } Status RemoveIsolatedConst(ge::ComputeGraphPtr &compute_graph); + Status RemoveIsolatedConstInThisGraph(ge::ComputeGraphPtr &compute_graph); Status OptimizeStage1(ComputeGraphPtr &compute_graph); Status OptimizeStage2(ComputeGraphPtr &compute_graph); @@ -265,13 +267,13 @@ class GraphManager { Status NewOptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph); - Status LoadGraphAsync(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); + Status LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); Status CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); - bool CheckModelLoad(const GeModelPtr &ge_model, bool load_flag); + bool CheckModelLoad(const GeRootModelPtr &ge_model, bool load_flag); - Status LoadGraph(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); + Status LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); bool IsGraphNeedBuild(const GraphNodePtr &graph_node); @@ -287,6 +289,8 @@ class GraphManager { static void StopQueue(GraphManager *graph_manager); static void ReturnError(GraphManager *graph_manager, RunAsyncCallback callback, Status ret, const string &log); + void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph); + std::atomic_bool thread_run_flag_; BlockingQueue prerun_args_q_{}; BlockingQueue run_args_q_{}; diff --git a/src/ge/graph/manager/graph_manager_utils.h b/src/ge/graph/manager/graph_manager_utils.h index b595e182..746933a9 100644 --- a/src/ge/graph/manager/graph_manager_utils.h +++ b/src/ge/graph/manager/graph_manager_utils.h @@ -36,6 +36,7 @@ #include "graph/graph.h" #include "graph/model.h" #include "model/ge_model.h" +#include "model/ge_root_model.h" #include "register/register_fmk_types.h" #include "external/ge/ge_api_types.h" @@ -160,6 +161,8 @@ class GraphNode { void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; } void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; } GeModelPtr GetGeModel() const { return ge_model_; } + void SetGeRootModel(const GeRootModelPtr &ge_root_model) { ge_root_model_ = ge_root_model; } + GeRootModelPtr GetGeRootModel() const { return ge_root_model_; } const std::map &GetOptions() const { return options_; } void SetOptions(const std::map &options) { options_ = options; } void Lock(); @@ -179,6 +182,7 @@ class GraphNode { bool build_flag_; bool load_flag_; GeModelPtr ge_model_; + GeRootModelPtr ge_root_model_; BlockingQueue sem_; }; diff --git a/src/ge/graph/manager/graph_mem_allocator.cc b/src/ge/graph/manager/graph_mem_allocator.cc index 95773f11..e63039dc 100644 --- a/src/ge/graph/manager/graph_mem_allocator.cc +++ b/src/ge/graph/manager/graph_mem_allocator.cc @@ -15,6 +15,7 @@ */ #include "graph/manager/graph_mem_allocator.h" +#include "graph/manager/graph_caching_allocator.h" #include #include @@ -47,7 +48,7 @@ void MemoryAllocator::Finalize(uint32_t device_id) { memory_base_map_.clear(); } -uint8_t *MemoryAllocator::MallocMemory(const string &purpose, uint64_t memory_size, uint32_t device_id) const { +uint8_t *MemoryAllocator::MallocMemory(const string &purpose, size_t memory_size, uint32_t device_id) const { uint8_t *memory_addr = nullptr; if (rtMalloc(reinterpret_cast(&memory_addr), memory_size, memory_type_) != RT_ERROR_NONE) { @@ -74,7 +75,7 @@ Status MemoryAllocator::FreeMemory(uint8_t *memory_addr, uint32_t device_id) con return ge::SUCCESS; } -uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memory_key, uint64_t memory_size, +uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memory_key, size_t memory_size, uint32_t device_id) { auto it = memory_base_map_.find(memory_key); if (it != memory_base_map_.end()) { @@ -147,7 +148,7 @@ uint8_t *MemoryAllocator::GetMemoryAddr(const string &memory_key, uint32_t devic return it->second.memory_addr_; } -MemManager::MemManager() : default_memory_allocator_(nullptr) {} +MemManager::MemManager() {} MemManager::~MemManager() { Finalize(); } @@ -159,7 +160,7 @@ MemManager &MemManager::Instance() { MemoryAllocator *MemManager::Instance(rtMemType_t memory_type) { return Instance().GetMemoryAllocator(memory_type); } Status MemManager::Initialize(const std::vector &memory_type) { - std::lock_guard lock(allocator_mutex_); + std::lock_guard lock(allocator_mutex_); MemoryAllocator *memory_allocator = nullptr; for (unsigned int index : memory_type) { auto it = memory_allocator_map_.find(index); @@ -184,34 +185,34 @@ Status MemManager::Initialize(const std::vector &memory_type) { } } - default_memory_allocator_ = new (std::nothrow) MemoryAllocator(RT_MEMORY_RESERVED); - if (default_memory_allocator_ == nullptr) { - GELOGE(ge::INTERNAL_ERROR, "Create MemoryAllocator failed."); - return ge::INTERNAL_ERROR; - } - return ge::SUCCESS; + return InitCachingAllocator(memory_type); } void MemManager::Finalize() noexcept { GELOGI("Finalize."); - std::lock_guard lock(allocator_mutex_); + std::lock_guard lock(allocator_mutex_); + // caching allocator use memory allocator, so finalize it first + for (auto &caching_allocator : caching_allocator_map_) { + if (caching_allocator.second != nullptr) { + caching_allocator.second->Finalize(); + delete caching_allocator.second; + caching_allocator.second = nullptr; + } + } + caching_allocator_map_.clear(); + for (auto &memory_allocator : memory_allocator_map_) { if (memory_allocator.second != nullptr) { - memory_allocator.second->Finalize(0); + memory_allocator.second->Finalize(); delete memory_allocator.second; memory_allocator.second = nullptr; } } - - if (default_memory_allocator_ != nullptr) { - delete default_memory_allocator_; - default_memory_allocator_ = nullptr; - } memory_allocator_map_.clear(); } MemoryAllocator *MemManager::GetMemoryAllocator(rtMemType_t memory_type) { - std::lock_guard lock(allocator_mutex_); + std::lock_guard lock(allocator_mutex_); MemoryAllocator *memory_allocator = nullptr; auto it = memory_allocator_map_.find(memory_type); if (it != memory_allocator_map_.end()) { @@ -221,9 +222,60 @@ MemoryAllocator *MemManager::GetMemoryAllocator(rtMemType_t memory_type) { // Usually impossible if (memory_allocator == nullptr) { GELOGE(ge::INTERNAL_ERROR, "GetMemoryAllocator failed, memory type is %u.", memory_type); - return default_memory_allocator_; + static MemoryAllocator default_memory_allocator(RT_MEMORY_RESERVED); + return &default_memory_allocator; } return memory_allocator; } + +Status MemManager::InitCachingAllocator(const std::vector &memory_type) { + CachingAllocator *caching_allocator = nullptr; + for (unsigned int index : memory_type) { + auto it = caching_allocator_map_.find(index); + if (it == caching_allocator_map_.end()) { + caching_allocator = new (std::nothrow) CachingAllocator(index); + if (caching_allocator != nullptr) { + caching_allocator_map_[index] = caching_allocator; + GELOGI("Create CachingAllocator memory type[%u] success.", index); + } else { + GELOGE(ge::INTERNAL_ERROR, "Alloc CachingAllocator failed."); + } + } else { + caching_allocator = it->second; + } + + if (caching_allocator == nullptr) { + GELOGE(ge::INTERNAL_ERROR, "Create CachingAllocator failed."); + return ge::INTERNAL_ERROR; + } else { + if (caching_allocator->Initialize() != ge::SUCCESS) { + return ge::INTERNAL_ERROR; + } + } + } + return ge::SUCCESS; +} + +CachingAllocator &MemManager::GetCachingAllocator(rtMemType_t memory_type) { + std::lock_guard lock(allocator_mutex_); + CachingAllocator *caching_allocator = nullptr; + auto it = caching_allocator_map_.find(memory_type); + if (it != caching_allocator_map_.end()) { + caching_allocator = it->second; + } + + // Usually impossible + if (caching_allocator == nullptr) { + GELOGE(ge::INTERNAL_ERROR, "GetCachingAllocator failed, memory type is %u.", memory_type); + static CachingAllocator default_caching_allocator(RT_MEMORY_RESERVED); + return default_caching_allocator; + ; + } + return *caching_allocator; +} + +CachingAllocator &MemManager::CachingInstance(rtMemType_t memory_type) { + return Instance().GetCachingAllocator(memory_type); +} } // namespace ge diff --git a/src/ge/graph/manager/graph_mem_allocator.h b/src/ge/graph/manager/graph_mem_allocator.h index 9622e07a..7bf82897 100644 --- a/src/ge/graph/manager/graph_mem_allocator.h +++ b/src/ge/graph/manager/graph_mem_allocator.h @@ -88,7 +88,7 @@ class MemoryAllocator { /// @param [in] device_id device id /// @return memory address /// - uint8_t *MallocMemory(const string &purpose, uint64_t memory_size, uint32_t device_id = 0) const; + uint8_t *MallocMemory(const string &purpose, size_t memory_size, uint32_t device_id = 0) const; /// /// @ingroup ge_graph @@ -108,7 +108,7 @@ class MemoryAllocator { /// @param [in] device_id device id /// @return memory address /// - uint8_t *MallocMemory(const string &purpose, const string &memory_key, uint64_t memory_size, uint32_t device_id = 0); + uint8_t *MallocMemory(const string &purpose, const string &memory_key, size_t memory_size, uint32_t device_id = 0); /// /// @ingroup ge_graph @@ -135,6 +135,7 @@ class MemoryAllocator { }; using MemoryAllocatorPtr = std::shared_ptr; +class CachingAllocator; class MemManager { public: @@ -142,6 +143,7 @@ class MemManager { virtual ~MemManager(); static MemManager &Instance(); static MemoryAllocator *Instance(rtMemType_t memory_type); + static CachingAllocator &CachingInstance(rtMemType_t memory_type); MemManager(const MemManager &) = delete; MemManager &operator=(const MemManager &) = delete; /// @@ -164,13 +166,29 @@ class MemManager { /// @ingroup ge_graph /// @brief ge memory allocator /// @param [in] memory_type memory type - /// @return Status result of function + /// @return MemoryAllocator ptr /// MemoryAllocator *GetMemoryAllocator(rtMemType_t memory_type); + /// + /// @ingroup ge_graph + /// @brief ge caching allocator + /// @param [in] memory_type memory type + /// @return CachingAllocator ptr + /// + CachingAllocator &GetCachingAllocator(rtMemType_t memory_type); + + /// + /// @ingroup ge_graph + /// @brief ge create caching allocator + /// @param [in] memory_type memory type + /// @return Status result of function + /// + Status InitCachingAllocator(const std::vector &memory_type); + std::map memory_allocator_map_; - MemoryAllocator *default_memory_allocator_; - std::mutex allocator_mutex_; + std::map caching_allocator_map_; + std::recursive_mutex allocator_mutex_; }; }; // namespace ge diff --git a/src/ge/graph/manager/trans_var_data_utils.cc b/src/ge/graph/manager/trans_var_data_utils.cc index 6109b120..e8444c53 100644 --- a/src/ge/graph/manager/trans_var_data_utils.cc +++ b/src/ge/graph/manager/trans_var_data_utils.cc @@ -25,8 +25,343 @@ #include "graph/manager/graph_var_manager.h" #include "graph/types.h" #include "graph/utils/type_utils.h" +#include "common/thread_pool.h" +#include namespace ge { +namespace { +class RtContextSwitchGuard { + public: + RtContextSwitchGuard(rtCtxMode_t mode, uint32_t device_id) : last_(nullptr), current_(nullptr) { + auto ret = rtCtxGetCurrent(&last_); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Failed to get current context from rt, error-code %d", ret); + return; + } + + ret = rtCtxCreate(¤t_, mode, static_cast(device_id)); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Failed to create new context for device %u, error-code %d", device_id, ret); + return; + } + + ret = rtCtxSetCurrent(current_); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Failed to switch context to normal, context %p, device %u", current_, device_id); + return; + } + GELOGD("Create and switch rt context %p type %d for device %u, backup last %p.", current_, mode, device_id, last_); + } + + ~RtContextSwitchGuard() { + if (current_ != nullptr) { + auto ret = rtCtxDestroy(current_); + GELOGD("Destory current context %p result %d", current_, ret); + } + if (last_ != nullptr) { + auto ret = rtCtxSetCurrent(last_); + GELOGD("Recovery last context %p result %d.", last_, ret); + } + } + + private: + rtContext_t last_; + rtContext_t current_; +}; + +int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { + int64_t var_size = GetSizeByDataType(desc.GetDataType()); + if (var_size <= 0) { + GELOGE(PARAM_INVALID, "Failed to calc var data size from data type %s", + TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str()); + return -1; + } + auto shape = desc.GetShape(); + auto dim_num = shape.GetDimNum(); + for (size_t dim_index = 0; dim_index < dim_num; ++dim_index) { + var_size *= shape.GetDim(dim_index); + } + return var_size; +} + +Status CopyVarToDevice(const NodePtr &var, const formats::TransResult &trans_result, void *var_addr) { + GELOGD("Copy var %s from host to device, size %zu", var->GetName().c_str(), trans_result.length); + auto ret = rtMemcpy(var_addr, trans_result.length, reinterpret_cast(trans_result.data.get()), + trans_result.length, RT_MEMCPY_HOST_TO_DEVICE); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Failed to copy memory to device, size %zu", trans_result.length); + return RT_FAILED; + } + return SUCCESS; +} + +Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_ptr &var_data, + const GeTensorDesc &input_desc) { + uint8_t *var_logic = nullptr; + GE_CHECK_NOTNULL(var); + auto ret = VarManager::Instance(session_id)->GetVarAddr(var->GetName(), input_desc, &var_logic); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, + "Failed to copy var %s from device, can not find it" + " from var manager %u", + var->GetName().c_str(), ret); + return INTERNAL_ERROR; + } + + uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM); + if (var_addr == nullptr) { + GELOGE(INTERNAL_ERROR, + "Failed to copy var %s from device, cant not get " + "var addr from logic addr %p", + var->GetName().c_str(), var_logic); + return INTERNAL_ERROR; + } + + int64_t var_size_bytes = CalcVarSizeInBytes(input_desc); + if (var_size_bytes <= 0) { + return INTERNAL_ERROR; + } + + std::unique_ptr var_host(new (std::nothrow) uint8_t[var_size_bytes]); + if (var_host == nullptr) { + GELOGE(OUT_OF_MEMORY, "Failed to malloc rt-host memory, size %ld", var_size_bytes); + return OUT_OF_MEMORY; + } + + ret = rtMemcpy(reinterpret_cast(var_host.get()), var_size_bytes, reinterpret_cast(var_addr), + var_size_bytes, RT_MEMCPY_DEVICE_TO_HOST); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, + "Failed to copy var memory from device, var %s, size %ld," + " rt-error-code %u", + var->GetName().c_str(), var_size_bytes, ret); + return RT_FAILED; + } + + GELOGD("Copy var %s from device to host, size %ld", var->GetName().c_str(), var_size_bytes); + var_data.swap(var_host); + + GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr); + + return SUCCESS; +} + +Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats::TransResult &result) { + formats::TransResult result_last_time{}; + bool use_init_data = true; + for (const auto &trans_info : trans_road) { + if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) { + GELOGD("Skip to trans variable data on the reshape/reformat node"); + continue; + } + uint8_t *src_data = nullptr; + if (use_init_data) { + src_data = var_data; + use_init_data = false; + } else { + src_data = result_last_time.data.get(); + } + + formats::TransResult tmp_result{}; + if (trans_info.node_type == TRANSDATA || trans_info.node_type == TRANSPOSED) { + auto src_format = trans_info.input.GetFormat(); + auto src_shape = trans_info.input.GetShape().GetDims(); + auto dst_format = trans_info.output.GetFormat(); + auto dst_shape = trans_info.output.GetShape().GetDims(); + auto data_type = trans_info.input.GetDataType(); + GELOGD("Trans format from %s to %s, shape %s to %s, data-type %s", + TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), + formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + auto ret = formats::TransFormat({src_data, src_format, dst_format, src_shape, dst_shape, data_type}, tmp_result); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, + "Failed to trans format from %s to %s, shape %s to %s, " + "data type %s error code %u", + TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), + formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str(), ret); + return ret; + } + } else if (trans_info.node_type == CAST) { + auto input_shape = trans_info.input.GetShape(); + auto src_data_size = input_shape.GetShapeSize() == 0 ? 1 : input_shape.GetShapeSize(); + auto src_data_type = trans_info.input.GetDataType(); + auto dst_data_type = trans_info.output.GetDataType(); + GELOGD("Trans data type from %s to %s, input shape %s, data size %ld", + TypeUtils::DataTypeToSerialString(src_data_type).c_str(), + TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), + src_data_size); + auto ret = formats::TransDataType({src_data, static_cast(src_data_size), src_data_type, dst_data_type}, + tmp_result); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to trans data type from %s to %s, input shape %s, data size %ld, error code %u", + TypeUtils::DataTypeToSerialString(src_data_type).c_str(), + TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), + src_data_size, ret); + return ret; + } + } else { + GELOGE(UNSUPPORTED, "Failed to trans var data, the trans type %s does not supported", + trans_info.node_type.c_str()); + return UNSUPPORTED; + } + result_last_time = tmp_result; + } + + result = result_last_time; + return SUCCESS; +} + +/// re-alloc var memory on device using var-manager +/// free origin var memory(var manager does not support now) +/// @param session_id +/// @param var +/// @param var_size_bytes +/// @param var_device +/// @return +Status ReAssignVarAddr(uint64_t session_id, const std::string &var_name, const GeTensorDesc &tensor_desc, + void **var_device) { + uint8_t *var_logic = nullptr; + Status ret = VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &var_logic); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, + "Failed to get var %s device addr, can not find it" + " from var manager %u", + var_name.c_str(), ret); + return INTERNAL_ERROR; + } + + uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM); + if (var_addr == nullptr) { + GELOGE(INTERNAL_ERROR, "Failed to convert var %s logic addr to real addr", var_name.c_str()); + return INTERNAL_ERROR; + } + *var_device = var_addr; + + GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr); + + return SUCCESS; +} + +Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t session_id) { + // do not need to do anything if only all reshape/reformat node on the trans_road + GE_CHECK_NOTNULL(var); + bool need_trans = false; + for (auto &road : trans_road) { + if (road.node_type != RESHAPE && road.node_type != REFORMAT) { + need_trans = true; + break; + } + } + if (!need_trans) { + return SUCCESS; + } + + // Sync var data from device + std::unique_ptr var_data; + if (trans_road.empty()) { + GELOGE(INTERNAL_ERROR, "Failed to get trans_road, trans_road is empty."); + return INTERNAL_ERROR; + } + const GeTensorDesc &input_desc = trans_road.begin()->input; + auto ret = CopyVarFromDevice(session_id, var, var_data, input_desc); + if (ret != SUCCESS) { + return ret; + } + + formats::TransResult trans_result{}; + ret = TransVarOnHost(var_data.get(), trans_road, trans_result); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to trans var data on host, error code %u", ret); + return ret; + } + + void *var_device = nullptr; + + /// It is a temporary solution to use the last GeTensorDesc to assign variable memory because the variable manager + /// depends on TensorDesc and it is difficult to be modified. The correct solution is to assign memory based on the + /// size of the converted variable. To complete the final solution, the dependency of the variable manager on + /// TensorDesc needs to be removed. This change is large and needs to be performed step by step. + ret = ReAssignVarAddr(session_id, var->GetName(), trans_road.rbegin()->output, &var_device); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to re-assign memory on device, size %zu", trans_result.length); + return ret; + } + + // sync new data to device + ret = CopyVarToDevice(var, trans_result, var_device); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to send var data to device"); + return ret; + } + + return SUCCESS; +} + +Status TransTensor(uint8_t *var_data, const NodePtr &var_src, const NodePtr &var_dst, formats::TransResult &result) { + GE_CHECK_NOTNULL(var_src); + GE_CHECK_NOTNULL(var_src->GetOpDesc()); + GE_CHECK_NOTNULL(var_dst); + GE_CHECK_NOTNULL(var_dst->GetOpDesc()); + auto src_data_shape_size = var_src->GetOpDesc()->GetOutputDesc(0).GetShape().GetShapeSize(); + auto src_data_datatype = var_src->GetOpDesc()->GetOutputDesc(0).GetDataType(); + auto dst_data_datatype = var_dst->GetOpDesc()->GetOutputDesc(0).GetDataType(); + GE_IF_BOOL_EXEC( + src_data_datatype != dst_data_datatype, + auto ret = formats::TransDataType( + {var_data, static_cast(src_data_shape_size), src_data_datatype, dst_data_datatype}, result); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "trans var data on host failed"); + return ret; + }); + return SUCCESS; +} + +Status CopyTensorFromSrcVarNode(const NodePtr &var_src, const NodePtr &var_dst, uint64_t session_id, + uint32_t device_id) { + /// after FE fusion pass, input num of applymomentum op was changed, 0th input is var_fp32, 6th input is + /// var_fp16(new). + /// unlink edges between var_fp32 and "dst_node" (need fp16) of var_fp32, add edge between var_fp16 and dst_node. + /// need copy value from var_fp32 to var_fp16. + /// [opdesc of var_src and var_dst are checked before passed in, no need to check if they are nullptr] + GE_IF_BOOL_EXEC(var_src == nullptr || var_dst == nullptr, GELOGE(FAILED, "node var is nullptr"); return FAILED); + // src_node output_desc (fp32) + GeTensorDesc output_desc = var_src->GetOpDesc()->GetOutputDesc(0); + auto src_data_type = output_desc.GetDataType(); + auto src_shape = output_desc.GetShape(); + auto src_format = output_desc.GetFormat(); + GELOGI("src_node %s, src_format %s, src_shape %s, src_type %s", var_src->GetName().c_str(), + TypeUtils::FormatToSerialString(src_format).c_str(), formats::ShapeToString(src_shape).c_str(), + TypeUtils::DataTypeToSerialString(src_data_type).c_str()); + // dst_node output_desc (fp16) + GeTensorDesc dst_tensor_desc = var_dst->GetOpDesc()->GetOutputDesc(0); + auto data_type = dst_tensor_desc.GetDataType(); + auto data_shape = dst_tensor_desc.GetShape(); + auto data_format = dst_tensor_desc.GetFormat(); + GELOGI("dst_node %s, src_format %s, src_shape %s, src_type %s", var_dst->GetName().c_str(), + TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + // Sync var data from device + std::unique_ptr var_src_data; + RtContextSwitchGuard switch_context(RT_CTX_NORMAL_MODE, device_id); + // copy from src_node + auto ret = CopyVarFromDevice(session_id, var_src, var_src_data, output_desc); + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "Copy Var From Device failed"); return ret); + // trans dtype + formats::TransResult trans_result{}; + ret = TransTensor(var_src_data.get(), var_src, var_dst, trans_result); + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "trans var data on host failed"); return ret); + // reset src value. + void *var_device = nullptr; + ret = ReAssignVarAddr(session_id, var_dst->GetName(), dst_tensor_desc, &var_device); + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "assign mem failed"); return ret); + // copy to device + ret = CopyVarToDevice(var_dst, trans_result, var_device); + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Failed to send var data to device"); return ret); + return SUCCESS; +} +} // namespace Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) { GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "dst addr is null. "); @@ -88,4 +423,101 @@ Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8 return SUCCESS; } + +Status TransVarDataUtils::TransAllVarData(const vector &variable_nodes, uint64_t session_id, + rtContext_t context, uint32_t graph_id, uint32_t thread_num) { + ThreadPool executor(thread_num); + std::vector> vector_future; + for (auto &node : variable_nodes) { + if (node == nullptr) { + continue; + } + + if (node->GetType() != VARIABLE) { + continue; + } + + std::future f = executor.commit( + [](const ge::NodePtr &node, uint64_t session_id, rtContext_t ctx, uint32_t graph_id) -> Status { + rtError_t rt_ret = rtCtxSetCurrent(ctx); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Failed to set context, error_code is: 0x%X.", rt_ret); + return RT_FAILED; + } + uint32_t allocated_graph_id = 0; + Status ret = VarManager::Instance(session_id)->GetAllocatedGraphId(node->GetName(), allocated_graph_id); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "var has not been allocated, node:%s, graph_id:%u.", node->GetName().c_str(), + graph_id); + return INTERNAL_ERROR; + } + uint32_t changed_graph_id = 0; + ret = VarManager::Instance(session_id)->GetChangedGraphId(node->GetName(), changed_graph_id); + bool call_trans_var = + (ret == SUCCESS && changed_graph_id == graph_id && changed_graph_id != allocated_graph_id); + if (call_trans_var) { + GELOGI("VarManager::GetChangedGraphId() success, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id); + VarTransRoad *trans_road = VarManager::Instance(session_id)->GetTransRoad(node->GetName()); + if (trans_road == nullptr) { + GELOGI("The variable %s does not have any trans road", node->GetName().c_str()); + return SUCCESS; + } + ret = TransVarData(node, *trans_road, session_id); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "TransVarData failed, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id); + return INTERNAL_ERROR; + } + VarManager::Instance(session_id)->RemoveChangedGraphId(node->GetName()); + } + return SUCCESS; + }, + node, session_id, context, graph_id); + if (!f.valid()) { + GELOGE(FAILED, "Future is invalid"); + return FAILED; + } + vector_future.push_back(std::move(f)); + } + + Status ret_status; + for (size_t i = 0; i < vector_future.size(); ++i) { + ret_status = vector_future[i].get(); + if (ret_status != SUCCESS) { + GELOGE(ret_status, "TransAllVarData:: trans %zu vardata failed", i); + return ret_status; + } + } + + return SUCCESS; +} + +Status TransVarDataUtils::CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id) { + GELOGI("CopyVarData start: session_id:%lu.", session_id); + if (compute_graph == nullptr) { + GELOGE(FAILED, "compute_graph is nullptr"); + return FAILED; + } + + string cp_from_node; + bool copy_value = false; + for (auto &node : compute_graph->GetAllNodes()) { + GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() != VARIABLE, continue); + GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), "_copy_from_var_node", cp_from_node), + GELOGI("Get original type of cp_from_node")); + if (cp_from_node.length() != 0) { + (void)ge::AttrUtils::GetBool(node->GetOpDesc(), "_copy_value", copy_value); // no need to check value + if (!copy_value) { + auto src_node = compute_graph->FindNode(cp_from_node); + GE_CHECK_NOTNULL(src_node); + GELOGI("current_var_node__: [%s] copy_from_var_node__: [%s].", node->GetName().c_str(), + src_node->GetName().c_str()); + auto ret = CopyTensorFromSrcVarNode(src_node, node, session_id, device_id); + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "copy tensor failed!"); return FAILED); + // only copy once + (void)ge::AttrUtils::SetBool(node->GetOpDesc(), "_copy_value", true); // no need to check value + } + } + } + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/manager/trans_var_data_utils.h b/src/ge/graph/manager/trans_var_data_utils.h index 69521dab..efdfa51f 100644 --- a/src/ge/graph/manager/trans_var_data_utils.h +++ b/src/ge/graph/manager/trans_var_data_utils.h @@ -22,6 +22,9 @@ #include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_types.h" #include "graph/utils/tensor_utils.h" +#include "graph/node.h" +#include "runtime/context.h" +#include "graph_var_manager.h" namespace ge { class TransVarDataUtils { @@ -31,6 +34,11 @@ class TransVarDataUtils { static ge::Status SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); + static ge::Status TransAllVarData(const std::vector &variable_nodes, uint64_t session_id, + rtContext_t context, uint32_t graph_id, uint32_t thread_num = 16); + + static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id); + private: static ge::Status SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, uint8_t **host_addr, int64_t &addr_size, uint64_t session_id_); diff --git a/src/ge/graph/manager/util/hcom_util.cc b/src/ge/graph/manager/util/hcom_util.cc index a1c4d769..4f6fe591 100644 --- a/src/ge/graph/manager/util/hcom_util.cc +++ b/src/ge/graph/manager/util/hcom_util.cc @@ -24,35 +24,49 @@ #include "graph/utils/type_utils.h" namespace ge { -Status HcomOmeUtil::GetHcomDataType(const ge::ConstOpDescPtr &op_desc, hcclDataType_t &data_type) { + +Status HcomOmeUtil::GetHcclDataType(const ge::ConstOpDescPtr &op_desc, + std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); + if (CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { + GELOGE(PARAM_INVALID, "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); + return PARAM_INVALID; + } + GELOGI("GetHcclDataType start, node[%s], opType[%s].", op_desc->GetName().c_str(), op_desc->GetType().c_str()); + if (op_desc->GetType() == HVDWAIT) { + return SUCCESS; + } ge::DataType src_data_type = ge::DT_FLOAT; - if (op_desc->GetType() == HCOMRECEIVE) { - bool ret = ge::AttrUtils::GetDataType(op_desc, HCOM_ATTR_DATA_TYPE, src_data_type); - if (ret == false) { - GELOGE(PARAM_INVALID, "op:HcomReceive, op desc no attr: dtype."); + for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { + if (op_desc->GetType() == HCOMRECEIVE) { + bool ret = ge::AttrUtils::GetDataType(op_desc, HCOM_ATTR_DATA_TYPE, src_data_type); + if (ret == false) { + GELOGE(PARAM_INVALID, "op:HcomReceive, op desc no attr: dtype."); + return PARAM_INVALID; + } + } else { + auto input_desc_ptr = op_desc->GetInputDescPtr(i); + GE_CHECK_NOTNULL(input_desc_ptr); + src_data_type = input_desc_ptr->GetDataType(); + } + + auto iter = kConstOpHcclDataType.find(static_cast(src_data_type)); + if (iter == kConstOpHcclDataType.end()) { + GELOGE(PARAM_INVALID, + "HcomOmeUtil:: Node: %s Optype: %s HcomDataType cann't support! Current Davinci Data Type : %s", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), + ge::TypeUtils::DataTypeToSerialString(src_data_type).c_str()); return PARAM_INVALID; } - } else { - auto input_desc_ptr = op_desc->GetInputDescPtr(0); - GE_CHECK_NOTNULL(input_desc_ptr); - src_data_type = input_desc_ptr->GetDataType(); - } - auto iter = kConstOpHcomDataType.find(static_cast(src_data_type)); - if (iter == kConstOpHcomDataType.end()) { - GELOGE(PARAM_INVALID, "HcomOmeUtil:: HcomDataType cann't support! Current Davinci Data Type : %s", - ge::TypeUtils::DataTypeToSerialString(src_data_type).c_str()); - return PARAM_INVALID; + kernel_hccl_infos[i].dataType = iter->second; } - - data_type = iter->second; return SUCCESS; } -Status HcomOmeUtil::GetHcomTypeSize(hcclDataType_t data_type, int32_t &size) { - auto iter = kConstOpHcomDataTypeSize.find(data_type); - GE_CHK_BOOL_EXEC(iter != kConstOpHcomDataTypeSize.end(), return PARAM_INVALID, +Status HcomOmeUtil::GetHcclTypeSize(hcclDataType_t data_type, int32_t &size) { + auto iter = kConstOpHcclDataTypeSize.find(data_type); + GE_CHK_BOOL_EXEC(iter != kConstOpHcclDataTypeSize.end(), return PARAM_INVALID, "HcomOmeUtil::HcomDataTypeSize , No DataTypeSize!"); size = iter->second; @@ -62,10 +76,14 @@ Status HcomOmeUtil::GetHcomTypeSize(hcclDataType_t data_type, int32_t &size) { Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType_t data_type, bool is_allgather, int &count) { GE_CHECK_NOTNULL(op_desc); + if (!IsHCOMOp(op_desc->GetType())) { + GELOGE(PARAM_INVALID, "HcomOmeUtil:: operator is not Hcom operator."); + return PARAM_INVALID; + } int64_t total_size = 0; int64_t align_size = 512; int32_t size = 0; - GE_CHK_STATUS_RET(HcomOmeUtil::GetHcomTypeSize(data_type, size), "GetHcomCount: GetHcomTypeSize fail!"); + GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclTypeSize(data_type, size), "GetHcomCount: GetHcclTypeSize fail!"); if (op_desc->GetType() == HCOMRECEIVE) { vector shape_dims; bool ret = ge::AttrUtils::GetListInt(op_desc, HCOM_ATTR_SHAPE, shape_dims); @@ -114,34 +132,207 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType return SUCCESS; } -Status HcomOmeUtil::GetHcomOperationType(const ge::ConstOpDescPtr &op_desc, hcclRedOp_t &op_type) { +Status HcomOmeUtil::GetHorovodCount(const ge::ConstOpDescPtr &op_desc, + std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); + if (!IsHorovodOp(op_desc->GetType())) { + GELOGE(PARAM_INVALID, "HcomOmeUtil:: operator is not Horovod operator."); + return PARAM_INVALID; + } + int64_t align_size = 512; + int32_t size = 0; + for (size_t i = 0; i < op_desc->GetInputsSize(); i++) { + GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclTypeSize(static_cast(kernel_hccl_infos[i].dataType), size), + "GetHorovodCount: GetHcclTypeSize fail!"); + int64_t input_size = 0; + int64_t block_size = 0; + GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(i)); + GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetInputDescPtr(i), input_size), + "get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); - std::string hcom_op_type; - GE_CHK_BOOL_EXEC(ge::AttrUtils::GetStr(op_desc, HCOM_ATTR_REDUCE_TYPE, hcom_op_type), return PARAM_INVALID, - "HcomOmeUtil::Get HCOM_ATTR_REDUCE_TYPE fail, not support!"); - - if (hcom_op_type == "min") { - op_type = HCCL_REP_OP_MIN; - } else if (hcom_op_type == "max") { - op_type = HCCL_REP_OP_MAX; - } else if (hcom_op_type == "prod") { - op_type = HCCL_REP_OP_PROD; - } else if (hcom_op_type == "sum") { - op_type = HCCL_REP_OP_SUM; - } else { - GELOGE(PARAM_INVALID, "HcomOmeUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [%s] not support!", hcom_op_type.c_str()); + int64_t shape_size = op_desc->GetInputDescPtr(i)->GetShape().GetShapeSize(); + GE_CHK_STATUS_RET(ge::CheckInt64Int32MulOverflow(shape_size, size), + "Product of shape size and size beyond INT64_MAX"); + if (kernel_hccl_infos[0].hccl_type == HVDCALLBACKALLGATHER) { + block_size = shape_size * size; + } else { + block_size = (input_size + align_size - 1) / align_size * align_size; + } + + GE_CHK_BOOL_RET_STATUS(size != 0, PARAM_INVALID, "Size is zero"); + GE_CHK_BOOL_EXEC(block_size % size == 0, return PARAM_INVALID, "block_size:%ld is not divisiable by size:%d.", + block_size, size); + kernel_hccl_infos[i].count = static_cast(block_size / size); + } + + return SUCCESS; +} + +Status HcomOmeUtil::GetHcclCount(const ge::ConstOpDescPtr &op_desc, + std::vector &kernel_hccl_infos) { + GE_CHECK_NOTNULL(op_desc); + Status ret; + ret = CheckKernelHcclInfo(op_desc, kernel_hccl_infos); + if (ret != SUCCESS) { + GELOGE(PARAM_INVALID, "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); return PARAM_INVALID; } + GELOGI("GetHcclCount start, node[%s], opType[%s].", op_desc->GetName().c_str(), op_desc->GetType().c_str()); + if (IsHCOMOp(op_desc->GetType())) { + int32_t count = 0; + ret = GetHcomCount(op_desc, static_cast(kernel_hccl_infos[0].dataType), + kernel_hccl_infos[0].hccl_type == HCOMALLGATHER, count); + if (ret != SUCCESS) { + GELOGE(ret, "HcomOmeUtil:: Node: %s Optype: %s get the Hcom operator hccl count fail.", + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + return PARAM_INVALID; + } + kernel_hccl_infos[0].count = count; + } + + if (IsHorovodOp(op_desc->GetType())) { + ret = GetHorovodCount(op_desc, kernel_hccl_infos); + if (ret != SUCCESS) { + GELOGE(PARAM_INVALID, "HcomOmeUtil:: Node: %s Optype: %s get the Horovod hccl operator count fail.", + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + return PARAM_INVALID; + } + } return SUCCESS; } -Status HcomOmeUtil::GetHcomRootId(const ge::ConstOpDescPtr &op_desc, int64_t &root_id) { +Status HcomOmeUtil::GetHcclOperationType(const ge::ConstOpDescPtr &op_desc, hcclRedOp_t &op_type) { + GE_CHECK_NOTNULL(op_desc); + + if (IsHCOMOp(op_desc->GetType())) { + std::string hcom_op_type; + GE_CHK_BOOL_EXEC(ge::AttrUtils::GetStr(op_desc, HCOM_ATTR_REDUCE_TYPE, hcom_op_type), return PARAM_INVALID, + "HcomOmeUtil:: Node: %s Optype: %s Get HCOM_ATTR_REDUCE_TYPE fail, not support!", + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + + if (hcom_op_type == "min") { + op_type = HCCL_REP_OP_MIN; + } else if (hcom_op_type == "max") { + op_type = HCCL_REP_OP_MAX; + } else if (hcom_op_type == "prod") { + op_type = HCCL_REP_OP_PROD; + } else if (hcom_op_type == "sum") { + op_type = HCCL_REP_OP_SUM; + } else { + GELOGE(PARAM_INVALID, "HcomOmeUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [%s] not support!", hcom_op_type.c_str()); + return PARAM_INVALID; + } + } + + if (IsHorovodOp(op_desc->GetType())) { + int64_t horovod_op_type; + GE_CHK_BOOL_EXEC(ge::AttrUtils::GetInt(op_desc, ATTR_HOROVOD_ATTR_REDUCE_TYPE, horovod_op_type), + return PARAM_INVALID, + "HcomOmeUtil:: Node: %s Optype: %s Get ATTR_HOROVOD_ATTR_REDUCE_TYPE fail, not support!", + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + + auto iter = kHorovodRedOpToHcclRedOp.find(static_cast(horovod_op_type)); + if (iter == kHorovodRedOpToHcclRedOp.end()) { + GELOGE(PARAM_INVALID, "HcomOmeUtil:: Node: %s Optype: %s HcomOpType cann't support! Current HcomOpType : %ld", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), horovod_op_type); + return PARAM_INVALID; + } + op_type = iter->second; + } + + return SUCCESS; +} + +Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &root_id) { GE_CHECK_NOTNULL(op_desc); GE_CHK_BOOL_EXEC(ge::AttrUtils::GetInt(op_desc, HCOM_ATTR_ROOT_RANK, root_id), return PARAM_INVALID, - "HcomOmeUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!"); + "HcomOmeUtil::Node %s Optype: %s Get HCOM_ATTR_ROOT_INDEX fail, not support!", + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + + return SUCCESS; +} + +Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, + std::vector &kernel_hccl_infos) { + GE_CHECK_NOTNULL(op_desc); + if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST) { + GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); + int64_t root_id = 0; + Status dmrt = GetHcclRootId(op_desc, root_id); + if (dmrt != SUCCESS) { + GELOGE(FAILED, "davinci_model: GetHcomRootId fail! domi error: %u", dmrt); + return FAILED; + } + for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { + kernel_hccl_infos[i].rootId = root_id; + } + } + return SUCCESS; +} + +bool HcomOmeUtil::IsHCOMOp(const string &op_type) { + return (op_type == HCOMALLREDUCE) || (op_type == HCOMALLGATHER) || (op_type == HCOMBROADCAST) || + (op_type == HCOMSEND) || (op_type == HCOMRECEIVE) || (op_type == HCOMREDUCESCATTER); +} + +bool HcomOmeUtil::IsHorovodOp(const string &op_type) { + return (op_type == HVDCALLBACKALLREDUCE) || (op_type == HVDCALLBACKALLGATHER) || (op_type == HVDCALLBACKBROADCAST) || + (op_type == HVDWAIT); +} + +Status HcomOmeUtil::CheckKernelHcclInfo(const ge::ConstOpDescPtr &op_desc, + std::vector &kernel_hccl_infos) { + GE_CHECK_NOTNULL(op_desc); + if (IsHCOMOp(op_desc->GetType()) && kernel_hccl_infos.size() != 1) { + GELOGE(PARAM_INVALID, "HcomOmeUtil:: in Hcom scenario, the number of GETaskKernelHcclInfo is invalid."); + return PARAM_INVALID; + } + + if (IsHorovodOp(op_desc->GetType())) { + if (op_desc->GetType() == HVDWAIT) { + return SUCCESS; + } + if (kernel_hccl_infos.empty() || op_desc->GetInputsSize() != kernel_hccl_infos.size()) { + GELOGE(PARAM_INVALID, "HcomOmeUtil:: in Horovod scenario, the number of GETaskKernelHcclInfo is invalid."); + return PARAM_INVALID; + } + } + return SUCCESS; +} + +void HcomOmeUtil::GetHcclType(const domi::TaskDef &task_def, std::vector &kernel_hccl_infos) { + auto hccl_def = task_def.kernel_hccl(); + std::string hccl_type = hccl_def.hccl_type(); + for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { + kernel_hccl_infos[i].hccl_type = hccl_type; + } +} + +Status HcomOmeUtil::GetHorovodInputs(const ge::ConstOpDescPtr &op_desc, + std::vector &kernel_hccl_infos) { + GE_CHECK_NOTNULL(op_desc); + if (!IsHorovodOp(op_desc->GetType())) { + return SUCCESS; + } + + if (CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { + GELOGE(PARAM_INVALID, "HcomOmeUtil:: Node: %s Optype: %s the number of GETaskKernelHcclInfo is invalid.", + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + return PARAM_INVALID; + } + + if (op_desc->GetType() == HVDWAIT) { + return SUCCESS; + } + + for (size_t i = 0; i < op_desc->GetInputsSize(); i++) { + ConstGeTensorDescPtr input_desc = op_desc->GetInputDescPtr(i); + GETaskKernelHcclInfo &kernel_hccl_info = kernel_hccl_infos.at(i); + kernel_hccl_info.input_name = op_desc->GetInputNameByIndex(i); + kernel_hccl_info.dims = input_desc->GetShape().GetDims(); + } return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/manager/util/hcom_util.h b/src/ge/graph/manager/util/hcom_util.h index 31bf246e..40aac3e5 100644 --- a/src/ge/graph/manager/util/hcom_util.h +++ b/src/ge/graph/manager/util/hcom_util.h @@ -22,72 +22,146 @@ #include #include "common/debug/log.h" +#include "common/opskernel/ge_task_info.h" #include "common/string_util.h" #include "common/types.h" #include "common/util.h" #include "graph/op_desc.h" #include "hccl/hcom.h" +#include "proto/task.pb.h" namespace ge { using std::string; using std::vector; -static std::map kConstOpHcomDataType = { - {ge::DT_FLOAT, HCCL_DATA_TYPE_FLOAT}, - {ge::DT_FLOAT16, HCCL_DATA_TYPE_HALF}, - {ge::DT_INT8, HCCL_DATA_TYPE_INT8}, - {ge::DT_INT32, HCCL_DATA_TYPE_INT}, +static std::map kConstOpHcclDataType = { + {ge::DT_FLOAT, HCCL_DATA_TYPE_FLOAT}, + {ge::DT_FLOAT16, HCCL_DATA_TYPE_HALF}, + {ge::DT_INT8, HCCL_DATA_TYPE_INT8}, + {ge::DT_INT32, HCCL_DATA_TYPE_INT}, }; -static std::map kConstOpHcomDataTypeSize = { - {HCCL_DATA_TYPE_FLOAT, sizeof(float)}, - {HCCL_DATA_TYPE_HALF, sizeof(float) / 2}, - {HCCL_DATA_TYPE_INT8, sizeof(int8_t)}, - {HCCL_DATA_TYPE_INT, sizeof(int32_t)}, +static std::map kConstOpHcclDataTypeSize = { + {HCCL_DATA_TYPE_FLOAT, sizeof(float)}, + {HCCL_DATA_TYPE_HALF, sizeof(float) / 2}, + {HCCL_DATA_TYPE_INT8, sizeof(int8_t)}, + {HCCL_DATA_TYPE_INT, sizeof(int32_t)}, +}; + +static std::map kHorovodRedOpToHcclRedOp = { + {HOROVOD_REP_OP_SUM, HCCL_REP_OP_SUM}, {HOROVOD_REP_OP_MIN, HCCL_REP_OP_MIN}, + {HOROVOD_REP_OP_MAX, HCCL_REP_OP_MAX}, {HOROVOD_REP_OP_PROD, HCCL_REP_OP_PROD}, + {HOROVOD_REP_OP_RESERVED, HCCL_REP_OP_RESERVED}, }; class HcomOmeUtil { public: /// /// @ingroup domi_ome - /// @brief GetHcomDataType + /// @brief GetHcclDataType /// @return SUCCESS /// @return FAIL /// - static Status GetHcomDataType(const ge::ConstOpDescPtr &op_desc, hcclDataType_t &data_type); + static Status GetHcclDataType(const ge::ConstOpDescPtr &op_desc, + std::vector &kernel_hccl_infos); /// /// @ingroup domi_ome - /// @brief GetHcomTypeSize + /// @brief GetHcclTypeSize /// @return SUCCESS /// @return FAIL /// - static Status GetHcomTypeSize(hcclDataType_t data_type, int32_t &size); + static Status GetHcclTypeSize(hcclDataType_t data_type, int32_t &size); /// /// @ingroup domi_ome - /// @brief GetHcomCount + /// @brief GetHcclCount /// @return SUCCESS /// @return FAIL /// - static Status GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType_t data_type, bool is_allgather, - int &count); + static Status GetHcclCount(const ge::ConstOpDescPtr &op_desc, std::vector &kernel_hccl_infos); + + /// + /// @ingroup domi_ome + /// @brief GetHcclOperationType + /// @return SUCCESS + /// @return FAIL + /// + static Status GetHcclOperationType(const ge::ConstOpDescPtr &op_desc, hcclRedOp_t &op_type); + + /// + /// @ingroup domi_ome + /// @brief GetHcclRootId + /// @return SUCCESS + /// @return FAIL + /// + static Status GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &root_id); + + /// + /// @ingroup domi_ome + /// @brief GetAllRootId + /// @return SUCCESS + /// @return FAIL + /// + static Status GetAllRootId(const ge::ConstOpDescPtr &op_desc, std::vector &kernel_hccl_infos); + + /// + /// @ingroup domi_ome + /// @brief check the op_type whether is hcom operator or not + /// @return true + /// @return false + /// + static bool IsHCOMOp(const string &op_type); + + /// + /// @ingroup domi_ome + /// @brief check the op_type whether is horovod operator or not + /// @return true + /// @return false + /// + static bool IsHorovodOp(const string &op_type); + + /// + /// @ingroup domi_ome + /// @brief GetHcclType + /// @return void + /// + static void GetHcclType(const domi::TaskDef &task_def, std::vector &kernel_hccl_infos); /// /// @ingroup domi_ome - /// @brief GetHcomOperationType + /// @brief CheckKernelHcclInfo + /// @return SUCCESS + /// @return FAIL + /// + static Status CheckKernelHcclInfo(const ge::ConstOpDescPtr &op_desc, + std::vector &kernel_hccl_infos); + /// + /// @ingroup domi_ome + /// @brief GetHorovodInputs /// @return SUCCESS /// @return FAIL /// - static Status GetHcomOperationType(const ge::ConstOpDescPtr &op_desc, hcclRedOp_t &op_type); + static Status GetHorovodInputs(const ge::ConstOpDescPtr &op_desc, + std::vector &kernel_hccl_infos); + private: + /// + /// @ingroup domi_ome + /// @brief GetHcomCount + /// @return SUCCESS + /// @return FAIL + /// + static Status GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType_t data_type, bool is_allgather, + int &count); /// /// @ingroup domi_ome - /// @brief GetHcomRootId + /// @brief GetHorovodCount /// @return SUCCESS /// @return FAIL /// - static Status GetHcomRootId(const ge::ConstOpDescPtr &op_desc, int64_t &root_id); + static Status GetHorovodCount(const ge::ConstOpDescPtr &op_desc, + std::vector &kernel_hccl_infos); }; } // namespace ge #endif // GE_GRAPH_MANAGER_UTIL_HCOM_UTIL_H_ diff --git a/src/ge/graph/optimize/graph_optimize.cc b/src/ge/graph/optimize/graph_optimize.cc index 84cc77f9..f23ad110 100644 --- a/src/ge/graph/optimize/graph_optimize.cc +++ b/src/ge/graph/optimize/graph_optimize.cc @@ -134,7 +134,7 @@ Status GraphOptimize::OptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { return GE_CLI_GE_NOT_INITIALIZED; } - std::map graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjs(); + auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", graph_optimizer.size()); string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; @@ -154,6 +154,37 @@ Status GraphOptimize::OptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { return ret; } +Status GraphOptimize::OptimizeOriginalGraphJudgeInsert(ComputeGraphPtr &compute_graph) { + GELOGD("OptimizeOriginalGraphJudgeInsert in"); + GE_CHECK_NOTNULL(compute_graph); + Status ret = SUCCESS; + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "OptimizeOriginalGraph failed."); + return GE_CLI_GE_NOT_INITIALIZED; + } + + auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); + GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", + graph_optimizer.size()); + string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; + if (graph_optimizer.size() != 0) { + for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { + if (iter->first == exclude_core_Type) { + GELOGI("[OptimizeOriginalGraphJudgeInsert]: engine type will exclude: %s", exclude_core_Type.c_str()); + continue; + } + GELOGI("Begin to refine running format by engine %s", iter->first.c_str()); + ret = (iter->second)->OptimizeOriginalGraphJudgeInsert(*compute_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[OptimizeOriginalGraphJudgeInsert]: graph optimize failed, ret:%d", ret); + return ret; + } + } + } + return ret; +} + Status GraphOptimize::NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { GELOGD("NewOptimizeOriginalGraph in"); if (compute_graph == nullptr) { @@ -168,7 +199,7 @@ Status GraphOptimize::NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { return GE_CLI_GE_NOT_INITIALIZED; } - std::map graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjs(); + auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", graph_optimizer.size()); string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; @@ -207,7 +238,7 @@ Status GraphOptimize::OptimizeOriginalGraphForQuantize(ComputeGraphPtr &compute_ return GE_CLI_GE_NOT_INITIALIZED; } - std::map graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjs(); + auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); GELOGI("optimize by opskernel in original graph optimize quantize phase. num of graph_optimizer is %zu.", graph_optimizer.size()); Status ret = SUCCESS; diff --git a/src/ge/graph/optimize/graph_optimize.h b/src/ge/graph/optimize/graph_optimize.h index 83b5489f..72709932 100644 --- a/src/ge/graph/optimize/graph_optimize.h +++ b/src/ge/graph/optimize/graph_optimize.h @@ -47,6 +47,8 @@ class GraphOptimize { // original graph optimize Status OptimizeOriginalGraph(ComputeGraphPtr &compute_graph); + Status OptimizeOriginalGraphJudgeInsert(ComputeGraphPtr &compute_graph); + // new original graph optimize Status NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph); diff --git a/src/ge/graph/partition/dynamic_shape_partition.cc b/src/ge/graph/partition/dynamic_shape_partition.cc index bbf31e5c..4ffd37bd 100644 --- a/src/ge/graph/partition/dynamic_shape_partition.cc +++ b/src/ge/graph/partition/dynamic_shape_partition.cc @@ -43,39 +43,44 @@ #define REQUIRE_SUCCESS(cond, ...) REQUIRE(((cond) == SUCCESS), __VA_ARGS__) #define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) -namespace { -const bool kDebugging = (std::getenv("DEBUG_DYNAMIC_PARTITION") != nullptr); -} // namespace +bool IsExperimental() { + const static bool kIsExperimental = (std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION") != nullptr); + return kIsExperimental; +} -#define DLOG() \ - if (kDebugging) std::cerr namespace ge { using Cluster = DynamicShapePartitioner::Cluster; using ClusterPtr = std::shared_ptr; Status DynamicShapePartitioner::Partition() { REQUIRE_NOT_NULL(root_graph_, "Graph is nullptr."); - DLOG() << "Start dynamic shape partition graph " << root_graph_->GetName() << std::endl; - REQUIRE_SUCCESS(MarkUnknowShapeNodes(), "Failed mark unknow shape nodes."); + if (!IsExperimental()) { + GELOGD("Skip dynamic shape partition as not in experimental mode."); + REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, false), + "Failed set dynamic shape partitioned flag on root graph."); + return SUCCESS; + } + + GELOGD("Start dynamic shape partition graph %s.", root_graph_->GetName().c_str()); + REQUIRE_SUCCESS(MarkUnknownShapeNodes(), "Failed mark unknown shape nodes."); if (unknown_shape_nodes_.empty()) { - DLOG() << "Skip dynamic shape partition of graph " << root_graph_->GetName() << " as all nodes are known shape." - << std::endl; - REQUIRE(AttrUtils::SetBool(*root_graph_, "_dynamic_shape_partitioned", false), + GELOGD("Skip dynamic shape partition of graph %s as all nodes are known shape.", root_graph_->GetName().c_str()); + REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, false), "Failed set dynamic shape partitioned flag on root graph."); return SUCCESS; } - REQUIRE(AttrUtils::SetBool(*root_graph_, "_dynamic_shape_partitioned", true), + REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, true), "Failed set dynamic shape partitioned flag on root graph."); + DumpGraph("_Before_DSP"); auto status = PartitionImpl(); - DLOG() << DebugString() << std::endl; + GELOGD("%s.", DebugString().c_str()); if (status != SUCCESS) { GELOGE(status, "Failed dynamic shape partition graph: %s, status:\n %s", root_graph_->GetName().c_str(), DebugString().c_str()); } DumpGraph("_After_DSP"); - DLOG() << (status == SUCCESS ? "Succeed" : "Failed") << " dynamic shape partition graph " << root_graph_->GetName() - << std::endl; + GELOGD("Finish dynamic shape partition graph %s.", root_graph_->GetName().c_str()); ClearResource(); return status; } @@ -122,36 +127,36 @@ Status DynamicShapePartitioner::BuildPartitionSubgraph() { return SUCCESS; } -std::string DynamicShapePartitioner::DebugString() { - size_t unknow = 0; - size_t know = 0; +std::string DynamicShapePartitioner::DebugString() const { + size_t unknown = 0; + size_t known = 0; size_t data = 0; size_t netoutput = 0; std::stringstream ss; - ss << "All unknow shape nodes:" << std::endl; + ss << "All unknown shape nodes:" << std::endl; for (auto node : unknown_shape_nodes_) { ss << " [" << node->GetName() << "](" << node->GetType() << ")" << std::endl; } for (auto cluster : unique_clusters_) { - if (cluster->IsUnknowShape()) { - unknow++; - } else if (cluster->IsKnowShape()) { - know++; + if (cluster->IsUnknownShape()) { + unknown++; + } else if (cluster->IsKnownShape()) { + known++; } else if (cluster->IsData()) { data++; } else if (cluster->IsNetOutput()) { netoutput++; } } - ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", know:" << know << ", unknow:" << unknow - << ", netoutput:" << netoutput << std::endl; + ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", known:" << known + << ", unknown:" << unknown << ", netoutput:" << netoutput << std::endl; for (auto cluster : unique_clusters_) { ss << " " << cluster->DebugString() << std::endl; } return ss.str(); } -void DynamicShapePartitioner::DumpGraph(std::string suffix) { +void DynamicShapePartitioner::DumpGraph(const std::string &suffix) { GraphUtils::DumpGEGraphToOnnx(*root_graph_, root_graph_->GetName() + suffix); for (auto sub_graph : root_graph_->GetAllSubgraphs()) { GraphUtils::DumpGEGraphToOnnx(*sub_graph, sub_graph->GetName() + suffix); @@ -169,10 +174,10 @@ void DynamicShapePartitioner::ClearResource() { root_graph_.reset(); } -Status DynamicShapePartitioner::MarkUnknowShapeNodes() { +Status DynamicShapePartitioner::MarkUnknownShapeNodes() { auto graph = root_graph_; for (auto &node : graph->GetDirectNode()) { - REQUIRE_SUCCESS(CollectSpreadUnknowShapeNodes(node), "Failed collect spread unknow shape nodes %s.", + REQUIRE_SUCCESS(CollectSpreadUnknownShapeNodes(node), "Failed collect spread unknown shape nodes %s.", node->GetName().c_str()); } return SUCCESS; @@ -188,14 +193,14 @@ Status DynamicShapePartitioner::InitClusters() { } else if (node->GetType() == NETOUTPUT) { type = Cluster::NETOUTPUT; } else if (unknown_shape_nodes_.count(node) > 0) { - type = Cluster::UNKNOW_SHAPE; + type = Cluster::UNKNOWN_SHAPE; } else { - type = Cluster::KNOW_SHAPE; + type = Cluster::KNOWN_SHAPE; } auto cluster = MakeShared(rank++, type, node, this); REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); node_2_cluster_[node] = cluster; - if (cluster->IsUnknowShape()) { + if (cluster->IsUnknownShape()) { ordered_cluster_.push_back(cluster); } // Already sorted topologically, so access to the parent cluster is safe @@ -203,18 +208,15 @@ Status DynamicShapePartitioner::InitClusters() { cluster->AddInput(node_2_cluster_[parent]); } } - if (kDebugging) { - for (const auto node : graph->GetDirectNode()) { - DLOG() << "Make cluster for node :" << node->GetName() << ":" << node_2_cluster_[node]->DebugString() - << std::endl; - } + for (const auto node : graph->GetDirectNode()) { + GELOGD("Make cluster for node %s : %s.", node->GetName().c_str(), node_2_cluster_[node]->DebugString().c_str()); } return SUCCESS; } Status DynamicShapePartitioner::TopologicalSortClusters() { ordered_cluster_.clear(); - // BFS topological sort clusters for know shape cluster + // BFS topological sort clusters for known shape cluster std::queue ready_clusters; std::unordered_map cluster_pending_count; std::unordered_set seen_clusters; @@ -231,16 +233,17 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { cluster_pending_count[cluster] = pending_count; } } + size_t rank = 0; while (!ready_clusters.empty()) { auto cluster = ready_clusters.front(); ready_clusters.pop(); cluster->UpdateRank(rank++); - if (cluster->IsKnowShape()) { + if (cluster->IsKnownShape()) { ordered_cluster_.push_back(cluster); } for (auto out_cluster : cluster->Outputs()) { - if (--cluster_pending_count[out_cluster] == 0) { + if (cluster_pending_count[out_cluster] > 0 && --cluster_pending_count[out_cluster] == 0) { ready_clusters.push(out_cluster); } } @@ -252,49 +255,58 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { } namespace { -template -static std::string ToString(T vec) { - if (vec.empty()) { +static std::string ToString(const std::vector &clusters) { + if (clusters.empty()) { return "()"; } std::stringstream ss; ss << "("; - auto iter = vec.begin(); - for (size_t i = 0; i < vec.size() - 1; i++) { - ss << (*iter++)->Id() << ","; + auto iter = clusters.begin(); + for (size_t i = 0; i < clusters.size() - 1; i++) { + ss << (*iter)->Id() << ","; + iter++; } - ss << (*iter++)->Id() << ")."; + ss << (*iter)->Id() << ")."; return ss.str(); } } // namespace Status DynamicShapePartitioner::MergeClusters() { - // Merge unknow shape clusters + // Merge unknown shape clusters for (auto cluster : ordered_cluster_) { for (auto in_cluster : cluster->Inputs()) { - if (in_cluster->IsUnknowShape()) { - auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); - DLOG() << "Merge all path cluster from " << in_cluster->Id() << " to " << cluster->Id() - << ToString(merged_clusters) << std::endl; - for (auto merged_cluster : merged_clusters) { - for (auto node : merged_cluster->Nodes()) { - node_2_cluster_[node] = cluster; - } + if (!in_cluster->IsUnknownShape()) { + continue; + } + auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); + GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), + ToString(merged_clusters).c_str()); + for (auto merged_cluster : merged_clusters) { + for (auto node : merged_cluster->Nodes()) { + node_2_cluster_[node] = cluster; } } } } - REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknow shape clusters."); - // Merge know shape clusters + + REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); + // Merge known shape clusters for (auto cluster : ordered_cluster_) { + if (cluster->IsRefVariable() && cluster->Inputs().size() == 1) { + auto in_cluster = *(cluster->Inputs().begin()); + in_cluster->Merge(cluster); + node_2_cluster_[*(cluster->Nodes().begin())] = in_cluster; + continue; + } + for (auto in_cluster : cluster->Inputs()) { - if (in_cluster->IsKnowShape()) { - if (cluster->TryMerge(in_cluster)) { - DLOG() << "Success merge known shape cluster " << in_cluster->Id() << " to " << cluster->Id() << "." - << std::endl; - for (auto node : in_cluster->Nodes()) { - node_2_cluster_[node] = cluster; - } + if (!in_cluster->IsKnownShape()) { + continue; + } + if (cluster->TryMerge(in_cluster)) { + GELOGD("Success merge known shape cluster from %lu to %lu.", in_cluster->Id(), cluster->Id()); + for (auto node : in_cluster->Nodes()) { + node_2_cluster_[node] = cluster; } } } @@ -302,23 +314,30 @@ Status DynamicShapePartitioner::MergeClusters() { return SUCCESS; } -Status DynamicShapePartitioner::CollectSpreadUnknowShapeNodes(NodePtr node) { +Status DynamicShapePartitioner::CollectSpreadUnknownShapeNodes(NodePtr node) { if (unknown_shape_nodes_.count(node) > 0) { return SUCCESS; } auto opdesc = node->GetOpDesc(); + // One can set 'ATTR_NAME_IS_UNKNOWN_SHAPE=true' on node so as to forcing the node flow into the unknown subgraph, + // ignore the actual shape. + bool is_forced_unknown = false; + if (AttrUtils::GetBool(opdesc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_forced_unknown) && is_forced_unknown) { + GELOGD("Collect node %s as unknown as it was marked unknown forcibly.", node->GetName().c_str()); + unknown_shape_nodes_.insert(node); + return SUCCESS; + } size_t anchor_index = 0; - bool is_unknow = false; + bool is_unknown = false; for (auto &out_tensor : opdesc->GetAllOutputsDesc()) { - if (IsUnknowShapeTensor(out_tensor)) { - DLOG() << "Collect node " << node->GetName() << " as unknown as output " << anchor_index << " is unknown" - << std::endl; - is_unknow = true; + if (IsUnknownShapeTensor(out_tensor)) { + GELOGD("Collect node %s as unknown as output %lu is unknown.", node->GetName().c_str(), anchor_index); + is_unknown = true; auto anchor = node->GetOutDataAnchor(anchor_index); for (const auto peer_anchor : anchor->GetPeerInDataAnchors()) { if (peer_anchor != nullptr) { - DLOG() << "Collect node " << peer_anchor->GetOwnerNode()->GetName() << " as has unknown input from " - << node->GetName() << ":" << anchor_index << std::endl; + GELOGD("Collect node %s as has unknown input from %s:%lu.", peer_anchor->GetOwnerNode()->GetName().c_str(), + node->GetName().c_str(), anchor_index); unknown_shape_nodes_.insert(peer_anchor->GetOwnerNode()); } } @@ -327,21 +346,20 @@ Status DynamicShapePartitioner::CollectSpreadUnknowShapeNodes(NodePtr node) { } anchor_index = 0; for (auto &in_tensor : opdesc->GetAllInputsDesc()) { - if (IsUnknowShapeTensor(in_tensor)) { - DLOG() << "Collect node " << node->GetName() << " as unknown as input " << anchor_index << " is unknown" - << std::endl; - is_unknow = true; + if (IsUnknownShapeTensor(in_tensor)) { + GELOGD("Collect node %s as unknown as input %lu is unknown.", node->GetName().c_str(), anchor_index); + is_unknown = true; auto anchor = node->GetInDataAnchor(anchor_index); const auto peer_anchor = anchor->GetPeerOutAnchor(); if (peer_anchor != nullptr) { - DLOG() << "Collect node " << peer_anchor->GetOwnerNode()->GetName() << " as has unknown output to " - << node->GetName() << ":" << anchor_index << std::endl; + GELOGD("Collect node %s as has unknown output to %s:%lu.", peer_anchor->GetOwnerNode()->GetName().c_str(), + node->GetName().c_str(), anchor_index); unknown_shape_nodes_.insert(peer_anchor->GetOwnerNode()); } } anchor_index++; } - if (is_unknow) { + if (is_unknown) { unknown_shape_nodes_.insert(node); } else { auto graph = root_graph_; @@ -350,11 +368,10 @@ Status DynamicShapePartitioner::CollectSpreadUnknowShapeNodes(NodePtr node) { REQUIRE_NOT_NULL(subgraph, "Failed get subgraph %s of node %s on root graph.", subgraph_name.c_str(), node->GetName().c_str()); bool is_graph_unknow = false; - REQUIRE_SUCCESS(IsUnknowShapeGraph(subgraph, is_graph_unknow), "Failed check subgraph %s shape of node %s.", + REQUIRE_SUCCESS(IsUnknownShapeGraph(subgraph, is_graph_unknow), "Failed check subgraph %s shape of node %s.", subgraph_name.c_str(), node->GetName().c_str()); if (is_graph_unknow) { - DLOG() << "Collect node " << node->GetName() << " as its subgraph " << subgraph->GetName() << " is unknown." - << std::endl; + GELOGD("Collect node %s as its subgraph %s is unknown.", node->GetName().c_str(), subgraph->GetName().c_str()); unknown_shape_nodes_.insert(node); break; } @@ -363,20 +380,20 @@ Status DynamicShapePartitioner::CollectSpreadUnknowShapeNodes(NodePtr node) { return SUCCESS; } -Status DynamicShapePartitioner::IsUnknowShapeNode(NodePtr node, bool &is_unknow) { +Status DynamicShapePartitioner::IsUnknownShapeNode(NodePtr node, bool &is_unknown) { auto opdesc = node->GetOpDesc(); auto graph = root_graph_; for (auto &out_tensor : opdesc->GetAllOutputsDesc()) { - if (IsUnknowShapeTensor(out_tensor)) { - DLOG() << "Mark node " << node->GetName() << " unknown because unknown output " << std::endl; - is_unknow = true; + if (IsUnknownShapeTensor(out_tensor)) { + GELOGD("Mark node %s unknown as unknown output.", node->GetName().c_str()); + is_unknown = true; return SUCCESS; } } for (auto &in_tensor : opdesc->GetAllInputsDesc()) { - if (IsUnknowShapeTensor(in_tensor)) { - DLOG() << "Mark node " << node->GetName() << " unknown because unknown intput " << std::endl; - is_unknow = true; + if (IsUnknownShapeTensor(in_tensor)) { + GELOGD("Mark node %s unknown as unknown intput.", node->GetName().c_str()); + is_unknown = true; return SUCCESS; } } @@ -384,30 +401,30 @@ Status DynamicShapePartitioner::IsUnknowShapeNode(NodePtr node, bool &is_unknow) auto subgraph = graph->GetSubgraph(subgraph_name); REQUIRE_NOT_NULL(subgraph, "Failed get subgraph %s of node %s on root graph.", subgraph_name.c_str(), node->GetName().c_str()); - REQUIRE_SUCCESS(IsUnknowShapeGraph(subgraph, is_unknow), "Failed check subgraph %s shape of node %s.", + REQUIRE_SUCCESS(IsUnknownShapeGraph(subgraph, is_unknown), "Failed check subgraph %s shape of node %s.", subgraph_name.c_str(), node->GetName().c_str()); - if (is_unknow) { - DLOG() << "Mark node " << node->GetName() << " unknown because unknown subgraph " << std::endl; + if (is_unknown) { + GELOGD("Mark node %s unknown as unknown subgraph.", node->GetName().c_str()); return SUCCESS; } } - is_unknow = false; + is_unknown = false; return SUCCESS; } -Status DynamicShapePartitioner::IsUnknowShapeGraph(ComputeGraphPtr graph, bool &is_unknow) { +Status DynamicShapePartitioner::IsUnknownShapeGraph(ComputeGraphPtr graph, bool &is_unknown) { for (auto &node : graph->GetDirectNode()) { - REQUIRE_SUCCESS(IsUnknowShapeNode(node, is_unknow), "Failed check node %s shape on graph %s.", + REQUIRE_SUCCESS(IsUnknownShapeNode(node, is_unknown), "Failed check node %s shape on graph %s.", node->GetName().c_str(), graph->GetName().c_str()); - if (is_unknow) { - DLOG() << "Mark graph " << graph->GetName() << " unknown because unknown node " << node->GetName() << std::endl; + if (is_unknown) { + GELOGD("Mark graph %s unknown as contains unknown node %s.", graph->GetName().c_str(), node->GetName().c_str()); return SUCCESS; } } return SUCCESS; } -bool DynamicShapePartitioner::IsUnknowShapeTensor(GeTensorDesc &tensor) { +bool DynamicShapePartitioner::IsUnknownShapeTensor(const GeTensorDesc &tensor) { const static int kUnknowShape = -1; const static int kUnknowRank = -2; for (auto dim_size : tensor.GetShape().GetDims()) { @@ -418,7 +435,7 @@ bool DynamicShapePartitioner::IsUnknowShapeTensor(GeTensorDesc &tensor) { return false; } -std::string Cluster::DebugString() { +std::string Cluster::DebugString() const { std::stringstream ss; switch (type_) { case DATA: @@ -427,10 +444,10 @@ std::string Cluster::DebugString() { case NETOUTPUT: ss << "NETOUTPUT"; break; - case UNKNOW_SHAPE: + case UNKNOWN_SHAPE: ss << "UNKNOW"; break; - case KNOW_SHAPE: + case KNOWN_SHAPE: ss << "KNOW"; break; } @@ -450,18 +467,22 @@ std::string Cluster::DebugString() { return ss.str(); } -size_t Cluster::Id() { return id_; } +size_t Cluster::Id() const { return id_; } void Cluster::UpdateRank(size_t rank) { max_ = rank; min_ = rank; }; -bool Cluster::IsData() { return type_ == DATA; }; -bool Cluster::IsKnowShape() { return type_ == KNOW_SHAPE; }; -bool Cluster::IsUnknowShape() { return type_ == UNKNOW_SHAPE; }; -bool Cluster::IsNetOutput() { return type_ == NETOUTPUT; }; -bool Cluster::IsolatedConstant() { - return ((nodes_.size() == 1) && (nodes_[0]->GetType() == CONSTANTOP) && (out_clusters_.size() == 1) && - (*out_clusters_.begin())->IsUnknowShape() && in_clusters_.empty()); +bool Cluster::IsData() const { return type_ == DATA; }; +bool Cluster::IsKnownShape() const { return type_ == KNOWN_SHAPE; }; +bool Cluster::IsUnknownShape() const { return type_ == UNKNOWN_SHAPE; }; +bool Cluster::IsNetOutput() const { return type_ == NETOUTPUT; }; +bool Cluster::IsRefVariable() const { + if ((nodes_.size() == 1) && ((nodes_[0]->GetType() == VARIABLE) || (nodes_[0]->GetType() == VARIABLEV2))) { + std::string ref_variable_name; + return (AttrUtils::GetStr(nodes_[0]->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_variable_name) && + !ref_variable_name.empty()); + } + return false; } void Cluster::AddInput(ClusterPtr in) { in_clusters_.insert(in); @@ -562,9 +583,9 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { } return path_clusters; } -std::unordered_set Cluster::Inputs() { return in_clusters_; }; -std::unordered_set Cluster::Outputs() { return out_clusters_; }; -std::vector Cluster::Nodes() { return nodes_; }; +std::unordered_set Cluster::Inputs() const { return in_clusters_; }; +std::unordered_set Cluster::Outputs() const { return out_clusters_; }; +std::vector Cluster::Nodes() const { return nodes_; }; void Cluster::AddFrameInput(InDataAnchorPtr anchor) { inputs_index_[anchor] = inputs_.size(); @@ -589,7 +610,7 @@ InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_-> OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; Status Cluster::BuildFrame() { - if (IsUnknowShape() || IsKnowShape()) { + if (IsUnknownShape() || IsKnownShape()) { return BuildPartitionFrame(); } else { auto node = nodes_.front(); @@ -621,7 +642,7 @@ Status Cluster::BuildFrame() { Status Cluster::BuildPartitionFrame() { auto graph = partitioner_->root_graph_; - bool is_unknown_shape = IsUnknowShape(); + bool is_unknown_shape = IsUnknownShape(); std::string sub_graph_name = graph->GetName() + "_sub_" + std::to_string(unique_id_) + (is_unknown_shape ? "_unknow" : "_know"); subgraph_ = MakeShared(sub_graph_name); @@ -727,6 +748,7 @@ Status Cluster::BuildPartitionSubgraph() { auto data_op = MakeShared(std::string("Data_") + std::to_string(parent_node_index), ge::DATA); REQUIRE_NOT_NULL(data_op, "Failed new memory for data op."); auto input_desc = anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(anchor->GetIdx()); + REQUIRE_GRAPH_SUCCESS(data_op->AddInputDesc(input_desc), "Failed add input desc."); REQUIRE_GRAPH_SUCCESS(data_op->AddOutputDesc(input_desc), "Failed add output desc."); REQUIRE(AttrUtils::SetInt(data_op, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index), "Failed set parent_node_index on subgraph data node."); diff --git a/src/ge/graph/partition/dynamic_shape_partition.h b/src/ge/graph/partition/dynamic_shape_partition.h index 8734d7aa..4cbd20b7 100644 --- a/src/ge/graph/partition/dynamic_shape_partition.h +++ b/src/ge/graph/partition/dynamic_shape_partition.h @@ -29,27 +29,27 @@ class DynamicShapePartitioner { public: // An cluster means set of nodes that can be merged in same partition, // Corresponding relationship between cluster type and node: - // DATA:DATA, UNKNOW_SHAPE:unknowshape, KNOW_SHAPE:knowshape, NETOUTPUT:NETOUTPUT. + // DATA:DATA, UNKNOWN_SHAPE:unknowshape, KNOWN_SHAPE:knowshape, NETOUTPUT:NETOUTPUT. class Cluster : public std::enable_shared_from_this { public: - enum Type { DATA, NETOUTPUT, KNOW_SHAPE, UNKNOW_SHAPE }; - explicit Cluster(size_t rank, Type type, NodePtr node, DynamicShapePartitioner *partitioner) + enum Type { DATA, NETOUTPUT, KNOWN_SHAPE, UNKNOWN_SHAPE }; + Cluster(size_t rank, Type type, NodePtr node, DynamicShapePartitioner *partitioner) : id_(rank), min_(rank), max_(rank), type_(type), partitioner_(partitioner) { nodes_.push_back(node); } ~Cluster() = default; - std::string DebugString(); + std::string DebugString() const; // Basic bean functions - size_t Id(); + size_t Id() const; void UpdateRank(size_t rank); - bool IsData(); - bool IsKnowShape(); - bool IsUnknowShape(); - bool IsNetOutput(); - std::unordered_set> Inputs(); - std::unordered_set> Outputs(); - std::vector Nodes(); - bool IsolatedConstant(); + bool IsData() const; + bool IsKnownShape() const; + bool IsUnknownShape() const; + bool IsNetOutput() const; + std::unordered_set> Inputs() const; + std::unordered_set> Outputs() const; + std::vector Nodes() const; + bool IsRefVariable() const; // Cluster modify functions void AddInput(std::shared_ptr in); void RemoveInput(std::shared_ptr in); @@ -110,16 +110,16 @@ class DynamicShapePartitioner { // Collect nodes that satisfy the unknowshape rules: // 1) The Tensor shape of any input or output is unknow shape(dim_size = -1) or unknow rank(dim_size=-2) // 2) Subgraphs of the node has an operator that satisfies rule 1) - Status MarkUnknowShapeNodes(); + Status MarkUnknownShapeNodes(); // For each node a Cluster structure, and connected according to the connection relationship of the nodes // An cluster means set of nodes that can be merged in same partition, // Corresponding relationship between cluster type and node: - // DATA:DATA, UNKNOW_SHAPE:unknowshape, KNOW_SHAPE:knowshape, NETOUTPUT:NETOUTPUT + // DATA:DATA, UNKNOWN_SHAPE:unknowshape, KNOWN_SHAPE:knowshape, NETOUTPUT:NETOUTPUT Status InitClusters(); // Merge clusters according to the following rules: - // 1) Iterate through the UNKNOW_SHAPE clusters, if the input is UNKNOW_SHAPE, + // 1) Iterate through the UNKNOWN_SHAPE clusters, if the input is UNKNOWN_SHAPE, // merge all the clusters in the path(s) between the two clusters - // 2) Iterate through the KNOW_SHAPE clusters, if the input is KNOW_SHAPE, and + // 2) Iterate through the KNOWN_SHAPE clusters, if the input is KNOWN_SHAPE, and // and there's only one path between the two clusters , merge the two clusters Status MergeClusters(); // Topological sort clusters after merge unknow shape clusters. @@ -135,18 +135,18 @@ class DynamicShapePartitioner { // Clear resource and break circular dependency void ClearResource(); // Debug functions - void DumpGraph(std::string suffix); - std::string DebugString(); + void DumpGraph(const std::string &suffix); + std::string DebugString() const; // Util functions - Status CollectSpreadUnknowShapeNodes(NodePtr node); - Status IsUnknowShapeGraph(ge::ComputeGraphPtr graph, bool &is_unknow); - Status IsUnknowShapeNode(ge::NodePtr node, bool &is_unknow); - bool IsUnknowShapeTensor(ge::GeTensorDesc &tensor); + Status CollectSpreadUnknownShapeNodes(NodePtr node); + Status IsUnknownShapeGraph(ge::ComputeGraphPtr graph, bool &is_unknow); + Status IsUnknownShapeNode(ge::NodePtr node, bool &is_unknow); + bool IsUnknownShapeTensor(const ge::GeTensorDesc &tensor); ge::ComputeGraphPtr root_graph_; // The original graph to partition std::unordered_map> node_2_cluster_; // Record nodes and the cluster it belongs to // topological sorted clusters, this field will change with the splitting. - // When partitioning UNKNOW_SHAPE cluster, it is a collection of all topological sorted UNKNOW_SHAPE clusters - // When partitioning KNOW_SHAPE cluster, it is a collection of all topological sorted KNOW_SHAPE clusters + // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters + // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters std::vector> ordered_cluster_; // Unique clusters left after merged clusters std::unordered_set> unique_clusters_; diff --git a/src/ge/graph/partition/graph_partition.cc b/src/ge/graph/partition/graph_partition.cc index b408c287..b25de017 100644 --- a/src/ge/graph/partition/graph_partition.cc +++ b/src/ge/graph/partition/graph_partition.cc @@ -254,75 +254,39 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co return SUCCESS; } -Status ge::GraphPartitioner::UpdatePldOpDesc(const NodePtr &src_node, int output_index, OpDescPtr &pld_op_desc) { - if (src_node == nullptr || pld_op_desc == nullptr || src_node->GetOpDesc() == nullptr) { +Status ge::GraphPartitioner::UpdatePldOpDesc(const NodePtr &dst_node, int input_index, OpDescPtr &pld_op_desc) { + if (dst_node == nullptr || pld_op_desc == nullptr || dst_node->GetOpDesc() == nullptr) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } - const auto &output_desc = src_node->GetOpDesc()->GetOutputDesc(static_cast(output_index)); - GE_IF_BOOL_EXEC(pld_op_desc->AddOutputDesc(output_desc) != GRAPH_SUCCESS, GELOGE(FAILED, "AddOutputDesc failed"); + const auto &input_desc = dst_node->GetOpDesc()->GetInputDesc(static_cast(input_index)); + GE_IF_BOOL_EXEC(pld_op_desc->AddOutputDesc(input_desc) != GRAPH_SUCCESS, GELOGE(FAILED, "AddOutputDesc failed"); return FAILED;) if (pld_op_desc->MutableOutputDesc(0) != nullptr) { ge::TensorUtils::SetRealDimCnt(*(pld_op_desc->MutableOutputDesc(0).get()), - static_cast(output_desc.GetShape().GetDims().size())); + static_cast(input_desc.GetShape().GetDims().size())); } else { GELOGE(GE_GRAPH_ADD_PLC_END_FAILED, "[GraphPartitioner]: pld_op_desc is null."); return FAILED; } - const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); - if (buffer_optimize_on == nullptr) { - // flush pld data type as original data type - if (output_desc.GetOriginDataType() != DT_UNDEFINED) { - pld_op_desc->MutableOutputDesc(0)->SetDataType(output_desc.GetOriginDataType()); - } else { - GELOGW("Original data type of %s is undefined![data type is %s]", src_node->GetName().c_str(), - TypeUtils::DataTypeToSerialString(output_desc.GetDataType()).c_str()); - } - // flush pld format as original format - if (output_desc.GetOriginFormat() != FORMAT_RESERVED) { - pld_op_desc->MutableOutputDesc(0)->SetFormat(output_desc.GetOriginFormat()); - pld_op_desc->MutableOutputDesc(0)->SetShape(output_desc.GetOriginShape()); - } else { - GELOGW("Original format of %s is undefined![format is %s]", src_node->GetName().c_str(), - TypeUtils::FormatToSerialString(output_desc.GetFormat()).c_str()); - } - } return SUCCESS; } -Status ge::GraphPartitioner::UpdateEndOpDesc(const NodePtr &dst_node, int input_index, OpDescPtr &end_op_desc) { - if (dst_node == nullptr || end_op_desc == nullptr || dst_node->GetOpDesc() == nullptr) { +Status ge::GraphPartitioner::UpdateEndOpDesc(const NodePtr &src_node, int output_index, OpDescPtr &end_op_desc) { + if (src_node == nullptr || end_op_desc == nullptr || src_node->GetOpDesc() == nullptr) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } - const auto &input_desc = dst_node->GetOpDesc()->GetInputDesc(static_cast(input_index)); - GE_IF_BOOL_EXEC(end_op_desc->AddInputDesc(input_desc) != GRAPH_SUCCESS, GELOGE(FAILED, "AddInputDesc failed"); + const auto &output_desc = src_node->GetOpDesc()->GetOutputDesc(static_cast(output_index)); + GE_IF_BOOL_EXEC(end_op_desc->AddInputDesc(output_desc) != GRAPH_SUCCESS, GELOGE(FAILED, "AddInputDesc failed"); return FAILED;) if (end_op_desc->MutableInputDesc(0) != nullptr) { ge::TensorUtils::SetRealDimCnt(*(end_op_desc->MutableInputDesc(0).get()), - static_cast(input_desc.GetShape().GetDims().size())); + static_cast(output_desc.GetShape().GetDims().size())); } else { GELOGE(GE_GRAPH_ADD_PLC_END_FAILED, "[GraphPartitioner]: pld_op_desc is null."); return FAILED; } - const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); - if (buffer_optimize_on == nullptr) { - // flush end data type as original data type - if (input_desc.GetOriginDataType() != DT_UNDEFINED) { - end_op_desc->MutableInputDesc(0)->SetDataType(input_desc.GetOriginDataType()); - } else { - GELOGI("Original data type of %s is undefined![data type is %s]", dst_node->GetName().c_str(), - TypeUtils::DataTypeToSerialString(input_desc.GetDataType()).c_str()); - } - // flush end format as original format - if (input_desc.GetOriginFormat() != FORMAT_RESERVED) { - end_op_desc->MutableInputDesc(0)->SetFormat(input_desc.GetOriginFormat()); - end_op_desc->MutableInputDesc(0)->SetShape(input_desc.GetOriginShape()); - } else { - GELOGW("Original format of %s is undefined![format is %s]", dst_node->GetName().c_str(), - TypeUtils::FormatToSerialString(input_desc.GetFormat()).c_str()); - } - } return SUCCESS; } @@ -350,18 +314,18 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr GE_IF_BOOL_EXEC(!AttrUtils::SetStr(end_op_desc, "parentOpType", dst_node->GetType()), GELOGW("SetStr parentOpType failed");) // replace input_desc of end with owner node's desc - int input_index = ge::AnchorUtils::GetIdx(peer_in_anchor); - bool is_need_update_desc = (input_index >= 0) && (graph_info_.mode_ == kPartitioning); + int output_index = ge::AnchorUtils::GetIdx(out_anchor); + bool is_need_update_desc = (output_index >= 0) && (graph_info_.mode_ == kPartitioning); if (is_need_update_desc) { - if (UpdateEndOpDesc(dst_node, input_index, end_op_desc) != SUCCESS) { - GELOGE(GRAPH_PARAM_INVALID, "UpdateEndOpDesc failed, input index %d, engine name is %s", input_index, + if (UpdateEndOpDesc(src_node, output_index, end_op_desc) != SUCCESS) { + GELOGE(GRAPH_PARAM_INVALID, "UpdateEndOpDesc failed, input index %d, engine name is %s", output_index, engine_end_name.c_str()); return FAILED; } } else { GeTensorDesc input_desc; if (end_op_desc->AddInputDesc(input_desc) != SUCCESS) { - GELOGE(GRAPH_PARAM_INVALID, "AddInputDesc failed, input index %d, engine name is %s", input_index, + GELOGE(GRAPH_PARAM_INVALID, "AddInputDesc failed, input index %d, engine name is %s", output_index, engine_end_name.c_str()); return FAILED; } @@ -402,11 +366,11 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr // do not care over flow graph_info_.num_of_pld_end_++; // replace output_desc of pld with input node's output desc - int output_index = ge::AnchorUtils::GetIdx(out_anchor); - is_need_update_desc = (output_index >= 0) && (graph_info_.mode_ == kPartitioning); + int input_index = ge::AnchorUtils::GetIdx(peer_in_anchor); + is_need_update_desc = (input_index >= 0) && (graph_info_.mode_ == kPartitioning); if (is_need_update_desc) { - if (UpdatePldOpDesc(src_node, output_index, pld_op_desc) != SUCCESS) { - GELOGE(GRAPH_PARAM_INVALID, "UpdateEndOpDesc failed, output index %d, engine name is %s", output_index, + if (UpdatePldOpDesc(dst_node, input_index, pld_op_desc) != SUCCESS) { + GELOGE(GRAPH_PARAM_INVALID, "UpdateEndOpDesc failed, output index %d, engine name is %s", input_index, engine_pld_name.c_str()); return FAILED; } @@ -596,14 +560,14 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vectorGetName()); - GraphUtils::DumpGEGraphToOnnx(*sub_graph, sub_graph->GetName()); + GE_DUMP(sub_graph, sub_graph->GetName()); if (!session_graph_id.empty()) { GE_IF_BOOL_EXEC(!AttrUtils::SetStr(sub_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id), GELOGW("SetStr ATTR_NAME_SESSION_GRAPH_ID failed");) } // flush parent node of subgraph sub_graph->SetParentNode(compute_graph->GetParentNode()); + (void)AttrUtils::SetStr(*sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); if (engine_name != input_subgraph_name) { // do not add Data subGraph into SubGraphInfo auto sgi = MakeShared(); if (sgi == nullptr) { diff --git a/src/ge/graph/partition/graph_partition.h b/src/ge/graph/partition/graph_partition.h index 51cafb47..26592359 100644 --- a/src/ge/graph/partition/graph_partition.h +++ b/src/ge/graph/partition/graph_partition.h @@ -127,7 +127,7 @@ class GraphPartitioner { /// Split all sub graph and add placeholder, end according to marks /// traverse marked clusters and split them into sub-graphs Status SplitSubGraphs(ComputeGraphPtr compute_graph); - Status UpdateEndOpDesc(const NodePtr &dst_node, int input_index, OpDescPtr &end_op_desc); + Status UpdateEndOpDesc(const NodePtr &src_node, int output_index, OpDescPtr &end_op_desc); Status UpdatePldOpDesc(const NodePtr &dst_node, int input_index, OpDescPtr &end_op_desc); // Clear partition data diff --git a/src/ge/graph/passes/atomic_addr_clean_pass.cc b/src/ge/graph/passes/atomic_addr_clean_pass.cc index 63928c53..253ab775 100644 --- a/src/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/src/ge/graph/passes/atomic_addr_clean_pass.cc @@ -36,26 +36,10 @@ namespace ge { namespace { bool GraphShouldBeSkip(const ge::ComputeGraphPtr &graph) { // Internal function, guaranteeing graph non-null - auto parent = graph->GetParentGraph(); - if (parent == nullptr) { + if (graph->GetParentGraph() == nullptr) { return false; } - for (NodePtr &node : graph->GetDirectNode()) { - bool is_unknown = false; - auto ret_status = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); - if (ret_status != GRAPH_SUCCESS) { - GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), - node->GetType().c_str()); - continue; - } - if (is_unknown) { - GELOGI("Node %s, type %s is unknown shape, sub graph %s should be skip.", node->GetName().c_str(), - node->GetType().c_str(), graph->GetName().c_str()); - return true; - } - } - GELOGI("Sub graph %s does not have unknown shape node, run the pass.", graph->GetName().c_str()); - return false; + return GraphUtils::IsUnknownShapeGraph(graph); } } // namespace @@ -274,4 +258,12 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { GELOGD("Recognized atomic op %s from FE engine.", op_desc->GetName().c_str()); return true; } +/// +/// @brief Clear Status, uesd for subgraph pass +/// @return SUCCESS +/// +Status AtomicAddrCleanPass::ClearStatus() { + hcom_node_vec_.clear(); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/passes/atomic_addr_clean_pass.h b/src/ge/graph/passes/atomic_addr_clean_pass.h index a4dd2e72..d2d8f2ce 100644 --- a/src/ge/graph/passes/atomic_addr_clean_pass.h +++ b/src/ge/graph/passes/atomic_addr_clean_pass.h @@ -37,6 +37,7 @@ namespace ge { class AtomicAddrCleanPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); + Status ClearStatus() override; private: /** diff --git a/src/ge/graph/passes/base_pass.cc b/src/ge/graph/passes/base_pass.cc index 53025f6a..629b08ba 100644 --- a/src/ge/graph/passes/base_pass.cc +++ b/src/ge/graph/passes/base_pass.cc @@ -78,7 +78,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder continue; } - GELOGD("Begin to run pass %s", name_to_pass.first.c_str()); + GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); name_to_pass.second->init(); auto result = name_to_pass.second->Run(node); if (result != SUCCESS) { diff --git a/src/ge/graph/passes/cast_remove_pass.cc b/src/ge/graph/passes/cast_remove_pass.cc index a0742a03..87caf6e4 100644 --- a/src/ge/graph/passes/cast_remove_pass.cc +++ b/src/ge/graph/passes/cast_remove_pass.cc @@ -124,7 +124,7 @@ Status CastRemovePass::RemoveCast(DataType &type, std::vector &nodes_to return SUCCESS; } -NodePtr CastRemovePass::GetTheEndNode(NodePtr &begin_node, std::vector &nodes_to_fuse) { +NodePtr CastRemovePass::GetTheEndNode(NodePtr begin_node, std::vector &nodes_to_fuse) { while (begin_node->GetOutDataNodes().size() == 1) { auto out_node = begin_node->GetOutDataNodes().at(0); if (!TransOpUtil::IsTransOp(out_node)) { diff --git a/src/ge/graph/passes/cast_remove_pass.h b/src/ge/graph/passes/cast_remove_pass.h index 53318cff..e889781f 100644 --- a/src/ge/graph/passes/cast_remove_pass.h +++ b/src/ge/graph/passes/cast_remove_pass.h @@ -28,7 +28,7 @@ class CastRemovePass : public BaseNodePass { private: bool HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op_desc, DataType &type) const; Status RemoveCast(DataType &type, std::vector &nodes_to_fuse); - NodePtr GetTheEndNode(NodePtr &begin_node, std::vector &nodes_to_fuse); + NodePtr GetTheEndNode(NodePtr begin_node, std::vector &nodes_to_fuse); }; } // namespace ge #endif // GE_GRAPH_PASSES_CAST_REMOVE_PASS_H_ diff --git a/src/ge/graph/passes/cond_pass.cc b/src/ge/graph/passes/cond_pass.cc index 4052950a..6f47689b 100644 --- a/src/ge/graph/passes/cond_pass.cc +++ b/src/ge/graph/passes/cond_pass.cc @@ -18,17 +18,15 @@ #include "common/op/ge_op_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" +#include "graph/utils/node_utils.h" namespace { -const std::set kIfTypes = {ge::IF, ge::_IF, ge::STATELESSIF}; -const std::set kWhileTypes = {ge::WHILE, ge::_WHILE, ge::STATELESSWHILE}; const std::string kStringLength = "StringLength"; const size_t kScalarDimNum = 1; } // namespace namespace ge { Status CondPass::Run(NodePtr &node) { - GE_CHECK_NOTNULL(node); ComputeGraphPtr graph = nullptr; OutDataAnchorPtr cond_out_anchor = nullptr; InDataAnchorPtr cond_in_anchor = nullptr; @@ -41,7 +39,7 @@ Status CondPass::Run(NodePtr &node) { } /// cond - /// 1. NonScalar: cond->Shape->Shape(int32)->If / NetOutput(while) + /// 1. NonScalar: cond->Size(int32)->If / NetOutput(while) /// 2. String Scalar: cond->StringLength(int32)->If / NetOutput(while) /// 3. bool / float / double / uint8 / int16 / int8 / int64 Scalar: cond->Cast(2int32)->If / NetOutput(while) /// 4. Int32 Scalar: cond->If / NetOutput(while) @@ -100,18 +98,18 @@ Status CondPass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDat InDataAnchorPtr &cond_in_anchor) { GE_CHECK_NOTNULL(node); std::string type = node->GetType(); - if (kIfTypes.count(type) != 0) { + if (kIfOpTypes.count(type) != 0) { if (GetCondInfoForIf(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { GELOGE(FAILED, "Get cond_info for if node failed."); return FAILED; } - } else if (kWhileTypes.count(type) != 0) { + } else if (kWhileOpTypes.count(type) != 0) { if (GetCondInfoForWhile(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { GELOGE(FAILED, "Get cond_info for while node failed."); return FAILED; } } else { - GELOGI("no need cond_pass for node %s.", node->GetName().c_str()); + GELOGD("no need cond_pass for node %s.", node->GetName().c_str()); return NOT_CHANGED; } @@ -180,7 +178,7 @@ Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph } /// -/// @brief Process Cond Op with non-scalar cond_input: cond->Shape->Shape->If / NetOutput(while) +/// @brief Process Cond Op with non-scalar cond_input: cond->Size->If / NetOutput(while) /// @param [in] graph /// @param [in] out_anchor: peer_cond_anchor /// @param [in] in_anchor: cond_input @@ -188,17 +186,8 @@ Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph /// Status CondPass::HandleNonScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &in_anchor) { - if (InsertNode(graph, out_anchor, in_anchor, SHAPE) != SUCCESS) { - GELOGE(FAILED, "Insert first Shape node failed."); - return FAILED; - } - - if (InsertNode(graph, in_anchor->GetPeerOutAnchor(), in_anchor, SHAPE) != SUCCESS) { - GELOGE(FAILED, "Insert second Shape node failed."); - return FAILED; - } - - return SUCCESS; + GELOGI("Handle cond with non-scalar cond-input."); + return InsertNode(graph, out_anchor, in_anchor, SIZE); } /// @@ -266,17 +255,8 @@ Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr GeTensorDesc out_tensor = in_anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(out_anchor->GetIdx()); out_tensor.SetDataType(DT_INT32); out_tensor.SetOriginDataType(DT_INT32); - if (type == SHAPE) { - int64_t size = static_cast(in_tensor.GetShape().GetDimNum()); - if (size == kScalarDimNum) { - out_tensor.SetShape(GeShape()); - out_tensor.SetOriginShape(GeShape()); - } else { - std::vector size_v{size}; - out_tensor.SetShape(GeShape(size_v)); - out_tensor.SetOriginShape(GeShape(size_v)); - } - } + out_tensor.SetShape(GeShape()); + out_tensor.SetOriginShape(GeShape()); OpDescBuilder op_desc_builder(out_anchor->GetOwnerNode()->GetName() + "_" + type, type); OpDescPtr op_desc = op_desc_builder.AddInput("x", in_tensor).AddOutput("y", out_tensor).Build(); diff --git a/src/ge/graph/passes/cond_remove_pass.cc b/src/ge/graph/passes/cond_remove_pass.cc new file mode 100644 index 00000000..a5ba0a19 --- /dev/null +++ b/src/ge/graph/passes/cond_remove_pass.cc @@ -0,0 +1,336 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/cond_remove_pass.h" +#include "common/op/ge_op_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/type_utils.h" +#include "graph/utils/node_utils.h" + +namespace { +const uint32_t kConditionIndexNum = 1; +const uint32_t kElseBranchIndex = 1; +const uint32_t kTrueIndex = 1; +const uint32_t kFalseIndex = 0; +/// Extra 8 bytes store pointer of string +/// Extra 1 byte store '\0' +const int32_t kStrHeadLen = 9; +const int32_t kInvalidRetVal = -1; +} // namespace + +namespace ge { +Status CondRemovePass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + ComputeGraphPtr graph = nullptr; + OutDataAnchorPtr cond_out_anchor = nullptr; + InDataAnchorPtr cond_in_anchor = nullptr; + Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); + int32_t cond_index = 0; + GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str()); + bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index); + if (!if_cond_const || (cond_index < 0)) { + return ge::SUCCESS; + } + ComputeGraphPtr chosen_graph = nullptr; + const std::string &node_type = node->GetType(); + // Keep chosen branch + if (kIfOpTypes.count(node_type) != 0) { + ret = GetIfChosenBranch(node, static_cast(cond_index), chosen_graph); + if (ret != ge::SUCCESS) { + return ge::FAILED; + } + } else if (kCaseOpTypes.count(node_type) != 0) { + ret = GetCaseChosenBranch(node, static_cast(cond_index), chosen_graph); + if (ret != ge::SUCCESS) { + return ge::FAILED; + } + } else { + return ge::SUCCESS; + } + // Remove unused link from cond->node + ret = RemoveDeadCondLink(static_cast(IF_COND_INPUT), node); + if (ret != ge::SUCCESS) { + return ge::FAILED; + } + // Copy If/Case node's relations to the new node + ret = ReplaceIfCaseNodeWithPartitioncall(node, chosen_graph); + if (ret != ge::SUCCESS) { + return ge::FAILED; + } + // Isolate and delete the old node + ret = IsolateAndDeleteNode(node, std::vector()); + return ret; +} + +Status CondRemovePass::RemoveDeadCondLink(const int32_t index, const NodePtr &node) { + const auto &in_anchor = node->GetInDataAnchor(index); + const auto &peerout_anchor = in_anchor->GetPeerOutAnchor(); + if (GraphUtils::RemoveEdge(peerout_anchor, in_anchor) != SUCCESS) { + GELOGE(FAILED, "Remove edge from node %s index %d to node %s index %d.", + peerout_anchor->GetOwnerNode()->GetName().c_str(), peerout_anchor->GetIdx(), + in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetIdx()); + return FAILED; + } + return SUCCESS; +} + +Status CondRemovePass::GetCaseChosenBranch(const NodePtr &node, const uint32_t cond_index, + ComputeGraphPtr &compute_graph) { + uint32_t subgraph_names_size = static_cast(node->GetOpDesc()->GetSubgraphInstanceNames().size()); + uint32_t cond_index_new = cond_index; + if (subgraph_names_size == 0) { + GELOGE(FAILED, "Node %s has none subgraph.", node->GetName().c_str()); + return ge::FAILED; + } + // If cond index is over the maimum subgraph number, choose the last subgraph + if (cond_index >= subgraph_names_size) { + cond_index_new = subgraph_names_size - 1; + } + const auto &chosen_branch_name = node->GetOpDesc()->GetSubgraphInstanceName(cond_index_new); + if (chosen_branch_name.empty()) { + GELOGE(FAILED, "Node %s has no subgraph, index is %u.", node->GetName().c_str(), cond_index_new); + return ge::FAILED; + } + auto chosen_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph())->GetSubgraph(chosen_branch_name); + compute_graph = chosen_graph; + // Remove graph from node, in order for remove connection from this node to chosen branch + node->GetOpDesc()->RemoveSubgraphInstanceName(chosen_branch_name); + return ge::SUCCESS; +} + +Status CondRemovePass::GetIfChosenBranch(const NodePtr &node, const uint32_t cond, ComputeGraphPtr &compute_graph) { + uint32_t subgraph_names_size = static_cast(node->GetOpDesc()->GetSubgraphInstanceNames().size()); + uint32_t cond_index_new = 0; + if (subgraph_names_size == 0) { + GELOGE(FAILED, "Node %s has none subgraph.", node->GetName().c_str()); + return ge::FAILED; + } + // If cond is false, else branch + if (cond == 0) { + cond_index_new = kElseBranchIndex; + } + const auto &chosen_branch_name = node->GetOpDesc()->GetSubgraphInstanceName(cond_index_new); + if (chosen_branch_name.empty()) { + GELOGE(FAILED, "Node %s has no subgraph, index is %u.", node->GetName().c_str(), cond_index_new); + return ge::FAILED; + } + auto chosen_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph())->GetSubgraph(chosen_branch_name); + if (chosen_graph == nullptr) { + GELOGE(FAILED, "Can not find branch %s in node %s's parent graph %s.", chosen_branch_name.c_str(), + node->GetName().c_str(), node->GetOwnerComputeGraph()->GetName().c_str()); + return ge::FAILED; + } + compute_graph = chosen_graph; + // Remove graph from node, in order for remove connection from this node to chosen branch + node->GetOpDesc()->RemoveSubgraphInstanceName(chosen_branch_name); + return ge::SUCCESS; +} + +int32_t CondRemovePass::GetCondIndex(const ConstGeTensorPtr &tensor) { + if (tensor == nullptr) { + return kInvalidRetVal; + } + const uint8_t *data_ptr = tensor->GetData().data(); + size_t tensor_size = tensor->GetData().size(); + const auto type = tensor->GetTensorDesc().GetDataType(); + GELOGD("Data type is %d, tensor_size is %zu.", type, tensor_size); + switch (type) { + case DT_STRING: + return static_cast(((tensor_size - kStrHeadLen) > 0) ? kTrueIndex : kFalseIndex); + case DT_BOOL: + return static_cast(*reinterpret_cast(data_ptr)); + case DT_FLOAT: + return static_cast(*reinterpret_cast(data_ptr)); + case DT_DOUBLE: + return static_cast(*reinterpret_cast(data_ptr)); + case DT_INT8: + case DT_UINT8: + return static_cast(*data_ptr); + case DT_FLOAT16: + case DT_INT16: + case DT_UINT16: + return static_cast(*reinterpret_cast(data_ptr)); + case DT_INT32: + return static_cast(*reinterpret_cast(data_ptr)); + case DT_UINT32: + return *reinterpret_cast(data_ptr); + case DT_INT64: + case DT_UINT64: + return static_cast(*reinterpret_cast(data_ptr)); + default: + return static_cast(*data_ptr); + } +} + +bool CondRemovePass::CheckIfCondConstInput(const OutDataAnchorPtr &cond_out_anchor, + const InDataAnchorPtr &cond_in_anchor, int32_t &cond_index) { + // if pre or next anchor is null, return + CHECK_FALSE_EXEC(cond_out_anchor != nullptr, return false); + CHECK_FALSE_EXEC(cond_in_anchor != nullptr, return false); + const auto &out_node = cond_out_anchor->GetOwnerNode(); + const auto &cur_node = cond_in_anchor->GetOwnerNode(); + OpDescPtr op_desc = cur_node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return false); + GeTensorDesc cond_tensor = out_node->GetOpDesc()->GetOutputDesc(static_cast(cond_out_anchor->GetIdx())); + GELOGI("Check if condition is const for node %s.", op_desc->GetName().c_str()); + if (kConstOpTypes.count(out_node->GetOpDesc()->GetType()) == 0) { + return false; + } + // Case node only support int32 input + if ((kCaseOpTypes.count(cur_node->GetType()) != 0) && (cond_tensor.GetDataType() != DT_INT32)) { + GELOGW("Check input failed, node is %s, condition datatype is %s.", op_desc->GetName().c_str(), + TypeUtils::DataTypeToSerialString(cond_tensor.GetDataType()).c_str()); + return false; + } + // Get weights from peer node + auto weights = OpDescUtils::GetWeights(out_node); + if (weights.size() <= static_cast(cond_out_anchor->GetIdx())) { + GELOGI("Get weights of node %s out index %d, weight size %u is not fit for data index %d.", + out_node->GetName().c_str(), cond_out_anchor->GetIdx(), weights.size(), cond_out_anchor->GetIdx()); + return false; + } + ConstGeTensorPtr tensor = weights[cond_out_anchor->GetIdx()]; + GE_CHECK_NOTNULL_EXEC(tensor, return false); + bool if_zero_dim = false; + if (!cond_tensor.GetShape().IsScalar()) { + for (size_t dim = 0; dim < cond_tensor.GetShape().GetDimNum(); dim++) { + if (cond_tensor.GetShape().GetDim(dim) == 0) { + if_zero_dim = true; + break; + } + } + // If dim num is not zero and do not has zero dim, index is 1, else index is 0 + cond_index = static_cast((cond_tensor.GetShape().GetDimNum() != 0) && !if_zero_dim); + } else { + // Get condition index + cond_index = GetCondIndex(tensor); + } + GELOGD("Condition index is %d, node name is %s, anchor index is %d, dim num is %zu, zero dim flag %d", cond_index, + op_desc->GetName().c_str(), cond_out_anchor->GetIdx(), cond_tensor.GetShape().GetDimNum(), if_zero_dim); + return true; +} + +Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, const ComputeGraphPtr &save_branch) { + // Add compute graph to new node + const auto &input_anchors = node->GetAllInAnchors(); + const auto &output_anchors = node->GetAllOutAnchors(); + // Create subgraph opdesc & node + auto partitioncall_opdesc = + CreateSubgraphOpDesc(save_branch->GetName(), input_anchors.size() - kConditionIndexNum, output_anchors.size()); + auto partitioncall_node = node->GetOwnerComputeGraph()->AddNode(partitioncall_opdesc); + // Link node's peerout anchors to new node's inanchors + for (const auto &input_anchor : input_anchors) { + for (const auto &peerout_anchor : input_anchor->GetPeerAnchors()) { + if (GraphUtils::AddEdge(peerout_anchor, partitioncall_node->GetInAnchor( + input_anchor->GetIdx() - kConditionIndexNum)) != ge::GRAPH_SUCCESS) { + GELOGE(FAILED, "Add edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%d, output num:%d", + peerout_anchor->GetOwnerNode()->GetName().c_str(), peerout_anchor->GetIdx(), + partitioncall_node->GetName().c_str(), input_anchor->GetIdx(), input_anchors.size(), + output_anchors.size()); + return FAILED; + } + } + } + // Remove If / Case anchor and peer in anchor + // Link new node's out anchors to node's peer inanchors + for (const auto &output_anchor : output_anchors) { + for (const auto &peerin_anchor : output_anchor->GetPeerAnchors()) { + if (GraphUtils::RemoveEdge(node->GetOutAnchor(output_anchor->GetIdx()), peerin_anchor) != ge::GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%d, output num:%d", + node->GetName().c_str(), output_anchor->GetIdx(), peerin_anchor->GetOwnerNode()->GetName().c_str(), + peerin_anchor->GetIdx(), input_anchors.size(), output_anchors.size()); + return FAILED; + } + if (GraphUtils::AddEdge(partitioncall_node->GetOutAnchor(output_anchor->GetIdx()), peerin_anchor) != + ge::GRAPH_SUCCESS) { + GELOGE(FAILED, "Add edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%d, output num:%d", + partitioncall_node->GetName().c_str(), output_anchor->GetIdx(), + peerin_anchor->GetOwnerNode()->GetName().c_str(), peerin_anchor->GetIdx(), input_anchors.size(), + output_anchors.size()); + return FAILED; + } + } + } + // update save branch information + std::map input_mapping; + uint32_t new_input_num = static_cast(node->GetOpDesc()->GetAllInputsSize()) - kConditionIndexNum; + for (uint32_t i = 0; i < new_input_num; i++) { + // original index + 1 map to index + input_mapping[i + 1] = i; + } + save_branch->UpdateInputMapping(input_mapping); + save_branch->SetParentNode(partitioncall_node); + save_branch->SetParentGraph(node->GetOwnerComputeGraph()); + return SUCCESS; +} + +/// +/// @brief Create op_desc for subgraph node +/// @param [in] name +/// @param [in] input_num +/// @param [in] output_num +/// @return OpDescPtr +/// +OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const std::string &name, size_t input_num, size_t output_num) { + OpDescBuilder op_desc_builder(name, PARTITIONEDCALL); + op_desc_builder.AddDynamicInput("args", input_num).AddDynamicOutput("output", output_num); + + OpDescPtr op_desc = op_desc_builder.Build(); + GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); + + size_t index = op_desc->GetSubgraphInstanceNames().size(); + op_desc->AddSubgraphName("f"); + op_desc->SetSubgraphInstanceName(static_cast(index), name); + return op_desc; +} + +/// +/// @brief Get cond info for if/case node +/// @param [in] node: If/Case op +/// @param [out] graph: owner_graph of if node +/// @param [out] cond_out_anchor: peer_cond_anchor +/// @param [out] cond_in_anchor: cond_input of if +/// @return Status +/// +Status CondRemovePass::GetCondInfoForIfCase(const NodePtr &node, ComputeGraphPtr &graph, + OutDataAnchorPtr &cond_out_anchor, InDataAnchorPtr &cond_in_anchor) { + GE_CHECK_NOTNULL(node); + graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + cond_in_anchor = node->GetInDataAnchor(IF_COND_INPUT); + GE_CHECK_NOTNULL(cond_in_anchor); + cond_out_anchor = cond_in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(cond_out_anchor); + return SUCCESS; +} + +Status CondRemovePass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, + InDataAnchorPtr &cond_in_anchor) { + GE_CHECK_NOTNULL(node); + std::string type = node->GetType(); + if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) { + if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { + GELOGE(FAILED, "Get cond_info for if node failed."); + return FAILED; + } + } else { + GELOGI("no need cond_pass for node %s.", node->GetName().c_str()); + return NOT_CHANGED; + } + + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/cond_remove_pass.h b/src/ge/graph/passes/cond_remove_pass.h new file mode 100644 index 00000000..69dd7195 --- /dev/null +++ b/src/ge/graph/passes/cond_remove_pass.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_COND_REMOVE_PASS_H +#define GE_GRAPH_PASSES_COND_REMOVE_PASS_H + +#include "graph/passes/base_pass.h" + +namespace ge { +class CondRemovePass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; + + private: + /// + /// @brief Get cond info for if/case node + /// @param [in] node: If/Case op + /// @param [out] graph: owner_graph of if node + /// @param [out] cond_out_anchor: peer_cond_anchor + /// @param [out] cond_in_anchor: cond_input of if + /// @return Status + /// + Status GetCondInfoForIfCase(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, + InDataAnchorPtr &cond_in_anchor); + /// + /// @brief Get cond info for if/case node + /// @param [in] node: If/Case op + /// @param [out] graph: owner_graph of if node + /// @param [out] cond_out_anchor: peer_cond_anchor + /// @param [out] cond_in_anchor: cond_input of if + /// @return Status + /// + Status GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, + InDataAnchorPtr &cond_in_anchor); + /// + /// @brief Check if condition input is const, for if / case / while + /// + bool CheckIfCondConstInput(const OutDataAnchorPtr &cond_out_anchor, const InDataAnchorPtr &cond_in_anchor, + int32_t &cond_index); + + /// + /// @brief Remove if dead branch, for if + /// + Status GetIfChosenBranch(const NodePtr &node, const uint32_t cond_index, ComputeGraphPtr &compute_graph); + + /// + /// @brief Remove if dead branch, for case + /// + Status GetCaseChosenBranch(const NodePtr &node, const uint32_t cond_index, ComputeGraphPtr &compute_graph); + + /// + /// @brief Remove dead condition input, for if / case / while + /// + Status RemoveDeadCondLink(const int32_t index, const NodePtr &node); + + /// + /// @brief Remove if dead branch, for if + /// + Status ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, const ComputeGraphPtr &save_branch); + + OpDescPtr CreateSubgraphOpDesc(const std::string &name, size_t input_num, size_t output_num); + + int32_t GetCondIndex(const ConstGeTensorPtr &tensor); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_COND_REMOVE_PASS_H diff --git a/src/ge/graph/passes/constant_fuse_same_pass.cc b/src/ge/graph/passes/constant_fuse_same_pass.cc index 69726e5d..4197f429 100644 --- a/src/ge/graph/passes/constant_fuse_same_pass.cc +++ b/src/ge/graph/passes/constant_fuse_same_pass.cc @@ -22,10 +22,10 @@ #include #include -#include "graph/debug/ge_attr_define.h" +#include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "common/ge/ge_util.h" +#include "graph/debug/ge_attr_define.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" @@ -177,6 +177,8 @@ Status ConstantFuseSamePass::FuseConstNodes(ComputeGraphPtr &graph, auto first_node = nodes.at(0); for (size_t i = 1; i < len; ++i) { auto node = nodes.at(i); + + GELOGI("Replace redundant const ndoe %s by %s", node->GetName().c_str(), first_node->GetName().c_str()); // the const node which can be fused has none input(both data and control in) if (GraphUtils::MoveOutCtrlEdges(node, first_node) != SUCCESS) { return FAILED; diff --git a/src/ge/graph/passes/control_trigger_pass.cc b/src/ge/graph/passes/control_trigger_pass.cc index b1218d9f..77fcbd69 100644 --- a/src/ge/graph/passes/control_trigger_pass.cc +++ b/src/ge/graph/passes/control_trigger_pass.cc @@ -30,7 +30,6 @@ namespace ge { Status ControlTriggerPass::Run(ComputeGraphPtr graph) { GELOGD("ControlTriggerPass Enter"); - for (NodePtr &node : graph->GetDirectNode()) { if (node->GetType() != CONTROLTRIGGER) { continue; @@ -444,4 +443,13 @@ Status ControlTriggerPass::FindPredInput(const NodePtr &switch_node) { switch_cond_map_[switch_node] = pred_cond_anchor->GetOwnerNode(); return SUCCESS; } +/// +/// @brief Clear Status, uesd for subgraph pass +/// @return SUCCESS +/// +Status ControlTriggerPass::ClearStatus() { + switch_cond_map_.clear(); + control_trigger_map_.clear(); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/passes/control_trigger_pass.h b/src/ge/graph/passes/control_trigger_pass.h index b9fff9b4..44d11cad 100644 --- a/src/ge/graph/passes/control_trigger_pass.h +++ b/src/ge/graph/passes/control_trigger_pass.h @@ -30,6 +30,7 @@ enum ControlNodeType { kNotControlOp, kCondSwitch, kCondMerge, kLoopSwitchT, kLo class ControlTriggerPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); + Status ClearStatus() override; private: Status HandleDynamicCtrlEdges(ComputeGraphPtr &graph, NodePtr &node, NodePtr &in_ctrl_node); diff --git a/src/ge/graph/passes/ctrl_edge_transfer_pass.cc b/src/ge/graph/passes/ctrl_edge_transfer_pass.cc new file mode 100644 index 00000000..9454c00d --- /dev/null +++ b/src/ge/graph/passes/ctrl_edge_transfer_pass.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/ctrl_edge_transfer_pass.h" + +#include "framework/common/debug/ge_log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/util.h" +#include "graph/utils/graph_utils.h" + +namespace ge { +/* Pass Explaination: + * + * After optimizing such as constant folding, it will form the following ctrl relationship + * The sceno like this is unreasonable and when unknown shape, it will be error because + * constant does not generate task. So when graph is stability, transfer the ctrl edge to + * next op and guatantee the timing relationship + * + * A(ctrl edge)----constant------(ctrl edge)B or A(ctrl edge)----constant-----(data edge)B + * + * when after process, it will be like as follows: + * + * A constant + * \ / + * B + */ + +Status CtrlEdgeTransferPass::Run(ge::ComputeGraphPtr graph) { + GELOGD("CtrlEdgeTransferPass start running"); + GE_CHECK_NOTNULL(graph); + + for (ge::NodePtr &n : graph->GetDirectNode()) { + auto op_desc = n->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + auto op_type = op_desc->GetType(); + if (op_type == CONSTANT || op_type == CONSTANTOP) { + if (n->GetInAllNodes().empty()) { + GELOGD("[CtrlEdgeTransferPass] node [%s] in nodes is empty", n->GetName().c_str()); + continue; + } + + GELOGD("start to tranfer ctrl edge for const node [%s]", n->GetName().c_str()); + + for (auto &in_control_node : n->GetInControlNodes()) { + GE_CHECK_NOTNULL(in_control_node); + GE_CHK_STATUS_RET(ge::GraphUtils::RemoveEdge(in_control_node->GetOutControlAnchor(), n->GetInControlAnchor()), + "remove edge failed"); + for (auto &out_node : n->GetOutNodes()) { + if (out_node == nullptr) { + continue; + } + GE_CHK_STATUS_RET( + ge::GraphUtils::AddEdge(in_control_node->GetOutControlAnchor(), out_node->GetInControlAnchor()), + "add edge failed."); + } + } + } + } + GELOGD("CtrlEdgeTransferPass running success!"); + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/ctrl_edge_transfer_pass.h b/src/ge/graph/passes/ctrl_edge_transfer_pass.h new file mode 100644 index 00000000..ee981012 --- /dev/null +++ b/src/ge/graph/passes/ctrl_edge_transfer_pass.h @@ -0,0 +1,28 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_CTRL_EDGE_TRANSFER_PASS_H_ +#define GE_GRAPH_PASSES_CTRL_EDGE_TRANSFER_PASS_H_ + +#include "inc/graph_pass.h" + +namespace ge { +class CtrlEdgeTransferPass : public GraphPass { + public: + Status Run(ge::ComputeGraphPtr graph) override; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_CTRL_EDGE_TRANSFER_PASS_H_ diff --git a/src/ge/graph/passes/data_pass.cc b/src/ge/graph/passes/data_pass.cc new file mode 100644 index 00000000..517e7737 --- /dev/null +++ b/src/ge/graph/passes/data_pass.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/data_pass.h" + +#include + +#include "framework/common/debug/ge_log.h" +#include "graph/utils/graph_utils.h" +#include "register/op_registry.h" + +namespace ge { +Status DataPass::Run(ComputeGraphPtr compute_graph) { + GE_CHECK_NOTNULL(compute_graph); + if (compute_graph->GetParentNode() == nullptr) { // for subgraph post process. + return SUCCESS; + } + + for (const NodePtr &node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + if (node->GetType() == DATA) { + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + break; // parent_index not set, Graph from IR. + } + + return SUCCESS; // Graph from Parser. + } + } + + std::string subgraph_name; + const auto &parent_node = compute_graph->GetParentNode(); + GE_CHECK_NOTNULL(parent_node->GetOpDesc()); + auto func_desc = parent_node->GetOpDesc(); + GE_CHK_STATUS_RET(func_desc->GetSubgraphNameByInstanceName(compute_graph->GetName(), subgraph_name), + "Subgraph: %s get subgraph name failed.", compute_graph->GetName().c_str()); + + GELOGI("Post process for subgraph %s, Subgraph name: %s, Parent name: %s, Parent type: %s.", + compute_graph->GetName().c_str(), subgraph_name.c_str(), parent_node->GetName().c_str(), + parent_node->GetType().c_str()); + + const auto &parent_graph = compute_graph->GetParentGraph(); + GE_CHECK_NOTNULL(parent_graph); + for (const NodePtr &node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { + continue; + } + + node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); + } + + auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(parent_node->GetType()); + if (post_func == nullptr) { + GELOGW("The subgraph post func for node %s type %s is null.", parent_node->GetName().c_str(), + parent_node->GetType().c_str()); + return SUCCESS; + } + + auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); + auto ret = post_func(subgraph_name, graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), + parent_node->GetName().c_str(), parent_node->GetType().c_str()); + return FAILED; + } + + return SUCCESS; +} +} // namespace ge diff --git a/third_party/fwkacllib/inc/ops/decode_bbox.h b/src/ge/graph/passes/data_pass.h similarity index 60% rename from third_party/fwkacllib/inc/ops/decode_bbox.h rename to src/ge/graph/passes/data_pass.h index 9fe95488..1f6d0f0b 100644 --- a/third_party/fwkacllib/inc/ops/decode_bbox.h +++ b/src/ge/graph/passes/data_pass.h @@ -14,20 +14,17 @@ * limitations under the License. */ - #ifndef GE_OP_DECODE_BBOX_H - #define GE_OP_DECODE_BBOX_H +#ifndef GE_GRAPH_PASSES_DATA_PASS_H_ +#define GE_GRAPH_PASSES_DATA_PASS_H_ - #include "graph/operator_reg.h" +#include "graph/graph.h" +#include "inc/graph_pass.h" - namespace ge { +namespace ge { +class DataPass : public GraphPass { + public: + Status Run(ge::ComputeGraphPtr graph); +}; +} // namespace ge - REG_OP(DecodeBbox) - .INPUT(box_predictions, TensorType{DT_FLOAT16}) - .INPUT(anchors, TensorType{DT_FLOAT16}) - .OUTPUT(decoded_boxes, TensorType{DT_FLOAT16}) - .REQUIRED_ATTR(decode_clip, Float) - .OP_END_FACTORY_REG(DecodeBbox) - - } // namespace ge - - #endif // GE_OP_DECODE_BBOX_H +#endif // GE_GRAPH_PASSES_DATA_PASS_H_ diff --git a/src/ge/graph/passes/dimension_adjust_pass.cc b/src/ge/graph/passes/dimension_adjust_pass.cc index 28ebbb83..a734ddc3 100644 --- a/src/ge/graph/passes/dimension_adjust_pass.cc +++ b/src/ge/graph/passes/dimension_adjust_pass.cc @@ -45,6 +45,7 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { GELOGE(ret, "DimensionAdjustPass get originnal type fail."); return ret; } + KernelFactory &factory = KernelFactory::Instance(); shared_ptr op_kernel = factory.Create(type); if (op_kernel == nullptr) { diff --git a/src/ge/graph/passes/flow_ctrl_pass.cc b/src/ge/graph/passes/flow_ctrl_pass.cc index a8c20a79..fb05ca6a 100644 --- a/src/ge/graph/passes/flow_ctrl_pass.cc +++ b/src/ge/graph/passes/flow_ctrl_pass.cc @@ -38,7 +38,7 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { } if (!PassUtils::IsNeedTrainIteFlowCtrl(compute_graph)) { - GELOGI("No need FlowCtrl"); + GELOGI("No need FlowCtrl for graph %u", compute_graph->GetGraphID()); return NOT_CHANGED; } @@ -251,6 +251,7 @@ NodePtr FlowCtrlPass::InsertAssignOp(ge::ComputeGraphPtr &compute_graph, const s GELOGE(FAILED, "Add value_node to %s edge failed, add_ret=%u.", node_name.c_str(), add_ret); return nullptr; } + (void)ge::AttrUtils::SetBool(assign_node->GetOpDesc(), ATTR_NEED_COMPILE, true); return assign_node; } diff --git a/src/ge/graph/passes/folding_pass.cc b/src/ge/graph/passes/folding_pass.cc index 41528ec3..4e51f1ca 100644 --- a/src/ge/graph/passes/folding_pass.cc +++ b/src/ge/graph/passes/folding_pass.cc @@ -142,6 +142,10 @@ Status FoldingPass::Folding(NodePtr &node, vector &outputs) { for (auto iter = in_data_nodes_set.begin(); iter != in_data_nodes_set.end(); ++iter) { auto pre_node = *iter; if (pre_node->GetOutDataNodesSize() == 0) { + if (pre_node->GetType() == DATA) { + GELOGI("No need to remove data, node name:%s.", pre_node->GetName().c_str()); + continue; + } if (IsolateAndDeleteNode(pre_node, {}) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to isolate and delete in data node %s, type %s.", pre_node->GetName().c_str(), pre_node->GetType().c_str()); diff --git a/src/ge/graph/passes/for_pass.cc b/src/ge/graph/passes/for_pass.cc index f63e8627..d9a17509 100644 --- a/src/ge/graph/passes/for_pass.cc +++ b/src/ge/graph/passes/for_pass.cc @@ -29,26 +29,26 @@ namespace { const uint32_t kWhileIInputIndex = 0; -const uint32_t kWhileNInputIndex = 1; -const uint32_t kWhileStartInputIndex = 2; -const uint32_t kWhileDeltaInputIndex = 3; -const uint32_t kWhileDataInputIndex = 4; +const uint32_t kWhileAbsDeltaInputIndex = 1; +const uint32_t kWhileRangeInputIndex = 2; +const uint32_t kWhileStartInputIndex = 3; +const uint32_t kWhileDeltaInputIndex = 4; +const uint32_t kWhileDataInputIndex = 5; const uint32_t kSubgraphLoopVarInputIndex = 0; const uint32_t kSubgraphInputIndex = 1; -const uint32_t kWhileOutputIndex = 4; +const uint32_t kWhileOutputIndex = 5; const std::string kAbs = "Abs"; } // namespace namespace ge { Status ForPass::Run(NodePtr &node) { - GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); if (node->GetType() != FOR) { + GELOGD("no need for_pass for node %s.", node->GetName().c_str()); return SUCCESS; } GELOGI("Begin to transfer for_op to while_op, node:%s.", node->GetName().c_str()); - ComputeGraphPtr graph = node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL(graph); ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph); @@ -98,7 +98,7 @@ Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &n OutDataAnchorPtr limit = FindInputWithIndex(node, FOR_LIMIT_INPUT); OutDataAnchorPtr delta = FindInputWithIndex(node, FOR_DELTA_INPUT); if ((start == nullptr) || (limit == nullptr) || (delta == nullptr)) { - GELOGE(FAILED, "BuildForInfo for %s failed: start / limit / delta is NULL.", node->GetName().c_str()); + GELOGE(FAILED, "BuildForInfo for %s failed: start/limit/delta is NULL.", node->GetName().c_str()); return FAILED; } @@ -107,7 +107,7 @@ Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &n std::vector ctrl_inputs; std::vector ctrl_outputs; if (FindInputsAndOutputs(node, data_inputs, data_outputs, ctrl_inputs, ctrl_outputs) != SUCCESS) { - GELOGE(FAILED, "BuildForInfo for %s failed: find inputs /outputs failed.", node->GetName().c_str()); + GELOGE(FAILED, "BuildForInfo for %s failed: find inputs/outputs failed.", node->GetName().c_str()); return FAILED; } NodeUtils::UnlinkAll(*node); @@ -235,20 +235,31 @@ Status ForPass::TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_i } AddRePassNode(i_node); - // Const node has and only has one output - OutDataAnchorPtr i_input = i_node->GetOutDataAnchor(0); + std::string identity_name = i_name + "_Identity"; + NodePtr identity_node = graph->AddNode(CreateOpDesc(identity_name, IDENTITY, true)); + // Const node has and only has one output, Identity node has and only has one input + if ((identity_node == nullptr) || + (GraphUtils::AddEdge(i_node->GetOutDataAnchor(0), identity_node->GetInDataAnchor(0)) != GRAPH_SUCCESS)) { + GELOGE(FAILED, "TranWhileInfo failed: Add data-edge %s:0->%s:0 failed.", i_name.c_str(), identity_name.c_str()); + return FAILED; + } + AddRePassNode(identity_node); + + // Identity node has and only has one output + OutDataAnchorPtr i_input = identity_node->GetOutDataAnchor(0); if (i_input == nullptr) { GELOGE(FAILED, "TranWhileInfo failed: i_input is NULL."); return FAILED; } - OutDataAnchorPtr n_input = CreateLoopCountInput(graph, for_info); - if (n_input == nullptr) { - GELOGE(FAILED, "TranWhileInfo failed: n_input is NULL."); + OutDataAnchorPtr range_input = nullptr; + OutDataAnchorPtr abs_delta_input = nullptr; + if (CreateLoopInput(graph, for_info, range_input, abs_delta_input) != SUCCESS) { + GELOGE(FAILED, "TranWhileInfo failed: create loop input failed."); return FAILED; } - BuildWhileInfo(for_info, i_input, n_input, while_info); + BuildWhileInfo(for_info, i_input, range_input, abs_delta_input, while_info); if (InsertWhileNode(graph, for_name + "_While", while_info) != SUCCESS) { GELOGE(FAILED, "TranWhileInfo failed: insert while node failed."); @@ -273,7 +284,7 @@ OpDescPtr ForPass::CreateConstDesc(const std::string &name, int32_t value) { return nullptr; } - GeTensorDesc data_desc(GeShape(), FORMAT_NCHW, DT_INT32); + GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_INT32); GeTensorPtr const_value = MakeShared(data_desc, reinterpret_cast(&value), sizeof(int32_t)); if (const_value == nullptr) { GELOGE(FAILED, "Create tensor failed, const:%s.", name.c_str()); @@ -294,12 +305,15 @@ OpDescPtr ForPass::CreateConstDesc(const std::string &name, int32_t value) { } /// -/// @brief Create loop_count node +/// @brief Create loop node /// @param [in] graph /// @param [in] for_info -/// @return OutDataAnchorPtr +/// @param [out] range_input +/// @param [out] abs_delta_input +/// @return Status /// -OutDataAnchorPtr ForPass::CreateLoopCountInput(const ComputeGraphPtr &graph, const ForInfo &for_info) { +Status ForPass::CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info, OutDataAnchorPtr &range_input, + OutDataAnchorPtr &abs_delta_input) { std::string for_name = for_info.for_node->GetName(); GELOGD("Begin to create loop_count input, node:%s", for_name.c_str()); @@ -310,15 +324,8 @@ OutDataAnchorPtr ForPass::CreateLoopCountInput(const ComputeGraphPtr &graph, con std::string sub_name_0 = for_name + "_Sub_0"; std::string abs_name_0 = for_name + "_Abs_0"; std::string abs_name_1 = for_name + "_Abs_1"; - std::string add_name_0 = for_name + "_Add_0"; - std::string const_name = for_name + "_Const"; - std::string sub_name_1 = for_name + "_Sub_1"; - std::string cast_name_0 = for_name + "_Cast_0"; - std::string cast_name_1 = for_name + "_Cast_1"; - std::string div_name = for_name + "_RealDiv"; - std::string cast_name_2 = for_name + "_Cast_2"; - - // n = cast(cast(abs(limit-start) + abs(delta) - 1, float) / cast(abs(delta), float), int32) + + // i * |delta| < |limit-start| PartialGraphBuilder graph_builder; graph_builder.SetOwnerGraph(graph) .AddExistNode(for_info.start->GetOwnerNode()) @@ -327,78 +334,36 @@ OutDataAnchorPtr ForPass::CreateLoopCountInput(const ComputeGraphPtr &graph, con .AddNode(CreateOpDesc(sub_name_0, SUB, false)) .AddNode(CreateOpDesc(abs_name_0, kAbs, true)) .AddNode(CreateOpDesc(abs_name_1, kAbs, true)) - .AddNode(CreateOpDesc(add_name_0, ADD, false)) - .AddNode(CreateConstDesc(const_name, 1)) - .AddNode(CreateOpDesc(sub_name_1, SUB, false)) - .AddNode(CreateCastDesc(cast_name_0, DT_INT32, DT_FLOAT)) - .AddNode(CreateCastDesc(cast_name_1, DT_INT32, DT_FLOAT)) - .AddNode(CreateOpDesc(div_name, REALDIV, false)) - .AddNode(CreateCastDesc(cast_name_2, DT_FLOAT, DT_INT32)) + .AddDataLink(delta->GetOwnerNode()->GetName(), delta->GetIdx(), abs_name_0, 0) .AddDataLink(limit->GetOwnerNode()->GetName(), limit->GetIdx(), sub_name_0, 0) .AddDataLink(start->GetOwnerNode()->GetName(), start->GetIdx(), sub_name_0, 1) - .AddDataLink(sub_name_0, 0, abs_name_0, 0) - .AddDataLink(delta->GetOwnerNode()->GetName(), delta->GetIdx(), abs_name_1, 0) - .AddDataLink(abs_name_0, 0, add_name_0, 0) - .AddDataLink(abs_name_1, 0, add_name_0, 1) - .AddDataLink(add_name_0, 0, sub_name_1, 0) - .AddDataLink(const_name, 0, sub_name_1, 1) - .AddDataLink(sub_name_1, 0, cast_name_0, 0) - .AddDataLink(abs_name_1, 0, cast_name_1, 0) - .AddDataLink(cast_name_0, 0, div_name, 0) - .AddDataLink(cast_name_1, 0, div_name, 1) - .AddDataLink(div_name, 0, cast_name_2, 0); + .AddDataLink(sub_name_0, 0, abs_name_1, 0); graphStatus error_code = GRAPH_SUCCESS; std::string error_msg; if ((graph_builder.Build(error_code, error_msg) == nullptr) || (error_code != GRAPH_SUCCESS)) { GELOGE(FAILED, "Create loop_count node failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str()); - return nullptr; - } - - NodePtr loop_count_node = graph_builder.GetNode(cast_name_2); - if (loop_count_node == nullptr) { - GELOGE(FAILED, "Create loop_count node failed: node is NULL."); - return nullptr; + return FAILED; } - GELOGD("Create loop_count input succ, node:%s", for_name.c_str()); - // loop_count_node is a Cast node, has and only has one output - return loop_count_node->GetOutDataAnchor(0); -} - -/// -/// @brief Create cast op_desc -/// @param [in] name -/// @param [in] src_data_type -/// @param [in] dst_data_type -/// @return OpDescPtr -/// -OpDescPtr ForPass::CreateCastDesc(const std::string &name, DataType src, DataType dst) { - OpDescPtr cast_desc = CreateOpDesc(name, CAST, true); - if (cast_desc == nullptr) { - GELOGE(FAILED, "Create cast op_desc failed, node: %s.", name.c_str()); - return nullptr; + // Add repass_nodes + for (auto &node : graph_builder.GetAllNodes()) { + AddRePassNode(node); } - // cast node has and only has one input /output - GeTensorDesc in_tensor = cast_desc->GetInputDesc(0); - in_tensor.SetDataType(src); - GeTensorDesc out_tensor = cast_desc->GetOutputDesc(0); - out_tensor.SetDataType(dst); - if ((cast_desc->UpdateInputDesc(0, in_tensor) != GRAPH_SUCCESS) || - (cast_desc->UpdateOutputDesc(0, out_tensor) != GRAPH_SUCCESS)) { - GELOGE(FAILED, "Update tensor failed."); - return nullptr; + NodePtr abs_delta_node = graph_builder.GetNode(abs_name_0); + NodePtr loop_count_node = graph_builder.GetNode(abs_name_1); + if ((abs_delta_node == nullptr) || (loop_count_node == nullptr)) { + GELOGE(FAILED, "Create loop node failed: node is NULL."); + return FAILED; } - if (!(AttrUtils::SetInt(cast_desc, CAST_ATTR_SRCT, src) && AttrUtils::SetInt(cast_desc, CAST_ATTR_DSTT, dst) && - AttrUtils::SetInt(cast_desc, CAST_ATTR_DST_TYPE, dst) && - AttrUtils::SetBool(cast_desc, CAST_ATTR_TRUNCATE, false))) { - GELOGE(FAILED, "Set CAST_ATTR failed, node: %s.", name.c_str()); - return nullptr; - } + GELOGD("Create loop_range input succ, node:%s", for_name.c_str()); + // abs_node has and only has one output + abs_delta_input = abs_delta_node->GetOutDataAnchor(0); + range_input = loop_count_node->GetOutDataAnchor(0); - return cast_desc; + return SUCCESS; } /// @@ -423,20 +388,24 @@ OpDescPtr ForPass::CreateOpDesc(const std::string &name, const std::string &type /// @brief Build while-info /// @param [in] for_info /// @param [in] i_input -/// @param [in] n_input +/// @param [in] range_input +/// @param [in] abs_delta_input /// @param [out] while_info /// @return void /// -void ForPass::BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input, const OutDataAnchorPtr &n_input, +void ForPass::BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input, + const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input, WhileInfo &while_info) { while_info.i = i_input; - while_info.n = n_input; + while_info.abs_delta = abs_delta_input; + while_info.range = range_input; while_info.start = for_info.start; while_info.delta = for_info.delta; while_info.for_body_name = for_info.body_name; while_info.for_body = for_info.for_body; while_info.data_inputs.emplace_back(while_info.i); - while_info.data_inputs.emplace_back(while_info.n); + while_info.data_inputs.emplace_back(while_info.abs_delta); + while_info.data_inputs.emplace_back(while_info.range); while_info.data_inputs.emplace_back(while_info.start); while_info.data_inputs.emplace_back(while_info.delta); for (auto &item : for_info.data_inputs) { @@ -553,12 +522,15 @@ ComputeGraphPtr ForPass::BuildCondGraph(WhileInfo &while_info) { graph_builder.SetParentNode(while_info.while_node); // Add Node + const std::string mul_name = "Mul"; + graph_builder.AddNode(CreateOpDesc(mul_name, MUL, false)); const std::string less_name = "Less"; graph_builder.AddNode(CreateOpDesc(less_name, LESS, false)); // Set Input - graph_builder.SetInput(kWhileIInputIndex, {less_name}, {0}) - .SetInput(kWhileNInputIndex, {less_name}, {1}) + graph_builder.SetInput(kWhileIInputIndex, {mul_name}, {0}) + .SetInput(kWhileAbsDeltaInputIndex, {mul_name}, {1}) + .SetInput(kWhileRangeInputIndex, {less_name}, {1}) .SetUselessInput(kWhileStartInputIndex) .SetUselessInput(kWhileDeltaInputIndex); size_t input_num = while_info.data_inputs.size(); @@ -569,6 +541,9 @@ ComputeGraphPtr ForPass::BuildCondGraph(WhileInfo &while_info) { // Add Output graph_builder.AddOutput(less_name, 0); + // Add Edges + graph_builder.AddDataLink(mul_name, 0, less_name, 0); + // Add Input-Mapping std::map input_mapping; for (size_t i = 0; i < input_num; i++) { @@ -622,7 +597,8 @@ ComputeGraphPtr ForPass::BuildBodyGraph(WhileInfo &while_info) { // Set Input graph_builder.SetInput(kWhileIInputIndex, {add_name_0, mul_name}, {0, 0}) - .SetUselessInput(kWhileNInputIndex) + .SetUselessInput(kWhileAbsDeltaInputIndex) + .SetUselessInput(kWhileRangeInputIndex) .SetInput(kWhileStartInputIndex, {add_name_1}, {0}) .SetInput(kWhileDeltaInputIndex, {mul_name}, {1}); for (uint32_t i = 0; i < input_num - kWhileDataInputIndex; i++) { @@ -631,7 +607,7 @@ ComputeGraphPtr ForPass::BuildBodyGraph(WhileInfo &while_info) { // Add Outputs graph_builder.AddOutput(add_name_0, 0); - for (uint32_t i = kWhileNInputIndex; i < kWhileDataInputIndex; i++) { + for (uint32_t i = kWhileAbsDeltaInputIndex; i < kWhileDataInputIndex; i++) { graph_builder.AddOutput("Data_" + std::to_string(i), 0); } for (uint32_t i = 0; i < sub_graph_output_num; i++) { diff --git a/src/ge/graph/passes/for_pass.h b/src/ge/graph/passes/for_pass.h index 3611171e..f25655f8 100644 --- a/src/ge/graph/passes/for_pass.h +++ b/src/ge/graph/passes/for_pass.h @@ -34,11 +34,22 @@ struct ForInfo { }; struct WhileInfo { - WhileInfo() : while_node(nullptr), sub_graph_node(nullptr), i(nullptr), n(nullptr), start(nullptr), delta(nullptr) {} + WhileInfo() + : while_node(nullptr), + sub_graph_node(nullptr), + i(nullptr), + abs_delta(nullptr), + range(nullptr), + start(nullptr), + delta(nullptr), + for_body(nullptr), + while_cond(nullptr), + while_body(nullptr) {} ge::NodePtr while_node; ge::NodePtr sub_graph_node; ge::OutDataAnchorPtr i; - ge::OutDataAnchorPtr n; + ge::OutDataAnchorPtr abs_delta; + ge::OutDataAnchorPtr range; ge::OutDataAnchorPtr start; ge::OutDataAnchorPtr delta; std::string for_body_name; @@ -127,21 +138,15 @@ class ForPass : public BaseNodePass { static OpDescPtr CreateConstDesc(const std::string &name, int32_t value); /// - /// @brief Create loop_count input + /// @brief Create loop input /// @param [in] graph /// @param [in] for_info - /// @return OutDataAnchorPtr - /// - OutDataAnchorPtr CreateLoopCountInput(const ComputeGraphPtr &graph, const ForInfo &for_info); - - /// - /// @brief Create cast op_desc - /// @param [in] name - /// @param [in] src_data_type - /// @param [in] dst_data_type - /// @return OpDescPtr + /// @param [out] range_input + /// @param [out] abs_delta_input + /// @return Status /// - static OpDescPtr CreateCastDesc(const std::string &name, DataType src, DataType dst); + Status CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info, OutDataAnchorPtr &range_input, + OutDataAnchorPtr &abs_delta_input); /// /// @brief Create op_desc @@ -156,11 +161,13 @@ class ForPass : public BaseNodePass { /// @brief Build while-info /// @param [in] for_info /// @param [in] i_input - /// @param [in] n_input + /// @param [in] range_input + /// @param [in] abs_delta_input /// @param [out] while_info /// @return void /// - static void BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input, const OutDataAnchorPtr &n_input, + static void BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input, + const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input, WhileInfo &while_info); /// diff --git a/src/ge/graph/passes/hccl_group_pass.cc b/src/ge/graph/passes/hccl_group_pass.cc new file mode 100644 index 00000000..d8f11434 --- /dev/null +++ b/src/ge/graph/passes/hccl_group_pass.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hccl_group_pass.h" +#include +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_attr_define.h" +#include "framework/common/util.h" + +namespace ge { +Status HcclGroupPass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + bool is_fused_node = false; + if (!AttrUtils::GetBool(op_desc, ATTR_NAME_HCCL_FUSED_FLAG, is_fused_node)) { + GELOGW("Get attr ATTR_NAME_GRADIENT_FUSED_GROUP failed."); + return SUCCESS; + } + GELOGI("Recoginzed fused node %s", node->GetName().c_str()); + if (op_desc->HasAttr(ATTR_NAME_HCCL_FUSED_GROUP)) { + GELOGD("Current node %s already marked group id, ignore it.", node->GetName().c_str()); + return SUCCESS; + } + + if (!is_fused_node) { + GELOGD("Current node %s is not gradient fused node , ignore it.", node->GetName().c_str()); + return SUCCESS; + } + Status ret = MarkGroupForFusedNode(node); + if (ret != SUCCESS) { + GELOGW("Mark group for fused node %s failed. It might cause performance problem.", node->GetName().c_str()); + } + return SUCCESS; +} + +Status HcclGroupPass::MarkGroupForFusedNode(NodePtr &fused_node) { + std::deque queue; + queue.push_back(fused_node); + string group_id = fused_node->GetName(); + + while (!queue.empty()) { + NodePtr node = queue.front(); + queue.pop_front(); + for (auto out_data_node : node->GetOutDataNodes()) { + if (out_data_node->GetType() == fused_node->GetType()) { + // if meet fused node, it is the end of current group + break; + } + if (!AttrUtils::SetStr(out_data_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_id)) { + GELOGW("Set attr ATTR_NAME_GRADIENT_FUSED_GROUP failed."); + return FAILED; + } + GELOGI("Set group_id %s for node %s", group_id.c_str(), out_data_node->GetName().c_str()); + queue.emplace_back(out_data_node); + } + } + return SUCCESS; +} +} // namespace ge diff --git a/third_party/fwkacllib/inc/ops/decode_wheels_target.h b/src/ge/graph/passes/hccl_group_pass.h similarity index 54% rename from third_party/fwkacllib/inc/ops/decode_wheels_target.h rename to src/ge/graph/passes/hccl_group_pass.h index 053a6c1a..059710ce 100644 --- a/third_party/fwkacllib/inc/ops/decode_wheels_target.h +++ b/src/ge/graph/passes/hccl_group_pass.h @@ -14,18 +14,18 @@ * limitations under the License. */ - #ifndef GE_OP_DECODE_WHEELS_TARGET_H - #define GE_OP_DECODE_WHEELS_TARGET_H +#ifndef GE_GRAPH_PASSES_HCCL_GROUP_PASS_H_ +#define GE_GRAPH_PASSES_HCCL_GROUP_PASS_H_ - #include "graph/operator_reg.h" +#include "graph/passes/base_pass.h" +namespace ge { +class HcclGroupPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; - namespace ge { + private: + Status MarkGroupForFusedNode(NodePtr &fused_node); +}; +} // namespace ge - REG_OP(DecodeWheelsTarget) - .INPUT(boundary_predictions, TensorType({DT_FLOAT16})) /* "First operand." */ - .INPUT(anchors, TensorType({DT_FLOAT16})) /* "Second operand." */ - .OUTPUT(boundary_encoded, TensorType({DT_FLOAT16})) /* "Result, has same element type as two inputs" */ - .OP_END_FACTORY_REG(DecodeWheelsTarget) - } // namespace ge - - #endif // GE_OP_DECODE_WHEELS_TARGET_H +#endif // GE_GRAPH_PASSES_HCOM_ALLREDUCE_GROUP_PASS_H_ diff --git a/src/ge/graph/passes/hccl_memcpy_pass.cc b/src/ge/graph/passes/hccl_memcpy_pass.cc index ac037d62..44c1b084 100644 --- a/src/ge/graph/passes/hccl_memcpy_pass.cc +++ b/src/ge/graph/passes/hccl_memcpy_pass.cc @@ -181,5 +181,12 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const } return SUCCESS; } - +/// +/// @brief Clear Status, uesd for subgraph pass +/// @return SUCCESS +/// +Status HcclMemcpyPass::ClearStatus() { + node_num_map_.clear(); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/passes/hccl_memcpy_pass.h b/src/ge/graph/passes/hccl_memcpy_pass.h index 4c4e8ae5..9de96fbf 100644 --- a/src/ge/graph/passes/hccl_memcpy_pass.h +++ b/src/ge/graph/passes/hccl_memcpy_pass.h @@ -27,6 +27,7 @@ namespace ge { class HcclMemcpyPass : public GraphPass { public: Status Run(ge::ComputeGraphPtr graph); + Status ClearStatus() override; private: NodePtr CreateMemcpyNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor); diff --git a/src/ge/graph/passes/iterator_op_pass.cc b/src/ge/graph/passes/iterator_op_pass.cc index a5dafdca..540742cf 100644 --- a/src/ge/graph/passes/iterator_op_pass.cc +++ b/src/ge/graph/passes/iterator_op_pass.cc @@ -21,6 +21,8 @@ #include #include +#include "common/debug/log.h" +#include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" @@ -29,9 +31,14 @@ #include "graph/node.h" #include "graph/passes/pass_utils.h" #include "graph/utils/graph_utils.h" +#include "runtime/mem.h" +#include "graph/manager/graph_var_manager.h" +#include "graph/ge_context.h" +#include "graph/manager/util/rt_context_util.h" namespace ge { const char *const kGetNext = "GetNext"; +const int kMaxIterationsPerLoop = INT32_MAX - 1; Status IteratorOpPass::Run(ge::ComputeGraphPtr graph) { GELOGD("GetNextOpPass begin"); @@ -58,12 +65,130 @@ Status IteratorOpPass::Run(ge::ComputeGraphPtr graph) { node->GetName().c_str()); GELOGI("Set independent loop for iterator node success"); + + int64_t loop_per_iter = 0; + ge::GeTensorDesc ge_tensor_desc; + Status status = + VarManager::Instance(graph->GetSessionID())->GetCurVarDesc(NODE_NAME_FLOWCTRL_LOOP_PER_ITER, ge_tensor_desc); + GE_IF_BOOL_EXEC(status != SUCCESS, GELOGW("Fail to Get var_desc of NODE_NAME_FLOWCTRL_LOOP_PER_ITER failed."); + continue); + + status = + GetVariableValue(graph->GetSessionID(), ge_tensor_desc, NODE_NAME_FLOWCTRL_LOOP_PER_ITER, &loop_per_iter); + GE_IF_BOOL_EXEC(status != SUCCESS, GELOGW("Get variable value of NODE_NAME_FLOWCTRL_LOOP_PER_ITER failed."); + continue); + GELOGI("The value of NODE_NAME_FLOWCTRL_LOOP_PER_ITER is %ld", loop_per_iter); + + if (loop_per_iter == kMaxIterationsPerLoop) { + ge::NodePtr end_of_sequence_node = InsertEndOfSequenceNode(node, memcpy_async_node, graph); + GE_CHECK_NOTNULL(end_of_sequence_node); + GE_CHK_STATUS_RET(SetStreamLabel(end_of_sequence_node, end_of_sequence_node->GetName()), + "Set stream label fail, node:%s", node->GetName().c_str()); + GELOGI("Insert EndOfSequence node success."); + } } + GELOGI("GetNextOpPass end"); } GELOGD("GetNextOpPass end"); return SUCCESS; } +Status IteratorOpPass::GetVariableValue(uint64_t session_id, const ge::GeTensorDesc &tensor_desc, + const std::string &var_name, void *dest) { + // base_addr + uint8_t *var_mem_base = VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM); + GE_CHECK_NOTNULL(var_mem_base); + // offset + uint8_t *dev_ptr = nullptr; + GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &dev_ptr), + "Get variable %s address failed.", var_name.c_str()); + int64_t offset = static_cast(reinterpret_cast(dev_ptr)); + // logic_base_addr + auto logic_var_base = VarManager::Instance(session_id)->GetVarMemLogicBase(); + // devcice_addr + uint8_t *variable_addr = static_cast(var_mem_base + offset - logic_var_base); + Status ret; + ret = SetRtContext(rtContext_t(), RT_CTX_NORMAL_MODE); + if (ret != SUCCESS) { + GELOGE(ret, "Set rt context RT_CTX_NORMAL_MODE failed."); + return ret; + } + GE_CHK_RT_RET(rtMemcpy(dest, sizeof(int64_t), variable_addr, sizeof(int64_t), RT_MEMCPY_DEVICE_TO_HOST)); + ret = SetRtContext(rtContext_t(), RT_CTX_GEN_MODE); + if (ret != SUCCESS) { + GELOGE(ret, "Set rt context RT_CTX_GEN_MODE failed."); + return ret; + } + return SUCCESS; +} + +/// +/// @brief insert EndOfSequence after GetNext +/// +/// @param pre_node +/// @param graph +/// @return ge::NodePtr +/// +ge::NodePtr IteratorOpPass::InsertEndOfSequenceNode(const ge::NodePtr &pre_node, const ge::NodePtr &memcpy_node, + const ge::ComputeGraphPtr &graph) { + GELOGI("Start to insert EndOfSequence node."); + GE_CHK_BOOL_EXEC(pre_node != nullptr, GELOGW("Pre node is null."); return nullptr); + GE_CHK_BOOL_EXEC(graph != nullptr, GELOGW("graph is null."); return nullptr); + ge::OpDescPtr end_of_seq_op_desc = CreateEndOfSequenceOp(pre_node); + GE_CHK_BOOL_EXEC(end_of_seq_op_desc != nullptr, GELOGW("Create EndOfSequence op fail."); return nullptr); + ge::NodePtr end_of_seq_node = graph->AddNode(end_of_seq_op_desc); + GE_CHK_BOOL_EXEC(end_of_seq_node != nullptr, return nullptr, "Insert EndOfSequence node fail."); + + // getnext(data) --> EOS + GE_CHK_BOOL_EXEC(pre_node->GetAllOutDataAnchorsSize() != 0, GELOGW("Pre node has no output."); return nullptr); + auto out_anchor = pre_node->GetOutDataAnchor(0); + ge::graphStatus status; + status = GraphUtils::AddEdge(out_anchor, end_of_seq_node->GetInDataAnchor(0)); + GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS, return nullptr, "Graph add EndOfSequence op input edge fail, dst node: %s.", + end_of_seq_node->GetName().c_str()); + // EOS(control) --> subsequent of memcpy + OutControlAnchorPtr out_ctrl_anchor = end_of_seq_node->GetOutControlAnchor(); + GE_CHK_BOOL_EXEC(out_ctrl_anchor != nullptr, GELOGW("out_ctrl_anchor is null."); return nullptr); + // add ctrl edge + for (const auto &out_node : memcpy_node->GetOutNodes()) { + auto in_ctrl_anchor = out_node->GetInControlAnchor(); + if (in_ctrl_anchor == nullptr) { + continue; + } + status = GraphUtils::AddEdge(out_ctrl_anchor, in_ctrl_anchor); + GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS, return nullptr, + "Graph add EndOfSequence op out ctrl edge fail, dst node: %s.", out_node->GetName().c_str()); + GELOGI("Graph add EndOfSequence op out ctrl edge, dst node: %s.", out_node->GetName().c_str()); + } + + return end_of_seq_node; +} + +/// +/// @brief create EndOfSequence +/// +/// @param pre_node +/// @return ge::OpDescPtr +/// +ge::OpDescPtr IteratorOpPass::CreateEndOfSequenceOp(const ge::NodePtr &pre_node) { + GELOGI("Start to create endOfSequence op."); + GE_CHK_BOOL_EXEC(pre_node != nullptr, return nullptr, "Input param invalid."); + + string node_name = pre_node->GetName() + "_EndOfSequence"; + ge::OpDescPtr op_desc = MakeShared(node_name, ENDOFSEQUENCE); + if (op_desc == nullptr) { + GELOGE(FAILED, "MakeShared fail."); + return op_desc; + } + ge::OpDescPtr pre_node_op_desc = pre_node->GetOpDesc(); + GE_CHK_BOOL_EXEC(pre_node_op_desc != nullptr, return nullptr, "OpDesc of pre_node is invalid."); + + GELOGI("Create EndOfSequence op:%s.", op_desc->GetName().c_str()); + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(pre_node_op_desc->GetOutputDesc(0)) == GRAPH_SUCCESS, return nullptr, + "Create EndOfSequence op:add input desc fail."); + return op_desc; +} + /// /// @brief insert memcpy after GetNext /// @@ -153,4 +278,12 @@ ge::OpDescPtr IteratorOpPass::CreateMemcpyAsyncOp(const ge::NodePtr &pre_node) { return op_desc; } + +Status IteratorOpPass::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode) { + GELOGI("set rt_context %d, device id:%u.", static_cast(mode), ge::GetContext().DeviceId()); + GE_CHK_RT_RET(rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId())); + GE_CHK_RT_RET(rtCtxSetCurrent(rt_context)); + RtContextUtil::GetInstance().AddrtContext(rt_context); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/passes/iterator_op_pass.h b/src/ge/graph/passes/iterator_op_pass.h index be76cc87..e403020c 100644 --- a/src/ge/graph/passes/iterator_op_pass.h +++ b/src/ge/graph/passes/iterator_op_pass.h @@ -29,9 +29,27 @@ class IteratorOpPass : public GraphPass { Status Run(ge::ComputeGraphPtr graph); + Status GetVariableValue(uint64_t session_id, const ge::GeTensorDesc &opdesc, const std::string &var_name, void *dest); + private: /// - /// @brief inset memcpy node + /// @brief Insert EndOfSequence node + /// + /// @param preNode + /// @param graph + /// @return ge::NodePtr + /// + ge::NodePtr InsertEndOfSequenceNode(const ge::NodePtr &pre_node, const ge::NodePtr &memcpy_node, + const ge::ComputeGraphPtr &graph); + /// + /// @brief Create a EndOfSequence Op object + /// + /// @param preNode + /// @return ge::OpDescPtr + /// + ge::OpDescPtr CreateEndOfSequenceOp(const ge::NodePtr &pre_node); + /// + /// @brief Insert memcpy node /// /// @param preNode /// @param graph @@ -45,6 +63,8 @@ class IteratorOpPass : public GraphPass { /// @return ge::OpDescPtr /// ge::OpDescPtr CreateMemcpyAsyncOp(const ge::NodePtr &pre_node); + + Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); }; } // namespace ge #endif // GE_GRAPH_PASSES_ITERATOR_OP_PASS_H_ diff --git a/src/ge/graph/passes/multi_batch_pass.cc b/src/ge/graph/passes/multi_batch_pass.cc index aac72892..bb0050be 100644 --- a/src/ge/graph/passes/multi_batch_pass.cc +++ b/src/ge/graph/passes/multi_batch_pass.cc @@ -32,7 +32,11 @@ namespace ge { Status MultiBatchPass::Run(ComputeGraphPtr graph) { GELOGD("MultiBatchPass Enter"); - + GE_CHECK_NOTNULL(graph); + if (graph->GetParentGraph() != nullptr) { + GELOGI("Subgraph %s skip the MultiBatchPass.", graph->GetName().c_str()); + return SUCCESS; + } OutDataAnchorPtr pred_value = nullptr; Status ret = FindPredValue(graph, pred_value); if (ret == NOT_CHANGED) { diff --git a/src/ge/graph/passes/next_iteration_pass.cc b/src/ge/graph/passes/next_iteration_pass.cc index 030ff6ac..138ad86b 100644 --- a/src/ge/graph/passes/next_iteration_pass.cc +++ b/src/ge/graph/passes/next_iteration_pass.cc @@ -328,4 +328,12 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string } return SUCCESS; } +/// +/// @brief Clear Status, uesd for subgraph pass +/// @return SUCCESS +/// +Status NextIterationPass::ClearStatus() { + loop_group_map_.clear(); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/passes/next_iteration_pass.h b/src/ge/graph/passes/next_iteration_pass.h index aefcc0f5..4bbced4f 100644 --- a/src/ge/graph/passes/next_iteration_pass.h +++ b/src/ge/graph/passes/next_iteration_pass.h @@ -37,6 +37,7 @@ namespace ge { class NextIterationPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); + Status ClearStatus() override; private: Status HandleEnterNode(const NodePtr &enter_node); diff --git a/src/ge/graph/passes/pass_manager.cc b/src/ge/graph/passes/pass_manager.cc index f62ea160..eec33eef 100644 --- a/src/ge/graph/passes/pass_manager.cc +++ b/src/ge/graph/passes/pass_manager.cc @@ -22,26 +22,29 @@ #include "omg/omg_inner_types.h" namespace ge { -const vector &PassManager::GraphPasses() const { return graph_passes_; } +const vector> &PassManager::GraphPasses() const { return names_to_graph_passes_; } -Status PassManager::AddPass(GraphPass *pass) { +Status PassManager::AddPass(const string &pass_name, GraphPass *pass) { GE_CHECK_NOTNULL(pass); - graph_passes_.push_back(pass); + names_to_graph_passes_.emplace_back(pass_name, pass); return SUCCESS; } Status PassManager::Run(const ComputeGraphPtr &graph) { GE_CHECK_NOTNULL(graph); - return Run(graph, graph_passes_); + return Run(graph, names_to_graph_passes_); } -Status PassManager::Run(const ComputeGraphPtr &graph, vector &passes) { +Status PassManager::Run(const ComputeGraphPtr &graph, vector> &names_to_passes) { GE_CHECK_NOTNULL(graph); bool not_changed = true; - for (auto &pass : passes) { + for (auto &pass_pair : names_to_passes) { + const auto &pass = pass_pair.second; + const auto &pass_name = pass_pair.first; GE_CHECK_NOTNULL(pass); + GE_TIMESTAMP_START(PassRun); Status status = pass->Run(graph); if (status == SUCCESS) { not_changed = false; @@ -51,7 +54,11 @@ Status PassManager::Run(const ComputeGraphPtr &graph, vector &passe } for (const auto &subgraph : graph->GetAllSubgraphs()) { GE_CHECK_NOTNULL(subgraph); + GE_CHK_STATUS_RET(pass->ClearStatus(), "pass clear status failed for subgraph %s", subgraph->GetName().c_str()); + string subgraph_pass_name = pass_name + "::" + graph->GetName(); + GE_TIMESTAMP_START(PassRunSubgraph); status = pass->Run(subgraph); + GE_TIMESTAMP_END(PassRunSubgraph, subgraph_pass_name.c_str()); if (status == SUCCESS) { not_changed = false; } else if (status != NOT_CHANGED) { @@ -59,13 +66,15 @@ Status PassManager::Run(const ComputeGraphPtr &graph, vector &passe return status; } } + GE_TIMESTAMP_END(PassRun, pass_name.c_str()); } return not_changed ? NOT_CHANGED : SUCCESS; } PassManager::~PassManager() { - for (auto pass : graph_passes_) { + for (auto &pass_pair : names_to_graph_passes_) { + auto &pass = pass_pair.second; GE_DELETE_NEW_SINGLE(pass); } } diff --git a/src/ge/graph/passes/pass_utils.cc b/src/ge/graph/passes/pass_utils.cc index 9b3f6b5f..a51b4e29 100644 --- a/src/ge/graph/passes/pass_utils.cc +++ b/src/ge/graph/passes/pass_utils.cc @@ -254,6 +254,14 @@ bool PassUtils::IsNeedTrainIteFlowCtrl(const ComputeGraphPtr &compute_graph) { if (compute_graph == nullptr) { return false; } + if (compute_graph->GetParentGraph() != nullptr) { + GELOGI("Subgraph %s no need flow ctrl.", compute_graph->GetName().c_str()); + return false; + } + if (GraphUtils::IsUnknownShapeGraph(compute_graph)) { + GELOGI("Unknown shape graph %s no need flow ctrl.", compute_graph->GetName().c_str()); + return false; + } if (!ge::VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(NODE_NAME_FLOWCTRL_LOOP_PER_ITER)) { return false; } diff --git a/src/ge/graph/passes/prune_pass.cc b/src/ge/graph/passes/prune_pass.cc index f7d09740..af10c54f 100644 --- a/src/ge/graph/passes/prune_pass.cc +++ b/src/ge/graph/passes/prune_pass.cc @@ -23,6 +23,7 @@ #include "common/types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" +#include "graph/utils/node_utils.h" namespace ge { Status PrunePass::Run(ge::ComputeGraphPtr graph) { @@ -79,6 +80,8 @@ Status PrunePass::Run(ge::ComputeGraphPtr graph) { node_ptr->GetOpDesc()->GetName().c_str(), out_nodes[0]->GetOpDesc()->GetName().c_str()); continue; } + // Remove subgraphs on the node before remove it in graph. + (void)NodeUtils::RemoveSubgraphsOnNode(node_ptr); /// Common function:[RemoveNode] will delete not only input node but its constant input node also will be deleted (void)graph->RemoveNode(node_ptr); GELOGI("[PrunePass] remove graph node [%s]!", node_ptr->GetOpDesc()->GetName().c_str()); diff --git a/src/ge/graph/passes/remove_nodes_pass.cc b/src/ge/graph/passes/remove_nodes_pass.cc new file mode 100644 index 00000000..b29d6af3 --- /dev/null +++ b/src/ge/graph/passes/remove_nodes_pass.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "remove_nodes_pass.h" +#include "debug/ge_log.h" +#include "inc/framework/common/util.h" +#include "inc/graph/utils/node_utils.h" + +namespace ge { +Status RemoveNodesPass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto node_type = NodeUtils::GetNodeType(*node); + auto type_iter = remove_node_types_to_arg_.find(node_type); + if (type_iter != remove_node_types_to_arg_.end()) { + GELOGI("Remove node %s by type %s", node->GetName().c_str(), node_type.c_str()); + return IsolateAndDeleteNode(node, type_iter->second); + } + for (const auto &attr_name_to_arg : remove_node_attr_names_to_arg_) { + if (AttrUtils::HasAttr(node->GetOpDesc(), attr_name_to_arg.first)) { + GELOGI("Remove node %s by attr name %s", node->GetName().c_str(), attr_name_to_arg.first.c_str()); + return IsolateAndDeleteNode(node, attr_name_to_arg.second); + } + } + + return SUCCESS; +} +RemoveNodesPass &RemoveNodesPass::AddNodeType(const string &node_type, std::initializer_list arg) { + remove_node_types_to_arg_[node_type] = std::move(arg); + return *this; +} +RemoveNodesPass &RemoveNodesPass::AddAttrName(const string &attr_name, std::initializer_list arg) { + remove_node_attr_names_to_arg_[attr_name] = std::move(arg); + return *this; +} +} // namespace ge \ No newline at end of file diff --git a/third_party/fwkacllib/inc/ops/score_filter_pre_sort.h b/src/ge/graph/passes/remove_nodes_pass.h similarity index 50% rename from third_party/fwkacllib/inc/ops/score_filter_pre_sort.h rename to src/ge/graph/passes/remove_nodes_pass.h index 8cfac8cf..32acda1b 100644 --- a/third_party/fwkacllib/inc/ops/score_filter_pre_sort.h +++ b/src/ge/graph/passes/remove_nodes_pass.h @@ -14,23 +14,20 @@ * limitations under the License. */ - #ifndef GE_OP_SCORE_FILTER_PRE_SORT_H - #define GE_OP_SCORE_FILTER_PRE_SORT_H - - #include "graph/operator_reg.h" +#ifndef GE_REMOVE_NODES_PASS_H_ +#define GE_REMOVE_NODES_PASS_H_ +#include "graph/passes/base_pass.h" namespace ge { - REG_OP(ScoreFiltePreSort) - .INPUT(rois, TensorType({DT_FLOAT16})) - .INPUT(cls_bg_prob, TensorType({DT_FLOAT16})) - .OUTPUT(sorted_proposal, TensorType({ DT_FLOAT16})) - .OUTPUT(proposal_num, TensorType({ DT_UINT32})) - .REQUIRED_ATTR(score_threshold, Float) - .REQUIRED_ATTR(k, Int) - .ATTR(score_filter, Bool, true) - .ATTR(core_max_num, Int, 8) - .OP_END_FACTORY_REG(ScoreFiltePreSort) - } // namespace ge - - #endif // GE_OP_SCORE_FILTER_PRE_SORT_H +class RemoveNodesPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; + RemoveNodesPass &AddNodeType(const std::string &node_type, std::initializer_list arg = {0}); + RemoveNodesPass &AddAttrName(const std::string &attr_name, std::initializer_list arg = {0}); + private: + std::map> remove_node_types_to_arg_; + std::map> remove_node_attr_names_to_arg_; +}; +} // namespace ge +#endif // GE_REMOVE_NODES_PASS_H_ diff --git a/src/ge/graph/passes/replace_with_empty_const_pass.cc b/src/ge/graph/passes/replace_with_empty_const_pass.cc index b76b2cc9..b6d680f7 100644 --- a/src/ge/graph/passes/replace_with_empty_const_pass.cc +++ b/src/ge/graph/passes/replace_with_empty_const_pass.cc @@ -101,6 +101,7 @@ Status ReplaceWithEmptyConstPass::ReplaceWithEmptyConst(NodePtr &node_to_replace const_node->GetName().c_str()); return FAILED; } + AddRePassNodesWithInOut(const_node); GELOGI("Node %s has been replaced by empty const %s.", node_to_replace->GetName().c_str(), const_node->GetName().c_str()); } @@ -115,14 +116,11 @@ Status ReplaceWithEmptyConstPass::ReplaceWithEmptyConst(NodePtr &node_to_replace } Status ReplaceWithEmptyConstPass::InsertEmptyConst(const GeTensorDesc &out_desc, NodePtr &const_node, ComputeGraphPtr &graph) { - GeTensorPtr empty_tensor = MakeShared(); + GeTensorPtr empty_tensor = MakeShared(out_desc); if (empty_tensor == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed create empty tensor."); return OUT_OF_MEMORY; } - empty_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); - empty_tensor->MutableTensorDesc().SetFormat(out_desc.GetFormat()); - empty_tensor->MutableTensorDesc().SetShape(out_desc.GetShape()); auto const_desc = OpDescUtils::CreateConstOp(empty_tensor); if (const_desc == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to get const desc from tensor"); diff --git a/src/ge/graph/passes/reshape_recovery_pass.cc b/src/ge/graph/passes/reshape_recovery_pass.cc new file mode 100644 index 00000000..787c8d83 --- /dev/null +++ b/src/ge/graph/passes/reshape_recovery_pass.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/reshape_recovery_pass.h" +#include "common/ge/ge_util.h" + +namespace ge { +namespace { +NodePtr CreateReshape(const ConstGeTensorDescPtr &src, const ConstGeTensorDescPtr &dst, const ComputeGraphPtr &graph) { + static std::atomic reshape_num(0); + auto next_num = reshape_num.fetch_add(1); + auto reshape = MakeShared("Reshape_ReshapeRecoveryPass_" + std::to_string(next_num), RESHAPE); + if (reshape == nullptr) { + return nullptr; + } + auto ret = reshape->AddInputDesc("x", *src); + if (ret != GRAPH_SUCCESS) { + return nullptr; + } + ret = reshape->AddOutputDesc("y", *dst); + if (ret != GRAPH_SUCCESS) { + return nullptr; + } + + return graph->AddNode(reshape); +} + +Status InsertReshapeIfNeed(const NodePtr &node) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + for (auto src_anchor : node->GetAllOutDataAnchors()) { + auto src_tensor = node->GetOpDesc()->GetOutputDescPtr(src_anchor->GetIdx()); + GE_CHECK_NOTNULL(src_tensor); + for (auto dst_anchor : src_anchor->GetPeerInDataAnchors()) { + auto dst_node = dst_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(dst_node); + GE_CHECK_NOTNULL(dst_node->GetOpDesc()); + auto dst_tensor = dst_node->GetOpDesc()->GetInputDescPtr(dst_anchor->GetIdx()); + if (src_tensor->GetShape().GetDims() != dst_tensor->GetShape().GetDims()) { + auto reshape = CreateReshape(src_tensor, dst_tensor, node->GetOwnerComputeGraph()); + GE_CHECK_NOTNULL(reshape); + auto ret = GraphUtils::InsertNodeBetweenDataAnchors(src_anchor, dst_anchor, reshape); + if (ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(), + dst_node->GetName().c_str()); + return INTERNAL_ERROR; + } + GELOGI("Insert reshape between %s and %s to keep the shape continues", node->GetName().c_str(), + dst_node->GetName().c_str()); + } + } + } + return SUCCESS; +} +} // namespace + +Status ReshapeRecoveryPass::Run(ComputeGraphPtr graph) { + for (const auto &node : graph->GetDirectNode()) { + auto ret = InsertReshapeIfNeed(node); + if (ret != SUCCESS) { + return ret; + } + } + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/src/ge/graph/passes/reshape_recovery_pass.h b/src/ge/graph/passes/reshape_recovery_pass.h new file mode 100644 index 00000000..b3ab1baa --- /dev/null +++ b/src/ge/graph/passes/reshape_recovery_pass.h @@ -0,0 +1,27 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_RESHAPE_RECOVERY_PASS_H_ +#define GE_RESHAPE_RECOVERY_PASS_H_ +#include "inc/graph_pass.h" +namespace ge { +class ReshapeRecoveryPass : public GraphPass { + public: + virtual Status Run(ge::ComputeGraphPtr graph) override; +}; +} // namespace ge + +#endif // GE_RESHAPE_RECOVERY_PASS_H_ diff --git a/src/ge/graph/passes/reshape_remove_pass.cc b/src/ge/graph/passes/reshape_remove_pass.cc index bd84882a..0f6d52d1 100644 --- a/src/ge/graph/passes/reshape_remove_pass.cc +++ b/src/ge/graph/passes/reshape_remove_pass.cc @@ -17,12 +17,12 @@ #include "graph/passes/reshape_remove_pass.h" #include "framework/common/util.h" #include "graph/passes/pass_utils.h" +#include "graph/utils/node_utils.h" namespace ge { namespace { const int kReshapeDataIndex = 0; -const int kReshapeShapeIndex = 1; -} // namespace +} Status ReshapeRemovePass::Run(NodePtr &node) { GE_CHECK_NOTNULL(node); @@ -30,15 +30,16 @@ Status ReshapeRemovePass::Run(NodePtr &node) { if (node->GetType() != RESHAPE && node->GetType() != REFORMAT) { return SUCCESS; } - auto op_desc = node->GetOpDesc(); - auto output_desc = op_desc->GetOutputDescPtr(kReshapeDataIndex); - GE_CHECK_NOTNULL(output_desc); - if (output_desc->GetShape().IsUnknownShape()) { - GELOGD("Reshape node %s is unknown shape. It should be remained.", node->GetName().c_str()); - return SUCCESS; + + bool is_shape_unknown = false; + if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) { + if (is_shape_unknown) { + GELOGI("op:%s is unknown shape, can not be deleted.", node->GetName().c_str()); + return SUCCESS; + } } - GELOGD("Remove %s node %s", node->GetType().c_str(), node->GetName().c_str()); + GELOGI("Remove %s node %s", node->GetType().c_str(), node->GetName().c_str()); return IsolateAndDeleteNode(node, {kReshapeDataIndex}); } } // namespace ge diff --git a/src/ge/graph/passes/subgraph_pass.cc b/src/ge/graph/passes/subgraph_pass.cc index 6c4ad385..b677179e 100644 --- a/src/ge/graph/passes/subgraph_pass.cc +++ b/src/ge/graph/passes/subgraph_pass.cc @@ -15,17 +15,12 @@ */ #include "graph/passes/subgraph_pass.h" - +#include #include "graph/utils/node_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" -namespace { -const std::set kWhileTypes = {ge::WHILE, ge::_WHILE, ge::STATELESSWHILE}; -} - namespace ge { - /** * @ingroup ge * @brief Subgraph optimizer. @@ -35,28 +30,32 @@ namespace ge { Status SubgraphPass::Run(ComputeGraphPtr graph) { const bool is_sub_graph = graph->GetParentNode() != nullptr; for (const NodePtr &node : graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - if (is_sub_graph && (node->GetType() == DATA)) { if (SubgraphInputNode(graph, node) != SUCCESS) { + GELOGE(FAILED, "Handle input %s of subgraph failed.", node->GetName().c_str()); return FAILED; } continue; } - // 2. Const->NetOutput in subgraph - // 3. Data->NetOutput in subgraph but not while body + // NetOutput in subgraph if (is_sub_graph && (node->GetType() == NETOUTPUT)) { if (SubgraphOutputNode(graph, node) != SUCCESS) { + GELOGE(FAILED, "Handle output %s of subgraph failed.", node->GetName().c_str()); return FAILED; } continue; } - // 4. Input->While and Input link to other nodes - if (kWhileTypes.count(node->GetType()) > 0) { + if (kWhileOpTypes.count(node->GetType()) > 0) { + // Input->While and Input link to other nodes if (WhileInputNodes(graph, node) != SUCCESS) { + GELOGE(FAILED, "Handle input of while_body failed, while:%s.", node->GetName().c_str()); + return FAILED; + } + // body subgraph of While op + if (WhileBodySubgraph(graph, node) != SUCCESS) { + GELOGE(FAILED, "Handle while_body failed, while:%s.", node->GetName().c_str()); return FAILED; } continue; @@ -74,20 +73,63 @@ Status SubgraphPass::Run(ComputeGraphPtr graph) { * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodePtr &node) { + GELOGD("Hadle input_node %s for graph %s.", node->GetName().c_str(), graph->GetName().c_str()); + // Data has and only has one output + bool input_continues_required_flag = false; + OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); + std::vector in_anchors; + for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + input_continues_required_flag = + input_continues_required_flag || IsInputContinuesRequired(peer_in_anchor->GetOwnerNode()); + in_anchors.emplace_back(peer_in_anchor); + } + // Data->InputContinuesRequiredOp in subgraph need memcpy. + if (input_continues_required_flag) { + GELOGD("Data %s output_node required continues input.", node->GetName().c_str()); + std::string name = node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + if (InsertMemcpyNode(graph, out_data_anchor, in_anchors, name) != SUCCESS) { + GELOGE(FAILED, "Insert memcpy after %s failed.", node->GetName().c_str()); + return FAILED; + } + } + uint32_t parent_index = 0; if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "Get attr PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); return FAILED; } + NodePtr in_node = NodeUtils::GetParentInput(node); + GE_CHECK_NOTNULL(in_node); // Subgraph Data Node, check for constant input. std::string const_type; - NodePtr in_node = NodeUtils::GetParentInput(node); if (!NodeUtils::GetConstOpType(in_node, const_type)) { + GELOGE(FAILED, "Get const_op_type failed, node:%s.", in_node->GetName().c_str()); return SUCCESS; } - if (!AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) { - return FAILED; + const NodePtr &parent_node = graph->GetParentNode(); + if (kWhileOpTypes.count(parent_node->GetType()) == 0) { + if (!AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) { + GELOGE(FAILED, "Set attr PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); + return FAILED; + } + } else { + // Constant input to While need memcpy. + const ComputeGraphPtr &parent_graph = parent_node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(parent_graph); + const InDataAnchorPtr &in_data_anchor = parent_node->GetInDataAnchor(parent_index); + GE_CHECK_NOTNULL(in_data_anchor); + const OutDataAnchorPtr &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + GELOGD("Constant input %s links to While %s.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), + parent_node->GetName().c_str()); + std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + if (InsertMemcpyNode(parent_graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { + GELOGE(FAILED, "Insert memcpy between %s and %s failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), + parent_node->GetName().c_str()); + return FAILED; + } } return SUCCESS; @@ -109,18 +151,19 @@ Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const Node GE_CHECK_NOTNULL(in_node); // Need insert memcpy - // 2. Const->NetOutput in subgraph - // 3. Data->NetOutput in subgraph but not while body + // 1. Const->NetOutput in subgraph + // 2. AtomicOp->NetOutput in subgraph + // 3. OutputContinuesRequiredOp->NetOutput in subgraph + // 4. Data->NetOutput in subgraph but not while body std::string op_type; - bool input_const_flag = NodeUtils::GetConstOpType(in_node, op_type); - if ((in_node->GetType() == DATA) && !IsWhileBodyOutput(in_data_anchor)) { - input_const_flag = true; - } - - if (input_const_flag) { + bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) || + IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || + ((in_node->GetType() == DATA) && !IsWhileBodyOutput(in_data_anchor)); + if (insert_flag) { GELOGI("Insert MemcpyAsync node between %s and %s.", node->GetName().c_str(), in_node->GetName().c_str()); std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); - if (InsertMemcpyNode(graph, peer_out_anchor, in_data_anchor, name) != SUCCESS) { + if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { + GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); return FAILED; } } @@ -140,16 +183,14 @@ Status SubgraphPass::WhileInputNodes(const ComputeGraphPtr &graph, const NodePtr for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { const OutDataAnchorPtr &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); - NodePtr in_node = peer_out_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(in_node); - - // Need insert memcpy - // 4. Input->While and Input link to other nodes + // Input->While and Input link to other nodes need insert memcpy if (peer_out_anchor->GetPeerInDataAnchors().size() > 1) { - GELOGI("Insert MemcpyAsync node between %s and %s.", node->GetName().c_str(), in_node->GetName().c_str()); + GELOGI("Input %s of While %s links to other nodes.", in_node->GetName().c_str(), node->GetName().c_str()); std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); - if (InsertMemcpyNode(graph, peer_out_anchor, in_data_anchor, name) != SUCCESS) { + if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { + GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); return FAILED; } } @@ -158,6 +199,249 @@ Status SubgraphPass::WhileInputNodes(const ComputeGraphPtr &graph, const NodePtr return SUCCESS; } +/** + * @ingroup ge + * @brief Check body subgraph of While op + * @param [in] graph: ComputeGraph. + * @param [in] node: While node. + * @return: 0 for SUCCESS / others for FAILED + */ +Status SubgraphPass::WhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node) { + ComputeGraphPtr while_body = GetWhileBodySubgraph(graph, node); + if (while_body == nullptr) { + GELOGE(FAILED, "while_body of %s is NULL.", node->GetName().c_str()); + return FAILED; + } + + NodePtr output_node = while_body->FindNode(NODE_NAME_NET_OUTPUT); + if (output_node == nullptr) { + GELOGE(FAILED, "net_output_node not exist in graph %s.", while_body->GetName().c_str()); + return FAILED; + } + OpDescPtr output_desc = output_node->GetOpDesc(); + GE_CHECK_NOTNULL(output_desc); + std::unordered_map> node_to_attr_index; + for (const InDataAnchorPtr &in_data_anchor : output_node->GetAllInDataAnchors()) { + uint32_t index = 0; + if (!AttrUtils::GetInt(output_desc->GetInputDesc(in_data_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index)) { + GELOGE(FAILED, "Get attr PARENT_NODE_INDEX failed, node %s:%u.", output_node->GetName().c_str(), + in_data_anchor->GetIdx()); + return FAILED; + } + MarkOutputIndex(in_data_anchor->GetPeerOutAnchor(), index, node_to_attr_index); + } + + std::set data_nodes; + std::set netoutput_input_indexes; + GetExchangeInOut(node_to_attr_index, data_nodes, netoutput_input_indexes); + return InsertMemcpyInWhileBody(while_body, data_nodes, output_node, netoutput_input_indexes); +} + +/** + * @ingroup ge + * @brief Get body subgraph of While op + * @param [in] graph: ComputeGraph. + * @param [in] node: While node. + * @return: body subgraph + */ +ComputeGraphPtr SubgraphPass::GetWhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(FAILED, "op_desc is NULL."); + return nullptr; + } + + const std::vector &subgraph_instance_names = op_desc->GetSubgraphInstanceNames(); + std::string body_instance_name; + for (const std::string &instance_name : subgraph_instance_names) { + std::string subgraph_name; + if (op_desc->GetSubgraphNameByInstanceName(instance_name, subgraph_name) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Get subgraph_name by instance_name %s failed, node:%s.", instance_name.c_str(), + node->GetName().c_str()); + return nullptr; + } + if (subgraph_name == ATTR_NAME_WHILE_BODY) { + body_instance_name = instance_name; + break; + } + } + + ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph); + if (root_graph == nullptr) { + GELOGE(FAILED, "root_graph is NULL."); + return nullptr; + } + + return root_graph->GetSubgraph(body_instance_name); +} + +/** + * @ingroup ge + * @brief Mark output parent_node_index + * @param [in] peer_out_anchor: peer_out_anchor of NetOutput + * @param [in] index: parent_node_index of NetOutput + * @param [out] node_to_attr_index: key for node in subgraph, value for parent_node_index + * @return: void + */ +void SubgraphPass::MarkOutputIndex(const OutDataAnchorPtr &peer_out_anchor, uint32_t index, + std::unordered_map> &node_to_attr_index) { + if (peer_out_anchor == nullptr) { + return; + } + std::set visited_nodes; + std::stack nodes; + nodes.emplace(peer_out_anchor->GetOwnerNode()); + while (!nodes.empty()) { + NodePtr cur_node = nodes.top(); + nodes.pop(); + if (visited_nodes.count(cur_node) > 0) { + continue; + } + if (node_to_attr_index.count(cur_node) > 0) { + node_to_attr_index[cur_node].emplace_back(index); + } else { + node_to_attr_index[cur_node] = {index}; + } + for (const NodePtr &in_node : cur_node->GetInDataNodes()) { + nodes.emplace(in_node); + } + visited_nodes.emplace(cur_node); + } +} + +/** + * @ingroup ge + * @brief Get data_nodes / input_indexes of netoutput if need insert memcpy + * @param [in] node_to_attr_index: key for node in subgraph, value for parent_node_index + * @param [out] data_nodes: data_nodes need insert memcpy + * @param [out] netoutput_input_indexes: input_indexes of netoutput need insert memcpy + * @return: void + */ +void SubgraphPass::GetExchangeInOut(const std::unordered_map> &node_to_attr_index, + std::set &data_nodes, std::set &netoutput_input_indexes) { + for (const auto &item : node_to_attr_index) { + NodePtr node = item.first; + uint32_t input_index = 0; + if ((node->GetType() != DATA) || !AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, input_index)) { + continue; + } + if (item.second.empty() || ((item.second.size() == 1) && (item.second[0] == input_index))) { + continue; + } + data_nodes.emplace(node); + + // Data node has and only has one output + OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); + if (out_data_anchor == nullptr) { + continue; + } + for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + NodePtr out_node = peer_in_anchor->GetOwnerNode(); + if ((out_node->GetType() != NETOUTPUT) || (out_node->GetOpDesc() == nullptr)) { + continue; + } + uint32_t output_index = 0; + GeTensorDesc input_tensor = out_node->GetOpDesc()->GetInputDesc(peer_in_anchor->GetIdx()); + if (!AttrUtils::GetInt(input_tensor, ATTR_NAME_PARENT_NODE_INDEX, output_index)) { + continue; + } + if (input_index != output_index) { + netoutput_input_indexes.emplace(peer_in_anchor->GetIdx()); + } + } + } +} + +/** + * @ingroup ge + * @brief Insert memcpy node in while_body + * @param [in] graph: while_body + * @param [in] data_nodes: data_nodes need insert memcpy + * @param [in] output_node: NetOutput in while_body + * @param [in] netoutput_input_indexes: input_indexes of netoutput need insert memcpy + * @return: 0 for SUCCESS / others for FAILED + */ +Status SubgraphPass::InsertMemcpyInWhileBody(const ComputeGraphPtr &graph, const std::set &data_nodes, + const NodePtr &output_node, + const std::set &netoutput_input_indexes) { + for (const NodePtr &data_node : data_nodes) { + // Data node has and only has one output + OutDataAnchorPtr out_data_anchor = data_node->GetOutDataAnchor(0); + std::vector in_anchors; + for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + in_anchors.emplace_back(peer_in_anchor); + } + std::string name = data_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + GELOGD("Insert memcpy after while_body %s input_node %s.", graph->GetName().c_str(), data_node->GetName().c_str()); + if (InsertMemcpyNode(graph, out_data_anchor, in_anchors, name) != SUCCESS) { + GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), data_node->GetName().c_str()); + return FAILED; + } + } + + for (uint32_t index : netoutput_input_indexes) { + InDataAnchorPtr in_data_anchor = output_node->GetInDataAnchor(index); + GE_CHECK_NOTNULL(in_data_anchor); + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + std::string name = + peer_out_anchor->GetOwnerNode()->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + GELOGD("Insert memcpy after while_body %s output %u.", graph->GetName().c_str(), index); + if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { + GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), + peer_out_anchor->GetOwnerNode()->GetName().c_str()); + return FAILED; + } + } + + std::set memcpy_nodes; + std::set loop_body_nodes; + for (const NodePtr &data_node : data_nodes) { + // data_node has only one output node + NodePtr memcpy_node = data_node->GetOutDataNodes().at(0); + GE_CHECK_NOTNULL(memcpy_node); + memcpy_nodes.emplace(memcpy_node); + for (const NodePtr &out_node : memcpy_node->GetOutDataNodes()) { + loop_body_nodes.insert(out_node); + } + } + return InsertNoOp(graph, memcpy_nodes, loop_body_nodes); +} + +/** + * @ingroup ge + * @brief Insert NoOp node between memcpy_nodes and loop_body_nodes + * @param [in] graph: while_body + * @param [in] memcpy_nodes + * @param [in] loop_body_nodes + * @return: 0 for SUCCESS / others for FAILED + */ +Status SubgraphPass::InsertNoOp(const ComputeGraphPtr &graph, const std::set &memcpy_nodes, + const std::set &loop_body_nodes) { + if (memcpy_nodes.empty() || loop_body_nodes.empty()) { + return SUCCESS; + } + + OpDescBuilder noop_desc_builder("NoOp_for_Control", NOOP); + OpDescPtr noop_desc = noop_desc_builder.Build(); + NodePtr noop_node = graph->AddNode(noop_desc); + GE_CHECK_NOTNULL(noop_node); + for (const NodePtr &memcpy_node : memcpy_nodes) { + if (GraphUtils::AddEdge(memcpy_node->GetOutControlAnchor(), noop_node->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add ctrl edge %s->%s failed.", memcpy_node->GetName().c_str(), noop_node->GetName().c_str()); + return FAILED; + } + } + for (const NodePtr &loop_body_node : loop_body_nodes) { + if (GraphUtils::AddEdge(noop_node->GetOutControlAnchor(), loop_body_node->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add ctrl edge %s->%s failed.", noop_node->GetName().c_str(), loop_body_node->GetName().c_str()); + return FAILED; + } + } + + return SUCCESS; +} + /** * @ingroup ge * @brief Check is data->netoutput in while body @@ -172,7 +456,7 @@ bool SubgraphPass::IsWhileBodyOutput(const InDataAnchorPtr &in_data_anchor) { } // Check if parent_node is While - if (kWhileTypes.count(parent_node->GetType()) == 0) { + if (kWhileOpTypes.count(parent_node->GetType()) == 0) { return false; } @@ -184,27 +468,87 @@ bool SubgraphPass::IsWhileBodyOutput(const InDataAnchorPtr &in_data_anchor) { return AttrUtils::HasAttr(op_desc->GetInputDesc(in_data_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX); } +/** + * @ingroup ge + * @brief Check is AtomicOp->NetOutput + * @param [in] node + * @param [in] out_index + * @return: true for AtomicOp->NetOutput / false for others + */ +bool SubgraphPass::IsAtomicRequired(const NodePtr &node, int64_t out_index) { + auto op_desc = node->GetOpDesc(); + if (op_desc != nullptr) { + bool is_atomic = false; + (void)ge::AttrUtils::GetBool(op_desc, ATOMIC_ATTR_IS_ATOMIC_NODE, is_atomic); + if (is_atomic) { + std::vector atomic_output_index; + // If GetListInt fail, atomic_output_index is empty. + (void)ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); + for (int64_t ind : atomic_output_index) { + if (ind == out_index) { + return true; + } + } + } + } + return false; +} + +/** + * @ingroup ge + * @brief Check is OutputContinuesRequiredOp->NetOutput + * @param [in] node + * @return: true for OutputContinuesRequiredOp->NetOutput / false for others + */ +bool SubgraphPass::IsOutputContinuesRequired(const NodePtr &node) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc != nullptr) { + bool continuous_output_flag = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_OUTPUT, continuous_output_flag); + bool no_padding_continuous_output_flag = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT, no_padding_continuous_output_flag); + return continuous_output_flag || no_padding_continuous_output_flag; + } + return false; +} + +/** + * @ingroup ge + * @brief Check is InputContinuesRequiredOp->NetOutput + * @param [in] node + * @return: true for InputContinuesRequiredOp->NetOutput / false for others + */ +bool SubgraphPass::IsInputContinuesRequired(const NodePtr &node) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc != nullptr) { + bool continuous_input_flag = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT, continuous_input_flag); + bool no_padding_continuous_input_flag = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOPADDING_CONTINUOUS_INPUT, no_padding_continuous_input_flag); + return continuous_input_flag || no_padding_continuous_input_flag; + } + return false; +} + /** * @ingroup ge * @brief Insert memcpy node * @param [in] graph * @param [in] out_anchor - * @param [in] in_anchor + * @param [in] in_anchors * @param [in] name * @return: 0 for success / others for fail */ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, - const InDataAnchorPtr &in_anchor, const std::string &name) { + const std::vector &in_anchors, const std::string &name) { GE_CHECK_NOTNULL(out_anchor); - GE_CHECK_NOTNULL(in_anchor); NodePtr in_node = out_anchor->GetOwnerNode(); OpDescBuilder op_desc_builder(name, MEMCPYASYNC); OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(0)) .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) .Build(); - if (GraphUtils::InsertNodeBefore(out_anchor, {in_anchor}, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Insert MemcpyAsync node %s between %s->%s failed.", name.c_str(), in_node->GetName().c_str(), - in_anchor->GetOwnerNode()->GetName().c_str()); + if (GraphUtils::InsertNodeBefore(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), in_node->GetName().c_str()); return FAILED; } diff --git a/src/ge/graph/passes/subgraph_pass.h b/src/ge/graph/passes/subgraph_pass.h index 57e4e4c6..2308b1bd 100644 --- a/src/ge/graph/passes/subgraph_pass.h +++ b/src/ge/graph/passes/subgraph_pass.h @@ -66,23 +66,111 @@ class SubgraphPass : public GraphPass { /** * @ingroup ge - * @brief Check is data->netoutput in while body + * @brief Check body subgraph of While op + * @param [in] graph: ComputeGraph. + * @param [in] node: While node. + * @return: 0 for SUCCESS / others for FAILED + */ + Status WhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node); + + /** + * @ingroup ge + * @brief Get body subgraph of While op + * @param [in] graph: ComputeGraph. + * @param [in] node: While node. + * @return: body subgraph + */ + ComputeGraphPtr GetWhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node); + + /** + * @ingroup ge + * @brief Mark output parent_node_index + * @param [in] peer_out_anchor: peer_out_anchor of NetOutput + * @param [in] index: parent_node_index of NetOutput + * @param [out] node_to_attr_index: key for node in subgraph, value for parent_node_index + * @return: void + */ + void MarkOutputIndex(const OutDataAnchorPtr &peer_out_anchor, uint32_t index, + std::unordered_map> &node_to_attr_index); + + /** + * @ingroup ge + * @brief Get data_nodes / input_indexes of netoutput if need insert memcpy + * @param [in] node_to_attr_index: key for node in subgraph, value for parent_node_index + * @param [out] data_nodes: data_nodes need insert memcpy + * @param [out] netoutput_input_indexes: input_indexes of netoutput need insert memcpy + * @return: void + */ + void GetExchangeInOut(const std::unordered_map> &node_to_attr_index, + std::set &data_nodes, std::set &netoutput_input_indexes); + + /** + * @ingroup ge + * @brief Insert memcpy node in while_body + * @param [in] graph: while_body + * @param [in] data_nodes: data_nodes need insert memcpy + * @param [in] output_node: NetOutput in while_body + * @param [in] netoutput_input_indexes: input_indexes of netoutput need insert memcpy + * @return: 0 for SUCCESS / others for FAILED + */ + Status InsertMemcpyInWhileBody(const ComputeGraphPtr &graph, const std::set &data_nodes, + const NodePtr &output_node, const std::set &netoutput_input_indexes); + + /** + * @ingroup ge + * @brief Insert NoOp node between memcpy_nodes and loop_body_nodes + * @param [in] graph: while_body + * @param [in] memcpy_nodes + * @param [in] loop_body_nodes + * @return: 0 for SUCCESS / others for FAILED + */ + Status InsertNoOp(const ComputeGraphPtr &graph, const std::set &memcpy_nodes, + const std::set &loop_body_nodes); + + /** + * @ingroup ge + * @brief Check is Data->NetOutput in while body * @param [in] in_data_anchor - * @return: true for data->netoutput in while body / for false for others + * @return: true for Data->NetOutput in while body / false for others */ bool IsWhileBodyOutput(const InDataAnchorPtr &in_data_anchor); + /** + * @ingroup ge + * @brief Check is AtomicOp->NetOutput + * @param [in] node + * @param [in] out_index + * @return: true for AtomicOp->NetOutput / false for others + */ + bool IsAtomicRequired(const NodePtr &node, int64_t out_index); + + /** + * @ingroup ge + * @brief Check is OutputContinuesRequiredOp->NetOutput + * @param [in] node + * @return: true for OutputContinuesRequiredOp->NetOutput / false for others + */ + bool IsOutputContinuesRequired(const NodePtr &node); + + /** + * @ingroup ge + * @brief Check is InputContinuesRequiredOp->NetOutput + * @param [in] node + * @return: true for InputContinuesRequiredOp->NetOutput / false for others + */ + bool IsInputContinuesRequired(const NodePtr &node); + /** * @ingroup ge * @brief Insert memcpy node * @param [in] graph * @param [in] out_anchor - * @param [in] in_anchor + * @param [in] in_anchors * @param [in] name * @return: 0 for success / others for fail */ Status InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, - const InDataAnchorPtr &in_anchor, const std::string &name); + const std::vector &in_anchors, const std::string &name); // Append index for new memcpy node. uint32_t memcpy_num_{0}; diff --git a/src/ge/graph/passes/switch_pass.cc b/src/ge/graph/passes/switch_dead_branch_elimination.cc similarity index 94% rename from src/ge/graph/passes/switch_pass.cc rename to src/ge/graph/passes/switch_dead_branch_elimination.cc index 8230d294..c4ae4647 100644 --- a/src/ge/graph/passes/switch_pass.cc +++ b/src/ge/graph/passes/switch_dead_branch_elimination.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/switch_pass.h" +#include "graph/passes/switch_dead_branch_elimination.h" #include #include @@ -90,7 +90,8 @@ bool ParseOutDataAnchors(const NodePtr &node, const NodePtr &pred_node, OutDataA } } // namespace -Status SwitchPass::DeleteSwitchNode(NodePtr &node, NodePtr &pred_node, const OutDataAnchorPtr &active_out_data_anchor) { +Status SwitchDeadBranchElimination::DeleteSwitchNode(NodePtr &node, NodePtr &pred_node, + const OutDataAnchorPtr &active_out_data_anchor) { if (node == nullptr || active_out_data_anchor == nullptr) { GELOGE(FAILED, "parameter is null."); return FAILED; @@ -116,8 +117,7 @@ Status SwitchPass::DeleteSwitchNode(NodePtr &node, NodePtr &pred_node, const Out return IsolateAndDeleteNode(node, switch_io_map); } -Status SwitchPass::Run(NodePtr &node) { - GELOGD("SwitchPass running"); +Status SwitchDeadBranchElimination::Run(NodePtr &node) { if (node == nullptr) { GELOGE(PARAM_INVALID, "Param [node] must not be null."); return PARAM_INVALID; diff --git a/src/ge/graph/passes/switch_pass.h b/src/ge/graph/passes/switch_dead_branch_elimination.h similarity index 78% rename from src/ge/graph/passes/switch_pass.h rename to src/ge/graph/passes/switch_dead_branch_elimination.h index 04760843..4f2b9f02 100644 --- a/src/ge/graph/passes/switch_pass.h +++ b/src/ge/graph/passes/switch_dead_branch_elimination.h @@ -14,13 +14,13 @@ * limitations under the License. */ -#ifndef GE_GRAPH_PASSES_SWITCH_PASS_H_ -#define GE_GRAPH_PASSES_SWITCH_PASS_H_ +#ifndef GE_GRAPH_PASSES_SWITCH_DEAD_BRANCH_ELIMINATION_H_ +#define GE_GRAPH_PASSES_SWITCH_DEAD_BRANCH_ELIMINATION_H_ #include "graph/passes/base_pass.h" namespace ge { -class SwitchPass : public BaseNodePass { +class SwitchDeadBranchElimination : public BaseNodePass { public: Status Run(NodePtr &node) override; @@ -29,4 +29,4 @@ class SwitchPass : public BaseNodePass { }; } // namespace ge -#endif // GE_GRAPH_PASSES_SWITCH_PASS_H_ +#endif // GE_GRAPH_PASSES_SWITCH_DEAD_BRANCH_ELIMINATION_H_ diff --git a/src/ge/graph/passes/switch_fusion_pass.cc b/src/ge/graph/passes/switch_fusion_pass.cc new file mode 100644 index 00000000..f475d857 --- /dev/null +++ b/src/ge/graph/passes/switch_fusion_pass.cc @@ -0,0 +1,249 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "switch_fusion_pass.h" +#include +#include +#include "common/ge/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +namespace ge { +namespace { +const int kSwitchDataInputIdx = 0; +const int kSwitchCondInputIdx = 1; +const int kSwitchFalseOutIdx = 0; +const int kSwitchTrueOutIdx = 1; +int GetSwitchOutDataIdx(const string fusion_group_id) { return std::stoi(fusion_group_id.substr(0)); } +} // namespace + +Status SwitchFusionPass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + if (node->GetOpDesc()->GetType() != SWITCH && node->GetOpDesc()->GetType() != REFSWITCH) { + return SUCCESS; + } + GELOGD("Switch fusion pass in.Current switch node name is %s", node->GetName().c_str()); + // 1. find cond input + auto switch_in_cond_anchor = node->GetInDataAnchor(kSwitchCondInputIdx); + if (switch_in_cond_anchor->GetPeerOutAnchor() == nullptr) { + GELOGI("Switch %s in condition peer out anchor is null.", node->GetName().c_str()); + return FAILED; + } + auto switch_cond_in_node = switch_in_cond_anchor->GetPeerOutAnchor()->GetOwnerNode(); + GELOGD("Switch %s cond in data node is %s.", node->GetName().c_str(), switch_cond_in_node->GetName().c_str()); + if (switch_cond_in_node->GetOutDataNodesSize() == 1) { + GELOGI("This condition only has one switch, no need fusion."); + return SUCCESS; + } + // 2. find other switch with same condition + for (const auto out_data_node : switch_cond_in_node->GetOutDataNodes()) { + if (out_data_node->GetType() == SWITCH || out_data_node->GetType() == REFSWITCH) { + // 2.1 collect switch node can be fused with same cond_in_node + auto true_out_anchor = out_data_node->GetOutDataAnchor(kSwitchTrueOutIdx); + auto false_out_anchor = out_data_node->GetOutDataAnchor(kSwitchFalseOutIdx); + int branch_idx = true_out_anchor == nullptr ? kSwitchFalseOutIdx : kSwitchTrueOutIdx; + if (out_data_node->GetOutDataAnchor(branch_idx)->GetPeerInDataNodesSize() > 1) { + GELOGI("Current switch node %s has more than one output, need go to switch split first.", + out_data_node->GetName().c_str()); + continue; + } + string fusion_road_id; + fusion_road_id = GetFusionRoadId(std::to_string(branch_idx), out_data_node); + GELOGI("Switch node %s out idx %d, group_id is %s.", out_data_node->GetName().c_str(), branch_idx, + fusion_road_id.c_str()); + auto iter = switch_group_map_.find(fusion_road_id); + if (iter == switch_group_map_.end()) { + switch_group_map_.emplace(std::make_pair(fusion_road_id, std::set{out_data_node})); + } else { + // to avoid one cond node is also as data node + if (iter->second.count(out_data_node) == 0) { + iter->second.emplace(out_data_node); + } + } + } + } + // 3. fuse switch from different group + auto ret = FuseSwitchGroup(); + if (ret != SUCCESS) { + GELOGE(FAILED, "Fuse switch nodes with same final output to one failed."); + return ret; + } + return SUCCESS; +} +/* + * var1 ALLREDUCE/Cast var3 var1 var2 var3 ALLREDUCE/Cast + * \ / \ \ / \ | / \ / \ + * switch1 switch2 switch3 ======> AdamApplyOne / \--->switch1 + * \ | / \ / | + * AdamApplyOne / mul <--- identity + * \ / + * mul + */ +Status SwitchFusionPass::FuseSwitchGroup() { + for (auto &key_2_switch_group : switch_group_map_) { + if (key_2_switch_group.second.size() == 1) { + break; + } + // 1.Insert Identity node + NodePtr remain_switch = *key_2_switch_group.second.begin(); + auto switch_out_anchor_idx = GetSwitchOutDataIdx(key_2_switch_group.first); + auto identity_node = InsertIdentityNode(remain_switch, switch_out_anchor_idx); + if (identity_node == nullptr) { + GELOGE(INTERNAL_ERROR, "Create Identity op %s fail.", identity_node->GetName().c_str()); + return FAILED; + } + // 2. Remove all switch nodes between data anchors. + string hccl_group_id; + for (const auto &switch_node : key_2_switch_group.second) { + GELOGI("Get corresponding SWITCH node is %s.Out data anchor idx is %d.", switch_node->GetName().c_str(), + switch_out_anchor_idx); + // get hccl group id for remain switch + if (AttrUtils::GetStr(switch_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { + GELOGI("Get hccl group id %s of switch node %s.", hccl_group_id.c_str(), switch_node->GetName().c_str()); + } + auto switch_peer_in_data_anchor = + switch_node->GetOutDataAnchor(switch_out_anchor_idx)->GetPeerInDataAnchors().at(0); + GE_RETURN_WITH_LOG_IF_ERROR(RemoveSwitchBetweenTwoNode(switch_out_anchor_idx, switch_node)); + GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(identity_node->GetOutControlAnchor(), + switch_peer_in_data_anchor->GetOwnerNode()->GetInControlAnchor()), + "Link control edge from identity %s to out node %s.", + identity_node->GetName().c_str(), + switch_peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + } + GELOGI("Start fusion switch nodes. Switch_nodes_set size is %d", key_2_switch_group.second.size()); + // 3.Fuse all switch to one, first is remain_switch + GE_RETURN_WITH_LOG_IF_ERROR(FuseSwitchNodesToOne(remain_switch, key_2_switch_group.second)); + if (!hccl_group_id.empty()) { + AttrUtils::SetStr(remain_switch->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id); + GELOGI("Set attr ATTR_NAME_HCCL_FUSED_GROUP for Stream node %s, value is %s.", remain_switch->GetName().c_str(), + hccl_group_id.c_str()); + } + // Link switch to identity + GraphUtils::AddEdge(remain_switch->GetOutDataAnchor(switch_out_anchor_idx), identity_node->GetInDataAnchor(0)); + } + return SUCCESS; +} +/* + * var1---- + * cond---- + * var2---- + */ +Status SwitchFusionPass::RemoveSwitchBetweenTwoNode(const int switch_out_anchor_idx, const NodePtr &switch_node) { + auto switch_in_data_anchor = switch_node->GetInDataAnchor(kSwitchDataInputIdx); + auto switch_in_cond_anchor = switch_node->GetInDataAnchor(kSwitchCondInputIdx); + // here we assume after switch split, one switch node only has one data output,so just get first is ok. + auto switch_peer_in_data_anchor = switch_node->GetOutDataAnchor(switch_out_anchor_idx)->GetPeerInDataAnchors().at(0); + // 2.1 unlink all data edge from switch to out_node + GE_RETURN_WITH_LOG_IF_ERROR( + GraphUtils::RemoveEdge(switch_node->GetOutDataAnchor(switch_out_anchor_idx), switch_peer_in_data_anchor), + "Remove edge from switch %s to out node %s.", switch_node->GetName().c_str(), + switch_peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + // 2.2 replace data edge from switch_data_in_node to switch_data_out_node + if (switch_in_data_anchor->GetPeerOutAnchor() == nullptr) { + GELOGI("Switch %s in data peer out anchor is null.", switch_node->GetName().c_str()); + return FAILED; + } + auto switch_in_node = switch_in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); + GELOGI("Switch %s in data node is %s.", switch_node->GetName().c_str(), switch_in_node->GetName().c_str()); + GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::ReplaceEdgeDst(switch_in_data_anchor->GetPeerOutAnchor(), + switch_in_data_anchor, switch_peer_in_data_anchor), + "ReplaceEdgeDst from switch_data_in_node %s to switch_out_node %s.", + switch_in_node->GetName().c_str(), + switch_peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + // 2.3 link control edge from switch_data_in_node to switch + GE_RETURN_WITH_LOG_IF_ERROR( + GraphUtils::AddEdge(switch_in_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), + "Link control edge from switch_data_in_node %s to switch node %s failed.", switch_in_node->GetName().c_str(), + switch_node->GetName().c_str()); + return SUCCESS; +} + +Status SwitchFusionPass::FuseSwitchNodesToOne(NodePtr &remain_switch, const std::set switch_nodes_set) { + auto iter = ++switch_nodes_set.begin(); + while (iter != switch_nodes_set.end()) { + GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::CopyInCtrlEdges(*iter, remain_switch), + "Copy in control edge from %s to %s failed.", (*iter)->GetName().c_str(), + remain_switch->GetName().c_str()); + GE_RETURN_WITH_LOG_IF_ERROR(NodeUtils::MoveOutputEdges(*iter, remain_switch), + "Move output edges from %s to %s failed.", (*iter)->GetName().c_str(), + remain_switch->GetName().c_str()); + if ((*iter)->GetOutDataNodesSize() == 0) { + auto ret = IsolateAndDeleteNode(const_cast(*iter), {}); + if (ret == SUCCESS) { + GELOGI("IsolateAndDeleteNode Switch node %s", (*iter)->GetName().c_str()); + } + } else { + GELOGI("Switch node %s has more than one out data nodes, keep it.", (*iter)->GetName().c_str()); + } + iter++; + } + // link data input for remain switch + auto cond_node = remain_switch->GetInDataAnchor(kSwitchCondInputIdx)->GetPeerOutAnchor()->GetOwnerNode(); + GELOGI("Get cond node %s of switch node %s.", cond_node->GetName().c_str(), remain_switch->GetName().c_str()); + GE_RETURN_WITH_LOG_IF_ERROR( + GraphUtils::AddEdge(cond_node->GetOutDataAnchor(0), remain_switch->GetInDataAnchor(kSwitchDataInputIdx)), + "Fail to add edge from cond_node %s to remain_switch %s.", cond_node->GetName().c_str(), + remain_switch->GetName().c_str()); + return SUCCESS; +} + +const string SwitchFusionPass::GetFusionRoadId(const string branch_id, const NodePtr &switch_node) { + std::deque queue; + queue.push_back(switch_node); + std::stringstream group_id; + group_id << branch_id; + + while (!queue.empty()) { + NodePtr node = queue.front(); + queue.pop_front(); + if (node->GetOutDataNodesSize() == 0) { + group_id << "-" << node->GetName(); + GELOGI("Switch node %s, group id is %s", switch_node->GetName().c_str(), group_id.str().c_str()); + return group_id.str(); + } + for (const auto &out_data_node : node->GetOutDataNodes()) { + if (out_data_node->GetType() == NETOUTPUT || out_data_node->GetType() == SWITCH || + out_data_node->GetType() == SWITCH) { + // if meet NETOUTPUT, it is the end of current ROAD + group_id << "-" << node->GetName(); + GELOGI("Switch node %s, group id is %s", switch_node->GetName().c_str(), group_id.str().c_str()); + return group_id.str(); + } + queue.emplace_back(out_data_node); + } + } + return group_id.str(); +} +NodePtr SwitchFusionPass::InsertIdentityNode(const NodePtr &remain_switch, const int out_data_anchor_idx) { + const std::string identity_name = remain_switch->GetOpDesc()->GetName() + "_" + IDENTITY; + ComputeGraphPtr graph = remain_switch->GetOwnerComputeGraph(); + auto data_desc = remain_switch->GetOpDesc()->GetOutputDesc(out_data_anchor_idx); + OpDescPtr op_desc = MakeShared(identity_name, IDENTITY); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create Identity op %s: create op_desc fail.", identity_name.c_str()); + return nullptr; + } + if ((op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) || (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS)) { + GELOGE(INTERNAL_ERROR, "Create Identity op %s: add input/output desc fail.", identity_name.c_str()); + return nullptr; + } + GELOGI("Create Identity op:%s.", identity_name.c_str()); + return graph->AddNode(op_desc); +} +} // namespace ge \ No newline at end of file diff --git a/src/ge/graph/passes/switch_fusion_pass.h b/src/ge/graph/passes/switch_fusion_pass.h new file mode 100644 index 00000000..10ba5dad --- /dev/null +++ b/src/ge/graph/passes/switch_fusion_pass.h @@ -0,0 +1,37 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_SWITCH_FUSION_PASS_H_ +#define GE_GRAPH_PASSES_SWITCH_FUSION_PASS_H_ + +#include +#include "graph/passes/base_pass.h" +namespace ge { +class SwitchFusionPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; + + private: + Status FuseSwitchGroup(); + Status RemoveSwitchBetweenTwoNode(const int switch_out_anchor_idx, const NodePtr &switch_node); + Status FuseSwitchNodesToOne(NodePtr &remain_switch, const std::set switch_nodes_set); + const string GetFusionRoadId(const string branch_id, const NodePtr &switch_node); + NodePtr InsertIdentityNode(const NodePtr &remain_switch, const int out_data_anchor_idx); + map> switch_group_map_; +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_SWITCH_FUSION_PASS_H_ diff --git a/src/ge/graph/passes/switch_op_pass.cc b/src/ge/graph/passes/switch_op_pass.cc index 7eae40f8..b501804f 100644 --- a/src/ge/graph/passes/switch_op_pass.cc +++ b/src/ge/graph/passes/switch_op_pass.cc @@ -23,18 +23,19 @@ #include #include #include "common/ge/ge_util.h" +#include "ge/ge_api_types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" +#include "graph/ge_context.h" #include "graph/utils/type_utils.h" namespace ge { Status SwitchOpPass::Run(ComputeGraphPtr graph) { GELOGD("SwitchOpPass Enter"); - GE_CHK_STATUS_RET(CheckCycleDependence(graph), "CheckCycleDependence fail."); for (auto &switch_node : switch_nodes_) { @@ -93,7 +94,7 @@ Status SwitchOpPass::ReplaceSwitchNode(ComputeGraphPtr &graph, NodePtr &switch_n OutDataAnchorPtr peer_data_anchor = nullptr; OutDataAnchorPtr peer_cond_anchor = nullptr; GE_CHK_BOOL_EXEC(BypassSwitchNode(switch_node, peer_data_anchor, peer_cond_anchor) == SUCCESS, return FAILED, - "Bypass switch node fail."); + "Bypass switch node %s fail.", switch_node->GetName().c_str()); GE_CHECK_NOTNULL(peer_data_anchor); GE_CHECK_NOTNULL(peer_cond_anchor); OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); @@ -268,6 +269,15 @@ NodePtr SwitchOpPass::CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodeP GELOGE(FAILED, "Create op_desc fail, StreamSwitch:%s.", node_name.c_str()); return nullptr; } + // mark hccl group id + std::string hccl_group_id; + if (AttrUtils::GetStr(switch_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { + (void)AttrUtils::SetStr(op_desc, ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id); + GELOGI("Set attr ATTR_NAME_HCCL_FUSED_GROUP for Stream_Switch%s, value is %s.", node_name.c_str(), + hccl_group_id.c_str()); + } else { + GELOGI("Can not find attr ATTR_NAME_HCCL_FUSED_GROUP for node %s.", switch_node->GetName().c_str()); + } if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) || !AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) { @@ -333,62 +343,66 @@ NodePtr SwitchOpPass::CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDat /// Status SwitchOpPass::CombineSwitchNode(ComputeGraphPtr &graph) { for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { - OutDataAnchorPtr peer_cond_anchor = iter->first; - GE_CHECK_NOTNULL(peer_cond_anchor); - std::list false_switch_list = iter->second[SWITCH_FALSE_OUTPUT]; - std::list true_switch_list = iter->second[SWITCH_TRUE_OUTPUT]; - std::set same_cond_switch; - same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); - same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); + for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { + OutDataAnchorPtr peer_cond_anchor = iter->first; + GE_CHECK_NOTNULL(peer_cond_anchor); + std::list false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; + std::list true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; + std::set same_cond_switch; + same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); + same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); - NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); - GELOGI("CombineSwitchNode: cond_node=%s", cond_node->GetName().c_str()); + NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); + GELOGI("CombineSwitchNode: cond_node=%s", cond_node->GetName().c_str()); - NodePtr cast_node = CreateCastOp(graph, peer_cond_anchor); - GE_CHK_BOOL_EXEC(cast_node != nullptr, return FAILED, "Create cast_node fail."); + NodePtr cast_node = CreateCastOp(graph, peer_cond_anchor); + GE_CHK_BOOL_EXEC(cast_node != nullptr, return FAILED, "Create cast_node fail."); - NodePtr active_node = CreateActiveNode(graph, cond_node); - GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutControlAnchor(), active_node->GetInControlAnchor()), - "StreamActive add ctl edge fail."); - if (SetActiveLabelList(active_node, {cast_node->GetName()}) != SUCCESS) { - GELOGE(FAILED, "SetActiveLabelList for node %s fail.", active_node->GetName().c_str()); - return FAILED; - } - - const std::string cond_group = cond_node->GetName(); - for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { - bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); - std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); - GE_IF_BOOL_EXEC(switch_list.empty(), continue); - - // select first stream_switch - NodePtr stream_switch = switch_list.front(); - OpDescPtr switch_desc = stream_switch->GetOpDesc(); - GE_CHECK_NOTNULL(switch_desc); - switch_desc->SetName(cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f")); - stream_switch_nodes_.emplace_back(stream_switch); - need_label_nodes_.emplace_back(stream_switch); - - // 0_input: original pred input, 1_input: constant node - GE_CHK_STATUS_RET(AddConstNode(graph, stream_switch), "Add const node fail"); - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), - "StreamSwitch remove data edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), - "Cast add data edge fail."); - - for (NodePtr &node : switch_list) { - GE_CHECK_NOTNULL(node); - GE_IF_BOOL_EXEC(node != stream_switch, { - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), - "StreamSwitch remove data edge fail."); - }); - GE_CHK_STATUS(ModifySwitchInCtlEdges(node, cast_node, same_cond_switch), "ModifySwitchInCtlEdges fail"); - GE_CHK_STATUS(ModifySwitchOutCtlEdges(node, stream_switch, active_node), "ModifySwitchOutCtlEdges fail"); + NodePtr active_node = CreateActiveNode(graph, cond_node); + GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutControlAnchor(), active_node->GetInControlAnchor()), + "StreamActive add ctl edge fail."); + if (SetActiveLabelList(active_node, {cast_node->GetName()}) != SUCCESS) { + GELOGE(FAILED, "SetActiveLabelList for node %s fail.", active_node->GetName().c_str()); + return FAILED; } - GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), stream_switch->GetInControlAnchor()), - "StreamActive add ctl edge fail."); + const std::string cond_group = cond_node->GetName(); + for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { + bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); + std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); + GE_IF_BOOL_EXEC(switch_list.empty(), continue); + + // select first stream_switch + NodePtr stream_switch = switch_list.front(); + OpDescPtr switch_desc = stream_switch->GetOpDesc(); + GE_CHECK_NOTNULL(switch_desc); + std::string node_name = cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f"); + node_name = CheckDuplicateName(node_name); + switch_desc->SetName(node_name); + stream_switch_nodes_.emplace_back(stream_switch); + need_label_nodes_.emplace_back(stream_switch); + + // 0_input: original pred input, 1_input: constant node + GE_CHK_STATUS_RET(AddConstNode(graph, stream_switch), "Add const node fail"); + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), + "StreamSwitch remove data edge fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), + "Cast add data edge fail."); + + for (NodePtr &node : switch_list) { + GE_CHECK_NOTNULL(node); + GE_IF_BOOL_EXEC(node != stream_switch, { + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), + "StreamSwitch remove data edge fail."); + }); + GE_CHK_STATUS(ModifySwitchInCtlEdges(node, cast_node, same_cond_switch), "ModifySwitchInCtlEdges fail"); + GE_CHK_STATUS(ModifySwitchOutCtlEdges(node, stream_switch, active_node), "ModifySwitchOutCtlEdges fail"); + } + + GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), stream_switch->GetInControlAnchor()), + "StreamActive add ctl edge fail."); + } } } return SUCCESS; @@ -479,9 +493,11 @@ Status SwitchOpPass::BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &pe GE_CHK_BOOL_EXEC(switch_node != nullptr, return FAILED, "Switch_node is null."); for (uint32_t idx = 0; idx < SWITCH_INPUT_NUM; ++idx) { InDataAnchorPtr in_data_anchor = switch_node->GetInDataAnchor(idx); - GE_CHK_BOOL_EXEC(in_data_anchor != nullptr, return FAILED, "Check Switch input anchor fail."); + GE_CHK_BOOL_EXEC(in_data_anchor != nullptr, return FAILED, "node[%s]Check Switch input anchor fail.", + switch_node->GetName().c_str()); OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHK_BOOL_EXEC(peer_out_anchor != nullptr, return FAILED, "Check Pre node output anchor fail."); + GE_CHK_BOOL_EXEC(peer_out_anchor != nullptr, return FAILED, "node[%s]Check Pre node output anchor fail.", + switch_node->GetName().c_str()); // Remove Switch data input. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "remove edge failed"); @@ -528,6 +544,36 @@ Status SwitchOpPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr return SUCCESS; } +int SwitchOpPass::GetGroupId(const NodePtr &node) { + string tailing_optimization_option; + bool is_tailing_optimization = false; + auto ret = GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option); + if (ret == GRAPH_SUCCESS) { + // "1" means it's True from frontend option + is_tailing_optimization = (tailing_optimization_option == "1"); + GELOGI("Option ge.exec.isTailingOptimization is %s", tailing_optimization_option.c_str()); + } + if (!is_tailing_optimization) { + return 0; + } + + string hccl_group_id; + if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { + GELOGI("Node is %s, can not find hccl group id", node->GetName().c_str()); + return 0; + } + auto key_index = hccl_group_id.find_last_of('_'); + auto key_num = hccl_group_id.substr(key_index + 1, hccl_group_id.length() - key_index); + GELOGI("Node is %s,Hccl group id is %s, key_num is %s", node->GetName().c_str(), hccl_group_id.c_str(), + key_num.c_str()); + int num = atoi(key_num.c_str()); + if (num == 0) { + return 0; + } + GELOGI("Hccl group id is %s, group id is %d", hccl_group_id.c_str(), num); + return num; +} + /// /// @brief Mark Switch Branch /// @param [in] peer_cond_anchor @@ -540,12 +586,27 @@ Status SwitchOpPass::MarkBranchs(OutDataAnchorPtr &peer_cond_anchor, NodePtr &st GE_CHECK_NOTNULL(stream_switch); auto it = cond_node_map_.find(peer_cond_anchor); if (it != cond_node_map_.end()) { - GE_IF_BOOL_EXEC(it->second.size() != SWITCH_OUTPUT_NUM, { - GELOGE(INTERNAL_ERROR, "cond_node_map_ check size fail, node: %s", stream_switch->GetName().c_str()); - return FAILED; - }); - it->second[index].emplace_back(stream_switch); + int switch_group_id = GetGroupId(stream_switch); + auto switch_group_it = it->second.find(switch_group_id); + if (switch_group_it == it->second.end()) { + std::list false_node_list; + std::list true_node_list; + std::list &node_list = true_branch_flag ? true_node_list : false_node_list; + node_list.emplace_back(stream_switch); + std::vector> switch_list; + switch_list.emplace_back(false_node_list); + switch_list.emplace_back(true_node_list); + (void)it->second.emplace(switch_group_id, switch_list); + } else { + GE_IF_BOOL_EXEC(switch_group_it->second.size() != SWITCH_OUTPUT_NUM, { + GELOGE(INTERNAL_ERROR, "cond_node_map_ check size fail, node: %s", stream_switch->GetName().c_str()); + return FAILED; + }); + switch_group_it->second[index].emplace_back(stream_switch); + } } else { + int switch_group_id = GetGroupId(stream_switch); + map>> switch_group_map; std::list false_node_list; std::list true_node_list; std::list &node_list = true_branch_flag ? true_node_list : false_node_list; @@ -553,8 +614,9 @@ Status SwitchOpPass::MarkBranchs(OutDataAnchorPtr &peer_cond_anchor, NodePtr &st std::vector> switch_list; switch_list.emplace_back(false_node_list); switch_list.emplace_back(true_node_list); + (void)switch_group_map.emplace(switch_group_id, switch_list); auto result = cond_node_map_.insert( - std::pair>>(peer_cond_anchor, switch_list)); + std::pair>>>(peer_cond_anchor, switch_group_map)); GE_IF_BOOL_EXEC(!result.second, { GELOGE(INTERNAL_ERROR, "cond_node_map_ insert fail, node: %s", stream_switch->GetName().c_str()); return FAILED; @@ -574,7 +636,8 @@ NodePtr SwitchOpPass::CreateCastOp(ComputeGraphPtr &graph, OutDataAnchorPtr &pee OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); GE_CHK_BOOL_EXEC(cond_desc != nullptr, return nullptr, "Get cond_desc fail."); - const std::string cast_name = cond_desc->GetName() + "_" + CAST; + std::string cast_name = cond_desc->GetName() + "_" + CAST; + cast_name = CheckDuplicateName(cast_name); GELOGI("Create cast_node: %s, input datatype:DT_BOOL, out datatype:DT_INT32", cast_name.c_str()); OpDescPtr cast_desc = MakeShared(cast_name, CAST); if (cast_desc == nullptr) { @@ -1116,4 +1179,22 @@ void SwitchOpPass::ReplaceControlEdges(NodePtr &old_node, NodePtr &new_node) { CopyControlEdges(old_node, new_node); RemoveControlEdges(old_node); } +/// +/// @brief Clear Status, uesd for subgraph pass +/// @return +/// +Status SwitchOpPass::ClearStatus() { + switch_nodes_.clear(); + merge_nodes_.clear(); + enter_nodes_.clear(); + switch_cyclic_map_.clear(); + bypass_nodes_.clear(); + branch_head_nodes_.clear(); + stream_switch_nodes_.clear(); + need_label_nodes_.clear(); + cond_node_map_.clear(); + switch_node_map_.clear(); + node_num_map_.clear(); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/passes/switch_op_pass.h b/src/ge/graph/passes/switch_op_pass.h index 7e107e3b..704adcc1 100644 --- a/src/ge/graph/passes/switch_op_pass.h +++ b/src/ge/graph/passes/switch_op_pass.h @@ -94,6 +94,7 @@ namespace ge { class SwitchOpPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); + Status ClearStatus() override; private: Status ReplaceSwitchNode(ComputeGraphPtr &graph, NodePtr &switch_node); @@ -146,6 +147,8 @@ class SwitchOpPass : public GraphPass { void ReplaceControlEdges(NodePtr &old_node, NodePtr &new_node); + int GetGroupId(const NodePtr &node); + std::vector switch_nodes_; std::vector merge_nodes_; std::vector enter_nodes_; @@ -155,7 +158,7 @@ class SwitchOpPass : public GraphPass { std::set branch_head_nodes_; std::vector stream_switch_nodes_; std::vector need_label_nodes_; - std::unordered_map>> cond_node_map_; + std::unordered_map>>> cond_node_map_; std::unordered_map> switch_node_map_; std::unordered_map node_num_map_; }; diff --git a/src/ge/graph/passes/switch_split_pass.cc b/src/ge/graph/passes/switch_split_pass.cc new file mode 100644 index 00000000..07f59d20 --- /dev/null +++ b/src/ge/graph/passes/switch_split_pass.cc @@ -0,0 +1,145 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "switch_split_pass.h" +#include +#include "common/ge/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" + +using namespace ge; +namespace { +const string output_false = "output_false"; +const string output_true = "output_true"; +string GetOutputDescName(const int idx) { return idx == 0 ? output_false : output_true; } +graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node) { + if ((src_node == nullptr) || (dst_node == nullptr)) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_PARAM_INVALID; + } + auto src_data_in_nodes = src_node->GetInDataNodes(); + if (src_data_in_nodes.empty()) { + return GRAPH_SUCCESS; + } + for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) { + auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); + auto ret = + GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx())); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s", + in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(), + src_node->GetName().c_str(), dst_node->GetName().c_str()); + return ret; + } + } + return GRAPH_SUCCESS; +} +NodePtr CreateSwitchFromOld(const int index, const NodePtr &old_switch, const OutDataAnchorPtr &out_data_anchor) { + auto graph = old_switch->GetOwnerComputeGraph(); + // 1. create new switch op desc + string new_switch_name = old_switch->GetName() + "_" + std::to_string(index); + auto new_switch_opdesc = MakeShared(new_switch_name, old_switch->GetType()); + if (new_switch_opdesc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Failed to insert switch node, name %s", new_switch_name.c_str()); + return nullptr; + } + // 2. add input_desc & output_desc for new switch + Status ret; + for (const auto &in_data_anchor : old_switch->GetAllInDataAnchors()) { + auto input_desc = old_switch->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); + ret = new_switch_opdesc->AddInputDesc(in_data_anchor->GetIdx(), input_desc); + if (ret != SUCCESS) { + GELOGE(FAILED, "Add Input desc failed for new switch %s.", new_switch_name.c_str()); + return nullptr; + } + } + auto output_desc = old_switch->GetOpDesc()->GetOutputDesc(out_data_anchor->GetIdx()); + // we got out_data_anchor, another out_data_anchor is (1-idx), because idx is 0 or 1. + auto ret1 = new_switch_opdesc->AddOutputDesc(GetOutputDescName(1 - out_data_anchor->GetIdx()), output_desc); + auto ret2 = new_switch_opdesc->AddOutputDesc(GetOutputDescName(out_data_anchor->GetIdx()), output_desc); + if (ret1 != SUCCESS || ret2 != SUCCESS) { + GELOGE(FAILED, "Add Output desc failed for new switch %s.", new_switch_name.c_str()); + return nullptr; + } + GELOGI("Insert new switch node %s.", new_switch_name.c_str()); + return graph->AddNode(new_switch_opdesc); +} +} // namespace +namespace ge { +Status SwitchSplitPass::Run(NodePtr &node) { + // To handle one out data anchor with multi peer input data anchor + GE_CHECK_NOTNULL(node); + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (op_desc->GetType() != SWITCH && op_desc->GetType() != REFSWITCH) { + return SUCCESS; + } + if (op_desc->GetName().find("apply_one_adam") == string::npos) { + // Currently for bert optimize, will fix later. + GELOGI("Current switch node name is %s, ignore it.", op_desc->GetName().c_str()); + return SUCCESS; + } + GELOGI("Switch split pass in. Current switch node name is %s", op_desc->GetName().c_str()); + int index = 0; + // 1. find all output + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + if (out_data_anchor->GetPeerInDataNodesSize() < 2) { + GELOGI("Switch node %s %d th out data anchor only has 1 peer_in_data_anchor.Ignore it.", node->GetName().c_str(), + out_data_anchor->GetIdx()); + continue; + } + for (const auto &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + NodePtr new_switch = CreateSwitchFromOld(index, node, out_data_anchor); + if (new_switch == nullptr) { + GELOGW("Insert switch node failed."); + return FAILED; + } + // 1.3 copy int/out edge from old switch to new switch + auto ret1 = CopyInDataEdges(node, new_switch); + auto ret2 = GraphUtils::CopyInCtrlEdges(node, new_switch); + auto ret3 = GraphUtils::CopyOutCtrlEdges(node, new_switch); + if (ret1 != GRAPH_SUCCESS || ret2 != GRAPH_SUCCESS || ret3 != GRAPH_SUCCESS) { + GELOGE(FAILED, "Copy edge from %s to %s failed.", node->GetName().c_str(), new_switch->GetName().c_str()); + return FAILED; + } + if (out_data_anchor->Unlink(peer_in_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Unlink from old switch %s out data anchor %d to peer in anchor failed.", + node->GetName().c_str(), out_data_anchor->GetIdx()); + } + auto ret4 = GraphUtils::AddEdge(new_switch->GetOutDataAnchor(out_data_anchor->GetIdx()), peer_in_anchor); + if (ret4 != GRAPH_SUCCESS) { + GELOGE(FAILED, "Replace out data edge from old switch %s to new switch %s failed.", node->GetName().c_str(), + new_switch->GetName().c_str()); + return FAILED; + } + AddRePassNode(new_switch); + index++; + } + } + // 2.isolate switch node with no data output + if (node->GetOutDataNodesSize() == 0) { + auto ret = IsolateAndDeleteNode(node, {}); + if (ret != SUCCESS) { + GELOGE(FAILED, "IsolateAndDelete switch node %s.", node->GetName().c_str()); + return FAILED; + } + GELOGI("IsolateAndDelete switch node %s.", node->GetName().c_str()); + } + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/switch_split_pass.h b/src/ge/graph/passes/switch_split_pass.h new file mode 100644 index 00000000..69ab01c3 --- /dev/null +++ b/src/ge/graph/passes/switch_split_pass.h @@ -0,0 +1,28 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_SWITCH_SPLIT_PASS_H_ +#define GE_GRAPH_PASSES_SWITCH_SPLIT_PASS_H_ + +#include +#include "graph/passes/base_pass.h" +namespace ge { +class SwitchSplitPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_SWITCH_SPLIT_PASS_H_ diff --git a/src/ge/graph/passes/transop_depth_fusion_pass.cc b/src/ge/graph/passes/transop_depth_fusion_pass.cc index 68899e2e..c0c854b6 100644 --- a/src/ge/graph/passes/transop_depth_fusion_pass.cc +++ b/src/ge/graph/passes/transop_depth_fusion_pass.cc @@ -25,6 +25,7 @@ #include "graph/op_desc.h" #include "graph/utils/graph_utils.h" #include "graph/common/transop_util.h" +#include "graph/utils/node_utils.h" namespace ge { graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) { @@ -38,6 +39,7 @@ graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) { if (TransOpUtil::IsTransOp(node)) { continue; } + GELOGD("Current normal node is: %s, type: %s, begin in-depth recursive", node->GetName().c_str(), node->GetType().c_str()); for (const auto &out_anchor : node->GetAllOutDataAnchors()) { @@ -164,6 +166,13 @@ graphStatus TransOpDepthFusionPass::RecursiveInDepth(const InDataAnchorPtr &dst_ } bool TransOpDepthFusionPass::CheckNodeCanBeDeleted(const NodePtr &node) { + bool is_shape_unknown = false; + if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) { + if (is_shape_unknown) { + GELOGI("op:%s is unknown shape, can not be deleted.", node->GetName().c_str()); + return false; + } + } return node->GetType() == RESHAPE || node->GetType() == REFORMAT || node->GetType() == SQUEEZE || node->GetType() == EXPANDDIMS; } diff --git a/src/ge/graph/passes/transop_symmetry_elimination_pass.cc b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc index af75d9d0..0f8f30bf 100644 --- a/src/ge/graph/passes/transop_symmetry_elimination_pass.cc +++ b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc @@ -90,7 +90,8 @@ bool TransOpSymmetryEliminationPass::DescAreSymmetry(const NodePtr &src_node, co const auto &dst_output_shape = dst_output_desc->GetShape().GetDims(); if (src_node->GetType() == CAST && dst_node->GetType() == CAST) { - return (src_input_dtype == dst_output_dtype) && (src_input_format == dst_output_format); + bool is_format_symmetry = (src_input_format == dst_output_format) || (dst_output_format == FORMAT_ND); + return (src_input_dtype == dst_output_dtype) && is_format_symmetry; } else { return (src_input_dtype == dst_output_dtype) && (src_input_shape == dst_output_shape) && (src_input_format == dst_output_format); diff --git a/src/ge/graph/passes/variable_op_pass.cc b/src/ge/graph/passes/variable_op_pass.cc index d5dedbdc..175a049a 100644 --- a/src/ge/graph/passes/variable_op_pass.cc +++ b/src/ge/graph/passes/variable_op_pass.cc @@ -96,7 +96,7 @@ bool IsTransSupport(const TransNodeInfo &trans_info) { } if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) { return true; - } else if (trans_info.node_type == TRANSDATA) { + } else if (trans_info.node_type == TRANSDATA || trans_info.node_type == TRANSPOSED) { formats::TransArgs args{nullptr, trans_info.input.GetFormat(), trans_info.output.GetFormat(), diff --git a/src/ge/graph/passes/variable_ref_delete_op_pass.cc b/src/ge/graph/passes/variable_ref_delete_op_pass.cc index dfdb8335..cd5b9fe9 100644 --- a/src/ge/graph/passes/variable_ref_delete_op_pass.cc +++ b/src/ge/graph/passes/variable_ref_delete_op_pass.cc @@ -72,7 +72,7 @@ Status VariableRefDeleteOpPass::DealVariableRef(ge::ComputeGraphPtr &graph, ge:: GE_CHECK_NOTNULL(peer_node->GetOpDesc()); bool is_set_str = ge::AttrUtils::SetStr(peer_node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); - ge::NodePtr ref_var_src_var = graph->FindNode(ref_var_src_var_name); + ge::NodePtr ref_var_src_var = GraphUtils::FindNodeFromAllNodes(graph, ref_var_src_var_name); if (ref_var_src_var == nullptr) { GELOGE(FAILED, "get ref_var_src_var failed"); return FAILED; diff --git a/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc b/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc new file mode 100644 index 00000000..bd153184 --- /dev/null +++ b/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "variable_ref_useless_control_out_delete_pass.h" + +namespace ge { + +Status VariableRefUselessControlOutDeletePass::Run(ge::ComputeGraphPtr graph) { + GE_CHECK_NOTNULL(graph); + for (const auto &node : graph->GetDirectNode()) { + if (node->GetType() != VARIABLE) { + continue; + } + std::string src_var_name; + if (!AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, src_var_name)) { + continue; + } + auto src_nodes = node->GetInDataNodes(); + if (src_nodes.empty()) { + GELOGW("The variable ref name %s(ref %s) does not has a input node", node->GetName().c_str(), + src_var_name.c_str()); + continue; + } + auto &src_node = src_nodes.at(0); + auto controlled_nodes_vec = src_node->GetOutNodes(); + std::set controlled_nodes{controlled_nodes_vec.begin(), controlled_nodes_vec.end()}; + + auto out_control_anchor = node->GetOutControlAnchor(); + for (const auto &dst_node_anchor : out_control_anchor->GetPeerInControlAnchors()) { + if (controlled_nodes.count(dst_node_anchor->GetOwnerNode()) > 0) { + GELOGI("Unlink the duplicated control edge from variable ref %s to %s, prev node %s", node->GetName().c_str(), + dst_node_anchor->GetOwnerNode()->GetName().c_str(), src_node->GetName().c_str()); + out_control_anchor->Unlink(dst_node_anchor); + } + } + } + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.h b/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.h new file mode 100644 index 00000000..307754da --- /dev/null +++ b/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_VARIABLE_REF_USELESS_CONTROL_OUT_DELETE_PASS_H_ +#define GE_VARIABLE_REF_USELESS_CONTROL_OUT_DELETE_PASS_H_ + +#include +#include + +#include "framework/common/ge_inner_error_codes.h" +#include "inc/graph_pass.h" + +namespace ge { +class VariableRefUselessControlOutDeletePass : public GraphPass { + public: + Status Run(ge::ComputeGraphPtr graph); +}; +} // namespace ge +#endif // GE_VARIABLE_REF_USELESS_CONTROL_OUT_DELETE_PASS_H_ diff --git a/src/ge/graph/preprocess/graph_preprocess.cc b/src/ge/graph/preprocess/graph_preprocess.cc index 9850ef9b..f17d0395 100644 --- a/src/ge/graph/preprocess/graph_preprocess.cc +++ b/src/ge/graph/preprocess/graph_preprocess.cc @@ -21,12 +21,14 @@ #include #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" +#include "common/formats/format_transfers/format_transfer_transpose.h" #include "common/helper/model_helper.h" #include "common/math/math_util.h" +#include "common/util/error_manager/error_manager.h" #include "common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" -#include "graph/common/transop_util.h" #include "graph/common/ge_call_wrapper.h" +#include "graph/common/transop_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" #include "graph/manager/graph_var_manager.h" @@ -36,6 +38,8 @@ #include "graph/passes/aicpu_constant_folding_pass.h" #include "graph/passes/assert_pass.h" #include "graph/passes/base_pass.h" +#include "graph/passes/common_subexpression_elimination_pass.h" +#include "graph/passes/cond_pass.h" #include "graph/passes/constant_folding_pass.h" #include "graph/passes/constant_fuse_same_pass.h" #include "graph/passes/control_trigger_pass.h" @@ -47,9 +51,9 @@ #include "graph/passes/for_pass.h" #include "graph/passes/get_original_format_pass.h" #include "graph/passes/guarantee_const_pass.h" +#include "graph/passes/hccl_group_pass.h" #include "graph/passes/hccl_memcpy_pass.h" #include "graph/passes/identity_pass.h" -#include "graph/passes/cond_pass.h" #include "graph/passes/infershape_pass.h" #include "graph/passes/iterator_op_pass.h" #include "graph/passes/merge_pass.h" @@ -61,15 +65,20 @@ #include "graph/passes/prevent_gradient_pass.h" #include "graph/passes/print_op_pass.h" #include "graph/passes/prune_pass.h" +#include "graph/passes/replace_transshape_pass.h" +#include "graph/passes/replace_with_empty_const_pass.h" #include "graph/passes/resource_pair_add_control_pass.h" #include "graph/passes/resource_pair_remove_control_pass.h" #include "graph/passes/save_pass.h" #include "graph/passes/shape_operate_op_remove_pass.h" #include "graph/passes/snapshot_pass.h" #include "graph/passes/stop_gradient_pass.h" +#include "graph/passes/subgraph_pass.h" +#include "graph/passes/switch_dead_branch_elimination.h" +#include "graph/passes/switch_fusion_pass.h" #include "graph/passes/switch_logic_remove_pass.h" #include "graph/passes/switch_op_pass.h" -#include "graph/passes/switch_pass.h" +#include "graph/passes/switch_split_pass.h" #include "graph/passes/unused_const_pass.h" #include "graph/passes/unused_op_remove_pass.h" #include "graph/passes/var_is_initialized_op_pass.h" @@ -78,6 +87,7 @@ #include "graph/passes/replace_with_empty_const_pass.h" #include "graph/passes/subgraph_pass.h" #include "graph/passes/replace_transshape_pass.h" +#include "graph/passes/cond_remove_pass.h" #include "graph/preprocess/insert_op/util_insert_aipp_op.h" #include "graph/types.h" #include "graph/utils/tensor_utils.h" @@ -87,18 +97,19 @@ #include "multi_batch_copy_graph.h" #include "runtime/dev.h" -#include "graph/passes/transop_nearby_allreduce_fusion_pass.h" -#include "graph/passes/reshape_remove_pass.h" #include "graph/passes/dimension_adjust_pass.h" #include "graph/passes/identify_reference_pass.h" #include "graph/passes/link_gen_mask_nodes_pass.h" #include "graph/passes/permute_pass.h" +#include "graph/passes/reshape_remove_pass.h" #include "graph/passes/same_transdata_breadth_fusion_pass.h" #include "graph/passes/transop_breadth_fusion_pass.h" #include "graph/passes/transop_depth_fusion_pass.h" +#include "graph/passes/transop_nearby_allreduce_fusion_pass.h" -#include "graph/passes/transop_without_reshape_fusion_pass.h" #include "graph/passes/cast_remove_pass.h" +#include "graph/passes/data_pass.h" +#include "graph/passes/transop_without_reshape_fusion_pass.h" #include "graph/passes/transpose_transdata_pass.h" #include "graph/passes/variable_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" @@ -162,6 +173,17 @@ void AddTransNodeAttr(const std::string &node_type, const GeTensorDesc &input, c !AttrUtils::SetStr(op_desc, FORMAT_TRANSFER_DST_FORMAT, TypeUtils::FormatToSerialString(output.GetFormat())), GELOGW("SetStr FORMAT_TRANSFER_DST_FORMAT failed");) } + + // For TransposeD node, the IR definition has perm attrs + if (node_type == TRANSPOSED) { + Format src_format = input.GetFormat(); + Format dst_format = output.GetFormat(); + std::vector perm_arg; + GE_CHK_BOOL_EXEC_WARN(formats::GetPermByForamt(src_format, dst_format, perm_arg) == SUCCESS, return, + "Get perm by foramt failed."); + GE_CHK_BOOL_EXEC_WARN(AttrUtils::SetListInt(op_desc, PERMUTE_ATTR_PERM, perm_arg), return, + "SetStr PERMUTE_ATTR_PERM failed") + } // For cast node, the IR definition has src/dst attrs if (node_type == CAST) { GE_IF_BOOL_EXEC(!AttrUtils::SetInt(op_desc, CAST_ATTR_SRCT, static_cast(input.GetDataType())), @@ -326,6 +348,9 @@ Status UpdateVarFormats(const NodePtr &var, const GeTensorDesc &tensor_desc) { output_desc.SetFormat(tensor_desc.GetFormat()); output_desc.SetDataType(tensor_desc.GetDataType()); output_desc.SetShape(tensor_desc.GetShape()); + output_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); + output_desc.SetOriginDataType(tensor_desc.GetOriginDataType()); + output_desc.SetOriginShape(tensor_desc.GetOriginShape()); GE_IF_BOOL_EXEC(var->GetOpDesc()->UpdateOutputDesc(0, output_desc) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "UpdateOutputDesc failed"); return INTERNAL_ERROR;); @@ -336,6 +361,9 @@ Status UpdateVarFormats(const NodePtr &var, const GeTensorDesc &tensor_desc) { desc.SetFormat(tensor_desc.GetFormat()); desc.SetDataType(tensor_desc.GetDataType()); desc.SetShape(tensor_desc.GetShape()); + desc.SetOriginFormat(tensor_desc.GetOriginFormat()); + desc.SetOriginDataType(tensor_desc.GetOriginDataType()); + desc.SetOriginShape(tensor_desc.GetOriginShape()); GE_IF_BOOL_EXEC(var->GetOpDesc()->UpdateInputDesc(0, desc) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "UpdateInputDesc failed"); return INTERNAL_ERROR;) @@ -399,7 +427,7 @@ Status RecoverTransRoadForVarRef(const std::set &nodes, const VarTransR GE_CHK_BOOL_EXEC((ge::AttrUtils::SetBool(last_node->GetOpDesc(), ge::ATTR_INSERTED_BY_GE, true)), return INTERNAL_ERROR, "Set attr ATTR_INSERTED_BY_GE failed."); } - if (!(road.empty()) && (UpdateVarFormats(var, road.rbegin()->input) != SUCCESS)) { + if (!(road.empty()) && (UpdateVarFormats(var, road.rbegin()->output) != SUCCESS)) { return INTERNAL_ERROR; } } @@ -589,16 +617,12 @@ Status ModifyInputFormatAndShape(NodePtr &node_ptr) { } input->SetFormat(FORMAT_NC1HWC0); - input->SetOriginFormat(FORMAT_NC1HWC0); input->SetShape(ge::GeShape(dst_shape_dims)); - input->SetOriginShape(ge::GeShape(dst_shape_dims)); auto output = op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(output); output->SetFormat(FORMAT_NC1HWC0); - output->SetOriginFormat(FORMAT_NC1HWC0); output->SetShape(ge::GeShape(dst_shape_dims)); - output->SetOriginShape(ge::GeShape(dst_shape_dims)); int64_t size = 0; graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(*output, size); @@ -623,9 +647,7 @@ Status ModifyFormatAndShapeForSingleTensor(const GeTensorDescPtr &input_output) return FAILED; } input_output->SetFormat(FORMAT_NC1HWC0); - input_output->SetOriginFormat(FORMAT_NC1HWC0); input_output->SetShape(ge::GeShape(dst_shape_dims)); - input_output->SetOriginShape(ge::GeShape(dst_shape_dims)); return SUCCESS; } @@ -1022,7 +1044,7 @@ Status CheckIfNeedSetNdFormat(const NodePtr &node_ptr) { // In the dynamic shape process, transnode insertion by FE is advanced to the stage of whole // graph optimization, GE only sets the final data_type/format/shape information for variable, // data and netoutput, and no longer inserts the transnode. -Status ProcessInputFP16DynShape(NodePtr &node_ptr) { +Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node) { GE_CHECK_NOTNULL(node_ptr); auto op_desc = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -1035,23 +1057,39 @@ Status ProcessInputFP16DynShape(NodePtr &node_ptr) { } input->SetDataType(DT_FLOAT16); input->SetOriginDataType(DT_FLOAT16); - int64_t shape_size = 0; - ge::graphStatus graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, shape_size); - if (graph_status != ge::GRAPH_SUCCESS) { - GELOGE(graph_status, "GetTensorSizeInBytes failed!"); + int64_t input_shape_size = 0; + int64_t output_shape_size = 0; + ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); + ge::graphStatus output_graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, output_shape_size); + if (input_graph_status != ge::GRAPH_SUCCESS && output_graph_status != ge::GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "GetTensorSize failed!"); return FAILED; } - ge::TensorUtils::SetSize(*input, shape_size); + ge::TensorUtils::SetSize(*input, input_shape_size); const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(output); output->SetDataType(DT_FLOAT16); output->SetOriginDataType(DT_FLOAT16); - ge::TensorUtils::SetSize(*output, shape_size); - + ge::TensorUtils::SetSize(*output, output_shape_size); + if (is_dynamic_batch) { + GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str()); + auto switchn_op_desc = switchn_node->GetOpDesc(); + GE_CHECK_NOTNULL(switchn_op_desc); + auto switchn_input = switchn_op_desc->MutableInputDesc(0); + GE_CHECK_NOTNULL(switchn_input); + switchn_input->SetDataType(DT_FLOAT16); + switchn_input->SetOriginDataType(DT_FLOAT16); + for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { + const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i); + GE_CHECK_NOTNULL(switchn_output); + switchn_output->SetDataType(DT_FLOAT16); + switchn_output->SetOriginDataType(DT_FLOAT16); + } + } return SUCCESS; } -Status ProcessInputNC1HWC0DynShape(NodePtr &node_ptr) { +Status ProcessInputNC1HWC0DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node) { GE_CHECK_NOTNULL(node_ptr); auto op_desc = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -1072,51 +1110,94 @@ Status ProcessInputNC1HWC0DynShape(NodePtr &node_ptr) { GELOGE(INTERNAL_ERROR, "modify format and shape failed"); return FAILED; } - + if (is_dynamic_batch) { + auto switchn_op_desc = switchn_node->GetOpDesc(); + GE_CHECK_NOTNULL(switchn_op_desc); + const GeTensorDescPtr &switchn_input = switchn_op_desc->MutableInputDesc(0); + if (ModifyFormatAndShapeForSingleTensor(switchn_input) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "modify format and shape failed"); + return FAILED; + } + for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { + auto switchn_output = switchn_op_desc->MutableOutputDesc(i); + GE_CHECK_NOTNULL(switchn_output); + old_format = switchn_output->GetFormat(); + old_shape = switchn_output->GetShape(); + if (ModifyFormatAndShapeForSingleTensor(switchn_output) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "modify format and shape failed"); + return FAILED; + } + } + } return SUCCESS; } Status ProcessDataNodeDynShape(NodePtr &node_ptr) { + auto op_desc = node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto data_input = op_desc->MutableInputDesc(0); + GE_CHECK_NOTNULL(data_input); bool set_fp16 = false; if (!ge::AttrUtils::GetBool(node_ptr->GetOpDesc(), "input_fp16", set_fp16) || !set_fp16) { return SUCCESS; } for (auto const &next_node : node_ptr->GetOutNodes()) { if (next_node->GetType() == AIPP) { + ErrorManager::GetInstance().ATCReportErrMessage("E10049", {"opname"}, {node_ptr->GetName()}); GELOGE(INTERNAL_ERROR, - "This input node [%s] is linked to aipp, can not be set to fp16," - "please check your atc parma insert_op_conf, input_fp16_nodes.", + "This input op [%s] is linked to aipp, can not be set to fp16, " + "please check your atc parameter --insert_op_conf, --input_fp16_nodes.", node_ptr->GetName().c_str()); return FAILED; } } GELOGI("input_fp16 is found, the node name is %s.", node_ptr->GetName().c_str()); - if (ProcessInputFP16DynShape(node_ptr) != SUCCESS) { + bool is_dynamic_batch = false; + NodePtr switchn_node = nullptr; + if (CheckIfDynamicBatchScene(node_ptr, is_dynamic_batch, switchn_node)) { + GELOGE(INTERNAL_ERROR, "CheckIfDynamicBatchScene failed"); + return FAILED; + } + if (ProcessInputFP16DynShape(node_ptr, is_dynamic_batch, switchn_node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "ProcessInputFP16 failed"); return FAILED; } // check if need to set format bool set_format = false; - if (!ge::AttrUtils::GetBool(node_ptr->GetOpDesc(), "input_set_nc1hwc0", set_format) || !set_format) { - return SUCCESS; - } - GELOGI("The format of node [%s] should be set NC1HWC0.", node_ptr->GetName().c_str()); - if (ProcessInputNC1HWC0DynShape(node_ptr) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "ProcessInputNC1HWC0 failed"); - return FAILED; + (void)ge::AttrUtils::GetBool(node_ptr->GetOpDesc(), "input_set_nc1hwc0", set_format); + if (set_format) { + GELOGI("The format of node [%s] should be set NC1HWC0.", node_ptr->GetName().c_str()); + if (ProcessInputNC1HWC0DynShape(node_ptr, is_dynamic_batch, switchn_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "ProcessInputNC1HWC0 failed"); + return FAILED; + } } return SUCCESS; } Status ProcessNetoutputNodeFp16Nc1hwc0DynShape(GeTensorDesc &src_desc, GeTensorDescPtr &net_output_input_desc, NodePtr &node) { + bool is_dynamic = CheckOpType(node, MERGE); + auto src_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(src_op_desc); ge::GeShape src_shape = src_desc.GetShape(); ge::Format src_format = src_desc.GetFormat(); - ge::DataType src_dtype = src_desc.GetDataType(); - if (src_dtype != DT_FLOAT16) { - net_output_input_desc->SetDataType(DT_FLOAT16); - net_output_input_desc->SetOriginDataType(DT_FLOAT16); + + net_output_input_desc->SetDataType(DT_FLOAT16); + net_output_input_desc->SetOriginDataType(DT_FLOAT16); + if (is_dynamic) { + auto merge_output = src_op_desc->MutableOutputDesc(0); + GE_CHECK_NOTNULL(merge_output); + merge_output->SetDataType(DT_FLOAT16); + merge_output->SetOriginDataType(DT_FLOAT16); + for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + auto merge_input = src_op_desc->MutableInputDesc(i); + GE_CHECK_NOTNULL(merge_input); + merge_input->SetDataType(DT_FLOAT16); + merge_input->SetOriginDataType(DT_FLOAT16); + } } + if (src_format == FORMAT_NC1HWC0) { GELOGI("Format is NC1HWC0, no need to transfer"); return SUCCESS; @@ -1129,32 +1210,80 @@ Status ProcessNetoutputNodeFp16Nc1hwc0DynShape(GeTensorDesc &src_desc, GeTensorD } ge::GeShape dst_shape(dst_shape_dims); net_output_input_desc->SetFormat(FORMAT_NC1HWC0); - net_output_input_desc->SetOriginFormat(FORMAT_NC1HWC0); net_output_input_desc->SetShape(dst_shape); - net_output_input_desc->SetOriginShape(dst_shape); + if (is_dynamic) { + auto merge_out = src_op_desc->MutableOutputDesc(0); + GE_CHECK_NOTNULL(merge_out); + if (ModifyFormatAndShapeForSingleTensor(merge_out) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "modify format and shape failed"); + return FAILED; + } + for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + auto merge_in = src_op_desc->MutableInputDesc(i); + GE_CHECK_NOTNULL(merge_in); + if (ModifyFormatAndShapeForSingleTensor(merge_in) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "modify format and shape failed"); + return FAILED; + } + } + } return SUCCESS; } +bool NeedUpdateOutputByOutputTypeParm(std::string &output_type, NodePtr &src_node, uint32_t src_index, + ge::DataType &dt) { + if (CheckIfSetOutputType(output_type, dt)) { + GELOGI("All output node should be set datatype."); + return true; + } + bool is_dynamic = CheckOpType(src_node, MERGE); + auto op_desc = src_node->GetOpDesc(); + if (is_dynamic) { + const InDataAnchorPtr &merge_input_anchor = src_node->GetInDataAnchor(0); + GE_RT_FALSE_CHECK_NOTNULL(merge_input_anchor); + const OutDataAnchorPtr &src_out_anchor = merge_input_anchor->GetPeerOutAnchor(); + GE_RT_FALSE_CHECK_NOTNULL(src_out_anchor); + src_index = static_cast(src_out_anchor->GetIdx()); + auto src_merge_node = src_out_anchor->GetOwnerNode(); + GE_RT_FALSE_CHECK_NOTNULL(src_merge_node); + op_desc = src_merge_node->GetOpDesc(); + GE_RT_FALSE_CHECK_NOTNULL(op_desc); + } + vector output_data_type_vec; + vector index_vec; + if ((ge::AttrUtils::GetListDataType(op_desc, "_output_dt_list", output_data_type_vec)) && + (ge::AttrUtils::GetListInt(op_desc, "_output_dt_index", index_vec))) { + if (output_data_type_vec.size() != index_vec.size()) { + GELOGW("output_dt_list size is not match output_dt_index size"); + return false; + } + for (uint32_t i = 0; i < index_vec.size(); ++i) { + if (index_vec[i] == src_index) { + dt = output_data_type_vec[i]; + GELOGI("Find node %s output %u datatype should set %s .", op_desc->GetName().c_str(), i, + TypeUtils::DataTypeToSerialString(dt).c_str()); + return true; + } + } + } + return false; +} + Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); ge::DataType output_data_type = ge::DT_FLOAT; - bool is_set_output_type = false; - if (output_type_str_to_datatype.find(output_type) != output_type_str_to_datatype.end()) { - output_data_type = output_type_str_to_datatype[output_type]; - is_set_output_type = true; - } else { - GELOGI("output_type [%s] is not set or set unexpected", output_type.c_str()); - is_set_output_type = false; - } for (const auto &in_anchor : node->GetAllInDataAnchors()) { auto index = static_cast(in_anchor->GetIdx()); auto peer_out = in_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out); auto src_index = static_cast(peer_out->GetIdx()); - auto own_node = peer_out->GetOwnerNode(); - OpDescPtr src_op_desc = own_node->GetOpDesc(); + auto src_node = peer_out->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + bool is_dynamic = CheckOpType(src_node, MERGE); + + OpDescPtr src_op_desc = src_node->GetOpDesc(); GE_CHECK_NOTNULL(src_op_desc); auto net_output_input_desc = op_desc->MutableInputDesc(index); GE_CHECK_NOTNULL(net_output_input_desc); @@ -1163,7 +1292,7 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { ge::Format src_format = src_op_desc->GetOutputDesc(src_index).GetFormat(); ge::DataType src_dtype = src_op_desc->GetOutputDesc(src_index).GetDataType(); // Update datatype - if (is_set_output_type) { + if (NeedUpdateOutputByOutputTypeParm(output_type, src_node, src_index, output_data_type)) { GELOGI("Enter into process output_type schedule"); if (src_dtype == output_data_type) { GELOGI("Data type is same ,no need to transfer."); @@ -1171,22 +1300,46 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { } net_output_input_desc->SetDataType(output_data_type); net_output_input_desc->SetOriginDataType(output_data_type); + if (is_dynamic) { + auto merge_output = src_op_desc->MutableOutputDesc(0); + GE_CHECK_NOTNULL(merge_output); + merge_output->SetDataType(output_data_type); + merge_output->SetOriginDataType(output_data_type); + for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) { + auto merge_input = src_op_desc->MutableInputDesc(i); + GE_CHECK_NOTNULL(merge_input); + merge_input->SetDataType(output_data_type); + merge_input->SetOriginDataType(output_data_type); + } + } continue; } // output_node is not set,check if is_output_adjust_hw_layout is set bool set_fp16_nc1hwc0 = false; - if (AttrUtils::GetBool(src_op_desc, "output_set_fp16_nc1hwc0", set_fp16_nc1hwc0)) { - if (set_fp16_nc1hwc0) { - GELOGI("Node [%s] should be set FP16 and NC1HWC0", src_op_desc->GetName().c_str()); - if ((src_format != FORMAT_NCHW) && (src_format != FORMAT_NHWC) && (src_format != FORMAT_NC1HWC0)) { - GELOGE(INTERNAL_ERROR, "Format is not one of NCHW, NHWC, NC1HWC0."); - return FAILED; - } - GeTensorDesc src_desc(src_shape, src_format, src_dtype); - if (ProcessNetoutputNodeFp16Nc1hwc0DynShape(src_desc, net_output_input_desc, node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Process netoutput fp16 nc1hwc0."); - return FAILED; - } + if (!is_dynamic) { + (void)AttrUtils::GetBool(src_op_desc, "output_set_fp16_nc1hwc0", set_fp16_nc1hwc0); + } else { + // need check dynamic scene, graph structure: node->merge->netoutput + const InDataAnchorPtr &merge_input_anchor = src_node->GetInDataAnchor(0); + GE_CHECK_NOTNULL(merge_input_anchor); + const OutDataAnchorPtr &src_out_anchor = merge_input_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(src_out_anchor); + auto src_merge_node = src_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_merge_node); + auto src_merge_node_opdesc = src_merge_node->GetOpDesc(); + (void)AttrUtils::GetBool(src_merge_node_opdesc, "output_set_fp16_nc1hwc0", set_fp16_nc1hwc0); + } + + if (set_fp16_nc1hwc0) { + GELOGI("Node [%s] should be set FP16 and NC1HWC0", src_op_desc->GetName().c_str()); + if ((src_format != FORMAT_NCHW) && (src_format != FORMAT_NHWC) && (src_format != FORMAT_NC1HWC0)) { + GELOGE(INTERNAL_ERROR, "Format is not one of NCHW, NHWC, NC1HWC0."); + return FAILED; + } + GeTensorDesc src_desc(src_shape, src_format, src_dtype); + if (ProcessNetoutputNodeFp16Nc1hwc0DynShape(src_desc, net_output_input_desc, src_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Process netoutput fp16 nc1hwc0."); + return FAILED; } } } @@ -1434,7 +1587,7 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { } Status GraphPrepare::UpdateInput(const std::vector &user_input) { - compute_graph_->SaveDataFormat((ge::Format)(domi::GetContext().format)); + compute_graph_->SaveDataFormat(ge::TypeUtils::DomiFormatToFormat(domi::GetContext().format)); for (NodePtr &input_node : compute_graph_->GetDirectNode()) { GE_CHECK_NOTNULL(input_node); OpDescPtr op = input_node->GetOpDesc(); @@ -1481,14 +1634,14 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input) { GELOGE(PARAM_INVALID, "input data size =%ld, shape_size =%ld.", size, shape_size); return FAILED; } - ge::TensorUtils::SetSize(desc, shape_size); - graphStatus graph_ret = op->UpdateInputDesc(0, desc); if (graph_ret != GRAPH_SUCCESS) { GELOGE(graph_ret, "UpdateInputDesc fail, graph_ret:%u", graph_ret); return graph_ret; } + // Size will be recalculated in the build stage + ge::TensorUtils::SetSize(desc, 0); graph_ret = op->UpdateOutputDesc(0, desc); if (graph_ret != GRAPH_SUCCESS) { GELOGE(graph_ret, "UpdateOutputDesc fail, graph_ret:%u", graph_ret); @@ -1511,8 +1664,7 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input) { Status GraphPrepare::TryDoAipp() { // infer and with aipp configure file, then call aipp insert if ((!options_.train_graph_flag) && (!options_.insert_op_file.empty())) { - GraphUtils::DumpGEGraph(compute_graph_, "Before_insert_aipp"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "Before_insert_aipp"); + GE_DUMP(compute_graph_, "Before_insert_aipp"); Status ret = ge::InsertNewOpUtil::Instance().Init(); if (ret != SUCCESS) { GELOGE(INTERNAL_ERROR, "TryDoAipp: InsertNewOpUtil instance failed."); @@ -1543,18 +1695,16 @@ Status GraphPrepare::FormatAndShapeProcess() { GE_TIMESTAMP_START(InferOriginFormat1); ret = compute_graph_->InferOriginFormat(); GE_TIMESTAMP_END(InferOriginFormat1, "GraphPrepare::InferOriginFormat1"); + GE_DUMP(compute_graph_, "after_first_inferformat"); if (ret != SUCCESS) { GELOGE(ret, "Prepare Graph first inferformat failed"); return ret; } - GraphUtils::DumpGEGraph(compute_graph_, "after_first_inferformat"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "after_first_inferformat"); GE_TIMESTAMP_START(InferShapeForPreprocess); ret = InferShapeForPreprocess(); GE_TIMESTAMP_END(InferShapeForPreprocess, "GraphPrepare::InferShapeForPreprocess"); - GraphUtils::DumpGEGraph(compute_graph_, "after_infershape"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "after_infershape"); + GE_DUMP(compute_graph_, "after_infershape"); if (ret != SUCCESS) { GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "Prepare Graph infershape failed"); return GE_GRAPH_INFERSHAPE_FAILED; @@ -1581,9 +1731,10 @@ Status GraphPrepare::ResourcePairProcess(const std::string &action) { if (options_.train_graph_flag) { try { if (action == "add") { - (void)control_pass.AddPass(new ResourcePairAddControlPass); + (void)control_pass.AddPass("ResourcePairProcess::ResourcePairAddControlPass", new ResourcePairAddControlPass); } else { - (void)control_pass.AddPass(new ResourcePairRemoveControlPass); + (void)control_pass.AddPass("ResourcePairProcess::ResourcePairRemoveControlPass", + new ResourcePairRemoveControlPass); } } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occur, action:%s.", action.c_str()); @@ -1668,9 +1819,9 @@ Status GraphPrepare::OptimizeBeforeInfershape() { // Graph pass try { if (options_.train_graph_flag) { - (void)graph_passes_before_infershape.AddPass(new SavePass); + (void)graph_passes_before_infershape.AddPass("OptimizeBeforeInfershape::SavePass", new SavePass); } - (void)graph_passes_before_infershape.AddPass(new NetOutputPass); + (void)graph_passes_before_infershape.AddPass("OptimizeBeforeInfershape::NetOutputPass", new NetOutputPass); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR; @@ -1718,8 +1869,7 @@ Status GraphPrepare::Preprocess(const std::vector &user_input) { GELOGE(ret, "Check user input failed."); return ret; } - GraphUtils::DumpGEGraph(compute_graph_, "after_update_input"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "after_update_input"); + GE_DUMP(compute_graph_, "after_update_input"); GEPass ge_passes(compute_graph_); NamesToPass names_to_passes; @@ -1732,8 +1882,7 @@ Status GraphPrepare::Preprocess(const std::vector &user_input) { GELOGE(ret, "Run ForPass optimize for preprocess failed, ret:%u.", ret); return ret; } - GraphUtils::DumpGEGraph(compute_graph_, "after_for_pass"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "after_for_pass"); + GE_DUMP(compute_graph_, "after_for_pass"); GE_TIMESTAMP_START(netoutput_process); ret = ProcessNetOutput(); @@ -1748,8 +1897,7 @@ Status GraphPrepare::Preprocess(const std::vector &user_input) { GELOGE(ret, "Failed to do multi-batch processing"); return ret; } - GraphUtils::DumpGEGraph(compute_graph_, "after_multibatch_process"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "after_multibatch_process"); + GE_DUMP(compute_graph_, "after_multibatch_process"); ret = TryDoAipp(); if (ret != SUCCESS) { @@ -1763,8 +1911,7 @@ Status GraphPrepare::Preprocess(const std::vector &user_input) { GELOGE(ret, "FormatAndShape process failed"); return ret; } - GraphUtils::DumpGEGraph(compute_graph_, "after_inferformat_before_preprocess"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "after_inferformat_before_preprocess"); + GE_DUMP(compute_graph_, "after_inferformat_before_preprocess"); ProcessCCEFormat(); @@ -1794,16 +1941,14 @@ Status GraphPrepare::Preprocess(const std::vector &user_input) { } GELOGI("Update variable formats success."); - GraphUtils::DumpGEGraph(compute_graph_, "Optimize_after_preprocess"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "Optimize_after_preprocess"); + GE_DUMP(compute_graph_, "Optimize_after_preprocess"); return SUCCESS; } #define PP_RUN_AND_DUMP(name, func, ...) \ do { \ GE_RUN(Prepare, func, __VA_ARGS__); \ - GraphUtils::DumpGEGraph(compute_graph, "PrepareAfter" name); \ - GraphUtils::DumpGEGraphToOnnx(*compute_graph, "PrepareAfter" name); \ + GE_DUMP(compute_graph, "PrepareAfter" name); \ GELOGI("Prepare %s on graph %s success.", name, compute_graph->GetName().c_str()); \ } while (0) @@ -1826,20 +1971,72 @@ Status GraphPrepare::PrepareDynShape(ConstGraphPtr graph, const std::vectorGetDirectNode()) { + if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_id)) { + (void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, ""); + } + } + ret = compute_graph->TopologicalSorting(); + if (ret != SUCCESS) { + GELOGE(ret, "Graph topological sort failed, ret:%u.", ret); + return ret; + } + return SUCCESS; +} #undef PP_RUN_AND_DUMP #undef PP_RUN @@ -1888,8 +2085,7 @@ Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &u GraphOptimize graph_optimize; if (!domi::GetContext().train_flag) { - GraphUtils::DumpGEGraph(compute_graph_, "BeforeOriginalGraphForQuantize"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "BeforeOriginalGraphForQuantize"); + GE_DUMP(compute_graph_, "BeforeOriginalGraphForQuantize"); GE_TIMESTAMP_START(OptimizeOriginalGraphForQuantize); ret = graph_optimize.OptimizeOriginalGraphForQuantize(compute_graph_); GE_TIMESTAMP_END(OptimizeOriginalGraphForQuantize, "GraphPrepare::OptimizeOriginalGraphForQuantize"); @@ -1898,9 +2094,7 @@ Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &u return ret; } } - - GraphUtils::DumpGEGraph(compute_graph_, "BeforePreprocess"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "BeforePreprocess"); + GE_DUMP(compute_graph_, "BeforePreprocess"); GE_TIMESTAMP_START(Preprocess); ret = Preprocess(user_input); @@ -1915,9 +2109,7 @@ Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &u if (options_.local_fmk_op_flag) { graph_optimize.TranFrameOp(compute_graph_); } - - GraphUtils::DumpGEGraph(compute_graph_, "Prepare"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "Prepare"); + GE_DUMP(compute_graph_, "Prepare"); GE_TIMESTAMP_START(OptimizeOriginalGraph); const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); @@ -1927,8 +2119,7 @@ Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &u ret = graph_optimize.OptimizeOriginalGraph(compute_graph_); } GE_TIMESTAMP_END(OptimizeOriginalGraph, "GraphPrepare::OptimizeOriginalGraph"); - GraphUtils::DumpGEGraph(compute_graph_, "PreProcessOptimizeOriginalGraphAfter"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "PreProcessOptimizeOriginalGraphAfter"); + GE_DUMP(compute_graph_, "PreProcessOptimizeOriginalGraphAfter"); if (ret != SUCCESS) { GELOGE(ret, "originalGraph optimize Failed"); return ret; @@ -2061,6 +2252,8 @@ Status GraphPrepare::InferShapeForPreprocess() { } InferShapePass infer_shape_pass; names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); + ReplaceWithEmptyConstPass replace_with_empty_const_pass; + names_to_passes.emplace_back("ReplaceWithEmptyConstPass", &replace_with_empty_const_pass); DimensionComputePass dimension_compute_pass; names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); ConstantFoldingPass constant_folding_pass; @@ -2102,7 +2295,8 @@ Status GraphPrepare::PrepareOptimize() { PassManager original_graph_passes; // Graph pass try { - (void)original_graph_passes.AddPass(new ShapeOperateOpRemovePass); + (void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); + (void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR; @@ -2121,6 +2315,8 @@ Status GraphPrepare::PrepareOptimize() { EnterPass enter_pass; PrintOpPass print_pass; names_to_passes.emplace_back("EnterPass", &enter_pass); + CondPass cond_pass; + names_to_passes.emplace_back("CondPass", &cond_pass); if (options_.enable_print_op_pass) { names_to_passes.emplace_back("PrintOpPass", &print_pass); } @@ -2135,6 +2331,7 @@ Status GraphPrepare::PrepareOptimize() { PlaceholderWithDefaultPass placeholder_with_default_pass; GuaranteeConstPass guarantee_const_pass; VarIsInitializedOpPass var_is_initialized_pass; + ParallelConcatStartOpPass parallel_concat_start_op_pass; IdentityPass identity_pass(false); SnapshotPass snapshot_pass; if (!options_.train_graph_flag) { @@ -2148,6 +2345,7 @@ Status GraphPrepare::PrepareOptimize() { names_to_passes.emplace_back("SnapshotPass", &snapshot_pass); names_to_passes.emplace_back("GuaranteeConstPass", &guarantee_const_pass); names_to_passes.emplace_back("VarIsInitializedOpPass", &var_is_initialized_pass); + names_to_passes.emplace_back("ParallelConcatStartOpPass", ¶llel_concat_start_op_pass); names_to_passes.emplace_back("IdentityPass", &identity_pass); GE_TIMESTAMP_START(names_to_passes); ret = ge_passes.Run(names_to_passes); @@ -2159,31 +2357,21 @@ Status GraphPrepare::PrepareOptimize() { PassManager graph_pass; try { - (void)graph_pass.AddPass(new PrunePass); - (void)graph_pass.AddPass(new NextIterationPass); - (void)graph_pass.AddPass(new ControlTriggerPass); - (void)graph_pass.AddPass(new SwitchOpPass); + (void)graph_pass.AddPass("PrepareOptimize::PrunePass", new PrunePass); + // todo 临时把hccl的memcpy插入放到图准备,为了防止其多插memcpy + (void)graph_pass.AddPass("PrepareOptimize::HcclMemcpyPass", new (std::nothrow) HcclMemcpyPass); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR; } + GE_TIMESTAMP_START(graph_passes); ret = graph_pass.Run(compute_graph_); + GE_TIMESTAMP_END(graph_passes, "GraphPrepare::GraphPasses"); if (ret != SUCCESS && ret != NOT_CHANGED) { GELOGE(ret, "Run graph passes optimize for preprocess failed, ret:%u.", ret); return ret; } - - NamesToPass identity_remove_pass; - GE_TIMESTAMP_START(identity_remove_pass); - IdentityPass identity_force_pass(true); // after SwitchOpPass - identity_remove_pass.emplace_back("IdentityPass", &identity_force_pass); - ret = ge_passes.Run(identity_remove_pass); - GE_TIMESTAMP_END(identity_remove_pass, "GraphPrepare::IdentityRemovePass"); - if (ret != SUCCESS) { - GELOGE(ret, "Run identity remove pass for preprocess failed, ret:%u.", ret); - return ret; - } // The constant for train is CONSTANTOP, and is CONSTANT for inference. They will be unified in future. if (options_.train_graph_flag) { for (ge::NodePtr &n : compute_graph_->GetAllNodes()) { @@ -2193,7 +2381,6 @@ Status GraphPrepare::PrepareOptimize() { } } } - ret = compute_graph_->TopologicalSorting(); if (ret != SUCCESS) { GELOGE(ret, "Graph topological sort failed, ret:%u.", ret); @@ -2209,11 +2396,12 @@ Status GraphPrepare::OptimizeForPreprocess() { PassManager original_graph_passes; // Graph pass try { - (void)original_graph_passes.AddPass(new ConstantFuseSamePass); - (void)original_graph_passes.AddPass(new VariablePrepareOpPass); - (void)original_graph_passes.AddPass(new IteratorOpPass); - (void)original_graph_passes.AddPass(new ShapeOperateOpRemovePass); - (void)original_graph_passes.AddPass(new ReplaceTransShapePass); + (void)original_graph_passes.AddPass("OptimizeForPreprocess::ConstantFuseSamePass", new ConstantFuseSamePass); + (void)original_graph_passes.AddPass("OptimizeForPreprocess::VariablePrepareOpPass", new VariablePrepareOpPass); + (void)original_graph_passes.AddPass("OptimizeForPreprocess::IteratorOpPass", new IteratorOpPass); + (void)original_graph_passes.AddPass("OptimizeForPreprocess::ShapeOperateOpRemovePass", + new ShapeOperateOpRemovePass); + (void)original_graph_passes.AddPass("OptimizeForPreprocess::ReplaceTransShapePass", new ReplaceTransShapePass); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR; @@ -2267,8 +2455,8 @@ Status GraphPrepare::OptimizeForPreprocess() { names_to_passes.emplace_back("ParallelConcatStartOpPass", ¶llel_concat_start_op_pass); IdentityPass identity_pass(false); names_to_passes.emplace_back("IdentityPass", &identity_pass); - SwitchPass switch_pass; - names_to_passes.emplace_back("SwitchPass", &switch_pass); + SwitchDeadBranchElimination switch_dead_branch_elimination; + names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); SwitchLogicRemovePass switch_logic_remove_pass; names_to_passes.emplace_back("SwitchLogicRemovePass", &switch_logic_remove_pass); MergePass merge_pass; @@ -2283,12 +2471,13 @@ Status GraphPrepare::OptimizeForPreprocess() { PassManager graph_pass; try { - (void)graph_pass.AddPass(new PrunePass); - (void)graph_pass.AddPass(new NextIterationPass); - (void)graph_pass.AddPass(new ControlTriggerPass); - (void)graph_pass.AddPass(new SwitchOpPass); - (void)graph_pass.AddPass(new HcclMemcpyPass); - GE_IF_BOOL_EXEC(options_.train_graph_flag, (void)graph_pass.AddPass(new FlowCtrlPass);); + (void)graph_pass.AddPass("OptimizeForPreprocess::PrunePass", new PrunePass); + (void)graph_pass.AddPass("OptimizeForPreprocess::NextIterationPass", new NextIterationPass); + (void)graph_pass.AddPass("OptimizeForPreprocess::ControlTriggerPass", new ControlTriggerPass); + (void)graph_pass.AddPass("OptimizeForPreprocess::SwitchOpPass", new SwitchOpPass); + (void)graph_pass.AddPass("OptimizeForPreprocess::HcclMemcpyPass", new HcclMemcpyPass); + GE_IF_BOOL_EXEC(options_.train_graph_flag, + (void)graph_pass.AddPass("OptimizeForPreprocess::FlowCtrlPass", new FlowCtrlPass);); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR; @@ -2300,26 +2489,6 @@ Status GraphPrepare::OptimizeForPreprocess() { return ret; } - NamesToPass identity_remove_pass; - GE_TIMESTAMP_START(identity_remove_pass); - IdentityPass identity_force_pass(true); // after SwitchOpPass - identity_remove_pass.emplace_back("IdentityPass", &identity_force_pass); - ret = ge_passes.Run(identity_remove_pass); - GE_TIMESTAMP_END(identity_remove_pass, "GraphPrepare::IdentityRemovePass"); - if (ret != SUCCESS) { - GELOGE(ret, "Run identity remove pass for preprocess failed, ret:%u.", ret); - return ret; - } - // The constant for train is CONSTANTOP, and is CONSTANT for inference. They will be unified in future. - if (options_.train_graph_flag) { - for (ge::NodePtr &n : compute_graph_->GetAllNodes()) { - // This can ensure that n is not a null pointer - if (n->GetOpDesc()->GetType() == CONSTANT) { - n->GetOpDesc()->SetType(CONSTANTOP); - } - } - } - ret = compute_graph_->TopologicalSorting(); if (ret != SUCCESS) { GELOGE(ret, "Graph topological sort failed, ret:%u.", ret); @@ -2331,13 +2500,36 @@ Status GraphPrepare::OptimizeForPreprocess() { return SUCCESS; } +Status GraphPrepare::GraphEquivalentTransformation() { + NamesToPass names_to_pass; + ForPass for_pass; + names_to_pass.emplace_back("ForToWhilePass", &for_pass); + return GEPass(compute_graph_).Run(names_to_pass); +} + +Status GraphPrepare::ProcessBeforeInfershape() { + NamesToPass names_to_passes; + CondRemovePass condition_remove_pass; + names_to_passes.emplace_back("CondRemovePass", &condition_remove_pass); + GE_TIMESTAMP_START(ProcessCondRemove); + auto ret = GEPass(compute_graph_).Run(names_to_passes); + GE_TIMESTAMP_END(ProcessCondRemove, "GraphManager::ProcessCondRemove"); + if (ret != SUCCESS) { + GELOGE(ret, "Run ge_passes optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); + return ret; + } + return SUCCESS; +} + Status GraphPrepare::ProcessNetOutput() { PassManager graph_passes_before_infershape; try { if (options_.train_graph_flag) { - graph_passes_before_infershape.AddPass(new (std::nothrow) SavePass); + graph_passes_before_infershape.AddPass("ProcessNetOutput::SavePass", new (std::nothrow) SavePass); } - graph_passes_before_infershape.AddPass(new (std::nothrow) NetOutputPass); + graph_passes_before_infershape.AddPass("ProcessNetOutput::NetOutputPass", new (std::nothrow) NetOutputPass); + graph_passes_before_infershape.AddPass("ProcessNetOutput::DataPass", + new (std::nothrow) DataPass); // Add NetOutput first. } catch (std::bad_alloc) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR; @@ -2354,7 +2546,8 @@ Status GraphPrepare::ProcessNetOutput() { Status GraphPrepare::NewOptimizeGraphBeforeSubGraph(VarAccelerateCtrl &var_acc_ctrl) { GELOGD("NewOptimizeGraphBeforeSubGraph in"); PassManager passes; - (void)passes.AddPass(new (std::nothrow) CommonSubexpressionEliminationPass); + (void)passes.AddPass("NewOptimizeGraphBeforeSubGraph::CommonSubexpressionEliminationPass", + new (std::nothrow) CommonSubexpressionEliminationPass); auto ret = passes.Run(compute_graph_); if (ret != SUCCESS) { GELOGE(ret, "Failed to optimize for graph"); @@ -2382,16 +2575,24 @@ Status GraphPrepare::NewOptimizeGraphBeforeSubGraph(VarAccelerateCtrl &var_acc_c GELOGI("get ge.exec.variable_acc failed. set default value."); } PassManager pass_manager; - GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) PermutePass)) - GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) VariablePrepareOpPass)) + GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::PermutePass", new (std::nothrow) PermutePass)) + GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::VariablePrepareOpPass", + new (std::nothrow) VariablePrepareOpPass)) GE_IF_BOOL_EXEC(options == "default" || options == "1", GELOGI("turn on variable accelerator"); - GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) VariableOpPass(&var_acc_ctrl)))) - GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) TransOpDepthFusionPass)) - GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) TransOpBreadthFusionPass)) - GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) VariableRefDeleteOpPass)) - GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) SameTransdataBreadthFusionPass)) - GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) TransOpWithoutReshapeFusionPass)) - GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) LinkGenMaskNodesPass(options_.stream_max_parallel_num))) + GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::VariableOpPass", + new (std::nothrow) VariableOpPass(&var_acc_ctrl)))) + GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::TransOpWithoutReshapeFusionPass", + new (std::nothrow) TransOpWithoutReshapeFusionPass)) + GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::TransOpDepthFusionPass", + new (std::nothrow) TransOpDepthFusionPass)) + GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::TransOpBreadthFusionPass", + new (std::nothrow) TransOpBreadthFusionPass)) + GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::VariableRefDeleteOpPass", + new (std::nothrow) VariableRefDeleteOpPass)) + GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::SameTransdataBreadthFusionPass", + new (std::nothrow) SameTransdataBreadthFusionPass)) + GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::LinkGenMaskNodesPass", + new (std::nothrow) LinkGenMaskNodesPass(options_.stream_max_parallel_num))) GE_TIMESTAMP_START(pass_manager); ret = pass_manager.Run(compute_graph_); @@ -2434,8 +2635,9 @@ Status GraphPrepare::NewOptimizeGraphBeforeSubGraph(VarAccelerateCtrl &var_acc_c Status GraphPrepare::OptimizeGraphBeforeSubGraph() { PassManager passes; - (void)passes.AddPass(new (std::nothrow) VariablePrepareOpPass); - (void)passes.AddPass(new (std::nothrow) CommonSubexpressionEliminationPass); + (void)passes.AddPass("OptimizeGraphBeforeSubGraph::VariablePrepareOpPass", new (std::nothrow) VariablePrepareOpPass); + (void)passes.AddPass("OptimizeGraphBeforeSubGraph::CommonSubexpressionEliminationPass", + new (std::nothrow) CommonSubexpressionEliminationPass); auto ret = passes.Run(compute_graph_); if (ret != SUCCESS) { GELOGE(ret, "Failed to optimize for graph"); @@ -2494,7 +2696,7 @@ Status GraphPrepare::UpdateInputOutputByOptions() { GELOGE(INTERNAL_ERROR, "Set node [%s] format ND failed", node_ptr->GetName().c_str()); return FAILED; } - // todo do not insert trans op + if (node_ptr->GetType() == DATA) { if (ProcessDataNodeDynShape(node_ptr) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Process data node failed"); @@ -2519,7 +2721,7 @@ bool GraphPrepare::IsBroadCastOpData(const ge::NodePtr &var_node) { GE_RT_FALSE_CHECK_NOTNULL(in_anchor); ge::NodePtr dst_node = in_anchor->GetOwnerNode(); GE_RT_FALSE_CHECK_NOTNULL(dst_node); - if (dst_node->GetType() == HCOMBROADCAST) { + if (dst_node->GetType() == HCOMBROADCAST || dst_node->GetType() == HVDCALLBACKBROADCAST) { return true; } } @@ -2586,5 +2788,4 @@ void GraphPrepare::AdjustAssignOpData(const ge::NodePtr &var_node) { GELOGW("SetStr var_is_restore failed"); } } - } // namespace ge diff --git a/src/ge/graph/preprocess/graph_preprocess.h b/src/ge/graph/preprocess/graph_preprocess.h index 767ef96e..3c8646f7 100644 --- a/src/ge/graph/preprocess/graph_preprocess.h +++ b/src/ge/graph/preprocess/graph_preprocess.h @@ -49,8 +49,10 @@ class GraphPrepare { VarAccelerateCtrl &var_acc_ctrl, uint64_t session_id = 0); Status PrepareDynShape(ConstGraphPtr graph, const std::vector &user_input, ge::ComputeGraphPtr &compute_graph, uint64_t session_id = 0); + Status PrepareRunningFormatRefiner(); void SetOptions(const GraphManagerOptions &options); Status GenerateInfershapeGraph(ConstGraphPtr graph); + Status SwitchOpOptimize(ComputeGraphPtr &compute_graph); private: Status Init(const ge::Graph &graph, uint64_t session_id = 0); @@ -81,6 +83,7 @@ class GraphPrepare { Status NewOptimizeGraphBeforeSubGraph(VarAccelerateCtrl &var_acc_ctrl); Status SaveOriginalGraphToOmModel(); Status ProcessNetOutput(); + Status ProcessBeforeInfershape(); Status UpdateInputOutputByOptions(); bool IsBroadCastOpData(const ge::NodePtr &var_node); @@ -95,7 +98,7 @@ class GraphPrepare { bool ConfirmUseOpAndIndexByNode(const ge::NodePtr &var_node, const map> &confirm_ops, ge::NodePtr &use_node); - + Status GraphEquivalentTransformation(); ge::ComputeGraphPtr compute_graph_; GraphManagerOptions options_; }; diff --git a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc index 277d711a..9ce87d38 100644 --- a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc +++ b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc @@ -329,7 +329,7 @@ Status AippOp::GetAndCheckTarget(const ComputeGraphPtr &graph, int rank, NodePtr auto switchn = graph->FindNode(related_node_name); if (switchn == nullptr) { GELOGE(INTERNAL_ERROR, "The data node %s has switchn node %s, but can not find it on the graph", - data_node->GetName().c_str(), switchn->GetName().c_str()); + data_node->GetName().c_str(), related_node_name.c_str()); return INTERNAL_ERROR; } target = switchn; diff --git a/src/ge/graph/passes/folding_kernel/add_kernel.cc b/src/ge/host_kernels/add_kernel.cc similarity index 99% rename from src/ge/graph/passes/folding_kernel/add_kernel.cc rename to src/ge/host_kernels/add_kernel.cc index 89f99938..6d6a049c 100644 --- a/src/ge/graph/passes/folding_kernel/add_kernel.cc +++ b/src/ge/host_kernels/add_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/add_kernel.h" +#include "host_kernels/add_kernel.h" #include diff --git a/src/ge/graph/passes/folding_kernel/add_kernel.h b/src/ge/host_kernels/add_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/add_kernel.h rename to src/ge/host_kernels/add_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/broadcast_args_kernel.cc b/src/ge/host_kernels/broadcast_args_kernel.cc similarity index 97% rename from src/ge/graph/passes/folding_kernel/broadcast_args_kernel.cc rename to src/ge/host_kernels/broadcast_args_kernel.cc index 364fb415..da2d00ab 100644 --- a/src/ge/graph/passes/folding_kernel/broadcast_args_kernel.cc +++ b/src/ge/host_kernels/broadcast_args_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/broadcast_args_kernel.h" +#include "host_kernels/broadcast_args_kernel.h" #include diff --git a/src/ge/graph/passes/folding_kernel/broadcast_args_kernel.h b/src/ge/host_kernels/broadcast_args_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/broadcast_args_kernel.h rename to src/ge/host_kernels/broadcast_args_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc b/src/ge/host_kernels/broadcast_gradient_args_kernel.cc similarity index 97% rename from src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc rename to src/ge/host_kernels/broadcast_gradient_args_kernel.cc index 0053a9df..ed790dab 100644 --- a/src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc +++ b/src/ge/host_kernels/broadcast_gradient_args_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/broadcast_gradient_args_kernel.h" +#include "host_kernels/broadcast_gradient_args_kernel.h" #include diff --git a/src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.h b/src/ge/host_kernels/broadcast_gradient_args_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.h rename to src/ge/host_kernels/broadcast_gradient_args_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/cast_kernel.cc b/src/ge/host_kernels/cast_kernel.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/cast_kernel.cc rename to src/ge/host_kernels/cast_kernel.cc index 99944c20..106aa1c2 100644 --- a/src/ge/graph/passes/folding_kernel/cast_kernel.cc +++ b/src/ge/host_kernels/cast_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/cast_kernel.h" +#include "host_kernels/cast_kernel.h" #include #include @@ -29,7 +29,7 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/common/bcast.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/cast_kernel.h b/src/ge/host_kernels/cast_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/cast_kernel.h rename to src/ge/host_kernels/cast_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/concat_offset_kernel.cc b/src/ge/host_kernels/concat_offset_kernel.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/concat_offset_kernel.cc rename to src/ge/host_kernels/concat_offset_kernel.cc index f5146b5b..2e609d68 100644 --- a/src/ge/graph/passes/folding_kernel/concat_offset_kernel.cc +++ b/src/ge/host_kernels/concat_offset_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/concat_offset_kernel.h" +#include "host_kernels/concat_offset_kernel.h" #include diff --git a/src/ge/graph/passes/folding_kernel/concat_offset_kernel.h b/src/ge/host_kernels/concat_offset_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/concat_offset_kernel.h rename to src/ge/host_kernels/concat_offset_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/concat_v2_kernel.cc b/src/ge/host_kernels/concat_v2_kernel.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/concat_v2_kernel.cc rename to src/ge/host_kernels/concat_v2_kernel.cc index 9d06b612..81127302 100644 --- a/src/ge/graph/passes/folding_kernel/concat_v2_kernel.cc +++ b/src/ge/host_kernels/concat_v2_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/concat_v2_kernel.h" +#include "host_kernels/concat_v2_kernel.h" #include #include @@ -24,7 +24,7 @@ #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/concat_v2_kernel.h b/src/ge/host_kernels/concat_v2_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/concat_v2_kernel.h rename to src/ge/host_kernels/concat_v2_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/dynamic_stitch_kernel.cc b/src/ge/host_kernels/dynamic_stitch_kernel.cc similarity index 99% rename from src/ge/graph/passes/folding_kernel/dynamic_stitch_kernel.cc rename to src/ge/host_kernels/dynamic_stitch_kernel.cc index 94576ed1..c8a19e44 100644 --- a/src/ge/graph/passes/folding_kernel/dynamic_stitch_kernel.cc +++ b/src/ge/host_kernels/dynamic_stitch_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/dynamic_stitch_kernel.h" +#include "host_kernels/dynamic_stitch_kernel.h" #include #include diff --git a/src/ge/graph/passes/folding_kernel/dynamic_stitch_kernel.h b/src/ge/host_kernels/dynamic_stitch_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/dynamic_stitch_kernel.h rename to src/ge/host_kernels/dynamic_stitch_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/empty_kernel.cc b/src/ge/host_kernels/empty_kernel.cc similarity index 97% rename from src/ge/graph/passes/folding_kernel/empty_kernel.cc rename to src/ge/host_kernels/empty_kernel.cc index 1b135b9c..856caf50 100644 --- a/src/ge/graph/passes/folding_kernel/empty_kernel.cc +++ b/src/ge/host_kernels/empty_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/empty_kernel.h" +#include "host_kernels/empty_kernel.h" #include @@ -23,7 +23,7 @@ #include "common/types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/passes/pass_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/empty_kernel.h b/src/ge/host_kernels/empty_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/empty_kernel.h rename to src/ge/host_kernels/empty_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/expanddims_kernel.cc b/src/ge/host_kernels/expanddims_kernel.cc similarity index 96% rename from src/ge/graph/passes/folding_kernel/expanddims_kernel.cc rename to src/ge/host_kernels/expanddims_kernel.cc index f4091d2d..1d17ad48 100644 --- a/src/ge/graph/passes/folding_kernel/expanddims_kernel.cc +++ b/src/ge/host_kernels/expanddims_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/expanddims_kernel.h" +#include "host_kernels/expanddims_kernel.h" #include @@ -22,7 +22,7 @@ #include "common/op/ge_op_utils.h" #include "common/types.h" #include "framework/common/debug/ge_log.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" namespace ge { diff --git a/src/ge/graph/passes/folding_kernel/expanddims_kernel.h b/src/ge/host_kernels/expanddims_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/expanddims_kernel.h rename to src/ge/host_kernels/expanddims_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/fill_kernel.cc b/src/ge/host_kernels/fill_kernel.cc similarity index 97% rename from src/ge/graph/passes/folding_kernel/fill_kernel.cc rename to src/ge/host_kernels/fill_kernel.cc index 3a3aa597..27bcb9aa 100644 --- a/src/ge/graph/passes/folding_kernel/fill_kernel.cc +++ b/src/ge/host_kernels/fill_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/fill_kernel.h" +#include "host_kernels/fill_kernel.h" #include #include @@ -23,7 +23,7 @@ #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/fill_kernel.h b/src/ge/host_kernels/fill_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/fill_kernel.h rename to src/ge/host_kernels/fill_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/floordiv_kernel.cc b/src/ge/host_kernels/floordiv_kernel.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/floordiv_kernel.cc rename to src/ge/host_kernels/floordiv_kernel.cc index 81595822..4175df92 100644 --- a/src/ge/graph/passes/folding_kernel/floordiv_kernel.cc +++ b/src/ge/host_kernels/floordiv_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/floordiv_kernel.h" +#include "host_kernels/floordiv_kernel.h" #include @@ -24,7 +24,7 @@ #include "common/op/ge_op_utils.h" #include "common/types.h" #include "framework/common/debug/ge_log.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/floordiv_kernel.h b/src/ge/host_kernels/floordiv_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/floordiv_kernel.h rename to src/ge/host_kernels/floordiv_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/floormod_kernel.cc b/src/ge/host_kernels/floormod_kernel.cc similarity index 99% rename from src/ge/graph/passes/folding_kernel/floormod_kernel.cc rename to src/ge/host_kernels/floormod_kernel.cc index d7fb3b1c..a8c16c9d 100644 --- a/src/ge/graph/passes/folding_kernel/floormod_kernel.cc +++ b/src/ge/host_kernels/floormod_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/floormod_kernel.h" +#include "host_kernels/floormod_kernel.h" #include #include diff --git a/src/ge/graph/passes/folding_kernel/floormod_kernel.h b/src/ge/host_kernels/floormod_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/floormod_kernel.h rename to src/ge/host_kernels/floormod_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/gather_v2_kernel.cc b/src/ge/host_kernels/gather_v2_kernel.cc similarity index 99% rename from src/ge/graph/passes/folding_kernel/gather_v2_kernel.cc rename to src/ge/host_kernels/gather_v2_kernel.cc index 732e0b53..c8cc3006 100644 --- a/src/ge/graph/passes/folding_kernel/gather_v2_kernel.cc +++ b/src/ge/host_kernels/gather_v2_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/gather_v2_kernel.h" +#include "host_kernels/gather_v2_kernel.h" #include #include @@ -25,7 +25,7 @@ #include "common/types.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/gather_v2_kernel.h b/src/ge/host_kernels/gather_v2_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/gather_v2_kernel.h rename to src/ge/host_kernels/gather_v2_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/greater_kernel.cc b/src/ge/host_kernels/greater_kernel.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/greater_kernel.cc rename to src/ge/host_kernels/greater_kernel.cc index 816d3d05..f23eee2f 100644 --- a/src/ge/graph/passes/folding_kernel/greater_kernel.cc +++ b/src/ge/host_kernels/greater_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/greater_kernel.h" +#include "host_kernels/greater_kernel.h" #include #include diff --git a/src/ge/graph/passes/folding_kernel/greater_kernel.h b/src/ge/host_kernels/greater_kernel.h similarity index 91% rename from src/ge/graph/passes/folding_kernel/greater_kernel.h rename to src/ge/host_kernels/greater_kernel.h index 84b5bc87..3697a8e8 100644 --- a/src/ge/graph/passes/folding_kernel/greater_kernel.h +++ b/src/ge/host_kernels/greater_kernel.h @@ -37,8 +37,8 @@ class GreaterKernel : public Kernel { Status GreaterCheck(const std::vector &input); const std::set greater_supported_type = { - DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, - DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE, + DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE, }; }; } // namespace ge diff --git a/src/ge/graph/passes/folding_kernel/kernel_utils.cc b/src/ge/host_kernels/kernel_utils.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/kernel_utils.cc rename to src/ge/host_kernels/kernel_utils.cc index 2002643a..9bcd9e3c 100644 --- a/src/ge/graph/passes/folding_kernel/kernel_utils.cc +++ b/src/ge/host_kernels/kernel_utils.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include diff --git a/src/ge/graph/passes/folding_kernel/kernel_utils.h b/src/ge/host_kernels/kernel_utils.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/kernel_utils.h rename to src/ge/host_kernels/kernel_utils.h diff --git a/src/ge/graph/passes/folding_kernel/maximum_kernel.cc b/src/ge/host_kernels/maximum_kernel.cc similarity index 99% rename from src/ge/graph/passes/folding_kernel/maximum_kernel.cc rename to src/ge/host_kernels/maximum_kernel.cc index 5f83f0d5..aca4ec2b 100644 --- a/src/ge/graph/passes/folding_kernel/maximum_kernel.cc +++ b/src/ge/host_kernels/maximum_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/maximum_kernel.h" +#include "host_kernels/maximum_kernel.h" #include #include diff --git a/src/ge/graph/passes/folding_kernel/maximum_kernel.h b/src/ge/host_kernels/maximum_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/maximum_kernel.h rename to src/ge/host_kernels/maximum_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/mul_kernel.cc b/src/ge/host_kernels/mul_kernel.cc similarity index 99% rename from src/ge/graph/passes/folding_kernel/mul_kernel.cc rename to src/ge/host_kernels/mul_kernel.cc index 4ca740d1..8dbe83a5 100644 --- a/src/ge/graph/passes/folding_kernel/mul_kernel.cc +++ b/src/ge/host_kernels/mul_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/mul_kernel.h" +#include "host_kernels/mul_kernel.h" #include #include diff --git a/src/ge/graph/passes/folding_kernel/mul_kernel.h b/src/ge/host_kernels/mul_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/mul_kernel.h rename to src/ge/host_kernels/mul_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/pack_kernel.cc b/src/ge/host_kernels/pack_kernel.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/pack_kernel.cc rename to src/ge/host_kernels/pack_kernel.cc index 5db3b394..f3f64a6c 100644 --- a/src/ge/graph/passes/folding_kernel/pack_kernel.cc +++ b/src/ge/host_kernels/pack_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/pack_kernel.h" +#include "host_kernels/pack_kernel.h" #include #include @@ -25,7 +25,7 @@ #include "common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/pack_kernel.h b/src/ge/host_kernels/pack_kernel.h similarity index 99% rename from src/ge/graph/passes/folding_kernel/pack_kernel.h rename to src/ge/host_kernels/pack_kernel.h index b32e3fae..708e46c3 100644 --- a/src/ge/graph/passes/folding_kernel/pack_kernel.h +++ b/src/ge/host_kernels/pack_kernel.h @@ -31,6 +31,7 @@ class PackKernel : public Kernel { public: Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector &input, std::vector &v_output) override; + private: Status ValidateKernelParams(const ge::OpDescPtr &op_desc_ptr, const std::vector &input); Status ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::vector &input); diff --git a/src/ge/graph/passes/folding_kernel/permute_kernel.cc b/src/ge/host_kernels/permute_kernel.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/permute_kernel.cc rename to src/ge/host_kernels/permute_kernel.cc index 4f0225ac..8263d19f 100644 --- a/src/ge/graph/passes/folding_kernel/permute_kernel.cc +++ b/src/ge/host_kernels/permute_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/permute_kernel.h" +#include "host_kernels/permute_kernel.h" #include #include @@ -30,7 +30,7 @@ #include "common/formats/formats.h" #include "common/formats/format_transfers/format_transfer_transpose.h" #include "common/formats/utils/formats_trans_utils.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "framework/common/ge_inner_error_codes.h" namespace ge { diff --git a/src/ge/graph/passes/folding_kernel/permute_kernel.h b/src/ge/host_kernels/permute_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/permute_kernel.h rename to src/ge/host_kernels/permute_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/range_kernel.cc b/src/ge/host_kernels/range_kernel.cc similarity index 99% rename from src/ge/graph/passes/folding_kernel/range_kernel.cc rename to src/ge/host_kernels/range_kernel.cc index 8bcfa254..4ce3725d 100644 --- a/src/ge/graph/passes/folding_kernel/range_kernel.cc +++ b/src/ge/host_kernels/range_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/range_kernel.h" +#include "host_kernels/range_kernel.h" #include #include diff --git a/src/ge/graph/passes/folding_kernel/range_kernel.h b/src/ge/host_kernels/range_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/range_kernel.h rename to src/ge/host_kernels/range_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/rank_kernel.cc b/src/ge/host_kernels/rank_kernel.cc similarity index 97% rename from src/ge/graph/passes/folding_kernel/rank_kernel.cc rename to src/ge/host_kernels/rank_kernel.cc index 43df2ff7..faaf16b8 100644 --- a/src/ge/graph/passes/folding_kernel/rank_kernel.cc +++ b/src/ge/host_kernels/rank_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/rank_kernel.h" +#include "host_kernels/rank_kernel.h" #include #include diff --git a/src/ge/graph/passes/folding_kernel/rank_kernel.h b/src/ge/host_kernels/rank_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/rank_kernel.h rename to src/ge/host_kernels/rank_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.cc b/src/ge/host_kernels/reduce_prod_kernel.cc similarity index 99% rename from src/ge/graph/passes/folding_kernel/reduce_prod_kernel.cc rename to src/ge/host_kernels/reduce_prod_kernel.cc index e8ea093f..479b50ab 100644 --- a/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.cc +++ b/src/ge/host_kernels/reduce_prod_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/reduce_prod_kernel.h" +#include "host_kernels/reduce_prod_kernel.h" #include #include @@ -24,7 +24,7 @@ #include "common/types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.h b/src/ge/host_kernels/reduce_prod_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/reduce_prod_kernel.h rename to src/ge/host_kernels/reduce_prod_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/reformat_kernel.cc b/src/ge/host_kernels/reformat_kernel.cc similarity index 97% rename from src/ge/graph/passes/folding_kernel/reformat_kernel.cc rename to src/ge/host_kernels/reformat_kernel.cc index 8829d4c4..33a13599 100644 --- a/src/ge/graph/passes/folding_kernel/reformat_kernel.cc +++ b/src/ge/host_kernels/reformat_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/reformat_kernel.h" +#include "host_kernels/reformat_kernel.h" #include "common/formats/utils/formats_trans_utils.h" #include "common/ge/ge_util.h" #include "common/ge_inner_error_codes.h" @@ -22,7 +22,7 @@ #include "common/types.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/reformat_kernel.h b/src/ge/host_kernels/reformat_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/reformat_kernel.h rename to src/ge/host_kernels/reformat_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/reshape_kernel.cc b/src/ge/host_kernels/reshape_kernel.cc similarity index 96% rename from src/ge/graph/passes/folding_kernel/reshape_kernel.cc rename to src/ge/host_kernels/reshape_kernel.cc index 4e925836..906624d2 100644 --- a/src/ge/graph/passes/folding_kernel/reshape_kernel.cc +++ b/src/ge/host_kernels/reshape_kernel.cc @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/reshape_kernel.h" +#include "host_kernels/reshape_kernel.h" #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "common/types.h" #include "framework/common/debug/ge_log.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" namespace ge { diff --git a/src/ge/graph/passes/folding_kernel/reshape_kernel.h b/src/ge/host_kernels/reshape_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/reshape_kernel.h rename to src/ge/host_kernels/reshape_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/rsqrt_kernel.cc b/src/ge/host_kernels/rsqrt_kernel.cc similarity index 96% rename from src/ge/graph/passes/folding_kernel/rsqrt_kernel.cc rename to src/ge/host_kernels/rsqrt_kernel.cc index 25e81713..3e14fd5f 100644 --- a/src/ge/graph/passes/folding_kernel/rsqrt_kernel.cc +++ b/src/ge/host_kernels/rsqrt_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/rsqrt_kernel.h" +#include "host_kernels/rsqrt_kernel.h" #include @@ -25,7 +25,7 @@ #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" namespace ge { diff --git a/src/ge/graph/passes/folding_kernel/rsqrt_kernel.h b/src/ge/host_kernels/rsqrt_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/rsqrt_kernel.h rename to src/ge/host_kernels/rsqrt_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/shape_kernel.cc b/src/ge/host_kernels/shape_kernel.cc similarity index 95% rename from src/ge/graph/passes/folding_kernel/shape_kernel.cc rename to src/ge/host_kernels/shape_kernel.cc index f7475b91..2f20fb24 100644 --- a/src/ge/graph/passes/folding_kernel/shape_kernel.cc +++ b/src/ge/host_kernels/shape_kernel.cc @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/shape_kernel.h" +#include "host_kernels/shape_kernel.h" #include #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/shape_kernel.h b/src/ge/host_kernels/shape_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/shape_kernel.h rename to src/ge/host_kernels/shape_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/shape_n_kernel.cc b/src/ge/host_kernels/shape_n_kernel.cc similarity index 95% rename from src/ge/graph/passes/folding_kernel/shape_n_kernel.cc rename to src/ge/host_kernels/shape_n_kernel.cc index 8ed546de..33b878cf 100644 --- a/src/ge/graph/passes/folding_kernel/shape_n_kernel.cc +++ b/src/ge/host_kernels/shape_n_kernel.cc @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/shape_n_kernel.h" +#include "host_kernels/shape_n_kernel.h" #include #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/shape_n_kernel.h b/src/ge/host_kernels/shape_n_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/shape_n_kernel.h rename to src/ge/host_kernels/shape_n_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/size_kernel.cc b/src/ge/host_kernels/size_kernel.cc similarity index 95% rename from src/ge/graph/passes/folding_kernel/size_kernel.cc rename to src/ge/host_kernels/size_kernel.cc index 3b121ba4..65bb21fc 100644 --- a/src/ge/graph/passes/folding_kernel/size_kernel.cc +++ b/src/ge/host_kernels/size_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/size_kernel.h" +#include "host_kernels/size_kernel.h" #include #include @@ -25,7 +25,7 @@ #include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "framework/common/util.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" #include "omg/omg_inner_types.h" diff --git a/src/ge/graph/passes/folding_kernel/size_kernel.h b/src/ge/host_kernels/size_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/size_kernel.h rename to src/ge/host_kernels/size_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/slice_d_kernel.cc b/src/ge/host_kernels/slice_d_kernel.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/slice_d_kernel.cc rename to src/ge/host_kernels/slice_d_kernel.cc index 900828c2..ad0a1675 100644 --- a/src/ge/graph/passes/folding_kernel/slice_d_kernel.cc +++ b/src/ge/host_kernels/slice_d_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/slice_d_kernel.h" +#include "host_kernels/slice_d_kernel.h" #include @@ -22,7 +22,7 @@ #include "common/op/ge_op_utils.h" #include "common/types.h" #include "framework/common/debug/ge_log.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/slice_d_kernel.h b/src/ge/host_kernels/slice_d_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/slice_d_kernel.h rename to src/ge/host_kernels/slice_d_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/slice_kernel.cc b/src/ge/host_kernels/slice_kernel.cc similarity index 97% rename from src/ge/graph/passes/folding_kernel/slice_kernel.cc rename to src/ge/host_kernels/slice_kernel.cc index a1250367..ac2d5cc3 100644 --- a/src/ge/graph/passes/folding_kernel/slice_kernel.cc +++ b/src/ge/host_kernels/slice_kernel.cc @@ -14,14 +14,14 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/slice_kernel.h" +#include "host_kernels/slice_kernel.h" #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "common/types.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/slice_kernel.h b/src/ge/host_kernels/slice_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/slice_kernel.h rename to src/ge/host_kernels/slice_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/squeeze_kernel.cc b/src/ge/host_kernels/squeeze_kernel.cc similarity index 96% rename from src/ge/graph/passes/folding_kernel/squeeze_kernel.cc rename to src/ge/host_kernels/squeeze_kernel.cc index b253f9a9..5108d9fa 100644 --- a/src/ge/graph/passes/folding_kernel/squeeze_kernel.cc +++ b/src/ge/host_kernels/squeeze_kernel.cc @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/squeeze_kernel.h" +#include "host_kernels/squeeze_kernel.h" #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "common/types.h" #include "framework/common/debug/ge_log.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" namespace { diff --git a/src/ge/graph/passes/folding_kernel/squeeze_kernel.h b/src/ge/host_kernels/squeeze_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/squeeze_kernel.h rename to src/ge/host_kernels/squeeze_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.cc b/src/ge/host_kernels/ssd_prior_box_kernel.cc similarity index 99% rename from src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.cc rename to src/ge/host_kernels/ssd_prior_box_kernel.cc index 15985c5d..c874d732 100644 --- a/src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.cc +++ b/src/ge/host_kernels/ssd_prior_box_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/ssd_prior_box_kernel.h" +#include "host_kernels/ssd_prior_box_kernel.h" #include #include diff --git a/src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.h b/src/ge/host_kernels/ssd_prior_box_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.h rename to src/ge/host_kernels/ssd_prior_box_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/strided_slice_kernel.cc b/src/ge/host_kernels/strided_slice_kernel.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/strided_slice_kernel.cc rename to src/ge/host_kernels/strided_slice_kernel.cc index 3448a071..0d70a36a 100644 --- a/src/ge/graph/passes/folding_kernel/strided_slice_kernel.cc +++ b/src/ge/host_kernels/strided_slice_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/strided_slice_kernel.h" +#include "host_kernels/strided_slice_kernel.h" #include @@ -23,7 +23,7 @@ #include "common/math/math_util.h" #include "common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/strided_slice_kernel.h b/src/ge/host_kernels/strided_slice_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/strided_slice_kernel.h rename to src/ge/host_kernels/strided_slice_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/sub_kernel.cc b/src/ge/host_kernels/sub_kernel.cc similarity index 99% rename from src/ge/graph/passes/folding_kernel/sub_kernel.cc rename to src/ge/host_kernels/sub_kernel.cc index 5934c6c1..ed1e5808 100644 --- a/src/ge/graph/passes/folding_kernel/sub_kernel.cc +++ b/src/ge/host_kernels/sub_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/sub_kernel.h" +#include "host_kernels/sub_kernel.h" #include #include diff --git a/src/ge/graph/passes/folding_kernel/sub_kernel.h b/src/ge/host_kernels/sub_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/sub_kernel.h rename to src/ge/host_kernels/sub_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/transdata_kernel.cc b/src/ge/host_kernels/transdata_kernel.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/transdata_kernel.cc rename to src/ge/host_kernels/transdata_kernel.cc index d3637169..5fe44fe4 100644 --- a/src/ge/graph/passes/folding_kernel/transdata_kernel.cc +++ b/src/ge/host_kernels/transdata_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/transdata_kernel.h" +#include "host_kernels/transdata_kernel.h" #include #include @@ -29,7 +29,7 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/common/bcast.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/transdata_kernel.h b/src/ge/host_kernels/transdata_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/transdata_kernel.h rename to src/ge/host_kernels/transdata_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/transpose_kernel.cc b/src/ge/host_kernels/transpose_kernel.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/transpose_kernel.cc rename to src/ge/host_kernels/transpose_kernel.cc index da5c71d9..574022e0 100644 --- a/src/ge/graph/passes/folding_kernel/transpose_kernel.cc +++ b/src/ge/host_kernels/transpose_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/transpose_kernel.h" +#include "host_kernels/transpose_kernel.h" #include #include #include "common/debug/log.h" @@ -26,7 +26,7 @@ #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/passes/folding_kernel/kernel_utils.h" +#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/src/ge/graph/passes/folding_kernel/transpose_kernel.h b/src/ge/host_kernels/transpose_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/transpose_kernel.h rename to src/ge/host_kernels/transpose_kernel.h diff --git a/src/ge/graph/passes/folding_kernel/unpack_kernel.cc b/src/ge/host_kernels/unpack_kernel.cc similarity index 98% rename from src/ge/graph/passes/folding_kernel/unpack_kernel.cc rename to src/ge/host_kernels/unpack_kernel.cc index 44f666fa..fbfd9e16 100644 --- a/src/ge/graph/passes/folding_kernel/unpack_kernel.cc +++ b/src/ge/host_kernels/unpack_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/passes/folding_kernel/unpack_kernel.h" +#include "host_kernels/unpack_kernel.h" #include "common/debug/ge_log.h" #include "common/op/ge_op_utils.h" #include "common/op/ge_op_utils.h" diff --git a/src/ge/graph/passes/folding_kernel/unpack_kernel.h b/src/ge/host_kernels/unpack_kernel.h similarity index 100% rename from src/ge/graph/passes/folding_kernel/unpack_kernel.h rename to src/ge/host_kernels/unpack_kernel.h diff --git a/src/ge/hybrid/common/npu_memory_allocator.cc b/src/ge/hybrid/common/npu_memory_allocator.cc new file mode 100644 index 00000000..f432318b --- /dev/null +++ b/src/ge/hybrid/common/npu_memory_allocator.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "npu_memory_allocator.h" +#include +#include "framework/common/debug/log.h" +#include "graph/manager/graph_mem_allocator.h" +#include "graph/manager/graph_caching_allocator.h" + +namespace ge { +namespace hybrid { +std::map> NpuMemoryAllocator::allocators_; +std::mutex NpuMemoryAllocator::mu_; + +NpuMemoryAllocator *NpuMemoryAllocator::GetAllocator() { + int32_t device_id = 0; + if (rtGetDevice(&device_id) != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Failed to get device id"); + return nullptr; + } + + GELOGD("Got device id = %d from context", device_id); + return GetAllocator(static_cast(device_id)); +} + +NpuMemoryAllocator::NpuMemoryAllocator(uint32_t device_id) : device_id_(device_id) {} + +void *NpuMemoryAllocator::Allocate(std::size_t size, void *try_reuse_addr) { + void *buffer = + MemManager::CachingInstance(RT_MEMORY_HBM).Malloc(size, reinterpret_cast(try_reuse_addr), device_id_); + if (buffer == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to malloc memory, device_id = %u, size = %zu", device_id_, size); + return nullptr; + } + + GELOGI("Allocating buffer of size %u successfully. device_id = %u, address = %p", size, device_id_, buffer); + return buffer; +} + +void NpuMemoryAllocator::Deallocate(void *data) { + GELOGI("To deallocating buffer, addr = %p", data); + if (data != nullptr) { + GELOGI("Deallocating buffer successfully. addr = %p", data); + MemManager::CachingInstance(RT_MEMORY_HBM).Free(reinterpret_cast(data), device_id_); + } +} + +NpuMemoryAllocator *NpuMemoryAllocator::GetAllocator(uint32_t device_id) { + std::lock_guard lk(mu_); + auto it = allocators_.find(device_id); + if (it == allocators_.end()) { + auto allocator = std::unique_ptr(new (std::nothrow) NpuMemoryAllocator(device_id)); + if (allocator == nullptr) { + return nullptr; + } + + allocators_.emplace(device_id, std::move(allocator)); + } + + return allocators_[device_id].get(); +} + +void NpuMemoryAllocator::DestroyAllocator() { + std::lock_guard lk(mu_); + int device_id = 0; + allocators_.erase(device_id); +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/common/npu_memory_allocator.h b/src/ge/hybrid/common/npu_memory_allocator.h new file mode 100644 index 00000000..8cfeafa6 --- /dev/null +++ b/src/ge/hybrid/common/npu_memory_allocator.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_COMMON_MEMORY_ALLOCATOR_H_ +#define GE_HYBRID_COMMON_MEMORY_ALLOCATOR_H_ + +#include +#include +#include +#include +#include +#include "external/ge/ge_api_error_codes.h" + +namespace ge { +namespace hybrid { +class NpuMemoryAllocator { + public: + ~NpuMemoryAllocator() = default; + static NpuMemoryAllocator *GetAllocator(uint32_t device_id); + static NpuMemoryAllocator *GetAllocator(); + static void DestroyAllocator(); + + void *Allocate(std::size_t size, void *try_reuse_addr = nullptr); + void Deallocate(void *data); + + private: + explicit NpuMemoryAllocator(uint32_t device_id); + uint32_t device_id_; + + static std::map> allocators_; + static std::mutex mu_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_COMMON_MEMORY_ALLOCATOR_H_ diff --git a/src/ge/hybrid/common/tensor_value.cc b/src/ge/hybrid/common/tensor_value.cc new file mode 100644 index 00000000..9544e03a --- /dev/null +++ b/src/ge/hybrid/common/tensor_value.cc @@ -0,0 +1,131 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/common/tensor_value.h" +#include +#include "framework/common/debug/ge_log.h" +#include "hybrid/common/npu_memory_allocator.h" + +namespace ge { +namespace hybrid { +TensorBuffer::TensorBuffer(NpuMemoryAllocator *allocator, void *buffer, size_t size) + : allocator_(allocator), buffer_(buffer), size_(size) {} + +std::unique_ptr TensorBuffer::Create(NpuMemoryAllocator *allocator, size_t size) { + void *buffer = nullptr; + if (size == 0) { + GELOGD("size is 0"); + return Create(buffer, 0U); + } + + if (allocator == nullptr) { + GELOGE(INTERNAL_ERROR, "allocator is NULL"); + return nullptr; + } + + buffer = allocator->Allocate(size); + if (buffer == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to allocate memory. size = %zu", size); + return nullptr; + } + + GELOGD("Tensor created. addr = %p, size = %zu", buffer, size); + return std::unique_ptr(new (std::nothrow) TensorBuffer(allocator, buffer, size)); +} + +std::unique_ptr TensorBuffer::Create(void *buffer, size_t size) { + GELOGD("Tensor created. addr = %p, size = %zu", buffer, size); + return std::unique_ptr(new (std::nothrow) TensorBuffer(nullptr, buffer, size)); +} + +TensorBuffer::~TensorBuffer() { + if (allocator_ != nullptr && buffer_ != nullptr) { + allocator_->Deallocate(buffer_); + } +} + +TensorValue::TensorValue(std::shared_ptr buffer) : buffer_(std::move(buffer)) {} + +TensorValue::TensorValue(void *buffer, size_t size) : ref_buffer_(buffer), ref_size_(size) {} + +TensorValue::~TensorValue() { Destroy(); } + +void TensorValue::Destroy() { + if (buffer_ != nullptr || ref_buffer_ != nullptr) { + GELOGD("Unref tensor: %s", DebugString().c_str()); + buffer_.reset(); + } +} + +size_t TensorValue::GetSize() const { + if (ref_buffer_ != nullptr) { + return ref_size_; + } + + if (buffer_ == nullptr) { + GELOGD("TensorValue[%s] is empty", name_.c_str()); + return 0; + } + + return buffer_->GetSize(); +} + +const void *TensorValue::GetData() const { + if (ref_buffer_ != nullptr) { + return ref_buffer_; + } + + if (buffer_ == nullptr) { + GELOGD("TensorValue[%s] is empty", name_.c_str()); + return nullptr; + } + return buffer_->GetData(); +} + +void *TensorValue::MutableData() { + if (ref_buffer_ != nullptr) { + return ref_buffer_; + } + + if (buffer_ == nullptr) { + GELOGD("TensorValue[%s] is empty", name_.c_str()); + return nullptr; + } + + return buffer_->GetData(); +} + +std::string TensorValue::DebugString() const { + std::stringstream ss; + ss << "TensorValue["; + if (name_.empty()) { + ss << "unnamed] "; + } else { + ss << name_ << "] "; + } + + if (ref_buffer_ != nullptr) { + ss << "ref_addr = " << ref_buffer_ << ", size = " << ref_size_; + } else if (buffer_ != nullptr) { + ss << "addr = " << buffer_->GetData() << ", size = " << buffer_->GetSize(); + } else { + ss << "addr = (nil)"; + } + + return ss.str(); +} +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/common/tensor_value.h b/src/ge/hybrid/common/tensor_value.h new file mode 100644 index 00000000..18e67534 --- /dev/null +++ b/src/ge/hybrid/common/tensor_value.h @@ -0,0 +1,82 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_COMMON_TENSOR_VALUE_H_ +#define GE_HYBRID_COMMON_TENSOR_VALUE_H_ + +#include +#include +#include + +namespace ge { +namespace hybrid { +class NpuMemoryAllocator; + +class TensorBuffer { + public: + static std::unique_ptr Create(NpuMemoryAllocator *allocator, size_t size); + + static std::unique_ptr Create(void *buffer, size_t size); + + ~TensorBuffer(); + + void *GetData() { return buffer_; } + + size_t GetSize() const { return size_; } + + private: + TensorBuffer(NpuMemoryAllocator *allocator, void *buffer, size_t size); + + NpuMemoryAllocator *allocator_ = nullptr; + void *buffer_ = nullptr; + size_t size_ = 0; +}; + +class TensorValue { + public: + TensorValue() = default; + + explicit TensorValue(std::shared_ptr buffer); + + TensorValue(void *buffer, size_t size); + + ~TensorValue(); + + void Destroy(); + + bool IsEmpty() { return ref_buffer_ == nullptr && buffer_ == nullptr; } + + const void *GetData() const; + + std::string DebugString() const; + + void SetName(const std::string &name) { name_ = name; } + + void *MutableData(); + + size_t GetSize() const; + + private: + std::shared_ptr buffer_; + std::string name_; + // for weights and variables + void *ref_buffer_ = nullptr; + size_t ref_size_ = 0; + // shape +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_COMMON_TENSOR_VALUE_H_ diff --git a/src/ge/hybrid/executor/hybrid_execution_context.cc b/src/ge/hybrid/executor/hybrid_execution_context.cc new file mode 100644 index 00000000..bb8e0195 --- /dev/null +++ b/src/ge/hybrid/executor/hybrid_execution_context.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid_execution_context.h" + +namespace ge { +namespace hybrid { +NodeStatePtr GraphExecutionContext::GetOrCreateNodeState(const NodePtr &node) { + auto &node_state = node_states[node]; + if (node_state == nullptr) { + const NodeItem *node_item = model->GetNodeItem(node); + if (node_item == nullptr) { + return nullptr; + } + node_state.reset(new (std::nothrow) NodeState(*node_item)); + } + + return node_state; +} + +void GraphExecutionContext::OnError(Status error_code) { + GELOGE(error_code, "Error occurred while executing model"); + { + std::lock_guard lk(mu_); + this->status = error_code; + } + + compile_queue.Stop(); + execution_queue.Stop(); +} + +Status GraphExecutionContext::GetStatus() { + std::lock_guard lk(mu_); + return status; +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/hybrid_execution_context.h b/src/ge/hybrid/executor/hybrid_execution_context.h new file mode 100644 index 00000000..f7e7af88 --- /dev/null +++ b/src/ge/hybrid/executor/hybrid_execution_context.h @@ -0,0 +1,85 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_HYBRID_EXECUTION_CONTEXT_H_ +#define GE_HYBRID_EXECUTOR_HYBRID_EXECUTION_CONTEXT_H_ + +#include +#include +#include "common/blocking_queue.h" +#include "hybrid/common/npu_memory_allocator.h" +#include "hybrid/common/tensor_value.h" +#include "hybrid/executor/hybrid_profiler.h" +#include "hybrid/executor/node_done_manager.h" +#include "hybrid/executor/node_state.h" +#include "hybrid/executor/rt_callback_manager.h" +#include "hybrid/model/hybrid_model.h" + +namespace ge { +namespace hybrid { +struct GraphExecutionContext { + uint64_t session_id = 0; + const HybridModel *model = nullptr; + NodeDoneManager cv_manager; + BlockingQueue compile_queue; + BlockingQueue execution_queue; + std::vector all_inputs; + std::vector all_outputs; + std::unordered_map node_states; + rtStream_t stream = nullptr; + std::unique_ptr callback_manager; + NpuMemoryAllocator *allocator; + mutable std::unique_ptr profiler; + bool trace_enabled = false; + int profiling_level = 0; + bool dump_enabled = false; + Status status; + std::mutex mu_; + + NodeStatePtr GetOrCreateNodeState(const NodePtr &node); + void OnError(Status status); + Status GetStatus(); +}; + +#define RECORD_PROFILING_EVENT(context, event_type, fmt, category, node_name, ...) \ + do { \ + if ((context)->profiler != nullptr) { \ + if (node_name != nullptr) { \ + context->profiler->RecordEvent(event_type, "[%s] [%s] " fmt, node_name, category, ##__VA_ARGS__); \ + } else { \ + context->profiler->RecordEvent(event_type, "[%s] " fmt, category, ##__VA_ARGS__); \ + } \ + } \ + } while (0) + +#define RECORD_MODEL_EXECUTION_EVENT(context, fmt, ...) \ + RECORD_PROFILING_EVENT((context), HybridProfiler::GENERAL, fmt, "ModelExecutor", nullptr, ##__VA_ARGS__) + +#define RECORD_SHAPE_INFERENCE_EVENT(context, name, fmt, ...) \ + RECORD_PROFILING_EVENT((context), HybridProfiler::SHAPE_INFERENCE, fmt, "ShapeInference", name, ##__VA_ARGS__) + +#define RECORD_COMPILE_EVENT(context, name, fmt, ...) \ + RECORD_PROFILING_EVENT((context), HybridProfiler::COMPILE, fmt, "Compilation", name, ##__VA_ARGS__) + +#define RECORD_EXECUTION_EVENT(context, name, fmt, ...) \ + RECORD_PROFILING_EVENT((context), HybridProfiler::EXECUTION, fmt, "Execution", name, ##__VA_ARGS__) + +#define RECORD_CALLBACK_EVENT(context, name, fmt, ...) \ + RECORD_PROFILING_EVENT((context), HybridProfiler::CALLBACK, fmt, "Callback", name, ##__VA_ARGS__) + +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_EXECUTOR_HYBRID_EXECUTION_CONTEXT_H_ diff --git a/src/ge/hybrid/executor/hybrid_model_async_executor.cc b/src/ge/hybrid/executor/hybrid_model_async_executor.cc new file mode 100644 index 00000000..2999daba --- /dev/null +++ b/src/ge/hybrid/executor/hybrid_model_async_executor.cc @@ -0,0 +1,319 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/executor/hybrid_model_async_executor.h" +#include "graph/load/new_model_manager/model_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" +#include "omm/csa_interact.h" + +namespace ge { +namespace hybrid { +namespace { +int kDataOutputIndex = 0; +} +HybridModelAsyncExecutor::HybridModelAsyncExecutor(HybridModel *model) : model_(model), run_flag_(false) {} + +HybridModelAsyncExecutor::~HybridModelAsyncExecutor() { + if (stream_ != nullptr) { + GE_CHK_RT(rtStreamDestroy(stream_)); + } +} + +void HybridModelAsyncExecutor::SetDeviceId(uint32_t device_id) { device_id_ = device_id; } + +void HybridModelAsyncExecutor::SetModelId(uint32_t model_id) { model_id_ = model_id; } + +Status HybridModelAsyncExecutor::EnqueueData(const shared_ptr &data) { + GE_CHK_STATUS_EXEC(data_inputer_->Push(data), return domi::DATA_QUEUE_ISFULL, + "Data queue is full, please call again later, model_id %u ", model_id_); + GELOGD("EnqueueData successfully. model_id = %u, data_index = %u", data->GetInput().model_id, data->GetInput().index); + return SUCCESS; +} + +Status HybridModelAsyncExecutor::Start(const std::shared_ptr &listener) { + GELOGD("HybridModelExecutor::Start IN, listener = %p", listener.get()); + std::lock_guard lk(mu_); + GE_CHK_BOOL_RET_STATUS(!run_flag_, INTERNAL_ERROR, "Model already started."); + + run_flag_ = true; + listener_ = listener; + future_ = std::async([&]() -> Status { return RunInternal(); }); + + GE_CHK_BOOL_RET_STATUS(future_.valid(), INTERNAL_ERROR, "Failed to start."); + GELOGD("HybridModelExecutor::Start successfully"); + return SUCCESS; +} + +Status HybridModelAsyncExecutor::Stop() { + std::lock_guard lk(mu_); + run_flag_ = false; + data_inputer_->Stop(); + auto ret = future_.get(); + + if (stream_ != nullptr) { + GE_CHK_RT(rtStreamDestroy(stream_)); + stream_ = nullptr; + } + + return ret; +} + +Status HybridModelAsyncExecutor::Init() { + data_inputer_ = std::unique_ptr(new (std::nothrow) DataInputer()); + GE_CHECK_NOTNULL(data_inputer_); + GE_CHK_RT_RET(rtStreamCreate(&stream_, RT_STREAM_PRIORITY_DEFAULT)); + + engine_ = std::unique_ptr(new (std::nothrow) HybridModelExecutor(model_, device_id_, stream_)); + GE_CHECK_NOTNULL(engine_); + GE_CHK_STATUS_RET(engine_->Init(), "Failed to init hybrid engine"); + + GE_CHK_STATUS_RET(InitInputTensors(), "Failed to init input tensors"); + return SUCCESS; +} + +Status HybridModelAsyncExecutor::PreRun(InputData ¤t_data) { + GE_CHK_STATUS_RET(SyncVarData(), "Failed to sync var data"); + RECORD_MODEL_EXECUTION_EVENT(engine_->GetContext(), "[SyncVarData] End"); + GE_CHK_STATUS_RET(CopyInputData(current_data), "Failed to copy input data to model"); + RECORD_MODEL_EXECUTION_EVENT(engine_->GetContext(), "[CopyInputData] End"); + return SUCCESS; +} + +Status HybridModelAsyncExecutor::RunInternal() { + auto device_id = static_cast(device_id_); + GELOGD("Hybrid model start. model_id = %u, device_id = %u", model_id_, device_id_); + GE_CHK_RT_RET(rtSetDevice(device_id)); + // DeviceReset before thread run finished! + GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(device_id)); }); + + while (run_flag_) { + std::shared_ptr data_wrapper; + Status ret = data_inputer_->Pop(data_wrapper); + if (data_wrapper == nullptr || ret != SUCCESS) { + GELOGI("data_wrapper is null!, ret = %u", ret); + continue; + } + + GELOGI("Getting the input data, model_id:%u", model_id_); + GE_IF_BOOL_EXEC(!run_flag_, break); + InputData current_data = data_wrapper->GetInput(); + GELOGI("Model thread Run begin, model id:%u, data index:%u.", model_id_, current_data.index); + + HybridModelExecutor::ExecuteArgs args; + args.inputs.resize(input_tensors_.size()); + for (auto &it : input_tensors_) { + args.inputs[it.first] = it.second; + } + + RECORD_MODEL_EXECUTION_EVENT(engine_->GetContext(), "[RunInternal] [iteration = %d] Start", iterator_count_); + ret = PreRun(current_data); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ret != SUCCESS, (void)HandleResult(ret, current_data.index, args.outputs, data_wrapper->GetOutput()); + CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); + continue, "PreRun failed."); // [No need to check value] + + ret = engine_->Execute(args); + ret = HandleResult(ret, current_data.index, args.outputs, data_wrapper->GetOutput()); + if (ret != SUCCESS) { + CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); + continue; + } + + RECORD_MODEL_EXECUTION_EVENT(engine_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_); + iterator_count_++; + GELOGI("run iterator count is %lu", iterator_count_); + } + + CsaInteract::GetInstance().WriteInternalErrorCode(); + GELOGI("Model run end, model id:%u", model_id_); + return SUCCESS; +} + +Status HybridModelAsyncExecutor::HandleResult(Status exec_ret, uint32_t data_id, + const std::vector &output_tensors, OutputData *output_data) { + GELOGD("Start to handle result. model id = %u, data index = %u, execution ret = %u", model_id_, data_id, exec_ret); + std::vector output_tensor_info_list; + if (exec_ret == END_OF_SEQUENCE) { + GELOGW("End of sequence, model id = %u", model_id_); + return OnComputeDone(data_id, END_OF_SEQUENCE, output_tensor_info_list); + } + + if (exec_ret != SUCCESS) { + GELOGE(exec_ret, "Failed to execute graph. model_id = %u", model_id_); + return OnComputeDone(data_id, INTERNAL_ERROR, output_tensor_info_list); + } + + GE_CHECK_NOTNULL(output_data); + auto ret = CopyOutputs(output_tensors, output_data, output_tensor_info_list); + if (ret != SUCCESS) { + OnComputeDone(data_id, INTERNAL_ERROR, output_tensor_info_list); + return INTERNAL_ERROR; + } + + GELOGD("Executed graph successfully, model id = %u, data_index = %u", model_id_, data_id); + return OnComputeDone(data_id, SUCCESS, output_tensor_info_list); +} + +Status HybridModelAsyncExecutor::SyncVarData() { + GELOGI("Sync var data, model id:%u", model_id_); + + TensorValue *global_step_var = model_->GetVariable(NODE_NAME_GLOBAL_STEP); + if (global_step_var != nullptr) { + std::vector v_step; + v_step.push_back(iterator_count_); + GE_CHK_RT_RET(rtMemcpy(global_step_var->MutableData(), global_step_var->GetSize(), v_step.data(), + v_step.size() * sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE)); + } else { + GELOGD("No GLOBAL_STEP variable was found."); + } + + return SUCCESS; +} + +Status HybridModelAsyncExecutor::CopyInputData(const InputData ¤t_data) { + const std::vector &blobs = current_data.blobs; + for (const auto &it : input_tensors_) { + auto input_index = it.first; + auto input_tensor = it.second; + auto data_size = input_tensor.GetSize(); + GELOGD("To copy input data for input[%u]", input_index); + if (input_index >= blobs.size()) { + GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld", blobs.size(), + model_->input_nodes_.size(), input_index, data_size); + return FAILED; + } + + const DataBuffer &data_buf = blobs[input_index]; + auto mem_size = static_cast(data_size); + GE_CHK_BOOL_RET_STATUS(mem_size >= data_buf.length, PARAM_INVALID, + "input data size(%u) does not match model required size(%u), ret failed.", data_buf.length, + mem_size); + + GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] output[%u] memaddr[%p] mem_size[%u] datasize[%u]", + model_->root_runtime_param_.graph_id, input_index, input_tensor.GetData(), mem_size, data_buf.length); + GE_CHK_RT_RET( + rtMemcpy(input_tensor.MutableData(), mem_size, data_buf.data, data_buf.length, RT_MEMCPY_HOST_TO_DEVICE)); + } + + return SUCCESS; +} + +Status HybridModelAsyncExecutor::InitInputTensors() { + auto allocator = NpuMemoryAllocator::GetAllocator(device_id_); + GE_CHECK_NOTNULL(allocator); + for (const auto &it : model_->input_nodes_) { + auto input_index = it.first; + auto input_node = it.second; + GELOGD("Init input[%u], node = %s", input_index, input_node->NodeName().c_str()); + auto output_desc = input_node->op_desc->GetOutputDescPtr(kDataOutputIndex); + GE_CHECK_NOTNULL(output_desc); + int64_t tensor_size = 0; + GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetSize(*output_desc, tensor_size), "Failed to get size from %s", + input_node->NodeName().c_str()); + if (tensor_size == 0) { + GELOGW("[%s] Tensor size == 0", input_node->NodeName().c_str()); + GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*output_desc, tensor_size), + "Failed to calc tensor size"); + GELOGD("[%s] Tensor size updated to %ld", input_node->NodeName().c_str(), tensor_size); + } + auto buffer = TensorBuffer::Create(allocator, tensor_size); + GE_CHECK_NOTNULL(buffer); + TensorValue tensor(shared_ptr(buffer.release())); + tensor.SetName("Input_" + input_node->NodeName()); + input_tensors_.emplace(input_index, tensor); + } + + return SUCCESS; +} + +Status HybridModelAsyncExecutor::OnComputeDone(uint32_t data_index, uint32_t result_code, + std::vector &outputs) { + GELOGD("OnComputeDone. model id = %u, data index = %u, execution ret = %u", model_id_, data_index, result_code); + if (listener_ != nullptr) { + GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_index, result_code, outputs), "OnComputeDone failed"); + } + + return result_code; +} + +Status HybridModelAsyncExecutor::CopyOutputs(const std::vector &output_tensors, OutputData *output_data, + std::vector &outputs) { + // copy output data from op to designated position + NodeItem *net_output_node = model_->net_output_node_; + GE_CHECK_NOTNULL(net_output_node); + auto all_input_desc = net_output_node->op_desc->GetAllInputsDescPtr(); + + if (all_input_desc.size() != output_tensors.size()) { + GELOGE(INTERNAL_ERROR, "Output sizes mismatch. From op_desc = %zu, and from output tensors = %zu", + all_input_desc.size(), output_tensors.size()); + return INTERNAL_ERROR; + } + + GELOGD("Number of outputs = %zu", all_input_desc.size()); + for (size_t i = 0; i < output_tensors.size(); ++i) { + GELOGD("Start to process output[%zu]", i); + auto &output_tensor = output_tensors[i]; + auto &tensor_desc = all_input_desc.at(i); + GE_CHECK_NOTNULL(tensor_desc); + int64_t output_size = -1; + GE_CHK_GRAPH_STATUS_RET(TensorUtils::CalcTensorMemSize(tensor_desc->MutableShape(), tensor_desc->GetFormat(), + tensor_desc->GetDataType(), output_size), + "Failed to calc tensor size for output[%zu]. shape = [%s], type = %s, format = %s", i, + tensor_desc->MutableShape().ToString().c_str(), + TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(), + TypeUtils::FormatToSerialString(tensor_desc->GetFormat()).c_str()); + + GELOGD("Got tensor size for output[%zu] successfully. shape = [%s], type = %s, format = %s, size = %ld", i, + tensor_desc->MutableShape().ToString().c_str(), + TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(), + TypeUtils::FormatToSerialString(tensor_desc->GetFormat()).c_str(), output_size); + + GE_CHECK_GE(output_size, 0); + GE_CHECK_LE(output_size, UINT32_MAX); + if (output_tensor.GetSize() < static_cast(output_size)) { + GELOGE(INTERNAL_ERROR, "output[%zu] tensor size(%zu) is not enough for output shape [%s]", i, + output_tensor.GetSize(), tensor_desc->MutableShape().ToString().c_str()); + return INTERNAL_ERROR; + } + + ge::OutputTensorInfo output; + output.data_type = static_cast(tensor_desc->GetDataType()); + output.dims = tensor_desc->GetShape().GetDims(); + output.length = output_size; + if (output_size > 0) { + std::unique_ptr data_buf(new (std::nothrow) uint8_t[output_size]); + GE_CHECK_NOTNULL(data_buf); + GE_CHK_RT_RET( + rtMemcpy(data_buf.get(), output_size, output_tensor.GetData(), output_size, RT_MEMCPY_DEVICE_TO_HOST)); + output.data = std::move(data_buf); + output_data->blobs.emplace_back(data_buf.get(), static_cast(output_size), false); + } else { + GELOGW("Output[%zu] is empty. shape = [%s]", i, tensor_desc->MutableShape().ToString().c_str()); + output.data = nullptr; + output_data->blobs.emplace_back(nullptr, 0U, false); + } + + outputs.emplace_back(std::move(output)); + GELOGD("Output[%zu] added, type = %s, shape = [%s], size = %ld", i, + TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(), + tensor_desc->MutableShape().ToString().c_str(), output_size); + } + + return SUCCESS; +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/hybrid_model_async_executor.h b/src/ge/hybrid/executor/hybrid_model_async_executor.h new file mode 100644 index 00000000..cb440ba7 --- /dev/null +++ b/src/ge/hybrid/executor/hybrid_model_async_executor.h @@ -0,0 +1,83 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_MODEL_HYBRID_MODEL_ASYNC_EXECUTOR_H_ +#define GE_HYBRID_EXECUTOR_MODEL_HYBRID_MODEL_ASYNC_EXECUTOR_H_ +#include +#include +#include +#include "external/ge/ge_api_error_codes.h" +#include "external/ge/ge_api_types.h" +#include "graph/load/new_model_manager/data_inputer.h" +#include "hybrid/executor/hybrid_model_executor.h" +#include "runtime/stream.h" + +namespace ge { +namespace hybrid { +class HybridModel; +class HybridModelAsyncExecutor { + public: + explicit HybridModelAsyncExecutor(HybridModel *model); + ~HybridModelAsyncExecutor(); + + Status Init(); + + Status Start(const std::shared_ptr &listener); + + void SetDeviceId(uint32_t device_id); + + void SetModelId(uint32_t model_id); + + Status Stop(); + + Status EnqueueData(const std::shared_ptr &data); + + private: + Status InitInputTensors(); + + Status RunInternal(); + + Status SyncVarData(); + + Status HandleResult(Status exec_ret, uint32_t data_id, const std::vector &output_tensors, + OutputData *output_data); + + Status CopyOutputs(const std::vector &output_tensors, OutputData *output_data, + std::vector &outputs); + + Status OnComputeDone(uint32_t data_index, uint32_t result_code, std::vector &outputs); + + Status PreRun(InputData ¤t_data); + + Status CopyInputData(const InputData ¤t_data); + + std::mutex mu_; + HybridModel *model_; + uint32_t device_id_ = 0U; + uint32_t model_id_ = 0U; + std::atomic_bool run_flag_; + std::unique_ptr data_inputer_; + std::unique_ptr engine_; + std::future future_; + uint64_t iterator_count_ = 0; + + rtStream_t stream_ = nullptr; + std::map input_tensors_; + std::shared_ptr listener_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_EXECUTOR_MODEL_HYBRID_MODEL_ASYNC_EXECUTOR_H_ diff --git a/src/ge/hybrid/executor/hybrid_model_executor.cc b/src/ge/hybrid/executor/hybrid_model_executor.cc new file mode 100644 index 00000000..97dc5a36 --- /dev/null +++ b/src/ge/hybrid/executor/hybrid_model_executor.cc @@ -0,0 +1,172 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid_model_executor.h" +#include "graph/ge_context.h" +#include "graph/runtime_inference_context.h" + +namespace ge { +namespace hybrid { +HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, rtStream_t stream) + : model_(model), device_id_(device_id), stream_(stream) {} + +Status HybridModelExecutor::Init() { + GELOGD("Start to init HybridGraphEngine."); + GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); + infer_shape_engine_.reset(new (std::nothrow) ShapeInferenceEngine(&context_)); + compile_engine_.reset(new (std::nothrow) TaskCompileEngine(&context_)); + execute_engine_.reset(new (std::nothrow) ExecutionEngine(&context_, context_.callback_manager.get())); + GE_CHK_STATUS_RET_NOLOG(compile_engine_->Init()); + GELOGD("HybridGraphEngine initialized successfully."); + return SUCCESS; +} + +Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { + GELOGD("Start to execute model."); + auto ret = ExecuteGraphInternal(args); + Cleanup(); + RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); + GE_CHK_STATUS_RET(ret, "Failed to execute model"); + GELOGD("Model executed successfully."); + + if (context_.profiler != nullptr) { + context_.profiler->Reset(); + } + + return SUCCESS; +} + +Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) { + RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); + GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); + RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); + GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs(args, context_)); + RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitInputsAndOutputs] End"); + GE_CHK_STATUS_RET_NOLOG(compile_engine_->Start(pool_)); + RECORD_MODEL_EXECUTION_EVENT(&context_, "[CompileProcess] Started"); + GE_CHK_STATUS_RET_NOLOG(infer_shape_engine_->Start(pool_)); + RECORD_MODEL_EXECUTION_EVENT(&context_, "[InferShapeProcess] Started"); + GE_CHK_STATUS_RET(execute_engine_->Start(), "Run execution engine failed."); + RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecutionProcess] End"); + GE_CHK_STATUS_RET_NOLOG(Synchronize()); + RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End"); + GE_CHK_STATUS_RET_NOLOG(GetOutput(args)); + RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); + return SUCCESS; +} + +Status HybridModelExecutor::Cleanup() { + GELOGD("Start to cleanup."); + context_.callback_manager->Destroy(); + context_.cv_manager.Reset(); + context_.node_states.clear(); + context_.all_inputs.clear(); + context_.all_outputs.clear(); + context_.compile_queue.Clear(); + context_.execution_queue.Clear(); + RuntimeInferenceContext::DestroyContext(to_string(context_.session_id)); + GELOGD("Cleanup successfully."); + return SUCCESS; +} + +Status HybridModelExecutor::InitExecutionContext() { + context_.stream = stream_; + context_.model = model_; + context_.session_id = ::ge::GetContext().SessionId(); + GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); + context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_); + GE_CHECK_NOTNULL(context_.allocator); + context_.callback_manager = std::unique_ptr(new (std::nothrow) CallbackManager(stream_)); + GE_CHECK_NOTNULL(context_.callback_manager); + if (IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) { + context_.trace_enabled = true; + } + + return SUCCESS; +} + +Status HybridModelExecutor::ResetExecutionContext(GraphExecutionContext &context) { + auto &model = *context.model; + context.all_inputs.resize(model.TotalInputs()); + context.all_outputs.resize(model.TotalOutputs()); + context.compile_queue.Restart(); + context.execution_queue.Restart(); + GE_CHK_STATUS_RET_NOLOG(context.callback_manager->Init()); + + // TODO do not re-assign Consts every run + for (auto const_node : model.GetConstNodes()) { + auto weight_tensor = model.GetWeight(const_node); + GE_CHECK_NOTNULL(weight_tensor); + for (auto &dst_aid_and_nid : const_node->outputs[0]) { + auto *dst_node_item = dst_aid_and_nid.second; + auto input_offset = dst_node_item->input_start + dst_aid_and_nid.first; + context.all_inputs[input_offset] = *weight_tensor; + } + } + + string ctx_id = std::to_string(context.session_id); + RuntimeInferenceContext::DestroyContext(ctx_id); + GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext"); + return SUCCESS; +} + +Status HybridModelExecutor::InitInputsAndOutputs(HybridModelExecutor::ExecuteArgs &args, + GraphExecutionContext &context) { + for (const auto &it : model_->GetInputNodes()) { + uint32_t input_index = it.first; + if (input_index >= args.inputs.size()) { + GELOGE(PARAM_INVALID, "Not enough inputs. NumInputs = %zu, but input index = %u", args.inputs.size(), + input_index); + return PARAM_INVALID; + } + + auto node_item = it.second; + auto &input_tensor = args.inputs[input_index]; + GELOGD("Set input tensor[%u] to inputs with index = %d, addr = %p, size = %zu", input_index, node_item->input_start, + input_tensor.GetData(), input_tensor.GetSize()); + context.all_inputs[node_item->input_start] = input_tensor; + } + + for (size_t i = 0; i < model_->GetOutputOffsets().size(); ++i) { + auto offset = model_->GetOutputOffsets()[i]; + if (i < args.outputs.size() && args.outputs[i].GetData() != nullptr) { + GELOGD("Use user allocated output memory. output index = %zu, output offset = %d", i, offset); + context.all_outputs[offset] = args.outputs[i]; + } + } + + return SUCCESS; +} + +Status HybridModelExecutor::Synchronize() { + GE_CHK_RT_RET(rtStreamSynchronize(stream_)); + return SUCCESS; +} + +Status HybridModelExecutor::GetOutput(HybridModelExecutor::ExecuteArgs &args) { + auto &net_output_input_offsets = model_->GetNetOutputInputOffsets(); + auto num_outputs = net_output_input_offsets.size(); + args.outputs.resize(num_outputs); + for (size_t i = 0; i < num_outputs; ++i) { + auto offset = net_output_input_offsets[i]; + GELOGI("Get output[%zu] from offset %d", i, offset); + args.outputs[i] = context_.all_inputs[offset]; + } + + return SUCCESS; +} +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/executor/hybrid_model_executor.h b/src/ge/hybrid/executor/hybrid_model_executor.h new file mode 100644 index 00000000..2bda6331 --- /dev/null +++ b/src/ge/hybrid/executor/hybrid_model_executor.h @@ -0,0 +1,68 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_HYBRID_MODEL_EXECUTOR_H_ +#define GE_HYBRID_EXECUTOR_HYBRID_MODEL_EXECUTOR_H_ +#include "common/thread_pool.h" +#include "graph/load/new_model_manager/data_inputer.h" +#include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/rt_callback_manager.h" +#include "hybrid/executor/worker/execution_engine.h" +#include "hybrid/executor/worker/shape_inference_engine.h" +#include "hybrid/executor/worker/task_compile_engine.h" + +namespace ge { +namespace hybrid { +class HybridModelExecutor { + public: + struct ExecuteArgs { + std::vector inputs; + std::vector outputs; + }; + + HybridModelExecutor(HybridModel *model, uint32_t device_id, rtStream_t stream); + + ~HybridModelExecutor() = default; + + Status Init(); + + const GraphExecutionContext *GetContext() const { return &context_; } + + Status Execute(ExecuteArgs &args); + + private: + Status ExecuteGraphInternal(ExecuteArgs &args); + Status Cleanup(); + Status InitExecutionContext(); + static Status ResetExecutionContext(GraphExecutionContext &context); + + Status InitInputsAndOutputs(ExecuteArgs &args, GraphExecutionContext &context); + Status GetOutput(ExecuteArgs &args); + + Status Synchronize(); + + ThreadPool pool_; + HybridModel *model_; + uint32_t device_id_; + rtStream_t stream_; + GraphExecutionContext context_; + std::unique_ptr infer_shape_engine_; + std::unique_ptr compile_engine_; + std::unique_ptr execute_engine_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_EXECUTOR_HYBRID_MODEL_EXECUTOR_H_ diff --git a/src/ge/hybrid/executor/hybrid_profiler.cc b/src/ge/hybrid/executor/hybrid_profiler.cc new file mode 100644 index 00000000..1081a144 --- /dev/null +++ b/src/ge/hybrid/executor/hybrid_profiler.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid_profiler.h" +#include +#include +#include +#include "framework/common/debug/ge_log.h" +#include "securec.h" + +namespace ge { +namespace hybrid { +namespace { +const int kMaxEvents = 10000; +const int kEventDescMax = 256; +const int kMaxEventTypes = 8; +const int kIndent = 8; +} // namespace + +HybridProfiler::HybridProfiler() : counter_(0) { Reset(); } + +void HybridProfiler::RecordEvent(EventType event_type, const char *fmt, ...) { + va_list args; + va_start(args, fmt); + + char buf[kEventDescMax]; + if (vsnprintf_s(buf, kEventDescMax, kEventDescMax - 1, fmt, args) == -1) { + GELOGE(FAILED, "Format %s failed.", fmt); + va_end(args); + return; + } + + va_end(args); + std::string event = buf; + auto index = counter_++; + auto &evt = events_[index]; + evt.timestamp = std::chrono::system_clock::now(); + evt.desc = std::move(event); + evt.event_type = event_type; +} + +void HybridProfiler::Dump(std::ostream &output_stream) { + if (events_.empty()) { + return; + } + + auto first_evt = events_[0]; + auto start = first_evt.timestamp; + output_stream << "Start " << first_evt.desc << std::endl; + std::vector prev_timestamps; + prev_timestamps.resize(kMaxEventTypes, start); + + for (int i = 1; i < counter_; ++i) { + auto &evt = events_[i]; + auto elapsed = std::chrono::duration_cast(evt.timestamp - start).count(); + auto &prev_ts = prev_timestamps[evt.event_type]; + auto cost = std::chrono::duration_cast(evt.timestamp - prev_ts).count(); + prev_ts = evt.timestamp; + output_stream << std::setw(kIndent) << elapsed << "\t\t" << cost << "\t\t" << evt.desc << std::endl; + } + + events_.clear(); +} + +void HybridProfiler::Reset() { + counter_ = 0; + events_.clear(); + events_.resize(kMaxEvents); +} +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/executor/hybrid_profiler.h b/src/ge/hybrid/executor/hybrid_profiler.h new file mode 100644 index 00000000..6f6794f4 --- /dev/null +++ b/src/ge/hybrid/executor/hybrid_profiler.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_HYBRID_PROFILER_H_ +#define GE_HYBRID_EXECUTOR_HYBRID_PROFILER_H_ + +#include +#include +#include +#include +#include +#include + +namespace ge { +namespace hybrid { +class HybridProfiler { + public: + enum EventType { + GENERAL, + SHAPE_INFERENCE, + COMPILE, + EXECUTION, + CALLBACK, + }; + + struct Event { + std::chrono::system_clock::time_point timestamp; + EventType event_type; + std::string desc; + }; + + HybridProfiler(); + ~HybridProfiler() = default; + + void RecordEvent(EventType event_type, const char *fmt, ...); + + void Reset(); + + void Dump(std::ostream &os); + + private: + std::vector events_; + std::atomic_int counter_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_EXECUTOR_HYBRID_PROFILER_H_ diff --git a/src/ge/hybrid/executor/node_done_manager.cc b/src/ge/hybrid/executor/node_done_manager.cc new file mode 100644 index 00000000..dfeddb5b --- /dev/null +++ b/src/ge/hybrid/executor/node_done_manager.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/executor/node_done_manager.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +namespace hybrid { +bool NodeDoneManager::Cond::Await() { + std::unique_lock lk(mu_); + cv_.wait(lk, [&]() { return is_released_ || is_cancelled_; }); + return is_released_; +} + +void NodeDoneManager::Cond::Release() { + std::unique_lock lk(mu_); + is_released_ = true; + cv_.notify_all(); +} + +void NodeDoneManager::Cond::Cancel() { + std::unique_lock lk(mu_); + is_cancelled_ = true; + cv_.notify_all(); +} + +bool NodeDoneManager::Cond::IsRelease() { + std::unique_lock lk(mu_); + return is_released_; +} + +NodeDoneManager::Cond *NodeDoneManager::GetSubject(const NodePtr &node) { + std::lock_guard lk(mu_); + auto it = subjects_.find(node); + if (it == subjects_.end()) { + return &subjects_[node]; + } + + return &it->second; +} + +void NodeDoneManager::Reset() { + std::lock_guard lk(mu_); + for (auto &sub : subjects_) { + if (!sub.second.IsRelease()) { + sub.second.Cancel(); + GELOGD("[%s] Node canceled.", sub.first->GetName().c_str()); + } + } + + subjects_.clear(); +} + +void NodeDoneManager::NodeDone(const NodePtr &node) { + GetSubject(node)->Release(); + GELOGD("[%s] Node released.", node->GetName().c_str()); +} + +bool NodeDoneManager::Await(const NodePtr &node) { + auto sub = GetSubject(node); + GELOGD("[%s] Await start. is_released = %s", node->GetName().c_str(), sub->IsRelease() ? "true" : "false"); + bool ret = sub->Await(); + GELOGD("[%s] Await ended. is_released = %s", node->GetName().c_str(), sub->IsRelease() ? "true" : "false"); + return ret; +} +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/executor/node_done_manager.h b/src/ge/hybrid/executor/node_done_manager.h new file mode 100644 index 00000000..ccf263d1 --- /dev/null +++ b/src/ge/hybrid/executor/node_done_manager.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_NODE_DONE_COND_MANAGER_H_ +#define GE_HYBRID_EXECUTOR_NODE_DONE_COND_MANAGER_H_ + +#include +#include +#include +#include +#include "graph/node.h" + +namespace ge { +namespace hybrid { +class NodeDoneManager { + public: + void NodeDone(const NodePtr &node); + + bool Await(const NodePtr &node); + + void Reset(); + + private: + class Cond { + public: + bool IsRelease(); + void Release(); + void Cancel(); + bool Await(); + + private: + std::mutex mu_; + std::condition_variable cv_; + bool is_released_ = false; + bool is_cancelled_ = false; + }; + + Cond *GetSubject(const NodePtr &node); + std::mutex mu_; + std::unordered_map subjects_; +}; +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_EXECUTOR_NODE_DONE_COND_MANAGER_H_ diff --git a/src/ge/hybrid/executor/node_state.cc b/src/ge/hybrid/executor/node_state.cc new file mode 100644 index 00000000..6895f158 --- /dev/null +++ b/src/ge/hybrid/executor/node_state.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/executor/node_state.h" +#include "graph/compute_graph.h" + +namespace ge { +namespace hybrid { +NodeState::NodeState(const NodeItem &node_item) { + this->node_item = &node_item; + this->op_desc = node_item.node->GetOpDesc(); +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/node_state.h b/src/ge/hybrid/executor/node_state.h new file mode 100644 index 00000000..b2811bcb --- /dev/null +++ b/src/ge/hybrid/executor/node_state.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_NODE_STATE_H_ +#define GE_HYBRID_EXECUTOR_NODE_STATE_H_ + +#include "hybrid/model/node_item.h" + +namespace ge { +namespace hybrid { + +class NodeTask; + +// 存放一些会变化的信息... +class NodeState { + public: + NodeState() = default; + explicit NodeState(const NodeItem &node_item); + ~NodeState() = default; + + inline int NodeId() const { return node_item->node_id; } + + inline Node *GetNode() const { return node_item->node.get(); } + + OpDesc *GetOpDesc() const { return op_desc.get(); } + + inline const NodeItem *GetNodeItem() const { return node_item; } + + inline const string &GetName() const { return node_item->NodeName(); } + + inline const string &GetType() const { return node_item->NodeType(); } + + // private: + const NodeItem *node_item = nullptr; + std::shared_ptr kernel_task = nullptr; + + bool is_compiled = false; + OpDescPtr op_desc; +}; + +using NodeStatePtr = std::shared_ptr; +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_EXECUTOR_NODE_STATE_H_ diff --git a/src/ge/hybrid/executor/rt_callback_manager.cc b/src/ge/hybrid/executor/rt_callback_manager.cc new file mode 100644 index 00000000..1787cf77 --- /dev/null +++ b/src/ge/hybrid/executor/rt_callback_manager.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/executor/rt_callback_manager.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" + +namespace ge { +namespace hybrid { +CallbackManager::CallbackManager(rtStream_t stream) : stream_(stream) {} + +Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) { + GELOGD("To register callback"); + rtEvent_t event = nullptr; + GE_CHK_RT_RET(rtEventCreate(&event)); + GE_CHK_RT_RET(rtEventRecord(event, stream_)); + auto cb = std::pair(callback, user_data); + auto entry = std::pair>(event, std::move(cb)); + if (!callback_queue_.Push(entry)) { + return INTERNAL_ERROR; + } + + GELOGD("Registering callback successfully"); + return SUCCESS; +} + +Status CallbackManager::Init() { + rtContext_t ctx = nullptr; + GE_CHK_RT_RET(rtCtxGetCurrent(&ctx)); + ret_future_ = std::async([&](rtContext_t context) -> Status { return CallbackProcess(context); }, ctx); + + if (!ret_future_.valid()) { + GELOGE(INTERNAL_ERROR, "Failed to init callback manager."); + return INTERNAL_ERROR; + } + + return SUCCESS; +} + +Status CallbackManager::CallbackProcess(rtContext_t context) { + GE_CHK_RT_RET(rtCtxSetCurrent(context)); + std::pair> entry; + while (true) { + if (!callback_queue_.Pop(entry)) { + GELOGI("CallbackManager stopped"); + return INTERNAL_ERROR; + } + + auto event = entry.first; + if (event == nullptr) { + return SUCCESS; + } + + auto rt_err = rtEventSynchronize(event); + if (rt_err != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtEventSynchronize failed. ret = %d", rt_err); + GE_CHK_RT(rtEventDestroy(event)); + return RT_FAILED; + } + + // TODO reuse event + GE_CHK_RT(rtEventDestroy(event)); + + auto cb_func = entry.second.first; + auto cb_args = entry.second.second; + cb_func(cb_args); + } +} + +Status CallbackManager::Destroy() { + GELOGI("To destroy callback manager."); + if (!ret_future_.valid()) { + GELOGI("CallbackManager not initialized."); + return SUCCESS; + } + + std::pair> eof_entry; + eof_entry.first = nullptr; + callback_queue_.Push(eof_entry); + + auto ret = ret_future_.get(); + GELOGI("Callback manager ended. ret = %u", ret); + return ret; +} + +void CallbackManager::RtCallbackFunc(void *data) { + GELOGD("To invoke callback function"); + auto callback_func = reinterpret_cast *>(data); + (*callback_func)(); + delete callback_func; +} + +Status CallbackManager::RegisterCallback(const std::function &callback) { + auto *func = new (std::nothrow) std::function(callback); + GE_CHECK_NOTNULL(func); + GELOGD("Callback registered"); + return RegisterCallback(RtCallbackFunc, func); +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/rt_callback_manager.h b/src/ge/hybrid/executor/rt_callback_manager.h new file mode 100644 index 00000000..f102d660 --- /dev/null +++ b/src/ge/hybrid/executor/rt_callback_manager.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_RT_CALLBACK_MANAGER_H_ +#define GE_HYBRID_EXECUTOR_RT_CALLBACK_MANAGER_H_ + +#include +#include +#include +#include + +#include "common/blocking_queue.h" +#include "ge/ge_api_error_codes.h" +#include "runtime/rt.h" + +namespace ge { +namespace hybrid { +class CallbackManager { + public: + explicit CallbackManager(rtStream_t stream); + + ~CallbackManager() = default; + + Status Init(); + + Status Destroy(); + + Status RegisterCallback(rtCallback_t callback, void *user_data); + Status RegisterCallback(const std::function &callback); + + private: + Status CallbackProcess(rtContext_t context); + static void RtCallbackFunc(void *data); + + BlockingQueue>> callback_queue_; + rtStream_t stream_; + std::future ret_future_; +}; +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_EXECUTOR_RT_CALLBACK_MANAGER_H_ diff --git a/src/ge/hybrid/executor/worker/execution_engine.cc b/src/ge/hybrid/executor/worker/execution_engine.cc new file mode 100644 index 00000000..f4657cd4 --- /dev/null +++ b/src/ge/hybrid/executor/worker/execution_engine.cc @@ -0,0 +1,201 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/executor/worker/execution_engine.h" +#include +#include "graph/runtime_inference_context.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/tensor_adapter.h" +#include "hybrid/node_executor/node_executor.h" + +namespace ge { +namespace hybrid { +class NodeDoneCallback { + public: + NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr &task_context); + ~NodeDoneCallback() = default; + Status OnNodeDone(); + + private: + Status PrepareConstInputs(const NodeItem &node_item); + GraphExecutionContext *graph_context_; + std::shared_ptr context_; +}; + +NodeDoneCallback::NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr &task_context) + : graph_context_(graph_context), context_(task_context) {} + +Status NodeDoneCallback::PrepareConstInputs(const NodeItem &node_item) { + for (auto output_idx : node_item.to_const_output_id_list) { + RECORD_CALLBACK_EVENT(graph_context_, node_item.NodeName().c_str(), "[PrepareConstInputs] [index = %d] Start", + output_idx); + + auto output_tensor = context_->GetOutput(output_idx); + GE_CHECK_NOTNULL(output_tensor); + + vector host_buffer(output_tensor->GetSize()); + GELOGD("[%s] To cache output[%d] to host, size = %zu", node_item.NodeName().c_str(), output_idx, + output_tensor->GetSize()); + GE_CHK_RT_RET(rtMemcpy(host_buffer.data(), host_buffer.size(), output_tensor->GetData(), output_tensor->GetSize(), + RT_MEMCPY_HOST_TO_DEVICE)); + Tensor tensor; + tensor.SetData(host_buffer); + auto ge_tensor_desc = node_item.op_desc->MutableOutputDesc(output_idx); + GE_CHECK_NOTNULL(ge_tensor_desc); + tensor.SetTensorDesc(TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc)); + + string session_id = std::to_string(context_->GetSessionId()); + RuntimeInferenceContext *runtime_infer_ctx = nullptr; + GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx), + "Failed to get RuntimeInferenceContext, session_id = %s", session_id.c_str()); + GE_CHK_STATUS_RET(runtime_infer_ctx->SetTensor(node_item.node_id, output_idx, std::move(tensor)), + "Failed to SetTensor, node = %s, output_index = %d", node_item.NodeName().c_str(), output_idx); + GELOGD("[%s] Output[%d] cached successfully in session: %s. node_id = %d, shape = [%s]", + node_item.NodeName().c_str(), output_idx, session_id.c_str(), node_item.node_id, + ge_tensor_desc->GetShape().ToString().c_str()); + + RECORD_CALLBACK_EVENT(graph_context_, node_item.NodeName().c_str(), "[PrepareConstInputs] [index = %d] End", + output_idx); + } + + return SUCCESS; +} + +Status NodeDoneCallback::OnNodeDone() { + auto &node_item = context_->GetNodeItem(); + GELOGI("[%s] Start callback process.", node_item.NodeName().c_str()); + RECORD_CALLBACK_EVENT(graph_context_, context_->GetNodeName(), "Start"); + + // release inputs + for (int i = 0; i < context_->NumInputs(); ++i) { + context_->ReleaseInput(i); + } + + GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item)); + // PropagateOutputs for type == DEPEND_COMPUTE + if (node_item.shape_inference_type == DEPEND_COMPUTE) { + GE_CHK_STATUS_RET(context_->PropagateOutputs(), "[%s] Failed to propagate outputs failed", + node_item.NodeName().c_str()); + + RECORD_CALLBACK_EVENT(graph_context_, context_->GetNodeName(), "[PropagateOutputs] End"); + } + + // release + if (node_item.has_observer) { + GELOGI("[%s] Notify observer. node_id = %d", node_item.NodeName().c_str(), node_item.node_id); + graph_context_->cv_manager.NodeDone(node_item.node); + } + + RECORD_CALLBACK_EVENT(graph_context_, context_->GetNodeName(), "[Callback] End"); + return SUCCESS; +} + +ExecutionEngine::ExecutionEngine(GraphExecutionContext *context, CallbackManager *callback_manager) + : context_(context), callback_manager_(callback_manager) {} + +Status ExecutionEngine::Start() { + GE_CHK_STATUS_RET_NOLOG(ExecutionProcess()); + return SUCCESS; +} + +Status ExecutionEngine::ExecutionProcess() { + GELOGI("ExecutorEngine worker started"); + auto &ready_queue = context_->execution_queue; + while (true) { + NodeStatePtr node_state = nullptr; + if (!ready_queue.Pop(node_state)) { + GELOGE(FAILED, "Pop task failed"); + return FAILED; + } + + // EOF + if (node_state == nullptr) { + break; + } + + RECORD_EXECUTION_EVENT(context_, node_state->GetName().c_str(), "Start"); + GELOGI("[%s] Node is ready for execution", node_state->GetName().c_str()); + auto *node_item = node_state->node_item; + auto task_context = TaskContext::Create(*node_item, context_); + GE_CHECK_NOTNULL(task_context); + auto shared_task_context = shared_ptr(task_context.release()); + + auto cb = std::shared_ptr(new (std::nothrow) NodeDoneCallback(context_, shared_task_context)); + auto callback = [&, cb]() { + auto ret = cb->OnNodeDone(); + if (ret != SUCCESS) { + context_->OnError(ret); + } + }; + + GE_CHK_STATUS_RET_NOLOG(ExecuteAsync(*node_state, *shared_task_context, callback)); + GE_CHK_STATUS_RET_NOLOG(PropagateOutputs(*node_item, *shared_task_context)); + } + + GELOGI("ExecutorEngine worker ended."); + return SUCCESS; +} + +Status ExecutionEngine::ExecuteAsync(NodeState &node_state, TaskContext &task_context, + const std::function &callback) { + const auto &task = node_state.kernel_task; + if (task == nullptr) { + GELOGE(INTERNAL_ERROR, "[%s] NodeTask is null.", node_state.GetName().c_str()); + return INTERNAL_ERROR; + } + + RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[PrepareTask] Start"); + auto executor = node_state.node_item->node_executor; + GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[%s] Failed to prepare task", + node_state.GetName().c_str()); + RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[PrepareTask] End"); + GELOGD("[%s] Done task preparation successfully.", node_state.GetName().c_str()); + + if (context_->trace_enabled) { + for (auto i = 0; i < task_context.NumInputs(); ++i) { + const auto &input_tensor = task_context.GetInput(i); + GE_CHECK_NOTNULL(input_tensor); + GELOGD("[%s] Tensor of input[%d] = %s", node_state.GetName().c_str(), i, input_tensor->DebugString().c_str()); + } + + for (auto i = 0; i < task_context.NumOutputs(); ++i) { + const auto &output_tensor = task_context.GetOutput(i); + GE_CHECK_NOTNULL(output_tensor); + GELOGD("[%s] Tensor of output[%d] = %s", node_state.GetName().c_str(), i, output_tensor->DebugString().c_str()); + } + } + + RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[ExecuteTask] Start"); + GE_CHK_STATUS_RET(executor->ExecuteTask(*task, task_context, callback), "[%s] Failed to execute task", + node_state.GetName().c_str()); + RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[ExecuteTask] End"); + + GELOGD("[%s] Done task launch successfully.", node_state.GetName().c_str()); + return SUCCESS; +} + +Status ExecutionEngine::PropagateOutputs(const NodeItem &node_item, TaskContext &task_context) { + if (node_item.shape_inference_type != DEPEND_COMPUTE) { + GE_CHK_STATUS_RET(task_context.PropagateOutputs(), "[%s] Failed to propagate outputs.", + node_item.NodeName().c_str()); + RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[PropagateOutputs] End"); + } + + GELOGD("[%s] Done propagating outputs successfully.", node_item.NodeName().c_str()); + return SUCCESS; +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/worker/execution_engine.h b/src/ge/hybrid/executor/worker/execution_engine.h new file mode 100644 index 00000000..f5f317af --- /dev/null +++ b/src/ge/hybrid/executor/worker/execution_engine.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_EXECUTOR_EXECUTION_ENGINE_H_ +#define GE_HYBRID_EXECUTOR_EXECUTOR_EXECUTION_ENGINE_H_ + +#include "common/thread_pool.h" +#include "hybrid/common/npu_memory_allocator.h" +#include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/rt_callback_manager.h" +#include "hybrid/node_executor/task_context.h" + +namespace ge { +namespace hybrid { +class ExecutionEngine { + public: + explicit ExecutionEngine(GraphExecutionContext *context, CallbackManager *callback_manager); + ~ExecutionEngine() = default; + + Status Start(); + + private: + Status PropagateOutputs(const NodeItem &node_item, TaskContext &task_context); + + Status ExecutionProcess(); + + Status ExecuteAsync(NodeState &node_state, TaskContext &task_context, const std::function &callback); + + GraphExecutionContext *context_; + CallbackManager *callback_manager_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_EXECUTOR_EXECUTOR_EXECUTION_ENGINE_H_ diff --git a/src/ge/hybrid/executor/worker/shape_inference_engine.cc b/src/ge/hybrid/executor/worker/shape_inference_engine.cc new file mode 100644 index 00000000..90082fff --- /dev/null +++ b/src/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -0,0 +1,258 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/executor/worker/shape_inference_engine.h" + +#include "graph/shape_refiner.h" +#include "graph/runtime_inference_context.h" +#include "graph/utils/node_utils.h" +#include "hybrid/node_executor/node_executor.h" + +namespace ge { +namespace hybrid { + +ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *context) : context_(context) {} + +Status ShapeInferenceEngine::Start(ThreadPool &pool) { + GELOGI("RuntimeShapeInferenceEngine start."); + pool.commit([&]() { + auto ret = this->InferShapeProcess(); + InferenceDone(ret); + }); + + return SUCCESS; +} + +Status ShapeInferenceEngine::InferShapeProcess() { + GELOGI("RuntimeShapeInferenceEngine worker start."); + const auto &root_nodes = context_->model->RootNodes(); + auto &complete_queue = context_->compile_queue; + std::queue ready_nodes; + for (auto &node_item : root_nodes) { + auto infer_state = GetOrCreateEntry(*node_item); + GE_CHECK_NOTNULL(infer_state); + ready_nodes.emplace(infer_state); + } + + while (!ready_nodes.empty()) { + InferenceState *infer_state = ready_nodes.front(); + ready_nodes.pop(); + auto node_item = infer_state->node_item; + // even for non-dynamic shape node, it is still necessary to wait for pending shapes if got any. + // which indicates that the parent node is of type 4, in which case the inputs will be valid only + // when computing is done. + GE_CHK_STATUS_RET(infer_state->AwaitShapeFutures(context_), "Await shape failed."); + GELOGI("[%s] Node is ready for shape inference.", node_item.NodeName().c_str()); + if (node_item.is_dynamic) { + // may block + RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "Start"); + GELOGI("[%s] Start to invoke InferShape", node_item.NodeName().c_str()); + auto ret = InferShape(*infer_state); + if (ret != SUCCESS) { + return ret; + } + + RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start"); + GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().CalcOpRunningParam(*node_item.node), + "[%s] Failed to invoke CalcOpRunningParam.", node_item.NodeName().c_str()); + RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] End"); + } else { + GELOGD("[%s] Skip static shape node", node_item.NodeName().c_str()); + } + + if (node_item.node_type != NETOUTPUT) { + GELOGI("[%s] Push to compile queue", node_item.NodeName().c_str()); + // may block if full + auto node_state = context_->GetOrCreateNodeState(node_item.node); + complete_queue.Push(node_state); + } + + // Propagate + RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] Start"); + PropagateOutputShapes(*infer_state, ready_nodes); + RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] End"); + } + + return SUCCESS; +} + +void ShapeInferenceEngine::InferenceDone(Status status) { + if (status != SUCCESS) { + GELOGE(status, "Error occurred while shape inference"); + context_->OnError(status); + } else { + context_->compile_queue.Push(nullptr); + } + inference_states_.clear(); + GELOGI("RuntimeShapeInferenceEngine worker END"); +} + +Status ShapeInferenceEngine::InferShape(InferenceState &entry) { + // input shapes are ready, wait for dependent data if has any + const auto &node_item = entry.node_item; + if (!node_item.dependent_node_list.empty()) { + for (auto &src_node : node_item.dependent_node_list) { + auto *src_node_item = context_->model->GetNodeItem(src_node); + GELOGI("[%s] Start to wait for data dependent node: %s", node_item.NodeName().c_str(), + src_node_item->NodeName().c_str()); + RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[AwaitNodeDone] [%s] Start", + src_node->GetName().c_str()); + if (!context_->cv_manager.Await(src_node)) { + GELOGE(INTERNAL_ERROR, "[%s] Await node failed.", src_node_item->NodeName().c_str()); + return INTERNAL_ERROR; + } + + RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[AwaitNodeDone] [%s] End", + src_node->GetName().c_str()); + GELOGI("[%s] Done waiting node.", src_node_item->NodeName().c_str()); + } + } + + if (node_item.shape_inference_type == DEPEND_COMPUTE) { + GELOGD("[%s] Skip node with unknown shape type DEPEND_COMPUTE", node_item.NodeName().c_str()); + return SUCCESS; + } + + if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE) { + // in case InferFunc forgot to reset output shape + for (auto &output_desc : node_item.op_desc->GetAllOutputsDescPtr()) { + output_desc->SetShape(GeShape({UNKNOWN_DIM_NUM})); + } + } + + // do shape inference + RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[InferShape] Start"); + GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); + GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndType(node_item.node), "Invoke InferShapeAndType failed."); + RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[InferShape] End"); + + // Check shape + if (node_item.shape_inference_type != DEPEND_SHAPE_RANGE) { + bool is_unknown_shape = false; + GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node_item.node, is_unknown_shape), + "Failed to get shape status. node = %s", node_item.NodeName().c_str()); + + GE_CHK_BOOL_RET_STATUS(!is_unknown_shape, INTERNAL_ERROR, "[%s] Shape is still unknown after shape inference.", + node_item.NodeName().c_str()); + } + + GELOGD("[%s] InferShapeAndType finished successfully.", node_item.NodeName().c_str()); + return SUCCESS; +} + +void ShapeInferenceEngine::PropagateOutputShapes(InferenceState &entry, std::queue &queue) { + auto &node_item = entry.node_item; + // output shape will not be valid until compute is done. + bool shape_is_future = + node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE; + GELOGD("[%s] Start to propagate output shapes. shape_type = %d", node_item.NodeName().c_str(), + node_item.shape_inference_type); + + // propagate each output + for (int i = 0; i < node_item.num_outputs; ++i) { + auto output_desc = node_item.op_desc->MutableOutputDesc(i); + const auto &shape = output_desc->MutableShape(); + const auto &ori_shape = output_desc->GetOriginShape(); + auto &output_nodes = node_item.outputs[i]; + + // propagate output to all sub-inputs + for (auto &dst_input_index_and_node : output_nodes) { + auto &dst_node_item = dst_input_index_and_node.second; + auto inference_state = GetOrCreateEntry(*dst_node_item); + GELOGI("[%s] Update dst node [%s], input index = %d", node_item.NodeName().c_str(), + dst_node_item->NodeName().c_str(), dst_input_index_and_node.first); + + // in case type 3/4, shape will be valid after computing is done + if (shape_is_future) { + ShapeFuture future(node_item.node, i, &context_->cv_manager); + inference_state->UpdateInputShapeFuture(dst_input_index_and_node.first, std::move(future)); + } else { + inference_state->UpdateInputShape(dst_input_index_and_node.first, ori_shape, shape); + } + + if (inference_state->IsInputShapesReady()) { + GELOGI("[%s] Node input shape is ready, add to queue.", inference_state->node_item.NodeName().c_str()); + queue.emplace(inference_state); + } + } + } + + GELOGD("[%s] Propagating output shapes finished successfully.", node_item.NodeName().c_str()); +} + +ShapeInferenceEngine::InferenceState *ShapeInferenceEngine::GetOrCreateEntry(const NodeItem &node_item) { + auto &node_state = inference_states_[node_item.node_id]; + if (node_state == nullptr) { + node_state.reset(new (std::nothrow) InferenceState(node_item)); + } + + return node_state.get(); +} + +ShapeInferenceEngine::InferenceState::InferenceState(const NodeItem &node_item) : node_item(node_item) { + this->num_pending_shapes = node_item.num_inputs; +} + +void ShapeInferenceEngine::InferenceState::UpdateInputShape(uint32_t idx, const GeShape &ori_shape, + const GeShape &shape) { + if (node_item.const_input_shapes.count(idx) != 0) { + GELOGD("[%s] Trying to update constant shape, idx = %u. old shape = [%s], new shape = [%s]", + node_item.NodeName().c_str(), idx, node_item.op_desc->MutableInputDesc(idx)->GetShape().ToString().c_str(), + shape.ToString().c_str()); + } + + GELOGD("[%s] Update input shape [%u] with Shape: [%s] and OriginalShape: [%s]", node_item.NodeName().c_str(), idx, + shape.ToString().c_str(), ori_shape.ToString().c_str()); + num_pending_shapes -= 1; + node_item.op_desc->MutableInputDesc(idx)->SetShape(shape); + node_item.op_desc->MutableInputDesc(idx)->SetOriginShape(ori_shape); +} + +void ShapeInferenceEngine::InferenceState::UpdateInputShapeFuture(uint32_t idx, ShapeFuture &&future) { + if (node_item.const_input_shapes.count(idx) != 0) { + GELOGE(INTERNAL_ERROR, "[%s] Trying to update constant shape, idx = %u", node_item.NodeName().c_str(), idx); + return; + } + + GELOGD("[%s] Update input shape [%u] with ShapeFuture.", node_item.NodeName().c_str(), idx); + num_pending_shapes -= 1; + shape_futures.emplace_back(idx, std::move(future)); +} + +Status ShapeInferenceEngine::InferenceState::AwaitShapeFutures(GraphExecutionContext *context) { + for (auto &p : shape_futures) { + auto idx = p.first; + auto &future = p.second; + GeShape shape; + GeShape ori_shape; + RECORD_SHAPE_INFERENCE_EVENT(context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); + GE_CHK_STATUS_RET(future.Get(ori_shape, shape), "[%s] Get shape failed. index = %u", node_item.NodeName().c_str(), + idx); + RECORD_SHAPE_INFERENCE_EVENT(context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); + + GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s]", node_item.NodeName().c_str(), idx, + shape.ToString().c_str(), ori_shape.ToString().c_str()); + node_item.op_desc->MutableInputDesc(idx)->SetShape(std::move(shape)); + node_item.op_desc->MutableInputDesc(idx)->SetOriginShape(ori_shape); + } + + return SUCCESS; +} + +ShapeInferenceEngine::ShapeFuture::ShapeFuture(NodePtr src_node, uint32_t src_index, NodeDoneManager *node_done_manager) + : src_node_(std::move(src_node)), src_index_(src_index), node_done_manager_(node_done_manager) {} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/worker/shape_inference_engine.h b/src/ge/hybrid/executor/worker/shape_inference_engine.h new file mode 100644 index 00000000..b1e1c879 --- /dev/null +++ b/src/ge/hybrid/executor/worker/shape_inference_engine.h @@ -0,0 +1,92 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_INFERSHAPE_SHAPE_INFERENCE_ENGINE_H_ +#define GE_HYBRID_EXECUTOR_INFERSHAPE_SHAPE_INFERENCE_ENGINE_H_ + +#include +#include +#include +#include "common/thread_pool.h" +#include "hybrid/executor/hybrid_execution_context.h" + +namespace ge { +namespace hybrid { +class ShapeInferenceEngine { + public: + explicit ShapeInferenceEngine(GraphExecutionContext *context); + + ~ShapeInferenceEngine() = default; + + Status Start(ThreadPool &pool); + + private: + class ShapeFuture { + public: + ShapeFuture(NodePtr src_node, uint32_t src_index, NodeDoneManager *node_done_manager); + ~ShapeFuture() = default; + Status Get(GeShape &ori_shape, GeShape &shape) { + GELOGI("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); + if (!node_done_manager_->Await(src_node_)) { + GELOGE(INTERNAL_ERROR, "cancelled"); + return INTERNAL_ERROR; + } + + shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->MutableShape(); + ori_shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->GetOriginShape(); + GELOGI("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); + return SUCCESS; + } + + private: + NodePtr src_node_; + uint32_t src_index_; + NodeDoneManager *node_done_manager_; + }; + + struct InferenceState { + explicit InferenceState(const NodeItem &node_item); + inline bool IsInputShapesReady() const { return num_pending_shapes == 0; } + + void UpdateInputShape(uint32_t idx, const GeShape &ori_shape, const GeShape &shape); + + Status AwaitShapeFutures(GraphExecutionContext *context); + + void UpdateInputShapeFuture(uint32_t idx, ShapeFuture &&future); + + const NodeItem &node_item; + + private: + std::vector> shape_futures; + int num_pending_shapes = 0; + }; + + InferenceState *GetOrCreateEntry(const NodeItem &node_item); + + Status InferShapeProcess(); + + void InferenceDone(Status status); + + Status InferShape(InferenceState &entry); + + void PropagateOutputShapes(InferenceState &entry, std::queue &queue); + + GraphExecutionContext *context_; + std::unordered_map> inference_states_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_EXECUTOR_INFERSHAPE_SHAPE_INFERENCE_ENGINE_H_ diff --git a/src/ge/hybrid/executor/worker/task_compile_engine.cc b/src/ge/hybrid/executor/worker/task_compile_engine.cc new file mode 100644 index 00000000..07e70a93 --- /dev/null +++ b/src/ge/hybrid/executor/worker/task_compile_engine.cc @@ -0,0 +1,186 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/executor/worker/task_compile_engine.h" +#include "init/gelib.h" +#include "framework/common/debug/log.h" +#include "hybrid/node_executor/node_executor.h" + +namespace ge { +namespace hybrid { +namespace { +uint32_t kDefaultWorkerCnt = 4; +uint32_t kDefaultDeviceId = 0; +} // namespace +TaskCompileEngine::TaskCompileEngine(GraphExecutionContext *context) : context_(context), pool_(kDefaultWorkerCnt) {} + +TaskCompileEngine::~TaskCompileEngine() { + if (rt_context_ != nullptr) { + GELOGD("To destroy compile context: %p.", rt_context_); + GE_CHK_RT(rtCtxDestroy(rt_context_)); + } +} + +Status TaskCompileEngine::Init() { + GELOGD("Start to init CompileEngine"); + rtContext_t current_ctx = nullptr; + GE_CHK_RT(rtCtxGetCurrent(¤t_ctx)); + GE_CHK_RT_RET(rtCtxCreate(&rt_context_, RT_CTX_GEN_MODE, kDefaultDeviceId)); + GELOGD("Context created for compiling. ctx = %p", rt_context_); + GE_CHK_RT_RET(rtCtxSetCurrent(current_ctx)); + return SUCCESS; +} + +void TaskCompileEngine::Reset() { + complete_queue_.Push(nullptr); // ensure iteration can stop + unique_ptr entry; + while (true) { + complete_queue_.Pop(entry); + if (entry == nullptr) { + break; + } + + if (entry->future != nullptr) { + entry->future->wait(); + } + } + + complete_queue_.Clear(); +} + +Status TaskCompileEngine::Start(ThreadPool &pool) { + pool.commit([&]() { (void)this->CompileProcess(); }); + + worker_future_ = pool_.commit([&]() -> Status { return this->DistributeCompiledTasks(); }); + + if (!worker_future_.valid()) { + GELOGE(INTERNAL_ERROR, "Failed to start worker thread"); + return INTERNAL_ERROR; + } + + return SUCCESS; +} + +Status TaskCompileEngine::CompileProcess() { + auto &compile_queue = context_->compile_queue; + while (true) { + NodeStatePtr node_state; + // Stop() will not be invoked, Pop won't failed + (void)compile_queue.Pop(node_state); + + // EOF + if (node_state == nullptr) { + GELOGD("Got EOF"); + complete_queue_.Push(unique_ptr()); + break; + } + + auto entry = unique_ptr(new (std::nothrow) ResultQueueEntry()); + entry->node_state = node_state; + + auto node_item = *node_state->node_item; + if (node_item.kernel_task != nullptr) { + GELOGD("use precompiled task. node name = %s", node_item.NodeName().c_str()); + node_state->kernel_task = node_item.kernel_task; + complete_queue_.Push(std::move(entry)); + continue; + } + + auto ret = CompileAsync(*node_state->node_item, *entry); + if (ret == SUCCESS) { + complete_queue_.Push(std::move(entry)); + continue; + } + + // On Error + worker_future_.wait(); + Reset(); + return CompileDone(ret); + } + + Status ret = worker_future_.get(); + Reset(); + return CompileDone(ret); +} + +Status TaskCompileEngine::CompileDone(Status status) { + if (status != SUCCESS) { + GELOGE(status, "Error occurred while compiling node."); + context_->OnError(status); + } else { + context_->execution_queue.Push(nullptr); + } + GELOGI("CompileEngine worker END. ret = %u", status); + return status; +} + +Status TaskCompileEngine::DoCompile(const NodeItem &node_item, NodeState &node_state) { + RECORD_COMPILE_EVENT(context_, node_state.GetName().c_str(), "Start"); + GE_CHK_RT_RET(rtCtxSetCurrent(rt_context_)); + auto ret = node_item.node_executor->CompileTask(*context_->model, node_item.node, node_state.kernel_task); + RECORD_COMPILE_EVENT(context_, node_state.GetName().c_str(), "End"); + GE_CHK_STATUS_RET(ret, "Failed to create task for node: %s", node_item.NodeName().c_str()); + GELOGI("Compiling node %s successfully", node_state.GetName().c_str()); + return SUCCESS; +} + +Status TaskCompileEngine::CompileAsync(const NodeItem &node_item, ResultQueueEntry &entry) { + auto node_state = entry.node_state; + auto f = pool_.commit([this, node_item, node_state]() -> Status { return DoCompile(node_item, *node_state); }); + + if (!f.valid()) { + GELOGE(INTERNAL_ERROR, "Failed to commit compile task"); + return INTERNAL_ERROR; + } + + entry.future = unique_ptr>(new (std::nothrow) std::future(std::move(f))); + GE_CHECK_NOTNULL(entry.future); + return SUCCESS; +} + +Status TaskCompileEngine::DistributeCompiledTasks() { + GELOGD("DistributeCompiledTasks start."); + auto &execute_queue = context_->execution_queue; + unique_ptr entry; + bool ret = SUCCESS; + while (true) { + if (!complete_queue_.Pop(entry)) { + GELOGE(INTERNAL_ERROR, "Failed to pop item from queue"); + ret = INTERNAL_ERROR; + break; + } + + // EOF + if (entry == nullptr) { + break; + } + + // if has compile future + if (entry->future != nullptr) { + ret = entry->future->get(); + if (ret != SUCCESS) { + break; + } + } + + execute_queue.Push(entry->node_state); + } + + GELOGD("DistributeCompiledTasks out. ret = %u.", ret); + return ret; +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/worker/task_compile_engine.h b/src/ge/hybrid/executor/worker/task_compile_engine.h new file mode 100644 index 00000000..828a1d8c --- /dev/null +++ b/src/ge/hybrid/executor/worker/task_compile_engine.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_COMPILE_TASK_COMPILE_ENGINE_H_ +#define GE_HYBRID_EXECUTOR_COMPILE_TASK_COMPILE_ENGINE_H_ + +#include +#include +#include "common/thread_pool.h" +#include "hybrid/executor/hybrid_execution_context.h" + +namespace ge { +namespace hybrid { +class TaskCompileEngine { + public: + explicit TaskCompileEngine(GraphExecutionContext *context); + + ~TaskCompileEngine(); + + Status Init(); + + Status Start(ThreadPool &pool); + + private: + struct ResultQueueEntry { + NodeStatePtr node_state; + std::unique_ptr> future; + }; + + Status CompileProcess(); + + Status CompileDone(Status status); + + private: + Status DoCompile(const NodeItem &node_item, NodeState &node_state); + Status CompileAsync(const NodeItem &node_item, ResultQueueEntry &entry); + Status DistributeCompiledTasks(); + void Reset(); + + rtContext_t rt_context_ = nullptr; + GraphExecutionContext *context_; + BlockingQueue> complete_queue_; + ThreadPool pool_; + std::future worker_future_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_EXECUTOR_COMPILE_TASK_COMPILE_ENGINE_H_ diff --git a/src/ge/hybrid/hybrid_davinci_model.cc b/src/ge/hybrid/hybrid_davinci_model.cc new file mode 100644 index 00000000..58c7d0e3 --- /dev/null +++ b/src/ge/hybrid/hybrid_davinci_model.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "hybrid_davinci_model.h" +#include "hybrid/model/hybrid_model.h" +#include "hybrid/executor/hybrid_model_async_executor.h" + +namespace ge { +namespace hybrid { +class HybridDavinciModel::Impl { + public: + explicit Impl(GeRootModelPtr ge_model) : model_(std::move(ge_model)), executor_(&model_) {} + + ~Impl() = default; + + Status Init() { + GE_CHK_STATUS_RET(model_.Init(), "Failed to init model.") + GE_CHK_STATUS_RET(executor_.Init(), "Failed to init model executor.") + return SUCCESS; + } + + Status ModelRunStart() { return executor_.Start(listener_); } + + Status ModelRunStop() { return executor_.Stop(); } + + Status EnqueueData(const std::shared_ptr &data) { return executor_.EnqueueData(data); } + + void SetListener(const shared_ptr &listener) { listener_ = listener; } + + void SetModelId(uint32_t model_id) { + executor_.SetModelId(model_id); + model_.SetModelId(model_id); + } + + void SetDeviceId(uint32_t device_id) { + model_.SetDeviceId(device_id); + executor_.SetDeviceId(device_id); + } + + private: + std::shared_ptr listener_; + HybridModel model_; + HybridModelAsyncExecutor executor_; +}; + +HybridDavinciModel::~HybridDavinciModel() { delete impl_; } + +unique_ptr HybridDavinciModel::Create(const GeRootModelPtr &ge_root_model) { + auto instance = unique_ptr(new (std::nothrow) HybridDavinciModel()); + if (instance != nullptr) { + instance->impl_ = new (std::nothrow) HybridDavinciModel::Impl(ge_root_model); + if (instance->impl_ != nullptr) { + return instance; + } + } + + return nullptr; +} + +Status HybridDavinciModel::Init() { + GE_CHECK_NOTNULL(impl_); + return impl_->Init(); +} + +Status HybridDavinciModel::ModelRunStart() { + GE_CHECK_NOTNULL(impl_); + return impl_->ModelRunStart(); +} + +Status HybridDavinciModel::ModelRunStop() { + GE_CHECK_NOTNULL(impl_); + return impl_->ModelRunStop(); +} + +Status HybridDavinciModel::EnqueueData(const shared_ptr &data) { + GE_CHECK_NOTNULL(impl_); + return impl_->EnqueueData(data); +} + +void HybridDavinciModel::SetListener(const shared_ptr &listener) { + if (impl_ != nullptr) { + impl_->SetListener(listener); + } +} + +void HybridDavinciModel::SetModelId(uint32_t model_id) { + if (impl_ != nullptr) { + impl_->SetModelId(model_id); + } +} + +void HybridDavinciModel::SetDeviceId(uint32_t device_id) { + if (impl_ != nullptr) { + impl_->SetDeviceId(device_id); + } +} +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/hybrid_davinci_model.h b/src/ge/hybrid/hybrid_davinci_model.h new file mode 100644 index 00000000..866b756b --- /dev/null +++ b/src/ge/hybrid/hybrid_davinci_model.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRID_HYBRID_DAVINCI_MODEL_H_ +#define HYBRID_HYBRID_DAVINCI_MODEL_H_ + +#include +#include "external/ge/ge_api_error_codes.h" +#include "graph/load/new_model_manager/data_inputer.h" +#include "model/ge_root_model.h" + +namespace ge { +namespace hybrid { +class HybridDavinciModel { + public: + ~HybridDavinciModel(); + + HybridDavinciModel(const HybridDavinciModel &) = delete; + HybridDavinciModel(HybridDavinciModel &&) = delete; + HybridDavinciModel &operator=(const HybridDavinciModel &) = delete; + HybridDavinciModel &operator=(HybridDavinciModel &&) = delete; + + static std::unique_ptr Create(const GeRootModelPtr &ge_root_model); + + Status Init(); + + Status ModelRunStart(); + + Status ModelRunStop(); + + Status EnqueueData(const std::shared_ptr &data); + + void SetListener(const shared_ptr &listener); + + void SetModelId(uint32_t model_id); + + void SetDeviceId(uint32_t device_id); + + private: + HybridDavinciModel() = default; + class Impl; + Impl *impl_ = nullptr; +}; +} // namespace hybrid +} // namespace ge +#endif // HYBRID_HYBRID_DAVINCI_MODEL_H_ diff --git a/src/ge/hybrid/hybrid_davinci_model_stub.cc b/src/ge/hybrid/hybrid_davinci_model_stub.cc new file mode 100644 index 00000000..bca118f8 --- /dev/null +++ b/src/ge/hybrid/hybrid_davinci_model_stub.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid_davinci_model.h" + +namespace ge { +namespace hybrid { +HybridDavinciModel::~HybridDavinciModel() {} + +std::unique_ptr HybridDavinciModel::Create(const GeRootModelPtr &ge_root_model) { + return std::unique_ptr(new (std::nothrow) HybridDavinciModel()); +} + +Status HybridDavinciModel::Init() { return UNSUPPORTED; } + +Status HybridDavinciModel::ModelRunStart() { return UNSUPPORTED; } + +Status HybridDavinciModel::ModelRunStop() { return UNSUPPORTED; } + +Status HybridDavinciModel::EnqueueData(const shared_ptr &data) { return UNSUPPORTED; } + +void HybridDavinciModel::SetListener(const shared_ptr &listener) {} + +void HybridDavinciModel::SetModelId(uint32_t model_id) {} + +void HybridDavinciModel::SetDeviceId(uint32_t device_id) {} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/model/hybrid_model.cc b/src/ge/hybrid/model/hybrid_model.cc new file mode 100644 index 00000000..e3726aec --- /dev/null +++ b/src/ge/hybrid/model/hybrid_model.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid_model.h" +#include +#include "graph/load/new_model_manager/model_utils.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/tensor_utils.h" +#include "hybrid/common/npu_memory_allocator.h" +#include "hybrid/model/hybrid_model_builder.h" +#include "hybrid/node_executor/node_executor.h" + +namespace ge { +namespace hybrid { +HybridModel::HybridModel(GeRootModelPtr ge_model) : ge_root_model_(std::move(ge_model)) {} + +Status HybridModel::Init() { + GELOGD("Start to init hybrid model."); + GE_CHK_STATUS_RET(HybridModelBuilder(*this).Build(), "Failed to build hybrid model."); + GELOGD("HybridModel initialized successfully."); + return SUCCESS; +} + +void HybridModel::Print() const { + for (const auto &node : node_items_) { + GELOGD("%s", node->DebugString().c_str()); + } +} + +TensorValue *HybridModel::GetWeight(const NodeItem *const_node) const { + auto it = weights_.find(const_node->node_id); + if (it == weights_.end() || it->second == nullptr) { + GELOGE(INTERNAL_ERROR, "[%s] Failed to get weight", const_node->NodeName().c_str()); + return nullptr; + } + + return it->second.get(); +} + +TensorValue *HybridModel::GetVariable(const string &name) const { + auto it = variable_tensors_.find(name); + if (it == variable_tensors_.end()) { + GELOGI("Failed to get variable tensor. var name = [%s]", name.c_str()); + return nullptr; + } + + GELOGD("Got variable tensor. var name = [%s], tensor = %s", name.c_str(), it->second->DebugString().c_str()); + return it->second.get(); +} + +NodePtr HybridModel::GetVariableNode(const string &name) const { + auto it = variable_nodes_.find(name); + if (it == variable_nodes_.end()) { + GELOGI("Failed to get variable node by name = [%s]", name.c_str()); + return nullptr; + } + + return it->second; +} + +const std::vector *HybridModel::GetTaskDefs(const NodePtr &node) const { + auto it = task_defs_.find(node); + if (it == task_defs_.end()) { + return nullptr; + } + + return &it->second; +} + +NodeItem *HybridModel::MutableNodeItem(const NodePtr &node) { + auto node_id = node->GetOpDesc()->GetId(); + if (node_id < 0 || static_cast(node_id) > node_items_.size()) { + GELOGE(INTERNAL_ERROR, "index out of range. node_id = %ld, num_nodes = %zu", node_id, node_items_.size()); + return nullptr; + } + return node_items_[node_id].get(); +} + +const NodeItem *HybridModel::GetNodeItem(const NodePtr &node) const { + auto node_id = node->GetOpDesc()->GetId(); + if (node_id < 0 || static_cast(node_id) > node_items_.size()) { + GELOGE(INTERNAL_ERROR, "Index out of range. node_id = %ld, num_nodes = %zu.", node_id, node_items_.size()); + return nullptr; + } + return node_items_[node_id].get(); +} + +GeModelPtr HybridModel::GetGeModel(const NodePtr &node) const { + auto it = known_shape_sub_graphs_.find(node); + if (it == known_shape_sub_graphs_.end()) { + GELOGE(INTERNAL_ERROR, "[%s] Failed to get GeModel for subgraph node.", node->GetName().c_str()); + return nullptr; + } + + return it->second; +} + +const vector &HybridModel::GetNetOutputInputOffsets() const { return net_output_input_offsets_; } + +void HybridModel::SetDeviceId(uint32_t device_id) { device_id_ = device_id; } +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/model/hybrid_model.h b/src/ge/hybrid/model/hybrid_model.h new file mode 100644 index 00000000..007f76c6 --- /dev/null +++ b/src/ge/hybrid/model/hybrid_model.h @@ -0,0 +1,119 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_HYBRID_GRAPH_H_ +#define GE_HYBRID_HYBRID_GRAPH_H_ + +#include +#include +#include +#include "framework/common/ge_inner_error_codes.h" +#include "graph/load/new_model_manager/data_inputer.h" +#include "graph/load/new_model_manager/task_info/task_info.h" +#include "graph/node.h" +#include "hybrid/common/tensor_value.h" +#include "hybrid/model/node_item.h" +#include "model/ge_root_model.h" + +namespace ge { +namespace hybrid { +class HybridModelAsyncExecutor; +class HybridModel { + public: + explicit HybridModel(GeRootModelPtr ge_model); + + ~HybridModel() = default; + + Status Init(); + + const std::vector &RootNodes() const { return root_nodes_; } + + const NodeItem *GetNodeItem(const NodePtr &node) const; + + size_t NumNodes() const { return node_items_.size(); } + + uint64_t GetSessionId() const { return root_runtime_param_.session_id; } + + int TotalInputs() const { return total_inputs_; } + + const map &GetInputNodes() const { return input_nodes_; } + + const std::map> &GetInputOffsets() const { return input_offsets_; } + + const vector &GetNetOutputInputOffsets() const; + + const std::vector &GetOutputOffsets() const { return output_offsets_; } + + const std::vector &GetConstNodes() const { return const_nodes_; } + + GeModelPtr GetGeModel(const NodePtr &node) const; + + NodeItem *MutableNodeItem(const NodePtr &node); + + size_t TotalVarMemSize() const { return root_runtime_param_.var_size; } + + const uint8_t *GetVarMemBase() const { return var_mem_base_; } + + void SetDeviceId(uint32_t device_id); + + void SetModelId(uint32_t model_id) { model_id_ = model_id; } + + uint32_t GetModelId() const { return model_id_; } + + TensorValue *GetWeight(const NodeItem *const_node) const; + + TensorValue *GetVariable(const string &name) const; + + NodePtr GetVariableNode(const string &name) const; + + const std::vector *GetTaskDefs(const NodePtr &node) const; + + int TotalOutputs() const { return total_outputs_; } + + GeRootModelPtr GetGeRootModel() const { return ge_root_model_; } + void Print() const; + + private: + friend class HybridModelBuilder; + friend class HybridModelAsyncExecutor; + + GeRootModelPtr ge_root_model_; + std::vector root_nodes_; + std::map input_nodes_; + std::map> input_offsets_; + std::vector output_offsets_; + std::vector net_output_input_offsets_; + NodeItem *net_output_node_ = nullptr; + std::vector> node_items_; + std::vector const_nodes_; + std::map constant_op_nodes_; + std::map variable_nodes_; + std::map> variable_tensors_; + std::map> weights_; + std::map> task_defs_; + std::map known_shape_sub_graphs_; + int total_inputs_ = 0; + int total_outputs_ = 0; + + // runtime fields + uint32_t device_id_ = 0; + uint32_t model_id_ = 0; + uint8_t *var_mem_base_ = nullptr; + RuntimeParam root_runtime_param_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_HYBRID_GRAPH_H_ diff --git a/src/ge/hybrid/model/hybrid_model_builder.cc b/src/ge/hybrid/model/hybrid_model_builder.cc new file mode 100644 index 00000000..ce220bde --- /dev/null +++ b/src/ge/hybrid/model/hybrid_model_builder.cc @@ -0,0 +1,956 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/model/hybrid_model_builder.h" +#include "common/math/math_util.h" +#include "graph/utils/node_utils.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/load/new_model_manager/model_utils.h" +#include "graph/manager/graph_var_manager.h" +#include "graph/manager/trans_var_data_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/type_utils.h" +#include "framework/common/debug/log.h" +#include "hybrid/common/npu_memory_allocator.h" +#include "hybrid/node_executor/node_executor.h" + +namespace ge { +namespace hybrid { +namespace { +const uint32_t kSubgraphIndex = 0U; +const uint32_t kVarOutputIndex = 0U; +const int kBytes = 8; + +int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { + int64_t var_size = GetSizeByDataType(desc.GetDataType()); + if (var_size <= 0) { + GELOGE(PARAM_INVALID, "Failed to calc var data size from data type %s", + TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str()); + return -1; + } + auto shape = desc.GetShape(); + auto dim_num = shape.GetDimNum(); + for (size_t dim_index = 0; dim_index < dim_num; ++dim_index) { + var_size *= shape.GetDim(dim_index); + } + return var_size; +} +} // namespace +HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) + : hybrid_model_(hybrid_model), runtime_param_(hybrid_model.root_runtime_param_) { + ge_root_model_ = hybrid_model_.ge_root_model_; +} + +Status HybridModelBuilder::Build() { + GE_CHK_STATUS_RET(ValidateParams(), "Failed to validate GeRootModel"); + graph_name_ = ge_root_model_->GetRootGraph()->GetName(); + GELOGI("[%s] Start to build hybrid model.", GetGraphName()); + GE_CHK_STATUS_RET(InitRuntimeParams(), "[%s] Failed to InitRuntimeParams", GetGraphName()); + GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().EnsureInitialized(), "Failed to initialize executors"); + GE_CHK_STATUS_RET(IndexSpecialNodes(), "[%s] Failed to index nodes", GetGraphName()); + GE_CHK_STATUS_RET(IndexTaskDefs(), "[%s] Failed to index task defs", GetGraphName()); + GE_CHK_STATUS_RET(LoadGraph(), "[%s] Failed to load graph", GetGraphName()); + GE_CHK_STATUS_RET(TransAllVarData(), "[%s] Failed to trans all var data", GetGraphName()); + GE_CHK_STATUS_RET(CopyVarData(), "[%s] Failed to copy var data", GetGraphName()); + GE_CHK_STATUS_RET(InitModelMem(), "[%s] Failed to init memory", GetGraphName()); + // TODO VAR_ATTR_VAR_IS_BROADCAST ??? + GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName()); + GE_CHK_STATUS_RET(InitConstantOps(), "[%s] Failed to init constant op", GetGraphName()); + GE_CHK_STATUS_RET(InitVariableTensors(), "[%s] Failed to init variables", GetGraphName()); + GE_CHK_STATUS_RET(ResolveRootNodes(), "[%s] Failed to resolve root nodes", GetGraphName()); + GE_CHK_STATUS_RET(LoadTasks(), "[%s] Failed to load tasks", GetGraphName()); + GELOGI("[%s] Done building hybrid model successfully.", GetGraphName()); + return SUCCESS; +} + +Status HybridModelBuilder::ValidateParams() { + GE_CHECK_NOTNULL(ge_root_model_); + GE_CHECK_NOTNULL(ge_root_model_->GetRootGraph()); + return SUCCESS; +} + +Status HybridModelBuilder::ResolveRootNodes() { + for (auto &node : hybrid_model_.node_items_) { + if (node->node->GetInDataNodes().empty()) { + hybrid_model_.root_nodes_.emplace_back(node.get()); + GELOGI("[%s] Root node added. node name = %s", GetGraphName(), node->NodeName().c_str()); + } + } + + if (hybrid_model_.root_nodes_.empty()) { + GELOGE(PARAM_INVALID, "[%s] Root nodes is empty.", GetGraphName()); + return PARAM_INVALID; + } + + return SUCCESS; +} + +Status HybridModelBuilder::BuildNoteItem(const NodePtr &node, NodeItem &node_item) { + GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, node_item.is_dynamic), + "[%s] Failed to get shape status.", node->GetName().c_str()); + + auto op_desc = node->GetOpDesc(); + vector dependencies = node->GetOpDesc()->GetOpInferDepends(); + GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies), "[%s] Failed to parse node dependencies.", + node_item.NodeName().c_str()); + + auto it = node_ref_inputs_.find(node); + if (it != node_ref_inputs_.end()) { + for (auto &idx_and_node : it->second) { + // var and constant only have one output + node_item.const_input_shapes[idx_and_node.first] = + idx_and_node.second->GetOpDesc()->MutableOutputDesc(kVarOutputIndex); + } + } + + node_item.outputs.resize(node_item.num_outputs); + for (int i = 0; i < node_item.num_outputs; ++i) { + auto out_data_anchor = node->GetOutDataAnchor(i); + if (out_data_anchor == nullptr) { + GELOGE(INTERNAL_ERROR, "out anchor[%zu] of node %s is nullptr", i, node->GetName().c_str()); + return INTERNAL_ERROR; + } + + for (auto &dst_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + auto dst_node = dst_in_anchor->GetOwnerNode(); + if (dst_node == nullptr) { + GELOGW("dst node is nullptr. out anchor = %d", out_data_anchor->GetIdx()); + continue; + } + + NodeItem *dst_node_item = nullptr; + GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), "[%s] Failed to get or create node item.", + dst_node->GetName().c_str()); + node_item.outputs[i].emplace_back(dst_in_anchor->GetIdx(), dst_node_item); + } + } + + return SUCCESS; +} + +Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item) { + auto &node_items = hybrid_model_.node_items_; + auto node_id = node->GetOpDesc()->GetId(); + if (node_id < 0 || static_cast(node_id) > node_items.size()) { + GELOGE(INTERNAL_ERROR, "[%s] Index out of range. node_id = %ld, num_nodes = %zu", node->GetName().c_str(), node_id, + node_items.size()); + return INTERNAL_ERROR; + } + + auto &node_ptr = node_items[node_id]; + if (node_ptr != nullptr) { + *node_item = node_ptr.get(); + return SUCCESS; + } + + auto new_node = std::unique_ptr(new (std::nothrow) NodeItem(node)); + GE_CHECK_NOTNULL(new_node); + GE_CHECK_NOTNULL(new_node->op_desc); + GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); + + // we do not need L2 Buffer + const char *const kIsFirstNode = "is_first_node"; + const char *const kIsLastNode = "is_last_node"; + (void)AttrUtils::SetBool(new_node->op_desc, kIsFirstNode, false); + (void)AttrUtils::SetBool(new_node->op_desc, kIsLastNode, false); + + int32_t unknown_shape_type_val = 0; + (void)AttrUtils::GetInt(new_node->op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val); + new_node->shape_inference_type = static_cast(unknown_shape_type_val); + if (new_node->shape_inference_type == DEPEND_SHAPE_RANGE || new_node->shape_inference_type == DEPEND_COMPUTE) { + new_node->has_observer = true; + } + + *node_item = new_node.get(); + node_items[node_id] = std::move(new_node); + return SUCCESS; +} + +Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies) { + std::set dependent_input_nodes; + auto &ge_node = node_item.node; + for (const auto &input_name : dependencies) { + int input_index = node_item.op_desc->GetInputIndexByName(input_name); + if (input_index < 0) { + GELOGE(INTERNAL_ERROR, "[%s] Failed to get input index by name: %s", node_item.NodeName().c_str(), + input_name.c_str()); + return INTERNAL_ERROR; + } + + const auto &in_anchor = ge_node->GetInDataAnchor(input_index); + GE_CHECK_NOTNULL(in_anchor); + const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + const auto &src_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + auto src_node_item = MutableNodeItem(src_node); + src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); + src_node_item->has_observer = true; + + dependent_input_nodes.emplace(src_node); + GELOGD("[%s] Dependent added from output of [%s:%d]", node_item.NodeName().c_str(), + src_node_item->NodeName().c_str(), peer_out_anchor->GetIdx()); + } + + for (const auto &dep_node : dependent_input_nodes) { + node_item.dependent_node_list.emplace_back(dep_node); + } + + return SUCCESS; +} + +Status HybridModelBuilder::UpdateAnchorStatus(const NodePtr &node) { + if (NodeUtils::SetAllAnchorStatus(node) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "[%s] NodeUtils::SetAllAnchorStatus failed.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + for (auto &anchor : node->GetAllInDataAnchors()) { + auto peer_anchor = anchor->GetPeerOutAnchor(); + if (peer_anchor == nullptr) { + if (AnchorUtils::SetStatus(anchor, ANCHOR_SUSPEND) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "[%s] AnchorUtils::SetStatus failed.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + } else if (peer_anchor->GetOwnerNode()->GetType() == CONSTANT) { + if (AnchorUtils::SetStatus(anchor, ANCHOR_CONST) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "[%s] AnchorUtils::SetStatus failed.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + } else { + if (AnchorUtils::SetStatus(anchor, ANCHOR_DATA) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "[%s] AnchorUtils::SetStatus failed.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + } + } + + return SUCCESS; +} + +Status HybridModelBuilder::DoUnlinkDataAnchors(const OutDataAnchorPtr &out_data_anchor, + const InDataAnchorPtr &in_data_anchor) { + GE_CHK_GRAPH_STATUS_RET(out_data_anchor->Unlink(in_data_anchor), "Failed to unlink %s:%d from %s:%d", + out_data_anchor->GetOwnerNode()->GetName().c_str(), out_data_anchor->GetIdx(), + in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetIdx()); + + GELOGD("Succeeded in unlinking %s:%d from %s:%d", out_data_anchor->GetOwnerNode()->GetName().c_str(), + out_data_anchor->GetIdx(), in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetIdx()); + return SUCCESS; +} + +Status HybridModelBuilder::DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, InDataAnchorPtr &in_data_anchor) { + GE_CHK_GRAPH_STATUS_RET(out_data_anchor->LinkTo(in_data_anchor), "Failed to link %s:%d to %s:%d", + out_data_anchor->GetOwnerNode()->GetName().c_str(), out_data_anchor->GetIdx(), + in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetIdx()); + + GELOGD("Succeeded in linking %s:%d to %s:%d", out_data_anchor->GetOwnerNode()->GetName().c_str(), + out_data_anchor->GetIdx(), in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetIdx()); + return SUCCESS; +} + +Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { + const auto &wrapped_node = graph.GetParentNode(); + for (const auto &node : graph.GetDirectNode()) { + GE_CHECK_NOTNULL(node); + if (node->GetType() != DATA_TYPE) { + continue; + } + + auto data_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(data_op_desc); + + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(data_op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "[%s] Failed to get attr [%s]", data_op_desc->GetName().c_str(), + ATTR_NAME_PARENT_NODE_INDEX.c_str()); + return FAILED; + } + + auto wrapped_node_in_anchor = wrapped_node->GetInDataAnchor(parent_index); + GE_CHECK_NOTNULL(wrapped_node_in_anchor); + auto src_out_anchor = wrapped_node_in_anchor->GetPeerOutAnchor(); + if (src_out_anchor == nullptr || src_out_anchor->GetOwnerNode() == nullptr) { + continue; + } + auto src_node = wrapped_node_in_anchor->GetPeerOutAnchor()->GetOwnerNode(); + wrapped_node_in_anchor->UnlinkAll(); + + // link src to outputs of DataNode + for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { + GE_CHECK_NOTNULL(out_data_anchor); + for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(out_data_anchor, peer_in_data_anchor)); + GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, peer_in_data_anchor)); + } + } + } + + return SUCCESS; +} + +Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { + const auto &parent_node = graph.GetParentNode(); + const NodePtr &net_output_node = graph.FindNode(NODE_NAME_NET_OUTPUT); + GE_CHECK_NOTNULL(net_output_node); + const auto &net_output_desc = net_output_node->GetOpDesc(); + GE_CHECK_NOTNULL(net_output_desc); + + for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { + auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(src_out_anchor); + GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(src_out_anchor, in_data_anchor)); + + auto index = in_data_anchor->GetIdx(); + auto input_desc = net_output_desc->MutableInputDesc(index); + if (input_desc == nullptr) { + GELOGE(INTERNAL_ERROR, "[%s] Failed to get input desc[%d]", net_output_desc->GetName().c_str(), index); + return INTERNAL_ERROR; + } + + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(input_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGW("SubGraph: %s NetOutput input tensor %d, attr %s not found.", graph.GetName().c_str(), index, + ATTR_NAME_PARENT_NODE_INDEX.c_str()); + continue; + } + + const OutDataAnchorPtr &parent_out_anchor = parent_node->GetOutDataAnchor(parent_index); + GE_CHECK_NOTNULL(parent_out_anchor); + for (InDataAnchorPtr &dst_in_anchor : parent_out_anchor->GetPeerInDataAnchors()) { + if (dst_in_anchor == nullptr) { + continue; + } + + GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(parent_out_anchor, dst_in_anchor)); + GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, dst_in_anchor)); + } + } + + return SUCCESS; +} + +Status HybridModelBuilder::MergeSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { + merged_graph = MakeShared("MergedGraph"); + for (const auto &node : root_graph.GetDirectNode()) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + const auto &op_type = node->GetType(); + if (op_type == DATA || op_type == AIPP_DATA_TYPE || op_type == NETOUTPUT) { + merged_graph->AddNode(node); + GELOGD("[%s] Node added to merged graph.", op_desc->GetName().c_str()); + continue; + } + + if (op_type != PARTITIONEDCALL) { + GELOGE(INTERNAL_ERROR, "[%s] Unexpected node in root graph. type = %s", op_desc->GetName().c_str(), + op_type.c_str()); + return INTERNAL_ERROR; + } + + bool is_unknown_shape = false; + GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), + "Failed to invoke GetNodeUnknownShapeStatus."); + if (!is_unknown_shape) { + merged_graph->AddNode(node); + GELOGD("[%s] Known shape partitioned call added to merged graph.", op_desc->GetName().c_str()); + continue; + } + + auto subgraph = NodeUtils::GetSubgraph(*node, kSubgraphIndex); + GE_CHK_STATUS_RET(MergeInputNodes(*subgraph), "Failed to merge data nodes for subgraph: %s", + subgraph->GetName().c_str()); + GE_CHK_STATUS_RET(MergeNetOutputNode(*subgraph), "Failed to merge net output nodes for subgraph: %s", + subgraph->GetName().c_str()); + GELOGD("Merging subgraph %s successfully.", subgraph->GetName().c_str()); + for (auto &sub_node : subgraph->GetAllNodes()) { + auto sub_op_type = sub_node->GetType(); + if (sub_op_type == DATA_TYPE || sub_op_type == NETOUTPUT) { + continue; + } + + if (sub_op_type == CONSTANT || sub_op_type == CONSTANTOP || sub_op_type == VARIABLE) { + GELOGE(INTERNAL_ERROR, "Unexpected node in unknown subgraph. type = %s, node = %s::%s", sub_op_type.c_str(), + subgraph->GetName().c_str(), sub_node->GetName().c_str()); + return INTERNAL_ERROR; + } + + merged_graph->AddNode(sub_node); + GELOGD("%s::%s added to merged graph.", subgraph->GetName().c_str(), sub_node->GetName().c_str()); + } + } + + return SUCCESS; +} + +Status HybridModelBuilder::ParseNetOutput(const NodeItem &node_item) { + for (auto &in_data_anchor : node_item.node->GetAllInDataAnchors()) { + auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + auto src_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + + auto src_node_item = GetNodeItem(src_node); + GE_CHECK_NOTNULL(src_node_item); + auto output_offset = src_node_item->output_start + peer_out_anchor->GetIdx(); + GELOGI("Output[%d], node = %s, output_index = %d, output_offset = %d ", in_data_anchor->GetIdx(), + src_node_item->NodeName().c_str(), peer_out_anchor->GetIdx(), output_offset); + hybrid_model_.output_offsets_.emplace_back(output_offset); + } + + for (int i = 0; i < node_item.num_inputs; ++i) { + hybrid_model_.net_output_input_offsets_.emplace_back(node_item.input_start + i); + } + + return SUCCESS; +} + +Status HybridModelBuilder::LoadGraph() { + auto root_graph = ge_root_model_->GetRootGraph(); + GELOGI("Before merge subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), + root_graph->GetAllNodesSize()); + ComputeGraphPtr merged_graph; + GE_CHK_STATUS_RET_NOLOG(MergeSubgraphs(*root_graph, merged_graph)); + GELOGI("After merge subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", merged_graph->GetDirectNodesSize(), + merged_graph->GetAllNodesSize()); + + merged_graph->SetGraphID(runtime_param_.graph_id); + GE_DUMP(merged_graph, "hybrid_merged_graph"); + int input_start = 0; + int output_start = 0; + uint32_t data_op_index = 0; + hybrid_model_.node_items_.resize(merged_graph->GetDirectNodesSize()); + + int64_t node_index = 0; + for (auto &node : merged_graph->GetDirectNode()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + op_desc->SetId(node_index++); + } + + for (const auto &node : merged_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + const auto &op_type = node->GetType(); + + NodeItem *node_item = nullptr; + GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); + GE_CHK_STATUS_RET_NOLOG(BuildNoteItem(node, *node_item)); + GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task + + node_item->input_start = input_start; + node_item->output_start = output_start; + input_start += node_item->num_inputs; + output_start += node_item->num_outputs; + + if (op_type == DATA_TYPE || op_type == AIPP_DATA_TYPE) { + auto data_index = data_op_index; + if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_INDEX, data_index)) { + GELOGI("ge_train: get new index %u, old %u", data_index, data_op_index); + } + hybrid_model_.input_nodes_.emplace(data_index, node_item); + data_op_index++; + } else if (op_type == NETOUTPUT) { + hybrid_model_.net_output_node_ = node_item; + GE_CHK_STATUS_RET_NOLOG(ParseNetOutput(*node_item)); + } else if (op_type == PARTITIONEDCALL) { // known graph + GE_CHK_STATUS_RET_NOLOG(ParsePartitionedCall(*node_item)); + } + + GELOGI("NodeItem created: %s", node_item->DebugString().c_str()); + } + + for (auto &it : hybrid_model_.input_nodes_) { + auto input_index = it.first; + auto input_node = it.second; + + if (input_node->outputs.empty()) { + GELOGE(INTERNAL_ERROR, "data output anchor is empty"); + return INTERNAL_ERROR; + } + + for (auto &out : input_node->outputs) { + std::vector offsets; + for (auto &dst_anchor_and_node : out) { + auto dst_node_item = dst_anchor_and_node.second; + offsets.emplace_back(dst_node_item->input_start + dst_anchor_and_node.first); + } + + hybrid_model_.input_offsets_.emplace(input_index, std::move(offsets)); + } + } + + hybrid_model_.total_inputs_ = input_start; + hybrid_model_.total_outputs_ = output_start; + GELOGI("HybridGraph::LoadGraph OUT"); + return SUCCESS; +} + +const NodeItem *HybridModelBuilder::GetNodeItem(const NodePtr &node) const { return hybrid_model_.GetNodeItem(node); } + +NodeItem *HybridModelBuilder::MutableNodeItem(const NodePtr &node) { return hybrid_model_.MutableNodeItem(node); } + +Status HybridModelBuilder::VarNodeToTensor(const NodePtr &var_node, std::unique_ptr &tensor) { + string var_name = var_node->GetName(); + auto tensor_desc = var_node->GetOpDesc()->MutableOutputDesc(0); + uint8_t *var_logic = nullptr; + + GE_CHK_STATUS_RET(var_manager_->GetVarAddr(var_name, *tensor_desc, &var_logic), + "Failed to get var addr. var_name = %s, session_id = %ld", var_name.c_str(), + hybrid_model_.GetSessionId()); + + uint8_t *dev_mem = var_manager_->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM); + if (dev_mem == nullptr) { + GELOGE(INTERNAL_ERROR, + "Failed to copy var %s from device, cant not get " + "var addr from logic addr %p", + var_node->GetName().c_str(), var_logic); + return INTERNAL_ERROR; + } + + int64_t var_size = CalcVarSizeInBytes(*tensor_desc); + if (var_size < 0) { + GELOGE(INTERNAL_ERROR, "[%s] Invalid var size: %ld", var_name.c_str(), var_size); + return INTERNAL_ERROR; + } + + tensor.reset(new (std::nothrow) TensorValue(dev_mem, var_size)); + GE_CHECK_NOTNULL(tensor); + return SUCCESS; +} + +Status HybridModelBuilder::HandleDtString(const GeTensor &tensor, void *var_addr) { + auto desc = tensor.GetTensorDesc(); + if (desc.GetDataType() == DT_STRING) { + GeShape tensor_shape = desc.GetShape(); + /// if tensor is a scaler, it's shape size if zero, according ge_tensor.cc. + /// the logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero + /// and that of unknown shape is zero too. + /// unknown shape will not appear here, so we can use zero judge a tensor is scalar or not + int64_t elem_num = tensor_shape.GetShapeSize(); + if (elem_num == 0 && tensor_shape.GetDims().empty()) { + elem_num = 1; + } + + auto &mutable_tensor = const_cast(tensor); + uint64_t *buff = reinterpret_cast(mutable_tensor.MutableData().data()); + GE_CHK_BOOL_RET_STATUS(ge::CheckInt64Uint32MulOverflow(elem_num, kBytes) == SUCCESS, FAILED, + "Shape size is invalid"); + auto offset = static_cast(elem_num * kBytes); + auto hbm_raw_data_base_addr = reinterpret_cast(reinterpret_cast(var_addr) + offset); + for (int64_t i = elem_num - 1; i >= 0; --i) { + buff[i] = hbm_raw_data_base_addr + (buff[i] - buff[0]); + } + } + + return SUCCESS; +} + +Status HybridModelBuilder::InitConstantOps() { + for (auto &it : hybrid_model_.constant_op_nodes_) { + string var_name = it.first; + NodePtr &var_node = it.second; + std::unique_ptr var_tensor; + + GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); + GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); + var_tensor->SetName("ConstOp_" + var_name); + + auto op_desc = var_node->GetOpDesc(); + auto v_weights = ModelUtils::GetWeights(op_desc); + auto v_output_size = var_tensor->GetSize(); + auto v_output_addr = var_tensor->MutableData(); + + auto *ge_tensor = const_cast(v_weights[0].get()); + if (ge_tensor->GetData().size() > 0) { + GE_CHK_STATUS_RET_NOLOG(HandleDtString(*ge_tensor, v_output_addr)); + + GELOGI("[IMAS]InitConstant memcpy graph_%u type[V] name[%s] output[%d] memaddr[%p] mem_size[%u] datasize[%zu]", + runtime_param_.graph_id, op_desc->GetName().c_str(), 0, v_output_addr, v_output_size, + ge_tensor->GetData().size()); + GE_CHK_RT_RET(rtMemcpy(v_output_addr, v_output_size, ge_tensor->GetData().data(), ge_tensor->GetData().size(), + RT_MEMCPY_HOST_TO_DEVICE)); + } else { + GELOGI("[%s] Const op has no weight data.", op_desc->GetName().c_str()); + } + + hybrid_model_.variable_tensors_.emplace(var_name, std::move(var_tensor)); + } + + return SUCCESS; +} + +Status HybridModelBuilder::InitVariableTensors() { + for (auto &it : hybrid_model_.variable_nodes_) { + string var_name = it.first; + NodePtr &var_node = it.second; + std::unique_ptr tensor; + GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, tensor)); + GELOGD("Init variable tensor. name = %s, size = %ld, addr = %p", var_name.c_str(), tensor->GetSize(), + tensor->GetData()); + tensor->SetName("Var_" + var_name); + hybrid_model_.variable_tensors_.emplace(var_name, std::move(tensor)); + } + + return SUCCESS; +} + +Status HybridModelBuilder::InitWeights() { + // Train do not have weight. (only got ConstOp) + return SUCCESS; +} + +Status HybridModelBuilder::LoadTasks() { + for (auto &node_item : hybrid_model_.node_items_) { + auto &node_ptr = node_item->node; + if (node_item->node_type == NETOUTPUT) { + continue; + } + + GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str()); + auto load_ret = node_item->node_executor->LoadTask(hybrid_model_, node_ptr, node_item->kernel_task); + + if (load_ret != UNSUPPORTED && load_ret != SUCCESS) { + GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str()); + return load_ret; + } + + GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str()); + } + + return SUCCESS; +} + +Status HybridModelBuilder::IndexTaskDefs() { + const auto &root_graph = ge_root_model_->GetRootGraph(); + for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { + auto &name = it.first; + auto &ge_model = it.second; + GE_CHECK_NOTNULL(ge_model); + + const auto &sub_graph = root_graph->GetSubgraph(name); + if (sub_graph == nullptr) { + continue; + } + + bool is_unknown_shape = false; + GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*sub_graph->GetParentNode(), is_unknown_shape), + "Failed to invoke GetNodeUnknownShapeStatus."); + if (!is_unknown_shape) { + GELOGD("Set ge_model for subgraph: %s", sub_graph->GetName().c_str()); + hybrid_model_.known_shape_sub_graphs_.emplace(sub_graph->GetParentNode(), ge_model); + continue; + } + + // index task defs + GELOGD("To index tasks for subgraph: %s", name.c_str()); + unordered_map node_map; + for (const auto &node : sub_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + auto node_id = node->GetOpDesc()->GetId(); + GELOGD("op_index = %ld, node_name = %s", node_id, node->GetName().c_str()); + node_map.emplace(node_id, node); + } + + auto tasks = ge_model->GetModelTaskDefPtr()->task(); + for (int i = 0; i < tasks.size(); ++i) { + const domi::TaskDef &task_def = tasks[i]; + GELOGI("Task id = %d, task type = %d", i, task_def.type()); + auto task_type = static_cast(task_def.type()); + uint32_t op_index = -1; + if (task_type == RT_MODEL_TASK_KERNEL) { + op_index = task_def.kernel().context().op_index(); + } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { + op_index = task_def.kernel_ex().op_index(); + } else { + GELOGD("Skip task type: %d", static_cast(task_type)); + continue; + } + + auto iter = node_map.find(op_index); + if (iter == node_map.end()) { + GELOGE(INTERNAL_ERROR, "Failed to get node by index = %u", op_index); + return INTERNAL_ERROR; + } + + auto &node = iter->second; + if (task_type == RT_MODEL_TASK_KERNEL) { + ge_model->GetTBEKernelStore().LoadTBEKernelBinToOpDesc(node->GetOpDesc()); + } + + GELOGD("Task loaded for node: %s, task type = %d, op_index = %u", node->GetName().c_str(), task_type, op_index); + hybrid_model_.task_defs_[node].emplace_back(task_def); + } + } + + return SUCCESS; +} + +Status HybridModelBuilder::IndexSpecialNodes() { + GELOGD("Start to index special nodes"); + const auto &root_graph = ge_root_model_->GetRootGraph(); + for (auto &node : root_graph->GetAllNodes()) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + auto op_type = node->GetType(); + if (op_type == VARIABLE) { + hybrid_model_.variable_nodes_.emplace(node->GetName(), node); + } else if (op_type == CONSTANTOP) { + hybrid_model_.constant_op_nodes_.emplace(node->GetName(), node); + } else if (op_type == DATA && node->GetOwnerComputeGraph() != root_graph) { + NodePtr src_node; + int peer_out_index = -1; + GE_CHK_STATUS_RET_NOLOG(GetPeerNodeAcrossSubGraphs(node, src_node, peer_out_index)); + GELOGD("Got peer node for data node %s, peer node = %s(%s)", node->GetName().c_str(), src_node->GetName().c_str(), + src_node->GetType().c_str()); + + auto src_op_type = src_node->GetType(); + if (src_op_type == CONSTANTOP || src_op_type == VARIABLE) { + for (auto &dst_node_and_in_anchor : node->GetOutDataNodesAndAnchors()) { + auto &dst_node = dst_node_and_in_anchor.first; + auto &in_anchor = dst_node_and_in_anchor.second; + node_ref_inputs_[dst_node].emplace_back(std::make_pair(in_anchor->GetIdx(), src_node)); + } + } + } + } + + return SUCCESS; +} + +Status HybridModelBuilder::GetPeerNodeAcrossSubGraphs(const NodePtr &data_node, NodePtr &peer_node, + int &peer_out_index) { + auto sub_graph = data_node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(sub_graph); + GELOGD("To get peer node of %s::%s", sub_graph->GetName().c_str(), data_node->GetName().c_str()); + auto wrapped_node = data_node->GetOwnerComputeGraph()->GetParentNode(); + if (wrapped_node == nullptr) { + GELOGE(INTERNAL_ERROR, "[%s] Node is in root graph.", data_node->GetName().c_str()); + return INTERNAL_ERROR; + } + auto data_op_desc = data_node->GetOpDesc(); + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(data_op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(INTERNAL_ERROR, "[%s] Failed to get attr [%s]", data_op_desc->GetName().c_str(), + ATTR_NAME_PARENT_NODE_INDEX.c_str()); + return INTERNAL_ERROR; + } + + auto wrapped_node_in_anchor = wrapped_node->GetInDataAnchor(parent_index); + GE_CHECK_NOTNULL(wrapped_node_in_anchor); + auto src_out_anchor = wrapped_node_in_anchor->GetPeerOutAnchor(); + if (src_out_anchor == nullptr || src_out_anchor->GetOwnerNode() == nullptr) { + GELOGE(INTERNAL_ERROR, "[%s] Parent node do not have peer anchor.", data_node->GetName().c_str()); + return INTERNAL_ERROR; + } + + auto src_wrapped_node_out_anchor = wrapped_node_in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(src_wrapped_node_out_anchor); + auto src_wrapped_node = src_wrapped_node_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_wrapped_node); + + // connected to root-graph's DATA + auto src_node_type = src_wrapped_node->GetType(); + if (src_node_type != PARTITIONEDCALL) { + peer_node = src_wrapped_node; + peer_out_index = kVarOutputIndex; + GELOGD("[%s] Node is connected to root graph's node: %s", data_node->GetName().c_str(), + peer_node->GetName().c_str()); + return SUCCESS; + } + + auto src_graph = NodeUtils::GetSubgraph(*src_wrapped_node, kSubgraphIndex); + GE_CHECK_NOTNULL(src_graph); + auto src_net_output_node = src_graph->FindNode(NODE_NAME_NET_OUTPUT); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(src_net_output_node == nullptr, return INTERNAL_ERROR, + "Failed to find NetOutput in subgraph: %s", src_graph->GetName().c_str()); + auto net_output_desc = src_net_output_node->GetOpDesc(); + GE_CHECK_NOTNULL(net_output_desc); + + auto out_index = static_cast(src_wrapped_node_out_anchor->GetIdx()); + GELOGD("src graph = %s, src parent output index = %d", src_graph->GetName().c_str(), out_index); + + // link src to outputs of DataNode + auto input_size = net_output_desc->GetAllInputsSize(); + GE_CHECK_LE(input_size, UINT32_MAX); + for (uint32_t i = 0; i < static_cast(input_size); ++i) { + uint32_t p_index = 0; + if (!AttrUtils::GetInt(net_output_desc->GetInputDesc(i), ATTR_NAME_PARENT_NODE_INDEX, p_index)) { + GELOGW("SubGraph: %s input tensor %zu attr %s not found.", src_graph->GetName().c_str(), i, + ATTR_NAME_PARENT_NODE_INDEX.c_str()); + continue; + } + + GELOGD("NetOutput's input[%zu], parent_node_index = %u", i, p_index); + if (p_index == out_index) { + auto in_anchor = src_net_output_node->GetInDataAnchor(i); + GE_CHECK_NOTNULL(in_anchor); + auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + peer_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_node); + peer_out_index = peer_out_anchor->GetIdx(); + GELOGD("Found peer node of Data node: %s::%s is %s::%s", sub_graph->GetName().c_str(), + data_node->GetName().c_str(), src_graph->GetName().c_str(), peer_node->GetName().c_str()); + return SUCCESS; + } + } + + GELOGE(FAILED, "Failed to find peer node for %s::%s", sub_graph->GetName().c_str(), data_node->GetName().c_str()); + return FAILED; +} +Status HybridModelBuilder::InitRuntimeParams() { + int64_t value = 0; + bool ret = false; + if (ge_root_model_->GetSubgraphInstanceNameToModel().empty()) { + GELOGE(INTERNAL_ERROR, "Root model has no sub model"); + return INTERNAL_ERROR; + } + + // session id and var size is same for every model + auto first_model = ge_root_model_->GetSubgraphInstanceNameToModel().begin()->second; + ret = ge::AttrUtils::GetInt(first_model, ge::MODEL_ATTR_SESSION_ID, value); + runtime_param_.session_id = ret ? (uint64_t)value : 0; + ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_TASK_GEN_VAR_ADDR, value); + runtime_param_.logic_var_base = ret ? (uint64_t)value : 0; + ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_VAR_SIZE, value); + runtime_param_.var_size = ret ? (uint64_t)value : 0; + runtime_param_.graph_id = ge_root_model_->GetRootGraph()->GetGraphID(); + GELOGI("InitRuntimeParams(), session_id:%u, var_size:%lu. graph_id = %u", runtime_param_.session_id, + runtime_param_.var_size, runtime_param_.graph_id); + + var_manager_ = VarManager::Instance(runtime_param_.session_id); + GE_CHECK_NOTNULL(var_manager_); + return SUCCESS; +} + +Status HybridModelBuilder::ParsePartitionedCall(NodeItem &node_item) { + GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); + auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); + GE_CHECK_NOTNULL(subgraph); + auto net_output_node = subgraph->FindNode(NODE_NAME_NET_OUTPUT); + GE_CHECK_NOTNULL(net_output_node); + auto net_output_desc = net_output_node->GetOpDesc(); + GE_CHECK_NOTNULL(net_output_desc); + + for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { + auto src_node = GetPeerNode(in_data_anchor); + GE_CHECK_NOTNULL(src_node); + auto src_op_type = src_node->GetType(); + GELOGD("Node %s, output %d, src node = %s, src node type = %s", node_item.NodeName().c_str(), + in_data_anchor->GetIdx(), src_node->GetName().c_str(), src_op_type.c_str()); + + if (src_op_type != CONSTANTOP && src_op_type != VARIABLE) { + continue; + } + + uint32_t parent_index = 0; + GE_CHK_STATUS_RET_NOLOG(GetParentNodeOutputIndex(*net_output_desc, in_data_anchor->GetIdx(), parent_index)); + GELOGD("Got parent output index = %u", parent_index); + node_item.ref_outputs.emplace(parent_index, src_node); + } + + for (auto &node : subgraph->GetDirectNode()) { + if (node->GetType() != DATA) { + continue; + } + + string ref_var_name; + (void)AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_name); + if (ref_var_name.empty()) { + continue; + } + + GELOGD("Data node ref to variable: %s", ref_var_name.c_str()); + NodePtr src_node; + auto var_node = hybrid_model_.GetVariableNode(ref_var_name); + GE_CHECK_NOTNULL(var_node); + GELOGD("Found var node [%s] by ref_var_name [%s]", var_node->GetName().c_str(), ref_var_name.c_str()); + int peer_output_index = -1; + GE_CHK_STATUS_RET_NOLOG(GetPeerNodeAcrossSubGraphs(node, src_node, peer_output_index)); + auto src_node_item = MutableNodeItem(src_node); + GE_CHECK_NOTNULL(src_node_item); + src_node_item->ref_outputs.emplace(peer_output_index, var_node); + } + + return SUCCESS; +} + +NodePtr HybridModelBuilder::GetPeerNode(const InDataAnchorPtr &in_data_anchor) { + auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor != nullptr) { + return peer_out_anchor->GetOwnerNode(); + } + + return nullptr; +} + +Status HybridModelBuilder::GetParentNodeOutputIndex(const OpDesc &op_desc, int index, uint32_t &out_index) { + auto input_desc = op_desc.MutableInputDesc(index); + GE_CHECK_NOTNULL(input_desc); + if (!AttrUtils::GetInt(input_desc, ATTR_NAME_PARENT_NODE_INDEX, out_index)) { + GELOGE(INTERNAL_ERROR, "NetOutput input tensor %d, attr %s not found.", index, ATTR_NAME_PARENT_NODE_INDEX.c_str()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status HybridModelBuilder::InitModelMem() { + hybrid_model_.var_mem_base_ = var_manager_->GetVarMemoryBase(RT_MEMORY_HBM); + auto total_var_size = hybrid_model_.TotalVarMemSize(); + if (total_var_size > 0 && hybrid_model_.var_mem_base_ == nullptr) { + GE_CHK_STATUS_RET(var_manager_->MallocVarMemory(total_var_size), "Malloc Var Memory Fail."); + hybrid_model_.var_mem_base_ = var_manager_->GetVarMemoryBase(RT_MEMORY_HBM); + } + + runtime_param_.var_base = hybrid_model_.var_mem_base_; + return SUCCESS; +} + +Status HybridModelBuilder::TransAllVarData() { + GELOGI("TransAllVarData start: session_id:%lu, graph_id: %u.", runtime_param_.session_id, runtime_param_.graph_id); + rtContext_t ctx = nullptr; + rtError_t rt_ret = rtCtxGetCurrent(&ctx); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Failed to get current context, error_code is: 0x%X.", rt_ret); + return RT_FAILED; + } + + std::vector variable_node_list; + for (auto &it : hybrid_model_.variable_nodes_) { + variable_node_list.emplace_back(it.second); + GELOGD("[%s] added for trans var data", it.first.c_str()); + } + + GE_CHK_STATUS_RET( + TransVarDataUtils::TransAllVarData(variable_node_list, runtime_param_.session_id, ctx, runtime_param_.graph_id), + "TransAllVarData failed."); + + GELOGI("TransAllVarData success."); + return SUCCESS; +} + +Status HybridModelBuilder::CopyVarData() { + GE_CHK_STATUS_RET( + TransVarDataUtils::CopyVarData(ge_root_model_->GetRootGraph(), runtime_param_.session_id, hybrid_model_.device_id_), + "CopyVarData failed."); + GELOGI("CopyVarData success."); + return SUCCESS; +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/model/hybrid_model_builder.h b/src/ge/hybrid/model/hybrid_model_builder.h new file mode 100644 index 00000000..33cd1f03 --- /dev/null +++ b/src/ge/hybrid/model/hybrid_model_builder.h @@ -0,0 +1,87 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_MODEL_HYBRID_MODEL_BUILDER_H_ +#define GE_HYBRID_MODEL_HYBRID_MODEL_BUILDER_H_ + +#include +#include +#include +#include "framework/common/ge_inner_error_codes.h" +#include "graph/load/new_model_manager/task_info/task_info.h" +#include "graph/node.h" +#include "hybrid/model/hybrid_model.h" +#include "hybrid/model/node_item.h" +#include "model/ge_model.h" + +namespace ge { +class VarManager; +namespace hybrid { +class HybridModelBuilder { + public: + explicit HybridModelBuilder(HybridModel &hybrid_model); + ~HybridModelBuilder() = default; + Status Build(); + + private: + static Status UpdateAnchorStatus(const NodePtr &node); + static Status DoUnlinkDataAnchors(const OutDataAnchorPtr &out_data_anchor, const InDataAnchorPtr &in_data_anchor); + static Status DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, InDataAnchorPtr &in_data_anchor); + static NodePtr GetPeerNode(const InDataAnchorPtr &in_data_anchor); + static Status GetParentNodeOutputIndex(const OpDesc &op_desc, int index, uint32_t &out_index); + static Status GetPeerNodeAcrossSubGraphs(const NodePtr &data_node, NodePtr &peer_node, int &peer_out_index); + static Status HandleDtString(const GeTensor &tensor, void *var_addr); + static Status MergeInputNodes(ComputeGraph &compute_graph); + static Status MergeNetOutputNode(ComputeGraph &compute_graph); + static Status MergeSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph); + static Status InitWeights(); + + Status ValidateParams(); + Status LoadGraph(); + Status LoadTasks(); + Status ParsePartitionedCall(NodeItem &node_item); + Status ParseNetOutput(const NodeItem &node_item); + Status BuildNoteItem(const NodePtr &node, NodeItem &node_item); + Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); + Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); + Status ResolveRootNodes(); + Status IndexTaskDefs(); + Status IndexSpecialNodes(); + Status InitRuntimeParams(); + Status InitModelMem(); + Status TransAllVarData(); + Status CopyVarData(); + Status VarNodeToTensor(const NodePtr &var_node, std::unique_ptr &tensor); + Status InitConstantOps(); + Status InitVariableTensors(); + + const char *GetGraphName() const { return graph_name_.c_str(); } + + const NodeItem *GetNodeItem(const NodePtr &node) const; + NodeItem *MutableNodeItem(const NodePtr &node); + + GeRootModelPtr ge_root_model_; + std::string graph_name_; + std::map> weights_; + HybridModel &hybrid_model_; + std::map>> node_ref_inputs_; + + RuntimeParam &runtime_param_; + VarManager *var_manager_ = nullptr; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_MODEL_HYBRID_MODEL_BUILDER_H_ diff --git a/src/ge/hybrid/model/node_item.cc b/src/ge/hybrid/model/node_item.cc new file mode 100644 index 00000000..b5d4fbda --- /dev/null +++ b/src/ge/hybrid/model/node_item.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "node_item.h" +#include + +namespace ge { +namespace hybrid { +NodeItem::NodeItem(NodePtr node) : node(std::move(node)) { + this->op_desc = this->node->GetOpDesc().get(); + this->node_id = this->op_desc->GetId(); + this->num_inputs = this->op_desc->GetInputsSize(); + this->num_outputs = this->op_desc->GetOutputsSize(); + this->node_name = this->node->GetName(); + this->node_type = this->node->GetType(); +} + +std::string NodeItem::DebugString() const { + std::stringstream ss; + ss << "Node: "; + ss << "id = " << node_id; + ss << ", name = " << node->GetName(); + ss << ", type = " << node->GetType(); + ss << ", is_dynamic = " << (is_dynamic ? "True" : "False"); + ss << ", unknown_shape_op_type = " << shape_inference_type; + ss << ", input_start = " << input_start; + ss << ", num_inputs = " << num_inputs; + ss << ", output_start = " << output_start; + ss << ", num_outputs = " << num_outputs; + ss << ", dependent_nodes = ["; + for (const auto &dep_node : dependent_node_list) { + ss << dep_node->GetName() << ", "; + } + ss << "]"; + int index = 0; + for (auto &items : outputs) { + ss << ", output[" << index++ << "]: "; + for (auto &item : items) { + ss << "(" << item.second->NodeName() << ":" << item.first << "), "; + } + } + + return ss.str(); +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/model/node_item.h b/src/ge/hybrid/model/node_item.h new file mode 100644 index 00000000..b12d100b --- /dev/null +++ b/src/ge/hybrid/model/node_item.h @@ -0,0 +1,70 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_MODEL_NODE_ITEM_H_ +#define GE_HYBRID_MODEL_NODE_ITEM_H_ + +#include +#include "graph/node.h" +#include "graph/op_desc.h" +#include "framework/common/types.h" +#include "hybrid/common/tensor_value.h" + +namespace ge { +namespace hybrid { +class NodeTask; +class NodeExecutor; + +// for caching static information across execution +struct NodeItem { + explicit NodeItem(NodePtr node); + ~NodeItem() = default; + + const std::string &NodeName() const { return node_name; } + + const std::string &NodeType() const { return node_type; } + + std::string DebugString() const; + + NodePtr node; + OpDesc *op_desc; + int node_id; + int num_inputs; + int num_outputs; + + int input_start = -1; + int output_start = -1; + bool is_dynamic = false; + bool has_observer = false; + UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE; + std::string node_name; + std::string node_type; + std::vector dependent_node_list; + std::set to_const_output_id_list; + + // src_output_id, dst_anchor_id, dst_node + vector inputs; + vector>> outputs; + + std::shared_ptr kernel_task; + const NodeExecutor *node_executor = nullptr; + std::map const_input_shapes; + std::map ref_outputs; +}; +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_MODEL_NODE_ITEM_H_ diff --git a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc new file mode 100644 index 00000000..3f198bba --- /dev/null +++ b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc @@ -0,0 +1,310 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "aicore_node_executor.h" +#include "cce/taskdown_common.hpp" +#include "graph/debug/ge_attr_define.h" +#include "hybrid/model/hybrid_model.h" +#include "init/gelib.h" +#include "framework/common/debug/log.h" + +namespace ge { +namespace hybrid { +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICORE, AiCoreNodeExecutor); + +AiCoreNodeTask::AiCoreNodeTask(std::vector> &&tasks) : tasks_(std::move(tasks)) {} + +Status AiCoreNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { + GE_CHECK_NOTNULL(node); + GELOGI("AiCoreNodeExecutor[%s] LoadTask Start.", node->GetName().c_str()); + + auto *task_defs = model.GetTaskDefs(node); + Status ret = SUCCESS; + GE_IF_BOOL_EXEC(task_defs != nullptr && !task_defs->empty(), ret = CreateTask(model, *task_defs, node, task)); + + GELOGI("AiCoreNodeExecutor[%s] LoadTask End, ret[%u].", node->GetName().c_str(), ret); + return ret; +} + +Status AiCoreNodeExecutor::GenNodeKey(const NodePtr &node, std::string &node_key) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + // make sure unique, (op_id + input_shape) is unique + node_key = std::to_string(op_desc->GetId()) + "/"; + node_key.append(std::to_string(op_desc->GetInputsSize())); + auto input_descs = op_desc->GetAllInputsDesc(); + for (auto input_desc : input_descs) { + node_key.push_back('/'); + std::vector dims = input_desc.GetShape().GetDims(); + GE_IF_BOOL_EXEC(dims.size() == 0, continue); // scalar + for (std::size_t i = 0; i < dims.size() - 1; i++) { + node_key.append(std::to_string(dims[i])); + node_key.push_back(','); + } + node_key.append(std::to_string(dims[dims.size() - 1])); + } + return SUCCESS; +} + +bool AiCoreNodeTaskRegistry::AddTask(const std::string &node_key, const std::shared_ptr task) { + GE_CHECK_NOTNULL(task); + std::lock_guard lock(mutex_); + auto iter = reg_node_tasks_.find(node_key); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iter != reg_node_tasks_.end(), return false, + "AiCoreNodeTaskRegistry[%s] AddTask failed, key already exist.", node_key.c_str()); + auto ret = reg_node_tasks_.emplace(node_key, task); + return ret.second; +} + +std::shared_ptr AiCoreNodeTaskRegistry::GetTask(const std::string &node_key) { + std::lock_guard lock(mutex_); + auto iter = reg_node_tasks_.find(node_key); + return (iter != reg_node_tasks_.end()) ? iter->second : nullptr; +} + +Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node, + shared_ptr &task) const { + GE_CHECK_NOTNULL(node); + GELOGI("AiCoreNodeExecutor[%s] CompileTask Start.", node->GetName().c_str()); + + AiCoreNodeTaskRegistry ®istry = AiCoreNodeTaskRegistry::GetInstance(); + std::string node_key; + GE_CHK_STATUS_RET(GenNodeKey(node, node_key), "GenNodeKey failed. op name = %s", node->GetName().c_str()); + + GELOGD("NodeKey for %s = %s", node->GetName().c_str(), node_key.c_str()); + task = registry.GetTask(node_key); + GE_CHK_TRUE_EXEC_INFO(task != nullptr, return SUCCESS, "AiCoreNodeExecutor[%s] CompileTask Skip.", + node->GetName().c_str()); + + std::vector task_defs; + GE_CHK_STATUS_RET_NOLOG(compiler_->CompileOp(node, task_defs)); + GELOGD("successfully generated task_defs: %s", node->GetName().c_str()); + + GE_CHK_STATUS_RET_NOLOG(CreateTask(model, task_defs, node, task)); + GELOGD("successfully created node task: %s", node->GetName().c_str()); + + GE_CHK_BOOL_EXEC(registry.AddTask(node_key, task), return INTERNAL_ERROR, "Add NodeTask failed. op name = %s", + node->GetName().c_str()); // should not happen. + GELOGI("AiCoreNodeExecutor[%s] CompileTask End.", node->GetName().c_str()); + return SUCCESS; +} + +Status AiCoreNodeExecutor::BuildAiCoreTask(const domi::KernelDef &kernel_def, const OpDescPtr &op_desc, + AiCoreOpTask **task) { + GE_CHECK_NOTNULL(op_desc); + GE_CHECK_NOTNULL(task); + + const auto &context = kernel_def.context(); + auto kernel_type = static_cast(context.kernel_type()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(kernel_type != cce::ccKernelType::TE, return UNSUPPORTED, + "Only TBE kernel is supported, but [%s] got %u", op_desc->GetName().c_str(), + context.kernel_type()); + + auto *aicore_task = new (std::nothrow) AiCoreOpTask(); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aicore_task == nullptr, return MEMALLOC_FAILED, "Create AiCore op task failed."); + + auto builder = AiCoreTaskBuilder(op_desc, kernel_def); + auto ret = builder.BuildTask(*aicore_task); + GE_IF_BOOL_EXEC(ret != SUCCESS, delete aicore_task; aicore_task = nullptr; return ret); + + *task = aicore_task; + return SUCCESS; +} + +Status AiCoreNodeExecutor::CreateTask(const HybridModel &model, const std::vector &task_defs, + const NodePtr &node, std::shared_ptr &task) { + GE_CHECK_NOTNULL(node); + GELOGD("To CreateTask, task def size = %zu", task_defs.size()); + std::vector> aicore_op_tasks; + aicore_op_tasks.reserve(task_defs.size()); + for (size_t i = 0; i < task_defs.size(); ++i) { + const domi::TaskDef &task_def = task_defs[i]; + GELOGD("Op[%s] Task[%d], type = %u, DebugString = %s", node->GetName().c_str(), i, task_def.type(), + task_def.DebugString().c_str()); + auto task_type = static_cast(task_def.type()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(task_type == RT_MODEL_TASK_KERNEL_EX, return UNSUPPORTED, + "BuildKernelExTask is not supported"); + GE_CHK_BOOL_TRUE_EXEC_INFO(task_type != RT_MODEL_TASK_KERNEL, continue, "Skip task type %d", + static_cast(task_type)); + + const domi::KernelDef &kernel_def = task_def.kernel(); + AiCoreOpTask *aicore_op_task = nullptr; + // not use hybrid model now + GE_CHK_STATUS_RET_NOLOG(BuildAiCoreTask(kernel_def, node->GetOpDesc(), &aicore_op_task)); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aicore_op_task == nullptr, return FAILED, "BuildAiCoreTask[%s] failed.", + node->GetName().c_str()); + + aicore_op_tasks.emplace_back(std::unique_ptr(aicore_op_task)); + } + + if (!aicore_op_tasks.empty()) { + auto aic_task = std::shared_ptr(new AiCoreNodeTask(std::move(aicore_op_tasks))); + task = std::move(aic_task); + GELOGD("Generate AiCoreOpTask success"); + return SUCCESS; + } + + GELOGE(INTERNAL_ERROR, "Failed to build task. node = %s", node->GetName().c_str()); + return INTERNAL_ERROR; +} + +Status AiCoreNodeExecutor::Initialize() { + std::shared_ptr ge_lib = GELib::GetInstance(); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((ge_lib == nullptr) || !ge_lib->InitFlag(), return GE_CLI_GE_NOT_INITIALIZED, + "Get ge_lib failed."); + + auto &kernel_manager = ge_lib->OpsKernelManagerObj(); + auto aic_ops_store = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aic_ops_store == nullptr, return GE_CLI_GE_NOT_INITIALIZED, + "Failed to get kernel info store for AIcoreEngine."); + + compiler_.reset(new (std::nothrow) AiCoreTaskCompiler(aic_ops_store)); + GE_CHECK_NOTNULL(compiler_); + return SUCCESS; +} + +Status AiCoreNodeExecutor::Finalize() { return NodeExecutor::Finalize(); } + +Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + auto op_desc = context.GetNodeItem().op_desc; + GE_CHECK_NOTNULL(op_desc); + GELOGI("AiCoreNodeTask[%s] ExecuteAsync Start.", op_desc->GetName().c_str()); + for (size_t i = 0; i < tasks_.size(); i++) { + GE_CHECK_NOTNULL(tasks_[i]); + GE_CHK_STATUS_RET_NOLOG(tasks_[i]->LaunchKernel(context.GetStream())); + } + + if (done_callback != nullptr) { + GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); + } + + GELOGI("AiCoreNodeTask[%s] ExecuteAsync End.", op_desc->GetName().c_str()); + return SUCCESS; +} + +Status AiCoreNodeTask::UpdateAtomicArgs(TaskContext &context, std::unique_ptr &task) { + GE_CHECK_NOTNULL(task); + auto op_desc = context.GetNodeItem().op_desc; + GE_CHECK_NOTNULL(op_desc); + + // refresh atomic output addr + std::vector atomic_output_indexes; // here atomic just clean output + (void)ge::AttrUtils::GetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indexes); + GE_RETURN_WITH_LOG_IF_TRUE(atomic_output_indexes.size() > static_cast(context.NumOutputs()), + "AtomicAddrClean op's arg_size error."); + auto *arg_off = reinterpret_cast(task->args_.get()) + task->offset_; + auto *arg_base = reinterpret_cast(arg_off); + int index = 0; + for (size_t i = 0; i < atomic_output_indexes.size(); ++i) { + const auto output = context.GetOutput(atomic_output_indexes[i]); + GE_CHECK_NOTNULL(output); + arg_base[index++] = reinterpret_cast(output->GetData()); + } + + // refresh atomic workspace addr + auto workspace_sizes = op_desc->GetWorkspaceBytes(); + uint64_t ops_workspace_num = static_cast(workspace_sizes.size()); + uint64_t workspace_num = static_cast(context.NumWorkspaces()); + GE_CHK_BOOL_EXEC(ops_workspace_num == workspace_num, return PARAM_INVALID, + "The workspace_num in op_desc %lu is not equal to it %lu in context.", ops_workspace_num, + workspace_num); + GE_IF_BOOL_EXEC(workspace_num == 0, return SUCCESS); + + map> workspace_info; + workspace_info = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, workspace_info); + if (!workspace_info.empty()) { + bool is_fusion_node = false; + (void)ge::AttrUtils::GetBool(op_desc, ATOMIC_ATTR_IS_FUSION_NODE, is_fusion_node); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(is_fusion_node, return PARAM_INVALID, + "Atomic desc[%s] shouldn't be fusion_node in AiCoreNodeTask", + op_desc->GetName().c_str()); + + for (auto iter = workspace_info.begin(); iter != workspace_info.end(); ++iter) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc->GetName() != iter->first, return PARAM_INVALID, + "The node name %s and the node name %s in workspace info are inconsistent.", + op_desc->GetName().c_str(), iter->first.c_str()); + GE_IF_BOOL_EXEC(iter->second.empty(), continue); + + for (auto &info_iter : iter->second) { + auto workspace_index = static_cast(info_iter.first); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(workspace_index >= workspace_num, return PARAM_INVALID, + "The workspace index %lu is more than the size %lu of workspace vector.", + workspace_index, workspace_num); + + const auto workspace = context.MutableWorkspace(workspace_index); + arg_base[index++] = reinterpret_cast(workspace); + } + } + } + + return SUCCESS; +} + +Status AiCoreNodeTask::UpdateAllArgs(TaskContext &context, std::unique_ptr &task) { + GE_CHECK_NOTNULL(task); + auto *arg_off = reinterpret_cast(task->args_.get()) + task->offset_; + auto *arg_base = reinterpret_cast(arg_off); + int index = 0; + for (int i = 0; i < context.NumInputs(); ++i) { + const auto input = context.GetInput(i); + GE_CHECK_NOTNULL(input); + arg_base[index++] = reinterpret_cast(input->GetData()); + } + + for (int i = 0; i < context.NumOutputs(); ++i) { + const auto output = context.GetOutput(i); + GE_CHECK_NOTNULL(output); + arg_base[index++] = reinterpret_cast(output->GetData()); + } + + auto op_desc = context.GetNodeItem().op_desc; + GE_CHECK_NOTNULL(op_desc); + auto workspace_sizes = op_desc->GetWorkspaceBytes(); + int ops_workspace_num = static_cast(workspace_sizes.size()); + int workspace_num = static_cast(context.NumWorkspaces()); + GE_CHK_BOOL_EXEC(ops_workspace_num == workspace_num, return PARAM_INVALID, + "The workspace_num in op_desc %lu is not equal to it %lu in context.", ops_workspace_num, + workspace_num); + for (int i = 0; i < workspace_num; ++i) { + const auto workspace = context.MutableWorkspace(i); + arg_base[index++] = reinterpret_cast(workspace); + } + + return SUCCESS; +} + +Status AiCoreNodeTask::UpdateArgs(TaskContext &context) { + auto op_desc = context.GetNodeItem().op_desc; + GE_CHECK_NOTNULL(op_desc); + GELOGI("AiCoreNodeTask[%s] UpdateArgs Start.", op_desc->GetName().c_str()); + GE_IF_BOOL_EXEC(tasks_.size() == 1, return UpdateAllArgs(context, tasks_[0])); + + std::vector atomic_output_indexes; // here atomic just clean output + (void)ge::AttrUtils::GetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indexes); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(atomic_output_indexes.empty(), return FAILED, "ATOMIC_ATTR_OUTPUT_INDEX is empty."); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(tasks_.size() != 2, return FAILED, "AtomicAddrClean op task num != 2."); + + GE_CHK_STATUS_RET_NOLOG(UpdateAtomicArgs(context, tasks_[0])); + GE_CHK_STATUS_RET_NOLOG(UpdateAllArgs(context, tasks_[1])); + + GELOGI("AiCoreNodeTask[%s] UpdateArgs End.", op_desc->GetName().c_str()); + return SUCCESS; +} +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h new file mode 100644 index 00000000..a8b24e68 --- /dev/null +++ b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h @@ -0,0 +1,78 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ +#define GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ + +#include "hybrid/node_executor/aicore/aicore_task_builder.h" +#include "hybrid/node_executor/aicore/aicore_task_compiler.h" +#include "hybrid/node_executor/node_executor.h" +#include +#include + +namespace ge { +namespace hybrid { + +class AiCoreNodeTaskRegistry { + public: + ~AiCoreNodeTaskRegistry() = default; + + static AiCoreNodeTaskRegistry &GetInstance() { + static AiCoreNodeTaskRegistry instance; + return instance; + } + + std::shared_ptr GetTask(const std::string &node_key); + bool AddTask(const std::string &node_key, const std::shared_ptr task); + + private: + AiCoreNodeTaskRegistry() = default; + std::map> reg_node_tasks_; + std::mutex mutex_; +}; + +class AiCoreNodeTask : public NodeTask { + public: + explicit AiCoreNodeTask(std::vector> &&tasks); + ~AiCoreNodeTask() = default; + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + Status UpdateArgs(TaskContext &context) override; + + private: + static Status UpdateAllArgs(TaskContext &context, std::unique_ptr &task); + static Status UpdateAtomicArgs(TaskContext &context, std::unique_ptr &task); + std::vector> tasks_; +}; + +class AiCoreNodeExecutor : public NodeExecutor { + public: + Status Initialize() override; + Status Finalize() override; + + Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const override; + Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const override; + + private: + static Status CreateTask(const HybridModel &model, const std::vector &task_defs, const NodePtr &node, + std::shared_ptr &task); + static Status BuildAiCoreTask(const domi::KernelDef &kernel_def, const OpDescPtr &op_desc, AiCoreOpTask **task); + static Status GenNodeKey(const NodePtr &node, std::string &node_key); + std::unique_ptr compiler_; +}; + +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ diff --git a/third_party/fwkacllib/inc/ops/decode_boundaries_target.h b/src/ge/hybrid/node_executor/aicore/aicore_op_task.cc similarity index 53% rename from third_party/fwkacllib/inc/ops/decode_boundaries_target.h rename to src/ge/hybrid/node_executor/aicore/aicore_op_task.cc index f6951f9a..27256e9a 100644 --- a/third_party/fwkacllib/inc/ops/decode_boundaries_target.h +++ b/src/ge/hybrid/node_executor/aicore/aicore_op_task.cc @@ -14,18 +14,19 @@ * limitations under the License. */ - #ifndef GE_OP_DECODE_BOUNDARIES_TARGET_H - #define GE_OP_DECODE_BOUNDARIES_TARGET_H +#include "aicore_op_task.h" +#include "framework/common/debug/log.h" - #include "graph/operator_reg.h" +namespace ge { +namespace hybrid { - namespace ge { +Status AiCoreOpTask::LaunchKernel(rtStream_t stream) { + GELOGI("AiCoreOpTask LaunchKernel Start (task = %s, block_dim = %u).", stub_name_.c_str(), block_dim_); - REG_OP(DecodeBoundariesTarget) - .INPUT(boundary_predictions, TensorType({DT_FLOAT16})) /* "First operand." */ - .INPUT(anchors, TensorType({DT_FLOAT16})) /* "Second operand." */ - .OUTPUT(boundary_encoded, TensorType({DT_FLOAT16})) /* "Result, has same element type as two inputs" */ - .OP_END_FACTORY_REG(DecodeBoundariesTarget) - } // namespace ge + GE_CHK_RT_RET(rtKernelLaunch(stub_func_, block_dim_, args_.get(), args_size_, nullptr, stream)); + GELOGI("AiCoreOpTask LaunchKernel End (task = %s, block_dim = %u).", stub_name_.c_str(), block_dim_); + return SUCCESS; +} - #endif // GE_OP_DECODE_BOUNDARIES_TARGET_H \ No newline at end of file +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/aicore/aicore_op_task.h b/src/ge/hybrid/node_executor/aicore/aicore_op_task.h new file mode 100644 index 00000000..d23688a5 --- /dev/null +++ b/src/ge/hybrid/node_executor/aicore/aicore_op_task.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_KERNEL_AICORE_OP_TASK_H_ +#define GE_HYBRID_KERNEL_AICORE_OP_TASK_H_ + +#include +#include "common/ge_inner_error_codes.h" +#include "runtime/stream.h" +namespace ge { +namespace hybrid { +class AiCoreOpTask { + public: + AiCoreOpTask() = default; + ~AiCoreOpTask() = default; + Status LaunchKernel(rtStream_t stream); + + private: + friend class AiCoreTaskBuilder; + friend class AiCoreNodeTask; + std::string stub_name_; + void *stub_func_ = nullptr; + std::unique_ptr args_ = nullptr; + uint32_t args_size_ = 0; + uint32_t block_dim_ = 1; + uint16_t offset_ = 0; +}; + +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_KERNEL_AICORE_OP_TASK_H_ diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc new file mode 100644 index 00000000..5b263007 --- /dev/null +++ b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc @@ -0,0 +1,90 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "aicore_task_builder.h" +#include +#include "graph/op_desc.h" +#include "cce/taskdown_common.hpp" +#include "framework/common/debug/log.h" +#include "graph/debug/ge_attr_define.h" + +namespace ge { +namespace hybrid { +std::mutex g_reg_mutex; + +AiCoreTaskBuilder::AiCoreTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def) + : op_desc_(op_desc), kernel_def_(kernel_def) { + std::string session_graph_id; + GE_IF_BOOL_EXEC(AttrUtils::GetStr(*op_desc_, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id), + GELOGD("Get original type of session_graph_id.")); + // get bin_file_key + stub_name_ = (session_graph_id.empty()) ? op_desc_->GetName() : session_graph_id + "_" + op_desc_->GetName(); +} + +Status AiCoreTaskBuilder::SetKernelArgs(AiCoreOpTask &task) { + const domi::KernelContext &context = kernel_def_.context(); + // get kernel_type + auto kernel_type = static_cast(context.kernel_type()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(kernel_type != cce::ccKernelType::TE, return UNSUPPORTED, + "Invalid kernel type[%d] in AiCore TaskDef.", static_cast(kernel_type)); + + task.args_size_ = kernel_def_.args_size(); + task.block_dim_ = kernel_def_.block_dim(); + + // malloc args memory + task.args_.reset(new (std::nothrow) uint8_t[task.args_size_]); + // task.args_ = std::make_unique(task.args_size_); + GE_CHECK_NOTNULL(task.args_); + errno_t err = memcpy_s(task.args_.get(), task.args_size_, kernel_def_.args().data(), task.args_size_); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(err != EOK, return INTERNAL_ERROR, "AiCoreTask memcpy failed."); + + const auto *args_offset_tmp = reinterpret_cast(const_cast(context.args_offset().data())); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(context.args_offset().size() / sizeof(uint16_t) < 1, return FAILED, + "context.args_offset().size() / sizeof(uint16_t) less than 1"); + task.offset_ = *args_offset_tmp; + return SUCCESS; +} + +const char *AiCoreKernelRegistry::GetUnique(const string &stub_key) { + std::lock_guard lock(mutex_); + auto it = unique_stubs_.find(stub_key); + GE_IF_BOOL_EXEC(it != unique_stubs_.end(), return it->c_str()); + it = unique_stubs_.insert(unique_stubs_.end(), stub_key); + return it->c_str(); +} + +Status AiCoreTaskBuilder::SetStub(AiCoreOpTask &task) { + AiCoreKernelRegistry ®istry = AiCoreKernelRegistry::GetInstance(); + std::lock_guard lock(g_reg_mutex); + const char *unique_key = registry.GetUnique(stub_name_); + + GE_CHK_RT_RET(rtGetFunctionByName(unique_key, &(task.stub_func_))); + task.stub_name_ = stub_name_; + + return SUCCESS; +} + +Status AiCoreTaskBuilder::BuildTask(AiCoreOpTask &task) { + GE_CHECK_NOTNULL(op_desc_); + GELOGI("AiCoreTaskBuilder[%s] BuildTask Start.", op_desc_->GetName().c_str()); + GE_CHK_STATUS_RET_NOLOG(SetKernelArgs(task)); + GE_CHK_STATUS_RET_NOLOG(SetStub(task)); + GELOGI("AiCoreTaskBuilder[%s] BuildTask End.", op_desc_->GetName().c_str()); + return SUCCESS; +} + +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h new file mode 100644 index 00000000..18cb309c --- /dev/null +++ b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_KERNEL_AICORE_TASK_BUILDER_H_ +#define GE_HYBRID_KERNEL_AICORE_TASK_BUILDER_H_ + +#include +#include +#include +#include +#include "aicore_op_task.h" +#include "proto/task.pb.h" +#include "graph/utils/attr_utils.h" +#include "graph/op_kernel_bin.h" + +namespace ge { +namespace hybrid { +class AiCoreKernelRegistry { + public: + ~AiCoreKernelRegistry() = default; + static AiCoreKernelRegistry &GetInstance() { + static AiCoreKernelRegistry instance; + return instance; + } + const char *GetUnique(const string &stub_func); + + private: + AiCoreKernelRegistry() = default; + std::set unique_stubs_; + std::mutex mutex_; +}; + +class AiCoreTaskBuilder { + public: + AiCoreTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def); + ~AiCoreTaskBuilder() = default; + Status BuildTask(AiCoreOpTask &task); + + private: + Status SetKernelArgs(AiCoreOpTask &task); + Status SetStub(AiCoreOpTask &task); + const OpDescPtr &op_desc_; + const domi::KernelDef &kernel_def_; + std::string stub_name_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_KERNEL_AICORE_TASK_BUILDER_H_ diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc b/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc new file mode 100644 index 00000000..ac89afbd --- /dev/null +++ b/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "aicore_task_compiler.h" +#include "framework/common/debug/log.h" +#include "graph/debug/ge_attr_define.h" + +namespace ge { +namespace hybrid { +namespace { +uintptr_t kWeightBase = 0x10000000; +uintptr_t kMemBase = 0x20000000; +uint64_t kFakeSize = 0x10000000UL; +} // namespace +std::mutex AiCoreTaskCompiler::mu_; + +AiCoreTaskCompiler::AiCoreTaskCompiler(OpsKernelInfoStorePtr aic_kernel_store) + : aic_kernel_store_(std::move(aic_kernel_store)) {} + +Status AiCoreTaskCompiler::DoCompileOp(OpsKernelInfoStore &ops_store, const NodePtr &node) { + GE_CHECK_NOTNULL(node); + vector node_vec; + node_vec.emplace_back(node); + std::lock_guard lk(mu_); + GE_CHK_STATUS_RET(ops_store.CompileOpRun(node_vec), "Failed to execute CompileOp, node = %s", + node->GetName().c_str()); + GE_CHK_STATUS_RET(ops_store.CalcOpRunningParam(*node), "Failed to execute CalcOpRunningParam, node = %s", + node->GetName().c_str()); + return SUCCESS; +} + +Status AiCoreTaskCompiler::CompileOp(const NodePtr &node, std::vector &tasks) const { + GE_CHECK_NOTNULL(node); + GELOGI("AiCoreTaskCompiler[%s] CompileOp Start.", node->GetName().c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aic_kernel_store_ == nullptr, return FAILED, + "Failed to get AiCore kernel store, node = %s", node->GetName().c_str()); + + GE_CHK_STATUS_RET_NOLOG(DoCompileOp(*aic_kernel_store_, node)); + GELOGD("successfully compiled op: %s", node->GetName().c_str()); + + auto op_desc = node->GetOpDesc(); + std::vector input_offsets(op_desc->GetInputsSize(), kMemBase); + std::vector output_offsets(op_desc->GetOutputsSize(), kMemBase); + op_desc->SetInputOffset(input_offsets); + op_desc->SetOutputOffset(output_offsets); + GE_CHK_STATUS_RET_NOLOG(DoGenerateTask(*aic_kernel_store_, *node, tasks)); + GELOGD("successfully generated task: %s", node->GetName().c_str()); + GELOGI("AiCoreTaskCompiler[%s] CompileOp End.", node->GetName().c_str()); + return SUCCESS; +} + +Status AiCoreTaskCompiler::DoGenerateTask(OpsKernelInfoStore &store, const Node &node, + std::vector &tasks) { + rtModel_t rt_model_ = nullptr; + GE_CHK_RT_RET(rtModelCreate(&rt_model_, 0)); + rtStream_t stream = nullptr; + GE_CHK_RT_EXEC(rtStreamCreate(&stream, 0), GE_CHK_RT(rtModelDestroy(rt_model_)); return RT_FAILED); + GE_MAKE_GUARD_RTSTREAM(stream); + GE_CHK_RT_EXEC(rtModelBindStream(rt_model_, stream, 0), GE_CHK_RT(rtModelDestroy(rt_model_)); return RT_FAILED); + + RunContext context; + context.stream = stream; + context.model = rt_model_; + context.graphStreamList.emplace_back(stream); + context.weightMemBase = reinterpret_cast(kWeightBase); + context.dataMemBase = reinterpret_cast(kWeightBase); + context.weightMemSize = kFakeSize; + context.dataMemSize = kFakeSize; + + Status ret; + { + std::lock_guard lk(mu_); + ret = store.GenerateTask(node, context, tasks); + } + + GE_CHK_STATUS(ret, "Failed to execute GenerateTask, node = %s", node.GetName().c_str()); + GE_CHK_RT(rtModelUnbindStream(rt_model_, stream)); + GE_CHK_RT(rtModelDestroy(rt_model_)); + return ret; +} + +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.h b/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.h new file mode 100644 index 00000000..39673188 --- /dev/null +++ b/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_KERNEL_AICORE_TASK_COMPILER_H_ +#define GE_HYBRID_KERNEL_AICORE_TASK_COMPILER_H_ + +#include +#include "opskernel_manager/ops_kernel_manager.h" + +namespace ge { +namespace hybrid { +class AiCoreTaskCompiler { + public: + explicit AiCoreTaskCompiler(OpsKernelInfoStorePtr aic_kernel_store); + ~AiCoreTaskCompiler() = default; + + Status CompileOp(const NodePtr &node, std::vector &tasks) const; + + private: + static Status DoCompileOp(OpsKernelInfoStore &store, const NodePtr &node); + static Status DoGenerateTask(OpsKernelInfoStore &store, const Node &node, std::vector &tasks); + OpsKernelInfoStorePtr aic_kernel_store_; + static std::mutex mu_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_KERNEL_AICORE_TASK_COMPILER_H_ diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc new file mode 100644 index 00000000..2698f79e --- /dev/null +++ b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc @@ -0,0 +1,662 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/node_executor/aicpu/aicpu_node_executor.h" +#include "common/formats/formats.h" +#include "graph/load/new_model_manager/model_manager.h" +#include "hybrid/common/npu_memory_allocator.h" +#include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/model/hybrid_model.h" +#include "init/gelib.h" + +namespace ge { +namespace hybrid { +using aicpu::FWKAdapter::ExtInfo; +namespace { +// mem need release +constexpr uint64_t kReleaseFlag = 1; + +// max dim count is 8. +constexpr uint32_t kMaxDimCount = 8; + +// if dim count is not reach kMaxDimCount, use INT64_MIN to mark dim end. +constexpr int64_t kDimEndFlag = INT64_MIN; + +struct MaxShape { + int64_t dims[kMaxDimCount] = {0}; +}; +} // namespace +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_TF, AiCpuNodeExecutor); + +Status AicpuTfNodeTask::AllocTensorBuffer(size_t size, std::unique_ptr &tensor_buffer) { + auto allocator = NpuMemoryAllocator::GetAllocator(); + GE_CHECK_NOTNULL(allocator); + tensor_buffer = TensorBuffer::Create(allocator, size); + GE_CHECK_NOTNULL(tensor_buffer); + return SUCCESS; +} + +Status AicpuTfNodeTask::InitExtInfo() { + // exit info, 0: op type + size_t ext_info_size = sizeof(ExtInfo) + sizeof(uint32_t); + ext_info_num_ = 1; + // exit info 1:input shape, 2:output shape + if (input_num_ > 0) { + ext_info_size += sizeof(ExtInfo) + input_num_ * sizeof(MaxShape); + ++ext_info_num_; + } + + // exit info 2:output shape + if ((unknown_type_ != DEPEND_COMPUTE) && (output_num_ > 0)) { + ext_info_size += sizeof(ExtInfo) + output_num_ * sizeof(MaxShape); + ++ext_info_num_; + } + + GE_CHK_STATUS_RET(AllocTensorBuffer(ext_info_size, ext_info_addr_dev_), + "Node %s alloc buffer for ext info failed, size=%zu.", node_->GetName().c_str(), ext_info_size); + + auto ext_info_dev_base = reinterpret_cast(ext_info_addr_dev_->GetData()); + ext_info_addr_host_.reset(new (std::nothrow) uint8_t[ext_info_size]); + GE_CHECK_NOTNULL(ext_info_addr_host_); + + size_t ext_info_type_offset = ext_info_num_ * sizeof(ExtInfo); + size_t ext_info_input_shape_offset = ext_info_type_offset + sizeof(uint32_t); + + auto ext_info_host_buf = ext_info_addr_host_.get(); + + auto ext_info_type = reinterpret_cast(ext_info_host_buf); + ext_info_type->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_SHAPE_TYPE; + ext_info_type->infoLen = sizeof(uint32_t); + ext_info_type->infoAddr = ext_info_dev_base + ext_info_type_offset; + // set unknown shape type + auto unkonw_shape_type_addr = reinterpret_cast(ext_info_host_buf + ext_info_type_offset); + *unkonw_shape_type_addr = unknown_type_; + + if (input_num_ > 0) { + auto ext_info_input = reinterpret_cast(ext_info_host_buf + sizeof(ExtInfo)); + ext_info_input->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_INPUT_SHAPE; + ext_info_input->infoLen = input_num_ * sizeof(MaxShape); + ext_info_input->infoAddr = ext_info_dev_base + ext_info_input_shape_offset; + } + if ((unknown_type_ != DEPEND_COMPUTE) && (output_num_ > 0)) { + size_t ext_info_output_shape_offset = ext_info_input_shape_offset + input_num_ * sizeof(MaxShape); + auto ext_info_output = reinterpret_cast(ext_info_host_buf + (ext_info_num_ - 1) * sizeof(ExtInfo)); + ext_info_output->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_OUTPUT_SHAPE; + ext_info_output->infoLen = output_num_ * sizeof(MaxShape); + ext_info_output->infoAddr = ext_info_dev_base + ext_info_output_shape_offset; + } + + GE_CHK_RT_RET(rtMemcpy(ext_info_addr_dev_->GetData(), ext_info_addr_dev_->GetSize(), ext_info_host_buf, ext_info_size, + RT_MEMCPY_HOST_TO_DEVICE)); + return SUCCESS; +} + +Status AicpuTfNodeTask::InitForDependComputeTask() { + if ((unknown_type_ != DEPEND_COMPUTE) || (output_num_ == 0)) { + GELOGI("node %s type %s unknown_type is %d, output num is %zu.", node_->GetName().c_str(), node_->GetType().c_str(), + unknown_type_, output_num_); + return SUCCESS; + } + + output_summary_.resize(output_num_); + constexpr auto result_summary_size = sizeof(aicpu::FWKAdapter::ResultSummary); + for (size_t i = 0; i < output_num_; ++i) { + GE_CHK_STATUS_RET(AllocTensorBuffer(result_summary_size, output_summary_[i]), + "Node %s alloc buffer for ext info failed, size=%zu.", node_->GetName().c_str(), + result_summary_size); + } + output_summary_host_.resize(output_num_); + + // init for mem copy task + // copy task need copy output_data and output_shape, max len is 2 * output_num + const size_t copy_input_buf_len = output_num_ * 2 * sizeof(uint64_t); + GE_CHK_STATUS_RET(AllocTensorBuffer(copy_input_buf_len, copy_input_release_flag_dev_), + "Node %s alloc copy task input release_flag failed, size=%zu", node_->GetName().c_str(), + copy_input_buf_len); + GE_CHK_STATUS_RET(AllocTensorBuffer(copy_input_buf_len, copy_input_data_size_dev_), + "Node %s alloc copy task input data_size failed, size=%zu", node_->GetName().c_str(), + copy_input_buf_len); + GE_CHK_STATUS_RET(AllocTensorBuffer(copy_input_buf_len, copy_input_src_dev_), + "Node %s alloc copy task input src failed, size=%zu", node_->GetName().c_str(), copy_input_buf_len); + GE_CHK_STATUS_RET(AllocTensorBuffer(copy_input_buf_len, copy_input_dst_dev_), + "Node %s alloc copy task input dst failed, size=%zu", node_->GetName().c_str(), copy_input_buf_len); + + // copy task args buf + GE_CHK_STATUS_RET(AllocTensorBuffer(sizeof(STR_FWK_OP_KERNEL), copy_task_args_buf_), + "Node %s alloc copy task args buf failed, size=%zu", node_->GetName().c_str(), + sizeof(STR_FWK_OP_KERNEL)); + + std::vector copy_io_addr; + copy_io_addr.emplace_back(reinterpret_cast(copy_input_release_flag_dev_->GetData())); + copy_io_addr.emplace_back(reinterpret_cast(copy_input_data_size_dev_->GetData())); + copy_io_addr.emplace_back(reinterpret_cast(copy_input_src_dev_->GetData())); + copy_io_addr.emplace_back(reinterpret_cast(copy_input_dst_dev_->GetData())); + + // mem copy op has 4 inputs and 0 output. + const auto copy_io_addr_size = sizeof(uint64_t) * copy_io_addr.size(); + + // can alloc in init, it can reuse + GE_CHK_STATUS_RET(AllocTensorBuffer(copy_io_addr_size, copy_ioaddr_dev_), + "Node %s alloc copy task io buf failed, size=%zu", node_->GetName().c_str(), copy_io_addr_size); + + GE_CHK_RT_RET(rtMemcpy(copy_ioaddr_dev_->GetData(), copy_io_addr_size, ©_io_addr[0], copy_io_addr_size, + RT_MEMCPY_HOST_TO_DEVICE)); + return SUCCESS; +} + +Status AicpuTfNodeTask::Init(const HybridModel &model) { + auto node_name = node_->GetName(); + GELOGI("AicpuTfNodeTask[%s] Init Start.", node_name.c_str()); + auto op_desc = node_->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + const auto node_item = model.GetNodeItem(node_); + GE_CHECK_NOTNULL(node_item); + unknown_type_ = node_item->shape_inference_type; + + auto &kernel_ex_def = task_def_.kernel_ex(); + + auto kernel_workspace_size = static_cast(kernel_ex_def.task_info_size()); + GE_CHK_STATUS_RET(AllocTensorBuffer(kernel_workspace_size, kernel_workspace_), + "Node %s alloc buffer for kernel workspace failed, size=%zu.", node_name.c_str(), + kernel_workspace_size); + + GE_CHK_RT_RET(rtMemcpy(kernel_workspace_->GetData(), kernel_workspace_size, kernel_ex_def.task_info().data(), + static_cast(kernel_ex_def.task_info_size()), RT_MEMCPY_HOST_TO_DEVICE)); + input_num_ = op_desc->GetInputsSize(); + output_num_ = op_desc->GetOutputsSize(); + size_t input_output_size = (input_num_ + output_num_) * sizeof(uint64_t); + if (input_output_size > 0) { + // alloc input output addr buf + GE_CHK_STATUS_RET(AllocTensorBuffer(input_output_size, input_output_addr_), + "Node %s alloc buffer for input output addr failed, size=%zu.", node_name.c_str(), + input_output_size); + } + + // init ext info + GE_CHK_STATUS_RET(InitExtInfo(), "Task %s init ext info failed.", node_name.c_str()); + GE_CHK_STATUS_RET(InitForDependComputeTask(), "Task %s init for depend compute task failed.", node_name.c_str()); + + // build fwk_op_kernel. + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(sizeof(STR_FWK_OP_KERNEL) < kernel_ex_def.args_size(), return FAILED, + "sizeof STR_FWK_OP_KERNEL is: %zu, but args_size is: %u", sizeof(STR_FWK_OP_KERNEL), + kernel_ex_def.args_size()); + + STR_FWK_OP_KERNEL fwk_op_kernel = {0}; + errno_t sec_ret = memcpy_s(&fwk_op_kernel, sizeof(STR_FWK_OP_KERNEL), kernel_ex_def.args().data(), + static_cast(kernel_ex_def.args_size())); + GE_CHK_BOOL_EXEC(sec_ret == EOK, return INTERNAL_ERROR, "memcpy fwk_op_kernel failed, ret: %d", sec_ret); + + fwk_op_kernel.fwkKernelBase.fwk_kernel.workspaceBaseAddr = reinterpret_cast(kernel_workspace_->GetData()); + fwk_op_kernel.fwkKernelBase.fwk_kernel.inputOutputAddr = reinterpret_cast(input_output_addr_->GetData()); + // set ext info addr and ext info num + fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = reinterpret_cast(ext_info_addr_dev_->GetData()); + fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoNum = ext_info_num_; + + // get step_id_addr + auto var_tensor = model.GetVariable(NODE_NAME_GLOBAL_STEP); + uint64_t step_id_addr = 0; + if (var_tensor != nullptr) { + step_id_addr = reinterpret_cast(var_tensor->GetData()); + } + + fwk_op_kernel.fwkKernelBase.fwk_kernel.stepIDAddr = step_id_addr; + + auto session_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.sessionID; + GE_CHK_STATUS_RET(EnsureSessionCreated(session_id), "session id %lu create failed.", session_id); + + // alloc kernel_buf_ and copy to device. + GE_CHK_STATUS_RET(AllocTensorBuffer(sizeof(STR_FWK_OP_KERNEL), kernel_buf_), + "Node %s alloc buffer for kernel buf failed, size=%zu.", node_name.c_str(), + sizeof(STR_FWK_OP_KERNEL)); + + GE_CHK_RT_RET(rtMemcpy(kernel_buf_->GetData(), sizeof(STR_FWK_OP_KERNEL), &fwk_op_kernel, sizeof(STR_FWK_OP_KERNEL), + RT_MEMCPY_HOST_TO_DEVICE)); + + GELOGI("AicpuTfNodeTask[%s] init end.", node_name.c_str()); + return SUCCESS; +} + +Status AicpuTfNodeTask::EnsureSessionCreated(uint64_t session_id) { + auto model_manager = ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + GE_CHK_STATUS_RET(model_manager->CreateAicpuSession(session_id), "Create aicpu session %u failed", session_id); + return SUCCESS; +} + +Status AicpuTfNodeTask::SetShapeToBuf(const GeShape &shape, int64_t buf[], uint32_t buf_size) { + auto node_name = node_->GetName(); + uint32_t index = 0; + int64_t shape_size = shape.GetDimNum(); + if (shape_size > buf_size) { + GELOGI("SetShapeToBuf[%s] failed, as shape size %ld is over %u.", node_name.c_str(), shape_size, buf_size); + return PARAM_INVALID; + } + for (; index < shape_size; ++index) { + buf[index] = shape.GetDim(index); + } + if (index < buf_size) { + buf[index] = kDimEndFlag; + } + return SUCCESS; +} + +Status AicpuTfNodeTask::UpdateShapeToOutputDesc(const GeShape &shape_new, size_t output_index, + GeTensorDescPtr &output_desc) { + auto node_name = node_->GetName(); + auto shape_old = output_desc->GetShape(); + output_desc->SetShape(shape_new); + GELOGI("Update node[%s] out[%zu] shape from %s to %s.", node_name.c_str(), output_index, shape_old.ToString().c_str(), + shape_new.ToString().c_str()); + + auto origin_shape_old = output_desc->GetOriginShape(); + auto origin_format = output_desc->GetOriginFormat(); + auto format = output_desc->GetFormat(); + if (origin_format == format) { + output_desc->SetOriginShape(shape_new); + return SUCCESS; + } + // if format is not same need convert shape + std::vector origin_dims_new; + auto trans_ret = + formats::TransShape(format, shape_new.GetDims(), output_desc->GetDataType(), origin_format, origin_dims_new); + GE_CHK_STATUS_RET(trans_ret, + "Node[%s] out[%zu] originFormat[%d] is not same as format[%d], but TransShape failed, shape=%s.", + node_name.c_str(), output_index, origin_format, format, shape_new.ToString().c_str()); + auto origin_shape_new = GeShape(origin_dims_new); + output_desc->SetOriginShape(origin_shape_new); + GELOGI("Node[%s] out[%zu] originFormat[%d] is not same as format[%d], need update from %s ro %s.", node_name.c_str(), + output_index, origin_format, format, origin_shape_old.ToString().c_str(), origin_shape_new.ToString().c_str()); + return SUCCESS; +} + +Status AicpuTfNodeTask::ReadResultSummaryAndPrepareMemory(TaskContext &context, + std::vector> &out_shape_hbm) { + for (size_t i = 0; i < output_num_; ++i) { + auto &result_summary = output_summary_host_[i]; + GE_CHK_RT_RET(rtMemcpy(&result_summary, sizeof(aicpu::FWKAdapter::ResultSummary), output_summary_[i]->GetData(), + output_summary_[i]->GetSize(), RT_MEMCPY_DEVICE_TO_HOST)); + + GELOGI( + "Node[%s] out[%zu] result summary addr=%p," + " shape_data_ptr=0x%lx, shape_data_size=%lu, raw_data_ptr=0x%lx, raw_data_size=%lu.", + node_->GetName().c_str(), i, output_summary_[i]->GetData(), result_summary.shape_data_ptr, + result_summary.shape_data_size, result_summary.raw_data_ptr, result_summary.raw_data_size); + + auto raw_data_size = result_summary.raw_data_size; + std::unique_ptr tensor_buffer; + GE_CHK_STATUS_RET(AllocTensorBuffer(raw_data_size, tensor_buffer), "alloc tensor buffer failed, raw_data_size=%lu", + raw_data_size); + auto status = context.SetOutput(i, TensorValue(std::shared_ptr(tensor_buffer.release()))); + GE_CHK_STATUS_RET(status, "SetOutput %zu failed.", i); + + auto shape_data_size = result_summary.shape_data_size; + std::unique_ptr shape_buffer; + GE_CHK_STATUS_RET(AllocTensorBuffer(shape_data_size, shape_buffer), + "alloc shape buffer failed, shape_data_size=%lu", shape_data_size); + out_shape_hbm.emplace_back(std::move(shape_buffer)); + } + return SUCCESS; +} + +Status AicpuTfNodeTask::CopyDataToHbm(TaskContext &context, + const std::vector> &out_shape_hbm) { + GE_CHK_BOOL_RET_STATUS(out_shape_hbm.size() == output_num_, INTERNAL_ERROR, + "Node %s has %zu outputs but out shape is %zu", node_->GetName().c_str(), output_num_, + out_shape_hbm.size()); + + std::vector copy_input_release_flag; + std::vector copy_input_data_size; + std::vector copy_input_src; + std::vector copy_input_dst; + + for (size_t i = 0; i < output_num_; ++i) { + const auto &summary = output_summary_host_[i]; + GELOGI("node[%s] [%zu]th output summary, shape data=%lx, shape data size=%lu, raw data=%lx, raw data size=%lu.", + node_->GetName().c_str(), i, summary.shape_data_ptr, summary.shape_data_size, summary.raw_data_ptr, + summary.raw_data_size); + if (summary.raw_data_size > 0) { + auto output = context.GetOutput(i); + GE_CHECK_NOTNULL(output); + GE_CHECK_NOTNULL(output->GetData()); + copy_input_release_flag.emplace_back(kReleaseFlag); + copy_input_data_size.emplace_back(summary.raw_data_size); + copy_input_src.emplace_back(summary.raw_data_ptr); + copy_input_dst.emplace_back(reinterpret_cast(output->GetData())); + } + + if (summary.shape_data_size > 0) { + const auto &shape_buffer = out_shape_hbm[i]; + GE_CHECK_NOTNULL(shape_buffer); + GE_CHECK_NOTNULL(shape_buffer->GetData()); + copy_input_release_flag.emplace_back(kReleaseFlag); + copy_input_data_size.emplace_back(summary.shape_data_size); + copy_input_src.emplace_back(summary.shape_data_ptr); + copy_input_dst.emplace_back(reinterpret_cast(shape_buffer->GetData())); + } + } + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(copy_input_release_flag.empty(), return INTERNAL_ERROR, "Node %s need copy num is 0", + node_->GetName().c_str()); + + auto copy_num = copy_input_release_flag.size(); + STR_FWK_OP_KERNEL aicpu_task = {0}; + std::string task_info; + RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_->GetName().c_str(), "[GenMemCopyTask] Start"); + GE_CHK_STATUS_RET_NOLOG(GenMemCopyTask(copy_num, aicpu_task, task_info)); + RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_->GetName().c_str(), "[GenMemCopyTask] End"); + + // copy task need copy output and output shape + const size_t copy_input_buf_len = copy_num * sizeof(uint64_t); + + GE_CHK_RT_RET(rtMemcpy(copy_input_release_flag_dev_->GetData(), copy_input_release_flag_dev_->GetSize(), + ©_input_release_flag[0], copy_input_buf_len, RT_MEMCPY_HOST_TO_DEVICE)); + GE_CHK_RT_RET(rtMemcpy(copy_input_data_size_dev_->GetData(), copy_input_data_size_dev_->GetSize(), + ©_input_data_size[0], copy_input_buf_len, RT_MEMCPY_HOST_TO_DEVICE)); + GE_CHK_RT_RET(rtMemcpy(copy_input_src_dev_->GetData(), copy_input_src_dev_->GetSize(), ©_input_src[0], + copy_input_buf_len, RT_MEMCPY_HOST_TO_DEVICE)); + GE_CHK_RT_RET(rtMemcpy(copy_input_dst_dev_->GetData(), copy_input_dst_dev_->GetSize(), ©_input_dst[0], + copy_input_buf_len, RT_MEMCPY_HOST_TO_DEVICE)); + + std::unique_ptr kernel_workspace_buf; + GE_CHK_STATUS_RET(AllocTensorBuffer(task_info.size(), kernel_workspace_buf), + "Node %s alloc copy task workspace buf failed, size=%zu", node_->GetName().c_str(), + task_info.size()); + + GE_CHK_RT_RET(rtMemcpy(kernel_workspace_buf->GetData(), task_info.size(), task_info.data(), task_info.size(), + RT_MEMCPY_HOST_TO_DEVICE)); + + aicpu_task.fwkKernelBase.fwk_kernel.inputOutputAddr = reinterpret_cast(copy_ioaddr_dev_->GetData()); + aicpu_task.fwkKernelBase.fwk_kernel.workspaceBaseAddr = reinterpret_cast(kernel_workspace_buf->GetData()); + aicpu_task.fwkKernelBase.fwk_kernel.extInfoAddr = 0; + aicpu_task.fwkKernelBase.fwk_kernel.extInfoNum = 0; + + GE_CHK_RT_RET(rtMemcpy(copy_task_args_buf_->GetData(), sizeof(STR_FWK_OP_KERNEL), &aicpu_task, + sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE)); + + RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_->GetName().c_str(), "[LaunchCopy] Start"); + GE_CHK_RT_RET(rtKernelLaunchEx(copy_task_args_buf_->GetData(), sizeof(STR_FWK_OP_KERNEL), RT_KERNEL_DEFAULT, + context.GetStream())); + RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_->GetName().c_str(), "[LaunchCopy] End"); + + GE_CHK_RT_RET(rtStreamSynchronize(context.GetStream())); + RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_->GetName().c_str(), "[SynchronizeCopy] End"); + return SUCCESS; +} + +Status AicpuTfNodeTask::GenMemCopyTask(uint64_t copy_num, STR_FWK_OP_KERNEL &task, string &task_info) { + auto instance_ptr = ge::GELib::GetInstance(); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(instance_ptr == nullptr || !instance_ptr->InitFlag(), return GE_CLI_GE_NOT_INITIALIZED, + "GE is not initialized"); + + static constexpr const char *const kKernelLibName = "aicpu_kernel"; + OpsKernelInfoStorePtr kernel_info = instance_ptr->OpsKernelManagerObj().GetOpsKernelInfoStore(kKernelLibName); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(kernel_info == nullptr, return FAILED, "Get op kernel info store failed"); + auto ret = kernel_info->GenMemCopyTask(copy_num, task, task_info); + GE_CHK_STATUS_RET(ret, "call aicpu GenMemCopyTask failed, copy_num=%lu, ret=%u", copy_num, ret); + return SUCCESS; +} + +Status AicpuTfNodeTask::UpdateShapeByHbmBuffer(TaskContext &context, + const std::vector> &out_shape_hbm) { + GE_CHK_BOOL_RET_STATUS(out_shape_hbm.size() == output_num_, INTERNAL_ERROR, + "Node %s has %zu outputs but out shape is %zu", node_->GetName().c_str(), output_num_, + out_shape_hbm.size()); + auto op_desc = node_->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (size_t i = 0; i < output_num_; ++i) { + const auto &result_summary = output_summary_host_[i]; + auto output_desc = op_desc->MutableOutputDesc(i); + std::vector shape_dims; + if (result_summary.shape_data_size > 0) { + const auto &shape_hbm = out_shape_hbm[i]; + GE_CHK_BOOL_RET_STATUS((result_summary.shape_data_size % sizeof(int64_t) == 0), INTERNAL_ERROR, + "node %s %zuth output shape data size is %lu is not divided by int64_t.", + node_->GetName().c_str(), i, result_summary.shape_data_size); + uint32_t dim_num = result_summary.shape_data_size / sizeof(int64_t); + GELOGI("node %s %zuth output dim num=%lu.", node_->GetName().c_str(), i, dim_num); + std::unique_ptr shape_addr(new (std::nothrow) int64_t[dim_num]()); + GE_CHECK_NOTNULL(shape_addr); + GE_CHK_RT_RET(rtMemcpy(shape_addr.get(), result_summary.shape_data_size, shape_hbm->GetData(), + shape_hbm->GetSize(), RT_MEMCPY_DEVICE_TO_HOST)); + for (uint32_t dim_idx = 0; dim_idx < dim_num; ++dim_idx) { + shape_dims.emplace_back(shape_addr[dim_idx]); + GELOGD("node %s %zuth output dim[%u]=%lu.", node_->GetName().c_str(), i, dim_idx, shape_addr[dim_idx]); + } + } + GE_CHK_STATUS_RET(UpdateShapeToOutputDesc(GeShape(shape_dims), i, output_desc), + "update node %s %uth output shape failed.", node_->GetName().c_str(), i); + } + return SUCCESS; +} + +Status AicpuTfNodeTask::UpdateOutputShapeFromExtInfo() { + auto node_name = node_->GetName(); + if (output_num_ == 0) { + GELOGI("Task [%s] output_num is 0, no need reset output shape.", node_name.c_str()); + return SUCCESS; + } + + auto ext_output_shape_offset = ext_info_num_ * sizeof(ExtInfo) + sizeof(uint32_t) + input_num_ * sizeof(MaxShape); + size_t ext_info_output_shape_len = output_num_ * sizeof(MaxShape); + auto output_shape_host_buf = ext_info_addr_host_.get() + ext_output_shape_offset; + auto output_shape_dev_buf = reinterpret_cast(ext_info_addr_dev_->GetData()) + ext_output_shape_offset; + + GE_CHK_RT_RET(rtMemcpy(output_shape_host_buf, ext_info_output_shape_len, output_shape_dev_buf, + ext_info_output_shape_len, RT_MEMCPY_DEVICE_TO_HOST)); + + auto op_desc = node_->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + auto shapeBuf = reinterpret_cast(output_shape_host_buf); + for (uint32_t i = 0; i < output_num_; ++i) { + std::vector dims; + GetShapeFromBuf(shapeBuf + i * kMaxDimCount, kMaxDimCount, dims); + auto output_desc = op_desc->MutableOutputDesc(i); + GE_CHECK_NOTNULL(output_desc); + GE_CHK_STATUS_RET(UpdateShapeToOutputDesc(GeShape(dims), i, output_desc), + "update node %s %uth output shape failed.", node_name.c_str(), i); + } + + return SUCCESS; +} + +Status AicpuTfNodeTask::UpdateShapeAndDataByResultSummary(TaskContext &context) { + GELOGI("Task [%s] update shape and data by result summary begin.", node_->GetName().c_str()); + + std::vector> out_shape_hbm; + GE_CHK_STATUS_RET(ReadResultSummaryAndPrepareMemory(context, out_shape_hbm), + "node %s read ResultSummary and update output shape failed.", node_->GetName().c_str()); + + RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_->GetName().c_str(), + "[ReadResultSummaryAndPrepareMemory] End"); + + GE_CHK_STATUS_RET(CopyDataToHbm(context, out_shape_hbm), "node %s copy data to output failed.", + node_->GetName().c_str()); + + RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_->GetName().c_str(), "[CopyDataToHbm] End"); + + GE_CHK_STATUS_RET(UpdateShapeByHbmBuffer(context, out_shape_hbm), "node %s update shape by hbm buffer failed.", + node_->GetName().c_str()); + + GELOGI("Task [%s] update shape and data by result summary end.", node_->GetName().c_str()); + return SUCCESS; +} + +void AicpuTfNodeTask::GetShapeFromBuf(const int64_t buf[], uint32_t buf_size, std::vector &dims) { + for (uint32_t index = 0; index < buf_size; ++index) { + auto tmpDim = buf[index]; + if (tmpDim == kDimEndFlag) { + break; + } + dims.emplace_back(tmpDim); + } +} + +Status AicpuTfNodeTask::UpdateArgs(TaskContext &context) { + auto node_name = node_->GetName(); + GELOGI("AicpuTfNodeTask[%s] UpdateArgs begin. unknown_type=%d", node_name.c_str(), unknown_type_); + auto op_desc = node_->GetOpDesc(); + auto io_nums = input_num_ + output_num_; + if (io_nums == 0) { + GELOGI("Node %s has no input and output, no need update args.", node_name.c_str()); + return SUCCESS; + } + + vector io_addrs(io_nums, 0UL); + size_t ext_shape_nums = (unknown_type_ == DEPEND_COMPUTE) ? input_num_ : io_nums; + vector io_shapes(ext_shape_nums); + + uint32_t index = 0; + for (size_t i = 0; i < input_num_; ++i, ++index) { + auto inputData = context.GetInput(i); + GE_CHECK_NOTNULL(inputData); + auto input_desc = op_desc->MutableInputDesc(i); + GE_CHECK_NOTNULL(input_desc); + auto &shape = input_desc->MutableShape(); + + GELOGD("io_addr[%u] = %p, size = %zu", index, inputData->GetData(), inputData->GetSize()); + io_addrs[index] = reinterpret_cast(inputData->GetData()); + GE_CHK_STATUS_RET(SetShapeToBuf(shape, io_shapes[index].dims, kMaxDimCount), + "task %s input[%zu] SetShapeToBuf failed.", node_name.c_str(), i); + } + + if (unknown_type_ != DEPEND_COMPUTE) { + // unknown type 4 do this in call back. + GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); + for (size_t j = 0; j < output_num_; ++j, ++index) { + auto outputData = context.GetOutput(j); + GE_CHECK_NOTNULL(outputData); + auto output_desc = op_desc->MutableOutputDesc(j); + GE_CHECK_NOTNULL(output_desc); + auto shape = output_desc->GetShape(); + + // shape range need use range update shape + if (unknown_type_ == DEPEND_SHAPE_RANGE) { + std::vector> range; + auto range_ret = output_desc->GetShapeRange(range); + GE_CHK_BOOL_RET_STATUS(range_ret == GRAPH_SUCCESS, INTERNAL_ERROR, + "node %s has is shape range but get GetShapeRange failed, ret=%u.", node_name.c_str(), + range_ret); + for (size_t k = 0; k < range.size(); ++k) { + if (shape.GetDim(k) < 0 && k < range.size()) { + GELOGD("node %s output[%zu] update dim[%zu] from %lu to range max %lu.", node_name.c_str(), j, k, + shape.GetDim(k), range[k].second); + shape.SetDim(k, range[k].second); + } + } + } + + GELOGD("io_addr[%u] = %p, size = %zu", index, outputData->GetData(), outputData->GetSize()); + io_addrs[index] = reinterpret_cast(outputData->GetData()); + GE_CHK_STATUS_RET(SetShapeToBuf(shape, io_shapes[index].dims, kMaxDimCount), + "task %s output[%zu] SetShapeToBuf failed.", node_name.c_str(), j); + } + } else { + // unknown type 4 use result summary update ioaddr. + GELOGI("AicpuTfNodeTask[%s] is unknown-shape, use ResultSummary as out-addr.", node_name.c_str()); + GE_CHK_BOOL_RET_STATUS(output_summary_.size() == output_num_, INTERNAL_ERROR, + "node %s has %zu output but %zu output summary.", node_name.c_str(), output_num_, + output_summary_.size()); + + for (size_t j = 0; j < output_num_; ++j, ++index) { + void *summary_addr = output_summary_[j]->GetData(); + io_addrs[index] = reinterpret_cast(summary_addr); + } + } + + // if has input and output, need copy to ioaddr + if (io_nums > 0) { + // copy input and output to device + GE_CHK_RT_RET(rtMemcpy(input_output_addr_->GetData(), input_output_addr_->GetSize(), &io_addrs[0], + sizeof(uint64_t) * io_addrs.size(), RT_MEMCPY_HOST_TO_DEVICE)); + } + + // if has shape ext info, need copy to ext addr + if (ext_shape_nums > 0) { + uint32_t offset = ext_info_num_ * sizeof(ExtInfo) + sizeof(uint32_t); + uint32_t len = sizeof(MaxShape) * ext_shape_nums; + auto ext_addr_dev_base = reinterpret_cast(ext_info_addr_dev_->GetData()) + offset; + // copy input and output shapes to device + GE_CHK_RT_RET(rtMemcpy(ext_addr_dev_base, ext_info_addr_dev_->GetSize() - offset, &io_shapes[0], len, + RT_MEMCPY_HOST_TO_DEVICE)); + } + + GELOGI("AicpuTfNodeTask[%s] UpdateArgs end.", node_name.c_str()); + return SUCCESS; +} + +Status AicpuTfNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + auto node_name = node_->GetName(); + GELOGI("AicpuTfNodeTask[%s] ExecuteAsync Start. unknown_type=%d.", node_name.c_str(), unknown_type_); + + uint32_t flag = RT_KERNEL_DEFAULT; + GE_CHK_RT_RET(rtKernelLaunchEx(kernel_buf_->GetData(), kernel_buf_->GetSize(), flag, context.GetStream())); + + auto callback = [=, &context]() { + GELOGI("AicpuTfNodeTask[%s] callback start.", node_->GetName().c_str()); + RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_->GetName().c_str(), "[TaskCallback] Start"); + Status callback_ret = SUCCESS; + // check need update shape, call update shape. + if (unknown_type_ == DEPEND_SHAPE_RANGE) { + // check result + callback_ret = UpdateOutputShapeFromExtInfo(); + } else if (unknown_type_ == DEPEND_COMPUTE) { + callback_ret = UpdateShapeAndDataByResultSummary(context); + } + + GELOGI("AicpuTfNodeTask[%s] refresh output complete, ret = %d.", node_->GetName().c_str(), callback_ret); + RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_->GetName().c_str(), "[TaskCallback] End"); + + if (done_callback != nullptr) { + context.SetStatus(callback_ret); + done_callback(); + } + + GELOGI("AicpuTfNodeTask[%s] callback end.", node_->GetName().c_str()); + }; + + GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(callback)); + + GELOGI("AicpuTfNodeTask[%s] ExecuteAsync end.", node_name.c_str()); + return SUCCESS; +} + +Status AiCpuNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { + // malloc HBM memory at Init, here just update them + return task.UpdateArgs(context); +} + +Status AiCpuNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, + std::shared_ptr &task) const { + GE_CHECK_NOTNULL(node); + GELOGI("Node[%s] create task start.", node->GetName().c_str()); + auto task_defs = model.GetTaskDefs(node); + GE_CHECK_NOTNULL(task_defs); + GE_CHK_BOOL_EXEC((*task_defs).size() == 1, return PARAM_INVALID, "aicpu op[%s] task_def num[%zu] != 1", + node->GetName().c_str(), (*task_defs).size()); + auto aicpu_task = MakeShared(node, (*task_defs)[0]); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aicpu_task == nullptr, return MEMALLOC_FAILED, + "create aicpuTfNodeTask for node %s failed", node->GetName().c_str()); + + GE_CHK_STATUS_RET(aicpu_task->Init(model), "AicpuTfNodeTask %s Init failed.", node->GetName().c_str()); + + task = std::move(aicpu_task); + GELOGI("Node[%s] create task end.", node->GetName().c_str()); + return SUCCESS; +} +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h new file mode 100644 index 00000000..0444c2aa --- /dev/null +++ b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h @@ -0,0 +1,112 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_KERNEL_AICPU_NODE_EXECUTOR_H_ +#define GE_HYBRID_KERNEL_AICPU_NODE_EXECUTOR_H_ + +#include "external/graph/types.h" +#include "cce/aicpu_engine_struct.h" +#include "hybrid/node_executor/node_executor.h" + +namespace ge { +namespace hybrid { +class AicpuTfNodeTask : public NodeTask { + public: + AicpuTfNodeTask(const NodePtr &node, const domi::TaskDef &task_def) : node_(node), task_def_(task_def) {} + + Status Init(const HybridModel &model); + + ~AicpuTfNodeTask() override = default; + + Status UpdateArgs(TaskContext &context) override; + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + + private: + Status InitExtInfo(); + Status InitForDependComputeTask(); + + Status SetShapeToBuf(const GeShape &shape, int64_t buf[], uint32_t buf_size); + void GetShapeFromBuf(const int64_t buf[], uint32_t buf_size, std::vector &dims); + Status UpdateOutputShapeFromExtInfo(); + + Status UpdateShapeAndDataByResultSummary(TaskContext &context); + + Status UpdateShapeToOutputDesc(const GeShape &shape_new, size_t output_index, GeTensorDescPtr &output_desc); + + /// + /// read result summary and prepare copy task memory. + /// @param context task context + /// @param out_shape_hbm if scalar, TensorBuffer->data is null, size=0 + /// @return SUCCESS:success other:failed + /// + Status ReadResultSummaryAndPrepareMemory(TaskContext &context, + std::vector> &out_shape_hbm); + Status CopyDataToHbm(TaskContext &context, const std::vector> &out_shape_hbm); + + Status UpdateShapeByHbmBuffer(TaskContext &context, const std::vector> &out_shape_hbm); + + // common method + static Status AllocTensorBuffer(size_t size, std::unique_ptr &tensor_buffer); + static Status EnsureSessionCreated(uint64_t session_id); + static Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info); + + private: + const NodePtr node_; + // just reference. + const domi::TaskDef &task_def_; + + UnknowShapeOpType unknown_type_ = DEPEND_IN_SHAPE; + + size_t input_num_ = 0; + + size_t output_num_ = 0; + + // kernel buf, device mem + std::unique_ptr kernel_buf_; + + std::unique_ptr kernel_workspace_; + + // input and output addr, device mem + std::unique_ptr input_output_addr_; + + // ext info addr, device mem + std::unique_ptr ext_info_addr_dev_; + std::unique_ptr ext_info_addr_host_; + uint32_t ext_info_num_ = 0; + + // just used for depend DEPEND_COMPUTE op + std::unique_ptr copy_task_args_buf_; + std::vector> output_summary_; + std::vector output_summary_host_; + + std::unique_ptr copy_ioaddr_dev_; + + std::unique_ptr copy_input_release_flag_dev_; + std::unique_ptr copy_input_data_size_dev_; + std::unique_ptr copy_input_src_dev_; + std::unique_ptr copy_input_dst_dev_; +}; + +class AiCpuNodeExecutor : public NodeExecutor { + public: + Status LoadTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const override; + + Status PrepareTask(NodeTask &task, TaskContext &context) const override; +}; + +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_KERNEL_AICPU_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc new file mode 100644 index 00000000..97c8cdbe --- /dev/null +++ b/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc @@ -0,0 +1,165 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/node_executor/compiledsubgraph/known_node_executor.h" +#include "cce/aicpu_engine_struct.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/fmk_error_codes.h" +#include "common/ge/ge_util.h" +#include "graph/attr_value.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/load/new_model_manager/model_utils.h" +#include "graph/load/new_model_manager/model_manager.h" + +namespace ge { +namespace hybrid { + +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH, KnownNodeExecutor); + +Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GELOGI("[%s] KnownNodeTask::ExecuteAsync in.", context.GetNodeName()); + if (davinci_model_->GetTaskList().size() == 0) { + GELOGW("KnownNodeExecutor::ExecuteAsync davinci moel has no taskinfo."); + + // todo if data is connected to netoutput, forward address ? copy data? + if (context.NumInputs() == context.NumOutputs()) { + GELOGW("[%s] KnownNodeExecutor::ExecuteAsync davinci moel has no taskinfo.", context.GetNodeName()); + for (int i = 0; i < context.NumInputs(); ++i) { + auto tensor = context.MutableInput(i); + GE_CHK_STATUS_RET(context.SetOutput(i, *tensor), "[%s] Failed to set output[%d]", context.GetNodeName(), i); + } + } + + context.RegisterCallback(done_callback); + return SUCCESS; + } + + rtError_t rt_ret; + GELOGI("rtModelExecute start."); + rt_ret = rtModelExecute(davinci_model_->GetRtModelHandle(), davinci_model_->GetRtModelStream(), 0); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtModelExecute error, ret: Ox%X", rt_ret); return FAILED;); + GELOGI("rtModelExecute end"); + + GELOGI("rtStreamSynchronize start."); + rt_ret = rtStreamSynchronize(davinci_model_->GetRtModelStream()); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtStreamSynchronize error, ret: Ox%X", rt_ret); + return FAILED;); + GELOGI("rtStreamSynchronize end."); + + context.RegisterCallback(done_callback); + GELOGI("[%s] KnownNodeTask::ExecuteAsync success.", context.GetNodeName()); + + return SUCCESS; +} + +Status KnownNodeTask::UpdateArgs(TaskContext &context) { + GELOGI("[%s] KnownNodeExecutor::UpdateArgs in.", context.GetNodeName()); + if (davinci_model_->GetTaskList().size() == 0) { + GELOGW("KnownNodeExecutor::UpdateArgs davinci moel has no taskinfo."); + return SUCCESS; + } + + vector inputs; + for (int i = 0; i < context.NumInputs(); ++i) { + TensorValue *tv = context.MutableInput(i); + GE_CHECK_NOTNULL(tv); + inputs.emplace_back(tv->MutableData()); + } + + vector outputs; + for (int i = 0; i < context.NumOutputs(); ++i) { + TensorValue *tv = context.MutableOutput(i); + GE_CHECK_NOTNULL(tv); + outputs.emplace_back(tv->MutableData()); + } + + GE_CHK_STATUS_RET(davinci_model_->UpdateKnownNodeArgs(inputs, outputs), + "known node task update known node args failed."); + GELOGI("[%s] KnownNodeExecutor::UpdateArgs success.", context.GetNodeName()); + return SUCCESS; +} + +Status KnownNodeTask::Init(TaskContext &context) { + // allocate output mem + GE_CHK_STATUS_RET(context.AllocateOutputs(), "known node task allocate output failed."); + + // init davinicmodel + davinci_model_->InitRuntimeParams(); + GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed."); + // allocate mem base + void *buffer = nullptr; + if (davinci_model_->TotalMemSize() != 0) { + GE_CHK_STATUS_RET(context.AllocateWorkspace(davinci_model_->TotalMemSize(), &buffer), + "known node task allocate workspace failed."); + // update mem base + davinci_model_->UpdateMemBase(static_cast(buffer)); + GELOGI("KnownNodeTask::Init mem base is %p, size %u.", davinci_model_->GetRuntimeParam().mem_base, + davinci_model_->GetRuntimeParam().mem_size); + } + if (!load_flag_) { + GE_CHK_STATUS_RET(davinci_model_->Init(), "KnownNodeExecutor::InitDavinciModel failed."); + load_flag_ = true; + } else { + GE_CHK_STATUS_RET( + ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(), davinci_model_->Id()), + "KnownNodeTask::Init destroy aicpu kernel failed."); + } + GELOGI("[%s] KnownNodeExecutor::Init success.", context.GetNodeName()); + return SUCCESS; +} + +Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { + GELOGI("[%s] KnownNodeExecutor::PrepareTask in.", context.GetNodeName()); + + GE_CHK_STATUS_RET(task.Init(context), "known node init davinci model failed."); + + GE_CHK_STATUS_RET(task.UpdateArgs(context), "known node task update args failed."); + GELOGI("[%s] KnownNodeExecutor::PrepareTask success.", context.GetNodeName()); + return SUCCESS; +} + +Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { + GELOGI("[%s] KnownNodeExecutor::LoadTask in.", node->GetName().c_str()); + GE_CHECK_NOTNULL(node); + + const GeModelPtr ge_model = model.GetGeModel(node); + GE_CHECK_NOTNULL(ge_model); + + std::shared_ptr davinci_model = MakeShared(0, nullptr); + GE_CHECK_NOTNULL(davinci_model); + + // set known node flag as true + davinci_model->SetKnownNode(true); + // set model id + davinci_model->SetId(model.GetModelId()); + + GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed."); + + task = MakeShared(davinci_model); + GE_CHECK_NOTNULL(task); + GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str()); + return SUCCESS; +} + +Status KnownNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, + const std::function &callback) const { + GE_CHK_STATUS_RET(task.ExecuteAsync(context, callback), "Failed to execute task. node = %s", + context.GetNodeItem().NodeName().c_str()); + return SUCCESS; +} + +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h b/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h new file mode 100644 index 00000000..5847c833 --- /dev/null +++ b/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRID_KNOWN_NODE_EXECUTOR_H_ +#define HYBRID_KNOWN_NODE_EXECUTOR_H_ +#include "hybrid/node_executor/node_executor.h" +#include "hybrid/model/hybrid_model.h" +#include "graph/op_desc.h" +#include "graph/load/new_model_manager/davinci_model.h" + +namespace ge { +namespace hybrid { +class HybridModel; + +class KnownNodeTask : public NodeTask { + public: + KnownNodeTask(std::shared_ptr davinci_model) : davinci_model_(davinci_model) {} + + ~KnownNodeTask() {} + + Status UpdateArgs(TaskContext &context) override; + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + Status Init(TaskContext &context) override; + + private: + std::shared_ptr davinci_model_ = nullptr; + bool load_flag_ = false; +}; + +class KnownNodeExecutor : public NodeExecutor { + public: + Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const; + Status PrepareTask(NodeTask &task, TaskContext &context) const; + Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function &callback) const; + ~KnownNodeExecutor() {} + + private: + std::shared_ptr davinci_model_ = nullptr; +}; +} // namespace hybrid +} // namespace ge + +#endif // HYBRID_KNOWN_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc b/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc new file mode 100644 index 00000000..2c849b59 --- /dev/null +++ b/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc @@ -0,0 +1,206 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/node_executor/hostcpu/ge_local_node_executor.h" +#include "graph/debug/ge_attr_define.h" +#include "framework/common/util.h" +#include "framework/common/types.h" +#include "inc/kernel.h" +#include "inc/kernel_factory.h" +#include "common/ge/ge_util.h" + +namespace ge { +namespace hybrid { + +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::GE_LOCAL, GeLocalNodeExecutor); + +const std::unordered_map> RefInputTask::out_ref_input_index_ = { + {DATA, {}}, {AIPPDATA, {}}, {RESHAPE, {}}, {EXPANDDIMS, {}}}; + +const std::unordered_set DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE}; + +Status RefInputTask::UpdateArgs(TaskContext &) { + // no need update args + return SUCCESS; +} + +Status RefInputTask::Execute(TaskContext &context) { + auto iter = out_ref_input_index_.find(node_type_); + if (iter == out_ref_input_index_.end()) { + GELOGE(UNSUPPORTED, "node %s type %s can not use RefInputTask.", node_name_.c_str(), node_type_.c_str()); + return UNSUPPORTED; + } + + auto &ref_index = iter->second; + if (ref_index.empty()) { + return RefOneByOne(context); + } else { + return RefByOrder(ref_index, context); + } +} + +Status RefInputTask::RefOneByOne(TaskContext &context) { + GELOGI("node %s type %s ref input one by one begin.", node_name_.c_str(), node_type_.c_str()); + uint32_t input_num = context.NumInputs(); + uint32_t output_num = context.NumOutputs(); + if (output_num > input_num) { + GELOGE(INTERNAL_ERROR, "node %s type %s has %u outputs but only %u inputs, can't ref one by one.", + node_name_.c_str(), node_type_.c_str(), output_num, input_num); + return INTERNAL_ERROR; + } + for (uint32_t out_index = 0; out_index < output_num; ++out_index) { + auto input = context.GetInput(out_index); + GE_CHECK_NOTNULL(input); + context.SetOutput(out_index, *input); + GELOGD("node %s type %s output[%u] ref input[%u] addr=%p.", node_name_.c_str(), node_type_.c_str(), out_index, + out_index, input->GetData()); + } + GELOGI("node %s type %s ref input one by one end.", node_name_.c_str(), node_type_.c_str()); + return SUCCESS; +} + +Status RefInputTask::RefByOrder(const std::vector &ref_order, TaskContext &context) { + GELOGI("node %s type %s ref input by order begin.", node_name_.c_str(), node_type_.c_str()); + uint32_t output_num = context.NumOutputs(); + if (ref_order.size() != output_num) { + GELOGE(INTERNAL_ERROR, "node %s type %s has %u outputs but only has %u out ref index.", node_name_.c_str(), + node_type_.c_str(), output_num, ref_order.size()); + return INTERNAL_ERROR; + } + for (uint32_t out_index = 0; out_index < output_num; ++out_index) { + auto ref_input_index = ref_order[out_index]; + auto input = context.GetInput(ref_input_index); + GE_CHECK_NOTNULL(input); + context.SetOutput(out_index, *input); + GELOGD("node %s type %s output[%u] ref input[%u] addr=%p.", node_name_.c_str(), node_type_.c_str(), out_index, + ref_input_index, input->GetData()); + } + GELOGI("node %s type %s ref input by order end.", node_name_.c_str(), node_type_.c_str()); + return SUCCESS; +} + +Status RefInputTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GE_CHK_STATUS_RET(Execute(context), "node:%s type:%s ref input task execute failed", node_name_.c_str(), + node_type_.c_str()); + if (done_callback != nullptr) { + // host cpu no need register callback, call it directly. + done_callback(); + } + return SUCCESS; +} + +bool RefInputTask::IsBelong(const std::string &op_type) { return out_ref_input_index_.count(op_type) > 0; } + +Status DependInputShapeTask::UpdateArgs(TaskContext &) { + // no need update args + return SUCCESS; +} + +Status DependInputShapeTask::Execute(TaskContext &context) { + KernelFactory &factory = KernelFactory::Instance(); + std::string node_type = node_->GetType(); + auto kernel = factory.Create(node_type); + if (kernel == nullptr) { + GELOGE(UNSUPPORTED, "node %s type %s is not supported by host kernel.", node_->GetName().c_str(), + node_type.c_str()); + return UNSUPPORTED; + } + std::vector outputs; + Status compute_ret = kernel->Compute(node_, outputs); + if (compute_ret != SUCCESS) { + GELOGE(compute_ret, "node %s type %s compute failed or not imply.", node_->GetName().c_str(), node_type.c_str()); + return compute_ret; + } + uint32_t output_num = context.NumOutputs(); + if (output_num != outputs.size()) { + GELOGE(INTERNAL_ERROR, "node %s type %s has %u output, but kernel compute only has %zu output.", + node_->GetName().c_str(), node_type.c_str(), output_num, outputs.size()); + return INTERNAL_ERROR; + } + + // alloc output + GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); + + // copy data to output + for (uint32_t i = 0; i < output_num; ++i) { + GeTensorPtr &tensor = outputs[i]; + GE_CHECK_NOTNULL(tensor); + auto tensor_data = tensor->GetData(); + auto tensor_value = context.MutableOutput(i); + GE_CHECK_NOTNULL(tensor_value); + if (tensor_data.GetSize() > tensor_value->GetSize()) { + GELOGE(INTERNAL_ERROR, "node:%s type:%s [%zu]th compute data size=%zu, but context data size=%zu.", + node_->GetName().c_str(), node_type.c_str(), i, tensor_data.GetSize(), tensor_value->GetSize()); + return INTERNAL_ERROR; + } + + GELOGI("node:%s type:%s [%zu]th output data=%p, out size=%zu, data size=%zu.", node_->GetName().c_str(), + node_type.c_str(), i, tensor_value->GetData(), tensor_value->GetSize(), tensor_data.GetSize()); + + if (tensor_data.GetSize() > 0) { + GE_CHK_RT_RET(rtMemcpy(tensor_value->MutableData(), tensor_value->GetSize(), tensor_data.GetData(), + tensor_data.GetSize(), RT_MEMCPY_HOST_TO_DEVICE)); + } + GELOGI("node:%s type:%s [%zu]th set data success, data size=%zu.", node_->GetName().c_str(), node_type.c_str(), i, + tensor_data.GetSize()); + } + return SUCCESS; +} + +Status DependInputShapeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GE_CHK_STATUS_RET(Execute(context), "node:%s type:%s depend input shape task execute failed", + node_->GetName().c_str(), node_->GetType().c_str()); + if (done_callback != nullptr) { + // host cpu no need register callback, call it directly. + done_callback(); + } + return SUCCESS; +} + +bool DependInputShapeTask::IsBelong(const std::string &op_type) { return depend_input_shape_ops_.count(op_type) > 0; } + +Status GeLocalNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { return task.UpdateArgs(context); } + +Status GeLocalNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, + std::shared_ptr &task) const { + GE_CHECK_NOTNULL(node); + std::string node_type = node->GetType(); + if (RefInputTask::IsBelong(node_type)) { + GELOGI("node %s type %s is ref input task, use RefInputTask.", node->GetName().c_str(), node_type.c_str()); + task = MakeShared(node); + if (task == nullptr) { + GELOGE(MEMALLOC_FAILED, "create RefInputTask for node %s failed.", node->GetName().c_str()); + return MEMALLOC_FAILED; + } + } else if (DependInputShapeTask::IsBelong(node_type)) { + GELOGI("node %s type %s is depend input shape task, use DependInputShapeTask.", node->GetName().c_str(), + node_type.c_str()); + task = MakeShared(node); + if (task == nullptr) { + GELOGE(MEMALLOC_FAILED, "create DependInputShapeTask for node %s type %s failed.", node->GetName().c_str(), + node_type.c_str()); + return MEMALLOC_FAILED; + } + } else { + GELOGE(UNSUPPORTED, "node %s type %s is not support in GeLocalNodeExecutor now.", node->GetName().c_str(), + node_type.c_str()); + return UNSUPPORTED; + } + return SUCCESS; +} + +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.h b/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.h new file mode 100644 index 00000000..beb1f50d --- /dev/null +++ b/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.h @@ -0,0 +1,80 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_KERNEL_GE_LOCAL_NODE_EXECUTOR_H_ +#define GE_HYBRID_KERNEL_GE_LOCAL_NODE_EXECUTOR_H_ + +#include +#include +#include "hybrid/node_executor/node_executor.h" + +namespace ge { +namespace hybrid { + +class RefInputTask : public NodeTask { + public: + explicit RefInputTask(const NodePtr &node) : node_name_(node->GetName()), node_type_(node->GetType()) {} + + ~RefInputTask() = default; + + virtual Status UpdateArgs(TaskContext &context) override; + virtual Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + static bool IsBelong(const std::string &op_type); + + private: + Status Execute(TaskContext &context); + Status RefOneByOne(TaskContext &context); + Status RefByOrder(const std::vector &ref_order, TaskContext &context); + + private: + const std::string node_name_; + const std::string node_type_; + + // key is op type, value is output ref input index, + // e.g. {1,0} means out[0] ref input[1], out[1] ref input[0], if vector is empty, it means ref input one by one + static const std::unordered_map> out_ref_input_index_; +}; + +class DependInputShapeTask : public NodeTask { + public: + explicit DependInputShapeTask(const NodePtr &node) : node_(node) {} + + ~DependInputShapeTask() = default; + + virtual Status UpdateArgs(TaskContext &context) override; + virtual Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + static bool IsBelong(const std::string &op_type); + + private: + Status Execute(TaskContext &context); + + private: + const NodePtr node_; + + // ops depend input shape + static const std::unordered_set depend_input_shape_ops_; +}; + +class GeLocalNodeExecutor : public NodeExecutor { + public: + Status PrepareTask(NodeTask &task, TaskContext &context) const override; + + virtual Status LoadTask(const HybridModel &model, const NodePtr &node, + std::shared_ptr &task) const override; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_KERNEL_GE_LOCAL_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/node_executor.cc b/src/ge/hybrid/node_executor/node_executor.cc new file mode 100644 index 00000000..38d37aa1 --- /dev/null +++ b/src/ge/hybrid/node_executor/node_executor.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/node_executor/node_executor.h" +#include "framework/common/debug/log.h" +#include "init/gelib.h" +#include "hybrid/model/hybrid_model.h" + +namespace ge { +namespace hybrid { +namespace { +const char *const kEngineNameAiCore = "AIcoreEngine"; +const char *const kEngineNameGeLocal = "DNN_VM_GE_LOCAL_OP_STORE"; +const char *const kEngineNameAiCpu = "aicpu_kernel"; +} // namespace +Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { + GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); + GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); + GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); + return SUCCESS; +} + +Status NodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, const std::function &callback) const { + GE_CHK_STATUS_RET(task.ExecuteAsync(context, callback), "Failed to execute task. node = %s", + context.GetNodeItem().NodeName().c_str()); + return SUCCESS; +} + +Status NodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { + return UNSUPPORTED; +} + +Status NodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { + return UNSUPPORTED; +} + +Status NodeExecutorManager::EnsureInitialized() { + std::lock_guard lk(mu_); + if (initialized_) { + return SUCCESS; + } + + engine_mapping_.emplace(kEngineNameAiCore, NodeExecutorManager::ExecutorType::AICORE); + engine_mapping_.emplace(kEngineNameGeLocal, NodeExecutorManager::ExecutorType::GE_LOCAL); + engine_mapping_.emplace(kEngineNameAiCpu, NodeExecutorManager::ExecutorType::AICPU_TF); + + std::shared_ptr instance_ptr = GELib::GetInstance(); + if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { + GELOGW("GELib not initialized"); + return FAILED; + } + + OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); + for (auto &it : ops_kernel_manager.GetAllOpsKernelInfoStores()) { + GELOGD("add kernel store: %s", it.first.c_str()); + kernel_stores_.emplace(it.first, it.second); + } + + GELOGI("Start to Initialize NodeExecutors"); + for (auto &it : builders_) { + auto engine_type = it.first; + auto build_fn = it.second; + GE_CHECK_NOTNULL(build_fn); + auto executor = std::unique_ptr(build_fn()); + if (executor == nullptr) { + GELOGE(INTERNAL_ERROR, "Failed to create executor for engine type = %d", engine_type); + return INTERNAL_ERROR; + } + + GELOGD("Executor of engine type = %d was created successfully", engine_type); + GE_CHK_STATUS_RET(executor->Initialize(), "Failed to initialize NodeExecutor of type = %d", engine_type); + executors_.emplace(engine_type, std::move(executor)); + } + + initialized_ = true; + GELOGI("Initializing NodeExecutors successfully"); + return SUCCESS; +} + +NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node &node) const { + auto op_type = node.GetType(); + if (op_type == PARTITIONEDCALL) { + return ExecutorType::COMPILED_SUBGRAPH; + } + + // rts kernel store is assigned to NetOutput + if (op_type == NETOUTPUT || op_type == VARIABLE) { + return ExecutorType::GE_LOCAL; + } + + auto op_desc = node.GetOpDesc(); // checked before + const auto &lib_name = op_desc->GetOpKernelLibName(); + auto it = engine_mapping_.find(lib_name); + if (it == engine_mapping_.end()) { + GELOGE(UNSUPPORTED, "KernelLib not supported. node = %s, lib_name = %s", node.GetName().c_str(), lib_name.c_str()); + return ExecutorType::RESERVED; + } + + return it->second; +} + +Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) const { + auto executor_type = ResolveExecutorType(node); + const auto it = executors_.find(executor_type); + if (it == executors_.end()) { + GELOGE(INTERNAL_ERROR, "Failed to get executor by type: %d", executor_type); + return INTERNAL_ERROR; + } + + *executor = it->second.get(); + return SUCCESS; +} + +void NodeExecutorManager::RegisterExecutorBuilder(NodeExecutorManager::ExecutorType executor_type, + const std::function &builder) { + builders_.emplace(executor_type, builder); +} + +Status NodeExecutorManager::CalcOpRunningParam(Node &node) const { + auto op_desc = node.GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto it = kernel_stores_.find(op_desc->GetOpKernelLibName()); + if (it == kernel_stores_.end()) { + GELOGE(INTERNAL_ERROR, "Failed to get OpKernelStore. libName = %s, node = %s", + op_desc->GetOpKernelLibName().c_str(), op_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + + return it->second->CalcOpRunningParam(node); +} + +NodeExecutorRegistrar::NodeExecutorRegistrar(NodeExecutorManager::ExecutorType executor_type, + NodeExecutor *(*builder)()) { + NodeExecutorManager::GetInstance().RegisterExecutorBuilder(executor_type, builder); +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/node_executor.h b/src/ge/hybrid/node_executor/node_executor.h new file mode 100644 index 00000000..613c0bb1 --- /dev/null +++ b/src/ge/hybrid/node_executor/node_executor.h @@ -0,0 +1,102 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_KERNEL_NODE_EXECUTOR_H_ +#define GE_HYBRID_KERNEL_NODE_EXECUTOR_H_ + +#include "external/ge/ge_api_error_codes.h" +#include "common/opskernel/ops_kernel_info_store.h" +#include "graph/node.h" +#include "proto/task.pb.h" +#include "task_context.h" + +namespace ge { +namespace hybrid { +class HybridModel; + +class NodeTask { + public: + NodeTask() = default; + virtual ~NodeTask() = default; + virtual Status UpdateArgs(TaskContext &context) = 0; + virtual Status ExecuteAsync(TaskContext &context, std::function done_callback) = 0; + virtual Status Init(TaskContext &context) { return SUCCESS; } +}; + +class NodeExecutor { + public: + NodeExecutor() = default; + virtual ~NodeExecutor() = default; + + virtual Status Initialize() { return SUCCESS; } + + virtual Status Finalize() { return SUCCESS; } + + virtual Status LoadTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const; + + virtual Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const; + + virtual Status PrepareTask(NodeTask &task, TaskContext &context) const; + virtual Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function &callback) const; +}; + +class NodeExecutorManager { + public: + enum class ExecutorType { AICORE, GE_LOCAL, AICPU_TF, AICPU_CUSTOM, COMPILED_SUBGRAPH, HCCL, RESERVED }; + + static NodeExecutorManager &GetInstance() { + static NodeExecutorManager instance; + return instance; + } + + Status CalcOpRunningParam(Node &node) const; + + void RegisterExecutorBuilder(ExecutorType executor_type, const std::function &builder); + + Status EnsureInitialized(); + + Status GetExecutor(Node &node, const NodeExecutor **executor) const; + + ExecutorType ResolveExecutorType(Node &node) const; + + std::map> executors_; + std::map> builders_; + std::map> kernel_stores_; + std::map engine_mapping_; + std::mutex mu_; + bool initialized_ = false; +}; + +class NodeExecutorRegistrar { + public: + NodeExecutorRegistrar(NodeExecutorManager::ExecutorType executor_type, NodeExecutor *(*builder)()); + ~NodeExecutorRegistrar() = default; +}; +} // namespace hybrid +} // namespace ge + +#define REGISTER_NODE_EXECUTOR_BUILDER(engine_type, executor) \ + REGISTER_NODE_EXECUTOR_BUILDER_UNIQ_HELPER(__COUNTER__, engine_type, executor) + +#define REGISTER_NODE_EXECUTOR_BUILDER_UNIQ_HELPER(ctr, engine_type, executor) \ + REGISTER_NODE_EXECUTOR_BUILDER_UNIQ(ctr, engine_type, executor) + +#define REGISTER_NODE_EXECUTOR_BUILDER_UNIQ(ctr, engine_type, executor) \ + static ::ge::hybrid::NodeExecutorRegistrar register_##ctr __attribute__((unused)) = \ + ::ge::hybrid::NodeExecutorRegistrar( \ + engine_type, []() -> ::ge::hybrid::NodeExecutor * { return new (std::nothrow) executor(); }) + +#endif // GE_HYBRID_KERNEL_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/task_context.cc b/src/ge/hybrid/node_executor/task_context.cc new file mode 100644 index 00000000..91bcc402 --- /dev/null +++ b/src/ge/hybrid/node_executor/task_context.cc @@ -0,0 +1,293 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task_context.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/debug/log.h" +#include "graph/utils/tensor_utils.h" +#include "hybrid/executor/hybrid_execution_context.h" + +namespace ge { +namespace hybrid { +TaskContext::TaskContext(GraphExecutionContext *execution_context) : execution_context_(execution_context) {} +TaskContext::~TaskContext() { + GELOGD("To execute ~TaskContext(). node = %s", node_item_->NodeName().c_str()); + for (auto ws_addr : workspaces_) { + execution_context_->allocator->Deallocate(ws_addr); + } + + // release output + for (int i = 0; i < NumOutputs(); ++i) { + auto output_tensor = MutableOutput(i); + if (output_tensor != nullptr) { + output_tensor->Destroy(); + } + } +} + +std::unique_ptr TaskContext::Create(const NodeItem &node_item, GraphExecutionContext *graph_context) { + GELOGI("To create task context for node %s, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d", + node_item.NodeName().c_str(), node_item.input_start, node_item.num_inputs, node_item.output_start, + node_item.num_outputs); + auto task_context = std::unique_ptr(new (std::nothrow) TaskContext(graph_context)); + if (task_context == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to create instance of TaskContext. node = %s", node_item.NodeName().c_str()); + return nullptr; + } + + task_context->node_item_ = &node_item; + task_context->inputs_start_ = graph_context->all_inputs.data() + node_item.input_start; + task_context->outputs_start_ = graph_context->all_outputs.data() + node_item.output_start; + return task_context; +} + +int TaskContext::NumInputs() const { return node_item_->num_inputs; } + +int TaskContext::NumOutputs() const { return node_item_->num_outputs; } + +TensorValue *TaskContext::MutableInput(int index) { + if (index < 0 || index > node_item_->num_inputs) { + GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_inputs = %d", index, node_item_->num_inputs); + return nullptr; + } + + return inputs_start_ + index; +} + +const TensorValue *TaskContext::GetOutput(int index) const { + if (index < 0 || index > node_item_->num_outputs) { + GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_outputs = %d", index, node_item_->num_outputs); + return nullptr; + } + + return outputs_start_ + index; +} + +TensorValue *TaskContext::MutableOutput(int index) { + if (index < 0 || index > node_item_->num_outputs) { + GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_outputs = %d", index, node_item_->num_outputs); + return nullptr; + } + + return outputs_start_ + index; +} + +std::size_t TaskContext::NumWorkspaces() const { return workspaces_.size(); } + +void *TaskContext::MutableWorkspace(int index) { + if (index < 0 || static_cast(index) >= workspaces_.size()) { + GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_workspaces = %d", index, node_item_->num_outputs); + return nullptr; + } + + return workspaces_[index]; +} + +const TensorValue *TaskContext::GetInput(int index) const { + if (index < 0 || index > node_item_->num_inputs) { + GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_inputs = %d", index, node_item_->num_inputs); + return nullptr; + } + + return inputs_start_ + index; +} + +Status TaskContext::AllocateWorkspaces() { + auto workspace_sizes = node_item_->node->GetOpDesc()->GetWorkspaceBytes(); + for (auto size : workspace_sizes) { + void *workspace = execution_context_->allocator->Allocate(size); + if (workspace == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to allocate workspace of size: %ld", size); + return MEMALLOC_FAILED; + } + + workspaces_.emplace_back(workspace); + } + return SUCCESS; +} + +Status TaskContext::RegisterCallback(const std::function &callback_fun) const { + return execution_context_->callback_manager->RegisterCallback(callback_fun); +} + +string TaskContext::TensorDesc2String(const GeTensorDesc &desc) { + std::stringstream ss; + ss << "[TensorDesc] "; + ss << "DataType = " << desc.GetDataType(); + ss << ", Format = " << desc.GetFormat(); + ss << ", Shape = ["; + for (auto dim : desc.GetShape().GetDims()) { + ss << dim << ", "; + } + ss << "]"; + + return ss.str(); +} + +Status TaskContext::AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor) { + int64_t size = 0; + if (ge::TensorUtils::GetSize(tensor_desc, size) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to get tensor size"); + return INTERNAL_ERROR; + } + + if (size == 0) { + GELOGW("size from tensor_desc == 0"); + } + + auto buffer = TensorBuffer::Create(execution_context_->allocator, size); + GE_CHECK_NOTNULL(buffer); + tensor = TensorValue(shared_ptr(buffer.release())); + return SUCCESS; +} + +Status TaskContext::AllocateOutput(int index, const GeTensorDesc &tensor_desc, TensorValue **tensor) { + GELOGI("To allocate output for node: %s. index = %d, tensor desc = %s", node_item_->NodeName().c_str(), index, + TensorDesc2String(tensor_desc).c_str()); + + if (index < 0 || index >= node_item_->num_outputs) { + GELOGE(PARAM_INVALID, "output index out of range. num_output = %d, index = %d", node_item_->num_outputs, index); + return PARAM_INVALID; + } + + if (outputs_start_[index].GetData() != nullptr) { + GELOGI("already allocated as net output"); + return SUCCESS; + } + + auto it = node_item_->ref_outputs.find(index); + if (it != node_item_->ref_outputs.end()) { + auto &ref_node = it->second; + GELOGD("source node of %s:%d = %s, op_type = %s", node_item_->NodeName().c_str(), index, + ref_node->GetName().c_str(), ref_node->GetType().c_str()); + + TensorValue *ref_tensor = execution_context_->model->GetVariable(ref_node->GetName()); + GE_CHECK_NOTNULL(ref_tensor); + outputs_start_[index] = *ref_tensor; + } else { + GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index])); + GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu", node_item_->NodeName().c_str(), index, + outputs_start_[index].GetSize()); + } + + if (execution_context_->trace_enabled) { + outputs_start_[index].SetName(node_item_->NodeName() + "_out_" + std::to_string(index)); + } + + if (tensor != nullptr) { + *tensor = outputs_start_ + index; + } + + return SUCCESS; +} + +Status TaskContext::AllocateOutputs() { + for (int i = 0; i < node_item_->num_outputs; ++i) { + const auto &output_desc = node_item_->op_desc->MutableOutputDesc(i); + GE_CHECK_NOTNULL(output_desc); + GE_CHK_STATUS_RET_NOLOG(AllocateOutput(i, *output_desc, nullptr)); + } + + return SUCCESS; +} + +Status TaskContext::AllocateTemp(size_t size, TensorValue &tensor) { + auto buffer = TensorBuffer::Create(execution_context_->allocator, size); + if (buffer == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to allocate buffer of size: %zu", size); + return MEMALLOC_FAILED; + } + + tensor = TensorValue(shared_ptr(buffer.release())); + return SUCCESS; +} + +const NodeItem &TaskContext::GetNodeItem() const { return *node_item_; } + +Status TaskContext::SetOutput(int index, const TensorValue &tensor) { + if (index < 0 || index >= node_item_->num_outputs) { + GELOGE(PARAM_INVALID, "output index out of range. num_output = %d, index = %d", node_item_->num_outputs, index); + return PARAM_INVALID; + } + + GELOGD("Set %s:%d with tensor: %s", node_item_->NodeName().c_str(), index, tensor.DebugString().c_str()); + outputs_start_[index] = tensor; + return SUCCESS; +} + +rtStream_t TaskContext::GetStream() { return execution_context_->stream; } + +int64_t TaskContext::GetSessionId() { return execution_context_->session_id; } + +Status TaskContext::GetStatus() const { return status_; } + +void TaskContext::SetStatus(Status status) { status_ = status; } + +Status TaskContext::AllocateWorkspace(size_t size, void **buffer, void *ori_addr) { + GE_CHECK_NOTNULL(buffer); + *buffer = execution_context_->allocator->Allocate(size, ori_addr); + if (*buffer == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to allocate workspace of size = %zu", size); + return MEMALLOC_FAILED; + } + + GELOGD("Allocating workspace of size = %zu successfully", size); + workspaces_.emplace_back(*buffer); + return SUCCESS; +} + +Status TaskContext::PropagateOutputs() { + // propagate outputs + for (int i = 0; i < NumOutputs(); ++i) { + auto tensor = MutableOutput(i); + GE_CHECK_NOTNULL(tensor); + if (tensor->GetData() == nullptr) { + GELOGD("[%s] Node output[%d] is null.", node_item_->NodeName().c_str(), i); + } + auto &output_nodes = node_item_->outputs[i]; + for (auto &dst_input_index_and_node : output_nodes) { + auto dst_input_idx = dst_input_index_and_node.first; + auto dst_node_item = dst_input_index_and_node.second; + GELOGI( + "Propagate output of node %s, output index = %d, dst node = %s, dst_input_index = %d, dst_input_offset = %d, " + "addr = %p", + node_item_->NodeName().c_str(), i, dst_node_item->NodeName().c_str(), dst_input_idx, + dst_node_item->input_start + dst_input_idx, + execution_context_->all_inputs.data() + dst_node_item->input_start + dst_input_idx); + execution_context_->all_inputs[dst_node_item->input_start + dst_input_idx] = *tensor; + if (execution_context_->trace_enabled) { + execution_context_->all_inputs[dst_node_item->input_start + dst_input_idx].SetName(node_item_->NodeName() + + "_in_" + std::to_string(i)); + } + } + } + + return SUCCESS; +} + +const void *TaskContext::GetVarBaseAddr() { return execution_context_->model->GetVarMemBase(); } + +const char *TaskContext::GetNodeName() const { return node_item_->NodeName().c_str(); } + +void TaskContext::ReleaseInput(int index) { + auto input_tensor = MutableInput(index); + if (input_tensor != nullptr) { + input_tensor->Destroy(); + GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); + } +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/task_context.h b/src/ge/hybrid/node_executor/task_context.h new file mode 100644 index 00000000..841dcb17 --- /dev/null +++ b/src/ge/hybrid/node_executor/task_context.h @@ -0,0 +1,85 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_KERNEL_TASK_CONTEXT_H_ +#define GE_HYBRID_KERNEL_TASK_CONTEXT_H_ + +#include +#include +#include +#include "external/ge/ge_api_error_codes.h" +#include "hybrid/common/tensor_value.h" +#include "hybrid/executor/rt_callback_manager.h" +#include "hybrid/model/node_item.h" + +namespace ge { +namespace hybrid { +class GraphExecutionContext; + +class TaskContext { + public: + static std::unique_ptr Create(const NodeItem &node_item, GraphExecutionContext *graph_context); + + ~TaskContext(); + + int NumInputs() const; + int NumOutputs() const; + size_t NumWorkspaces() const; + const NodeItem &GetNodeItem() const; + const char *GetNodeName() const; + TensorValue *MutableInput(int index); + void ReleaseInput(int index); + const TensorValue *GetInput(int index) const; + const TensorValue *GetOutput(int index) const; + TensorValue *MutableOutput(int index); + rtStream_t GetStream(); + int64_t GetSessionId(); + + Status SetOutput(int index, const TensorValue &tensor); + Status AllocateOutput(int index, const GeTensorDesc &tensor_desc, TensorValue **tensor); + Status AllocateOutputs(); + Status AllocateWorkspaces(); + Status AllocateWorkspace(size_t size, void **buffer, void *ori_addr = nullptr); + + const GraphExecutionContext *GetExecutionContext() { return execution_context_; } + + Status AllocateTemp(size_t size, TensorValue &tensor); + void *MutableWorkspace(int index); + const void *GetVarBaseAddr(); + + Status RegisterCallback(const std::function &callback_fun) const; + + Status PropagateOutputs(); + + Status GetStatus() const; + + void SetStatus(Status status); + + private: + explicit TaskContext(GraphExecutionContext *execution_context); + TensorValue *inputs_start_ = nullptr; + TensorValue *outputs_start_ = nullptr; + static string TensorDesc2String(const GeTensorDesc &desc); + Status AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor); + + GraphExecutionContext *execution_context_; + const NodeItem *node_item_ = nullptr; + Status status_ = SUCCESS; + std::vector workspaces_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_KERNEL_TASK_CONTEXT_H_ diff --git a/src/ge/inc/graph_pass.h b/src/ge/inc/graph_pass.h index b50269d0..d4abdd2f 100644 --- a/src/ge/inc/graph_pass.h +++ b/src/ge/inc/graph_pass.h @@ -45,6 +45,7 @@ class GraphPass : public Pass { /// @author /// virtual Status Run(ge::ComputeGraphPtr graph) = 0; + virtual Status ClearStatus() { return SUCCESS; }; static void RecordOriginalNames(std::vector original_nodes, const ge::NodePtr &node) { GE_CHECK_NOTNULL_JUST_RETURN(node); std::vector original_names; diff --git a/src/ge/inc/pass_manager.h b/src/ge/inc/pass_manager.h index 6a40b173..4efdaf94 100644 --- a/src/ge/inc/pass_manager.h +++ b/src/ge/inc/pass_manager.h @@ -35,14 +35,14 @@ class PassManager { /// get graph passes /// @author /// - const vector &GraphPasses() const; + const vector> &GraphPasses() const; /// /// Add graph pass /// @param [in] pass Pass to be added, it will be destroyed when pass manager destroys. /// @author /// - Status AddPass(GraphPass *pass); + Status AddPass(const string &pass_name, GraphPass *pass); /// /// Optimize graph with added pass @@ -63,12 +63,12 @@ class PassManager { /// @return others optimized failed /// @author /// - static Status Run(const ge::ComputeGraphPtr &graph, vector &passes); + static Status Run(const ge::ComputeGraphPtr &graph, vector> &passes); ~PassManager(); private: - vector graph_passes_; + vector> names_to_graph_passes_; }; } // namespace ge #endif // GE_INC_PASS_MANAGER_H_ diff --git a/src/ge/init/gelib.cc b/src/ge/init/gelib.cc index db12ef79..0a1178b1 100644 --- a/src/ge/init/gelib.cc +++ b/src/ge/init/gelib.cc @@ -147,35 +147,6 @@ Status GELib::InnerInitialize(const map &options) { return SUCCESS; } -void GELib::SetIncreBuild(const map &options) { - auto iter = options.find(OPTION_EXEC_ENABLE_INCRE_BUILD); - if (iter != options.end()) { - const std::string enable_incre_build = "true"; - const std::string disable_incre_build = "false"; - if (iter->second == enable_incre_build) { - is_incre_build_ = true; - GELOGI("Enable incre build."); - auto path_iter = options.find(OPTION_EXEC_INCRE_BUILD_CACHE_PATH); - if (path_iter != options.end()) { - std::string cache_path = path_iter->second; - if (!cache_path.empty() && cache_path[cache_path.size() - 1] != '/') { - cache_path += "/"; - } - incre_build_cache_path_ = cache_path; - } else { - incre_build_cache_path_ = ".ge_cache/"; - } - GELOGD("Using incre build cache path: %s.", incre_build_cache_path_.c_str()); - } else if (iter->second == disable_incre_build) { - is_incre_build_ = false; - GELOGI("Disable incre build."); - } else { - is_incre_build_ = false; - GELOGW("Invalid ENABLE_INCRE_BUILD option, it should be true or false."); - } - } -} - Status GELib::SystemInitialize(const map &options) { Status status = FAILED; auto iter = options.find(OPTION_GRAPH_RUN_MODE); @@ -214,8 +185,6 @@ Status GELib::SystemInitialize(const map &options) { PropertiesManager::Instance().SetDumpMode(dump_mode); } } - // check incre build flag - SetIncreBuild(options); if (is_train_mode_) { InitOptions(options); @@ -257,6 +226,11 @@ void GELib::InitOptions(const map &options) { if (iter != options.end()) { std::istringstream(iter->second) >> this->options_.isUseHcom; } + this->options_.isUseHvd = false; + iter = options.find(OPTION_EXEC_IS_USEHVD); + if (iter != options.end()) { + std::istringstream(iter->second) >> this->options_.isUseHvd; + } this->options_.deployMode = false; iter = options.find(OPTION_EXEC_DEPLOY_MODE); if (iter != options.end()) { diff --git a/src/ge/init/gelib.h b/src/ge/init/gelib.h index 60cbc0c0..81d36612 100644 --- a/src/ge/init/gelib.h +++ b/src/ge/init/gelib.h @@ -83,7 +83,6 @@ class GELib { Status SetRTSocVersion(const map &options); void RollbackInit(); void InitOptions(const map &options); - void SetIncreBuild(const map &options); DNNEngineManager engineManager_; OpsKernelManager opsManager_; diff --git a/src/ge/ir_build/atc_ir_common.cc b/src/ge/ir_build/atc_ir_common.cc index 109e6e6f..4a5b5bd4 100644 --- a/src/ge/ir_build/atc_ir_common.cc +++ b/src/ge/ir_build/atc_ir_common.cc @@ -19,6 +19,7 @@ #include "framework/common/types.h" #include "framework/common/util.h" #include "common/util/error_manager/error_manager.h" +#include "external/ge/ge_api_types.h" using std::pair; using std::string; @@ -28,6 +29,11 @@ namespace ge { namespace { const int64_t kDynamicInputDim = -1; const int64_t kDynamicImageSizeNum = 2; +// datatype/formats from user to GE, Unified to util interface file later +const std::map kOutputTypeSupportDatatype = { + {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; +const std::set kBufferOptimizeSupportOption = {"l1_optimize", "l2_optimize", "off_optimize", + "l1_and_l2_optimize"}; } // namespace bool CheckDynamicBatchSizeInputShapeValid(unordered_map> shape_map, @@ -37,7 +43,7 @@ bool CheckDynamicBatchSizeInputShapeValid(unordered_map> vector shape = iter->second; if (shape.size() < 1) { ErrorManager::GetInstance().ATCReportErrMessage("E10017"); - GELOGE(ge::PARAM_INVALID, "The input shape size can not be less than 0 in dynamic batchsize scenario."); + GELOGE(ge::PARAM_INVALID, "--input_shape's shape size can not be less than 1 when set --dynamic_batch_size."); return false; } if (shape[0] == kDynamicInputDim) { @@ -45,7 +51,7 @@ bool CheckDynamicBatchSizeInputShapeValid(unordered_map> if (shape[i] < 1) { ErrorManager::GetInstance().ATCReportErrMessage("E10018", {"index", "shape"}, {std::to_string(i), std::to_string(shape[i])}); - GELOGE(ge::PARAM_INVALID, "Only batch N can be -1 in dynamic batchsize scenario, current shape[%zu] is %ld", + GELOGE(ge::PARAM_INVALID, "Only batch N can be -1 when set --dynamic_batch_size, current shape[%zu] is %ld", i, shape[i]); return false; } @@ -56,13 +62,15 @@ bool CheckDynamicBatchSizeInputShapeValid(unordered_map> if (size == 0) { ErrorManager::GetInstance().ATCReportErrMessage("E10043"); - GELOGE(ge::PARAM_INVALID, "At least one batch n must be equal to -1 in dynamic batchsize scenario."); + GELOGE(ge::PARAM_INVALID, "At least one batch n must be equal to -1 when set --dynamic_batch_size."); return false; } for (char c : dynamic_batch_size) { if (!isdigit(c) && (c != ',') && (c != ' ')) { - GELOGE(ge::PARAM_INVALID, "dynamic_batch_size input : %s is invalid.", dynamic_batch_size.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10047", {"value"}, {dynamic_batch_size}); + GELOGE(ge::PARAM_INVALID, "Input parameter[--dynamic_batch_size]'s value[%s] is invalid.", + dynamic_batch_size.c_str()); return false; } } @@ -81,7 +89,8 @@ bool CheckDynamicImagesizeInputShapeValid(unordered_map> if (shape.size() != DIM_DEFAULT_SIZE) { if (std::count(shape.begin(), shape.end(), kDynamicInputDim) > 0) { ErrorManager::GetInstance().ATCReportErrMessage("E10019"); - GELOGE(ge::PARAM_INVALID, "Only height or width can be -1 in dynamic imagesize scenario."); + GELOGE(ge::PARAM_INVALID, + "--input_shape's shape is invalid, only height or width can be -1 when set --dynamic_image_size."); return false; } continue; @@ -106,13 +115,15 @@ bool CheckDynamicImagesizeInputShapeValid(unordered_map> continue; } else { ErrorManager::GetInstance().ATCReportErrMessage("E10019"); - GELOGE(ge::PARAM_INVALID, "Only height or width can be -1 in dynamic imagesize scenario."); + GELOGE(ge::PARAM_INVALID, + "--input_shape's shape is invalid, only height or width can be -1 when set --dynamic_image_size."); return false; } } if (size == 0) { ErrorManager::GetInstance().ATCReportErrMessage("E10019"); - GELOGE(ge::PARAM_INVALID, "Only height or width can be -1 in dynamic imagesize scenario."); + GELOGE(ge::PARAM_INVALID, + "--input_shape's shape is invalid, only height or width can be -1 when set --dynamic_image_size."); return false; } @@ -130,7 +141,7 @@ bool CheckDynamicImagesizeInputShapeValid(unordered_map> ErrorManager::GetInstance().ATCReportErrMessage("E10020", {"DynamicImageSizeNum"}, {std::to_string(kDynamicImageSizeNum)}); GELOGE(ge::PARAM_INVALID, - "Invalid dynamic_image_size : dynamic_image_size's number of dimensions of each " + "--dynamic_image_size's number of dimensions of each " "group must be %ld.", kDynamicImageSizeNum); return false; @@ -145,7 +156,7 @@ Status CheckDynamicBatchSizeOrImageSizeParamValid(std::string &dynamic_batch_siz bool &is_dynamic_input) { if (!dynamic_batch_size.empty() && !dynamic_image_size.empty()) { ErrorManager::GetInstance().ATCReportErrMessage("E10009", {"parameter0", "parameter1"}, - {dynamic_batch_size, dynamic_image_size}); + {"dynamic_batch_size", "dynamic_image_size"}); GELOGE(ge::PARAM_INVALID, "dynamic_batch_size and dynamic_image_size can not both exist"); return ge::PARAM_INVALID; } @@ -192,8 +203,10 @@ bool ParseInputShape(const string &input_shape, unordered_map shape_pair_vec = StringUtils::Split(shape, ':'); if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) { ErrorManager::GetInstance().ATCReportErrMessage("E10010", {"shape"}, {shape}); - GELOGW("Input parameter[--input_shape]’s shape is [%s], correct sample is input_name1:n1,c1,h1,w1", - shape.c_str()); + GELOGW( + "Input parameter[--input_shape]’s shape is [%s], " + "correct sample is input_name1:n1,c1,h1,w1", + shape.c_str()); return false; } if (shape_pair_vec[1].empty()) { @@ -226,8 +239,7 @@ bool ParseInputShape(const string &input_shape, unordered_map #include #include +#include + #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/omg/omg_inner_types.h" namespace ge { + +static std::set caffe_support_input_format = {"NCHW", "ND", "NCDHW"}; +static std::set tf_support_input_format = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"}; +static std::map input_format_str_to_geformat = { + {"ND", domi::DOMI_TENSOR_ND}, {"NCHW", domi::DOMI_TENSOR_NCHW}, {"NHWC", domi::DOMI_TENSOR_NHWC}, + {"CHWN", domi::DOMI_TENSOR_CHWN}, {"NC1HWC0", domi::DOMI_TENSOR_NC1HWC0}, {"NHWC1C0", domi::DOMI_TENSOR_NHWC1C0}, + {"NCDHW", domi::DOMI_TENSOR_NCDHW}, {"NDHWC", domi::DOMI_TENSOR_NDHWC}}; +static const std::string kEnableCompressWeightTrue = "1"; +static const std::string kEnableCompressWeightFalse = "0"; + bool CheckDynamicBatchSizeInputShapeValid(unordered_map> shape_map, std::string &dynamic_batch_size); @@ -39,5 +51,10 @@ Status CheckDynamicBatchSizeOrImageSizeParamValid(std::string &dynamic_batch_siz bool ParseInputShape(const std::string &input_shape, std::unordered_map> &shape_map, std::vector>> &user_shape_map, bool is_dynamic_input = false); + +Status CheckOutputTypeParamValid(const std::string output_type); +Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); +Status CheckCompressWeightParamValid(const std::string enable_compress_weight, const std::string compress_weight_conf); +int CheckLogParamValidAndSetLogLevel(const std::string log); } // namespace ge #endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ \ No newline at end of file diff --git a/src/ge/ir_build/ge_ir_build.cc b/src/ge/ir_build/ge_ir_build.cc index cf507c42..74b43215 100644 --- a/src/ge/ir_build/ge_ir_build.cc +++ b/src/ge/ir_build/ge_ir_build.cc @@ -40,14 +40,12 @@ using std::string; using namespace std; namespace ge { - -static std::map input_format_str_to_geformat = { - {"ND", domi::DOMI_TENSOR_ND}, {"NCHW", domi::DOMI_TENSOR_NCHW}, {"NHWC", domi::DOMI_TENSOR_NHWC}, - {"CHWN", domi::DOMI_TENSOR_CHWN}, {"NC1HWC0", domi::DOMI_TENSOR_NC1HWC0}, {"NHWC1C0", domi::DOMI_TENSOR_NHWC1C0}, -}; +namespace { const std::string IR_OPTION_TARGET = "target"; const std::string IR_OPTION_MODE = "mode"; const std::string IR_OP_CONF_DELIMITER = ":"; +const std::string IR_OPTION_LOG_LEVEL_DEFAULT = "default"; +} // namespace graphStatus aclgrphBuildInitialize(std::map global_options) { GELOGD("Enter aclgrphInitialize start!"); @@ -116,7 +114,12 @@ graphStatus Impl::CheckOptions(const std::map &options GELOGE(GRAPH_PARAM_INVALID, "input options include unsupported option(%s).Please check!", ele.first.c_str()); return GRAPH_PARAM_INVALID; } - options_.insert(ele); + + if (ele.first == ge::ir_option::ENABLE_COMPRESS_WEIGHT) { + continue; // this option will be set afer param check. + } else { + options_.insert(ele); + } } return GRAPH_SUCCESS; } @@ -127,6 +130,10 @@ graphStatus Impl::Init(const std::map &options) { GELOGE(ret, "user input options is not illegal!Please check!"); return ret; } + // set log level + std::string log = options_.find(ge::ir_option::LOG_LEVEL) == options_.end() ? IR_OPTION_LOG_LEVEL_DEFAULT + : options_[ge::ir_option::LOG_LEVEL]; + GE_CHK_BOOL_RET_STATUS_NOLOG(ge::CheckLogParamValidAndSetLogLevel(log) == 0, GRAPH_PARAM_INVALID); string input_shape = options_.find("input_shape") == options_.end() ? "" : options_["input_shape"]; string input_format = options_.find("input_format") == options_.end() ? "" : options_["input_format"]; @@ -148,6 +155,28 @@ graphStatus Impl::Init(const std::map &options) { dynamic_image_size.c_str()); GetContext().dynamic_batch_size = dynamic_batch_size; GetContext().dynamic_image_size = dynamic_image_size; + // check output_type + std::string output_type = + options_.find(ge::ir_option::OUTPUT_TYPE) == options_.end() ? "" : options_[ge::ir_option::OUTPUT_TYPE]; + GE_CHK_BOOL_EXEC(ge::CheckOutputTypeParamValid(output_type) == ge::SUCCESS, return ge::GRAPH_PARAM_INVALID, + "check output type failed!"); + // check buffer_optimize + std::string buffer_optimize = + options_.find(ge::ir_option::BUFFER_OPTIMIZE) == options_.end() ? "" : options_[ge::ir_option::BUFFER_OPTIMIZE]; + GE_CHK_BOOL_EXEC(ge::CheckBufferOptimizeParamValid(buffer_optimize) == ge::SUCCESS, return ge::GRAPH_PARAM_INVALID, + "check buffer optimize failed!"); + // check compress_weight + std::string enable_compress_weight = options_.find(ge::ir_option::ENABLE_COMPRESS_WEIGHT) == options_.end() + ? "" + : options_[ge::ir_option::ENABLE_COMPRESS_WEIGHT]; + std::string compress_weight_conf = options_.find(ge::ir_option::COMPRESS_WEIGHT_CONF) == options_.end() + ? "" + : options_[ge::ir_option::COMPRESS_WEIGHT_CONF]; + GE_CHK_BOOL_EXEC(ge::CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf) == ge::SUCCESS, + return ge::FAILED, "check compress weight failed!"); + options_.insert(std::pair( + std::string(ge::ir_option::ENABLE_COMPRESS_WEIGHT), + (enable_compress_weight == "true") ? ge::kEnableCompressWeightTrue : ge::kEnableCompressWeightFalse)); // for IR builder.Only support om mode, so here fixed; options_.insert(std::pair(string(IR_OPTION_MODE), to_string(0))); @@ -238,8 +267,8 @@ graphStatus Impl::InitDomiOmgContext(const string &input_shape, const string &in // the default value is ND GetContext().format = domi::DOMI_TENSOR_ND; if (!input_format.empty()) { - auto iter = input_format_str_to_geformat.find(input_format); - if (iter != input_format_str_to_geformat.end()) { + auto iter = ge::input_format_str_to_geformat.find(input_format); + if (iter != ge::input_format_str_to_geformat.end()) { GetContext().format = iter->second; } else { GELOGE(GRAPH_PARAM_INVALID, "Input format %s not support , expect ND/NCHW/NHWC/CHWN/NC1HWC0/NHWC1C0.", @@ -275,4 +304,5 @@ graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &m return FileSaver::SaveToFile((output_file + ".om"), reinterpret_cast(model.data.get()), static_cast(model.length)); } + } // namespace ge diff --git a/src/ge/model/ge_root_model.cc b/src/ge/model/ge_root_model.cc new file mode 100644 index 00000000..aee119fa --- /dev/null +++ b/src/ge/model/ge_root_model.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_root_model.h" +#include "graph/debug/ge_attr_define.h" +namespace ge { +void GeRootModel::SetSubgraphInstanceNameToModel(string instance_name, GeModelPtr ge_model) { + subgraph_instance_name_to_model_.insert(std::pair(instance_name, ge_model)); +} + +Status GeRootModel::CheckIsUnknownShape(bool &is_dynamic_shape) { + if (root_graph_ == nullptr) { + return FAILED; + } + is_dynamic_shape = false; + (void)AttrUtils::GetBool(root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/src/ge/model/ge_root_model.h b/src/ge/model/ge_root_model.h new file mode 100644 index 00000000..2b73c868 --- /dev/null +++ b/src/ge/model/ge_root_model.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "graph/compute_graph.h" +#include "model/ge_model.h" + +#ifndef GE_MODEL_GE_ROOT_MODEL_H_ +#define GE_MODEL_GE_ROOT_MODEL_H_ + +namespace ge { +class GeRootModel { + public: + explicit GeRootModel(ComputeGraphPtr &root_graph) : root_graph_(root_graph), model_id_(INVALID_MODEL_ID){}; + ~GeRootModel() = default; + + void SetSubgraphInstanceNameToModel(string instance_name, GeModelPtr ge_model); + const std::map &GetSubgraphInstanceNameToModel() const { + return subgraph_instance_name_to_model_; + }; + + const ComputeGraphPtr &GetRootGraph() const { return root_graph_; }; + void SetModelId(uint32_t model_id) { model_id_ = model_id; } + uint32_t GetModelId() const { return model_id_; } + Status CheckIsUnknownShape(bool &is_dynamic_shape); + + private: + ComputeGraphPtr root_graph_; + std::map subgraph_instance_name_to_model_; + uint32_t model_id_; +}; +} // namespace ge +using GeRootModelPtr = std::shared_ptr; +#endif // GE_MODEL_GE_ROOT_MODEL_H_ diff --git a/src/ge/opskernel_manager/ops_kernel_manager.cc b/src/ge/opskernel_manager/ops_kernel_manager.cc index b5276483..a8a1be88 100644 --- a/src/ge/opskernel_manager/ops_kernel_manager.cc +++ b/src/ge/opskernel_manager/ops_kernel_manager.cc @@ -67,6 +67,12 @@ Status OpsKernelManager::Initialize(const map &options_const) { options.emplace("ge.exec.isUseHcom", to_string(0)); } + iter = options.find(OPTION_EXEC_IS_USEHVD); + if (iter == options.end()) { + GELOGI("OPTION_EXEC_IS_USEHVD is not set, default is single P"); + options.emplace("ge.exec.isUseHvd", to_string(0)); + } + GetExternalEnginePath(extern_engine_path); GELOGI("OPTION_EXEC_EXTERN_PLUGIN_PATH=%s.", extern_engine_path.c_str()); @@ -74,13 +80,19 @@ Status OpsKernelManager::Initialize(const map &options_const) { if (ret == SUCCESS) { initialize_ = options; Status rst0 = plugin_manager_.InvokeAll &, Status>(kInitialize, initialize_); + if (rst0 == FAILED) { + GELOGE(GE_OPS_GET_NO_VALID_SO); + return GE_OPS_GET_NO_VALID_SO; + } Status rst1 = plugin_manager_.InvokeAll &>(kGetOpsKernelInfoStores, ops_kernel_store_); + if (rst1 != SUCCESS) { + GELOGW("Initialize OpsKernelInfo failed."); + } Status rst2 = plugin_manager_.InvokeAll &>(kGetGraphOptimizerObjs, graph_optimizers_); - if ((rst0 != SUCCESS) && (rst1 != SUCCESS) && (rst2 != SUCCESS)) { - GELOGE(GE_OPS_GET_NO_VALID_SO); - return GE_OPS_GET_NO_VALID_SO; + if (rst2 != SUCCESS) { + GELOGW("Initialize GraphOptimizerObjs failed."); } ret = CheckPluginPtr(); @@ -260,6 +272,10 @@ void OpsKernelManager::InitOpsKernelInfo() { } Status OpsKernelManager::InitGraphOptimzers(const map &options) { + GELOGI("Init graph optimizers options count %zu", options.size()); + for (const auto &option : options) { + GELOGI("Init graph optimizers option %s: %s", option.first.c_str(), option.second.c_str()); + } GELOGI("The number of GraphOptimzerObjs are %zu.", graph_optimizers_.size()); for (const auto &it : graph_optimizers_) { GELOGI("GraphOptimzer name: %s.", (it.first).c_str()); @@ -280,7 +296,6 @@ Status OpsKernelManager::InitGraphOptimzers(const map &options) return GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED; } } - return SUCCESS; } @@ -350,6 +365,10 @@ const map &OpsKernelManager::GetAllOpsKernelInfoS const map &OpsKernelManager::GetAllGraphOptimizerObjs() const { return graph_optimizers_; } +const vector> &OpsKernelManager::GetAllGraphOptimizerObjsByPriority() const { + return graph_optimizers_by_priority_; +} + void OpsKernelManager::GetGraphOptimizerByEngine(const std::string &engine_name, vector &graph_optimizer) { for (const auto &it : graph_optimizers_) { @@ -393,13 +412,11 @@ Status OpsKernelManager::InitGraphOptimizerPriority() { return SUCCESS; } // sort optimizer map by priority - map original_optimizers(graph_optimizers_); - graph_optimizers_.clear(); std::stringstream priority_seq; for (const auto optimizer_name : priorities) { - auto name_to_optimizer_pair = original_optimizers.find(optimizer_name); - if (name_to_optimizer_pair != original_optimizers.end()) { - graph_optimizers_.emplace(*name_to_optimizer_pair); + auto name_to_optimizer_pair = graph_optimizers_.find(optimizer_name); + if (name_to_optimizer_pair != graph_optimizers_.end()) { + graph_optimizers_by_priority_.emplace_back(*name_to_optimizer_pair); priority_seq << optimizer_name.c_str() << ' '; } else { GELOGW("Unknown optimizer %s show up in priority config file. Please check.", optimizer_name.c_str()); diff --git a/src/ge/opskernel_manager/ops_kernel_manager.h b/src/ge/opskernel_manager/ops_kernel_manager.h index df7e06b2..8d98ad3f 100644 --- a/src/ge/opskernel_manager/ops_kernel_manager.h +++ b/src/ge/opskernel_manager/ops_kernel_manager.h @@ -58,6 +58,9 @@ class OpsKernelManager { // get all graph_optimizer const map &GetAllGraphOptimizerObjs() const; + // get all graph_optimizer by priority + const vector> &GetAllGraphOptimizerObjsByPriority() const; + // get subgraphOptimizer by engine name void GetGraphOptimizerByEngine(const std::string &engine_name, vector &graph_optimizer); @@ -106,6 +109,8 @@ class OpsKernelManager { map ops_kernel_store_{}; // graph_optimizer map graph_optimizers_{}; + // ordered graph_optimzer + vector> graph_optimizers_by_priority_{}; // opsKernelInfo map> ops_kernel_info_{}; diff --git a/src/ge/opskernel_manager/optimizer_priority.pbtxt b/src/ge/opskernel_manager/optimizer_priority.pbtxt index 06bcf520..76768817 100644 --- a/src/ge/opskernel_manager/optimizer_priority.pbtxt +++ b/src/ge/opskernel_manager/optimizer_priority.pbtxt @@ -1 +1 @@ -optimizer:["AIcoreEngine","VectorEngine","aicpu_optimizer","hccl_graph_optimizer"] \ No newline at end of file +optimizer:["aicpu_original_optimizer","AIcoreEngine","VectorEngine","aicpu_optimizer","hccl_graph_optimizer", "hvd_graph_optimizer"] \ No newline at end of file diff --git a/src/ge/single_op/single_op.cc b/src/ge/single_op/single_op.cc index 04b09389..9578471a 100644 --- a/src/ge/single_op/single_op.cc +++ b/src/ge/single_op/single_op.cc @@ -134,7 +134,8 @@ Status SingleOp::UpdateArgs(const std::vector &inputs, const std::ve GELOGD("Update aicpu task args"); AiCpuTask *task_aicpu = dynamic_cast(task); GE_CHECK_NOTNULL(task_aicpu); - auto rt_ret = rtMemcpyAsync(task_aicpu->GetIOAddr(), sizeof(uint64_t) * args_.size(), &args_[0], + auto *dstIOAddr = const_cast(reinterpret_cast(task_aicpu->GetIOAddr())); + auto rt_ret = rtMemcpyAsync(dstIOAddr, sizeof(uint64_t) * args_.size(), &args_[0], sizeof(uint64_t) * args_.size(), RT_MEMCPY_HOST_TO_DEVICE_EX, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "rtMemcpyAsync addresses failed, ret = %d", rt_ret); diff --git a/src/ge/single_op/task/aicpu_task_builder.cc b/src/ge/single_op/task/aicpu_task_builder.cc index 3f571d30..e4b7aa80 100644 --- a/src/ge/single_op/task/aicpu_task_builder.cc +++ b/src/ge/single_op/task/aicpu_task_builder.cc @@ -108,7 +108,7 @@ Status AiCpuTaskBuilder::BuildTask(ge::AiCpuTask &task, const SingleOpModelParam STR_FWK_OP_KERNEL fwk_op_kernel; ret = SetFmkOpKernel(io_addr, ws_addr_vec[0], fwk_op_kernel); if (ret != SUCCESS) { - rtFree(io_addr); + (void)rtFree(io_addr); return ret; } // Create session @@ -119,7 +119,7 @@ Status AiCpuTaskBuilder::BuildTask(ge::AiCpuTask &task, const SingleOpModelParam return FAILED;) ret = SetKernelArgs(&task.args_, fwk_op_kernel); if (ret != SUCCESS) { - rtFree(io_addr); + (void)rtFree(io_addr); return ret; } diff --git a/src/ge/single_op/task/op_task.cc b/src/ge/single_op/task/op_task.cc index d515336b..e93fad71 100644 --- a/src/ge/single_op/task/op_task.cc +++ b/src/ge/single_op/task/op_task.cc @@ -64,7 +64,7 @@ Status TbeOpTask::LaunchKernel(rtStream_t stream) { AiCpuTask::~AiCpuTask() { if (args_ != nullptr) { - rtFree(args_); + (void)rtFree(args_); } if (io_addr_ != nullptr) { @@ -72,11 +72,11 @@ AiCpuTask::~AiCpuTask() { } } -void *AiCpuTask::GetIOAddr() { return io_addr_; } +const void *AiCpuTask::GetIOAddr() const { return io_addr_; } Status AiCpuTask::LaunchKernel(rtStream_t stream) { auto ret = rtMemcpyAsync(workspace_addr_, task_info_.size(), task_info_.data(), task_info_.size(), - RT_MEMCPY_HOST_TO_DEVICE, stream); + RT_MEMCPY_HOST_TO_DEVICE_EX, stream); if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "rtMemcpyAsync workspace data failed. ret = %d, task = %s", ret, this->op_type_.c_str()); return RT_FAILED; diff --git a/src/ge/single_op/task/op_task.h b/src/ge/single_op/task/op_task.h index 95e42772..168a71b3 100644 --- a/src/ge/single_op/task/op_task.h +++ b/src/ge/single_op/task/op_task.h @@ -68,7 +68,7 @@ class AiCpuTask : public OpTask { Status LaunchKernel(rtStream_t stream) override; OpTaskType GetOpTaskType() override { return OP_TASK_AICPU; } - void *GetIOAddr(); + const void *GetIOAddr() const; private: friend class AiCpuTaskBuilder; diff --git a/src/proto/fwk_adapter.proto b/src/proto/fwk_adapter.proto index 96368d55..99333d2e 100644 --- a/src/proto/fwk_adapter.proto +++ b/src/proto/fwk_adapter.proto @@ -31,12 +31,12 @@ message TensorDataInfo { // data point addr int64 data_addr = 3; -}; +} message KernelRunParam { // input repeated TensorDataInfo input = 1; // output repeated TensorDataInfo output = 2; -}; +} diff --git a/src/proto/op_mapping_info.proto b/src/proto/op_mapping_info.proto index 35383c5b..a02af28b 100644 --- a/src/proto/op_mapping_info.proto +++ b/src/proto/op_mapping_info.proto @@ -31,7 +31,7 @@ message Output { int32 original_output_data_type = 7; int32 original_output_format = 8; uint64 size = 9; -}; +} message Input { int32 data_type =1; @@ -39,12 +39,12 @@ message Input { Shape shape = 3; uint64 address = 4; uint64 size = 5; -}; +} message Op { string op_name = 1; string op_type = 2; -}; +} message Task { uint32 task_id = 1; @@ -53,7 +53,7 @@ message Task { repeated Output output = 4; bool end_graph = 5; repeated Input input = 6; -}; +} message OpMappingInfo { string dump_path = 1; @@ -75,4 +75,4 @@ message OpMappingInfo { uint32 flag = 7; // 0x01 load, 0x00 unload repeated Task task = 8; string dump_step = 9; -}; \ No newline at end of file +} \ No newline at end of file diff --git a/third_party/fwkacllib/inc/hccl/base.h b/third_party/fwkacllib/inc/hccl/base.h index 89c21f1c..74163baf 100644 --- a/third_party/fwkacllib/inc/hccl/base.h +++ b/third_party/fwkacllib/inc/hccl/base.h @@ -77,6 +77,19 @@ typedef enum tagHcclRedOp { HCCL_REP_OP_RESERVED /**< reserved */ } hcclRedOp_t; +/** + * @brief Horovod Reduction opperation + */ +typedef enum tagHorovodRedOp { + HOROVOD_REP_OP_AVERAGE = 0, /**< average */ + HOROVOD_REP_OP_SUM = 1, /**< sum */ + HOROVOD_REP_OP_ADASUM = 2, /**< adasum */ + HOROVOD_REP_OP_MIN = 3, /**< min */ + HOROVOD_REP_OP_MAX = 4, /**< max */ + HOROVOD_REP_OP_PROD = 5, /**< proo */ + HOROVOD_REP_OP_RESERVED /**< reserved */ +} horovodRedOp_t; + /** * @brief HCCL data type */ @@ -85,6 +98,7 @@ typedef enum tagHcclDataType { HCCL_DATA_TYPE_INT = 1, /**< int32 */ HCCL_DATA_TYPE_HALF = 2, /**< fp16 */ HCCL_DATA_TYPE_FLOAT = 3, /**< fp32 */ + HCCL_DATA_TYPE_INT16 = 4, /**< int16 */ HCCL_DATA_TYPE_RESERVED /**< reserved */ } hcclDataType_t; diff --git a/third_party/fwkacllib/inc/ops/all_ops.h b/third_party/fwkacllib/inc/ops/all_ops.h index d6bd1353..031e955c 100644 --- a/third_party/fwkacllib/inc/ops/all_ops.h +++ b/third_party/fwkacllib/inc/ops/all_ops.h @@ -24,17 +24,10 @@ #include "bitwise_ops.h" #include "boosted_trees_ops.h" #include "candidate_sampling_ops.h" -#include "clip_boxes.h" #include "control_flow_ops.h" #include "ctc_ops.h" #include "data_flow_ops.h" -#include "decode_bbox.h" -#include "decode_boundaries_target.h" -#include "decode_cornerpoints_target_bg.h" -#include "decode_cornerpoints_target_wrt_center_v1.h" -#include "decode_wheels_target.h" #include "elewise_calculation_ops.h" -#include "fastrcnn_predictions.h" #include "functional_ops.h" #include "get_data_ops.h" #include "hcom_ops.h" @@ -64,13 +57,11 @@ #include "resource_variable_ops.h" #include "rnn.h" #include "rpn_ops.h" -#include "rpn_proposals.h" #include "save_ops.h" #include "selection_ops.h" #include "set_ops.h" #include "sparse_ops.h" #include "split_combination_ops.h" -#include "ssddetectionoutput_ops.h" #include "stateful_random_ops.h" #include "stateless_random_ops.h" #include "state_ops.h" @@ -78,4 +69,5 @@ #include "swap_co_ops.h" #include "transformation_ops.h" #include "condtake_ops.h" +#include "warp_perspective_ops.h" #endif // BUILT_IN_OP_PROTO_INC_ALL_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/clip_boxes.h b/third_party/fwkacllib/inc/ops/clip_boxes.h deleted file mode 100644 index 967dc1b9..00000000 --- a/third_party/fwkacllib/inc/ops/clip_boxes.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - #ifndef GE_OP_CLIP_BOXES_H - #define GE_OP_CLIP_BOXES_H - - #include "graph/operator_reg.h" - - namespace ge { - - REG_OP(ClipBoxes) - .INPUT(boxes_input, TensorType({DT_FLOAT16})) - .INPUT(img_size, TensorType({DT_INT32})) - .OUTPUT(boxes_output, TensorType({DT_FLOAT16})) - .OP_END_FACTORY_REG(ClipBoxes) - - REG_OP(ClipBoxesD) - .INPUT(boxes_input, TensorType({DT_FLOAT16})) - .REQUIRED_ATTR(img_size, ListInt) - .OUTPUT(boxes_output, TensorType({DT_FLOAT16})) - .OP_END_FACTORY_REG(ClipBoxesD) - } // namespace ge - - #endif // GE_OP_CLIP_BOXES_H diff --git a/third_party/fwkacllib/inc/ops/decode_cornerpoints_target_bg.h b/third_party/fwkacllib/inc/ops/decode_cornerpoints_target_bg.h deleted file mode 100644 index ce10175f..00000000 --- a/third_party/fwkacllib/inc/ops/decode_cornerpoints_target_bg.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - #ifndef GE_OP_DECODE_CORNERPOINTS_TARGET_BG_H - #define GE_OP_DECODE_CORNERPOINTS_TARGET_BG_H - - #include "graph/operator_reg.h" - - namespace ge { - - REG_OP(DecodeCornerpointsTargetBG) - .INPUT(keypoints_prediction, TensorType({DT_FLOAT16})) /* "First operand." */ - .INPUT(anchors, TensorType({DT_FLOAT16})) /* "Second operand." */ - .OUTPUT(keypoints_decoded, TensorType({DT_FLOAT16})) /* "Result, has same element type as two inputs" */ - .OP_END_FACTORY_REG(DecodeCornerpointsTargetBG); - } // namespace ge - - #endif // GE_OP_DECODE_CORNERPOINTS_TARGET_BG_H diff --git a/third_party/fwkacllib/inc/ops/decode_cornerpoints_target_wrt_center_v1.h b/third_party/fwkacllib/inc/ops/decode_cornerpoints_target_wrt_center_v1.h deleted file mode 100644 index 0e96bc16..00000000 --- a/third_party/fwkacllib/inc/ops/decode_cornerpoints_target_wrt_center_v1.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - #ifndef GE_OP_DECODE_CORNERPOINTS_TARGET_WRT_CENTER_V1_H - #define GE_OP_DECODE_CORNERPOINTS_TARGET_WRT_CENTER_V1_H - - #include "graph/operator_reg.h" - - namespace ge { - - REG_OP(DecodeCornerpointsTargetWrtCenterV1) - .INPUT(keypoints_prediction, TensorType({DT_FLOAT16})) /* "First operand." */ - .INPUT(anchors, TensorType({DT_FLOAT16})) /* "Second operand." */ - .OUTPUT(keypoints_decoded, TensorType({DT_FLOAT16})) /* "Result, has same element type as two inputs" */ - .OP_END_FACTORY_REG(DecodeCornerpointsTargetWrtCenterV1) - } // namespace ge - - #endif // GE_OP_DECODE_CORNERPOINTS_TARGET_WRT_CENTER_V1_H - diff --git a/third_party/fwkacllib/inc/ops/fastrcnn_predictions.h b/third_party/fwkacllib/inc/ops/fastrcnn_predictions.h deleted file mode 100644 index e7794e45..00000000 --- a/third_party/fwkacllib/inc/ops/fastrcnn_predictions.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - #ifndef GE_OP_FASTRCNN_PREDICTIONS_H - #define GE_OP_FASTRCNN_PREDICTIONS_H - - #include "graph/operator_reg.h" - - namespace ge { - - REG_OP(FastrcnnPredictions) - .INPUT(rois, TensorType({DT_FLOAT16})) - .INPUT(score, TensorType({DT_FLOAT16})) - .REQUIRED_ATTR(nms_threshold, Float) - .REQUIRED_ATTR(score_threshold, Float) - .REQUIRED_ATTR(k, Int) - .OUTPUT(sorted_rois, TensorType({DT_FLOAT16})) - .OUTPUT(sorted_scores, TensorType({DT_FLOAT16})) - .OUTPUT(sorted_classes, TensorType({DT_FLOAT16})) - .OP_END_FACTORY_REG(FastrcnnPredictions) - } // namespace ge - - #endif // GE_OP_FASTRCNN_PREDICTIONS_H diff --git a/third_party/fwkacllib/inc/ops/functional_ops.h b/third_party/fwkacllib/inc/ops/functional_ops.h index ea15dba8..a297bc61 100644 --- a/third_party/fwkacllib/inc/ops/functional_ops.h +++ b/third_party/fwkacllib/inc/ops/functional_ops.h @@ -19,7 +19,6 @@ #include "graph/operator_reg.h" #include "graph/operator.h" -#include "graph/ge_attr_value.h" namespace ge { REG_OP(SymbolicGradient) diff --git a/third_party/fwkacllib/inc/ops/hcom_ops.h b/third_party/fwkacllib/inc/ops/hcom_ops.h index 5a69ed80..bdacebdf 100644 --- a/third_party/fwkacllib/inc/ops/hcom_ops.h +++ b/third_party/fwkacllib/inc/ops/hcom_ops.h @@ -23,7 +23,7 @@ namespace ge { /** * @brief Outputs a tensor gathering all input tensors. * @par Inputs: - * x: A tensor. Must be one of the following types: int8, int32, float16, + * x: A tensor. Must be one of the following types: int8, int16, int32, float16, * float32. * @par Attributes: * @li rank_size: A required integer identifying the number of ranks @@ -37,8 +37,8 @@ namespace ge { * as the name of a world group. */ REG_OP(HcomAllGather) - .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_FLOAT16})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_FLOAT16})) + .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16})) .REQUIRED_ATTR(rank_size, Int) .REQUIRED_ATTR(group, String) .ATTR(alpha, Float, 1.0) @@ -49,7 +49,7 @@ REG_OP(HcomAllGather) * @brief Outputs a tensor containing the reduction across all input tensors * passed to op. * @par Inputs: - * x: A tensor. Must be one of the following types: int8, int32, float16, + * x: A tensor. Must be one of the following types: int8, int16, int32, float16, * float32. * @par Attributes: * @li reduction: A required string identifying the reduction operation to @@ -67,8 +67,8 @@ REG_OP(HcomAllGather) * as the name of a world group. */ REG_OP(HcomAllReduce) - .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_FLOAT16})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_FLOAT16})) + .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16})) .REQUIRED_ATTR(reduction, String) .REQUIRED_ATTR(group, String) .ATTR(fusion, Int, 1) @@ -81,7 +81,7 @@ REG_OP(HcomAllReduce) * @brief Broadcasts the input tensor in root rank to all ranks. * @par Inputs: * x: A list of dynamic input tensor. Must be one of the following types: - * int8, int32, float16, float32. + * int8, int16, int32, float16, float32. * @par Attributes: * @li root_rank: A required integer identifying the root rank in the op * input of this rank will be broadcast to other ranks. @@ -94,8 +94,8 @@ REG_OP(HcomAllReduce) * as the name of a world group. */ REG_OP(HcomBroadcast) - .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_FLOAT16})) - .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_FLOAT16})) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16})) + .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16})) .REQUIRED_ATTR(root_rank, Int) .REQUIRED_ATTR(group, String) .ATTR(alpha, Float, 1.0) @@ -107,7 +107,7 @@ REG_OP(HcomBroadcast) * blocks among ranks, each rank getting a chunk of data based on its rank * index. * @par Inputs: - * x: A tensor. Must be one of the following types: int8, int32, float16, + * x: A tensor. Must be one of the following types: int8, int16, int32, float16, * float32. * @par Attributes: * @li reduction: A required string identifying the reduction operation to @@ -123,8 +123,8 @@ REG_OP(HcomBroadcast) * as the name of a world group. */ REG_OP(HcomReduceScatter) - .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_FLOAT16})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_FLOAT16})) + .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16})) .REQUIRED_ATTR(reduction, String) .REQUIRED_ATTR(group, String) .REQUIRED_ATTR(rank_size, Int) @@ -135,7 +135,7 @@ REG_OP(HcomReduceScatter) /** * @brief Sends the input tensor to destination rank. * @par Inputs: - * x: A tensor. Must be one of the following types: int8, int32, float16, + * x: A tensor. Must be one of the following types: int8, int16, int32, float16, * float32. * @par Attributes: * @li sr_tag: A required integer identifying the send/recv message tag. The @@ -152,7 +152,7 @@ REG_OP(HcomReduceScatter) * @see HcomReceive */ REG_OP(HcomSend) - .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_FLOAT16})) + .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16})) .REQUIRED_ATTR(group, String) .REQUIRED_ATTR(sr_tag, Int) .REQUIRED_ATTR(dest_rank, Int) @@ -168,12 +168,12 @@ REG_OP(HcomSend) * @li sr_tag: A required integer identifying the send/recv message tag. The * message will be send by the HcomSend op with the same "sr_tag". * @li src_rank: A required integer identifying the source rank. - * @li group: A required string identifying the group name of ranks + * @li group: A required string identifying the group name of ranks * participating in the op. * @li shape: A required list identifying the shape of the tensor to be * received. * @li dtype: A required integer identifying the type of the tensor to be - * received. The supported types are: int8, int32, float16, float32. + * received. The supported types are: int8, int16, int32, float16, float32. * @par Outputs: * y: A tensor with type identified in "dtype". * @attention Constraints:\n @@ -185,7 +185,7 @@ REG_OP(HcomSend) * @see HcomSend */ REG_OP(HcomReceive) - .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16})) .REQUIRED_ATTR(group, String) .REQUIRED_ATTR(sr_tag, Int) .REQUIRED_ATTR(src_rank, Int) diff --git a/third_party/fwkacllib/inc/ops/math_ops.h b/third_party/fwkacllib/inc/ops/math_ops.h index b75991e2..0bee7097 100644 --- a/third_party/fwkacllib/inc/ops/math_ops.h +++ b/third_party/fwkacllib/inc/ops/math_ops.h @@ -332,6 +332,22 @@ REG_OP(GetNext) .ATTR(output_num, Int, 1) .ATTR(channel_name, String, "") .OP_END_FACTORY_REG(GetNext) + +/** +*@brief End of sequence. + +*@par Inputs: +*x: A Tensor of type uint8. + +*@par Outputs: +*y: A Tensor. Has the same type as "x". +*/ + +REG_OP(EndOfSequence) + .INPUT(x, TensorType({DT_UINT8})) + .OUTPUT(y, TensorType({DT_UINT8})) + .OP_END_FACTORY_REG(EndOfSequence) + /** *@brief: Computes the Gauss error function of `x` element-wise. diff --git a/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h b/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h index 4f0f4557..df4c8359 100644 --- a/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h @@ -50,6 +50,38 @@ REG_OP(MatMul) .ATTR(transpose_x2, Bool, false) .OP_END_FACTORY_REG(MatMul) +/** +*@brief Multiplies matrix "a" by matrix "b", producing "a * b". + +*@par Inputs: +*Two inputs, including: +* @li x1: A matrix Tensor. 2D. Must be one of the following types: float16, +* float32, int32. Has format [ND, NHWC, FRACTAL_NZ]. +* @li x2: A matrix Tensor. 2D. Must be one of the following types: float16, +* float32, int32. Has format [ND, NHWC, FRACTAL_NZ]. +* @li bias: A 1D Tensor. Must be one of the following types: float16, +* float32, int32. Has format [ND, NHWC]. + +*@par Attributes: +*@li transpose_a: A bool. If True, changes the shape of "x1" from [M, K] to [K, M]. +*@li transpose_b: A bool. If True, changes the shape of "x2" from [M, K] to [K, M]. + +*@par Outputs: +*y: The result matrix Tensor. 2D. Must be one of the following types: float16, +* float32, int32. Has format [ND, NHWC, FRACTAL_NZ]. +*/ +REG_OP(MatMulV2) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) + .ATTR(transpose_x1, Bool, false) + .ATTR(transpose_x2, Bool, false) + .ATTR(offset_x, Int, 0) + .OP_END_FACTORY_REG(MatMulV2) + + /** *@brief Performs Matrix-to-matrix Multiply, producing c=alpha[0]*a*b+beta[0]*c. @@ -309,23 +341,23 @@ REG_OP(ScatterNdUpdate) * Three inputs, including: *@li x: An ND Tensor. \n -*Must be one of the following types: float16, float32, int32, int8, uint8 +*Must be one of the following types: float16, float32, bool, int8, uint8 *@li indices: An ND Tensor. \n *Must be one of the following types: int32 *@li updates: An ND Tensor. \n -*Must be one of the following types: float16, float32, int32, int8, uint8 +*Must be one of the following types: float16, float32, bool, int8, uint8 *@par Outputs: *y: A Tensor. Has the same type and format as input "x". */ REG_OP(TensorScatterUpdate) - .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .INPUT(x, TensorType::BasicType()) .INPUT(indices, TensorType::IndexNumberType()) - .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .INPUT(updates, TensorType::BasicType()) + .OUTPUT(y, TensorType::BasicType()) .OP_END_FACTORY_REG(TensorScatterUpdate) /** diff --git a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h index b7b55fb0..85062248 100644 --- a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h @@ -439,10 +439,11 @@ REG_OP(Conv2DBackpropInputD) * 4D with shape [filter_height, filter_width, in_channels, out_channels], * or [out_channels, filter_height, filter_width, in_channels], * or [out_channels, in_channel, filter_height, filter_width]. - * One optional input: + * Two optional inputs: * @li bias: An optional tensor of type int8 + * @li offset_w: An optional 1D tensor for quantized deconvolution. Reserved.\n *@par Attributes: - * Five attributes: + * Six attributes: * @li strides: A tuple or list of 2 integers. The stride of the sliding window * for H/W dimension. * @li pads: A tuple or list of 4 integers. The [top, bottom, left, right] @@ -453,6 +454,7 @@ REG_OP(Conv2DBackpropInputD) output channels. * @li data_format: An optional string from: "NHWC", "NCHW". Defaults to "NHWC".\n Specify the data format of the input and output data. + * @li offset_x: An optional integer for quantized deconvolution. *@par Outputs: * y: A Tensor. Has the same type as "filter". 4D tensor with shape * [batch, height, width, channels] or [batch, channels, height, width]. @@ -461,12 +463,14 @@ REG_OP(Deconvolution) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8})) .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8})) .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32})) .ATTR(strides, ListInt, {1, 1, 1, 1}) .ATTR(pads, ListInt, {0, 0, 0, 0}) .ATTR(dilations, ListInt, {1, 1, 1, 1}) .ATTR(groups, Int, 1) .ATTR(data_format, String, "NHWC") + .ATTR(offset_x, Int, 0) .OP_END_FACTORY_REG(Deconvolution) /** *@brief Computes the gradients of convolution with respect to the filter diff --git a/third_party/fwkacllib/inc/ops/nn_detect_ops.h b/third_party/fwkacllib/inc/ops/nn_detect_ops.h index ce06a9b2..90b49720 100644 --- a/third_party/fwkacllib/inc/ops/nn_detect_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_detect_ops.h @@ -328,7 +328,10 @@ REG_OP(PSROIPooling) *@li iou_threshold: An required float32, specifying the confidence threshold for box filtering, which is the output "obj" of operator Region. The value range is (0.0, 1.0). *@par Outputs: *box: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +Proposal of actual output, with shape [batch, numBoxes,8], 8 means [x1, y1, x2, y2, score, label, batchID, NULL], the maximum value of numBoxes is 1024. +That is, take min (the maximum number of input boxes, 1024) *actual_bbox_num: An NCHW tensor of type int32, specifying the number of output boxes. +With shape [bacth, num_classes], Actual number of bboxes output *@attention Constraints:\n *@li totalnum < max_rois_num * batch_rois. @@ -349,6 +352,49 @@ REG_OP(FSRDetectionOutput) .REQUIRED_ATTR(iou_threshold, Float) .OP_END_FACTORY_REG(FSRDetectionOutput) +/** +*@brief Returns detection result. + +*@par Inputs: +* Four inputs, including: +*@li bbox_delta: An ND tensor of type floa16 or float32, specifying the box loc predictions, used as the input of operator SSDDetectionOutput. +*@li score: An ND tensor of type floa16 or float32, specifying the box confidences data, used as the input of operator SSDDetectionOutput. +*@li anchors: An ND tensor of type floa16 or float32, output from operator PriorBoxD, used as the input of operator SSDDetectionOutput. +*@par Attributes: +*@li num_classes: An optional int32, specifying the number of classes to be predicted. Defaults to "2". The value must be greater than 1 and lesser than 1025. +*@li share_location: An option bool, specify the shared location. Defaults to True +*@li background_label_id: An option int32, specify the background label id. Must be 0 +*@li iou_threshold: An option float32, specify the nms threshold +*@li top_k: An option int32, specify the topk value. Defaults to 200 +*@li eta: An option float32, specify the eta value. Defaults to 1 +*@li variance_encoded_in_target: An option bool, specify whether variance encoded in target or not. Defaults to False +*@li code_type: An option int32, specify the code type. Defaults to 1(only supports 2). The corner is 1, center_size is 2, corner_size is 3 +*@li keep_top_k: An option int32, specify the topk value after nms. Defaults to -1 +*@li confidence_threshold: An option float32, specify the topk filter threshold. Only consider detections with confidence greater than the threshold +*@li kernel_name: An optional string, specifying the operator name. Defaults to "ssd_detection_output". +*@par Outputs: +*out_boxnum: An NCHW tensor of type int32, specifying the number of output boxes. +*y: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +With shape [batch,keep_top_k,8], 8 means (batchID, label(classID), score (class probability), xmin, ymin, xmax, yman, null) +*/ +REG_OP(SSDDetectionOutput) + .INPUT(bbox_delta, TensorType({DT_FLOAT, DT_FLOAT16})) + .INPUT(score, TensorType({DT_FLOAT, DT_FLOAT16})) + .INPUT(anchors, TensorType({DT_FLOAT, DT_FLOAT16})) + .OUTPUT(out_boxnum, TensorType({DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) + .ATTR(num_classes, Int, 2) + .ATTR(share_location, Bool, true) + .ATTR(background_label_id, Int, 0) + .ATTR(iou_threshold, Float, 0.3) + .ATTR(top_k, Int, 200) + .ATTR(eta, Float, 1.0) + .ATTR(variance_encoded_in_target, Bool, false) + .ATTR(code_type, Int, 1) + .ATTR(keep_top_k, Int, -1) + .ATTR(confidence_threshold, Float, 0.0) + .OP_END_FACTORY_REG(SSDDetectionOutput) + /** *@brief Normalizes data. It is called Region on YOLO v2 and Yolo on YOLO v3. @@ -412,12 +458,14 @@ and the actual image height and width. * *@par Outputs: *@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +With shape [batch,6,post_nms_topn], 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. *@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. - +With shape [batch,8,1,1], means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 +* *@attention Constraints:\n *@li This operator applies only to the YOLO v2 network. *@li The preceding layer of operator Yolov2DetectionOutput must be one Yolo operator. - +* *@see Yolo() */ REG_OP(YoloV2DetectionOutput) @@ -468,7 +516,9 @@ and the actual image height and width. * *@par Outputs: *@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +With shape [batch,6,post_nms_topn], 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. *@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. +With shape [batch,8,1,1], means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 * *@attention Constraints:\n *@li This operator applies only to the YOLO v2 network. @@ -524,7 +574,9 @@ and the actual image height and width. * *@par Outputs: *@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +With shape [batch,6,post_nms_topn], 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. *@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. +With shape [batch,8,1,1], means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 *@attention Constraints:\n *@li This operator applies only to the YOLO v3 network. @@ -587,7 +639,10 @@ and the actual image height and width. * *@par Outputs: *@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +With shape [batch,6,post_nms_topn], 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. *@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. +With shape [batch,8,1,1], means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 +* *@attention Constraints:\n *@li This operator applies only to the YOLO v3 network. @@ -714,6 +769,264 @@ REG_OP(ROIPooling) .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) .OP_END_FACTORY_REG(ROIPooling) +/** +*@brief Computes decode bbox function. +* +*@par Inputs: +*Inputs include: +* @li box_predictions: A Tensor. Must be float16. +* @li anchors: A Tensor. Must have the same type as box_predictions. +* +*@par Attributes: +* @ decode_clip: required, float, threahold of decode process. +* +*@par Outputs: +* @ decoded_boxes: A Tensor. Must have the same type as box_predictions. +* N-D with shape [N, 4]. +*/ +REG_OP(DecodeBbox) + .INPUT(box_predictions, TensorType{DT_FLOAT16}) + .INPUT(anchors, TensorType{DT_FLOAT16}) + .OUTPUT(decoded_boxes, TensorType{DT_FLOAT16}) + .REQUIRED_ATTR(decode_clip, Float) + .OP_END_FACTORY_REG(DecodeBbox) + +/** +*@brief Computes ClipBoxes function. +* +*@par Inputs: +*Inputs include: +* @li boxes_input: A Tensor. Must be float16. N-D with shape [N, 4]. +* @li img_size: A Tensor. Must be int32. shape [H, W]. +* +*@par Outputs: +* @ boxes_output: A Tensor. Must have the same type as boxes_output. N-D with shape [N, 4]. +*/ +REG_OP(ClipBoxes) + .INPUT(boxes_input, TensorType({DT_FLOAT16})) + .INPUT(img_size, TensorType({DT_INT32})) + .OUTPUT(boxes_output, TensorType({DT_FLOAT16})) + .OP_END_FACTORY_REG(ClipBoxes) +REG_OP(ClipBoxesD) + .INPUT(boxes_input, TensorType({DT_FLOAT16})) + .REQUIRED_ATTR(img_size, ListInt) + .OUTPUT(boxes_output, TensorType({DT_FLOAT16})) + .OP_END_FACTORY_REG(ClipBoxesD) + +/** +*@brief Computes Fastrcnn Predictions function. +* +*@par Inputs: +*Inputs include: +* @li rois: A Tensor. Must be float16. N-D with shape [N*C, 4]. +* @li score: A Tensor. Must be float16. N-D with shape [N, C+1]. +* +*@par Attributes: +* @li nms_threshold: required, float, threahold of nms process. +* @li score_threshold: required, float, threahold of topk process. +* @li k: required, Int, threahold of topk process. +*@par Outputs: +* @li sorted_rois: A Tensor. Must be float16. N-D with shape [N, 4]. +* @li sorted_scores: A Tensor. Must be float16. N-D with shape [N, 1]. +* @li sorted_classes: A Tensor. Must be float16. N-D with shape [N, 1]. +*/ +REG_OP(FastrcnnPredictions) + .INPUT(rois, TensorType({DT_FLOAT16})) + .INPUT(score, TensorType({DT_FLOAT16})) + .REQUIRED_ATTR(nms_threshold, Float) + .REQUIRED_ATTR(score_threshold, Float) + .REQUIRED_ATTR(k, Int) + .OUTPUT(sorted_rois, TensorType({DT_FLOAT16})) + .OUTPUT(sorted_scores, TensorType({DT_FLOAT16})) + .OUTPUT(sorted_classes, TensorType({DT_FLOAT16})) + .OP_END_FACTORY_REG(FastrcnnPredictions) + +/** +*@brief Computes Fastrcnn RpnProposals function. +* +*@par Inputs: +*Inputs include: +* @li rois: A Tensor. Must be float16. N-D with shape [N, 4]. +* @li cls_bg_prob: A Tensor. Must be float16. N-D with shape [N, 1]. +* @li img_size: A Tensor. Must be int32. shape [H, W]. +* +*@par Attributes: +* @li score_threshold: required, float, threahold of topk process. +* @li k: required, Int, threahold of topk process. +* @li min_size: required, float, threahold of nms process. +* @li nms_threshold: required, float, threahold of nms process. +* @li post_nms_num: required, float, threahold of nms process. +* @li score_filter: bool, mark of score_filter. Defaults to "true" +* @li box_filter: bool, mark of box_filter. Defaults to "true" +* @li score_sigmoid: bool, mark of score_sigmoid. Defaults to "false" +*@par Outputs: +* @li sorted_rois: A Tensor. Must be float16. N-D with shape [N, 4]. +* @li sorted_scores: A Tensor. Must be float16. N-D with shape [N, 1]. +* @li sorted_classes: A Tensor. Must be float16. N-D with shape [N, 1]. +*/ +REG_OP(RpnProposals) + .INPUT(rois, TensorType({DT_FLOAT16})) + .INPUT(cls_bg_prob, TensorType({DT_FLOAT16})) + .INPUT(img_size, TensorType({DT_INT32})) + .REQUIRED_ATTR(score_threshold, Float) + .REQUIRED_ATTR(k, Int) + .REQUIRED_ATTR(min_size, Float) + .REQUIRED_ATTR(nms_threshold, Float) + .REQUIRED_ATTR(post_nms_num, Int) + .ATTR(score_filter, Bool, true) + .ATTR(box_filter, Bool, true) + .ATTR(score_sigmoid, Bool, false) + .OUTPUT(sorted_box, TensorType({DT_FLOAT16})) + .OP_END_FACTORY_REG(RpnProposals) +REG_OP(RpnProposalsD) + .INPUT(rois, TensorType({DT_FLOAT16})) + .INPUT(cls_bg_prob, TensorType({DT_FLOAT16})) + .REQUIRED_ATTR(img_size, ListInt) + .REQUIRED_ATTR(score_threshold, Float) + .REQUIRED_ATTR(k, Int) + .REQUIRED_ATTR(min_size, Float) + .REQUIRED_ATTR(nms_threshold, Float) + .REQUIRED_ATTR(post_nms_num, Int) + .ATTR(score_filter, Bool, true) + .ATTR(box_filter, Bool, true) + .ATTR(score_sigmoid, Bool, false) + .OUTPUT(sorted_box, TensorType({DT_FLOAT16})) + .OP_END_FACTORY_REG(RpnProposalsD) + +/** +*@brief Computes Score Filte Pre-Sort function. +* +*@par Inputs: +*Inputs include: +* @li rois: A Tensor. Must be float16. N-D with shape [N, 4]. +* @li cls_bg_prob: A Tensor. Must be float16. N-D with shape [N, 1]. +* +*@par Attributes: +* @li score_threshold: required, float, threahold of topk process. +* @li k: required, Int, threahold of topk process. +* @li score_filter: bool, mark of score_filter. Defaults to "true" +* @li core_max_num: int, max number of core. Defaults to "8" +*@par Outputs: +* @li sorted_proposal: A Tensor. Must be float16. +* N-D with shape [8*6002, 8]. +* @li proposal_num: A Tensor. Must be uint32. N-D with shape [8, 8]. +*/ + +REG_OP(ScoreFiltePreSort) + .INPUT(rois, TensorType({DT_FLOAT16})) + .INPUT(cls_bg_prob, TensorType({DT_FLOAT16})) + .OUTPUT(sorted_proposal, TensorType({ DT_FLOAT16})) + .OUTPUT(proposal_num, TensorType({ DT_UINT32})) + .REQUIRED_ATTR(score_threshold, Float) + .REQUIRED_ATTR(k, Int) + .ATTR(score_filter, Bool, true) + .ATTR(core_max_num, Int, 8) + .OP_END_FACTORY_REG(ScoreFiltePreSort) + +/** +*@brief Computes Score Filte Pre-Sort function. +* +*@par Inputs: +*Inputs include: +* @li sorted_proposal: A Tensor. Must be float16. +* N-D with shape [8*6002, 8]. +* @li proposal_num: A Tensor. Must be uint32. N-D with shape [8, 8]. +* +*@par Attributes: +* @li min_size: required, float, threahold of nms process. +* @li score_threshold: required, float, threahold of topk process. +* @li k: required, Int, threahold of topk process. +* @li min_size: required, float, threahold of nms process. +* @li nms_threshold: required, float, threahold of nms process. +* @li post_nms_num: required, float, threahold of nms process. +* @li box_filter: bool, mark of box_filter. Defaults to "true" +* @li core_max_num: int, max number of core. Defaults to "8" +*@par Outputs: +* @li sorted_rois: A Tensor. Must be float16. N-D with shape [N, 4]. +* @li sorted_scores: A Tensor. Must be float16. N-D with shape [N, 1]. +* @li sorted_classes: A Tensor. Must be float16. N-D with shape [N, 1]. +*/ +REG_OP(RpnProposalPostProcessing) + .INPUT(sorted_proposal, TensorType({DT_FLOAT16})) + .INPUT(proposal_num, TensorType({DT_UINT32})) + .OUTPUT(sorted_box, TensorType({ DT_FLOAT16})) + .REQUIRED_ATTR(img_size, ListInt) + .REQUIRED_ATTR(score_threshold, Float) + .REQUIRED_ATTR(k, Int) + .REQUIRED_ATTR(min_size, Float) + .REQUIRED_ATTR(nms_threshold, Float) + .REQUIRED_ATTR(post_nms_num, Int) + .ATTR(box_filter, Bool, true) + .ATTR(core_max_num, Int, 8) + .OP_END_FACTORY_REG(RpnProposalPostProcessing) +/** +*@brief Computes DecodeBoundariesTarget function. +* +*@par Inputs: +*Inputs include: +* @li boundary_predictions: A Tensor. Must be float16. +* @li anchors: A Tensor. Must be float16. +* +*@par Outputs: +* @ boundary_encoded: A Tensor. Must be float16. +*/ +REG_OP(DecodeBoundariesTarget) + .INPUT(boundary_predictions, TensorType({DT_FLOAT16})) + .INPUT(anchors, TensorType({DT_FLOAT16})) + .OUTPUT(boundary_encoded, TensorType({DT_FLOAT16})) + .OP_END_FACTORY_REG(DecodeBoundariesTarget) + +/** +*@brief Computes DecodeCornerpointsTargetBG function. +* +*@par Inputs: +*Inputs include: +* @li keypoints_prediction: A Tensor. Must be float16. +* @li anchors: A Tensor. Must be float16. +* +*@par Outputs: +* @ keypoints_decoded: A Tensor. Must be float16. +*/ +REG_OP(DecodeCornerpointsTargetBG) + .INPUT(keypoints_prediction, TensorType({DT_FLOAT16})) + .INPUT(anchors, TensorType({DT_FLOAT16})) + .OUTPUT(keypoints_decoded, TensorType({DT_FLOAT16})) + .OP_END_FACTORY_REG(DecodeCornerpointsTargetBG); + +/** +*@brief Computes DecodeCornerpointsTargetWrtCenterV1 function. +* +*@par Inputs: +*Inputs include: +* @li keypoints_prediction: A Tensor. Must be float16. +* @li anchors: A Tensor. Must be float16. +* +*@par Outputs: +* @ keypoints_decoded: A Tensor. Must be float16. +*/ +REG_OP(DecodeCornerpointsTargetWrtCenterV1) + .INPUT(keypoints_prediction, TensorType({DT_FLOAT16})) + .INPUT(anchors, TensorType({DT_FLOAT16})) + .OUTPUT(keypoints_decoded, TensorType({DT_FLOAT16})) + .OP_END_FACTORY_REG(DecodeCornerpointsTargetWrtCenterV1) + +/** +*@brief Computes DecodeWheelsTarget function. +* +*@par Inputs: +*Inputs include: +* @li boundary_predictions: A Tensor. Must be float16. +* @li anchors: A Tensor. Must be float16. +* +*@par Outputs: +* @ boundary_encoded: A Tensor. Must be float16. +*/ +REG_OP(DecodeWheelsTarget) + .INPUT(boundary_predictions, TensorType({DT_FLOAT16})) + .INPUT(anchors, TensorType({DT_FLOAT16})) + .OUTPUT(boundary_encoded, TensorType({DT_FLOAT16})) + .OP_END_FACTORY_REG(DecodeWheelsTarget) + } // namespace ge #endif // GE_OP_NN_DETECT_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h index d3635c3f..0e57334f 100644 --- a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h @@ -173,6 +173,9 @@ REG_OP(MaxPool3D) .REQUIRED_ATTR(ksize, ListInt) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(padding, String) + .ATTR(pads, ListInt, {0,0,0}) + .ATTR(dilation, ListInt, {0,0,0}) + .ATTR(ceil_mode, Int, 0) .ATTR(data_format, String, "NDHWC") .OP_END_FACTORY_REG(MaxPool3D) diff --git a/third_party/fwkacllib/inc/ops/nn_training_ops.h b/third_party/fwkacllib/inc/ops/nn_training_ops.h index f09c1a8c..ff93f9fa 100644 --- a/third_party/fwkacllib/inc/ops/nn_training_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_training_ops.h @@ -174,6 +174,7 @@ REG_OP(SparseApplyAdagradD) .INPUT(grad, TensorType({DT_FLOAT})) .INPUT(indices, TensorType({DT_INT32})) .OUTPUT(var, TensorType({DT_FLOAT})) + .OUTPUT(accum, TensorType({DT_FLOAT})) .REQUIRED_ATTR(lr, Float) .ATTR(use_locking, Bool, false) .ATTR(update_slots, Bool, true) diff --git a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h index 15bd8812..0e1c4b22 100644 --- a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h +++ b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h @@ -123,6 +123,25 @@ REG_OP(Relu6) .OUTPUT(y, TensorType::RealNumberType()) .OP_END_FACTORY_REG(Relu6) +/** +* @brief Computes rectified linear 6*scale. +* activations = min(max(x, 0), 6*scale). + +* @par Inputs: +* x: A Tensor of type RealNumberType. + +* @par Attributes: +* epsilon: A required scalar. The data type is float32. + +* @par Outputs: +* y: A Tensor of type RealNumberType. +*/ +REG_OP(Relu6D) + .INPUT(x, TensorType::RealNumberType()) + .OUTPUT(y, TensorType::RealNumberType()) + .ATTR(scale, Float, 1.0) + .OP_END_FACTORY_REG(Relu6D) + /** * @brief Computes rectified linear 6 gradients for a Relu6 operation. * backprops = gradients * (features > 0) * (features < 6). @@ -450,15 +469,18 @@ REG_OP(LeakyReluGrad) .OUTPUT(backprops, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .OP_END_FACTORY_REG(LeakyReluGrad) -REG_OP(threshold_grad_v2_d) - .INPUT(input_x, TensorType({DT_INT32, DT_FLOAT16})) - .INPUT(input_y, TensorType({DT_INT32, DT_FLOAT16})) - .OUTPUT(output_z, TensorType({DT_INT32, DT_FLOAT16})) - .OP_END_FACTORY_REG(threshold_grad_v2_d) +REG_OP(ThresholdGradV2D) + .INPUT(gradients, TensorType({DT_INT32, DT_FLOAT16})) + .INPUT(features, TensorType({DT_INT32, DT_FLOAT16})) + .OUTPUT(backprops, TensorType({DT_INT32, DT_FLOAT16})) + .REQUIRED_ATTR(threshold, Float) + .OP_END_FACTORY_REG(ThresholdGradV2D) REG_OP(ThresholdV2D) .INPUT(x, TensorType::RealNumberType()) .OUTPUT(y, TensorType::RealNumberType()) + .REQUIRED_ATTR(threshold, Float) + .REQUIRED_ATTR(value, Float) .OP_END_FACTORY_REG(ThresholdV2D) } // namespace ge diff --git a/third_party/fwkacllib/inc/ops/rpn_proposal_post_processing.h b/third_party/fwkacllib/inc/ops/rpn_proposal_post_processing.h deleted file mode 100644 index b8861f49..00000000 --- a/third_party/fwkacllib/inc/ops/rpn_proposal_post_processing.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - #ifndef GE_OP_RPN_PROPOSAL_POST_PROCESSING_H - #define GE_OP_RPN_PROPOSAL_POST_PROCESSING_H - - #include "graph/operator_reg.h" - -namespace ge { - REG_OP(RpnProposalPostProcessing) - .INPUT(sorted_proposal, TensorType({DT_FLOAT16})) - .INPUT(proposal_num, TensorType({DT_UINT32})) - .OUTPUT(sorted_box, TensorType({ DT_FLOAT16})) - .REQUIRED_ATTR(img_size, ListInt) - .REQUIRED_ATTR(score_threshold, Float) - .REQUIRED_ATTR(k, Int) - .REQUIRED_ATTR(min_size, Float) - .REQUIRED_ATTR(nms_threshold, Float) - .REQUIRED_ATTR(post_nms_num, Int) - .ATTR(box_filter, Bool, true) - .ATTR(core_max_num, Int, 8) - .OP_END_FACTORY_REG(RpnProposalPostProcessing) - } // namespace ge - - #endif // GE_OP_GENERATE_RPN_PROPOSAL_POST_PROCESSING_H - diff --git a/third_party/fwkacllib/inc/ops/rpn_proposals.h b/third_party/fwkacllib/inc/ops/rpn_proposals.h deleted file mode 100644 index 3ebf7589..00000000 --- a/third_party/fwkacllib/inc/ops/rpn_proposals.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - #ifndef GE_OP_RPN_PROPOSALS_H - #define GE_OP_RPN_PROPOSALS_H - - #include "graph/operator_reg.h" - -namespace ge { -REG_OP(RpnProposals) - .INPUT(rois, TensorType({DT_FLOAT16})) - .INPUT(cls_bg_prob, TensorType({DT_FLOAT16})) - .INPUT(img_size, TensorType({DT_INT32})) - .REQUIRED_ATTR(score_threshold, Float) - .REQUIRED_ATTR(k, Int) - .REQUIRED_ATTR(min_size, Float) - .REQUIRED_ATTR(nms_threshold, Float) - .REQUIRED_ATTR(post_nms_num, Int) - .ATTR(score_filter, Bool, true) - .ATTR(box_filter, Bool, true) - .ATTR(score_sigmoid, Bool, false) - .OUTPUT(sorted_box, TensorType({DT_FLOAT16})) - .OP_END_FACTORY_REG(RpnProposals) - -REG_OP(RpnProposalsD) - .INPUT(rois, TensorType({DT_FLOAT16})) - .INPUT(cls_bg_prob, TensorType({DT_FLOAT16})) - .REQUIRED_ATTR(img_size, ListInt) - .REQUIRED_ATTR(score_threshold, Float) - .REQUIRED_ATTR(k, Int) - .REQUIRED_ATTR(min_size, Float) - .REQUIRED_ATTR(nms_threshold, Float) - .REQUIRED_ATTR(post_nms_num, Int) - .ATTR(score_filter, Bool, true) - .ATTR(box_filter, Bool, true) - .ATTR(score_sigmoid, Bool, false) - .OUTPUT(sorted_box, TensorType({DT_FLOAT16})) - .OP_END_FACTORY_REG(RpnProposalsD) -} // namespace ge - - #endif // GE_OP_GENERATE_RPN_PROPOSALS_H diff --git a/third_party/fwkacllib/inc/ops/ssddetectionoutput_ops.h b/third_party/fwkacllib/inc/ops/ssddetectionoutput_ops.h deleted file mode 100644 index 7c50db14..00000000 --- a/third_party/fwkacllib/inc/ops/ssddetectionoutput_ops.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_OP_SSDDETECTIONOUTPUT_OPS_H_ -#define GE_OP_SSDDETECTIONOUTPUT_OPS_H_ -#include "graph/operator_reg.h" - -namespace ge { -/** -*@brief Returns detection result. - -*@par Inputs: -* Four inputs, including: -*@li mbox_conf: An ND tensor of type floa16 or float32, specifying the box confidences data, used as the input of operator SSDDetectionOutput. -*@li mbox_loc: An ND tensor of type floa16 or float32, specifying the box loc predictions, used as the input of operator SSDDetectionOutput. -*@li mbox_priorbox: An ND tensor of type floa16 or float32, output from operator PriorBoxD, used as the input of operator SSDDetectionOutput. -*@par Attributes: -*@li num_classes: An optional int32, specifying the number of classes to be predicted. Defaults to "2". The value must be greater than 1 and lesser than 1025. -*@li share_location: An option bool, specify the shared location. Defaults to True -*@li background_label_id: An option int32, specify the background label id. Must be 0 -*@li nms_threshold: An option float32, specify the nms threshold -*@li top_k: An option int32, specify the topk value. Defaults to 200 -*@li eta: An option float32, specify the eta value. Defaults to 1 -*@li variance_encoded_in_target: An option bool, specify whether variance encoded in target or not. Defaults to False -*@li code_type: An option int32, specify the code type. Defaults to 1(only supports 2). The corner is 1, center_size is 2, corner_size is 3 -*@li keep_top_k: An option int32, specify the topk value after nms. Defaults to -1 -*@li confidence_threshold: An option float32, specify the topk filter threshold. Only consider detections with confidence greater than the threshold -*@li kernel_name: An optional string, specifying the operator name. Defaults to "ssd_detection_output". -*@par Outputs: -*out_boxnum: An NCHW tensor of type int32, specifying the number of output boxes. -*y: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. - -*/ -REG_OP(SSDDetectionOutput) - .INPUT(bbox_delta, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(score, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(anchors, TensorType({DT_FLOAT, DT_FLOAT16})) - .OUTPUT(out_boxnum, TensorType({DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) - .ATTR(num_classes, Int, 2) - .ATTR(share_location, Bool, true) - .ATTR(background_label_id, Int, 0) - .ATTR(iou_threshold, Float, 0.3) - .ATTR(top_k, Int, 200) - .ATTR(eta, Float, 1.0) - .ATTR(variance_encoded_in_target, Bool, false) - .ATTR(code_type, Int, 1) - .ATTR(keep_top_k, Int, -1) - .ATTR(confidence_threshold, Float, 0.0) - .OP_END_FACTORY_REG(SSDDetectionOutput) -} -#endif diff --git a/third_party/fwkacllib/inc/ops/warp_perspective_ops.h b/third_party/fwkacllib/inc/ops/warp_perspective_ops.h new file mode 100644 index 00000000..7da49c1e --- /dev/null +++ b/third_party/fwkacllib/inc/ops/warp_perspective_ops.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_OP_WARP_PERSPECTIVE_OPS_H_ +#define GE_OP_WARP_PERSPECTIVE_OPS_H_ + +#include "graph/operator_reg.h" +#include "graph/operator.h" + +namespace ge { +/** +*@brief Applies a perspective transformation to an image. + +*@par Inputs: +*@li x: input tensor, format NCHW, type must be float. +*@li matrix: transformation matrix, format ND , shape must be (N, 9), type must be float. + +*@par Attributes: +*@li out_height:output height. +*@li out_width:output width. +*@li borderType:border processing way, only support BORDER_CONSTANT and BORDER_REPLICATE, default BORDER_CONSTANT. +*@li constant: border processed value when borderType is BORDER_CONSTANT. + +*@par Outputs: +*@li y: output tensor, format NCHW, type must be float. +*/ + +REG_OP(WarpPerspective) + .INPUT(x, TensorType({DT_FLOAT})) + .INPUT(matrix, TensorType({DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT})) + .REQUIRED_ATTR(out_height, Int) + .REQUIRED_ATTR(out_width, Int) + .ATTR(border_type, String, "BORDER_CONSTANT") + .ATTR(constant, Float, 0) + .OP_END_FACTORY_REG(WarpPerspective) +} // namespace ge + +#endif // GE_OP_WARP_PERSPECTIVE_OPS_H_ diff --git a/third_party/fwkacllib/inc/runtime/config.h b/third_party/fwkacllib/inc/runtime/config.h index fcdcf2ec..2e48cc57 100644 --- a/third_party/fwkacllib/inc/runtime/config.h +++ b/third_party/fwkacllib/inc/runtime/config.h @@ -59,6 +59,9 @@ typedef enum tagRtPlatformType { PLATFORM_BEGIN = 0, PLATFORM_MINI_V1 = PLATFORM_BEGIN, PLATFORM_CLOUD_V1, + PLATFORM_MINI_V2, + PLATFORM_LHISI_ES, + PLATFORM_LHISI_CS, PLATFORM_END, } rtPlatformType_t; diff --git a/third_party/fwkacllib/inc/runtime/dev.h b/third_party/fwkacllib/inc/runtime/dev.h index 08fa3970..928f2822 100644 --- a/third_party/fwkacllib/inc/runtime/dev.h +++ b/third_party/fwkacllib/inc/runtime/dev.h @@ -218,6 +218,13 @@ RTS_API rtError_t rtGetRunMode(rtRunMode *mode); * @return RT_ERROR_NONE for ok */ RTS_API rtError_t rtSetSocVersion(const char *version); + +/** + * @ingroup dvrt_dev + * @brief get chipType + * @return RT_ERROR_NONE for ok + */ +rtError_t rtGetSocVersion(char *version, const uint32_t maxLen); #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/runtime/event.h b/third_party/fwkacllib/inc/runtime/event.h index fbc4d759..31991cf7 100644 --- a/third_party/fwkacllib/inc/runtime/event.h +++ b/third_party/fwkacllib/inc/runtime/event.h @@ -23,6 +23,13 @@ extern "C" { #endif +/** + * @ingroup event_flags + * @brief event op bit flags + */ +#define RT_EVENT_DEFAULT (0x00) +#define RT_EVENT_WITH_FLAG (0x01) + /** * @ingroup dvrt_event * @brief create event instance @@ -32,6 +39,15 @@ extern "C" { */ RTS_API rtError_t rtEventCreate(rtEvent_t *event); +/** + * @ingroup dvrt_event + * @brief create event instance with flag + * @param [in|out] event created event flag event op flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + */ +RTS_API rtError_t rtEventCreateWithFlag(rtEvent_t *event, uint32_t flag); + /** * @ingroup dvrt_event * @brief destroy event instance diff --git a/third_party/fwkacllib/inc/tdt/data_common.h b/third_party/fwkacllib/inc/tdt/data_common.h index da9881ff..687d31d2 100644 --- a/third_party/fwkacllib/inc/tdt/data_common.h +++ b/third_party/fwkacllib/inc/tdt/data_common.h @@ -46,6 +46,7 @@ struct TdtDataItem { TdtDataType dataType_; /**< Input data type*/ uint64_t label_; /**< Input data label*/ uint64_t dataLen_; /**< Input data type length*/ + uint64_t realDataLen_; /**< Real Input data type length*/ std::string tensorShape_; /**< Tensor shape*/ std::string tensorType_; /**< Tensor type*/ uint32_t cnt_; /**< Data count*/ diff --git a/third_party/fwkacllib/inc/tdt/status.h b/third_party/fwkacllib/inc/tdt/status.h index ec624b35..87ae8f75 100644 --- a/third_party/fwkacllib/inc/tdt/status.h +++ b/third_party/fwkacllib/inc/tdt/status.h @@ -255,6 +255,7 @@ enum { TDT_RECV_MSG_MD5_WRONG_CODE, TDT_RECV_MSG_FAIL_TO_GENERATE_MD5_CODE, TDT_RECV_MSG_SEQUENCE_ERROR_CODE, + TDT_SERVER_MEMORY_COPY_FAILED_CODE, TDT_DEVICEID_ERROR_CODE, TDT_MEMORY_DATA_TYPE_FACTORY_MAKE_SHARED_FAILED_CODE, TDT_PREFETCH_FILELIST_NOT_EXIST_CODE, @@ -298,6 +299,10 @@ enum { TDT_TUNING_DATA_RECEIVE_CHECK_PARA_ERROR_CODE, TDT_TUNING_DATA_TRANSFER_PARAMETER_ERROR_CODE, TDT_RECV_MSG_CHECKSUM_WRONG_ERROR_CODE, + TDT_SVM_INIT_FAILED_CODE, + TDT_SVM_FREE_PIN_FAILED_CODE, + TDT_SVM_FREE_SVM_FAILED_CODE, + TDT_SVM_ADD_BUFFER_MAP_FAILED_CODE, TDT_STATUS_CODE_TOTAL }; @@ -350,6 +355,7 @@ constexpr uint16_t MODID_TSD_CLIENT = 0x0114; // TSD_CLIENT module ID constexpr uint16_t MODID_CHECKSUM = 0x0115; // Checksum module ID constexpr uint16_t MODID_TDT_MONITOR = 0x0116; // TDT monitor module ID constexpr uint16_t MODID_TDT_HOST = 0x0117; // GE adapts the TDT HOST module ID +constexpr uint16_t MODID_SVM = 0x0118; // SVM Driver module ID constexpr uint32_t TDT_API_MAX_SUB_VERSION = 100; static const int32_t TDT_INVAILED_DEVICE_ID = 0xFFFFFFFF; @@ -676,6 +682,7 @@ TDT_DEF_ERROR_CODE(MODID_TDT_SERVER, TDT_ERROR, TDT_RECV_MSG_MD5_WRONG, "md5 of TDT_DEF_ERROR_CODE(MODID_TDT_SERVER, TDT_ERROR, TDT_RECV_MSG_CHECKSUM_WRONG_ERROR, "checksum of recv msg is wrong"); TDT_DEF_ERROR_CODE(MODID_TDT_SERVER, TDT_ERROR, TDT_RECV_MSG_FAIL_TO_GENERATE_MD5, "md5 of recv msg is wrong"); TDT_DEF_ERROR_CODE(MODID_TDT_SERVER, TDT_ERROR, TDT_RECV_MSG_SEQUENCE_ERROR, "sequence recv msg is wrong"); +TDT_DEF_ERROR_CODE(MODID_TDT_SERVER, TDT_ERROR, TDT_SERVER_MEMORY_COPY_FAILED, "memory copy failed"); TDT_DEF_ERROR_CODE(MODID_TDT_CLIENT, TDT_ERROR, TDT_CHANNEL_HAS_NO_SESSION_ERROR, "channel has no session"); TDT_DEF_ERROR_CODE(MODID_HDC_CLIENT, TDT_ERROR, TDT_HDC_CLIENT_INIT_ERROR, "hdc client init error"); TDT_DEF_ERROR_CODE(MODID_HDC_CLIENT, TDT_ERROR, TDT_HDC_CLIENT_CREATE_SESSION_ERROR, "hdc client create error"); @@ -735,4 +742,8 @@ TDT_DEF_ERROR_CODE(MODID_TDT_CLIENT, TDT_ERROR, TDT_TUNING_DATA_TRANSFER_INIT_FA TDT_DEF_ERROR_CODE(MODID_TDT_CLIENT, TDT_ERROR, TDT_TUNING_DATA_RECEIVE_CHECK_PARA_ERROR, "the index is error"); TDT_DEF_ERROR_CODE(MODID_TDT_CLIENT, TDT_ERROR, TDT_TUNING_DATA_TRANSFER_PARAMETER_ERROR, "the parameter is error"); +TDT_DEF_ERROR_CODE(MODID_SVM, TDT_ERROR, TDT_SVM_INIT_FAILED, "SVM driver init failed"); +TDT_DEF_ERROR_CODE(MODID_SVM, TDT_ERROR, TDT_SVM_FREE_PIN_FAILED, "SVM driver free host pin memory failed"); +TDT_DEF_ERROR_CODE(MODID_SVM, TDT_ERROR, TDT_SVM_FREE_SVM_FAILED, "SVM driver free device svm memory failed"); +TDT_DEF_ERROR_CODE(MODID_SVM, TDT_ERROR, TDT_SVM_ADD_BUFFER_MAP_FAILED, "add svm buffer info to map failed"); #endif // INC_TDT_STATUS_H_