|
|
|
@ -186,9 +186,8 @@ framework::OpKernelType BatchNormOp::GetKernelTypeForVar(
|
|
|
|
|
// Some models may have intentionally set "AnyLayout" for pool
|
|
|
|
|
// op. Treat this as NCHW (default data_format value)
|
|
|
|
|
if (dl != framework::DataLayout::kAnyLayout) {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
expected_kernel_type.data_type_, tensor.place(),
|
|
|
|
|
framework::StringToDataLayout(data_layout));
|
|
|
|
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
|
|
|
|
tensor.place(), dl);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
@ -465,8 +464,11 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
const auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
const DataLayout data_layout = framework::StringToDataLayout(
|
|
|
|
|
ctx->Attrs().Get<std::string>("data_layout"));
|
|
|
|
|
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
|
|
|
|
|
: x_dims[x_dims.size() - 1]);
|
|
|
|
|
|
|
|
|
|
const int C =
|
|
|
|
|
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
|
|
|
|
|
? x_dims[1]
|
|
|
|
|
: x_dims[x_dims.size() - 1]);
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
|
|
|
|
@ -499,12 +501,6 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
// TODO(jczaja): Add support for NHWC
|
|
|
|
|
const std::string data_layout = ctx.Attr<std::string>("data_layout");
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
data_layout, "NHWC",
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"Batch Norm MKLDNN grad does not support NHWC data format yet"));
|
|
|
|
|
library = framework::LibraryType::kMKLDNN;
|
|
|
|
|
layout = framework::DataLayout::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
@ -515,6 +511,31 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
|
|
|
|
|
library);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar(
|
|
|
|
|
const std::string &var_name, const Tensor &tensor,
|
|
|
|
|
const framework::OpKernelType &expected_kernel_type) const {
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
// Only input require reshaping, weights and
|
|
|
|
|
// bias are having shape in NCHW order
|
|
|
|
|
if (((var_name == "X") || (var_name == framework::GradVarName("Y"))) &&
|
|
|
|
|
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
|
|
|
|
|
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
|
|
|
|
|
auto attrs = Attrs();
|
|
|
|
|
auto ar = paddle::framework::AttrReader(attrs);
|
|
|
|
|
const std::string data_layout = ar.Get<std::string>("data_layout");
|
|
|
|
|
auto dl = framework::StringToDataLayout(data_layout);
|
|
|
|
|
// Some models may have intentionally set "AnyLayout" for pool
|
|
|
|
|
// op. Treat this as NCHW (default data_format value)
|
|
|
|
|
if (dl != framework::DataLayout::kAnyLayout) {
|
|
|
|
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
|
|
|
|
tensor.place(), dl);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
|
|
|
|
tensor.place(), tensor.layout());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class BatchNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
: public framework::OpKernel<T> {
|
|
|
|
|