commit
8871f06f42
@ -0,0 +1,28 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
|
||||
#include "pybind_api/api_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
// Define python 'RefKey' class.
|
||||
REGISTER_PYBIND_DEFINE(CNode, ([](const pybind11::module *m) {
|
||||
(void)py::class_<CNode, CNodePtr>(*m, "CNode")
|
||||
.def("expanded_str", (std::string(CNode::*)(int) const) & CNode::DebugString,
|
||||
"Get CNode string representation with specified expansion level.");
|
||||
}));
|
||||
} // namespace mindspore
|
@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include "ir/meta_func_graph.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
|
||||
// Define python "MetaFuncGraph_" class
|
||||
(void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
|
||||
.def_readonly(PYTHON_METAFUNCGRAPH_FLAG, &MetaFuncGraph::parse_info_)
|
||||
.def(py::init<std::string &>());
|
||||
// Define python "FuncGraph" class
|
||||
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
|
||||
.def(py::init())
|
||||
.def("str", &FuncGraph::ToString, "Get FuncGraph string representation.")
|
||||
.def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph");
|
||||
}));
|
||||
} // namespace mindspore
|
@ -0,0 +1,64 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.python_pass_register import registe_pass, PyPassManager
|
||||
from mindspore.common.api import _generate_pip_args
|
||||
from mindspore._c_expression import generate_key, Executor_
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
def get_func_graph(obj, *args, phase="predict"):
|
||||
args_names, args_list = _generate_pip_args(obj, *args)
|
||||
dic = dict(zip(args_names, args_list))
|
||||
key = generate_key(phase, dic)
|
||||
phase_prefix = str(key[1])
|
||||
if phase == 'export':
|
||||
phase = phase + '.' + phase_prefix + '.' + str(obj.create_time)
|
||||
else:
|
||||
phase = phase_prefix + phase + '.' + str(obj.create_time)
|
||||
_executor = Executor_.get_instance()
|
||||
_executor.compile(obj, args_list, phase, False)
|
||||
return _executor.get_func_graph(phase)
|
||||
|
||||
def test_softmax_relu():
|
||||
"""
|
||||
Use python pass to transform from Softmax to ReLU.
|
||||
"""
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
softmax = P.Softmax()
|
||||
relu = P.ReLU()
|
||||
def pattern(x):
|
||||
x = softmax(x)
|
||||
return x
|
||||
def target(x):
|
||||
x = relu(x)
|
||||
return x
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(softmax_relu_pass)
|
||||
assert "ReLU" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
Loading…
Reference in new issue