diff --git a/paddle/function/PadOp.cpp b/paddle/function/PadOp.cpp index 75e64a8ee4..df44fd0fa6 100644 --- a/paddle/function/PadOp.cpp +++ b/paddle/function/PadOp.cpp @@ -24,20 +24,19 @@ void Pad(real* outputs, const int inC, const int inH, const int inW, - const int padc0, - const int padc1, - const int padh0, - const int padh1, - const int padw0, - const int padw1) { - int outC = inC + padc0 + padc1; - int outH = inH + padh0 + padh1; - int outW = inW + padw0 + padw1; + const PadConf& pad) { + int cstart = pad.channelStart, cend = pad.channelEnd; + int hstart = pad.heightStart, hend = pad.heightEnd; + int wstart = pad.widthStart, wend = pad.widthEnd; + int outC = inC + cstart + cend; + int outH = inH + hstart + hend; + int outW = inW + wstart + wend; for (int i = 0; i < num; i++) { for (int c = 0; c < inC; c++) { for (int h = 0; h < inH; h++) { int inoff = ((i * inC + c) * inH + h) * inW; - int outoff = ((i * outC + c + padc0) * outH + h + padh0) * outW + padw0; + int outoff = + ((i * outC + c + cstart) * outH + h + hstart) * outW + wstart; memcpy(outputs + outoff, inputs + inoff, inW * sizeof(real)); } } @@ -51,20 +50,19 @@ void PadGrad(real* inGrad, const int inC, const int inH, const int inW, - const int padc0, - const int padc1, - const int padh0, - const int padh1, - const int padw0, - const int padw1) { - int outC = inC + padc0 + padc1; - int outH = inH + padh0 + padh1; - int outW = inW + padw0 + padw1; + const PadConf& pad) { + int cstart = pad.channelStart, cend = pad.channelEnd; + int hstart = pad.heightStart, hend = pad.heightEnd; + int wstart = pad.widthStart, wend = pad.widthEnd; + int outC = inC + cstart + cend; + int outH = inH + hstart + hend; + int outW = inW + wstart + wend; for (int i = 0; i < num; i++) { for (int c = 0; c < inC; c++) { for (int h = 0; h < inH; h++) { int inoff = ((i * inC + c) * inH + h) * inW; - int outoff = ((i * outC + c + padc0) * outH + h + padh0) * outW + padw0; + int outoff = + ((i * outC + c + cstart) * outH + h + hstart) * outW + wstart; CpuVector inG = CpuVector(inW, inGrad + inoff); CpuVector outG = CpuVector(inW, const_cast(outGrad + outoff)); inG += outG; @@ -73,22 +71,71 @@ void PadGrad(real* inGrad, } } +/** + * \brief Padding zeros to input according to the specify dimension. + * The struct pad_ contains the padding size in each dimension. + * The input and output is a 4D tensor. In PadFunc, we only + * pad zeros to the 2nd to 4th dimension. + * + * Argument in this Function: + * \param pad_ A struct object contains the padding size in each dimension. + * It has six integers. The channelStart and channelEnd indicates + * how many zeros to add before and after the input in channel + * dimension. And the heightStart and heightEnd indicates padding + * in height dimension. The widthStart and widthEnd indicates the + * padding in width dimension. + * \param inputs A 4D tensor, only one input. + * \param outputs A 4D tensor, the output value after padding. + * + * For example, + * Input(2,2,2,3) = [ + * [ [[1,2,3], [3,4,5]], + * [[2,3,5], [1,6,7]] ], + * [ [[4,3,1], [1,8,7]], + * [[3,8,9], [2,3,5]] ] + * ] # the shape is (1,2,2,3) + * + * pad_: if channelStart = channelEnd = 1, others are 0. + * Output(2,4,2,3) = [ + * [ [[0,0,0], [0,0,0]], + * [[1,2,3], [3,4,5]], + * [[2,3,5], [1,6,7]], + * [[0,0,0], [0,0,0]] ], + * [ [[0,0,0], [0,0,0]], + * [[4,3,1], [1,8,7]], + * [[3,8,9], [2,3,5]], + * [[0,0,0], [0,0,0]] ] + * ] # the shape is (2,4,2,3) + * + * pad_: if widthStart = 1, widthEnd = 2, others are 0. + * Output(2,2,2,6) = [ + * [ [[0,1,2,3,0,0], [0,3,4,5,0,0]], + * [[0,2,3,5,0,0], [0,1,6,7,0,0]] ], + * [ [[0,4,3,1,0,0], [0,1,8,7,0,0]], + * [[0,3,8,9,0,0], [0,2,3,5,0,0]] ], + * ] # the shape is (2,2,2,6) + * + * pad_: if heightStart = 1, heightEnd = 1, others are 0. + * Output(2,2,4,3) = [ + * [ [[0,0,0], [1,2,3], [3,4,5], [0,0,0]], + * [[0,0,0], [2,3,5], [1,6,7], [0,0,0]] ], + * [ [[0,0,0], [4,3,1], [1,8,7], [0,0,0]], + * [[0,0,0], [3,8,9], [2,3,5], [0,0,0]] ], + * ] # the shape is (2,2,4,3) + */ + template class PadFunc : public FunctionBase { public: void init(const FuncConfig& config) override { - padc0_ = config.get("padc0"); - padc1_ = config.get("padc1"); - padh0_ = config.get("padh0"); - padh1_ = config.get("padh1"); - padw0_ = config.get("padw0"); - padw1_ = config.get("padw1"); + pad_.channelStart = config.get("cstart"); + pad_.channelEnd = config.get("cend"); + pad_.heightStart = config.get("hstart"); + pad_.heightEnd = config.get("hend"); + pad_.widthStart = config.get("wstart"); + pad_.widthEnd = config.get("wend"); } - /** - * \param inputs[0] input value. - * \param outputs[0] output value. - */ void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1UL, inputs.size()); CHECK_EQ(1UL, outputs.size()); @@ -108,39 +155,35 @@ public: inC, inH, inW, - padc0_, - padc1_, - padh0_, - padh1_, - padw0_, - padw1_); + pad_); } private: - int padc0_; - int padc1_; - int padh0_; - int padh1_; - int padw0_; - int padw1_; + PadConf pad_; }; +/** + * \brief The backward propagation of padding Function. Remove the elements + * in the padding positions of forward. + * + * Argument in this Function: + * \param pad_ The same meaning as it in PadFunc. + * \param inputs The gradient with respect to the output value of PadFunc. + * \param outputs The gradient with respect to the input value of PadFunc. + */ + template class PadGradFunc : public FunctionBase { public: void init(const FuncConfig& config) override { - padc0_ = config.get("padc0"); - padc1_ = config.get("padc1"); - padh0_ = config.get("padh0"); - padh1_ = config.get("padh1"); - padw0_ = config.get("padw0"); - padw1_ = config.get("padw1"); + pad_.channelStart = config.get("cstart"); + pad_.channelEnd = config.get("cend"); + pad_.heightStart = config.get("hstart"); + pad_.heightEnd = config.get("hend"); + pad_.widthStart = config.get("wstart"); + pad_.widthEnd = config.get("wend"); } - /** - * \param inputs[0] output grad. - * \param inouts[0] input grad. - */ void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1UL, inputs.size()); CHECK_EQ(1UL, outputs.size()); @@ -163,21 +206,11 @@ public: inC, inH, inW, - padc0_, - padc1_, - padh0_, - padh1_, - padw0_, - padw1_); + pad_); } private: - int padc0_; - int padc1_; - int padh0_; - int padh1_; - int padw0_; - int padw1_; + PadConf pad_; }; REGISTER_TYPED_FUNC(Pad, CPU, PadFunc); diff --git a/paddle/function/PadOp.h b/paddle/function/PadOp.h index 4a5e8fe338..7b5c730a6a 100644 --- a/paddle/function/PadOp.h +++ b/paddle/function/PadOp.h @@ -18,29 +18,34 @@ limitations under the License. */ namespace paddle { +struct PadConf { + /// how many values to add before the data along channel dimension. + int channelStart; + /// how many values to add after the data along channel dimension. + int channelEnd; + /// how many values to add before the data along height dimension. + int heightStart; + /// how many values to add after the data along height dimension. + int heightEnd; + /// how many values to add before the data along width dimension. + int widthStart; + /// how many values to add after the data along width dimension. + int widthEnd; +}; + /** * \brief This funtion pads zeros to inputs according to the specify dimension. - * The data structure of image data is NCHW. - * - * \param[out] outputs save results. - * \param[in] inputs input data. - * \param[in] num batch size of input data. - * \param[in] inC channel number of input data. - * \param[in] inH height of input data. - * \param[in] inH with of input data. - * \param[in] padc0 how many values to add before the data in dimension of - * channel. - * \param[in] padc1 how many values to add after the data in dimension of - * channel. - * \param[in] padh0 how many values to add before the data in dimension of - * height. - * \param[in] padh1 how many values to add after the data in dimension of - * height. - * \param[in] padw0 how many values to add before the data in dimension of - * width. - * \param[in] padw1 how many values to add after the data in dimension of - * width. + * The input and output is a 4D tensor. Padding zeros from the 2nd to + * the 4th dimenstion according argument of pad. * + * \param[out] outputs save results. + * \param[in] inputs input data. + * \param[in] num batch size of input data. + * \param[in] inC channel number of input data. + * \param[in] inH height of input data. + * \param[in] inH with of input data. + * \param[in] pad the padding config, contains the size along the + * specify dimension. */ template void Pad(real* outputs, @@ -49,36 +54,19 @@ void Pad(real* outputs, const int inC, const int inH, const int inW, - const int padc0, - const int padc1, - const int padh0, - const int padh1, - const int padw0, - const int padw1); + const PadConf& pad); /** * \brief Padding operation backward. - * The data structure of image data is NCHW. - * - * \param[out] inGrad gradients of previous layer. - * \param[in] outGrad output gradients. - * \param[in] num batch size of input data. - * \param[in] inC channel number of input data. - * \param[in] inH height of input data. - * \param[in] inH with of input data. - * \param[in] padc0 how many values to add before the data in dimension of - * channel. - * \param[in] padc1 how many values to add after the data in dimension of - * channel. - * \param[in] padh0 how many values to add before the data in dimension of - * height. - * \param[in] padh1 how many values to add after the data in dimension of - * height. - * \param[in] padw0 how many values to add before the data in dimension of - * width. - * \param[in] padw1 how many values to add after the data in dimension of - * width. * + * \param[out] inGrad gradients of previous layer. + * \param[in] outGrad output gradients. + * \param[in] num batch size of input data. + * \param[in] inC channel number of input data. + * \param[in] inH height of input data. + * \param[in] inH with of input data. + * \param[in] pad the padding config, contains the size along the + * specify dimension. */ template void PadGrad(real* inGrad, @@ -87,10 +75,5 @@ void PadGrad(real* inGrad, const int inC, const int inH, const int inW, - const int padc0, - const int padc1, - const int padh0, - const int padh1, - const int padw0, - const int padw1); + const PadConf& pad); } // namespace paddle diff --git a/paddle/function/PadOpGpu.cu b/paddle/function/PadOpGpu.cu index 578d6e86d7..9104b1aca5 100644 --- a/paddle/function/PadOpGpu.cu +++ b/paddle/function/PadOpGpu.cu @@ -40,20 +40,18 @@ void Pad(real* outputs, const int inC, const int inH, const int inW, - const int padc0, - const int padc1, - const int padh0, - const int padh1, - const int padw0, - const int padw1) { + const PadConf& pad) { size_t nth = num * inC * inH * inW; int blockSize = 1024; int gridSize = (nth + 1024 - 1) / 1024; - int outC = inC + padc0 + padc1; - int outH = inH + padh0 + padh1; - int outW = inW + padw0 + padw1; + int cstart = pad.channelStart, cend = pad.channelEnd; + int hstart = pad.heightStart, hend = pad.heightEnd; + int wstart = pad.widthStart, wend = pad.widthEnd; + int outC = inC + cstart + cend; + int outH = inH + hstart + hend; + int outW = inW + wstart + wend; KePad<<>> - (outputs, inputs, inC, inH, inW, padc0, padh0, padw0, + (outputs, inputs, inC, inH, inW, cstart, hstart, wstart, outC, outH, outW, nth); CHECK_SYNC("Pad"); } @@ -81,20 +79,18 @@ void PadGrad(real* inGrad, const int inC, const int inH, const int inW, - const int padc0, - const int padc1, - const int padh0, - const int padh1, - const int padw0, - const int padw1) { + const PadConf& pad) { int nth = num * inC * inH * inW; int blockSize = 1024; int gridSize = (nth + 1024 - 1) / 1024; - int outC = inC + padc0 + padc1; - int outH = inH + padh0 + padh1; - int outW = inW + padw0 + padw1; + int cstart = pad.channelStart, cend = pad.channelEnd; + int hstart = pad.heightStart, hend = pad.heightEnd; + int wstart = pad.widthStart, wend = pad.widthEnd; + int outC = inC + cstart + cend; + int outH = inH + hstart + hend; + int outW = inW + wstart + wend; KePadDiff <<>> - (inGrad, outGrad, inC, inH, inW, padc0, padh0, padw0, + (inGrad, outGrad, inC, inH, inW, cstart, hstart, wstart, outC, outH, outW, nth); CHECK_SYNC("PadGrad"); } diff --git a/paddle/function/PadOpTest.cpp b/paddle/function/PadOpTest.cpp index dce2bac3e9..cd22d91135 100644 --- a/paddle/function/PadOpTest.cpp +++ b/paddle/function/PadOpTest.cpp @@ -27,12 +27,12 @@ TEST(Pad, real) { FunctionCompare compare("Pad", FuncConfig() - .set("padc0", 2) - .set("padc1", 3) - .set("padh0", 1) - .set("padh1", 2) - .set("padw0", 3) - .set("padw1", 2)); + .set("cstart", 2) + .set("cend", 3) + .set("hstart", 1) + .set("hend", 2) + .set("wstart", 3) + .set("wend", 2)); TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; TensorShape outDims{ numSamples, channels + 5, imgSizeH + 3, imgSizeW + 5}; @@ -54,12 +54,12 @@ TEST(PadGrad, real) { << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; FunctionCompare compare("PadGrad", FuncConfig() - .set("padc0", 2) - .set("padc1", 3) - .set("padh0", 1) - .set("padh1", 2) - .set("padw0", 3) - .set("padw1", 2)); + .set("cstart", 2) + .set("cend", 3) + .set("hstart", 1) + .set("hend", 2) + .set("wstart", 3) + .set("wend", 2)); TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; TensorShape outDims{ numSamples, channels + 5, imgSizeH + 3, imgSizeW + 5}; diff --git a/paddle/gserver/layers/PadLayer.cpp b/paddle/gserver/layers/PadLayer.cpp index a2a469ff92..bb618c09f9 100644 --- a/paddle/gserver/layers/PadLayer.cpp +++ b/paddle/gserver/layers/PadLayer.cpp @@ -49,21 +49,21 @@ bool PadLayer::init(const LayerMap& layerMap, createFunction(forward_, "Pad", FuncConfig() - .set("padc0", padc_[0]) - .set("padc1", padc_[1]) - .set("padh0", padh_[0]) - .set("padh1", padh_[1]) - .set("padw0", padw_[0]) - .set("padw1", padw_[1])); + .set("cstart", padc_[0]) + .set("cend", padc_[1]) + .set("hstart", padh_[0]) + .set("hend", padh_[1]) + .set("wstart", padw_[0]) + .set("wend", padw_[1])); createFunction(backward_, "PadGrad", FuncConfig() - .set("padc0", padc_[0]) - .set("padc1", padc_[1]) - .set("padh0", padh_[0]) - .set("padh1", padh_[1]) - .set("padw0", padw_[0]) - .set("padw1", padw_[1])); + .set("cstart", padc_[0]) + .set("cend", padc_[1]) + .set("hstart", padh_[0]) + .set("hend", padh_[1]) + .set("wstart", padw_[0]) + .set("wend", padw_[1])); return true; } diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 66817fc93b..85a28e14ae 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -3617,6 +3617,31 @@ def pad_layer(input, input data and 3 zeros after the input data in channel dimension. pad_h means padding zeros in height dimension. pad_w means padding zeros in width dimension. + + For example, + + .. code-block:: + + input(2,2,2,3) = [ + [ [[1,2,3], [3,4,5]], + [[2,3,5], [1,6,7]] ], + [ [[4,3,1], [1,8,7]], + [[3,8,9], [2,3,5]] ] + ] + + pad_c=[1,1], pad_h=[0,0], pad_w=[0,0] + output(2,4,2,3) = [ + [ [[0,0,0], [0,0,0]], + [[1,2,3], [3,4,5]], + [[2,3,5], [1,6,7]], + [[0,0,0], [0,0,0]] ], + [ [[0,0,0], [0,0,0]], + [[4,3,1], [1,8,7]], + [[3,8,9], [2,3,5]], + [[0,0,0], [0,0,0]] ] + ] + + The simply usage is: .. code-block:: python