Add job=time in trainer, refine cudnn_conv to reduce gpu memory and speed up training. (#218)

* Add benchmark for PaddlePaddle, tensorflow and caffe

* ConvProjection to reduce memory for goolenet

* Add unit test for ConvProjection.
1. unit test in test_LayerGrad.
2. compare the ConvPorjection and CudnnConvLayer, also compare the concat_layer+img_conv_layer and concat_layer_conv_projection.

* Reduce cudnn_conv memory and add benchmark document.
1. Use TmpMatrix as the workspace in cudnn_conv to reduce gpu memory. It reduce lots of memory.
2. Add benchmark document.
3. fix smallnet_mnist_cifar.py in paddle.

* Add job=time and refine cudnn_conv to reduce gpu memroy and speed up

* Refine cudnn_conv and shared biases operation in concat_layer and mixed_layer.

* follow comments

* follow comments

* Use unique_ptr to prevent memory leaks in CudnnConvLayer.
avx_docs
qingqing01 8 years ago committed by GitHub
parent 12945b2c90
commit 45c81a414f

@ -183,7 +183,7 @@ It looks like there are a lot of arguments. However, most of them are for develo
</tr>
<tr>
<td class="left" rowspan = "5">GPU</td><td class="left">gpu_id</td>
<td class="left" rowspan = "6">GPU</td><td class="left">gpu_id</td>
<td class="left"></td><td class="left"></td><td class="left"></td><td class="left"></td>
</tr>
@ -207,6 +207,11 @@ It looks like there are a lot of arguments. However, most of them are for develo
<td class="left"></td><td class="left"></td><td class="left"></td><td class="left"></td>
</tr>
<tr>
<td class="left">cudnn_conv_workspace_limit_in_mb</td>
<td class="left"></td><td class="left"></td><td class="left"></td><td class="left"></td>
</tr>
<tr>
<td class="left" rowspan = "4">RNN</td>
<td class="left">beam_size</td>

@ -163,6 +163,10 @@
- Choose path to dynamic load NVIDIA CUDA library, for instance, /usr/local/cuda/lib64. [Default]: LD_LIBRARY_PATH
- type: string (default: "", null)
* `--cudnn_conv_workspace_limit_in_mb`
- Specify cuDNN max workspace limit, in units MB, 4096MB=4GB by default.
- type: int32 (default: 4096MB=4GB)
## NLP: RNN/LSTM/GRU
* `--rnn_use_batch`
- Whether to use batch method for calculation in simple RecurrentLayer.

@ -48,5 +48,24 @@ inline __device__ double paddleAtomicAdd(double* address, double val) {
}
} // namespace paddle
/**
* @brief sum reduction
*
* @param[in,out] smem input data, better to use __shared__ memory.
* @param[in] tid thread index.
* @param[in] threads the total thread number used to reduce,
* such as, blockDim.x.
*
* @return smem[0]: the sum of each elements in smem.
*/
__device__ __forceinline__
void simpleReduce(real* smem, int tid, int threads) {
for (unsigned int s = threads / 2; s > 0; s >>= 1) {
if (tid < s) {
smem[tid] += smem[tid + s];
}
__syncthreads();
}
}
#endif /* HL_DEVICE_FUNCTIONS_CUH_ */

@ -229,4 +229,40 @@ extern void hl_cossim_derivative(real* grad,
int input2_height,
real scale);
/**
* @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel].
*
* @param[in] A_d input matrix (M x N).
* @param[in] B_d input matrix (1 x channel).
* @param[in] channel width of B.
* @param[in] dimM height of A.
* @param[in] dimN width of A.
* @param[in] scale scalar used for addition.
*
*/
extern void hl_matrix_add_shared_bias(real* A_d,
real* B_d,
const int channel,
const int dimM,
const int dimN,
real scale);
/**
* @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel].
*
* @param[in] B_d input matrix (1 x channel).
* @param[in] A_d input matrix (M x N).
* @param[in] channel width of B.
* @param[in] dimM height of A.
* @param[in] dimN width of A.
* @param[in] scale scalar used for addition.
*
*/
extern void hl_matrix_collect_shared_bias(real* B_d,
real* A_d,
const int channel,
const int dimM,
const int dimN,
real scale);
#endif /* HL_MATRIX_H_ */

@ -101,4 +101,17 @@ inline void hl_cossim_derivative(real* grad,
int input2_height,
real scale) {}
inline void hl_matrix_add_shared_bias(real* A_d,
real* B_d,
const int channel,
const int dimM,
const int dimN,
real scale) {}
inline void hl_matrix_collect_shared_bias(real* B_d,
real* A_d,
const int channel,
const int dimM,
const int dimN,
real scale) {}
#endif // HL_MATRIX_STUB_H_

@ -20,6 +20,11 @@ limitations under the License. */
#include "hl_thread.ph"
#include "hl_dso_loader.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/CommandLineParser.h"
P_DEFINE_int32(cudnn_conv_workspace_limit_in_mb, 4096,
"Specify cuDNN max workspace limit, in units MB, "
"4096MB=4GB by default.");
namespace dynload {
@ -242,7 +247,7 @@ void hl_conv_workspace(hl_tensor_descriptor input,
CHECK_NOTNULL(conv);
// Specify workspace limit directly
size_t memoryLimitBytes = 8 * 1024 * 1024;
size_t memoryLimitBytes = (1LL << 20) * FLAGS_cudnn_conv_workspace_limit_in_mb;
// cudnn convolution forward configuration
cudnnTensorDescriptor_t fwd_src_desc = GET_TENSOR_DESCRIPTOR(input);

@ -20,6 +20,7 @@ limitations under the License. */
#include "hl_sequence.h"
#include "paddle/utils/Logging.h"
#include "hl_device_functions.cuh"
#include "hl_gpu_matrix_kernel.cuh"
DEFINE_MATRIX_UNARY_OP(Zero, a = 0);
DEFINE_MATRIX_TERNARY_PARAMETER_OP(_add, TWO_PARAMETER, c = p1*a + p2*b);
@ -673,3 +674,89 @@ void hl_cossim_derivative(real* grad,
input1_height, input2_height, scale);
CHECK_SYNC("hl_cossim_derivate failed");
}
__global__ void KeMatrixAddSharedBias(real* A,
real* B,
const int channel,
const int M,
const int N,
real scale) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int dim = N / channel;
if (index < M * N) {
int i = index % N;
i = i / dim;
A[index] += scale * B[i];
}
}
void hl_matrix_add_shared_bias(real* A_d,
real* B_d,
const int channel,
const int dimM,
const int dimN,
real scale) {
const int blocks = 512;
const int grids = DIVUP(dimM * dimN, blocks);
KeMatrixAddSharedBias<<<grids, blocks, 0, STREAM_DEFAULT>>>
(A_d, B_d, channel, dimM, dimN, scale);
CHECK_SYNC("hl_matrix_add_shared_bias failed");
}
template <int blockSize>
__global__ void KeMatrixCollectSharedBias(real *B,
real *A,
const int channel,
const int M,
const int N,
const int dim,
const int limit,
real scale) {
if (dim < limit) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < channel) {
real sum = 0.0;
for (int i = 0; i < M; ++i) {
for (int j = 0; j < dim; ++j) {
sum += A[i * N + index * dim + j];
}
}
B[index] += scale * sum;
}
} else {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
__shared__ real smem[blockSize];
real sum = 0.0;
for (int j = 0; j < ((dim * M + blockSize - 1) / blockSize); ++j) {
int n = j * blockSize + tid;
int m = n / dim;
int w = n % dim;
smem[tid] = (m < M && w < dim) ? A[m * N + bid * dim + w] : 0.0;
__syncthreads();
simpleReduce(smem, tid, blockSize);
sum += smem[0];
}
if (tid == 0) {
B[bid] += scale * sum;
}
}
}
void hl_matrix_collect_shared_bias(real* B_d,
real* A_d,
const int channel,
const int dimM,
const int dimN,
real scale) {
const int dim = dimN / channel;
const int blocks = 256;
const int limit = 64;
int grids = (dimM * dim) < limit ? DIVUP(channel, blocks) : channel;
KeMatrixCollectSharedBias<blocks>
<<< grids, blocks, 0, STREAM_DEFAULT>>>
(B_d, A_d, channel, dimM, dimN, dim, limit, scale);
CHECK_SYNC("hl_matrix_collect_shared_bias failed");
}

@ -908,24 +908,6 @@ int findIndex(int* indice, int num, int index) {
return (end - 1);
}
/**
* @brief sum reduction
*
* @param[in,out] smem input data, better to use __shared__ memory.
* @param[in] tid local thread index.
* @param[in] blockDimX the size of blockDim.x.
*
* note: return smem[0]: the sum of each elements of smem.
*/
__device__ __forceinline__
void reduce(real* smem, int tid, int blockDimX) {
for (unsigned int s = blockDimX / 2; s > 0; s >>= 1) {
if (tid < s) {
smem[tid] += smem[tid + s];
}
__syncthreads();
}
}
/**
* @brief sum columns of csr sparse matrix (csr_val), then add to a_val.

@ -97,7 +97,8 @@ void ConcatenateLayer::backward(const UpdateCallback& callback) {
*/
class ConcatenateLayer2 : public Layer {
public:
explicit ConcatenateLayer2(const LayerConfig& config) : Layer(config) {}
explicit ConcatenateLayer2(const LayerConfig& config) :
Layer(config) {}
~ConcatenateLayer2() {}
@ -110,6 +111,8 @@ protected:
std::vector<std::unique_ptr<Projection>> projections_;
std::vector<Argument> projOutput_;
std::vector<std::pair<size_t, size_t>> projCol_;
bool sharedBias_;
std::unique_ptr<Weight> biases_;
};
REGISTER_LAYER(concat2, ConcatenateLayer2);
@ -119,7 +122,6 @@ bool ConcatenateLayer2::init(const LayerMap& layerMap,
/* Initialize the basic parent class */
if (!Layer::init(layerMap, parameterMap)) return false;
CHECK(!biasParameter_);
CHECK_EQ(inputLayers_.size(), parameters_.size());
projections_.reserve(inputLayers_.size());
projCol_.reserve(inputLayers_.size());
@ -137,6 +139,13 @@ bool ConcatenateLayer2::init(const LayerMap& layerMap,
}
CHECK_EQ(getSize(), endCol);
/* initialize biases_ */
if (biasParameter_.get() != NULL) {
sharedBias_ = config_.shared_biases();
size_t psize = config_.bias_size();
biases_ = std::unique_ptr<Weight>(new Weight(1, psize, biasParameter_));
}
return true;
}
@ -154,8 +163,17 @@ void ConcatenateLayer2::forward(PassType passType) {
projOutput_[i].grad = output_.grad->subColMatrix(startCol, endCol);
}
for (size_t i = 0; i != inputLayers_.size(); ++i) {
projections_[i]->forward(&getInput(i), &projOutput_[i], passType);
{
AsyncGpuBlock block;
for (size_t i = 0; i != inputLayers_.size(); ++i) {
projections_[i]->forward(&getInput(i), &projOutput_[i], passType);
}
}
/* add the bias-vector */
if (biases_) {
REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str());
output_.value->addBias(*(biases_->getW()), 1, sharedBias_);
}
/* activation */ {
@ -170,6 +188,13 @@ void ConcatenateLayer2::backward(const UpdateCallback& callback) {
backwardActivation();
}
AsyncGpuBlock block;
if (biases_ && biases_->getWGrad()) {
REGISTER_TIMER_INFO("Concat2BpBiasTimer", getName().c_str());
biases_->getWGrad()->collectBias(*getOutputGrad(), 1, sharedBias_);
biases_->getParameterPtr()->incUpdate(callback);
}
for (size_t i = 0; i != inputLayers_.size(); ++i) {
if (projections_[i]) {
projections_[i]->backward(callback);

@ -35,25 +35,12 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
filterSizeY_.push_back(conf.filter_size_y());
filterPixels_.push_back(filterSize_.back() * filterSizeY_.back());
channels_.push_back(conf.channels());
imgSize_.push_back(conf.img_size());
imgPixels_.push_back(imgSize_.back() * imgSize_.back());
imgSizeH_.push_back(conf.img_size());
imgSizeW_.push_back(conf.img_size());
groups_.push_back(conf.groups());
filterChannels_.push_back(conf.filter_channels());
outputX_.push_back(conf.output_x());
outputs_.push_back(outputX_.back() * outputX_.back());
}
/* initialize the weightList */
CHECK(inputLayers_.size() == parameters_.size());
for (size_t i = 0; i < inputLayers_.size(); i++) {
size_t height, width;
height = filterPixels_[i] * filterChannels_[i];
width = numFilters_;
// create a new weight
CHECK_EQ(parameters_[i]->getSize(), width * height);
Weight* w = new Weight(height, width, parameters_[i]);
weights_.emplace_back(w);
outputH_.push_back(conf.output_x());
outputW_.push_back(conf.output_x());
}
/* initialize the biases_ */
@ -74,4 +61,34 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
return true;
}
size_t ConvBaseLayer::calOutputSize() {
auto clearAndReserve = [this](IntV* vec) {
vec->clear();
vec->reserve(this->inputLayers_.size());
};
clearAndReserve(&imgSizeH_);
clearAndReserve(&imgSizeW_);
clearAndReserve(&outputH_);
clearAndReserve(&outputW_);
size_t layerSize = 0;
for (size_t i = 0; i < inputLayers_.size(); i++) {
imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight());
imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth());
if (imgSizeH_[i] == 0)
imgSizeH_[i] = config_.inputs(i).conv_conf().img_size();
if (imgSizeW_[i] == 0)
imgSizeW_[i] = config_.inputs(i).conv_conf().img_size();
outputH_.push_back(
outputSize(imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i]));
outputW_.push_back(
outputSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i]));
CHECK_EQ(outputH_[i], outputH_[0]);
CHECK_EQ(outputW_[i], outputW_[0]);
}
getOutput().setFrameHeight(outputH_[0]);
getOutput().setFrameWidth(outputW_[0]);
layerSize = outputH_[0] * outputW_[0] * size_t(numFilters_);
return layerSize;
}
} // namespace paddle

@ -43,19 +43,18 @@ protected:
IntV filterSizeY_;
/// The spatial dimensions of the convolution input.
IntV channels_;
/// The spatial dimensions of input feature map.
IntV imgSize_;
/// The total pixel size of input feature map.
/// imgPixels_ = imgSizeX_ * imgSizeY_.
IntV imgPixels_;
/// The spatial dimensions of input feature map height.
IntV imgSizeH_;
/// The spatial dimensions of input feature map width.
IntV imgSizeW_;
/// filterPixels_ = filterSizeX_ * filterSizeY_.
IntV filterPixels_;
/// filterChannels_ = channels_/groups_.
IntV filterChannels_;
/// The spatial dimensions of output feature map.
IntV outputX_;
/// The spatial dimensions of output feature map.
IntV outputs_;
/// The spatial dimensions of output feature map height.
IntV outputH_;
/// The spatial dimensions of output feature map width.
IntV outputW_;
/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
/// filters are only connected to the first half of the input channels,
@ -80,6 +79,13 @@ public:
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
/**
* imgSizeH_ and imgSizeW_ will be set according to the previous input layers
* in this function. Then it will calculate outputH_ and outputW_ and set them
* into output argument.
*/
virtual size_t calOutputSize();
Weight& getWeight(int idx) { return *weights_[idx]; }
/**

@ -0,0 +1,210 @@
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/utils/Stat.h"
#include "ConvProjection.h"
namespace paddle {
REGISTER_PROJECTION(conv, ConvProjection);
ThreadLocalD<std::vector<MemoryHandle*>> ConvProjection::convMem_;
ConvProjection::ConvProjection(const ProjectionConfig& config,
ParameterPtr parameter, bool useGpu)
: Projection(config, parameter, useGpu) {
CHECK(useGpu); // only support GPU
getConvParams();
initCudnn();
size_t height = filterH_ * filterW_ * channels_ / groups_;
size_t width = numFilters_;
weight_.reset(new Weight(height, width, parameter));
weightOffset_ = height * width / groups_;
}
void ConvProjection::getConvParams() {
const ConvConfig &conf = config_.conv_conf();
paddingH_ = conf.padding_y();
paddingW_ = conf.padding();
strideH_ = conf.stride_y();
strideW_ = conf.stride();
filterH_ = conf.filter_size_y();
filterW_ = conf.filter_size();
configImgH_ = conf.img_size();
configImgW_ = conf.img_size();
channels_ = conf.channels();
numFilters_ = config_.num_filters();
groups_ = conf.groups();
CHECK_EQ(channels_ % groups_, 0);
CHECK_EQ(numFilters_ % groups_, 0);
}
void ConvProjection::initCudnn() {
hl_create_filter_descriptor(&filterDesc_, channels_, numFilters_,
filterH_, filterW_);
hl_create_tensor_descriptor(&inputDesc_);
hl_create_tensor_descriptor(&outputDesc_);
hl_create_convolution_descriptor(&convDesc_, inputDesc_, filterDesc_,
paddingH_, paddingW_, strideH_, strideW_);
// initialize all to default algorithms
fwdAlgo_ = 0;
bwdFilterAlgo_ = 0;
bwdDataAlgo_ = 0;
fwdLimitBytes_ = 0;
bwdDataLimitBytes_ = 0;
bwdFilterLimitBytes_ = 0;
workSpaceInBytes_ = 0;
batchNum_ = 0;
isSelectAlgo_ = false;
}
void ConvProjection::reshapeTensorDesc(int batchSize) {
hl_tensor_reshape(inputDesc_, batchSize, channels_, imageH_, imageW_,
channels_ * imageH_ * imageW_, imageH_ * imageW_,
imageW_, 1);
hl_reset_convolution_descriptor(convDesc_, inputDesc_, filterDesc_,
paddingH_, paddingW_, strideH_, strideW_);
// The stride between two consecutive images in ConvProjection may not be 1,
// for example, in the case of layer ConcatenateLayer2 with two
// ConvProjection, the stride is the output_size of layer ConcatenateLayer2.
// So the calculation of nStride is different from CudnnConvLayer.
// In fact, only "nStride = out_->value->getStride()" is ok.
size_t nStride = numFilters_ * outputH_ * outputW_;
if (out_->value->isContiguous()) {
CHECK_EQ(nStride, out_->value->getWidth());
} else {
nStride = out_->value->getStride();
}
hl_tensor_reshape(outputDesc_, batchSize, numFilters_, outputH_, outputW_,
nStride, outputH_ * outputW_, outputW_, 1);
}
void ConvProjection::reshape(int batchSize) {
size_t width = calOutputSize();
CHECK_EQ(width, out_->value->getWidth());
isSelectAlgo_ = (batchSize == batchNum_);
batchNum_ = batchSize;
if (!isSelectAlgo_) {
reshapeTensorDesc(batchSize);
hl_conv_workspace(inputDesc_, outputDesc_, filterDesc_,
convDesc_, &fwdAlgo_, &fwdLimitBytes_,
&bwdDataAlgo_, &bwdDataLimitBytes_,
&bwdFilterAlgo_, &bwdFilterLimitBytes_);
size_t maxWorkSpace = 0;
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
workSpaceInBytes_ = maxWorkSpace;
VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
<< " / " << bwdDataAlgo_
<< " / " << bwdFilterAlgo_;
}
isSelectAlgo_ = true;
}
void ConvProjection::forward() {
int batchSize = in_->value->getHeight();
reshape(batchSize);
void* workSpace = NULL;
if (workSpaceInBytes_ > 0) {
workSpace = getSpaceBytes(workSpaceInBytes_);
}
for (int g = 0; g < groups_; ++g) {
REGISTER_TIMER_INFO("CudnnConvFwTimer", getName().c_str());
real *inputData = in_->value->getData() + g * inputOffset_;
real *wgtData = weight_->getW()->getData() + g * weightOffset_;
real *outData = out_->value->getData() + g * outputOffset_;
hl_convolution_forward(inputDesc_, inputData, outputDesc_,
outData, filterDesc_, wgtData,
convDesc_, workSpace,
fwdLimitBytes_, fwdAlgo_);
}
}
void ConvProjection::backward(const UpdateCallback& callback) {
REGISTER_TIMER_INFO("CudnnConvBpTimer", getName().c_str());
void* workSpace = NULL;
if (workSpaceInBytes_ > 0) {
workSpace = getSpaceBytes(workSpaceInBytes_);
}
for (int g = 0; g < groups_; ++g) {
real *outGrad = out_->grad->getData() + g * outputOffset_;
if (weight_->getWGrad()) {
real *inputData = in_->value->getData() + g * inputOffset_;
real *weightGrad = weight_->getWGrad()->getData() + g * weightOffset_;
hl_convolution_backward_filter(
inputDesc_, inputData, outputDesc_, outGrad, filterDesc_,
weightGrad, convDesc_, workSpace, bwdFilterLimitBytes_,
bwdFilterAlgo_);
}
MatrixPtr preGrad = in_->grad;
if (NULL != preGrad) {
real *inputGrad = preGrad->getData() + g * inputOffset_;
real *wgtData = weight_->getW()->getData() + g* weightOffset_;
hl_convolution_backward_data(
inputDesc_, inputGrad, outputDesc_, outGrad, filterDesc_,
wgtData, convDesc_, workSpace, bwdDataLimitBytes_,
bwdDataAlgo_);
}
}
weight_->getParameterPtr()->incUpdate(callback);
}
void* ConvProjection::getSpaceBytes(size_t size) {
std::vector<MemoryHandle*>& convMem = *convMem_;
if (convMem.empty()) {
int numDevices = hl_get_device_count();
convMem.resize(numDevices);
}
int devId = hl_get_device();
MemoryHandle** localMem = &(convMem[devId]);
if (NULL == *localMem || size > (*localMem)->getAllocSize()) {
*localMem = new GpuMemoryHandle(size);
}
return (*localMem)->getBuf();
}
ConvProjection::~ConvProjection() {
hl_destroy_tensor_descriptor(inputDesc_);
hl_destroy_tensor_descriptor(outputDesc_);
hl_destroy_filter_descriptor(filterDesc_);
hl_destroy_convolution_descriptor(convDesc_);
}
} // namespace paddle

@ -0,0 +1,125 @@
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Projection.h"
namespace paddle {
/**
* @brief Convolution projection do the same calculation with CudnnConvLayer.
*/
class ConvProjection : public Projection {
public:
/**
* Constructor.
*/
ConvProjection(const ProjectionConfig& config, ParameterPtr parameter,
bool useGpu);
~ConvProjection();
virtual void forward();
virtual void backward(const UpdateCallback& callback);
protected:
void getConvParams();
void initCudnn();
void reshapeTensorDesc(int batchSize);
void reshape(int batchSize);
int outputSize(int imageSize, int filterSize, int padding, int stride) {
return (imageSize - filterSize + 2 * padding) / stride + 1;
}
size_t calOutputSize() {
imageH_ = in_->getFrameHeight();
imageW_ = in_->getFrameWidth();
if (imageH_ == 0) imageH_ = configImgH_;
if (imageW_ == 0) imageW_ = configImgW_;
outputH_ = outputSize(imageH_, filterH_, paddingH_, strideH_);
outputW_ = outputSize(imageW_, filterW_, paddingW_, strideW_);
const_cast<Argument*>(out_)->setFrameHeight(outputH_);
const_cast<Argument*>(out_)->setFrameWidth(outputW_);
inputOffset_ = (channels_ / groups_) * imageH_ * imageW_;
outputOffset_ = (numFilters_ / groups_) * outputH_ * outputW_;
return outputH_ * outputW_ * numFilters_;
}
static void* getSpaceBytes(size_t size);
/// imageH_ and imageW_ is calculated from the input layer.
int imageH_, imageW_;
/// configImgH_ and configImgW_ is obtained from config.
int configImgH_, configImgW_;
int outputH_, outputW_;
int channels_, numFilters_;
int paddingH_, paddingW_;
int strideH_, strideW_;
int filterH_, filterW_;
/// One group offset of input data.
int inputOffset_;
/// One group offset of output data.
int outputOffset_;
/// One group offset of weight.
int weightOffset_;
int groups_;
/// Cudnn tensor descriptor for input.
hl_tensor_descriptor inputDesc_;
/// Cudnn tensor descriptor for output.
hl_tensor_descriptor outputDesc_;
/// Cudnn tensor descriptor for filter.
hl_filter_descriptor filterDesc_;
/// Cudnn tensor descriptor for a convolution operation.
hl_convolution_descriptor convDesc_;
/// Record the algorithm for forward convolution, which is obtained by cudnn
/// api to search the best suited algorithm.
int fwdAlgo_;
/// Record the algorithm for computing convolution gradient with respect to
/// filter coefficients.
int bwdFilterAlgo_;
/// Record the algorithm for computing convolution gradient with respect to
/// the output.
int bwdDataAlgo_;
/// Amount of GPU memory needed as workspace to be able to execute a
/// forward convolution with the specified algo.
size_t fwdLimitBytes_;
/// Amount of GPU memory needed as workspace to be able to execute a
/// backwardFilter with the specified algo.
size_t bwdDataLimitBytes_;
/// Amount of GPU memory needed as workspace to be able to execute a
/// backwardData with the specified algo.
size_t bwdFilterLimitBytes_;
/// Size of total work space.
size_t workSpaceInBytes_;
/// Whether to call cuDNN api to choose conv algorithm.
bool isSelectAlgo_;
/// batchNum is used to record batch size. If the batch size is changed,
/// the selection algorithm will be called.
int batchNum_;
bool bias_;
std::unique_ptr<Weight> weight_;
static ThreadLocalD<std::vector<MemoryHandle*>> convMem_;
};
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -17,12 +17,13 @@ limitations under the License. */
#include "ConvBaseLayer.h"
#include "paddle/math/Matrix.h"
#include "Projection.h"
#include <vector>
namespace paddle {
/**
* @brief A subclass of ConvBaseLayer by cuDNN implementation. It only
* @brief A 2-dimension conv layer implemented by cuDNN. It only
* supports GPU mode. We automatic select CudnnConvLayer for GPU
* mode and ExpandConvLayer for CPU mode if you set type of "conv".
* User also can specfiy type of "exconv" or "cudnn_conv" for
@ -31,81 +32,21 @@ namespace paddle {
* The config file api is img_conv_layer.
*/
class CudnnConvLayer : public ConvBaseLayer {
private:
/// resize Cudnn workspace size
void allocConvWorkSpace(size_t maxWorkSpace);
protected:
int imageH_, imageW_, outputH_, outputW_;
/// Cudnn tensor descriptor for bias.
std::vector<std::unique_ptr<ProjectionConfig>> projConf_;
std::vector<std::unique_ptr<Projection>> projections_;
hl_tensor_descriptor biasDesc_;
/// Cudnn tensor descriptor for input.
std::vector<hl_tensor_descriptor> inputDesc_;
/// Cudnn tensor descriptor for output.
std::vector<hl_tensor_descriptor> outputDesc_;
/// Cudnn tensor descriptor for filter.
std::vector<hl_filter_descriptor> filterDesc_;
/// Cudnn tensor descriptor for a convolution operation.
std::vector<hl_convolution_descriptor> convDesc_;
/// One sample offset of input data.
IntV inputOffset_;
/// One sample offset of output data.
IntV outputOffset_;
/// One group offset of weight.
IntV weightOffset_;
/// One group offset of bias.
hl_tensor_descriptor outputDesc_;
int biasOffset_;
/// Save the algorithm for forward convolution, which is obtained by cudnn
/// api to search the best suited algorithm.
std::vector<int> fwdAlgo_;
/// Save the algorithm for computing convolution gradient with respect to
/// filter coefficients.
std::vector<int> bwdFilterAlgo_;
/// Save the algorithm for computing convolution gradient with respect to
/// the output.
std::vector<int> bwdDataAlgo_;
/// Amount of GPU memory needed as workspace to be able to execute a
/// forward convolution with the specified algo.
std::vector<size_t> fwdLimitBytes_;
/// Amount of GPU memory needed as workspace to be able to execute a
/// backwardFilter with the specified algo.
std::vector<size_t> bwdFilterLimitBytes_;
/// Amount of GPU memory needed as workspace to be able to execute a
/// backwardData with the specified algo.
std::vector<size_t> bwdDataLimitBytes_;
/// Device work space address for each group.
std::vector<void*> workSpace_;
/// Max number of groups.
int maxGroups_;
/// Total work space address in device for all groups.
void* workSpaceData_;
/// Size of total work space.
size_t workSpaceInBytes_;
/// Is or not select conv algorihtm.
bool isSelectAlgo_;
/// batchNum is used to record batch size. If the batch size is changed,
/// the selection algorithm will be called.
int batchNum_;
int outputOffset_;
public:
explicit CudnnConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
~CudnnConvLayer();
/**
* Intialization. Initialize member variables and create tenor descriptor.
*/
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
/**
* Reshape is done each forward. Reshape tensor decriptor
* inputDesc_, outputDesc_, convDesc_. And search the faster algo
* or the fastest algo within a given memeory limit.
*/
void reshape(int batchSize);
void forward(PassType passType);
void backward(const UpdateCallback& callback);
void addBiases();

@ -37,32 +37,29 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
caffeMode_ = conf.caffe_mode();
}
/* initialize the weightList */
CHECK(inputLayers_.size() == parameters_.size());
for (size_t i = 0; i < inputLayers_.size(); i++) {
size_t height, width;
height = filterPixels_[i] * filterChannels_[i];
width = numFilters_;
// create a new weight
CHECK_EQ(parameters_[i]->getSize(), width * height);
Weight* w = new Weight(height, width, parameters_[i]);
weights_.emplace_back(w);
}
return true;
}
size_t ExpandConvLayer::getSize() {
size_t ExpandConvLayer::getOutputSize() {
CHECK_NE(inputLayers_.size(), 0UL);
imgSizeH_.clear();
imgSizeW_.clear();
outputH_.clear();
outputW_.clear();
size_t layerSize = ConvBaseLayer::calOutputSize();
subN_.clear();
size_t layerSize = 0;
for (size_t i = 0; i < inputLayers_.size(); i++) {
imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight());
imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth());
if (imgSizeH_[i] == 0) imgSizeH_[i] = imgSize_[i];
if (imgSizeW_[i] == 0) imgSizeW_[i] = imgSize_[i];
outputH_.push_back(
outputSize(imgSizeH_[i], filterSize_[i], padding_[i], stride_[i]));
outputW_.push_back(
outputSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i]));
subN_.push_back(outputH_[i] * outputW_[i]);
CHECK(layerSize == 0 || subN_[i] * size_t(numFilters_) == layerSize);
layerSize = subN_[i] * numFilters_;
}
getOutput().setFrameHeight(outputH_[0]);
getOutput().setFrameWidth(outputW_[0]);
return layerSize;
}
@ -119,7 +116,7 @@ void ExpandConvLayer::expandFwdOnce(MatrixPtr image, int inIdx, int startIdx) {
}
void ExpandConvLayer::addSharedBias() {
size_t mapW = getSize() / numFilters_;
size_t mapW = getOutputValue()->getWidth() / numFilters_;
size_t mapH = getOutputValue()->getElementCnt() / mapW;
MatrixPtr out =
Matrix::create(getOutputValue()->getData(), mapH, mapW, false, useGpu_);
@ -158,7 +155,7 @@ void ExpandConvLayer::forward(PassType passType) {
* transOutValue correspond sample to one row */
int batchSize = inputLayers_[0]->getOutputValue()->getWidth();
batchSize = inputLayers_[0]->getOutputValue()->getHeight();
resetOutput(batchSize, getSize());
resetOutput(batchSize, getOutputSize());
MatrixPtr image = nullptr;
for (size_t i = 0; i != inputLayers_.size(); ++i) {
@ -183,7 +180,7 @@ void ExpandConvLayer::forward(PassType passType) {
}
void ExpandConvLayer::bpropSharedBias(MatrixPtr biases, MatrixPtr v) {
size_t mapW = getSize() / numFilters_;
size_t mapW = v->getWidth() / numFilters_;
size_t mapH = v->getElementCnt() / mapW;
MatrixPtr vTmp = Matrix::create(v->getData(), mapH, mapW, false, useGpu_);

@ -37,14 +37,6 @@ protected:
IntV subN_;
/// subK_ = channels_ * filterPixels_ * groups_.
IntV subK_;
/// The spatial dimensions of height of input feature map.
IntV imgSizeH_;
/// The spatial dimensions of width of input feature map.
IntV imgSizeW_;
/// The spatial dimensions of height of output feature map.
IntV outputH_;
/// The spatial dimensions of width of output feature map.
IntV outputW_;
/// Expand one sample at a time. shape:
/// (numChannels * filterPixels_, outputSizeH * outputSizeW)
MatrixPtr expandInput_;
@ -58,7 +50,7 @@ public:
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
size_t getSize();
size_t getOutputSize();
/**
* Create or resize expandInput_.

@ -41,9 +41,13 @@ bool MixedLayer::init(const LayerMap& layerMap,
}
operators_.emplace_back(Operator::create(operator_conf, useGpu_));
}
/* initialize biases_ */
if (biasParameter_.get() != NULL) {
biases_ = std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_));
sharedBias_ = config_.shared_biases();
size_t psize = config_.bias_size();
biases_ = std::unique_ptr<Weight>(
new Weight(1, psize, biasParameter_));
}
return true;
@ -119,12 +123,6 @@ void MixedLayer::forward(PassType passType) {
MatrixPtr outV = getOutputValue();
/* add the bias-vector */
if (biases_.get() != NULL) {
REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str());
outV->addBias(*(biases_->getW()), 1);
}
for (size_t i = 0; i != inputLayers_.size(); ++i) {
if (projections_[i]) {
projections_[i]->forward(&getInput(i), &output_, passType);
@ -140,6 +138,12 @@ void MixedLayer::forward(PassType passType) {
op->forward(ins, &output_, passType);
}
/* add the bias-vector */
if (biases_.get() != NULL) {
REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str());
outV->addBias(*(biases_->getW()), 1, sharedBias_);
}
/* activation */ {
REGISTER_TIMER_INFO("FwAtvTimer", getName().c_str());
forwardActivation();
@ -154,7 +158,7 @@ void MixedLayer::backward(const UpdateCallback& callback) {
if (biases_ && biases_->getWGrad()) {
REGISTER_TIMER_INFO("BpBiasTimer", getName().c_str());
biases_->getWGrad()->collectBias(*getOutputGrad(), 1);
biases_->getWGrad()->collectBias(*getOutputGrad(), 1, sharedBias_);
/* Increasing the number of gradient */
biases_->getParameterPtr()->incUpdate(callback);

@ -58,5 +58,6 @@ protected:
/// the matrix size of projection state
std::vector<int> projectionStateMatrixSize_;
std::unique_ptr<Weight> biases_;
bool sharedBias_;
};
} // namespace paddle

@ -669,12 +669,14 @@ void testLayerGrad(TestConfig testConf, string testLayerName, size_t batchSize,
void testProjectionGrad(ProjectionConfig conf, InputType inputType,
size_t parameterSize, size_t batchSize, bool useGpu,
bool testState) {
bool testState, int biasSize, bool sharedBias) {
TestConfig config;
conf.set_name(conf.type());
config.layerConfig.set_type("mixed");
config.layerConfig.set_size(conf.output_size());
config.biasSize = config.layerConfig.size();
config.biasSize = biasSize == 0 ? config.layerConfig.size() : biasSize;
config.layerConfig.set_bias_size(config.biasSize);
config.layerConfig.set_shared_biases(sharedBias);
config.inputDefs.push_back(
{inputType, "layer_0", conf.input_size(), parameterSize});
*config.layerConfig.add_inputs()->mutable_proj_conf() = conf;

@ -217,7 +217,8 @@ void testLayerGrad(TestConfig testConf, string testLayerName, size_t batchSize,
void testProjectionGrad(ProjectionConfig conf, InputType inputType,
size_t parameterSize, size_t batchSize, bool useGpu,
bool testState = false);
bool testState = false, int biasSize = 0,
bool sharedBias = false);
void testOperatorGrad(TestConfig& config, OperatorConfig& operatorConf,
size_t batchSize, bool useGpu, bool testState = false);

@ -0,0 +1,39 @@
#edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
settings(batch_size=10)
data = data_layer(name ="input", size=8*16*16)
conv1 = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
num_channels=8,
num_filters=16, stride=1,
bias_attr=False,
act=ReluActivation())
conv2 = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
num_channels=8,
num_filters=16, stride=1,
bias_attr=False,
act=ReluActivation())
concat = concat_layer(input=[conv1, conv2])
conv = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
num_channels=8,
num_filters=16, stride=1,
bias_attr=True,
act=LinearActivation())
outputs(concat, conv)

@ -0,0 +1,32 @@
#edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
settings(batch_size=10)
data = data_layer(name ="input", size=8*16*16)
proj1 = conv_projection(input=data, filter_size=1, filter_size_y=1,
num_channels=8, num_filters=16, stride=1)
proj2 = conv_projection(input=data, filter_size=1, filter_size_y=1,
num_channels=8, num_filters=16, stride=1)
concat = concat_layer(input=[proj1, proj2], bias_attr=False, act=ReluActivation())
proj = conv_projection(input=data, filter_size=1, filter_size_y=1,
num_channels=8, num_filters=16, stride=1)
with mixed_layer(bias_attr=True, act=LinearActivation()) as conv:
conv += proj
outputs(concat, conv)

@ -134,6 +134,45 @@ TEST(Projection, identity) {
}
}
#ifndef PADDLE_ONLY_CPU
TEST(Projection, conv) {
const int NUM_FILTERS = 16;
const int FILTER_SIZE = 2;
const int FILTER_SIZE_Y = 3;
const int CHANNELS = 3;
const int IMAGE_SIZE = 16;
ProjectionConfig conf;
conf.set_type("conv");
conf.set_num_filters(NUM_FILTERS);
ConvConfig* conv = conf.mutable_conv_conf();
conv->set_filter_size(FILTER_SIZE);
conv->set_filter_size_y(FILTER_SIZE_Y);
conv->set_channels(CHANNELS);
conv->set_padding(0);
conv->set_padding_y(1);
conv->set_stride(2);
conv->set_stride_y(2);
conv->set_groups(1);
conv->set_filter_channels(conv->channels() / conv->groups());
conv->set_img_size(IMAGE_SIZE);
int outputSize = (2 * conv->padding() + conv->img_size() -
conv->filter_size()) / conv->stride() + 1;
int outputSizeY = (2 * conv->padding_y() + conv->img_size() -
conv->filter_size_y()) / conv->stride_y() + 1;
conv->set_output_x(outputSize);
conf.set_input_size(IMAGE_SIZE * IMAGE_SIZE * CHANNELS);
conf.set_output_size(outputSize * outputSizeY * NUM_FILTERS);
testProjectionGrad(conf, INPUT_DATA,
/* parameterSize */ NUM_FILTERS * CHANNELS * FILTER_SIZE * FILTER_SIZE_Y,
/* batchSize */ 100, true, false, NUM_FILTERS, true);
}
#endif
TEST(Layer, concat) {
TestConfig config;
config.biasSize = 0;

@ -236,6 +236,15 @@ TEST(Compare, img_pool) {
compareNetwork(config_file_a, config_file_b);
FLAGS_use_gpu = useGpu;
}
TEST(Compare, img_conv) {
std::string config_file_a = "./gserver/tests/img_conv_a.conf";
std::string config_file_b = "./gserver/tests/img_conv_b.conf";
bool useGpu = FLAGS_use_gpu;
FLAGS_use_gpu = true;
compareNetwork(config_file_a, config_file_b);
FLAGS_use_gpu = useGpu;
}
#endif

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save