|
|
|
@ -15,6 +15,9 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/data_norm_op.h"
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/framework/data_layout.h"
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -94,6 +97,13 @@ class DataNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
|
|
|
|
|
framework::LibraryType library = framework::LibraryType::kPlain;
|
|
|
|
|
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library = framework::LibraryType::kMKLDNN;
|
|
|
|
|
layout = framework::DataLayout::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
|
|
|
|
|
library);
|
|
|
|
@ -251,6 +261,14 @@ class DataNormGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
framework::LibraryType library = framework::LibraryType::kPlain;
|
|
|
|
|
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library = framework::LibraryType::kMKLDNN;
|
|
|
|
|
layout = framework::DataLayout::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
|
|
|
|
ctx.GetPlace(), layout, library);
|
|
|
|
|
}
|
|
|
|
|