|
|
|
@ -244,7 +244,8 @@ class MulDoubleGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("DOut"), "Input(DOut) should not be null");
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) {
|
|
|
|
|
if (ctx->HasOutput("DDOut") &&
|
|
|
|
|
(ctx->HasInput("DDX") || (ctx->HasInput("DDY")))) {
|
|
|
|
|
ctx->ShareDim("DOut", "DDOut");
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasOutput("DX") && ctx->HasInput("DDY")) {
|
|
|
|
@ -275,9 +276,9 @@ class MulDoubleGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
auto ddw = OutputGrad(framework::GradVarName("Y"));
|
|
|
|
|
std::vector<std::string> empty_str = {};
|
|
|
|
|
|
|
|
|
|
retv->SetOutput("DDOut", (ddx.empty())
|
|
|
|
|
? empty_str
|
|
|
|
|
: InputGrad(framework::GradVarName("Out")));
|
|
|
|
|
if (!ddx.empty() || !ddw.empty()) {
|
|
|
|
|
retv->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
|
|
|
|
|
}
|
|
|
|
|
retv->SetOutput("DX", ddw.empty() ? empty_str : InputGrad("X"));
|
|
|
|
|
retv->SetOutput("DY", ddx.empty() ? empty_str : InputGrad("Y"));
|
|
|
|
|
|
|
|
|
|