CondRemovePass

pull/984/head
lianghao 4 years ago
parent 435795139f
commit b598ea75cd

@ -234,7 +234,7 @@ Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, c
const auto &output_desc_size = node->GetOpDesc()->GetOutputsSize();
// Create subgraph opdesc & node
auto partitioncall_opdesc =
CreateSubgraphOpDesc(save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size);
CreateSubgraphOpDesc(node, save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size);
auto partitioncall_node = node->GetOwnerComputeGraph()->AddNode(partitioncall_opdesc);
// Link node's peerout anchors to new node's inanchors
for (const auto &input_anchor : node->GetAllInAnchors()) {
@ -289,7 +289,8 @@ Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, c
/// @param [in] output_num
/// @return OpDescPtr
///
OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const std::string &name, size_t input_num, size_t output_num) {
OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const NodePtr &node, const std::string &name, size_t input_num,
size_t output_num) {
OpDescBuilder op_desc_builder(name, PARTITIONEDCALL);
op_desc_builder.AddDynamicInput("args", input_num).AddDynamicOutput("output", output_num);
@ -299,6 +300,16 @@ OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const std::string &name, size_t i
size_t index = op_desc->GetSubgraphInstanceNames().size();
op_desc->AddSubgraphName("f");
op_desc->SetSubgraphInstanceName(static_cast<uint32_t>(index), name);
auto node_desc = node->GetOpDesc();
GE_CHECK_NOTNULL_EXEC(node_desc, return nullptr);
for (size_t i = 0; i < input_num; ++i) {
(void)op_desc->UpdateInputDesc(i, node_desc->GetInputDesc(i + 1));
}
for (size_t i = 0; i < output_num; ++i) {
(void)op_desc->UpdateOutputDesc(i, node_desc->GetOutputDesc(i));
}
return op_desc;
}

@ -70,7 +70,7 @@ class CondRemovePass : public BaseNodePass {
///
Status ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, const ComputeGraphPtr &save_branch);
OpDescPtr CreateSubgraphOpDesc(const std::string &name, size_t input_num, size_t output_num);
OpDescPtr CreateSubgraphOpDesc(const NodePtr &node, const std::string &name, size_t input_num, size_t output_num);
int32_t GetCondIndex(const ConstGeTensorPtr &tensor);
};

Loading…
Cancel
Save