|
|
|
@ -736,7 +736,9 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto tensor_input = op->MutableInputDesc(0);
|
|
|
|
|
auto tensor_output = op->MutableOutputDesc(0);
|
|
|
|
|
GE_CHECK_NOTNULL(tensor_input);
|
|
|
|
|
GE_CHECK_NOTNULL(tensor_output);
|
|
|
|
|
string data_op_name = op->GetName();
|
|
|
|
|
auto origin_shape = tensor_input->GetShape();
|
|
|
|
|
auto iter = shape_range_map.find(data_op_name);
|
|
|
|
@ -755,6 +757,8 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
|
|
|
|
|
}
|
|
|
|
|
tensor_input->SetShape(origin_shape);
|
|
|
|
|
tensor_input->SetShapeRange(cur_shape_range);
|
|
|
|
|
tensor_output->SetShape(origin_shape);
|
|
|
|
|
tensor_output->SetShapeRange(cur_shape_range);
|
|
|
|
|
GELOGI("Update input [%s] shape range info", data_op_name.c_str());
|
|
|
|
|
} else {
|
|
|
|
|
GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str());
|
|
|
|
@ -763,6 +767,29 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph,
|
|
|
|
|
const map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) {
|
|
|
|
|
for (const auto &it : shape_range_map) {
|
|
|
|
|
std::string node_name = it.first;
|
|
|
|
|
ge::NodePtr node = compute_graph->FindNode(node_name);
|
|
|
|
|
if (node == nullptr) {
|
|
|
|
|
REPORT_INPUT_ERROR("E10016", std::vector<std::string>({"parameter", "opname"}),
|
|
|
|
|
std::vector<std::string>({"input_shape_range", node_name}));
|
|
|
|
|
GELOGE(PARAM_INVALID, "[Check][InputNode]Input parameter[--input_shape_range]'s opname[%s] is not exist in model",
|
|
|
|
|
node_name.c_str());
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
if (node->GetType() != DATA) {
|
|
|
|
|
REPORT_INPUT_ERROR("E10017", std::vector<std::string>({"parameter", "opname"}),
|
|
|
|
|
std::vector<std::string>({"input_shape_range", node_name}));
|
|
|
|
|
GELOGE(PARAM_INVALID, "[Check][InputNode]Input parameter[--input_shape_range]'s opname[%s] is not a input opname",
|
|
|
|
|
node_name.c_str());
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range) {
|
|
|
|
|
if (input_shape_range.empty()) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -775,6 +802,11 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (CheckInputShapeRangeNode(compute_graph, shape_range_map) != SUCCESS) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "[Check][InputShapeRange]check input shape range:%s failed.", input_shape_range.c_str());
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (NodePtr &input_node : compute_graph->GetDirectNode()) {
|
|
|
|
|
GE_CHECK_NOTNULL(input_node);
|
|
|
|
|
OpDescPtr op = input_node->GetOpDesc();
|
|
|
|
@ -788,5 +820,4 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace ge
|
|
|
|
|