parent
465390e580
commit
e7c6b7e66a
File diff suppressed because it is too large
Load Diff
@ -1,81 +0,0 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Python pass register"""
|
||||
from inspect import isfunction
|
||||
from mindspore.common.graph_pattern import Pattern
|
||||
from mindspore._c_expression import PyPassManager_
|
||||
from mindspore._c_expression import phase
|
||||
|
||||
class PyPassManager(PyPassManager_):
|
||||
r"""
|
||||
Used to registe and unregiste python passes which can be used to alter graphs.
|
||||
|
||||
Args:
|
||||
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
|
||||
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
|
||||
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument has invalid type.
|
||||
"""
|
||||
def __init__(self, pipeline_phase=phase.opt, run_only_once=False, multi_graph=True):
|
||||
if not isinstance(pipeline_phase, phase):
|
||||
raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
||||
if not isinstance(run_only_once, bool):
|
||||
raise TypeError(f"Expecting bool, got : ({type(run_only_once)}){run_only_once}")
|
||||
if not isinstance(multi_graph, bool):
|
||||
raise TypeError(f"Expecting bool, got : ({type(multi_graph)}){multi_graph}")
|
||||
PyPassManager_.__init__(self)
|
||||
self.phase_ = pipeline_phase
|
||||
self.run_only_once_ = run_only_once
|
||||
self.multi_graph_ = multi_graph
|
||||
|
||||
def registe(self, py_pass):
|
||||
if not isfunction(py_pass):
|
||||
raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}")
|
||||
pattern, target = py_pass()
|
||||
pass_name = py_pass.__name__
|
||||
if not isinstance(pattern, Pattern):
|
||||
raise TypeError(f"Expecting pattern of Pattern type, got : ({type(pattern)}){pattern}")
|
||||
if not isinstance(target, Pattern):
|
||||
raise TypeError(f"Expecting target of Pattern type, got : ({type(target)}){target}")
|
||||
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_)
|
||||
|
||||
def unregiste(self, py_pass, pipeline_phase=phase.opt):
|
||||
if not isinstance(pipeline_phase, phase):
|
||||
raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
||||
if isinstance(py_pass, str):
|
||||
super().unregiste(py_pass, pipeline_phase)
|
||||
return
|
||||
if isfunction(py_pass):
|
||||
super().unregiste(py_pass.__name__, pipeline_phase)
|
||||
return
|
||||
raise TypeError(f"Expecting py_pass to be string or function, got ({type(py_pass)}){py_pass}")
|
||||
|
||||
def __call__(self, py_pass):
|
||||
self.registe(py_pass)
|
||||
return py_pass
|
||||
|
||||
def registe_pass(pipeline_phase=phase.opt, run_only_once=False, multi_graph=True):
|
||||
"""
|
||||
Examples:
|
||||
>>> @registe_pass()
|
||||
>>> def toy_pass():
|
||||
>>> def pattern():
|
||||
>>> pass
|
||||
>>> def target():
|
||||
>>> pass
|
||||
"""
|
||||
return PyPassManager(pipeline_phase, run_only_once, multi_graph)
|
@ -0,0 +1,15 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Top-level reference to python pass."""
|
@ -0,0 +1,24 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Top-level reference to python pass."""
|
||||
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm
|
||||
|
||||
__all__ = [
|
||||
"registe_pass",
|
||||
"unregiste_pass",
|
||||
"gen_new_parameter",
|
||||
"cancel_new_parameter",
|
||||
"set_renorm"
|
||||
]
|
@ -0,0 +1,170 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Python pass register"""
|
||||
from inspect import isfunction
|
||||
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
|
||||
from mindspore._c_expression import PyPassManager_, phase
|
||||
|
||||
|
||||
__all__ = [
|
||||
"registe_pass",
|
||||
"unregiste_pass",
|
||||
"gen_new_parameter",
|
||||
"cancel_new_parameter",
|
||||
"set_renorm"
|
||||
]
|
||||
class PyPassManager(PyPassManager_):
|
||||
r"""
|
||||
Used to registe and unregiste python passes which can be used to alter graphs.
|
||||
|
||||
Args:
|
||||
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
|
||||
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
|
||||
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument has invalid type.
|
||||
"""
|
||||
def __init__(self, pipeline_phase=phase.opt, run_only_once=False):
|
||||
if not isinstance(pipeline_phase, phase):
|
||||
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
||||
if not isinstance(run_only_once, bool):
|
||||
raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
|
||||
PyPassManager_.__init__(self)
|
||||
self.phase_ = pipeline_phase
|
||||
self.run_only_once_ = run_only_once
|
||||
|
||||
def registe(self, py_pass):
|
||||
if not isfunction(py_pass):
|
||||
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
|
||||
pattern, target = py_pass()
|
||||
pass_name = py_pass.__name__
|
||||
if not isinstance(pattern, Pattern):
|
||||
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
|
||||
if not isinstance(target, Pattern):
|
||||
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
|
||||
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_)
|
||||
|
||||
def unregiste(self, py_pass, pipeline_phase=phase.opt):
|
||||
if not isinstance(pipeline_phase, phase):
|
||||
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
||||
if isinstance(py_pass, str):
|
||||
super().unregiste(py_pass, pipeline_phase)
|
||||
return
|
||||
if isfunction(py_pass):
|
||||
super().unregiste(py_pass.__name__, pipeline_phase)
|
||||
return
|
||||
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
|
||||
|
||||
def __call__(self, py_pass):
|
||||
self.registe(py_pass)
|
||||
return py_pass
|
||||
|
||||
def gen_new_parameter(self, pattern):
|
||||
if not isinstance(pattern, NewParameter):
|
||||
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
|
||||
super().gen_new_parameter(pattern)
|
||||
|
||||
def set_renorm(self, should_renorm):
|
||||
if not isinstance(should_renorm, bool):
|
||||
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
|
||||
super().set_renorm(should_renorm)
|
||||
|
||||
def registe_pass(pipeline_phase=phase.opt, run_only_once=False):
|
||||
"""
|
||||
Registe python pass to specified pipeline phase which would be used in compilation.
|
||||
|
||||
Args:
|
||||
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
|
||||
registed. Support phase.resolve and phase.opt. Default: phase.opt.
|
||||
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
|
||||
|
||||
Returns:
|
||||
This function should be used as a decorator, return the decoratorated pass function.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.graph_utils.graph_pattern import IsPrimTypeOf
|
||||
>>> @registe_pass()
|
||||
>>> def toy_pass():
|
||||
>>> pattern = IsPrimTypeOf("ReLU")
|
||||
>>> target = IsPrimTypeOf("ReLU6")
|
||||
>>> return pattern, target
|
||||
"""
|
||||
return PyPassManager(pipeline_phase, run_only_once)
|
||||
|
||||
def unregiste_pass(py_pass, pipeline_phase=phase.opt):
|
||||
"""
|
||||
Unregiste python pass.
|
||||
|
||||
Args:
|
||||
py_pass(Union(str, function)): target python pass to unregiste.
|
||||
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
|
||||
unregisted. Support phase.resolve and phase.opt. Default: phase.opt.
|
||||
"""
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(py_pass, pipeline_phase)
|
||||
|
||||
def gen_new_parameter(pattern):
|
||||
"""
|
||||
Generate specified parameter every time a network gets compiled.
|
||||
|
||||
NOTE:
|
||||
In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without
|
||||
gen_new_parameter, every pass match would build a new Parameter.
|
||||
This would registe a pass to add new parameter in the compilation pipeline, so later compilation would
|
||||
ALSO add this parameter unless the pass is unregisted. To unregiste this pass, call
|
||||
cancel_new_parameter(pattern)
|
||||
|
||||
Args:
|
||||
pattern (NewParameter): NewParameter type, could be used to build nested patterns across multiple passes
|
||||
after gen_new_parameter.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument has invalid type.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.graph_utils.graph_pattern import NewParameter
|
||||
>>> abc = NewParameter("abc")
|
||||
>>> gen_new_parameter(abc)
|
||||
"""
|
||||
ppm = PyPassManager()
|
||||
ppm.gen_new_parameter(pattern)
|
||||
|
||||
def cancel_new_parameter(pattern):
|
||||
"""
|
||||
Use with gen_new_parameter to unregiste gen_new_parameter pass.
|
||||
|
||||
Args:
|
||||
pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern
|
||||
describes.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.graph_utils.graph_pattern import NewParameter
|
||||
>>> abc = NewParameter("abc")
|
||||
>>> gen_new_parameter(abs)
|
||||
>>> # some compilations
|
||||
>>> cancel_new_parameter(abc)
|
||||
"""
|
||||
if not isinstance(pattern, NewParameter):
|
||||
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(pattern.para_name)
|
||||
|
||||
def set_renorm(should_renorm):
|
||||
"""
|
||||
Set whether or not to do renorm after modified graph in python pass(es).
|
||||
"""
|
||||
ppm = PyPassManager()
|
||||
ppm.set_renorm(should_renorm)
|
Loading…
Reference in new issue