|
|
|
@ -33,8 +33,8 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
|
|
|
|
|
int ih = (tid % hw) / w;
|
|
|
|
|
int iw = tid % w;
|
|
|
|
|
|
|
|
|
|
const int c1 = static_cast<T>(c * shift_ratio);
|
|
|
|
|
const int c2 = static_cast<T>(c * 2 * shift_ratio);
|
|
|
|
|
const int c1 = static_cast<int>(c * shift_ratio);
|
|
|
|
|
const int c2 = static_cast<int>(c * 2 * shift_ratio);
|
|
|
|
|
|
|
|
|
|
if (ic < c1) {
|
|
|
|
|
src_it = it - 1;
|
|
|
|
@ -69,8 +69,8 @@ __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad,
|
|
|
|
|
int ih = (tid % hw) / w;
|
|
|
|
|
int iw = tid % w;
|
|
|
|
|
|
|
|
|
|
const int c1 = static_cast<T>(c * shift_ratio);
|
|
|
|
|
const int c2 = static_cast<T>(c * 2 * shift_ratio);
|
|
|
|
|
const int c1 = static_cast<int>(c * shift_ratio);
|
|
|
|
|
const int c2 = static_cast<int>(c * 2 * shift_ratio);
|
|
|
|
|
|
|
|
|
|
if (ic < c1) {
|
|
|
|
|
src_it = it - 1;
|
|
|
|
@ -163,8 +163,11 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(temporal_shift, ops::TemporalShiftOpCUDAKernel<float>,
|
|
|
|
|
ops::TemporalShiftOpCUDAKernel<double>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(temporal_shift_grad,
|
|
|
|
|
ops::TemporalShiftGradOpCUDAKernel<float>,
|
|
|
|
|
ops::TemporalShiftGradOpCUDAKernel<double>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
temporal_shift, ops::TemporalShiftOpCUDAKernel<float>,
|
|
|
|
|
ops::TemporalShiftOpCUDAKernel<double>,
|
|
|
|
|
ops::TemporalShiftOpCUDAKernel<paddle::platform::float16>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
temporal_shift_grad, ops::TemporalShiftGradOpCUDAKernel<float>,
|
|
|
|
|
ops::TemporalShiftGradOpCUDAKernel<double>,
|
|
|
|
|
ops::TemporalShiftGradOpCUDAKernel<paddle::platform::float16>);
|
|
|
|
|