|
|
|
@ -150,13 +150,12 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Output(W@Grad should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
|
|
|
|
"Output(X@Grad should not be null.");
|
|
|
|
|
if (!ctx->Attrs().Get<bool>("is_sparse")) {
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Bias"),
|
|
|
|
|
ctx->GetInputDim("Bias"));
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Bias"),
|
|
|
|
|
ctx->GetInputDim("Bias"));
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
|
|
|
|
|
}
|
|
|
|
|