|
|
|
@ -17,12 +17,106 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using DataLayout = framework::DataLayout;
|
|
|
|
|
|
|
|
|
|
static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih,
|
|
|
|
|
int iw, const int tchw,
|
|
|
|
|
const int chw, const int hw,
|
|
|
|
|
const int w) {
|
|
|
|
|
return in * tchw + it * chw + ic * hw + ih * w + iw;
|
|
|
|
|
template <typename T>
|
|
|
|
|
void TemporalShiftFwNCHW(const T* input, T* output, const int ntchw,
|
|
|
|
|
const int tchw, const int chw, const int hw,
|
|
|
|
|
const int t, const int c1, const int c2) {
|
|
|
|
|
int src_it = 0;
|
|
|
|
|
for (int i = 0; i < ntchw; i++) {
|
|
|
|
|
int it = (i % tchw) / chw;
|
|
|
|
|
int ic = (i % chw) / hw;
|
|
|
|
|
|
|
|
|
|
if (ic < c1) {
|
|
|
|
|
src_it = it - 1;
|
|
|
|
|
} else if (ic < c2) {
|
|
|
|
|
src_it = it + 1;
|
|
|
|
|
} else {
|
|
|
|
|
src_it = it;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (src_it < 0 || src_it >= t) {
|
|
|
|
|
output[i] = 0;
|
|
|
|
|
} else {
|
|
|
|
|
output[i] = input[i + (src_it - it) * chw];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void TemporalShiftFwNHWC(const T* input, T* output, const int nthwc,
|
|
|
|
|
const int thwc, const int hwc, const int t,
|
|
|
|
|
const int c, const int c1, const int c2) {
|
|
|
|
|
int src_it = 0;
|
|
|
|
|
for (int i = 0; i < nthwc; i++) {
|
|
|
|
|
int it = (i % thwc) / hwc;
|
|
|
|
|
int ic = i % c;
|
|
|
|
|
|
|
|
|
|
if (ic < c1) {
|
|
|
|
|
src_it = it - 1;
|
|
|
|
|
} else if (ic < c2) {
|
|
|
|
|
src_it = it + 1;
|
|
|
|
|
} else {
|
|
|
|
|
src_it = it;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (src_it < 0 || src_it >= t) {
|
|
|
|
|
output[i] = 0;
|
|
|
|
|
} else {
|
|
|
|
|
output[i] = input[i + (src_it - it) * hwc];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void TemporalShiftBwNCHW(const T* output_grad, T* input_grad, const int ntchw,
|
|
|
|
|
const int tchw, const int chw, const int hw,
|
|
|
|
|
const int t, const int c1, const int c2) {
|
|
|
|
|
int src_it = 0;
|
|
|
|
|
for (int i = 0; i < ntchw; i++) {
|
|
|
|
|
int it = (i % tchw) / chw;
|
|
|
|
|
int ic = (i % chw) / hw;
|
|
|
|
|
|
|
|
|
|
if (ic < c1) {
|
|
|
|
|
src_it = it + 1;
|
|
|
|
|
} else if (ic < c2) {
|
|
|
|
|
src_it = it - 1;
|
|
|
|
|
} else {
|
|
|
|
|
src_it = it;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (src_it >= 0 && src_it < t) {
|
|
|
|
|
input_grad[i] = output_grad[i + (src_it - it) * chw];
|
|
|
|
|
} else {
|
|
|
|
|
input_grad[i] = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void TemporalShiftBwNHWC(const T* output_grad, T* input_grad, const int nthwc,
|
|
|
|
|
const int thwc, const int hwc, const int t,
|
|
|
|
|
const int c, const int c1, const int c2) {
|
|
|
|
|
int src_it = 0;
|
|
|
|
|
for (int i = 0; i < nthwc; i++) {
|
|
|
|
|
int it = (i % thwc) / hwc;
|
|
|
|
|
int ic = i % c;
|
|
|
|
|
|
|
|
|
|
if (ic < c1) {
|
|
|
|
|
src_it = it + 1;
|
|
|
|
|
} else if (ic < c2) {
|
|
|
|
|
src_it = it - 1;
|
|
|
|
|
} else {
|
|
|
|
|
src_it = it;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (src_it >= 0 && src_it < t) {
|
|
|
|
|
input_grad[i] = output_grad[i + (src_it - it) * hwc];
|
|
|
|
|
} else {
|
|
|
|
|
input_grad[i] = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -33,44 +127,38 @@ class TemporalShiftKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* output = ctx.Output<Tensor>("Out");
|
|
|
|
|
int t = ctx.Attr<int>("seg_num");
|
|
|
|
|
float shift_ratio = ctx.Attr<float>("shift_ratio");
|
|
|
|
|
const std::string data_format_str = ctx.Attr<std::string>("data_format");
|
|
|
|
|
const DataLayout data_layout =
|
|
|
|
|
framework::StringToDataLayout(data_format_str);
|
|
|
|
|
|
|
|
|
|
const int nt = input->dims()[0];
|
|
|
|
|
const int c = input->dims()[1];
|
|
|
|
|
const int h = input->dims()[2];
|
|
|
|
|
const int w = input->dims()[3];
|
|
|
|
|
|
|
|
|
|
const int c1 = static_cast<int>(c * shift_ratio);
|
|
|
|
|
const int c2 = static_cast<int>(c * 2 * shift_ratio);
|
|
|
|
|
const int c = (data_layout == DataLayout::kNCHW ? input->dims()[1]
|
|
|
|
|
: input->dims()[3]);
|
|
|
|
|
const int h = (data_layout == DataLayout::kNCHW ? input->dims()[2]
|
|
|
|
|
: input->dims()[1]);
|
|
|
|
|
const int w = (data_layout == DataLayout::kNCHW ? input->dims()[3]
|
|
|
|
|
: input->dims()[2]);
|
|
|
|
|
|
|
|
|
|
const int hw = h * w;
|
|
|
|
|
const int chw = c * hw;
|
|
|
|
|
const int tchw = t * chw;
|
|
|
|
|
const int ntchw = nt * chw;
|
|
|
|
|
|
|
|
|
|
const int c1 = static_cast<int>(c * shift_ratio);
|
|
|
|
|
const int c2 = static_cast<int>(c * 2 * shift_ratio);
|
|
|
|
|
|
|
|
|
|
framework::DDim out_dims = (data_layout == DataLayout::kNCHW
|
|
|
|
|
? framework::make_ddim({nt, c, h, w})
|
|
|
|
|
: framework::make_ddim({nt, h, w, c}));
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
int src_it = 0;
|
|
|
|
|
for (int i = 0; i < output->numel(); i++) {
|
|
|
|
|
int in = i / tchw;
|
|
|
|
|
int it = (i % tchw) / chw;
|
|
|
|
|
int ic = (i % chw) / hw;
|
|
|
|
|
int ih = (i % hw) / w;
|
|
|
|
|
int iw = i % w;
|
|
|
|
|
|
|
|
|
|
if (ic < c1) {
|
|
|
|
|
src_it = it - 1;
|
|
|
|
|
} else if (ic < c2) {
|
|
|
|
|
src_it = it + 1;
|
|
|
|
|
} else {
|
|
|
|
|
src_it = it;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (src_it < 0 || src_it >= t) {
|
|
|
|
|
output_data[i] = 0;
|
|
|
|
|
} else {
|
|
|
|
|
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
|
|
|
|
|
output_data[i] = input_data[src_idx];
|
|
|
|
|
}
|
|
|
|
|
T* output_data = output->mutable_data<T>(out_dims, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
if (data_layout == DataLayout::kNCHW) {
|
|
|
|
|
TemporalShiftFwNCHW<T>(input_data, output_data, ntchw, tchw, chw, hw, t,
|
|
|
|
|
c1, c2);
|
|
|
|
|
} else {
|
|
|
|
|
TemporalShiftFwNHWC<T>(input_data, output_data, ntchw, tchw, chw, t, c,
|
|
|
|
|
c1, c2);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -83,44 +171,39 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
int t = ctx.Attr<int>("seg_num");
|
|
|
|
|
float shift_ratio = ctx.Attr<float>("shift_ratio");
|
|
|
|
|
const std::string data_format_str = ctx.Attr<std::string>("data_format");
|
|
|
|
|
const DataLayout data_layout =
|
|
|
|
|
framework::StringToDataLayout(data_format_str);
|
|
|
|
|
|
|
|
|
|
const int nt = output_grad->dims()[0];
|
|
|
|
|
const int c = output_grad->dims()[1];
|
|
|
|
|
const int h = output_grad->dims()[2];
|
|
|
|
|
const int w = output_grad->dims()[3];
|
|
|
|
|
|
|
|
|
|
const int c1 = static_cast<int>(c * shift_ratio);
|
|
|
|
|
const int c2 = static_cast<int>(c * 2 * shift_ratio);
|
|
|
|
|
const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1]
|
|
|
|
|
: output_grad->dims()[3]);
|
|
|
|
|
const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2]
|
|
|
|
|
: output_grad->dims()[1]);
|
|
|
|
|
const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3]
|
|
|
|
|
: output_grad->dims()[2]);
|
|
|
|
|
|
|
|
|
|
const int hw = h * w;
|
|
|
|
|
const int chw = c * hw;
|
|
|
|
|
const int tchw = t * chw;
|
|
|
|
|
const int ntchw = nt * chw;
|
|
|
|
|
|
|
|
|
|
const int c1 = static_cast<int>(c * shift_ratio);
|
|
|
|
|
const int c2 = static_cast<int>(c * 2 * shift_ratio);
|
|
|
|
|
|
|
|
|
|
framework::DDim in_grad_dims = (data_layout == DataLayout::kNCHW
|
|
|
|
|
? framework::make_ddim({nt, c, h, w})
|
|
|
|
|
: framework::make_ddim({nt, h, w, c}));
|
|
|
|
|
const T* output_grad_data = output_grad->data<T>();
|
|
|
|
|
T* input_grad_data =
|
|
|
|
|
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
|
|
|
|
|
memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
|
|
|
|
|
|
|
|
|
|
int src_it = 0;
|
|
|
|
|
for (int i = 0; i < output_grad->numel(); i++) {
|
|
|
|
|
int in = i / tchw;
|
|
|
|
|
int it = (i % tchw) / chw;
|
|
|
|
|
int ic = (i % chw) / hw;
|
|
|
|
|
int ih = (i % hw) / w;
|
|
|
|
|
int iw = i % w;
|
|
|
|
|
|
|
|
|
|
if (ic < c1) {
|
|
|
|
|
src_it = it - 1;
|
|
|
|
|
} else if (ic < c2) {
|
|
|
|
|
src_it = it + 1;
|
|
|
|
|
} else {
|
|
|
|
|
src_it = it;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (src_it >= 0 && src_it < t) {
|
|
|
|
|
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
|
|
|
|
|
input_grad_data[src_idx] = output_grad_data[i];
|
|
|
|
|
}
|
|
|
|
|
input_grad->mutable_data<T>(in_grad_dims, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
if (data_layout == DataLayout::kNCHW) {
|
|
|
|
|
TemporalShiftBwNCHW<T>(output_grad_data, input_grad_data, ntchw, tchw,
|
|
|
|
|
chw, hw, t, c1, c2);
|
|
|
|
|
} else {
|
|
|
|
|
TemporalShiftBwNHWC<T>(output_grad_data, input_grad_data, ntchw, tchw,
|
|
|
|
|
chw, t, c, c1, c2);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|