diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cu new file mode 100644 index 0000000000..dba71d8f69 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cu @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "batchnorm_grad_impl.cuh" +#include "include/cuda_runtime.h" + +const int kWarpSize = 32; +const int kBlockSize = 1024; +const int kNumWarps = 32; + +template +__global__ void BatchNormGradKernel(T *x_input, T *dy, float *scale, float *save_mean, float *save_variance, T *dx, + float *bn_scale, float *bn_bias, double epsilon, int N, int C, int H, int W) { + __shared__ T shared_dy[kNumWarps]; + __shared__ T shared_p[kNumWarps]; + int warpId = threadIdx.x / kWarpSize; + int laneId = threadIdx.x % kWarpSize; + + int plane = blockIdx.x; + int plane_size = N * H * W; + + T invstd = static_cast(1) / static_cast(sqrt(save_variance[plane] + epsilon)); + T scale_val = scale != nullptr ? static_cast(scale[plane]) : static_cast(1); + T grad_scale = invstd * scale_val; + + T mean = static_cast(save_mean[plane]); + T dy_sum = static_cast(0); + T dot_p = static_cast(0); + + if (threadIdx.x < kNumWarps) { + shared_dy[threadIdx.x] = static_cast(0); + shared_p[threadIdx.x] = static_cast(0); + } + __syncthreads(); + + // Compute three values across (Batch, Height, Width) in one pass: + // 1. dx + // 2. Sum(dy) + // 3. DotProduct(x - mean, dy) + for (int x = threadIdx.x; x < plane_size; x += blockDim.x) { + int index = (x / (H * W) * C * H * W) + (plane * H * W) + (x % (H * W)); + dx[index] = static_cast(dy[index] * grad_scale); + dy_sum += dy[index]; + dot_p += (x_input[index] - mean) * dy[index]; + } + __syncthreads(); + + // Warp reduction + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { + T other_dy = __shfl_down_sync(0xffffffff, dy_sum, offset); + T other_p = __shfl_down_sync(0xffffffff, dot_p, offset); + dy_sum += other_dy; + dot_p += other_p; + } + __syncwarp(); + + // Move warp-reduction result to shared memory + if (laneId == 0) { + shared_dy[warpId] = dy_sum; + shared_p[warpId] = dot_p; + } + __syncthreads(); + + // Shared memory reduction + // There are exactly 32 items in shared memory, can be reduced within one warp. + if (warpId == 0) { + dy_sum = shared_dy[laneId]; + dot_p = shared_p[laneId]; + __syncwarp(); + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { + T other_dy = __shfl_down_sync(0xffffffff, dy_sum, offset); + T other_p = __shfl_down_sync(0xffffffff, dot_p, offset); + dy_sum += other_dy; + dot_p += other_p; + } + __syncwarp(); + } + + // Compute bn_scale & bn_bias + if (threadIdx.x == 0) { + bn_scale[plane] = static_cast(dot_p * invstd); + } + + if (threadIdx.x == 0) { + bn_bias[plane] = static_cast(dy_sum); + } +} + +template +void CalBatchNormGrad(T *x, T *dy, float *scale, float *save_mean, float *save_variance, T *dx, float *bn_scale, + float *bn_bias, double epsilon, int N, int C, int H, int W, cudaStream_t cuda_stream) { + BatchNormGradKernel<<>>(x, dy, scale, save_mean, save_variance, dx, bn_scale, bn_bias, + epsilon, N, C, H, W); +} + +template void CalBatchNormGrad(float *x, float *dy, float *scale, float *save_mean, float *save_variance, + float *dx, float *bn_scale, float *bn_bias, double epsilon, int N, int C, int H, + int W, cudaStream_t cuda_stream); + +template void CalBatchNormGrad(half *x, half *dy, float *scale, float *save_mean, float *save_variance, half *dx, + float *bn_scale, float *bn_bias, double epsilon, int N, int C, int H, int W, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh new file mode 100644 index 0000000000..351d70a31e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMGRAD_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMGRAD_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void CalBatchNormGrad(T *x, T *dy, float *scale, float *save_mean, float *save_variance, T *dx, float *bn_scale, + float *bn_bias, double epsilon, int N, int C, int H, int W, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMGRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batchnorm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batchnorm_grad_gpu_kernel.h index 8ba2678ff3..2c42781e13 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batchnorm_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batchnorm_grad_gpu_kernel.h @@ -21,6 +21,7 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh" namespace mindspore { namespace kernel { @@ -66,16 +67,21 @@ class BatchNormGradGpuKernel : public GpuKernel { // For CI only, reserved vars can not be unused. MS_LOG(DEBUG) << reinterpret_cast(reserve_1) << reinterpret_cast(reserve_2); // NOLINT - const float alpha_data_diff = 1; - const float beta_data_diff = 0; - const float alpha_param_diff = 1; - const float beta_param_diff = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, - &beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, scale, - bn_scale, bn_bias, epsilon_, save_mean, save_variance), - "Kernel Launch Failed."); + if (is_training_) { + const float alpha_data_diff = 1; + const float beta_data_diff = 0; + const float alpha_param_diff = 1; + const float beta_param_diff = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, + &beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, + scale, bn_scale, bn_bias, epsilon_, save_mean, save_variance), + "Kernel Launch Failed."); + } else { + CalBatchNormGrad(x, dy, scale, save_mean, save_variance, dx, bn_scale, bn_bias, epsilon_, batch_, channel_, + height_, width_, reinterpret_cast(stream_ptr)); + } return true; } bool Init(const CNodePtr &kernel_node) override { @@ -104,6 +110,7 @@ class BatchNormGradGpuKernel : public GpuKernel { width_ = SizeToInt(shape[3]); mode_ = CUDNN_BATCHNORM_SPATIAL; + is_training_ = GetAttr(kernel_node, "is_training"); epsilon_ = GetAttr(kernel_node, "epsilon"); CHECK_CUDNN_RET_WITH_EXCEPT( @@ -175,6 +182,7 @@ class BatchNormGradGpuKernel : public GpuKernel { int width_; cudnnBatchNormMode_t mode_; + bool is_training_; double epsilon_; bool is_null_input_; cudnnTensorDescriptor_t x_desc_; diff --git a/tests/st/ops/gpu/test_batchnorm_op.py b/tests/st/ops/gpu/test_batchnorm_op.py index 0e4e803ffa..d8457c01f2 100644 --- a/tests/st/ops/gpu/test_batchnorm_op.py +++ b/tests/st/ops/gpu/test_batchnorm_op.py @@ -178,3 +178,26 @@ def test_train_stats_false_forward(): diff = output.asnumpy() - expect_output assert np.all(diff < error) assert np.all(-diff < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_infer_backward(): + expect_output = np.array([[[[-0.3224156, -0.3840524], [1.1337637, -1.0998858]], + [[-0.1724273, -0.877854], [0.0422135, 0.5828123]], + [[-1.1006137, 1.1447179], [0.9015862, 0.5024918]]]]).astype(np.float32) + np.random.seed(1) + x_np = np.random.randn(1, 3, 2, 2).astype(np.float32) + input_grad_np = np.random.randn(1, 3, 2, 2).astype(np.float32) + ms_input = Tensor(x_np) + weight = Tensor(np.ones(3).astype(np.float32)) + bias = Tensor(np.zeros(3).astype(np.float32)) + moving_mean = Tensor(np.zeros(3).astype(np.float32)) + moving_var_init = Tensor(np.ones(3).astype(np.float32)) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + ms_net = Batchnorm_Net(3, weight, bias, moving_mean, moving_var_init) + ms_net.set_train(False) + ms_grad = Grad(ms_net) + ms_out_grad_np = ms_grad(ms_input, Tensor(input_grad_np)) + assert np.allclose(ms_out_grad_np[0].asnumpy(), expect_output)