|
|
|
@ -27,10 +27,12 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
|
|
|
|
|
const bool is_float_type = std::is_same<T, float>::value;
|
|
|
|
|
PADDLE_ENFORCE(is_float_type, "MKLDNN LRN must use float data.");
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
|
|
|
|
|
"MKLDNN LRN must use CPUPlace.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
is_float_type, true,
|
|
|
|
|
platform::errors::PreconditionNotMet("DNNL LRN must use float data."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
|
|
|
|
|
paddle::platform::errors::PreconditionNotMet(
|
|
|
|
|
"Operator DNNL LRN must use CPUPlace"));
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
|
|
|
|
|
|
|
|
|
auto x = ctx.Input<Tensor>("X");
|
|
|
|
@ -93,12 +95,16 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
|
|
|
|
|
const bool is_float_type = std::is_same<T, float>::value;
|
|
|
|
|
PADDLE_ENFORCE(is_float_type, "MKLDNN LRN must use float data.");
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
|
|
|
|
|
"MKLDNN LRN must use CPUPlace.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
!ctx.Attr<bool>("is_test"),
|
|
|
|
|
"is_test attribute should be set to False in training phase.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(is_float_type, true,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"DNNL LRN GradOpKernl must use float data."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
|
|
|
|
|
paddle::platform::errors::PreconditionNotMet(
|
|
|
|
|
"Operator DNNL LRNGrad must use CPUPlace"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx.Attr<bool>("is_test"), false,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"is_test attribute should be set to False in training phase."));
|
|
|
|
|
|
|
|
|
|
auto x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto mid = ctx.Input<Tensor>("MidOut");
|
|
|
|
|