From 25af911ed904301847d33ddca8905b65599f69da Mon Sep 17 00:00:00 2001 From: VectorSL Date: Tue, 28 Apr 2020 17:03:12 +0800 Subject: [PATCH] gpu update bn --- mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h | 5 +++-- .../ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h index 6f0c59e29a..26f4332273 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h @@ -82,6 +82,7 @@ class FusedBatchNormGpuKernel : public GpuKernel { } bool Init(const CNodePtr &kernel_node) override { InitResource(); + cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 5) { MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5"; @@ -112,11 +113,11 @@ class FusedBatchNormGpuKernel : public GpuKernel { } CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch_, channel_, height_, width_), + cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), "Set x desc failed"); CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch_, channel_, height_, width_), + cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), "Set y desc failed"); CHECK_CUDNN_RET_WITH_EXCEPT( diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h index 08eac28af7..07372ad22d 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h @@ -110,7 +110,7 @@ class FusedBatchNormGradGpuKernel : public GpuKernel { cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), "Set dx desc failed"); CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 1, channel_, 1, 1), + cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), "Set para desc failed"); InitSizeLists();