remove data_pass

pull/704/head
wjm 4 years ago
parent f91fdad16a
commit 113b80da8b

@ -185,17 +185,13 @@ Status DataPass::Run(ComputeGraphPtr compute_graph) {
const auto &parent_graph = compute_graph->GetParentGraph();
GE_CHECK_NOTNULL(parent_graph);
bool flag = false;
(void)AttrUtils::GetBool(compute_graph, "_no_reset_name", flag);
if (!flag) {
for (const NodePtr &node : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node->GetOpDesc());
if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) {
continue;
}
node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName());
for (const NodePtr &node : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node->GetOpDesc());
if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) {
continue;
}
node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName());
}
return PostParseSubgraph(compute_graph, subgraph_name, parent_node);

@ -503,12 +503,24 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) {
///
/// @ingroup ge
/// @brief Set shape to Data node in branch.
/// @param [in] const NodePtr &data: data in branch.
/// @brief Update Data node in Subgraph.
/// @param [in] const NodePtr &data: data in Subgraph.
/// @param [in] size_t index: The batch index.
/// @return 0: SUCCESS / others: FAILED
///
Status MultiBatchClonePass::UpdateShapeToData(const NodePtr &data, size_t index) {
Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index) {
int node_index = -1;
if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) {
GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str());
return FAILED;
}
int parent_index = node_index + 1;
if (!AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
GELOGE(FAILED, "Failed to set parent index for node %s", data->GetName().c_str());
return FAILED;
}
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
const auto &dims = data_shape.GetDims();
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
@ -581,13 +593,15 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const
(void)AttrUtils::SetBool(subgraph, "_no_reset_name", true);
graph->AddSubgraph(subgraph->GetName(), subgraph);
all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT);
GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]),
"Update %s failed", all_branch_output_[subgraph]->GetName().c_str());
const string key_name = "branches" + std::to_string(i);
op_desc->AddSubgraphName(key_name);
op_desc->SetSubgraphInstanceName(i, subgraph->GetName());
for (const auto &data : input_nodes) {
GE_CHK_STATUS_RET(UpdateShapeToData(data, i), "Update %s failed", subgraph->GetName().c_str());
GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str());
}
}
@ -596,7 +610,28 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const
const auto &op_desc = n->GetOpDesc();
op_desc->SetName(n->GetName() + kMultiBatchNodePostfix + "0");
if (n->GetType() == DATA) {
GE_CHK_STATUS_RET(UpdateShapeToData(n, 0), "Update %s failed", branch->GetName().c_str());
GE_CHK_STATUS_RET(UpdateSubgraphData(n, 0), "Update %s failed", branch->GetName().c_str());
}
}
return SUCCESS;
}
///
/// @ingroup ge
/// @brief Update output_node in Subgraph.
/// @param [in] const NodePtr &data: output_node in Subgraph.
/// @return 0: SUCCESS / others: FAILED
///
Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) {
const auto &op_desc = output_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) {
GeTensorDescPtr tensor = op_desc->MutableInputDesc(index);
GE_CHECK_NOTNULL(tensor);
if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) {
GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str());
return FAILED;
}
}

@ -105,12 +105,20 @@ class MultiBatchClonePass : public GraphPass {
///
/// @ingroup ge
/// @brief Set shape to Data node in branch.
/// @param [in] const NodePtr &data: data in branch.
/// @brief Update Data node in Subgraph.
/// @param [in] const NodePtr &data: data in Subgraph.
/// @param [in] size_t index: The batch index.
/// @return 0: SUCCESS / others: FAILED
///
Status UpdateShapeToData(const NodePtr &data, size_t index);
Status UpdateSubgraphData(const NodePtr &data, size_t index);
///
/// @ingroup ge
/// @brief Update output_node in Subgraph.
/// @param [in] const NodePtr &data: output_node in Subgraph.
/// @return 0: SUCCESS / others: FAILED
///
Status UpdateSubgraphOutput(const NodePtr &output_node);
///
/// @ingroup ge

@ -29,7 +29,6 @@
#include "framework/omg/omg_inner_types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_context.h"
#include "graph/passes/data_pass.h"
#include "graph/passes/multi_batch_clone_pass.h"
#include "graph/passes/prune_pass.h"
#include "graph/preprocess/multi_batch_options.h"
@ -1698,7 +1697,6 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) {
if (multi_batch_with_switchn == nullptr) {
PassManager pass_manager;
GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass));
GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass)); // set subgraph parent index
return pass_manager.Run(graph);
}
}

Loading…
Cancel
Save