|
|
|
@ -59,15 +59,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
{MKLDNN_ARG_DST, *dst_memory},
|
|
|
|
|
{MKLDNN_ARG_WORKSPACE, *workspace_memory}});
|
|
|
|
|
} else {
|
|
|
|
|
// mid has to be allocated and filled
|
|
|
|
|
// k to pass LRN unit tests
|
|
|
|
|
// TODO(jczaja): Disable checking mid in unit tests (Require API change)
|
|
|
|
|
mid->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
|
|
|
|
|
const float k = ctx.Attr<float>("k");
|
|
|
|
|
e_mid = e_mid.constant(k);
|
|
|
|
|
mid->set_format(platform::GetMKLDNNFormat(*dst_memory));
|
|
|
|
|
|
|
|
|
|
lrn_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
|
|
|
|
|
{MKLDNN_ARG_DST, *dst_memory}});
|
|
|
|
|
}
|
|
|
|
@ -85,7 +76,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
const bool is_float_type = std::is_same<T, float>::value;
|
|
|
|
|
PADDLE_ENFORCE_EQ(is_float_type, true,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"DNNL LRN GradOpKernl must use float data."));
|
|
|
|
|
"DNNL LRN GradOpKernel must use float data."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
|
|
|
|
|
paddle::platform::errors::PreconditionNotMet(
|
|
|
|
|
"Operator DNNL LRNGrad must use CPUPlace"));
|
|
|
|
|