|
|
|
@ -43,12 +43,14 @@ class GridSampleOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
|
|
|
|
|
"Input(X) and Input(Grid) dims[0] should be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
grid_dims[1], x_dims[2],
|
|
|
|
|
"Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
grid_dims[2], x_dims[3],
|
|
|
|
|
"Input(X) dims[3] and Input(Grid) dims[2] should be equal.");
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
grid_dims[1], x_dims[2],
|
|
|
|
|
"Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
grid_dims[2], x_dims[3],
|
|
|
|
|
"Input(X) dims[3] and Input(Grid) dims[2] should be equal.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Output", x_dims);
|
|
|
|
|
ctx->ShareLoD("X", "Output");
|
|
|
|
|