diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp index d7dbf087c5..acc88a553a 100644 --- a/paddle/function/Im2ColTest.cpp +++ b/paddle/function/Im2ColTest.cpp @@ -20,7 +20,8 @@ limitations under the License. */ namespace paddle { -TEST(Im2ColFunctor, real) { +template +void TestIm2ColFunctor() { for (size_t channels : {1, 5, 32}) { for (size_t inputHeight : {5, 33, 100}) { for (size_t inputWidth : {5, 32, 96}) { @@ -50,16 +51,18 @@ TEST(Im2ColFunctor, real) { filterHeight, filterWidth}); - VectorPtr input = Vector::create(imShape.getElements(), false); size_t height = channels * filterHeight * filterWidth; size_t width = outputHeight * outputWidth; + VectorPtr input1 = Vector::create(imShape.getElements(), false); + VectorPtr input2 = Vector::create(imShape.getElements(), false); MatrixPtr output1 = Matrix::create(height, width, false, false); MatrixPtr output2 = Matrix::create(width, height, false, false); - Im2ColFunctor im2col1; - Im2ColFunctor im2col2; + input1->uniform(0.001, 1); + input2->copyFrom(*input1); - input->uniform(0.001, 1); - im2col1(input->getData(), + Im2ColFunctor im2Col1; + Im2ColFunctor im2Col2; + im2Col1(input1->getData(), imShape, output1->getData(), colShape1, @@ -67,7 +70,7 @@ TEST(Im2ColFunctor, real) { stride, padding, padding); - im2col2(input->getData(), + im2Col2(input2->getData(), imShape, output2->getData(), colShape2, @@ -76,27 +79,32 @@ TEST(Im2ColFunctor, real) { padding, padding); + // The transposition of the result of ColFormat == kCFO + // is equal to the result of ColFormat == kOCF. MatrixPtr test; output2->transpose(test, true); autotest::TensorCheckErr(*output1, *test); - } - } - } - } - } - } - } -} -#if 0 -TEST(Col2ImFunctor, real) { - for (size_t channels : {1, 5, 32}) { - for (size_t inputHeight : {5, 33, 100}) { - for (size_t inputWidth : {5, 32, 96}) { - for (size_t filterHeight : {1, 5}) { - for (size_t filterWidth : {3, 7}) { - for (size_t stride : {1, 2}) { - for (size_t padding : {0, 1}) { + Col2ImFunctor col2Im1; + Col2ImFunctor col2Im2; + col2Im1(input1->getData(), + imShape, + output1->getData(), + colShape1, + stride, + stride, + padding, + padding); + col2Im2(input2->getData(), + imShape, + output2->getData(), + colShape2, + stride, + stride, + padding, + padding); + + autotest::TensorCheckErr(*input1, *input2); } } } @@ -105,6 +113,13 @@ TEST(Col2ImFunctor, real) { } } } + +TEST(Im2ColFunctor, CPU) { TestIm2ColFunctor(); } + +#ifndef PADDLE_ONLY_CPU + +TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor(); } + #endif } // namespace paddle