Modify the arguments description of ConvFunctionBase. And add the definition of backward input and backward filter function.

gangliao-patch-1
hedaoyuan 8 years ago
parent 3408b4b2f4
commit afbe556e56

@ -19,20 +19,36 @@ limitations under the License. */
namespace paddle { namespace paddle {
/* /*
* Function Arguments: * \brief Based on the ConvFunctionBase class, the forward calculation,
* backward input calculation and backward filter calculation
* of convolution operations can be implemented.
* *
* \param inputs[0] Input image data, is NCHW format, where N is batch size, * Arguments of forward and backward calculation:
* C is the number of channels, H and W is the height and * 1. Forward calculation of convolution.
* width of input image. * inputs = {INPUT, FILTER}, outputs = {OUTPUT}
* \param inputs[1] Filter data, is MCHW, where M is the number of output * The first and second input arguments are input image and filter data.
* channels, C is the number of input channels, H and W * The output argument is output image.
* is height and width of filter.
* \param outputs[0] Output image data, is NCHW format, where N is batch size,
* C is the number of channels, H and W is the height and
* width of output image.
* *
* \note Implemented based on the ConvFunctionBase class only supports * 2. Backward input calculation of convolution.
* input data in the NCHW format. * inputs = {OUTPUT_GRAD, FILTER}, outputs = {INPUT_GRAD}
* The first and second input arguments are output grad image
* and filter data.
* The output argument is input grad image.
*
* 3. Backward filter calculation of convolution.
* inputs = {OUTPUT_GRAD, INPUT}, outputs = {FILTER_GRAD}
* The first and second input arguments are output grad image
* and input image.
* The output argument is filter grad.
*
* Arguments format of input, filter and output:
* 1. Input image, output image, input image gradient, output image gradient
* are all NCHW format. Where N is batch size, C is the number of channels,
* H and W is the height and width of image or image gradient.
*
* 2. The format of the filter data is MCHW, where M is the number of
* output image channels, C is the number of input image channels,
* H and W is height and width of filter.
*/ */
class ConvFunctionBase : public FunctionBase { class ConvFunctionBase : public FunctionBase {
public: public:
@ -49,17 +65,25 @@ public:
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
void check(const BufferArgs& inputs, const BufferArgs& outputs) override { // input can be INPUT and INPUT_GRAD
CHECK_EQ(numInputs_, inputs.size()); // filter can be FILTER and FILTER_GRAD
CHECK_EQ(numOutputs_, outputs.size()); // output can be OUTPUT and OUTPUT_GRAD
void check(const TensorShape& input,
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); const TensorShape& filter,
CHECK_EQ(inputs[1].shape().ndims(), (size_t)4); const TensorShape& output) {
CHECK_EQ(outputs[0].shape().ndims(), (size_t)4); // inputs and outputs arguments should be 4-dimensional.
CHECK_EQ(input.ndims(), (size_t)4);
CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]); CHECK_EQ(filter.ndims(), (size_t)4);
CHECK(inputs[0].shape()[1] / groups_ == inputs[1].shape()[1]); CHECK_EQ(output.ndims(), (size_t)4);
CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]);
// The batchSize of the input needs to be equal to
// the batchSize of the output.
CHECK_EQ(input[0], output[0]);
// The input and output channel dimensions are the second and first
// dimensions of the filter shape.
CHECK_EQ(input[1] / groups_, filter[1]);
CHECK_EQ(output[1], filter[0]);
} }
protected: protected:

@ -68,17 +68,7 @@ public:
}; };
/* /*
* Function Arguments: * \brief Forward calculation of convolution.
*
* \param inputs[0] Input image data, is NCHW format, where N is batch size,
* C is the number of channels, H and W is the height and
* width of input image.
* \param inputs[1] Filter data, is MCHW, where M is the number of output
* channels, C is the number of input channels, H and W
* is height and width of filter.
* \param outputs[0] Output image data, is NCHW format, where N is batch size,
* C is the number of channels, H and W is the height and
* width of output image.
*/ */
template <DeviceType Device> template <DeviceType Device>
class GemmConvFunction : public ConvFunctionBase { class GemmConvFunction : public ConvFunctionBase {
@ -88,8 +78,21 @@ public:
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
check(inputs, outputs); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(numOutputs_, outputs.size());
// TODO(hedaoyuan): Need to define some index macros,
// to avoid useing 0 and 1.
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
check(input, filter, output);
real beta;
if (outputs[0].getArgType() == ADD_TO) {
beta = 1.0;
} else {
beta = 0.0;
}
size_t batchSize = inputs[0].shape()[0]; size_t batchSize = inputs[0].shape()[0];
size_t inputChannels = inputs[0].shape()[1]; size_t inputChannels = inputs[0].shape()[1];
@ -143,7 +146,7 @@ public:
K, K,
colData, colData,
N, N,
0.0f, beta,
outputData + g * outputOffset, outputData + g * outputOffset,
N); N);
} }
@ -166,9 +169,53 @@ private:
MemoryHandlePtr memory_; MemoryHandlePtr memory_;
}; };
/*
* \brief Backward input calculation of convolution.
*/
template <DeviceType Device>
class GemmConvGradInputFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& outputGrad = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& inputGrad = outputs[0].shape();
check(inputGrad, filter, outputGrad);
}
};
/*
* \brief Backward filter calculation of convolution.
*/
template <DeviceType Device>
class GemmConvGradFilterFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& outputGrad = inputs[0].shape();
const TensorShape& input = inputs[1].shape();
const TensorShape& filterGrad = outputs[0].shape();
check(input, filterGrad, outputGrad);
}
};
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction); REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction);
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction); REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction);
REGISTER_TYPED_FUNC(GemmConvGradInput, GPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, GPU, GemmConvGradFilterFunction);
#endif #endif
} // namespace paddle } // namespace paddle

@ -91,7 +91,12 @@ public:
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
check(inputs, outputs); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
check(input, filter, output);
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
size_t batchSize = inputs[0].shape()[0]; size_t batchSize = inputs[0].shape()[0];

Loading…
Cancel
Save