|
|
|
@ -71,8 +71,16 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
"Input(Y) of %s operator must not be null", comment.type);
|
|
|
|
|
auto dim_x = context->GetInputDim("X");
|
|
|
|
|
auto dim_y = context->GetInputDim("Y");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
|
|
|
|
|
"The number of elements in X and Y should be same");
|
|
|
|
|
|
|
|
|
|
int product_x = framework::product(dim_x);
|
|
|
|
|
int product_y = framework::product(dim_y);
|
|
|
|
|
bool check = context->IsRuntime() || (product_x >= 0 && product_y >= 0);
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
product_x, product_y,
|
|
|
|
|
"The number of elements in X and Y should be same, %d != %d",
|
|
|
|
|
product_x, product_y);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
context->SetOutputDim("Out", context->GetInputDim("X"));
|
|
|
|
|
context->ShareLoD("X", "Out");
|
|
|
|
|