diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index b931196bcd..d9f476f0ca 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1433,8 +1433,8 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args & } bool PynativeExecutor::IsNotNestedGrad() const { - MS_LOG(DEBUG) << "Grad nested count is " << grad_count_; - return grad_count_ <= 1; + MS_LOG(DEBUG) << "Grad nested count is " << grad_order_; + return grad_order_ <= 1; } bool PynativeExecutor::IsTopGraph(const std::string &cell_id) { @@ -1446,8 +1446,8 @@ bool PynativeExecutor::IsTopGraph(const std::string &cell_id) { } void PynativeExecutor::SubNestedGradCount() { - if (grad_count_ > 0) { - --grad_count_; + if (grad_order_ > 0) { + --grad_order_; } } @@ -1828,7 +1828,7 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string void PynativeExecutor::UpdateCellGraph(const std::string &cell_id, bool need_cloned, bool is_grad) { FuncGraphPtr tmp = curr_g_; - if (need_cloned) { + if (need_cloned && !IsNotNestedGrad()) { auto cloned_curr_g = BasicClone(curr_g_); graph_info_map_[cloned_curr_g] = graph_info_map_.at(curr_g_); tmp = cloned_curr_g; @@ -2365,7 +2365,7 @@ void PynativeExecutor::Clean() { void PynativeExecutor::ClearRes() { MS_LOG(DEBUG) << "Clear all res"; Clean(); - grad_count_ = 0; + grad_order_ = 0; grad_flag_ = false; dynamic_cell_ = false; grad_is_running_ = false; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index f8445174ba..ba2f1a8372 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -147,7 +147,7 @@ class PynativeExecutor : public std::enable_shared_from_this { void PopGraphStack(); FuncGraphPtr GetDfbuilder(const std::string &cell_id = ""); ResourcePtr GetResource(const std::string &cell_id = ""); - void AddNestedGradCount() { ++grad_count_; } + void AddNestedGradCount() { ++grad_order_; } void SubNestedGradCount(); bool IsNotNestedGrad() const; bool IsTopGraph(const std::string &cell_id); @@ -204,7 +204,7 @@ class PynativeExecutor : public std::enable_shared_from_this { static std::shared_ptr executor_; static std::mutex instance_lock_; static int64_t graph_id_; - int64_t grad_count_{0}; + int64_t grad_order_{0}; bool grad_flag_{false}; bool dynamic_cell_{false}; bool grad_is_running_{false};