commit
bc3ec53671
@ -0,0 +1,188 @@
|
||||
if(NOT WITH_GPU)
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(paddle_known_gpu_archs "30 35 50 52 60 61 70")
|
||||
set(paddle_known_gpu_archs7 "30 35 50 52")
|
||||
set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
|
||||
|
||||
######################################################################################
|
||||
# A function for automatic detection of GPUs installed (if autodetection is enabled)
|
||||
# Usage:
|
||||
# detect_installed_gpus(out_variable)
|
||||
function(detect_installed_gpus out_variable)
|
||||
if(NOT CUDA_gpu_detect_output)
|
||||
set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu)
|
||||
|
||||
file(WRITE ${cufile} ""
|
||||
"#include <cstdio>\n"
|
||||
"int main() {\n"
|
||||
" int count = 0;\n"
|
||||
" if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n"
|
||||
" if (count == 0) return -1;\n"
|
||||
" for (int device = 0; device < count; ++device) {\n"
|
||||
" cudaDeviceProp prop;\n"
|
||||
" if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n"
|
||||
" std::printf(\"%d.%d \", prop.major, prop.minor);\n"
|
||||
" }\n"
|
||||
" return 0;\n"
|
||||
"}\n")
|
||||
|
||||
execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "-ccbin=${CUDA_HOST_COMPILER}"
|
||||
"--run" "${cufile}"
|
||||
WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/"
|
||||
RESULT_VARIABLE nvcc_res OUTPUT_VARIABLE nvcc_out
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
if(nvcc_res EQUAL 0)
|
||||
# only keep the last line of nvcc_out
|
||||
STRING(REGEX REPLACE ";" "\\\\;" nvcc_out "${nvcc_out}")
|
||||
STRING(REGEX REPLACE "\n" ";" nvcc_out "${nvcc_out}")
|
||||
list(GET nvcc_out -1 nvcc_out)
|
||||
string(REPLACE "2.1" "2.1(2.0)" nvcc_out "${nvcc_out}")
|
||||
set(CUDA_gpu_detect_output ${nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_installed_gpus tool" FORCE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT CUDA_gpu_detect_output)
|
||||
message(STATUS "Automatic GPU detection failed. Building for all known architectures.")
|
||||
set(${out_variable} ${paddle_known_gpu_archs} PARENT_SCOPE)
|
||||
else()
|
||||
set(${out_variable} ${CUDA_gpu_detect_output} PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
|
||||
########################################################################
|
||||
# Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME
|
||||
# Usage:
|
||||
# select_nvcc_arch_flags(out_variable)
|
||||
function(select_nvcc_arch_flags out_variable)
|
||||
# List of arch names
|
||||
set(archs_names "Kepler" "Maxwell" "Pascal" "All" "Manual")
|
||||
set(archs_name_default "All")
|
||||
if(NOT CMAKE_CROSSCOMPILING)
|
||||
list(APPEND archs_names "Auto")
|
||||
endif()
|
||||
|
||||
# set CUDA_ARCH_NAME strings (so it will be seen as dropbox in CMake-Gui)
|
||||
set(CUDA_ARCH_NAME ${archs_name_default} CACHE STRING "Select target NVIDIA GPU achitecture.")
|
||||
set_property( CACHE CUDA_ARCH_NAME PROPERTY STRINGS "" ${archs_names} )
|
||||
mark_as_advanced(CUDA_ARCH_NAME)
|
||||
|
||||
# verify CUDA_ARCH_NAME value
|
||||
if(NOT ";${archs_names};" MATCHES ";${CUDA_ARCH_NAME};")
|
||||
string(REPLACE ";" ", " archs_names "${archs_names}")
|
||||
message(FATAL_ERROR "Only ${archs_names} architeture names are supported.")
|
||||
endif()
|
||||
|
||||
if(${CUDA_ARCH_NAME} STREQUAL "Manual")
|
||||
set(CUDA_ARCH_BIN ${paddle_known_gpu_archs} CACHE STRING "Specify 'real' GPU architectures to build binaries for, BIN(PTX) format is supported")
|
||||
set(CUDA_ARCH_PTX "50" CACHE STRING "Specify 'virtual' PTX architectures to build PTX intermediate code for")
|
||||
mark_as_advanced(CUDA_ARCH_BIN CUDA_ARCH_PTX)
|
||||
else()
|
||||
unset(CUDA_ARCH_BIN CACHE)
|
||||
unset(CUDA_ARCH_PTX CACHE)
|
||||
endif()
|
||||
|
||||
if(${CUDA_ARCH_NAME} STREQUAL "Kepler")
|
||||
set(cuda_arch_bin "30 35")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell")
|
||||
set(cuda_arch_bin "50")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
|
||||
set(cuda_arch_bin "60 61")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
|
||||
set(cuda_arch_bin "70")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
|
||||
set(cuda_arch_bin ${paddle_known_gpu_archs})
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")
|
||||
detect_installed_gpus(cuda_arch_bin)
|
||||
else() # (${CUDA_ARCH_NAME} STREQUAL "Manual")
|
||||
set(cuda_arch_bin ${CUDA_ARCH_BIN})
|
||||
endif()
|
||||
|
||||
# remove dots and convert to lists
|
||||
string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}")
|
||||
string(REGEX REPLACE "\\." "" cuda_arch_ptx "${CUDA_ARCH_PTX}")
|
||||
string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}")
|
||||
string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}")
|
||||
list(REMOVE_DUPLICATES cuda_arch_bin)
|
||||
list(REMOVE_DUPLICATES cuda_arch_ptx)
|
||||
|
||||
set(nvcc_flags "")
|
||||
set(nvcc_archs_readable "")
|
||||
|
||||
# Tell NVCC to add binaries for the specified GPUs
|
||||
foreach(arch ${cuda_arch_bin})
|
||||
if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
|
||||
# User explicitly specified PTX for the concrete BIN
|
||||
list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})
|
||||
list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1})
|
||||
else()
|
||||
# User didn't explicitly specify PTX for the concrete BIN, we assume PTX=BIN
|
||||
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch})
|
||||
list(APPEND nvcc_archs_readable sm_${arch})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Tell NVCC to add PTX intermediate code for the specified architectures
|
||||
foreach(arch ${cuda_arch_ptx})
|
||||
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch})
|
||||
list(APPEND nvcc_archs_readable compute_${arch})
|
||||
endforeach()
|
||||
|
||||
string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}")
|
||||
set(${out_variable} ${nvcc_flags} PARENT_SCOPE)
|
||||
set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
message(STATUS "CUDA detected: " ${CUDA_VERSION})
|
||||
if (${CUDA_VERSION} LESS 7.0)
|
||||
set(paddle_known_gpu_archs ${paddle_known_gpu_archs})
|
||||
elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x
|
||||
set(paddle_known_gpu_archs ${paddle_known_gpu_archs7})
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
|
||||
elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
|
||||
set(paddle_known_gpu_archs ${paddle_known_gpu_archs8})
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
|
||||
# CUDA 8 may complain that sm_20 is no longer supported. Suppress the
|
||||
# warning for now.
|
||||
list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
|
||||
endif()
|
||||
|
||||
include_directories(${CUDA_INCLUDE_DIRS})
|
||||
list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
|
||||
if(NOT WITH_DSO)
|
||||
list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY} ${NCCL_LIBRARY})
|
||||
endif(NOT WITH_DSO)
|
||||
|
||||
# setting nvcc arch flags
|
||||
select_nvcc_arch_flags(NVCC_FLAGS_EXTRA)
|
||||
list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA})
|
||||
message(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA_readable}")
|
||||
|
||||
# Set C++11 support
|
||||
set(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
||||
|
||||
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
|
||||
# So, don't set these flags here.
|
||||
list(APPEND CUDA_NVCC_FLAGS "-std=c++11")
|
||||
list(APPEND CUDA_NVCC_FLAGS "--use_fast_math")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC")
|
||||
# Set :expt-relaxed-constexpr to suppress Eigen warnings
|
||||
list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
|
||||
|
||||
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG})
|
||||
elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
|
||||
elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
|
||||
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO})
|
||||
elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel")
|
||||
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL})
|
||||
endif()
|
||||
|
||||
mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD)
|
||||
mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION)
|
@ -0,0 +1,97 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "Layer.h"
|
||||
#include "paddle/math/Matrix.h"
|
||||
#include "paddle/utils/Logging.h"
|
||||
#include "paddle/utils/Stat.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* @brief A layer for computing the dot product of two vectors.
|
||||
* Input1: vector (batchSize * dim)
|
||||
* Input2: vector (batchSize * dim)
|
||||
* Output: a matrix: (batchSize * 1)
|
||||
*/
|
||||
|
||||
class DotProdLayer : public Layer {
|
||||
public:
|
||||
explicit DotProdLayer(const LayerConfig& config) : Layer(config) {}
|
||||
|
||||
~DotProdLayer() {}
|
||||
|
||||
bool init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) override;
|
||||
|
||||
void forward(PassType passType) override;
|
||||
void backward(const UpdateCallback& callback = nullptr) override;
|
||||
};
|
||||
|
||||
REGISTER_LAYER(dot_prod, DotProdLayer);
|
||||
|
||||
bool DotProdLayer::init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) {
|
||||
Layer::init(layerMap, parameterMap);
|
||||
|
||||
CHECK_EQ(inputLayers_.size(), 2U);
|
||||
CHECK_EQ(1UL, getSize())
|
||||
<< "The output dimensionality of this layer should be fixed to 1.";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void DotProdLayer::forward(PassType passType) {
|
||||
Layer::forward(passType);
|
||||
|
||||
MatrixPtr inV0 = getInputValue(0);
|
||||
MatrixPtr inV1 = getInputValue(1);
|
||||
|
||||
size_t batchSize = inV0->getHeight();
|
||||
CHECK_EQ(inV1->getHeight(), batchSize);
|
||||
CHECK_EQ(inV0->getWidth(), inV1->getWidth());
|
||||
|
||||
{
|
||||
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
|
||||
reserveOutput(batchSize, 1);
|
||||
}
|
||||
|
||||
MatrixPtr outV = getOutputValue();
|
||||
{
|
||||
REGISTER_TIMER_INFO("FwDotProdTimer", getName().c_str());
|
||||
outV->sumOfProducts(*inV0, *inV1, 1, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void DotProdLayer::backward(const UpdateCallback& callback) {
|
||||
MatrixPtr inV0 = getInputValue(0);
|
||||
MatrixPtr inV1 = getInputValue(1);
|
||||
MatrixPtr outG = getOutputGrad();
|
||||
MatrixPtr inG0 = getInputGrad(0);
|
||||
MatrixPtr inG1 = getInputGrad(1);
|
||||
|
||||
{
|
||||
REGISTER_TIMER_INFO("BwDotProdTimer", getName().c_str());
|
||||
|
||||
if (inG0) {
|
||||
inG0->addRowScale(0, *inV1, *outG);
|
||||
}
|
||||
|
||||
if (inG1) {
|
||||
inG1->addRowScale(0, *inV0, *outG);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,91 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "L2DistanceLayer.h"
|
||||
#include "paddle/utils/Logging.h"
|
||||
#include "paddle/utils/Stat.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
REGISTER_LAYER(l2_distance, L2DistanceLayer);
|
||||
|
||||
bool L2DistanceLayer::init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) {
|
||||
/* Initialize the basic parent class */
|
||||
Layer::init(layerMap, parameterMap);
|
||||
|
||||
CHECK_EQ(inputLayers_.size(), 2UL) << "The L2DistanceLayer accepts two and "
|
||||
<< "only two inputs.";
|
||||
CHECK_EQ(getSize(), 1UL) << "The output dimensionality of L2DistanceLayer "
|
||||
<< "is fixed to be 1.";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void L2DistanceLayer::forward(PassType passType) {
|
||||
Layer::forward(passType);
|
||||
|
||||
const auto inV1 = getInputValue(0);
|
||||
const auto inV2 = getInputValue(1);
|
||||
|
||||
CHECK(inV1 && inV2);
|
||||
CHECK_EQ(inV1->getHeight(), inV2->getHeight())
|
||||
<< "The height of two inputs of this layer must be the same.";
|
||||
CHECK_EQ(inV1->getWidth(), inV2->getWidth())
|
||||
<< "The width of two inputs of this layer must be the same.";
|
||||
|
||||
int batchSize = inV1->getHeight();
|
||||
int output_dim = getSize();
|
||||
{
|
||||
REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str());
|
||||
reserveOutput(batchSize, output_dim);
|
||||
auto outV = getOutputValue();
|
||||
CHECK(outV) << "The output matrix should not be null.";
|
||||
|
||||
Matrix::resizeOrCreate(
|
||||
inputSub_, inV1->getHeight(), inV1->getWidth(), false, useGpu_);
|
||||
|
||||
inputSub_->assign(*inV1);
|
||||
inputSub_->sub(*inV2);
|
||||
outV->sumOfProducts(*inputSub_, *inputSub_, 1, 0);
|
||||
outV->sqrt2(*outV);
|
||||
}
|
||||
}
|
||||
|
||||
void L2DistanceLayer::backward(const UpdateCallback& callback) {
|
||||
const auto outG = getOutputGrad();
|
||||
const auto outV = getOutputValue();
|
||||
CHECK(outG && outV);
|
||||
|
||||
auto inGrad1 = getInputGrad(0);
|
||||
auto inGrad2 = getInputGrad(1);
|
||||
|
||||
{
|
||||
REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str());
|
||||
|
||||
if (inGrad1 || inGrad2) {
|
||||
outV->scalarDiv(*outV, 1.);
|
||||
outV->dotMul(*outG, *outV);
|
||||
}
|
||||
|
||||
if (inGrad1) inGrad1->addRowScale(0, *inputSub_, *outV);
|
||||
|
||||
if (inGrad2) {
|
||||
inputSub_->mulScalar(-1.);
|
||||
inGrad2->addRowScale(0, *inputSub_, *outV);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,52 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Layer.h"
|
||||
#include "paddle/math/Matrix.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* @brief The layer calculates the l2 distance between two input vectors.
|
||||
* \f[
|
||||
* f(\bf{x}, \bf{y}) = \sqrt{\sum_{i=1}^D(x_i - y_i)}
|
||||
* \f]
|
||||
*
|
||||
* - Input1: A vector (batchSize * dataDim)
|
||||
* - Input2: A vector (batchSize * dataDim)
|
||||
* - Output: A vector (batchSize * 1)
|
||||
*
|
||||
* The configuration api is: l2_distance_layer.
|
||||
*/
|
||||
|
||||
class L2DistanceLayer : public Layer {
|
||||
public:
|
||||
explicit L2DistanceLayer(const LayerConfig& config) : Layer(config) {}
|
||||
~L2DistanceLayer() {}
|
||||
|
||||
bool init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) override;
|
||||
|
||||
void forward(PassType passType) override;
|
||||
void backward(const UpdateCallback& callback = nullptr) override;
|
||||
|
||||
private:
|
||||
// Store the result of subtracting Input2 from Input1 in forward computation,
|
||||
// which will be reused in backward computation.
|
||||
MatrixPtr inputSub_;
|
||||
};
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,202 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "MKLDNNConcatLayer.h"
|
||||
|
||||
using namespace mkldnn; // NOLINT
|
||||
typedef memory::format format;
|
||||
|
||||
namespace paddle {
|
||||
|
||||
REGISTER_LAYER(mkldnn_concat, MKLDNNConcatLayer);
|
||||
|
||||
bool MKLDNNConcatLayer::init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) {
|
||||
if (!MKLDNNLayer::init(layerMap, parameterMap)) {
|
||||
return false;
|
||||
}
|
||||
CHECK_GT(inputLayers_.size(), 1UL);
|
||||
CHECK(!biasParameter_);
|
||||
return true;
|
||||
}
|
||||
|
||||
void MKLDNNConcatLayer::reshape(
|
||||
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) {
|
||||
reshapeInput(bs, ih, iw);
|
||||
ic = inputLayers_[0]->getSize() / ih / iw;
|
||||
CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize());
|
||||
CHECK_EQ(inputElemenCnt_, (size_t)bs * ic * ih * iw);
|
||||
CHECK_GT(inputLayers_.size(), 1UL);
|
||||
channels_.resize(inputLayers_.size());
|
||||
channels_[0] = ic;
|
||||
// need change the output channel, so use oc_ instead
|
||||
// TODO(TJ): change API, use &oc
|
||||
oc_ = ic;
|
||||
for (size_t i = 1; i < inputLayers_.size(); i++) {
|
||||
int batchsize, height, witdh;
|
||||
reshapeInput(batchsize, height, witdh, i);
|
||||
CHECK_EQ(bs, batchsize);
|
||||
CHECK_EQ(ih, height);
|
||||
CHECK_EQ(iw, witdh);
|
||||
|
||||
channels_[i] = inputLayers_[i]->getSize() / height / witdh;
|
||||
CHECK_EQ((size_t)channels_[i] * height * witdh, inputLayers_[i]->getSize());
|
||||
oc_ += channels_[i];
|
||||
}
|
||||
oh = ih;
|
||||
ow = iw;
|
||||
reshapeOutput(oh, ow);
|
||||
resizeOutput(bs, oc_ * oh * ow);
|
||||
}
|
||||
|
||||
void MKLDNNConcatLayer::resetFwd(std::vector<primitive>& pipeline,
|
||||
MKLDNNMatrixPtr& in,
|
||||
MKLDNNMatrixPtr& wgt,
|
||||
MKLDNNMatrixPtr& bias,
|
||||
MKLDNNMatrixPtr& out) {
|
||||
resetFwdBuffers(inVals_, out);
|
||||
in = inVals_[0];
|
||||
|
||||
std::shared_ptr<concat::primitive_desc> fwdPD;
|
||||
resetFwdPD(fwdPD, inVals_, out);
|
||||
|
||||
resetFwdPipeline(pipeline, fwdPD, inVals_, out);
|
||||
}
|
||||
|
||||
void MKLDNNConcatLayer::resetBwd(std::vector<primitive>& pipeline,
|
||||
MKLDNNMatrixPtr& in,
|
||||
MKLDNNMatrixPtr& wgt,
|
||||
MKLDNNMatrixPtr& bias,
|
||||
MKLDNNMatrixPtr& out) {
|
||||
resetBwdBuffers(inGrads_, out);
|
||||
in = inGrads_[0];
|
||||
|
||||
resetBwdPipeline(pipeline, bwds_, inGrads_, out);
|
||||
}
|
||||
|
||||
void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
|
||||
MKLDNNMatrixPtr& out) {
|
||||
inputs.resize(inputLayers_.size());
|
||||
bool has8c = false, has16c = false, hasnc = false;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
// resetInValue will use ic_ so temporary change as current input's channel
|
||||
// TODO(TJ): change ic_ as vector then can remove channels_
|
||||
ic_ = channels_[i];
|
||||
resetInValue(inputs[i], nullptr, i);
|
||||
CHECK(inputs[i]);
|
||||
auto dm = inputs[i]->getDims();
|
||||
// inputs format can be different, but ndims must equal
|
||||
CHECK(i == 0 || dm.size() == inputs[0]->getDims().size());
|
||||
CHECK_EQ(bs_, dm[0]);
|
||||
CHECK_EQ(channels_[i], dm[1]);
|
||||
if (dm.size() > 2) {
|
||||
CHECK_EQ(ih_, dm[2]);
|
||||
CHECK_EQ(iw_, dm[3]);
|
||||
}
|
||||
if (inputs[i]->getFormat() == format::nc) {
|
||||
hasnc = true;
|
||||
}
|
||||
if (inputs[i]->getFormat() == format::nChw8c) {
|
||||
has8c = true;
|
||||
}
|
||||
if (inputs[i]->getFormat() == format::nChw16c) {
|
||||
has16c = true;
|
||||
}
|
||||
}
|
||||
// change back, ic_ always save the input 0 size
|
||||
ic_ = channels_[0];
|
||||
|
||||
format outFmt;
|
||||
if (has16c && oc_ % 16 == 0) {
|
||||
outFmt = format::nChw16c;
|
||||
} else if (has8c && oc_ % 8 == 0) {
|
||||
outFmt = format::nChw8c;
|
||||
} else if (hasnc) {
|
||||
CHECK(oh_ == 1 && ow_ == 1);
|
||||
outFmt = format::nc;
|
||||
} else {
|
||||
outFmt = format::nchw;
|
||||
}
|
||||
memory::dims outDims =
|
||||
hasnc ? memory::dims{bs_, oc_} : memory::dims{bs_, oc_, oh_, ow_};
|
||||
auto outPD = MKLDNNMatrix::createPrimitiveDesc(outDims, outFmt, engine_);
|
||||
resetOutValue(out, outPD);
|
||||
}
|
||||
|
||||
void MKLDNNConcatLayer::resetFwdPD(std::shared_ptr<concat::primitive_desc>& pd,
|
||||
std::vector<MKLDNNMatrixPtr>& inputs,
|
||||
MKLDNNMatrixPtr out) {
|
||||
std::vector<memory::primitive_desc> srcPDs;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
srcPDs.push_back(inputs[i]->getPrimitiveDesc());
|
||||
}
|
||||
CHECK(out);
|
||||
pd.reset(new concat::primitive_desc(out->getMemoryDesc(), axis_, srcPDs));
|
||||
CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc());
|
||||
}
|
||||
|
||||
void MKLDNNConcatLayer::resetFwdPipeline(
|
||||
std::vector<primitive>& pipeline,
|
||||
std::shared_ptr<concat::primitive_desc>& pd,
|
||||
std::vector<MKLDNNMatrixPtr>& inputs,
|
||||
MKLDNNMatrixPtr& out) {
|
||||
std::vector<primitive::at> srcs;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
srcs.push_back(*(inputs[i]));
|
||||
}
|
||||
fwd_.reset(new concat(*pd, srcs, *out));
|
||||
pipeline.push_back(*fwd_);
|
||||
}
|
||||
|
||||
void MKLDNNConcatLayer::resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
|
||||
MKLDNNMatrixPtr& out) {
|
||||
CHECK(outVal_);
|
||||
resetOutGrad(out, outVal_->getPrimitiveDesc());
|
||||
CHECK(out);
|
||||
|
||||
inputs.resize(inputLayers_.size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
CHECK(inVals_[i]);
|
||||
// resetInGrad will use inVal_
|
||||
// TODO(TJ): change move inVals_ to MKLDNNLayer ans remove inVal_
|
||||
inVal_ = inVals_[i];
|
||||
resetInGrad(inputs[i], inVals_[i]->getPrimitiveDesc(), i);
|
||||
CHECK_PRIMITIVE_DESC_EQ(inputs[i], inVals_[i]->getPrimitiveDesc());
|
||||
}
|
||||
// change back, inVal_ always save the input 0
|
||||
inVal_ = inVals_[0];
|
||||
}
|
||||
|
||||
void MKLDNNConcatLayer::resetBwdPipeline(
|
||||
std::vector<mkldnn::primitive>& pipeline,
|
||||
std::vector<std::shared_ptr<mkldnn::primitive>>& prims,
|
||||
std::vector<MKLDNNMatrixPtr>& inputs,
|
||||
MKLDNNMatrixPtr& out) {
|
||||
// reset the backward primitives
|
||||
memory::dims offsets = {0, 0, 0, 0};
|
||||
prims.resize(inputs.size());
|
||||
CHECK_EQ(inputs.size(), channels_.size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
auto viewPD = view::primitive_desc(
|
||||
out->getPrimitiveDesc(), inputs[i]->getDims(), offsets);
|
||||
auto bwdPD = reorder::primitive_desc(viewPD.dst_primitive_desc(),
|
||||
inputs[i]->getPrimitiveDesc());
|
||||
prims[i].reset(new reorder(bwdPD, *out, *(inputs[i])));
|
||||
offsets[axis_] += channels_[i];
|
||||
// push to pipeline
|
||||
pipeline.push_back(*prims[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,129 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "MKLDNNLayer.h"
|
||||
#include "mkldnn.hpp"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* @brief A subclass of MKLDNNLayer Concatenate layer.
|
||||
*
|
||||
* The config file api is mkldnn_concat
|
||||
*/
|
||||
class MKLDNNConcatLayer : public MKLDNNLayer {
|
||||
protected:
|
||||
std::vector<MKLDNNMatrixPtr> inVals_;
|
||||
std::vector<MKLDNNMatrixPtr> inGrads_;
|
||||
std::vector<std::shared_ptr<mkldnn::primitive>> bwds_;
|
||||
// input channel numbers
|
||||
std::vector<int> channels_;
|
||||
|
||||
// concat_dimension in MKLDNN
|
||||
// if axis_ == 0, concat batchsize
|
||||
// if axis_ == 1, concat channel (default)
|
||||
int axis_;
|
||||
|
||||
public:
|
||||
explicit MKLDNNConcatLayer(const LayerConfig& config)
|
||||
: MKLDNNLayer(config), axis_(1) {}
|
||||
|
||||
~MKLDNNConcatLayer() {}
|
||||
|
||||
bool init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) override;
|
||||
|
||||
void reshape(
|
||||
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override;
|
||||
|
||||
void resetFwd(std::vector<mkldnn::primitive>& pipeline,
|
||||
MKLDNNMatrixPtr& in,
|
||||
MKLDNNMatrixPtr& wgt,
|
||||
MKLDNNMatrixPtr& bias,
|
||||
MKLDNNMatrixPtr& out) override;
|
||||
|
||||
void resetBwd(std::vector<mkldnn::primitive>& pipeline,
|
||||
MKLDNNMatrixPtr& in,
|
||||
MKLDNNMatrixPtr& wgt,
|
||||
MKLDNNMatrixPtr& bias,
|
||||
MKLDNNMatrixPtr& out) override;
|
||||
|
||||
void printSizeInfo() override {
|
||||
CHECK_EQ(channels_.size(), inputLayers_.size());
|
||||
for (size_t i = 0; i < channels_.size(); ++i) {
|
||||
VLOG(MKLDNN_SIZES) << "Input " << i << ", " << inputLayers_[i]->getName()
|
||||
<< ": " << bs_ << ", " << channels_[i] << ", " << ih_
|
||||
<< ", " << iw_;
|
||||
}
|
||||
VLOG(MKLDNN_SIZES) << "Output: " << bs_ << ", " << oc_ << ", " << oh_
|
||||
<< ", " << ow_;
|
||||
}
|
||||
|
||||
void printValueFormat() override {
|
||||
for (size_t i = 0; i < inVals_.size(); ++i) {
|
||||
VLOG(MKLDNN_FMTS) << "Input " << i << ", " << inputLayers_[i]->getName()
|
||||
<< ": " << inVals_[i]->getFormat() << " >>>";
|
||||
}
|
||||
if (outVal_) {
|
||||
VLOG(MKLDNN_FMTS) << outVal_->getFormat() << " >>> ";
|
||||
}
|
||||
if (extOutVal_) {
|
||||
VLOG(MKLDNN_FMTS) << extOutVal_->getFormat();
|
||||
}
|
||||
}
|
||||
|
||||
void printGradFormat() override {
|
||||
if (extOutGrad_) {
|
||||
VLOG(MKLDNN_FMTS) << extOutGrad_->getFormat();
|
||||
}
|
||||
if (outGrad_) {
|
||||
VLOG(MKLDNN_FMTS) << outGrad_->getFormat() << " <<< ";
|
||||
}
|
||||
for (size_t i = 0; i < inGrads_.size(); ++i) {
|
||||
VLOG(MKLDNN_FMTS) << "Input " << i << ", " << inputLayers_[i]->getName()
|
||||
<< ": " << inGrads_[i]->getFormat() << "<<<";
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* Forward functions: reset buffers(inputs, output, bias),
|
||||
* reset primitive descriptor,
|
||||
* reset pipeline.
|
||||
*/
|
||||
void resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
|
||||
MKLDNNMatrixPtr& out);
|
||||
void resetFwdPD(std::shared_ptr<mkldnn::concat::primitive_desc>& pd,
|
||||
std::vector<MKLDNNMatrixPtr>& inputs,
|
||||
MKLDNNMatrixPtr out);
|
||||
void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline,
|
||||
std::shared_ptr<mkldnn::concat::primitive_desc>& pd,
|
||||
std::vector<MKLDNNMatrixPtr>& inputs,
|
||||
MKLDNNMatrixPtr& out);
|
||||
|
||||
/**
|
||||
* Backward functions: reset buffers(inputs, output, bias)
|
||||
* reset primitives and pipeline
|
||||
*/
|
||||
void resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
|
||||
MKLDNNMatrixPtr& out);
|
||||
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
|
||||
std::vector<std::shared_ptr<mkldnn::primitive>>& prims,
|
||||
std::vector<MKLDNNMatrixPtr>& inputs,
|
||||
MKLDNNMatrixPtr& out);
|
||||
};
|
||||
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue