From df17218f413a9aba4b36ddab209a414e2e2f1cfa Mon Sep 17 00:00:00 2001 From: wjm Date: Thu, 28 Jan 2021 14:10:55 +0800 Subject: [PATCH 1/5] fix shape -1 not in input_shape --- ge/graph/passes/multi_batch_clone_pass.cc | 56 ++++++++++++++++++++++- ge/graph/passes/multi_batch_clone_pass.h | 2 + 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index b7efa070..3bfa5727 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -92,8 +92,7 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { } // parser data dynamic info from atc parameter --input_shape - if (multibatch::ParserDataToDynmaicInfo(batch_shapes_, GetLocalOmgContext().user_input_dims, - data_to_dynamic_info_) != SUCCESS) { + if (CheckAndParseDynamicData() != SUCCESS) { GELOGE(PARAM_INVALID, "Parse each data's own dynamic info failed"); return PARAM_INVALID; } @@ -177,6 +176,59 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { return SUCCESS; } +Status MultiBatchClonePass::CheckAndParseDynamicData(){ + size_t unknown_shape_count = 0; + auto data_name_and_shape = GetLocalOmgContext().user_input_dims; + std::vector data_name_order; + for (auto &item : GetLocalOmgContext().user_input_dims) { + data_name_order.push_back(item.first); + } + if (!getnext_sink_dynamic_dims_) { + for (const auto &node : all_data_nodes_) { + auto data_desc = NodeUtils::GetOutputDesc(*node, kDataOutIndex); + auto data_shape = data_desc.GetShape(); + auto data_format = data_desc.GetFormat() == Format::FORMAT_NCHW ? "NCHW" : + data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others"; + auto data_name = node->GetName(); + GELOGI("CheckAndParseDynamicData shape_dims is %s.", formats::JoinToString(data_shape.GetDims()).c_str()); + + std::vector data_shape_dims = data_shape.GetDims(); + if (std::all_of(data_shape_dims.begin(), data_shape_dims.end(), [](int64_t val) { return val >= 0; })) { + continue; + } + ++unknown_shape_count; + auto iter = find(data_name_order.begin(), data_name_order.end(), data_name); + if (iter == data_name_order.end()) { + if (!GetLocalOmgContext().dynamic_batch_size.empty()) { + auto ret = multibatch::CheckDynamicBatchShape(data_shape_dims, data_name); + GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic batch shape of %s.", + data_name.c_str()); return PARAM_INVALID); + } else if (!GetLocalOmgContext().dynamic_image_size.empty()) { + auto ret = multibatch::CheckDynamicImageSizeShape(data_shape_dims, data_name, data_format); + GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic image size shape of %s.", + data_name.c_str()); return PARAM_INVALID); + } else if (!GetLocalOmgContext().dynamic_dims.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001",{"parameter", "reason"}, + {"--input_shape","all dynamic data must be set in --input_shape"}); + GELOGE(INTERNAL_ERROR, "data: %s shape:%s must be set int --input_shape", + node->GetName().c_str(), data_shape.ToString().c_str()); + return INTERNAL_ERROR; + } + data_name_and_shape.emplace_back(data_name, data_shape_dims); + } + } + } + auto ret = multibatch::ParserDataToDynmaicInfo(batch_shapes_, data_name_and_shape, data_to_dynamic_info_); + GE_CHK_STATUS_RET(ret, "Failed to parse data to dynamic info."); + if (!getnext_sink_dynamic_dims_ && unknown_shape_count == 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10040"); + GELOGE(PARAM_INVALID, + "Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims"); + return PARAM_INVALID; + } + return SUCCESS; +} + Status MultiBatchClonePass::InitParamsOfGetNext(const NodePtr &node) { data_count_from_getnext_ = 0; getnext_sink_dynamic_dims_ = false; diff --git a/ge/graph/passes/multi_batch_clone_pass.h b/ge/graph/passes/multi_batch_clone_pass.h index 66e92892..0dae88ca 100755 --- a/ge/graph/passes/multi_batch_clone_pass.h +++ b/ge/graph/passes/multi_batch_clone_pass.h @@ -175,6 +175,8 @@ class MultiBatchClonePass : public GraphPass { /// @return 0: SUCCESS / others: FAILED /// Status UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num); + + Status CheckAndParseDynamicData(); std::string session_graph_id_; std::vector> batch_shapes_; From 9284635a22ff08f8b3ff0125e8515c639b4f3a34 Mon Sep 17 00:00:00 2001 From: wjm Date: Thu, 28 Jan 2021 14:44:02 +0800 Subject: [PATCH 2/5] fix format --- ge/graph/passes/multi_batch_clone_pass.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 3bfa5727..0b734b60 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -180,7 +180,7 @@ Status MultiBatchClonePass::CheckAndParseDynamicData(){ size_t unknown_shape_count = 0; auto data_name_and_shape = GetLocalOmgContext().user_input_dims; std::vector data_name_order; - for (auto &item : GetLocalOmgContext().user_input_dims) { + for (auto &item : data_name_and_shape) { data_name_order.push_back(item.first); } if (!getnext_sink_dynamic_dims_) { @@ -190,7 +190,6 @@ Status MultiBatchClonePass::CheckAndParseDynamicData(){ auto data_format = data_desc.GetFormat() == Format::FORMAT_NCHW ? "NCHW" : data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others"; auto data_name = node->GetName(); - GELOGI("CheckAndParseDynamicData shape_dims is %s.", formats::JoinToString(data_shape.GetDims()).c_str()); std::vector data_shape_dims = data_shape.GetDims(); if (std::all_of(data_shape_dims.begin(), data_shape_dims.end(), [](int64_t val) { return val >= 0; })) { @@ -208,8 +207,8 @@ Status MultiBatchClonePass::CheckAndParseDynamicData(){ GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic image size shape of %s.", data_name.c_str()); return PARAM_INVALID); } else if (!GetLocalOmgContext().dynamic_dims.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10001",{"parameter", "reason"}, - {"--input_shape","all dynamic data must be set in --input_shape"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "reason"}, + {"--input_shape", "all dynamic data must be set in --input_shape"}); GELOGE(INTERNAL_ERROR, "data: %s shape:%s must be set int --input_shape", node->GetName().c_str(), data_shape.ToString().c_str()); return INTERNAL_ERROR; From 6cbb1d3f02efb794620ad347e6fbbdffd3b98ce4 Mon Sep 17 00:00:00 2001 From: wjm Date: Thu, 28 Jan 2021 14:46:18 +0800 Subject: [PATCH 3/5] fix format --- ge/graph/passes/multi_batch_clone_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 0b734b60..56895991 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -176,7 +176,7 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { return SUCCESS; } -Status MultiBatchClonePass::CheckAndParseDynamicData(){ +Status MultiBatchClonePass::CheckAndParseDynamicData() { size_t unknown_shape_count = 0; auto data_name_and_shape = GetLocalOmgContext().user_input_dims; std::vector data_name_order; From 03becccde38606cbd0e1629ccefcb9daa16376c5 Mon Sep 17 00:00:00 2001 From: wjm Date: Thu, 28 Jan 2021 19:32:54 +0800 Subject: [PATCH 4/5] fix --- ge/graph/passes/multi_batch_clone_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 56895991..f4a55419 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -191,7 +191,7 @@ Status MultiBatchClonePass::CheckAndParseDynamicData() { data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others"; auto data_name = node->GetName(); - std::vector data_shape_dims = data_shape.GetDims(); + const auto &data_shape_dims = data_shape.GetDims(); if (std::all_of(data_shape_dims.begin(), data_shape_dims.end(), [](int64_t val) { return val >= 0; })) { continue; } From f1e2d6cef543840152886bf2a4a432ed2dc59268 Mon Sep 17 00:00:00 2001 From: wjm Date: Thu, 28 Jan 2021 20:36:16 +0800 Subject: [PATCH 5/5] fix name --- ge/graph/passes/multi_batch_clone_pass.cc | 2 +- ge/graph/preprocess/multi_batch_copy_graph.cc | 2 +- ge/graph/preprocess/multi_batch_options.cc | 2 +- ge/graph/preprocess/multi_batch_options.h | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index f4a55419..fe820df4 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -217,7 +217,7 @@ Status MultiBatchClonePass::CheckAndParseDynamicData() { } } } - auto ret = multibatch::ParserDataToDynmaicInfo(batch_shapes_, data_name_and_shape, data_to_dynamic_info_); + auto ret = multibatch::ParserDataToDynamicInfo(batch_shapes_, data_name_and_shape, data_to_dynamic_info_); GE_CHK_STATUS_RET(ret, "Failed to parse data to dynamic info."); if (!getnext_sink_dynamic_dims_ && unknown_shape_count == 0) { ErrorManager::GetInstance().ATCReportErrMessage("E10040"); diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index 5506435e..d1be70a8 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -738,7 +738,7 @@ Status MultiBatchGraphCopyer::CheckAndParseDynamicData(){ } } } - auto ret = ParserDataToDynmaicInfo(shapes_, data_name_and_shape, data_to_dynamic_info_); + auto ret = ParserDataToDynamicInfo(shapes_, data_name_and_shape, data_to_dynamic_info_); GE_CHK_STATUS_RET(ret, "Failed to parse data to dynamic info."); if (!getnext_sink_dynamic_dims_ && unknown_shape_count == 0) { ErrorManager::GetInstance().ATCReportErrMessage("E10040"); diff --git a/ge/graph/preprocess/multi_batch_options.cc b/ge/graph/preprocess/multi_batch_options.cc index 8aab0981..a0b0be8d 100644 --- a/ge/graph/preprocess/multi_batch_options.cc +++ b/ge/graph/preprocess/multi_batch_options.cc @@ -377,7 +377,7 @@ bool InitDynamicParams(vector> &shapes) { /// @param [out] map>> &data_to_dynamic_info: key:data_name. value:dynamic dims. /// @return true: Configed for Multi batch / false: Not configed for Multi batch. /// -Status ParserDataToDynmaicInfo(const vector> &shapes, +Status ParserDataToDynamicInfo(const vector> &shapes, vector>> &data_name_and_shape, map> > &data_to_dynamic_info) { size_t cur_data_index = 0; diff --git a/ge/graph/preprocess/multi_batch_options.h b/ge/graph/preprocess/multi_batch_options.h index 9baf4f43..bfe96ea2 100644 --- a/ge/graph/preprocess/multi_batch_options.h +++ b/ge/graph/preprocess/multi_batch_options.h @@ -74,7 +74,7 @@ Status CalcShape(const std::vector &batch_shape, GeShape &data_shape); /// @param [out] map>> &data_to_dynamic_info: key:data_name. value:dynamic dims. /// @return SUCCESS / PARAM_INVALID /// -Status ParserDataToDynmaicInfo(const vector> &shapes, +Status ParserDataToDynamicInfo(const vector> &shapes, vector>> &data_name_and_shape, map>> &data_to_dynamic_info);