Add job=time in trainer, refine cudnn_conv to reduce gpu memory and speed up training. (#218)
* Add benchmark for PaddlePaddle, tensorflow and caffe * ConvProjection to reduce memory for goolenet * Add unit test for ConvProjection. 1. unit test in test_LayerGrad. 2. compare the ConvPorjection and CudnnConvLayer, also compare the concat_layer+img_conv_layer and concat_layer_conv_projection. * Reduce cudnn_conv memory and add benchmark document. 1. Use TmpMatrix as the workspace in cudnn_conv to reduce gpu memory. It reduce lots of memory. 2. Add benchmark document. 3. fix smallnet_mnist_cifar.py in paddle. * Add job=time and refine cudnn_conv to reduce gpu memroy and speed up * Refine cudnn_conv and shared biases operation in concat_layer and mixed_layer. * follow comments * follow comments * Use unique_ptr to prevent memory leaks in CudnnConvLayer.avx_docs
parent
12945b2c90
commit
45c81a414f
@ -0,0 +1,210 @@
|
|||||||
|
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
|
||||||
|
#include "paddle/utils/Stat.h"
|
||||||
|
#include "ConvProjection.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
REGISTER_PROJECTION(conv, ConvProjection);
|
||||||
|
|
||||||
|
ThreadLocalD<std::vector<MemoryHandle*>> ConvProjection::convMem_;
|
||||||
|
|
||||||
|
ConvProjection::ConvProjection(const ProjectionConfig& config,
|
||||||
|
ParameterPtr parameter, bool useGpu)
|
||||||
|
: Projection(config, parameter, useGpu) {
|
||||||
|
|
||||||
|
CHECK(useGpu); // only support GPU
|
||||||
|
getConvParams();
|
||||||
|
initCudnn();
|
||||||
|
|
||||||
|
size_t height = filterH_ * filterW_ * channels_ / groups_;
|
||||||
|
size_t width = numFilters_;
|
||||||
|
weight_.reset(new Weight(height, width, parameter));
|
||||||
|
weightOffset_ = height * width / groups_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ConvProjection::getConvParams() {
|
||||||
|
const ConvConfig &conf = config_.conv_conf();
|
||||||
|
paddingH_ = conf.padding_y();
|
||||||
|
paddingW_ = conf.padding();
|
||||||
|
|
||||||
|
strideH_ = conf.stride_y();
|
||||||
|
strideW_ = conf.stride();
|
||||||
|
|
||||||
|
filterH_ = conf.filter_size_y();
|
||||||
|
filterW_ = conf.filter_size();
|
||||||
|
|
||||||
|
configImgH_ = conf.img_size();
|
||||||
|
configImgW_ = conf.img_size();
|
||||||
|
|
||||||
|
channels_ = conf.channels();
|
||||||
|
numFilters_ = config_.num_filters();
|
||||||
|
|
||||||
|
groups_ = conf.groups();
|
||||||
|
CHECK_EQ(channels_ % groups_, 0);
|
||||||
|
CHECK_EQ(numFilters_ % groups_, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ConvProjection::initCudnn() {
|
||||||
|
hl_create_filter_descriptor(&filterDesc_, channels_, numFilters_,
|
||||||
|
filterH_, filterW_);
|
||||||
|
hl_create_tensor_descriptor(&inputDesc_);
|
||||||
|
hl_create_tensor_descriptor(&outputDesc_);
|
||||||
|
hl_create_convolution_descriptor(&convDesc_, inputDesc_, filterDesc_,
|
||||||
|
paddingH_, paddingW_, strideH_, strideW_);
|
||||||
|
|
||||||
|
// initialize all to default algorithms
|
||||||
|
fwdAlgo_ = 0;
|
||||||
|
bwdFilterAlgo_ = 0;
|
||||||
|
bwdDataAlgo_ = 0;
|
||||||
|
fwdLimitBytes_ = 0;
|
||||||
|
bwdDataLimitBytes_ = 0;
|
||||||
|
bwdFilterLimitBytes_ = 0;
|
||||||
|
workSpaceInBytes_ = 0;
|
||||||
|
|
||||||
|
batchNum_ = 0;
|
||||||
|
isSelectAlgo_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ConvProjection::reshapeTensorDesc(int batchSize) {
|
||||||
|
hl_tensor_reshape(inputDesc_, batchSize, channels_, imageH_, imageW_,
|
||||||
|
channels_ * imageH_ * imageW_, imageH_ * imageW_,
|
||||||
|
imageW_, 1);
|
||||||
|
hl_reset_convolution_descriptor(convDesc_, inputDesc_, filterDesc_,
|
||||||
|
paddingH_, paddingW_, strideH_, strideW_);
|
||||||
|
|
||||||
|
// The stride between two consecutive images in ConvProjection may not be 1,
|
||||||
|
// for example, in the case of layer ConcatenateLayer2 with two
|
||||||
|
// ConvProjection, the stride is the output_size of layer ConcatenateLayer2.
|
||||||
|
// So the calculation of nStride is different from CudnnConvLayer.
|
||||||
|
// In fact, only "nStride = out_->value->getStride()" is ok.
|
||||||
|
size_t nStride = numFilters_ * outputH_ * outputW_;
|
||||||
|
if (out_->value->isContiguous()) {
|
||||||
|
CHECK_EQ(nStride, out_->value->getWidth());
|
||||||
|
} else {
|
||||||
|
nStride = out_->value->getStride();
|
||||||
|
}
|
||||||
|
|
||||||
|
hl_tensor_reshape(outputDesc_, batchSize, numFilters_, outputH_, outputW_,
|
||||||
|
nStride, outputH_ * outputW_, outputW_, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ConvProjection::reshape(int batchSize) {
|
||||||
|
size_t width = calOutputSize();
|
||||||
|
CHECK_EQ(width, out_->value->getWidth());
|
||||||
|
|
||||||
|
isSelectAlgo_ = (batchSize == batchNum_);
|
||||||
|
batchNum_ = batchSize;
|
||||||
|
|
||||||
|
if (!isSelectAlgo_) {
|
||||||
|
reshapeTensorDesc(batchSize);
|
||||||
|
hl_conv_workspace(inputDesc_, outputDesc_, filterDesc_,
|
||||||
|
convDesc_, &fwdAlgo_, &fwdLimitBytes_,
|
||||||
|
&bwdDataAlgo_, &bwdDataLimitBytes_,
|
||||||
|
&bwdFilterAlgo_, &bwdFilterLimitBytes_);
|
||||||
|
|
||||||
|
size_t maxWorkSpace = 0;
|
||||||
|
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
|
||||||
|
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
|
||||||
|
workSpaceInBytes_ = maxWorkSpace;
|
||||||
|
|
||||||
|
|
||||||
|
VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
|
||||||
|
<< " / " << bwdDataAlgo_
|
||||||
|
<< " / " << bwdFilterAlgo_;
|
||||||
|
}
|
||||||
|
|
||||||
|
isSelectAlgo_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ConvProjection::forward() {
|
||||||
|
int batchSize = in_->value->getHeight();
|
||||||
|
reshape(batchSize);
|
||||||
|
|
||||||
|
void* workSpace = NULL;
|
||||||
|
if (workSpaceInBytes_ > 0) {
|
||||||
|
workSpace = getSpaceBytes(workSpaceInBytes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int g = 0; g < groups_; ++g) {
|
||||||
|
REGISTER_TIMER_INFO("CudnnConvFwTimer", getName().c_str());
|
||||||
|
|
||||||
|
real *inputData = in_->value->getData() + g * inputOffset_;
|
||||||
|
real *wgtData = weight_->getW()->getData() + g * weightOffset_;
|
||||||
|
real *outData = out_->value->getData() + g * outputOffset_;
|
||||||
|
hl_convolution_forward(inputDesc_, inputData, outputDesc_,
|
||||||
|
outData, filterDesc_, wgtData,
|
||||||
|
convDesc_, workSpace,
|
||||||
|
fwdLimitBytes_, fwdAlgo_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ConvProjection::backward(const UpdateCallback& callback) {
|
||||||
|
REGISTER_TIMER_INFO("CudnnConvBpTimer", getName().c_str());
|
||||||
|
|
||||||
|
void* workSpace = NULL;
|
||||||
|
if (workSpaceInBytes_ > 0) {
|
||||||
|
workSpace = getSpaceBytes(workSpaceInBytes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int g = 0; g < groups_; ++g) {
|
||||||
|
real *outGrad = out_->grad->getData() + g * outputOffset_;
|
||||||
|
if (weight_->getWGrad()) {
|
||||||
|
real *inputData = in_->value->getData() + g * inputOffset_;
|
||||||
|
real *weightGrad = weight_->getWGrad()->getData() + g * weightOffset_;
|
||||||
|
hl_convolution_backward_filter(
|
||||||
|
inputDesc_, inputData, outputDesc_, outGrad, filterDesc_,
|
||||||
|
weightGrad, convDesc_, workSpace, bwdFilterLimitBytes_,
|
||||||
|
bwdFilterAlgo_);
|
||||||
|
}
|
||||||
|
|
||||||
|
MatrixPtr preGrad = in_->grad;
|
||||||
|
if (NULL != preGrad) {
|
||||||
|
real *inputGrad = preGrad->getData() + g * inputOffset_;
|
||||||
|
real *wgtData = weight_->getW()->getData() + g* weightOffset_;
|
||||||
|
hl_convolution_backward_data(
|
||||||
|
inputDesc_, inputGrad, outputDesc_, outGrad, filterDesc_,
|
||||||
|
wgtData, convDesc_, workSpace, bwdDataLimitBytes_,
|
||||||
|
bwdDataAlgo_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
weight_->getParameterPtr()->incUpdate(callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* ConvProjection::getSpaceBytes(size_t size) {
|
||||||
|
std::vector<MemoryHandle*>& convMem = *convMem_;
|
||||||
|
if (convMem.empty()) {
|
||||||
|
int numDevices = hl_get_device_count();
|
||||||
|
convMem.resize(numDevices);
|
||||||
|
}
|
||||||
|
|
||||||
|
int devId = hl_get_device();
|
||||||
|
MemoryHandle** localMem = &(convMem[devId]);
|
||||||
|
if (NULL == *localMem || size > (*localMem)->getAllocSize()) {
|
||||||
|
*localMem = new GpuMemoryHandle(size);
|
||||||
|
}
|
||||||
|
return (*localMem)->getBuf();
|
||||||
|
}
|
||||||
|
|
||||||
|
ConvProjection::~ConvProjection() {
|
||||||
|
hl_destroy_tensor_descriptor(inputDesc_);
|
||||||
|
hl_destroy_tensor_descriptor(outputDesc_);
|
||||||
|
hl_destroy_filter_descriptor(filterDesc_);
|
||||||
|
hl_destroy_convolution_descriptor(convDesc_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,125 @@
|
|||||||
|
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Projection.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Convolution projection do the same calculation with CudnnConvLayer.
|
||||||
|
*/
|
||||||
|
class ConvProjection : public Projection {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* Constructor.
|
||||||
|
*/
|
||||||
|
ConvProjection(const ProjectionConfig& config, ParameterPtr parameter,
|
||||||
|
bool useGpu);
|
||||||
|
|
||||||
|
~ConvProjection();
|
||||||
|
|
||||||
|
virtual void forward();
|
||||||
|
virtual void backward(const UpdateCallback& callback);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void getConvParams();
|
||||||
|
void initCudnn();
|
||||||
|
|
||||||
|
void reshapeTensorDesc(int batchSize);
|
||||||
|
void reshape(int batchSize);
|
||||||
|
|
||||||
|
int outputSize(int imageSize, int filterSize, int padding, int stride) {
|
||||||
|
return (imageSize - filterSize + 2 * padding) / stride + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t calOutputSize() {
|
||||||
|
imageH_ = in_->getFrameHeight();
|
||||||
|
imageW_ = in_->getFrameWidth();
|
||||||
|
if (imageH_ == 0) imageH_ = configImgH_;
|
||||||
|
if (imageW_ == 0) imageW_ = configImgW_;
|
||||||
|
outputH_ = outputSize(imageH_, filterH_, paddingH_, strideH_);
|
||||||
|
outputW_ = outputSize(imageW_, filterW_, paddingW_, strideW_);
|
||||||
|
|
||||||
|
const_cast<Argument*>(out_)->setFrameHeight(outputH_);
|
||||||
|
const_cast<Argument*>(out_)->setFrameWidth(outputW_);
|
||||||
|
|
||||||
|
inputOffset_ = (channels_ / groups_) * imageH_ * imageW_;
|
||||||
|
outputOffset_ = (numFilters_ / groups_) * outputH_ * outputW_;
|
||||||
|
return outputH_ * outputW_ * numFilters_;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void* getSpaceBytes(size_t size);
|
||||||
|
|
||||||
|
/// imageH_ and imageW_ is calculated from the input layer.
|
||||||
|
int imageH_, imageW_;
|
||||||
|
/// configImgH_ and configImgW_ is obtained from config.
|
||||||
|
int configImgH_, configImgW_;
|
||||||
|
int outputH_, outputW_;
|
||||||
|
int channels_, numFilters_;
|
||||||
|
int paddingH_, paddingW_;
|
||||||
|
int strideH_, strideW_;
|
||||||
|
int filterH_, filterW_;
|
||||||
|
/// One group offset of input data.
|
||||||
|
int inputOffset_;
|
||||||
|
/// One group offset of output data.
|
||||||
|
int outputOffset_;
|
||||||
|
/// One group offset of weight.
|
||||||
|
int weightOffset_;
|
||||||
|
int groups_;
|
||||||
|
|
||||||
|
/// Cudnn tensor descriptor for input.
|
||||||
|
hl_tensor_descriptor inputDesc_;
|
||||||
|
/// Cudnn tensor descriptor for output.
|
||||||
|
hl_tensor_descriptor outputDesc_;
|
||||||
|
/// Cudnn tensor descriptor for filter.
|
||||||
|
hl_filter_descriptor filterDesc_;
|
||||||
|
/// Cudnn tensor descriptor for a convolution operation.
|
||||||
|
hl_convolution_descriptor convDesc_;
|
||||||
|
|
||||||
|
/// Record the algorithm for forward convolution, which is obtained by cudnn
|
||||||
|
/// api to search the best suited algorithm.
|
||||||
|
int fwdAlgo_;
|
||||||
|
/// Record the algorithm for computing convolution gradient with respect to
|
||||||
|
/// filter coefficients.
|
||||||
|
int bwdFilterAlgo_;
|
||||||
|
/// Record the algorithm for computing convolution gradient with respect to
|
||||||
|
/// the output.
|
||||||
|
int bwdDataAlgo_;
|
||||||
|
/// Amount of GPU memory needed as workspace to be able to execute a
|
||||||
|
/// forward convolution with the specified algo.
|
||||||
|
size_t fwdLimitBytes_;
|
||||||
|
/// Amount of GPU memory needed as workspace to be able to execute a
|
||||||
|
/// backwardFilter with the specified algo.
|
||||||
|
size_t bwdDataLimitBytes_;
|
||||||
|
/// Amount of GPU memory needed as workspace to be able to execute a
|
||||||
|
/// backwardData with the specified algo.
|
||||||
|
size_t bwdFilterLimitBytes_;
|
||||||
|
/// Size of total work space.
|
||||||
|
size_t workSpaceInBytes_;
|
||||||
|
|
||||||
|
/// Whether to call cuDNN api to choose conv algorithm.
|
||||||
|
bool isSelectAlgo_;
|
||||||
|
/// batchNum is used to record batch size. If the batch size is changed,
|
||||||
|
/// the selection algorithm will be called.
|
||||||
|
int batchNum_;
|
||||||
|
bool bias_;
|
||||||
|
|
||||||
|
std::unique_ptr<Weight> weight_;
|
||||||
|
static ThreadLocalD<std::vector<MemoryHandle*>> convMem_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,39 @@
|
|||||||
|
#edit-mode: -*- python -*-
|
||||||
|
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from paddle.trainer_config_helpers import *
|
||||||
|
|
||||||
|
settings(batch_size=10)
|
||||||
|
data = data_layer(name ="input", size=8*16*16)
|
||||||
|
conv1 = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
|
||||||
|
num_channels=8,
|
||||||
|
num_filters=16, stride=1,
|
||||||
|
bias_attr=False,
|
||||||
|
act=ReluActivation())
|
||||||
|
conv2 = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
|
||||||
|
num_channels=8,
|
||||||
|
num_filters=16, stride=1,
|
||||||
|
bias_attr=False,
|
||||||
|
act=ReluActivation())
|
||||||
|
|
||||||
|
concat = concat_layer(input=[conv1, conv2])
|
||||||
|
|
||||||
|
conv = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
|
||||||
|
num_channels=8,
|
||||||
|
num_filters=16, stride=1,
|
||||||
|
bias_attr=True,
|
||||||
|
act=LinearActivation())
|
||||||
|
|
||||||
|
outputs(concat, conv)
|
@ -0,0 +1,32 @@
|
|||||||
|
#edit-mode: -*- python -*-
|
||||||
|
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from paddle.trainer_config_helpers import *
|
||||||
|
|
||||||
|
settings(batch_size=10)
|
||||||
|
data = data_layer(name ="input", size=8*16*16)
|
||||||
|
proj1 = conv_projection(input=data, filter_size=1, filter_size_y=1,
|
||||||
|
num_channels=8, num_filters=16, stride=1)
|
||||||
|
proj2 = conv_projection(input=data, filter_size=1, filter_size_y=1,
|
||||||
|
num_channels=8, num_filters=16, stride=1)
|
||||||
|
concat = concat_layer(input=[proj1, proj2], bias_attr=False, act=ReluActivation())
|
||||||
|
|
||||||
|
proj = conv_projection(input=data, filter_size=1, filter_size_y=1,
|
||||||
|
num_channels=8, num_filters=16, stride=1)
|
||||||
|
|
||||||
|
with mixed_layer(bias_attr=True, act=LinearActivation()) as conv:
|
||||||
|
conv += proj
|
||||||
|
|
||||||
|
outputs(concat, conv)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue