|
|
|
@ -67,7 +67,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
mid->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
const int n = ctx.Attr<int>("n");
|
|
|
|
|
const float alpha = ctx.Attr<float>("alpha");
|
|
|
|
|
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
|
|
|
|
|
const float beta = ctx.Attr<float>("beta");
|
|
|
|
|
const float k = ctx.Attr<float>("k");
|
|
|
|
|
const bool is_test = ctx.Attr<bool>("is_test");
|
|
|
|
@ -156,7 +156,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
|
|
|
|
|
|
|
|
|
|
const int n = ctx.Attr<int>("n");
|
|
|
|
|
const float alpha = ctx.Attr<float>("alpha");
|
|
|
|
|
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
|
|
|
|
|
const float beta = ctx.Attr<float>("beta");
|
|
|
|
|
const float k = ctx.Attr<float>("k");
|
|
|
|
|
|
|
|
|
|