|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
@ -33,7 +33,7 @@ using Array4 = Eigen::DSizes<int64_t, 4>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline bool isInBound(T x, T y, T x_max, T y_max) {
|
|
|
|
|
static inline bool isInBound(T x, T y, T x_max, T y_max) {
|
|
|
|
|
if (x < 0 || x > x_max || y < 0 || y > y_max) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
@ -41,10 +41,10 @@ inline bool isInBound(T x, T y, T x_max, T y_max) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
void CalcGridLocations(const framework::ExecutionContext& ctx, const Tensor& grid,
|
|
|
|
|
static void CalcGridLocations(const DeviceContext& ctx, const Tensor& grid,
|
|
|
|
|
Tensor* x_w, Tensor* x_e, Tensor* y_n, Tensor* y_s,
|
|
|
|
|
Tensor* d_w, Tensor* d_e, Tensor* d_n, Tensor* d_s) {
|
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto& place = *ctx.eigen_device();
|
|
|
|
|
const int n = grid.dims()[0];
|
|
|
|
|
const int h = grid.dims()[1];
|
|
|
|
|
const int w = grid.dims()[2];
|
|
|
|
@ -71,6 +71,7 @@ void CalcGridLocations(const framework::ExecutionContext& ctx, const Tensor& gri
|
|
|
|
|
grid_x_t.device(place) = 0.5 * ((grid_x_t + ones_t) * x_max);
|
|
|
|
|
grid_y_t.device(place) = 0.5 * ((grid_y_t + ones_t) * y_max);
|
|
|
|
|
|
|
|
|
|
// calculate coords of 4 corner points
|
|
|
|
|
x_w->mutable_data<T>({n, h, w}, ctx.GetPlace());
|
|
|
|
|
x_e->mutable_data<T>({n, h, w}, ctx.GetPlace());
|
|
|
|
|
y_n->mutable_data<T>({n, h, w}, ctx.GetPlace());
|
|
|
|
@ -84,6 +85,7 @@ void CalcGridLocations(const framework::ExecutionContext& ctx, const Tensor& gri
|
|
|
|
|
y_n_t.device(place) = grid_y_t.floor();
|
|
|
|
|
y_s_t.device(place) = y_n_t + ones_t;
|
|
|
|
|
|
|
|
|
|
// calculate distances to 4 sides
|
|
|
|
|
d_w->mutable_data<T>({n, h, w}, ctx.GetPlace());
|
|
|
|
|
d_e->mutable_data<T>({n, h, w}, ctx.GetPlace());
|
|
|
|
|
d_n->mutable_data<T>({n, h, w}, ctx.GetPlace());
|
|
|
|
@ -99,7 +101,7 @@ void CalcGridLocations(const framework::ExecutionContext& ctx, const Tensor& gri
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void GetGridPointValue(const Tensor& input, Tensor* output,
|
|
|
|
|
static void GetGridPointValue(const Tensor& input, Tensor* output,
|
|
|
|
|
const Tensor& x, const Tensor& y) {
|
|
|
|
|
const int n = input.dims()[0];
|
|
|
|
|
const int c = input.dims()[1];
|
|
|
|
@ -124,7 +126,7 @@ void GetGridPointValue(const Tensor& input, Tensor* output,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void GatherOutputGradToInputGrad(const Tensor& output_grad, Tensor* input_grad,
|
|
|
|
|
static void GatherOutputGradToInputGrad(const Tensor& output_grad, Tensor* input_grad,
|
|
|
|
|
const Tensor& x, const Tensor& y,
|
|
|
|
|
const Tensor& d1, const Tensor& d2) {
|
|
|
|
|
const int n = output_grad.dims()[0];
|
|
|
|
@ -170,9 +172,10 @@ class GridSampleOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
// calc locations and distances of 4 corner points
|
|
|
|
|
Tensor x_w, x_e, y_n, y_s;
|
|
|
|
|
Tensor d_w, d_e, d_n, d_s;
|
|
|
|
|
CalcGridLocations<DeviceContext, T>(ctx, *grid,
|
|
|
|
|
&x_w, &x_e, &y_n, &y_s,
|
|
|
|
|
&d_w, &d_e, &d_n, &d_s);
|
|
|
|
|
CalcGridLocations<DeviceContext, T>(ctx.template device_context<DeviceContext>(),
|
|
|
|
|
*grid,
|
|
|
|
|
&x_w, &x_e, &y_n, &y_s,
|
|
|
|
|
&d_w, &d_e, &d_n, &d_s);
|
|
|
|
|
|
|
|
|
|
auto* output = ctx.Output<Tensor>("Output");
|
|
|
|
|
output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
|
|
|
|
@ -239,9 +242,10 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
Tensor x_w, x_e, y_n, y_s;
|
|
|
|
|
Tensor d_w, d_e, d_n, d_s;
|
|
|
|
|
CalcGridLocations<DeviceContext, T>(ctx, *grid,
|
|
|
|
|
&x_w, &x_e, &y_n, &y_s,
|
|
|
|
|
&d_w, &d_e, &d_n, &d_s);
|
|
|
|
|
CalcGridLocations<DeviceContext, T>(ctx.template device_context<DeviceContext>(),
|
|
|
|
|
*grid,
|
|
|
|
|
&x_w, &x_e, &y_n, &y_s,
|
|
|
|
|
&d_w, &d_e, &d_n, &d_s);
|
|
|
|
|
|
|
|
|
|
// gather output grad value to input grad by corner point coords and weight
|
|
|
|
|
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_n, d_e, d_s);
|
|
|
|
|