|
|
|
@ -87,9 +87,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto input_data_type =
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type());
|
|
|
|
|
// For float or float16 input tensor, the type of the scale, bias, mean,
|
|
|
|
|
// and var tensors should both be float.
|
|
|
|
|
// By default, the type of the scale, bias, mean,
|
|
|
|
|
// and var tensors should both be float. (For float or float16 input tensor)
|
|
|
|
|
// or double (For double input tensor).
|
|
|
|
|
auto bn_param_type = framework::proto::VarType::FP32;
|
|
|
|
|
if (input_data_type == framework::proto::VarType::FP64) {
|
|
|
|
|
bn_param_type = framework::proto::VarType::FP64;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(bn_param_type,
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
|
|
|
|
|
"Scale input should be of float type");
|
|
|
|
@ -492,8 +496,9 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
|
|
|
|
|
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
batch_norm,
|
|
|
|
|
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
batch_norm_grad,
|
|
|
|
|
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|