diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp index 44ba2c4b7d..49a9540c0b 100644 --- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp +++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp @@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap, } void CudnnBatchNormLayer::reshape(int batchSize) { - hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_); + hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_); } void CudnnBatchNormLayer::forward(PassType passType) { @@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) { EPS, batchSize, channels_, - imageH_, + imageH_ * imageD_, imageW_); } }