From 73192bb12ac78a546ae04aab26db9107719c535a Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 7 Aug 2017 19:09:34 +0800 Subject: [PATCH 1/4] add a batch norm inference kernel. --- paddle/cuda/CMakeLists.txt | 1 + paddle/cuda/include/hl_batch_norm.h | 50 +++++++++++++ paddle/cuda/src/hl_batch_norm.cu | 68 ++++++++++++++++++ paddle/gserver/layers/CudnnBatchNormLayer.cpp | 37 +++++++--- paddle/gserver/tests/test_BatchNorm.cpp | 70 +++++++++++++++++++ 5 files changed, 216 insertions(+), 10 deletions(-) create mode 100644 paddle/cuda/include/hl_batch_norm.h create mode 100644 paddle/cuda/src/hl_batch_norm.cu diff --git a/paddle/cuda/CMakeLists.txt b/paddle/cuda/CMakeLists.txt index 73ffa690d9..0865b02c4f 100755 --- a/paddle/cuda/CMakeLists.txt +++ b/paddle/cuda/CMakeLists.txt @@ -39,6 +39,7 @@ set(CUDA_CU_SOURCES src/hl_cuda_lstm.cu src/hl_top_k.cu src/hl_batch_transpose.cu + src/hl_batch_norm.cu src/hl_cuda_sequence.cu src/hl_table_apply.cu) diff --git a/paddle/cuda/include/hl_batch_norm.h b/paddle/cuda/include/hl_batch_norm.h new file mode 100644 index 0000000000..e1fea13163 --- /dev/null +++ b/paddle/cuda/include/hl_batch_norm.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_BATCH_NORM_H_ +#define HL_BATCH_NORM_H_ + +#include "hl_base.h" + +/** + * @brief batch norm inferece. + * + * @param[in] input input data. + * @param[out] output output data. + * @param[in] scale batch normalization scale parameter (in original + * paper scale is referred to as gamma). + * @param[in] bias batch normalization bias parameter (in original + * paper scale is referred to as beta). + * @param[in] estimatedMean + * @param[in] estimatedVar It is suggested that resultRunningMean, + * resultRunningVariance from the + * cudnnBatchNormalizationForwardTraining call + * accumulated during the training phase are passed + * as inputs here. + * @param[in] epsilon Epsilon value used in the batch + * normalization formula. + */ +extern void hl_batch_norm_cuda_inference(const real* input, + real* output, + const real* scale, + const real* bias, + const real* estimatedMean, + const real* estimatedVar, + const double epsilon, + size_t batchSize, + size_t channel, + size_t height, + size_t width); + +#endif // HL_BATCH_NORM_H_ diff --git a/paddle/cuda/src/hl_batch_norm.cu b/paddle/cuda/src/hl_batch_norm.cu new file mode 100644 index 0000000000..57474ee2f7 --- /dev/null +++ b/paddle/cuda/src/hl_batch_norm.cu @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_batch_norm.h" + +__global__ void batchNormInference(real* output, + const real* input, + const real* scale, + const real* bias, + const real* estimatedMean, + const real* estimatedVar, + const double epsilon, + size_t batchSize, + size_t channel, + size_t height, + size_t width) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int num = channel * height * width; + const int batch = blockIdx.y; + for (int i = tid; i < num; i += blockDim.x) { + const int c = (i / (height * width)) % channel; + const int id = batch * num + i; + real val = input[id] - estimatedMean[c]; + val /= sqrt(estimatedVar[c] + epsilon); + val *= scale[c]; + val += bias[c]; + output[id] = val; + } +} + +void hl_batch_norm_cuda_inference(const real* input, + real* output, + const real* scale, + const real* bias, + const real* estimatedMean, + const real* estimatedVar, + const double epsilon, + size_t batchSize, + size_t channel, + size_t height, + size_t width) { + dim3 block(256, 1); + dim3 grid(1, batchSize); + batchNormInference<<>>(output, + input, + scale, + bias, + estimatedMean, + estimatedVar, + epsilon, + batchSize, + channel, + height, + width); + + CHECK_SYNC("hl_batch_norm_cuda_inference failed!"); +} diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp index 09dac05a7a..d99b50385e 100644 --- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp +++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include "CudnnBatchNormLayer.h" #include "Layer.h" +#include "paddle/cuda/include/hl_batch_norm.h" #include "paddle/utils/Stat.h" namespace paddle { @@ -79,16 +80,32 @@ void CudnnBatchNormLayer::forward(PassType passType) { savedInvVar); } else { // used movingMean and movingVar in testing - hl_batch_norm_forward_inference(ioDesc_, - input, - ioDesc_, - output, - bnParamDesc_, - gamma, - beta, - movingMean, - movingVar, - EPS); + if (batchSize > 1024) { + // when batchSize is larger than 1024, there is a bug + // in cudnn library. + hl_batch_norm_cuda_inference(input, + output, + gamma, + beta, + movingMean, + movingVar, + EPS, + batchSize, + channels_, + imageH_, + imageW_); + } else { + hl_batch_norm_forward_inference(ioDesc_, + input, + ioDesc_, + output, + bnParamDesc_, + gamma, + beta, + movingMean, + movingVar, + EPS); + } } /* activation */ { diff --git a/paddle/gserver/tests/test_BatchNorm.cpp b/paddle/gserver/tests/test_BatchNorm.cpp index 83fcfed46c..659eefa31b 100644 --- a/paddle/gserver/tests/test_BatchNorm.cpp +++ b/paddle/gserver/tests/test_BatchNorm.cpp @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/utils/GlobalConstants.h" #include "LayerGradUtil.h" +#include "paddle/cuda/include/hl_batch_norm.h" +#include "paddle/math/tests/TensorCheck.h" #include "paddle/testing/TestUtil.h" using namespace paddle; // NOLINT @@ -117,6 +119,74 @@ TEST(Layer, batchNorm) { CHECK_EQ(static_cast(convLayer->getOutputValue()->getWidth()), 576); } +#ifndef PADDLE_ONLY_CPU +void batchNormInference(int n, int c, int h, int w) { + MatrixPtr input = std::make_shared(n, c * h * w); + MatrixPtr cudnnOut = std::make_shared(n, c * h * w); + MatrixPtr cudaOut = std::make_shared(n, c * h * w); + MatrixPtr cudnnCheck = std::make_shared(n, c * h * w); + MatrixPtr cudaCheck = std::make_shared(n, c * h * w); + input->randomizeUniform(); + cudnnOut->zeroMem(); + cudaOut->zeroMem(); + + MatrixPtr scale = std::make_shared(1, c); + scale->randomizeUniform(); + MatrixPtr bias = std::make_shared(1, c); + bias->randomizeUniform(); + + MatrixPtr movingMean = std::make_shared(1, c); + movingMean->randomizeUniform(); + + MatrixPtr movingVar = std::make_shared(1, c); + movingVar->randomizeUniform(); + movingVar->clip(0.01, 50); + + hl_tensor_descriptor ioDesc; + hl_tensor_descriptor bnDesc; + hl_create_tensor_descriptor(&ioDesc); + hl_create_tensor_descriptor(&bnDesc); + hl_tensor_reshape(ioDesc, n, c, h, w); + hl_tensor_reshape(bnDesc, 1, c, 1, 1); + + double EPS = 1E-5; + hl_batch_norm_forward_inference(ioDesc, + input->getData(), + ioDesc, + cudnnOut->getData(), + bnDesc, + scale->getData(), + bias->getData(), + movingMean->getData(), + movingVar->getData(), + EPS); + + hl_batch_norm_cuda_inference(input->getData(), + cudaOut->getData(), + scale->getData(), + bias->getData(), + movingMean->getData(), + movingVar->getData(), + EPS, + n, + c, + h, + w); + + cudnnCheck->copyFrom(*cudnnOut); + cudaCheck->copyFrom(*cudaOut); + autotest::TensorCheckErr(*cudnnCheck, *cudaCheck); + + hl_destroy_tensor_descriptor(ioDesc); + hl_destroy_tensor_descriptor(bnDesc); +} + +TEST(BatchNorm, Inference) { + batchNormInference(33, 267, 1, 1); + batchNormInference(19, 105, 4, 4); +} +#endif + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); From bf08e5d985a39f1bb4d9085c042cdc78de8fbecb Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 7 Aug 2017 19:18:40 +0800 Subject: [PATCH 2/4] modify code comments. --- paddle/cuda/include/hl_batch_norm.h | 24 +++++++++---------- paddle/gserver/layers/CudnnBatchNormLayer.cpp | 4 ++-- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/paddle/cuda/include/hl_batch_norm.h b/paddle/cuda/include/hl_batch_norm.h index e1fea13163..afc5e0b2de 100644 --- a/paddle/cuda/include/hl_batch_norm.h +++ b/paddle/cuda/include/hl_batch_norm.h @@ -20,20 +20,18 @@ limitations under the License. */ /** * @brief batch norm inferece. * - * @param[in] input input data. - * @param[out] output output data. - * @param[in] scale batch normalization scale parameter (in original - * paper scale is referred to as gamma). - * @param[in] bias batch normalization bias parameter (in original - * paper scale is referred to as beta). + * @param[in] input input data. + * @param[out] output output data. + * @param[in] scale batch normalization scale parameter (in original + * paper scale is referred to as gamma). + * @param[in] bias batch normalization bias parameter (in original + * paper scale is referred to as beta). * @param[in] estimatedMean - * @param[in] estimatedVar It is suggested that resultRunningMean, - * resultRunningVariance from the - * cudnnBatchNormalizationForwardTraining call - * accumulated during the training phase are passed - * as inputs here. - * @param[in] epsilon Epsilon value used in the batch - * normalization formula. + * @param[in] estimatedVar The moving mean and variance + * accumulated during the training phase are passed + * as inputs here. + * @param[in] epsilon Epsilon value used in the batch + * normalization formula. */ extern void hl_batch_norm_cuda_inference(const real* input, real* output, diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp index d99b50385e..cc2cc21cdf 100644 --- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp +++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp @@ -81,8 +81,8 @@ void CudnnBatchNormLayer::forward(PassType passType) { } else { // used movingMean and movingVar in testing if (batchSize > 1024) { - // when batchSize is larger than 1024, there is a bug - // in cudnn library. + // there is a bug in cudnn library when the batch size + // is larger than 1024. hl_batch_norm_cuda_inference(input, output, gamma, From da7b9a5eb309d936cf836b5201a71962e895e2c4 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 7 Aug 2017 19:26:10 +0800 Subject: [PATCH 3/4] Remove the warning in hl_batch_norm_forward_inference function. --- paddle/cuda/src/hl_cuda_cudnn.cc | 8 -------- 1 file changed, 8 deletions(-) diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index 7ad8a39768..78642a1744 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -1023,14 +1023,6 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc, real beta = 1.0f; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; - int batch_size = ((cudnn_tensor_descriptor)inputDesc)->batch_size; - if (batch_size > 1024 && g_cudnn_lib_version < 6000) { - LOG(INFO) << " To process current batch data with size " << batch_size - << " (>1024), cudnnBatchNorm requires cuDNN version >= 6000." - << " If there is an error complaining CUDNN_STATUS_NOT_SUPPORTED," - << " just recompile PaddlePaddle with cuDNN >= 6000, replacing" - << " current version " << g_cudnn_lib_version; - } CHECK_CUDNN( dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle, mode, From 7da1db053bc14f3c3f96ba3bae36519f679abcb4 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 7 Aug 2017 20:27:08 +0800 Subject: [PATCH 4/4] update cuda kernel. --- paddle/cuda/src/hl_batch_norm.cu | 30 +++++++++---------- paddle/gserver/layers/CudnnBatchNormLayer.cpp | 29 +++++++++--------- 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/paddle/cuda/src/hl_batch_norm.cu b/paddle/cuda/src/hl_batch_norm.cu index 57474ee2f7..5828ecb8e0 100644 --- a/paddle/cuda/src/hl_batch_norm.cu +++ b/paddle/cuda/src/hl_batch_norm.cu @@ -25,11 +25,11 @@ __global__ void batchNormInference(real* output, size_t channel, size_t height, size_t width) { - const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int tid = threadIdx.x; const int num = channel * height * width; - const int batch = blockIdx.y; + const int batch = blockIdx.x; for (int i = tid; i < num; i += blockDim.x) { - const int c = (i / (height * width)) % channel; + const int c = i / (height * width); const int id = batch * num + i; real val = input[id] - estimatedMean[c]; val /= sqrt(estimatedVar[c] + epsilon); @@ -50,19 +50,17 @@ void hl_batch_norm_cuda_inference(const real* input, size_t channel, size_t height, size_t width) { - dim3 block(256, 1); - dim3 grid(1, batchSize); - batchNormInference<<>>(output, - input, - scale, - bias, - estimatedMean, - estimatedVar, - epsilon, - batchSize, - channel, - height, - width); + batchNormInference<<>>(output, + input, + scale, + bias, + estimatedMean, + estimatedVar, + epsilon, + batchSize, + channel, + height, + width); CHECK_SYNC("hl_batch_norm_cuda_inference failed!"); } diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp index cc2cc21cdf..44ba2c4b7d 100644 --- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp +++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp @@ -80,9 +80,21 @@ void CudnnBatchNormLayer::forward(PassType passType) { savedInvVar); } else { // used movingMean and movingVar in testing - if (batchSize > 1024) { - // there is a bug in cudnn library when the batch size - // is larger than 1024. + if (batchSize <= 1024) { + hl_batch_norm_forward_inference(ioDesc_, + input, + ioDesc_, + output, + bnParamDesc_, + gamma, + beta, + movingMean, + movingVar, + EPS); + } else { + // There is a limitation in cudnn library. + // When the batch size is larger than 1024 in cuDNN v5.1, + // the cudnnBatchNormalizationForwardInference will fail. hl_batch_norm_cuda_inference(input, output, gamma, @@ -94,17 +106,6 @@ void CudnnBatchNormLayer::forward(PassType passType) { channels_, imageH_, imageW_); - } else { - hl_batch_norm_forward_inference(ioDesc_, - input, - ioDesc_, - output, - bnParamDesc_, - gamma, - beta, - movingMean, - movingVar, - EPS); } }