|
|
|
@ -119,6 +119,26 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
template struct LRNGradFunctor<platform::CPUDeviceContext, float>;
|
|
|
|
|
template struct LRNGradFunctor<platform::CPUDeviceContext, double>;
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
framework::OpKernelType GetExpectedLRNKernel(
|
|
|
|
|
const framework::ExecutionContext& ctx) {
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
|
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
|
|
|
|
layout_, library_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class LRNOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
@ -140,21 +160,8 @@ class LRNOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
|
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
|
|
|
|
layout_, library_);
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return GetExpectedLRNKernel(ctx);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -261,21 +268,8 @@ class LRNOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
|
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
|
|
|
|
layout_, library_);
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return GetExpectedLRNKernel(ctx);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|