Add NewParameter and Imm patterns

pull/4126/head
BowenK 5 years ago
parent 465390e580
commit e7c6b7e66a

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "frontend/optimizer/pass_group.h"
#include "frontend/optimizer/py_pass_manager.h"
namespace mindspore {
namespace opt {
@ -35,14 +36,15 @@ bool PassGroup::DeletePass(const std::string &pass_name) {
return false;
}
bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const {
bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes,
const MatchResultPtr &res) const {
if (func_graph == nullptr) {
return false;
}
bool changed = false;
for (const auto &pass : passes) {
if (pass != nullptr) {
if (pass->Run(func_graph)) {
if (pass->Run(func_graph, res)) {
changed = true;
}
}
@ -54,8 +56,9 @@ bool PassGroup::Run(const FuncGraphPtr &func_graph) const {
bool changed = false;
// run all passes
bool change = true;
auto res = PyPassManager::GetInstance()->GetMatchResult();
while (change) {
change = Run(func_graph, passes_);
change = Run(func_graph, passes_, res);
changed = change || changed;
if (run_only_once_) {
break;

@ -41,12 +41,14 @@ class PassGroup {
// @return false, graph not changed
bool Run(const FuncGraphPtr &func_graph) const;
// Run the given graph passes on the input graph
// @param [inout] graph The graph to be optimized
// @param [inout] func_graph The graph to be optimized
// @param [in] passes The given graph passes
// @param [inout] res MatchResult used to collect all matched patterns and nodes
// @return true, graph changed
// @return false, graph not changed
bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const;
bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes, const MatchResultPtr &res) const;
std::string name() const { return name_; }
void SetRunOnlyOnce(bool run_only_once) { run_only_once_ = run_only_once; }
private:
const std::string name_;

@ -96,6 +96,7 @@ MatchResultPtr IsIn::match(const AnfNodePtr &node) {
for (auto &iter : patterns_) {
auto res = iter->match(node);
if (res != nullptr) {
res->add_entry(shared_from_base<IsIn>(), node);
return res;
}
}
@ -151,6 +152,9 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>());
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
.def(py::init<tensor::TensorPtr>());
(void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_")
.def(py::init<string, tensor::TensorPtr, bool, bool, bool>());
(void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>());
}));
} // namespace python_pass
} // namespace opt

@ -42,6 +42,10 @@ class CallWith;
using CallWithPtr = std::shared_ptr<CallWith>;
class NewTensor;
using NewTensorPtr = std::shared_ptr<NewTensor>;
class NewParameter;
using NewParameterPtr = std::shared_ptr<NewParameter>;
class Imm;
using ImmPtr = std::shared_ptr<Imm>;
struct PatternHasher;
struct PatternEqual;
using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, PatternEqual>;
@ -55,6 +59,7 @@ class Pattern : public Base {
string unique_name() const { return unique_name_; }
vector<PatternPtr> inputs() { return inputs_; }
bool should_replace() { return should_replace_; }
void set_should_replace(bool should_replace) { should_replace_ = should_replace; }
virtual void reset() {}
protected:
@ -86,14 +91,14 @@ class IsPrimTypeOf : public Pattern {
~IsPrimTypeOf() = default;
IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace)
: primitives_(prims), name_(name), matched_prim_(nullptr) {
unique_name_ = std::to_string(g_id_++) + "_" + name;
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
should_replace_ = should_replace;
if (!should_replace) {
matched_prim_ = prims[0];
}
}
IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) {
unique_name_ = std::to_string(g_id_++) + "_" + name;
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
// Make primitives_
for (auto &iter : types) {
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
@ -126,19 +131,20 @@ class CallWith : public Pattern {
CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) {
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
prim_pattern_ = prim_pattern;
unique_name_ = std::to_string(g_id_++) + prim_pattern->unique_name();
unique_name_ = std::to_string(g_id_++) + "CallWithPattern_" + prim_pattern->unique_name();
inputs_ = inputs;
should_replace_ = should_replace;
// NOTE: should_replace_ is overrided by it prim_pattern(if exists) silently.
should_replace_ = prim_pattern->should_replace();
}
CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) {
prim_ = prim;
unique_name_ = std::to_string(g_id_++) + prim_->ToString();
unique_name_ = std::to_string(g_id_++) + "CallWithPrim_" + prim_->ToString();
inputs_ = inputs;
should_replace_ = should_replace;
}
CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) {
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
unique_name_ = std::to_string(g_id_++) + prim_->ToString();
unique_name_ = std::to_string(g_id_++) + "CallWithStr_" + prim_->ToString();
inputs_ = inputs;
should_replace_ = should_replace;
}
@ -159,7 +165,7 @@ class IsIn : public Pattern {
IsIn() { unique_name_ = std::to_string(g_id_++); }
~IsIn() = default;
explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) {
unique_name_ = std::to_string(g_id_++);
unique_name_ = std::to_string(g_id_++) + "IsIn";
for (auto &iter : patterns) {
unique_name_ = unique_name_ + "_" + iter->unique_name();
}
@ -176,9 +182,9 @@ class IsNot : public Pattern {
IsNot() { unique_name_ = std::to_string(g_id_++); }
~IsNot() = default;
explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) {
unique_name_ = std::to_string(g_id_++);
unique_name_ = std::to_string(g_id_++) + "IsNot";
for (auto &iter : patterns) {
unique_name_ = "IsNot_" + unique_name_ + "_" + iter->unique_name();
unique_name_ = unique_name_ + "_" + iter->unique_name();
}
}
MS_DECLARE_PARENT(IsNot, Pattern);
@ -200,7 +206,10 @@ class NewTensor : public Pattern {
public:
NewTensor() { unique_name_ = std::to_string(g_id_++); }
~NewTensor() = default;
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; }
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) {
should_replace_ = false;
unique_name_ = std::to_string(g_id_++) + "NewTensor";
}
MS_DECLARE_PARENT(NewTensor, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override {
MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n";
@ -211,6 +220,54 @@ class NewTensor : public Pattern {
tensor::TensorPtr input_tensor_;
};
class NewParameter : public Pattern {
public:
NewParameter() { unique_name_ = std::to_string(g_id_++); }
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel,
bool should_replace)
: para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
should_replace_ = should_replace;
unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
// clone input tensor
default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
built_ = false;
}
MS_DECLARE_PARENT(NewParameter, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override {
MS_LOG(EXCEPTION) << "Find NewParameter in pattern, NewParameter should only appear in the target.\n";
}
string para_name() { return para_name_; }
tensor::TensorPtr default_tensor() { return default_tensor_; }
bool requires_grad() { return requires_grad_; }
bool layerwise_parallel() { return layerwise_parallel_; }
bool built() { return built_; }
void set_built(bool built) { built_ = built; }
void reset() override { built_ = false; }
private:
string para_name_;
bool requires_grad_;
bool layerwise_parallel_;
bool built_;
tensor::TensorPtr default_tensor_;
};
class Imm : public Pattern {
public:
Imm() { unique_name_ = std::to_string(g_id_++); }
explicit Imm(int value) : value_(value) {
should_replace_ = false;
unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value);
}
MS_DECLARE_PARENT(Imm, Pattern);
// NOTE: Doesn't support Imm in src pattern currently.
MatchResultPtr match(const AnfNodePtr &node) override { return nullptr; }
int value() { return value_; }
private:
int value_;
};
class MatchResult {
public:
MatchResult() {}

File diff suppressed because it is too large Load Diff

@ -34,20 +34,20 @@ using NodeEquivPtr = std::shared_ptr<NodeEquiv>;
class PythonPass {
public:
explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false,
bool multigraph = true)
: src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {}
explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false)
: src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once) {}
~PythonPass() = default;
bool Run(const FuncGraphPtr &func_graph);
bool Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res);
std::string name() const { return name_; }
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node);
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res);
PatternPtr src_pattern() { return src_pattern_; }
PatternPtr dst_pattern() { return dst_pattern_; }
private:
PatternPtr src_pattern_;
PatternPtr dst_pattern_;
const std::string name_;
bool run_only_once_;
bool multigraph_ = true;
};
using PythonPassPtr = std::shared_ptr<PythonPass>;

@ -45,14 +45,19 @@ PyPassManagerPtr PyPassManager::GetInstance() {
PyPassManager::PyPassManager() {
phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>();
phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>();
res_ = std::make_shared<MatchResult>();
}
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
Phase phase, bool run_only_once, bool multigraph) {
auto cur_pm = GetPassGroup(phase);
MS_EXCEPTION_IF_NULL(cur_pm);
PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once, multigraph);
cur_pm->AddPass(new_pass);
Phase phase, bool run_only_once) {
auto cur_pg = GetPassGroup(phase);
MS_EXCEPTION_IF_NULL(cur_pg);
cur_pg->SetRunOnlyOnce(run_only_once);
MS_EXCEPTION_IF_NULL(pattern);
MS_EXCEPTION_IF_NULL(target);
MS_EXCEPTION_IF_NULL(cur_pg);
PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once);
cur_pg->AddPass(new_pass);
}
void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
@ -63,6 +68,21 @@ void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
}
}
void PyPassManager::GenNewParameter(const PatternPtr &parameter) {
MS_EXCEPTION_IF_NULL(parameter);
// Add new parameter after resolve
// NOTE: Add NewParameter at early stage will cause CSE problems
auto cur_pg = GetPassGroup(Phase::OPT);
MS_EXCEPTION_IF_NULL(cur_pg);
cur_pg->SetRunOnlyOnce(true);
auto new_para_pattern = parameter->cast<NewParameterPtr>();
MS_EXCEPTION_IF_NULL(new_para_pattern);
auto pass_name = new_para_pattern->para_name();
parameter->set_should_replace(false);
auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true);
cur_pg->AddPass(new_pass);
}
void PyPassManager::ClearRes() {
MS_LOG(INFO) << "Clear PyPassManager resources!";
global_instance = nullptr;
@ -75,7 +95,9 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
.def(py::init([]() { return PyPassManager::GetInstance(); }))
.def("registe", &PyPassManager::Registe, "Registe python pass")
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass");
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass")
.def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph");
}));
} // namespace python_pass
} // namespace opt

@ -27,7 +27,7 @@
#include "ir/graph_utils.h"
#include "utils/ms_utils.h"
#include "pipeline/jit/parse/resolve.h"
#include "pipeline/jit/resource.h"
#include "frontend/optimizer/pattern.h"
#include "frontend/optimizer/py_pass.h"
#include "frontend/optimizer/pass_group.h"
@ -53,12 +53,21 @@ class PyPassManager {
static PyPassManagerPtr GetInstance();
virtual ~PyPassManager() = default;
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true);
Phase phase = Phase::RESOLVE, bool run_only_once = false);
void Unregiste(const std::string &pass_name, Phase phase);
void GenNewParameter(const PatternPtr &parameter);
PassGroupPtr GetPassGroup(Phase phase);
void ClearRes();
MatchResultPtr GetMatchResult() { return res_; }
void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
bool ShouldRenorm() { return should_renorm_; }
void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
pipeline::ResourcePtr GetResource() { return resource_; }
private:
bool should_renorm_ = true;
MatchResultPtr res_;
pipeline::ResourcePtr resource_;
static std::unordered_map<Phase, PassGroupPtr> phase_to_group_;
};
} // namespace python_pass

@ -448,8 +448,21 @@ void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
MS_EXCEPTION_IF_NULL(res->manager());
MS_EXCEPTION_IF_NULL(res->func_graph());
auto ppm = opt::python_pass::PyPassManager::GetInstance();
ppm->SetResource(res);
if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) {
MS_LOG(DEBUG) << "No match.\n";
} else if (phase == opt::python_pass::Phase::OPT && opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
MS_LOG(DEBUG) << "Entered PyStub Renorm";
// Renomalize
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
abstract::AbstractBasePtrList args_spec;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
res->set_func_graph(new_fg);
res->set_args_spec(args_spec);
}
}
@ -477,6 +490,7 @@ static std::vector<ActionItem> CommonPipeline() {
}
// Add resolve-stage python pass stub
actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub));
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
// Evaluate type and shape, and specialize
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));

@ -1,81 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Python pass register"""
from inspect import isfunction
from mindspore.common.graph_pattern import Pattern
from mindspore._c_expression import PyPassManager_
from mindspore._c_expression import phase
class PyPassManager(PyPassManager_):
r"""
Used to registe and unregiste python passes which can be used to alter graphs.
Args:
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
Raises:
TypeError: If argument has invalid type.
"""
def __init__(self, pipeline_phase=phase.opt, run_only_once=False, multi_graph=True):
if not isinstance(pipeline_phase, phase):
raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}")
if not isinstance(run_only_once, bool):
raise TypeError(f"Expecting bool, got : ({type(run_only_once)}){run_only_once}")
if not isinstance(multi_graph, bool):
raise TypeError(f"Expecting bool, got : ({type(multi_graph)}){multi_graph}")
PyPassManager_.__init__(self)
self.phase_ = pipeline_phase
self.run_only_once_ = run_only_once
self.multi_graph_ = multi_graph
def registe(self, py_pass):
if not isfunction(py_pass):
raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}")
pattern, target = py_pass()
pass_name = py_pass.__name__
if not isinstance(pattern, Pattern):
raise TypeError(f"Expecting pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern):
raise TypeError(f"Expecting target of Pattern type, got : ({type(target)}){target}")
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_)
def unregiste(self, py_pass, pipeline_phase=phase.opt):
if not isinstance(pipeline_phase, phase):
raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}")
if isinstance(py_pass, str):
super().unregiste(py_pass, pipeline_phase)
return
if isfunction(py_pass):
super().unregiste(py_pass.__name__, pipeline_phase)
return
raise TypeError(f"Expecting py_pass to be string or function, got ({type(py_pass)}){py_pass}")
def __call__(self, py_pass):
self.registe(py_pass)
return py_pass
def registe_pass(pipeline_phase=phase.opt, run_only_once=False, multi_graph=True):
"""
Examples:
>>> @registe_pass()
>>> def toy_pass():
>>> def pattern():
>>> pass
>>> def target():
>>> pass
"""
return PyPassManager(pipeline_phase, run_only_once, multi_graph)

@ -0,0 +1,15 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Top-level reference to python pass."""

@ -15,7 +15,8 @@
"""Patterns for describing graphs"""
from mindspore.ops import Primitive
from mindspore.common.tensor import Tensor
from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_
from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_,\
NewParameter_, Imm
__all__ = [
"IsIn",
@ -24,17 +25,25 @@ __all__ = [
"IsNot",
"AnyPattern",
"NewTensor",
"NewParameter",
"Imm"
]
class IsIn(IsIn_):
"""
r"""
Express a pattern which allows a list of patterns.
"""
def __init__(self, patterns=None, should_replace=True):
r"""
Args:
patterns(list/tuple): list of allowed patterns
patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`],
list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
each element should be one of the exposed Pattern instance.
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
Raises:
ValueError: raise if should_replace is False
TypeError: raise type error for invalid inputs.
"""
if not should_replace:
raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \
@ -52,19 +61,28 @@ class IsIn(IsIn_):
class IsPrimTypeOf(IsPrimTypeOf_):
r"""
Express a pattern of certain primitive type(s).
NOTE: This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
please refer to CallWith pattern.
NOTE:
This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
please refer to CallWith pattern.
"""
def __init__(self, types, name=None, should_replace=True):
r"""
Args:
types (str/(list/tuple of Primitives)): Specify allowed types.
types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
tuple[:class:`mindspore.ops.Primitive`]):
Specify allowed types.
If it is a string, the form could be
1) a single primitive type, e.g. 'Conv2D'
2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
It can also be a list of Primitives, e.g. [ops.Conv2D(1, 6)]
name (str): name of the pattern, optional
should_replace
It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
name (str): name of the pattern, optional. Default: None.
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
Default: True.
Raises:
TypeError: raise type error for invalid argument.
"""
if name is not None and not isinstance(name, str):
raise TypeError(f"Expect string, got : {name}")
@ -91,12 +109,21 @@ class CallWith(CallWith_):
r"""
Express a primitive CNode.
"""
def __init__(self, prim_pattern, inputs=None, should_replace=False):
def __init__(self, prim_pattern, inputs=None, should_replace=True):
r"""
Args:
prim_pattern (Pattern/Primitive/str): Primitive ValueNode in the Primitive CNode.
inputs (list/tuple): Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs;
if specified, input patterns should be of right order.
prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
:class:`mindspore.ops.Primitive`]): Primitive ValueNode in the Primitive CNode.
inputs (Union[list[:class:`mindspore.graph_utils.graph_pattern`],
tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
patterns should be of right order and each element should be one of the exposed Pattern instance.
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
Default: True.
Raises:
TypeError: raise type error for invalid argument.
"""
if not isinstance(prim_pattern, (Pattern, str, Primitive)):
raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
@ -110,17 +137,23 @@ class CallWith(CallWith_):
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace)
class IsNot(IsNot_):
r"""
Express a pattern which forbids a list of patterns.
NOTE: IsNot pattern should not be the root pattern.
NOTE:
IsNot pattern should not be the root pattern.
"""
def __init__(self, patterns=None, should_replace=True):
r"""
Args:
patterns(list/tuple): list of forbiden patterns
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
should be one of the exposed Pattern instance.
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
Raises:
ValueError: raise if should_replace is False.
TypeError: raise type error for invalid argument.
"""
if not should_replace:
raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \
@ -142,13 +175,48 @@ class NewTensor(NewTensor_):
def __init__(self, input_tensor, should_replace=False):
r"""
Args:
input_tensor(Tensor): new tensor to be used in the target
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
should_replace(bool): added this for interface consistency. NewTensor should only appear in the target.
Raises:
ValueError: raise if should_replace is True
TypeError: raise type error for invalid argument.
"""
if should_replace:
raise ValueError("NewTensor should only appear in the target, thus should_replace can onlyu be False.")
raise ValueError("NewTensor should only appear in the target, thus should_replace can only be False.")
self.input_tensor = input_tensor
if isinstance(input_tensor, Tensor):
NewTensor_.__init__(self, input_tensor)
else:
raise TypeError(f"Expect input_tensor to be a Tensor got : {input_tensor}")
class NewParameter(NewParameter_):
r"""
New Parameter to be used in the target.
"""
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False, should_replace=False):
r"""
Args:
para_name(str): name for the new Parameter
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
requires_grad(bool): True if the parameter requires gradient. Default: True
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new
parameter everytime a pass target got built. Default: False
Raises:
TypeError: raise type error for invalid argument.
"""
self.para_name = para_name
self.default_tensor = default_tensor
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
self.should_replace = should_replace
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
isinstance(layerwise_parallel, bool) and isinstance(should_replace, bool):
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
self.layerwise_parallel, self.should_replace)
else:
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
layerwise_parallel(bool) should_replace(bool) got : {para_name}, {default_tensor}, \
{requires_grad}, {layerwise_parallel}, {should_replace}")

@ -0,0 +1,24 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Top-level reference to python pass."""
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm
__all__ = [
"registe_pass",
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"set_renorm"
]

@ -0,0 +1,170 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Python pass register"""
from inspect import isfunction
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
from mindspore._c_expression import PyPassManager_, phase
__all__ = [
"registe_pass",
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"set_renorm"
]
class PyPassManager(PyPassManager_):
r"""
Used to registe and unregiste python passes which can be used to alter graphs.
Args:
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
Raises:
TypeError: If argument has invalid type.
"""
def __init__(self, pipeline_phase=phase.opt, run_only_once=False):
if not isinstance(pipeline_phase, phase):
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
if not isinstance(run_only_once, bool):
raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
PyPassManager_.__init__(self)
self.phase_ = pipeline_phase
self.run_only_once_ = run_only_once
def registe(self, py_pass):
if not isfunction(py_pass):
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
pattern, target = py_pass()
pass_name = py_pass.__name__
if not isinstance(pattern, Pattern):
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern):
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_)
def unregiste(self, py_pass, pipeline_phase=phase.opt):
if not isinstance(pipeline_phase, phase):
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
if isinstance(py_pass, str):
super().unregiste(py_pass, pipeline_phase)
return
if isfunction(py_pass):
super().unregiste(py_pass.__name__, pipeline_phase)
return
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
def __call__(self, py_pass):
self.registe(py_pass)
return py_pass
def gen_new_parameter(self, pattern):
if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
super().gen_new_parameter(pattern)
def set_renorm(self, should_renorm):
if not isinstance(should_renorm, bool):
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
super().set_renorm(should_renorm)
def registe_pass(pipeline_phase=phase.opt, run_only_once=False):
"""
Registe python pass to specified pipeline phase which would be used in compilation.
Args:
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
registed. Support phase.resolve and phase.opt. Default: phase.opt.
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
Returns:
This function should be used as a decorator, return the decoratorated pass function.
Examples:
>>> from mindspore.graph_utils.graph_pattern import IsPrimTypeOf
>>> @registe_pass()
>>> def toy_pass():
>>> pattern = IsPrimTypeOf("ReLU")
>>> target = IsPrimTypeOf("ReLU6")
>>> return pattern, target
"""
return PyPassManager(pipeline_phase, run_only_once)
def unregiste_pass(py_pass, pipeline_phase=phase.opt):
"""
Unregiste python pass.
Args:
py_pass(Union(str, function)): target python pass to unregiste.
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
unregisted. Support phase.resolve and phase.opt. Default: phase.opt.
"""
ppm = PyPassManager()
ppm.unregiste(py_pass, pipeline_phase)
def gen_new_parameter(pattern):
"""
Generate specified parameter every time a network gets compiled.
NOTE:
In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without
gen_new_parameter, every pass match would build a new Parameter.
This would registe a pass to add new parameter in the compilation pipeline, so later compilation would
ALSO add this parameter unless the pass is unregisted. To unregiste this pass, call
cancel_new_parameter(pattern)
Args:
pattern (NewParameter): NewParameter type, could be used to build nested patterns across multiple passes
after gen_new_parameter.
Raises:
TypeError: If argument has invalid type.
Examples:
>>> from mindspore.graph_utils.graph_pattern import NewParameter
>>> abc = NewParameter("abc")
>>> gen_new_parameter(abc)
"""
ppm = PyPassManager()
ppm.gen_new_parameter(pattern)
def cancel_new_parameter(pattern):
"""
Use with gen_new_parameter to unregiste gen_new_parameter pass.
Args:
pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern
describes.
Examples:
>>> from mindspore.graph_utils.graph_pattern import NewParameter
>>> abc = NewParameter("abc")
>>> gen_new_parameter(abs)
>>> # some compilations
>>> cancel_new_parameter(abc)
"""
if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
ppm = PyPassManager()
ppm.unregiste(pattern.para_name)
def set_renorm(should_renorm):
"""
Set whether or not to do renorm after modified graph in python pass(es).
"""
ppm = PyPassManager()
ppm.set_renorm(should_renorm)

@ -152,7 +152,7 @@ class Primitive(Primitive_):
Check if certain inputs should go to the backend. Subclass in need should override this method.
Args:
*args(Primitive args): Same as arguments of current Primitive.
args(Primitive args): Same as arguments of current Primitive.
Returns:
A tuple consisting of two elements. The first element indicates whether we should filter out current

@ -19,10 +19,12 @@ import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.common.python_pass_register import registe_pass, PyPassManager
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
cancel_new_parameter
from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_
from mindspore.common.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor
from mindspore.graph_utils.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor,\
NewParameter, Imm
context.set_context(mode=context.GRAPH_MODE)
@ -56,12 +58,39 @@ def test_softmax_relu():
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
ppm = PyPassManager()
ppm.unregiste(softmax_relu_pass)
unregiste_pass(softmax_relu_pass)
assert "ReLU" in transformed_repr
assert "Softmax" not in transformed_repr
def test_isin_pattern():
def test_softmax_relu_sigmoid():
"""
Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)).
NOTE:
Sigmoid pattern only exists in the target.
"""
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
def softmax_relu_pass():
x = AnyPattern()
softmax_pattern = IsPrimTypeOf(P.Softmax())
pattern = CallWith(softmax_pattern, inputs=[x])
sigmoid_pattern = IsPrimTypeOf(P.Sigmoid(), should_replace=False)
call_sigmoid = CallWith(sigmoid_pattern, [x])
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
target = CallWith(relu_pattern, inputs=[call_sigmoid])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
unregiste_pass(softmax_relu_pass)
assert "ReLU" in transformed_repr
assert "Sigmoid" in transformed_repr
assert "Softmax" not in transformed_repr
def test_isin_pattern_0():
"""
Test IsIn pattern which expresses the IsIn/OneOf semantics.
"""
@ -81,16 +110,41 @@ def test_isin_pattern():
target = CallWith(relu6_pattern, inputs=[x])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
ppm = PyPassManager()
ppm.unregiste(softmax_relu_pass)
unregiste_pass(softmax_relu_pass)
assert "ReLU6" in transformed_repr
assert "Softmax" not in transformed_repr
def test_isin_pattern_1():
"""
Test IsIn. IsIn is used as nested inputs for the target in this case.
"""
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
def softmax_neg_pass():
x = AnyPattern()
softmax_pattern = IsPrimTypeOf(P.Softmax())
call_softmax = CallWith(softmax_pattern, inputs=[x])
relu_pattern = IsPrimTypeOf(P.ReLU())
call_relu = CallWith(relu_pattern, inputs=[x])
pattern = IsIn([call_softmax, call_relu])
neg_ops = IsPrimTypeOf(P.Neg(), should_replace=False)
target = CallWith(neg_ops, inputs=[pattern])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
print(transformed_repr)
unregiste_pass(softmax_neg_pass)
assert "Neg" in transformed_repr
assert "Softmax" in transformed_repr
def test_isnot_pattern_0():
"""
Test IsNot pattern which expresses the IsNot semantics.
Case: IsNot pass failed to match
"""
set_renorm(False)
class ConvBN(nn.Cell):
def __init__(self):
super(ConvBN, self).__init__()
@ -132,11 +186,11 @@ def test_isnot_pattern_0():
return pattern, target
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
ppm = PyPassManager()
ppm.unregiste(single_bn_pass)
ppm.unregiste(bn_pass)
unregiste_pass(single_bn_pass)
unregiste_pass(bn_pass)
assert "ReLU6" not in transformed_repr
assert "Softmax" in transformed_repr
set_renorm(True)
def test_isnot_pattern_1():
"""
@ -160,12 +214,15 @@ def test_isnot_pattern_1():
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
ppm = PyPassManager()
ppm.unregiste(single_bn_pass)
unregiste_pass(single_bn_pass)
assert "ReLU6" in transformed_repr
assert "Softmax" not in transformed_repr
def test_newtensor_pattern():
"""
Test NewTensor pattern in the target
"""
set_renorm(False)
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@ -181,7 +238,84 @@ def test_newtensor_pattern():
target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False)
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
ppm = PyPassManager()
ppm.unregiste(softmax_addn_pass)
unregiste_pass(softmax_addn_pass)
assert "AddN" in transformed_repr
assert "Softmax" not in transformed_repr
set_renorm(True)
def test_newparameter_pattern():
"""
Test NewParameter pattern in the target
"""
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
def softmax_addn_pass():
x = AnyPattern()
softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[x])
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
new_para_0 = NewParameter("Merlin", default_tensor0)
new_para_1 = NewParameter("Arthur", default_tensor1)
target_0 = CallWith(P.MatMul(), inputs=[new_para_0, new_para_1], should_replace=False)
target = CallWith("make_tuple", inputs=[target_0], should_replace=False)
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
unregiste_pass(softmax_addn_pass)
assert "MatMul" in transformed_repr
assert "make_tuple" in transformed_repr
assert "Softmax" not in transformed_repr
def test_imm_pattern():
"""
Test NewParameter pattern in the target
"""
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
def softmax_addn_pass():
x = AnyPattern()
softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[x])
imm = Imm(0)
target_0 = CallWith("make_tuple", inputs=[pattern], should_replace=False)
target = CallWith("tuple_getitem", inputs=[target_0, imm], should_replace=False)
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
unregiste_pass(softmax_addn_pass)
assert "make_tuple" in transformed_repr
assert "tuple_getitem" in transformed_repr
assert "Softmax" in transformed_repr
def test_gen_new_parameter():
"""
Test gen_new_parameter
"""
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
new_para = NewParameter("Merlin", default_tensor, should_replace=True)
gen_new_parameter(new_para)
@registe_pass(run_only_once=True)
def softmax_make_tuple_pass():
x = AnyPattern()
softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[x])
target = CallWith("make_tuple", inputs=[pattern, new_para], should_replace=False)
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
assert "Merlin" in transformed_repr
unregiste_pass(softmax_make_tuple_pass)
cancel_new_parameter(new_para)
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
assert "Merlin" not in transformed_repr

Loading…
Cancel
Save