From d9f4659cfd361dd09b5241484feea0a14f675abe Mon Sep 17 00:00:00 2001 From: LianLiguang Date: Wed, 24 Mar 2021 18:23:06 +0800 Subject: [PATCH] change some check type api --- mindspore/core/ops/abs.cc | 3 +- mindspore/core/ops/adam.cc | 12 +- mindspore/core/ops/add.cc | 3 +- mindspore/core/ops/addn.cc | 8 +- mindspore/core/ops/apply_momentum.cc | 2 +- mindspore/core/ops/arg_min.cc | 3 - mindspore/core/ops/asin.cc | 10 +- mindspore/core/ops/assert.cc | 8 +- mindspore/core/ops/assign_add.cc | 3 +- mindspore/core/ops/atan.cc | 12 +- mindspore/core/ops/batch_norm.cc | 10 +- mindspore/core/ops/batch_norm_fold.cc | 19 +- mindspore/core/ops/batch_to_space.cc | 3 +- mindspore/core/ops/bias_add.cc | 3 +- mindspore/core/ops/binary_cross_entropy.cc | 4 +- mindspore/core/ops/broadcast.cc | 2 +- mindspore/core/ops/broadcast_to.cc | 7 +- mindspore/core/ops/ceil.cc | 8 +- mindspore/core/ops/concat.cc | 4 +- mindspore/core/ops/constant.cc | 3 +- mindspore/core/ops/conv2d.cc | 9 +- mindspore/core/ops/conv2d_transpose.cc | 5 +- mindspore/core/ops/cos.cc | 3 +- mindspore/core/ops/custom_extract_features.cc | 4 +- mindspore/core/ops/custom_predict.cc | 6 +- mindspore/core/ops/depthwise_conv2d.cc | 6 +- mindspore/core/ops/detection_post_process.cc | 2 +- mindspore/core/ops/div.cc | 3 +- mindspore/core/ops/dropout.cc | 12 +- mindspore/core/ops/elu.cc | 5 +- mindspore/core/ops/embedding_lookup.cc | 11 +- mindspore/core/ops/equal.cc | 3 +- mindspore/core/ops/exp.cc | 3 +- mindspore/core/ops/expand_dims.cc | 6 +- .../core/ops/fake_quant_with_min_max_vars.cc | 5 +- ...ake_quant_with_min_max_vars_per_channel.cc | 3 +- mindspore/core/ops/fft_imag.cc | 2 +- mindspore/core/ops/fill.cc | 4 +- mindspore/core/ops/flatten.cc | 2 +- mindspore/core/ops/floor.cc | 7 +- mindspore/core/ops/fusion/add_fusion.cc | 3 +- mindspore/core/ops/fusion/pow_fusion.cc | 3 +- mindspore/core/ops/gather.cc | 10 +- mindspore/core/ops/gather_nd.cc | 7 +- mindspore/core/ops/gelu.cc | 5 +- .../ops/grad/binary_cross_entropy_grad.cc | 4 +- .../core/ops/grad/conv2d_backprop_filter.cc | 5 +- .../core/ops/grad/conv2d_backprop_input.cc | 2 +- mindspore/core/ops/grad/dropout_grad.cc | 4 +- .../sigmoid_cross_entropy_with_logits_grad.cc | 11 +- .../core/ops/grad/smooth_l1_loss_grad.cc | 10 +- mindspore/core/ops/hashtable_lookup.cc | 2 +- mindspore/core/ops/l2_normalize.cc | 4 +- mindspore/core/ops/leaky_relu.cc | 3 +- mindspore/core/ops/less.cc | 2 +- mindspore/core/ops/less_equal.cc | 3 +- .../core/ops/local_response_normalization.cc | 3 +- mindspore/core/ops/log.cc | 7 +- mindspore/core/ops/logical_and.cc | 8 +- mindspore/core/ops/logical_not.cc | 8 +- mindspore/core/ops/logical_or.cc | 8 +- mindspore/core/ops/lrn.cc | 5 +- mindspore/core/ops/lsh_projection.cc | 3 +- mindspore/core/ops/mat_mul.cc | 9 +- mindspore/core/ops/matrix_diag.cc | 12 +- mindspore/core/ops/maximum.cc | 3 +- mindspore/core/ops/merge.cc | 13 +- mindspore/core/ops/minimum.cc | 3 +- mindspore/core/ops/neg.cc | 3 +- mindspore/core/ops/non_max_suppression.cc | 2 +- mindspore/core/ops/one_hot.cc | 16 +- mindspore/core/ops/ones_like.cc | 7 +- mindspore/core/ops/op_utils.h | 12 +- mindspore/core/ops/pad.cc | 6 +- mindspore/core/ops/pow.cc | 3 +- mindspore/core/ops/prelu.cc | 11 +- mindspore/core/ops/prior_box.cc | 2 +- mindspore/core/ops/range.cc | 3 +- mindspore/core/ops/rank.cc | 4 +- mindspore/core/ops/real_div.cc | 3 +- mindspore/core/ops/reciprocal.cc | 5 +- mindspore/core/ops/reduce.cc | 6 +- mindspore/core/ops/relu6.cc | 7 +- mindspore/core/ops/resize_bilinear.cc | 9 +- mindspore/core/ops/reverse_sequence.cc | 13 +- mindspore/core/ops/reverse_v2.cc | 18 +- mindspore/core/ops/rfft.cc | 2 +- mindspore/core/ops/round.cc | 8 +- mindspore/core/ops/rsqrt.cc | 7 +- mindspore/core/ops/scatter_nd.cc | 6 +- .../ops/sigmoid_cross_entropy_with_logits.cc | 10 +- mindspore/core/ops/sin.cc | 5 +- mindspore/core/ops/smooth_l1_loss.cc | 5 +- mindspore/core/ops/softmax.cc | 7 +- .../ops/softmax_cross_entropy_with_logits.cc | 5 +- mindspore/core/ops/space_to_batch.cc | 3 +- mindspore/core/ops/space_to_batch_nd.cc | 3 +- mindspore/core/ops/sparse_to_dense.cc | 4 - mindspore/core/ops/squared_difference.cc | 5 +- mindspore/core/ops/sub.cc | 3 +- mindspore/core/ops/tan.cc | 11 +- mindspore/core/ops/tensor_list_from_tensor.cc | 6 +- mindspore/core/ops/tile.cc | 6 +- mindspore/core/ops/topk.cc | 8 +- mindspore/core/ops/unsorted_segment_sum.cc | 23 +- mindspore/core/ops/unstack.cc | 2 - mindspore/core/ops/zeros_like.cc | 12 +- mindspore/core/utils/check_convert_utils.cc | 227 ++++++------------ mindspore/core/utils/check_convert_utils.h | 49 ++-- 109 files changed, 330 insertions(+), 605 deletions(-) diff --git a/mindspore/core/ops/abs.cc b/mindspore/core/ops/abs.cc index 14f55720e0..ed0574c38e 100644 --- a/mindspore/core/ops/abs.cc +++ b/mindspore/core/ops/abs.cc @@ -46,8 +46,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } std::map types; types.emplace("input_x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/adam.cc b/mindspore/core/ops/adam.cc index c8071d2e08..6754b84598 100644 --- a/mindspore/core/ops/adam.cc +++ b/mindspore/core/ops/adam.cc @@ -42,14 +42,10 @@ abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::ve auto m_type = input_args[1]->BuildType(); auto v_type = input_args[2]->BuildType(); auto grad_type = input_args[9]->BuildType(); - CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name); - CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name); - CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name); - CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name); - - auto infer_var_type = var_type->cast()->element(); - auto infer_m_type = m_type->cast()->element(); - auto infer_v_type = v_type->cast()->element(); + auto infer_var_type = CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name); + auto infer_m_type = CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name); + auto infer_v_type = CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name); // auto infer_grad_type = grad_type->cast()->element(); auto output0 = std::make_shared(infer_var_type, var_shape); auto output1 = std::make_shared(infer_m_type, m_shape); diff --git a/mindspore/core/ops/add.cc b/mindspore/core/ops/add.cc index 58beb2a22a..c041153f7f 100644 --- a/mindspore/core/ops/add.cc +++ b/mindspore/core/ops/add.cc @@ -40,8 +40,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/addn.cc b/mindspore/core/ops/addn.cc index fce06ce4c1..de1b4e95d0 100644 --- a/mindspore/core/ops/addn.cc +++ b/mindspore/core/ops/addn.cc @@ -56,12 +56,10 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt } types.emplace(elementi, elements[i]->BuildType()); } - std::set valid_types = common_valid_types; - valid_types.insert(kNumberTypeBool); + std::set valid_types = common_valid_types; + valid_types.insert(kBool); auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name); - - return std::make_shared(TypeIdToType(infer_type), - std::make_shared(element0_shape)); + return std::make_shared(infer_type, std::make_shared(element0_shape)); } REGISTER_PRIMITIVE_C(kNameAddN, AddN); } // namespace ops diff --git a/mindspore/core/ops/apply_momentum.cc b/mindspore/core/ops/apply_momentum.cc index 992bf31baa..ebc3962a79 100644 --- a/mindspore/core/ops/apply_momentum.cc +++ b/mindspore/core/ops/apply_momentum.cc @@ -68,7 +68,7 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr auto l_type = input_args[2]->BuildType(); auto g_type = input_args[3]->BuildType(); auto m_type = input_args[4]->BuildType(); - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name); CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name); std::map args; diff --git a/mindspore/core/ops/arg_min.cc b/mindspore/core/ops/arg_min.cc index be88497aef..230b61855d 100644 --- a/mindspore/core/ops/arg_min.cc +++ b/mindspore/core/ops/arg_min.cc @@ -62,9 +62,6 @@ AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const Primitive // Infer type auto x_dtype = input_args[0]->BuildType()->cast()->element(); - std::set template_types = {TypeIdToType(kObjectTypeTensorType)}; - CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim_name); - return std::make_shared(x_dtype, std::make_shared(out_shape)); } REGISTER_PRIMITIVE_C(kNameArgMin, ArgMin); diff --git a/mindspore/core/ops/asin.cc b/mindspore/core/ops/asin.cc index 2fc6275e40..32c16249bc 100644 --- a/mindspore/core/ops/asin.cc +++ b/mindspore/core/ops/asin.cc @@ -36,14 +36,8 @@ AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt // Infer Type auto dtype = input_args[0]->BuildType(); - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32}; - CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name); - auto tensor_type = dtype->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto element = tensor_type->element(); - MS_EXCEPTION_IF_NULL(element); - auto infer_type = std::make_shared(TypeIdToType(element->type_id())); - + const std::set valid_types = {kFloat16, kFloat32, kInt32}; + auto infer_type = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name); return std::make_shared(infer_type, infer_shape->shape()); } REGISTER_PRIMITIVE_C(kNameAsin, Asin); diff --git a/mindspore/core/ops/assert.cc b/mindspore/core/ops/assert.cc index c1bd60e255..9831df4217 100644 --- a/mindspore/core/ops/assert.cc +++ b/mindspore/core/ops/assert.cc @@ -61,15 +61,15 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive condition = input_args[0]->BuildType(); } std::vector output_shape = {1}; - std::set local_bool = {kNumberTypeBool}; + std::set local_bool = {kBool}; std::map args = {{"condition", condition}}; - CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name); + (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name); auto inputs_type = input_args[1]->BuildType()->cast()->elements(); for (auto dtype : inputs_type) { - std::set template_types = {TypeIdToType(kObjectTypeTensorType)}; + std::set template_types = {kTensorType}; CheckAndConvertUtils::CheckSubClass("input", dtype, template_types, op_name); } - return std::make_shared(TypeIdToType(kNumberTypeInt32), output_shape); + return std::make_shared(kInt32, output_shape); } REGISTER_PRIMITIVE_C(kNameAssert, Assert); } // namespace ops diff --git a/mindspore/core/ops/assign_add.cc b/mindspore/core/ops/assign_add.cc index 762ef6e15b..09aca18ffe 100644 --- a/mindspore/core/ops/assign_add.cc +++ b/mindspore/core/ops/assign_add.cc @@ -38,8 +38,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vectorBuildType()); types.emplace("w", input_args[1]->BuildType()); // check_scalar_or_tensor_types_same - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd"); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd"); } } // namespace AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/atan.cc b/mindspore/core/ops/atan.cc index 0d3309efb7..21a014cdf4 100644 --- a/mindspore/core/ops/atan.cc +++ b/mindspore/core/ops/atan.cc @@ -34,15 +34,9 @@ AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt // Infer Type auto dtype = input_args[0]->BuildType(); - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32}; - CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name); - auto tensor_type = dtype->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto element = tensor_type->element(); - MS_EXCEPTION_IF_NULL(element); - auto infer_type = std::make_shared(TypeIdToType(element->type_id())); - - return std::make_shared(infer_type, infer_shape->shape()); + const std::set valid_types = {kFloat16, kFloat32, kInt32}; + auto element = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name); + return std::make_shared(element, infer_shape->shape()); } REGISTER_PRIMITIVE_C(kNameAtan, Atan); } // namespace ops diff --git a/mindspore/core/ops/batch_norm.cc b/mindspore/core/ops/batch_norm.cc index 507ceea701..75cba64e47 100644 --- a/mindspore/core/ops/batch_norm.cc +++ b/mindspore/core/ops/batch_norm.cc @@ -107,20 +107,20 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit } // Infer type - auto input_x_type = input_args[0]->BuildType()->cast()->element(); auto scale_type = input_args[1]->BuildType()->cast()->element(); auto bias_type = input_args[2]->BuildType()->cast()->element(); - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; - CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); + const std::set valid_types = {kFloat16, kFloat32}; + auto input_x_type = + CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); std::map args; args.emplace("scale", input_args[1]->BuildType()); args.emplace("bias", input_args[2]->BuildType()); - CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); std::map args_moving; args_moving.emplace("scale", input_args[2]->BuildType()); args_moving.emplace("bias", input_args[3]->BuildType()); - CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name); auto output0 = std::make_shared(input_x_type, input_x); auto output1 = std::make_shared(scale_type, scale); diff --git a/mindspore/core/ops/batch_norm_fold.cc b/mindspore/core/ops/batch_norm_fold.cc index b98d549f22..dd48bfe72a 100644 --- a/mindspore/core/ops/batch_norm_fold.cc +++ b/mindspore/core/ops/batch_norm_fold.cc @@ -87,23 +87,8 @@ AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const Pr auto global_step_type = input_args[3]->BuildType(); std::map args = {{"x", x_type}, {"mean", mean_type}, {"variance", variance_type}}; - CheckAndConvertUtils::CheckTensorTypeSame(args, {kNumberTypeFloat16, kNumberTypeFloat32}, op_name); - CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kNumberTypeInt32}, op_name); - - auto tensor_type0 = x_type->cast(); - MS_EXCEPTION_IF_NULL(tensor_type0); - auto element0 = tensor_type0->element(); - - auto tensor_type1 = mean_type->cast(); - MS_EXCEPTION_IF_NULL(tensor_type1); - auto element1 = tensor_type1->element(); - - auto tensor_type2 = variance_type->cast(); - MS_EXCEPTION_IF_NULL(tensor_type2); - auto element2 = tensor_type2->element(); - - CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "mean_type", element1->type_id(), op_name); - CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "variance_type", element2->type_id(), op_name); + auto element0 = CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name); + CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kInt32}, op_name); auto output = std::make_shared(element0, mean_shape); AbstractBasePtrList output1 = {output, output, output, output}; diff --git a/mindspore/core/ops/batch_to_space.cc b/mindspore/core/ops/batch_to_space.cc index 7969714e0b..2a0a176b94 100644 --- a/mindspore/core/ops/batch_to_space.cc +++ b/mindspore/core/ops/batch_to_space.cc @@ -54,7 +54,8 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, + prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); diff --git a/mindspore/core/ops/bias_add.cc b/mindspore/core/ops/bias_add.cc index 359d18093a..b6660b28e0 100644 --- a/mindspore/core/ops/bias_add.cc +++ b/mindspore/core/ops/bias_add.cc @@ -55,8 +55,7 @@ TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector types; types.emplace("input_x", input_args[0]->BuildType()); types.emplace("bias", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace void BiasAdd::set_format(const Format &format) { diff --git a/mindspore/core/ops/binary_cross_entropy.cc b/mindspore/core/ops/binary_cross_entropy.cc index e1795596ee..52010fd663 100644 --- a/mindspore/core/ops/binary_cross_entropy.cc +++ b/mindspore/core/ops/binary_cross_entropy.cc @@ -57,7 +57,7 @@ TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; + const std::set valid_types = {kFloat16, kFloat32}; std::map types; types.emplace("x_shape", input_args[0]->BuildType()); types.emplace("y_shape", input_args[1]->BuildType()); @@ -67,7 +67,7 @@ TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vectorBuildType()); infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } - return TypeIdToType(infer_type); + return infer_type; } } // namespace diff --git a/mindspore/core/ops/broadcast.cc b/mindspore/core/ops/broadcast.cc index 9f82743c46..4cf62b610e 100644 --- a/mindspore/core/ops/broadcast.cc +++ b/mindspore/core/ops/broadcast.cc @@ -56,7 +56,7 @@ AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const Primit // infer type auto x_type = input_args[0]->BuildType()->cast()->element(); std::vector output_types; - const std::set valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; + const std::set valid_types = {kInt8, kInt32, kFloat16, kFloat32}; for (size_t i = 0; i < input_args.size(); i++) { auto out_type = input_args[i]->BuildType()->cast()->element(); output_types.push_back(out_type); diff --git a/mindspore/core/ops/broadcast_to.cc b/mindspore/core/ops/broadcast_to.cc index ba47a49534..85f038867b 100644 --- a/mindspore/core/ops/broadcast_to.cc +++ b/mindspore/core/ops/broadcast_to.cc @@ -57,11 +57,10 @@ TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vectorBuildType()->cast()->element(); - std::set template_types = {TypeIdToType(kObjectTypeTensorType)}; + auto x_dtype = input_args[0]->BuildType()->cast(); + std::set template_types = {kTensorType}; CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name()); - auto infer_dtype = input_args[0]->BuildType()->type_id(); - return TypeIdToType(infer_dtype); + return x_dtype->element(); } } // namespace diff --git a/mindspore/core/ops/ceil.cc b/mindspore/core/ops/ceil.cc index 2cb317a005..7a786a3ae2 100644 --- a/mindspore/core/ops/ceil.cc +++ b/mindspore/core/ops/ceil.cc @@ -33,13 +33,9 @@ AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt MS_EXCEPTION_IF_NULL(item); } auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Ceil"); - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; + const std::set valid_types = {kFloat16, kFloat32}; auto infer_type = input_args[0]->BuildType(); - CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name()); - MS_EXCEPTION_IF_NULL(infer_type); - auto tensor_type = infer_type->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto data_type = tensor_type->element(); + auto data_type = CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name()); MS_EXCEPTION_IF_NULL(data_type); return std::make_shared(data_type, x_shape); } diff --git a/mindspore/core/ops/concat.cc b/mindspore/core/ops/concat.cc index 7c4e9b225c..ba6906a839 100644 --- a/mindspore/core/ops/concat.cc +++ b/mindspore/core/ops/concat.cc @@ -74,9 +74,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, all_types, prim_name); auto ret_shape = element0_shape; ret_shape[axis] = all_shp; - - return std::make_shared(TypeIdToType(infer_type), - std::make_shared(ret_shape)); + return std::make_shared(infer_type, std::make_shared(ret_shape)); } REGISTER_PRIMITIVE_C(kNameConcat, Concat); } // namespace ops diff --git a/mindspore/core/ops/constant.cc b/mindspore/core/ops/constant.cc index 03bc5f0ada..05aed4dd57 100644 --- a/mindspore/core/ops/constant.cc +++ b/mindspore/core/ops/constant.cc @@ -42,8 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } std::map types; types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/conv2d.cc b/mindspore/core/ops/conv2d.cc index 4aff527387..f7cf304906 100644 --- a/mindspore/core/ops/conv2d.cc +++ b/mindspore/core/ops/conv2d.cc @@ -107,16 +107,11 @@ TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeFloat16, - kNumberTypeFloat32}; + const std::set valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32}; std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("w", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - if (infer_type == kNumberTypeInt8) { - return TypeIdToType(kNumberTypeInt32); - } - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace void Conv2D::Init(int64_t out_channel, const std::vector &kernel_size, int64_t mode, const PadMode &pad_mode, diff --git a/mindspore/core/ops/conv2d_transpose.cc b/mindspore/core/ops/conv2d_transpose.cc index 36c08c7d4b..78ce9f82a6 100644 --- a/mindspore/core/ops/conv2d_transpose.cc +++ b/mindspore/core/ops/conv2d_transpose.cc @@ -40,12 +40,11 @@ TypePtr Conv2dTransposeInferType(const PrimitivePtr &prim, const std::vector valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; + const std::set valid_types = {kInt8, kInt32, kFloat16, kFloat32}; std::map types; types.emplace("doutput_dtye", input_args[0]->BuildType()); types.emplace("w_dtype", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/cos.cc b/mindspore/core/ops/cos.cc index dc8b7880b4..77cbef9ce0 100644 --- a/mindspore/core/ops/cos.cc +++ b/mindspore/core/ops/cos.cc @@ -40,8 +40,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } std::map types; types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/custom_extract_features.cc b/mindspore/core/ops/custom_extract_features.cc index 1ea5b15650..b7d7d18ac8 100644 --- a/mindspore/core/ops/custom_extract_features.cc +++ b/mindspore/core/ops/custom_extract_features.cc @@ -31,8 +31,8 @@ AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &, // auto input = input_args[0]; // Infer type - auto output0_type = TypeIdToType(kNumberTypeInt32); - auto output1_type = TypeIdToType(kNumberTypeFloat32); + auto output0_type = kInt32; + auto output1_type = kFloat32; // Infer shape std::vector out_shape; diff --git a/mindspore/core/ops/custom_predict.cc b/mindspore/core/ops/custom_predict.cc index 1ed40fb978..e28c3af1ae 100644 --- a/mindspore/core/ops/custom_predict.cc +++ b/mindspore/core/ops/custom_predict.cc @@ -47,14 +47,14 @@ AbstractBasePtr CustomPredictInfer(const abstract::AnalysisEnginePtr &, const Pr MS_EXCEPTION_IF_NULL(primitive); auto CustomPredict_prim = primitive->cast(); MS_EXCEPTION_IF_NULL(CustomPredict_prim); - for (auto input : input_args) { + for (const auto &input : input_args) { MS_EXCEPTION_IF_NULL(input); } std::vector shape; shape.push_back(CustomPredict_prim->get_output_num()); - auto output0 = std::make_shared(TypeIdToType(kNumberTypeInt32), shape); - auto output1 = std::make_shared(TypeIdToType(kNumberTypeFloat32), shape); + auto output0 = std::make_shared(kInt32, shape); + auto output1 = std::make_shared(kFloat32, shape); AbstractBasePtrList output = {output0, output1}; return std::make_shared(output); } diff --git a/mindspore/core/ops/depthwise_conv2d.cc b/mindspore/core/ops/depthwise_conv2d.cc index bbc5599cf9..3b4b5e5961 100644 --- a/mindspore/core/ops/depthwise_conv2d.cc +++ b/mindspore/core/ops/depthwise_conv2d.cc @@ -216,10 +216,10 @@ TypePtr DepthWiseConv2DInferType(const PrimitivePtr &prim, const std::vectorBuildType()); types.emplace("w", input_args[1]->BuildType()); auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - if (infer_type == kNumberTypeInt8) { - return std::make_shared(TypeIdToType(kNumberTypeInt32)); + if (*infer_type == *kInt8) { + return kInt32; } - return TypeIdToType(infer_type); + return infer_type; } AbstractBasePtr DepthWiseConv2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/detection_post_process.cc b/mindspore/core/ops/detection_post_process.cc index 708f37df59..19f54ca07b 100644 --- a/mindspore/core/ops/detection_post_process.cc +++ b/mindspore/core/ops/detection_post_process.cc @@ -157,7 +157,7 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c std::vector output_num_shape = {1}; // Infer type - auto output_type = TypeIdToType(kNumberTypeFloat32); + auto output_type = kFloat32; auto output0 = std::make_shared(output_type, output_boxes_shape); auto output1 = std::make_shared(output_type, output_class_shape); diff --git a/mindspore/core/ops/div.cc b/mindspore/core/ops/div.cc index 6e39bb9fad..affcb3782d 100644 --- a/mindspore/core/ops/div.cc +++ b/mindspore/core/ops/div.cc @@ -41,8 +41,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/dropout.cc b/mindspore/core/ops/dropout.cc index fe8cff4661..13a7093616 100644 --- a/mindspore/core/ops/dropout.cc +++ b/mindspore/core/ops/dropout.cc @@ -53,15 +53,9 @@ AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const Primitiv auto infer_shape = std::make_shared(out_shape); // Infer type - auto dtype = input_args[0]->BuildType(); - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; - CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name); - auto tensor_type = dtype->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto element = tensor_type->element(); - MS_EXCEPTION_IF_NULL(element); - auto infer_type = std::make_shared(TypeIdToType(element->type_id())); - + const std::set valid_types = {kFloat16, kFloat32}; + auto infer_type = + CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", input_args[0]->BuildType(), valid_types, prim_name); return std::make_shared(infer_type, infer_shape->shape()); } REGISTER_PRIMITIVE_C(kNameDropout, Dropout); diff --git a/mindspore/core/ops/elu.cc b/mindspore/core/ops/elu.cc index a477d3dafd..7e0c16b197 100644 --- a/mindspore/core/ops/elu.cc +++ b/mindspore/core/ops/elu.cc @@ -46,10 +46,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & MS_LOG(EXCEPTION) << "nullptr"; } std::map types; - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace void Elu::Init(const float alpha) { this->set_alpha(alpha); } diff --git a/mindspore/core/ops/embedding_lookup.cc b/mindspore/core/ops/embedding_lookup.cc index 501b220674..1ad5824088 100644 --- a/mindspore/core/ops/embedding_lookup.cc +++ b/mindspore/core/ops/embedding_lookup.cc @@ -45,14 +45,9 @@ AbstractBasePtr EmbeddingLookupInfer(const abstract::AnalysisEnginePtr &, const MS_EXCEPTION_IF_NULL(params); auto indices = input_args[1]->cast(); MS_EXCEPTION_IF_NULL(indices); - const std::set int_valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64}; - CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices->BuildType(), int_valid_types, prim_name); - MS_EXCEPTION_IF_NULL(input_args[2]->BuildType()); - auto offset_type = input_args[2]->BuildType()->type_id(); - if (int_valid_types.find(offset_type) == int_valid_types.end()) { - MS_LOG(EXCEPTION) << "offset must be int."; - } - + const std::set int_valid_types = {kInt8, kInt16, kInt32, kInt64}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices->BuildType(), int_valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("offset", input_args[2]->BuildType(), int_valid_types, prim_name); MS_EXCEPTION_IF_NULL(params->shape()); auto params_shp = params->shape()->shape(); MS_EXCEPTION_IF_NULL(indices->shape()); diff --git a/mindspore/core/ops/equal.cc b/mindspore/core/ops/equal.cc index 58ce9f3be6..f44d0742c7 100644 --- a/mindspore/core/ops/equal.cc +++ b/mindspore/core/ops/equal.cc @@ -42,8 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/exp.cc b/mindspore/core/ops/exp.cc index 6868b90441..c59de0b67a 100644 --- a/mindspore/core/ops/exp.cc +++ b/mindspore/core/ops/exp.cc @@ -39,8 +39,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } std::map types; types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace AbstractBasePtr ExpInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/expand_dims.cc b/mindspore/core/ops/expand_dims.cc index fc965e3cb4..62d9ad5f68 100644 --- a/mindspore/core/ops/expand_dims.cc +++ b/mindspore/core/ops/expand_dims.cc @@ -50,10 +50,10 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi out_shape.insert(out_shape.begin() + dim_val, 1, 1); // Infer type - auto x_type = input_args[0]->BuildType()->cast()->element(); - std::set valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; + auto x_type = input_args[0]->BuildType()->cast(); + std::set valid_x_type = {kTensorType}; CheckAndConvertUtils::CheckSubClass("x_type", x_type, valid_x_type, prim_name); - return std::make_shared(x_type, out_shape); + return std::make_shared(x_type->element(), out_shape); } REGISTER_PRIMITIVE_C(kNameExpandDims, ExpandDims); } // namespace ops diff --git a/mindspore/core/ops/fake_quant_with_min_max_vars.cc b/mindspore/core/ops/fake_quant_with_min_max_vars.cc index 19594d71f9..3933f0647d 100644 --- a/mindspore/core/ops/fake_quant_with_min_max_vars.cc +++ b/mindspore/core/ops/fake_quant_with_min_max_vars.cc @@ -48,7 +48,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; + const std::set valid_types = {kFloat16, kFloat32}; if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { MS_LOG(EXCEPTION) << "nullptr"; } @@ -56,8 +56,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & types.emplace("x", input_args[0]->BuildType()); types.emplace("min", input_args[1]->BuildType()); types.emplace("max", input_args[2]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace void FakeQuantWithMinMaxVars::Init(const bool narrow_range, const int64_t num_bits) { diff --git a/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc b/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc index 681c587f00..80dc976c41 100644 --- a/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc +++ b/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc @@ -60,8 +60,7 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE std::vector type_name = {"x", "min", "max"}; std::vector type = {x_type, min_type, max_type}; for (int64_t i = 0; i < 3; i++) { - CheckAndConvertUtils::CheckTensorTypeValid(type_name[i], type[i], {kNumberTypeFloat16, kNumberTypeFloat32}, - op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid(type_name[i], type[i], {kFloat16, kFloat32}, op_name); } auto tensor_type = x_type->cast(); MS_EXCEPTION_IF_NULL(tensor_type); diff --git a/mindspore/core/ops/fft_imag.cc b/mindspore/core/ops/fft_imag.cc index bea9075290..fb2545041c 100644 --- a/mindspore/core/ops/fft_imag.cc +++ b/mindspore/core/ops/fft_imag.cc @@ -36,7 +36,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - return TypeIdToType(kNumberTypeFloat32); + return kFloat32; } } // namespace diff --git a/mindspore/core/ops/fill.cc b/mindspore/core/ops/fill.cc index d332034b49..9578aa6a9b 100644 --- a/mindspore/core/ops/fill.cc +++ b/mindspore/core/ops/fill.cc @@ -37,8 +37,8 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt auto dtype = dtype_value->cast(); MS_EXCEPTION_IF_NULL(dtype); auto valid_types = common_valid_types; - valid_types.insert(kNumberTypeBool); - CheckAndConvertUtils::CheckTypeSame("output datatype", dtype, valid_types, prim_name); + valid_types.insert(kBool); + (void)CheckAndConvertUtils::CheckTypeValid("output datatype", dtype, valid_types, prim_name); auto out_shape = GetValue>(input_args[1]->BuildValue()); auto x_type = input_args[2]->BuildType(); auto x_type_id = x_type->type_id(); diff --git a/mindspore/core/ops/flatten.cc b/mindspore/core/ops/flatten.cc index 7d880e5844..9c09f8f7c9 100644 --- a/mindspore/core/ops/flatten.cc +++ b/mindspore/core/ops/flatten.cc @@ -42,7 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & MS_EXCEPTION_IF_NULL(item); } auto infer_type = input_args[0]->BuildType()->cast()->element(); - const std::set valid_types = {TypeIdToType(kObjectTypeTensorType)}; + const std::set valid_types = {kTensorType}; CheckAndConvertUtils::CheckSubClass("infer type", input_args[0]->BuildType(), valid_types, prim->name()); return infer_type; } diff --git a/mindspore/core/ops/floor.cc b/mindspore/core/ops/floor.cc index eb0d5e41d3..56bc0c6370 100644 --- a/mindspore/core/ops/floor.cc +++ b/mindspore/core/ops/floor.cc @@ -39,14 +39,13 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; - if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; + if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { MS_LOG(EXCEPTION) << "nullptr"; } std::map types; types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace AbstractBasePtr FloorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/fusion/add_fusion.cc b/mindspore/core/ops/fusion/add_fusion.cc index 9be0433111..384c2f3007 100644 --- a/mindspore/core/ops/fusion/add_fusion.cc +++ b/mindspore/core/ops/fusion/add_fusion.cc @@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/fusion/pow_fusion.cc b/mindspore/core/ops/fusion/pow_fusion.cc index 5566d84d07..9fadab92b7 100644 --- a/mindspore/core/ops/fusion/pow_fusion.cc +++ b/mindspore/core/ops/fusion/pow_fusion.cc @@ -50,8 +50,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/gather.cc b/mindspore/core/ops/gather.cc index ed61f84ad4..31ec4947b5 100644 --- a/mindspore/core/ops/gather.cc +++ b/mindspore/core/ops/gather.cc @@ -27,12 +27,12 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive CheckAndConvertUtils::CheckInteger("gather_infer", input_args.size(), kEqual, 3, prim_name); // Infer type - auto x_type = input_args[0]->BuildType()->cast()->element(); - std::set valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; - CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); - const std::set valid_index_types = {kNumberTypeInt32, kNumberTypeInt64}; + std::set valid_x_type = {kTensorType}; + auto x_type = + CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); + std::set valid_index_types = {kInt32, kInt64}; CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_index_types, prim_name); - std::set valid_dim_type = {TypeIdToType(kNumberTypeInt32), TypeIdToType(kNumberTypeInt64)}; + std::set valid_dim_type = {kInt32, kInt64}; CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_dim_type, prim_name); // Infer shape diff --git a/mindspore/core/ops/gather_nd.cc b/mindspore/core/ops/gather_nd.cc index 3772b6e1fc..12760e195e 100644 --- a/mindspore/core/ops/gather_nd.cc +++ b/mindspore/core/ops/gather_nd.cc @@ -52,14 +52,13 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - const std::set valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64}; - if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { + const std::set valid_types = {kInt8, kInt16, kInt32, kInt64}; + if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { MS_LOG(EXCEPTION) << "nullptr"; } std::map types; types.emplace("input_x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace AbstractBasePtr GatherNdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/gelu.cc b/mindspore/core/ops/gelu.cc index d4fe5dea1a..e9a429d424 100644 --- a/mindspore/core/ops/gelu.cc +++ b/mindspore/core/ops/gelu.cc @@ -39,11 +39,10 @@ TypePtr GeLUInferType(const PrimitivePtr &prim, const std::vector valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; + const std::set valid_types = {kFloat16, kFloat32}; std::map types; types.emplace("input_x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/grad/binary_cross_entropy_grad.cc b/mindspore/core/ops/grad/binary_cross_entropy_grad.cc index db167d4f2f..fd20e41e70 100644 --- a/mindspore/core/ops/grad/binary_cross_entropy_grad.cc +++ b/mindspore/core/ops/grad/binary_cross_entropy_grad.cc @@ -44,7 +44,7 @@ TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vect for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; + const std::set valid_types = {kFloat16, kFloat32}; std::map types; types.emplace("x_shape", input_args[0]->BuildType()); types.emplace("y_shape", input_args[1]->BuildType()); @@ -54,7 +54,7 @@ TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vect types.emplace("weight_shape", input_args[2]->BuildType()); infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } - return TypeIdToType(infer_type); + return infer_type; } } // namespace void BinaryCrossEntropyGrad::Init(const Reduction &reduction) { set_reduction(reduction); } diff --git a/mindspore/core/ops/grad/conv2d_backprop_filter.cc b/mindspore/core/ops/grad/conv2d_backprop_filter.cc index c290c0c359..c6815c5e0d 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_filter.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_filter.cc @@ -36,12 +36,11 @@ TypePtr Conv2DBackpropFilterInferType(const PrimitivePtr &prim, const std::vecto for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - const std::set valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; + const std::set valid_types = {kInt8, kInt32, kFloat16, kFloat32}; std::map types; types.emplace("drotput", input_args[0]->BuildType()); types.emplace("input_x", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/grad/conv2d_backprop_input.cc b/mindspore/core/ops/grad/conv2d_backprop_input.cc index b801164cac..cfddf16230 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_input.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_input.cc @@ -28,7 +28,7 @@ AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, co MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name); - for (auto item : input_args) { + for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } auto doutput = input_args[0]; diff --git a/mindspore/core/ops/grad/dropout_grad.cc b/mindspore/core/ops/grad/dropout_grad.cc index dd03cf9368..d73fa9ea05 100644 --- a/mindspore/core/ops/grad/dropout_grad.cc +++ b/mindspore/core/ops/grad/dropout_grad.cc @@ -49,8 +49,8 @@ TypePtr DropoutGradInferType(const PrimitivePtr &prim, const std::vectorname(); auto mask_dtype = input_args[1]->BuildType(); auto dy_dtype = input_args[0]->BuildType(); - CheckAndConvertUtils::CheckSubClass("mask", mask_dtype, {TypeIdToType(kObjectTypeTensorType)}, op_name); - CheckAndConvertUtils::CheckTensorTypeValid("dy", dy_dtype, {kNumberTypeFloat16, kNumberTypeFloat32}, op_name); + CheckAndConvertUtils::CheckTensorTypeValid("mask", mask_dtype, {kTensorType}, op_name); + CheckAndConvertUtils::CheckTensorTypeValid("dy", dy_dtype, {kFloat16, kFloat32}, op_name); auto tensor_type = dy_dtype->cast(); MS_EXCEPTION_IF_NULL(tensor_type); auto data_type = tensor_type->element(); diff --git a/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc b/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc index 47b2748e02..bd559e3798 100644 --- a/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc +++ b/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc @@ -44,18 +44,13 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError); // Infer type - const std::set valid_types = { - kNumberTypeBool, kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, - kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, - kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, - kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeComplex64}; + const std::set valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8, + kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64}; std::map args; args.emplace("x_type", input_args[0]->BuildType()); args.emplace("y_type", input_args[1]->BuildType()); args.emplace("dout_type", input_args[2]->BuildType()); - CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); - auto dout_type = input_args[2]->BuildType()->cast()->element(); - + auto dout_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); return std::make_shared(dout_type, x_shape); } REGISTER_PRIMITIVE_C(kNameSigmoidCrossEntropyWithLogitsGrad, SigmoidCrossEntropyWithLogitsGrad); diff --git a/mindspore/core/ops/grad/smooth_l1_loss_grad.cc b/mindspore/core/ops/grad/smooth_l1_loss_grad.cc index f0e19b48c5..26585f34ed 100644 --- a/mindspore/core/ops/grad/smooth_l1_loss_grad.cc +++ b/mindspore/core/ops/grad/smooth_l1_loss_grad.cc @@ -49,17 +49,13 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError); // Infer type - const std::set valid_types = { - kNumberTypeBool, kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, - kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, - kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, - kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeComplex64}; + const std::set valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8, + kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64}; std::map args; args.emplace("prediction", input_args[0]->BuildType()); args.emplace("target", input_args[1]->BuildType()); args.emplace("dloss", input_args[2]->BuildType()); - CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); - auto dloss_type = input_args[2]->BuildType()->cast()->element(); + auto dloss_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); return std::make_shared(dloss_type, prediction); } diff --git a/mindspore/core/ops/hashtable_lookup.cc b/mindspore/core/ops/hashtable_lookup.cc index 0e8171f828..c7496d1305 100644 --- a/mindspore/core/ops/hashtable_lookup.cc +++ b/mindspore/core/ops/hashtable_lookup.cc @@ -41,7 +41,7 @@ AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const auto data_type = tensor_type->element(); std::vector value_shape; auto output = std::make_shared(data_type, value_shape); - auto hits = std::make_shared(TypeIdToType(kNumberTypeInt8), hits_shape); + auto hits = std::make_shared(kInt8, hits_shape); AbstractBasePtrList output1 = {output, hits}; if (input_args[0]->BuildValue()->cast()->data_c() == nullptr) { diff --git a/mindspore/core/ops/l2_normalize.cc b/mindspore/core/ops/l2_normalize.cc index 1c58c1f8ee..08d4ce599d 100644 --- a/mindspore/core/ops/l2_normalize.cc +++ b/mindspore/core/ops/l2_normalize.cc @@ -49,8 +49,8 @@ AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const Prim for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; - CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); + const std::set valid_types = {kFloat16, kFloat32}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto x_rank = SizeToLong(x_shape.size()); auto axiss = prim->get_axis(); diff --git a/mindspore/core/ops/leaky_relu.cc b/mindspore/core/ops/leaky_relu.cc index ee5e28f009..54f8fb78b2 100644 --- a/mindspore/core/ops/leaky_relu.cc +++ b/mindspore/core/ops/leaky_relu.cc @@ -35,8 +35,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } std::map types; types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace void LeakyRelu::Init(const float negative_slope) { this->set_negative_slope(negative_slope); } diff --git a/mindspore/core/ops/less.cc b/mindspore/core/ops/less.cc index 56cfeb67d8..6c5fc4e31a 100644 --- a/mindspore/core/ops/less.cc +++ b/mindspore/core/ops/less.cc @@ -41,7 +41,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(kNumberTypeBool); + return kBool; } } // namespace diff --git a/mindspore/core/ops/less_equal.cc b/mindspore/core/ops/less_equal.cc index 228b75caf9..0a8c87f664 100644 --- a/mindspore/core/ops/less_equal.cc +++ b/mindspore/core/ops/less_equal.cc @@ -41,8 +41,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/local_response_normalization.cc b/mindspore/core/ops/local_response_normalization.cc index 5165de2e75..3896bf63b1 100644 --- a/mindspore/core/ops/local_response_normalization.cc +++ b/mindspore/core/ops/local_response_normalization.cc @@ -43,8 +43,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } std::map types; types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/log.cc b/mindspore/core/ops/log.cc index c9617cf125..f64804c2d0 100644 --- a/mindspore/core/ops/log.cc +++ b/mindspore/core/ops/log.cc @@ -29,10 +29,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - auto infer_type = input_args[0]->BuildType()->cast()->element(); - const std::set valid_types = {TypeIdToType(kObjectTypeTensorType)}; - CheckAndConvertUtils::CheckSubClass("infer type", input_args[0]->BuildType(), valid_types, prim->name()); - return infer_type; + const std::set valid_types = {kTensorType}; + return CheckAndConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types, + prim->name()); } } // namespace diff --git a/mindspore/core/ops/logical_and.cc b/mindspore/core/ops/logical_and.cc index fb33d56023..4db682377e 100644 --- a/mindspore/core/ops/logical_and.cc +++ b/mindspore/core/ops/logical_and.cc @@ -39,14 +39,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & MS_EXCEPTION_IF_NULL(item); } std::map types; - const std::set valid_types = {kNumberTypeBool}; + const std::set valid_types = {kBool}; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - if (infer_type == kNumberTypeBool) { - return TypeIdToType(infer_type); - } - return std::make_shared(TypeIdToType(kNumberTypeBool)); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/logical_not.cc b/mindspore/core/ops/logical_not.cc index 5581a835b1..f6dfc6d2b9 100644 --- a/mindspore/core/ops/logical_not.cc +++ b/mindspore/core/ops/logical_not.cc @@ -37,12 +37,8 @@ TypePtr LogicalNotInferType(const PrimitivePtr &prim, const std::vectorname(); auto infer_dtype = input_args[0]->BuildType(); - std::set local_bool = {kNumberTypeBool}; - CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, local_bool, op_name); - auto tensor_type = infer_dtype->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto element = tensor_type->element(); - return element; + std::set local_bool = {kBool}; + return CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, local_bool, op_name); } } // namespace AbstractBasePtr LogicalNotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/logical_or.cc b/mindspore/core/ops/logical_or.cc index 3614c994ca..e908ebb4b5 100644 --- a/mindspore/core/ops/logical_or.cc +++ b/mindspore/core/ops/logical_or.cc @@ -40,14 +40,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & MS_EXCEPTION_IF_NULL(item); } std::map types; - const std::set valid_types = {kNumberTypeBool}; + const std::set valid_types = {kBool}; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - if (infer_type == kNumberTypeBool) { - return TypeIdToType(infer_type); - } - return std::make_shared(TypeIdToType(kNumberTypeBool)); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/lrn.cc b/mindspore/core/ops/lrn.cc index 09360cab77..7760800631 100644 --- a/mindspore/core/ops/lrn.cc +++ b/mindspore/core/ops/lrn.cc @@ -86,14 +86,13 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; + const std::set valid_types = {kFloat16, kFloat32}; if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { MS_LOG(EXCEPTION) << "nullptr"; } std::map types; types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/lsh_projection.cc b/mindspore/core/ops/lsh_projection.cc index c7049817ff..dd9d1b8ebe 100644 --- a/mindspore/core/ops/lsh_projection.cc +++ b/mindspore/core/ops/lsh_projection.cc @@ -61,8 +61,7 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr out_shape.push_back(input0[0] * input0[1]); break; } - TypePtr infer_type = TypeIdToType(kNumberTypeInt32); - return std::make_shared(infer_type, out_shape); + return std::make_shared(kInt32, out_shape); } REGISTER_PRIMITIVE_C(kNameLshProjection, LshProjection); } // namespace ops diff --git a/mindspore/core/ops/mat_mul.cc b/mindspore/core/ops/mat_mul.cc index 0117868bdf..53749eb77e 100644 --- a/mindspore/core/ops/mat_mul.cc +++ b/mindspore/core/ops/mat_mul.cc @@ -55,16 +55,11 @@ TypePtr MatMulInferType(const PrimitivePtr &prim, const std::vector valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, - kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; + const std::set valid_types = {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32, kFloat64}; std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("w", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - if (infer_type == kNumberTypeInt8) { - return std::make_shared(TypeIdToType(kNumberTypeInt32)); - } - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/matrix_diag.cc b/mindspore/core/ops/matrix_diag.cc index c074c46e0d..a97b1536ad 100644 --- a/mindspore/core/ops/matrix_diag.cc +++ b/mindspore/core/ops/matrix_diag.cc @@ -59,19 +59,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - const std::set valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeUInt8, kNumberTypeFloat16, - kNumberTypeFloat32}; + const std::set valid_types = {kInt8, kInt32, kUInt8, kFloat16, kFloat32}; std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("assist", input_args[1]->BuildType()); - CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - auto type = input_args[0]->BuildType(); - MS_EXCEPTION_IF_NULL(type); - auto tensor_type = type->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto data_type = tensor_type->element(); - MS_EXCEPTION_IF_NULL(data_type); - return data_type; + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/maximum.cc b/mindspore/core/ops/maximum.cc index 1647569277..b212667203 100644 --- a/mindspore/core/ops/maximum.cc +++ b/mindspore/core/ops/maximum.cc @@ -38,8 +38,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/merge.cc b/mindspore/core/ops/merge.cc index 7ef024794a..540fc8802a 100644 --- a/mindspore/core/ops/merge.cc +++ b/mindspore/core/ops/merge.cc @@ -38,16 +38,13 @@ AbstractBasePtr MergeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP for (int64_t i = 0; i != (int64_t)inputs_type.size(); i++) { args.insert({"input[" + std::to_string(i) + "]", inputs_type[i]}); } - std::set template_type = {kNumberTypeBool}; - for (auto item : common_valid_types) { - template_type.insert(item); - } - CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, template_type, op_name); + std::set template_type = common_valid_types; + template_type.emplace(kBool); + auto infered_type = CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, template_type, op_name); std::vector in_shape0 = inputs_shape[0]->cast()->shape(); - auto output1 = - std::make_shared(inputs_type[0]->cast()->element(), in_shape0); - auto output2 = std::make_shared(TypeIdToType(kNumberTypeInt32), std::vector{1}); + auto output1 = std::make_shared(infered_type, in_shape0); + auto output2 = std::make_shared(kInt32, std::vector{1}); AbstractBasePtrList output = {output1, output2}; return std::make_shared(output); diff --git a/mindspore/core/ops/minimum.cc b/mindspore/core/ops/minimum.cc index 6c38d1e04b..2cc820607e 100644 --- a/mindspore/core/ops/minimum.cc +++ b/mindspore/core/ops/minimum.cc @@ -42,8 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/neg.cc b/mindspore/core/ops/neg.cc index e90b4ed8d3..b411d0dea5 100644 --- a/mindspore/core/ops/neg.cc +++ b/mindspore/core/ops/neg.cc @@ -31,7 +31,8 @@ AbstractBasePtr NegInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, + prim_name); return input_args[0]->Broaden(); } REGISTER_PRIMITIVE_C(kNameNeg, Neg); diff --git a/mindspore/core/ops/non_max_suppression.cc b/mindspore/core/ops/non_max_suppression.cc index bd36a562e1..195523d5ff 100644 --- a/mindspore/core/ops/non_max_suppression.cc +++ b/mindspore/core/ops/non_max_suppression.cc @@ -36,7 +36,7 @@ AbstractBasePtr NonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &, cons auto non_max_suppression_prim = primitive->cast(); MS_EXCEPTION_IF_NULL(non_max_suppression_prim); MS_LOG(INFO) << "NonMaxSuppression infer shape in runtime."; - return std::make_shared(TypeIdToType(kNumberTypeInt32), std::vector{}); + return std::make_shared(kInt32, std::vector{}); } REGISTER_PRIMITIVE_C(kNameNonMaxSuppression, NonMaxSuppression); } // namespace ops diff --git a/mindspore/core/ops/one_hot.cc b/mindspore/core/ops/one_hot.cc index 375603c420..d1e2c11410 100644 --- a/mindspore/core/ops/one_hot.cc +++ b/mindspore/core/ops/one_hot.cc @@ -53,17 +53,11 @@ TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vectorcast(); MS_EXCEPTION_IF_NULL(OneHot_prim); auto op_name = OneHot_prim->name(); - CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kNumberTypeInt32}, op_name); - CheckAndConvertUtils::CheckTypeSame("depth", input_args[1]->BuildType(), - {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64}, op_name); - auto value_type = input_args[2]->BuildType(); - auto tensor_type = value_type->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto element = tensor_type->element(); - MS_EXCEPTION_IF_NULL(element); - std::map args = {{"on_value", value_type}, {"off_dtype", input_args[3]->BuildType()}}; - CheckAndConvertUtils::CheckTensorTypeSame(args, {kNumberTypeFloat16, kNumberTypeFloat32}, op_name); - return element; + CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32}, op_name); + CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64}, op_name); + std::map args = {{"on_value", input_args[2]->BuildType()}, + {"off_dtype", input_args[3]->BuildType()}}; + return CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name); } } // namespace AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/ones_like.cc b/mindspore/core/ops/ones_like.cc index 4872ed2ac9..e1c62d2f6a 100644 --- a/mindspore/core/ops/ones_like.cc +++ b/mindspore/core/ops/ones_like.cc @@ -37,13 +37,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - // const std::set valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, - // kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, - // kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, - // kNumberTypeBool}; auto infer_type = input_args[0]->BuildType(); - CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, common_valid_types, "OnesLike"); - return infer_type; + return CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, common_valid_types, "OnesLike"); } } // namespace AbstractBasePtr OnesLikeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/op_utils.h b/mindspore/core/ops/op_utils.h index 51436c0042..12af5c8b9a 100644 --- a/mindspore/core/ops/op_utils.h +++ b/mindspore/core/ops/op_utils.h @@ -230,14 +230,12 @@ constexpr auto kSpliceContext = "context"; constexpr auto kSpliceForwardIndexes = "forward_indexes"; constexpr auto kSpliceOutputDims = "output_dim"; -const std::set common_valid_types = { - kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16, - kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; +const std::set common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, + kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; -const std::set all_types = { - kNumberTypeBool, kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, - kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, - kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeComplex64, +const std::set all_types = { + kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8, + kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64, }; abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector &input_args); diff --git a/mindspore/core/ops/pad.cc b/mindspore/core/ops/pad.cc index 05ffdcdb37..5961fd7486 100644 --- a/mindspore/core/ops/pad.cc +++ b/mindspore/core/ops/pad.cc @@ -49,10 +49,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - const std::set valid_types = {TypeIdToType(kObjectTypeTensorType)}; - auto infer_type = input_args[0]->BuildType(); - CheckAndConvertUtils::CheckSubClass("infer type", infer_type, valid_types, prim->name()); - return infer_type; + const std::set valid_types = {kTensorType}; + return CheckAndConvertUtils::CheckSubClass("infer type", input_args[0]->BuildType(), valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/pow.cc b/mindspore/core/ops/pow.cc index faca7068a7..510c97bd4e 100644 --- a/mindspore/core/ops/pow.cc +++ b/mindspore/core/ops/pow.cc @@ -37,8 +37,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/prelu.cc b/mindspore/core/ops/prelu.cc index 6aa95f3440..9cd8cc98a8 100644 --- a/mindspore/core/ops/prelu.cc +++ b/mindspore/core/ops/prelu.cc @@ -46,13 +46,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; - CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim->name()); - CheckAndConvertUtils::CheckTensorTypeValid("weight", input_args[1]->BuildType(), valid_types, prim->name()); - auto tensor_type = input_args[0]->BuildType()->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto input_x_type = tensor_type->element(); - return input_x_type; + const std::set valid_types = {kFloat16, kFloat32}; + std::map check_map = {{"input_x", input_args[0]->BuildType()}, + {"weight", input_args[1]->BuildType()}}; + return CheckAndConvertUtils::CheckTensorTypeSame(check_map, valid_types, prim->name()); } } // namespace AbstractBasePtr PReLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/prior_box.cc b/mindspore/core/ops/prior_box.cc index c5cdb1d637..edb3057226 100644 --- a/mindspore/core/ops/prior_box.cc +++ b/mindspore/core/ops/prior_box.cc @@ -143,7 +143,7 @@ AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const Primiti auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); int64_t h = input[0] * input[1] * num_priors_box * 4; std::vector output_shape{1, h, 1, 2}; - return std::make_shared(TypeIdToType(kNumberTypeFloat32), output_shape); + return std::make_shared(kFloat32, output_shape); } REGISTER_PRIMITIVE_C(kNamePriorBox, PriorBox); } // namespace ops diff --git a/mindspore/core/ops/range.cc b/mindspore/core/ops/range.cc index 61ee0f9da2..7e866d8d28 100644 --- a/mindspore/core/ops/range.cc +++ b/mindspore/core/ops/range.cc @@ -100,12 +100,11 @@ AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP int64_t start = prim->get_start(); int64_t limit = prim->get_limit(); int64_t delta = prim->get_delta(); - dtype = kNumberTypeInt32; shape_size = std::max(static_cast(std::ceil(LongToDouble(limit - start) / delta)), static_cast(0)); } return std::make_shared( - TypeIdToType(dtype), std::make_shared(std::vector{shape_size})); + kInt32, std::make_shared(std::vector{shape_size})); } REGISTER_PRIMITIVE_C(kNameRange, Range); } // namespace ops diff --git a/mindspore/core/ops/rank.cc b/mindspore/core/ops/rank.cc index 6f066f151e..b10e324bee 100644 --- a/mindspore/core/ops/rank.cc +++ b/mindspore/core/ops/rank.cc @@ -25,8 +25,8 @@ TypePtr RankInferType(const PrimitivePtr &prim, const std::vectorname(); auto infer_dtype = input_args[0]->BuildType(); - CheckAndConvertUtils::CheckSubClass("x", infer_dtype, {TypeIdToType(kObjectTypeTensorType)}, op_name); - return TypeIdToType(kMetaTypeNone); + CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, {kTensorType}, op_name); + return kTypeNone; } } // namespace AbstractBasePtr RankInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/real_div.cc b/mindspore/core/ops/real_div.cc index f81da9bde7..bd582a16a1 100644 --- a/mindspore/core/ops/real_div.cc +++ b/mindspore/core/ops/real_div.cc @@ -41,8 +41,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/reciprocal.cc b/mindspore/core/ops/reciprocal.cc index 635e94a422..35d3202036 100644 --- a/mindspore/core/ops/reciprocal.cc +++ b/mindspore/core/ops/reciprocal.cc @@ -39,9 +39,8 @@ AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const Primi auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), prim_name); // infer type - auto x_type = input_args[0]->BuildType()->cast()->element(); - std::set valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; - CheckAndConvertUtils::CheckSubClass("x_type", x_type, valid_x_type, prim_name); + std::set valid_x_type = {kTensorType}; + auto x_type = CheckAndConvertUtils::CheckTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); return std::make_shared(x_type, in_shape); } REGISTER_PRIMITIVE_C(kNameReciprocal, Reciprocal); diff --git a/mindspore/core/ops/reduce.cc b/mindspore/core/ops/reduce.cc index fa46e7622e..0114473032 100644 --- a/mindspore/core/ops/reduce.cc +++ b/mindspore/core/ops/reduce.cc @@ -87,10 +87,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - std::map types; - types.emplace("input_x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, + prim->name()); } } // namespace diff --git a/mindspore/core/ops/relu6.cc b/mindspore/core/ops/relu6.cc index 64fc0ff2d5..e74ac04882 100644 --- a/mindspore/core/ops/relu6.cc +++ b/mindspore/core/ops/relu6.cc @@ -35,14 +35,13 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; - if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { + const std::set valid_types = {kFloat16, kFloat32}; + if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { MS_LOG(EXCEPTION) << "nullptr"; } std::map types; types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace AbstractBasePtr ReLU6Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/resize_bilinear.cc b/mindspore/core/ops/resize_bilinear.cc index 433a9252db..0ed4cacf06 100644 --- a/mindspore/core/ops/resize_bilinear.cc +++ b/mindspore/core/ops/resize_bilinear.cc @@ -63,12 +63,9 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P out_shape.insert(out_shape.end(), size.begin(), size.end()); // Infer type - auto input_type = input_args[0]->BuildType()->cast()->element(); - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; - CheckAndConvertUtils::CheckTensorTypeValid("input_type", input_type, valid_types, prim_name); - auto out_type = TypeIdToType(kNumberTypeFloat32); - - return std::make_shared(out_type, out_shape); + const std::set valid_types = {kFloat16, kFloat32}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("input_type", input_args[0]->BuildType(), valid_types, prim_name); + return std::make_shared(kFloat32, out_shape); } REGISTER_PRIMITIVE_C(kNameResizeBilinear, ResizeBilinear); } // namespace ops diff --git a/mindspore/core/ops/reverse_sequence.cc b/mindspore/core/ops/reverse_sequence.cc index 4e556d8bfd..b770c118f0 100644 --- a/mindspore/core/ops/reverse_sequence.cc +++ b/mindspore/core/ops/reverse_sequence.cc @@ -62,15 +62,14 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const CheckAndConvertUtils::CheckInteger("seq_lengths vector size", seq_lengths[0], kEqual, input_shape[batch_dim], prim_name); // infer type - std::set tmp(common_valid_types); - tmp.insert(kNumberTypeBool); - const std::set valid_x_types(tmp); - const std::set valid_seq_types = {kNumberTypeInt32, kNumberTypeInt64}; + std::set valid_x_types(common_valid_types); + valid_x_types.emplace(kBool); + const std::set valid_seq_types = {kInt32, kInt64}; auto x_type = input_args[0]->BuildType()->cast()->element(); auto seq_type = input_args[1]->BuildType()->cast()->element(); - CheckAndConvertUtils::CheckTensorTypeValid("x_type", x_type, valid_x_types, prim_name); - CheckAndConvertUtils::CheckTensorTypeValid("seq_type", seq_type, valid_seq_types, prim_name); - return std::make_shared(x_type, input_shape); + auto infered_type = CheckAndConvertUtils::CheckTensorTypeValid("x_type", x_type, valid_x_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("seq_type", seq_type, valid_seq_types, prim_name); + return std::make_shared(infered_type, input_shape); } REGISTER_PRIMITIVE_C(kNameReverseSequence, ReverseSequence); } // namespace ops diff --git a/mindspore/core/ops/reverse_v2.cc b/mindspore/core/ops/reverse_v2.cc index 30a534ff89..3aa2e74f46 100644 --- a/mindspore/core/ops/reverse_v2.cc +++ b/mindspore/core/ops/reverse_v2.cc @@ -28,11 +28,6 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorname(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - // auto axis = reverseV2_prim->get_axis(); - // int dim = x_shape.size(); - // for (auto &axis_value : axis) { - // CheckAndConvertUtils::CheckInRange("axis value", axis_value, kIncludeLeft, {-dim, dim}, prim_name); - // } return std::make_shared(x_shape); } @@ -40,17 +35,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - const std::set valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, - kNumberTypeUInt8, kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, - kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool}; + const std::set valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, + kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kBool}; auto infer_type = input_args[0]->BuildType(); - CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, prim->name()); - MS_EXCEPTION_IF_NULL(infer_type); - auto tensor_type = infer_type->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto data_type = tensor_type->element(); - MS_EXCEPTION_IF_NULL(data_type); - return data_type; + return CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/rfft.cc b/mindspore/core/ops/rfft.cc index e65d388986..3dc04871e3 100644 --- a/mindspore/core/ops/rfft.cc +++ b/mindspore/core/ops/rfft.cc @@ -39,7 +39,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - return TypeIdToType(kNumberTypeComplex64); + return kComplex64; } } // namespace diff --git a/mindspore/core/ops/round.cc b/mindspore/core/ops/round.cc index 5d91dba5e4..41a5920894 100644 --- a/mindspore/core/ops/round.cc +++ b/mindspore/core/ops/round.cc @@ -29,13 +29,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { auto infer_type = input_args[0]->BuildType(); - CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, common_valid_types, prim->name()); - MS_EXCEPTION_IF_NULL(infer_type); - auto tensor_type = infer_type->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto data_type = tensor_type->element(); - MS_EXCEPTION_IF_NULL(data_type); - return data_type; + return CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/rsqrt.cc b/mindspore/core/ops/rsqrt.cc index fe54fbe276..c1d84f9b23 100644 --- a/mindspore/core/ops/rsqrt.cc +++ b/mindspore/core/ops/rsqrt.cc @@ -38,13 +38,10 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { + if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { MS_LOG(EXCEPTION) << "nullptr"; } - std::map types; - types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/scatter_nd.cc b/mindspore/core/ops/scatter_nd.cc index 784befd7ea..3957f4722c 100644 --- a/mindspore/core/ops/scatter_nd.cc +++ b/mindspore/core/ops/scatter_nd.cc @@ -42,11 +42,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - const std::set indices_valid_types = {kNumberTypeInt32, kNumberTypeInt64}; - const std::set update_valid_types = {TypeIdToType(kObjectTypeTensorType)}; + const std::set indices_valid_types = {kInt32, kInt64}; + const std::set update_valid_types = {kTensorType}; auto indices_type = input_args[0]->BuildType(); auto update_type = input_args[1]->BuildType(); - CheckAndConvertUtils::CheckSubClass("update type", update_type, update_valid_types, prim->name()); + CheckAndConvertUtils::CheckTypeValid("update type", update_type, update_valid_types, prim->name()); CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_type, indices_valid_types, prim->name()); return input_args[1]->BuildType(); } diff --git a/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc b/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc index 053a14b984..83144e7bc5 100644 --- a/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc +++ b/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc @@ -41,16 +41,12 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError); // Infer type - auto x_type = input_args[0]->BuildType()->cast()->element(); - const std::set valid_types = { - kNumberTypeBool, kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, - kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, - kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, - kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeComplex64}; + const std::set valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8, + kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64}; std::map args; args.emplace("x_type", input_args[0]->BuildType()); args.emplace("y_type", input_args[1]->BuildType()); - CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); + auto x_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); return std::make_shared(x_type, x_shape); } diff --git a/mindspore/core/ops/sin.cc b/mindspore/core/ops/sin.cc index 975224e0bc..6bf39ea2e6 100644 --- a/mindspore/core/ops/sin.cc +++ b/mindspore/core/ops/sin.cc @@ -39,9 +39,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto infer_type = input_args[0]->BuildType()->cast()->element(); - CheckAndConvertUtils::CheckTensorTypeValid("x type", input_args[0]->BuildType(), common_valid_types, prim->name()); - return infer_type; + return CheckAndConvertUtils::CheckTensorTypeValid("x type", input_args[0]->BuildType(), common_valid_types, + prim->name()); } } // namespace diff --git a/mindspore/core/ops/smooth_l1_loss.cc b/mindspore/core/ops/smooth_l1_loss.cc index a0b4057d07..f43b58c708 100644 --- a/mindspore/core/ops/smooth_l1_loss.cc +++ b/mindspore/core/ops/smooth_l1_loss.cc @@ -47,12 +47,11 @@ AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const Pri CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError); // Infer type - auto prediction_type = input_args[0]->BuildType()->cast()->element(); - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; + const std::set valid_types = {kFloat16, kFloat32}; std::map args; args.emplace("scale", input_args[0]->BuildType()); args.emplace("bias", input_args[1]->BuildType()); - CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); + auto prediction_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); return std::make_shared(prediction_type, prediction); } diff --git a/mindspore/core/ops/softmax.cc b/mindspore/core/ops/softmax.cc index f241f68035..7e76bf0872 100644 --- a/mindspore/core/ops/softmax.cc +++ b/mindspore/core/ops/softmax.cc @@ -62,11 +62,8 @@ TypePtr SoftMaxInferType(const PrimitivePtr &prim, const std::vector types; - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; - types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - return TypeIdToType(infer_type); + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; + return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name()); } AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/softmax_cross_entropy_with_logits.cc b/mindspore/core/ops/softmax_cross_entropy_with_logits.cc index 5cc01db647..ff636678b4 100644 --- a/mindspore/core/ops/softmax_cross_entropy_with_logits.cc +++ b/mindspore/core/ops/softmax_cross_entropy_with_logits.cc @@ -46,12 +46,11 @@ AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin auto dlogits_shape = logits_shape; // Infer type - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; + const std::set valid_types = {kFloat16, kFloat32}; std::map args; args.emplace("logits_type", input_args[0]->BuildType()); args.emplace("labels_type", input_args[1]->BuildType()); - CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); - auto logits_type = input_args[0]->BuildType()->cast()->element(); + auto logits_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); auto output0 = std::make_shared(logits_type, loss_shape); auto output1 = std::make_shared(logits_type, dlogits_shape); diff --git a/mindspore/core/ops/space_to_batch.cc b/mindspore/core/ops/space_to_batch.cc index 83f55451c7..bf0ca42118 100644 --- a/mindspore/core/ops/space_to_batch.cc +++ b/mindspore/core/ops/space_to_batch.cc @@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } std::map types; types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace void SpaceToBatch::set_paddings(const std::vector> &paddings) { diff --git a/mindspore/core/ops/space_to_batch_nd.cc b/mindspore/core/ops/space_to_batch_nd.cc index f1cdd4fa2e..c66d118ce7 100644 --- a/mindspore/core/ops/space_to_batch_nd.cc +++ b/mindspore/core/ops/space_to_batch_nd.cc @@ -56,8 +56,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto infer_type = input_args[0]->BuildType()->cast()->element(); - return infer_type; + return input_args[0]->BuildType()->cast()->element(); } } // namespace diff --git a/mindspore/core/ops/sparse_to_dense.cc b/mindspore/core/ops/sparse_to_dense.cc index c459ab098d..ea8fbf896f 100644 --- a/mindspore/core/ops/sparse_to_dense.cc +++ b/mindspore/core/ops/sparse_to_dense.cc @@ -38,11 +38,7 @@ AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const Pr auto dense_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dense_shape", input_args[3]->BuildShape(), prim_name); // infer type - auto indices_type = input_args[0]->BuildType()->cast()->element(); auto values_type = input_args[1]->BuildType()->cast()->element(); - std::set valid_type = {TypeIdToType(kObjectTypeTensorType)}; - CheckAndConvertUtils::CheckSubClass("indices_type", indices_type, valid_type, prim_name); - CheckAndConvertUtils::CheckSubClass("values_type", values_type, valid_type, prim_name); return std::make_shared(values_type, dense_shape); } REGISTER_PRIMITIVE_C(kNameSparseToDense, SparseToDense); diff --git a/mindspore/core/ops/squared_difference.cc b/mindspore/core/ops/squared_difference.cc index eb0328d617..672da25fbb 100644 --- a/mindspore/core/ops/squared_difference.cc +++ b/mindspore/core/ops/squared_difference.cc @@ -37,12 +37,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - const std::set valid_types = {kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; + const std::set valid_types = {kInt32, kFloat16, kFloat32}; std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/sub.cc b/mindspore/core/ops/sub.cc index a99c0ad2e4..0a6187126e 100644 --- a/mindspore/core/ops/sub.cc +++ b/mindspore/core/ops/sub.cc @@ -42,8 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/tan.cc b/mindspore/core/ops/tan.cc index 1739332022..6d21bc1355 100644 --- a/mindspore/core/ops/tan.cc +++ b/mindspore/core/ops/tan.cc @@ -40,15 +40,10 @@ AbstractBasePtr TanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr // Infer Type auto dtype = input_args[0]->BuildType(); - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32}; - CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name); - auto tensor_type = dtype->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto element = tensor_type->element(); - MS_EXCEPTION_IF_NULL(element); - auto infer_type = std::make_shared(TypeIdToType(element->type_id())); + const std::set valid_types = {kFloat16, kFloat32, kInt32}; + auto infered_type = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name); - return std::make_shared(infer_type, infer_shape->shape()); + return std::make_shared(infered_type, infer_shape->shape()); } REGISTER_PRIMITIVE_C(kNameTan, Tan); } // namespace ops diff --git a/mindspore/core/ops/tensor_list_from_tensor.cc b/mindspore/core/ops/tensor_list_from_tensor.cc index 46e228c5d5..b6c3e4169c 100644 --- a/mindspore/core/ops/tensor_list_from_tensor.cc +++ b/mindspore/core/ops/tensor_list_from_tensor.cc @@ -48,11 +48,7 @@ abstract::ShapePtr TensorListFromTensorInferShape(const PrimitivePtr &primitive, } TypePtr TensorListFromTensorInferType(const PrimitivePtr &prim, const std::vector &input_args) { - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - TypeId infer_type = kObjectTypeTensorType; - return TypeIdToType(infer_type); + return kTensorType; } } // namespace diff --git a/mindspore/core/ops/tile.cc b/mindspore/core/ops/tile.cc index 43f605ba1e..76910c4eba 100644 --- a/mindspore/core/ops/tile.cc +++ b/mindspore/core/ops/tile.cc @@ -58,10 +58,8 @@ TypePtr TileInferType(const PrimitivePtr &prim, const std::vectorBuildType()->cast(); - std::set template_types = {TypeIdToType(kObjectTypeTensorType)}; - CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name()); - auto infer_dtype = x_dtype->element()->type_id(); - return TypeIdToType(infer_dtype); + std::set template_types = {kTensorType}; + return CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_dtype, template_types, prim->name()); } } // namespace diff --git a/mindspore/core/ops/topk.cc b/mindspore/core/ops/topk.cc index 32cfc75d8d..d7d1b727d7 100644 --- a/mindspore/core/ops/topk.cc +++ b/mindspore/core/ops/topk.cc @@ -37,10 +37,10 @@ AbstractBasePtr TopKInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt CheckAndConvertUtils::CheckInteger("top_k_infer", input_args.size(), kEqual, 2, prim_name); // Infer dtype - auto output0_type = input_args[0]->BuildType()->cast()->element(); - auto output1_type = TypeIdToType(kNumberTypeInt32); - const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; - CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); + auto output1_type = kInt32; + const std::set valid_types = {kFloat16, kFloat32}; + auto output0_type = + CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); // Infer shape auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); diff --git a/mindspore/core/ops/unsorted_segment_sum.cc b/mindspore/core/ops/unsorted_segment_sum.cc index 2fdea7bd14..fe17df609a 100644 --- a/mindspore/core/ops/unsorted_segment_sum.cc +++ b/mindspore/core/ops/unsorted_segment_sum.cc @@ -34,13 +34,6 @@ AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, con // Infer type auto x_type = input_args[0]->BuildType()->cast()->element(); - auto num_segments_type = input_args[2]->BuildType(); - auto num_segments_v = 4; - std::set valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; - CheckAndConvertUtils::CheckSubClass("input_x", input_args[0]->BuildType(), valid_x_type, prim_name); - std::set valid_segment_ids_type = {TypeIdToType(kObjectTypeTensorType)}; - CheckAndConvertUtils::CheckSubClass("segment_ids", input_args[1]->BuildType(), valid_segment_ids_type, prim_name); - // Infer shape auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterThan, 0, prim_name); @@ -59,19 +52,9 @@ AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, con } } - const std::set valid_segments_types = {TypeIdToType(kObjectTypeTensorType)}; - for (const auto &valid_segments_type : valid_segments_types) { - if (IsIdentidityOrSubclass(num_segments_type, valid_segments_type)) { - const std::set valid_num_segments_types = {kNumberTypeInt32, kNumberTypeInt64}; - CheckAndConvertUtils::CheckTensorTypeValid("num_segments", input_args[2]->BuildType(), valid_num_segments_types, - prim_name); - shp = {-1}; - } else { - CheckAndConvertUtils::CheckInteger("num_segments", num_segments_v, kGreaterThan, 0, prim_name); - shp = {num_segments_v}; - } - } - + const std::set valid_num_segments_types = {kInt32, kInt64}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("num_segments", input_args[2]->BuildType(), valid_num_segments_types, + prim_name); int64_t size_segment_ids_shp = segment_ids_shape.size(); int64_t size_x_shpe = x_shape.size(); for (int64_t i = size_segment_ids_shp; i < size_x_shpe; ++i) { diff --git a/mindspore/core/ops/unstack.cc b/mindspore/core/ops/unstack.cc index 94a1924471..371986dde3 100644 --- a/mindspore/core/ops/unstack.cc +++ b/mindspore/core/ops/unstack.cc @@ -31,8 +31,6 @@ AbstractBasePtr UnstackInfer(const abstract::AnalysisEnginePtr &, const Primitiv auto unstack_prim = primitive->cast(); MS_EXCEPTION_IF_NULL(unstack_prim); auto prim_name = unstack_prim->name(); - CheckAndConvertUtils::CheckSubClass("x", input_args[0]->BuildType(), {TypeIdToType(kObjectTypeTensorType)}, - prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); int64_t dim = x_shape.size(); int64_t axis = unstack_prim->get_axis(); diff --git a/mindspore/core/ops/zeros_like.cc b/mindspore/core/ops/zeros_like.cc index 9a38ecc988..10009e0dc0 100644 --- a/mindspore/core/ops/zeros_like.cc +++ b/mindspore/core/ops/zeros_like.cc @@ -42,16 +42,12 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - std::set tmp(common_valid_types); - tmp.insert(kNumberTypeBool); - const std::set valid_types(tmp); - if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { + std::set valid_types(common_valid_types); + valid_types.emplace(kBool); + if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { MS_LOG(EXCEPTION) << "nullptr"; } - std::map types; - types.emplace("x", input_args[0]->BuildType()); - auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - return TypeIdToType(infer_type); + return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index 9cbe3cdfae..d2734bf367 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -420,49 +420,23 @@ void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, Comp << " but got " << arg_value; } -TypeId CheckAndConvertUtils::CheckTensorTypeSame(const std::map &types, - const std::set &check_list, const std::string &prim_name) { +TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map &types, + const std::set &check_list, const std::string &prim_name) { if (types.empty()) { MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!"; } - std::set types_id; - std::ostringstream buffer; - buffer << "For " << prim_name; - for (const auto &type : types) { - MS_EXCEPTION_IF_NULL(type.second); - if (!type.second->isa()) { - MS_EXCEPTION(TypeError) << "The " << prim_name << "'s" << type.first << " input must be tensor type but got " - << type.second->ToString(); - } - auto tensor_type = type.second->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto element = tensor_type->element(); - MS_EXCEPTION_IF_NULL(element); - types_id.emplace(element->type_id()); - } - if (types_id.size() > 1) { - buffer << "'s input type is not same : "; - for (const auto &item : types) { - buffer << "[ name : " << item.first << " ,type : " << item.second->ToString() << "]"; - } - MS_EXCEPTION(TypeError) << buffer.str(); - } - if (check_list.find(*types_id.begin()) == check_list.end()) { - buffer << " type of "; - for (const auto &elem : types) { - buffer << elem.first << " should be in ["; - for (auto type_elem : check_list) { - buffer << TypeIdToType(type_elem)->ToString() << " ,"; - } - buffer << "] , but got " << types.begin()->second->ToString(); - } - MS_EXCEPTION(TypeError) << buffer.str(); + auto type = types.begin()->second; + MS_EXCEPTION_IF_NULL(type); + if (type->isa()) { + MS_EXCEPTION(TypeError) << "The " << prim_name << "'s" << types.begin()->first << " input must be a tensor but got " + << type->ToString(); } - return *types_id.begin(); + TypePtr check_type = _CheckTypeSame(types, prim_name, false); + return CheckTypeValid(types.begin()->first, check_type, check_list, prim_name); } -void CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, const TypePtr type, - const std::set &check_list, const std::string &prim_name) { +TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, const TypePtr &type, + const std::set &check_list, const std::string &prim_name) { MS_EXCEPTION_IF_NULL(type); if (!type->isa()) { MS_EXCEPTION(TypeError) << "The " << prim_name << "'s " << type_name << " input must be tensor type but got " @@ -472,37 +446,28 @@ void CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, co MS_EXCEPTION_IF_NULL(tensor_type); auto element = tensor_type->element(); MS_EXCEPTION_IF_NULL(element); - std::ostringstream buffer; - if (check_list.find(element->type_id()) == check_list.end()) { - buffer << "type of " << type_name << " should be in ["; - for (auto type_elem : check_list) { - buffer << TypeIdToType(type_elem)->ToString() << " ,"; + for (const TypePtr &item : check_list) { + if (item->isa()) { + auto item_tensor_type = item->cast(); + if (item_tensor_type->element() == nullptr) { + return element; + } } - buffer << "], but got " << type->ToString(); - MS_EXCEPTION(TypeError) << buffer.str(); } + return CheckSubClass(type_name, element, check_list, prim_name); } -void CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const TypePtr type_, - const std::set &template_types, const std::string &prim_name) { - MS_EXCEPTION_IF_NULL(type_); - bool hit = false; - for (auto template_type : template_types) { - if (type_->isa()) { - if (IsIdentidityOrSubclass(type_, template_type)) { - hit = true; - break; - } - } else if (type_->type_id() == template_type->type_id()) { - hit = true; - break; - } - } - if (!hit) { +TypePtr CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const TypePtr &type_, + const std::set &template_types, const std::string &prim_name) { + bool ok = std::any_of(template_types.begin(), template_types.end(), + [type_](const TypePtr &accept) -> bool { return IsIdentidityOrSubclass(type_, accept); }); + if (ok) { + return type_; + } else { std::string type_str = type_->ToString(); std::ostringstream buffer; buffer << "For '" << prim_name << "', the type of `" << type_name << "` should be subclass of "; - for (auto template_type : template_types) { + for (const auto &template_type : template_types) { buffer << template_type->ToString() << ","; } buffer << " but got " << type_str << "."; @@ -510,103 +475,71 @@ void CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const Typ } } -void CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map &args, - const std::set &valid_values, - const std::string &prim_name, const bool allow_mix) { - std::vector> check_results; - for (auto &iter : args) { - std::map arg = {{iter.first, iter.second}}; - check_results.push_back(_CheckArgumentType(arg, valid_values, prim_name)); - } - - std::map &arg_ = check_results[0]; - int64_t size = check_results.size(); - for (int64_t it = 1; it != size; it++) { - arg_ = _CheckTypeSame(arg_, check_results[it], prim_name, allow_mix); - } +TypePtr CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map &args, + const std::set &valid_values, + const std::string &prim_name, const bool allow_mix) { + auto arg_ = _CheckTypeSame(args, prim_name, allow_mix); + return CheckTypeValid(args.begin()->first, arg_, valid_values, prim_name); } -std::map CheckAndConvertUtils::_CheckArgumentType(const std::map &arg, - const std::set &valid_values, - const std::string &prim_name) { - std::string arg_key = arg.begin()->first; - TypePtr arg_val = arg.begin()->second; - - if (arg_val->isa()) { - auto arg_val_ = std::static_pointer_cast(arg_val); - arg_val = arg_val_->element(); +TypePtr CheckAndConvertUtils::_CheckTypeSame(const std::map &args, const std::string &prim_name, + const bool allow_mix) { + if (args.empty()) { + MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!"; } - - auto it = valid_values.find(arg_val->type_id()); - if (it == valid_values.end()) { - std::ostringstream buffer; - buffer << "For '" << prim_name << "' , the `" << arg_key << "` should be in { "; - for (auto valid_value : valid_values) { - buffer << TypeIdToType(valid_value)->ToString() << ","; + std::ostringstream buffer; + TypePtr return_type = nullptr; + buffer << "For " << prim_name; + auto first_type = args.begin()->second; + MS_EXCEPTION_IF_NULL(first_type); + bool tensor_flag = first_type->isa(); + std::set types_id; + for (const auto &elem : args) { + auto type = elem.second; + MS_EXCEPTION_IF_NULL(type); + if (!allow_mix) { + // input must be all tensor or all other type + if (!(tensor_flag ^ type->isa())) { + buffer << "For " << prim_name << "'s " + << "type is not same"; + for (const auto &error_elem : args) { + buffer << " [ name :" << error_elem.first << ", type : " << error_elem.second->ToString() << "]"; + } + MS_EXCEPTION(TypeError) << buffer.str(); + } + } + if (type->isa()) { + auto tensor_type = type->cast(); + MS_EXCEPTION_IF_NULL(tensor_type); + auto element = tensor_type->element(); + return_type = element->DeepCopy(); + MS_EXCEPTION_IF_NULL(element); + types_id.emplace(element->type_id()); + } else { + types_id.emplace(type->type_id()); + return_type = type->DeepCopy(); + } + if (types_id.size() > 1) { + buffer << "'s input type is not same : "; + for (const auto &item : args) { + buffer << "[ name : " << item.first << " ,type : " << item.second->ToString() << "]"; + } + MS_EXCEPTION(TypeError) << buffer.str(); } - buffer << " },"; - buffer << "but `" << arg_key << "`" - << "is" << arg_val->ToString() << "."; - MS_EXCEPTION(TypeError) << buffer.str(); - } - return arg; -} - -std::map CheckAndConvertUtils::_CheckTypeSame(const std::map &arg1, - const std::map &arg2, - const std::string &prim_name, - const bool allow_mix) { - std::string arg1_name = arg1.begin()->first; - TypePtr arg1_type = arg1.begin()->second; - std::string arg2_name = arg2.begin()->first; - TypePtr arg2_type = arg2.begin()->second; - bool except_flag = false; - - if (arg1_type->isa() && arg2_type->isa()) { - arg1_type = std::static_pointer_cast(arg1_type)->element(); - arg2_type = std::static_pointer_cast(arg2_type)->element(); - } else if (allow_mix) { - arg1_type = arg1_type->isa() ? std::static_pointer_cast(arg1_type)->element() : arg1_type; - arg2_type = arg2_type->isa() ? std::static_pointer_cast(arg2_type)->element() : arg2_type; - } else { - except_flag = true; - } - - if (except_flag || arg1_type->type_id() != arg2_type->type_id()) { - std::ostringstream buffer; - buffer << "For '" << prim_name << "'" - << "type of " - << "`" << arg2_name << "` should be same as " - << "`" << arg1_name << "`,"; - buffer << "but `" << arg1_name << "` is " << arg1_type->ToString() << "and `" << arg2_name << "` is " - << arg2_type->ToString() << "."; - MS_EXCEPTION(TypeError) << buffer.str(); } - return arg1; + return return_type; } -TypeId CheckAndConvertUtils::CheckTypeSame(const std::string &arg_name, const TypePtr arg_type, - const std::set &valid_type, const std::string &prim_name) { +TypePtr CheckAndConvertUtils::CheckTypeValid(const std::string &arg_name, const TypePtr &arg_type, + const std::set &valid_type, const std::string &prim_name) { if (valid_type.empty()) { MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty valid_type!"; } - // std::set types_id; - std::ostringstream buffer; - TypeId arg_type_; - arg_type_ = arg_type->isa() ? std::static_pointer_cast(arg_type)->generic_type_id() - : arg_type->type_id(); - - auto it = valid_type.find(arg_type_); - if (it == valid_type.end()) { - buffer << "For" << prim_name << ", the '" << arg_name << "' should be {' one of '" << valid_type.size() << "'}"; - for (auto type : valid_type) { - buffer << "{" << TypeIdLabel(type); - } - buffer << "},"; - buffer << "but got " << arg_type->ToString() << "."; - MS_EXCEPTION(TypeError) << buffer.str(); + MS_EXCEPTION_IF_NULL(arg_type); + if (arg_type->isa()) { + return CheckTensorTypeValid(arg_name, arg_type, valid_type, prim_name); } - return arg_type_; + return CheckSubClass(arg_name, arg_type, valid_type, prim_name); } bool CheckAndConvertUtils::CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name, diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index bb667e5436..9865b91cf2 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -22,6 +22,8 @@ #include #include #include +#include +#include "abstract/param_validator.h" #include "base/base.h" #include "ir/anf.h" #include "ir/dtype/type_id.h" @@ -270,17 +272,34 @@ class CheckAndConvertUtils { MS_EXCEPTION(exception_type) << buffer.str(); } - static TypeId CheckTensorTypeSame(const std::map &types, const std::set &check_list, - const std::string &prim_name); - static void CheckTensorTypeValid(const std::string &type_name, const TypePtr type, const std::set &check_list, - const std::string &prim_name); - static void CheckSubClass(const std::string &type_name, const TypePtr type, const std::set &template_types, - const std::string &prim_name); - static void CheckScalarOrTensorTypesSame(const std::map &args, - const std::set &valid_values, const std::string &prim_name, - bool allow_mix = false); - static TypeId CheckTypeSame(const std::string &arg_name, const TypePtr arg_type, const std::set &valid_type, - const std::string &prim_name); + template + static std::shared_ptr CheckArgs(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) { + if (index >= args_spec_list.size()) { + MS_EXCEPTION(ValueError) << op << " evaluator args list index out of bound, size " << args_spec_list.size() + << ", index " << index; + } + auto args_spec = args_spec_list[index]; + MS_EXCEPTION_IF_NULL(args_spec); + auto arg = dyn_cast(args_spec); + if (arg == nullptr) { + MS_EXCEPTION(TypeError) << "Operator " << op << " input[" << index << "] should be " + << abstract::ReportNameTraits::name << ", but got " + << args_spec_list[index]->BuildType()->ToString() << "."; + } + return arg; + } + + static TypePtr CheckTensorTypeSame(const std::map &types, const std::set &check_list, + const std::string &prim_name); + static TypePtr CheckTensorTypeValid(const std::string &type_name, const TypePtr &type, + const std::set &check_list, const std::string &prim_name); + static TypePtr CheckSubClass(const std::string &type_name, const TypePtr &type, + const std::set &template_types, const std::string &prim_name); + static TypePtr CheckScalarOrTensorTypesSame(const std::map &args, + const std::set &valid_values, const std::string &prim_name, + bool allow_mix = false); + static TypePtr CheckTypeValid(const std::string &arg_name, const TypePtr &arg_type, + const std::set &valid_type, const std::string &prim_name); static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); static void ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); @@ -292,12 +311,8 @@ class CheckAndConvertUtils { private: static bool IsEqualVector(const std::vector &vec_1, const std::vector &vec_2); - static std::map _CheckArgumentType(const std::map &arg, - const std::set &valid_values, - const std::string &prim_name); - static std::map _CheckTypeSame(const std::map &arg1, - const std::map &arg2, - const std::string &prim_name, const bool allow_mix); + static TypePtr _CheckTypeSame(const std::map &args, const std::string &prim_name, + const bool allow_mix); }; } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_