!13146 Make a_2 and b_1 pass as global sensitive, and traverse from SUB to IR.

From: @zh_qh
Reviewed-by: 
Signed-off-by:
pull/13146/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e1666355a8

@ -184,9 +184,6 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
for (auto &substitution : list_) { for (auto &substitution : list_) {
auto res = DoTransform(optimizer, node, substitution); auto res = DoTransform(optimizer, node, substitution);
if (res != nullptr) { if (res != nullptr) {
if (is_once_) {
return true;
}
change = true; change = true;
changes = true; changes = true;
node = res; node = res;
@ -228,9 +225,6 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
bool change = false; bool change = false;
auto res = DoTransform(optimizer, node, substitution); auto res = DoTransform(optimizer, node, substitution);
if (res != nullptr) { if (res != nullptr) {
if (is_once_) {
return true;
}
change = true; change = true;
changes = true; changes = true;
node = res; node = res;
@ -316,9 +310,13 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize
: kOptTraverseFromSubstitutionsToIR); : kOptTraverseFromSubstitutionsToIR);
if (traverse_mode == kOptTraverseFromIRToSubstitutions && if (traverse_mode == kOptTraverseFromIRToSubstitutions &&
MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
optimizer->traverse_nodes_first()) { optimizer->traverse_nodes_first() && !is_once_ && !global_sensitive_) {
MS_LOG(DEBUG) << "IR >> SUB, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_"
<< optimizer->CurPass_.name;
changes = ApplyIRToSubstitutions(optimizer, func_graph); changes = ApplyIRToSubstitutions(optimizer, func_graph);
} else { } else {
MS_LOG(DEBUG) << "SUB >> IR, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_"
<< optimizer->CurPass_.name;
changes = ApplySubstitutionsToIR(optimizer, func_graph); changes = ApplySubstitutionsToIR(optimizer, func_graph);
} }
return changes; return changes;

@ -63,8 +63,9 @@ enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptT
class SubstitutionList { class SubstitutionList {
public: public:
explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false) explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false,
: list_(patterns), is_once_(is_once) {} bool global_sensitive = false)
: list_(patterns), is_once_(is_once), global_sensitive_(global_sensitive) {}
~SubstitutionList() = default; ~SubstitutionList() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const;
@ -77,6 +78,7 @@ class SubstitutionList {
std::vector<SubstitutionPtr> list_; std::vector<SubstitutionPtr> list_;
// a flag to mark this list of Substitution can only be executed only once // a flag to mark this list of Substitution can only be executed only once
bool is_once_; bool is_once_;
bool global_sensitive_;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

@ -43,10 +43,10 @@ using OptimizeGraphFunc = std::function<bool(const FuncGraphPtr &func_graph, con
class OptPassConfig { class OptPassConfig {
public: public:
explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {} explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {}
explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false) explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false)
: list_(list), is_once_(is_once) {} : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {}
OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false) OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false)
: list_(list), is_once_(is_once) {} : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {}
~OptPassConfig() = default; ~OptPassConfig() = default;
const std::vector<SubstitutionPtr> &list() const { return list_; } const std::vector<SubstitutionPtr> &list() const { return list_; }
@ -57,6 +57,8 @@ class OptPassConfig {
const bool is_once() const { return is_once_; } const bool is_once() const { return is_once_; }
const bool global_sensitive() const { return global_sensitive_; }
private: private:
OptPassConfig() : is_renormalize_(true) {} OptPassConfig() : is_renormalize_(true) {}
@ -64,6 +66,7 @@ class OptPassConfig {
std::vector<SubstitutionPtr> list_; std::vector<SubstitutionPtr> list_;
bool is_renormalize_{false}; bool is_renormalize_{false};
bool is_once_{false}; bool is_once_{false};
bool global_sensitive_{false};
}; };
class OptPass { class OptPass {
@ -115,7 +118,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
} }
if (config.list().size() > 0) { if (config.list().size() > 0) {
OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once()); OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once(), config.global_sensitive());
passes_.push_back(OptPass(func)); passes_.push_back(OptPass(func));
continue; continue;
} }

@ -136,19 +136,21 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.stopgrad_eliminater_, irpass.stopgrad_eliminater_,
irpass.sparse_tensor_eliminate_, irpass.sparse_tensor_eliminate_,
}); });
opt::OptPassConfig a_2 = opt::OptPassConfig({ opt::OptPassConfig a_2 = opt::OptPassConfig(
irpass.merge_addn_, {
irpass.float_tuple_getitem_switch_, irpass.merge_addn_,
irpass.float_env_getitem_switch_, irpass.float_tuple_getitem_switch_,
irpass.incorporate_getitem_set_, irpass.float_env_getitem_switch_,
irpass.incorporate_call_, irpass.incorporate_getitem_set_,
irpass.incorporate_call_switch_, irpass.incorporate_call_,
irpass.incorporate_env_getitem_bypass_recursive_, irpass.incorporate_call_switch_,
irpass.incorporate_env_getitem_switch_, irpass.incorporate_env_getitem_bypass_recursive_,
irpass.new_env_get_item_, irpass.incorporate_env_getitem_switch_,
irpass.depend_value_elim_, irpass.new_env_get_item_,
irpass.all_reduce_const_elim_, irpass.depend_value_elim_,
}); irpass.all_reduce_const_elim_,
},
false, true);
opt::OptPassConfig a_after_grad = opt::OptPassConfig({ opt::OptPassConfig a_after_grad = opt::OptPassConfig({
irpass.inline_without_move_, irpass.inline_without_move_,
}); });
@ -229,7 +231,8 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_, irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_,
irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_},
false, true);
opt::OptPassConfig b_2 = opt::OptPassConfig({ opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_, irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_, irpass.make_ref_eliminate_,

Loading…
Cancel
Save