!8974 [ME]Fix bug of memory leak of calss PrimitivePy in graph mode

From: @chenfei52
Reviewed-by: 
Signed-off-by:
pull/8974/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ca298b83e4

@ -313,6 +313,7 @@ void ExecutorPy::DelNetRes(const std::string &id) {
void ExecutorPy::ClearRes() {
MS_LOG(INFO) << "Clean executor resource!";
Resource::ClearPrimitivePyPythonObj();
executor_ = nullptr;
}

@ -275,6 +275,39 @@ Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
return GetMethodOrAttr(name, type_id, attr_map);
}
std::unordered_map<PrimitivePy *, bool> Resource::py_objs_ = std::unordered_map<PrimitivePy *, bool>();
void Resource::RecordPrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
py_objs_[prim] = true;
}
void Resource::ErasePrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
auto it = py_objs_.find(prim);
if (it == py_objs_.end()) {
return;
}
// If flag is false,the pointer hased been released, so it can't be visited.
if (!it->second) {
return;
}
py_objs_[prim] = false;
prim->SetPyObj(py::none());
}
void Resource::ClearPrimitivePyPythonObj() {
for (auto &it : py_objs_) {
if (it.second) {
it.first->SetPyObj(py::none());
}
}
py_objs_.clear();
}
void Resource::Clean() {
// AbstractTensor->elements() will be saved in AbstractBasePtrList
args_spec_.clear();

@ -80,7 +80,9 @@ class Resource : public ResourceBase {
}
bool gpu_loopsink_flag() { return gpu_loopsink_flag_; }
int64_t gpu_loopsink_size() { return gpu_loopsink_size_; }
static void RecordPrimitivePy(PrimitivePy *prim);
static void ErasePrimitivePy(PrimitivePy *prim);
static void ClearPrimitivePyPythonObj();
// Reclaim resource and clear the cache.
// ExecutorPy::Compile() can be called multiple times, so cache
// should be cleared.
@ -94,6 +96,7 @@ class Resource : public ResourceBase {
bool is_cleaned_;
bool gpu_loopsink_flag_{false};
int64_t gpu_loopsink_size_{1};
static std::unordered_map<PrimitivePy *, bool> py_objs_;
};
using ResourcePtr = std::shared_ptr<pipeline::Resource>;

@ -28,12 +28,14 @@
#include "utils/convert_utils_py.h"
#include "utils/ms_context.h"
#include "utils/primitive_utils.h"
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace {
constexpr auto kBpropAttrName = "bprop";
constexpr auto kCellHookAttrName = "cell_hook";
constexpr auto kCellIDAttrName = "cell_id";
void SyncData(const py::object &arg) {
if (py::isinstance<py::tuple>(arg)) {
py::tuple arg_list = py::cast<py::tuple>(arg);
@ -49,6 +51,12 @@ void SyncData(const py::object &arg) {
} // namespace
std::map<std::string, py::object> PrimitivePy::hook_grad_;
PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj)
: Primitive(name, false), python_obj_(python_obj), signatures_() {
pipeline::Resource::RecordPrimitivePy(this);
}
PrimitivePy::~PrimitivePy() { pipeline::Resource::ErasePrimitivePy(this); }
void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; }
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
signatures_ = signatures;
set_has_signature(true);

@ -36,9 +36,8 @@ namespace py = pybind11;
namespace mindspore {
class PrimitivePy : public Primitive {
public:
PrimitivePy(const py::str &name, const py::object &python_obj)
: Primitive(name, false), python_obj_(python_obj), signatures_() {}
~PrimitivePy() override = default;
PrimitivePy(const py::str &name, const py::object &python_obj);
~PrimitivePy() override;
MS_DECLARE_PARENT(PrimitivePy, Primitive);
py::function GetBpropFunction();
@ -59,6 +58,7 @@ class PrimitivePy : public Primitive {
bool HasComputeFunction() const;
const bool parse_info_ = true;
const py::object &GetPyObj() const { return python_obj_; }
void SetPyObj(const py::object &obj);
py::dict RunInfer(const py::tuple &args);
void RunCheck(const py::tuple &args);
py::object RunInferValue(const py::tuple &args);

Loading…
Cancel
Save