!1359 Optimize the IR modules.

Merge pull request !1359 from ZhangQinghua/master
pull/1359/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 848d19207f

@ -29,6 +29,7 @@
#include "utils/visible.h"
#include "utils/log_adapter.h"
#include "utils/ordered_set.h"
#include "utils/ordered_map.h"
namespace mindspore {
template <typename T>

File diff suppressed because it is too large Load Diff

@ -26,6 +26,7 @@
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include "ir/anf.h"
#include "ir/manager.h"
@ -36,8 +37,13 @@
namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>;
using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher, CNodeIndexEqual>;
template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>;
using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
@ -183,12 +189,24 @@ class FuncGraph : public FuncGraphBase {
// get all nodes belonging to this func graph
const AnfNodeSet &nodes();
void CopyNodes(const FuncGraphPtr &source);
void ClearNodes();
void AddNode(AnfNodePtr node);
void DropNode(AnfNodePtr node);
// get all value_nodes belonging to this func graph
const AnfNodeCounterMap &value_nodes();
// get all vars directly pointed to in this func graph
const AnfNodeCounterMap &free_variables_direct();
void CopyValueNodes(const FuncGraphPtr &source);
void ClearValueNodes();
void AddValueNode(AnfNodePtr node, int count = 1);
void DropValueNode(AnfNodePtr node);
// get all free vars directly used in this func graph
const AnfNodeCounterMap &free_variables();
void CopyFreeVariables(const FuncGraphPtr &source);
void ClearFreeVariables();
bool AddFreeVariable(AnfNodePtr node, int count = 1);
bool DropFreeVariable(AnfNodePtr node);
// get all vars required by this func graph
const BaseRefCounterMap &free_variables_total();
@ -199,14 +217,29 @@ class FuncGraph : public FuncGraphBase {
// get all vars that are func graphs
std::vector<FuncGraphPtr> free_variables_func_graphs();
// get all func graphs directly used by this func graph
// get all value nodes of func graph directly used by this func graph
const FuncGraphCounterMap &func_graphs_used();
void CopyFuncGraphsUsed(const FuncGraphPtr &source);
void ClearFuncGraphsUsed();
bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1);
bool DropFuncGraphUsed(FuncGraphPtr fg);
// get all value nodes of J func graph directly used by this func graph
const FuncGraphCounterMap &j_func_graphs();
void CopyJFuncGraphs(const FuncGraphPtr &source);
void ClearJFuncGraphs();
void AddJFuncGraph(FuncGraphPtr fg, int count = 1);
void DropJFuncGraph(FuncGraphPtr fg);
// get all func graphs nested used by this func graph
const FuncGraphSet &func_graphs_used_total();
// get all user value nodes of this func graph
// get all user value nodes of this func graph, by CNode and its input's index
const CNodeIndexCounterMap &func_graph_cnodes_index();
void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
void ClearFuncGraphCNodesIndex();
void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1);
void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node);
// Return the parent of this graph.
FuncGraphPtr parent();
@ -256,6 +289,7 @@ class FuncGraph : public FuncGraphBase {
// parameter default value
std::map<std::string, AnfNodePtr> parameter_default_value_;
std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
size_t seen_;
std::list<CNodePtr> GetOrderedCnodes();
void EraseUnusedNodeInOrder(const AnfNodePtr &n);
@ -270,6 +304,24 @@ class FuncGraph : public FuncGraphBase {
// graph is manipulated by manager and others
friend FuncGraphManager;
// all nodes of the function
AnfNodeSet nodes_;
// all value nodes of the function
AnfNodeCounterMap value_nodes_;
// all func graph value nodes of the function
FuncGraphCounterMap func_graphs_used_;
// all free variables of the function
AnfNodeCounterMap free_variables_;
// all value nodes calling J in the function
FuncGraphCounterMap j_func_graphs_;
// all user value nodes of this func graph, recording by CNode and its input's index
CNodeIndexCounterMap func_graph_cnodes_index_;
// parameters of this function
std::vector<AnfNodePtr> parameters_;
std::vector<AnfNodePtr> paramter_obj_nodes_;
@ -313,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphP
return fg->NewCNode(inputs);
}
size_t NewFgSeenGeneration();
// Find the root cnodes of a segment of cnodes.
std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
// Find the leaf cnodes of a segment of cnodes.

@ -123,7 +123,7 @@ void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
if (!clone_all_valuenodes_) {
return;
}
auto &value_nodes = manager_->valuenodes()[func_graph];
auto &value_nodes = func_graph->value_nodes();
for (auto &value_node : value_nodes) {
auto old_node = value_node.first;
MS_EXCEPTION_IF_NULL(old_node);
@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
if (!clone_all_used_graphs_) {
return;
}
auto &used_graphs = manager_->func_graphs_used()[func_graph];
for (auto &used_graph : used_graphs) {
todo_.push_back({used_graph.first, nullptr, {}});
auto &used = func_graph->func_graphs_used();
for (auto &fg : used) {
todo_.push_back({fg.first, nullptr, {}});
}
}
@ -185,7 +185,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
}
target_func_graph->set_return(return_node);
auto &cnodes = manager_->func_graph_cnodes_index()[func_graph];
auto &cnodes = func_graph->func_graph_cnodes_index();
for (auto &cnode : cnodes) {
auto parent = cnode.first->first->cast<CNodePtr>();
auto valuenode = parent->input(cnode.first->second);
@ -441,7 +441,7 @@ void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &t
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph);
MS_EXCEPTION_IF_NULL(manager_);
const AnfNodeSet &nodes = manager_->nodes()[func_graph];
const AnfNodeSet &nodes = func_graph->nodes();
for (auto &node : nodes) {
CloneNode(node, target_func_graph);
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -492,7 +492,7 @@ void DFunctor::MapParamObject() {
void DFunctor::MapValueObject() {
// Map ValueNode.
auto manager = resources_->manager();
auto &value_nodes = manager->valuenodes()[primal_graph_];
auto &value_nodes = primal_graph_->value_nodes();
for (const auto &value_pair : value_nodes) {
auto node = value_pair.first;
auto parent_adjoint = FindAdjoint(node);

@ -119,7 +119,7 @@ FuncGraphPtr TransformGraphCondBranchNodes(
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node;
// record the node input to be replaced
NodeInputReplMap repl_node_inputs;
const AnfNodeSet &nodes = manager->nodes()[graph];
const AnfNodeSet &nodes = graph->nodes();
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
@ -436,7 +436,7 @@ FuncGraphPtr TransformGraphDependNode(
ResetSharedOp();
std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_node =
std::make_shared<std::unordered_map<AnfNodePtr, AnfNodePtr>>(); // record the node to be replaced
const AnfNodeSet &nodes = manager->nodes()[graph];
const AnfNodeSet &nodes = graph->nodes();
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {

@ -391,7 +391,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
auto manager = res->manager();
// Remove duplicated value nodes, due to replace operation, can't use reference.
auto value_nodes = manager->valuenodes()[func_graph];
auto value_nodes = func_graph->value_nodes();
HashCache hash_cache;
HashValue hashes;
for (const auto &value_pair : value_nodes) {

@ -488,12 +488,12 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {
void TraverseGraphMap(
const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr,
const FuncGraphToAnfNodeCounterMap<AnfNodePtr> &cts,
const FuncGraphSet &fgs,
const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
MS_EXCEPTION_IF_NULL(manager_ptr);
MS_EXCEPTION_IF_NULL(tr);
for (const auto &ct_graphs : cts) {
for (const auto &ct_any : ct_graphs.second) {
for (const auto &fg : fgs) {
for (const auto &ct_any : fg->value_nodes()) {
AnfNodePtr const_primitive_node = ct_any.first;
if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
auto users = manager_ptr->node_users()[const_primitive_node];
@ -553,8 +553,8 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
};
FuncGraphTransaction tr = manager_ptr->Transact();
auto &cts = manager_ptr->valuenodes();
TraverseGraphMap(manager_ptr, &tr, cts, get_prim_graph);
auto &fgs = manager_ptr->func_graphs();
TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
return graph;
}

@ -132,18 +132,6 @@ class NestingSpecs {
CheckAnfNodeCounter(counter_p);
return;
}
auto counter_pair = dynamic_pointer_cast<CounterAnfNodeCollector<CNodeIndexPairPtr>>(results);
if (counter_pair != nullptr) {
CheckCNodeIndexPairCounter(counter_pair);
return;
}
auto nodes = dynamic_pointer_cast<NodesCollector>(results);
if (nodes != nullptr) {
CheckNodes(nodes);
return;
}
}
private:
@ -205,33 +193,7 @@ class NestingSpecs {
ASSERT_EQ(clean_results, expected_);
}
void CheckNodes(std::shared_ptr<NodesCollector> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->nodes_analysis()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);
std::set<std::string> v;
for (auto& node : value) {
if (!node->isa<CNode>() && !Name(node).empty()) {
v.insert(Name(node));
}
}
if (!v.empty()) {
clean_results[k] = v;
}
}
ASSERT_EQ(clean_results, expected_);
}
// Add CheckNesting function
void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_nodes_map()) {
@ -258,32 +220,6 @@ class NestingSpecs {
ASSERT_EQ(clean_results, expected_);
}
void CheckCNodeIndexPairCounter(std::shared_ptr<CounterAnfNodeCollector<CNodeIndexPairPtr>> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_nodes_map()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);
std::set<std::string> v;
for (auto& node : value) {
auto fg = node.first->first;
if (!Name(fg).empty()) {
v.insert(Name(fg));
}
}
if (!v.empty()) {
clean_results[k] = v;
}
}
ASSERT_EQ(clean_results, expected_);
}
void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_func_graphs_map()) {
@ -471,17 +407,10 @@ std::vector<FuncGraphPtr> MakeNestedGraph2() {
}
// Add TestManager::CheckManager function to checkout the result
void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
auto size = mng->func_graphs().size();
ASSERT_EQ(size + 1, mng->nodes().size());
ASSERT_EQ(size, mng->free_variables_total().size());
ASSERT_EQ(size, mng->valuenodes().size());
ASSERT_EQ(size, mng->free_variables_direct().size());
ASSERT_EQ(size, mng->func_graph_cnodes_index().size());
ASSERT_EQ(size, mng->func_graph_parents_direct().size());
ASSERT_EQ(size, mng->func_graphs_used().size());
}
TEST_F(TestManager, test_scalar_add_manual) {
@ -525,31 +454,26 @@ TEST_F(TestManager, test_nested_manual) {
ASSERT_EQ(1, mng->roots().size());
CheckAnalysisSize(mng);
auto nodes = mng->nodes();
ASSERT_EQ(3, nodes[nullptr].size());
ASSERT_EQ(2, nodes[f].size());
ASSERT_EQ(1, nodes[g].size());
ASSERT_EQ(2, f->nodes().size());
ASSERT_EQ(1, g->nodes().size());
auto users = mng->node_users();
for (auto& iter : users) {
ASSERT_EQ(1, iter.second.size());
}
auto graphs_used = mng->func_graphs_used();
ASSERT_EQ(1, graphs_used[f].size());
ASSERT_EQ(0, graphs_used[g].size());
ASSERT_EQ(1, f->func_graphs_used().size());
ASSERT_EQ(0, g->func_graphs_used().size());
auto fv_direct = mng->free_variables_direct();
ASSERT_EQ(0, fv_direct[f].size());
ASSERT_EQ(1, fv_direct[g].size());
ASSERT_EQ(0, f->free_variables().size());
ASSERT_EQ(1, g->free_variables().size());
auto fv_total = mng->free_variables_total();
ASSERT_EQ(0, fv_total[f].size());
ASSERT_EQ(1, fv_total[g].size());
auto cnodes = mng->func_graph_cnodes_index();
ASSERT_EQ(0, cnodes[f].size());
ASSERT_EQ(1, cnodes[g].size());
ASSERT_EQ(0, f->func_graph_cnodes_index().size());
ASSERT_EQ(1, g->func_graph_cnodes_index().size());
}
TEST_F(TestManager, test_deep_nested2_manual) {
@ -567,7 +491,7 @@ TEST_F(TestManager, test_deep_nested2_manual) {
ASSERT_EQ(3, mng->func_graphs().size());
ASSERT_EQ(1, mng->roots().size());
ASSERT_EQ(4, mng->nodes().size());
ASSERT_EQ(4, gfn->nodes().size());
ASSERT_EQ(20, mng->all_nodes().size());
ASSERT_EQ(25, mng->node_users().size());
CheckAnalysisSize(mng);
@ -631,7 +555,6 @@ TEST_F(TestManager, test_deep_nested_manual) {
ASSERT_EQ(3, mng->func_graphs().size());
ASSERT_EQ(1, mng->roots().size());
ASSERT_EQ(4, mng->nodes().size());
ASSERT_EQ(20, mng->all_nodes().size());
CheckAnalysisSize(mng);
}
@ -716,12 +639,12 @@ TEST_F(TestManager, test_drop_root) {
FuncGraphPtr fg = getPyFun("ir_get_fn");
auto mng = Manage(fg);
const FuncGraphToAnfNodeMap& nodes = mng->nodes();
ASSERT_TRUE(nodes.find(fg) != nodes.end());
const auto &fgs = mng->func_graphs();
ASSERT_TRUE(fgs.contains(fg));
FuncGraphSet s;
s.add(fg);
mng->MaybeDropFuncGraphs(s);
ASSERT_TRUE(nodes.find(fg) != nodes.end());
ASSERT_TRUE(fgs.contains(fg));
}
TEST_F(TestManager, test_keep_roots) {

@ -26,15 +26,14 @@
namespace mindspore {
void CheckNoFreeVariables(FuncGraphPtr root) {
auto mng = Manage(root);
for (auto &iter : mng->nodes()) {
auto g = iter.first;
auto nodes = iter.second;
for (auto &iter : mng->func_graphs()) {
auto g = iter;
if (g == nullptr) {
continue;
}
ASSERT_TRUE(g->parent() == nullptr);
auto nodes = g->nodes();
for (auto &node : nodes) {
ASSERT_EQ(node->func_graph(), g);
auto cnode = node->cast<CNodePtr>();

Loading…
Cancel
Save