|
|
|
@ -121,9 +121,11 @@ class AffineGridOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
// TODO(wanghaoshuang): Refine batched matrix multiply
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
Tensor sliced_grid = grid.Slice(i, i + 1).Resize({h * w, 3});
|
|
|
|
|
Tensor sliced_grid = grid.Slice(i, i + 1).Resize(
|
|
|
|
|
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 3});
|
|
|
|
|
Tensor sliced_theta = theta->Slice(i, i + 1).Resize({2, 3});
|
|
|
|
|
Tensor sliced_out = output->Slice(i, i + 1).Resize({h * w, 2});
|
|
|
|
|
Tensor sliced_out = output->Slice(i, i + 1).Resize(
|
|
|
|
|
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 2});
|
|
|
|
|
blas.MatMul(sliced_grid, false, sliced_theta, true, T(1), &sliced_out,
|
|
|
|
|
T(0));
|
|
|
|
|
}
|
|
|
|
@ -161,8 +163,10 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
// TODO(wanghaoshuang): Refine batched matrix multiply
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
Tensor sliced_grid = grid.Slice(i, i + 1).Resize({h * w, 3});
|
|
|
|
|
Tensor sliced_out_grad = output_grad->Slice(i, i + 1).Resize({h * w, 2});
|
|
|
|
|
Tensor sliced_grid = grid.Slice(i, i + 1).Resize(
|
|
|
|
|
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 3});
|
|
|
|
|
Tensor sliced_out_grad = output_grad->Slice(i, i + 1).Resize(
|
|
|
|
|
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 2});
|
|
|
|
|
Tensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({2, 3});
|
|
|
|
|
blas.MatMul(sliced_out_grad, true, sliced_grid, false, T(1),
|
|
|
|
|
&sliced_theta_grad, T(0));
|
|
|
|
|