|
|
@ -107,20 +107,20 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Infer type
|
|
|
|
// Infer type
|
|
|
|
auto input_x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
|
|
|
|
|
|
|
auto scale_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
|
|
|
|
auto scale_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
|
|
|
|
auto bias_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
|
|
|
|
auto bias_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
|
|
|
|
|
|
|
|
|
|
|
|
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
|
|
|
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
|
|
|
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
|
|
|
|
auto input_x_type =
|
|
|
|
|
|
|
|
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
|
|
|
|
std::map<std::string, TypePtr> args;
|
|
|
|
std::map<std::string, TypePtr> args;
|
|
|
|
args.emplace("scale", input_args[1]->BuildType());
|
|
|
|
args.emplace("scale", input_args[1]->BuildType());
|
|
|
|
args.emplace("bias", input_args[2]->BuildType());
|
|
|
|
args.emplace("bias", input_args[2]->BuildType());
|
|
|
|
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
|
|
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
|
|
|
std::map<std::string, TypePtr> args_moving;
|
|
|
|
std::map<std::string, TypePtr> args_moving;
|
|
|
|
args_moving.emplace("scale", input_args[2]->BuildType());
|
|
|
|
args_moving.emplace("scale", input_args[2]->BuildType());
|
|
|
|
args_moving.emplace("bias", input_args[3]->BuildType());
|
|
|
|
args_moving.emplace("bias", input_args[3]->BuildType());
|
|
|
|
CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name);
|
|
|
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name);
|
|
|
|
|
|
|
|
|
|
|
|
auto output0 = std::make_shared<abstract::AbstractTensor>(input_x_type, input_x);
|
|
|
|
auto output0 = std::make_shared<abstract::AbstractTensor>(input_x_type, input_x);
|
|
|
|
auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);
|
|
|
|
auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);
|
|
|
|