Refine Infershape in activation_op for double_grad. (#18485)

* Refine Infershape in activation_op for double_grad.
sum_op
qingqing01 6 years ago committed by GitHub
parent 602cb6a5b4
commit 7ac4818a98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -604,21 +604,48 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
if (HasOutputs("DX") && ctx->HasOutput("DX")) {
if (ctx->HasOutput("DX")) {
ctx->ShareDim("X", "DX");
ctx->ShareLoD("X", "DX");
}
if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
}
}
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
if (HasOutputs("DOut") && ctx->HasOutput("DOut")) {
if (ctx->HasOutput("DOut")) {
ctx->ShareDim("Out", "DOut");
ctx->ShareLoD("Out", "DOut");
}
if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut");
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX");
}
};
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
}
}
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut");
}
@ -771,7 +798,7 @@ REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
ops::ReluDoubleGradMaker);
REGISTER_OPERATOR(
relu_grad_grad,
ops::ActivationOpDoubleGrad<ops::ReluGradFunctor<float>::FwdDeps()>);
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
@ -796,7 +823,7 @@ REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
ops::LeakyReluDoubleGradMaker);
REGISTER_OPERATOR(
leaky_relu_grad_grad,
ops::ActivationOpDoubleGrad<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
LeakyReluGradFunctor);

Loading…
Cancel
Save