!8974 [ME]Fix bug of memory leak of calss PrimitivePy in graph mode

From: @chenfei52
Reviewed-by: 
Signed-off-by:
pull/8974/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ca298b83e4

@ -313,6 +313,7 @@ void ExecutorPy::DelNetRes(const std::string &id) {
void ExecutorPy::ClearRes() { void ExecutorPy::ClearRes() {
MS_LOG(INFO) << "Clean executor resource!"; MS_LOG(INFO) << "Clean executor resource!";
Resource::ClearPrimitivePyPythonObj();
executor_ = nullptr; executor_ = nullptr;
} }

@ -275,6 +275,39 @@ Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
return GetMethodOrAttr(name, type_id, attr_map); return GetMethodOrAttr(name, type_id, attr_map);
} }
std::unordered_map<PrimitivePy *, bool> Resource::py_objs_ = std::unordered_map<PrimitivePy *, bool>();
void Resource::RecordPrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
py_objs_[prim] = true;
}
void Resource::ErasePrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
auto it = py_objs_.find(prim);
if (it == py_objs_.end()) {
return;
}
// If flag is false,the pointer hased been released, so it can't be visited.
if (!it->second) {
return;
}
py_objs_[prim] = false;
prim->SetPyObj(py::none());
}
void Resource::ClearPrimitivePyPythonObj() {
for (auto &it : py_objs_) {
if (it.second) {
it.first->SetPyObj(py::none());
}
}
py_objs_.clear();
}
void Resource::Clean() { void Resource::Clean() {
// AbstractTensor->elements() will be saved in AbstractBasePtrList // AbstractTensor->elements() will be saved in AbstractBasePtrList
args_spec_.clear(); args_spec_.clear();

@ -80,7 +80,9 @@ class Resource : public ResourceBase {
} }
bool gpu_loopsink_flag() { return gpu_loopsink_flag_; } bool gpu_loopsink_flag() { return gpu_loopsink_flag_; }
int64_t gpu_loopsink_size() { return gpu_loopsink_size_; } int64_t gpu_loopsink_size() { return gpu_loopsink_size_; }
static void RecordPrimitivePy(PrimitivePy *prim);
static void ErasePrimitivePy(PrimitivePy *prim);
static void ClearPrimitivePyPythonObj();
// Reclaim resource and clear the cache. // Reclaim resource and clear the cache.
// ExecutorPy::Compile() can be called multiple times, so cache // ExecutorPy::Compile() can be called multiple times, so cache
// should be cleared. // should be cleared.
@ -94,6 +96,7 @@ class Resource : public ResourceBase {
bool is_cleaned_; bool is_cleaned_;
bool gpu_loopsink_flag_{false}; bool gpu_loopsink_flag_{false};
int64_t gpu_loopsink_size_{1}; int64_t gpu_loopsink_size_{1};
static std::unordered_map<PrimitivePy *, bool> py_objs_;
}; };
using ResourcePtr = std::shared_ptr<pipeline::Resource>; using ResourcePtr = std::shared_ptr<pipeline::Resource>;

@ -28,12 +28,14 @@
#include "utils/convert_utils_py.h" #include "utils/convert_utils_py.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "utils/primitive_utils.h" #include "utils/primitive_utils.h"
#include "pipeline/jit/resource.h"
namespace mindspore { namespace mindspore {
namespace { namespace {
constexpr auto kBpropAttrName = "bprop"; constexpr auto kBpropAttrName = "bprop";
constexpr auto kCellHookAttrName = "cell_hook"; constexpr auto kCellHookAttrName = "cell_hook";
constexpr auto kCellIDAttrName = "cell_id"; constexpr auto kCellIDAttrName = "cell_id";
void SyncData(const py::object &arg) { void SyncData(const py::object &arg) {
if (py::isinstance<py::tuple>(arg)) { if (py::isinstance<py::tuple>(arg)) {
py::tuple arg_list = py::cast<py::tuple>(arg); py::tuple arg_list = py::cast<py::tuple>(arg);
@ -49,6 +51,12 @@ void SyncData(const py::object &arg) {
} // namespace } // namespace
std::map<std::string, py::object> PrimitivePy::hook_grad_; std::map<std::string, py::object> PrimitivePy::hook_grad_;
PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj)
: Primitive(name, false), python_obj_(python_obj), signatures_() {
pipeline::Resource::RecordPrimitivePy(this);
}
PrimitivePy::~PrimitivePy() { pipeline::Resource::ErasePrimitivePy(this); }
void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; }
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) { void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
signatures_ = signatures; signatures_ = signatures;
set_has_signature(true); set_has_signature(true);

@ -36,9 +36,8 @@ namespace py = pybind11;
namespace mindspore { namespace mindspore {
class PrimitivePy : public Primitive { class PrimitivePy : public Primitive {
public: public:
PrimitivePy(const py::str &name, const py::object &python_obj) PrimitivePy(const py::str &name, const py::object &python_obj);
: Primitive(name, false), python_obj_(python_obj), signatures_() {} ~PrimitivePy() override;
~PrimitivePy() override = default;
MS_DECLARE_PARENT(PrimitivePy, Primitive); MS_DECLARE_PARENT(PrimitivePy, Primitive);
py::function GetBpropFunction(); py::function GetBpropFunction();
@ -59,6 +58,7 @@ class PrimitivePy : public Primitive {
bool HasComputeFunction() const; bool HasComputeFunction() const;
const bool parse_info_ = true; const bool parse_info_ = true;
const py::object &GetPyObj() const { return python_obj_; } const py::object &GetPyObj() const { return python_obj_; }
void SetPyObj(const py::object &obj);
py::dict RunInfer(const py::tuple &args); py::dict RunInfer(const py::tuple &args);
void RunCheck(const py::tuple &args); void RunCheck(const py::tuple &args);
py::object RunInferValue(const py::tuple &args); py::object RunInferValue(const py::tuple &args);

Loading…
Cancel
Save