diff --git a/CMakeLists.txt b/CMakeLists.txt index d43df124bd..24262c1821 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,7 @@ option(WITH_INFERENCE "Compile fluid inference library" ON) option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF) option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) +option(WITH_FAST_MATH "Make use of fast math library" OFF) # PY_VERSION if(NOT PY_VERSION) diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 03c73786a6..f507bb41a1 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -175,7 +175,10 @@ list(APPEND CUDA_NVCC_FLAGS "-std=c++11") list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC") endif(NOT WIN32) -list(APPEND CUDA_NVCC_FLAGS "--use_fast_math") +if(WITH_FAST_MATH) + # Make use of fast math library. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html + list(APPEND CUDA_NVCC_FLAGS "--use_fast_math") +endif() # in cuda9, suppress cuda warning on eigen list(APPEND CUDA_NVCC_FLAGS "-w") # Set :expt-relaxed-constexpr to suppress Eigen warnings diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index e029300eee..573ad5e5f0 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -3,6 +3,14 @@ INCLUDE(ExternalProject) SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3) SET(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}/src/extern_eigen3) INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR}) +if(NOT WITH_FAST_MATH) + # EIGEN_FAST_MATH: https://eigen.tuxfamily.org/dox/TopicPreprocessorDirectives.html + # enables some optimizations which might affect the accuracy of the result. + # This currently enables the SSE vectorization of sin() and cos(), + # and speedups sqrt() for single precision. + # Defined to 1 by default. Define it to 0 to disable. + add_definitions(-DEIGEN_FAST_MATH=0) +endif() if(WITH_AMD_GPU) ExternalProject_Add( diff --git a/cmake/flags.cmake b/cmake/flags.cmake index 331b2af367..343e44ab4b 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -157,6 +157,8 @@ if (APPLE) # On Mac OS X build fat binaries with x86_64 architectures by default. set (CMAKE_OSX_ARCHITECTURES "x86_64" CACHE STRING "Build architectures for OSX" FORCE) endif() + # On Mac OS X register class specifier is deprecated and will cause warning error on latest clang 10.0 + set (COMMON_FLAGS -Wno-deprecated-register) endif(APPLE) if(LINUX) diff --git a/paddle/fluid/framework/rw_lock.h b/paddle/fluid/framework/rw_lock.h index da163835e8..dbf00f3a79 100644 --- a/paddle/fluid/framework/rw_lock.h +++ b/paddle/fluid/framework/rw_lock.h @@ -46,6 +46,7 @@ struct RWLock { private: pthread_rwlock_t lock_; }; +// TODO(paddle-dev): Support RWLock for WIN32 for correctness. #else // https://stackoverflow.com/questions/7125250/making-pthread-rwlock-wrlock-recursive // In windows, rw_lock seems like a hack. Use empty object and do nothing. diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc index 290fb007d8..8add7a59da 100644 --- a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc @@ -27,6 +27,9 @@ void SetConfig(AnalysisConfig *cfg) { cfg->device = 0; cfg->enable_ir_optim = true; cfg->specify_input_name = true; +#ifdef PADDLE_WITH_MKLDNN + cfg->_use_mkldnn = true; +#endif } void SetInput(std::vector> *inputs) { diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 9da8551eb2..8e4a07556f 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid, * 3. go to the second setp, until one thread's topk value is null; * 4. go to the first setp, until get the topk value. */ + template __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, - const T* src, int lds, int dim, int k) { + const T* src, int lds, int dim, int k, + int grid_dim, int num) { __shared__ Pair sh_topk[BlockSize]; __shared__ int maxid[BlockSize / 2]; const int tid = threadIdx.x; const int warp = threadIdx.x / 32; - output += blockIdx.x * output_stride; - indices += blockIdx.x * k; - Pair topk[MaxLength]; - int beam = MaxLength; - Pair max; - bool is_empty = false; - bool firststep = true; + const int bid = blockIdx.x; + for (int i = bid; i < num; i += grid_dim) { + output += i * output_stride; + indices += i * k; + + Pair topk[MaxLength]; + int beam = MaxLength; + Pair max; + bool is_empty = false; + bool firststep = true; + + for (int k = 0; k < MaxLength; k++) { + topk[k].set(-INFINITY, -1); + } + while (k) { + ThreadGetTopK( + topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid); - for (int k = 0; k < MaxLength; k++) { - topk[k].set(-INFINITY, -1); + sh_topk[tid] = topk[0]; + BlockReduce(sh_topk, maxid, topk, &output, + &indices, &beam, &k, tid, warp); + } } - while (k) { - ThreadGetTopK(topk, &beam, k, - src + blockIdx.x * lds, &firststep, - &is_empty, &max, dim, tid); - - sh_topk[tid] = topk[0]; - BlockReduce(sh_topk, maxid, topk, &output, - &indices, &beam, &k, tid, warp); +} + +inline static int GetDesiredBlockDim(int dim) { + if (dim > 128) { + return 256; + } else if (dim > 64) { + return 128; + } else if (dim > 32) { + return 64; + } else { + return 32; } } +#define FIXED_BLOCK_DIM_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kBlockDim = (dim); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) + template class TopkOpCUDAKernel : public framework::OpKernel { public: @@ -310,18 +339,26 @@ class TopkOpCUDAKernel : public framework::OpKernel { // NOTE: pass lds and dim same to input width. // NOTE: old matrix implementation of stride is different to eigen. // TODO(typhoonzero): refine this kernel. - dim3 threads(256, 1); - dim3 grid(input_height, 1); - - KeMatrixTopK<<< - grid, threads, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>( - output_data, output->dims()[1], indices_data, input_data, input_width, - input_width, static_cast(k)); + const int kMaxHeight = 2048; + int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; + auto& dev_ctx = ctx.cuda_device_context(); + + switch (GetDesiredBlockDim(input_width)) { + FIXED_BLOCK_DIM( + KeMatrixTopK<<>>( + output_data, output->dims()[1], indices_data, input_data, + input_width, input_width, static_cast(k), gridx, + input_height)); + default: + PADDLE_THROW("Error"); + } } }; +#undef FIXED_BLOCK_DIM_BASE +#undef FIXED_BLOCK_DIM + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 16eac1ec24..3c8a01b6e4 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -224,10 +224,12 @@ class WhileGradOp : public framework::OperatorBase { if (cur_scope_iter == step_scopes->rbegin()) { auto *var = (*cur_scope_iter)->FindVar(inside_grad_name); PADDLE_ENFORCE_NOT_NULL(var, "Can not find var %s", inside_grad_name); - PADDLE_ENFORCE(var->IsType() || - var->IsType(), - "Currently the type of var only can be LoDTensorArray " - "or LoDTensor."); + PADDLE_ENFORCE( + var->IsType() || + var->IsType(), + "Currently the type of var only can be LoDTensorArray, " + "or LoDTensor, but the received var[%s] is %s.", + inside_grad_name, var->Type().name()); if (var->IsType()) { auto &inside_tensor = var->Get(); diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 126636d879..f599e7fbc8 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -20,8 +20,11 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" DEFINE_double(fraction_of_gpu_memory_to_use, 0.92, - "Default use 92% of GPU memory for PaddlePaddle," - "reserve the rest for page tables, etc"); + "Allocate a trunk of gpu memory that is this fraction of the " + "total gpu memory size. Future memory usage will be allocated " + "from the trunk. If the trunk doesn't have enough gpu memory, " + "additional trunks of the same size will be requested from gpu " + "until the gpu has no memory left for another trunk."); namespace paddle { namespace platform { diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index d9214d0b8c..b434c9f08e 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -598,7 +598,7 @@ EOF EOF if [[ ${WITH_GPU} == "ON" ]]; then - NCCL_DEPS="apt-get install -y --allow-downgrades libnccl2=2.2.13-1+cuda${CUDA_MAJOR} libnccl-dev=2.2.13-1+cuda${CUDA_MAJOR} &&" + NCCL_DEPS="apt-get install -y --allow-downgrades libnccl2=2.2.13-1+cuda${CUDA_MAJOR} libnccl-dev=2.2.13-1+cuda${CUDA_MAJOR} || true" else NCCL_DEPS="" fi @@ -614,9 +614,8 @@ EOF cat >> ${PADDLE_ROOT}/build/Dockerfile <