diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc index 42830f54fa..c2373d3c7e 100644 --- a/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc +++ b/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace device { namespace ascend { const uint64_t kAscendDeviceMemGB = 20; -const uint64_t kAscendMemPoolGB = 5; +const uint64_t kAscendMemPoolGB = 10; const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30); const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30); diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index 63920cac13..6e2c7be685 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -38,6 +38,7 @@ #include "parallel/graph_util/get_parallel_info.h" #include "device/kernel_runtime_manager.h" #include "debug/trace.h" +#include "pynative/pynative_execute.h" #if (ENABLE_GE || ENABLE_D) #include "pipeline/pipeline_ge.h" @@ -829,6 +830,7 @@ void FinalizeBackend() { void ClearResAtexit() { MS_LOG(DEBUG) << "Pipeline clear all resource"; + pynative::ClearPyNativeSession(); device::KernelRuntimeManager::Instance().ClearRuntimeResource(); ad::g_k_prims.clear(); diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index 821a35d8fb..8d3fe4fbb7 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -44,6 +44,7 @@ const std::set vm_operators = {"partial", "depend", "make_ref", "ze namespace mindspore { namespace pynative { +static std::shared_ptr session = nullptr; inline ValuePtr PyAttrValue(const py::object &obj) { ValuePtr converted_ret = nullptr; bool converted = parse::ConvertData(obj, &converted_ret); @@ -310,7 +311,11 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat if (device_target != kAscendDevice && device_target != kGPUDevice) { MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; } - std::shared_ptr session = session::SessionFactory::Get().Create(device_target); + + if (session == nullptr) { + session = session::SessionFactory::Get().Create(device_target); + } + MS_EXCEPTION_IF_NULL(session); session->Init(ms_context->device_id()); @@ -407,5 +412,7 @@ py::tuple RunOp(const py::args &args) { MS_LOG(INFO) << "RunOp end"; return result; } + +void ClearPyNativeSession() { session = nullptr; } } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pynative/pynative_execute.h b/mindspore/ccsrc/pynative/pynative_execute.h index c64c6b4b25..65be3b2ab2 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pynative/pynative_execute.h @@ -36,6 +36,9 @@ namespace py = pybind11; py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); py::tuple RunOp(const py::args &args); + +void ClearPyNativeSession(); + } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index b15637e7be..bd5fba6d4b 100755 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -249,10 +249,23 @@ void AscendSession::RunOpExecTask(const std::shared_ptr &kernel_gra MS_LOG(INFO) << "Finish!"; } +bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { + if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { + return true; + } + + return false; +} + void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, const std::vector &tensors_mask) { MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; + if (GraphCacheExist(graph_info)) { + MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; + return; + } + // construct graph include one op auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); MS_EXCEPTION_IF_NULL(graph); @@ -267,6 +280,7 @@ void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph RunOpAdjustKernel(graph); BuildKernel(graph); run_op_graphs_[graph_info] = graph; + MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; } py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, @@ -291,7 +305,6 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr } py::object tuple_obj = utils::cast(output_tensors).object_; py::tuple tuple_tensors = py::cast(tuple_obj); - run_op_graphs_.clear(); MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; return tuple_tensors; } diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index eec4e4ea41..1ce236c9c3 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -111,6 +111,8 @@ class AscendSession : public SessionBasic { std::vector &GetGraphOrderType(GraphId final_graph_id); // copy output of if and else void CopyOutputOfIf(GraphId false_graph_id); + // check if graph cache exist + bool GraphCacheExist(const GraphInfo &graph_info) const; // member variables // key is final_graph_id,value is child graph execute order of final graph diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index cb9e5c4dc9..3436d68b81 100755 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -125,7 +125,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne // if in paynative mode,data only copyed to host when user want to print data auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->enable_pynative_infer()) { + if (ms_context->execution_mode() == kPynativeMode) { tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), LongToSize(tensor->data().nbytes()), tensor->data_type(),