From c9d8cb4e90597409257da63c3d788ad067382772 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 11 Sep 2017 21:25:30 +0800 Subject: [PATCH 01/19] Convolution op and forward calculation. --- paddle/operators/conv_op.cc | 96 ++++++++++++++++ paddle/operators/conv_op.cu | 22 ++++ paddle/operators/gemm_conv_op.h | 103 ++++++++++++++++++ paddle/pybind/pybind.cc | 1 + .../paddle/v2/framework/tests/CMakeLists.txt | 1 + .../v2/framework/tests/test_conv2d_op.py | 62 +++++++++++ 6 files changed, 285 insertions(+) create mode 100644 paddle/operators/conv_op.cc create mode 100644 paddle/operators/conv_op.cu create mode 100644 paddle/operators/gemm_conv_op.h create mode 100644 python/paddle/v2/framework/tests/test_conv2d_op.py diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc new file mode 100644 index 0000000000..873366394d --- /dev/null +++ b/paddle/operators/conv_op.cc @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/gemm_conv_op.h" + +namespace paddle { +namespace operators { + +int outputSize(int input_size, int filter_size, int padding, int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +class Conv2DOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto *in = ctx.Input("Input"); + auto *filter = ctx.Input("Filter"); + auto *out = ctx.Output("Output"); + PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp intput should be 4-D."); + PADDLE_ENFORCE_EQ(filter->dims().size(), 4, + "Conv2DOp filter should be 4-D."); + + std::vector strides = Attr>("strides"); + std::vector paddings = Attr>("paddings"); + auto output_height = + outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]); + auto output_width = + outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]); + out->Resize( + {in->dims()[0], filter->dims()[0], output_height, output_width}); + } +}; + +class Conv2DOppMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv2DOppMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "The input tensor of convolution operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddInput( + "Filter", + "The filter tensor of convolution operator." + "The format of the filter tensor is MCHW, where M is the number of " + "output " + "image channels, C is the number of input image channels, H and W is " + " height and width of filter."); + AddOutput("Output", + "The output tensor of convolution operator." + "The format of output tensor is also NCHW."); + AddComment(R"DOC( +The convolution operation calculates the output based on +the input, filter and strides, paddings parameters. +)DOC"); + AddAttr>("strides", "strides of convolution operator."); + AddAttr>("paddings", "paddings of convolution operator."); + } +}; + +class Conv2DOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOppMaker, conv2d_grad, + ops::Conv2DOpGrad); + +REGISTER_OP_CPU_KERNEL(conv2d, + ops::GemmConvKernel); +REGISTER_OP_CPU_KERNEL( + conv2d_grad, ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_op.cu b/paddle/operators/conv_op.cu new file mode 100644 index 0000000000..a15adecda4 --- /dev/null +++ b/paddle/operators/conv_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/gemm_conv_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(conv2d, + ops::GemmConvKernel); +REGISTER_OP_GPU_KERNEL( + conv2d_grad, ops::GemmConvGradKernel); diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h new file mode 100644 index 0000000000..16ea5ff74c --- /dev/null +++ b/paddle/operators/gemm_conv_op.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/im2col.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GemmConvKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + Tensor* filter = const_cast(context.Input("Filter")); + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + paddle::framework::Tensor col; + paddle::framework::Tensor in_slice; + paddle::framework::Tensor out_slice; + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_height = filter->dims()[filter->dims().size() - 2]; + int filter_width = filter->dims()[filter->dims().size() - 1]; + int output_height = output->dims()[2]; + int output_width = output->dims()[3]; + + paddle::operators::math::Im2ColFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + im2col; + framework::DDim col_shape = {input_channels, filter_height, filter_width, + output_height, output_width}; + col.mutable_data(col_shape, context.GetPlace()); + + auto* device_context = + const_cast(context.device_context_); + + framework::DDim input_shape = {input->dims()[1], input->dims()[2], + input->dims()[3]}; + framework::DDim filter_matrix_shape = { + filter->dims()[0], + filter->dims()[1] * filter->dims()[2] * filter->dims()[3]}; + framework::DDim col_matrix_shape = { + input_channels * filter_height * filter_width, + output_height * output_width}; + framework::DDim output_matrix_shape = { + output->dims()[1], output->dims()[2] * output->dims()[3]}; + filter->Resize(filter_matrix_shape); + + // convolution opperator: im2col + gemm + for (int i = 0; i < batch_size; i++) { + // im2col + in_slice = input->Slice(i, i + 1); + in_slice.Resize(input_shape); + col.Resize(col_shape); + im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], + device_context); + + // gemm + out_slice = output->Slice(i, i + 1); + out_slice.Resize(output_matrix_shape); + col.Resize(col_matrix_shape); + math::matmul(*filter, false, col, false, T(1.0), &out_slice, + T(0.0), device_context); + } + } +}; + +template +class GemmConvGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { +#if 0 + auto input = context.Input("Input"); + auto filter = context.Input("Filter"); + auto output = context.Output("Output"); + output->mutable_data(context.GetPlace()); +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 53985933ed..ef72c86cbd 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -51,6 +51,7 @@ USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); USE_OP(top_k); USE_OP(squared_l2_distance); +USE_OP(conv2d); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index ef910f939b..11290e042d 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -35,3 +35,4 @@ py_test(test_lookup_table SRCS test_lookup_table.py) py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py) py_test(mnist SRCS mnist.py) py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py) +py_test(test_conv2d SRCS test_conv2d_op.py) diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py new file mode 100644 index 0000000000..d2015d0ce5 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -0,0 +1,62 @@ +import unittest +import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta + + +class TestConv2dOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "conv2d" + batch_size = 2 + input_channels = 3 + input_height = 5 + input_width = 5 + output_channels = 6 + filter_height = 3 + filter_width = 3 + stride = 1 + padding = 0 + output_height = (input_height - filter_height + 2 * padding + ) / stride + 1 + output_width = (input_width - filter_width + 2 * padding) / stride + 1 + input = np.random.random((batch_size, input_channels, input_height, + input_width)).astype("float32") + filter = np.random.random( + (output_channels, input_channels, filter_height, + filter_width)).astype("float32") + output = np.ndarray( + (batch_size, output_channels, output_height, output_width)) + + for batchid in xrange(batch_size): + for channelid in xrange(output_channels): + for rowid in xrange(output_height): + for colid in xrange(output_width): + start_h = (rowid * stride) - padding + start_w = (colid * stride) - padding + output_value = 0.0 + for inchannelid in xrange(input_channels): + for frowid in xrange(filter_height): + for fcolid in xrange(filter_width): + input_value = 0.0 + inrowid = start_h + frowid + incolid = start_w + fcolid + if ((inrowid >= 0 and + inrowid < input_height) and + (incolid >= 0 and + incolid < input_width)): + input_value = input[batchid][ + inchannelid][inrowid][incolid] + filter_value = filter[channelid][ + inchannelid][frowid][fcolid] + output_value += input_value * filter_value + output[batchid][channelid][rowid][colid] = output_value + + self.inputs = {'Input': input, 'Filter': filter} + self.outputs = {'Output': output} + self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} + + +if __name__ == '__main__': + unittest.main() From 40fe0a8c47cb9613f3e2db462ca74886754f41fe Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 12 Sep 2017 18:08:32 +0800 Subject: [PATCH 02/19] Add backward of convolution. --- paddle/operators/conv_op.cc | 24 ++-- paddle/operators/gemm_conv_op.h | 105 ++++++++++++++++-- .../v2/framework/tests/test_conv2d_op.py | 38 +++++++ 3 files changed, 146 insertions(+), 21 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 873366394d..107682848b 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -28,9 +28,9 @@ class Conv2DOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto *in = ctx.Input("Input"); - auto *filter = ctx.Input("Filter"); - auto *out = ctx.Output("Output"); + auto in = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto out = ctx.Output("Output"); PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp intput should be 4-D."); PADDLE_ENFORCE_EQ(filter->dims().size(), 4, "Conv2DOp filter should be 4-D."); @@ -46,10 +46,9 @@ class Conv2DOp : public framework::OperatorWithKernel { } }; -class Conv2DOppMaker : public framework::OpProtoAndCheckerMaker { +class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { public: - Conv2DOppMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "Input", @@ -62,7 +61,7 @@ class Conv2DOppMaker : public framework::OpProtoAndCheckerMaker { "The format of the filter tensor is MCHW, where M is the number of " "output " "image channels, C is the number of input image channels, H and W is " - " height and width of filter."); + "height and width of filter."); AddOutput("Output", "The output tensor of convolution operator." "The format of output tensor is also NCHW."); @@ -80,14 +79,21 @@ class Conv2DOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} + void InferShape(const framework::InferShapeContext &ctx) const override { + auto in = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto d_in = ctx.Output(framework::GradVarName("Input")); + auto d_filter = ctx.Output(framework::GradVarName("Filter")); + d_in->Resize(in->dims()); + d_filter->Resize(filter->dims()); + } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOppMaker, conv2d_grad, +REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad, ops::Conv2DOpGrad); REGISTER_OP_CPU_KERNEL(conv2d, diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h index 16ea5ff74c..6c72362195 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/im2col.h" #include "paddle/operators/math/math_function.h" @@ -31,12 +32,10 @@ class GemmConvKernel : public framework::OpKernel { Tensor* filter = const_cast(context.Input("Filter")); Tensor* output = context.Output("Output"); output->mutable_data(context.GetPlace()); - paddle::framework::Tensor col; - paddle::framework::Tensor in_slice; - paddle::framework::Tensor out_slice; std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + auto filter_dims = filter->dims(); int batch_size = input->dims()[0]; int input_channels = input->dims()[1]; @@ -50,6 +49,7 @@ class GemmConvKernel : public framework::OpKernel { im2col; framework::DDim col_shape = {input_channels, filter_height, filter_width, output_height, output_width}; + Tensor col; col.mutable_data(col_shape, context.GetPlace()); auto* device_context = @@ -67,22 +67,23 @@ class GemmConvKernel : public framework::OpKernel { output->dims()[1], output->dims()[2] * output->dims()[3]}; filter->Resize(filter_matrix_shape); - // convolution opperator: im2col + gemm + // convolution operator: im2col + gemm for (int i = 0; i < batch_size; i++) { // im2col - in_slice = input->Slice(i, i + 1); + Tensor in_slice = input->Slice(i, i + 1); in_slice.Resize(input_shape); col.Resize(col_shape); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // gemm - out_slice = output->Slice(i, i + 1); + Tensor out_slice = output->Slice(i, i + 1); out_slice.Resize(output_matrix_shape); col.Resize(col_matrix_shape); math::matmul(*filter, false, col, false, T(1.0), &out_slice, T(0.0), device_context); } + filter->Resize(filter_dims); } }; @@ -90,12 +91,92 @@ template class GemmConvGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { -#if 0 - auto input = context.Input("Input"); - auto filter = context.Input("Filter"); - auto output = context.Output("Output"); - output->mutable_data(context.GetPlace()); -#endif + const Tensor* input = context.Input("Input"); + Tensor* filter = const_cast(context.Input("Filter")); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + input_grad->mutable_data(context.GetPlace()); + filter_grad->mutable_data(context.GetPlace()); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + auto filter_dims = filter->dims(); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_height = filter->dims()[filter->dims().size() - 2]; + int filter_width = filter->dims()[filter->dims().size() - 1]; + int output_height = output_grad->dims()[2]; + int output_width = output_grad->dims()[3]; + + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + col2im; + paddle::operators::math::Im2ColFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + im2col; + Tensor col; + framework::DDim col_shape = {input_channels, filter_height, filter_width, + output_height, output_width}; + col.mutable_data(col_shape, context.GetPlace()); + + auto* device_context = + const_cast(context.device_context_); + + framework::DDim input_shape = {input->dims()[1], input->dims()[2], + input->dims()[3]}; + framework::DDim filter_matrix_shape = { + filter->dims()[0], + filter->dims()[1] * filter->dims()[2] * filter->dims()[3]}; + framework::DDim col_matrix_shape = { + input_channels * filter_height * filter_width, + output_height * output_width}; + framework::DDim output_matrix_shape = { + output_grad->dims()[1], + output_grad->dims()[2] * output_grad->dims()[3]}; + filter->Resize(filter_matrix_shape); + filter_grad->Resize(filter_matrix_shape); + + auto t1 = framework::EigenVector::Flatten(*filter_grad); + t1.device(context.GetEigenDevice()) = t1.constant(static_cast(0)); + auto t2 = framework::EigenVector::Flatten(*input_grad); + t2.device(context.GetEigenDevice()) = t2.constant(static_cast(0)); + + // convolution backward input operator: gemm + col2im + // convolution backward weight operator: im2col + gemm + for (int i = 0; i < batch_size; i++) { + // gemm + Tensor out_slice = output_grad->Slice(i, i + 1); + out_slice.Resize(output_matrix_shape); + col.Resize(col_matrix_shape); + math::matmul(*filter, true, out_slice, false, T(1.0), &col, + T(0.0), device_context); + + // col2im + Tensor in_grad_slice = input_grad->Slice(i, i + 1); + in_grad_slice.Resize(input_shape); + col.Resize(col_shape); + col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], + paddings[1], device_context); + + // im2col + Tensor in_slice = input->Slice(i, i + 1); + in_slice.Resize(input_shape); + col.Resize(col_shape); + im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], + device_context); + + // gemm + col.Resize(col_matrix_shape); + math::matmul(out_slice, false, col, true, T(1.0), filter_grad, + T(1.0), device_context); + } + filter->Resize(filter_dims); + filter_grad->Resize(filter_dims); } }; diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index d2015d0ce5..43f328ca03 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -2,6 +2,7 @@ import unittest import numpy as np from gradient_checker import GradientChecker, create_op from op_test_util import OpTestMeta +from paddle.v2.framework.op import Operator class TestConv2dOp(unittest.TestCase): @@ -58,5 +59,42 @@ class TestConv2dOp(unittest.TestCase): self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} +class TestConv2dGradOp(GradientChecker): + def setUp(self): + batch_size = 2 + input_channels = 3 + input_height = 5 + input_width = 5 + output_channels = 6 + filter_height = 3 + filter_width = 3 + stride = 1 + padding = 0 + output_height = (input_height - filter_height + 2 * padding + ) / stride + 1 + output_width = (input_width - filter_width + 2 * padding) / stride + 1 + input = np.random.random((batch_size, input_channels, input_height, + input_width)).astype("float32") + filter = np.random.random( + (output_channels, input_channels, filter_height, + filter_width)).astype("float32") + + self.inputs = {'Input': input, 'Filter': filter} + self.op = Operator( + "conv2d", + Input='Input', + Filter='Filter', + Output='Output', + strides=[1, 1], + paddings=[0, 0]) + + def test_compare_grad(self): + self.compare_grad(self.op, self.inputs) + + def test_check_grad(self): + self.check_grad(self.op, self.inputs, + set(['Input', 'Filter']), 'Output') + + if __name__ == '__main__': unittest.main() From c671189d7fc34a25165e70018f2ce85e85ce205d Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 12 Sep 2017 20:49:51 +0800 Subject: [PATCH 03/19] Fix test_conv2d_op.py. --- .../v2/framework/tests/test_conv2d_op.py | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 43f328ca03..01513be66e 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -1,15 +1,11 @@ import unittest import numpy as np -from gradient_checker import GradientChecker, create_op -from op_test_util import OpTestMeta -from paddle.v2.framework.op import Operator +from op_test import OpTest -class TestConv2dOp(unittest.TestCase): - __metaclass__ = OpTestMeta - +class TestConv2dOp(OpTest): def setUp(self): - self.type = "conv2d" + self.op_type = "conv2d" batch_size = 2 input_channels = 3 input_height = 5 @@ -58,8 +54,11 @@ class TestConv2dOp(unittest.TestCase): self.outputs = {'Output': output} self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} + def test_check_output(self): + self.check_output() + -class TestConv2dGradOp(GradientChecker): +class TestConv2dGradOp(OpTest): def setUp(self): batch_size = 2 input_channels = 3 @@ -79,21 +78,18 @@ class TestConv2dGradOp(GradientChecker): (output_channels, input_channels, filter_height, filter_width)).astype("float32") + self.op_type = 'conv2d' self.inputs = {'Input': input, 'Filter': filter} - self.op = Operator( - "conv2d", - Input='Input', - Filter='Filter', - Output='Output', - strides=[1, 1], - paddings=[0, 0]) + output = np.ndarray( + (batch_size, output_channels, output_height, output_width)) + self.outputs = {'Output': output} + self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} - def test_compare_grad(self): - self.compare_grad(self.op, self.inputs) + #def test_compare_grad(self): + # self.compare_grad(self.op, self.inputs) def test_check_grad(self): - self.check_grad(self.op, self.inputs, - set(['Input', 'Filter']), 'Output') + self.check_grad(set(['Input', 'Filter']), 'Output') if __name__ == '__main__': From a7c1872206cf11ba968a932a0fc880a03e8a4c28 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 12 Sep 2017 21:05:54 +0800 Subject: [PATCH 04/19] Refine test_conv2d_op.py --- .../v2/framework/tests/test_conv2d_op.py | 36 ++----------------- 1 file changed, 3 insertions(+), 33 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 01513be66e..29a637a382 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -26,6 +26,9 @@ class TestConv2dOp(OpTest): output = np.ndarray( (batch_size, output_channels, output_height, output_width)) + self.inputs = {'Input': input, 'Filter': filter} + self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} + for batchid in xrange(batch_size): for channelid in xrange(output_channels): for rowid in xrange(output_height): @@ -50,44 +53,11 @@ class TestConv2dOp(OpTest): output_value += input_value * filter_value output[batchid][channelid][rowid][colid] = output_value - self.inputs = {'Input': input, 'Filter': filter} self.outputs = {'Output': output} - self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} def test_check_output(self): self.check_output() - -class TestConv2dGradOp(OpTest): - def setUp(self): - batch_size = 2 - input_channels = 3 - input_height = 5 - input_width = 5 - output_channels = 6 - filter_height = 3 - filter_width = 3 - stride = 1 - padding = 0 - output_height = (input_height - filter_height + 2 * padding - ) / stride + 1 - output_width = (input_width - filter_width + 2 * padding) / stride + 1 - input = np.random.random((batch_size, input_channels, input_height, - input_width)).astype("float32") - filter = np.random.random( - (output_channels, input_channels, filter_height, - filter_width)).astype("float32") - - self.op_type = 'conv2d' - self.inputs = {'Input': input, 'Filter': filter} - output = np.ndarray( - (batch_size, output_channels, output_height, output_width)) - self.outputs = {'Output': output} - self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} - - #def test_compare_grad(self): - # self.compare_grad(self.op, self.inputs) - def test_check_grad(self): self.check_grad(set(['Input', 'Filter']), 'Output') From 67db9d3521ee3423f9d86004860662e12a601303 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 00:11:39 +0800 Subject: [PATCH 05/19] Refine the GemmConvKernel. --- paddle/operators/gemm_conv_op.h | 47 +++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h index 6c72362195..560dfd311f 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv_op.h @@ -29,61 +29,68 @@ class GemmConvKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); - Tensor* filter = const_cast(context.Input("Filter")); + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); Tensor* output = context.Output("Output"); output->mutable_data(context.GetPlace()); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - auto filter_dims = filter->dims(); int batch_size = input->dims()[0]; int input_channels = input->dims()[1]; - int filter_height = filter->dims()[filter->dims().size() - 2]; - int filter_width = filter->dims()[filter->dims().size() - 1]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output->dims()[1]; int output_height = output->dims()[2]; int output_width = output->dims()[3]; paddle::operators::math::Im2ColFunctor< paddle::operators::math::ColFormat::kCFO, Place, T> im2col; + // use col_shape in the im2col calculation framework::DDim col_shape = {input_channels, filter_height, filter_width, output_height, output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels * filter_height * filter_width, + output_height * output_width}; Tensor col; col.mutable_data(col_shape, context.GetPlace()); - - auto* device_context = - const_cast(context.device_context_); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); framework::DDim input_shape = {input->dims()[1], input->dims()[2], input->dims()[3]}; framework::DDim filter_matrix_shape = { - filter->dims()[0], - filter->dims()[1] * filter->dims()[2] * filter->dims()[3]}; - framework::DDim col_matrix_shape = { - input_channels * filter_height * filter_width, - output_height * output_width}; - framework::DDim output_matrix_shape = { - output->dims()[1], output->dims()[2] * output->dims()[3]}; - filter->Resize(filter_matrix_shape); + output_channels, framework::product(filter.dims()) / output_channels}; + filter.Resize(filter_matrix_shape); + + framework::DDim output_matrix_shape = {output_channels, + output_height * output_width}; + + auto* device_context = + const_cast(context.device_context_); // convolution operator: im2col + gemm for (int i = 0; i < batch_size; i++) { // im2col Tensor in_slice = input->Slice(i, i + 1); in_slice.Resize(input_shape); - col.Resize(col_shape); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // gemm Tensor out_slice = output->Slice(i, i + 1); out_slice.Resize(output_matrix_shape); - col.Resize(col_matrix_shape); - math::matmul(*filter, false, col, false, T(1.0), &out_slice, - T(0.0), device_context); + math::matmul(filter, false, col_matrix, false, T(1.0), + &out_slice, T(0.0), device_context); } - filter->Resize(filter_dims); } }; From db33ff12a5517fb1c3f10abbcdd84d8b071cf92f Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 00:38:18 +0800 Subject: [PATCH 06/19] Refine the GemmConvGradKernel. --- paddle/operators/gemm_conv_op.h | 65 ++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h index 560dfd311f..cdcc0039b0 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv_op.h @@ -68,7 +68,7 @@ class GemmConvKernel : public framework::OpKernel { framework::DDim input_shape = {input->dims()[1], input->dims()[2], input->dims()[3]}; framework::DDim filter_matrix_shape = { - output_channels, framework::product(filter.dims()) / output_channels}; + filter.dims()[0], framework::product(filter.dims()) / filter.dims()[0]}; filter.Resize(filter_matrix_shape); framework::DDim output_matrix_shape = {output_channels, @@ -99,24 +99,28 @@ class GemmConvGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); - Tensor* filter = const_cast(context.Input("Filter")); const Tensor* output_grad = context.Input(framework::GradVarName("Output")); Tensor* input_grad = context.Output(framework::GradVarName("Input")); - Tensor* filter_grad = + Tensor* filter_grad_ = context.Output(framework::GradVarName("Filter")); input_grad->mutable_data(context.GetPlace()); - filter_grad->mutable_data(context.GetPlace()); + filter_grad_->mutable_data(context.GetPlace()); + + // The filter and filter_grad will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + Tensor filter_grad = *filter_grad_; std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - auto filter_dims = filter->dims(); int batch_size = input->dims()[0]; int input_channels = input->dims()[1]; - int filter_height = filter->dims()[filter->dims().size() - 2]; - int filter_width = filter->dims()[filter->dims().size() - 1]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; int output_height = output_grad->dims()[2]; int output_width = output_grad->dims()[3]; @@ -126,64 +130,65 @@ class GemmConvGradKernel : public framework::OpKernel { paddle::operators::math::Im2ColFunctor< paddle::operators::math::ColFormat::kCFO, Place, T> im2col; - Tensor col; + // use col_shape in the im2col and col2im calculation framework::DDim col_shape = {input_channels, filter_height, filter_width, output_height, output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels * filter_height * filter_width, + output_height * output_width}; + Tensor col; col.mutable_data(col_shape, context.GetPlace()); - - auto* device_context = - const_cast(context.device_context_); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); framework::DDim input_shape = {input->dims()[1], input->dims()[2], input->dims()[3]}; - framework::DDim filter_matrix_shape = { - filter->dims()[0], - filter->dims()[1] * filter->dims()[2] * filter->dims()[3]}; - framework::DDim col_matrix_shape = { - input_channels * filter_height * filter_width, - output_height * output_width}; framework::DDim output_matrix_shape = { output_grad->dims()[1], output_grad->dims()[2] * output_grad->dims()[3]}; - filter->Resize(filter_matrix_shape); - filter_grad->Resize(filter_matrix_shape); - auto t1 = framework::EigenVector::Flatten(*filter_grad); + framework::DDim filter_matrix_shape = { + filter.dims()[0], framework::product(filter.dims()) / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + filter_grad.Resize(filter_matrix_shape); + + auto t1 = framework::EigenVector::Flatten(filter_grad); t1.device(context.GetEigenDevice()) = t1.constant(static_cast(0)); auto t2 = framework::EigenVector::Flatten(*input_grad); t2.device(context.GetEigenDevice()) = t2.constant(static_cast(0)); + auto* device_context = + const_cast(context.device_context_); + // convolution backward input operator: gemm + col2im // convolution backward weight operator: im2col + gemm for (int i = 0; i < batch_size; i++) { // gemm Tensor out_slice = output_grad->Slice(i, i + 1); out_slice.Resize(output_matrix_shape); - col.Resize(col_matrix_shape); - math::matmul(*filter, true, out_slice, false, T(1.0), &col, - T(0.0), device_context); + math::matmul(filter, true, out_slice, false, T(1.0), + &col_matrix, T(0.0), device_context); // col2im Tensor in_grad_slice = input_grad->Slice(i, i + 1); in_grad_slice.Resize(input_shape); - col.Resize(col_shape); col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // im2col Tensor in_slice = input->Slice(i, i + 1); in_slice.Resize(input_shape); - col.Resize(col_shape); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // gemm - col.Resize(col_matrix_shape); - math::matmul(out_slice, false, col, true, T(1.0), filter_grad, - T(1.0), device_context); + math::matmul(out_slice, false, col_matrix, true, T(1.0), + &filter_grad, T(1.0), device_context); } - filter->Resize(filter_dims); - filter_grad->Resize(filter_dims); } }; From 5860150d96eefc11f55fe9e8408734001ab0483c Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 10:44:53 +0800 Subject: [PATCH 07/19] Fix Tensor::Slice with dims[0] == 1. --- paddle/framework/tensor_impl.h | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 642b53efc7..3fcbc5447f 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -130,15 +130,19 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound."); PADDLE_ENFORCE_LT(begin_idx, end_idx, "Begin index must be less than end index."); - PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1."); - size_t base = numel() / dims_[0]; - Tensor dst; - dst.holder_ = holder_; - DDim dst_dims = dims_; - dst_dims[0] = end_idx - begin_idx; - dst.Resize(dst_dims); - dst.offset_ = offset_ + begin_idx * base * sizeof(T); - return dst; + + if (dims_[0] == 1) { + return *this; + } else { + size_t base = numel() / dims_[0]; + Tensor dst; + dst.holder_ = holder_; + DDim dst_dims = dims_; + dst_dims[0] = end_idx - begin_idx; + dst.Resize(dst_dims); + dst.offset_ = offset_ + begin_idx * base * sizeof(T); + return dst; + } } inline Tensor& Tensor::Resize(const DDim& dims) { From 8219f20672dcb660174ab9c96f54d7214f248f7a Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 11:01:24 +0800 Subject: [PATCH 08/19] Refine gemm convolution kernel. --- paddle/operators/gemm_conv_op.h | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h index cdcc0039b0..3b7ba685c8 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv_op.h @@ -58,7 +58,7 @@ class GemmConvKernel : public framework::OpKernel { input_channels * filter_height * filter_width, output_height * output_width}; Tensor col; - col.mutable_data(col_shape, context.GetPlace()); + col.mutable_data(col_shape, context.GetPlace()); // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. @@ -67,8 +67,8 @@ class GemmConvKernel : public framework::OpKernel { framework::DDim input_shape = {input->dims()[1], input->dims()[2], input->dims()[3]}; - framework::DDim filter_matrix_shape = { - filter.dims()[0], framework::product(filter.dims()) / filter.dims()[0]}; + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); framework::DDim output_matrix_shape = {output_channels, @@ -80,14 +80,12 @@ class GemmConvKernel : public framework::OpKernel { // convolution operator: im2col + gemm for (int i = 0; i < batch_size; i++) { // im2col - Tensor in_slice = input->Slice(i, i + 1); - in_slice.Resize(input_shape); + Tensor in_slice = input->Slice(i, i + 1).Resize(input_shape); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // gemm - Tensor out_slice = output->Slice(i, i + 1); - out_slice.Resize(output_matrix_shape); + Tensor out_slice = output->Slice(i, i + 1).Resize(output_matrix_shape); math::matmul(filter, false, col_matrix, false, T(1.0), &out_slice, T(0.0), device_context); } @@ -138,7 +136,7 @@ class GemmConvGradKernel : public framework::OpKernel { input_channels * filter_height * filter_width, output_height * output_width}; Tensor col; - col.mutable_data(col_shape, context.GetPlace()); + col.mutable_data(col_shape, context.GetPlace()); // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. @@ -151,8 +149,8 @@ class GemmConvGradKernel : public framework::OpKernel { output_grad->dims()[1], output_grad->dims()[2] * output_grad->dims()[3]}; - framework::DDim filter_matrix_shape = { - filter.dims()[0], framework::product(filter.dims()) / filter.dims()[0]}; + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); filter_grad.Resize(filter_matrix_shape); @@ -168,20 +166,18 @@ class GemmConvGradKernel : public framework::OpKernel { // convolution backward weight operator: im2col + gemm for (int i = 0; i < batch_size; i++) { // gemm - Tensor out_slice = output_grad->Slice(i, i + 1); - out_slice.Resize(output_matrix_shape); + Tensor out_slice = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); math::matmul(filter, true, out_slice, false, T(1.0), &col_matrix, T(0.0), device_context); // col2im - Tensor in_grad_slice = input_grad->Slice(i, i + 1); - in_grad_slice.Resize(input_shape); + Tensor in_grad_slice = input_grad->Slice(i, i + 1).Resize(input_shape); col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // im2col - Tensor in_slice = input->Slice(i, i + 1); - in_slice.Resize(input_shape); + Tensor in_slice = input->Slice(i, i + 1).Resize(input_shape); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); From fb46345f007e7c989d8c5d635dc0ff9d24bbbf31 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 14:15:58 +0800 Subject: [PATCH 09/19] Add groups in convolution operator. --- paddle/operators/conv_op.cc | 22 ++++++++++++++++++-- paddle/operators/gemm_conv_op.h | 36 ++++++++++++++++++++++----------- 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 107682848b..174f777f0e 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -31,12 +31,22 @@ class Conv2DOp : public framework::OperatorWithKernel { auto in = ctx.Input("Input"); auto filter = ctx.Input("Filter"); auto out = ctx.Output("Output"); + std::vector strides = Attr>("strides"); + std::vector paddings = Attr>("paddings"); + int groups = context.Attr("groups"); + int input_channels = in->dims()[1]; + int output_channels = filter->dims()[0]; + PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp intput should be 4-D."); PADDLE_ENFORCE_EQ(filter->dims().size(), 4, "Conv2DOp filter should be 4-D."); + PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups, + "The number of input channels should be equal to filter " + "channels * groups."); + PADDLE_ENFORCE_EQ( + output_channels % groups, 0, + "The number of output channels should be divided by groups."); - std::vector strides = Attr>("strides"); - std::vector paddings = Attr>("paddings"); auto output_height = outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]); auto output_width = @@ -71,6 +81,14 @@ the input, filter and strides, paddings parameters. )DOC"); AddAttr>("strides", "strides of convolution operator."); AddAttr>("paddings", "paddings of convolution operator."); + AddAttr( + "groups", + "group size of convolution operator. " + "Refer to grouped convolution in Alex Krizhevsky's paper: " + "when group=2, the first half of the filters are only connected to the " + "first half of the input channels, and the second half only connected " + "to the second half.") + .SetDefault(1); } }; diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h index 3b7ba685c8..8ac92d3bd2 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv_op.h @@ -38,6 +38,7 @@ class GemmConvKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + int groups = context.Attr("groups"); int batch_size = input->dims()[0]; int input_channels = input->dims()[1]; @@ -51,11 +52,11 @@ class GemmConvKernel : public framework::OpKernel { paddle::operators::math::ColFormat::kCFO, Place, T> im2col; // use col_shape in the im2col calculation - framework::DDim col_shape = {input_channels, filter_height, filter_width, - output_height, output_width}; + framework::DDim col_shape = {input_channels / groups, filter_height, + filter_width, output_height, output_width}; // use col_matrix_shape in the gemm calculation framework::DDim col_matrix_shape = { - input_channels * filter_height * filter_width, + input_channels / groups * filter_height * filter_width, output_height * output_width}; Tensor col; col.mutable_data(col_shape, context.GetPlace()); @@ -78,16 +79,26 @@ class GemmConvKernel : public framework::OpKernel { const_cast(context.device_context_); // convolution operator: im2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; for (int i = 0; i < batch_size; i++) { - // im2col - Tensor in_slice = input->Slice(i, i + 1).Resize(input_shape); - im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], - device_context); - - // gemm - Tensor out_slice = output->Slice(i, i + 1).Resize(output_matrix_shape); - math::matmul(filter, false, col_matrix, false, T(1.0), - &out_slice, T(0.0), device_context); + Tensor in_slice_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_slice_batch = + output->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; g++) { + // im2col + Tensor in_slice = + in_slice_batch.Slice(g * in_step, (g + 1) * in_step); + im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], + device_context); + + // gemm + Tensor out_slice = + out_slice_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(filter_slice, false, col_matrix, false, T(1.0), + &out_slice, T(0.0), device_context); + } } } }; @@ -114,6 +125,7 @@ class GemmConvGradKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + // int groups = context.Attr("groups"); int batch_size = input->dims()[0]; int input_channels = input->dims()[1]; From 2340cedaf604191f16f646bfbb0bf9cb6b7e1934 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 16:45:04 +0800 Subject: [PATCH 10/19] Add groups in convolution GemmConvGradKernel. --- paddle/operators/gemm_conv_op.h | 68 +++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h index 8ac92d3bd2..b125698c6d 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv_op.h @@ -82,19 +82,16 @@ class GemmConvKernel : public framework::OpKernel { int in_step = input_channels / groups; int out_step = output_channels / groups; for (int i = 0; i < batch_size; i++) { - Tensor in_slice_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_slice_batch = - output->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); for (int g = 0; g < groups; g++) { // im2col - Tensor in_slice = - in_slice_batch.Slice(g * in_step, (g + 1) * in_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // gemm - Tensor out_slice = - out_slice_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); math::matmul(filter_slice, false, col_matrix, false, T(1.0), &out_slice, T(0.0), device_context); @@ -125,12 +122,13 @@ class GemmConvGradKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - // int groups = context.Attr("groups"); + int groups = context.Attr("groups"); int batch_size = input->dims()[0]; int input_channels = input->dims()[1]; int filter_height = filter.dims()[filter.dims().size() - 2]; int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output_grad->dims()[1]; int output_height = output_grad->dims()[2]; int output_width = output_grad->dims()[3]; @@ -141,11 +139,11 @@ class GemmConvGradKernel : public framework::OpKernel { paddle::operators::math::ColFormat::kCFO, Place, T> im2col; // use col_shape in the im2col and col2im calculation - framework::DDim col_shape = {input_channels, filter_height, filter_width, - output_height, output_width}; + framework::DDim col_shape = {input_channels / groups, filter_height, + filter_width, output_height, output_width}; // use col_matrix_shape in the gemm calculation framework::DDim col_matrix_shape = { - input_channels * filter_height * filter_width, + input_channels / groups * filter_height * filter_width, output_height * output_width}; Tensor col; col.mutable_data(col_shape, context.GetPlace()); @@ -176,26 +174,38 @@ class GemmConvGradKernel : public framework::OpKernel { // convolution backward input operator: gemm + col2im // convolution backward weight operator: im2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; for (int i = 0; i < batch_size; i++) { - // gemm - Tensor out_slice = + Tensor out_grad_batch = output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - math::matmul(filter, true, out_slice, false, T(1.0), - &col_matrix, T(0.0), device_context); - - // col2im - Tensor in_grad_slice = input_grad->Slice(i, i + 1).Resize(input_shape); - col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], - paddings[1], device_context); - - // im2col - Tensor in_slice = input->Slice(i, i + 1).Resize(input_shape); - im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], - device_context); - - // gemm - math::matmul(out_slice, false, col_matrix, true, T(1.0), - &filter_grad, T(1.0), device_context); + Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(filter_slice, true, out_grad_slice, false, + T(1.0), &col_matrix, T(0.0), device_context); + + // col2im + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], + paddings[1], device_context); + + // im2col + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], + device_context); + + // gemm + Tensor filter_grad_slice = + filter_grad.Slice(g * out_step, (g + 1) * out_step); + math::matmul(out_grad_slice, false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0), device_context); + } } } }; From 1dd639ebbe0763bc0fa36bbe713c8f4ce319e46b Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 17:02:32 +0800 Subject: [PATCH 11/19] Bug fix. --- paddle/operators/conv_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 174f777f0e..593fdc0e7e 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -33,7 +33,7 @@ class Conv2DOp : public framework::OperatorWithKernel { auto out = ctx.Output("Output"); std::vector strides = Attr>("strides"); std::vector paddings = Attr>("paddings"); - int groups = context.Attr("groups"); + int groups = Attr("groups"); int input_channels = in->dims()[1]; int output_channels = filter->dims()[0]; From b4ba35caeb248136461b33c7d47977e09dfb4286 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 17:11:34 +0800 Subject: [PATCH 12/19] Add groups test. --- .../v2/framework/tests/test_conv2d_op.py | 58 +++++++++++-------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 29a637a382..660eb31962 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -15,43 +15,53 @@ class TestConv2dOp(OpTest): filter_width = 3 stride = 1 padding = 0 + groups = 3 output_height = (input_height - filter_height + 2 * padding ) / stride + 1 output_width = (input_width - filter_width + 2 * padding) / stride + 1 input = np.random.random((batch_size, input_channels, input_height, input_width)).astype("float32") + filter = np.random.random( - (output_channels, input_channels, filter_height, + (output_channels, input_channels / groups, filter_height, filter_width)).astype("float32") output = np.ndarray( (batch_size, output_channels, output_height, output_width)) self.inputs = {'Input': input, 'Filter': filter} - self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} + self.attrs = {'strides': [1, 1], 'paddings': [0, 0], 'groups': groups} + output_group_channels = output_channels / groups + input_group_channels = input_channels / groups for batchid in xrange(batch_size): - for channelid in xrange(output_channels): - for rowid in xrange(output_height): - for colid in xrange(output_width): - start_h = (rowid * stride) - padding - start_w = (colid * stride) - padding - output_value = 0.0 - for inchannelid in xrange(input_channels): - for frowid in xrange(filter_height): - for fcolid in xrange(filter_width): - input_value = 0.0 - inrowid = start_h + frowid - incolid = start_w + fcolid - if ((inrowid >= 0 and - inrowid < input_height) and - (incolid >= 0 and - incolid < input_width)): - input_value = input[batchid][ - inchannelid][inrowid][incolid] - filter_value = filter[channelid][ - inchannelid][frowid][fcolid] - output_value += input_value * filter_value - output[batchid][channelid][rowid][colid] = output_value + for group in xrange(groups): + for outchannelid in range(group * output_group_channels, + (group + 1) * output_group_channels): + for rowid in xrange(output_height): + for colid in xrange(output_width): + start_h = (rowid * stride) - padding + start_w = (colid * stride) - padding + output_value = 0.0 + for inchannelid in range( + group * input_group_channels, + (group + 1) * input_group_channels): + for frowid in xrange(filter_height): + for fcolid in xrange(filter_width): + input_value = 0.0 + inrowid = start_h + frowid + incolid = start_w + fcolid + if ((inrowid >= 0 and + inrowid < input_height) and + (incolid >= 0 and + incolid < input_width)): + input_value = input[batchid][ + inchannelid][inrowid][incolid] + filter_value = filter[outchannelid][ + inchannelid % input_group_channels][ + frowid][fcolid] + output_value += input_value * filter_value + output[batchid][outchannelid][rowid][ + colid] = output_value self.outputs = {'Output': output} From 656f775c293480f5cb00dc1983dd9d004df2b578 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 21:25:58 +0800 Subject: [PATCH 13/19] Fix the doc. --- paddle/operators/conv_op.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 593fdc0e7e..934f153e72 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -69,15 +69,17 @@ class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { "Filter", "The filter tensor of convolution operator." "The format of the filter tensor is MCHW, where M is the number of " - "output " - "image channels, C is the number of input image channels, H and W is " - "height and width of filter."); + "output image channels, C is the number of input image channels, " + "H and W is height and width of filter. " + "If the groups attribute is greater than 1, C equal the number of " + "input image channels divided by the groups."); AddOutput("Output", "The output tensor of convolution operator." "The format of output tensor is also NCHW."); AddComment(R"DOC( -The convolution operation calculates the output based on -the input, filter and strides, paddings parameters. +The convolution operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. )DOC"); AddAttr>("strides", "strides of convolution operator."); AddAttr>("paddings", "paddings of convolution operator."); From 09c65b6d4fc3e5e9106c7b3fefc1d04c2c99596b Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 18 Sep 2017 00:02:34 +0800 Subject: [PATCH 14/19] Follow comments. --- paddle/operators/{conv_op.cc => conv2d_op.cc} | 8 ++++---- paddle/operators/{conv_op.cu => conv2d_op.cu} | 8 ++++---- paddle/operators/{gemm_conv_op.h => gemm_conv2d_op.h} | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) rename paddle/operators/{conv_op.cc => conv2d_op.cc} (95%) rename paddle/operators/{conv_op.cu => conv2d_op.cu} (76%) rename paddle/operators/{gemm_conv_op.h => gemm_conv2d_op.h} (98%) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv2d_op.cc similarity index 95% rename from paddle/operators/conv_op.cc rename to paddle/operators/conv2d_op.cc index 934f153e72..b74b42546d 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv2d_op.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/gemm_conv_op.h" +#include "paddle/operators/gemm_conv2d_op.h" namespace paddle { namespace operators { @@ -116,7 +116,7 @@ namespace ops = paddle::operators; REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad, ops::Conv2DOpGrad); -REGISTER_OP_CPU_KERNEL(conv2d, - ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( - conv2d_grad, ops::GemmConvGradKernel); + conv2d, ops::GemmConv2dKernel); +REGISTER_OP_CPU_KERNEL( + conv2d_grad, ops::GemmConvGrad2dKernel); diff --git a/paddle/operators/conv_op.cu b/paddle/operators/conv2d_op.cu similarity index 76% rename from paddle/operators/conv_op.cu rename to paddle/operators/conv2d_op.cu index a15adecda4..7666f4c4c1 100644 --- a/paddle/operators/conv_op.cu +++ b/paddle/operators/conv2d_op.cu @@ -12,11 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/gemm_conv_op.h" +#include "paddle/operators/gemm_conv2d_op.h" namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(conv2d, - ops::GemmConvKernel); REGISTER_OP_GPU_KERNEL( - conv2d_grad, ops::GemmConvGradKernel); + conv2d, ops::GemmConv2dKernel); +REGISTER_OP_GPU_KERNEL( + conv2d_grad, ops::GemmConvGrad2dKernel); diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv2d_op.h similarity index 98% rename from paddle/operators/gemm_conv_op.h rename to paddle/operators/gemm_conv2d_op.h index b125698c6d..71bf09bb7e 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv2d_op.h @@ -25,7 +25,7 @@ namespace operators { using Tensor = framework::Tensor; template -class GemmConvKernel : public framework::OpKernel { +class GemmConv2dKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); @@ -101,7 +101,7 @@ class GemmConvKernel : public framework::OpKernel { }; template -class GemmConvGradKernel : public framework::OpKernel { +class GemmConvGrad2dKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); From 91afa0d877bd28535c62a361a947b669cf16ed09 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 18 Sep 2017 13:45:24 +0800 Subject: [PATCH 15/19] Some bug fix. --- paddle/operators/conv2d_op.cc | 12 +++++++----- paddle/operators/conv2d_op.cu | 4 ++-- paddle/operators/gemm_conv2d_op.h | 4 ++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc index b74b42546d..3aedab4992 100644 --- a/paddle/operators/conv2d_op.cc +++ b/paddle/operators/conv2d_op.cc @@ -30,7 +30,7 @@ class Conv2DOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { auto in = ctx.Input("Input"); auto filter = ctx.Input("Filter"); - auto out = ctx.Output("Output"); + auto out = ctx.Output("Output"); std::vector strides = Attr>("strides"); std::vector paddings = Attr>("paddings"); int groups = Attr("groups"); @@ -102,8 +102,10 @@ class Conv2DOpGrad : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { auto in = ctx.Input("Input"); auto filter = ctx.Input("Filter"); - auto d_in = ctx.Output(framework::GradVarName("Input")); - auto d_filter = ctx.Output(framework::GradVarName("Filter")); + auto d_in = + ctx.Output(framework::GradVarName("Input")); + auto d_filter = + ctx.Output(framework::GradVarName("Filter")); d_in->Resize(in->dims()); d_filter->Resize(filter->dims()); } @@ -117,6 +119,6 @@ REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad, ops::Conv2DOpGrad); REGISTER_OP_CPU_KERNEL( - conv2d, ops::GemmConv2dKernel); + conv2d, ops::GemmConv2DKernel); REGISTER_OP_CPU_KERNEL( - conv2d_grad, ops::GemmConvGrad2dKernel); + conv2d_grad, ops::GemmConvGrad2DKernel); diff --git a/paddle/operators/conv2d_op.cu b/paddle/operators/conv2d_op.cu index 7666f4c4c1..5df818ba04 100644 --- a/paddle/operators/conv2d_op.cu +++ b/paddle/operators/conv2d_op.cu @@ -17,6 +17,6 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - conv2d, ops::GemmConv2dKernel); + conv2d, ops::GemmConv2DKernel); REGISTER_OP_GPU_KERNEL( - conv2d_grad, ops::GemmConvGrad2dKernel); + conv2d_grad, ops::GemmConvGrad2DKernel); diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/gemm_conv2d_op.h index 71bf09bb7e..a4df7b9cb9 100644 --- a/paddle/operators/gemm_conv2d_op.h +++ b/paddle/operators/gemm_conv2d_op.h @@ -25,7 +25,7 @@ namespace operators { using Tensor = framework::Tensor; template -class GemmConv2dKernel : public framework::OpKernel { +class GemmConv2DKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); @@ -101,7 +101,7 @@ class GemmConv2dKernel : public framework::OpKernel { }; template -class GemmConvGrad2dKernel : public framework::OpKernel { +class GemmConvGrad2DKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); From 5a4138b66b588d05d5d9c7a518fcf407f8cbf693 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 18 Sep 2017 13:47:34 +0800 Subject: [PATCH 16/19] Add test with groups=1. --- .../v2/framework/tests/test_conv2d_op.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 660eb31962..64aeb6e8a9 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -5,6 +5,7 @@ from op_test import OpTest class TestConv2dOp(OpTest): def setUp(self): + self.init_groups() self.op_type = "conv2d" batch_size = 2 input_channels = 3 @@ -15,7 +16,6 @@ class TestConv2dOp(OpTest): filter_width = 3 stride = 1 padding = 0 - groups = 3 output_height = (input_height - filter_height + 2 * padding ) / stride + 1 output_width = (input_width - filter_width + 2 * padding) / stride + 1 @@ -23,18 +23,22 @@ class TestConv2dOp(OpTest): input_width)).astype("float32") filter = np.random.random( - (output_channels, input_channels / groups, filter_height, + (output_channels, input_channels / self.groups, filter_height, filter_width)).astype("float32") output = np.ndarray( (batch_size, output_channels, output_height, output_width)) self.inputs = {'Input': input, 'Filter': filter} - self.attrs = {'strides': [1, 1], 'paddings': [0, 0], 'groups': groups} + self.attrs = { + 'strides': [1, 1], + 'paddings': [0, 0], + 'groups': self.groups + } - output_group_channels = output_channels / groups - input_group_channels = input_channels / groups + output_group_channels = output_channels / self.groups + input_group_channels = input_channels / self.groups for batchid in xrange(batch_size): - for group in xrange(groups): + for group in xrange(self.groups): for outchannelid in range(group * output_group_channels, (group + 1) * output_group_channels): for rowid in xrange(output_height): @@ -71,6 +75,14 @@ class TestConv2dOp(OpTest): def test_check_grad(self): self.check_grad(set(['Input', 'Filter']), 'Output') + def init_groups(self): + self.groups = 1 + + +class TestWithGroup(TestConv2dOp): + def init_groups(self): + self.groups = 3 + if __name__ == '__main__': unittest.main() From 64b0b7568511b7bc72b98098d502a48e068266d2 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 18 Sep 2017 19:18:00 +0800 Subject: [PATCH 17/19] Follow comments fix conv2d_op.cc --- paddle/operators/conv2d_op.cc | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc index 3aedab4992..10091ec6a5 100644 --- a/paddle/operators/conv2d_op.cc +++ b/paddle/operators/conv2d_op.cc @@ -37,7 +37,7 @@ class Conv2DOp : public framework::OperatorWithKernel { int input_channels = in->dims()[1]; int output_channels = filter->dims()[0]; - PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp intput should be 4-D."); + PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp input should be 4-D."); PADDLE_ENFORCE_EQ(filter->dims().size(), 4, "Conv2DOp filter should be 4-D."); PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups, @@ -76,13 +76,10 @@ class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Output", "The output tensor of convolution operator." "The format of output tensor is also NCHW."); - AddComment(R"DOC( -The convolution operation calculates the output based on the input, filter -and strides, paddings, groups parameters. The size of each dimension of the -parameters is checked in the infer-shape. -)DOC"); - AddAttr>("strides", "strides of convolution operator."); - AddAttr>("paddings", "paddings of convolution operator."); + AddAttr>("strides", "strides of convolution operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", "paddings of convolution operator.") + .SetDefault({0, 0}); AddAttr( "groups", "group size of convolution operator. " @@ -91,6 +88,11 @@ parameters is checked in the infer-shape. "first half of the input channels, and the second half only connected " "to the second half.") .SetDefault(1); + AddComment(R"DOC( +The convolution operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +)DOC"); } }; From f3669ca3f18eee7c817f4b72f163734f0daaa001 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 18 Sep 2017 23:48:49 +0800 Subject: [PATCH 18/19] Support input_grad = null or filter_grad = null. --- paddle/operators/conv2d_op.cc | 11 ++- paddle/operators/gemm_conv2d_op.h | 84 ++++++++++++------- .../v2/framework/tests/test_conv2d_op.py | 6 ++ 3 files changed, 68 insertions(+), 33 deletions(-) diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc index 10091ec6a5..12db65b5cb 100644 --- a/paddle/operators/conv2d_op.cc +++ b/paddle/operators/conv2d_op.cc @@ -28,6 +28,13 @@ class Conv2DOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), + "Input(Input) of Conv2DOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"), + "Input(Filter) of Conv2DOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), + "Output(Output) of Conv2DOp should not be null."); + auto in = ctx.Input("Input"); auto filter = ctx.Input("Filter"); auto out = ctx.Output("Output"); @@ -108,8 +115,8 @@ class Conv2DOpGrad : public framework::OperatorWithKernel { ctx.Output(framework::GradVarName("Input")); auto d_filter = ctx.Output(framework::GradVarName("Filter")); - d_in->Resize(in->dims()); - d_filter->Resize(filter->dims()); + if (d_in) d_in->Resize(in->dims()); + if (d_filter) d_filter->Resize(filter->dims()); } }; diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/gemm_conv2d_op.h index a4df7b9cb9..96f4c06005 100644 --- a/paddle/operators/gemm_conv2d_op.h +++ b/paddle/operators/gemm_conv2d_op.h @@ -111,14 +111,16 @@ class GemmConvGrad2DKernel : public framework::OpKernel { context.Output(framework::GradVarName("Input")); Tensor* filter_grad_ = context.Output(framework::GradVarName("Filter")); - input_grad->mutable_data(context.GetPlace()); - filter_grad_->mutable_data(context.GetPlace()); // The filter and filter_grad will be reshaped in the calculations, // so here use an assignment operation, // that avoids modifying the variable in the Scope. Tensor filter = *context.Input("Filter"); - Tensor filter_grad = *filter_grad_; + Tensor filter_grad; + if (filter_grad_) { + filter_grad_->mutable_data(context.GetPlace()); + filter_grad = *filter_grad_; + } std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); @@ -162,12 +164,20 @@ class GemmConvGrad2DKernel : public framework::OpKernel { framework::DDim filter_matrix_shape = {filter.dims()[0], filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); - filter_grad.Resize(filter_matrix_shape); - auto t1 = framework::EigenVector::Flatten(filter_grad); - t1.device(context.GetEigenDevice()) = t1.constant(static_cast(0)); - auto t2 = framework::EigenVector::Flatten(*input_grad); - t2.device(context.GetEigenDevice()) = t2.constant(static_cast(0)); + if (filter_grad_) { + filter_grad.Resize(filter_matrix_shape); + auto t1 = framework::EigenVector::Flatten(filter_grad); + t1.device(context.GetEigenDevice()) = + t1.constant(static_cast(0)); + } + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + auto t2 = framework::EigenVector::Flatten(*input_grad); + t2.device(context.GetEigenDevice()) = + t2.constant(static_cast(0)); + } auto* device_context = const_cast(context.device_context_); @@ -176,35 +186,47 @@ class GemmConvGrad2DKernel : public framework::OpKernel { // convolution backward weight operator: im2col + gemm int in_step = input_channels / groups; int out_step = output_channels / groups; + Tensor in_grad_batch; + Tensor in_batch; for (int i = 0; i < batch_size; i++) { Tensor out_grad_batch = output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + if (input_grad) { + in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); + } + if (filter_grad_) { + in_batch = input->Slice(i, i + 1).Resize(input_shape); + } for (int g = 0; g < groups; g++) { - // gemm Tensor out_grad_slice = out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, true, out_grad_slice, false, - T(1.0), &col_matrix, T(0.0), device_context); - - // col2im - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], - paddings[1], device_context); - - // im2col - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], - device_context); - - // gemm - Tensor filter_grad_slice = - filter_grad.Slice(g * out_step, (g + 1) * out_step); - math::matmul(out_grad_slice, false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0), device_context); + if (input_grad) { + // gemm + Tensor filter_slice = + filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(filter_slice, true, out_grad_slice, false, + T(1.0), &col_matrix, T(0.0), device_context); + + // col2im + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], + paddings[1], device_context); + } + + if (filter_grad_) { + // im2col + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + im2col(in_slice, col, strides[0], strides[1], paddings[0], + paddings[1], device_context); + + // gemm + Tensor filter_grad_slice = + filter_grad.Slice(g * out_step, (g + 1) * out_step); + math::matmul(out_grad_slice, false, col_matrix, true, + T(1.0), &filter_grad_slice, T(1.0), + device_context); + } } } } diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 64aeb6e8a9..3142a60a1a 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -75,6 +75,12 @@ class TestConv2dOp(OpTest): def test_check_grad(self): self.check_grad(set(['Input', 'Filter']), 'Output') + def test_check_grad_no_filter(self): + self.check_grad(['Input'], 'Output', no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + self.check_grad(['Filter'], 'Output', no_grad_set=set(['Input'])) + def init_groups(self): self.groups = 1 From 6c0129af951d3b209300d3635b5cb934f03ab3bb Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 19 Sep 2017 11:15:29 +0800 Subject: [PATCH 19/19] Refine the GemmConvGrad2DKernel. --- paddle/operators/gemm_conv2d_op.h | 69 ++++++++++++++----------------- 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/gemm_conv2d_op.h index 96f4c06005..08b7df1dfe 100644 --- a/paddle/operators/gemm_conv2d_op.h +++ b/paddle/operators/gemm_conv2d_op.h @@ -109,18 +109,13 @@ class GemmConvGrad2DKernel : public framework::OpKernel { context.Input(framework::GradVarName("Output")); Tensor* input_grad = context.Output(framework::GradVarName("Input")); - Tensor* filter_grad_ = + Tensor* filter_grad = context.Output(framework::GradVarName("Filter")); // The filter and filter_grad will be reshaped in the calculations, // so here use an assignment operation, // that avoids modifying the variable in the Scope. Tensor filter = *context.Input("Filter"); - Tensor filter_grad; - if (filter_grad_) { - filter_grad_->mutable_data(context.GetPlace()); - filter_grad = *filter_grad_; - } std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); @@ -165,20 +160,6 @@ class GemmConvGrad2DKernel : public framework::OpKernel { filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); - if (filter_grad_) { - filter_grad.Resize(filter_matrix_shape); - auto t1 = framework::EigenVector::Flatten(filter_grad); - t1.device(context.GetEigenDevice()) = - t1.constant(static_cast(0)); - } - - if (input_grad) { - input_grad->mutable_data(context.GetPlace()); - auto t2 = framework::EigenVector::Flatten(*input_grad); - t2.device(context.GetEigenDevice()) = - t2.constant(static_cast(0)); - } - auto* device_context = const_cast(context.device_context_); @@ -186,22 +167,21 @@ class GemmConvGrad2DKernel : public framework::OpKernel { // convolution backward weight operator: im2col + gemm int in_step = input_channels / groups; int out_step = output_channels / groups; - Tensor in_grad_batch; - Tensor in_batch; - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - if (input_grad) { - in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); - } - if (filter_grad_) { - in_batch = input->Slice(i, i + 1).Resize(input_shape); - } - for (int g = 0; g < groups; g++) { - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - if (input_grad) { + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); math::matmul(filter_slice, true, out_grad_slice, false, @@ -213,16 +193,31 @@ class GemmConvGrad2DKernel : public framework::OpKernel { col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); } + } + } - if (filter_grad_) { + if (filter_grad) { + filter_grad->mutable_data(context.GetPlace()); + Tensor filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); + auto t = framework::EigenVector::Flatten(filter_grad_); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { // im2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // gemm Tensor filter_grad_slice = - filter_grad.Slice(g * out_step, (g + 1) * out_step); + filter_grad_.Slice(g * out_step, (g + 1) * out_step); math::matmul(out_grad_slice, false, col_matrix, true, T(1.0), &filter_grad_slice, T(1.0), device_context);