MKLDNN layout: Support for activation operator

revert-11610-move_hooks
mozga-intel 7 years ago
parent d734595978
commit 792d3b2406

File diff suppressed because it is too large Load Diff

@ -19,18 +19,20 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ using paddle::framework::Tensor;
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \ #define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
public: \ class OP_NAME##OpMaker \
void Make() override { \ : public ::paddle::framework::OpProtoAndCheckerMaker { \
AddInput("X", "Input of " #OP_NAME " operator"); \ public: \
AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \ void Make() override { \
AddAttr<bool>("use_mkldnn", \ AddInput("X", "Input of " #OP_NAME " operator"); \
"(default false) Only used in mkldnn kernel") \ AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \
.SetDefault(false); \ AddAttr<bool>("use_mkldnn", \
AddComment(OP_COMMENT); \ "(bool, default false) Only used in mkldnn kernel") \
} \ .SetDefault(false); \
AddComment(#OP_COMMENT); \
} \
} }
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \ #define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
@ -58,7 +60,6 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper, const framework::OperatorWithKernel& oper,
const std::string& name) { const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn"); auto it = oper.Attrs().find("use_mkldnn");
@ -82,6 +83,7 @@ class ActivationOp : public framework::OperatorWithKernel {
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X"); return GetKernelType(ctx, *this, "X");
@ -96,6 +98,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
} }
protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out"); return GetKernelType(ctx, *this, "Out");

Loading…
Cancel
Save