add input shape range node check

pull/1298/head
zhengyuanhua 4 years ago
parent 0664647c5a
commit 192dc75f1f

@ -722,7 +722,9 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
} }
auto tensor_input = op->MutableInputDesc(0); auto tensor_input = op->MutableInputDesc(0);
auto tensor_output = op->MutableOutputDesc(0);
GE_CHECK_NOTNULL(tensor_input); GE_CHECK_NOTNULL(tensor_input);
GE_CHECK_NOTNULL(tensor_output);
string data_op_name = op->GetName(); string data_op_name = op->GetName();
auto origin_shape = tensor_input->GetShape(); auto origin_shape = tensor_input->GetShape();
auto iter = shape_range_map.find(data_op_name); auto iter = shape_range_map.find(data_op_name);
@ -741,6 +743,8 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
} }
tensor_input->SetShape(origin_shape); tensor_input->SetShape(origin_shape);
tensor_input->SetShapeRange(cur_shape_range); tensor_input->SetShapeRange(cur_shape_range);
tensor_output->SetShape(origin_shape);
tensor_output->SetShapeRange(cur_shape_range);
GELOGI("Update input [%s] shape range info", data_op_name.c_str()); GELOGI("Update input [%s] shape range info", data_op_name.c_str());
} else { } else {
GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str()); GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str());
@ -749,6 +753,29 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
return SUCCESS; return SUCCESS;
} }
static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph,
map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) {
for (const auto &it : shape_range_map) {
std::string node_name = it.first;
ge::NodePtr node = compute_graph->FindNode(node_name);
if (node == nullptr) {
REPORT_INPUT_ERROR("E10016", std::vector<std::string>({"parameter", "opname"}),
std::vector<std::string>({"input_shape_range", node_name}));
GELOGE(PARAM_INVALID, "[Check][InputNode]Input parameter[--input_shape_range]'s opname[%s] is not exist in model",
node_name.c_str());
return PARAM_INVALID;
}
if (node->GetType() != DATA) {
REPORT_INPUT_ERROR("E10017", std::vector<std::string>({"parameter", "opname"}),
std::vector<std::string>({"input_shape_range", node_name}));
GELOGE(PARAM_INVALID, "[Check][InputNode]Input parameter[--input_shape_range]'s opname[%s] is not a input opname",
node_name.c_str());
return PARAM_INVALID;
}
}
return SUCCESS;
}
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range) { Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range) {
if (input_shape_range.empty()) { if (input_shape_range.empty()) {
return SUCCESS; return SUCCESS;
@ -757,7 +784,12 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co
map<string, vector<pair<int64_t, int64_t>>> shape_range_map; map<string, vector<pair<int64_t, int64_t>>> shape_range_map;
if (!ParseInputShapeRange(input_shape_range, shape_range_map)) { if (!ParseInputShapeRange(input_shape_range, shape_range_map)) {
GELOGE(PARAM_INVALID, "Parse input shape range failed."); GELOGE(PARAM_INVALID, "[Update][InputShapeRange]Parse input shape range failed.");
return PARAM_INVALID;
}
if (CheckInputShapeRangeNode(compute_graph, shape_range_map) != SUCCESS) {
GELOGE(PARAM_INVALID, "[Update][InputShapeRange]Parse input shape range failed.");
return PARAM_INVALID; return PARAM_INVALID;
} }
@ -767,7 +799,7 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co
GE_CHECK_NOTNULL(op); GE_CHECK_NOTNULL(op);
if (op->GetType() == DATA) { if (op->GetType() == DATA) {
if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) { if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) {
GELOGE(FAILED, "Update data op [%s] input shape range failed.", op->GetName().c_str()); GELOGE(FAILED, "[Update][InputShapeRange]Update data op [%s] input shape range failed.", op->GetName().c_str());
return FAILED; return FAILED;
} }
} }

@ -99,8 +99,9 @@ static void ParseAtcParms(const std::map<std::string, std::string> &atc_params,
} }
} }
static Status CheckInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input, RunMode run_mode) { static Status CheckInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input,
if (!is_dynamic_input && run_mode != MODEL_TO_JSON) { const std::string &input_shape_range, RunMode run_mode) {
if (!is_dynamic_input && run_mode != MODEL_TO_JSON && input_shape_range.empty()) {
for (auto node : graph->GetDirectNode()) { for (auto node : graph->GetDirectNode()) {
if (node->GetType() == DATA) { if (node->GetType() == DATA) {
auto data_op_desc = node->GetOpDesc(); auto data_op_desc = node->GetOpDesc();
@ -760,8 +761,9 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri
ParseAtcParms(atc_params, "is_input_adjust_hw_layout", is_input_adjust_hw_layout); ParseAtcParms(atc_params, "is_input_adjust_hw_layout", is_input_adjust_hw_layout);
compute_graph = GraphUtils::GetComputeGraph(graph); compute_graph = GraphUtils::GetComputeGraph(graph);
GE_RETURN_IF_ERROR(CheckInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout)); GE_RETURN_IF_ERROR(CheckInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout));
std::string input_shape_range;
GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input, run_mode)); ParseAtcParms(atc_params, INPUT_SHAPE_RANGE, input_shape_range);
GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input, input_shape_range, run_mode));
// Verify the contents of the op_name_map // Verify the contents of the op_name_map
if (op_conf != nullptr && *op_conf != '\0') { if (op_conf != nullptr && *op_conf != '\0') {
@ -790,8 +792,6 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "ATC weights parse ret fail."); GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "ATC weights parse ret fail.");
// parser input shape range and update op shape range // parser input shape range and update op shape range
std::string input_shape_range;
ParseAtcParms(atc_params, INPUT_SHAPE_RANGE, input_shape_range);
GE_RETURN_WITH_LOG_IF_ERROR(UpdateDynamicInputShapeRange(compute_graph, input_shape_range), GE_RETURN_WITH_LOG_IF_ERROR(UpdateDynamicInputShapeRange(compute_graph, input_shape_range),
"Update input shape range failed"); "Update input shape range failed");

@ -97,4 +97,14 @@ TEST(UtestIrCommon, update_dynamic_shape_range_failed) {
input_shape_range = "input1:[1, 2~-3, -1]"; input_shape_range = "input1:[1, 2~-3, -1]";
ret = UpdateDynamicInputShapeRange(graph, input_shape_range); ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
EXPECT_EQ(ret, ge::PARAM_INVALID); EXPECT_EQ(ret, ge::PARAM_INVALID);
//5
input_shape_range = "input:[1, 2~3, -1]";
ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
EXPECT_EQ(ret, ge::PARAM_INVALID);
//6
input_shape_range = "addn1:[1, 2~3, -1]";
ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
EXPECT_EQ(ret, ge::PARAM_INVALID);
} }

Loading…
Cancel
Save