From 075da9a4b1ad1fe3c6b8ef00558a0f7bc483f037 Mon Sep 17 00:00:00 2001 From: lvliang Date: Tue, 16 Jun 2020 11:02:14 +0800 Subject: [PATCH] pynative-insert-transdata-for-hook-mode --- .../ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc | 2 +- mindspore/ccsrc/utils/context/ms_context.cc | 1 + mindspore/ccsrc/utils/context/ms_context.h | 4 ++++ mindspore/ccsrc/vm/transform.cc | 5 +++++ 4 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc index 95d29eeae0..953f464431 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc @@ -51,7 +51,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kPynativeMode) { + if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) { if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { return new_node; } diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index feb8ed6092..37d11264b6 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -74,6 +74,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { precompile_only_ = false; auto_mixed_precision_flag_ = false; enable_pynative_infer_ = false; + enable_pynative_hook_ = false; enable_dynamic_mem_pool_ = true; graph_memory_max_size_ = "0"; variable_memory_max_size_ = "0"; diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index 7a3da24acb..cfedefe3d5 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -64,6 +64,9 @@ class MsContext { bool enable_pynative_infer() const { return enable_pynative_infer_; } void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; } + bool enable_pynative_hook() const { return enable_pynative_hook_; } + void set_enable_pynative_hook(bool enable_pynative_hook) { enable_pynative_hook_ = enable_pynative_hook; } + bool enable_task_sink() const { return enable_task_sink_; } void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; } @@ -161,6 +164,7 @@ class MsContext { uint32_t device_id_; int execution_mode_; bool enable_pynative_infer_; + bool enable_pynative_hook_; bool save_graphs_flag_; std::string save_graphs_path_; uint32_t tsd_ref_; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index d65363bb5c..e33a68e1c5 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -277,6 +277,11 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { for (auto &prim : cut_list_) { MS_EXCEPTION_IF_NULL(prim); if (prim->name() == node_prim->name()) { + if (prim->name() == prim::kPrimBpropCut->name()) { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + ms_context->set_enable_pynative_hook(true); + } return true; } }