handle mem leak in pynative mode

use enter and leave construct
pull/9408/head
chenfei 4 years ago
parent 33faba7100
commit 1f20e552ed

@ -304,7 +304,7 @@ void ExecutorPy::DelNetRes(const std::string &id) {
void ExecutorPy::ClearRes() { void ExecutorPy::ClearRes() {
MS_LOG(INFO) << "Clean executor resource!"; MS_LOG(INFO) << "Clean executor resource!";
Resource::ClearPrimitivePyPythonObj(); Resource::mem_cleaner().ClearPrimitivePyPythonObj();
executor_ = nullptr; executor_ = nullptr;
} }

@ -275,39 +275,78 @@ Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
return GetMethodOrAttr(name, type_id, attr_map); return GetMethodOrAttr(name, type_id, attr_map);
} }
std::unordered_map<PrimitivePy *, bool> Resource::py_objs_ = std::unordered_map<PrimitivePy *, bool>(); MemoryCleaner Resource::mem_cleaner_ = MemoryCleaner();
void Resource::RecordPrimitivePy(PrimitivePy *prim) { void MemoryCleaner::RecordPrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) { if (prim == nullptr) {
return; return;
} }
py_objs_[prim] = true; all_primitives_[prim] = true;
} }
void Resource::ErasePrimitivePy(PrimitivePy *prim) { void MemoryCleaner::ErasePrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) { if (prim == nullptr) {
return; return;
} }
auto it = py_objs_.find(prim); auto it = all_primitives_.find(prim);
if (it == py_objs_.end()) { if (it == all_primitives_.end()) {
return; return;
} }
// If flag is false,the pointer hased been released, so it can't be visited. // If flag is false,the pointer hased been released, so it can't be visited.
if (!it->second) { if (!it->second) {
return; return;
} }
py_objs_[prim] = false; all_primitives_[prim] = false;
prim->SetPyObj(py::none()); prim->SetPyObj(py::none());
} }
void Resource::ClearPrimitivePyPythonObj() { void MemoryCleaner::ClearPrimitivePyPythonObj() {
for (auto &it : py_objs_) { for (auto &it : all_primitives_) {
if (it.second) { if (it.second) {
it.first->SetPyObj(py::none()); it.first->SetPyObj(py::none());
} }
} }
py_objs_.clear(); all_primitives_.clear();
} }
void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) {
return;
}
MS_LOG(DEBUG) << "Record pynative tmp primitve:" << prim->ToString();
pynative_short_life_primitives_.insert(prim);
}
void MemoryCleaner::ErasePynativeShortLifePrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
if (pynative_short_life_primitives_.find(prim) == pynative_short_life_primitives_.end()) {
return;
}
MS_LOG(DEBUG) << "Erase pynative tmp primitive:" << prim->ToString();
ErasePrimitivePy(prim);
}
void MemoryCleaner::ClearPynativeShortLifePrimitivePy() {
for (auto &primitive : pynative_short_life_primitives_) {
ErasePynativeShortLifePrimitivePy(primitive);
}
pynative_short_life_primitives_.clear();
}
void MemoryCleaner::EnterPynativeConstructProcess() { pynative_in_construct_process_ = true; }
void MemoryCleaner::LeavePynativeConstructProcess() {
pynative_in_construct_process_ = false;
ClearPynativeShortLifePrimitivePy();
}
bool MemoryCleaner::IsInPynativeConstructProcess() const { return pynative_in_construct_process_; }
void MemoryCleaner::EnterPynativeEndGraphProcess() { pynative_in_end_graph_process_ = true; }
void MemoryCleaner::LeavePynativeEndGraphProcess() { pynative_in_end_graph_process_ = false; }
bool MemoryCleaner::IsInPynativeEndGraphProcess() const { return pynative_in_end_graph_process_; }
void Resource::Clean() { void Resource::Clean() {
// AbstractTensor->elements() will be saved in AbstractBasePtrList // AbstractTensor->elements() will be saved in AbstractBasePtrList
args_spec_.clear(); args_spec_.clear();

@ -22,6 +22,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <memory> #include <memory>
#include <unordered_set>
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
@ -52,6 +53,34 @@ BuiltInTypeMap &GetMethodMap();
BuiltInTypeMap &GetAttrMap(); BuiltInTypeMap &GetAttrMap();
class MemoryCleaner {
public:
MemoryCleaner() = default;
~MemoryCleaner() = default;
void RecordPrimitivePy(PrimitivePy *prim);
void ErasePrimitivePy(PrimitivePy *prim);
void ClearPrimitivePyPythonObj();
void RecordPynativeShortLifePrimitivePy(PrimitivePy *prim);
void ErasePynativeShortLifePrimitivePy(PrimitivePy *prim);
void ClearPynativeShortLifePrimitivePy();
void EnterPynativeConstructProcess();
void LeavePynativeConstructProcess();
bool IsInPynativeConstructProcess() const;
void EnterPynativeEndGraphProcess();
void LeavePynativeEndGraphProcess();
bool IsInPynativeEndGraphProcess() const;
private:
std::unordered_map<PrimitivePy *, bool> all_primitives_;
// PrimitivePy objects that created in pynative construct process.These primitives should be released after construct
// finished.
std::unordered_set<PrimitivePy *> pynative_short_life_primitives_;
bool pynative_in_construct_process_{false};
bool pynative_in_end_graph_process_{false};
};
class Resource : public ResourceBase { class Resource : public ResourceBase {
public: public:
explicit Resource(const py::object &obj = py::none()); explicit Resource(const py::object &obj = py::none());
@ -80,13 +109,11 @@ class Resource : public ResourceBase {
} }
bool gpu_loopsink_flag() { return gpu_loopsink_flag_; } bool gpu_loopsink_flag() { return gpu_loopsink_flag_; }
int64_t gpu_loopsink_size() { return gpu_loopsink_size_; } int64_t gpu_loopsink_size() { return gpu_loopsink_size_; }
static void RecordPrimitivePy(PrimitivePy *prim);
static void ErasePrimitivePy(PrimitivePy *prim);
static void ClearPrimitivePyPythonObj();
// Reclaim resource and clear the cache. // Reclaim resource and clear the cache.
// ExecutorPy::Compile() can be called multiple times, so cache // ExecutorPy::Compile() can be called multiple times, so cache
// should be cleared. // should be cleared.
void Clean(); void Clean();
static MemoryCleaner &mem_cleaner() { return mem_cleaner_; }
private: private:
abstract::AnalysisEnginePtr engine_; abstract::AnalysisEnginePtr engine_;
@ -96,7 +123,8 @@ class Resource : public ResourceBase {
bool is_cleaned_; bool is_cleaned_;
bool gpu_loopsink_flag_{false}; bool gpu_loopsink_flag_{false};
int64_t gpu_loopsink_size_{1}; int64_t gpu_loopsink_size_{1};
static std::unordered_map<PrimitivePy *, bool> py_objs_; // Used to handle mem leak objects.
static MemoryCleaner mem_cleaner_;
}; };
using ResourcePtr = std::shared_ptr<pipeline::Resource>; using ResourcePtr = std::shared_ptr<pipeline::Resource>;

@ -2476,7 +2476,12 @@ void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
} }
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
MS_LOG(DEBUG) << "Enter end graph process.";
auto &mem_cleaner = pipeline::Resource::mem_cleaner();
mem_cleaner.EnterPynativeEndGraphProcess();
PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args); PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args);
mem_cleaner.LeavePynativeEndGraphProcess();
MS_LOG(DEBUG) << "Leave end graph process.";
} }
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
@ -2491,6 +2496,24 @@ void PynativeExecutor::Sync() {
session->SyncStream(); session->SyncStream();
} }
void PynativeExecutor::EnterConstruct(const py::object &cell) {
if (top_cell_ != nullptr) {
return;
}
top_cell_ = cell.ptr();
pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess();
MS_LOG(DEBUG) << "Enter construct process.";
}
void PynativeExecutor::LeaveConstruct(const py::object &cell) {
if (top_cell_ != cell.ptr()) {
return;
}
top_cell_ = nullptr;
pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess();
MS_LOG(DEBUG) << "Leave construct process.";
}
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_") (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.") .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
@ -2502,6 +2525,10 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
.def("sync", &PynativeExecutor::Sync, "pynative sync stream.") .def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.") .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
"Executor set grad flag."); "Executor set grad flag.")
.def("enter_construct", &PynativeExecutor::EnterConstruct,
"Do something before enter construct function.")
.def("leave_construct", &PynativeExecutor::LeaveConstruct,
"Do something after leave construct function.");
})); }));
} // namespace mindspore::pynative } // namespace mindspore::pynative

@ -108,6 +108,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool need_replace_forward() const { return need_replace_forward_; } bool need_replace_forward() const { return need_replace_forward_; }
bool grad_flag() const { return grad_flag_; } bool grad_flag() const { return grad_flag_; }
void set_grad_flag(bool flag) { grad_flag_ = flag; } void set_grad_flag(bool flag) { grad_flag_ = flag; }
void EnterConstruct(const py::object &cell);
void LeaveConstruct(const py::object &cell);
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
OpExecInfoPtr GenerateOpExecInfo(const py::args &args); OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
@ -263,6 +265,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool dynamic_cell_{false}; bool dynamic_cell_{false};
bool grad_is_running_{false}; bool grad_is_running_{false};
bool need_replace_forward_{true}; bool need_replace_forward_{true};
// The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script,
// such as Resnet50(Cell),LeNet(Cell).This pointer is used to distinguish temporary primitives from global
// primitives to control memory release. Global primitives are always created in top cell's '__init__' function and
// temporary primitives are always created in other place.Temporary primitives will be released after executing top
// cell's 'construct' function but global primitives will not.
PyObject *top_cell_{nullptr};
// Used for construct grad graph // Used for construct grad graph
FuncGraphPtr curr_g_{nullptr}; FuncGraphPtr curr_g_{nullptr};

@ -54,9 +54,18 @@ std::map<std::string, py::object> PrimitivePy::hook_grad_;
PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj) PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj)
: Primitive(name, false), python_obj_(python_obj), signatures_() { : Primitive(name, false), python_obj_(python_obj), signatures_() {
pipeline::Resource::RecordPrimitivePy(this); auto &mem_cleaner = pipeline::Resource::mem_cleaner();
mem_cleaner.RecordPrimitivePy(this);
if (mem_cleaner.IsInPynativeConstructProcess() && !mem_cleaner.IsInPynativeEndGraphProcess()) {
mem_cleaner.RecordPynativeShortLifePrimitivePy(this);
}
}
PrimitivePy::~PrimitivePy() {
// Erase primitive here to set released flag false, to avoid calling released pointer when clear primitives in
// resource.
pipeline::Resource::mem_cleaner().ErasePrimitivePy(this);
MS_LOG(DEBUG) << "Release:" << ToString();
} }
PrimitivePy::~PrimitivePy() { pipeline::Resource::ErasePrimitivePy(this); }
void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; } void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; }
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) { void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
signatures_ = signatures; signatures_ = signatures;

@ -321,6 +321,12 @@ class _PynativeExecutor:
def set_grad_flag(self, flag): def set_grad_flag(self, flag):
self._executor.set_grad_flag(flag) self._executor.set_grad_flag(flag)
def enter_construct(self, cell):
self._executor.enter_construct(cell)
def leave_construct(self, cell):
self._executor.leave_construct(cell)
def __call__(self, obj, *args, **kwargs): def __call__(self, obj, *args, **kwargs):
args = args + tuple(kwargs.values()) args = args + tuple(kwargs.values())
return self._executor(obj, args, "") return self._executor(obj, args, "")

@ -352,9 +352,13 @@ class Cell(Cell_):
if not cast_inputs: if not cast_inputs:
cast_inputs = inputs cast_inputs = inputs
if self.enable_hook: if self.enable_hook:
_pynative_exec.enter_construct(self)
output = self._hook_construct(*cast_inputs, **kwargs) output = self._hook_construct(*cast_inputs, **kwargs)
_pynative_exec.leave_construct(self)
else: else:
_pynative_exec.enter_construct(self)
output = self.construct(*cast_inputs, **kwargs) output = self.construct(*cast_inputs, **kwargs)
_pynative_exec.leave_construct(self)
if isinstance(output, Parameter): if isinstance(output, Parameter):
output = output.data output = output.data
if self.requires_grad is True: if self.requires_grad is True:

@ -19,7 +19,6 @@ import pytest
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.composite import GradOperation from mindspore.ops.composite import GradOperation
@ -29,7 +28,6 @@ class NetSigmoidGrad(nn.Cell):
super(NetSigmoidGrad, self).__init__() super(NetSigmoidGrad, self).__init__()
self.sigmoid_grad = G.SigmoidGrad() self.sigmoid_grad = G.SigmoidGrad()
@ms_function
def construct(self, y, dy): def construct(self, y, dy):
return self.sigmoid_grad(y, dy) return self.sigmoid_grad(y, dy)
@ -40,7 +38,6 @@ class Grad(nn.Cell):
self.grad = GradOperation(get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function
def construct(self, y, y_grad, dout): def construct(self, y, y_grad, dout):
return self.grad(self.network)(y, y_grad, dout) return self.grad(self.network)(y, y_grad, dout)

Loading…
Cancel
Save