Fix mistake of batch norm op (#21237)

* fix_bn

* revert unittest,test=develop
revert-21172-masked_select_api
Lv Mengsi 6 years ago committed by GitHub
parent 41d13209d7
commit b6ce4f8b2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -418,7 +418,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
} else {
if (d_x) {
BNBackwardData<T, block, framework::DataLayout::kNCHW><<<
BNBackwardData<T, block, framework::DataLayout::kNHWC><<<
grid2, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,

Loading…
Cancel
Save