|
|
|
@ -117,9 +117,6 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
|
|
|
|
|
math::SetConstant<platform::GPUPlace, T> functor;
|
|
|
|
|
functor(ctx.device_context(), saved_mean, 0);
|
|
|
|
|
functor(ctx.device_context(), saved_variance, 0);
|
|
|
|
|
// FIXME(qiao) should not set zero self
|
|
|
|
|
functor(ctx.device_context(), mean_out, 0);
|
|
|
|
|
functor(ctx.device_context(), variance_out, 0);
|
|
|
|
|
|
|
|
|
|
auto handle = ctx.cuda_device_context().cudnn_handle();
|
|
|
|
|
|
|
|
|
@ -211,8 +208,15 @@ class BatchNormGradKernel<platform::GPUPlace, T>
|
|
|
|
|
mode_ = CUDNN_BATCHNORM_SPATIAL;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
std::vector<int> dims = {N, C, H, W, D};
|
|
|
|
|
std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C};
|
|
|
|
|
std::vector<int> dims;
|
|
|
|
|
std::vector<int> strides;
|
|
|
|
|
if (tensor_format == TensorFormat::NCHW) {
|
|
|
|
|
dims = {N, C, H, W, D};
|
|
|
|
|
strides = {C * H * W * D, H * W * D, W * D, D, 1};
|
|
|
|
|
} else {
|
|
|
|
|
dims = {N, C, H, W, D};
|
|
|
|
|
strides = {H * W * C * D, 1, W * D * C, D * C, C};
|
|
|
|
|
}
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
|
|
|
|
|
data_desc_, CudnnDataType<T>::type,
|
|
|
|
|
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
|
|
|
|
|