|
|
|
@ -22,29 +22,34 @@ namespace operators {
|
|
|
|
|
|
|
|
|
|
void FusionSquaredMatSubOp::InferShape(
|
|
|
|
|
framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of FusionSquaredMatSubOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"),
|
|
|
|
|
"Input(Y) of FusionSquaredMatSubOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("SquaredX"),
|
|
|
|
|
"Output(SquaredX) of FusionSquaredMatSubOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("SquaredY"),
|
|
|
|
|
"Output(SquaredY) of FusionSquaredMatSubOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("SquaredXY"),
|
|
|
|
|
"Output(SquaredXY) of FusionSquaredMatSubOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FusionSquaredMatSubOp should not be null.");
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusionSquaredMatSub");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusionSquaredMatSub");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("SquaredX"), "SquaredX", "Out",
|
|
|
|
|
"FusionSquaredMatSub");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("SquaredY"), "SquaredY", "Out",
|
|
|
|
|
"FusionSquaredMatSub");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("SquaredXY"), "SquaredXY", "Out",
|
|
|
|
|
"FusionSquaredMatSub");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Out", "Out", "FusionSquaredMatSub");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
|
|
|
|
|
"Input tensors dims size should be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input tensors should be a Matrix.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], y_dims[0], "Inputs Matrix should be multiply.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims.size(), y_dims.size(),
|
|
|
|
|
platform::errors::InvalidArgument("The input tensor X's dims size should "
|
|
|
|
|
"be equal to Y's. But received X's "
|
|
|
|
|
"dims size = %d, Y's dims size = %d.",
|
|
|
|
|
x_dims.size(), y_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input tensor X's dims size should be 2. But "
|
|
|
|
|
"received X's dims size = %d.",
|
|
|
|
|
x_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[1], y_dims[0],
|
|
|
|
|
platform::errors::InvalidArgument("The input tensor X's dims[1] should "
|
|
|
|
|
"be equal to Y's dims[0]. But received "
|
|
|
|
|
"X's dims[1] = %d, Y's dims[0] = %d.",
|
|
|
|
|
x_dims[1], y_dims[0]));
|
|
|
|
|
ctx->SetOutputDim("SquaredX", x_dims);
|
|
|
|
|
ctx->SetOutputDim("SquaredY", y_dims);
|
|
|
|
|
ctx->SetOutputDim("SquaredXY", {x_dims[0], y_dims[1]});
|
|
|
|
|