|
|
|
@ -37,18 +37,65 @@ using Array4 = Eigen::DSizes<int64_t, 4>;
|
|
|
|
|
*/
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
struct Linspace {
|
|
|
|
|
framework::Tensor operator()(T start, T end, int count,
|
|
|
|
|
const framework::ExecutionContext& ctx);
|
|
|
|
|
void operator()(T start, T end, int count, framework::Tensor* numbers,
|
|
|
|
|
const framework::ExecutionContext& ctx);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
inline void GetIdxMap(int n, int h, int w, Tensor* grid,
|
|
|
|
|
const framework::ExecutionContext& ctx) {
|
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
grid->mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
|
|
|
|
|
auto grid_t = EigenTensor<T, 4>::From(*grid);
|
|
|
|
|
// Get indexes of height with shape [height, width, 1]
|
|
|
|
|
Tensor h_idx;
|
|
|
|
|
Linspace<DeviceContext, T> linspace;
|
|
|
|
|
linspace((T)-1, (T)1, h, &h_idx, ctx);
|
|
|
|
|
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
|
|
|
|
|
// Get indexes of width with shape [height, width, 1]
|
|
|
|
|
Tensor w_idx;
|
|
|
|
|
linspace((T)-1, (T)1, w, &w_idx, ctx);
|
|
|
|
|
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
|
|
|
|
|
// Get constant ones tensor with shape [height, width, 1]
|
|
|
|
|
Tensor ones;
|
|
|
|
|
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
|
|
|
|
|
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);
|
|
|
|
|
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
|
|
|
|
|
// ones
|
|
|
|
|
Tensor w_idx_map;
|
|
|
|
|
w_idx_map.mutable_data<T>({h, w, 1}, ctx.GetPlace());
|
|
|
|
|
auto w_idx_map_t = EigenTensor<T, 3>::From(w_idx_map);
|
|
|
|
|
Tensor h_idx_map;
|
|
|
|
|
h_idx_map.mutable_data<T>({h, w, 1}, ctx.GetPlace());
|
|
|
|
|
auto h_idx_map_t = EigenTensor<T, 3>::From(h_idx_map);
|
|
|
|
|
Tensor w_h_idx_map;
|
|
|
|
|
w_h_idx_map.mutable_data<T>({h, w, 2}, ctx.GetPlace());
|
|
|
|
|
auto w_h_idx_map_t = EigenTensor<T, 3>::From(w_h_idx_map);
|
|
|
|
|
Tensor w_h_one_idx_map;
|
|
|
|
|
w_h_one_idx_map.mutable_data<T>({h, w, 3}, ctx.GetPlace());
|
|
|
|
|
auto w_h_one_idx_map_t = EigenTensor<T, 3>::From(w_h_one_idx_map);
|
|
|
|
|
|
|
|
|
|
w_idx_map_t.device(place) = w_idx_t.reshape(Array2(1, w))
|
|
|
|
|
.broadcast(Array2(h, 1))
|
|
|
|
|
.reshape(Array3(h, w, 1));
|
|
|
|
|
|
|
|
|
|
h_idx_map_t.device(place) = h_idx_t.reshape(Array2(1, h))
|
|
|
|
|
.broadcast(Array2(w, 1))
|
|
|
|
|
.shuffle(Array2(1, 0))
|
|
|
|
|
.reshape(Array3(h, w, 1));
|
|
|
|
|
|
|
|
|
|
w_h_idx_map_t.device(place) = w_idx_map_t.concatenate(h_idx_map_t, 2);
|
|
|
|
|
w_h_one_idx_map_t.device(place) = w_h_idx_map_t.concatenate(ones_t, 2);
|
|
|
|
|
grid_t.device(place) = w_h_one_idx_map_t.reshape(Array4(1, h, w, 3))
|
|
|
|
|
.broadcast(Array4(n, 1, 1, 1));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class AffineGridOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto* theta = ctx.Input<Tensor>("Theta");
|
|
|
|
|
int n = theta->dims()[0];
|
|
|
|
|
|
|
|
|
|
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
|
|
|
|
|
int h = 0;
|
|
|
|
|
int w = 0;
|
|
|
|
@ -63,44 +110,13 @@ class AffineGridOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
h = size_attr[2];
|
|
|
|
|
w = size_attr[3];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto* output = ctx.Output<Tensor>("Output");
|
|
|
|
|
output->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::SetConstant<DeviceContext, T>()(
|
|
|
|
|
ctx.template device_context<DeviceContext>(), output,
|
|
|
|
|
static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
Linspace<DeviceContext, T> linspace;
|
|
|
|
|
// Get indexes of height with shape [height, width, 1]
|
|
|
|
|
auto h_idx = linspace((T)-1, (T)1, h, ctx);
|
|
|
|
|
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
|
|
|
|
|
// Get indexes of width with shape [height, width, 1]
|
|
|
|
|
auto w_idx = linspace((T)-1, (T)1, w, ctx);
|
|
|
|
|
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
|
|
|
|
|
// Get constant ones tensor with shape [height, width, 1]
|
|
|
|
|
Tensor ones;
|
|
|
|
|
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
|
|
|
|
|
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);
|
|
|
|
|
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
|
|
|
|
|
// ones
|
|
|
|
|
Tensor grid;
|
|
|
|
|
grid.mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
|
|
|
|
|
auto grid_t = EigenTensor<T, 4>::From(grid);
|
|
|
|
|
|
|
|
|
|
grid_t.device(place) = w_idx_t.reshape(Array2(1, w))
|
|
|
|
|
.broadcast(Array2(h, 1))
|
|
|
|
|
.reshape(Array3(h, w, 1))
|
|
|
|
|
.concatenate(h_idx_t.reshape(Array2(1, h))
|
|
|
|
|
.broadcast(Array2(w, 1))
|
|
|
|
|
.shuffle(Array2(1, 0))
|
|
|
|
|
.reshape(Array3(h, w, 1)),
|
|
|
|
|
2)
|
|
|
|
|
.eval()
|
|
|
|
|
.concatenate(ones_t, 2)
|
|
|
|
|
.reshape(Array4(1, h, w, 3))
|
|
|
|
|
.broadcast(Array4(n, 1, 1, 1));
|
|
|
|
|
|
|
|
|
|
GetIdxMap<DeviceContext, T>(n, h, w, &grid, ctx);
|
|
|
|
|
// output = grid * theta.T
|
|
|
|
|
// TODO(wanghaoshuang): Refine batched matrix multiply
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
@ -118,10 +134,8 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class AffineGridGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
|
|
|
|
|
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
|
|
|
|
|
|
|
|
|
|
int n = output_grad->dims()[0];
|
|
|
|
|
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
|
|
|
|
|
int h = 0;
|
|
|
|
@ -137,42 +151,12 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
h = size_attr[2];
|
|
|
|
|
w = size_attr[3];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
theta_grad->mutable_data<T>({n, 2, 3}, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::SetConstant<DeviceContext, T>()(
|
|
|
|
|
ctx.template device_context<DeviceContext>(), theta_grad,
|
|
|
|
|
static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
Linspace<DeviceContext, T> linspace;
|
|
|
|
|
|
|
|
|
|
// Get indexes of height with shape [height, width, 1]
|
|
|
|
|
auto h_idx = linspace((T)-1, (T)1, h, ctx);
|
|
|
|
|
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
|
|
|
|
|
// Get indexes of width with shape [height, width, 1]
|
|
|
|
|
auto w_idx = linspace((T)-1, (T)1, w, ctx);
|
|
|
|
|
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
|
|
|
|
|
// Get constant ones tensor with shape [height, width, 1]
|
|
|
|
|
Tensor ones;
|
|
|
|
|
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
|
|
|
|
|
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);
|
|
|
|
|
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
|
|
|
|
|
// ones
|
|
|
|
|
Tensor grid;
|
|
|
|
|
grid.mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
|
|
|
|
|
auto grid_t = EigenTensor<T, 4>::From(grid);
|
|
|
|
|
grid_t.device(place) = w_idx_t.reshape(Array2(1, w))
|
|
|
|
|
.broadcast(Array2(h, 1))
|
|
|
|
|
.reshape(Array3(h, w, 1))
|
|
|
|
|
.concatenate(h_idx_t.reshape(Array2(1, h))
|
|
|
|
|
.broadcast(Array2(w, 1))
|
|
|
|
|
.shuffle(Array2(1, 0))
|
|
|
|
|
.reshape(Array3(h, w, 1)),
|
|
|
|
|
2)
|
|
|
|
|
.eval()
|
|
|
|
|
.concatenate(ones_t, 2)
|
|
|
|
|
.reshape(Array4(1, h, w, 3))
|
|
|
|
|
.broadcast(Array4(n, 1, 1, 1));
|
|
|
|
|
GetIdxMap<DeviceContext, T>(n, h, w, &grid, ctx);
|
|
|
|
|
// output = grid * theta.T
|
|
|
|
|
// TODO(wanghaoshuang): Refine batched matrix multiply
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|