diff --git a/mindspore/ccsrc/frontend/optimizer/pass_group.h b/mindspore/ccsrc/frontend/optimizer/pass_group.h index 20cf66649f..498256f7ff 100644 --- a/mindspore/ccsrc/frontend/optimizer/pass_group.h +++ b/mindspore/ccsrc/frontend/optimizer/pass_group.h @@ -49,6 +49,7 @@ class PassGroup { bool Run(const FuncGraphPtr &func_graph, const std::vector &passes, const MatchResultPtr &res) const; std::string name() const { return name_; } void SetRunOnlyOnce(bool run_only_once) { run_only_once_ = run_only_once; } + size_t size() { return passes_.size(); } private: const std::string name_; diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 928c317848..43819db23d 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -301,7 +301,12 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector &passes) return true; } -bool OptInlineAction(const ResourcePtr &res) { return OptimizeAction(res, kInlinePasses); } +bool OptInlineAction(const ResourcePtr &res) { + if (opt::python_pass::PyPassManager::GetInstance()->GetPassGroup(opt::python_pass::Phase::PREAD)->size() != 0) { + return OptimizeAction(res, kInlinePasses); + } + return true; +} bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } diff --git a/mindspore/graph_utils/python_pass/__init__.py b/mindspore/graph_utils/python_pass/__init__.py index 5fa33e3502..d9fe61c870 100644 --- a/mindspore/graph_utils/python_pass/__init__.py +++ b/mindspore/graph_utils/python_pass/__init__.py @@ -13,14 +13,14 @@ # limitations under the License. # ============================================================================ """Reference for python pass registration.""" -from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, _set_renorm,\ - _set_reopt +from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\ + set_reopt __all__ = [ "registe_pass", "unregiste_pass", "gen_new_parameter", "cancel_new_parameter", - "_set_renorm", - "_set_reopt" + "set_renorm", + "set_reopt" ] diff --git a/mindspore/graph_utils/python_pass/python_pass_register.py b/mindspore/graph_utils/python_pass/python_pass_register.py index 37427f177d..55e1a8b59a 100644 --- a/mindspore/graph_utils/python_pass/python_pass_register.py +++ b/mindspore/graph_utils/python_pass/python_pass_register.py @@ -23,8 +23,8 @@ __all__ = [ "unregiste_pass", "gen_new_parameter", "cancel_new_parameter", - "_set_renorm", - "_set_reopt" + "set_renorm", + "set_reopt" ] class PyPassManager(PyPassManager_): r""" @@ -162,7 +162,7 @@ def cancel_new_parameter(pattern): ppm = PyPassManager() ppm.unregiste(pattern.para_name) -def _set_renorm(should_renorm): +def set_renorm(should_renorm): """ Set whether or not to do renormalization after modified graph in python pass(es). @@ -176,7 +176,7 @@ def _set_renorm(should_renorm): ppm = PyPassManager() ppm.set_renorm(should_renorm) -def _set_reopt(do_reopt): +def set_reopt(do_reopt): """ Set whether or not to do optimization after modified graph in python pass(es). diff --git a/tests/ut/python/optimizer/test_python_pass.py b/tests/ut/python/optimizer/test_python_pass.py index 379f83f5f9..b6df9a1a20 100644 --- a/tests/ut/python/optimizer/test_python_pass.py +++ b/tests/ut/python/optimizer/test_python_pass.py @@ -19,8 +19,8 @@ import mindspore.nn as nn from mindspore import context from mindspore.common.tensor import Tensor from mindspore.ops import operations as P -from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, _set_renorm, gen_new_parameter,\ - cancel_new_parameter, _set_reopt +from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\ + cancel_new_parameter, set_reopt from mindspore.common.api import _generate_pip_args from mindspore._c_expression import generate_key, Executor_ from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm @@ -157,8 +157,8 @@ def test_isnot_pattern_0(): Test IsNot pattern which expresses the IsNot semantics. Case: IsNot pass failed to match """ - _set_renorm(False) - _set_reopt(False) + set_renorm(False) + set_reopt(False) class ConvBN(nn.Cell): def __init__(self): super(ConvBN, self).__init__() @@ -202,7 +202,7 @@ def test_isnot_pattern_0(): unregiste_pass(bn_pass) assert "ReLU6" not in transformed_repr assert "Softmax" in transformed_repr - _set_renorm(True) + set_renorm(True) def test_isnot_pattern_1(): """ @@ -234,8 +234,8 @@ def test_newtensor_pattern(): """ Test NewTensor pattern in the target """ - _set_renorm(False) - _set_reopt(False) + set_renorm(False) + set_reopt(False) inputs = Tensor(np.ones([42]), mindspore.float16) softmax_model = nn.Softmax() @@ -252,7 +252,7 @@ def test_newtensor_pattern(): unregiste_pass(softmax_addn_pass) assert "AddN" in transformed_repr assert "Softmax" not in transformed_repr - _set_renorm(True) + set_renorm(True) def test_newparameter_pattern(): """ @@ -261,8 +261,8 @@ def test_newparameter_pattern(): inputs = Tensor(np.ones([42]), mindspore.float16) softmax_model = nn.Softmax() - _set_renorm(False) - _set_reopt(False) + set_renorm(False) + set_reopt(False) @registe_pass(requires_grad=False, run_only_once=True) def softmax_addn_pass(): x = Any() @@ -288,8 +288,8 @@ def test_imm_target(): inputs = Tensor(np.ones([42]), mindspore.float16) softmax_model = nn.Softmax() - _set_renorm(False) - _set_reopt(False) + set_renorm(False) + set_reopt(False) @registe_pass(requires_grad=False, run_only_once=True) def softmax_pass(): x = Any() @@ -313,8 +313,8 @@ def test_gen_new_parameter(): default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) new_para = NewParameter("Merlin", default_tensor) - _set_renorm(False) - _set_reopt(False) + set_renorm(False) + set_reopt(False) gen_new_parameter(new_para) @registe_pass(requires_grad=False, run_only_once=True) def softmax_make_tuple_pass():