parent
							
								
									830877f59b
								
							
						
					
					
						commit
						e10040ca8a
					
				| @ -0,0 +1,177 @@ | ||||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #include "CropOp.h" | ||||
| #include "paddle/math/Vector.h" | ||||
| #include "paddle/function/TensorShape.h" | ||||
| namespace paddle { | ||||
| 
 | ||||
| static inline CropConf castToCropConf(const FuncConfig& conf) { | ||||
|   return {conf.get<std::vector<uint32_t>>("crop_corner"), | ||||
|           conf.get<std::vector<uint32_t>>("crop_shape")}; | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| void Crop<DEVICE_TYPE_CPU>(real* outputs, | ||||
|                            const real* inputs, | ||||
|                            const TensorShape inShape, | ||||
|                            const CropConf& crop) { | ||||
|   int cCrop = crop.corner[0]; | ||||
|   int hCrop = crop.corner[1]; | ||||
|   int wCrop = crop.corner[2]; | ||||
| 
 | ||||
|   int num = inShape[0]; | ||||
|   int inC = inShape[1]; | ||||
|   int inH = inShape[2]; | ||||
|   int inW = inShape[3]; | ||||
| 
 | ||||
|   int outC = crop.shape[0]; | ||||
|   int outH = crop.shape[1]; | ||||
|   int outW = crop.shape[2]; | ||||
| 
 | ||||
|   for (int n = 0; n < num; n++) { | ||||
|     for (int c = 0; c < outC; c++) { | ||||
|       for (int h = 0; h < outH; h++) { | ||||
|         int outoff = ((n * outC + c) * outH + h) * outW; | ||||
|         int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop; | ||||
|         memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real)); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| void CropGrad<DEVICE_TYPE_CPU>(const real* inGrad, | ||||
|                                real* outGrad, | ||||
|                                const TensorShape outShape, | ||||
|                                const CropConf& crop) { | ||||
|   int cCrop = crop.corner[0]; | ||||
|   int hCrop = crop.corner[1]; | ||||
|   int wCrop = crop.corner[2]; | ||||
| 
 | ||||
|   int num = outShape[0]; | ||||
|   int outC = outShape[1]; | ||||
|   int outH = outShape[2]; | ||||
|   int outW = outShape[3]; | ||||
| 
 | ||||
|   int inC = crop.shape[0]; | ||||
|   int inH = crop.shape[1]; | ||||
|   int inW = crop.shape[2]; | ||||
| 
 | ||||
|   for (int n = 0; n < num; n++) { | ||||
|     for (int c = 0; c < inC; c++) { | ||||
|       for (int h = 0; h < inH; h++) { | ||||
|         int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop; | ||||
|         int inoff = ((n * inC + c) * inH + h) * inW; | ||||
|         CpuVector inG = CpuVector(inW, const_cast<real*>(inGrad + inoff)); | ||||
|         CpuVector outG = CpuVector(inW, outGrad + outoff); | ||||
|         outG += inG; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /**
 | ||||
|  * \brief Crop input according to the specify corner and shape. | ||||
|  *        The input and output is a 4D tensor. In CropFunc, we only | ||||
|  *        crop the 2nd to 4th dimension. | ||||
|  * | ||||
|  * Argument in this Function: | ||||
|  * \param pad_    A struct object contains the cropping corner and shape. | ||||
|  * \param inputs  A 4D tensor, only one input. | ||||
|  * \param outputs A 4D tensor, the output value after cropping. | ||||
|  * | ||||
|  * For example, | ||||
|  * Input(2,2,2,3) = [ | ||||
|  *                    [ [[1,2,3], [3,4,5]], | ||||
|  *                      [[2,3,5], [1,6,7]] ], | ||||
|  *                    [ [[4,3,1], [1,8,7]], | ||||
|  *                      [[3,8,9], [2,3,5]] ] | ||||
|  *                  ] # the input shape is (2,2,2,3) | ||||
|  * | ||||
|  * pad_: if corner = (0,1,1) and crop_shape = (2,1,2) | ||||
|  * Output(2,2,1,2) = [ | ||||
|  *                    [ [[4,5]], | ||||
|  *                      [[6,7]] ], | ||||
|  *                    [ [[8,7]], | ||||
|  *                      [[3,5]] ] | ||||
|  *                  ] # the input shape is (2,2,2,3) | ||||
|  */ | ||||
| template <DeviceType Device> | ||||
| class CropFunc : public FunctionBase { | ||||
| public: | ||||
|   void init(const FuncConfig& config) override { | ||||
|     crop_ = castToCropConf(config); | ||||
|   } | ||||
| 
 | ||||
|   void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { | ||||
|     CHECK_EQ(1UL, inputs.size()); | ||||
|     CHECK_EQ(1UL, outputs.size()); | ||||
|     CHECK_EQ(outputs[0].shape()[1], crop_.shape[0]); | ||||
|     CHECK_EQ(outputs[0].shape()[2], crop_.shape[1]); | ||||
|     CHECK_EQ(outputs[0].shape()[3], crop_.shape[2]); | ||||
|     CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); | ||||
| 
 | ||||
|     TensorShape inShape = inputs[0].shape(); | ||||
| 
 | ||||
|     Crop<Device>( | ||||
|         outputs[0].data<real>(), inputs[0].data<real>(), inShape, crop_); | ||||
|   } | ||||
| 
 | ||||
| private: | ||||
|   CropConf crop_; | ||||
| }; | ||||
| 
 | ||||
| /**
 | ||||
|  * \brief The backward propagation of cropping Function. | ||||
|  * | ||||
|  * Argument in this Function: | ||||
|  * \param crop_    The same meaning as it in CropFunc. | ||||
|  * \param inputs  The gradient with respect to the output value of CropFunc. | ||||
|  * \param outputs The gradient with respect to the input value of CropFunc. | ||||
|  */ | ||||
| 
 | ||||
| template <DeviceType Device> | ||||
| class CropGradFunc : public FunctionBase { | ||||
| public: | ||||
|   void init(const FuncConfig& config) override { | ||||
|     crop_ = castToCropConf(config); | ||||
|   } | ||||
| 
 | ||||
|   void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { | ||||
|     CHECK_EQ(1UL, inputs.size()); | ||||
|     CHECK_EQ(1UL, outputs.size()); | ||||
|     CHECK_EQ(inputs[0].shape()[1], crop_.shape[0]); | ||||
|     CHECK_EQ(inputs[0].shape()[2], crop_.shape[1]); | ||||
|     CHECK_EQ(inputs[0].shape()[3], crop_.shape[2]); | ||||
|     CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); | ||||
| 
 | ||||
|     TensorShape outShape = outputs[0].shape(); | ||||
| 
 | ||||
|     CropGrad<Device>( | ||||
|         inputs[0].data<real>(), outputs[0].data<real>(), outShape, crop_); | ||||
|   } | ||||
| 
 | ||||
| private: | ||||
|   CropConf crop_; | ||||
| }; | ||||
| 
 | ||||
| REGISTER_TYPED_FUNC(Crop, CPU, CropFunc); | ||||
| REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc); | ||||
| #ifndef PADDLE_ONLY_CPU | ||||
| REGISTER_TYPED_FUNC(Crop, GPU, CropFunc); | ||||
| REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc); | ||||
| #endif | ||||
| 
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,56 @@ | ||||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include "Function.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| 
 | ||||
| struct CropConf { | ||||
|   /// The upper left corner of croped result
 | ||||
|   std::vector<uint32_t> corner; | ||||
|   /// The shape of croped result
 | ||||
|   std::vector<uint32_t> shape; | ||||
| }; | ||||
| 
 | ||||
| /**
 | ||||
|  * \brief  This funtion crops inputs according to the specify start point and | ||||
|  *shape. | ||||
|  * | ||||
|  * \param[out] outputs	save results. | ||||
|  * \param[in]  inputs	input data. | ||||
|  * \param[in]  inShape  the shape of input tensor. | ||||
|  * \param[in]  crop     the cropping config | ||||
|  */ | ||||
| template <DeviceType Device> | ||||
| void Crop(real* outputs, | ||||
|           const real* inputs, | ||||
|           const TensorShape inShape, | ||||
|           const CropConf& crop); | ||||
| 
 | ||||
| /**
 | ||||
|  * \brief   Cropping operation backward. | ||||
|  * | ||||
|  * \param[out] inGrad	gradients of previous layer | ||||
|  * \param[in]  outGrad  output gradient | ||||
|  * \param[in]  inShape  the shape of input tensor. | ||||
|  * \param[in]  crop     the cropping config | ||||
|  */ | ||||
| template <DeviceType Device> | ||||
| void CropGrad(const real* inGrad, | ||||
|               real* outGrad, | ||||
|               const TensorShape inShape, | ||||
|               const CropConf& crop); | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,109 @@ | ||||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #include "hl_base.h" | ||||
| #include "CropOp.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| 
 | ||||
| __global__ void KeCrop(real* outputs, const real* inputs, | ||||
|                       int inC, int inH, int inW, | ||||
|                       int cropC, int cropH, int cropW, | ||||
|                       int outC, int outH, int outW, int nthreads) { | ||||
|   const int idx = threadIdx.x + blockIdx.x * blockDim.x; | ||||
|   if (idx < nthreads) { | ||||
|     const int w = idx % outW; | ||||
|     const int h = (idx / outW) % outH; | ||||
|     const int c = (idx / outW / outH) % outC; | ||||
|     const int n = idx / outW / outH / outC; | ||||
| 
 | ||||
|     const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w; | ||||
|     outputs[idx] = inputs[off]; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| void Crop<DEVICE_TYPE_GPU>(real* outputs, | ||||
|                           const real* inputs, | ||||
| 						  const TensorShape inShape, | ||||
|                           const CropConf& crop) { | ||||
|   int cropC = crop.corner[0]; | ||||
|   int cropH = crop.corner[1]; | ||||
|   int cropW = crop.corner[2]; | ||||
| 
 | ||||
|   int num = inShape[0]; | ||||
|   int inC = inShape[1]; | ||||
|   int inH = inShape[2]; | ||||
|   int inW = inShape[3]; | ||||
| 
 | ||||
|   int outC = crop.shape[0]; | ||||
|   int outH = crop.shape[1]; | ||||
|   int outW = crop.shape[2]; | ||||
|    | ||||
|   size_t nth = num * outC * outH * outW; | ||||
|   int blockSize = 1024; | ||||
|   int gridSize = (nth + blockSize - 1) / blockSize; | ||||
|    | ||||
|   KeCrop<<<gridSize, blockSize, 0, STREAM_DEFAULT>>> | ||||
|     (outputs, inputs, inC, inH, inW, cropC, cropH, cropW, | ||||
|      outC, outH, outW, nth); | ||||
|   CHECK_SYNC("Crop"); | ||||
| } | ||||
| 
 | ||||
| __global__ void KeCropDiff(const real* inGrad, real* outGrad, | ||||
|                           int inC, int inH, int inW, | ||||
|                           int cropC, int cropH, int cropW, | ||||
|                           int outC, int outH, int outW, int nthreads) { | ||||
|   const int idx = threadIdx.x + blockIdx.x * blockDim.x; | ||||
|   if (idx < nthreads) { | ||||
|     const int w = idx % inW; | ||||
|     const int h = (idx / inW) % inH; | ||||
|     const int c = (idx / inW / inH) % inC; | ||||
|     const int n = idx / inW / inH / inC; | ||||
| 
 | ||||
|     const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w; | ||||
|      | ||||
|     outGrad[off] += inGrad[idx]; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| void CropGrad<DEVICE_TYPE_GPU>(const real* inGrad, | ||||
|                               real* outGrad, | ||||
|                               const TensorShape outShape, | ||||
|                               const CropConf& crop) { | ||||
|   int cropC = crop.corner[0]; | ||||
|   int cropH = crop.corner[1]; | ||||
|   int cropW = crop.corner[2]; | ||||
| 
 | ||||
|   int num = outShape[0]; | ||||
|   int outC = outShape[1]; | ||||
|   int outH = outShape[2]; | ||||
|   int outW = outShape[3]; | ||||
| 
 | ||||
|   int inC = crop.shape[0]; | ||||
|   int inH = crop.shape[1]; | ||||
|   int inW = crop.shape[2]; | ||||
|    | ||||
|   size_t nth = num * inC * inH * inW; | ||||
|   int blockSize = 1024; | ||||
|   int gridSize = (nth + blockSize - 1) / blockSize; | ||||
| 
 | ||||
|   KeCropDiff <<<gridSize, blockSize, 0, STREAM_DEFAULT>>> | ||||
|     (inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW, | ||||
|      outC, outH, outW, nth); | ||||
|   CHECK_SYNC("CropGrad"); | ||||
| } | ||||
| 
 | ||||
| }  // namespace paddle | ||||
| @ -0,0 +1,47 @@ | ||||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #include <gtest/gtest.h> | ||||
| #include "FunctionTest.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| 
 | ||||
| TEST(Crop, real) { | ||||
|   for (size_t numSamples : {5, 32}) { | ||||
|     for (size_t channels : {5, 5, 32}) { | ||||
|       for (size_t imgSizeH : {5, 33, 100}) { | ||||
|         for (size_t imgSizeW : {5, 32, 96}) { | ||||
|           VLOG(3) << " numSamples=" << numSamples << " channels=" << channels | ||||
|                   << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; | ||||
|           for (bool test_grad : {false, true}) { | ||||
|             FunctionCompare compare( | ||||
|                 test_grad ? "CropGrad" : "Crop", | ||||
|                 FuncConfig() | ||||
|                     .set<std::vector<uint32_t>>("crop_corner", {1, 1, 1}) | ||||
|                     .set<std::vector<uint32_t>>("crop_shape", {2, 3, 3})); | ||||
|             TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; | ||||
|             TensorShape outDims{numSamples, 2, 3, 3}; | ||||
|             compare.addInputs( | ||||
|                 BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims)); | ||||
|             compare.addOutputs(BufferArg( | ||||
|                 VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO)); | ||||
|             compare.run(); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,101 @@ | ||||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #include "CropLayer.h" | ||||
| #include "paddle/utils/Stat.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| 
 | ||||
| REGISTER_LAYER(crop, CropLayer); | ||||
| 
 | ||||
| bool CropLayer::init(const LayerMap& layerMap, | ||||
|                      const ParameterMap& parameterMap) { | ||||
|   /* Initialize the basic parent class */ | ||||
|   Layer::init(layerMap, parameterMap); | ||||
| 
 | ||||
|   auto& crop_conf = config_.inputs(0).crop_conf(); | ||||
|   auto& img_conf = crop_conf.image_conf(); | ||||
|   CHECK_EQ(config_.inputs_size(), 1); | ||||
|   inDims_ = TensorShape( | ||||
|       {0, | ||||
|        img_conf.channels(), | ||||
|        img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(), | ||||
|        img_conf.img_size()}); | ||||
| 
 | ||||
|   crop_corner_ = {crop_conf.crop_corner(0), | ||||
|                   crop_conf.crop_corner(1), | ||||
|                   crop_conf.crop_corner(2)}; | ||||
|   crop_shape_ = {crop_conf.crop_shape(0), | ||||
|                  crop_conf.crop_shape(1), | ||||
|                  crop_conf.crop_shape(2)}; | ||||
| 
 | ||||
|   outDims_ = TensorShape(4); | ||||
|   setOutDims(0); | ||||
| 
 | ||||
|   createFunction(forward_, | ||||
|                  "Crop", | ||||
|                  FuncConfig() | ||||
|                      .set("crop_corner", crop_corner_) | ||||
|                      .set("crop_shape", crop_shape_)); | ||||
|   createFunction(backward_, | ||||
|                  "CropGrad", | ||||
|                  FuncConfig() | ||||
|                      .set("crop_corner", crop_corner_) | ||||
|                      .set("crop_shape", crop_shape_)); | ||||
| 
 | ||||
|   return true; | ||||
| } | ||||
| 
 | ||||
| void CropLayer::setOutDims(const size_t batchSize) { | ||||
|   outDims_.reshape({batchSize, crop_shape_[0], crop_shape_[1], crop_shape_[2]}); | ||||
| } | ||||
| 
 | ||||
| void CropLayer::setTensorDim(const size_t batchSize) { | ||||
|   CHECK_EQ(static_cast<int>(inputLayers_.size()), 1); | ||||
|   inDims_.setDim(0, batchSize); | ||||
|   int h = inputLayers_[0]->getOutput().getFrameHeight(); | ||||
|   if (h != 0) inDims_.setDim(2, h); | ||||
|   int w = inputLayers_[0]->getOutput().getFrameWidth(); | ||||
|   if (w != 0) inDims_.setDim(3, w); | ||||
|   setOutDims(batchSize); | ||||
| } | ||||
| 
 | ||||
| void CropLayer::forward(PassType passType) { | ||||
|   Layer::forward(passType); | ||||
|   MatrixPtr input = inputLayers_[0]->getOutputValue(); | ||||
|   size_t batchSize = input->getHeight(); | ||||
|   setTensorDim(batchSize); | ||||
|   int size = outDims_[1] * outDims_[2] * outDims_[3]; | ||||
|   resetOutput(batchSize, size); | ||||
|   MatrixPtr outV = getOutputValue(); | ||||
|   REGISTER_TIMER_INFO("CropForward", getName().c_str()); | ||||
| 
 | ||||
|   BufferArgs inputs; | ||||
|   BufferArgs outputs; | ||||
|   inputs.addArg(*getInputValue(0), inDims_); | ||||
|   outputs.addArg(*getOutputValue(), outDims_, ASSIGN_TO); | ||||
|   forward_[0]->calc(inputs, outputs); | ||||
| } | ||||
| 
 | ||||
| void CropLayer::backward(const UpdateCallback& callback) { | ||||
|   (void)callback; | ||||
|   REGISTER_TIMER_INFO("CropBackward", getName().c_str()); | ||||
| 
 | ||||
|   BufferArgs inputs; | ||||
|   BufferArgs outputs; | ||||
|   inputs.addArg(*getOutputGrad(), outDims_); | ||||
|   outputs.addArg(*getInputGrad(0), inDims_, ADD_TO); | ||||
|   backward_[0]->calc(inputs, outputs); | ||||
| } | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,46 @@ | ||||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include "Layer.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| 
 | ||||
| /**
 | ||||
|  * \brief  This layer crop inputs according to the specify corner and shape. | ||||
|  *         The input and output is a 4D tensor. Cropping from the 2nd to | ||||
|  *         the 4th dimenstion. | ||||
|  */ | ||||
| class CropLayer : public Layer { | ||||
| public: | ||||
|   explicit CropLayer(const LayerConfig& config) : Layer(config) {} | ||||
| 
 | ||||
|   ~CropLayer() {} | ||||
| 
 | ||||
|   bool init(const LayerMap& layerMap, | ||||
|             const ParameterMap& parameterMap) override; | ||||
|   void forward(PassType passType) override; | ||||
|   void backward(const UpdateCallback& callback = nullptr) override; | ||||
| 
 | ||||
| protected: | ||||
|   void setOutDims(const size_t batchSize); | ||||
|   void setTensorDim(const size_t batchSize); | ||||
| 
 | ||||
|   std::vector<uint32_t> crop_corner_; | ||||
|   std::vector<uint32_t> crop_shape_; | ||||
|   TensorShape inDims_; | ||||
|   TensorShape outDims_; | ||||
| }; | ||||
| }  // namespace paddle
 | ||||
					Loading…
					
					
				
		Reference in new issue