fix cudnnBatchNorm for 3D data

enforce_failed
chengduoZH 8 years ago
parent 4cb2966d7b
commit 5500153a6d

@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
} }
void CudnnBatchNormLayer::reshape(int batchSize) { void CudnnBatchNormLayer::reshape(int batchSize) {
hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_); hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_);
} }
void CudnnBatchNormLayer::forward(PassType passType) { void CudnnBatchNormLayer::forward(PassType passType) {
@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
EPS, EPS,
batchSize, batchSize,
channels_, channels_,
imageH_, imageH_ * imageD_,
imageW_); imageW_);
} }
} }

Loading…
Cancel
Save