|
|
|
@ -33,10 +33,12 @@ class MKLDNNActivationKernel
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
const auto *x = ctx.Input<Tensor>("X");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
|
|
|
|
|
"Wrong layout set for X tensor");
|
|
|
|
|
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
"Wrong format set for X tensor");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x->layout(), DataLayout::kMKLDNN,
|
|
|
|
|
platform::errors::InvalidArgument("Wrong layout set for X tensor"));
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
x->format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
platform::errors::InvalidArgument("Wrong format set for X tensor"));
|
|
|
|
|
|
|
|
|
|
Functor functor;
|
|
|
|
|
functor(ctx);
|
|
|
|
@ -50,9 +52,11 @@ class MKLDNNActivationGradKernel
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
|
|
|
|
|
"Wrong layout set for Input OutGrad tensor");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Wrong layout set for Input OutGrad tensor"));
|
|
|
|
|
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
"Wrong format set for Input OutGrad tensor");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Wrong format set for Input OutGrad tensor"));
|
|
|
|
|
|
|
|
|
|
Functor functor;
|
|
|
|
|
functor(ctx);
|
|
|
|
@ -82,7 +86,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
|
|
|
|
|
"Input dim must be with 2, 3 or 4");
|
|
|
|
|
platform::errors::Unimplemented("Input dim must be with 2, 3 or 4"));
|
|
|
|
|
|
|
|
|
|
auto src_tz = framework::vectorize<int64_t>(x->dims());
|
|
|
|
|
|
|
|
|
|