!12345 Traverse all nodes once, then traverse all Substitutions on each node.

From: @zh_qh
Reviewed-by: 
Signed-off-by:
pull/12345/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 86e3099c05

@ -31,7 +31,6 @@
#include "backend/kernel_compiler/kernel_build_info.h"
#include "common/trans.h"
#include "abstract/param_validator.h"
#include "abstract/primitive_infer_map.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "utils/trace_base.h"
@ -1806,14 +1805,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) {
args_spec_list.emplace_back(real_input->abstract());
}
}
auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
auto ret = prim_eval_implement_map.find(primitive);
if (ret == prim_eval_implement_map.end()) {
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << primitive->name()
<< " primitive type:" << primitive->type_name();
}
auto eval_result = ret->second.impl_(nullptr, primitive, args_spec_list);
auto eval_result = abstract::CppInferShape(primitive, args_spec_list);
node->set_abstract(eval_result);
}
} // namespace session

@ -230,6 +230,8 @@ ResolveIRPassLib::ResolveIRPassLib() {
{prim::kPrimGetAttr, prim::kPrimResolve});
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetAttr>(), "resolver_getattr", prim::kPrimGetAttr);
resolver_getattr_resolve_ =
MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "resolver_getattr_resolve", prim::kPrimGetAttr);
}
InferenceOptPrepareLib::InferenceOptPrepareLib() {

@ -154,6 +154,7 @@ class ResolveIRPassLib {
SubstitutionPtr resolver_resolve_and_getattr_;
SubstitutionPtr resolver_resolve_;
SubstitutionPtr resolver_getattr_;
SubstitutionPtr resolver_getattr_resolve_;
};
class InferenceOptPrepareLib {

@ -71,7 +71,7 @@ AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const
return nullptr;
}
// Replace UpdateState with the input monad.
return update_state->inputs().at(kInputIndex);
return update_state->input(kInputIndex);
}
// Eliminate UpdateState that attaches a pure (no-side-effect) node.
@ -100,7 +100,7 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A
}
}
// Remove UpdateState by replace it with its input monad.
return update_state->inputs().at(kInputIndex);
return update_state->input(kInputIndex);
}
// Eliminate redundant UpdateState/Depend pair nodes caused by inline.
@ -118,7 +118,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN
// Skip if Depend attach input is not a monad.
return nullptr;
}
auto update_monad = update_state->inputs().at(kInputIndex);
auto update_monad = update_state->input(kInputIndex);
if (!HasAbstractMonad(update_monad)) {
// Skip if UpdateState input is not a monad.
MS_LOG(WARNING) << "Not a monad input: " << update_state->DebugString();
@ -139,7 +139,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN
}
// Replace Depend with its input.
if (depend->size() == kMinDependSize) {
auto depend_input = depend->inputs().at(kInputIndex);
auto depend_input = depend->input(kInputIndex);
mgr->Replace(depend, depend_input);
} else {
auto inputs = depend->inputs();
@ -163,7 +163,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN
if (make_tuple->size() != kMakeTupleSize) {
return nullptr;
}
auto &node = make_tuple->inputs().at(kAttachIndex);
auto &node = make_tuple->input(kAttachIndex);
auto node_abs = node->abstract();
if (node_abs == nullptr || !node_abs->isa<abstract::AbstractError>()) {
return nullptr;
@ -173,7 +173,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN
return nullptr;
}
// Create a new UpdateState to replace the old one.
const auto &attach = make_tuple->inputs().at(kInputIndex);
const auto &attach = make_tuple->input(kInputIndex);
auto new_update_state = fg->NewCNode({update_state->input(0), update_state->input(1), attach});
new_update_state->set_abstract(update_state->abstract());
new_update_state->set_scope(update_state->scope());
@ -206,42 +206,47 @@ AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, c
if (make_tuple->size() != kMakeTupleSize) {
return nullptr;
}
auto &first_input = make_tuple->inputs().at(kInputIndex);
auto &first_input = make_tuple->input(kInputIndex);
if (IsValueNode<FuncGraph>(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) {
return update_state->input(1);
}
auto &second_input = make_tuple->inputs().at(kAttachIndex);
auto &second_input = make_tuple->input(kAttachIndex);
if (IsValueNode<FuncGraph>(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) {
return update_state->input(1);
}
return nullptr;
}
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *loads);
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *loads);
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads);
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads);
// Search consecutive load nodes from UpdateState node.
size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *loads) {
auto &attach = update_state->inputs().at(kAttachIndex);
size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *update_states,
std::vector<CNodePtr> *loads) {
auto &attach = update_state->input(kAttachIndex);
if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
return GetLoadsFollowLoad(attach->cast<CNodePtr>(), loads);
update_states->emplace_back(update_state);
return GetLoadsFollowLoad(attach->cast<CNodePtr>(), update_states, loads);
}
if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), loads);
update_states->emplace_back(update_state);
return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), update_states, loads);
}
return 0;
}
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *loads) {
loads->push_back(load);
auto &load_attach = load->inputs().at(kAttachIndex);
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) {
loads->emplace_back(load);
auto &load_attach = load->input(kAttachIndex);
if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) {
return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), loads) + 1;
return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), update_states, loads) + 1;
}
return 1;
}
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *loads) {
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) {
if (!OnlyUpdateStateUse(update_state, make_tuple)) {
// UpdateState should be the only user of
return 0;
@ -256,12 +261,12 @@ size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tu
// Add load nodes from tuple elements.
for (size_t i = 1; i < inputs.size(); ++i) {
auto &element = inputs.at(i);
loads->push_back(element->cast<CNodePtr>());
loads->emplace_back(element->cast<CNodePtr>());
}
// Follow prev update state if found.
auto prev_node = update_state->inputs().at(kInputIndex);
auto prev_node = update_state->input(kInputIndex);
if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) {
return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), loads) + 1;
return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), update_states, loads) + 1;
}
return 1;
}
@ -301,7 +306,8 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd
// xN = Load(xN, u)
// t = make_tuple(x1, x2, ... , xN)
// u1 = UpdateState(u, t)
AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &loads) {
AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &update_states,
const std::vector<CNodePtr> &loads) {
auto fg = old_update_state->func_graph();
if (fg == nullptr) {
return nullptr;
@ -315,20 +321,24 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const
std::set<AnfNodePtr> loaded_para_set;
make_tuple_inputs.reserve(loads.size() + 1);
make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
auto input_monad = loads.back()->inputs().at(kAttachIndex);
auto input_monad = loads.back()->input(kAttachIndex);
for (auto iter = loads.rbegin(); iter != loads.rend(); ++iter) {
auto &load = *iter;
auto result = loaded_para_set.emplace(load->inputs().at(kInputIndex));
auto result = loaded_para_set.emplace(load->input(kInputIndex));
const bool is_new_load = result.second;
if (is_new_load) {
// Put Load node as a tuple element, if the parameter is not loaded by other Load.
make_tuple_inputs.emplace_back(load);
}
if (load->inputs().at(kAttachIndex) != input_monad) {
if (load->input(kAttachIndex) != input_monad) {
// Set all load use same input monad.
mgr->SetEdge(load, kAttachIndex, input_monad);
}
}
for (auto i = update_states.size() - 1; i > 0; i--) {
auto &us = update_states[i];
mgr->Replace(us, us->input(kInputIndex));
}
if (make_tuple_inputs.size() == 1) {
// This should not happen.
MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2);
@ -538,7 +548,7 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
return nullptr;
}
auto &attach = update_state_node->inputs().at(kAttachIndex);
auto &attach = update_state_node->input(kAttachIndex);
if (IsPrimitiveCNode(attach, prim::kPrimDepend)) {
return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>());
}
@ -586,9 +596,10 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
return new_node;
}
}
std::vector<CNodePtr> update_states;
std::vector<CNodePtr> loads;
if (GetLoadsFromUpdateState(update_state_node, &loads) > 1 && loads.size() > 1) {
return EliminateUpdateStateForLoads(update_state_node, loads);
if (GetLoadsFromUpdateState(update_state_node, &update_states, &loads) > 1 && loads.size() > 1) {
return EliminateUpdateStateForLoads(update_state_node, update_states, loads);
}
// Eliminate UpdateStates that attaches a no-side-effect node.
if (!attach_is_load && !attach_is_tuple) {

File diff suppressed because it is too large Load Diff

@ -59,6 +59,8 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std:
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR };
class SubstitutionList {
public:
explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false)
@ -68,7 +70,10 @@ class SubstitutionList {
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const;
private:
bool ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &transform) const;
bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const;
bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
bool ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const;
std::vector<SubstitutionPtr> list_;
// a flag to mark this list of Substitution can only be executed only once
bool is_once_;

@ -88,13 +88,14 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
class Optimizer : public std::enable_shared_from_this<Optimizer> {
public:
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr)
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr, bool traverse_nodes_first = true)
: name_(name),
resource_(resource_ptr),
run_only_once_(false),
is_watch_renormalize_(false),
is_enable_(true),
is_untyped_generated_(false) {}
is_untyped_generated_(false),
traverse_nodes_first_(traverse_nodes_first) {}
virtual ~Optimizer() = default;
void Init(const OptPassGroupMap &passes, bool run_only_once) {
@ -129,8 +130,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr,
const OptPassGroupMap &passes, bool run_only_once = false,
bool watch_renormalize = false) {
OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr);
bool watch_renormalize = false, bool traverse_nodes_first = true) {
OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr, traverse_nodes_first);
optimizer->Init(passes, run_only_once);
if (watch_renormalize) {
optimizer->enable_watch_renormalize();
@ -223,6 +224,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
bool is_watch_renormalize() { return is_watch_renormalize_; }
void set_enable(bool enable) { is_enable_ = enable; }
bool traverse_nodes_first() { return traverse_nodes_first_; }
struct {
int64_t counter;
std::string name;
@ -239,6 +242,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
bool is_watch_renormalize_;
bool is_enable_;
bool is_untyped_generated_;
bool traverse_nodes_first_;
};
} // namespace opt
} // namespace mindspore

@ -308,7 +308,8 @@ bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBa
return false;
}
opt::irpass::ResolveIRPassLib irpass;
opt::OptimizerPtr opt_resolve = opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass));
opt::OptimizerPtr opt_resolve =
opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass), false, false, false);
(void)parse::python_adapter::set_python_scoped();

@ -246,7 +246,7 @@ class CNode : public AnfNode, public EffectInfoHolder {
bool IsApply(const PrimitivePtr &) const;
const size_t size() const { return inputs_.size(); }
const AnfNodePtr input(size_t i) const { return inputs_[i]; }
const AnfNodePtr &input(size_t i) const { return inputs_.at(i); }
const std::vector<AnfNodePtr> &inputs() const { return inputs_; }
void add_input(const AnfNodePtr &input);
void set_input(size_t i, const AnfNodePtr &input);

Loading…
Cancel
Save