fix mul double grad (#20040)

fix-python-transpose
lvmengsi 6 years ago committed by GitHub
parent 8f0b3c0516
commit 647ff784e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -244,7 +244,8 @@ class MulDoubleGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput("DOut"), "Input(DOut) should not be null");
if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) {
if (ctx->HasOutput("DDOut") &&
(ctx->HasInput("DDX") || (ctx->HasInput("DDY")))) {
ctx->ShareDim("DOut", "DDOut");
}
if (ctx->HasOutput("DX") && ctx->HasInput("DDY")) {
@ -275,9 +276,9 @@ class MulDoubleGradMaker : public framework::SingleGradOpDescMaker {
auto ddw = OutputGrad(framework::GradVarName("Y"));
std::vector<std::string> empty_str = {};
retv->SetOutput("DDOut", (ddx.empty())
? empty_str
: InputGrad(framework::GradVarName("Out")));
if (!ddx.empty() || !ddw.empty()) {
retv->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
}
retv->SetOutput("DX", ddw.empty() ? empty_str : InputGrad("X"));
retv->SetOutput("DY", ddx.empty() ? empty_str : InputGrad("Y"));

Loading…
Cancel
Save