|
|
|
@ -20,7 +20,8 @@ using framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
|
|
|
|
|
const int tchw, const int chw, const int hw, const int w, const int t, const int c) {
|
|
|
|
|
const int tchw, const int chw, const int hw, const int w, const int t, const int c,
|
|
|
|
|
const float shift_ratio) {
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
|
int src_it = 0;
|
|
|
|
@ -31,9 +32,12 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
|
|
|
|
|
int ih = (tid % hw) / w;
|
|
|
|
|
int iw = tid % w;
|
|
|
|
|
|
|
|
|
|
if (ic < c / 4) {
|
|
|
|
|
const int c1 = static_cast<T>(c * shift_ratio);
|
|
|
|
|
const int c2 = static_cast<T>(c * 2 * shift_ratio);
|
|
|
|
|
|
|
|
|
|
if (ic < c1) {
|
|
|
|
|
src_it = it - 1;
|
|
|
|
|
} else if (ic < c / 2) {
|
|
|
|
|
} else if (ic < c2) {
|
|
|
|
|
src_it = it + 1;
|
|
|
|
|
} else {
|
|
|
|
|
src_it = it;
|
|
|
|
@ -50,7 +54,8 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ntchw,
|
|
|
|
|
const int tchw, const int chw, const int hw, const int w, const int t, const int c) {
|
|
|
|
|
const int tchw, const int chw, const int hw, const int w, const int t, const int c,
|
|
|
|
|
const float shift_ratio) {
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
|
int src_it = 0;
|
|
|
|
@ -61,9 +66,12 @@ __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int
|
|
|
|
|
int ih = (tid % hw) / w;
|
|
|
|
|
int iw = tid % w;
|
|
|
|
|
|
|
|
|
|
if (ic < c / 4) {
|
|
|
|
|
const int c1 = static_cast<T>(c * shift_ratio);
|
|
|
|
|
const int c2 = static_cast<T>(c * 2 * shift_ratio);
|
|
|
|
|
|
|
|
|
|
if (ic < c1) {
|
|
|
|
|
src_it = it - 1;
|
|
|
|
|
} else if (ic < c / 2) {
|
|
|
|
|
} else if (ic < c2) {
|
|
|
|
|
src_it = it + 1;
|
|
|
|
|
} else {
|
|
|
|
|
src_it = it;
|
|
|
|
@ -85,6 +93,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* input = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* output = ctx.Output<Tensor>("Out");
|
|
|
|
|
int t = ctx.Attr<int>("seg_num");
|
|
|
|
|
float shift_ratio = ctx.Attr<float>("shift_ratio");
|
|
|
|
|
|
|
|
|
|
const int nt = input->dims()[0];
|
|
|
|
|
const int c = input->dims()[1];
|
|
|
|
@ -105,7 +114,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
KeTemporalShiftFw<
|
|
|
|
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_data, output_data, ntchw, tchw, chw, hw, w, t, c);
|
|
|
|
|
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -116,6 +125,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
int t = ctx.Attr<int>("seg_num");
|
|
|
|
|
float shift_ratio = ctx.Attr<float>("shift_ratio");
|
|
|
|
|
|
|
|
|
|
const int nt = output_grad->dims()[0];
|
|
|
|
|
const int c = output_grad->dims()[1];
|
|
|
|
@ -139,7 +149,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
KeTemporalShiftBw<
|
|
|
|
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c);
|
|
|
|
|
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|