|
|
|
@ -802,6 +802,24 @@ class ExecuteOrderGenerator {
|
|
|
|
|
graph_->set_execution_order(std::move(execution_order));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::set<CNodePtr> GetAllNodes() {
|
|
|
|
|
auto &all_graphs = context_.visited_graphs();
|
|
|
|
|
std::set<CNodePtr> all_nodes;
|
|
|
|
|
for (auto &graph : all_graphs) {
|
|
|
|
|
auto out = graph->get_return();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(out);
|
|
|
|
|
auto nodes = TopoSort(out);
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (cnode != nullptr) {
|
|
|
|
|
all_nodes.insert(cnode);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return all_nodes;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static const AnfNodePtr &GetRealNode(const AnfNodePtr &input) {
|
|
|
|
|
if (IsPrimitiveCNode(input, prim::kPrimLoad) || IsPrimitiveCNode(input, prim::kPrimDepend)) {
|
|
|
|
|
return input->cast<CNodePtr>()->inputs().at(1);
|
|
|
|
@ -813,6 +831,7 @@ class ExecuteOrderGenerator {
|
|
|
|
|
void EraseParameter() {
|
|
|
|
|
// Copy out execution order list.
|
|
|
|
|
auto exec_order = graph_->execution_order();
|
|
|
|
|
std::set<CNodePtr> all_nodes = GetAllNodes();
|
|
|
|
|
|
|
|
|
|
// Remove assigns that target and source are same.
|
|
|
|
|
for (auto iter = exec_order.begin(); iter != exec_order.end();) {
|
|
|
|
@ -844,6 +863,18 @@ class ExecuteOrderGenerator {
|
|
|
|
|
auto kg = target->func_graph()->cast<KernelGraphPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kg);
|
|
|
|
|
kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source));
|
|
|
|
|
|
|
|
|
|
// replace parameter in node
|
|
|
|
|
for (auto &iter_node : all_nodes) {
|
|
|
|
|
for (size_t i = 0; i < iter_node->size(); ++i) {
|
|
|
|
|
if (iter_node->input(i) == target) {
|
|
|
|
|
MS_LOG(INFO) << "Replace " << iter_node->DebugString() << " input " << i << " by "
|
|
|
|
|
<< source->DebugString();
|
|
|
|
|
iter_node->set_input(i, source);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// replace parameter in graph input
|
|
|
|
|
auto &all_graphs = context_.visited_graphs();
|
|
|
|
|
for (auto &g : all_graphs) {
|
|
|
|
|