|
|
|
@ -52,14 +52,25 @@ class ElementwiseOp : public framework::OperatorWithKernel {
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
auto x_dim = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dim = ctx->GetInputDim("Y");
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
|
|
|
|
|
"Rank of first input must >= rank of second input.");
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
x_dim.size(), y_dim.size(),
|
|
|
|
|
"ShapeError: the dimension of input X must greater than or equal to "
|
|
|
|
|
"the one of input Y. But received: the shape of input X = [%s], the "
|
|
|
|
|
"dimension of input X = %d, the shape of input Y = [%s], the "
|
|
|
|
|
"dimension of input Y = %d",
|
|
|
|
|
x_dim, x_dim.size(), y_dim, y_dim.size());
|
|
|
|
|
} else if (ctx->GetInputsVarType("X").front() ==
|
|
|
|
|
framework::proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
PADDLE_ENFORCE((ctx->GetInputDim("Y").size() == 1u) &&
|
|
|
|
|
(ctx->GetInputDim("Y")[0] == 1),
|
|
|
|
|
"For elementwise_op, if X is Sparse, "
|
|
|
|
|
"Y must be scalar.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputDim("Y").size(), 1u,
|
|
|
|
|
"ShapeError: For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
|
|
|
|
|
"), Y must be scalar. But reveived the dimension of Y = %s",
|
|
|
|
|
ctx->GetInputDim("Y").size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputDim("Y")[0], 1,
|
|
|
|
|
"ShapeError: For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
|
|
|
|
|
"), Y must be scalar. But reveived the first dimension of Y = %s",
|
|
|
|
|
ctx->GetInputDim("Y")[0]);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
|
|
|
|
|
ctx->GetInputsVarType("X").front());
|
|
|
|
@ -203,8 +214,13 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
auto x_dims = ctx->GetInputDim(out_grad_name);
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
|
|
|
|
|
"Rank of first input must >= rank of second input.");
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
x_dims.size(), y_dims.size(),
|
|
|
|
|
"ShapeError: the dimension of Out@GRAD must greater than or equal to "
|
|
|
|
|
"the one of input Y. But received: the shape of Out@GRAD = [%s], the "
|
|
|
|
|
"dimension of Out@GRAD = %d, the shape of input Y = [%s], the "
|
|
|
|
|
"dimension of of input Y = %d",
|
|
|
|
|
x_dims, x_dims.size(), y_dims, y_dims.size());
|
|
|
|
|
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
auto y_grad_name = framework::GradVarName("Y");
|
|
|
|
|