From 9ab8faaf76057c21be368adb0e23999b3acc5028 Mon Sep 17 00:00:00 2001 From: xzl Date: Thu, 3 May 2018 13:05:01 +0800 Subject: [PATCH] fix pool with mask layer bug --- paddle/math/Matrix.cpp | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 0e84cb3739..bcd6dfe1fd 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -2157,26 +2157,20 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, int wend = wstart + sizeX; wstart = wstart < 0 ? 0 : wstart; wend = wend < (int)imgSizeW ? wend : (int)imgSizeW; - if (maskData == NULL) { - real tmp = -(real)FLT_MAX; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - tmp = tmp < inputData[h * imgSizeW + w] - ? inputData[h * imgSizeW + w] - : tmp; - } - } - outData[ph * outputW + pw] = tmp; - } else { - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - if (outData[ph * outputW + pw] < inputData[h * imgSizeW + w]) { - outData[ph * outputW + pw] = inputData[h * imgSizeW + w]; - maskData[ph * outputW + pw] = h * imgSizeW + w; - } + + real maxval = -(real)FLT_MAX; + int max_index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (maxval < inputData[h * imgSizeW + w]) { + maxval = inputData[h * imgSizeW + w]; + max_index = h * imgSizeW + w; } } } + + outData[ph * outputW + pw] = maxval; + if (maskData != NULL) maskData[ph * outputW + pw] = max_index; } } // compute offset