!1359 Optimize the IR modules.

Merge pull request !1359 from ZhangQinghua/master
pull/1359/MERGE
mindspore-ci-bot 6 years ago committed by Gitee
commit 848d19207f

@ -29,6 +29,7 @@
#include "utils/visible.h" #include "utils/visible.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/ordered_set.h" #include "utils/ordered_set.h"
#include "utils/ordered_map.h"
namespace mindspore { namespace mindspore {
template <typename T> template <typename T>

File diff suppressed because it is too large Load Diff

@ -26,6 +26,7 @@
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <functional>
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/manager.h" #include "ir/manager.h"
@ -36,8 +37,13 @@
namespace mindspore { namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>; using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>;
using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher, CNodeIndexEqual>; template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>;
using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
@ -183,12 +189,24 @@ class FuncGraph : public FuncGraphBase {
// get all nodes belonging to this func graph // get all nodes belonging to this func graph
const AnfNodeSet &nodes(); const AnfNodeSet &nodes();
void CopyNodes(const FuncGraphPtr &source);
void ClearNodes();
void AddNode(AnfNodePtr node);
void DropNode(AnfNodePtr node);
// get all value_nodes belonging to this func graph // get all value_nodes belonging to this func graph
const AnfNodeCounterMap &value_nodes(); const AnfNodeCounterMap &value_nodes();
void CopyValueNodes(const FuncGraphPtr &source);
// get all vars directly pointed to in this func graph void ClearValueNodes();
const AnfNodeCounterMap &free_variables_direct(); void AddValueNode(AnfNodePtr node, int count = 1);
void DropValueNode(AnfNodePtr node);
// get all free vars directly used in this func graph
const AnfNodeCounterMap &free_variables();
void CopyFreeVariables(const FuncGraphPtr &source);
void ClearFreeVariables();
bool AddFreeVariable(AnfNodePtr node, int count = 1);
bool DropFreeVariable(AnfNodePtr node);
// get all vars required by this func graph // get all vars required by this func graph
const BaseRefCounterMap &free_variables_total(); const BaseRefCounterMap &free_variables_total();
@ -199,14 +217,29 @@ class FuncGraph : public FuncGraphBase {
// get all vars that are func graphs // get all vars that are func graphs
std::vector<FuncGraphPtr> free_variables_func_graphs(); std::vector<FuncGraphPtr> free_variables_func_graphs();
// get all func graphs directly used by this func graph // get all value nodes of func graph directly used by this func graph
const FuncGraphCounterMap &func_graphs_used(); const FuncGraphCounterMap &func_graphs_used();
void CopyFuncGraphsUsed(const FuncGraphPtr &source);
void ClearFuncGraphsUsed();
bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1);
bool DropFuncGraphUsed(FuncGraphPtr fg);
// get all value nodes of J func graph directly used by this func graph
const FuncGraphCounterMap &j_func_graphs();
void CopyJFuncGraphs(const FuncGraphPtr &source);
void ClearJFuncGraphs();
void AddJFuncGraph(FuncGraphPtr fg, int count = 1);
void DropJFuncGraph(FuncGraphPtr fg);
// get all func graphs nested used by this func graph // get all func graphs nested used by this func graph
const FuncGraphSet &func_graphs_used_total(); const FuncGraphSet &func_graphs_used_total();
// get all user value nodes of this func graph // get all user value nodes of this func graph, by CNode and its input's index
const CNodeIndexCounterMap &func_graph_cnodes_index(); const CNodeIndexCounterMap &func_graph_cnodes_index();
void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
void ClearFuncGraphCNodesIndex();
void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1);
void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node);
// Return the parent of this graph. // Return the parent of this graph.
FuncGraphPtr parent(); FuncGraphPtr parent();
@ -256,6 +289,7 @@ class FuncGraph : public FuncGraphBase {
// parameter default value // parameter default value
std::map<std::string, AnfNodePtr> parameter_default_value_; std::map<std::string, AnfNodePtr> parameter_default_value_;
std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_; std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
size_t seen_;
std::list<CNodePtr> GetOrderedCnodes(); std::list<CNodePtr> GetOrderedCnodes();
void EraseUnusedNodeInOrder(const AnfNodePtr &n); void EraseUnusedNodeInOrder(const AnfNodePtr &n);
@ -270,6 +304,24 @@ class FuncGraph : public FuncGraphBase {
// graph is manipulated by manager and others // graph is manipulated by manager and others
friend FuncGraphManager; friend FuncGraphManager;
// all nodes of the function
AnfNodeSet nodes_;
// all value nodes of the function
AnfNodeCounterMap value_nodes_;
// all func graph value nodes of the function
FuncGraphCounterMap func_graphs_used_;
// all free variables of the function
AnfNodeCounterMap free_variables_;
// all value nodes calling J in the function
FuncGraphCounterMap j_func_graphs_;
// all user value nodes of this func graph, recording by CNode and its input's index
CNodeIndexCounterMap func_graph_cnodes_index_;
// parameters of this function // parameters of this function
std::vector<AnfNodePtr> parameters_; std::vector<AnfNodePtr> parameters_;
std::vector<AnfNodePtr> paramter_obj_nodes_; std::vector<AnfNodePtr> paramter_obj_nodes_;
@ -313,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphP
return fg->NewCNode(inputs); return fg->NewCNode(inputs);
} }
size_t NewFgSeenGeneration();
// Find the root cnodes of a segment of cnodes. // Find the root cnodes of a segment of cnodes.
std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment); std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
// Find the leaf cnodes of a segment of cnodes. // Find the leaf cnodes of a segment of cnodes.

@ -123,7 +123,7 @@ void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
if (!clone_all_valuenodes_) { if (!clone_all_valuenodes_) {
return; return;
} }
auto &value_nodes = manager_->valuenodes()[func_graph]; auto &value_nodes = func_graph->value_nodes();
for (auto &value_node : value_nodes) { for (auto &value_node : value_nodes) {
auto old_node = value_node.first; auto old_node = value_node.first;
MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(old_node);
@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
if (!clone_all_used_graphs_) { if (!clone_all_used_graphs_) {
return; return;
} }
auto &used_graphs = manager_->func_graphs_used()[func_graph]; auto &used = func_graph->func_graphs_used();
for (auto &used_graph : used_graphs) { for (auto &fg : used) {
todo_.push_back({used_graph.first, nullptr, {}}); todo_.push_back({fg.first, nullptr, {}});
} }
} }
@ -185,7 +185,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
} }
target_func_graph->set_return(return_node); target_func_graph->set_return(return_node);
auto &cnodes = manager_->func_graph_cnodes_index()[func_graph]; auto &cnodes = func_graph->func_graph_cnodes_index();
for (auto &cnode : cnodes) { for (auto &cnode : cnodes) {
auto parent = cnode.first->first->cast<CNodePtr>(); auto parent = cnode.first->first->cast<CNodePtr>();
auto valuenode = parent->input(cnode.first->second); auto valuenode = parent->input(cnode.first->second);
@ -441,7 +441,7 @@ void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &t
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(target_func_graph);
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
const AnfNodeSet &nodes = manager_->nodes()[func_graph]; const AnfNodeSet &nodes = func_graph->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
CloneNode(node, target_func_graph); CloneNode(node, target_func_graph);
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -492,7 +492,7 @@ void DFunctor::MapParamObject() {
void DFunctor::MapValueObject() { void DFunctor::MapValueObject() {
// Map ValueNode. // Map ValueNode.
auto manager = resources_->manager(); auto manager = resources_->manager();
auto &value_nodes = manager->valuenodes()[primal_graph_]; auto &value_nodes = primal_graph_->value_nodes();
for (const auto &value_pair : value_nodes) { for (const auto &value_pair : value_nodes) {
auto node = value_pair.first; auto node = value_pair.first;
auto parent_adjoint = FindAdjoint(node); auto parent_adjoint = FindAdjoint(node);

@ -119,7 +119,7 @@ FuncGraphPtr TransformGraphCondBranchNodes(
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node; std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node;
// record the node input to be replaced // record the node input to be replaced
NodeInputReplMap repl_node_inputs; NodeInputReplMap repl_node_inputs;
const AnfNodeSet &nodes = manager->nodes()[graph]; const AnfNodeSet &nodes = graph->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
@ -436,7 +436,7 @@ FuncGraphPtr TransformGraphDependNode(
ResetSharedOp(); ResetSharedOp();
std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_node = std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_node =
std::make_shared<std::unordered_map<AnfNodePtr, AnfNodePtr>>(); // record the node to be replaced std::make_shared<std::unordered_map<AnfNodePtr, AnfNodePtr>>(); // record the node to be replaced
const AnfNodeSet &nodes = manager->nodes()[graph]; const AnfNodeSet &nodes = graph->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {

@ -391,7 +391,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();
auto manager = res->manager(); auto manager = res->manager();
// Remove duplicated value nodes, due to replace operation, can't use reference. // Remove duplicated value nodes, due to replace operation, can't use reference.
auto value_nodes = manager->valuenodes()[func_graph]; auto value_nodes = func_graph->value_nodes();
HashCache hash_cache; HashCache hash_cache;
HashValue hashes; HashValue hashes;
for (const auto &value_pair : value_nodes) { for (const auto &value_pair : value_nodes) {

@ -488,12 +488,12 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {
void TraverseGraphMap( void TraverseGraphMap(
const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr,
const FuncGraphToAnfNodeCounterMap<AnfNodePtr> &cts, const FuncGraphSet &fgs,
const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
MS_EXCEPTION_IF_NULL(manager_ptr); MS_EXCEPTION_IF_NULL(manager_ptr);
MS_EXCEPTION_IF_NULL(tr); MS_EXCEPTION_IF_NULL(tr);
for (const auto &ct_graphs : cts) { for (const auto &fg : fgs) {
for (const auto &ct_any : ct_graphs.second) { for (const auto &ct_any : fg->value_nodes()) {
AnfNodePtr const_primitive_node = ct_any.first; AnfNodePtr const_primitive_node = ct_any.first;
if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) { if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
auto users = manager_ptr->node_users()[const_primitive_node]; auto users = manager_ptr->node_users()[const_primitive_node];
@ -553,8 +553,8 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
}; };
FuncGraphTransaction tr = manager_ptr->Transact(); FuncGraphTransaction tr = manager_ptr->Transact();
auto &cts = manager_ptr->valuenodes(); auto &fgs = manager_ptr->func_graphs();
TraverseGraphMap(manager_ptr, &tr, cts, get_prim_graph); TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
return graph; return graph;
} }

@ -132,18 +132,6 @@ class NestingSpecs {
CheckAnfNodeCounter(counter_p); CheckAnfNodeCounter(counter_p);
return; return;
} }
auto counter_pair = dynamic_pointer_cast<CounterAnfNodeCollector<CNodeIndexPairPtr>>(results);
if (counter_pair != nullptr) {
CheckCNodeIndexPairCounter(counter_pair);
return;
}
auto nodes = dynamic_pointer_cast<NodesCollector>(results);
if (nodes != nullptr) {
CheckNodes(nodes);
return;
}
} }
private: private:
@ -205,33 +193,7 @@ class NestingSpecs {
ASSERT_EQ(clean_results, expected_); ASSERT_EQ(clean_results, expected_);
} }
void CheckNodes(std::shared_ptr<NodesCollector> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->nodes_analysis()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);
std::set<std::string> v;
for (auto& node : value) {
if (!node->isa<CNode>() && !Name(node).empty()) {
v.insert(Name(node));
}
}
if (!v.empty()) {
clean_results[k] = v;
}
}
ASSERT_EQ(clean_results, expected_);
}
// Add CheckNesting function // Add CheckNesting function
void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) { void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) {
std::map<std::string, std::set<std::string>> clean_results; std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_nodes_map()) { for (auto& iter : results->count_nodes_map()) {
@ -258,32 +220,6 @@ class NestingSpecs {
ASSERT_EQ(clean_results, expected_); ASSERT_EQ(clean_results, expected_);
} }
void CheckCNodeIndexPairCounter(std::shared_ptr<CounterAnfNodeCollector<CNodeIndexPairPtr>> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_nodes_map()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);
std::set<std::string> v;
for (auto& node : value) {
auto fg = node.first->first;
if (!Name(fg).empty()) {
v.insert(Name(fg));
}
}
if (!v.empty()) {
clean_results[k] = v;
}
}
ASSERT_EQ(clean_results, expected_);
}
void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) { void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
std::map<std::string, std::set<std::string>> clean_results; std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_func_graphs_map()) { for (auto& iter : results->count_func_graphs_map()) {
@ -471,17 +407,10 @@ std::vector<FuncGraphPtr> MakeNestedGraph2() {
} }
// Add TestManager::CheckManager function to checkout the result // Add TestManager::CheckManager function to checkout the result
void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) { void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
auto size = mng->func_graphs().size(); auto size = mng->func_graphs().size();
ASSERT_EQ(size + 1, mng->nodes().size());
ASSERT_EQ(size, mng->free_variables_total().size()); ASSERT_EQ(size, mng->free_variables_total().size());
ASSERT_EQ(size, mng->valuenodes().size());
ASSERT_EQ(size, mng->free_variables_direct().size());
ASSERT_EQ(size, mng->func_graph_cnodes_index().size());
ASSERT_EQ(size, mng->func_graph_parents_direct().size());
ASSERT_EQ(size, mng->func_graphs_used().size());
} }
TEST_F(TestManager, test_scalar_add_manual) { TEST_F(TestManager, test_scalar_add_manual) {
@ -525,31 +454,26 @@ TEST_F(TestManager, test_nested_manual) {
ASSERT_EQ(1, mng->roots().size()); ASSERT_EQ(1, mng->roots().size());
CheckAnalysisSize(mng); CheckAnalysisSize(mng);
auto nodes = mng->nodes(); ASSERT_EQ(2, f->nodes().size());
ASSERT_EQ(3, nodes[nullptr].size()); ASSERT_EQ(1, g->nodes().size());
ASSERT_EQ(2, nodes[f].size());
ASSERT_EQ(1, nodes[g].size());
auto users = mng->node_users(); auto users = mng->node_users();
for (auto& iter : users) { for (auto& iter : users) {
ASSERT_EQ(1, iter.second.size()); ASSERT_EQ(1, iter.second.size());
} }
auto graphs_used = mng->func_graphs_used(); ASSERT_EQ(1, f->func_graphs_used().size());
ASSERT_EQ(1, graphs_used[f].size()); ASSERT_EQ(0, g->func_graphs_used().size());
ASSERT_EQ(0, graphs_used[g].size());
auto fv_direct = mng->free_variables_direct(); ASSERT_EQ(0, f->free_variables().size());
ASSERT_EQ(0, fv_direct[f].size()); ASSERT_EQ(1, g->free_variables().size());
ASSERT_EQ(1, fv_direct[g].size());
auto fv_total = mng->free_variables_total(); auto fv_total = mng->free_variables_total();
ASSERT_EQ(0, fv_total[f].size()); ASSERT_EQ(0, fv_total[f].size());
ASSERT_EQ(1, fv_total[g].size()); ASSERT_EQ(1, fv_total[g].size());
auto cnodes = mng->func_graph_cnodes_index(); ASSERT_EQ(0, f->func_graph_cnodes_index().size());
ASSERT_EQ(0, cnodes[f].size()); ASSERT_EQ(1, g->func_graph_cnodes_index().size());
ASSERT_EQ(1, cnodes[g].size());
} }
TEST_F(TestManager, test_deep_nested2_manual) { TEST_F(TestManager, test_deep_nested2_manual) {
@ -567,7 +491,7 @@ TEST_F(TestManager, test_deep_nested2_manual) {
ASSERT_EQ(3, mng->func_graphs().size()); ASSERT_EQ(3, mng->func_graphs().size());
ASSERT_EQ(1, mng->roots().size()); ASSERT_EQ(1, mng->roots().size());
ASSERT_EQ(4, mng->nodes().size()); ASSERT_EQ(4, gfn->nodes().size());
ASSERT_EQ(20, mng->all_nodes().size()); ASSERT_EQ(20, mng->all_nodes().size());
ASSERT_EQ(25, mng->node_users().size()); ASSERT_EQ(25, mng->node_users().size());
CheckAnalysisSize(mng); CheckAnalysisSize(mng);
@ -631,7 +555,6 @@ TEST_F(TestManager, test_deep_nested_manual) {
ASSERT_EQ(3, mng->func_graphs().size()); ASSERT_EQ(3, mng->func_graphs().size());
ASSERT_EQ(1, mng->roots().size()); ASSERT_EQ(1, mng->roots().size());
ASSERT_EQ(4, mng->nodes().size());
ASSERT_EQ(20, mng->all_nodes().size()); ASSERT_EQ(20, mng->all_nodes().size());
CheckAnalysisSize(mng); CheckAnalysisSize(mng);
} }
@ -716,12 +639,12 @@ TEST_F(TestManager, test_drop_root) {
FuncGraphPtr fg = getPyFun("ir_get_fn"); FuncGraphPtr fg = getPyFun("ir_get_fn");
auto mng = Manage(fg); auto mng = Manage(fg);
const FuncGraphToAnfNodeMap& nodes = mng->nodes(); const auto &fgs = mng->func_graphs();
ASSERT_TRUE(nodes.find(fg) != nodes.end()); ASSERT_TRUE(fgs.contains(fg));
FuncGraphSet s; FuncGraphSet s;
s.add(fg); s.add(fg);
mng->MaybeDropFuncGraphs(s); mng->MaybeDropFuncGraphs(s);
ASSERT_TRUE(nodes.find(fg) != nodes.end()); ASSERT_TRUE(fgs.contains(fg));
} }
TEST_F(TestManager, test_keep_roots) { TEST_F(TestManager, test_keep_roots) {

@ -26,15 +26,14 @@
namespace mindspore { namespace mindspore {
void CheckNoFreeVariables(FuncGraphPtr root) { void CheckNoFreeVariables(FuncGraphPtr root) {
auto mng = Manage(root); auto mng = Manage(root);
for (auto &iter : mng->nodes()) { for (auto &iter : mng->func_graphs()) {
auto g = iter.first; auto g = iter;
auto nodes = iter.second;
if (g == nullptr) { if (g == nullptr) {
continue; continue;
} }
ASSERT_TRUE(g->parent() == nullptr); ASSERT_TRUE(g->parent() == nullptr);
auto nodes = g->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
ASSERT_EQ(node->func_graph(), g); ASSERT_EQ(node->func_graph(), g);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();

Loading…
Cancel
Save