handle mem leak in pynative mode

use enter and leave construct
pull/9408/head
chenfei 4 years ago
parent 33faba7100
commit 1f20e552ed

@ -304,7 +304,7 @@ void ExecutorPy::DelNetRes(const std::string &id) {
void ExecutorPy::ClearRes() {
MS_LOG(INFO) << "Clean executor resource!";
Resource::ClearPrimitivePyPythonObj();
Resource::mem_cleaner().ClearPrimitivePyPythonObj();
executor_ = nullptr;
}

@ -275,39 +275,78 @@ Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
return GetMethodOrAttr(name, type_id, attr_map);
}
std::unordered_map<PrimitivePy *, bool> Resource::py_objs_ = std::unordered_map<PrimitivePy *, bool>();
void Resource::RecordPrimitivePy(PrimitivePy *prim) {
MemoryCleaner Resource::mem_cleaner_ = MemoryCleaner();
void MemoryCleaner::RecordPrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
py_objs_[prim] = true;
all_primitives_[prim] = true;
}
void Resource::ErasePrimitivePy(PrimitivePy *prim) {
void MemoryCleaner::ErasePrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
auto it = py_objs_.find(prim);
if (it == py_objs_.end()) {
auto it = all_primitives_.find(prim);
if (it == all_primitives_.end()) {
return;
}
// If flag is false,the pointer hased been released, so it can't be visited.
if (!it->second) {
return;
}
py_objs_[prim] = false;
all_primitives_[prim] = false;
prim->SetPyObj(py::none());
}
void Resource::ClearPrimitivePyPythonObj() {
for (auto &it : py_objs_) {
void MemoryCleaner::ClearPrimitivePyPythonObj() {
for (auto &it : all_primitives_) {
if (it.second) {
it.first->SetPyObj(py::none());
}
}
py_objs_.clear();
all_primitives_.clear();
}
void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) {
return;
}
MS_LOG(DEBUG) << "Record pynative tmp primitve:" << prim->ToString();
pynative_short_life_primitives_.insert(prim);
}
void MemoryCleaner::ErasePynativeShortLifePrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
if (pynative_short_life_primitives_.find(prim) == pynative_short_life_primitives_.end()) {
return;
}
MS_LOG(DEBUG) << "Erase pynative tmp primitive:" << prim->ToString();
ErasePrimitivePy(prim);
}
void MemoryCleaner::ClearPynativeShortLifePrimitivePy() {
for (auto &primitive : pynative_short_life_primitives_) {
ErasePynativeShortLifePrimitivePy(primitive);
}
pynative_short_life_primitives_.clear();
}
void MemoryCleaner::EnterPynativeConstructProcess() { pynative_in_construct_process_ = true; }
void MemoryCleaner::LeavePynativeConstructProcess() {
pynative_in_construct_process_ = false;
ClearPynativeShortLifePrimitivePy();
}
bool MemoryCleaner::IsInPynativeConstructProcess() const { return pynative_in_construct_process_; }
void MemoryCleaner::EnterPynativeEndGraphProcess() { pynative_in_end_graph_process_ = true; }
void MemoryCleaner::LeavePynativeEndGraphProcess() { pynative_in_end_graph_process_ = false; }
bool MemoryCleaner::IsInPynativeEndGraphProcess() const { return pynative_in_end_graph_process_; }
void Resource::Clean() {
// AbstractTensor->elements() will be saved in AbstractBasePtrList
args_spec_.clear();

@ -22,6 +22,7 @@
#include <string>
#include <unordered_map>
#include <memory>
#include <unordered_set>
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
@ -52,6 +53,34 @@ BuiltInTypeMap &GetMethodMap();
BuiltInTypeMap &GetAttrMap();
class MemoryCleaner {
public:
MemoryCleaner() = default;
~MemoryCleaner() = default;
void RecordPrimitivePy(PrimitivePy *prim);
void ErasePrimitivePy(PrimitivePy *prim);
void ClearPrimitivePyPythonObj();
void RecordPynativeShortLifePrimitivePy(PrimitivePy *prim);
void ErasePynativeShortLifePrimitivePy(PrimitivePy *prim);
void ClearPynativeShortLifePrimitivePy();
void EnterPynativeConstructProcess();
void LeavePynativeConstructProcess();
bool IsInPynativeConstructProcess() const;
void EnterPynativeEndGraphProcess();
void LeavePynativeEndGraphProcess();
bool IsInPynativeEndGraphProcess() const;
private:
std::unordered_map<PrimitivePy *, bool> all_primitives_;
// PrimitivePy objects that created in pynative construct process.These primitives should be released after construct
// finished.
std::unordered_set<PrimitivePy *> pynative_short_life_primitives_;
bool pynative_in_construct_process_{false};
bool pynative_in_end_graph_process_{false};
};
class Resource : public ResourceBase {
public:
explicit Resource(const py::object &obj = py::none());
@ -80,13 +109,11 @@ class Resource : public ResourceBase {
}
bool gpu_loopsink_flag() { return gpu_loopsink_flag_; }
int64_t gpu_loopsink_size() { return gpu_loopsink_size_; }
static void RecordPrimitivePy(PrimitivePy *prim);
static void ErasePrimitivePy(PrimitivePy *prim);
static void ClearPrimitivePyPythonObj();
// Reclaim resource and clear the cache.
// ExecutorPy::Compile() can be called multiple times, so cache
// should be cleared.
void Clean();
static MemoryCleaner &mem_cleaner() { return mem_cleaner_; }
private:
abstract::AnalysisEnginePtr engine_;
@ -96,7 +123,8 @@ class Resource : public ResourceBase {
bool is_cleaned_;
bool gpu_loopsink_flag_{false};
int64_t gpu_loopsink_size_{1};
static std::unordered_map<PrimitivePy *, bool> py_objs_;
// Used to handle mem leak objects.
static MemoryCleaner mem_cleaner_;
};
using ResourcePtr = std::shared_ptr<pipeline::Resource>;

@ -2476,7 +2476,12 @@ void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
}
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
MS_LOG(DEBUG) << "Enter end graph process.";
auto &mem_cleaner = pipeline::Resource::mem_cleaner();
mem_cleaner.EnterPynativeEndGraphProcess();
PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args);
mem_cleaner.LeavePynativeEndGraphProcess();
MS_LOG(DEBUG) << "Leave end graph process.";
}
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
@ -2491,6 +2496,24 @@ void PynativeExecutor::Sync() {
session->SyncStream();
}
void PynativeExecutor::EnterConstruct(const py::object &cell) {
if (top_cell_ != nullptr) {
return;
}
top_cell_ = cell.ptr();
pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess();
MS_LOG(DEBUG) << "Enter construct process.";
}
void PynativeExecutor::LeaveConstruct(const py::object &cell) {
if (top_cell_ != cell.ptr()) {
return;
}
top_cell_ = nullptr;
pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess();
MS_LOG(DEBUG) << "Leave construct process.";
}
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
@ -2502,6 +2525,10 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
.def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
"Executor set grad flag.");
"Executor set grad flag.")
.def("enter_construct", &PynativeExecutor::EnterConstruct,
"Do something before enter construct function.")
.def("leave_construct", &PynativeExecutor::LeaveConstruct,
"Do something after leave construct function.");
}));
} // namespace mindspore::pynative

@ -108,6 +108,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool need_replace_forward() const { return need_replace_forward_; }
bool grad_flag() const { return grad_flag_; }
void set_grad_flag(bool flag) { grad_flag_ = flag; }
void EnterConstruct(const py::object &cell);
void LeaveConstruct(const py::object &cell);
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
@ -263,6 +265,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool dynamic_cell_{false};
bool grad_is_running_{false};
bool need_replace_forward_{true};
// The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script,
// such as Resnet50(Cell),LeNet(Cell).This pointer is used to distinguish temporary primitives from global
// primitives to control memory release. Global primitives are always created in top cell's '__init__' function and
// temporary primitives are always created in other place.Temporary primitives will be released after executing top
// cell's 'construct' function but global primitives will not.
PyObject *top_cell_{nullptr};
// Used for construct grad graph
FuncGraphPtr curr_g_{nullptr};

@ -54,9 +54,18 @@ std::map<std::string, py::object> PrimitivePy::hook_grad_;
PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj)
: Primitive(name, false), python_obj_(python_obj), signatures_() {
pipeline::Resource::RecordPrimitivePy(this);
auto &mem_cleaner = pipeline::Resource::mem_cleaner();
mem_cleaner.RecordPrimitivePy(this);
if (mem_cleaner.IsInPynativeConstructProcess() && !mem_cleaner.IsInPynativeEndGraphProcess()) {
mem_cleaner.RecordPynativeShortLifePrimitivePy(this);
}
}
PrimitivePy::~PrimitivePy() {
// Erase primitive here to set released flag false, to avoid calling released pointer when clear primitives in
// resource.
pipeline::Resource::mem_cleaner().ErasePrimitivePy(this);
MS_LOG(DEBUG) << "Release:" << ToString();
}
PrimitivePy::~PrimitivePy() { pipeline::Resource::ErasePrimitivePy(this); }
void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; }
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
signatures_ = signatures;

@ -321,6 +321,12 @@ class _PynativeExecutor:
def set_grad_flag(self, flag):
self._executor.set_grad_flag(flag)
def enter_construct(self, cell):
self._executor.enter_construct(cell)
def leave_construct(self, cell):
self._executor.leave_construct(cell)
def __call__(self, obj, *args, **kwargs):
args = args + tuple(kwargs.values())
return self._executor(obj, args, "")

@ -352,9 +352,13 @@ class Cell(Cell_):
if not cast_inputs:
cast_inputs = inputs
if self.enable_hook:
_pynative_exec.enter_construct(self)
output = self._hook_construct(*cast_inputs, **kwargs)
_pynative_exec.leave_construct(self)
else:
_pynative_exec.enter_construct(self)
output = self.construct(*cast_inputs, **kwargs)
_pynative_exec.leave_construct(self)
if isinstance(output, Parameter):
output = output.data
if self.requires_grad is True:

@ -19,7 +19,6 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.composite import GradOperation
@ -29,7 +28,6 @@ class NetSigmoidGrad(nn.Cell):
super(NetSigmoidGrad, self).__init__()
self.sigmoid_grad = G.SigmoidGrad()
@ms_function
def construct(self, y, dy):
return self.sigmoid_grad(y, dy)
@ -40,7 +38,6 @@ class Grad(nn.Cell):
self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network
@ms_function
def construct(self, y, y_grad, dout):
return self.grad(self.network)(y, y_grad, dout)

Loading…
Cancel
Save