|
|
|
@ -1433,8 +1433,8 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PynativeExecutor::IsNotNestedGrad() const {
|
|
|
|
|
MS_LOG(DEBUG) << "Grad nested count is " << grad_count_;
|
|
|
|
|
return grad_count_ <= 1;
|
|
|
|
|
MS_LOG(DEBUG) << "Grad nested count is " << grad_order_;
|
|
|
|
|
return grad_order_ <= 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
|
|
|
|
@ -1446,8 +1446,8 @@ bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::SubNestedGradCount() {
|
|
|
|
|
if (grad_count_ > 0) {
|
|
|
|
|
--grad_count_;
|
|
|
|
|
if (grad_order_ > 0) {
|
|
|
|
|
--grad_order_;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1828,7 +1828,7 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::UpdateCellGraph(const std::string &cell_id, bool need_cloned, bool is_grad) {
|
|
|
|
|
FuncGraphPtr tmp = curr_g_;
|
|
|
|
|
if (need_cloned) {
|
|
|
|
|
if (need_cloned && !IsNotNestedGrad()) {
|
|
|
|
|
auto cloned_curr_g = BasicClone(curr_g_);
|
|
|
|
|
graph_info_map_[cloned_curr_g] = graph_info_map_.at(curr_g_);
|
|
|
|
|
tmp = cloned_curr_g;
|
|
|
|
@ -2365,7 +2365,7 @@ void PynativeExecutor::Clean() {
|
|
|
|
|
void PynativeExecutor::ClearRes() {
|
|
|
|
|
MS_LOG(DEBUG) << "Clear all res";
|
|
|
|
|
Clean();
|
|
|
|
|
grad_count_ = 0;
|
|
|
|
|
grad_order_ = 0;
|
|
|
|
|
grad_flag_ = false;
|
|
|
|
|
dynamic_cell_ = false;
|
|
|
|
|
grad_is_running_ = false;
|
|
|
|
|