| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -87,9 +87,6 @@ void ConvBaseProjection::initCudnn() {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  bwdDataLimitBytes_ = 0;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  bwdFilterLimitBytes_ = 0;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  workSpaceInBytes_ = 0;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  batchNum_ = 0;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  isSelectAlgo_ = false;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					void ConvBaseProjection::reshapeTensorDesc(int batchSize) {
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -142,32 +139,25 @@ void ConvBaseProjection::reshape(int batchSize) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  CHECK_EQ(width, out_->value->getWidth());
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  CHECK_EQ(calInputSize(), in_->value->getWidth());
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  isSelectAlgo_ = (batchSize == batchNum_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  batchNum_ = batchSize;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (!isSelectAlgo_) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    reshapeTensorDesc(batchSize);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    hl_conv_workspace(imageDesc_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      outputDesc_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      filterDesc_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      convDesc_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      &fwdAlgo_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      &fwdLimitBytes_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      &bwdDataAlgo_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      &bwdDataLimitBytes_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      &bwdFilterAlgo_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                      &bwdFilterLimitBytes_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    size_t maxWorkSpace = 0;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    workSpaceInBytes_ = maxWorkSpace;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            << " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  isSelectAlgo_ = true;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  reshapeTensorDesc(batchSize);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  hl_conv_workspace(imageDesc_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                    outputDesc_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                    filterDesc_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                    convDesc_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                    &fwdAlgo_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                    &fwdLimitBytes_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                    &bwdDataAlgo_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                    &bwdDataLimitBytes_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                    &bwdFilterAlgo_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                    &bwdFilterLimitBytes_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  size_t maxWorkSpace = 0;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  workSpaceInBytes_ = maxWorkSpace;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					          << " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					void *ConvBaseProjection::getSpaceBytes(size_t size) {
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				 | 
				
					
 
 |