From 04763b8b768b8028781ab4816a170b83f81f3132 Mon Sep 17 00:00:00 2001 From: leopz Date: Wed, 20 May 2020 17:37:44 +0800 Subject: [PATCH] move signature to primitivepy and bprop_func to utils --- mindspore/ccsrc/ir/primitive.cc | 90 +----------- mindspore/ccsrc/ir/primitive.h | 115 ++-------------- mindspore/ccsrc/ir/primitive_base.cc | 71 ++++++++++ mindspore/ccsrc/ir/primitive_base.h | 128 ++++++++++++++++++ mindspore/ccsrc/ir/primitive_base_extends.cc | 25 ++++ .../ccsrc/operator/composite/do_signature.cc | 4 +- mindspore/ccsrc/optimizer/ad/kprim.cc | 4 +- mindspore/ccsrc/utils/primitive_utils.cc | 49 +++++++ mindspore/ccsrc/utils/primitive_utils.h | 33 +++++ mindspore/ccsrc/vm/vmimpl.cc | 3 +- tests/ut/cpp/operator/ops_test.cc | 2 +- 11 files changed, 328 insertions(+), 196 deletions(-) create mode 100644 mindspore/ccsrc/ir/primitive_base.cc create mode 100644 mindspore/ccsrc/ir/primitive_base.h create mode 100644 mindspore/ccsrc/ir/primitive_base_extends.cc create mode 100644 mindspore/ccsrc/utils/primitive_utils.cc create mode 100644 mindspore/ccsrc/utils/primitive_utils.h diff --git a/mindspore/ccsrc/ir/primitive.cc b/mindspore/ccsrc/ir/primitive.cc index d848f9c0d8..4fd000ca59 100644 --- a/mindspore/ccsrc/ir/primitive.cc +++ b/mindspore/ccsrc/ir/primitive.cc @@ -24,75 +24,13 @@ #include "pipeline/parse/data_converter.h" #include "pybind11/pytypes.h" #include "utils/convert_utils.h" +#include "utils/primitive_utils.h" #include "pybind_api/api_register.h" #include "pybind_api/export_flags.h" namespace mindspore { -using mindspore::abstract::AbstractFunction; - -abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) { - auto prim_func = std::make_shared(shared_from_base(), anf_node); - return prim_func; -} - -static py::function GetBpropFunctionByObj(py::object obj) { - static const std::string get_bprop_fn = "get_bprop_fn"; - static const std::string ad_module = "mindspore.ops._grad"; - py::function fn = parse::python_adapter::GetPyFn(ad_module, get_bprop_fn)(obj); - return fn; -} - -py::function Primitive::GetBpropFunction() { - auto fn = GetBpropFunctionByObj(py::str(name())); - if (fn.is_none()) { - MS_LOG(WARNING) << "Can't find bprop function for " << name(); - } - return fn; -} - -py::function Primitive::GetComputeFunction() { - static const std::string module = "mindspore._extends.builtin_operations"; - py::module mod = py::module::import(common::SafeCStr(module)); - if (!py::hasattr(mod, common::SafeCStr(name()))) { - PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name())); - // If raise AttributeError, user can't understand. This case need raise NotImplementedError. - throw py::error_already_set(); - } - py::object fn = mod.attr(common::SafeCStr(name())); - return fn; -} - -bool Primitive::operator==(const Value &other) const { - if (other.isa()) { - auto other_prim = static_cast(other); - return *this == other_prim; - } else { - return false; - } -} - -bool Primitive::operator==(const Primitive &other) const { - if (name() != other.name()) { - return false; - } - if (attrs_.size() != other.attrs_.size()) { - return false; - } - auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair &item) -> bool { - if (item.second == nullptr) { - return false; - } - auto iter = other.attrs_.find(item.first); - if (iter == other.attrs_.end()) { - return false; - } - return *item.second == *iter->second; - }); - return all; -} - -void Primitive::set_signatures( +void PrimitivePy::set_signatures( std::vector> signatures) { signatures_.clear(); for (auto &signature : signatures) { @@ -104,27 +42,7 @@ void Primitive::set_signatures( std::tie(name, rw, kind, default_value, dtype) = signature; signatures_.emplace_back(Signature(name, rw, kind, default_value, dtype)); } -} - -std::string Primitive::GetAttrsText() const { - if (attrs_.empty()) { - return ""; - } - - std::ostringstream oss; - oss << "["; - bool is_first = true; - for (auto &attr : attrs_) { - if (is_first) { - is_first = false; - } else { - oss << ", "; - } - oss << attr.first << "=" << attr.second->DumpText(); - } - oss << "]"; - - return oss.str(); + set_has_signature(true); } py::function PrimitivePy::GetBpropFunction() { @@ -158,7 +76,7 @@ py::function PrimitivePy::GetComputeFunction() { if (py::isinstance(vm_fn)) { MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast(); - vm_fn = Primitive::GetComputeFunction(); + vm_fn = mindspore::GetComputeFunction(Primitive::name()); } return vm_fn; } diff --git a/mindspore/ccsrc/ir/primitive.h b/mindspore/ccsrc/ir/primitive.h index 08c6b7dc9b..1959b31993 100644 --- a/mindspore/ccsrc/ir/primitive.h +++ b/mindspore/ccsrc/ir/primitive.h @@ -22,59 +22,26 @@ #include #include #include -#include "pybind11/pybind11.h" +#include "pybind11/pybind11.h" #include "pipeline/static_analysis/abstract_value.h" #include "utils/misc.h" #include "utils/log_adapter.h" +#include "ir/primitive_base.h" #include "ir/signature.h" - #include "parallel/ops_info/operator_info.h" namespace py = pybind11; namespace mindspore { -using abstract::AbstractBasePtr; -using abstract::AbstractBasePtrList; -// Supported meta type -enum PrimType { - kPrimTypeUnknown = 0, - kPrimTypeBegin = kTypeUnknown, - kPrimTypeBuiltIn, // Built-in primitive operator - kPrimTypePyInferShape, // Primitive operator defined by custom - kPrimTypePyInferTensor, // Primitive operator defined by custom - kPrimTypeUserCustom -}; - -class Primitive : public Named { +class PrimitivePy : public Primitive { public: - explicit Primitive(const std::string &name, const PrimType prim_type = kPrimTypeBuiltIn) - : Named(name), signatures_(), prim_type_(prim_type) {} - - Primitive(const Primitive &prim) - : Named(prim), - attrs_(prim.attrs_), - signatures_(prim.signatures_), - instance_name_(prim.instance_name_), - prim_type_(prim.prim_type_) {} - - MS_DECLARE_PARENT(Primitive, Named); - - abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); - std::string ToString() const override { return name(); } - virtual py::function GetBpropFunction(); - virtual py::function GetComputeFunction(); - Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { - attrs_[name] = attr; - return *this; - } - - Primitive &SetAttrs(const std::unordered_map &attrs) { - for (auto &attr : attrs) { - attrs_[attr.first] = attr.second; - } - return *this; - } + PrimitivePy(const py::str &name, const py::object &python_obj) + : Primitive(name, false), python_obj_(python_obj), signatures_() {} + ~PrimitivePy() override = default; + MS_DECLARE_PARENT(PrimitivePy, Primitive); + py::function GetBpropFunction(); + py::function GetComputeFunction(); void set_signatures( std::vector> @@ -82,52 +49,6 @@ class Primitive : public Named { const std::vector &signatures() const { return signatures_; } - void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } - void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } - - ValuePtr GetAttr(const std::string &attrName) const { - auto iter = attrs_.find(attrName); - return iter == attrs_.cend() ? nullptr : iter->second; - } - - const std::unordered_map &attrs() const { return attrs_; } - - // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. - bool HasAttr() const { return !attrs_.empty(); } - bool HasAttr(const std::string &attrName) const { - auto iter = attrs_.find(attrName); - return !(iter == attrs_.cend()); - } - void set_prim_type(const PrimType t) { prim_type_ = t; } - void set_instance_name(const std::string s) { instance_name_ = s; } - bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } - bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } - bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } - - PrimType prim_type() const { return prim_type_; } - std::string instance_name() const { return instance_name_; } - std::string GetAttrsText() const; - bool operator==(const Value &other) const override; - bool operator==(const Primitive &other) const; - ~Primitive() override = default; - - protected: - std::unordered_map attrs_; - - private: - std::vector signatures_; - std::string instance_name_; - PrimType prim_type_; -}; - -class PrimitivePy : public Primitive { - public: - PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name), python_obj_(python_obj) {} - ~PrimitivePy() override = default; - MS_DECLARE_PARENT(PrimitivePy, Primitive); - py::function GetBpropFunction() override; - py::function GetComputeFunction() override; - void AddPyAttr(const py::str &name, const py::object &obj); py::dict GetAttrDict(); @@ -138,25 +59,9 @@ class PrimitivePy : public Primitive { private: py::object python_obj_; + std::vector signatures_; }; using PrimitivePyPtr = std::shared_ptr; - -inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { - os << *p; - return os; -} - -struct PrimitiveEqual { - bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { - MS_EXCEPTION_IF_NULL(t1); - MS_EXCEPTION_IF_NULL(t2); - return t1->name() == t2->name(); - } -}; - -struct PrimitiveHasher { - std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); } -}; } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_ diff --git a/mindspore/ccsrc/ir/primitive_base.cc b/mindspore/ccsrc/ir/primitive_base.cc new file mode 100644 index 0000000000..864427fe13 --- /dev/null +++ b/mindspore/ccsrc/ir/primitive_base.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/primitive_base.h" + +#include + +namespace mindspore { +bool Primitive::operator==(const Value &other) const { + if (other.isa()) { + auto other_prim = static_cast(other); + return *this == other_prim; + } else { + return false; + } +} + +bool Primitive::operator==(const Primitive &other) const { + if (name() != other.name()) { + return false; + } + if (attrs_.size() != other.attrs_.size()) { + return false; + } + auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair &item) -> bool { + if (item.second == nullptr) { + return false; + } + auto iter = other.attrs_.find(item.first); + if (iter == other.attrs_.end()) { + return false; + } + return *item.second == *iter->second; + }); + return all; +} + +std::string Primitive::GetAttrsText() const { + if (attrs_.empty()) { + return ""; + } + + std::ostringstream oss; + oss << "["; + bool is_first = true; + for (auto &attr : attrs_) { + if (is_first) { + is_first = false; + } else { + oss << ", "; + } + oss << attr.first << "=" << attr.second->DumpText(); + } + oss << "]"; + + return oss.str(); +} +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/primitive_base.h b/mindspore/ccsrc/ir/primitive_base.h new file mode 100644 index 0000000000..28135c8e75 --- /dev/null +++ b/mindspore/ccsrc/ir/primitive_base.h @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ +#define MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ + +#include +#include +#include +#include +#include + +#include "ir/dtype/type.h" + +namespace mindspore { +// Supported meta type +enum PrimType { + kPrimTypeUnknown = 0, + kPrimTypeBegin = kTypeUnknown, + kPrimTypeBuiltIn, // Built-in primitive operator + kPrimTypePyInferShape, // Primitive operator defined by custom + kPrimTypePyInferTensor, // Primitive operator defined by custom + kPrimTypeUserCustom +}; + +class Primitive : public Named { + public: + explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) + : Named(name), is_base_(is_base), has_signature_(false), prim_type_(prim_type) {} + + Primitive(const Primitive &prim) + : Named(prim), + attrs_(prim.attrs_), + instance_name_(prim.instance_name_), + is_base_(prim.is_base_), + has_signature_(prim.has_signature_), + prim_type_(prim.prim_type_) {} + + MS_DECLARE_PARENT(Primitive, Named); + + abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); + std::string ToString() const override { return name(); } + Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { + attrs_[name] = attr; + return *this; + } + + Primitive &SetAttrs(const std::unordered_map &attrs) { + for (auto &attr : attrs) { + attrs_[attr.first] = attr.second; + } + return *this; + } + + void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } + void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } + + ValuePtr GetAttr(const std::string &attrName) const { + auto iter = attrs_.find(attrName); + return iter == attrs_.cend() ? nullptr : iter->second; + } + + const std::unordered_map &attrs() const { return attrs_; } + + // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. + bool HasAttr() const { return !attrs_.empty(); } + bool HasAttr(const std::string &attrName) const { + auto iter = attrs_.find(attrName); + return !(iter == attrs_.cend()); + } + void set_prim_type(const PrimType t) { prim_type_ = t; } + void set_instance_name(const std::string s) { instance_name_ = s; } + bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } + bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } + bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } + + PrimType prim_type() const { return prim_type_; } + std::string instance_name() const { return instance_name_; } + std::string GetAttrsText() const; + bool operator==(const Value &other) const override; + bool operator==(const Primitive &other) const; + ~Primitive() override = default; + + void set_has_signature(bool has_signature) { has_signature_ = has_signature; } + bool has_signature() const { return has_signature_; } + bool is_base() const { return is_base_; } + + protected: + std::unordered_map attrs_; + + private: + std::string instance_name_; + bool is_base_; + bool has_signature_; + PrimType prim_type_; +}; + +inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { + os << *p; + return os; +} + +struct PrimitiveEqual { + bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { + MS_EXCEPTION_IF_NULL(t1); + MS_EXCEPTION_IF_NULL(t2); + return t1->name() == t2->name(); + } +}; + +struct PrimitiveHasher { + std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); } +}; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ diff --git a/mindspore/ccsrc/ir/primitive_base_extends.cc b/mindspore/ccsrc/ir/primitive_base_extends.cc new file mode 100644 index 0000000000..64bdafa4d1 --- /dev/null +++ b/mindspore/ccsrc/ir/primitive_base_extends.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/primitive_base.h" +#include "pipeline/static_analysis/abstract_function.h" + +namespace mindspore { +abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) { + auto prim_func = std::make_shared(shared_from_base(), anf_node); + return prim_func; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index 1098ed1520..1fe26d0023 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -36,8 +36,8 @@ using PatternListType = std::initializer_list; const std::vector &GetSignature(const ValuePtr &function) { static const auto empty = std::vector(); - if (function->isa()) { - return function->cast()->signatures(); + if (function->isa() && function->cast()->has_signature()) { + return function->cast()->signatures(); } else if (function->isa()) { return function->cast()->signatures(); } diff --git a/mindspore/ccsrc/optimizer/ad/kprim.cc b/mindspore/ccsrc/optimizer/ad/kprim.cc index 6fbb9d1ae8..6b68d15649 100644 --- a/mindspore/ccsrc/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/optimizer/ad/kprim.cc @@ -20,6 +20,7 @@ #include #include #include "ir/anf.h" +#include "ir/primitive.h" #include "ir/meta_func_graph.h" #include "ir/func_graph_cloner.h" #include "ir/manager.h" @@ -30,6 +31,7 @@ #include "operator/ops.h" #include "operator/composite/composite.h" #include "utils/symbolic.h" +#include "utils/primitive_utils.h" #include "debug/info.h" #include "debug/trace.h" @@ -49,7 +51,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { auto scope = std::make_shared(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + grad_op_child_scope_prefix + prim->name()); ScopeGuard scope_guard(scope); - py::function fn = prim->GetBpropFunction(); + py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast()->GetBpropFunction(); if (fn == nullptr || py::isinstance(fn)) { MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; return nullptr; diff --git a/mindspore/ccsrc/utils/primitive_utils.cc b/mindspore/ccsrc/utils/primitive_utils.cc new file mode 100644 index 0000000000..cfbfdebac7 --- /dev/null +++ b/mindspore/ccsrc/utils/primitive_utils.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/primitive_utils.h" +#include "pipeline/parse/python_adapter.h" +#include "utils/log_adapter.h" +#include "common/utils.h" + +namespace mindspore { +py::function GetBpropFunctionByObj(py::object obj) { + static const std::string get_bprop_fn = "get_bprop_fn"; + static const std::string ad_module = "mindspore.ops._grad"; + py::function fn = parse::python_adapter::GetPyFn(ad_module, get_bprop_fn)(obj); + return fn; +} + +py::function GetBpropFunction(std::string name) { + auto fn = GetBpropFunctionByObj(py::str(name)); + if (fn.is_none()) { + MS_LOG(WARNING) << "Can't find bprop function for " << name; + } + return fn; +} + +py::function GetComputeFunction(std::string name) { + static const std::string module = "mindspore._extends.builtin_operations"; + py::module mod = py::module::import(common::SafeCStr(module)); + if (!py::hasattr(mod, common::SafeCStr(name))) { + PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name)); + // If raise AttributeError, user can't understand. This case need raise NotImplementedError. + throw py::error_already_set(); + } + py::object fn = mod.attr(common::SafeCStr(name)); + return fn; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/primitive_utils.h b/mindspore/ccsrc/utils/primitive_utils.h new file mode 100644 index 0000000000..b7e2515aea --- /dev/null +++ b/mindspore/ccsrc/utils/primitive_utils.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_ +#define MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_ + +#include +#include "pybind11/pybind11.h" + +namespace py = pybind11; + +namespace mindspore { +py::function GetBpropFunctionByObj(py::object obj); + +py::function GetBpropFunction(std::string name); + +py::function GetComputeFunction(std::string name); +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_ diff --git a/mindspore/ccsrc/vm/vmimpl.cc b/mindspore/ccsrc/vm/vmimpl.cc index d83bb8f190..c17604ff8d 100644 --- a/mindspore/ccsrc/vm/vmimpl.cc +++ b/mindspore/ccsrc/vm/vmimpl.cc @@ -31,6 +31,7 @@ #include "ir/manager.h" #include "ir/func_graph_cloner.h" #include "utils/convert_utils.h" +#include "utils/primitive_utils.h" #include "debug/draw.h" namespace mindspore { @@ -443,7 +444,7 @@ BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) { PrimitivePyPtr operation = dyn_cast(prim); MS_LOG(DEBUG) << "operation start " << prim->name(); - auto func = operation != nullptr ? operation->GetComputeFunction() : prim->GetComputeFunction(); + auto func = operation != nullptr ? operation->GetComputeFunction() : GetComputeFunction(prim->name()); if (py::isinstance(func)) { MS_LOG(EXCEPTION) << prim->name() << " 's compute function is not implemented"; } diff --git a/tests/ut/cpp/operator/ops_test.cc b/tests/ut/cpp/operator/ops_test.cc index bf70fd7b70..99eee16178 100644 --- a/tests/ut/cpp/operator/ops_test.cc +++ b/tests/ut/cpp/operator/ops_test.cc @@ -390,7 +390,7 @@ TEST_F(TestOps, Conv2dAttrTest) { } TEST_F(TestOps, CustomOpAttrTest) { - Primitive prim("CustomOp", kPrimTypePyInferShape); + Primitive prim("CustomOp", true, kPrimTypePyInferShape); prim.SetAttrs({ {"attr1", MakeValue(3)}, {"attr2", MakeValue(1)},