@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License . */
limitations under the License . */
# include "paddle/math/Matrix.h"
# include "paddle/math/Matrix.h"
# include "paddle/math/MathUtils.h"
# include "Operator.h"
# include "Operator.h"
namespace paddle {
namespace paddle {
@ -35,8 +35,8 @@ public:
*/
*/
virtual ~ ConvOperator ( ) {
virtual ~ ConvOperator ( ) {
if ( workSpaceInBytes_ ! = 0 ) {
if ( workSpaceInBytes_ ! = 0 ) {
hl_free_mem_device ( workSpace_ ) ;
hl_free_mem_device ( workSpace_ ) ;
workSpaceInBytes_ = 0 ;
workSpaceInBytes_ = 0 ;
}
}
hl_destroy_tensor_descriptor ( inputDesc_ ) ;
hl_destroy_tensor_descriptor ( inputDesc_ ) ;
@ -83,33 +83,6 @@ private:
filterSize_ * filterSizeY_ * channels_ * numFilters_ ) ;
filterSize_ * filterSizeY_ * channels_ * numFilters_ ) ;
}
}
/**
* Calculate output size .
*/
int outputSize ( int imageSize , int filterSize , int padding , int stride ) {
int outputSize ;
if ( ! caffeMode_ ) {
/* input(+padding): 0123456789
* imageSize ( + padding ) = 10 ;
* filterSize = 3 ;
* stride = 2 ;
* output : ( 012 ) , ( 234 ) , ( 456 ) , ( 678 ) , ( 9 )
* outputSize = 5 ;
*/
outputSize =
( imageSize - filterSize + 2 * padding + stride - 1 ) / stride + 1 ;
} else {
/* input(+padding): 0123456789
* imageSize ( + padding ) = 10 ;
* filterSize = 3 ;
* stride = 2 ;
* output : ( 012 ) , ( 234 ) , ( 456 ) , ( 678 )
* outputSize = 4 ;
*/
outputSize = ( imageSize - filterSize + 2 * padding ) / stride + 1 ;
}
return outputSize ;
}
/// Most of member variables are same with CudnnConvLayer.
/// Most of member variables are same with CudnnConvLayer.
/// There is no explanation here.
/// There is no explanation here.
int imageH_ , imageW_ , outputH_ , outputW_ ;
int imageH_ , imageW_ , outputH_ , outputW_ ;
@ -129,7 +102,7 @@ private:
int fwdAlgo_ , bwdFilterAlgo_ , bwdDataAlgo_ ;
int fwdAlgo_ , bwdFilterAlgo_ , bwdDataAlgo_ ;
size_t fwdLimitBytes_ , bwdDataLimitBytes_ , bwdFilterLimitBytes_ ;
size_t fwdLimitBytes_ , bwdDataLimitBytes_ , bwdFilterLimitBytes_ ;
size_t workSpaceInBytes_ ;
size_t workSpaceInBytes_ ;
void * workSpace_ ;
void * workSpace_ ;
bool isSelectAlgo_ ;
bool isSelectAlgo_ ;
} ;
} ;
@ -160,7 +133,7 @@ ConvOperator::ConvOperator(const OperatorConfig &config, bool useGpu)
void ConvOperator : : allocConvWorkSpace ( size_t maxWorkSpace ) {
void ConvOperator : : allocConvWorkSpace ( size_t maxWorkSpace ) {
if ( maxWorkSpace > workSpaceInBytes_ ) {
if ( maxWorkSpace > workSpaceInBytes_ ) {
if ( workSpaceInBytes_ ! = 0 ) {
if ( workSpaceInBytes_ ! = 0 ) {
hl_free_mem_device ( workSpace_ ) ;
hl_free_mem_device ( workSpace_ ) ;
}
}
// total amount of storage needed
// total amount of storage needed
workSpace_ = hl_malloc_device ( maxWorkSpace ) ;
workSpace_ = hl_malloc_device ( maxWorkSpace ) ;
@ -168,14 +141,13 @@ void ConvOperator::allocConvWorkSpace(size_t maxWorkSpace) {
}
}
}
}
void ConvOperator : : reshape ( int batchSize ) {
void ConvOperator : : reshape ( int batchSize ) {
imageH_ = ins_ [ 0 ] - > getFrameHeight ( ) ;
imageH_ = ins_ [ 0 ] - > getFrameHeight ( ) ;
imageW_ = ins_ [ 0 ] - > getFrameWidth ( ) ;
imageW_ = ins_ [ 0 ] - > getFrameWidth ( ) ;
if ( imageH_ = = 0 ) imageH_ = imgSize_ ;
if ( imageH_ = = 0 ) imageH_ = imgSize_ ;
if ( imageW_ = = 0 ) imageW_ = imgSize_ ;
if ( imageW_ = = 0 ) imageW_ = imgSize_ ;
outputH_ = outputSize ( imageH_ , filterSizeY_ , paddingY_ , strideY_ );
outputH_ = outputSize ( imageH_ , filterSizeY_ , paddingY_ , strideY_ , caffeMode_ );
outputW_ = outputSize ( imageW_ , filterSize_ , padding_ , stride_ );
outputW_ = outputSize ( imageW_ , filterSize_ , padding_ , stride_ , caffeMode_ );
out_ - > setFrameHeight ( outputH_ ) ;
out_ - > setFrameHeight ( outputH_ ) ;
out_ - > setFrameWidth ( outputW_ ) ;
out_ - > setFrameWidth ( outputW_ ) ;
@ -183,10 +155,10 @@ void ConvOperator::reshape(int batchSize) {
reshapeImageDescriptors ( ) ;
reshapeImageDescriptors ( ) ;
if ( ! isSelectAlgo_ ) {
if ( ! isSelectAlgo_ ) {
hl_conv_workspace ( inputDesc_ , outputDesc_ , filterDesc_ ,
hl_conv_workspace ( inputDesc_ , outputDesc_ , filterDesc_ , convDesc_ ,
convDesc_ , & fwdAlgo_ , & fwdLimitBytes_ ,
& fwdAlgo_ , & fwdLimitBytes_ , & bwdDataAlgo_ ,
& bwdDataAlgo_ , & bwdDataLimitBytes _,
& bwdDataLimitBytes_ , & bwdFilterAlgo _,
& bwdFilterAlgo_ , & bwdFilterLimitBytes_ ) ;
& bwdFilterLimitBytes_ ) ;
size_t maxWorkSpace = 0 ;
size_t maxWorkSpace = 0 ;
maxWorkSpace = std : : max ( fwdLimitBytes_ , bwdDataLimitBytes_ ) ;
maxWorkSpace = std : : max ( fwdLimitBytes_ , bwdDataLimitBytes_ ) ;
@ -202,7 +174,8 @@ void ConvOperator::computeConvSizes() {
hl_create_filter_descriptor ( & filterDesc_ , channels_ , numFilters_ ,
hl_create_filter_descriptor ( & filterDesc_ , channels_ , numFilters_ ,
filterSizeY_ , filterSize_ ) ;
filterSizeY_ , filterSize_ ) ;
hl_create_tensor_descriptor ( & inputDesc_ ) ;
hl_create_tensor_descriptor ( & inputDesc_ ) ;
int outputX = outputSize ( imgSize_ , filterSize_ , padding_ , stride_ ) ;
int outputX =
outputSize ( imgSize_ , filterSize_ , padding_ , stride_ , caffeMode_ ) ;
CHECK_EQ ( outputX , outputX_ ) ;
CHECK_EQ ( outputX , outputX_ ) ;
hl_create_tensor_descriptor ( & outputDesc_ ) ;
hl_create_tensor_descriptor ( & outputDesc_ ) ;
hl_create_convolution_descriptor ( & convDesc_ , inputDesc_ , filterDesc_ ,
hl_create_convolution_descriptor ( & convDesc_ , inputDesc_ , filterDesc_ ,
@ -211,13 +184,13 @@ void ConvOperator::computeConvSizes() {
void ConvOperator : : reshapeImageDescriptors ( ) {
void ConvOperator : : reshapeImageDescriptors ( ) {
hl_tensor_reshape ( inputDesc_ , 1 , channels_ , imageH_ , imageW_ ,
hl_tensor_reshape ( inputDesc_ , 1 , channels_ , imageH_ , imageW_ ,
channels_ * imageH_ * imageW_ , imageH_ * imageW_ ,
channels_ * imageH_ * imageW_ , imageH_ * imageW_ , imageW_ ,
imageW_ , 1 ) ;
1 ) ;
hl_tensor_reshape ( outputDesc_ , 1 , numFilters_ , outputH_ , outputW_ ,
hl_tensor_reshape ( outputDesc_ , 1 , numFilters_ , outputH_ , outputW_ ,
numFilters_ * outputH_ * outputW_ , outputH_ * outputW_ ,
numFilters_ * outputH_ * outputW_ , outputH_ * outputW_ ,
outputW_ , 1 ) ;
outputW_ , 1 ) ;
hl_reset_convolution_descriptor ( convDesc_ , inputDesc_ , filterDesc_ ,
hl_reset_convolution_descriptor ( convDesc_ , inputDesc_ , filterDesc_ , paddingY_ ,
padding Y_, padding _, strideY_ , stride_ ) ;
padding _, strideY_ , stride_ ) ;
inputOffset_ = channels_ * imageH_ * imageW_ ;
inputOffset_ = channels_ * imageH_ * imageW_ ;
outputOffset_ = numFilters_ * outputH_ * outputW_ ;
outputOffset_ = numFilters_ * outputH_ * outputW_ ;
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSize_ ;
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSize_ ;
@ -273,18 +246,17 @@ void ConvOperator::backward() {
real * weightGrad = ins_ [ 1 ] - > grad - > getData ( ) + weightOffset_ * batchId ;
real * weightGrad = ins_ [ 1 ] - > grad - > getData ( ) + weightOffset_ * batchId ;
hl_convolution_backward_filter ( inputDesc_ , inputData , outputDesc_ ,
hl_convolution_backward_filter ( inputDesc_ , inputData , outputDesc_ ,
outGrad , filterDesc_ , weightGrad ,
outGrad , filterDesc_ , weightGrad ,
convDesc_ , workSpace_ ,
convDesc_ , workSpace_ , workSpaceInBytes_ ,
workSpaceInBytes_, bwdFilterAlgo_) ;
bwdFilterAlgo_) ;
}
}
MatrixPtr preGrad = ins_ [ 0 ] - > grad ;
MatrixPtr preGrad = ins_ [ 0 ] - > grad ;
if ( NULL ! = preGrad ) {
if ( NULL ! = preGrad ) {
real * inputGrad = preGrad - > getData ( ) + inputOffset_ * batchId ;
real * inputGrad = preGrad - > getData ( ) + inputOffset_ * batchId ;
real * wgtData = ins_ [ 1 ] - > value - > getData ( ) + weightOffset_ * batchId ;
real * wgtData = ins_ [ 1 ] - > value - > getData ( ) + weightOffset_ * batchId ;
hl_convolution_backward_data ( inputDesc_ , inputGrad , outputDesc_ ,
hl_convolution_backward_data (
outGrad , filterDesc_ , wgtData ,
inputDesc_ , inputGrad , outputDesc_ , outGrad , filterDesc_ , wgtData ,
convDesc_ , workSpace_ ,
convDesc_ , workSpace_ , workSpaceInBytes_ , bwdDataAlgo_ ) ;
workSpaceInBytes_ , bwdDataAlgo_ ) ;
}
}
}
}
}
}