|
|
|
@ -18,6 +18,7 @@
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "utils/utils.h"
|
|
|
|
|
#include "backend/optimizer/common/helper.h"
|
|
|
|
|
#include "backend/optimizer/ascend/ascend_helper.h"
|
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "utils/ms_context.h"
|
|
|
|
@ -30,12 +31,12 @@ const BaseRef InsertTransOp::DefinePattern() const {
|
|
|
|
|
return VectorRef({V, Xs});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) {
|
|
|
|
|
bool IsGraphOutput(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
|
|
|
|
auto outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem});
|
|
|
|
|
auto iter = std::find(outputs.begin(), outputs.end(), node);
|
|
|
|
|
if (iter != outputs.end()) {
|
|
|
|
|
if (iter != outputs.end() && GetRealNodeNum(func_graph, node) == 1) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -55,7 +56,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
|
|
|
|
|
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) {
|
|
|
|
|
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
|
|
|
|
|
if (IsGraphOutput(node, func_graph)) {
|
|
|
|
|
return new_node;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|