|
|
|
@ -57,7 +57,7 @@ using mindspore::tensor::TensorPy;
|
|
|
|
|
|
|
|
|
|
const char SINGLE_OP_GRAPH[] = "single_op_graph";
|
|
|
|
|
// primitive unable to infer value for constant input in PyNative mode
|
|
|
|
|
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient"};
|
|
|
|
|
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"};
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace pynative {
|
|
|
|
@ -815,6 +815,9 @@ PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
|
|
|
|
|
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
|
|
|
|
|
auto cell_id = GetId(cell);
|
|
|
|
|
if (cell_graph_map_.count(cell_id) != 0) {
|
|
|
|
|
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) {
|
|
|
|
|
resource_ = cell_resource_map_[cell_id];
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Newgraph already compiled";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -823,6 +826,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
|
|
|
|
|
|
|
|
|
|
if (top_g_ == nullptr) {
|
|
|
|
|
top_g_ = curr_g_ = g;
|
|
|
|
|
resource_ = std::make_shared<pipeline::Resource>();
|
|
|
|
|
cell_resource_map_[cell_id] = resource_;
|
|
|
|
|
df_builder_ = std::make_shared<FuncGraph>();
|
|
|
|
|
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
|
|
|
|
|
Pushp();
|
|
|
|
@ -1124,6 +1129,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
|
|
|
|
|
MS_LOG(DEBUG) << "Clear res";
|
|
|
|
|
(void)graph_map_.erase(flag);
|
|
|
|
|
(void)cell_graph_map_.erase(flag);
|
|
|
|
|
(void)cell_resource_map_.erase(flag);
|
|
|
|
|
Clean();
|
|
|
|
|
// Maybe exit in the pynative runing op, so need reset pynative flag.
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
@ -1135,6 +1141,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
|
|
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Clear";
|
|
|
|
|
top_g_ = nullptr;
|
|
|
|
|
df_builder_ = nullptr;
|
|
|
|
|
curr_g_ = nullptr;
|
|
|
|
|
graph_info_map_.clear();
|
|
|
|
|
op_id_map_.clear();
|
|
|
|
@ -1146,7 +1153,6 @@ void PynativeExecutor::Clean() {
|
|
|
|
|
Clear();
|
|
|
|
|
grad_flag_ = false;
|
|
|
|
|
op_forward_map_.clear();
|
|
|
|
|
df_builder_ = nullptr;
|
|
|
|
|
ad::CleanRes();
|
|
|
|
|
pipeline::ReclaimOptimizer();
|
|
|
|
|
}
|
|
|
|
|