reduce pass traversals if no pre-ad python pass exists

pull/6002/head
BowenK 4 years ago
parent 2f529c149a
commit e482e4e8bf

@ -49,6 +49,7 @@ class PassGroup {
bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes, const MatchResultPtr &res) const;
std::string name() const { return name_; }
void SetRunOnlyOnce(bool run_only_once) { run_only_once_ = run_only_once; }
size_t size() { return passes_.size(); }
private:
const std::string name_;

@ -301,7 +301,12 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
return true;
}
bool OptInlineAction(const ResourcePtr &res) { return OptimizeAction(res, kInlinePasses); }
bool OptInlineAction(const ResourcePtr &res) {
if (opt::python_pass::PyPassManager::GetInstance()->GetPassGroup(opt::python_pass::Phase::PREAD)->size() != 0) {
return OptimizeAction(res, kInlinePasses);
}
return true;
}
bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }

@ -13,14 +13,14 @@
# limitations under the License.
# ============================================================================
"""Reference for python pass registration."""
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, _set_renorm,\
_set_reopt
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
set_reopt
__all__ = [
"registe_pass",
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"_set_renorm",
"_set_reopt"
"set_renorm",
"set_reopt"
]

@ -23,8 +23,8 @@ __all__ = [
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"_set_renorm",
"_set_reopt"
"set_renorm",
"set_reopt"
]
class PyPassManager(PyPassManager_):
r"""
@ -162,7 +162,7 @@ def cancel_new_parameter(pattern):
ppm = PyPassManager()
ppm.unregiste(pattern.para_name)
def _set_renorm(should_renorm):
def set_renorm(should_renorm):
"""
Set whether or not to do renormalization after modified graph in python pass(es).
@ -176,7 +176,7 @@ def _set_renorm(should_renorm):
ppm = PyPassManager()
ppm.set_renorm(should_renorm)
def _set_reopt(do_reopt):
def set_reopt(do_reopt):
"""
Set whether or not to do optimization after modified graph in python pass(es).

@ -19,8 +19,8 @@ import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, _set_renorm, gen_new_parameter,\
cancel_new_parameter, _set_reopt
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
cancel_new_parameter, set_reopt
from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
@ -157,8 +157,8 @@ def test_isnot_pattern_0():
Test IsNot pattern which expresses the IsNot semantics.
Case: IsNot pass failed to match
"""
_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
class ConvBN(nn.Cell):
def __init__(self):
super(ConvBN, self).__init__()
@ -202,7 +202,7 @@ def test_isnot_pattern_0():
unregiste_pass(bn_pass)
assert "ReLU6" not in transformed_repr
assert "Softmax" in transformed_repr
_set_renorm(True)
set_renorm(True)
def test_isnot_pattern_1():
"""
@ -234,8 +234,8 @@ def test_newtensor_pattern():
"""
Test NewTensor pattern in the target
"""
_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@ -252,7 +252,7 @@ def test_newtensor_pattern():
unregiste_pass(softmax_addn_pass)
assert "AddN" in transformed_repr
assert "Softmax" not in transformed_repr
_set_renorm(True)
set_renorm(True)
def test_newparameter_pattern():
"""
@ -261,8 +261,8 @@ def test_newparameter_pattern():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_addn_pass():
x = Any()
@ -288,8 +288,8 @@ def test_imm_target():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_pass():
x = Any()
@ -313,8 +313,8 @@ def test_gen_new_parameter():
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
new_para = NewParameter("Merlin", default_tensor)
_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
gen_new_parameter(new_para)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_make_tuple_pass():

Loading…
Cancel
Save