!1037 fix shape -1 not in input_shape

From: @jiming6
Reviewed-by: @xchu42,@wqtshg
Signed-off-by:
pull/1037/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 6de4616171

@ -92,8 +92,7 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) {
}
// parser data dynamic info from atc parameter --input_shape
if (multibatch::ParserDataToDynmaicInfo(batch_shapes_, GetLocalOmgContext().user_input_dims,
data_to_dynamic_info_) != SUCCESS) {
if (CheckAndParseDynamicData() != SUCCESS) {
GELOGE(PARAM_INVALID, "Parse each data's own dynamic info failed");
return PARAM_INVALID;
}
@ -177,6 +176,58 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) {
return SUCCESS;
}
Status MultiBatchClonePass::CheckAndParseDynamicData() {
size_t unknown_shape_count = 0;
auto data_name_and_shape = GetLocalOmgContext().user_input_dims;
std::vector<std::string> data_name_order;
for (auto &item : data_name_and_shape) {
data_name_order.push_back(item.first);
}
if (!getnext_sink_dynamic_dims_) {
for (const auto &node : all_data_nodes_) {
auto data_desc = NodeUtils::GetOutputDesc(*node, kDataOutIndex);
auto data_shape = data_desc.GetShape();
auto data_format = data_desc.GetFormat() == Format::FORMAT_NCHW ? "NCHW" :
data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others";
auto data_name = node->GetName();
const auto &data_shape_dims = data_shape.GetDims();
if (std::all_of(data_shape_dims.begin(), data_shape_dims.end(), [](int64_t val) { return val >= 0; })) {
continue;
}
++unknown_shape_count;
auto iter = find(data_name_order.begin(), data_name_order.end(), data_name);
if (iter == data_name_order.end()) {
if (!GetLocalOmgContext().dynamic_batch_size.empty()) {
auto ret = multibatch::CheckDynamicBatchShape(data_shape_dims, data_name);
GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic batch shape of %s.",
data_name.c_str()); return PARAM_INVALID);
} else if (!GetLocalOmgContext().dynamic_image_size.empty()) {
auto ret = multibatch::CheckDynamicImageSizeShape(data_shape_dims, data_name, data_format);
GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic image size shape of %s.",
data_name.c_str()); return PARAM_INVALID);
} else if (!GetLocalOmgContext().dynamic_dims.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "reason"},
{"--input_shape", "all dynamic data must be set in --input_shape"});
GELOGE(INTERNAL_ERROR, "data: %s shape:%s must be set int --input_shape",
node->GetName().c_str(), data_shape.ToString().c_str());
return INTERNAL_ERROR;
}
data_name_and_shape.emplace_back(data_name, data_shape_dims);
}
}
}
auto ret = multibatch::ParserDataToDynamicInfo(batch_shapes_, data_name_and_shape, data_to_dynamic_info_);
GE_CHK_STATUS_RET(ret, "Failed to parse data to dynamic info.");
if (!getnext_sink_dynamic_dims_ && unknown_shape_count == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10040");
GELOGE(PARAM_INVALID,
"Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims");
return PARAM_INVALID;
}
return SUCCESS;
}
Status MultiBatchClonePass::InitParamsOfGetNext(const NodePtr &node) {
data_count_from_getnext_ = 0;
getnext_sink_dynamic_dims_ = false;

@ -175,6 +175,8 @@ class MultiBatchClonePass : public GraphPass {
/// @return 0: SUCCESS / others: FAILED
///
Status UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num);
Status CheckAndParseDynamicData();
std::string session_graph_id_;
std::vector<std::vector<int64_t>> batch_shapes_;

@ -738,7 +738,7 @@ Status MultiBatchGraphCopyer::CheckAndParseDynamicData(){
}
}
}
auto ret = ParserDataToDynmaicInfo(shapes_, data_name_and_shape, data_to_dynamic_info_);
auto ret = ParserDataToDynamicInfo(shapes_, data_name_and_shape, data_to_dynamic_info_);
GE_CHK_STATUS_RET(ret, "Failed to parse data to dynamic info.");
if (!getnext_sink_dynamic_dims_ && unknown_shape_count == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10040");

@ -377,7 +377,7 @@ bool InitDynamicParams(vector<vector<int64_t>> &shapes) {
/// @param [out] map<string, vector<vector<int64_t>>> &data_to_dynamic_info: key:data_name. value:dynamic dims.
/// @return true: Configed for Multi batch / false: Not configed for Multi batch.
///
Status ParserDataToDynmaicInfo(const vector<vector<int64_t>> &shapes,
Status ParserDataToDynamicInfo(const vector<vector<int64_t>> &shapes,
vector<pair<string, vector<int64_t>>> &data_name_and_shape,
map<string, vector<vector<int64_t>> > &data_to_dynamic_info) {
size_t cur_data_index = 0;

@ -74,7 +74,7 @@ Status CalcShape(const std::vector<int64_t> &batch_shape, GeShape &data_shape);
/// @param [out] map<string, vector<vector<int64_t>>> &data_to_dynamic_info: key:data_name. value:dynamic dims.
/// @return SUCCESS / PARAM_INVALID
///
Status ParserDataToDynmaicInfo(const vector<vector<int64_t>> &shapes,
Status ParserDataToDynamicInfo(const vector<vector<int64_t>> &shapes,
vector<pair<string, vector<int64_t>>> &data_name_and_shape,
map<string, vector<vector<int64_t>>> &data_to_dynamic_info);

Loading…
Cancel
Save