|
|
@ -503,12 +503,24 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) {
|
|
|
|
|
|
|
|
|
|
|
|
///
|
|
|
|
///
|
|
|
|
/// @ingroup ge
|
|
|
|
/// @ingroup ge
|
|
|
|
/// @brief Set shape to Data node in branch.
|
|
|
|
/// @brief Update Data node in Subgraph.
|
|
|
|
/// @param [in] const NodePtr &data: data in branch.
|
|
|
|
/// @param [in] const NodePtr &data: data in Subgraph.
|
|
|
|
/// @param [in] size_t index: The batch index.
|
|
|
|
/// @param [in] size_t index: The batch index.
|
|
|
|
/// @return 0: SUCCESS / others: FAILED
|
|
|
|
/// @return 0: SUCCESS / others: FAILED
|
|
|
|
///
|
|
|
|
///
|
|
|
|
Status MultiBatchClonePass::UpdateShapeToData(const NodePtr &data, size_t index) {
|
|
|
|
Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index) {
|
|
|
|
|
|
|
|
int node_index = -1;
|
|
|
|
|
|
|
|
if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) {
|
|
|
|
|
|
|
|
GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str());
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int parent_index = node_index + 1;
|
|
|
|
|
|
|
|
if (!AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
|
|
|
|
|
|
|
|
GELOGE(FAILED, "Failed to set parent index for node %s", data->GetName().c_str());
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
|
|
|
|
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
|
|
|
|
const auto &dims = data_shape.GetDims();
|
|
|
|
const auto &dims = data_shape.GetDims();
|
|
|
|
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
|
|
|
|
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
|
|
|
@ -580,13 +592,15 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const
|
|
|
|
subgraph->SetParentGraph(graph);
|
|
|
|
subgraph->SetParentGraph(graph);
|
|
|
|
graph->AddSubgraph(subgraph->GetName(), subgraph);
|
|
|
|
graph->AddSubgraph(subgraph->GetName(), subgraph);
|
|
|
|
all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT);
|
|
|
|
all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT);
|
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]),
|
|
|
|
|
|
|
|
"Update %s failed", all_branch_output_[subgraph]->GetName().c_str());
|
|
|
|
|
|
|
|
|
|
|
|
const string key_name = "branches" + std::to_string(i);
|
|
|
|
const string key_name = "branches" + std::to_string(i);
|
|
|
|
op_desc->AddSubgraphName(key_name);
|
|
|
|
op_desc->AddSubgraphName(key_name);
|
|
|
|
op_desc->SetSubgraphInstanceName(i, subgraph->GetName());
|
|
|
|
op_desc->SetSubgraphInstanceName(i, subgraph->GetName());
|
|
|
|
|
|
|
|
|
|
|
|
for (const auto &data : input_nodes) {
|
|
|
|
for (const auto &data : input_nodes) {
|
|
|
|
GE_CHK_STATUS_RET(UpdateShapeToData(data, i), "Update %s failed", subgraph->GetName().c_str());
|
|
|
|
GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -595,55 +609,27 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const
|
|
|
|
const auto &op_desc = n->GetOpDesc();
|
|
|
|
const auto &op_desc = n->GetOpDesc();
|
|
|
|
op_desc->SetName(n->GetName() + kMultiBatchNodePostfix + "0");
|
|
|
|
op_desc->SetName(n->GetName() + kMultiBatchNodePostfix + "0");
|
|
|
|
if (n->GetType() == DATA) {
|
|
|
|
if (n->GetType() == DATA) {
|
|
|
|
GE_CHK_STATUS_RET(UpdateShapeToData(n, 0), "Update %s failed", branch->GetName().c_str());
|
|
|
|
GE_CHK_STATUS_RET(UpdateSubgraphData(n, 0), "Update %s failed", branch->GetName().c_str());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return PostProcSubgraph(graph);
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
///
|
|
|
|
///
|
|
|
|
/// @ingroup ge
|
|
|
|
/// @ingroup ge
|
|
|
|
/// @brief Assign parent index for branches.
|
|
|
|
/// @brief Update output_node in Subgraph.
|
|
|
|
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
|
|
|
|
/// @param [in] const NodePtr &output_node: output_node in Subgraph.
|
|
|
|
/// @return 0: SUCCESS / others: FAILED
|
|
|
|
/// @return 0: SUCCESS / others: FAILED
|
|
|
|
///
|
|
|
|
///
|
|
|
|
Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) {
|
|
|
|
Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) {
|
|
|
|
auto func_desc = case_node_->GetOpDesc();
|
|
|
|
const auto &op_desc = output_node->GetOpDesc();
|
|
|
|
domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr;
|
|
|
|
GE_CHECK_NOTNULL(op_desc);
|
|
|
|
auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType());
|
|
|
|
for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) {
|
|
|
|
if (post_func == nullptr) {
|
|
|
|
GeTensorDescPtr tensor = op_desc->MutableInputDesc(index);
|
|
|
|
GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(),
|
|
|
|
GE_CHECK_NOTNULL(tensor);
|
|
|
|
case_node_->GetType().c_str());
|
|
|
|
if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) {
|
|
|
|
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_func_v2) != SUCCESS ||
|
|
|
|
GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str());
|
|
|
|
parse_func_v2 == nullptr) {
|
|
|
|
|
|
|
|
GELOGW("The subgraph new post func v2 for node %s type %s is null", case_node_->GetName().c_str(),
|
|
|
|
|
|
|
|
case_node_->GetType().c_str());
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (const auto &name : func_desc->GetSubgraphInstanceNames()) {
|
|
|
|
|
|
|
|
const auto &subgraph = graph->GetSubgraph(name);
|
|
|
|
|
|
|
|
if (subgraph == nullptr) {
|
|
|
|
|
|
|
|
GELOGE(FAILED, "Subgraph not found, name: %s", name.c_str());
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::string subgraph_name;
|
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(func_desc->GetSubgraphNameByInstanceName(subgraph->GetName(), subgraph_name),
|
|
|
|
|
|
|
|
"Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph);
|
|
|
|
|
|
|
|
Status ret = FAILED;
|
|
|
|
|
|
|
|
if (post_func != nullptr) {
|
|
|
|
|
|
|
|
ret = post_func(subgraph_name, graph);
|
|
|
|
|
|
|
|
} else if (parse_func_v2 != nullptr) {
|
|
|
|
|
|
|
|
ret = parse_func_v2(subgraph_name.c_str(), graph);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
|
|
|
GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(),
|
|
|
|
|
|
|
|
case_node_->GetName().c_str(), case_node_->GetType().c_str());
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|