|
|
|
@ -1022,6 +1022,15 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc,
|
|
|
|
|
real alpha = 1.0f;
|
|
|
|
|
real beta = 1.0f;
|
|
|
|
|
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
|
|
|
|
|
|
|
|
|
|
int batch_size = ((cudnn_tensor_descriptor)inputDesc)->batch_size;
|
|
|
|
|
if (batch_size > 1024 && g_cudnn_lib_version < 6000) {
|
|
|
|
|
LOG(INFO) << " To process current batch data with size " << batch_size
|
|
|
|
|
<< " (>1024), cudnnBatchNorm requires cuDNN version >= 6000."
|
|
|
|
|
<< " If there is an error complaining CUDNN_STATUS_NOT_SUPPORTED,"
|
|
|
|
|
<< " just recompile PaddlePaddle with cuDNN >= 6000, replacing"
|
|
|
|
|
<< " current version " << g_cudnn_lib_version;
|
|
|
|
|
}
|
|
|
|
|
CHECK_CUDNN(
|
|
|
|
|
dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle,
|
|
|
|
|
mode,
|
|
|
|
|