change some check type api

pull/13901/head
LianLiguang 4 years ago
parent b6bf797ae7
commit d9f4659cfd

@ -46,8 +46,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
std::map<std::string, TypePtr> types;
types.emplace("input_x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace

@ -42,14 +42,10 @@ abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::ve
auto m_type = input_args[1]->BuildType();
auto v_type = input_args[2]->BuildType();
auto grad_type = input_args[9]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name);
auto infer_var_type = var_type->cast<TensorTypePtr>()->element();
auto infer_m_type = m_type->cast<TensorTypePtr>()->element();
auto infer_v_type = v_type->cast<TensorTypePtr>()->element();
auto infer_var_type = CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name);
auto infer_m_type = CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name);
auto infer_v_type = CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name);
// auto infer_grad_type = grad_type->cast<TensorTypePtr>()->element();
auto output0 = std::make_shared<abstract::AbstractTensor>(infer_var_type, var_shape);
auto output1 = std::make_shared<abstract::AbstractTensor>(infer_m_type, m_shape);

@ -40,8 +40,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace

@ -56,12 +56,10 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
}
types.emplace(elementi, elements[i]->BuildType());
}
std::set<TypeId> valid_types = common_valid_types;
valid_types.insert(kNumberTypeBool);
std::set<TypePtr> valid_types = common_valid_types;
valid_types.insert(kBool);
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(infer_type),
std::make_shared<abstract::Shape>(element0_shape));
return std::make_shared<abstract::AbstractTensor>(infer_type, std::make_shared<abstract::Shape>(element0_shape));
}
REGISTER_PRIMITIVE_C(kNameAddN, AddN);
} // namespace ops

@ -68,7 +68,7 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
auto l_type = input_args[2]->BuildType();
auto g_type = input_args[3]->BuildType();
auto m_type = input_args[4]->BuildType();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name);
std::map<std::string, TypePtr> args;

@ -62,9 +62,6 @@ AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const Primitive
// Infer type
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)};
CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(x_dtype, std::make_shared<abstract::Shape>(out_shape));
}
REGISTER_PRIMITIVE_C(kNameArgMin, ArgMin);

@ -36,14 +36,8 @@ AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
// Infer Type
auto dtype = input_args[0]->BuildType();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32};
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
auto tensor_type = dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
MS_EXCEPTION_IF_NULL(element);
auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id()));
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32};
auto infer_type = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
}
REGISTER_PRIMITIVE_C(kNameAsin, Asin);

@ -61,15 +61,15 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive
condition = input_args[0]->BuildType();
}
std::vector<int64_t> output_shape = {1};
std::set<TypeId> local_bool = {kNumberTypeBool};
std::set<TypePtr> local_bool = {kBool};
std::map<std::string, TypePtr> args = {{"condition", condition}};
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name);
auto inputs_type = input_args[1]->BuildType()->cast<TuplePtr>()->elements();
for (auto dtype : inputs_type) {
std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)};
std::set<TypePtr> template_types = {kTensorType};
CheckAndConvertUtils::CheckSubClass("input", dtype, template_types, op_name);
}
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeInt32), output_shape);
return std::make_shared<abstract::AbstractTensor>(kInt32, output_shape);
}
REGISTER_PRIMITIVE_C(kNameAssert, Assert);
} // namespace ops

@ -38,8 +38,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
types.emplace("x", input_args[0]->BuildType());
types.emplace("w", input_args[1]->BuildType());
// check_scalar_or_tensor_types_same
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd");
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd");
}
} // namespace
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

@ -34,15 +34,9 @@ AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
// Infer Type
auto dtype = input_args[0]->BuildType();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32};
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
auto tensor_type = dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
MS_EXCEPTION_IF_NULL(element);
auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id()));
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32};
auto element = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(element, infer_shape->shape());
}
REGISTER_PRIMITIVE_C(kNameAtan, Atan);
} // namespace ops

@ -107,20 +107,20 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
}
// Infer type
auto input_x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
auto scale_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
auto bias_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto input_x_type =
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
std::map<std::string, TypePtr> args;
args.emplace("scale", input_args[1]->BuildType());
args.emplace("bias", input_args[2]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
std::map<std::string, TypePtr> args_moving;
args_moving.emplace("scale", input_args[2]->BuildType());
args_moving.emplace("bias", input_args[3]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name);
auto output0 = std::make_shared<abstract::AbstractTensor>(input_x_type, input_x);
auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);

@ -87,23 +87,8 @@ AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const Pr
auto global_step_type = input_args[3]->BuildType();
std::map<std::string, TypePtr> args = {{"x", x_type}, {"mean", mean_type}, {"variance", variance_type}};
CheckAndConvertUtils::CheckTensorTypeSame(args, {kNumberTypeFloat16, kNumberTypeFloat32}, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kNumberTypeInt32}, op_name);
auto tensor_type0 = x_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type0);
auto element0 = tensor_type0->element();
auto tensor_type1 = mean_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type1);
auto element1 = tensor_type1->element();
auto tensor_type2 = variance_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type2);
auto element2 = tensor_type2->element();
CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "mean_type", element1->type_id(), op_name);
CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "variance_type", element2->type_id(), op_name);
auto element0 = CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kInt32}, op_name);
auto output = std::make_shared<abstract::AbstractTensor>(element0, mean_shape);
AbstractBasePtrList output1 = {output, output, output, output};

@ -54,7 +54,8 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types,
prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);

@ -55,8 +55,7 @@ TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
std::map<std::string, TypePtr> types;
types.emplace("input_x", input_args[0]->BuildType());
types.emplace("bias", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace
void BiasAdd::set_format(const Format &format) {

@ -57,7 +57,7 @@ TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector<A
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x_shape", input_args[0]->BuildType());
types.emplace("y_shape", input_args[1]->BuildType());
@ -67,7 +67,7 @@ TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector<A
types.emplace("weight_shape", input_args[2]->BuildType());
infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
return TypeIdToType(infer_type);
return infer_type;
}
} // namespace

@ -56,7 +56,7 @@ AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const Primit
// infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::vector<TypePtr> output_types;
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kInt8, kInt32, kFloat16, kFloat32};
for (size_t i = 0; i < input_args.size(); i++) {
auto out_type = input_args[i]->BuildType()->cast<TensorTypePtr>()->element();
output_types.push_back(out_type);

@ -57,11 +57,10 @@ TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<Abstrac
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)};
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>();
std::set<TypePtr> template_types = {kTensorType};
CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name());
auto infer_dtype = input_args[0]->BuildType()->type_id();
return TypeIdToType(infer_dtype);
return x_dtype->element();
}
} // namespace

@ -33,13 +33,9 @@ AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
MS_EXCEPTION_IF_NULL(item);
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Ceil");
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto infer_type = input_args[0]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name());
MS_EXCEPTION_IF_NULL(infer_type);
auto tensor_type = infer_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto data_type = tensor_type->element();
auto data_type = CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name());
MS_EXCEPTION_IF_NULL(data_type);
return std::make_shared<abstract::AbstractTensor>(data_type, x_shape);
}

@ -74,9 +74,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, all_types, prim_name);
auto ret_shape = element0_shape;
ret_shape[axis] = all_shp;
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(infer_type),
std::make_shared<abstract::Shape>(ret_shape));
return std::make_shared<abstract::AbstractTensor>(infer_type, std::make_shared<abstract::Shape>(ret_shape));
}
REGISTER_PRIMITIVE_C(kNameConcat, Concat);
} // namespace ops

@ -42,8 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace

@ -107,16 +107,11 @@ TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeFloat16,
kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("w", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (infer_type == kNumberTypeInt8) {
return TypeIdToType(kNumberTypeInt32);
}
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, const PadMode &pad_mode,

@ -40,12 +40,11 @@ TypePtr Conv2dTransposeInferType(const PrimitivePtr &prim, const std::vector<Abs
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kInt8, kInt32, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("doutput_dtye", input_args[0]->BuildType());
types.emplace("w_dtype", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace

@ -40,8 +40,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace

@ -31,8 +31,8 @@ AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &,
// auto input = input_args[0];
// Infer type
auto output0_type = TypeIdToType(kNumberTypeInt32);
auto output1_type = TypeIdToType(kNumberTypeFloat32);
auto output0_type = kInt32;
auto output1_type = kFloat32;
// Infer shape
std::vector<int64_t> out_shape;

@ -47,14 +47,14 @@ AbstractBasePtr CustomPredictInfer(const abstract::AnalysisEnginePtr &, const Pr
MS_EXCEPTION_IF_NULL(primitive);
auto CustomPredict_prim = primitive->cast<PrimCustomPredictPtr>();
MS_EXCEPTION_IF_NULL(CustomPredict_prim);
for (auto input : input_args) {
for (const auto &input : input_args) {
MS_EXCEPTION_IF_NULL(input);
}
std::vector<int64_t> shape;
shape.push_back(CustomPredict_prim->get_output_num());
auto output0 = std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeInt32), shape);
auto output1 = std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeFloat32), shape);
auto output0 = std::make_shared<abstract::AbstractTensor>(kInt32, shape);
auto output1 = std::make_shared<abstract::AbstractTensor>(kFloat32, shape);
AbstractBasePtrList output = {output0, output1};
return std::make_shared<abstract::AbstractTuple>(output);
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save