Bug fix & add test of GemmConvGradFilter.

gangliao-patch-1
hedaoyuan 8 years ago
parent 6a93f0f37a
commit 90326198e9

@ -19,11 +19,18 @@ limitations under the License. */
namespace paddle {
enum TestType {
FORWARD_TEST = 0,
BACKWARD_INPUT_TEST = 1,
BACKWARD_FILTER_TEST = 2,
};
template <DeviceType DType1, DeviceType DType2>
class ConvolutionTest {
public:
ConvolutionTest(const std::string& conv1,
const std::string& conv2,
TestType type,
std::string algo = "auto") {
for (size_t batchSize : {1, 32}) {
for (size_t inputSize : {7, 14, 54}) {
@ -58,16 +65,31 @@ public:
.set("groups", (size_t)1)
.set("algo", algo));
TensorShape shape0{
TensorShape input{
batchSize, inputChannels, inputSize, inputSize};
TensorShape shape1{
TensorShape filter{
outputChannels, inputChannels, filterSize, filterSize};
TensorShape shape2{
TensorShape output{
batchSize, outputChannels, outputSize, outputSize};
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape0));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape1));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, shape2));
test.run();
if (type == FORWARD_TEST) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.run();
} else if (type == BACKWARD_INPUT_TEST) {
#if 0
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.run();
#endif
} else if (type == BACKWARD_FILTER_TEST) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.run();
}
}
}
}
@ -78,15 +100,20 @@ public:
}
};
TEST(Convolution, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test("NaiveConv-CPU",
"GemmConv-CPU");
TEST(Forward, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
"NaiveConv-CPU", "GemmConv-CPU", FORWARD_TEST);
}
#ifndef PADDLE_ONLY_CPU
TEST(Convolution, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test("GemmConv-CPU",
"GemmConv-GPU");
TEST(Forward, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConv-CPU", "GemmConv-GPU", FORWARD_TEST);
}
TEST(BackwardFilter, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", BACKWARD_FILTER_TEST);
}
#endif

@ -255,9 +255,9 @@ public:
filterGrad + g * filterOffset,
N);
}
inputData += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
}
inputData += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
}
};

Loading…
Cancel
Save