|
|
|
@ -36,14 +36,17 @@ class ConvShiftOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Y) should "
|
|
|
|
|
"be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
|
|
|
|
|
"The 2nd dimension of Input(Y) should be odd.");
|
|
|
|
|
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
|
|
|
|
|
"The 2nd dimension of Input(Y) should be less than or "
|
|
|
|
|
"equal to the 2nd dimension of Input(X).");
|
|
|
|
|
if (ctx->IsRuntime() || (x_dims[0] > 0 && y_dims[0] > 0))
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Y) should "
|
|
|
|
|
"be equal.");
|
|
|
|
|
if (ctx->IsRuntime() || y_dims[1] > 0)
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
|
|
|
|
|
"The 2nd dimension of Input(Y) should be odd.");
|
|
|
|
|
if (ctx->IsRuntime() || (x_dims[1] > 0 && y_dims[1] > 0))
|
|
|
|
|
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
|
|
|
|
|
"The 2nd dimension of Input(Y) should be less than or "
|
|
|
|
|
"equal to the 2nd dimension of Input(X).");
|
|
|
|
|
ctx->ShareDim("X", /*->*/ "Out");
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|