diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index fea180b7b4..817ec6509a 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -304,7 +304,7 @@ void ExecutorPy::DelNetRes(const std::string &id) { void ExecutorPy::ClearRes() { MS_LOG(INFO) << "Clean executor resource!"; - Resource::ClearPrimitivePyPythonObj(); + Resource::mem_cleaner().ClearPrimitivePyPythonObj(); executor_ = nullptr; } diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index ff3168a3b8..ece231bde4 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -277,39 +277,78 @@ Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) { return GetMethodOrAttr(name, type_id, attr_map); } -std::unordered_map Resource::py_objs_ = std::unordered_map(); -void Resource::RecordPrimitivePy(PrimitivePy *prim) { +MemoryCleaner Resource::mem_cleaner_ = MemoryCleaner(); +void MemoryCleaner::RecordPrimitivePy(PrimitivePy *prim) { if (prim == nullptr) { return; } - py_objs_[prim] = true; + all_primitives_[prim] = true; } -void Resource::ErasePrimitivePy(PrimitivePy *prim) { +void MemoryCleaner::ErasePrimitivePy(PrimitivePy *prim) { if (prim == nullptr) { return; } - auto it = py_objs_.find(prim); - if (it == py_objs_.end()) { + auto it = all_primitives_.find(prim); + if (it == all_primitives_.end()) { return; } // If flag is false,the pointer hased been released, so it can't be visited. if (!it->second) { return; } - py_objs_[prim] = false; + all_primitives_[prim] = false; prim->SetPyObj(py::none()); } -void Resource::ClearPrimitivePyPythonObj() { - for (auto &it : py_objs_) { +void MemoryCleaner::ClearPrimitivePyPythonObj() { + for (auto &it : all_primitives_) { if (it.second) { it.first->SetPyObj(py::none()); } } - py_objs_.clear(); + all_primitives_.clear(); } +void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) { + if (prim == nullptr) { + return; + } + if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) { + return; + } + MS_LOG(DEBUG) << "Record pynative tmp primitve:" << prim->ToString(); + pynative_short_life_primitives_.insert(prim); +} + +void MemoryCleaner::ErasePynativeShortLifePrimitivePy(PrimitivePy *prim) { + if (prim == nullptr) { + return; + } + if (pynative_short_life_primitives_.find(prim) == pynative_short_life_primitives_.end()) { + return; + } + MS_LOG(DEBUG) << "Erase pynative tmp primitive:" << prim->ToString(); + ErasePrimitivePy(prim); +} + +void MemoryCleaner::ClearPynativeShortLifePrimitivePy() { + for (auto &primitive : pynative_short_life_primitives_) { + ErasePynativeShortLifePrimitivePy(primitive); + } + pynative_short_life_primitives_.clear(); +} + +void MemoryCleaner::EnterPynativeConstructProcess() { pynative_in_construct_process_ = true; } +void MemoryCleaner::LeavePynativeConstructProcess() { + pynative_in_construct_process_ = false; + ClearPynativeShortLifePrimitivePy(); +} +bool MemoryCleaner::IsInPynativeConstructProcess() const { return pynative_in_construct_process_; } +void MemoryCleaner::EnterPynativeEndGraphProcess() { pynative_in_end_graph_process_ = true; } +void MemoryCleaner::LeavePynativeEndGraphProcess() { pynative_in_end_graph_process_ = false; } +bool MemoryCleaner::IsInPynativeEndGraphProcess() const { return pynative_in_end_graph_process_; } + void Resource::Clean() { // AbstractTensor->elements() will be saved in AbstractBasePtrList args_spec_.clear(); diff --git a/mindspore/ccsrc/pipeline/jit/resource.h b/mindspore/ccsrc/pipeline/jit/resource.h index ab81d8c409..ab14c3ab3b 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.h +++ b/mindspore/ccsrc/pipeline/jit/resource.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -52,6 +53,34 @@ BuiltInTypeMap &GetMethodMap(); BuiltInTypeMap &GetAttrMap(); +class MemoryCleaner { + public: + MemoryCleaner() = default; + ~MemoryCleaner() = default; + void RecordPrimitivePy(PrimitivePy *prim); + void ErasePrimitivePy(PrimitivePy *prim); + void ClearPrimitivePyPythonObj(); + + void RecordPynativeShortLifePrimitivePy(PrimitivePy *prim); + void ErasePynativeShortLifePrimitivePy(PrimitivePy *prim); + void ClearPynativeShortLifePrimitivePy(); + + void EnterPynativeConstructProcess(); + void LeavePynativeConstructProcess(); + bool IsInPynativeConstructProcess() const; + void EnterPynativeEndGraphProcess(); + void LeavePynativeEndGraphProcess(); + bool IsInPynativeEndGraphProcess() const; + + private: + std::unordered_map all_primitives_; + // PrimitivePy objects that created in pynative construct process.These primitives should be released after construct + // finished. + std::unordered_set pynative_short_life_primitives_; + bool pynative_in_construct_process_{false}; + bool pynative_in_end_graph_process_{false}; +}; + class Resource : public ResourceBase { public: explicit Resource(const py::object &obj = py::none()); @@ -80,13 +109,11 @@ class Resource : public ResourceBase { } bool gpu_loopsink_flag() { return gpu_loopsink_flag_; } int64_t gpu_loopsink_size() { return gpu_loopsink_size_; } - static void RecordPrimitivePy(PrimitivePy *prim); - static void ErasePrimitivePy(PrimitivePy *prim); - static void ClearPrimitivePyPythonObj(); // Reclaim resource and clear the cache. // ExecutorPy::Compile() can be called multiple times, so cache // should be cleared. void Clean(); + static MemoryCleaner &mem_cleaner() { return mem_cleaner_; } private: abstract::AnalysisEnginePtr engine_; @@ -96,7 +123,8 @@ class Resource : public ResourceBase { bool is_cleaned_; bool gpu_loopsink_flag_{false}; int64_t gpu_loopsink_size_{1}; - static std::unordered_map py_objs_; + // Used to handle mem leak objects. + static MemoryCleaner mem_cleaner_; }; using ResourcePtr = std::shared_ptr; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index c65224df1f..0f539cc6e9 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -2594,7 +2594,12 @@ void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { } void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { + MS_LOG(DEBUG) << "Enter end graph process."; + auto &mem_cleaner = pipeline::Resource::mem_cleaner(); + mem_cleaner.EnterPynativeEndGraphProcess(); PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args); + mem_cleaner.LeavePynativeEndGraphProcess(); + MS_LOG(DEBUG) << "Leave end graph process."; } void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, @@ -2609,6 +2614,24 @@ void PynativeExecutor::Sync() { session->SyncStream(); } +void PynativeExecutor::EnterConstruct(const py::object &cell) { + if (top_cell_ != nullptr) { + return; + } + top_cell_ = cell.ptr(); + pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess(); + MS_LOG(DEBUG) << "Enter construct process."; +} + +void PynativeExecutor::LeaveConstruct(const py::object &cell) { + if (top_cell_ != cell.ptr()) { + return; + } + top_cell_ = nullptr; + pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess(); + MS_LOG(DEBUG) << "Leave construct process."; +} + REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { (void)py::class_>(*m, "PynativeExecutor_") .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.") @@ -2620,6 +2643,10 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { .def("sync", &PynativeExecutor::Sync, "pynative sync stream.") .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.") .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), - "Executor set grad flag."); + "Executor set grad flag.") + .def("enter_construct", &PynativeExecutor::EnterConstruct, + "Do something before enter construct function.") + .def("leave_construct", &PynativeExecutor::LeaveConstruct, + "Do something after leave construct function."); })); } // namespace mindspore::pynative diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 2252b326b6..b74a851096 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -111,6 +111,8 @@ class PynativeExecutor : public std::enable_shared_from_this { bool need_replace_forward() const { return need_replace_forward_; } bool grad_flag() const { return grad_flag_; } void set_grad_flag(bool flag) { grad_flag_ = flag; } + void EnterConstruct(const py::object &cell); + void LeaveConstruct(const py::object &cell); py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); OpExecInfoPtr GenerateOpExecInfo(const py::args &args); @@ -272,6 +274,12 @@ class PynativeExecutor : public std::enable_shared_from_this { bool dynamic_cell_{false}; bool grad_is_running_{false}; bool need_replace_forward_{true}; + // The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script, + // such as Resnet50(Cell),LeNet(Cell).This pointer is used to distinguish temporary primitives from global + // primitives to control memory release. Global primitives are always created in top cell's '__init__' function and + // temporary primitives are always created in other place.Temporary primitives will be released after executing top + // cell's 'construct' function but global primitives will not. + PyObject *top_cell_{nullptr}; // Used for construct grad graph FuncGraphPtr curr_g_{nullptr}; diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index d4788ef4c9..909233acd5 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -54,9 +54,18 @@ std::map PrimitivePy::hook_grad_; PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name, false), python_obj_(python_obj), signatures_() { - pipeline::Resource::RecordPrimitivePy(this); + auto &mem_cleaner = pipeline::Resource::mem_cleaner(); + mem_cleaner.RecordPrimitivePy(this); + if (mem_cleaner.IsInPynativeConstructProcess() && !mem_cleaner.IsInPynativeEndGraphProcess()) { + mem_cleaner.RecordPynativeShortLifePrimitivePy(this); + } +} +PrimitivePy::~PrimitivePy() { + // Erase primitive here to set released flag false, to avoid calling released pointer when clear primitives in + // resource. + pipeline::Resource::mem_cleaner().ErasePrimitivePy(this); + MS_LOG(DEBUG) << "Release:" << ToString(); } -PrimitivePy::~PrimitivePy() { pipeline::Resource::ErasePrimitivePy(this); } void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; } void PrimitivePy::set_signatures(const std::vector &signatures) { signatures_ = signatures; diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 010ed26561..9a14448fd2 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -326,6 +326,12 @@ class _PynativeExecutor: def set_grad_flag(self, flag): self._executor.set_grad_flag(flag) + def enter_construct(self, cell): + self._executor.enter_construct(cell) + + def leave_construct(self, cell): + self._executor.leave_construct(cell) + def __call__(self, obj, *args, **kwargs): args = args + tuple(kwargs.values()) return self._executor(obj, args, "") diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 5e6b69deec..b6a9d5937d 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -352,9 +352,13 @@ class Cell(Cell_): if not cast_inputs: cast_inputs = inputs if self.enable_hook: + _pynative_exec.enter_construct(self) output = self._hook_construct(*cast_inputs, **kwargs) + _pynative_exec.leave_construct(self) else: + _pynative_exec.enter_construct(self) output = self.construct(*cast_inputs, **kwargs) + _pynative_exec.leave_construct(self) if isinstance(output, Parameter): output = output.data if self.requires_grad is True: diff --git a/tests/st/ops/gpu/test_sigmoid_grad_grad_op.py b/tests/st/ops/gpu/test_sigmoid_grad_grad_op.py index 209fe7d135..110230f49b 100644 --- a/tests/st/ops/gpu/test_sigmoid_grad_grad_op.py +++ b/tests/st/ops/gpu/test_sigmoid_grad_grad_op.py @@ -19,7 +19,6 @@ import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor -from mindspore.common.api import ms_function from mindspore.ops.operations import _grad_ops as G from mindspore.ops.composite import GradOperation @@ -29,7 +28,6 @@ class NetSigmoidGrad(nn.Cell): super(NetSigmoidGrad, self).__init__() self.sigmoid_grad = G.SigmoidGrad() - @ms_function def construct(self, y, dy): return self.sigmoid_grad(y, dy) @@ -40,7 +38,6 @@ class Grad(nn.Cell): self.grad = GradOperation(get_all=True, sens_param=True) self.network = network - @ms_function def construct(self, y, y_grad, dout): return self.grad(self.network)(y, y_grad, dout)