|
|
@ -80,7 +80,7 @@ class ActivationOp : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
|
|
|
ctx->ShareDim("X", /*->*/ "Out");
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -91,12 +91,26 @@ class ActivationOp : public framework::OperatorWithKernel {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ActivationOpInferVarType : public framework::VarTypeInference {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
void operator()(const framework::OpDesc& op_desc,
|
|
|
|
|
|
|
|
framework::BlockDesc* block) const override {
|
|
|
|
|
|
|
|
auto x_name = op_desc.Input("X")[0];
|
|
|
|
|
|
|
|
auto out_name = op_desc.Output("Out")[0];
|
|
|
|
|
|
|
|
auto& x = block->FindRecursiveOrCreateVar(x_name);
|
|
|
|
|
|
|
|
auto& out = block->FindRecursiveOrCreateVar(out_name);
|
|
|
|
|
|
|
|
out.SetType(x.GetType());
|
|
|
|
|
|
|
|
out.SetDataType(x.GetDataType());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class ActivationOpGrad : public framework::OperatorWithKernel {
|
|
|
|
class ActivationOpGrad : public framework::OperatorWithKernel {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
|
|
|
|
ctx->ShareDim("Out", framework::GradVarName("X"));
|
|
|
|
|
|
|
|
ctx->ShareLoD("Out", framework::GradVarName("X"));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
@ -525,12 +539,14 @@ namespace ops = paddle::operators;
|
|
|
|
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
|
|
|
|
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
|
|
|
|
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
|
|
|
|
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
|
|
|
|
::paddle::operators::OP_NAME##OpMaker, \
|
|
|
|
::paddle::operators::OP_NAME##OpMaker, \
|
|
|
|
|
|
|
|
::paddle::operators::ActivationOpInferVarType, \
|
|
|
|
::paddle::operators::OP_NAME##GradMaker); \
|
|
|
|
::paddle::operators::OP_NAME##GradMaker); \
|
|
|
|
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
|
|
|
|
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
|
|
|
|
|
|
|
|
|
|
|
|
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
|
|
|
|
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
|
|
|
|
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
|
|
|
|
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
|
|
|
|
::paddle::operators::OP_NAME##OpMaker, \
|
|
|
|
::paddle::operators::OP_NAME##OpMaker, \
|
|
|
|
|
|
|
|
::paddle::operators::ActivationOpInferVarType, \
|
|
|
|
::paddle::framework::DefaultGradOpDescMaker<true>); \
|
|
|
|
::paddle::framework::DefaultGradOpDescMaker<true>); \
|
|
|
|
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
|
|
|
|
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
|
|
|
|
|
|
|
|
|
|
|
|