diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index 1836a88dbc..24dea8f1f0 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -346,10 +346,6 @@ class TensorAddByZero : public AnfVisitor { } void Visit(const AnfNodePtr &node) override { - if (IsPrimitive(node, prim::kPrimZerosLike)) { - is_zero_ = true; - return; - } if (node->isa() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { is_zero_ = true; return; diff --git a/mindspore/ccsrc/optimizer/pass_group.cc b/mindspore/ccsrc/optimizer/pass_group.cc new file mode 100644 index 0000000000..2d1ab07f7d --- /dev/null +++ b/mindspore/ccsrc/optimizer/pass_group.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "optimizer/pass_group.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +void PassGroup::AddPass(const PythonPassPtr &pass) { + if (pass != nullptr) { + passes_.push_back(pass); + } +} + +bool PassGroup::DeletePass(const std::string &pass_name) { + for (auto iter = passes_.begin(); iter != passes_.end(); iter++) { + if ((*iter)->name() == pass_name) { + *iter = nullptr; + passes_.erase(iter); + return true; + } + } + return false; +} + +bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { + if (func_graph == nullptr) { + return false; + } + bool changed = false; + for (const auto &pass : passes) { + if (pass != nullptr) { + if (pass->Run(func_graph)) { + changed = true; + } + } + } + return changed; +} + +bool PassGroup::Run(const FuncGraphPtr &func_graph) const { + bool changed = false; + // run all passes + bool change = true; + while (change) { + change = Run(func_graph, passes_); + changed = change || changed; + if (run_only_once_) { + break; + } + } + return changed; +} + +} // namespace python_pass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/pass_group.h b/mindspore/ccsrc/optimizer/pass_group.h new file mode 100644 index 0000000000..895f5a4128 --- /dev/null +++ b/mindspore/ccsrc/optimizer/pass_group.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ + +#include +#include +#include +#include + +#include "optimizer/py_pass.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +class PassGroup { + public: + explicit PassGroup(const std::string &name = "pass_group", bool run_only_once = false) + : name_(name), passes_{}, run_only_once_(run_only_once) {} + virtual ~PassGroup() = default; + // Add graph pass, the pass object will be freed when pass manager freed. + void AddPass(const PythonPassPtr &pass); + // Delete graph pass before the pass manager is freed. + bool DeletePass(const std::string &pass_name); + // Run passes added in pass manager on the input graph + // @param [inout] graph The graph to be optimized + // @return true, graph changed + // @return false, graph not changed + bool Run(const FuncGraphPtr &func_graph) const; + // Run the given graph passes on the input graph + // @param [inout] graph The graph to be optimized + // @param [in] passes The given graph passes + // @return true, graph changed + // @return false, graph not changed + bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; + std::string name() const { return name_; } + + private: + const std::string name_; + std::vector passes_; + bool run_only_once_; +}; +using PassGroupPtr = std::shared_ptr; +} // namespace python_pass +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ diff --git a/mindspore/ccsrc/optimizer/py_pass.cc b/mindspore/ccsrc/optimizer/py_pass.cc new file mode 100644 index 0000000000..8ce348b22e --- /dev/null +++ b/mindspore/ccsrc/optimizer/py_pass.cc @@ -0,0 +1,236 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "optimizer/py_pass.h" +#include +#include +#include +#include +#include + +#include "ir/func_graph.h" +#include "ir/manager.h" +#include "pipeline/parse/parse_base.h" +#include "pipeline/resource.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +namespace internal { +std::string GetNodeRepr(AnfNodePtr node) { + if (node != nullptr) { + if (node->isa()) { + std::string repr = "("; + auto const &inputs = node->cast()->inputs(); + for (auto &input : inputs) { + repr += " "; + repr += GetNodeRepr(input); + repr += " "; + } + repr += ")"; + return repr; + } + if (node->isa()) { + return GetValueNode(node)->ToString(); + } + return node->ToString(); + } + return ""; +} + +void ResolveFuncGraph_(const FuncGraphPtr &fg) { + auto manager = Manage(fg, false); + parse::python_adapter::set_use_signature_in_resolve(false); + parse::ResolveAll(manager); +} + +bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { + if (node == nullptr) { + return false; + } + MS_EXCEPTION_IF_NULL(pattern); + if (pattern->isa()) { + if (!node->isa()) { + return false; + } + if (GetNodeRepr(pattern) == GetNodeRepr(node)) { + // add to equiv_ptr + equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node)); + return true; + } + return false; + } else if (pattern->isa()) { + MS_LOG(DEBUG) << pattern->ToString() + "\n"; + // add to equiv_ptr + equiv_ptr->insert(std::make_pair(pattern->ToString(), node)); + return true; + } else if (pattern->isa()) { + // match every single sub ANode + if (!node->isa()) { + return false; + } + auto pattern_inputs = pattern->cast()->inputs(); + auto node_inputs = node->cast()->inputs(); + if (pattern_inputs.size() != node_inputs.size()) { + return false; + } + for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end(); + p_item++, node_item++) { + auto res = Match(*p_item, *node_item, equiv_ptr); + if (!res) { + return false; + } + } + return true; + } + MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n"; +} + +AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_, + const NodeEquivPtr &equiv_ptr) { + if (cur_raw_dst_node_->isa()) { + auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString()); + if (sub_pair != equiv_ptr->end()) { + return sub_pair->second; + } + MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n"; + } else if (cur_raw_dst_node_->isa()) { + // check primitive ValueNode + auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast()->value()->ToString()); + if (sub_pair != equiv_ptr->end()) { + return sub_pair->second; + } + return cur_raw_dst_node_; + } else if (cur_raw_dst_node_->isa()) { + std::vector new_inputs; + auto inputs = cur_raw_dst_node_->cast()->inputs(); + for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) { + auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr); + new_inputs.push_back(subed); + } + return func_graph->NewCNode(new_inputs); + } + MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_); +} + +bool isTraversable(const AnfNodePtr &node) { + if (node == nullptr) { + return false; + } + if (node->isa() || node->isa()) { + return true; + } + if (IsValueNode(node) || IsValueNode(node)) { + return true; + } + return false; +} +} // namespace internal + +void PythonPass::Build(const py::function &src, const py::function &dst) { + // 1. get FuncGraph from py::function + auto src_fg_ = parse::ParsePythonCode(src); + auto dst_fg_ = parse::ParsePythonCode(dst); + if (src_fg_ == nullptr || dst_fg_ == nullptr) { + MS_LOG(EXCEPTION) << "Failed to parse python code.\n"; + } + // 2. Resolve + internal::ResolveFuncGraph_(src_fg_); + internal::ResolveFuncGraph_(dst_fg_); + // 3. from FuncGraphPtr to ValueNode + src_node_ = src_fg_->output(); + dst_node_ = dst_fg_->output(); +} + +PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once, + bool multigraph) + : name_(name), run_only_once_(run_only_once), multigraph_(multigraph) { + Build(src, dst); +} + +AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + auto equiv_ptr = std::make_shared(); + bool is_a_match = internal::Match(src_node_, node, equiv_ptr); + if (is_a_match) { + auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr); + MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; + return new_node; + } + return nullptr; +} + +bool PythonPass::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(func_graph); + auto seen = NewSeenGeneration(); + // 1024 is for the initial capacity of deque + std::deque todo(1024); + todo.push_back(func_graph->output()); + bool changes = false; + + auto &all_nodes = manager->all_nodes(); + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); + + // check whether this node has been matched. + if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) { + continue; + } + node->seen_ = seen; + + // select nodes that this transform can be applied. + AnfNodePtr new_node = Run(func_graph, node); + bool change = (new_node != nullptr); + if (new_node != nullptr && new_node != node) { + (void)manager->Replace(node, new_node); + } else if (new_node == nullptr) { + new_node = node; + } + if (run_only_once_) { + return change; + } + + // find success, and add them to todo list + if (IsValueNode(node)) { + todo.push_back(GetValueNode(node)->output()); + } + + if (node->isa()) { + auto &inputs = node->cast()->inputs(); + (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); + } + + auto &node_users = manager->node_users(); + if (change && node_users.find(node) != node_users.end()) { + for (auto &use : node_users[node]) { + auto use_node = use.first; + if (use_node == nullptr) { + continue; + } + todo.push_back(use_node); + if (use_node->seen_ == seen) { + use_node->seen_--; + } + } + } + } + return changes; +} +} // namespace python_pass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/py_pass.h b/mindspore/ccsrc/optimizer/py_pass.h new file mode 100644 index 0000000000..b01bf7c942 --- /dev/null +++ b/mindspore/ccsrc/optimizer/py_pass.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_PASS_H_ +#include +#include +#include + +#include "ir/anf.h" +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +class PythonPass; +using PythonPassPtr = std::shared_ptr; +using NodeEquiv = std::unordered_map; +using NodeEquivPtr = std::shared_ptr; + +class PythonPass { + public: + explicit PythonPass(const std::string &name, const py::function &src, const py::function &dst, + bool run_only_once = false, bool multigraph = true); + ~PythonPass() = default; + bool Run(const FuncGraphPtr &func_graph); + std::string name() const { return name_; } + AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node); + + private: + void Build(const py::function &src, const py::function &dst); + AnfNodePtr src_node_ = nullptr; + AnfNodePtr dst_node_ = nullptr; + const std::string name_; + bool run_only_once_; + bool multigraph_ = true; +}; + +using PythonPassPtr = std::shared_ptr; +} // namespace python_pass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_H_ diff --git a/mindspore/ccsrc/optimizer/py_pass_manager.cc b/mindspore/ccsrc/optimizer/py_pass_manager.cc new file mode 100644 index 0000000000..1c36e93c9a --- /dev/null +++ b/mindspore/ccsrc/optimizer/py_pass_manager.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "optimizer/py_pass_manager.h" + +#include +#include +#include +#include + +#include "ir/manager.h" +#include "optimizer/pass_group.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +PyPassManagerPtr PyPassManager::global_instance = nullptr; +std::unordered_map PyPassManager::phase_to_group_; + +PassGroupPtr PyPassManager::GetPassGroup(Phase phase) { + auto pm = phase_to_group_.find(phase); + if (pm == phase_to_group_.end()) { + return nullptr; + } + return pm->second; +} + +PyPassManagerPtr PyPassManager::GetInstance() { + if (global_instance == nullptr) { + global_instance = std::shared_ptr(new (std::nothrow) PyPassManager()); + } + return global_instance; +} + +PyPassManager::PyPassManager() { + phase_to_group_[Phase::RESOLVE] = std::make_shared(); + phase_to_group_[Phase::OPT] = std::make_shared(); +} + +void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, + Phase phase, bool run_only_once, bool multigraph) { + auto cur_pm = GetPassGroup(phase); + MS_EXCEPTION_IF_NULL(cur_pm); + PythonPassPtr new_pass = std::make_shared(pass_name, pattern, target, run_only_once, multigraph); + cur_pm->AddPass(new_pass); +} + +void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { + auto cur_pm = GetPassGroup(phase); + MS_EXCEPTION_IF_NULL(cur_pm); + if (!cur_pm->DeletePass(pass_name)) { + MS_LOG(WARNING) << "No such pass : " + pass_name + "\n"; + } +} + +void PyPassManager::ClearRes() { + MS_LOG(INFO) << "Clear PyPassManager resources!"; + global_instance = nullptr; + phase_to_group_.clear(); +} + +REGISTER_PYBIND_DEFINE( + PyPassManager_, ([](const py::module *m) { + (void)py::enum_(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT); + (void)py::class_>(*m, "PyPassManager_") + .def(py::init([]() { return PyPassManager::GetInstance(); })) + .def("registe", &PyPassManager::Registe, "Registe python pass") + .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass"); + })); +} // namespace python_pass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/py_pass_manager.h b/mindspore/ccsrc/optimizer/py_pass_manager.h new file mode 100644 index 0000000000..eaeefce213 --- /dev/null +++ b/mindspore/ccsrc/optimizer/py_pass_manager.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ + +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" +#include "utils/graph_utils.h" +#include "common/utils.h" + +#include "pipeline/parse/resolve.h" +#include "optimizer/py_pass.h" +#include "optimizer/pass_group.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +class PyPassManager; +using PyPassManagerPtr = std::shared_ptr; + +enum Phase { RESOLVE, OPT }; + +class PyPassManager { + protected: + PyPassManager(); + static PyPassManagerPtr global_instance; + + public: + // Singletons should not be cloneable and assignable + PyPassManager(const PyPassManager &other) = delete; + void operator=(const PyPassManager &) = delete; + // Access the only global instance + static PyPassManagerPtr GetInstance(); + virtual ~PyPassManager() = default; + void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, + Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); + void Unregiste(const std::string &pass_name, Phase phase); + PassGroupPtr GetPassGroup(Phase phase); + void ClearRes(); + + private: + static std::unordered_map phase_to_group_; +}; +} // namespace python_pass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index 3265905848..4c0bb0f81c 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -39,6 +39,7 @@ #include "optimizer/optimizer.h" #include "vm/transform.h" #include "parse/python_adapter.h" +#include "optimizer/py_pass_manager.h" namespace mindspore { namespace pipeline { @@ -420,6 +421,25 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } +void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { + MS_EXCEPTION_IF_NULL(res->manager()); + MS_EXCEPTION_IF_NULL(res->func_graph()); + auto ppm = opt::python_pass::PyPassManager::GetInstance(); + if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) { + MS_LOG(DEBUG) << "No match.\n"; + } +} + +bool ResolveActionPyStub(const ResourcePtr &res) { + ActionPyStub(res, opt::python_pass::Phase::RESOLVE); + return true; +} + +bool OptActionPyStub(const ResourcePtr &res) { + ActionPyStub(res, opt::python_pass::Phase::RESOLVE); + return true; +} + static std::vector CommonPipeline() { std::vector actions; @@ -432,6 +452,8 @@ static std::vector CommonPipeline() { if (!multi_graphs) { actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); } + // Add resolve-stage python pass stub + actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub)); actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); // Evaluate type and shape, and specialize actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); @@ -443,6 +465,8 @@ std::vector GePipeline() { auto actions = CommonPipeline(); // optimize actions.emplace_back(std::make_pair("optimize", GeOptimizeAction)); + // Add opt-stage python pass stub + actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); actions.emplace_back(std::make_pair("validate", ValidateAction)); return actions; @@ -454,6 +478,9 @@ std::vector VmPipeline() { // optimize actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); + // Add opt-stage python pass stub + actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); + actions.emplace_back(std::make_pair("validate", ValidateAction)); // compile the ANF graph diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index bb1f693c6b..47b191fbc2 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -39,6 +39,7 @@ #include "device/kernel_runtime_manager.h" #include "debug/trace.h" #include "pynative/pynative_execute.h" +#include "optimizer/py_pass_manager.h" #if (ENABLE_GE || ENABLE_D) #include "pipeline/pipeline_ge.h" @@ -964,6 +965,7 @@ void ClearResAtexit() { pipeline::ExecutorPy::ClearRes(); pipeline::ReclaimOptimizer(); pynative::PynativeExecutor::GetInstance()->ClearRes(); + opt::python_pass::PyPassManager::GetInstance()->ClearRes(); #ifdef ENABLE_GE transform::DfGraphManager::GetInstance().ClearGraph(); transform::DfGraphConvertor::get_adpt_map().clear(); diff --git a/mindspore/common/python_pass_register.py b/mindspore/common/python_pass_register.py new file mode 100644 index 0000000000..36eb37adc7 --- /dev/null +++ b/mindspore/common/python_pass_register.py @@ -0,0 +1,80 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Python pass register""" +from inspect import isfunction +from mindspore._c_expression import PyPassManager_ +from mindspore._c_expression import phase + +class PyPassManager(PyPassManager_): + r""" + Used to registe and unregiste python passes which can be used to alter graphs. + + Args: + pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt. + run_only_once (bool): Specify whether or not to run pass only once. Default: False. + multigraph (bool): Whether or not the pattern exists across graphs. Default: True. + + Raises: + TypeError: If argument has invalid type. + """ + def __init__(self, pipeline_phase=phase.opt, run_only_once=False, multi_graph=True): + if not isinstance(pipeline_phase, phase): + raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}") + if not isinstance(run_only_once, bool): + raise TypeError(f"Expecting bool, got : ({type(run_only_once)}){run_only_once}") + if not isinstance(multi_graph, bool): + raise TypeError(f"Expecting bool, got : ({type(multi_graph)}){multi_graph}") + PyPassManager_.__init__(self) + self.phase_ = pipeline_phase + self.run_only_once_ = run_only_once + self.multi_graph_ = multi_graph + + def registe(self, py_pass): + if not isfunction(py_pass): + raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}") + pattern, target = py_pass() + pass_name = py_pass.__name__ + if not isfunction(pattern): + raise TypeError(f"Expecting function pattern, got : ({type(pattern)}){pattern}") + if not isfunction(target): + raise TypeError(f"Expecting function target, got : ({type(target)}){target}") + super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_) + + def unregiste(self, py_pass, pipeline_phase=phase.opt): + if not isinstance(pipeline_phase, phase): + raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}") + if isinstance(py_pass, str): + super().unregiste(py_pass, pipeline_phase) + return + if isfunction(py_pass): + super().unregiste(py_pass.__name__, pipeline_phase) + return + raise TypeError(f"Expecting py_pass to be string or function, got ({type(py_pass)}){py_pass}") + + def __call__(self, py_pass): + self.registe(py_pass) + return py_pass + +def registe_pass(pipeline_phase=phase.opt, run_only_once=False, multi_graph=True): + """ + Examples: + >>> @registe_pass() + >>> def toy_pass(): + >>> def pattern(): + >>> pass + >>> def target(): + >>> pass + """ + return PyPassManager(pipeline_phase, run_only_once, multi_graph) diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 548fbcec1e..5bc678b003 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -170,7 +170,8 @@ class Dense(Cell): bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + activation (str): activate function applied to the output of the fully connected layer, eg. 'relu'. + Default: None. Raises: ValueError: If weight_init or bias_init shape is incorrect.