|
|
|
@ -87,23 +87,8 @@ AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|
|
|
|
auto global_step_type = input_args[3]->BuildType();
|
|
|
|
|
|
|
|
|
|
std::map<std::string, TypePtr> args = {{"x", x_type}, {"mean", mean_type}, {"variance", variance_type}};
|
|
|
|
|
CheckAndConvertUtils::CheckTensorTypeSame(args, {kNumberTypeFloat16, kNumberTypeFloat32}, op_name);
|
|
|
|
|
CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kNumberTypeInt32}, op_name);
|
|
|
|
|
|
|
|
|
|
auto tensor_type0 = x_type->cast<TensorTypePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_type0);
|
|
|
|
|
auto element0 = tensor_type0->element();
|
|
|
|
|
|
|
|
|
|
auto tensor_type1 = mean_type->cast<TensorTypePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_type1);
|
|
|
|
|
auto element1 = tensor_type1->element();
|
|
|
|
|
|
|
|
|
|
auto tensor_type2 = variance_type->cast<TensorTypePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_type2);
|
|
|
|
|
auto element2 = tensor_type2->element();
|
|
|
|
|
|
|
|
|
|
CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "mean_type", element1->type_id(), op_name);
|
|
|
|
|
CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "variance_type", element2->type_id(), op_name);
|
|
|
|
|
auto element0 = CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name);
|
|
|
|
|
CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kInt32}, op_name);
|
|
|
|
|
|
|
|
|
|
auto output = std::make_shared<abstract::AbstractTensor>(element0, mean_shape);
|
|
|
|
|
AbstractBasePtrList output1 = {output, output, output, output};
|
|
|
|
|