|
|
|
@ -87,9 +87,6 @@ void ConvBaseProjection::initCudnn() {
|
|
|
|
|
bwdDataLimitBytes_ = 0;
|
|
|
|
|
bwdFilterLimitBytes_ = 0;
|
|
|
|
|
workSpaceInBytes_ = 0;
|
|
|
|
|
|
|
|
|
|
batchNum_ = 0;
|
|
|
|
|
isSelectAlgo_ = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ConvBaseProjection::reshapeTensorDesc(int batchSize) {
|
|
|
|
@ -142,32 +139,25 @@ void ConvBaseProjection::reshape(int batchSize) {
|
|
|
|
|
CHECK_EQ(width, out_->value->getWidth());
|
|
|
|
|
CHECK_EQ(calInputSize(), in_->value->getWidth());
|
|
|
|
|
|
|
|
|
|
isSelectAlgo_ = (batchSize == batchNum_);
|
|
|
|
|
batchNum_ = batchSize;
|
|
|
|
|
|
|
|
|
|
if (!isSelectAlgo_) {
|
|
|
|
|
reshapeTensorDesc(batchSize);
|
|
|
|
|
hl_conv_workspace(imageDesc_,
|
|
|
|
|
outputDesc_,
|
|
|
|
|
filterDesc_,
|
|
|
|
|
convDesc_,
|
|
|
|
|
&fwdAlgo_,
|
|
|
|
|
&fwdLimitBytes_,
|
|
|
|
|
&bwdDataAlgo_,
|
|
|
|
|
&bwdDataLimitBytes_,
|
|
|
|
|
&bwdFilterAlgo_,
|
|
|
|
|
&bwdFilterLimitBytes_);
|
|
|
|
|
|
|
|
|
|
size_t maxWorkSpace = 0;
|
|
|
|
|
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
|
|
|
|
|
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
|
|
|
|
|
workSpaceInBytes_ = maxWorkSpace;
|
|
|
|
|
|
|
|
|
|
VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
|
|
|
|
|
<< " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
isSelectAlgo_ = true;
|
|
|
|
|
reshapeTensorDesc(batchSize);
|
|
|
|
|
hl_conv_workspace(imageDesc_,
|
|
|
|
|
outputDesc_,
|
|
|
|
|
filterDesc_,
|
|
|
|
|
convDesc_,
|
|
|
|
|
&fwdAlgo_,
|
|
|
|
|
&fwdLimitBytes_,
|
|
|
|
|
&bwdDataAlgo_,
|
|
|
|
|
&bwdDataLimitBytes_,
|
|
|
|
|
&bwdFilterAlgo_,
|
|
|
|
|
&bwdFilterLimitBytes_);
|
|
|
|
|
|
|
|
|
|
size_t maxWorkSpace = 0;
|
|
|
|
|
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
|
|
|
|
|
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
|
|
|
|
|
workSpaceInBytes_ = maxWorkSpace;
|
|
|
|
|
|
|
|
|
|
VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
|
|
|
|
|
<< " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void *ConvBaseProjection::getSpaceBytes(size_t size) {
|
|
|
|
|