gangliao-patch-1
hedaoyuan 8 years ago
parent c70d3e1af8
commit 3408b4b2f4

@ -58,7 +58,7 @@ public:
CHECK_EQ(outputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]);
CHECK(inputs[0].shape()[1] == inputs[1].shape()[1]);
CHECK(inputs[0].shape()[1] / groups_ == inputs[1].shape()[1]);
CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]);
}

@ -83,9 +83,11 @@ TEST(Convolution, GEMM) {
"GemmConv-CPU");
}
#ifndef PADDLE_ONLY_CPU
TEST(Convolution, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test("GemmConv-CPU",
"GemmConv-GPU");
}
#endif
} // namespace paddle

@ -101,8 +101,6 @@ public:
size_t outputHeight = outputs[0].shape()[2];
size_t outputWidth = outputs[0].shape()[3];
CHECK_EQ(inputChannels / groups_, inputs[1].shape()[1]);
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>();
@ -134,9 +132,9 @@ public:
outputWidth,
colData);
int M = outputChannels;
int M = outputChannels / groups_;
int N = outputHeight * outputWidth;
int K = inputChannels * filterHeight * filterWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth;
gemm(M,
N,
K,

Loading…
Cancel
Save