|
|
|
@ -16,11 +16,13 @@
|
|
|
|
|
|
|
|
|
|
#include "pre_activate/ascend/format_type/insert_trans_op.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "utils/utils.h"
|
|
|
|
|
#include "pre_activate/ascend/ascend_helper.h"
|
|
|
|
|
#include "session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "device/kernel_info.h"
|
|
|
|
|
#include "kernel/oplib/oplib.h"
|
|
|
|
|
#include "utils/context/ms_context.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
@ -30,6 +32,15 @@ const BaseRef InsertTransOp::DefinePattern() const {
|
|
|
|
|
return VectorRef({V, Xs});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) {
|
|
|
|
|
auto iter = std::find(outputs.begin(), outputs.end(), node);
|
|
|
|
|
if (iter != outputs.end()) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const EquivPtr &) const {
|
|
|
|
|
if (node == nullptr || !AnfAlgo::IsRealKernel(node)) {
|
|
|
|
@ -38,6 +49,13 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
|
|
|
|
MS_LOG(DEBUG) << "====process op: " << node->DebugString();
|
|
|
|
|
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
if (ms_context->execution_mode() == kPynativeMode) {
|
|
|
|
|
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
|
|
|
|
|
return new_node;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return InsertTransOpForOutput(func_graph, new_node, kernel_select_);
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|