diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index d813a958b6..3702ae13ef 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -34,6 +34,7 @@ #include "utils/ordered_map.h" #include "base/base_ref.h" #include "ir/func_graph_cloner.h" +#include "abstract/abstract_value.h" namespace mindspore { using BaseRefCounterMap = OrderedMap; @@ -417,6 +418,9 @@ class FuncGraph : public FuncGraphBase { // Design switch_layer_input as a ptr to // share between derived backpropagator and cloned graphs std::shared_ptr switch_layer_input_; + std::unordered_map + func_graph_cache_; }; inline CNodePtr NewCNode(const std::vector &inputs, const FuncGraphPtr &fg) { diff --git a/mindspore/core/ir/func_graph_extends.cc b/mindspore/core/ir/func_graph_extends.cc index 217cb5adf6..8095a13102 100644 --- a/mindspore/core/ir/func_graph_extends.cc +++ b/mindspore/core/ir/func_graph_extends.cc @@ -245,6 +245,10 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) if (!NeedGenerate(kwarg_list)) { return shared_from_base(); } + auto iter = func_graph_cache_.find(args_spec_list); + if (iter != func_graph_cache_.end()) { + return iter->second; + } FuncGraphPtr specialized_graph = BasicClone(shared_from_base()); size_t kwarg_count = kwarg_list.size(); int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_); @@ -290,6 +294,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) specialized_graph->set_kwonlyargs_count(0); specialized_graph->ClearDefaultValues(); specialized_graph->set_is_generate(true); + func_graph_cache_[args_spec_list] = specialized_graph; return specialized_graph; }