!481 fix infer_time increase when online_infer with dynamic_dims

From: @zhou_lili
Reviewed-by: @youui,@liujunzhu
Signed-off-by: @liujunzhu
pull/481/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1780155010

@ -131,6 +131,8 @@ const int64_t kInvalidDynaimcDimsType = -1;
const char *const kSubstrOfGetNextNosinkName = "IteratorGetNext";
const char *const kShapeDataName = "ascend_mbatch_shape_data";
const char *const kGetNextName = "IteratorV2";
const char *const kExtAttrDataNodes = "data_nodes";
const char *const kExtAttrGetNextNoSink = "getnext_no_sink";
bool IsTailingOptimization() {
string is_tailing_optimization_option;
@ -2731,37 +2733,6 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) {
}
}
Status GraphManager::DistinguishGetNextAndData(ComputeGraphPtr &graph, vector<NodePtr> &data_nodes,
vector<NodePtr> &getnext_nosink_nodes,
vector<NodePtr> &getnext_sink_nodes) {
GELOGD("Start distinguish getnext and data node.");
for (NodePtr &input_node : graph->GetDirectNode()) {
GE_CHECK_NOTNULL(input_node);
OpDescPtr op_desc = input_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (op_desc->GetType() == DATA && op_desc->GetName() != kShapeDataName) {
if (op_desc->GetName().find(kSubstrOfGetNextNosinkName) == string::npos) {
data_nodes.emplace_back(input_node);
} else {
getnext_nosink_nodes.emplace_back(input_node);
}
}
std::string op_type;
auto ret = GetOriginalType(input_node, op_type);
if (ret != SUCCESS) {
GELOGE(FAILED, "Failed to get node %s original type.", input_node->GetName().c_str());
return FAILED;
}
if (op_type == kGetNextName) {
GELOGD("Name of getnext sink is %s.", op_desc->GetName().c_str());
getnext_sink_nodes.emplace_back(input_node);
}
}
GELOGI("data count is %zu, getnext nosink count is %zu, getnext sink count is %zu.", data_nodes.size(),
getnext_nosink_nodes.size(), getnext_sink_nodes.size());
return SUCCESS;
}
void GraphManager::ParseInputsDimsForData(const std::vector<InputTensorInfo> &input_tensor) {
GELOGD("Start parse input dims from data.");
for (size_t i = 0; i < input_tensor.size(); ++i) {
@ -2804,11 +2775,8 @@ Status GraphManager::ParseInputsDims(const std::vector<InputTensorInfo> &input_t
if (!GetLocalOmgContext().dynamic_node_type.empty()) {
vector<NodePtr> data_nodes;
vector<NodePtr> getnext_nosink_nodes;
vector<NodePtr> getnext_sink_nodes;
if (DistinguishGetNextAndData(compute_graph_, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) {
GELOGE(PARAM_INVALID, "Failed to distinguish getnext and data node.");
return PARAM_INVALID;
}
data_nodes = compute_graph_->TryGetExtAttr(kExtAttrDataNodes, data_nodes);
getnext_nosink_nodes = compute_graph_->TryGetExtAttr(kExtAttrGetNextNoSink, getnext_nosink_nodes);
if (GetLocalOmgContext().dynamic_node_type == DATA) {
if (getnext_nosink_nodes.empty()) {
// just data or data+getnext_sink

@ -222,8 +222,6 @@ class GraphManager {
const ComputeGraphPtr &compute_graph, uint64_t session_id,
const GEThreadLocalContext &ge_context);
Status ParseInputsDims(const std::vector<InputTensorInfo> &input_tensor);
Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector<NodePtr> &data_nodes,
vector<NodePtr> &getnext_nosink_nodes, vector<NodePtr> &getnext_sink_nodes);
void ParseInputsDimsForData(const std::vector<InputTensorInfo> &input_tensor);
Status ParseInputsDimsForGetNexNosinkAndData(const vector<NodePtr> &dynamic_nodes,
const std::vector<InputTensorInfo> &input_tensor);

@ -46,6 +46,8 @@ const int kDivisionConst = 2;
const char *const kSubstrOfGetNextNosinkName = "IteratorGetNext";
const char *const kShapeDataName = "ascend_mbatch_shape_data";
const char *const kGetNextName = "IteratorV2";
const char *const kExtAttrDataNodes = "data_nodes";
const char *const kExtAttrGetNextNoSink = "getnext_no_sink";
inline bool IsGetNextType(const NodePtr &node) {
std::string original_type;
@ -97,6 +99,9 @@ Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector<NodePtr> &data_n
}
GELOGI("Data count is %zu, getnext nosink count is %zu, getnext sink count is %zu.", data_nodes.size(),
getnext_nosink_nodes.size(), getnext_sink_nodes.size());
GE_IF_BOOL_EXEC(!graph->SetExtAttr(kExtAttrDataNodes, data_nodes), GELOGW("Set data nodes attr failed.");)
GE_IF_BOOL_EXEC(!graph->SetExtAttr(kExtAttrGetNextNoSink, getnext_nosink_nodes),
GELOGW("Set getnext nosink nodes attr failed.");)
return SUCCESS;
}

Loading…
Cancel
Save