|
|
|
@ -1467,13 +1467,16 @@ bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad)
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::ClearResidualRes() {
|
|
|
|
|
void PynativeExecutor::ClearResidualRes(const std::string &cell_id) {
|
|
|
|
|
if (top_cell_list_.empty() && !graph_stack_.empty()) {
|
|
|
|
|
graph_id_ = 0;
|
|
|
|
|
graph_info_map_.clear();
|
|
|
|
|
cell_sw_map_.clear();
|
|
|
|
|
cell_graph_list_.clear();
|
|
|
|
|
top_cell_list_.clear();
|
|
|
|
|
std::stack<FuncGraphPtr>().swap(graph_stack_);
|
|
|
|
|
}
|
|
|
|
|
if (dynamic_cell_) {
|
|
|
|
|
VectorClear<std::vector<TopCellInfo>>(&top_cell_list_, cell_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1486,8 +1489,8 @@ FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) {
|
|
|
|
|
if (grad_order_ == 0 || grad_order_ == 1) {
|
|
|
|
|
return top_cell_list_.back().df_builder;
|
|
|
|
|
}
|
|
|
|
|
if (top_cell_list_.size() < grad_order_) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Get wrong grad order";
|
|
|
|
|
if (top_cell_list_.size() < 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Top cell list size must greater than 2";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size();
|
|
|
|
|
// Grad order greater than 2
|
|
|
|
@ -1517,8 +1520,8 @@ ResourcePtr PynativeExecutor::GetResource(const std::string &cell_id) {
|
|
|
|
|
if (grad_order_ == 0 || grad_order_ == 1) {
|
|
|
|
|
return top_cell_list_.back().resource;
|
|
|
|
|
}
|
|
|
|
|
if (top_cell_list_.size() < grad_order_) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Get wrong grad order";
|
|
|
|
|
if (top_cell_list_.size() < 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Top cell list size must greater than 2";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size();
|
|
|
|
|
// Grad order greater than 2
|
|
|
|
@ -1718,7 +1721,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
|
|
|
|
|
// init resource for constructing forward graph and grad graph
|
|
|
|
|
auto g = std::make_shared<FuncGraph>();
|
|
|
|
|
curr_g_ = g;
|
|
|
|
|
ClearResidualRes();
|
|
|
|
|
ClearResidualRes(cell_id);
|
|
|
|
|
if (graph_stack_.empty() && !IsBpropGraph(cell_id)) {
|
|
|
|
|
MakeNewTopGraph(cell_id, args, g);
|
|
|
|
|
}
|
|
|
|
@ -2030,10 +2033,6 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
|
|
|
|
|
|
|
|
|
|
// Set all params(input+weights)
|
|
|
|
|
SetGradGraphParams(df_builder, resource, size);
|
|
|
|
|
// Clone df_builder and resource at first time
|
|
|
|
|
if (CloneDfbuiler(cell_id, df_builder, resource)) {
|
|
|
|
|
df_builder = GetDfbuilder(cell_id);
|
|
|
|
|
}
|
|
|
|
|
// Get params(weights) require derivative
|
|
|
|
|
auto w_args = GetWeightsArgs(weights, df_builder);
|
|
|
|
|
// Get the parameters items and add the value to args_spec
|
|
|
|
@ -2239,26 +2238,6 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
|
|
|
|
|
return args_spec;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PynativeExecutor::CloneDfbuiler(const std::string &cell_id, const FuncGraphPtr &df_builder,
|
|
|
|
|
const ResourcePtr &resource) {
|
|
|
|
|
bool is_cloned = false;
|
|
|
|
|
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
|
|
|
|
|
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; });
|
|
|
|
|
if (it == top_cell_list_.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Get top cell failed";
|
|
|
|
|
}
|
|
|
|
|
if (it->bg == nullptr) {
|
|
|
|
|
auto cloned_df_newfg = BasicClone(resource->func_graph());
|
|
|
|
|
it->bg = cloned_df_newfg;
|
|
|
|
|
MS_LOG(DEBUG) << "Cloned df newfg";
|
|
|
|
|
is_cloned = false;
|
|
|
|
|
} else {
|
|
|
|
|
resource->set_func_graph(it->bg);
|
|
|
|
|
MS_LOG(DEBUG) << "Used cloned df newfg";
|
|
|
|
|
}
|
|
|
|
|
return is_cloned;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op,
|
|
|
|
|
const std::vector<AnfNodePtr> &weights, size_t arg_size, const std::string &cell_id) {
|
|
|
|
|
FuncGraphPtr top_g = nullptr;
|
|
|
|
@ -2433,28 +2412,6 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
|
|
|
|
|
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void MapClear(T *map, const std::string &cell_id) {
|
|
|
|
|
for (auto it = map->begin(); it != map->end();) {
|
|
|
|
|
if (it->first.find(cell_id) != std::string::npos) {
|
|
|
|
|
it = map->erase(it);
|
|
|
|
|
} else {
|
|
|
|
|
it++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void VectorClear(T *vec, const std::string &cell_id) {
|
|
|
|
|
for (auto it = vec->begin(); it != vec->end();) {
|
|
|
|
|
if (it->cell_id.find(cell_id) != std::string::npos) {
|
|
|
|
|
it = vec->erase(it);
|
|
|
|
|
} else {
|
|
|
|
|
it++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::Clear(const std::string &cell_id) {
|
|
|
|
|
if (cell_id.empty()) {
|
|
|
|
|
Clean();
|
|
|
|
|