optimize transdata for pynative

pull/1787/head
chujinjin 5 years ago
parent 5812c46ecf
commit 7465abc798

@ -16,11 +16,13 @@
#include "pre_activate/ascend/format_type/insert_trans_op.h"
#include <memory>
#include <vector>
#include "utils/utils.h"
#include "pre_activate/ascend/ascend_helper.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "kernel/oplib/oplib.h"
#include "utils/context/ms_context.h"
namespace mindspore {
namespace opt {
@ -30,6 +32,15 @@ const BaseRef InsertTransOp::DefinePattern() const {
return VectorRef({V, Xs});
}
bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) {
auto iter = std::find(outputs.begin(), outputs.end(), node);
if (iter != outputs.end()) {
return true;
}
return false;
}
const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealKernel(node)) {
@ -38,6 +49,13 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
MS_LOG(DEBUG) << "====process op: " << node->DebugString();
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
return new_node;
}
}
return InsertTransOpForOutput(func_graph, new_node, kernel_select_);
}
} // namespace opt

@ -21,6 +21,7 @@
#include "pre_activate/common/pass_manager.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "utils/context/ms_context.h"
#define private public
#define protected public
@ -103,6 +104,9 @@ TEST_F(TestHWInsertTransOp, test_insert_trans_op_for_single_output) {
* return output
*
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
auto fg = GetSingleOutputGraph("test_insert_trans_op_for_single_output", "before", "NC1HWC0");
// Do insert_trans_op_ pass of hardware opt
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();

@ -20,6 +20,8 @@
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
#include "debug/anf_ir_dump.h"
#include "utils/context/ms_context.h"
#define private public
#define protected public
#include "pre_activate/ascend/format_type/insert_trans_op.h"
@ -91,6 +93,9 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
* transdata = Transdata(transpose)
* return transdata
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transdata_split_fraz_nchw", "before");
std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);

@ -19,6 +19,7 @@
#include "device/kernel_info.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
#include "utils/context/ms_context.h"
#define private public
#define protected public
#include "pre_activate/ascend/format_type/insert_trans_op.h"
@ -76,6 +77,9 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
* transdata = Transdata(transpose)
* return transdata
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_transdata_fusion", "before");
std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);

@ -30,6 +30,7 @@
#include "utils/context/ms_context.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "utils/context/ms_context.h"
#define private public
#define protected public
@ -71,6 +72,9 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
* output = make_tuple(res)
* return output
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_eliminate_5to4_4to5", "before");
// Renormalize func_graph to infer and set shape and type information.
std::vector<int> shp{2, 32, 224, 224};

Loading…
Cancel
Save