diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 95ccf58cb8..5c17fbe036 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -351,7 +351,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "Windows") elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") target_link_libraries(mindspore mindspore::pybind11_module) target_link_libraries(mindspore mindspore_gvar) - target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load) + target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load) else() if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) target_link_libraries(mindspore proto_input mindspore::protobuf @@ -361,7 +361,8 @@ else() target_link_libraries(mindspore ibverbs rdmacm) endif() endif() - target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive) + target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore mindspore_core + proto_input -Wl,--no-whole-archive) target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module) target_link_libraries(_c_expression PRIVATE mindspore_gvar) if(ENABLE_D) diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/backend/optimizer/common/const_input_to_attr_registry.cc similarity index 98% rename from mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc rename to mindspore/ccsrc/backend/optimizer/common/const_input_to_attr_registry.cc index d4effc3c50..65625600e1 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/backend/optimizer/common/const_input_to_attr_registry.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "backend/optimizer/pass/const_input_to_attr_registry.h" +#include "backend/optimizer/common/const_input_to_attr_registry.h" #include diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.h b/mindspore/ccsrc/backend/optimizer/common/const_input_to_attr_registry.h similarity index 92% rename from mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.h rename to mindspore/ccsrc/backend/optimizer/common/const_input_to_attr_registry.h index 5f6b659372..5bcc23dfd2 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.h +++ b/mindspore/ccsrc/backend/optimizer/common/const_input_to_attr_registry.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_ #include #include #include @@ -75,4 +75,4 @@ struct ConstInputToAttrInfoReceiver { ::mindspore::opt::ConstInputToAttrInfoRegister(op_name) } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index c534c6fb41..87183e7ff0 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -31,6 +31,8 @@ #include "utils/ms_utils.h" #include "runtime/device/kernel_info.h" #include "utils/ms_context.h" +#include "backend/optimizer/common/const_input_to_attr_registry.h" +#include "abstract/primitive_infer_map.h" namespace mindspore { namespace opt { @@ -700,6 +702,92 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive } return CreateCNodeWithGraph(input_nodes, graph); } + +// rectify absttract if the input has been converted to the attr +AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive, + const AbstractBasePtrList &input_abstract) { + MS_EXCEPTION_IF_NULL(primitive); + opt::ConstInputToAttrInfoRegister reg; + if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), ®)) { + return input_abstract; + } + if (AnfAlgo::HasDynamicShapeFlag(primitive) || + DynamicShapeConstInputToAttr.find(primitive->name()) != DynamicShapeConstInputToAttr.end()) { + return input_abstract; + } + auto convert_input_list = reg.GetConstInputAttrInfo(); + auto input_names = primitive->GetAttr(kAttrInputNames); + if (input_names == nullptr) { + return input_abstract; + } + auto input_names_vec = GetValue>(input_names); + AbstractBasePtrList rectify_abs_list; + size_t ori_index = 0; + rectify_abs_list.resize(input_names_vec.size()); + for (size_t index = 0; index < rectify_abs_list.size(); ++index) { + // if convert input list find the index it means the input has been converted to the attr + if (convert_input_list.find(index) != convert_input_list.end()) { + AbstractBasePtr rectify_abs = nullptr; + auto input_name = input_names_vec[index]; + auto attr = primitive->GetAttr(input_name); + if (attr != nullptr) { + rectify_abs = attr->ToAbstract(); + } else { + MS_LOG(DEBUG) << "the node prim name :" << primitive->name() << "input index :" << index + << " input name :" << input_name << "has not been converted to the attr"; + rectify_abs = input_abstract[ori_index++]; + } + rectify_abs_list[index] = rectify_abs; + continue; + } + if (ori_index > input_abstract.size()) { + MS_LOG(EXCEPTION) << "index is out of range input abstract size " << input_abstract.size() + << " get index :" << ori_index; + } + rectify_abs_list[index] = input_abstract[ori_index++]; + } + return rectify_abs_list; +} + +AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive, + const AbstractBasePtrList &input_abstract) { + auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes); + if (dynamic_inputs_list == nullptr) { + return input_abstract; + } + AbstractBasePtrList rectifyed_abs_list; + const int kNotDynamicFlag = -1; + auto dynamic_inputs_index = GetValue>(dynamic_inputs_list); + size_t input_index = 0; + for (auto item : dynamic_inputs_index) { + if (item == kNotDynamicFlag) { + if (input_index >= input_abstract.size()) { + MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract " << input_abstract.size(); + } + rectifyed_abs_list.emplace_back(input_abstract[input_index++]); + } else { + if (item < 0) { + MS_LOG(EXCEPTION) << " the dynamic input size check error the index should be -1 or positive number but got " + << item; + } + AbstractBasePtrList dynamic_inputs_abs; + for (auto index = item; index > 0; --index) { + if (input_index >= input_abstract.size()) { + MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract " + << input_abstract.size(); + } + dynamic_inputs_abs.emplace_back(input_abstract[input_index++]); + } + rectifyed_abs_list.emplace_back(std::make_shared(dynamic_inputs_abs)); + } + } + return rectifyed_abs_list; +} + +AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) { + auto rectify_abs_list = RectifyAbstractFromRegAttr(primitive, input_abstract); + return RectifyAbstractFromDynamicInput(primitive, rectify_abs_list); +} } // namespace AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { @@ -835,5 +923,24 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C } } } +AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(prim); + auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap(); + auto ret = prim_eval_implement_map.find(prim); + if (ret != prim_eval_implement_map.end()) { + // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr + auto infer_spec_list = RectifyAbstract(prim, args_spec_list); + return ret->second.impl_(nullptr, prim, infer_spec_list); + } else { + // if the infer function has been not founded in the front infer map find it in the backend infer map instead + auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap(); + auto ret_backend = prim_backend_eval_impl_map.find(prim); + if (ret_backend != prim_backend_eval_impl_map.end()) { + return ret_backend->second.impl_(nullptr, prim, args_spec_list); + } + } + MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() + << " primitive type:" << prim->type_name(); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index a20a4234b5..d55e9e80ce 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -212,6 +212,8 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); // Transfer depend or control_depend to the new node void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); + +AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 0cd9dd294a..df19410f6c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -27,7 +27,7 @@ #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" #include "backend/kernel_compiler/kernel.h" #include "backend/session/anf_runtime_algorithm.h" -#include "backend/optimizer/pass/const_input_to_attr_registry.h" +#include "backend/optimizer/common/const_input_to_attr_registry.h" #include "ir/func_graph_cloner.h" #include "ir/func_graph.h" #include "pipeline/jit/parse/python_adapter.h" diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc index dfcd8523c3..b12a3e8a4f 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc @@ -19,7 +19,7 @@ #include #include -#include "backend/optimizer/pass/const_input_to_attr_registry.h" +#include "backend/optimizer/common/const_input_to_attr_registry.h" #include "backend/optimizer/common/helper.h" #include "utils/utils.h" #include "utils/ms_context.h" diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 41e92b6de4..2e4ba47779 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1534,6 +1534,18 @@ bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::stri return AnfAlgo::GetNodeAttr(node, attr); } +bool AnfRuntimeAlgorithm::HasDynamicShapeFlag(const PrimitivePtr &prim) { + auto get_bool_attr = [](const PrimitivePtr &primitive, const std::string &attr_name) -> bool { + MS_EXCEPTION_IF_NULL(primitive); + if (!primitive->HasAttr(attr_name)) { + return false; + } + return GetValue(primitive->GetAttr(attr_name)); + }; + return get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape) || + get_bool_attr(prim, kAttrIsDynamicShape); +} + bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) { return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) || GetBooleanAttr(node, kAttrIsDynamicShape); @@ -1805,7 +1817,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) { args_spec_list.emplace_back(real_input->abstract()); } } - auto eval_result = abstract::CppInferShape(primitive, args_spec_list); + auto eval_result = opt::CppInferShape(primitive, args_spec_list); node->set_abstract(eval_result); } } // namespace session diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 3ea2c92f3c..8399e3c90c 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -230,6 +230,7 @@ class AnfRuntimeAlgorithm { // get fix output precision from prev node, input_idx is the input index of current node related to prev node. static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); static bool IsDynamicShape(const AnfNodePtr &node); + static bool HasDynamicShapeFlag(const PrimitivePtr &prim); static bool IsCondControlKernel(const CNodePtr &node); static bool IsIndependentNode(const CNodePtr &node); static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 916f441521..95317b6dba 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -1311,15 +1311,6 @@ bool IsInWhiteList(const PrimitivePtr &primitive) { return false; } -StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { - MS_EXCEPTION_IF_NULL(primitive); - auto iter = GetPrimitiveToEvalImplMap().find(primitive); - if (iter == GetPrimitiveToEvalImplMap().end()) { - return nullptr; - } - return iter->second.impl_; -} - PrimEvaluatorMap &GetPrimEvaluatorConstructors() { PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; if (!constructor.empty()) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h index 1dd8468372..868b002f75 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -112,7 +112,6 @@ class MixedPrecisionCastEvaluator : public Evaluator { }; bool IsInWhiteList(const PrimitivePtr &primitive); -StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); using ValuePtrList = std::vector; using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 3e8d2bf414..5b95c02b59 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -357,6 +357,13 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr return std::make_shared(prim); } + // find prim infer function in the prim function map return a standard evaluator + StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); + if (eval_impl != nullptr) { + return std::make_shared(prim, eval_impl); + } + + // use python infer function if the infer function not founded in the map return a python evaluator EvaluatorPtr evaluator = nullptr; if (prim->HasPyEvaluator()) { auto prim_py = dyn_cast(prim); @@ -376,17 +383,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; } - if (prim->isa() || prim->HasAttr()) { - if (engine == nullptr) { - (void)GetPrimEvaluatorConstructors(); - } - // If a primitive may have attr, try to create a new evaluator. - StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); - if (eval_impl != nullptr) { - return std::make_shared(prim, eval_impl); - } - } - + // return a default evaluator if (engine == nullptr) { // If engine is nullptr, get constructor from default. const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); @@ -778,16 +775,5 @@ EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrLi auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); return eval_result; } - -AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) { - MS_EXCEPTION_IF_NULL(prim); - auto &prim_eval_implement_map = GetPrimitiveToEvalImplMap(); - auto ret = prim_eval_implement_map.find(prim); - if (ret == prim_eval_implement_map.end()) { - MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() - << " primitive type:" << prim->type_name(); - } - return ret->second.impl_(nullptr, prim, args_spec_list); -} } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index 5b6b462ef7..ebfbfdb768 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -331,8 +331,6 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) { } EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); - -AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index c34c1b0713..333173c95e 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -44,7 +44,7 @@ #include "pipeline/jit/static_analysis/prim.h" #include "pipeline/jit/static_analysis/auto_monad.h" #include "backend/session/session_factory.h" -#include "backend/optimizer/pass/const_input_to_attr_registry.h" +#include "backend/optimizer/common/const_input_to_attr_registry.h" #include "backend/optimizer/common/helper.h" #include "pipeline/jit/action.h" @@ -807,21 +807,13 @@ void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, } } // get output dynamic shape info - auto py_abstract = op_exec_info->abstract; - MS_EXCEPTION_IF_NULL(py_abstract); - auto py_shape = py_abstract->BuildShape(); - MS_EXCEPTION_IF_NULL(py_shape); - auto py_shape_info = py_shape->ToString(); - if (py_shape_info.find("-1") != string::npos) { - auto c_abstract = abstract::CppInferShape(prim, args_spec_list); - MS_EXCEPTION_IF_NULL(c_abstract); - auto c_shape = c_abstract->BuildShape(); - MS_EXCEPTION_IF_NULL(c_shape); - auto c_shape_info = c_shape->ToString(); - MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info; - if (c_shape_info.find("-1") != string::npos) { - op_exec_info->is_dynamic_shape = true; - } + auto abstract = op_exec_info->abstract; + MS_EXCEPTION_IF_NULL(abstract); + auto shape = abstract->BuildShape(); + MS_EXCEPTION_IF_NULL(shape); + auto shape_info = shape->ToString(); + if (shape_info.find("-1") != string::npos) { + op_exec_info->is_dynamic_shape = true; } } diff --git a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc index 79c1de652a..75c92c82ed 100644 --- a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc @@ -123,7 +123,7 @@ void DynamicKernel::InferShape() { } } - auto eval_result = abstract::CppInferShape(primitive, args_spec_list); + auto eval_result = opt::CppInferShape(primitive, args_spec_list); cnode_ptr_->set_abstract(eval_result); } diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index c7f419b09d..d09919eb7f 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -1041,6 +1041,9 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string &op_name = primitive->name(); + if (args_spec_list.size() == 1) { + return args_spec_list[0]->Broaden(); + } CheckArgsSize(op_name, args_spec_list, 3); AbstractTensorPtr range_start = CheckArg(op_name, args_spec_list, 0); AbstractTensorPtr range_end = CheckArg(op_name, args_spec_list, 1); diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index fb0d336ed7..c474884d46 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -292,24 +292,6 @@ AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const Primi return std::make_shared(rets); } -AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). - MS_EXCEPTION_IF_NULL(args_spec_list[1]); - MS_EXCEPTION_IF_NULL(args_spec_list[2]); - MS_EXCEPTION_IF_NULL(args_spec_list[3]); - - CheckArgsSize(primitive->name(), args_spec_list, 5); - auto dx = args_spec_list[1]->Broaden(); - auto dscale = args_spec_list[2]->Broaden(); - auto dbias = args_spec_list[3]->Broaden(); - auto reserve_1 = args_spec_list[4]->Broaden(); - auto reserve_2 = args_spec_list[5]->Broaden(); - - AbstractBasePtrList rets = {dx, dscale, dbias, reserve_1, reserve_2}; - return std::make_shared(rets); -} - AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: two tensors(y_backprop, x). @@ -468,20 +450,6 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p return std::make_shared(x_type, output_shape_ptr); } -AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors(doutput, input, filters). - CheckRequiredArgsSize(primitive->name(), args_spec_list, 3); - return args_spec_list[1]->Broaden(); -} - -AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors(inputs, filter, doutput). - CheckArgsSize(primitive->name(), args_spec_list, 3); - return args_spec_list[2]->Broaden(); -} - AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 94f7ca10c4..e1cd7eb4d3 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -17,6 +17,11 @@ */ #include "abstract/primitive_infer_map.h" + +#include +#include +#include + #include "abstract/abstract_function.h" #include "abstract/infer_functions.h" @@ -59,40 +64,21 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimNotInDict, {InferImplNotInDict, true}}, {prim::kPrimIsConsant, {InferImplIsConstant, true}}, // Maths - {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, - {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, - {prim::kPrimMul, {InferImplMul, true}}, - {prim::kPrimAdd, {InferImplAdd, true}}, {prim::kPrimSquare, {InferImplSquare, true}}, - {prim::kPrimSqrt, {InferImplSqrt, true}}, - {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, - {prim::kPrimSub, {InferImplSub, true}}, - {prim::kPrimEqual, {InferImplEqual, true}}, - {prim::kPrimReduceSum, {InferImplReduceFunc, true}}, - {prim::kPrimReduceMean, {InferImplReduceFunc, true}}, - {prim::kPrimReduceAll, {InferImplReduceFunc, true}}, - {prim::kPrimReduceAny, {InferImplReduceFunc, true}}, - {prim::kPrimReduceMax, {InferImplReduceFunc, true}}, - {prim::kPrimReduceMin, {InferImplReduceFunc, true}}, - {prim::kPrimMinimum, {InferImplMinimum, true}}, - {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, - {prim::kPrimLinSpace, {InferImplLinSpace, true}}, - {prim::kPrimAddN, {InferImplAddN, true}}, {prim::kPrimMatMul, {InferImplMatMul, true}}, {prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}}, - {prim::kPrimLess, {InferImplLess, true}}, + {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, + {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, + {prim::kPrimSqrt, {InferImplSqrt, true}}, // Array + {prim::kPrimRange, {InferImplRange, true}}, {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, - {prim::kPrimStack, {InferImplStack, true}}, - {prim::kPrimPad, {InferImplPad, true}}, {prim::kPrimUnique, {InferImplUnique, true}}, {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, {prim::kPrimGather, {InferImplGatherV2, true}}, {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, - {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, - {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, @@ -104,18 +90,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}}, {prim::kPrimPadAndShift, {InferImplPadAndShift, true}}, - {prim::kPrimDiv, {InferImplDiv, true}}, - {prim::kPrimRealDiv, {InferImplRealDiv, true}}, - {prim::kPrimShape, {InferImplShape, false}}, {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, - {prim::kPrimTranspose, {InferImplTranspose, true}}, - {prim::kPrimReshape, {InferImplReshape, true}}, {prim::kPrimMapUniform, {InferImplMapUniform, true}}, {prim::kPrimSplit, {InferImplSplit, true}}, {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, - {prim::kPrimConcat, {InferImplConcat, true}}, - {prim::kPrimRange, {InferImplRange, true}}, - {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}}, // Structure {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, {prim::kPrimMakeList, {InferImplMakeList, true}}, @@ -139,14 +117,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimPooling, {InferImplPooling, true}}, {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, - {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, {prim::kPrimFusedBatchNormEx, {InferImplFusedBatchNormEx, true}}, - {prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}}, {prim::kPrimReluGrad, {InferImplReluGrad, true}}, {prim::kPrimConv2D, {InferImplConv2D, true}}, - {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, - {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, {prim::kPrimBiasAdd, {InferImplBiasAdd, true}}, {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, {prim::kPrimRelu, {InferImplRelu, true}}, @@ -192,18 +166,60 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}}, // Comm Ops - {prim::kPrimAllReduce, {InferImplAllReduce, true}}, - {prim::kPrimBroadcast, {InferImplBroadcast, true}}, - {prim::kPrimAllGather, {InferImplAllGather, true}}, {prim::kPrimAllSwap, {InferImplAllSwap, true}}, - {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, + }; + return prim_eval_implement_map; +} + +PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { + static PrimitiveEvalImplMap prim_backend_eval_implement_map = { + {prim::kPrimMul, {InferImplMul, true}}, + {prim::kPrimAdd, {InferImplAdd, true}}, + {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, + {prim::kPrimSub, {InferImplSub, true}}, + {prim::kPrimEqual, {InferImplEqual, true}}, + {prim::kPrimReduceSum, {InferImplReduceFunc, true}}, + {prim::kPrimReduceMean, {InferImplReduceFunc, true}}, + {prim::kPrimReduceAll, {InferImplReduceFunc, true}}, + {prim::kPrimReduceAny, {InferImplReduceFunc, true}}, + {prim::kPrimReduceMax, {InferImplReduceFunc, true}}, + {prim::kPrimReduceMin, {InferImplReduceFunc, true}}, + {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, {prim::kPrimCast, {InferImplCast, true}}, {prim::kPrimExpandDims, {InferImplExpandDims, true}}, - {prim::kPrimSparseSoftmaxCrossEntropyWithLogits, {InferImplSparseSoftmaxCrossEntropyWithLogits, true}}, - {prim::kPrimDType, {InferImplDType, true}}, + {prim::kPrimAllReduce, {InferImplAllReduce, true}}, + {prim::kPrimBroadcast, {InferImplBroadcast, true}}, + {prim::kPrimAllGather, {InferImplAllGather, true}}, + {prim::kPrimMinimum, {InferImplMinimum, true}}, + {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, + {prim::kPrimLinSpace, {InferImplLinSpace, true}}, + {prim::kPrimAddN, {InferImplAddN, true}}, + + {prim::kPrimLess, {InferImplLess, true}}, + {prim::kPrimStack, {InferImplStack, true}}, + {prim::kPrimPad, {InferImplPad, true}}, + {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, + {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, + {prim::kPrimDiv, {InferImplDiv, true}}, + {prim::kPrimRealDiv, {InferImplRealDiv, true}}, + {prim::kPrimShape, {InferImplShape, false}}, + {prim::kPrimTranspose, {InferImplTranspose, true}}, + {prim::kPrimReshape, {InferImplReshape, true}}, + {prim::kPrimConcat, {InferImplConcat, true}}, + {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}}, + {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, }; - return prim_eval_implement_map; + return prim_backend_eval_implement_map; +} + +StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { + MS_EXCEPTION_IF_NULL(primitive); + auto iter = GetPrimitiveToEvalImplMap().find(primitive); + if (iter == GetPrimitiveToEvalImplMap().end()) { + return nullptr; + } + return iter->second.impl_; } void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) { diff --git a/mindspore/core/abstract/primitive_infer_map.h b/mindspore/core/abstract/primitive_infer_map.h index 77329ce901..f175c2e30c 100644 --- a/mindspore/core/abstract/primitive_infer_map.h +++ b/mindspore/core/abstract/primitive_infer_map.h @@ -18,6 +18,7 @@ #ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ #define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ #include +#include #include "ir/primitive.h" #include "base/core_ops.h" #include "abstract/abstract_value.h" @@ -37,6 +38,10 @@ using PrimitiveEvalImplMap = PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap(); +PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap(); + +StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); + std::vector GetDependsFormMap(const CNodePtr &cnode); void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg); diff --git a/mindspore/core/ops/fusion/max_pool_fusion.cc b/mindspore/core/ops/fusion/max_pool_fusion.cc index c8e687342c..b8545e021e 100644 --- a/mindspore/core/ops/fusion/max_pool_fusion.cc +++ b/mindspore/core/ops/fusion/max_pool_fusion.cc @@ -104,6 +104,5 @@ AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const Pr return std::make_shared(InferType(primitive, input_args), InferShape(primitive, input_args)->shape()); } -REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolFusion, prim::kPrimMaxPool, MaxPoolFusionInfer); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/avg_pool_grad.cc b/mindspore/core/ops/grad/avg_pool_grad.cc index 5752bd5378..13a78a904d 100644 --- a/mindspore/core/ops/grad/avg_pool_grad.cc +++ b/mindspore/core/ops/grad/avg_pool_grad.cc @@ -31,8 +31,6 @@ AbstractBasePtr AvgPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim auto element = tensor_type->element(); return std::make_shared(element, origin_input_shape); } - -REGISTER_PRIMITIVE_EVAL_IMPL(AvgPoolGrad, prim::kPrimAvgPoolGrad, AvgPoolGradInfer); REGISTER_PRIMITIVE_C(kNameAvgPoolGrad, AvgPoolGrad); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/bias_add_grad.cc b/mindspore/core/ops/grad/bias_add_grad.cc index 8c9bcea071..9ff1c387bd 100644 --- a/mindspore/core/ops/grad/bias_add_grad.cc +++ b/mindspore/core/ops/grad/bias_add_grad.cc @@ -58,8 +58,6 @@ AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const Prim return std::make_shared(intype, inshape); } - -REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer); REGISTER_PRIMITIVE_C(kNameBiasAddGrad, BiasAddGrad); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/max_pool_grad.cc b/mindspore/core/ops/grad/max_pool_grad.cc index ae6c5a4a29..d3575c24df 100644 --- a/mindspore/core/ops/grad/max_pool_grad.cc +++ b/mindspore/core/ops/grad/max_pool_grad.cc @@ -31,8 +31,6 @@ AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim auto element = tensor_type->element(); return std::make_shared(element, x1_shape); } - -REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolGrad, prim::kPrimMaxPoolGrad, MaxPoolGradInfer); REGISTER_PRIMITIVE_C(kNameMaxPoolGrad, MaxPoolGrad); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/lrn.cc b/mindspore/core/ops/lrn.cc index a86dd16c89..09360cab77 100644 --- a/mindspore/core/ops/lrn.cc +++ b/mindspore/core/ops/lrn.cc @@ -102,7 +102,6 @@ AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr return std::make_shared(InferType(primitive, input_args), InferShape(primitive, input_args)->shape()); } -REGISTER_PRIMITIVE_EVAL_IMPL(LRN, prim::kPrimLrn, LrnInfer); REGISTER_PRIMITIVE_C(kNameLRN, LRN); } // namespace ops } // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/common/restore_abs_input_in_backed_infer_test.cc b/tests/ut/cpp/pre_activate/common/restore_abs_input_in_backed_infer_test.cc new file mode 100644 index 0000000000..bfabc1998d --- /dev/null +++ b/tests/ut/cpp/pre_activate/common/restore_abs_input_in_backed_infer_test.cc @@ -0,0 +1,86 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "ir/primitive.h" +#include "utils/utils.h" +#include "abstract/abstract_value.h" +#include "abstract/primitive_infer_map.h" +#include "backend/optimizer/common/const_input_to_attr_registry.h" +#include "backend/optimizer/common/helper.h" +#include "common/common_test.h" +namespace mindspore { +namespace opt { +constexpr auto kAttrConvertTestName = "attr_convert_test"; +constexpr auto kDynamicInputTestName = "dynamic_input_test"; +inline const PrimitivePtr kPrimAttrConvertTest = std::make_shared(kAttrConvertTestName); +inline const PrimitivePtr kPrimDynamicInputTest = std::make_shared("dynamic_input_test"); +AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + EXPECT_EQ(args_spec_list.size(), 3); + EXPECT_NE(args_spec_list[1], nullptr); + EXPECT_EQ(args_spec_list[1]->isa(), true); + return args_spec_list[0]; +} +REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr,kPrimAttrConvertTest,InferImplAttrTest); +AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + EXPECT_EQ(args_spec_list.size(), 3); + EXPECT_NE(args_spec_list[1], nullptr); + EXPECT_EQ(args_spec_list[1]->isa(), true); + auto item = args_spec_list[1]->cast(); + return args_spec_list[0]; +} +REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput,kPrimDynamicInputTest,InferImplDynamicInputTest); +class TestAttrAndDynamicBackendInfer : public UT::Common { + public: + TestAttrAndDynamicBackendInfer() {} + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(TestAttrAndDynamicBackendInfer, test_attr_and_dynamic_input_infer) { + // Register Attr for ut + ConstInputToAttrInfoRegistry ® = ConstInputToAttrInfoRegistry::Instance(); + reg.Register(kAttrConvertTestName, {1}); + // construct primitive + PrimitivePtr prim_attr_test = std::make_shared(kAttrConvertTestName); + PrimitivePtr prim_dynamic_input_test = std::make_shared(kDynamicInputTestName); + // set primtive attr + auto input_names = std::vector{"a", "b", "c"}; + auto attr_name = "b"; + auto attr = MakeValue(std::vector{1, 2, 3}); + prim_attr_test->AddAttr(kAttrInputNames, MakeValue(input_names)); + prim_attr_test->AddAttr(attr_name, attr); + // set dynameic input list for primtive + std::vector dynamic_input_list = {-1, 2, -1}; + prim_dynamic_input_test->AddAttr(kAttrDynInputSizes, MakeValue(dynamic_input_list)); + // construct Abstract list + auto abs_a = std::make_shared(kFloat32, std::vector{2, 2, 2, 2}); + auto abs_c = std::make_shared(kFloat32, std::vector{2, 2, 2, 2}); + auto attr_infer_result = CppInferShape(prim_attr_test, {abs_a, abs_c}); + auto abs_dynamic_a = std::make_shared(kFloat32, std::vector{2, 2, 2, 2}); + auto abs_dynamic_b = std::make_shared(kFloat32, std::vector{2, 2, 2, 2}); + auto abs_dynamic_c = std::make_shared(kFloat32, std::vector{2, 2, 2, 2}); + auto abs_dynamic_d = std::make_shared(kFloat32, std::vector{2, 2, 2, 2}); + auto dynamic_infer_result = + CppInferShape(prim_dynamic_input_test, {abs_dynamic_a, abs_dynamic_b, abs_dynamic_c, abs_dynamic_d}); +} +} // namespace opt +} // namespace mindspore \ No newline at end of file