|
|
|
@ -21,8 +21,6 @@ namespace paddle {
|
|
|
|
|
|
|
|
|
|
REGISTER_LAYER(cudnn_batch_norm, CudnnBatchNormLayer);
|
|
|
|
|
|
|
|
|
|
const double CudnnBatchNormLayer::MIN_EPS = 1E-5;
|
|
|
|
|
|
|
|
|
|
bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
|
|
|
|
|
const ParameterMap& parameterMap) {
|
|
|
|
|
/* Initialize the basic parent class */
|
|
|
|
@ -61,14 +59,8 @@ void CudnnBatchNormLayer::forward(PassType passType) {
|
|
|
|
|
real* movingMean = movingMean_->getW()->getData();
|
|
|
|
|
real* movingVar = movingVar_->getW()->getData();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* If epsilon_ equals to 1e-5 and eps_ is assigned the value of
|
|
|
|
|
* static_cast<double>(epsilon_), The CUDNN_STATUS_BAD_PARAM error
|
|
|
|
|
* will occur due to eps_ value is less than
|
|
|
|
|
* CUDNN_BN_MIN_EPSILON.
|
|
|
|
|
* The following code is to ensure that the eps_ meets requirement.
|
|
|
|
|
*/
|
|
|
|
|
eps_ = std::max(MIN_EPS, static_cast<double>(epsilon_));
|
|
|
|
|
// cuDNN does not allow an epsilon value less than CUDNN_BN_MIN_EPSILON.
|
|
|
|
|
eps_ = std::max(CUDNN_BN_MIN_EPSILON, static_cast<double>(epsilon_));
|
|
|
|
|
|
|
|
|
|
if (!useGlobalStats_) {
|
|
|
|
|
REGISTER_TIMER_INFO("CudnnBatchFwTimer", getName().c_str());
|
|
|
|
@ -137,14 +129,8 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
|
|
|
|
|
real* savedMean = savedMean_->getData();
|
|
|
|
|
real* savedInvVar = savedInvVar_->getData();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* If epsilon_ equals to 1e-5 and eps_ is assigned the value of
|
|
|
|
|
* static_cast<double>(epsilon_), The CUDNN_STATUS_BAD_PARAM error
|
|
|
|
|
* will occur due to eps_ value is less than
|
|
|
|
|
* CUDNN_BN_MIN_EPSILON.
|
|
|
|
|
* The following code is to ensure that the eps_ meets requirement.
|
|
|
|
|
*/
|
|
|
|
|
eps_ = std::max(MIN_EPS, static_cast<double>(epsilon_));
|
|
|
|
|
// cuDNN does not allow an epsilon value less than CUDNN_BN_MIN_EPSILON.
|
|
|
|
|
eps_ = std::max(CUDNN_BN_MIN_EPSILON, static_cast<double>(epsilon_));
|
|
|
|
|
|
|
|
|
|
auto create = [](MatrixPtr& m, size_t h, size_t w, real** p) {
|
|
|
|
|
Matrix::resizeOrCreate(m, h, w, false, true);
|
|
|
|
|