|
|
@ -114,27 +114,30 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
create(tmpBiasGrad_, 1, channels_, &betaGrad);
|
|
|
|
create(tmpBiasGrad_, 1, channels_, &betaGrad);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#if CUDNN_VERSION < 5000
|
|
|
|
|
|
|
|
// because of the different api of cudnn v4 and v5.
|
|
|
|
// because of the different api of cudnn v4 and v5.
|
|
|
|
if (weight_->getWGrad()) {
|
|
|
|
if (hl_get_cudnn_lib_version() < 5000) {
|
|
|
|
create(tmpWGrad_, 1, channels_, &gammaGrad);
|
|
|
|
if (weight_->getWGrad()) {
|
|
|
|
}
|
|
|
|
create(tmpWGrad_, 1, channels_, &gammaGrad);
|
|
|
|
if (biases_ && biases_->getWGrad()) {
|
|
|
|
}
|
|
|
|
create(tmpBiasGrad_, 1, channels_, &betaGrad);
|
|
|
|
if (biases_ && biases_->getWGrad()) {
|
|
|
|
|
|
|
|
create(tmpBiasGrad_, 1, channels_, &betaGrad);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
|
|
|
hl_batch_norm_backward(ioDesc_, input, ioDesc_, outGrad,
|
|
|
|
hl_batch_norm_backward(ioDesc_, input, ioDesc_, outGrad,
|
|
|
|
ioDesc_, inGrad, bnParamDesc_,
|
|
|
|
ioDesc_, inGrad, bnParamDesc_,
|
|
|
|
gamma, gammaGrad, betaGrad,
|
|
|
|
gamma, gammaGrad, betaGrad,
|
|
|
|
EPS, savedMean, savedInvVar);
|
|
|
|
EPS, savedMean, savedInvVar);
|
|
|
|
|
|
|
|
|
|
|
|
#if CUDNN_VERSION < 5000
|
|
|
|
|
|
|
|
// because of the different api of cudnn v4 and v5.
|
|
|
|
// because of the different api of cudnn v4 and v5.
|
|
|
|
if (weight_->getWGrad() && biases_->getWGrad()) {
|
|
|
|
if (hl_get_cudnn_lib_version() < 5000) {
|
|
|
|
weight_->getWGrad()->add(*tmpWGrad_);
|
|
|
|
if (weight_->getWGrad() && biases_->getWGrad()) {
|
|
|
|
biases_->getWGrad()->add(*tmpBiasGrad_);
|
|
|
|
weight_->getWGrad()->add(*tmpWGrad_);
|
|
|
|
|
|
|
|
biases_->getWGrad()->add(*tmpBiasGrad_);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
|
|
|
|
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
|
|
|
|
biases_->getParameterPtr()->incUpdate(callback);
|
|
|
|
biases_->getParameterPtr()->incUpdate(callback);
|
|
|
|