|
|
|
@ -234,7 +234,7 @@ Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, c
|
|
|
|
|
const auto &output_desc_size = node->GetOpDesc()->GetOutputsSize();
|
|
|
|
|
// Create subgraph opdesc & node
|
|
|
|
|
auto partitioncall_opdesc =
|
|
|
|
|
CreateSubgraphOpDesc(save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size);
|
|
|
|
|
CreateSubgraphOpDesc(node, save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size);
|
|
|
|
|
auto partitioncall_node = node->GetOwnerComputeGraph()->AddNode(partitioncall_opdesc);
|
|
|
|
|
// Link node's peerout anchors to new node's inanchors
|
|
|
|
|
for (const auto &input_anchor : node->GetAllInAnchors()) {
|
|
|
|
@ -289,7 +289,8 @@ Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, c
|
|
|
|
|
/// @param [in] output_num
|
|
|
|
|
/// @return OpDescPtr
|
|
|
|
|
///
|
|
|
|
|
OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const std::string &name, size_t input_num, size_t output_num) {
|
|
|
|
|
OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const NodePtr &node, const std::string &name, size_t input_num,
|
|
|
|
|
size_t output_num) {
|
|
|
|
|
OpDescBuilder op_desc_builder(name, PARTITIONEDCALL);
|
|
|
|
|
op_desc_builder.AddDynamicInput("args", input_num).AddDynamicOutput("output", output_num);
|
|
|
|
|
|
|
|
|
@ -299,6 +300,16 @@ OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const std::string &name, size_t i
|
|
|
|
|
size_t index = op_desc->GetSubgraphInstanceNames().size();
|
|
|
|
|
op_desc->AddSubgraphName("f");
|
|
|
|
|
op_desc->SetSubgraphInstanceName(static_cast<uint32_t>(index), name);
|
|
|
|
|
|
|
|
|
|
auto node_desc = node->GetOpDesc();
|
|
|
|
|
GE_CHECK_NOTNULL_EXEC(node_desc, return nullptr);
|
|
|
|
|
for (size_t i = 0; i < input_num; ++i) {
|
|
|
|
|
(void)op_desc->UpdateInputDesc(i, node_desc->GetInputDesc(i + 1));
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < output_num; ++i) {
|
|
|
|
|
(void)op_desc->UpdateOutputDesc(i, node_desc->GetOutputDesc(i));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return op_desc;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|