Fix pynative order derivetive memory

Signed-off-by: zjun <zhangjun0@huawei.com>
pull/9868/head
zjun 4 years ago
parent 304211e2b6
commit 96b0452c14

@ -1433,8 +1433,8 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
}
bool PynativeExecutor::IsNotNestedGrad() const {
MS_LOG(DEBUG) << "Grad nested count is " << grad_count_;
return grad_count_ <= 1;
MS_LOG(DEBUG) << "Grad nested count is " << grad_order_;
return grad_order_ <= 1;
}
bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
@ -1446,8 +1446,8 @@ bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
}
void PynativeExecutor::SubNestedGradCount() {
if (grad_count_ > 0) {
--grad_count_;
if (grad_order_ > 0) {
--grad_order_;
}
}
@ -1828,7 +1828,7 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
void PynativeExecutor::UpdateCellGraph(const std::string &cell_id, bool need_cloned, bool is_grad) {
FuncGraphPtr tmp = curr_g_;
if (need_cloned) {
if (need_cloned && !IsNotNestedGrad()) {
auto cloned_curr_g = BasicClone(curr_g_);
graph_info_map_[cloned_curr_g] = graph_info_map_.at(curr_g_);
tmp = cloned_curr_g;
@ -2365,7 +2365,7 @@ void PynativeExecutor::Clean() {
void PynativeExecutor::ClearRes() {
MS_LOG(DEBUG) << "Clear all res";
Clean();
grad_count_ = 0;
grad_order_ = 0;
grad_flag_ = false;
dynamic_cell_ = false;
grad_is_running_ = false;

@ -147,7 +147,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void PopGraphStack();
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
ResourcePtr GetResource(const std::string &cell_id = "");
void AddNestedGradCount() { ++grad_count_; }
void AddNestedGradCount() { ++grad_order_; }
void SubNestedGradCount();
bool IsNotNestedGrad() const;
bool IsTopGraph(const std::string &cell_id);
@ -204,7 +204,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
static std::shared_ptr<PynativeExecutor> executor_;
static std::mutex instance_lock_;
static int64_t graph_id_;
int64_t grad_count_{0};
int64_t grad_order_{0};
bool grad_flag_{false};
bool dynamic_cell_{false};
bool grad_is_running_{false};

Loading…
Cancel
Save