|
|
@ -58,8 +58,8 @@ class CompareOpInferShape : public framework::InferShapeBase {
|
|
|
|
comment.type);
|
|
|
|
comment.type);
|
|
|
|
auto dim_x = context->GetInputDim("X");
|
|
|
|
auto dim_x = context->GetInputDim("X");
|
|
|
|
auto dim_y = context->GetInputDim("Y");
|
|
|
|
auto dim_y = context->GetInputDim("Y");
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
|
|
|
|
PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(),
|
|
|
|
"The number of elements in X and Y should be same");
|
|
|
|
"The size of dim_y should not be greater than dim_x's.");
|
|
|
|
|
|
|
|
|
|
|
|
context->SetOutputDim("Out", context->GetInputDim("X"));
|
|
|
|
context->SetOutputDim("Out", context->GetInputDim("X"));
|
|
|
|
context->ShareLoD("X", "Out");
|
|
|
|
context->ShareLoD("X", "Out");
|
|
|
|