fix input_grad not set zero. test=develop

revert-16555-model_data_cryption_link_all_lib
dengkaipeng 6 years ago
parent c9e0ade530
commit 71101c9cf7

@ -129,6 +129,9 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T>()(
ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
static_cast<T>(0));
int pixelNum = nt * chw;
int grid_dim = (pixelNum + 512 - 1) / 512;

@ -88,6 +88,7 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
int src_it = 0;
for (int i = 0; i < output_grad->numel(); i++) {

Loading…
Cancel
Save