|
|
|
@ -15,6 +15,9 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/batch_norm_op.h"
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/framework/data_layout.h"
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -87,9 +90,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto input_data_type =
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type());
|
|
|
|
|
// For float or float16 input tensor, the type of the scale, bias, mean,
|
|
|
|
|
// and var tensors should both be float.
|
|
|
|
|
// By default, the type of the scale, bias, mean,
|
|
|
|
|
// and var tensors should both be float. (For float or float16 input tensor)
|
|
|
|
|
// or double (For double input tensor).
|
|
|
|
|
auto bn_param_type = framework::proto::VarType::FP32;
|
|
|
|
|
if (input_data_type == framework::proto::VarType::FP64) {
|
|
|
|
|
bn_param_type = framework::proto::VarType::FP64;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(bn_param_type,
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
|
|
|
|
|
"Scale input should be of float type");
|
|
|
|
@ -102,7 +109,18 @@ class BatchNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
|
|
|
|
|
ctx.Input<Tensor>("Variance")->type()),
|
|
|
|
|
"Variance input should be of float type");
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
|
|
|
|
|
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
|
|
|
|
|
library_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -147,6 +165,9 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"Variance of the current mini batch, "
|
|
|
|
|
"will apply to output when training")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddAttr<bool>("use_mkldnn",
|
|
|
|
|
"(bool, default false) Only used in mkldnn kernel")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Batch Normalization.
|
|
|
|
|
|
|
|
|
@ -345,8 +366,19 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
if (t == nullptr) {
|
|
|
|
|
PADDLE_THROW("can't find Y@GRAD");
|
|
|
|
|
}
|
|
|
|
|
return framework::OpKernelType(framework::ToDataType(t->type()),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
|
|
|
|
|
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
|
|
|
|
layout, library_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -470,6 +502,7 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
|
|
|
|
|
|
|
|
|
|
op->SetInput("Scale", Input("Scale"));
|
|
|
|
|
op->SetInput("Bias", Input("Bias"));
|
|
|
|
|
op->SetInput("SavedMean", Output("SavedMean"));
|
|
|
|
|
op->SetInput("SavedVariance", Output("SavedVariance"));
|
|
|
|
|
|
|
|
|
@ -492,8 +525,9 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
|
|
|
|
|
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
batch_norm,
|
|
|
|
|
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
batch_norm_grad,
|
|
|
|
|
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|