From 45c988d86a43bf34667ce7110972fff8dcaf20de Mon Sep 17 00:00:00 2001 From: sabreshao Date: Fri, 16 Mar 2018 17:27:19 +0800 Subject: [PATCH 1/5] Demostration of cmake refine for HIP support. 1. Add option WITH_AMD_GPU. 2. Add cmake/hip.cmake for HIP toolchain. 3. Some external module such as eigen may need HIP port. 4. Add macro hip_library/hip_binary/hip_test to cmake/generic.cmake. 5. Add one HIP source concat.hip.cu as an example. Each .cu may have its corresponding .hip.cu. --- CMakeLists.txt | 9 + cmake/configure.cmake | 15 +- cmake/external/eigen.cmake | 43 +++- cmake/generic.cmake | 76 ++++++ cmake/hip.cmake | 46 ++++ paddle/fluid/operators/CMakeLists.txt | 3 + paddle/fluid/operators/math/CMakeLists.txt | 6 + paddle/fluid/operators/math/concat.hip.cu | 281 +++++++++++++++++++++ paddle/fluid/pybind/CMakeLists.txt | 21 +- paddle/scripts/docker/build.sh | 4 + 10 files changed, 477 insertions(+), 27 deletions(-) create mode 100644 cmake/hip.cmake create mode 100644 paddle/fluid/operators/math/concat.hip.cu mode change 100644 => 100755 paddle/scripts/docker/build.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 0ec65bac84..399bf50748 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,7 @@ include(simd) ################################ Configurations ####################################### option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) +option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND}) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) @@ -69,6 +70,9 @@ if(NOT CMAKE_BUILD_TYPE) FORCE) endif() +if(WITH_AMD_GPU) +endif() + if(ANDROID OR IOS) if(ANDROID) if(${CMAKE_SYSTEM_VERSION} VERSION_LESS "16") @@ -180,6 +184,11 @@ if(WITH_GPU) include(cuda) endif(WITH_GPU) +if(WITH_AMD_GPU) + find_package(HIP) + include(hip) +endif(WITH_AMD_GPU) + if(WITH_MKLML) list(APPEND EXTERNAL_LIBS ${MKLML_IOMP_LIB}) endif() diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 0f76f55270..f726405c47 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -57,11 +57,7 @@ if(NOT WITH_GOLANG) add_definitions(-DPADDLE_WITHOUT_GOLANG) endif(NOT WITH_GOLANG) -if(NOT WITH_GPU) - add_definitions(-DHPPL_STUB_FUNC) - - list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu) -else() +if(WITH_GPU) add_definitions(-DPADDLE_WITH_CUDA) FIND_PACKAGE(CUDA REQUIRED) @@ -84,7 +80,14 @@ else() # Include cuda and cudnn include_directories(${CUDNN_INCLUDE_DIR}) include_directories(${CUDA_TOOLKIT_INCLUDE}) -endif(NOT WITH_GPU) +elseif(WITH_AMD_GPU) + add_definitions(-DPADDLE_WITH_HIP) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_HCC__") +else() + add_definitions(-DHPPL_STUB_FUNC) + list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu) +endif() if (WITH_MKLML AND MKLML_IOMP_LIB) message(STATUS "Enable Intel OpenMP with ${MKLML_IOMP_LIB}") diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 6a701e076c..5d88c5a0b0 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -1,21 +1,36 @@ INCLUDE(ExternalProject) SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3) -SET(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}/src/extern_eigen3) -INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR}) -ExternalProject_Add( - extern_eigen3 - ${EXTERNAL_PROJECT_LOG_ARGS} - GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" - GIT_TAG 70661066beef694cadf6c304d0d07e0758825c10 - PREFIX ${EIGEN_SOURCE_DIR} - UPDATE_COMMAND "" - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - INSTALL_COMMAND "" - TEST_COMMAND "" -) +INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3) + +if(WITH_AMD_GPU) + ExternalProject_Add( + extern_eigen3 + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/sabreshao/hipeigen.git" + GIT_TAG 0cba03ff9f8f9f70bbd92ac5857b031aa8fed6f9 + PREFIX ${EIGEN_SOURCE_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +else() + ExternalProject_Add( + extern_eigen3 + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" + GIT_TAG 70661066beef694cadf6c304d0d07e0758825c10 + PREFIX ${EIGEN_SOURCE_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +endif() if (${CMAKE_VERSION} VERSION_LESS "3.3.0") set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/eigen3_dummy.c) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 471e392906..c749c97f13 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -317,6 +317,82 @@ function(nv_test TARGET_NAME) endif() endfunction(nv_test) +function(hip_library TARGET_NAME) + if (WITH_AMD_GPU) + set(options STATIC static SHARED shared) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(hip_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(_sources ${hip_library_SRCS}) + HIP_PREPARE_TARGET_COMMANDS(${TARGET_NAME} OBJ _generated_files _source_files ${_sources} HIPCC_OPTIONS ${_hipcc_options} HCC_OPTIONS ${_hcc_options} NVCC_OPTIONS ${_nvcc_options}) + if(_source_files) + list(REMOVE_ITEM _sources ${_source_files}) + endif() + if(hip_library_SRCS) + if (hip_library_SHARED OR hip_library_shared) # build *.so + add_library(${TARGET_NAME} SHARED ${_cmake_options} ${_generated_files} ${_sources}) + set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP) + else() + add_library(${TARGET_NAME} STATIC ${_cmake_options} ${_generated_files} ${_sources}) + set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE CXX) + target_link_libraries(${TARGET_NAME} /opt/rocm/hip/lib/libhip_hcc.so /opt/rocm/hip/lib/libhip_device.a) + find_fluid_modules(${TARGET_NAME}) + endif() + if (hip_library_DEPS) + add_dependencies(${TARGET_NAME} ${hip_library_DEPS}) + target_link_libraries(${TARGET_NAME} ${hip_library_DEPS}) + endif() + # cpplint code style + foreach(source_file ${hip_library_SRCS}) + string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND hip_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + endif() + endforeach() + add_style_check_target(${TARGET_NAME} ${hip_library_SRCS} ${hip_library_HEADERS}) + else(hip_library_SRCS) + if (hip_library_DEPS) + merge_static_libs(${TARGET_NAME} ${hip_library_DEPS}) + else() + message(FATAL "Please specify source file or library in nv_library.") + endif() + endif(hip_library_SRCS) + endif() +endfunction(hip_library) + +function(hip_binary TARGET_NAME) + if (WITH_AMD_GPU) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(hip_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + hip_add_executable(${TARGET_NAME} ${hip_binary_SRCS}) + if(hip_binary_DEPS) + target_link_libraries(${TARGET_NAME} ${hip_binary_DEPS}) + add_dependencies(${TARGET_NAME} ${hip_binary_DEPS}) + endif() + endif() +endfunction(hip_binary) + +function(hip_test TARGET_NAME) + if (WITH_AMD_GPU AND WITH_TESTING) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(hip_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(_sources ${hip_test_SRCS}) + HIP_PREPARE_TARGET_COMMANDS(${TARGET_NAME} OBJ _generated_files _source_files ${_sources} HIPCC_OPTIONS ${_hipcc_options} HCC_OPTIONS ${_hcc_options} NVCC_OPTIONS ${_nvcc_options}) + if(_source_files) + list(REMOVE_ITEM _sources ${_source_files}) + endif() + add_executable(${TARGET_NAME} ${_cmake_options} ${_generated_files} ${_sources}) + set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP) + target_link_libraries(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main paddle_memory gtest gflags) + add_dependencies(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main paddle_memory gtest gflags) + add_test(${TARGET_NAME} ${TARGET_NAME}) + endif() +endfunction(hip_test) + function(go_library TARGET_NAME) set(options STATIC static SHARED shared) set(oneValueArgs "") diff --git a/cmake/hip.cmake b/cmake/hip.cmake new file mode 100644 index 0000000000..cd880603a7 --- /dev/null +++ b/cmake/hip.cmake @@ -0,0 +1,46 @@ +if(NOT WITH_AMD_GPU) + return() +endif() + +include_directories("/opt/rocm/include") +include_directories("/opt/rocm/hipblas/include") +include_directories("/opt/rocm/hiprand/include") +include_directories("/opt/rocm/rocrand/include") +include_directories("/opt/rocm/rccl/include") +include_directories("/opt/rocm/thrust") + +list(APPEND EXTERNAL_LIBS "-L/opt/rocm/lib/ -lhip_hcc") + +set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fPIC -DPADDLE_WITH_HIP -std=c++14" ) + +if(WITH_DSO) + set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_USE_DSO") +endif(WITH_DSO) + +if(WITH_DOUBLE) + set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_TYPE_DOUBLE") +endif(WITH_DOUBLE) + +if(WITH_TESTING) + set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_WITH_TESTING") +endif(WITH_TESTING) + +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG}) +elseif(CMAKE_BUILD_TYPE STREQUAL "Release") +# Disable optimization since one eigen symbol will be removed in math_function.cu + #list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE}) +elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") + list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}) +elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel") + list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL}) +endif() + +if("x${HCC_HOME}" STREQUAL "x") + set(HCC_HOME "/opt/rocm/hcc") +endif() + +set(CMAKE_HIP_LINK_EXECUTABLE "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_HOME} -o ") +set(CMAKE_HIP_CREATE_SHARED_LIBRARY "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_HOME} -o -shared") +set(CMAKE_HIP_CREATE_SHARED_MODULE "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_HOME} -o -shared") + diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index d30124d4a3..26d1dab1e9 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -76,6 +76,9 @@ function(op_library TARGET) if (WITH_GPU) nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) + elseif (WITH_AMD_GPU) + hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS + ${op_library_DEPS} ${op_common_deps}) else() cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index fba1612d10..1cac62472c 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -6,6 +6,7 @@ function(math_library TARGET) # But it handle split GPU/CPU code and link some common library. set(cc_srcs) set(cu_srcs) + set(hip_srcs) set(math_common_deps device_context framework_proto) set(multiValueArgs DEPS) cmake_parse_arguments(math_library "${options}" "${oneValueArgs}" @@ -17,10 +18,15 @@ function(math_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) list(APPEND cu_srcs ${TARGET}.cu) endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu) + list(APPEND hip_srcs ${TARGET}.hip.cu) + endif() list(LENGTH cc_srcs cc_srcs_len) if (WITH_GPU) nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) + elseif (WITH_AMD_GPU) + hip_library(${TARGET} SRCS ${cc_srcs} ${hip_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) elseif(${cc_srcs_len} GREATER 0) cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) endif() diff --git a/paddle/fluid/operators/math/concat.hip.cu b/paddle/fluid/operators/math/concat.hip.cu new file mode 100644 index 0000000000..91efd8ea57 --- /dev/null +++ b/paddle/fluid/operators/math/concat.hip.cu @@ -0,0 +1,281 @@ +/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hip/hip_runtime.h" +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/operators/math/concat.h" +#include "paddle/fluid/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__device__ T upper_bound(const T* first, T count, T val) { + const T* orig = first; + const T* it = nullptr; + T step = 0; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (!(val < *it)) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + return first - orig; +} + +template +__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, + const int output_rows, const int output_cols, + T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int segment = upper_bound(input_cols, col_size, tid_x) - 1; + + int curr_offset = input_cols[segment]; + int curr_segment = segment; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + T curr_col_offset; + while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* input_ptr = inputs[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * segment_width + local_col]; + } +} + +template +__global__ void KernelConcat(T** inputs, const int input_col, + const int output_rows, const int output_cols, + T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + double inv_input_col = 1.0 / input_col; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * inv_input_col; + int in_offset = tid_x - split * input_col; + T* input_ptr = inputs[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) { + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * input_col + in_offset]; + } + } +} + +template +__global__ void KernelConcatGrad(const T* input, const int input_row, + const int input_col, const int* output_cols, + int col_size, T** outputs) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int segment = upper_bound(output_cols, col_size, tid_x) - 1; + int curr_offset = output_cols[segment]; + int curr_segment = segment; + for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + T curr_col_offset; + while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* output_ptr = outputs[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * segment_width + local_col] = + input[tid_y * input_col + tid_x]; + } +} + +template +__global__ void KernelConcatGrad(const T* input, const int input_row, + const int input_col, const int output_cols, + T** outputs) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + double inv_input_col = 1.0 / input_col; + for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * inv_input_col; + int in_offset = tid_x - split * input_col; + T* output_ptr = outputs[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * output_cols + in_offset] = + input[tid_y * input_col + tid_x]; + } +} + +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ +template +class ConcatFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const std::vector& input, const int axis, + framework::Tensor* output) { + // TODO(zcd): Add input data validity checking + int num = input.size(); + int rows = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int cols = input[0].numel() / rows; + int out_rows = rows, out_cols = 0; + + framework::Vector inputs_data(num * sizeof(T*) / 2); + framework::Vector inputs_cols(num + 1); + inputs_cols[0] = 0; + T** inputs_ptr = reinterpret_cast(inputs_data.data()); + + bool sameShape = true; + for (int i = 0; i < num; ++i) { + int t_cols = input[i].numel() / rows; + if (sameShape) { + if (t_cols != cols) sameShape = false; + } + out_cols += t_cols; + inputs_cols[i + 1] = out_cols; + inputs_ptr[i] = const_cast(input[i].data()); + } + + T** ins_gpu = + reinterpret_cast(inputs_data.CUDAMutableData(context.GetPlace())); + const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace()); + + // computation + // set the thread block and grid according to CurrentDeviceId + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((out_cols + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + dim3 block_size = dim3(block_cols, block_rows, 1); + + int max_threads = context.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((out_cols + block_cols - 1) / block_cols, max_blocks); + int grid_rows = + std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1)); + dim3 grid_size = dim3(grid_cols, grid_rows, 1); + + if (sameShape) { + hipLaunchKernelGGL((KernelConcat), dim3(grid_size), dim3(block_size), 0, context.stream(), + ins_gpu, cols, out_rows, out_cols, output->data()); + } else { + hipLaunchKernelGGL((KernelConcat), dim3(grid_size), dim3(block_size), 0, context.stream(), + ins_gpu, ins_col_gpu, static_cast(inputs_cols.size()), out_rows, + out_cols, output->data()); + } + } +}; + +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ +template +class ConcatGradFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, const int axis, + std::vector& outputs) { + // TODO(zcd): Add input data validity checking + int num = outputs.size(); + int input_row = 1; + auto dim_0 = outputs[0].dims(); + for (int i = 0; i < axis; ++i) { + input_row *= dim_0[i]; + } + + int output_col_0 = outputs[0].numel() / input_row; + int input_col = 0; + bool sameShape = true; + + framework::Vector outputs_data(num * sizeof(T*) / 2); + framework::Vector outputs_cols(num + 1); + outputs_cols[0] = 0; + T** outputs_ptr = reinterpret_cast(outputs_data.data()); + + for (int i = 0; i < num; ++i) { + int t_col = outputs[i].numel() / input_row; + if (sameShape) { + if (t_col != output_col_0) sameShape = false; + } + input_col += t_col; + outputs_cols[i + 1] = input_col; + outputs_ptr[i] = outputs[i].data(); + } + + T** outs_gpu = + reinterpret_cast(outputs_data.CUDAMutableData(context.GetPlace())); + const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace()); + + // computation + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((input_col + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + dim3 block_size = dim3(block_cols, block_rows, 1); + + int max_threads = context.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((input_col + block_cols - 1) / block_cols, max_blocks); + int grid_rows = + std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1)); + dim3 grid_size = dim3(grid_cols, grid_rows, 1); + + if (sameShape) { + hipLaunchKernelGGL((KernelConcatGrad), dim3(grid_size), dim3(block_size), 0, context.stream(), + input.data(), input_row, input_col, output_col_0, outs_gpu); + } else { + hipLaunchKernelGGL((KernelConcatGrad), dim3(grid_size), dim3(block_size), 0, context.stream(), + input.data(), input_row, input_col, outs_col_gpu, + static_cast(outputs_cols.size()), outs_gpu); + } + } +}; + +template class ConcatFunctor; +template class ConcatFunctor; +template class ConcatFunctor; +template class ConcatFunctor; + +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 8942b5c943..d523ad7f73 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,9 +1,16 @@ if(WITH_PYTHON) - cc_library(paddle_pybind SHARED - SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc - DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method - ${GLOB_OP_LIB}) - if(NOT APPLE AND NOT ANDROID) - target_link_libraries(paddle_pybind rt) - endif(NOT APPLE AND NOT ANDROID) + if(WITH_AMD_GPU) + hip_library(paddle_pybind SHARED + SRCS pybind.cc exception.cc protobuf.cc const_value.cc + DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method + ${GLOB_OP_LIB}) + else() + cc_library(paddle_pybind SHARED + SRCS pybind.cc exception.cc protobuf.cc const_value.cc + DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method + ${GLOB_OP_LIB}) + if(NOT APPLE AND NOT ANDROID) + target_link_libraries(paddle_pybind rt) + endif(NOT APPLE AND NOT ANDROID) + endif(WITH_AMD_GPU) endif(WITH_PYTHON) diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh old mode 100644 new mode 100755 index 6be2bd8fad..02f2d7ba12 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -37,6 +37,7 @@ function cmake_gen() { -DWITH_DSO=ON -DWITH_DOC=OFF -DWITH_GPU=${WITH_GPU:-OFF} + -DWITH_AMD_GPU=${WITH_AMD_GPU:-OFF} -DWITH_DISTRIBUTE=${WITH_DISTRIBUTE:-OFF} -DWITH_MKL=${WITH_MKL:-ON} -DWITH_AVX=${WITH_AVX:-OFF} @@ -50,6 +51,7 @@ function cmake_gen() { -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} -DWITH_TESTING=${WITH_TESTING:-ON} -DWITH_FAST_BUNDLE_TEST=ON + -DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake -DCMAKE_EXPORT_COMPILE_COMMANDS=ON ======================================== EOF @@ -62,6 +64,7 @@ EOF -DWITH_DSO=ON \ -DWITH_DOC=OFF \ -DWITH_GPU=${WITH_GPU:-OFF} \ + -DWITH_AMD_GPU=${WITH_AMD_GPU:-OFF} \ -DWITH_DISTRIBUTE=${WITH_DISTRIBUTE:-OFF} \ -DWITH_MKL=${WITH_MKL:-ON} \ -DWITH_AVX=${WITH_AVX:-OFF} \ @@ -74,6 +77,7 @@ EOF -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} \ -DWITH_TESTING=${WITH_TESTING:-ON} \ -DWITH_FAST_BUNDLE_TEST=ON \ + -DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON } From e50205e744753f5a6c93f49bd74e00aa7cc642d2 Mon Sep 17 00:00:00 2001 From: sabreshao Date: Tue, 20 Mar 2018 13:46:48 +0800 Subject: [PATCH 2/5] CMake refine for HIP support. 1. Add option WITH_AMD_GPU. 2. Add cmake/hip.cmake for HIP toolchain. 3. Some external module such as eigen may need HIP port. 4. Add macro hip_library/hip_binary/hip_test to cmake/generic.cmake. 5. Add one HIP source concat.hip.cu as an example. Each .cu may have its corresponding .hip.cu. --- CMakeLists.txt | 3 - cmake/external/eigen.cmake | 4 +- cmake/hip.cmake | 3 - paddle/fluid/operators/CMakeLists.txt | 33 ++- paddle/fluid/operators/math/concat.hip.cu | 268 +--------------------- paddle/scripts/docker/build.sh | 4 +- 6 files changed, 33 insertions(+), 282 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 399bf50748..1e11f86d0e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,9 +70,6 @@ if(NOT CMAKE_BUILD_TYPE) FORCE) endif() -if(WITH_AMD_GPU) -endif() - if(ANDROID OR IOS) if(ANDROID) if(${CMAKE_SYSTEM_VERSION} VERSION_LESS "16") diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 5d88c5a0b0..73d70c34dc 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -1,8 +1,8 @@ INCLUDE(ExternalProject) SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3) - -INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3) +SET(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}/src/extern_eigen3) +INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR}) if(WITH_AMD_GPU) ExternalProject_Add( diff --git a/cmake/hip.cmake b/cmake/hip.cmake index cd880603a7..bfe491bd6b 100644 --- a/cmake/hip.cmake +++ b/cmake/hip.cmake @@ -27,9 +27,6 @@ endif(WITH_TESTING) if(CMAKE_BUILD_TYPE STREQUAL "Debug") list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG}) -elseif(CMAKE_BUILD_TYPE STREQUAL "Release") -# Disable optimization since one eigen symbol will be removed in math_function.cu - #list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE}) elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}) elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel") diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 26d1dab1e9..c0245379ac 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -12,6 +12,8 @@ function(op_library TARGET) set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE) set(cc_srcs) set(cu_srcs) + set(hip_cu_srcs) + set(miopen_hip_cc_srcs) set(cu_cc_srcs) set(cudnn_cu_cc_srcs) set(CUDNN_FILE) @@ -36,10 +38,19 @@ function(op_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) list(APPEND cu_srcs ${TARGET}.cu) endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu) + list(APPEND hip_cu_srcs ${TARGET}.hip.cu) + endif() string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}") if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc) list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc) endif() + if(WITH_AMD_GPU) + string(REPLACE "_op" "_miopen_op" MIOPEN_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.hip.cc) + list(APPEND miopen_hip_cc_srcs ${MIOPEN_FILE}.hip.cc) + endif() + endif() if(WITH_MKLDNN) string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}") if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_FILE}.cc) @@ -48,10 +59,14 @@ function(op_library TARGET) endif() else() foreach(src ${op_library_SRCS}) - if (${src} MATCHES ".*\\.cu$") + if (${src} MATCHES ".*\\.hip.cu$") + list(APPEND hip_cu_srcs ${src}) + elseif (${src} MATCHES ".*\\.cu$") list(APPEND cu_srcs ${src}) elseif(${src} MATCHES ".*_cudnn_op.cu.cc$") list(APPEND cudnn_cu_cc_srcs ${src}) + elseif(WITH_AMD_GPU AND ${src} MATCHES ".*_miopen_op.hip.cc$") + list(APPEND miopen_hip_cc_srcs ${src}) elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$") list(APPEND mkldnn_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cu.cc$") @@ -77,8 +92,8 @@ function(op_library TARGET) nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) elseif (WITH_AMD_GPU) - hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS - ${op_library_DEPS} ${op_common_deps}) + hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cu_srcs} ${miopen_hip_cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} + ${op_common_deps}) else() cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) @@ -91,7 +106,7 @@ function(op_library TARGET) endif() endforeach() - # The registration of USE_OP, please refer to paddle/framework/op_registry.h. + # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h. # Note that it's enough to just adding one operator to pybind in a *_op.cc file. # And for detail pybind information, please see generated paddle/pybind/pybind.h. file(READ ${TARGET}.cc TARGET_CONTENT) @@ -117,7 +132,10 @@ function(op_library TARGET) list(LENGTH cu_srcs cu_srcs_len) list(LENGTH cu_cc_srcs cu_cc_srcs_len) list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) - if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0) + list(LENGTH hip_cu_srcs hip_cu_srcs_len) + list(LENGTH miopen_hip_cc_srcs miopen_hip_cc_srcs_len) + if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND + ${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0) file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") set(pybind_flag 1) endif() @@ -128,6 +146,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") endif() + # pybind USE_OP_DEVICE_KERNEL for MIOPEN + if (WITH_AMD_GPU AND ${miopen_hip_cc_srcs_len} GREATER 0) + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MIOPEN);\n") + endif() + # pybind USE_OP_DEVICE_KERNEL for MKLDNN if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") diff --git a/paddle/fluid/operators/math/concat.hip.cu b/paddle/fluid/operators/math/concat.hip.cu index 91efd8ea57..eacef04388 100644 --- a/paddle/fluid/operators/math/concat.hip.cu +++ b/paddle/fluid/operators/math/concat.hip.cu @@ -12,270 +12,4 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "hip/hip_runtime.h" -#include "paddle/fluid/framework/mixed_vector.h" -#include "paddle/fluid/operators/math/concat.h" -#include "paddle/fluid/platform/cuda_helper.h" - -namespace paddle { -namespace operators { -namespace math { - -template -__device__ T upper_bound(const T* first, T count, T val) { - const T* orig = first; - const T* it = nullptr; - T step = 0; - while (count > 0) { - it = first; - step = count / 2; - it += step; - if (!(val < *it)) { - first = ++it; - count -= step + 1; - } else { - count = step; - } - } - return first - orig; -} - -template -__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, - const int output_rows, const int output_cols, - T* output) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int segment = upper_bound(input_cols, col_size, tid_x) - 1; - - int curr_offset = input_cols[segment]; - int curr_segment = segment; - for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { - T curr_col_offset; - while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) { - curr_offset = curr_col_offset; - ++curr_segment; - } - - int local_col = tid_x - curr_offset; - int segment_width = curr_col_offset - curr_offset; - T* input_ptr = inputs[curr_segment]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) - output[tid_y * output_cols + tid_x] = - input_ptr[tid_y * segment_width + local_col]; - } -} - -template -__global__ void KernelConcat(T** inputs, const int input_col, - const int output_rows, const int output_cols, - T* output) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - double inv_input_col = 1.0 / input_col; - for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { - int split = tid_x * inv_input_col; - int in_offset = tid_x - split * input_col; - T* input_ptr = inputs[split]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) { - output[tid_y * output_cols + tid_x] = - input_ptr[tid_y * input_col + in_offset]; - } - } -} - -template -__global__ void KernelConcatGrad(const T* input, const int input_row, - const int input_col, const int* output_cols, - int col_size, T** outputs) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int segment = upper_bound(output_cols, col_size, tid_x) - 1; - int curr_offset = output_cols[segment]; - int curr_segment = segment; - for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { - T curr_col_offset; - while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) { - curr_offset = curr_col_offset; - ++curr_segment; - } - - int local_col = tid_x - curr_offset; - int segment_width = curr_col_offset - curr_offset; - T* output_ptr = outputs[curr_segment]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * segment_width + local_col] = - input[tid_y * input_col + tid_x]; - } -} - -template -__global__ void KernelConcatGrad(const T* input, const int input_row, - const int input_col, const int output_cols, - T** outputs) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - double inv_input_col = 1.0 / input_col; - for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { - int split = tid_x * inv_input_col; - int in_offset = tid_x - split * input_col; - T* output_ptr = outputs[split]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * output_cols + in_offset] = - input[tid_y * input_col + tid_x]; - } -} - -/* - * All tensors' dimension should be the same and the values of - * each dimension are the same, except the axis dimension. - */ -template -class ConcatFunctor { - public: - void operator()(const platform::CUDADeviceContext& context, - const std::vector& input, const int axis, - framework::Tensor* output) { - // TODO(zcd): Add input data validity checking - int num = input.size(); - int rows = 1; - auto dim_0 = input[0].dims(); - for (int i = 0; i < axis; ++i) { - rows *= dim_0[i]; - } - int cols = input[0].numel() / rows; - int out_rows = rows, out_cols = 0; - - framework::Vector inputs_data(num * sizeof(T*) / 2); - framework::Vector inputs_cols(num + 1); - inputs_cols[0] = 0; - T** inputs_ptr = reinterpret_cast(inputs_data.data()); - - bool sameShape = true; - for (int i = 0; i < num; ++i) { - int t_cols = input[i].numel() / rows; - if (sameShape) { - if (t_cols != cols) sameShape = false; - } - out_cols += t_cols; - inputs_cols[i + 1] = out_cols; - inputs_ptr[i] = const_cast(input[i].data()); - } - - T** ins_gpu = - reinterpret_cast(inputs_data.CUDAMutableData(context.GetPlace())); - const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace()); - - // computation - // set the thread block and grid according to CurrentDeviceId - const int kThreadsPerBlock = 1024; - int block_cols = kThreadsPerBlock; - if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32. - block_cols = ((out_cols + 31) >> 5) << 5; - } - int block_rows = kThreadsPerBlock / block_cols; - dim3 block_size = dim3(block_cols, block_rows, 1); - - int max_threads = context.GetMaxPhysicalThreadCount(); - int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); - - int grid_cols = - std::min((out_cols + block_cols - 1) / block_cols, max_blocks); - int grid_rows = - std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1)); - dim3 grid_size = dim3(grid_cols, grid_rows, 1); - - if (sameShape) { - hipLaunchKernelGGL((KernelConcat), dim3(grid_size), dim3(block_size), 0, context.stream(), - ins_gpu, cols, out_rows, out_cols, output->data()); - } else { - hipLaunchKernelGGL((KernelConcat), dim3(grid_size), dim3(block_size), 0, context.stream(), - ins_gpu, ins_col_gpu, static_cast(inputs_cols.size()), out_rows, - out_cols, output->data()); - } - } -}; - -/* - * All tensors' dimension should be the same and the values of - * each dimension are the same, except the axis dimension. - */ -template -class ConcatGradFunctor { - public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& input, const int axis, - std::vector& outputs) { - // TODO(zcd): Add input data validity checking - int num = outputs.size(); - int input_row = 1; - auto dim_0 = outputs[0].dims(); - for (int i = 0; i < axis; ++i) { - input_row *= dim_0[i]; - } - - int output_col_0 = outputs[0].numel() / input_row; - int input_col = 0; - bool sameShape = true; - - framework::Vector outputs_data(num * sizeof(T*) / 2); - framework::Vector outputs_cols(num + 1); - outputs_cols[0] = 0; - T** outputs_ptr = reinterpret_cast(outputs_data.data()); - - for (int i = 0; i < num; ++i) { - int t_col = outputs[i].numel() / input_row; - if (sameShape) { - if (t_col != output_col_0) sameShape = false; - } - input_col += t_col; - outputs_cols[i + 1] = input_col; - outputs_ptr[i] = outputs[i].data(); - } - - T** outs_gpu = - reinterpret_cast(outputs_data.CUDAMutableData(context.GetPlace())); - const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace()); - - // computation - const int kThreadsPerBlock = 1024; - int block_cols = kThreadsPerBlock; - if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32. - block_cols = ((input_col + 31) >> 5) << 5; - } - int block_rows = kThreadsPerBlock / block_cols; - dim3 block_size = dim3(block_cols, block_rows, 1); - - int max_threads = context.GetMaxPhysicalThreadCount(); - int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); - - int grid_cols = - std::min((input_col + block_cols - 1) / block_cols, max_blocks); - int grid_rows = - std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1)); - dim3 grid_size = dim3(grid_cols, grid_rows, 1); - - if (sameShape) { - hipLaunchKernelGGL((KernelConcatGrad), dim3(grid_size), dim3(block_size), 0, context.stream(), - input.data(), input_row, input_col, output_col_0, outs_gpu); - } else { - hipLaunchKernelGGL((KernelConcatGrad), dim3(grid_size), dim3(block_size), 0, context.stream(), - input.data(), input_row, input_col, outs_col_gpu, - static_cast(outputs_cols.size()), outs_gpu); - } - } -}; - -template class ConcatFunctor; -template class ConcatFunctor; -template class ConcatFunctor; -template class ConcatFunctor; - -template class ConcatGradFunctor; -template class ConcatGradFunctor; -template class ConcatGradFunctor; -template class ConcatGradFunctor; - -} // namespace math -} // namespace operators -} // namespace paddle +#include diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 02f2d7ba12..a0fc391c7c 100755 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -51,7 +51,7 @@ function cmake_gen() { -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} -DWITH_TESTING=${WITH_TESTING:-ON} -DWITH_FAST_BUNDLE_TEST=ON - -DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake + -DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake -DCMAKE_EXPORT_COMPILE_COMMANDS=ON ======================================== EOF @@ -77,7 +77,7 @@ EOF -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} \ -DWITH_TESTING=${WITH_TESTING:-ON} \ -DWITH_FAST_BUNDLE_TEST=ON \ - -DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake \ + -DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON } From e0ac6bc436725a7750b46a674b97b89cccdef36b Mon Sep 17 00:00:00 2001 From: sabreshao Date: Thu, 22 Mar 2018 10:48:27 +0800 Subject: [PATCH 3/5] CMake refine for HIP support. Fix CI. --- paddle/fluid/pybind/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index d523ad7f73..fe991033df 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,12 +1,12 @@ if(WITH_PYTHON) if(WITH_AMD_GPU) hip_library(paddle_pybind SHARED - SRCS pybind.cc exception.cc protobuf.cc const_value.cc + SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method ${GLOB_OP_LIB}) else() cc_library(paddle_pybind SHARED - SRCS pybind.cc exception.cc protobuf.cc const_value.cc + SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method ${GLOB_OP_LIB}) if(NOT APPLE AND NOT ANDROID) From 8f8728635a028e5ef69498cae109366302a048ee Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Thu, 22 Mar 2018 17:00:06 +0800 Subject: [PATCH 4/5] Fix bug for backward tanspiler when using parallel_do operator. (#9282) * Temporarily fix bug for backward tanspiler when using parallel_do operator. * Fix bug for backward tanspiler when using parallel_do operator --- paddle/fluid/operators/box_coder_op.cc | 3 ++- paddle/fluid/operators/detection_map_op.cc | 4 ++-- paddle/fluid/operators/iou_similarity_op.cc | 5 +++-- paddle/fluid/operators/mine_hard_examples_op.cc | 5 +++-- paddle/fluid/operators/prior_box_op.cc | 4 +++- paddle/fluid/operators/target_assign_op.cc | 4 ++-- python/paddle/fluid/layers/detection.py | 7 +++++-- 7 files changed, 20 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/box_coder_op.cc b/paddle/fluid/operators/box_coder_op.cc index eccdd408a1..ec416f725e 100644 --- a/paddle/fluid/operators/box_coder_op.cc +++ b/paddle/fluid/operators/box_coder_op.cc @@ -126,6 +126,7 @@ width and height. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(box_coder, ops::BoxCoderOp, ops::BoxCoderOpMaker); +REGISTER_OPERATOR(box_coder, ops::BoxCoderOp, ops::BoxCoderOpMaker, + paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(box_coder, ops::BoxCoderKernel, ops::BoxCoderKernel); diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc index 73c84c2fe0..93ef15b933 100644 --- a/paddle/fluid/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -188,8 +188,8 @@ The general steps are as follows. First, calculate the true positive and } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(detection_map, ops::DetectionMAPOp, - ops::DetectionMAPOpMaker); +REGISTER_OPERATOR(detection_map, ops::DetectionMAPOp, ops::DetectionMAPOpMaker, + paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( detection_map, ops::DetectionMAPOpKernel, ops::DetectionMAPOpKernel); diff --git a/paddle/fluid/operators/iou_similarity_op.cc b/paddle/fluid/operators/iou_similarity_op.cc index ffbd7c7814..4b78ec510d 100755 --- a/paddle/fluid/operators/iou_similarity_op.cc +++ b/paddle/fluid/operators/iou_similarity_op.cc @@ -87,8 +87,9 @@ $$ } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(iou_similarity, ops::IOUSimilarityOp, - ops::IOUSimilarityOpMaker); +REGISTER_OPERATOR(iou_similarity, ops::IOUSimilarityOp, + ops::IOUSimilarityOpMaker, + paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( iou_similarity, diff --git a/paddle/fluid/operators/mine_hard_examples_op.cc b/paddle/fluid/operators/mine_hard_examples_op.cc index 0e81d60878..277901cff4 100644 --- a/paddle/fluid/operators/mine_hard_examples_op.cc +++ b/paddle/fluid/operators/mine_hard_examples_op.cc @@ -324,8 +324,9 @@ MatchIndices elements with value -1. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(mine_hard_examples, ops::MineHardExamplesOp, - ops::MineHardExamplesOpMaker); +REGISTER_OPERATOR(mine_hard_examples, ops::MineHardExamplesOp, + ops::MineHardExamplesOpMaker, + paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( mine_hard_examples, diff --git a/paddle/fluid/operators/prior_box_op.cc b/paddle/fluid/operators/prior_box_op.cc index 7ba55437cb..c22a55bce2 100644 --- a/paddle/fluid/operators/prior_box_op.cc +++ b/paddle/fluid/operators/prior_box_op.cc @@ -168,7 +168,9 @@ https://arxiv.org/abs/1512.02325. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker); +REGISTER_OPERATOR(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker, + paddle::framework::EmptyGradOpMaker); + REGISTER_OP_CPU_KERNEL( prior_box, ops::PriorBoxOpKernel, ops::PriorBoxOpKernel); diff --git a/paddle/fluid/operators/target_assign_op.cc b/paddle/fluid/operators/target_assign_op.cc index a894b12fa3..33ff967e5e 100644 --- a/paddle/fluid/operators/target_assign_op.cc +++ b/paddle/fluid/operators/target_assign_op.cc @@ -153,8 +153,8 @@ template struct NegTargetAssignFunctor, diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index a889ab6bdc..cd519e1ee0 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -129,13 +129,11 @@ def detection_output(loc, prior_box_var=prior_box_var, target_box=loc, code_type='decode_center_size') - old_shape = scores.shape scores = ops.reshape(x=scores, shape=(-1, old_shape[-1])) scores = nn.softmax(input=scores) scores = ops.reshape(x=scores, shape=old_shape) scores = nn.transpose(scores, perm=[0, 2, 1]) - nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype) helper.append_op( type="multiclass_nms", @@ -475,6 +473,7 @@ def ssd_loss(location, # 2. Compute confidence for mining hard examples # 2.1. Get the target label based on matched indices gt_label = ops.reshape(x=gt_label, shape=gt_label.shape + (1, )) + gt_label.stop_gradient = True target_label, _ = target_assign( gt_label, matched_indices, mismatch_value=background_label) # 2.2. Compute confidence loss. @@ -482,10 +481,12 @@ def ssd_loss(location, confidence = __reshape_to_2d(confidence) target_label = tensor.cast(x=target_label, dtype='int64') target_label = __reshape_to_2d(target_label) + target_label.stop_gradient = True conf_loss = nn.softmax_with_cross_entropy(confidence, target_label) # 3. Mining hard examples conf_loss = ops.reshape(x=conf_loss, shape=(num, num_prior)) + conf_loss.stop_gradient = True neg_indices = helper.create_tmp_variable(dtype='int32') dtype = matched_indices.dtype updated_matched_indices = helper.create_tmp_variable(dtype=dtype) @@ -695,6 +696,8 @@ def multi_box_head(inputs, outputs={"Boxes": box, "Variances": var}, attrs=attrs, ) + box.stop_gradient = True + var.stop_gradient = True return box, var def _reshape_with_axis_(input, axis=1): From ee7f1ecd7cb79d34a7f14a45d4c34e4e6db9b7af Mon Sep 17 00:00:00 2001 From: Yancey Date: Thu, 22 Mar 2018 19:21:43 +0800 Subject: [PATCH 5/5] Fix dist compile error (#9320) --- .../operators/detail/bytebuffer_stream.h | 5 +++-- paddle/fluid/operators/detail/grpc_server.h | 2 -- paddle/fluid/operators/detail/test_serde.cc | 21 +++++++++---------- .../operators/detail/variable_response.h | 4 ++-- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/detail/bytebuffer_stream.h b/paddle/fluid/operators/detail/bytebuffer_stream.h index 0cbe514d04..1791a48aab 100644 --- a/paddle/fluid/operators/detail/bytebuffer_stream.h +++ b/paddle/fluid/operators/detail/bytebuffer_stream.h @@ -146,8 +146,9 @@ class GrpcByteBufferSource class GrpcByteBufferSourceWrapper : public Source { public: - GrpcByteBufferSourceWrapper(GrpcByteBufferSource* source) : source_(source) {} - virtual ::google::protobuf::io::ZeroCopyInputStream* contents() override { + explicit GrpcByteBufferSourceWrapper(GrpcByteBufferSource* source) + : source_(source) {} + ::google::protobuf::io::ZeroCopyInputStream* contents() override { return source_; } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 9c21a07432..5c278f0ed7 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -29,8 +29,6 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/grpc_service.h" -//#include - namespace paddle { namespace operators { namespace detail { diff --git a/paddle/fluid/operators/detail/test_serde.cc b/paddle/fluid/operators/detail/test_serde.cc index 4be5963794..99c1577223 100644 --- a/paddle/fluid/operators/detail/test_serde.cc +++ b/paddle/fluid/operators/detail/test_serde.cc @@ -81,7 +81,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { // operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); framework::Scope scope; scope.Var("myvar"); - operators::detail::TensorResponse resp(&scope, &ctx); + operators::detail::VariableResponse resp(&scope, &ctx); EXPECT_EQ(resp.Parse(msg), 0); framework::Variable* var2 = resp.GetVar(); @@ -166,7 +166,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { // deserialize zero-copy framework::Scope scope; scope.Var("myvar"); - operators::detail::TensorResponse resp(&scope, &ctx); + operators::detail::VariableResponse resp(&scope, &ctx); if (from_type == 0) { EXPECT_EQ(resp.Parse(msg), 0); } else { @@ -194,24 +194,23 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { for (int i = 0; i < tensor_numel; ++i) EXPECT_FLOAT_EQ(tensor_data2[i], 31.9); } -TEST(LodTensor, GPU) { - platform::CUDAPlace place; +TEST(LodTensor, Run) { + platform::CPUPlace place; RunTestLodTensor(place); RunTestLodTensor(place, 1); -} - -TEST(LodTensor, CPU) { - platform::CPUPlace place; +#ifdef PADDLE_WITH_CUDA + platform::CUDAPlace place; RunTestLodTensor(place); RunTestLodTensor(place, 1); +#endif } -TEST(SelectedRows, CPU) { +TEST(SelectedRows, Run) { platform::CPUPlace place; RunSerdeTestSelectedRows(place); -} -TEST(SelectedRows, GPU) { +#ifdef PADDLE_WITH_CUDA platform::CUDAPlace place; RunSerdeTestSelectedRows(place); +#endif } diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/detail/variable_response.h index c7bc7a46e7..e121ed7bce 100644 --- a/paddle/fluid/operators/detail/variable_response.h +++ b/paddle/fluid/operators/detail/variable_response.h @@ -36,9 +36,9 @@ class VariableResponse { public: VariableResponse(const framework::Scope* scope, const platform::DeviceContext* dev_ctx) - : scope_(scope), dev_ctx_(dev_ctx){}; + : scope_(scope), dev_ctx_(dev_ctx) {} - virtual ~VariableResponse(){}; + virtual ~VariableResponse() {} // return: // 0:ok.