From f267a105b81c9862278add45cae4db3f21b7869b Mon Sep 17 00:00:00 2001 From: BowenK Date: Mon, 13 Jul 2020 19:15:20 +0800 Subject: [PATCH] Add Python Pass UT --- mindspore/ccsrc/ir/anf_py.cc | 28 ++++++++ mindspore/ccsrc/ir/func_graph_py.cc | 35 ++++++++++ mindspore/ccsrc/optimizer/py_pass.cc | 1 + mindspore/ccsrc/pipeline/action.cc | 2 +- mindspore/ccsrc/pipeline/init.cc | 7 -- mindspore/ops/primitive.py | 2 +- tests/ut/python/optimizer/test_python_pass.py | 64 +++++++++++++++++++ 7 files changed, 130 insertions(+), 9 deletions(-) create mode 100644 mindspore/ccsrc/ir/anf_py.cc create mode 100644 mindspore/ccsrc/ir/func_graph_py.cc create mode 100644 tests/ut/python/optimizer/test_python_pass.py diff --git a/mindspore/ccsrc/ir/anf_py.cc b/mindspore/ccsrc/ir/anf_py.cc new file mode 100644 index 0000000000..d033dfff5a --- /dev/null +++ b/mindspore/ccsrc/ir/anf_py.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "ir/anf.h" + +#include "pybind_api/api_register.h" + +namespace mindspore { +// Define python 'RefKey' class. +REGISTER_PYBIND_DEFINE(CNode, ([](const pybind11::module *m) { + (void)py::class_(*m, "CNode") + .def("expanded_str", (std::string(CNode::*)(int) const) & CNode::DebugString, + "Get CNode string representation with specified expansion level."); + })); +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph_py.cc b/mindspore/ccsrc/ir/func_graph_py.cc new file mode 100644 index 0000000000..cff25b5aa1 --- /dev/null +++ b/mindspore/ccsrc/ir/func_graph_py.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "ir/meta_func_graph.h" +#include "ir/func_graph.h" + +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { + // Define python "MetaFuncGraph_" class + (void)py::class_>(*m, "MetaFuncGraph_") + .def_readonly(PYTHON_METAFUNCGRAPH_FLAG, &MetaFuncGraph::parse_info_) + .def(py::init()); + // Define python "FuncGraph" class + (void)py::class_(*m, "FuncGraph") + .def(py::init()) + .def("str", &FuncGraph::ToString, "Get FuncGraph string representation.") + .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph"); + })); +} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/py_pass.cc b/mindspore/ccsrc/optimizer/py_pass.cc index 8ce348b22e..842ccb75b9 100644 --- a/mindspore/ccsrc/optimizer/py_pass.cc +++ b/mindspore/ccsrc/optimizer/py_pass.cc @@ -54,6 +54,7 @@ void ResolveFuncGraph_(const FuncGraphPtr &fg) { auto manager = Manage(fg, false); parse::python_adapter::set_use_signature_in_resolve(false); parse::ResolveAll(manager); + parse::python_adapter::set_use_signature_in_resolve(true); } bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index a645452cc0..3c2ca3f84b 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -437,7 +437,7 @@ bool ResolveActionPyStub(const ResourcePtr &res) { } bool OptActionPyStub(const ResourcePtr &res) { - ActionPyStub(res, opt::python_pass::Phase::RESOLVE); + ActionPyStub(res, opt::python_pass::Phase::OPT); return true; } diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 199e841fc9..06b7fa756f 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -38,7 +38,6 @@ #endif namespace py = pybind11; -using FuncGraph = mindspore::FuncGraph; using EnvInstance = mindspore::EnvInstance; using ExecutorPy = mindspore::pipeline::ExecutorPy; using Pipeline = mindspore::pipeline::Pipeline; @@ -54,10 +53,6 @@ using CostModelContext = mindspore::parallel::CostModelContext; PYBIND11_MODULE(_c_expression, m) { m.doc() = "MindSpore c plugin"; - (void)py::class_>(*m, "MetaFuncGraph_") - .def_readonly(mindspore::PYTHON_METAFUNCGRAPH_FLAG, &mindspore::MetaFuncGraph::parse_info_) - .def(py::init()); - auto fns = mindspore::PybindDefineRegister::AllFuncs(); for (auto &item : fns) { item.second(&m); @@ -85,8 +80,6 @@ PYBIND11_MODULE(_c_expression, m) { py::arg("broadcast_params") = py::dict(), "Build data graph.") .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph."); - // Class Graph interface - (void)py::class_(m, "FuncGraph").def(py::init()); (void)py::class_>(m, "EnvInstance_") .def_readonly(mindspore::PYTHON_ENVINSTANCE_FLAG, &mindspore::EnvInstance::parse_info_) diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 768e9db2db..cb34e9ff24 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -146,7 +146,7 @@ class Primitive(Primitive_): Check whether or not certain inputs should go into backend. Subclass in need should override this method. Args: - Same as arguments of current Primitive + *args(Primitive args): Same as arguments of current Primitive. Returns: A tuple of two elements, first element indicates whether or not we should filter out current arguments; diff --git a/tests/ut/python/optimizer/test_python_pass.py b/tests/ut/python/optimizer/test_python_pass.py new file mode 100644 index 0000000000..c3ce3d6c4e --- /dev/null +++ b/tests/ut/python/optimizer/test_python_pass.py @@ -0,0 +1,64 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore +import mindspore.nn as nn +from mindspore import context +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P +from mindspore.common.python_pass_register import registe_pass, PyPassManager +from mindspore.common.api import _generate_pip_args +from mindspore._c_expression import generate_key, Executor_ + +context.set_context(mode=context.GRAPH_MODE) + +def get_func_graph(obj, *args, phase="predict"): + args_names, args_list = _generate_pip_args(obj, *args) + dic = dict(zip(args_names, args_list)) + key = generate_key(phase, dic) + phase_prefix = str(key[1]) + if phase == 'export': + phase = phase + '.' + phase_prefix + '.' + str(obj.create_time) + else: + phase = phase_prefix + phase + '.' + str(obj.create_time) + _executor = Executor_.get_instance() + _executor.compile(obj, args_list, phase, False) + return _executor.get_func_graph(phase) + +def test_softmax_relu(): + """ + Use python pass to transform from Softmax to ReLU. + """ + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def softmax_relu_pass(): + softmax = P.Softmax() + relu = P.ReLU() + def pattern(x): + x = softmax(x) + return x + def target(x): + x = relu(x) + return x + return pattern, target + + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) + ppm = PyPassManager() + ppm.unregiste(softmax_relu_pass) + assert "ReLU" in transformed_repr + assert "Softmax" not in transformed_repr