|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/grid_sampler_op.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#include "paddle/fluid/platform/cudnn_helper.h"
|
|
|
|
@ -40,10 +41,12 @@ class GridSampleOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Input(X) of GridSampleOp should be 4-D Tensor.");
|
|
|
|
|
PADDLE_ENFORCE(grid_dims.size() == 4,
|
|
|
|
|
"Input(Grid) of GridSampleOp should be 4-D Tensor.");
|
|
|
|
|
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
|
|
|
|
|
"Input(X) and Input(Grid) dims[0] should be equal.");
|
|
|
|
|
if (ctx->IsRuntime() || grid_dims[3] > 0) {
|
|
|
|
|
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
|
|
|
|
|
}
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
|
|
|
|
|
"Input(X) and Input(Grid) dims[0] should be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
grid_dims[1], x_dims[2],
|
|
|
|
|
"Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
|
|
|
|
|