|
|
@ -49,7 +49,8 @@ void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
|
|
|
|
void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
|
|
|
|
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ").";
|
|
|
|
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << py::str(py_fn.cast<py::object>())
|
|
|
|
|
|
|
|
<< ").";
|
|
|
|
auto fn = fn_cache_.find(types);
|
|
|
|
auto fn = fn_cache_.find(types);
|
|
|
|
if (fn != fn_cache_.end()) {
|
|
|
|
if (fn != fn_cache_.end()) {
|
|
|
|
MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
|
|
|
|
MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
|
|
|
@ -116,7 +117,7 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
|
|
|
auto py_fn = SignMatch(types);
|
|
|
|
auto py_fn = SignMatch(types);
|
|
|
|
std::ostringstream buffer;
|
|
|
|
std::ostringstream buffer;
|
|
|
|
buffer << types;
|
|
|
|
buffer << types;
|
|
|
|
if (py_fn != py::none()) {
|
|
|
|
if (!py_fn.is_none()) {
|
|
|
|
FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
|
|
|
|
FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
|
|
|
|
if (func_graph == nullptr) {
|
|
|
|
if (func_graph == nullptr) {
|
|
|
|
MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str();
|
|
|
|
MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str();
|
|
|
|