|
|
|
@ -41,7 +41,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto y_dim = ctx->GetInputDim("Y");
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
|
|
|
|
|
"Rank of first input must >= rank of second input.");
|
|
|
|
|
ctx->SetOutputDim("Out", x_dim);
|
|
|
|
|
|
|
|
|
|
ctx->ShareDim("X", /*->*/ "Out");
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -70,6 +71,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
|
|
|
|
|
auto& x = block->FindRecursiveOrCreateVar(x_name);
|
|
|
|
|
auto& out = block->FindRecursiveOrCreateVar(out_name);
|
|
|
|
|
out.SetType(x.GetType());
|
|
|
|
|
out.SetDataType(x.GetDataType());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -157,10 +159,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
auto y_grad_name = framework::GradVarName("Y");
|
|
|
|
|
if (ctx->HasOutput(x_grad_name)) {
|
|
|
|
|
ctx->SetOutputDim(x_grad_name, x_dims);
|
|
|
|
|
ctx->ShareDim("X", /*->*/ x_grad_name);
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ x_grad_name);
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasOutput(y_grad_name)) {
|
|
|
|
|
ctx->SetOutputDim(y_grad_name, y_dims);
|
|
|
|
|
ctx->ShareDim("Y", /*->*/ y_grad_name);
|
|
|
|
|
ctx->ShareLoD("Y", /*->*/ y_grad_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -193,14 +197,15 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
|
|
|
|
|
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
if (ctx->HasOutput(x_grad_name)) {
|
|
|
|
|
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
ctx->SetOutputDim(x_grad_name, out_dims);
|
|
|
|
|
ctx->ShareDim(framework::GradVarName("Out"), /*->*/ x_grad_name);
|
|
|
|
|
ctx->ShareLoD(framework::GradVarName("Out"), /*->*/ x_grad_name);
|
|
|
|
|
}
|
|
|
|
|
auto y_grad_name = framework::GradVarName("Y");
|
|
|
|
|
if (ctx->HasOutput(y_grad_name)) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
ctx->SetOutputDim(y_grad_name, y_dims);
|
|
|
|
|
|
|
|
|
|
ctx->ShareDim("Y", /*->*/ y_grad_name);
|
|
|
|
|
ctx->ShareLoD("Y", /*->*/ y_grad_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|