|
|
@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void CudnnBatchNormLayer::reshape(int batchSize) {
|
|
|
|
void CudnnBatchNormLayer::reshape(int batchSize) {
|
|
|
|
hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_);
|
|
|
|
hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void CudnnBatchNormLayer::forward(PassType passType) {
|
|
|
|
void CudnnBatchNormLayer::forward(PassType passType) {
|
|
|
@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
|
|
|
|
EPS,
|
|
|
|
EPS,
|
|
|
|
batchSize,
|
|
|
|
batchSize,
|
|
|
|
channels_,
|
|
|
|
channels_,
|
|
|
|
imageH_,
|
|
|
|
imageH_ * imageD_,
|
|
|
|
imageW_);
|
|
|
|
imageW_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|