|
|
|
@ -11,6 +11,7 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/temporal_shift_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/cuda_primitives.h"
|
|
|
|
|
#include "paddle/fluid/platform/gpu_launch_config.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -112,11 +113,11 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
int pixelNum = nt * chw;
|
|
|
|
|
int grid_dim = (pixelNum + 512 - 1) / 512;
|
|
|
|
|
grid_dim = grid_dim > 8 ? 8 : grid_dim;
|
|
|
|
|
platform::GpuLaunchConfig config =
|
|
|
|
|
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
|
|
|
|
|
|
|
|
|
|
KeTemporalShiftFw<
|
|
|
|
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
KeTemporalShiftFw<T><<<config.block_per_grid, config.thread_per_block, 0,
|
|
|
|
|
ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -148,11 +149,11 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
int pixelNum = nt * chw;
|
|
|
|
|
int grid_dim = (pixelNum + 512 - 1) / 512;
|
|
|
|
|
grid_dim = grid_dim > 8 ? 8 : grid_dim;
|
|
|
|
|
platform::GpuLaunchConfig config =
|
|
|
|
|
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
|
|
|
|
|
|
|
|
|
|
KeTemporalShiftBw<
|
|
|
|
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
KeTemporalShiftBw<T><<<config.block_per_grid, config.thread_per_block, 0,
|
|
|
|
|
ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c,
|
|
|
|
|
shift_ratio);
|
|
|
|
|
}
|
|
|
|
|