fix `MultitypeFuncGraph` and `HyperMap` in pynative mode

pull/1542/head
Wei Luning 5 years ago
parent c086d91aaf
commit ebf02dd528

@ -98,25 +98,29 @@ TypePtr TensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<TensorType>();
} else {
return std::make_shared<TensorType>(element_type_->DeepCopy());
}
return std::make_shared<TensorType>(element_type_->DeepCopy());
}
std::string TensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "tensor";
}
return "tensor[" + element_type_->ToReprString() + "]";
}
std::string TensorType::ToString() const {
if (element_type_ == nullptr) {
return "Tensor";
} else {
return "Tensor[" + element_type_->ToString() + "]";
}
return "Tensor[" + element_type_->ToString() + "]";
}
std::string TensorType::DumpText() const {
if (element_type_ == nullptr) {
return "Tensor";
} else {
return "Tensor(" + element_type_->DumpText() + ")";
}
return "Tensor(" + element_type_->DumpText() + ")";
}
bool TensorType::operator==(const Type &other) const {

@ -121,7 +121,7 @@ class TensorType : public Object {
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override { return "tensor"; }
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;

@ -363,6 +363,7 @@ REGISTER_PYBIND_DEFINE(
(void)m_sub.def("load_type", &TypeIdToType, "load type");
(void)m_sub.def(
"dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type");
(void)m_sub.def("str_to_type", &StringToType, "string to typeptr");
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
.def("__eq__",

@ -649,115 +649,6 @@ REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
py::arg("get_by_list"), py::arg("sens_param"));
}));
MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) {
fn_cache_.clear();
signatures_ = std::vector<Signature>({// def multitype(*args:ref):
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
}
void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) {
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ".";
auto fn = fn_cache_.find(types);
if (fn != fn_cache_.end()) {
MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
}
fn_cache_[types] = s_fn;
}
void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ").";
auto fn = fn_cache_.find(types);
if (fn != fn_cache_.end()) {
MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
}
fn_cache_py_[types] = py_fn;
}
void MultitypeFuncGraph::Register(const std::vector<std::string> &types_name, const py::function &py_fn) {
TypePtrList types;
for (auto &type_name : types_name) {
auto type_ptr = StringToType(type_name);
if (type_ptr == nullptr) {
MS_LOG(EXCEPTION) << type_name << " convert from string error ";
}
types.push_back(type_ptr);
}
Register(types, py_fn);
}
void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
std::vector<std::string> types_name;
for (size_t it = 0; it < tuple.size(); ++it) {
py::object name_py = tuple[it];
if (py::isinstance<py::str>(name_py)) {
types_name.push_back(name_py.cast<std::string>());
continue;
}
MS_LOG(EXCEPTION) << "Register must be string";
}
Register(types_name, py_fn);
}
static TypePtr UnwrapRef(const TypePtr &type) {
if (type->isa<RefType>()) {
return type->cast<RefTypePtr>()->subtype();
}
return type;
}
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
bool find_fn = false;
py::function py_fn;
for (auto &item : fn_cache_py_) {
TypePtrList sign = item.first;
if (sign.size() != types.size()) {
continue;
}
bool match = true;
for (size_t i = 0; i < sign.size(); ++i) {
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) {
match = false;
break;
}
}
if (!match) {
continue;
}
find_fn = true;
py_fn = item.second;
break;
}
std::ostringstream buffer;
buffer << types;
if (find_fn) {
FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str();
}
MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString();
return func_graph;
}
std::ostringstream oss;
oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
<< "`, corresponding location info:\n";
int idx = 0;
for (auto &item : fn_cache_py_) {
FuncGraphPtr func_graph = parse::ParsePythonCode(item.second);
if (func_graph == nullptr) {
MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`.";
continue;
}
oss << ++idx << ". " << item.first << "\n " << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
}
MS_LOG(EXCEPTION) << "The '" << name_ << "' operation does not support the type " << buffer.str() << "\n"
<< oss.str();
}
REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) {
(void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>(
*m, "MultitypeFuncGraph_")
.def(py::init<std::string &>())
.def("register_fn", &MultitypeFuncGraph::PyRegister);
}));
// Generate the ListMap func graph.
FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
size_t args_num = args_spec_list.size();

@ -30,6 +30,7 @@
#include "operator/composite/list_append_operation.h"
#include "operator/composite/do_signature.h"
#include "operator/composite/unpack_call.h"
#include "operator/composite/multitype_funcgraph.h"
#include "pipeline/static_analysis/static_analysis.h"
#include "utils/misc.h"
#include "utils/any.h"
@ -45,31 +46,6 @@ using AbstractTensorPtr = abstract::AbstractTensorPtr;
using ElemwiseMap = std::unordered_map<std::string, PrimitivePtr>;
using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
class MultitypeFuncGraph : public MetaFuncGraph {
public:
explicit MultitypeFuncGraph(const std::string &name);
~MultitypeFuncGraph() override = default;
MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph)
using specialize_fn = FuncGraph *(*)(TypePtrList);
// Register a method which specialize based on types vectors;
virtual void Register(const TypePtrList &types, specialize_fn s_fn);
virtual void Register(const TypePtrList &types, const py::function &py_fn);
virtual void Register(const std::vector<std::string> &types_name, const py::function &py_fn);
virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn);
FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override;
size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); }
const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> &GetPyFunctions() const {
return fn_cache_py_;
}
private:
std::unordered_map<TypePtrList, specialize_fn, TypeListHasher, TypeListEqual> fn_cache_;
std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_;
};
using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>;
class HyperMap : public MetaFuncGraph {
public:
explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);

@ -0,0 +1,153 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "operator/composite/multitype_funcgraph.h"
#include <algorithm>
#include <utility>
#include <sstream>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pipeline/static_analysis/abstract_function.h"
#include "pipeline/static_analysis/dshape.h"
#include "pipeline/static_analysis/param_validator.h"
#include "operator/cc_implementations.h"
#include "optimizer/opt.h"
#include "utils/symbolic.h"
#include "pybind_api/api_register.h"
#include "./common.h"
#include "ir/signature.h"
#include "debug/trace.h"
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) {
fn_cache_.clear();
signatures_ = std::vector<Signature>({// def multitype(*args:ref):
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
}
void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) {
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ".";
auto fn = fn_cache_.find(types);
if (fn != fn_cache_.end()) {
MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
}
fn_cache_[types] = s_fn;
}
void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ").";
auto fn = fn_cache_.find(types);
if (fn != fn_cache_.end()) {
MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
}
fn_cache_py_[types] = py_fn;
}
void MultitypeFuncGraph::Register(const std::vector<std::string> &types_name, const py::function &py_fn) {
TypePtrList types;
for (auto &type_name : types_name) {
auto type_ptr = StringToType(type_name);
if (type_ptr == nullptr) {
MS_LOG(EXCEPTION) << type_name << " convert from string error ";
}
types.push_back(type_ptr);
}
Register(types, py_fn);
}
void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
std::vector<std::string> types_name;
for (size_t it = 0; it < tuple.size(); ++it) {
py::object name_py = tuple[it];
if (py::isinstance<py::str>(name_py)) {
types_name.push_back(name_py.cast<std::string>());
continue;
}
MS_LOG(EXCEPTION) << "Register must be string";
}
Register(types_name, py_fn);
}
static TypePtr UnwrapRef(const TypePtr &type) {
if (type->isa<RefType>()) {
return type->cast<RefTypePtr>()->subtype();
}
return type;
}
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
bool find_fn = false;
py::function py_fn;
for (auto &item : fn_cache_py_) {
TypePtrList sign = item.first;
if (sign.size() != types.size()) {
continue;
}
bool match = true;
for (size_t i = 0; i < sign.size(); ++i) {
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) {
match = false;
break;
}
}
if (!match) {
continue;
}
find_fn = true;
py_fn = item.second;
break;
}
std::ostringstream buffer;
buffer << types;
if (find_fn) {
FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str();
}
MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString();
return func_graph;
}
std::ostringstream oss;
oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
<< "`, corresponding location info:\n";
int idx = 0;
for (auto &item : fn_cache_py_) {
FuncGraphPtr func_graph = parse::ParsePythonCode(item.second);
if (func_graph == nullptr) {
MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`.";
continue;
}
oss << ++idx << ". " << item.first << "\n " << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
}
MS_LOG(EXCEPTION) << "The '" << name_ << "' operation does not support the type " << buffer.str() << "\n"
<< oss.str();
}
REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) {
(void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>(
*m, "MultitypeFuncGraph_")
.def(py::init<std::string &>())
.def("register_fn", &MultitypeFuncGraph::PyRegister);
}));
} // namespace prim
} // namespace mindspore

@ -0,0 +1,66 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_
#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_
#include <vector>
#include <string>
#include <unordered_map>
#include <utility>
#include <map>
#include <set>
#include <memory>
#include "pipeline/static_analysis/static_analysis.h"
#include "utils/misc.h"
#include "ir/dtype.h"
#include "ir/meta_func_graph.h"
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
class MultitypeFuncGraph : public MetaFuncGraph {
public:
explicit MultitypeFuncGraph(const std::string &name);
~MultitypeFuncGraph() override = default;
MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph)
using specialize_fn = FuncGraph *(*)(TypePtrList);
// Register a method which specialize based on types vectors;
virtual void Register(const TypePtrList &types, specialize_fn s_fn);
virtual void Register(const TypePtrList &types, const py::function &py_fn);
virtual void Register(const std::vector<std::string> &types_name, const py::function &py_fn);
virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn);
FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override;
size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); }
const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> &GetPyFunctions() const {
return fn_cache_py_;
}
private:
std::unordered_map<TypePtrList, specialize_fn, TypeListHasher, TypeListEqual> fn_cache_;
std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_;
};
using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>;
} // namespace prim
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_

@ -156,7 +156,7 @@ def pytype_to_dtype(obj):
return obj
if isinstance(obj, type) and obj in _simple_types:
return _simple_types[obj]
raise NotImplementedError()
raise NotImplementedError(f"Unsupported type {obj} for `pytype_to_dtype`.")
def get_py_obj_dtype(obj):
@ -169,7 +169,11 @@ def get_py_obj_dtype(obj):
Returns:
Type of MindSpore type.
"""
# Tensor
if hasattr(obj, 'dtype'):
return tensor_type(obj.dtype())
if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
return function
if isinstance(obj, (typing.Type, type)):
return pytype_to_dtype(obj)
return pytype_to_dtype(type(obj))

@ -359,6 +359,4 @@ def tensor_grad_scale(scale, grad):
"""Get grad with scale."""
if scale == 1.0:
return grad
cast_op = P.Cast()
type_op = P.DType()
return grad * cast_op(F.scalar_to_array(scale), type_op(grad))
return grad * scale

@ -16,6 +16,7 @@
# ============================================================================
"""Basic composite operations."""
from functools import partial
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
@ -23,6 +24,7 @@ from ...common import dtype as mstype
from ...common.api import ms_function
from .. import functional as F
from .. import operations as P
from ...common.parameter import Parameter
__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
@ -144,7 +146,6 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
>>> # `add` is a metagraph object which will add two objects according to
>>> # input type using ".register" decorator.
>>> add = MultitypeFuncGraph('add')
"""
def __init__(self, name):
@ -152,8 +153,15 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
self.entries = list()
def __call__(self, *args):
for sig, fn in self.entries:
if len(sig) != len(args):
def unwrap(arg):
if isinstance(arg, Parameter):
return arg.data
return arg
types = tuple(map(lambda arg: mstype.get_py_obj_dtype(unwrap(arg)), args))
for sigs, fn in self.entries:
if len(sigs) != len(types):
continue
if any(not mstype.issubclass_(type_, sig) for sig, type_ in zip(sigs, types)):
continue
output = fn(*args)
return output
@ -162,8 +170,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
def register(self, *type_names):
"""Register a function for the given type string."""
def deco(fn):
types = tuple(map(mstype.typing.str_to_type, type_names))
self.register_fn(type_names, fn)
self.entries.append((type_names, fn))
self.entries.append((types, fn))
return fn
return deco
@ -198,38 +207,17 @@ class HyperMap(HyperMap_):
HyperMap_.__init__(self)
def __call__(self, *args):
func = args[0]
count = 0
count_max = 1
args_list = args[1:]
if self.ops is not None:
func = self.ops
args_list = args
for item in args_list:
if isinstance(item, (tuple, list)):
count_max = len(item)
break
def get_item(x):
nonlocal count
if isinstance(x, (tuple, list)):
return x[count]
return x
for i in range(count_max):
true_args = tuple(map(get_item, args_list))
func(*true_args)
count = i + 1
return True
def register(self, *type_names):
"""Register a function for the given type string."""
def deco(fn):
self.register_fn(type_names, fn)
return fn
return deco
func = self.ops
args_list = args
hypermap = self
if self.ops is None:
func = args[0]
args_list = args[1:]
hypermap = partial(self, func)
# is leaf
if not isinstance(args_list[0], (tuple, list)):
return func(*args_list)
return tuple(map(hypermap, *args_list))
class _ListAppend(ListAppend_):
"""

Loading…
Cancel
Save