add mark agnostic pass

pull/395/head
zhou_chao1993 4 years ago
parent 4d3355152d
commit 0c111a4da6

@ -16,6 +16,7 @@
#include "graph/passes/mark_agnostic_pass.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/tensor_utils.h"
namespace ge {
Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
@ -47,6 +48,16 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
}
if (node_type == MERGE) {
GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str());
auto in_nodes = node->GetInAllNodes();
vector<NodePtr> input_nodes(in_nodes.begin(), in_nodes.end());
/// Enter-----------+
/// +-> Merge
/// NextIteration---+
if (input_nodes.size() == 2) {
if (input_nodes[0]->GetType() == ENTER && input_nodes[1]->GetType() == NEXTITERATION) {
continue;
}
}
const OpDescPtr op_desc = node->GetOpDesc();
const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0);
if (op_tensor == nullptr) {

@ -117,6 +117,7 @@
#include "graph/passes/variable_op_pass.h"
#include "graph/passes/variable_prepare_op_pass.h"
#include "graph/passes/variable_ref_delete_op_pass.h"
#include "graph/passes/mark_agnostic_pass.h"
namespace ge {
@ -1700,6 +1701,7 @@ Status GraphPrepare::PrepareOptimize() {
try {
(void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass);
(void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass);
(void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass" , new MarkAgnosticPass);
} catch (std::bad_alloc &e) {
GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs.");
return INTERNAL_ERROR;

Loading…
Cancel
Save