|
|
|
@ -129,6 +129,9 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
const T* output_grad_data = output_grad->data<T>();
|
|
|
|
|
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
|
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, T>()(
|
|
|
|
|
ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
|
|
|
|
|
static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
int pixelNum = nt * chw;
|
|
|
|
|
int grid_dim = (pixelNum + 512 - 1) / 512;
|
|
|
|
|