diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake index aebb5d9fcb..0918e6cc63 100644 --- a/cmake/cblas.cmake +++ b/cmake/cblas.cmake @@ -44,7 +44,6 @@ if(MKL_INC_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64) message(STATUS "Found MKL (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") set(CBLAS_FOUND ON) if(${MKL_LAPACK_INC_DIR}) - add_definitions(-DPADDLE_USE_LAPACK) message(STATUS "Found lapack in MKL (include: ${MKL_LAPACK_INC_DIR})") endif() return() # return file. @@ -80,7 +79,6 @@ if(ATLAS_INC_DIR AND ATLAS_CBLAS_LIB AND ATLAS_LIB AND NOT CBLAS_FOUND) message(STATUS "Found ATLAS (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") set(CBLAS_FOUND ON) if(ATLAS_CLAPACK_INC_DIR) - add_definitions(-DPADDLE_USE_LAPACK) set(CBLAS_INC_DIR ${CBLAS_INC_DIR} ${ATLAS_CLAPACK_INC_DIR}) message(STATUS "Found lapack in ATLAS (include: ${ATLAS_CLAPACK_INC_DIR})") endif() @@ -115,7 +113,6 @@ if(OPENBLAS_INC_DIR AND OPENBLAS_LIB) message(STATUS "Found OpenBLAS (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") set(CBLAS_FOUND ON) if(OPENBLAS_LAPACKE_INC_DIR) - add_definitions(-DPADDLE_USE_LAPACK) message(STATUS "Found lapack in OpenBLAS (include: ${OPENBLAS_LAPACKE_INC_DIR})") endif() return() diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index 92ea23c763..46398b22c2 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -27,35 +27,6 @@ IF(NOT ${CBLAS_FOUND}) SET(CBLAS_LIBRARIES "${CBLAS_INSTALL_DIR}/lib/libopenblas.a" CACHE FILEPATH "openblas library" FORCE) ENDIF(WIN32) - IF(CMAKE_COMPILER_IS_GNUCC) - ENABLE_LANGUAGE(Fortran) - if (NOT CMAKE_Fortran_COMPILER_VERSION) - # cmake < 3.4 cannot get CMAKE_Fortran_COMPILER_VERSION directly. - execute_process(COMMAND ${CMAKE_Fortran_COMPILER} -dumpversion - OUTPUT_VARIABLE CMAKE_Fortran_COMPILER_VERSION) - endif() - string(REGEX MATCHALL "[0-9]+" Fortran_VERSION ${CMAKE_Fortran_COMPILER_VERSION}) - list(GET Fortran_VERSION 0 Fortran_MAJOR) - list(GET Fortran_VERSION 1 Fortran_MINOR) - find_library(GFORTRAN_LIBRARY NAMES gfortran PATHS - /lib - /usr/lib - /usr/lib/gcc/x86_64-linux-gnu/${Fortran_MAJOR}.${Fortran_MINOR}/ - /usr/lib/gcc/x86_64-linux-gnu/${Fortran_MAJOR}/) - if (NOT GFORTRAN_LIBRARY) - message(FATAL_ERROR "Cannot found gfortran library which it is used by openblas") - endif() - find_package(Threads REQUIRED) - LIST(APPEND CBLAS_LIBRARIES ${GFORTRAN_LIBRARY} ${CMAKE_THREAD_LIBS_INIT}) - ENDIF(CMAKE_COMPILER_IS_GNUCC) - - IF(NOT CMAKE_Fortran_COMPILER) - MESSAGE(FATAL_ERROR "To build lapack in libopenblas, " - "you need to set gfortran compiler: cmake .. -DCMAKE_Fortran_COMPILER=...") - ENDIF(NOT CMAKE_Fortran_COMPILER) - - ADD_DEFINITIONS(-DPADDLE_USE_LAPACK) - ExternalProject_Add( openblas ${EXTERNAL_PROJECT_LOG_ARGS} @@ -64,7 +35,7 @@ IF(NOT ${CBLAS_FOUND}) PREFIX ${CBLAS_SOURCES_DIR} INSTALL_DIR ${CBLAS_INSTALL_DIR} BUILD_IN_SOURCE 1 - BUILD_COMMAND ${CMAKE_MAKE_PROGRAM} FC=${CMAKE_Fortran_COMPILER} CC=${CMAKE_C_COMPILER} HOSTCC=${CMAKE_C_COMPILER} DYNAMIC_ARCH=1 NO_SHARED=1 libs netlib + BUILD_COMMAND ${CMAKE_MAKE_PROGRAM} FC=${CMAKE_Fortran_COMPILER} CC=${CMAKE_C_COMPILER} HOSTCC=${CMAKE_C_COMPILER} NO_LAPACK=1 DYNAMIC_ARCH=1 NO_SHARED=1 libs netlib INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 PREFIX= UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/paddle/cuda/CMakeLists.txt b/paddle/cuda/CMakeLists.txt index a28ccd6f07..f9061e96de 100755 --- a/paddle/cuda/CMakeLists.txt +++ b/paddle/cuda/CMakeLists.txt @@ -21,16 +21,13 @@ set(CUDA_CXX_WITH_GPU_SOURCES if(WITH_GPU) set(CUDA_CXX_SOURCES - src/hl_dso_loader.cc src/hl_warpctc_wrap.cc ${CUDA_CXX_WITH_GPU_SOURCES}) set_source_files_properties(${CUDA_CXX_SOURCES} PROPERTIES COMPILE_FLAGS "-D__NVCC__") else() - set(CUDA_CXX_SOURCES - src/hl_dso_loader.cc - src/hl_warpctc_wrap.cc) + set(CUDA_CXX_SOURCES src/hl_warpctc_wrap.cc) endif() set(CUDA_CU_SOURCES @@ -47,7 +44,6 @@ set(CUDA_CU_SOURCES set(CUDA_HEADERS include/hl_time.h - include/hl_dso_loader.h include/hl_warpctc_wrap.h include/hl_sequence.h include/hl_cuda_cublas.h diff --git a/paddle/cuda/include/hl_activation_functions.h b/paddle/cuda/include/hl_activation_functions.h index cdb2dba06c..93957fd964 100644 --- a/paddle/cuda/include/hl_activation_functions.h +++ b/paddle/cuda/include/hl_activation_functions.h @@ -40,18 +40,18 @@ public: namespace gpu { static __device__ Active::forward forward[] = HPPL_ACTIVE_FUNCTION; static __device__ Active::backward backward[] = HPPL_ACTIVE_FUNCTION; -} +} // namespace gpu #else namespace cpu { static Active::forward forward[] = HPPL_ACTIVE_FUNCTION; static Active::backward backward[] = HPPL_ACTIVE_FUNCTION; -} +} // namespace cpu #ifdef __AVX__ namespace avx { static Active<__m256>::forward forward[] = HPPL_ACTIVE_FUNCTION; static Active<__m256>::backward backward[] = HPPL_ACTIVE_FUNCTION; -} +} // namespace avx #endif #endif diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index c5787630ab..f55197c8c9 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -273,23 +273,23 @@ extern void hl_bilinear_forward(const real* inData, const real ratioW); /** -* @brief Bilinear interpolation backward. -* -* @param[out] inGrad input gradient. -* @param[in] inImgH input image height. -* @param[in] inImgW input image width. -* @param[in] inputH input batchSize. -* @param[in] inputW input image data dim. -* @param[in] outGrad output gradient. -* @param[in] outImgH output image height. -* @param[in] outImgW output image width. -* @param[in] outputH output batchSize. -* @param[in] outputW output image data dim. -* @param[in] numChannels number of channels. -* @param[in] ratioH inImgH / outImgH. -* @param[in] ratioW inImgW / outImgW. -* -*/ + * @brief Bilinear interpolation backward. + * + * @param[out] inGrad input gradient. + * @param[in] inImgH input image height. + * @param[in] inImgW input image width. + * @param[in] inputH input batchSize. + * @param[in] inputW input image data dim. + * @param[in] outGrad output gradient. + * @param[in] outImgH output image height. + * @param[in] outImgW output image width. + * @param[in] outputH output batchSize. + * @param[in] outputW output image data dim. + * @param[in] numChannels number of channels. + * @param[in] ratioH inImgH / outImgH. + * @param[in] ratioW inImgW / outImgW. + * + */ extern void hl_bilinear_backward(real* inGrad, const size_t inImgH, const size_t inImgW, diff --git a/paddle/cuda/src/hl_cuda_cublas.cc b/paddle/cuda/src/hl_cuda_cublas.cc index 182e8ab218..6163209e9b 100644 --- a/paddle/cuda/src/hl_cuda_cublas.cc +++ b/paddle/cuda/src/hl_cuda_cublas.cc @@ -14,10 +14,9 @@ limitations under the License. */ #include "hl_cuda_cublas.h" #include -#include #include "hl_cuda.h" -#include "hl_dso_loader.h" #include "hl_thread.ph" +#include "paddle/utils/DynamicLoader.h" #include "paddle/utils/Logging.h" namespace dynload { diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index 6198f067ba..c53a563682 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -15,10 +15,9 @@ limitations under the License. */ #include "hl_cuda_cudnn.h" #include #include -#include #include "hl_cuda_cudnn.ph" -#include "hl_dso_loader.h" #include "hl_thread.ph" +#include "paddle/utils/DynamicLoader.h" #include "paddle/utils/Logging.h" DEFINE_int32(cudnn_conv_workspace_limit_in_mb, diff --git a/paddle/cuda/src/hl_cuda_device.cc b/paddle/cuda/src/hl_cuda_device.cc index 6dfb12e00b..4042d9742a 100644 --- a/paddle/cuda/src/hl_cuda_device.cc +++ b/paddle/cuda/src/hl_cuda_device.cc @@ -21,11 +21,10 @@ limitations under the License. */ #include #include #include -#include #include "hl_cuda.ph" #include "hl_thread.ph" -#include "hl_dso_loader.h" #include "paddle/utils/Logging.h" +#include "paddle/utils/DynamicLoader.h" // clang-format on namespace dynload { diff --git a/paddle/cuda/src/hl_warpctc_wrap.cc b/paddle/cuda/src/hl_warpctc_wrap.cc index f57efb2b46..9f812dd0de 100644 --- a/paddle/cuda/src/hl_warpctc_wrap.cc +++ b/paddle/cuda/src/hl_warpctc_wrap.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "hl_warpctc_wrap.h" #include -#include "hl_dso_loader.h" +#include "paddle/utils/DynamicLoader.h" #include "paddle/utils/Logging.h" namespace dynload { diff --git a/paddle/function/MulOpTest.cpp b/paddle/function/MulOpTest.cpp index 8748eb0d79..8753057ebf 100644 --- a/paddle/function/MulOpTest.cpp +++ b/paddle/function/MulOpTest.cpp @@ -74,9 +74,9 @@ TEST(MulOp, DDDMatrixMul) { } /** - * C += A * B, B, C dense, A sparse - * dense = sparse * dense - */ + * C += A * B, B, C dense, A sparse + * dense = sparse * dense + */ void testFuncDSparseDMatrix( size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) { real scaleT = 1.0; @@ -119,9 +119,9 @@ TEST(MuLOp, DSparseDMul) { } /** - * C += A * B, A, C dense, B sparse - * dense = dense * sparse - */ + * C += A * B, A, C dense, B sparse + * dense = dense * sparse + */ void testFuncDDSparseMatrix( size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) { real scaleT = 1.0; @@ -165,9 +165,9 @@ TEST(MulOp, DDSparseMul) { } /** - * C += A * B, A sparse, B, C dense - * sparse = dense * dense - */ + * C += A * B, A sparse, B, C dense + * sparse = dense * dense + */ void testFuncSparseDDMatrix( size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) { real scaleT = 1.0; diff --git a/paddle/gserver/gradientmachines/GradientMachine.cpp b/paddle/gserver/gradientmachines/GradientMachine.cpp index 3eb87d9b85..b44e4dc202 100644 --- a/paddle/gserver/gradientmachines/GradientMachine.cpp +++ b/paddle/gserver/gradientmachines/GradientMachine.cpp @@ -21,7 +21,6 @@ limitations under the License. */ #include "MultiGradientMachine.h" #include "MultiNetwork.h" #include "NeuralNetwork.h" -#include "NeuralNetwork.h" #include "ParallelNeuralNetwork.h" #include "hl_gpu.h" diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp index 2ab964b8fc..01158d1dce 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp @@ -637,7 +637,7 @@ void RecurrentGradientMachine::removeBeamSearchStatisticsCallbacks() { /* create scattered id infomation for all realLayer of inFrameLines one time. * If hasSubseq, will also create scattered sequenceStartPositions infomation * for all realLayer of inFrameLines one time. -*/ + */ void RecurrentGradientMachine::createInFrameInfo(int inlinkId, const Argument& input, diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.h b/paddle/gserver/gradientmachines/RecurrentGradientMachine.h index 910ca4376b..c2bc52709a 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.h +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.h @@ -107,18 +107,18 @@ public: DropCallback; /** - * @brief NormOrDropNodeCallback - * - * Normalize a path's probabilities or just drop it by modifying path.logProb - * - * The first parameter is sequence index in a batch - * - * The second parameter is path.ids - * - * The third parameter is probabilites for each node in this path. - * - * The fourth parameter is the probability of the whole path. - */ + * @brief NormOrDropNodeCallback + * + * Normalize a path's probabilities or just drop it by modifying path.logProb + * + * The first parameter is sequence index in a batch + * + * The second parameter is path.ids + * + * The third parameter is probabilites for each node in this path. + * + * The fourth parameter is the probability of the whole path. + */ typedef std::function&, std::vector&, real*)> NormOrDropNodeCallback; @@ -348,9 +348,9 @@ protected: int targetInfoInlinkId_; /* create scattered id infomation for all realLayer of inFrameLines one time. - * If hasSubseq, will also create scattered sequenceStartPositions infomation - * for all realLayer of inFrameLines one time. - */ + * If hasSubseq, will also create scattered sequenceStartPositions infomation + * for all realLayer of inFrameLines one time. + */ void createInFrameInfo(int inlinks_id, const Argument& input, PassType passType); diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h index 47182c9ecc..0ed482889d 100644 --- a/paddle/gserver/layers/Layer.h +++ b/paddle/gserver/layers/Layer.h @@ -106,9 +106,9 @@ protected: public: /** - * Wait until all input value ready. - * Called before Layer::forward() function. - */ + * Wait until all input value ready. + * Called before Layer::forward() function. + */ virtual void waitInputValue(); /** @@ -118,9 +118,9 @@ public: virtual void copyOutputToOtherDevice(); /** - * Wait until all output grad ready and merge them to output_.grad. - * Called before Layer::backward() function. - */ + * Wait until all output grad ready and merge them to output_.grad. + * Called before Layer::backward() function. + */ virtual void waitAndMergeOutputGrad(); /** diff --git a/paddle/gserver/layers/RotateLayer.h b/paddle/gserver/layers/RotateLayer.h index 1a64d4d5a5..d05c2065cb 100644 --- a/paddle/gserver/layers/RotateLayer.h +++ b/paddle/gserver/layers/RotateLayer.h @@ -29,7 +29,7 @@ namespace paddle { * * The config file api is rotate_layer * -*/ + */ class RotateLayer : public Layer { public: diff --git a/paddle/gserver/layers/SequencePoolLayer.cpp b/paddle/gserver/layers/SequencePoolLayer.cpp index 8c49502011..235d9a9b0f 100644 --- a/paddle/gserver/layers/SequencePoolLayer.cpp +++ b/paddle/gserver/layers/SequencePoolLayer.cpp @@ -60,7 +60,7 @@ void SequencePoolLayer::forward(PassType passType) { * thus, in this case, output_ has no sequenceStartPositions. * If type_ = kSeq, seq has sub-seq degrades to a seq, thus, only in this * case, we should compute the new sequenceStartPositions. - */ + */ if (type_) { CHECK(input.subSequenceStartPositions) << "when trans_type = seq, input must hasSubseq"; diff --git a/paddle/math/MathFunctions.cpp b/paddle/math/MathFunctions.cpp index 6203cd3b9a..178fce5b0a 100644 --- a/paddle/math/MathFunctions.cpp +++ b/paddle/math/MathFunctions.cpp @@ -15,6 +15,54 @@ limitations under the License. */ #include "MathFunctions.h" #include "hl_matrix_apply.cuh" #include "hl_matrix_ops.cuh" +#include "paddle/utils/DynamicLoader.h" + +namespace dynload { + +std::once_flag lapack_dso_flag; +void* lapack_dso_handle = nullptr; + +/** + * The following macro definition can generate structs + * (for each function) to dynamic load lapack routine + * via operator overloading. + * + * note: default dynamic linked libs + */ +#define DYNAMIC_LOAD_LAPACK_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + using lapack_func = decltype(__name(args...)) (*)(Args...); \ + std::call_once(lapack_dso_flag, GetLapackDsoHandle, &lapack_dso_handle); \ + void* p_##__name = dlsym(lapack_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + } __name; // struct DynLoad__##__name + +// clang-format off +#ifdef PADDLE_USE_ATLAS + #define PADDLE_SGETRF clapack_sgetrf + #define PADDLE_DGETRF clapack_dgetrf + #define PADDLE_SGETRI clapack_sgetri + #define PADDLE_DGETRI clapack_dgetri +#else + #define PADDLE_SGETRF LAPACKE_sgetrf + #define PADDLE_DGETRF LAPACKE_dgetrf + #define PADDLE_SGETRI LAPACKE_sgetri + #define PADDLE_DGETRI LAPACKE_dgetri +#endif + +#define LAPACK_ROUTINE_EACH(__macro) \ + __macro(PADDLE_SGETRF) \ + __macro(PADDLE_DGETRF) \ + __macro(PADDLE_SGETRI) \ + __macro(PADDLE_DGETRI) +// clang-format on + +LAPACK_ROUTINE_EACH(DYNAMIC_LOAD_LAPACK_WRAP) + +} // namespace dynload namespace paddle { @@ -85,16 +133,7 @@ int getrf(const CBLAS_ORDER order, float* A, const int lda, int* ipiv) { -#ifdef PADDLE_USE_LAPACK -#ifdef PADDLE_USE_ATLAS - return clapack_sgetrf(order, M, N, A, lda, ipiv); -#else - return LAPACKE_sgetrf(order, M, N, A, lda, ipiv); -#endif -#else - LOG(FATAL) << "Not implemented"; -#endif - return 0; + return dynload::PADDLE_SGETRF(order, M, N, A, lda, ipiv); } template <> @@ -104,16 +143,7 @@ int getrf(const CBLAS_ORDER order, double* A, const int lda, int* ipiv) { -#ifdef PADDLE_USE_LAPACK -#ifdef PADDLE_USE_ATLAS - return clapack_dgetrf(order, M, N, A, lda, ipiv); -#else - return LAPACKE_dgetrf(order, M, N, A, lda, ipiv); -#endif -#else - LOG(FATAL) << "Not implemented"; -#endif - return 0; + return dynload::PADDLE_DGETRF(order, M, N, A, lda, ipiv); } template <> @@ -122,16 +152,7 @@ int getri(const CBLAS_ORDER order, float* A, const int lda, const int* ipiv) { -#ifdef PADDLE_USE_LAPACK -#ifdef PADDLE_USE_ATLAS - return clapack_sgetri(order, N, A, lda, ipiv); -#else - return LAPACKE_sgetri(order, N, A, lda, ipiv); -#endif -#else - LOG(FATAL) << "Not implemented"; -#endif - return 0; + return dynload::PADDLE_SGETRI(order, N, A, lda, ipiv); } template <> @@ -140,15 +161,7 @@ int getri(const CBLAS_ORDER order, double* A, const int lda, const int* ipiv) { -#ifdef PADDLE_USE_LAPACK -#ifdef PADDLE_USE_ATLAS - return clapack_dgetri(order, N, A, lda, ipiv); -#else - return LAPACKE_dgetri(order, N, A, lda, ipiv); -#endif -#else - LOG(FATAL) << "Not implemented"; -#endif + return dynload::PADDLE_DGETRI(order, N, A, lda, ipiv); return 0; } diff --git a/paddle/math/MathFunctions.h b/paddle/math/MathFunctions.h index 9f8f84a87c..c8559eefd8 100644 --- a/paddle/math/MathFunctions.h +++ b/paddle/math/MathFunctions.h @@ -17,14 +17,11 @@ limitations under the License. */ #ifdef PADDLE_USE_MKL #include -#ifdef PADDLE_USE_LAPACK #include -#endif #else extern "C" { #include } -#ifdef PADDLE_USE_LAPACK #ifdef PADDLE_USE_ATLAS extern "C" { #include @@ -33,7 +30,6 @@ extern "C" { #include #endif #endif -#endif #include diff --git a/paddle/math/tests/TestUtils.h b/paddle/math/tests/TestUtils.h index c302096188..713f407f49 100644 --- a/paddle/math/tests/TestUtils.h +++ b/paddle/math/tests/TestUtils.h @@ -37,7 +37,7 @@ limitations under the License. */ * * AutoCompare test; * test.cmpWithoutArg(function, height, width) -*/ + */ #include #include "TensorCheck.h" diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 5210fe3fa1..3b1b0065af 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h" #include "paddle/testing/TestUtil.h" +#include "paddle/utils/DynamicLoader.h" #include "paddle/utils/Stat.h" #include "paddle/utils/Util.h" @@ -235,10 +236,15 @@ TEST(Matrix, unary) { testMatrixTranspose(height, width); testMatrixRotate(height, width); } -// inverse -#ifdef PADDLE_USE_LAPACK - testMatrixInverse(height); -#endif + // inverse matrix + void** dso_handler = nullptr; + GetLapackDsoHandle(dso_handler); + if (nullptr == *dso_handler) { + LOG(WARNING) << "Failed to find liblapack.so, please specify its path " + "using LD_LIBRARY_PATH."; + } else { + testMatrixInverse(height); + } } } diff --git a/paddle/parameter/FirstOrderOptimizer.h b/paddle/parameter/FirstOrderOptimizer.h index 095019b74f..caa78acd98 100644 --- a/paddle/parameter/FirstOrderOptimizer.h +++ b/paddle/parameter/FirstOrderOptimizer.h @@ -126,7 +126,7 @@ protected: /* * AdaDelta Optimization. * http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf -*/ + */ class AdaDeltaParameterOptimizer : public ParameterOptimizer { public: explicit AdaDeltaParameterOptimizer(const OptimizationConfig& optConfig) diff --git a/paddle/trainer/tests/picojson.h b/paddle/trainer/tests/picojson.h index 23bfa16408..4aa64961d0 100644 --- a/paddle/trainer/tests/picojson.h +++ b/paddle/trainer/tests/picojson.h @@ -1059,14 +1059,14 @@ inline bool operator==(const value& x, const value& y) { } inline bool operator!=(const value& x, const value& y) { return !(x == y); } -} +} // namespace picojson namespace std { template <> inline void swap(picojson::value& x, picojson::value& y) { x.swap(y); } -} +} // namespace std inline std::istream& operator>>(std::istream& is, picojson::value& x) { picojson::set_last_error(std::string()); diff --git a/paddle/cuda/src/hl_dso_loader.cc b/paddle/utils/DynamicLoader.cpp similarity index 94% rename from paddle/cuda/src/hl_dso_loader.cc rename to paddle/utils/DynamicLoader.cpp index 53164dd27c..368c35e151 100644 --- a/paddle/cuda/src/hl_dso_loader.cc +++ b/paddle/utils/DynamicLoader.cpp @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "hl_dso_loader.h" +#include "DynamicLoader.h" #include -#include "paddle/utils/Logging.h" +#include "Logging.h" DEFINE_string(cudnn_dir, "", @@ -30,6 +30,8 @@ DEFINE_string(cuda_dir, DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so."); +DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so."); + static inline std::string join(const std::string& part1, const std::string& part2) { // directory separator @@ -160,3 +162,11 @@ void GetWarpCTCDsoHandle(void** dso_handle) { GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so", dso_handle); #endif } + +void GetLapackDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.so", dso_handle); +#endif +} diff --git a/paddle/cuda/include/hl_dso_loader.h b/paddle/utils/DynamicLoader.h similarity index 83% rename from paddle/cuda/include/hl_dso_loader.h rename to paddle/utils/DynamicLoader.h index 276a07d3c7..9b5ad21724 100644 --- a/paddle/cuda/include/hl_dso_loader.h +++ b/paddle/utils/DynamicLoader.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef HL_DSO_LOADER_H_ -#define HL_DSO_LOADER_H_ +#ifndef DYNAMIC_LOAD_H_ +#define DYNAMIC_LOAD_H_ #include #include +#include #include -#include "hl_base.h" /** * @brief load the DSO of CUBLAS @@ -52,4 +52,12 @@ void GetCurandDsoHandle(void** dso_handle); */ void GetWarpCTCDsoHandle(void** dso_handle); -#endif // HL_DSO_LOADER_H_ +/** + * @brief load the DSO of lapack + * + * @param **dso_handle dso handler + * + */ +void GetLapackDsoHandle(void** dso_handle); + +#endif // DYNAMIC_LOAD_H_