add func graph cache avoid repeat func graphs

pull/7814/head
yao_yf 4 years ago
parent 65d8e63580
commit 1a2fd0e0b0

@ -34,6 +34,7 @@
#include "utils/ordered_map.h"
#include "base/base_ref.h"
#include "ir/func_graph_cloner.h"
#include "abstract/abstract_value.h"
namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
@ -417,6 +418,9 @@ class FuncGraph : public FuncGraphBase {
// Design switch_layer_input as a ptr to
// share between derived backpropagator and cloned graphs
std::shared_ptr<bool> switch_layer_input_;
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
abstract::AbstractBasePtrListEqual>
func_graph_cache_;
};
inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {

@ -245,6 +245,10 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
if (!NeedGenerate(kwarg_list)) {
return shared_from_base<FuncGraph>();
}
auto iter = func_graph_cache_.find(args_spec_list);
if (iter != func_graph_cache_.end()) {
return iter->second;
}
FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
size_t kwarg_count = kwarg_list.size();
int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_);
@ -290,6 +294,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
specialized_graph->set_kwonlyargs_count(0);
specialized_graph->ClearDefaultValues();
specialized_graph->set_is_generate(true);
func_graph_cache_[args_spec_list] = specialized_graph;
return specialized_graph;
}

Loading…
Cancel
Save