|
|
|
@ -2594,7 +2594,12 @@ void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
|
|
|
|
|
MS_LOG(DEBUG) << "Enter end graph process.";
|
|
|
|
|
auto &mem_cleaner = pipeline::Resource::mem_cleaner();
|
|
|
|
|
mem_cleaner.EnterPynativeEndGraphProcess();
|
|
|
|
|
PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args);
|
|
|
|
|
mem_cleaner.LeavePynativeEndGraphProcess();
|
|
|
|
|
MS_LOG(DEBUG) << "Leave end graph process.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
|
|
|
@ -2609,6 +2614,24 @@ void PynativeExecutor::Sync() {
|
|
|
|
|
session->SyncStream();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::EnterConstruct(const py::object &cell) {
|
|
|
|
|
if (top_cell_ != nullptr) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
top_cell_ = cell.ptr();
|
|
|
|
|
pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess();
|
|
|
|
|
MS_LOG(DEBUG) << "Enter construct process.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::LeaveConstruct(const py::object &cell) {
|
|
|
|
|
if (top_cell_ != cell.ptr()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
top_cell_ = nullptr;
|
|
|
|
|
pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess();
|
|
|
|
|
MS_LOG(DEBUG) << "Leave construct process.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
|
|
|
|
|
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
|
|
|
|
|
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
|
|
|
|
@ -2620,6 +2643,10 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
|
|
|
|
|
.def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
|
|
|
|
|
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
|
|
|
|
|
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
|
|
|
|
|
"Executor set grad flag.");
|
|
|
|
|
"Executor set grad flag.")
|
|
|
|
|
.def("enter_construct", &PynativeExecutor::EnterConstruct,
|
|
|
|
|
"Do something before enter construct function.")
|
|
|
|
|
.def("leave_construct", &PynativeExecutor::LeaveConstruct,
|
|
|
|
|
"Do something after leave construct function.");
|
|
|
|
|
}));
|
|
|
|
|
} // namespace mindspore::pynative
|
|
|
|
|