|
|
@ -16,6 +16,7 @@
|
|
|
|
#include "graph/passes/mark_agnostic_pass.h"
|
|
|
|
#include "graph/passes/mark_agnostic_pass.h"
|
|
|
|
|
|
|
|
|
|
|
|
#include "graph/utils/node_utils.h"
|
|
|
|
#include "graph/utils/node_utils.h"
|
|
|
|
|
|
|
|
#include "graph/utils/tensor_utils.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
namespace ge {
|
|
|
|
Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
|
|
|
|
Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
|
|
|
@ -47,6 +48,16 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (node_type == MERGE) {
|
|
|
|
if (node_type == MERGE) {
|
|
|
|
GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str());
|
|
|
|
GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str());
|
|
|
|
|
|
|
|
auto in_nodes = node->GetInAllNodes();
|
|
|
|
|
|
|
|
vector<NodePtr> input_nodes(in_nodes.begin(), in_nodes.end());
|
|
|
|
|
|
|
|
/// Enter-----------+
|
|
|
|
|
|
|
|
/// +-> Merge
|
|
|
|
|
|
|
|
/// NextIteration---+
|
|
|
|
|
|
|
|
if (input_nodes.size() == 2) {
|
|
|
|
|
|
|
|
if (input_nodes[0]->GetType() == ENTER && input_nodes[1]->GetType() == NEXTITERATION) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
const OpDescPtr op_desc = node->GetOpDesc();
|
|
|
|
const OpDescPtr op_desc = node->GetOpDesc();
|
|
|
|
const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0);
|
|
|
|
const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0);
|
|
|
|
if (op_tensor == nullptr) {
|
|
|
|
if (op_tensor == nullptr) {
|
|
|
|