!13146 Make a_2 and b_1 pass as global sensitive, and traverse from SUB to IR.

From: @zh_qh
Reviewed-by: 
Signed-off-by:
pull/13146/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e1666355a8

@ -184,9 +184,6 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
for (auto &substitution : list_) {
auto res = DoTransform(optimizer, node, substitution);
if (res != nullptr) {
if (is_once_) {
return true;
}
change = true;
changes = true;
node = res;
@ -228,9 +225,6 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
bool change = false;
auto res = DoTransform(optimizer, node, substitution);
if (res != nullptr) {
if (is_once_) {
return true;
}
change = true;
changes = true;
node = res;
@ -316,9 +310,13 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize
: kOptTraverseFromSubstitutionsToIR);
if (traverse_mode == kOptTraverseFromIRToSubstitutions &&
MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
optimizer->traverse_nodes_first()) {
optimizer->traverse_nodes_first() && !is_once_ && !global_sensitive_) {
MS_LOG(DEBUG) << "IR >> SUB, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_"
<< optimizer->CurPass_.name;
changes = ApplyIRToSubstitutions(optimizer, func_graph);
} else {
MS_LOG(DEBUG) << "SUB >> IR, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_"
<< optimizer->CurPass_.name;
changes = ApplySubstitutionsToIR(optimizer, func_graph);
}
return changes;

@ -63,8 +63,9 @@ enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptT
class SubstitutionList {
public:
explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false)
: list_(patterns), is_once_(is_once) {}
explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false,
bool global_sensitive = false)
: list_(patterns), is_once_(is_once), global_sensitive_(global_sensitive) {}
~SubstitutionList() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const;
@ -77,6 +78,7 @@ class SubstitutionList {
std::vector<SubstitutionPtr> list_;
// a flag to mark this list of Substitution can only be executed only once
bool is_once_;
bool global_sensitive_;
};
} // namespace opt
} // namespace mindspore

@ -43,10 +43,10 @@ using OptimizeGraphFunc = std::function<bool(const FuncGraphPtr &func_graph, con
class OptPassConfig {
public:
explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {}
explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false)
: list_(list), is_once_(is_once) {}
OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false)
: list_(list), is_once_(is_once) {}
explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false)
: list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {}
OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false)
: list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {}
~OptPassConfig() = default;
const std::vector<SubstitutionPtr> &list() const { return list_; }
@ -57,6 +57,8 @@ class OptPassConfig {
const bool is_once() const { return is_once_; }
const bool global_sensitive() const { return global_sensitive_; }
private:
OptPassConfig() : is_renormalize_(true) {}
@ -64,6 +66,7 @@ class OptPassConfig {
std::vector<SubstitutionPtr> list_;
bool is_renormalize_{false};
bool is_once_{false};
bool global_sensitive_{false};
};
class OptPass {
@ -115,7 +118,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
}
if (config.list().size() > 0) {
OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once());
OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once(), config.global_sensitive());
passes_.push_back(OptPass(func));
continue;
}

@ -136,19 +136,21 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.stopgrad_eliminater_,
irpass.sparse_tensor_eliminate_,
});
opt::OptPassConfig a_2 = opt::OptPassConfig({
irpass.merge_addn_,
irpass.float_tuple_getitem_switch_,
irpass.float_env_getitem_switch_,
irpass.incorporate_getitem_set_,
irpass.incorporate_call_,
irpass.incorporate_call_switch_,
irpass.incorporate_env_getitem_bypass_recursive_,
irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_,
irpass.depend_value_elim_,
irpass.all_reduce_const_elim_,
});
opt::OptPassConfig a_2 = opt::OptPassConfig(
{
irpass.merge_addn_,
irpass.float_tuple_getitem_switch_,
irpass.float_env_getitem_switch_,
irpass.incorporate_getitem_set_,
irpass.incorporate_call_,
irpass.incorporate_call_switch_,
irpass.incorporate_env_getitem_bypass_recursive_,
irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_,
irpass.depend_value_elim_,
irpass.all_reduce_const_elim_,
},
false, true);
opt::OptPassConfig a_after_grad = opt::OptPassConfig({
irpass.inline_without_move_,
});
@ -229,7 +231,8 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_,
irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_});
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_},
false, true);
opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_,

Loading…
Cancel
Save