|
|
|
@ -287,6 +287,15 @@ class FusedElemwiseActivationGradMaker
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class FusedElemwiseAddActivationMaker : public FusedElemwiseActivationMaker {};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class FusedElemwiseAddActivationGradMaker
|
|
|
|
|
: public FusedElemwiseActivationGradMaker<T> {
|
|
|
|
|
public:
|
|
|
|
|
using FusedElemwiseActivationGradMaker<T>::FusedElemwiseActivationGradMaker;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
@ -367,6 +376,53 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class FusedElemwiseAddActivationOp : public FusedElemwiseActivationOp {
|
|
|
|
|
public:
|
|
|
|
|
using FusedElemwiseActivationOp::FusedElemwiseActivationOp;
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
FusedElemwiseActivationOp::InferShape(ctx);
|
|
|
|
|
std::vector<std::string> functor_names =
|
|
|
|
|
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
|
|
|
|
|
bool elemntwise_add_detected = false;
|
|
|
|
|
for (auto names : functor_names) {
|
|
|
|
|
if (names == "elementwise_add") {
|
|
|
|
|
elemntwise_add_detected = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
elemntwise_add_detected, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"When the FusedElemwiseAddActivationOp Is used in fused pass, the "
|
|
|
|
|
"elementwise_add Op must be"
|
|
|
|
|
"detected and used, Please check the fuse pass pattern"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class FusedElemwiseAddActivationOpGrad : public FusedElemwiseActivationOpGrad {
|
|
|
|
|
public:
|
|
|
|
|
using FusedElemwiseActivationOpGrad::FusedElemwiseActivationOpGrad;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
FusedElemwiseActivationOpGrad::InferShape(ctx);
|
|
|
|
|
std::vector<std::string> functor_names =
|
|
|
|
|
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
|
|
|
|
|
bool elemntwise_add_grad_detected = false;
|
|
|
|
|
for (auto names : functor_names) {
|
|
|
|
|
if (names == "elementwise_add_grad") {
|
|
|
|
|
elemntwise_add_grad_detected = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
elemntwise_add_grad_detected, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"When the FusedElemwiseAddActivationOpGrad Is used in fused pass, "
|
|
|
|
|
"the elementwise_add_grad Op must be"
|
|
|
|
|
"detected and used, Please check the fuse pass pattern"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
|
|
|
|
|
FusedElemwiseAddActivationNoNeddBufVarInferer, "X", "Y");
|
|
|
|
|
} // namespace operators
|
|
|
|
@ -397,13 +453,13 @@ REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
|
|
|
|
|
// for memory optimization, we register the fused_elemwise_add_activation OP
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
fused_elemwise_add_activation, ops::FusedElemwiseActivationOp,
|
|
|
|
|
ops::FusedElemwiseActivationMaker,
|
|
|
|
|
ops::FusedElemwiseActivationGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::FusedElemwiseActivationGradMaker<paddle::imperative::OpBase>);
|
|
|
|
|
fused_elemwise_add_activation, ops::FusedElemwiseAddActivationOp,
|
|
|
|
|
ops::FusedElemwiseAddActivationMaker,
|
|
|
|
|
ops::FusedElemwiseAddActivationGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::FusedElemwiseAddActivationGradMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OPERATOR(fused_elemwise_add_activation_grad,
|
|
|
|
|
ops::FusedElemwiseAddActivationNoNeddBufVarInferer,
|
|
|
|
|
ops::FusedElemwiseActivationOpGrad);
|
|
|
|
|
ops::FusedElemwiseAddActivationOpGrad);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
fused_elemwise_add_activation,
|
|
|
|
|