diff --git a/ge/graph/load/new_model_manager/davinci_model.cc b/ge/graph/load/new_model_manager/davinci_model.cc index f3d6f82b..706d4b3b 100755 --- a/ge/graph/load/new_model_manager/davinci_model.cc +++ b/ge/graph/load/new_model_manager/davinci_model.cc @@ -87,6 +87,7 @@ const uint32_t kDumpL1FusionOpMByteSize = 2097152; // 2 * 1024 * 1024 const uint32_t kDumpFlagOfL1Fusion = 0; const char *const kDefaultBatchLable = "Batch_default"; const char *const kGetDynamicDimsName = "ascend_mbatch_get_dynamic_dims_node"; +const char *const kMultiBatchNodePostfix = "_ascend_mbatch_batch_"; const int32_t kInvalidStream = -1; const uint32_t kEndOfSequence = 0x0704000a; const uint32_t kEndOfSequenceNew = 507005; @@ -867,6 +868,10 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { GELOGE(PARAM_INVALID, "NetOutput init failed, Name: %s", op_desc->GetName().c_str()); return PARAM_INVALID; } + if (InitRealSizeAndShapeInfo(compute_graph, node) != SUCCESS) { + GELOGE(PARAM_INVALID, "Init real size and shape failed, Name: %s", op_desc->GetName().c_str()); + return PARAM_INVALID; + } continue; } @@ -1143,16 +1148,24 @@ Status DavinciModel::InitNetOutput(const ComputeGraphPtr &graph, const NodePtr & real_virtual_addrs_.insert(real_addr); } } + return SUCCESS; +} +Status DavinciModel::InitRealSizeAndShapeInfo(const ComputeGraphPtr &compute_graph, const NodePtr &node) { + if (node->GetName().find(kMultiBatchNodePostfix) != string::npos) { + GELOGD("No need to get size and shape of netoutput in subgraph."); + return SUCCESS; + } + GELOGD("Start init real size and shape info of %s.", node->GetName().c_str()); GetAllGearsInfo(node); if (is_getnext_sink_dynamic_) { GE_IF_BOOL_EXEC(GetGetDynamicDimsNodeInfo(node) != SUCCESS, GELOGE(PARAM_INVALID, "Failed to get info of getdynamicdims node."); return PARAM_INVALID;); } if (is_online_infer_dynamic_) { - GE_IF_BOOL_EXEC(GetGearAndRealOutSizeInfo(input_count, node) != SUCCESS, + GE_IF_BOOL_EXEC(GetGearAndRealOutSizeInfo(compute_graph, node) != SUCCESS, GELOGE(PARAM_INVALID, "Failed to get gear and real out size info."); return PARAM_INVALID;); - GE_IF_BOOL_EXEC(GetGearAndRealOutShapeInfo(input_count, op_desc) != SUCCESS, + GE_IF_BOOL_EXEC(GetGearAndRealOutShapeInfo(compute_graph, node) != SUCCESS, GELOGE(PARAM_INVALID, "Failed to get gear and real out shape info."); return PARAM_INVALID;); } @@ -1171,7 +1184,7 @@ void DavinciModel::GetAllGearsInfo(const NodePtr &node) { if (shape_str.empty()) { continue; } - std::vector gear_info; + std::vector gear_info; std::vector dims = ge::StringUtils::Split(shape_str, ','); for (const auto &dim : dims) { if (dim.empty()) { @@ -1187,6 +1200,7 @@ void DavinciModel::GetAllGearsInfo(const NodePtr &node) { } } } + Status DavinciModel::GetGetDynamicDimsNodeInfo(const NodePtr &node) { GE_CHECK_NOTNULL(node->GetOpDesc()); size_t input_count = node->GetAllInDataAnchors().size(); @@ -1224,11 +1238,11 @@ Status DavinciModel::GetGetDynamicDimsNodeInfo(const NodePtr &node) { return SUCCESS; } -Status DavinciModel::GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr &node) { - GELOGD("Start get gear and real output size info of %s, input count is %zu.", node->GetName().c_str(), input_count); +Status DavinciModel::GetGearAndRealOutSizeInfo(const ComputeGraphPtr &graph, const NodePtr &node) { + GELOGD("Start get gear and real output size info of %s.", node->GetName().c_str()); merge_nodes_gear_and_real_out_size_info_.clear(); - for (size_t idx = 0; idx < input_count; ++idx) { - auto in_anchor = node->GetAllInDataAnchors().at(idx); + size_t idx = 0; + for (const auto &in_anchor : node->GetAllInDataAnchors()) { auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { continue; @@ -1236,89 +1250,106 @@ Status DavinciModel::GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr auto peer_node = peer_out_anchor->GetOwnerNode(); auto op_desc = peer_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - if ((peer_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { - if (GetRealOutputSizeOfMerge(idx, peer_node) != SUCCESS) { + if ((peer_node->GetType() == CASE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { + if (GetRealOutputSizeOfCase(graph, idx, peer_node) != SUCCESS) { GELOGE(PARAM_INVALID, "Get real output size of %s failed.", peer_node->GetName().c_str()); return PARAM_INVALID; } } + idx++; } return SUCCESS; } -Status DavinciModel::GetRealOutputSizeOfMerge(size_t input_index, const NodePtr &merge_node) { - GELOGD("Start get output size of %s, which is %zu input to netoutput.", merge_node->GetName().c_str(), input_index); - std::map, int64_t> gear_and_real_out_size_info; - for (auto &in_anchor : merge_node->GetAllInDataAnchors()) { - auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - continue; - } - auto in_node = peer_out_anchor->GetOwnerNode(); - GELOGD("Input node of merge is %s.", in_node->GetName().c_str()); - auto op_desc = in_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - string batch_label; - if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { - size_t batch_index = static_cast(stoi(batch_label.substr(batch_label.rfind('_') + 1))); - GELOGD("Batch index of %s is %zu.", op_desc->GetName().c_str(), batch_index); - if (batch_index > all_gears_info_.size()) { - GELOGE(PARAM_INVALID, "The value of ATTR_NAME_BATCH_LABEL is invalid."); - return PARAM_INVALID; - } - - const vector output_size_list = ModelUtils::GetOutputSize(op_desc); - int output_index = ge::AnchorUtils::GetIdx(peer_out_anchor); - auto tensor_desc = op_desc->GetOutputDescPtr(output_index); - GE_CHECK_NOTNULL(tensor_desc); - int64_t data_size = 0; - if (TensorUtils::GetTensorSizeInBytes(*tensor_desc, data_size) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Get tensor size in bytes failed."); - return FAILED; +Status DavinciModel::GetRealOutputSizeOfCase(const ComputeGraphPtr &graph, size_t input_index, + const NodePtr &case_node) { + GELOGD("Start get output size of %s, which is %zu input to netoutput.", case_node->GetName().c_str(), input_index); + const auto &func_desc = case_node->GetOpDesc(); + GE_CHECK_NOTNULL(func_desc); + std::map, int64_t> gear_and_real_out_size_info; + for (const auto &name : func_desc->GetSubgraphInstanceNames()) { + const auto &subgraph = graph->GetSubgraph(name); + if (subgraph == nullptr) { + GELOGE(GE_GRAPH_EMPTY_SUBGRAPH, "Subgraph not found, name: %s.", name.c_str()); + return GE_GRAPH_EMPTY_SUBGRAPH; + } + for (auto &node : subgraph->GetDirectNode()) { + if (node->GetType() == NETOUTPUT) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + string batch_label; + if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { + size_t batch_index = static_cast(stoi(batch_label.substr(batch_label.rfind('_') + 1))); + GELOGD("Batch index of %s is %zu.", op_desc->GetName().c_str(), batch_index); + if (batch_index > all_gears_info_.size()) { + GELOGE(PARAM_INVALID, "The value of ATTR_NAME_BATCH_LABEL is invalid."); + return PARAM_INVALID; + } + + const vector input_size_list = ModelUtils::GetInputSize(op_desc); + auto tensor_desc = op_desc->GetInputDescPtr(input_index); + GE_CHECK_NOTNULL(tensor_desc); + int64_t data_size = 0; + if (TensorUtils::GetTensorSizeInBytes(*tensor_desc, data_size) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Get tensor size in bytes failed."); + return FAILED; + } + gear_and_real_out_size_info[all_gears_info_[batch_index]] = data_size; + GELOGD("Get real gear index is: %zu, gear info is %s, size is %ld, tensor size is %ld", + batch_index, formats::JoinToString(all_gears_info_[batch_index]).c_str(), + input_size_list[input_index], data_size); + } + break; } - gear_and_real_out_size_info[all_gears_info_[batch_index]] = data_size; - GELOGD("Get real gear index is: %zu, gear info is %s, size is %ld, tensor size is %ld", - batch_index, formats::JoinToString(all_gears_info_[batch_index]).c_str(), - output_size_list[output_index], data_size); } } merge_nodes_gear_and_real_out_size_info_[input_index] = gear_and_real_out_size_info; return SUCCESS; } -Status DavinciModel::GetGearAndRealOutShapeInfo(size_t input_count, const OpDescPtr &op_desc) { - GELOGD("Start to get dynamic output dims of %s.", op_desc->GetName().c_str()); +Status DavinciModel::GetGearAndRealOutShapeInfo(const ComputeGraphPtr &graph, const NodePtr &node) { + GELOGD("Start to get dynamic output dims of %s.", node->GetName().c_str()); merge_nodes_gear_and_real_out_shape_info_.clear(); - std::vector dynamic_output_shape_info; - if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_shape_info)) { - GELOGD("Can not get dynamic output dims attr"); - return SUCCESS; - } - GELOGI("Dynamic output shape info is %s", formats::JoinToString(dynamic_output_shape_info).c_str()); - std::vector> dynamic_output_shape; - ParseDynamicOutShape(dynamic_output_shape_info, dynamic_output_shape); - // idx: input_index to netoutput - for (size_t idx = 0; idx < input_count; ++idx) { - std::map, vector> gear_and_real_out_shape_info; - for (auto &it : dynamic_output_shape) { - auto gear_index = static_cast(it[0]); - if (gear_index > all_gears_info_.size()) { - GELOGE(PARAM_INVALID, "The value of cur index: %zu is invalid.", static_cast(it[0])); - return PARAM_INVALID; + size_t idx = 0; + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + continue; + } + auto peer_node = peer_out_anchor->GetOwnerNode(); + auto op_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if ((peer_node->GetType() == CASE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { + std::vector dynamic_output_shape_info; + if (!AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_shape_info)) { + GELOGD("Can not get dynamic output dims attr from %s.", node->GetName().c_str()); + return SUCCESS; } + GELOGI("Dynamic output shape info is %s", formats::JoinToString(dynamic_output_shape_info).c_str()); + std::vector> dynamic_output_shape; + ParseDynamicOutShape(dynamic_output_shape_info, dynamic_output_shape); + std::map, vector> gear_and_real_out_shape_info; + for (auto &it : dynamic_output_shape) { + auto gear_index = static_cast(it[0]); + if (gear_index > all_gears_info_.size()) { + GELOGE(PARAM_INVALID, "The value of cur index: %zu is invalid.", static_cast(it[0])); + return PARAM_INVALID; + } - if (static_cast(it[1]) == idx) { - vector output_shape; - for (size_t i = 2; i < it.size(); ++i) { - output_shape.emplace_back(it[i]); + if (static_cast(it[1]) == idx) { + vector output_shape; + for (size_t i = 2; i < it.size(); ++i) { + output_shape.emplace_back(it[i]); + } + gear_and_real_out_shape_info[all_gears_info_[gear_index]] = output_shape; + GELOGD("Get real gear index is: %zu, gear info is %s, output shape is %s.", + gear_index, formats::JoinToString(all_gears_info_[gear_index]).c_str(), + formats::JoinToString(output_shape).c_str()); } - gear_and_real_out_shape_info[all_gears_info_[gear_index]] = output_shape; - GELOGD("Get real gear index is: %zu, gear info is %s, output shape is %s.", - gear_index, formats::JoinToString(all_gears_info_[gear_index]).c_str(), - formats::JoinToString(output_shape).c_str()); } + merge_nodes_gear_and_real_out_shape_info_[idx] = gear_and_real_out_shape_info; } - merge_nodes_gear_and_real_out_shape_info_[idx] = gear_and_real_out_shape_info; + idx++; } return SUCCESS; } @@ -1962,7 +1993,7 @@ void DavinciModel::CreateOutput(uint32_t index, const OpDescPtr &op_desc, InputO uint32_t &format_result) { /// netoutput input tensor desc GE_IF_BOOL_EXEC(op_desc->GetInputDescPtr(index) == nullptr, GELOGE(FAILED, "OpDesc GetInputDescPtr is nullptr"); - return ); + return); Format format = op_desc->GetInputDescPtr(index)->GetFormat(); GeShape shape = op_desc->GetInputDescPtr(index)->GetShape(); DataType data_type = op_desc->GetInputDescPtr(index)->GetDataType(); @@ -2567,7 +2598,7 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b GELOGD("Reinit cur dynamic dims when getnext sink dynamic."); cur_dynamic_dims_.clear(); cur_dynamic_dims_.resize(shape_of_cur_dynamic_dims_); - auto ret = rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int64_t), + auto ret = rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int32_t), netoutput_last_input_addr_, netoutput_last_input_size_, RT_MEMCPY_DEVICE_TO_HOST); GE_CHK_RT_RET(ret); } @@ -2668,11 +2699,11 @@ void *DavinciModel::Run(DavinciModel *model) { GE_IF_BOOL_EXEC(current_data.blobs.empty(), break); auto shape_data_buffer_data = current_data.blobs.back().data; auto shape_data_buffer_length = current_data.blobs.back().length; - model->cur_dynamic_dims_.assign(reinterpret_cast(shape_data_buffer_data), - reinterpret_cast(shape_data_buffer_data) + - shape_data_buffer_length / sizeof(int64_t)); + model->cur_dynamic_dims_.assign(reinterpret_cast(shape_data_buffer_data), + reinterpret_cast(shape_data_buffer_data) + + shape_data_buffer_length / sizeof(int32_t)); GELOGD("Data: cur dynamic dims is %s", formats::JoinToString(model->cur_dynamic_dims_).c_str()); - delete[] reinterpret_cast(current_data.blobs.back().data); + delete[] reinterpret_cast(current_data.blobs.back().data); current_data.blobs.pop_back(); } GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_PRE_PROC_END)); diff --git a/ge/graph/load/new_model_manager/davinci_model.h b/ge/graph/load/new_model_manager/davinci_model.h index 6b930b05..9ff59d4e 100755 --- a/ge/graph/load/new_model_manager/davinci_model.h +++ b/ge/graph/load/new_model_manager/davinci_model.h @@ -864,11 +864,13 @@ class DavinciModel { void ParseDynamicOutShape(const vector &str_info, vector> &vec_info); bool IsGetNextSinkDynamic(const OpDescPtr &op_desc); + + Status InitRealSizeAndShapeInfo(const ComputeGraphPtr &compute_graph, const NodePtr &node); void GetAllGearsInfo(const NodePtr &node); Status GetGetDynamicDimsNodeInfo(const NodePtr &node); - Status GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr &node); - Status GetRealOutputSizeOfMerge(size_t input_index, const NodePtr &merge_node); - Status GetGearAndRealOutShapeInfo(size_t input_count, const OpDescPtr &op_desc); + Status GetGearAndRealOutSizeInfo(const ComputeGraphPtr &graph, const NodePtr &node); + Status GetRealOutputSizeOfCase(const ComputeGraphPtr &graph, size_t input_index, const NodePtr &case_node); + Status GetGearAndRealOutShapeInfo(const ComputeGraphPtr &graph, const NodePtr &node); bool is_weight_mem_has_inited_; bool is_feature_map_mem_has_inited_; @@ -1021,15 +1023,15 @@ class DavinciModel { bool is_new_model_desc_{false}; bool is_online_infer_dynamic_ = false; bool is_getnext_sink_dynamic_ = false; - vector cur_dynamic_dims_; + vector cur_dynamic_dims_; void *netoutput_last_input_addr_ = nullptr; int64_t netoutput_last_input_size_ = 0; size_t shape_of_cur_dynamic_dims_ = 0; // key: input_index: input is merge node; value: each gear info and each output size - map, int64_t>> merge_nodes_gear_and_real_out_size_info_; + map, int64_t>> merge_nodes_gear_and_real_out_size_info_; // key: input_index: input is merge node; value: each gear info and each output shape - map, vector>> merge_nodes_gear_and_real_out_shape_info_; - vector> all_gears_info_; + map, vector>> merge_nodes_gear_and_real_out_shape_info_; + vector> all_gears_info_; multimap op_id_map_; vector profile_list_; diff --git a/ge/graph/load/new_model_manager/model_manager.cc b/ge/graph/load/new_model_manager/model_manager.cc index 6f923236..b2cce73a 100755 --- a/ge/graph/load/new_model_manager/model_manager.cc +++ b/ge/graph/load/new_model_manager/model_manager.cc @@ -460,8 +460,8 @@ Status ModelManager::DataInput(const InputData &input_data, OutputData &output_d Status ModelManager::GetCurDynamicDims(const vector> &user_real_input_dims, const vector>> &user_input_dims, - vector &cur_dynamic_dims) { - GELOGD(" Start get cur dynamic dims."); + vector &cur_dynamic_dims) { + GELOGD("Start get cur dynamic dims."); if (user_real_input_dims.size() != user_input_dims.size()) { GELOGE(INTERNAL_ERROR, "The input count of user: %zu should be equal to the data count of graph: %zu", @@ -478,7 +478,7 @@ Status ModelManager::GetCurDynamicDims(const vector> &user_real_ } for (size_t j = 0; j < user_input_dims.at(i).second.size(); ++j) { if (user_input_dims.at(i).second.at(j) < 0) { - cur_dynamic_dims.emplace_back(user_real_input_dims[i][j]); + cur_dynamic_dims.emplace_back(static_cast(user_real_input_dims[i][j])); } } } @@ -523,7 +523,7 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector cur_dynamic_dims; + std::vector cur_dynamic_dims; if (!GetLocalOmgContext().user_real_input_dims.empty()) { if (GetCurDynamicDims(GetLocalOmgContext().user_real_input_dims, GetLocalOmgContext().user_input_dims, cur_dynamic_dims) != SUCCESS) { @@ -531,9 +531,9 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector(cur_dynamic_dims.size() * sizeof(int64_t)); + uint32_t length = static_cast(cur_dynamic_dims.size() * sizeof(int32_t)); GE_CHK_BOOL_EXEC(memcpy_s(data.data, length, cur_dynamic_dims.data(), length) == EOK, return INTERNAL_ERROR, "Failed to memcpy data."); data.length = length; diff --git a/ge/graph/load/new_model_manager/model_manager.h b/ge/graph/load/new_model_manager/model_manager.h index 088ea5fd..500cad31 100755 --- a/ge/graph/load/new_model_manager/model_manager.h +++ b/ge/graph/load/new_model_manager/model_manager.h @@ -126,14 +126,14 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// /// @ingroup domi_ome /// @brief Get cur_dynamic_dims for all input. - /// @param [in] vector> &user_real_input_dims: dims info of all user_inputs. + /// @param [in] vector> &user_real_input_dims: dims info of all user_inputs. /// @param [in] vector>> &user_input_dims: key:name. value:dynamic dims from option. - /// @param [out] vector &cur_dynamic_dims: real dims gather, where the index of -1. + /// @param [out] vector &cur_dynamic_dims: real dims gather, where the index of -1. /// @return 0: SUCCESS / others: INTERNAL_ERROR /// Status GetCurDynamicDims(const vector> &user_real_input_dims, const vector>> &user_input_dims, - vector &cur_dynamic_dims); + vector &cur_dynamic_dims); /// /// @ingroup domi_ome diff --git a/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc b/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc index df43fd5b..8033c93e 100644 --- a/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc +++ b/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc @@ -145,7 +145,9 @@ Status HcclTaskInfo::SetFollowStream(const ge::ConstOpDescPtr &op_desc, DavinciM } else { GELOGI("need to reuse follow stream and create new follow stream."); size_t created_stream_num = follow_stream_usage.size(); - hccl_stream_list_ = follow_stream_usage; + for (const auto &stream : follow_stream_usage) { + hccl_stream_list_.emplace_back(stream); + } ret = CreateStream(hccl_stream_num - created_stream_num, davinci_model, main_stream_id); if (ret != SUCCESS) { GELOGE(RT_FAILED, "Create hccl stream failed."); diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 6372a018..38de6ff7 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -2780,8 +2780,10 @@ Status GraphManager::ParseInputsDims(const std::vector &input_t if (!GetLocalOmgContext().dynamic_node_type.empty()) { vector data_nodes; vector getnext_nosink_nodes; - data_nodes = compute_graph_->TryGetExtAttr(kExtAttrDataNodes, data_nodes); - getnext_nosink_nodes = compute_graph_->TryGetExtAttr(kExtAttrGetNextNoSink, getnext_nosink_nodes); + data_nodes = GetLocalOmgContext().data_nodes; + getnext_nosink_nodes = GetLocalOmgContext().getnext_nosink_nodes; + GELOGD("Data nodes count is %zu, getnext nosink nodes count is %zu.", data_nodes.size(), + getnext_nosink_nodes.size()); if (GetLocalOmgContext().dynamic_node_type == DATA) { if (getnext_nosink_nodes.empty()) { // just data or data+getnext_sink diff --git a/ge/graph/passes/common_subexpression_elimination_pass.cc b/ge/graph/passes/common_subexpression_elimination_pass.cc index a4662d5d..7d9724fc 100644 --- a/ge/graph/passes/common_subexpression_elimination_pass.cc +++ b/ge/graph/passes/common_subexpression_elimination_pass.cc @@ -26,6 +26,10 @@ namespace ge { namespace { +std::set un_compute_attrs = { + {ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES}, +}; + std::string GetCseKey(const NodePtr &node) { std::stringstream ss; ss << node->GetType() << "-data-inputs-"; @@ -49,7 +53,7 @@ std::string GetCseKey(const NodePtr &node) { ss << name << "-"; } - ss << "attrs-" << AttrUtils::GetAllAttrsStr(node->GetOpDesc()); + ss << "attrs-" << AttrUtils::GetAttrsStrAfterRid(node->GetOpDesc(), un_compute_attrs); return ss.str(); } diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index f8451ace..b7efa070 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -25,31 +25,65 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "register/op_registry.h" +#include "graph/common/omg_util.h" namespace ge { namespace { constexpr uint8_t kDataInIndex = 0; constexpr uint8_t kDataOutIndex = 0; constexpr uint8_t kCaseArgIndex = 1; +const int kDivisionConst = 2; +const size_t kNumOfGetnextNode = 1; const std::string kMultiBatchCaseNode = "ascend_mbatch_shape_case"; const std::string kMultiBatchDataNode = "ascend_mbatch_shape_data"; +const std::string kMultiBatchGetDynamicDimsNode = "ascend_mbatch_get_dynamic_dims_node"; const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const"; const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex"; const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_"; +const char *const kGetNextName = "IteratorV2"; } // namespace +inline bool IsGetNextType(const NodePtr &node) { + std::string original_type; + GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, + GELOGW("Get original type failed."); return false); + return (original_type == kGetNextName); +} + Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { + GE_IF_BOOL_EXEC(graph == nullptr, GELOGE(FAILED, "Original graph is nullptr"); return FAILED); if (graph->GetParentGraph() != nullptr) { GELOGD("Subgraph %s skip the MultiBatchClonePass", graph->GetName().c_str()); return SUCCESS; } - + if (!GetLocalOmgContext().need_multi_batch) { + GELOGI("No need to process_multi for no_train graph."); + return SUCCESS; + } + std::vector data_nodes; + std::vector getnext_nosink_nodes; + std::vector getnext_sink_nodes; + if (multibatch::CheckSequenceOfOptions(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Train_Dynamic] CheckSequenceOfOptions failed."); + return PARAM_INVALID; + } + if (multibatch::UpdateNameOfInputShape(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Train_Dynamic] UpdateNameForInputShapeOfOption failed."); + return PARAM_INVALID; + } + if (multibatch::DeleteIdentityInsertByAdapter(graph) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Train_Dynamic] DeleteIdentityInsertByAdapter failed."); + return PARAM_INVALID; + } if (!multibatch::InitDynamicParams(batch_shapes_)) { GELOGD("There is no multi-batch options, no need clone multi-batch graph"); return SUCCESS; } - + if (multibatch::CheckNegativeCountOfOptions(batch_shapes_) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Train_Dynamic] Input_shape and dynamic_dims should set correct params."); + return PARAM_INVALID; + } GELOGD("Begin to run Multi-batch clone on graph: %s", graph->GetName().c_str()); GE_CHK_STATUS_RET(multibatch::CheckDynamicParams(batch_shapes_), "Invalid multi-batch param"); if (CollectIoNodes(graph) != SUCCESS) { @@ -66,21 +100,14 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { (void)AttrUtils::GetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); ComputeGraphPtr branch = MakeShared(graph->GetName()); - if (branch == nullptr) { - GELOGE(OUT_OF_MEMORY, "Create multi-batch graph failed"); - return OUT_OF_MEMORY; - } + GE_IF_BOOL_EXEC(branch == nullptr, GELOGE(OUT_OF_MEMORY, "Create multi batch graph failed"); return OUT_OF_MEMORY); (void)AttrUtils::SetStr(branch, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); graph->InValid(); // Will modify, need topological again. graph->Swap(*branch); - if (CreateRootGraph(graph) != SUCCESS) { - return FAILED; - } - - if (CreateSubgraphs(graph, branch) != SUCCESS) { - return FAILED; - } + GE_CHK_STATUS_RET(CreateRootGraph(graph), "Construct root graph failed."); + GE_CHK_STATUS_RET(CreateOriGraph(branch), "Construct original graph failed.") + GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch), "Construct subgraph failed."); GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed"); GELOGD("MultiBatchClonePass Leave"); @@ -95,9 +122,13 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { /// Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { for (const auto &node : graph->GetDirectNode()) { + if (!GetLocalOmgContext().dynamic_node_type.empty() && IsGetNextType(node)) { + all_data_nodes_.emplace_back(node); + GE_CHK_STATUS_RET(InitParamsOfGetNext(node), "Init params of %s failed.", node->GetName().c_str()); + } if (node->GetType() == DATA) { all_data_nodes_.emplace_back(node); - } else if (node->GetType() == CONSTANT) { + } else if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) { all_const_nodes_.emplace_back(node); } else if (node->GetType() == NETOUTPUT) { all_output_nodes_.emplace_back(node); @@ -114,10 +145,16 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { } int64_t data_index = 0; + size_t getnext_node_count = 0; for (size_t i = 0; i < all_data_nodes_.size(); ++i) { + if (IsGetNextType(all_data_nodes_[i])) { + // just one getnext node in graph + getnext_node_count++; + continue; + } const auto &op_desc = all_data_nodes_[i]->GetOpDesc(); if (!AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) { - (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i); + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i - getnext_node_count); } } @@ -133,7 +170,43 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { "Remove edge failed"); } } + GELOGD("Data count is %zu, const count is %zu, getnext count is %zu, output count is %zu, direct out count is %zu.", + all_data_nodes_.size(), all_const_nodes_.size(), getnext_node_count, all_output_nodes_.size(), + direct_output_.size()); + + return SUCCESS; +} +Status MultiBatchClonePass::InitParamsOfGetNext(const NodePtr &node) { + data_count_from_getnext_ = 0; + getnext_sink_dynamic_dims_ = false; + GE_CHECK_NOTNULL(node->GetOpDesc()); + data_count_from_getnext_ = node->GetOpDesc()->GetOutputsSize(); + if (GetLocalOmgContext().dynamic_node_type == GETNEXT) { + data_count_from_getnext_ = data_count_from_getnext_ / kDivisionConst; + for (size_t i = 0; i < data_count_from_getnext_; ++i) { + GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(i); + GELOGD("The %zu data shape from getnext sink is %s.", i, + formats::JoinToString(output_desc.GetShape().GetDims()).c_str()); + const auto &dims = output_desc.GetShape().GetDims(); + if (std::all_of(dims.begin(), dims.end(), [](int64_t val) {return val >= 0; })) { + GELOGD("The %zu data from %s is static.", i, node->GetName().c_str()); + } else { + getnext_sink_dynamic_dims_ = true; + GELOGD("Dynamic dims in the pattern of getnext sink."); + } + } + } + if (node->GetOutControlAnchor() != nullptr) { + for (const auto &peer_in_control_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { + NodePtr next_node = peer_in_control_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(next_node); + if (next_node->GetType() == CONSTANTOP) { + out_control_nodes_.insert(next_node); + GELOGD("Control edge: %s connect with %s.", node->GetName().c_str(), next_node->GetName().c_str()); + } + } + } return SUCCESS; } @@ -144,7 +217,11 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { /// @return 0: SUCCESS / others: FAILED /// Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { + GELOGD("Start create root graph of %s.", graph->GetName().c_str()); uint32_t input_num = all_data_nodes_.size() + all_const_nodes_.size(); + if (data_count_from_getnext_ != 0) { + input_num = input_num + data_count_from_getnext_ - kNumOfGetnextNode; + } uint32_t output_num = all_output_nodes_[0]->GetAllInDataAnchorsSize(); OpDescBuilder op_builder(kMultiBatchCaseNode, CASE); @@ -185,6 +262,10 @@ Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { op_desc->GetName().c_str()); return FAILED; } + if (!AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) { + GELOGE(INTERNAL_ERROR, "Failed to add insert attr on case node %s", op_desc->GetName().c_str()); + return INTERNAL_ERROR; + } GE_CHK_STATUS_RET(multibatch::StampDynamicType(op_desc), "Set dynamic type failed"); GE_CHK_STATUS_RET(CreateIndexNode(graph), "Create index node failed"); @@ -202,7 +283,7 @@ Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { /// @param [in] NodePtr node: index data node. /// @return 0: SUCCESS / others: FAILED /// -Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &node) { +Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &shape_node) { const OpDescPtr data_desc = MakeShared(kMultiBatchDataNode, DATA); if (data_desc == nullptr) { GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed"); @@ -220,11 +301,12 @@ Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, No } size_t data_index = all_data_nodes_.size(); + data_index = data_count_from_getnext_ != 0 ? data_index - kNumOfGetnextNode : data_index; (void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index); (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true); - node = graph->AddNode(data_desc); - if (node == nullptr) { + shape_node = graph->AddNode(data_desc); + if (shape_node == nullptr) { GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed"); return OUT_OF_MEMORY; } @@ -286,15 +368,19 @@ Status MultiBatchClonePass::CreateIndexConstNode(const ComputeGraphPtr &graph, N /// @return 0: SUCCESS / others: FAILED /// Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { - // Data --> MapIndex --> Case - NodePtr data_node; - GE_CHK_STATUS_RET(CreateIndexDataNode(graph, data_node), "Create data node failed"); + // Data/GetDynamicDims --> MapIndex --> Case + if (!getnext_sink_dynamic_dims_) { + GE_CHK_STATUS_RET(CreateIndexDataNode(graph, shape_node_), "Create data node failed"); + } else { + GE_CHK_STATUS_RET(CreateGetDynamicDimsNode(graph, shape_node_), "Create get dynamic dims node failed"); + } NodePtr const_node; GE_CHK_STATUS_RET(CreateIndexConstNode(graph, const_node), "Create const node failed"); - + GELOGD("Shape node name is %s, type is %s, const node name is %s.", shape_node_->GetName().c_str(), + shape_node_->GetType().c_str(), const_node->GetName().c_str()); OpDescBuilder op_builder(kMultiBatchMapIndexNode, "MapIndex"); - op_builder.AddInput("x", data_node->GetOpDesc()->GetOutputDesc(0)) + op_builder.AddInput("x", shape_node_->GetOpDesc()->GetOutputDesc(0)) .AddInput("data_seq", const_node->GetOpDesc()->GetOutputDesc(0)) .AddOutput("y", GeTensorDesc(GeShape(), FORMAT_ND, DT_INT32)); @@ -309,8 +395,10 @@ Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { return OUT_OF_MEMORY; } - if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", data_node->GetName().c_str(), + GE_CHK_STATUS_RET(AddAttrForGetDynamicDims(shape_node_), "Failed to add attr for %s.", + shape_node_->GetName().c_str()); + if (GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", shape_node_->GetName().c_str(), index_node->GetName().c_str()); return FAILED; } @@ -328,6 +416,120 @@ Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { return SUCCESS; } +Status MultiBatchClonePass::CreateGetDynamicDimsNode(const ComputeGraphPtr &graph, NodePtr &shape_node) { + const OpDescPtr data_desc = MakeShared(kMultiBatchGetDynamicDimsNode, GETDYNAMICDIMS); + if (data_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch get dynamic dims node failed"); + return OUT_OF_MEMORY; + } + + // input of GetDynamicDims is shape_of_each_data, output is gear_info + for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) { + size_t input_shape_dims = GetLocalOmgContext().user_input_dims.at(i).second.size(); + // add input desc without GeShape for const input, value of input_shape is 1 transferred by adapter + if (input_shape_dims == 1 && GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) { + GeTensorDesc tensor_desc; + tensor_desc.SetFormat(FORMAT_ND); + tensor_desc.SetDataType(DT_INT32); + auto ret = data_desc->AddInputDesc(tensor_desc); + GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data"); + return FAILED); + continue; + } + GeTensorDesc tensor_desc(GeShape({static_cast(input_shape_dims)}), FORMAT_ND, DT_INT32); + auto ret = data_desc->AddInputDesc(tensor_desc); + GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data"); + return FAILED); + } + GeTensorDesc tensor_desc(GeShape({static_cast(batch_shapes_.at(0).size())}), FORMAT_ND, DT_INT32); + auto ret = data_desc->AddOutputDesc(tensor_desc); + GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add output desc for created data"); + return FAILED); + + (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true); + + shape_node = graph->AddNode(data_desc); + if (shape_node == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch dynamic dims node failed"); + return OUT_OF_MEMORY; + } + return SUCCESS; +} + +Status MultiBatchClonePass::AddAttrForGetDynamicDims(const NodePtr &shape_node) { + if (!getnext_sink_dynamic_dims_) { + GELOGD("No need to add attr when not insert get dynamic dims node."); + return SUCCESS; + } + GELOGD("Add attr for :%s, type is %s:", shape_node->GetName().c_str(), shape_node->GetType().c_str()); + if (!AttrUtils::SetInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_DATA_COUNT, data_count_from_getnext_)) { + GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_DATA_COUNT failed"); + return INTERNAL_ERROR; + } + vector shape_info; + for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) { + if (GetLocalOmgContext().user_input_dims.at(i).second.size() == 1 && + GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) { + shape_info.emplace_back(0); + continue; + } + shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.size()); + for (size_t j = 0; j < GetLocalOmgContext().user_input_dims.at(i).second.size(); ++j) { + shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.at(j)); + } + } + if (!AttrUtils::SetListInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_SHAPE_INFO, shape_info)) { + GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_SHAPE_INFO failed"); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status MultiBatchClonePass::LinkGetNextToGetDynamicDims(const NodePtr &getnext_node, const NodePtr &shape_node) { + GELOGD("Start relink shape anchor of %s to %s.", getnext_node->GetName().c_str(), shape_node->GetName().c_str()); + size_t input_index = 0; + size_t data_count = getnext_node->GetAllOutDataAnchors().size() / kDivisionConst; + for (size_t out_index = data_count; out_index < getnext_node->GetAllOutDataAnchors().size(); ++out_index, + ++input_index) { + GELOGD("Start add %s of %zu out_anchor to %s of %zu in_anchor.", getnext_node->GetName().c_str(), out_index, + shape_node->GetName().c_str(), input_index); + auto out_data_anchor = getnext_node->GetOutDataAnchor(out_index); + auto ret = GraphUtils::AddEdge(out_data_anchor, shape_node->GetInDataAnchor(input_index)); + GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link getnext %s to getdynamicdims %s", + getnext_node->GetName().c_str(), shape_node->GetName().c_str()); + return INTERNAL_ERROR); + } + return SUCCESS; +} + +Status MultiBatchClonePass::LinkGetDynamicDimsToNetOutput(const NodePtr &output_node) { + if (!GetLocalOmgContext().dynamic_node_type.empty()) { + if (!AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, GetLocalOmgContext().dynamic_dims)) { + GELOGE(INTERNAL_ERROR, "Failed to set all gears info attr on netoutput %s.", output_node->GetName().c_str()); + return INTERNAL_ERROR; + } + } + if (getnext_sink_dynamic_dims_) { + GELOGD("Start link %s to %s.", shape_node_->GetName().c_str(), output_node->GetName().c_str()); + size_t input_index = output_node->GetAllInDataAnchors().size(); + if (NodeUtils::AppendInputAnchor(output_node, input_index + 1) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Append input anchor of %s of %zu failed.", output_node->GetName().c_str(), input_index); + return INTERNAL_ERROR; + } + auto ret = GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(kDataOutIndex), + output_node->GetInDataAnchor(input_index)); + GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link netoutput %s to getdynamicdims %s", + output_node->GetName().c_str(), shape_node_->GetName().c_str()); + return INTERNAL_ERROR); + if (!AttrUtils::SetBool(output_node->GetOpDesc(), ATTR_GETNEXT_SINK_DYNMAIC, true)) { + GELOGE(INTERNAL_ERROR, "Failed to set getnext sink dynamic attr on netoutput %s.", + output_node->GetName().c_str()); + return INTERNAL_ERROR; + } + } + return SUCCESS; +} + /// /// @ingroup ge /// @brief Create input node for root graph. @@ -337,8 +539,10 @@ Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { // Data --> Case std::vector all_data_nodes; - const size_t arg_index = kCaseArgIndex; - for (size_t i = 0; i < all_data_nodes_.size(); ++i) { + size_t case_input_index = kCaseArgIndex; + NodePtr getnext_node = nullptr; + size_t input_index_of_getnext = 0; + for (size_t i = 0; i < all_data_nodes_.size(); ++i, ++case_input_index) { const auto &node = all_data_nodes_[i]; const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc()); if (op_desc == nullptr) { @@ -353,22 +557,60 @@ Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { op_desc->SetName(node->GetName()); const NodePtr &data = graph->AddNode(op_desc); GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); - if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", - data->GetName().c_str(), case_node_->GetName().c_str()); - return FAILED; + if (IsGetNextType(node)) { + getnext_node = data; + input_index_of_getnext = case_input_index; + case_input_index = case_input_index + data_count_from_getnext_; + continue; + } else { + if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(case_input_index)) != + GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", data->GetName().c_str(), + case_node_->GetName().c_str()); + return FAILED; + } } - if (SetMaxShapeToData(data) != SUCCESS) { + if (SetMaxShape(data) != SUCCESS) { + GELOGE(FAILED, "Set max shape of %s failed.", data->GetName().c_str()); return FAILED; } all_data_nodes.emplace_back(data); } + if (getnext_node != nullptr) { + if (LinkEdgeForGetNext(getnext_node, input_index_of_getnext) != SUCCESS) { + GELOGE(FAILED, "Failed to link edge for %s.", getnext_node->GetName().c_str()); + return FAILED; + } + if (SetMaxShape(getnext_node) != SUCCESS) { + GELOGE(FAILED, "Set max shape of %s failed.", getnext_node->GetName().c_str()); + return FAILED; + } + all_data_nodes.emplace_back(getnext_node); + } all_data_nodes_.swap(all_data_nodes); return SUCCESS; } +Status MultiBatchClonePass::LinkEdgeForGetNext(const NodePtr &getnext_node, size_t &case_input_index) { + GELOGD("Start link edge for %s, which is the %zu input of %s.", getnext_node->GetName().c_str(), + case_input_index, case_node_->GetName().c_str()); + for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++case_input_index) { + if (GraphUtils::AddEdge(getnext_node->GetOutDataAnchor(out_index), + case_node_->GetInDataAnchor(case_input_index)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to add data edge between %zu Data:%s to %zu Case:%s", out_index, + getnext_node->GetName().c_str(), case_input_index, case_node_->GetName().c_str()); + return FAILED; + } + } + if (getnext_sink_dynamic_dims_) { + GE_CHK_STATUS_RET(LinkGetNextToGetDynamicDims(getnext_node, shape_node_), "Failed to add link for %s.", + shape_node_->GetName().c_str()); + } + return SUCCESS; +} + /// /// @ingroup ge /// @brief Create Const node for root graph. @@ -378,7 +620,11 @@ Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { // Const --> Case std::vector all_const_nodes; - const size_t arg_index = kCaseArgIndex + all_data_nodes_.size(); + size_t arg_index = kCaseArgIndex + all_data_nodes_.size(); + if (data_count_from_getnext_ != 0) { + arg_index = arg_index + data_count_from_getnext_ - kNumOfGetnextNode; + } + for (size_t i = 0; i < all_const_nodes_.size(); ++i) { const auto &node = all_const_nodes_[i]; const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc()); @@ -395,15 +641,33 @@ Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { const NodePtr &data = graph->AddNode(op_desc); GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s", - data->GetName().c_str(), case_node_->GetName().c_str()); + GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s", data->GetName().c_str(), + case_node_->GetName().c_str()); return FAILED; } all_const_nodes.emplace_back(data); } + ChangeConstToData(); + all_const_nodes_.swap(all_const_nodes); + return SUCCESS; +} +void MultiBatchClonePass::ChangeConstToData() { size_t data_index = all_data_nodes_.size(); + if (data_count_from_getnext_ != 0) { + data_index = data_index + data_count_from_getnext_ - kNumOfGetnextNode; + } for (size_t i = 0; i < all_const_nodes_.size(); ++i, ++data_index) { // Trans subgraph Const to Data. + auto &const_node = all_const_nodes_[i]; + bool need_change_type = true; + if (out_control_nodes_.find(const_node) != out_control_nodes_.end()) { + GELOGD("No need to change %s to data type.", const_node->GetName().c_str()); + need_change_type = false; + break; + } + if (!need_change_type) { + continue; + } const OpDescPtr &op_desc = all_const_nodes_[i]->GetOpDesc(); op_desc->SetType(DATA); (void)op_desc->DelAttr(ATTR_NAME_WEIGHTS); // Delete weight. @@ -413,9 +677,6 @@ Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index); (void)NodeUtils::AppendInputAnchor(all_const_nodes_[i], 1); } - - all_const_nodes_.swap(all_const_nodes); - return SUCCESS; } /// @@ -461,7 +722,8 @@ Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) { } } } - + GE_CHK_STATUS_RET(LinkGetDynamicDimsToNetOutput(node), "Failed to add edge between %s to netoutput: %s.", + shape_node_->GetName().c_str(), output->GetName().c_str()); all_output_nodes_.clear(); all_output_nodes_.emplace_back(node); return SUCCESS; @@ -473,34 +735,69 @@ Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) { /// @param [in] const NodePtr &data: data in Root/Case graph. /// @return 0: SUCCESS / others: FAILED /// -Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { - auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); - auto data_name = data->GetName(); +Status MultiBatchClonePass::SetMaxShape(const NodePtr &data) { + GELOGD("Start set max shape for %s.", data->GetName().c_str()); + if (!IsGetNextType(data)) { + if (SetMaxShapeToData(data, kDataOutIndex) != SUCCESS) { + GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", data->GetName().c_str()); + return PARAM_INVALID; + } + } else { + for (size_t out_anchor_index = 0; out_anchor_index < data_count_from_getnext_; ++out_anchor_index) { + if (SetMaxShapeToData(data, out_anchor_index) != SUCCESS) { + GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", data->GetName().c_str()); + return PARAM_INVALID; + } + } + } + return SUCCESS; +} + +Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &node, size_t out_anchor_index) { + GELOGD("Start update max shape of %s, %zu output.", node->GetName().c_str(), out_anchor_index); + auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape(); + string data_name = node->GetName(); + if (IsGetNextType(node)) { + data_name.append("_").append(std::to_string(out_anchor_index)); + } + GELOGD("Update max shape of %s, shape dims is %s.", data_name.c_str(), + formats::JoinToString(data_shape.GetDims()).c_str()); const auto &dims = data_shape.GetDims(); - if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { - return SUCCESS; + if (!IsGetNextType(node)) { + if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { + GELOGD("No need to do anything for static data."); + return SUCCESS; + } + } else { + if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { + if (getnext_sink_dynamic_dims_) { + // need to update shape of Shape_node when getnext node has dynamic data + GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(node, out_anchor_index), "Failed to update shape of shape node"); + } + return SUCCESS; + } } - (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); + (void)AttrUtils::SetListInt(node->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); - GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex)); + GeTensorDesc tensor(NodeUtils::GetOutputDesc(*node, kDataOutIndex)); std::vector input_dims_str; for (size_t i = 0; i < batch_shapes_.size(); ++i) { auto shape = data_shape; auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); if (ret != SUCCESS) { - GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", data->GetName().c_str()); + GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", node->GetName().c_str()); return ret; } tensor.SetShape(shape); int64_t tensor_size = 0; (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + - TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + + TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + node->GetName() + ":" + std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + formats::JoinToString(tensor.GetShape().GetDims()); input_dims_str.emplace_back(input_str); } - (void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); + (void)AttrUtils::SetListStr(node->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); size_t max_shape_index = 0; int64_t max_size = 0; @@ -519,18 +816,72 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { max_shape_index = i; } } + return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), node, data_shape, out_anchor_index); +} - return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), data, data_shape); +/// +/// @ingroup ge +/// @brief Set max shape to Data/GetNext node in root graph. +/// @param [in] const std::vector &shapes: dims of shape. +/// @param [in] const NodePtr &data: data in Root/Case graph. +/// @param [in] GeShape &data_shape: dims of data node. +/// @param [in] size_t out_anchor_index: out anchor index of data node. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::SetShapeToData(const std::vector &shapes, const NodePtr &data, GeShape &data_shape, + size_t out_anchor_index) { + GELOGD("Start set shape to %zu out of %s.", out_anchor_index, data->GetName().c_str()); + if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to calculate the batched shape for data node %s, the shapes may not match", + data->GetName().c_str()); + return INTERNAL_ERROR; + } + + if (NodeUtils::UpdateOutputShape(*data, out_anchor_index, data_shape) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str()); + return INTERNAL_ERROR; + } + if (!IsGetNextType(data)) { + if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); + return INTERNAL_ERROR; + } + } else { + if (getnext_sink_dynamic_dims_) { + // need to update shape of Shape_node when getnext_sink_dynamic + GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(data, out_anchor_index), "Failed to update shape of shape node"); + } + } + + GELOGI("Update the data %s input/output shape to the max %s", data->GetName().c_str(), + formats::ShapeToString(data_shape).c_str()); + return SUCCESS; +} + +Status MultiBatchClonePass::UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index) { + GELOGD("Start update output shape of shape node insert by adapter, which is the %zu out of %s.", out_anchor_index, + node->GetName().c_str()); + auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape(); + size_t shape_index = out_anchor_index + (node->GetAllOutDataAnchors().size() / kDivisionConst); + GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(shape_index); + std::vector output_dims = {static_cast(data_shape.GetDims().size())}; + GeShape output_shape(output_dims); + output_desc.SetShape(output_shape); + if (node->GetOpDesc()->UpdateOutputDesc(shape_index, output_desc) != SUCCESS) { + GELOGE(FAILED, "Update output desc fail."); + return FAILED; + } + return SUCCESS; } /// /// @ingroup ge /// @brief Update Data node in Subgraph. /// @param [in] const NodePtr &data: data in Subgraph. -/// @param [in] size_t index: The batch index. +/// @param [in] size_t batch_index: The batch index. /// @return 0: SUCCESS / others: FAILED /// -Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index) { +Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t batch_index) { int node_index = -1; if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) { GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str()); @@ -545,6 +896,8 @@ Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); const auto &dims = data_shape.GetDims(); + GELOGD("Start update shape of %s , batch index is %zu, dims is %s.", data->GetName().c_str(), batch_index, + formats::JoinToString(dims).c_str()); if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { return SUCCESS; } @@ -559,35 +912,77 @@ Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index } auto parent_name = data_name.substr(0, pos); - return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(index), data, data_shape); + return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(batch_index), data, data_shape, kDataOutIndex); } -/// -/// @ingroup ge -/// @brief Set max shape to Data node in root graph. -/// @param [in] const std::vector &shapes: dims of shape. -/// @param [in] const NodePtr &data: data in Root/Case graph. -/// @param [in] GeShape &data_shape: dims of data node. -/// @return 0: SUCCESS / others: FAILED -/// -Status MultiBatchClonePass::SetShapeToData(const vector &shapes, const NodePtr &data, GeShape &data_shape) { - // must not be error, the calc result has been checked in function InsertSwitchNForData - if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) { - return INTERNAL_ERROR; +Status MultiBatchClonePass::CreateOriGraph(const ComputeGraphPtr &graph) { + if (data_count_from_getnext_ == 0) { + GELOGD("No need to change original graph without getnext node."); + return SUCCESS; } - - if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); - return INTERNAL_ERROR; + GELOGD("Start change original graph: %s when exit getnext node.", graph->GetName().c_str()); + size_t data_index = all_data_nodes_.size() - kNumOfGetnextNode; + for (const auto &node : graph->GetDirectNode()) { + if (IsGetNextType(node)) { + for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++data_index) { + auto out_data_anchor = node->GetOutDataAnchor(out_index); + GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); + NodePtr data_node = CreateDataNode(graph, out_data_anchor, data_index); + GE_IF_BOOL_EXEC(data_node == nullptr, GELOGE(INTERNAL_ERROR, "Create %zu data node failed.", + out_data_anchor->GetIdx()); return INTERNAL_ERROR); + for (auto &in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_IF_BOOL_EXEC(in_anchor == nullptr, continue); + NodePtr dst_node = in_anchor->GetOwnerNode(); + if (GraphUtils::RemoveEdge(out_data_anchor, in_anchor) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to remove edge between %s to %s", node->GetName().c_str(), + dst_node->GetName().c_str()); + return INTERNAL_ERROR; + } + if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(in_anchor->GetIdx())) != + GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to add edge between %s to %s", data_node->GetName().c_str(), + dst_node->GetName().c_str()); + return INTERNAL_ERROR; + } + } + } + if (graph->RemoveNode(node) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Remove node %s failed!", node->GetName().c_str()); + return GRAPH_FAILED; + } + break; + } } + return SUCCESS; +} - if (NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str()); - return INTERNAL_ERROR; +NodePtr MultiBatchClonePass::CreateDataNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, + size_t data_index) { + size_t out_anchor_index = out_data_anchor->GetIdx(); + std::string node_name = out_data_anchor->GetOwnerNode()->GetName() + "_" + std::to_string(out_anchor_index); + OpDescPtr op_desc = MakeShared(node_name, DATA); + if (op_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create data node failed."); + return nullptr; } + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index); - GELOGI("Update %s input/output shape to %s", data->GetName().c_str(), formats::ShapeToString(data_shape).c_str()); - return SUCCESS; + OpDescPtr getnext_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); + if (getnext_op_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Op desc of %s is nullptr.", out_data_anchor->GetOwnerNode()->GetName().c_str()); + return nullptr; + } + if (op_desc->AddInputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add %s input desc failed.", op_desc->GetName().c_str()); + return nullptr; + } + if (op_desc->AddOutputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add %s output desc failed.", op_desc->GetName().c_str()); + return nullptr; + } + NodePtr data_node = graph->AddNode(op_desc); + GELOGD("Success create %s node.", data_node->GetName().c_str()); + return data_node; } /// @@ -598,17 +993,14 @@ Status MultiBatchClonePass::SetShapeToData(const vector &shapes, const /// @return 0: SUCCESS / others: FAILED /// Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch) { + GELOGD("Start create subgraphs for %s.", graph->GetName().c_str()); const auto &op_desc = case_node_->GetOpDesc(); for (size_t i = 0; i < batch_shapes_.size(); ++i) { std::vector input_nodes; std::vector output_nodes; const std::string postfix = kMultiBatchNodePostfix + std::to_string(i); ComputeGraphPtr subgraph = (i == 0) ? branch : GraphUtils::CloneGraph(branch, postfix, input_nodes, output_nodes); - if (subgraph == nullptr) { - GELOGE(FAILED, "Create multi-batch case node failed"); - return FAILED; - } - + GE_IF_BOOL_EXEC(subgraph == nullptr, GELOGE(FAILED, "Create multi-batch case node failed"); return FAILED); subgraph->SetName("Batch_" + std::to_string(i)); subgraph->SetParentNode(case_node_); subgraph->SetParentGraph(graph); @@ -621,6 +1013,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const op_desc->AddSubgraphName(key_name); op_desc->SetSubgraphInstanceName(i, subgraph->GetName()); + GELOGD("The %s has %zu input, %zu output.", subgraph->GetName().c_str(), input_nodes.size(), output_nodes.size()); for (const auto &data : input_nodes) { GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str()); } @@ -666,6 +1059,7 @@ Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) { /// @return 0: SUCCESS / others: FAILED /// Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) { + GELOGD("Start prune direct output."); const auto &func_desc = case_node_->GetOpDesc(); uint32_t unused_num = 0; uint32_t output_num = func_desc->GetOutputsSize(); @@ -710,6 +1104,7 @@ Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) { /// Status MultiBatchClonePass::UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num) { if (unused_num == 0) { + GELOGD("No need to update output tensor."); return SUCCESS; } diff --git a/ge/graph/passes/multi_batch_clone_pass.h b/ge/graph/passes/multi_batch_clone_pass.h index ee137b5a..66e92892 100755 --- a/ge/graph/passes/multi_batch_clone_pass.h +++ b/ge/graph/passes/multi_batch_clone_pass.h @@ -36,6 +36,7 @@ class MultiBatchClonePass : public GraphPass { /// @return 0: SUCCESS / others: FAILED /// Status CollectIoNodes(const ComputeGraphPtr &graph); + Status InitParamsOfGetNext(const NodePtr &node); /// /// @ingroup ge @@ -49,10 +50,12 @@ class MultiBatchClonePass : public GraphPass { /// @ingroup ge /// @brief Create index data node for root graph. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. - /// @param [in] NodePtr node: index data node. + /// @param [in] NodePtr shape_node: index data node, DATA or GETDYNAMICDIMS type. /// @return 0: SUCCESS / others: FAILED /// - Status CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &node); + Status CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &shape_node); + + Status CreateGetDynamicDimsNode(const ComputeGraphPtr &graph, NodePtr &shape_node); /// /// @ingroup ge @@ -70,6 +73,9 @@ class MultiBatchClonePass : public GraphPass { /// @return 0: SUCCESS / others: FAILED /// Status CreateIndexNode(const ComputeGraphPtr &graph); + Status AddAttrForGetDynamicDims(const NodePtr &shape_node); + Status LinkGetNextToGetDynamicDims(const NodePtr &getnext_node, const NodePtr &shape_node); + Status LinkGetDynamicDimsToNetOutput(const NodePtr &output_node); /// /// @ingroup ge @@ -78,39 +84,54 @@ class MultiBatchClonePass : public GraphPass { /// @return 0: SUCCESS / others: FAILED /// Status CreateInputNode(const ComputeGraphPtr &graph); + Status LinkEdgeForGetNext(const NodePtr &getnext_node, size_t &case_input_index); /// /// @ingroup ge - /// @brief Create Const node for root graph. - /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. + /// @brief Set max shape to Data node in root graph. + /// @param [in] const NodePtr &data: data in Root/Case graph. /// @return 0: SUCCESS / others: FAILED /// - Status CreateConstNode(const ComputeGraphPtr &graph); + Status SetMaxShape(const NodePtr &data); + Status SetMaxShapeToData(const NodePtr &node, size_t out_anchor_index); + /// + /// @ingroup ge + /// @brief Set max shape to Data/GetNext node in root graph. + /// @param [in] const std::vector &shapes: dims of shape. + /// @param [in] const NodePtr &data: data in Root/Case graph. + /// @param [in] GeShape &data_shape: dims of data node. + /// @param [in] size_t out_anchor_index: out anchor index of data node. + /// @return 0: SUCCESS / others: FAILED + /// + Status SetShapeToData(const std::vector &shapes, const NodePtr &data, GeShape &data_shape, + size_t out_anchor_index); + Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index); /// /// @ingroup ge - /// @brief Create output node for root graph. + /// @brief Create Const node for root graph. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. /// @return 0: SUCCESS / others: FAILED /// - Status CreateOutputNode(const ComputeGraphPtr &graph); + Status CreateConstNode(const ComputeGraphPtr &graph); + void ChangeConstToData(); /// /// @ingroup ge - /// @brief Set max shape to Data node in root graph. - /// @param [in] const NodePtr &data: data in Root/Case graph. + /// @brief Create output node for root graph. + /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. /// @return 0: SUCCESS / others: FAILED /// - Status SetMaxShapeToData(const NodePtr &data); + Status CreateOutputNode(const ComputeGraphPtr &graph); /// /// @ingroup ge /// @brief Update Data node in Subgraph. /// @param [in] const NodePtr &data: data in Subgraph. - /// @param [in] size_t index: The batch index. + /// @param [in] size_t batch_index: The batch index. /// @return 0: SUCCESS / others: FAILED /// - Status UpdateSubgraphData(const NodePtr &data, size_t index); + Status UpdateSubgraphData(const NodePtr &data, size_t batch_index); /// /// @ingroup ge @@ -122,13 +143,12 @@ class MultiBatchClonePass : public GraphPass { /// /// @ingroup ge - /// @brief Set max shape to Data node in root graph. - /// @param [in] const std::vector &shapes: dims of shape. - /// @param [in] const NodePtr &data: data in Root/Case graph. - /// @param [in] GeShape &data_shape: dims of data node. + /// @brief Create nodes for root graph. + /// @param [in] const ComputeGraphPtr &graph: Original graph. /// @return 0: SUCCESS / others: FAILED /// - Status SetShapeToData(const std::vector &shapes, const NodePtr &data, GeShape &data_shape); + Status CreateOriGraph(const ComputeGraphPtr &graph); + NodePtr CreateDataNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, size_t data_index); /// /// @ingroup ge @@ -168,6 +188,10 @@ class MultiBatchClonePass : public GraphPass { std::map>> data_to_dynamic_info_; NodePtr case_node_; + size_t data_count_from_getnext_ = 0; + bool getnext_sink_dynamic_dims_ = false; + NodePtr shape_node_; + std::set out_control_nodes_; }; } // namespace ge #endif // GE_GRAPH_PASSES_MULTI_BATCH_CLONE_PASS_H_ diff --git a/ge/graph/passes/unused_args_clean_pass.cc b/ge/graph/passes/unused_args_clean_pass.cc index 83fd0438..ec66b129 100755 --- a/ge/graph/passes/unused_args_clean_pass.cc +++ b/ge/graph/passes/unused_args_clean_pass.cc @@ -204,6 +204,10 @@ Status UnusedArgsCleanPass::RemoveInputTensor(const mapGetName().c_str(), func_node->GetName().c_str()); + if (out_node->GetInDataNodes().size() == 0 && out_node->GetOutAllNodes().size() == 0) { + GE_CHK_GRAPH_STATUS_RET(out_node->GetOwnerComputeGraph()->RemoveNode(out_node), "Remove node failed: %s", + out_node->GetName().c_str()); + } return SUCCESS; } } // namespace ge \ No newline at end of file diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index c8880b2e..5506435e 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -1692,13 +1692,11 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { } Status ProcessMultiBatch(ComputeGraphPtr &graph) { - if (GetLocalOmgContext().dynamic_node_type.empty()) { - const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); - if (multi_batch_with_switchn == nullptr) { - PassManager pass_manager; - GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); - return pass_manager.Run(graph); - } + const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); + if (multi_batch_with_switchn == nullptr) { + PassManager pass_manager; + GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); + return pass_manager.Run(graph); } if (!GetLocalOmgContext().need_multi_batch) { GELOGI("No need to process_multi for no_train graph."); diff --git a/ge/graph/preprocess/multi_batch_options.cc b/ge/graph/preprocess/multi_batch_options.cc index c26b08bc..aba2b88d 100644 --- a/ge/graph/preprocess/multi_batch_options.cc +++ b/ge/graph/preprocess/multi_batch_options.cc @@ -99,9 +99,8 @@ Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector &data_n } GELOGI("Data count is %zu, getnext nosink count is %zu, getnext sink count is %zu.", data_nodes.size(), getnext_nosink_nodes.size(), getnext_sink_nodes.size()); - GE_IF_BOOL_EXEC(!graph->SetExtAttr(kExtAttrDataNodes, data_nodes), GELOGW("Set data nodes attr failed.");) - GE_IF_BOOL_EXEC(!graph->SetExtAttr(kExtAttrGetNextNoSink, getnext_nosink_nodes), - GELOGW("Set getnext nosink nodes attr failed.");) + GetLocalOmgContext().data_nodes = data_nodes; + GetLocalOmgContext().getnext_nosink_nodes = getnext_nosink_nodes; return SUCCESS; } diff --git a/inc/framework/omg/omg_inner_types.h b/inc/framework/omg/omg_inner_types.h index dab79053..1049b6b5 100644 --- a/inc/framework/omg/omg_inner_types.h +++ b/inc/framework/omg/omg_inner_types.h @@ -26,6 +26,7 @@ #include #include "framework/common/fmk_error_codes.h" #include "register/register_fmk_types.h" +#include "graph/node.h" using domi::DOMI_TENSOR_ND; using domi::DOMI_TENSOR_RESERVED; @@ -120,6 +121,8 @@ struct OmgContext { std::vector> user_real_input_dims; std::vector cur_dynamic_dims; bool need_multi_batch = false; + std::vector data_nodes; + std::vector getnext_nosink_nodes; }; } // namespace ge diff --git a/metadef b/metadef index 44bcbb5e..fe37bc34 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 44bcbb5ea25ada1a5393aa4c7f554d40b6859b18 +Subproject commit fe37bc343ea52c76d35e9e9ec83cea0151bfa900 diff --git a/parser b/parser index 5b93b050..336cd310 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit 5b93b050dd7ca5b77c3001a790031d877fa10956 +Subproject commit 336cd3107253d3fe41cfb9fec2db62b5f3d8a33b diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index dcf389c0..db725dfb 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -627,6 +627,7 @@ set(PASS_TEST_FILES "graph/passes/net_output_pass_unittest.cc" "graph/passes/no_use_reshape_remove_pass_unittest.cc" "graph/passes/infershape_pass_unittest.cc" + "graph/passes/multi_batch_clone_pass_unittest.cc" ) set(KERNEL_TEST_FILES diff --git a/tests/ut/ge/graph/load/davinci_model_unittest.cc b/tests/ut/ge/graph/load/davinci_model_unittest.cc index a9efab3d..9e51585b 100644 --- a/tests/ut/ge/graph/load/davinci_model_unittest.cc +++ b/tests/ut/ge/graph/load/davinci_model_unittest.cc @@ -32,6 +32,18 @@ class UtestDavinciModel : public testing::Test { void SetUp() {} void TearDown() {} + public: + NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { + GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); + auto op_desc = std::make_shared(name, type); + for (auto i = 0; i < in_num; ++i) { + op_desc->AddInputDesc(test_desc); + } + for (auto i = 0; i < out_num; ++i) { + op_desc->AddOutputDesc(test_desc); + } + return graph->AddNode(op_desc); + } }; TEST_F(UtestDavinciModel, init_success) { @@ -324,5 +336,94 @@ TEST_F(UtestDavinciModel, SyncVarData_test) { EXPECT_NE(model.SyncVarData(), SUCCESS); } +TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ1) { + DavinciModel model(0, nullptr); + model.ge_model_ = make_shared(); + ComputeGraphPtr graph = make_shared("default"); + + GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); + OpDescPtr op_output = CreateOpDesc("output_ascend_mbatch_batch_1", NETOUTPUT); + op_output->AddInputDesc(tensor); + op_output->SetInputOffset({1024}); + NodePtr node_output = graph->AddNode(op_output); + EXPECT_EQ(model.InitRealSizeAndShapeInfo(graph, node_output), SUCCESS); +} + +TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ2) { + DavinciModel model(0, nullptr); + ComputeGraphPtr graph = std::make_shared("test_graph"); + + OpDescPtr data1 = CreateOpDesc("data1", DATA); + GeTensorDesc shape_desc(GeShape({4,3,224,224}), FORMAT_NCHW, DT_FLOAT); + data1->AddInputDesc(shape_desc); + data1->AddOutputDesc(shape_desc); + NodePtr data1_node = graph->AddNode(data1); + + OpDescPtr case_node = CreateOpDesc("case1", CASE); + GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); + case_node->AddInputDesc(tensor); + case_node->AddOutputDesc(tensor); + NodePtr case1_node = graph->AddNode(case_node); + + OpDescPtr output = CreateOpDesc("output1", NETOUTPUT); + output->AddInputDesc(tensor); + output->SetSrcName( { "case1" } ); + output->SetSrcIndex( { 0 } ); + NodePtr output_node = graph->AddNode(output); + + GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), case1_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(case1_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + + (void)AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, "1;2;4;8"); + (void)AttrUtils::SetBool(case_node, ATTR_INSERT_BY_MBATCH, true); + + model.is_getnext_sink_dynamic_ = false; + model.is_online_infer_dynamic_ = true; + auto ret = model.InitRealSizeAndShapeInfo(graph, output_node); + // GetGearAndRealOutShapeInfo without ATTR_NAME_DYNAMIC_OUTPUT_DIMS + EXPECT_EQ(ret, SUCCESS); + vector dynamic_output_dims = {"0,0,1,1,0,2,2,0,4,3,0,8"}; + (void)AttrUtils::SetListStr(output_node->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_dims); + ret = model.InitRealSizeAndShapeInfo(graph, output_node); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ3) { + DavinciModel model(0, nullptr); + ComputeGraphPtr graph = std::make_shared("test_graph"); + + OpDescPtr data1 = CreateOpDesc("data1", DATA); + GeTensorDesc shape_desc(GeShape({4,3,224,224}), FORMAT_NCHW, DT_FLOAT); + data1->AddInputDesc(shape_desc); + data1->AddOutputDesc(shape_desc); + NodePtr data1_node = graph->AddNode(data1); + + OpDescPtr shape_node = CreateOpDesc("ascend_mbatch_get_dynamic_dims_node", GETDYNAMICDIMS); + GeTensorDesc in_tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); + GeTensorDesc out_tensor(GeShape({4,3}), FORMAT_NCHW, DT_FLOAT); + shape_node->AddInputDesc(in_tensor); + shape_node->AddOutputDesc(out_tensor); + NodePtr get_dynamic_dims_node = graph->AddNode(shape_node); + + OpDescPtr output = CreateOpDesc("output1", NETOUTPUT); + GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); + output->AddInputDesc(tensor); + output->SetSrcName( { "data1", "ascend_mbatch_get_dynamic_dims_node" } ); + output->SetSrcIndex( { 0, 1 } ); + NodePtr output_node = graph->AddNode(output); + GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(get_dynamic_dims_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(1)); + + (void)AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, "1,3;;4,3;,3"); + + model.is_getnext_sink_dynamic_ = true; + model.is_online_infer_dynamic_ = false; + auto ret = model.InitRealSizeAndShapeInfo(graph, output_node); + EXPECT_EQ(ret, SUCCESS); + model.runtime_param_.mem_base = (uint8_t *)0x08000000; + model.runtime_param_.mem_size = 4; + ret = model.InitRealSizeAndShapeInfo(graph, output_node); + EXPECT_EQ(ret, SUCCESS); +} } // namespace ge diff --git a/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc b/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc new file mode 100644 index 00000000..b1cd6d4d --- /dev/null +++ b/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc @@ -0,0 +1,247 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/multi_batch_clone_pass.h" + +#include +#include +#include + +#include "inc/pass_manager.h" +#include "graph/utils/tensor_utils.h" +#include "graph/common/local_context.h" +#include "graph/passes/multi_batch_pass.h" +#include "graph/preprocess/multi_batch_copy_graph.h" +#include "graph/preprocess/insert_op/util_insert_aipp_op.h" +#include "framework/omg/omg_inner_types.h" +#include "register/op_registry.h" + + +namespace ge{ +class UtestMultiBatchClonePass : public testing::Test { +protected: + void SetUp() { + SetLocalOmgContext(domi::GetContext()); + GetLocalOmgContext().dynamic_image_size.clear(); + GetLocalOmgContext().dynamic_batch_size.clear(); + } + void TearDown() { + GetLocalOmgContext().dynamic_image_size.clear(); + GetLocalOmgContext().dynamic_batch_size.clear(); + GetLocalOmgContext().dynamic_node_type.clear(); + } + +public: + NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { + GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); + auto op_desc = std::make_shared(name, type); + for (auto i = 0; i < in_num; ++i) { + op_desc->AddInputDesc(test_desc); + } + for (auto i = 0; i < out_num; ++i) { + op_desc->AddOutputDesc(test_desc); + } + return graph->AddNode(op_desc); + } + + NodePtr MakeConstNode(const ComputeGraphPtr &graph) { + static uint32_t index = 0; + GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); + auto op_desc = std::make_shared("dynamic_const_" + std::to_string(index++), "Const"); + op_desc->AddOutputDesc(test_desc); + return graph->AddNode(op_desc); + } + + void make_original_graph(const ComputeGraphPtr &graph) { + auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); + { + auto data1 = MakeNode(graph, 1, 1, "data", "Data"); + GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT); + data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); + GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector{-1,3,224,224})}; + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); + auto const1 = MakeConstNode(graph); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); + auto const2 = MakeConstNode(graph); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); + } + + auto bn_conv1 = MakeNode(graph, 4, 1, "bn_conv1", "BNInference"); + { + GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(0)); + auto const1 = MakeConstNode(graph); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(1)); + auto const2 = MakeConstNode(graph); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(2)); + auto const3= MakeConstNode(graph); + GraphUtils::AddEdge(const3->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(3)); + } + + auto scale_conv1 = MakeNode(graph, 4, 1, "scale1", "Scale"); + { + GraphUtils::AddEdge(bn_conv1->GetOutDataAnchor(0), scale_conv1->GetInDataAnchor(0)); + auto const1 = MakeConstNode(graph); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), scale_conv1->GetInDataAnchor(1)); + auto const2 = MakeConstNode(graph); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), scale_conv1->GetInDataAnchor(2)); + } + + auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput"); + GraphUtils::AddEdge(scale_conv1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + } + + void GraphWithJustData(const ComputeGraphPtr &graph) { + auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); + { + auto data1 = MakeNode(graph, 1, 1, "data", "Data"); + GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT); + data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); + GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector{-1,3,224,224})}; + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); + auto const1 = MakeConstNode(graph); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); + auto const2 = MakeConstNode(graph); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); + } + + auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput"); + GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + } + + void GraphWithGetNextNosink(const ComputeGraphPtr &graph) { + auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); + { + auto data1 = MakeNode(graph, 1, 1, "IteratorGetNext_data", "Data"); + GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT); + data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); + GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector{-1,3,224,224})}; + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); + auto const1 = MakeConstNode(graph); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); + auto const2 = MakeConstNode(graph); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); + } + + auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput"); + GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + } + + // getnext has one data and has one out of shape + void GraphWithGetNextSink(const ComputeGraphPtr &graph) { + auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); + { + auto data1 = MakeNode(graph, 1, 2, "data", "IteratorV2"); + GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT); + GeTensorDesc shape_desc(GeShape({4,3,224,224}), FORMAT_NCHW, DT_FLOAT); + data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + data1->GetOpDesc()->UpdateOutputDesc(1, shape_desc); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); + GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector{-1,3,224,224})}; + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); + auto identity = MakeNode(graph, 1, 0, "identity", "Identity"); + GraphUtils::AddEdge(data1->GetOutDataAnchor(1), identity->GetInDataAnchor(0)); + auto const1 = MakeConstNode(graph); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); + auto const2 = MakeConstNode(graph); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); + } + + auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput"); + GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + } +}; + +// graph is nullptr +TEST_F(UtestMultiBatchClonePass, graph_nullptr) { + PassManager pass_manager; + pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass); + ComputeGraphPtr graph; + EXPECT_EQ(pass_manager.Run(graph), PARAM_INVALID); +} + +// graph with subgraph +TEST_F(UtestMultiBatchClonePass, graph_with_subgraph) { + PassManager pass_manager; + pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass); + ComputeGraphPtr graph = std::make_shared("test_graph"); + make_original_graph(graph); + EXPECT_EQ(pass_manager.Run(graph), SUCCESS); + + ComputeGraphPtr owner = std::make_shared("test_owner"); + auto func_node = MakeNode(owner, 3, 1, "test_if", "If"); + graph->SetParentNode(func_node); + graph->SetParentGraph(owner); + EXPECT_EQ(pass_manager.Run(graph), SUCCESS); +} + +//graph is uncompute graph, not need to do multi batch +TEST_F(UtestMultiBatchClonePass, uncompute_graph) { + MultiBatchClonePass multi_batch_clone; + ComputeGraphPtr graph = std::make_shared("test_graph"); + make_original_graph(graph); + GetLocalOmgContext().need_multi_batch = false; + EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); +} + + +//compute_graph with data from DATA +TEST_F(UtestMultiBatchClonePass, compute_graph_with_data) { + MultiBatchClonePass multi_batch_clone; + ComputeGraphPtr graph = std::make_shared("test_graph"); + GraphWithJustData(graph); + GetLocalOmgContext().need_multi_batch = true; + EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); + GetLocalOmgContext().dynamic_node_type = DATA; + GetLocalOmgContext().dynamic_dims = "1;2;4;8"; + EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); + EXPECT_EQ(GetLocalOmgContext().data_nodes.size(), 1); +} + +//compute_graph with data from GetNext_nosink +TEST_F(UtestMultiBatchClonePass, compute_graph_with_getnext_nosink) { + MultiBatchClonePass multi_batch_clone; + ComputeGraphPtr graph = std::make_shared("test_graph"); + GraphWithGetNextNosink(graph); + GetLocalOmgContext().need_multi_batch = true; + GetLocalOmgContext().dynamic_node_type = GETNEXT; + GetLocalOmgContext().dynamic_dims = "1;2;4;8"; + EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); + EXPECT_EQ(GetLocalOmgContext().getnext_nosink_nodes.size(), 1); +} + +//compute_graph with data from GetNext_nosink +TEST_F(UtestMultiBatchClonePass, compute_graph_with_getnext_sink) { + MultiBatchClonePass multi_batch_clone; + ComputeGraphPtr graph = std::make_shared("test_graph"); + GraphWithGetNextSink(graph); + GetLocalOmgContext().need_multi_batch = true; + GetLocalOmgContext().dynamic_node_type = GETNEXT; + GetLocalOmgContext().dynamic_dims = "1;2;4;8"; + EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); + EXPECT_EQ(GetLocalOmgContext().getnext_nosink_nodes.size(), 0); +} + +}