|
|
|
@ -604,21 +604,48 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
|
|
|
|
|
if (HasOutputs("DX") && ctx->HasOutput("DX")) {
|
|
|
|
|
if (ctx->HasOutput("DX")) {
|
|
|
|
|
ctx->ShareDim("X", "DX");
|
|
|
|
|
ctx->ShareLoD("X", "DX");
|
|
|
|
|
}
|
|
|
|
|
if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) {
|
|
|
|
|
if (ctx->HasOutput("DDOut")) {
|
|
|
|
|
ctx->ShareDim("X", "DDOut");
|
|
|
|
|
ctx->ShareLoD("X", "DDOut");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
|
|
|
|
|
if (HasOutputs("DOut") && ctx->HasOutput("DOut")) {
|
|
|
|
|
if (ctx->HasOutput("DOut")) {
|
|
|
|
|
ctx->ShareDim("Out", "DOut");
|
|
|
|
|
ctx->ShareLoD("Out", "DOut");
|
|
|
|
|
}
|
|
|
|
|
if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) {
|
|
|
|
|
if (ctx->HasOutput("DDOut")) {
|
|
|
|
|
ctx->ShareDim("Out", "DDOut");
|
|
|
|
|
ctx->ShareLoD("Out", "DDOut");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return GetKernelType(ctx, *this, "DDX");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <ActBwdOpFwdDeps kDepValue>
|
|
|
|
|
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
|
|
|
|
|
if (ctx->HasOutput("DDOut")) {
|
|
|
|
|
ctx->ShareDim("X", "DDOut");
|
|
|
|
|
ctx->ShareLoD("X", "DDOut");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
|
|
|
|
|
if (ctx->HasOutput("DDOut")) {
|
|
|
|
|
ctx->ShareDim("Out", "DDOut");
|
|
|
|
|
ctx->ShareLoD("Out", "DDOut");
|
|
|
|
|
}
|
|
|
|
@ -771,7 +798,7 @@ REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
|
|
|
|
|
ops::ReluDoubleGradMaker);
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
relu_grad_grad,
|
|
|
|
|
ops::ActivationOpDoubleGrad<ops::ReluGradFunctor<float>::FwdDeps()>);
|
|
|
|
|
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>);
|
|
|
|
|
|
|
|
|
|
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
|
|
|
|
|
|
|
|
|
@ -796,7 +823,7 @@ REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
|
|
|
|
|
ops::LeakyReluDoubleGradMaker);
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
leaky_relu_grad_grad,
|
|
|
|
|
ops::ActivationOpDoubleGrad<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
|
|
|
|
|
ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
|
|
|
|
|
|
|
|
|
|
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
|
|
|
|
|
LeakyReluGradFunctor);
|
|
|
|
|