|
|
|
@ -19,20 +19,36 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
* Function Arguments:
|
|
|
|
|
* \brief Based on the ConvFunctionBase class, the forward calculation,
|
|
|
|
|
* backward input calculation and backward filter calculation
|
|
|
|
|
* of convolution operations can be implemented.
|
|
|
|
|
*
|
|
|
|
|
* \param inputs[0] Input image data, is NCHW format, where N is batch size,
|
|
|
|
|
* C is the number of channels, H and W is the height and
|
|
|
|
|
* width of input image.
|
|
|
|
|
* \param inputs[1] Filter data, is MCHW, where M is the number of output
|
|
|
|
|
* channels, C is the number of input channels, H and W
|
|
|
|
|
* is height and width of filter.
|
|
|
|
|
* \param outputs[0] Output image data, is NCHW format, where N is batch size,
|
|
|
|
|
* C is the number of channels, H and W is the height and
|
|
|
|
|
* width of output image.
|
|
|
|
|
* Arguments of forward and backward calculation:
|
|
|
|
|
* 1. Forward calculation of convolution.
|
|
|
|
|
* inputs = {INPUT, FILTER}, outputs = {OUTPUT}
|
|
|
|
|
* The first and second input arguments are input image and filter data.
|
|
|
|
|
* The output argument is output image.
|
|
|
|
|
*
|
|
|
|
|
* \note Implemented based on the ConvFunctionBase class only supports
|
|
|
|
|
* input data in the NCHW format.
|
|
|
|
|
* 2. Backward input calculation of convolution.
|
|
|
|
|
* inputs = {OUTPUT_GRAD, FILTER}, outputs = {INPUT_GRAD}
|
|
|
|
|
* The first and second input arguments are output grad image
|
|
|
|
|
* and filter data.
|
|
|
|
|
* The output argument is input grad image.
|
|
|
|
|
*
|
|
|
|
|
* 3. Backward filter calculation of convolution.
|
|
|
|
|
* inputs = {OUTPUT_GRAD, INPUT}, outputs = {FILTER_GRAD}
|
|
|
|
|
* The first and second input arguments are output grad image
|
|
|
|
|
* and input image.
|
|
|
|
|
* The output argument is filter grad.
|
|
|
|
|
*
|
|
|
|
|
* Arguments format of input, filter and output:
|
|
|
|
|
* 1. Input image, output image, input image gradient, output image gradient
|
|
|
|
|
* are all NCHW format. Where N is batch size, C is the number of channels,
|
|
|
|
|
* H and W is the height and width of image or image gradient.
|
|
|
|
|
*
|
|
|
|
|
* 2. The format of the filter data is MCHW, where M is the number of
|
|
|
|
|
* output image channels, C is the number of input image channels,
|
|
|
|
|
* H and W is height and width of filter.
|
|
|
|
|
*/
|
|
|
|
|
class ConvFunctionBase : public FunctionBase {
|
|
|
|
|
public:
|
|
|
|
@ -49,17 +65,25 @@ public:
|
|
|
|
|
|
|
|
|
|
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
|
|
|
|
|
|
|
|
|
|
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ(numInputs_, inputs.size());
|
|
|
|
|
CHECK_EQ(numOutputs_, outputs.size());
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
|
|
|
|
|
CHECK_EQ(inputs[1].shape().ndims(), (size_t)4);
|
|
|
|
|
CHECK_EQ(outputs[0].shape().ndims(), (size_t)4);
|
|
|
|
|
|
|
|
|
|
CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]);
|
|
|
|
|
CHECK(inputs[0].shape()[1] / groups_ == inputs[1].shape()[1]);
|
|
|
|
|
CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]);
|
|
|
|
|
// input can be INPUT and INPUT_GRAD
|
|
|
|
|
// filter can be FILTER and FILTER_GRAD
|
|
|
|
|
// output can be OUTPUT and OUTPUT_GRAD
|
|
|
|
|
void check(const TensorShape& input,
|
|
|
|
|
const TensorShape& filter,
|
|
|
|
|
const TensorShape& output) {
|
|
|
|
|
// inputs and outputs arguments should be 4-dimensional.
|
|
|
|
|
CHECK_EQ(input.ndims(), (size_t)4);
|
|
|
|
|
CHECK_EQ(filter.ndims(), (size_t)4);
|
|
|
|
|
CHECK_EQ(output.ndims(), (size_t)4);
|
|
|
|
|
|
|
|
|
|
// The batchSize of the input needs to be equal to
|
|
|
|
|
// the batchSize of the output.
|
|
|
|
|
CHECK_EQ(input[0], output[0]);
|
|
|
|
|
|
|
|
|
|
// The input and output channel dimensions are the second and first
|
|
|
|
|
// dimensions of the filter shape.
|
|
|
|
|
CHECK_EQ(input[1] / groups_, filter[1]);
|
|
|
|
|
CHECK_EQ(output[1], filter[0]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|