gangliao-patch-1
hedaoyuan 8 years ago
parent c70d3e1af8
commit 3408b4b2f4

@ -58,7 +58,7 @@ public:
CHECK_EQ(outputs[0].shape().ndims(), (size_t)4); CHECK_EQ(outputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]); CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]);
CHECK(inputs[0].shape()[1] == inputs[1].shape()[1]); CHECK(inputs[0].shape()[1] / groups_ == inputs[1].shape()[1]);
CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]); CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]);
} }

@ -83,9 +83,11 @@ TEST(Convolution, GEMM) {
"GemmConv-CPU"); "GemmConv-CPU");
} }
#ifndef PADDLE_ONLY_CPU
TEST(Convolution, GEMM2) { TEST(Convolution, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test("GemmConv-CPU", ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test("GemmConv-CPU",
"GemmConv-GPU"); "GemmConv-GPU");
} }
#endif
} // namespace paddle } // namespace paddle

@ -101,8 +101,6 @@ public:
size_t outputHeight = outputs[0].shape()[2]; size_t outputHeight = outputs[0].shape()[2];
size_t outputWidth = outputs[0].shape()[3]; size_t outputWidth = outputs[0].shape()[3];
CHECK_EQ(inputChannels / groups_, inputs[1].shape()[1]);
real* inputData = inputs[0].data<real>(); real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>(); real* outputData = outputs[0].data<real>();
@ -134,9 +132,9 @@ public:
outputWidth, outputWidth,
colData); colData);
int M = outputChannels; int M = outputChannels / groups_;
int N = outputHeight * outputWidth; int N = outputHeight * outputWidth;
int K = inputChannels * filterHeight * filterWidth; int K = inputChannels / groups_ * filterHeight * filterWidth;
gemm(M, gemm(M,
N, N,
K, K,

Loading…
Cancel
Save