From 013d0a268591829d7f757deeb3c23c58915c2d95 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 16 Jun 2017 19:02:46 +0800 Subject: [PATCH 01/37] add crop layer --- paddle/function/CMakeLists.txt | 1 + paddle/function/CropOp.cpp | 177 ++++++++++++++++++++++++++++ paddle/function/CropOp.h | 56 +++++++++ paddle/function/CropOpGpu.cu | 109 +++++++++++++++++ paddle/function/CropOpTest.cpp | 47 ++++++++ paddle/gserver/layers/CropLayer.cpp | 101 ++++++++++++++++ paddle/gserver/layers/CropLayer.h | 46 ++++++++ 7 files changed, 537 insertions(+) create mode 100644 paddle/function/CropOp.cpp create mode 100644 paddle/function/CropOp.h create mode 100644 paddle/function/CropOpGpu.cu create mode 100644 paddle/function/CropOpTest.cpp create mode 100644 paddle/gserver/layers/CropLayer.cpp create mode 100644 paddle/gserver/layers/CropLayer.h diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 1518a8a654..f19a1eb777 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -37,6 +37,7 @@ if(WITH_GPU) add_simple_unittest(MulOpTest) add_simple_unittest(CosSimOpTest) add_simple_unittest(RowConvOpTest) + add_simple_unittest(CropOpTest) endif() add_simple_unittest(ConvOpTest) diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp new file mode 100644 index 0000000000..4d47d9c149 --- /dev/null +++ b/paddle/function/CropOp.cpp @@ -0,0 +1,177 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CropOp.h" +#include "paddle/math/Vector.h" +#include "paddle/function/TensorShape.h" +namespace paddle { + +static inline CropConf castToCropConf(const FuncConfig& conf) { + return {conf.get>("crop_corner"), + conf.get>("crop_shape")}; +} + +template <> +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const CropConf& crop) { + int cCrop = crop.corner[0]; + int hCrop = crop.corner[1]; + int wCrop = crop.corner[2]; + + int num = inShape[0]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + int outC = crop.shape[0]; + int outH = crop.shape[1]; + int outW = crop.shape[2]; + + for (int n = 0; n < num; n++) { + for (int c = 0; c < outC; c++) { + for (int h = 0; h < outH; h++) { + int outoff = ((n * outC + c) * outH + h) * outW; + int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop; + memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real)); + } + } + } +} + +template <> +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape outShape, + const CropConf& crop) { + int cCrop = crop.corner[0]; + int hCrop = crop.corner[1]; + int wCrop = crop.corner[2]; + + int num = outShape[0]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + int inC = crop.shape[0]; + int inH = crop.shape[1]; + int inW = crop.shape[2]; + + for (int n = 0; n < num; n++) { + for (int c = 0; c < inC; c++) { + for (int h = 0; h < inH; h++) { + int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop; + int inoff = ((n * inC + c) * inH + h) * inW; + CpuVector inG = CpuVector(inW, const_cast(inGrad + inoff)); + CpuVector outG = CpuVector(inW, outGrad + outoff); + outG += inG; + } + } + } +} + +/** + * \brief Crop input according to the specify corner and shape. + * The input and output is a 4D tensor. In CropFunc, we only + * crop the 2nd to 4th dimension. + * + * Argument in this Function: + * \param pad_ A struct object contains the cropping corner and shape. + * \param inputs A 4D tensor, only one input. + * \param outputs A 4D tensor, the output value after cropping. + * + * For example, + * Input(2,2,2,3) = [ + * [ [[1,2,3], [3,4,5]], + * [[2,3,5], [1,6,7]] ], + * [ [[4,3,1], [1,8,7]], + * [[3,8,9], [2,3,5]] ] + * ] # the input shape is (2,2,2,3) + * + * pad_: if corner = (0,1,1) and crop_shape = (2,1,2) + * Output(2,2,1,2) = [ + * [ [[4,5]], + * [[6,7]] ], + * [ [[8,7]], + * [[3,5]] ] + * ] # the input shape is (2,2,2,3) + */ +template +class CropFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + crop_ = castToCropConf(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(outputs[0].shape()[1], crop_.shape[0]); + CHECK_EQ(outputs[0].shape()[2], crop_.shape[1]); + CHECK_EQ(outputs[0].shape()[3], crop_.shape[2]); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + TensorShape inShape = inputs[0].shape(); + + Crop( + outputs[0].data(), inputs[0].data(), inShape, crop_); + } + +private: + CropConf crop_; +}; + +/** + * \brief The backward propagation of cropping Function. + * + * Argument in this Function: + * \param crop_ The same meaning as it in CropFunc. + * \param inputs The gradient with respect to the output value of CropFunc. + * \param outputs The gradient with respect to the input value of CropFunc. + */ + +template +class CropGradFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + crop_ = castToCropConf(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(inputs[0].shape()[1], crop_.shape[0]); + CHECK_EQ(inputs[0].shape()[2], crop_.shape[1]); + CHECK_EQ(inputs[0].shape()[3], crop_.shape[2]); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + TensorShape outShape = outputs[0].shape(); + + CropGrad( + inputs[0].data(), outputs[0].data(), outShape, crop_); + } + +private: + CropConf crop_; +}; + +REGISTER_TYPED_FUNC(Crop, CPU, CropFunc); +REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(Crop, GPU, CropFunc); +REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc); +#endif + +} // namespace paddle diff --git a/paddle/function/CropOp.h b/paddle/function/CropOp.h new file mode 100644 index 0000000000..78a55bd43e --- /dev/null +++ b/paddle/function/CropOp.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +struct CropConf { + /// The upper left corner of croped result + std::vector corner; + /// The shape of croped result + std::vector shape; +}; + +/** + * \brief This funtion crops inputs according to the specify start point and + *shape. + * + * \param[out] outputs save results. + * \param[in] inputs input data. + * \param[in] inShape the shape of input tensor. + * \param[in] crop the cropping config + */ +template +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const CropConf& crop); + +/** + * \brief Cropping operation backward. + * + * \param[out] inGrad gradients of previous layer + * \param[in] outGrad output gradient + * \param[in] inShape the shape of input tensor. + * \param[in] crop the cropping config + */ +template +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape inShape, + const CropConf& crop); +} // namespace paddle diff --git a/paddle/function/CropOpGpu.cu b/paddle/function/CropOpGpu.cu new file mode 100644 index 0000000000..f7d7d03abd --- /dev/null +++ b/paddle/function/CropOpGpu.cu @@ -0,0 +1,109 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "CropOp.h" + +namespace paddle { + +__global__ void KeCrop(real* outputs, const real* inputs, + int inC, int inH, int inW, + int cropC, int cropH, int cropW, + int outC, int outH, int outW, int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % outW; + const int h = (idx / outW) % outH; + const int c = (idx / outW / outH) % outC; + const int n = idx / outW / outH / outC; + + const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w; + outputs[idx] = inputs[off]; + } +} + +template <> +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const CropConf& crop) { + int cropC = crop.corner[0]; + int cropH = crop.corner[1]; + int cropW = crop.corner[2]; + + int num = inShape[0]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + int outC = crop.shape[0]; + int outH = crop.shape[1]; + int outW = crop.shape[2]; + + size_t nth = num * outC * outH * outW; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeCrop<<>> + (outputs, inputs, inC, inH, inW, cropC, cropH, cropW, + outC, outH, outW, nth); + CHECK_SYNC("Crop"); +} + +__global__ void KeCropDiff(const real* inGrad, real* outGrad, + int inC, int inH, int inW, + int cropC, int cropH, int cropW, + int outC, int outH, int outW, int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % inW; + const int h = (idx / inW) % inH; + const int c = (idx / inW / inH) % inC; + const int n = idx / inW / inH / inC; + + const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w; + + outGrad[off] += inGrad[idx]; + } +} + +template <> +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape outShape, + const CropConf& crop) { + int cropC = crop.corner[0]; + int cropH = crop.corner[1]; + int cropW = crop.corner[2]; + + int num = outShape[0]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + int inC = crop.shape[0]; + int inH = crop.shape[1]; + int inW = crop.shape[2]; + + size_t nth = num * inC * inH * inW; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeCropDiff <<>> + (inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW, + outC, outH, outW, nth); + CHECK_SYNC("CropGrad"); +} + +} // namespace paddle diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp new file mode 100644 index 0000000000..62b4bd9fde --- /dev/null +++ b/paddle/function/CropOpTest.cpp @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +namespace paddle { + +TEST(Crop, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {5, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; + for (bool test_grad : {false, true}) { + FunctionCompare compare( + test_grad ? "CropGrad" : "Crop", + FuncConfig() + .set>("crop_corner", {1, 1, 1}) + .set>("crop_shape", {2, 3, 3})); + TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; + TensorShape outDims{numSamples, 2, 3, 3}; + compare.addInputs( + BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims)); + compare.addOutputs(BufferArg( + VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO)); + compare.run(); + } + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/CropLayer.cpp b/paddle/gserver/layers/CropLayer.cpp new file mode 100644 index 0000000000..ab23d4617e --- /dev/null +++ b/paddle/gserver/layers/CropLayer.cpp @@ -0,0 +1,101 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CropLayer.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(crop, CropLayer); + +bool CropLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + + auto& crop_conf = config_.inputs(0).crop_conf(); + auto& img_conf = crop_conf.image_conf(); + CHECK_EQ(config_.inputs_size(), 1); + inDims_ = TensorShape( + {0, + img_conf.channels(), + img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(), + img_conf.img_size()}); + + crop_corner_ = {crop_conf.crop_corner(0), + crop_conf.crop_corner(1), + crop_conf.crop_corner(2)}; + crop_shape_ = {crop_conf.crop_shape(0), + crop_conf.crop_shape(1), + crop_conf.crop_shape(2)}; + + outDims_ = TensorShape(4); + setOutDims(0); + + createFunction(forward_, + "Crop", + FuncConfig() + .set("crop_corner", crop_corner_) + .set("crop_shape", crop_shape_)); + createFunction(backward_, + "CropGrad", + FuncConfig() + .set("crop_corner", crop_corner_) + .set("crop_shape", crop_shape_)); + + return true; +} + +void CropLayer::setOutDims(const size_t batchSize) { + outDims_.reshape({batchSize, crop_shape_[0], crop_shape_[1], crop_shape_[2]}); +} + +void CropLayer::setTensorDim(const size_t batchSize) { + CHECK_EQ(static_cast(inputLayers_.size()), 1); + inDims_.setDim(0, batchSize); + int h = inputLayers_[0]->getOutput().getFrameHeight(); + if (h != 0) inDims_.setDim(2, h); + int w = inputLayers_[0]->getOutput().getFrameWidth(); + if (w != 0) inDims_.setDim(3, w); + setOutDims(batchSize); +} + +void CropLayer::forward(PassType passType) { + Layer::forward(passType); + MatrixPtr input = inputLayers_[0]->getOutputValue(); + size_t batchSize = input->getHeight(); + setTensorDim(batchSize); + int size = outDims_[1] * outDims_[2] * outDims_[3]; + resetOutput(batchSize, size); + MatrixPtr outV = getOutputValue(); + REGISTER_TIMER_INFO("CropForward", getName().c_str()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), inDims_); + outputs.addArg(*getOutputValue(), outDims_, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); +} + +void CropLayer::backward(const UpdateCallback& callback) { + (void)callback; + REGISTER_TIMER_INFO("CropBackward", getName().c_str()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outDims_); + outputs.addArg(*getInputGrad(0), inDims_, ADD_TO); + backward_[0]->calc(inputs, outputs); +} +} // namespace paddle diff --git a/paddle/gserver/layers/CropLayer.h b/paddle/gserver/layers/CropLayer.h new file mode 100644 index 0000000000..3ce89707ca --- /dev/null +++ b/paddle/gserver/layers/CropLayer.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" + +namespace paddle { + +/** + * \brief This layer crop inputs according to the specify corner and shape. + * The input and output is a 4D tensor. Cropping from the 2nd to + * the 4th dimenstion. + */ +class CropLayer : public Layer { +public: + explicit CropLayer(const LayerConfig& config) : Layer(config) {} + + ~CropLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + +protected: + void setOutDims(const size_t batchSize); + void setTensorDim(const size_t batchSize); + + std::vector crop_corner_; + std::vector crop_shape_; + TensorShape inDims_; + TensorShape outDims_; +}; +} // namespace paddle From 90ed2004a56a955dc6a1413e1d4c624caf31780b Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 22 Jun 2017 16:54:07 +0800 Subject: [PATCH 02/37] Refine configure option of crop layer 1. change configure content to 'axis, offset, shape' 2. add an optional input to crop layer as cropping reference --- paddle/function/CropOp.cpp | 63 ++++++++++++--------------- paddle/function/CropOp.h | 15 ++----- paddle/function/CropOpGpu.cu | 32 ++++++++------ paddle/function/CropOpTest.cpp | 4 +- paddle/gserver/layers/CropLayer.cpp | 67 ++++++++++++++++++++++------- paddle/gserver/layers/CropLayer.h | 13 ++++-- 6 files changed, 114 insertions(+), 80 deletions(-) diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp index 4d47d9c149..0d511ceef5 100644 --- a/paddle/function/CropOp.cpp +++ b/paddle/function/CropOp.cpp @@ -17,28 +17,27 @@ limitations under the License. */ #include "paddle/function/TensorShape.h" namespace paddle { -static inline CropConf castToCropConf(const FuncConfig& conf) { - return {conf.get>("crop_corner"), - conf.get>("crop_shape")}; -} - template <> void Crop(real* outputs, const real* inputs, const TensorShape inShape, - const CropConf& crop) { - int cCrop = crop.corner[0]; - int hCrop = crop.corner[1]; - int wCrop = crop.corner[2]; + const FuncConfig& conf) { + std::vector crop_corner = + conf.get>("crop_corner"); + std::vector crop_shape = + conf.get>("crop_shape"); + int cCrop = crop_corner[1]; + int hCrop = crop_corner[2]; + int wCrop = crop_corner[3]; int num = inShape[0]; int inC = inShape[1]; int inH = inShape[2]; int inW = inShape[3]; - int outC = crop.shape[0]; - int outH = crop.shape[1]; - int outW = crop.shape[2]; + int outC = crop_shape[1]; + int outH = crop_shape[2]; + int outW = crop_shape[3]; for (int n = 0; n < num; n++) { for (int c = 0; c < outC; c++) { @@ -55,19 +54,23 @@ template <> void CropGrad(const real* inGrad, real* outGrad, const TensorShape outShape, - const CropConf& crop) { - int cCrop = crop.corner[0]; - int hCrop = crop.corner[1]; - int wCrop = crop.corner[2]; + const FuncConfig& conf) { + std::vector crop_corner = + conf.get>("crop_corner"); + std::vector crop_shape = + conf.get>("crop_shape"); + int cCrop = crop_corner[1]; + int hCrop = crop_corner[2]; + int wCrop = crop_corner[3]; int num = outShape[0]; int outC = outShape[1]; int outH = outShape[2]; int outW = outShape[3]; - int inC = crop.shape[0]; - int inH = crop.shape[1]; - int inW = crop.shape[2]; + int inC = crop_shape[1]; + int inH = crop_shape[2]; + int inW = crop_shape[3]; for (int n = 0; n < num; n++) { for (int c = 0; c < inC; c++) { @@ -111,26 +114,21 @@ void CropGrad(const real* inGrad, template class CropFunc : public FunctionBase { public: - void init(const FuncConfig& config) override { - crop_ = castToCropConf(config); - } + void init(const FuncConfig& config) override { conf_ = config; } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1UL, inputs.size()); CHECK_EQ(1UL, outputs.size()); - CHECK_EQ(outputs[0].shape()[1], crop_.shape[0]); - CHECK_EQ(outputs[0].shape()[2], crop_.shape[1]); - CHECK_EQ(outputs[0].shape()[3], crop_.shape[2]); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); TensorShape inShape = inputs[0].shape(); Crop( - outputs[0].data(), inputs[0].data(), inShape, crop_); + outputs[0].data(), inputs[0].data(), inShape, conf_); } private: - CropConf crop_; + FuncConfig conf_; }; /** @@ -145,26 +143,21 @@ private: template class CropGradFunc : public FunctionBase { public: - void init(const FuncConfig& config) override { - crop_ = castToCropConf(config); - } + void init(const FuncConfig& config) override { conf_ = config; } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1UL, inputs.size()); CHECK_EQ(1UL, outputs.size()); - CHECK_EQ(inputs[0].shape()[1], crop_.shape[0]); - CHECK_EQ(inputs[0].shape()[2], crop_.shape[1]); - CHECK_EQ(inputs[0].shape()[3], crop_.shape[2]); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); TensorShape outShape = outputs[0].shape(); CropGrad( - inputs[0].data(), outputs[0].data(), outShape, crop_); + inputs[0].data(), outputs[0].data(), outShape, conf_); } private: - CropConf crop_; + FuncConfig conf_; }; REGISTER_TYPED_FUNC(Crop, CPU, CropFunc); diff --git a/paddle/function/CropOp.h b/paddle/function/CropOp.h index 78a55bd43e..71e8c4c00e 100644 --- a/paddle/function/CropOp.h +++ b/paddle/function/CropOp.h @@ -18,13 +18,6 @@ limitations under the License. */ namespace paddle { -struct CropConf { - /// The upper left corner of croped result - std::vector corner; - /// The shape of croped result - std::vector shape; -}; - /** * \brief This funtion crops inputs according to the specify start point and *shape. @@ -32,13 +25,13 @@ struct CropConf { * \param[out] outputs save results. * \param[in] inputs input data. * \param[in] inShape the shape of input tensor. - * \param[in] crop the cropping config + * \param[in] conf the cropping config */ template void Crop(real* outputs, const real* inputs, const TensorShape inShape, - const CropConf& crop); + const FuncConfig& conf); /** * \brief Cropping operation backward. @@ -46,11 +39,11 @@ void Crop(real* outputs, * \param[out] inGrad gradients of previous layer * \param[in] outGrad output gradient * \param[in] inShape the shape of input tensor. - * \param[in] crop the cropping config + * \param[in] conf the cropping config */ template void CropGrad(const real* inGrad, real* outGrad, const TensorShape inShape, - const CropConf& crop); + const FuncConfig& conf); } // namespace paddle diff --git a/paddle/function/CropOpGpu.cu b/paddle/function/CropOpGpu.cu index f7d7d03abd..cadb58b6e9 100644 --- a/paddle/function/CropOpGpu.cu +++ b/paddle/function/CropOpGpu.cu @@ -37,19 +37,21 @@ template <> void Crop(real* outputs, const real* inputs, const TensorShape inShape, - const CropConf& crop) { - int cropC = crop.corner[0]; - int cropH = crop.corner[1]; - int cropW = crop.corner[2]; + const FuncConfig& conf) { + std::vector crop_corner = conf.get>("crop_corner"); + std::vector crop_shape = conf.get>("crop_shape"); + int cropC = crop_corner[1]; + int cropH = crop_corner[2]; + int cropW = crop_corner[3]; int num = inShape[0]; int inC = inShape[1]; int inH = inShape[2]; int inW = inShape[3]; - int outC = crop.shape[0]; - int outH = crop.shape[1]; - int outW = crop.shape[2]; + int outC = crop_shape[1]; + int outH = crop_shape[2]; + int outW = crop_shape[3]; size_t nth = num * outC * outH * outW; int blockSize = 1024; @@ -82,19 +84,21 @@ template <> void CropGrad(const real* inGrad, real* outGrad, const TensorShape outShape, - const CropConf& crop) { - int cropC = crop.corner[0]; - int cropH = crop.corner[1]; - int cropW = crop.corner[2]; + const FuncConfig& conf) { + std::vector crop_corner = conf.get>("crop_corner"); + std::vector crop_shape = conf.get>("crop_shape"); + int cropC = crop_corner[1]; + int cropH = crop_corner[2]; + int cropW = crop_corner[3]; int num = outShape[0]; int outC = outShape[1]; int outH = outShape[2]; int outW = outShape[3]; - int inC = crop.shape[0]; - int inH = crop.shape[1]; - int inW = crop.shape[2]; + int inC = crop_shape[1]; + int inH = crop_shape[2]; + int inW = crop_shape[3]; size_t nth = num * inC * inH * inW; int blockSize = 1024; diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp index 62b4bd9fde..c331a70d1f 100644 --- a/paddle/function/CropOpTest.cpp +++ b/paddle/function/CropOpTest.cpp @@ -28,8 +28,8 @@ TEST(Crop, real) { FunctionCompare compare( test_grad ? "CropGrad" : "Crop", FuncConfig() - .set>("crop_corner", {1, 1, 1}) - .set>("crop_shape", {2, 3, 3})); + .set>("crop_corner", {0, 1, 1, 1}) + .set>("crop_shape", {0, 2, 3, 3})); TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; TensorShape outDims{numSamples, 2, 3, 3}; compare.addInputs( diff --git a/paddle/gserver/layers/CropLayer.cpp b/paddle/gserver/layers/CropLayer.cpp index ab23d4617e..198ceffb46 100644 --- a/paddle/gserver/layers/CropLayer.cpp +++ b/paddle/gserver/layers/CropLayer.cpp @@ -25,20 +25,57 @@ bool CropLayer::init(const LayerMap& layerMap, Layer::init(layerMap, parameterMap); auto& crop_conf = config_.inputs(0).crop_conf(); - auto& img_conf = crop_conf.image_conf(); - CHECK_EQ(config_.inputs_size(), 1); - inDims_ = TensorShape( - {0, - img_conf.channels(), - img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(), - img_conf.img_size()}); - - crop_corner_ = {crop_conf.crop_corner(0), - crop_conf.crop_corner(1), - crop_conf.crop_corner(2)}; - crop_shape_ = {crop_conf.crop_shape(0), - crop_conf.crop_shape(1), - crop_conf.crop_shape(2)}; + crop_axis_ = crop_conf.axis(); + for (int i = 0; i < crop_conf.offset_size(); i++) { + crop_offsets_[i] = crop_conf.offset(i); + } + + // 1. get input_0 shape + auto& input0_img_conf = config_.inputs(0).image_conf(); + inDims_ = TensorShape({0, + input0_img_conf.channels(), + input0_img_conf.has_img_size_y() + ? input0_img_conf.img_size_y() + : input0_img_conf.img_size(), + input0_img_conf.img_size()}); + + // 2. get output shape from input_1 or crop shap conf + if (config_.inputs_size() == 2) { + auto& input1_img_conf = config_.inputs(1).image_conf(); + targetDims_ = TensorShape({0, + input1_img_conf.channels(), + input1_img_conf.has_img_size_y() + ? input1_img_conf.img_size_y() + : input1_img_conf.img_size(), + input1_img_conf.img_size()}); + } else { + targetDims_ = TensorShape({crop_conf.shape(0), + crop_conf.shape(1), + crop_conf.shape(2), + crop_conf.shape(3)}); + } + + // 3. get final crop shape + int dimSize = 4; + for (int i = 0; i < dimSize; i++) { + if (i >= crop_axis_) { + crop_shape_[i] = targetDims_[i]; + } else { + crop_shape_[i] = inDims_[i]; + } + } + + // 4. get final crop corner + crop_corner_ = {0, 0, 0, 0}; + for (int i = 0; i < dimSize; i++) { + if (i >= crop_axis_) { + if (crop_offsets_.size() > 1) { + crop_corner_[i] = crop_offsets_[i - crop_axis_]; + } else { + crop_corner_[i] = crop_offsets_[0]; + } + } + } outDims_ = TensorShape(4); setOutDims(0); @@ -58,7 +95,7 @@ bool CropLayer::init(const LayerMap& layerMap, } void CropLayer::setOutDims(const size_t batchSize) { - outDims_.reshape({batchSize, crop_shape_[0], crop_shape_[1], crop_shape_[2]}); + outDims_.reshape({batchSize, crop_shape_[1], crop_shape_[2], crop_shape_[3]}); } void CropLayer::setTensorDim(const size_t batchSize) { diff --git a/paddle/gserver/layers/CropLayer.h b/paddle/gserver/layers/CropLayer.h index 3ce89707ca..23cede1c3f 100644 --- a/paddle/gserver/layers/CropLayer.h +++ b/paddle/gserver/layers/CropLayer.h @@ -19,9 +19,13 @@ limitations under the License. */ namespace paddle { /** - * \brief This layer crop inputs according to the specify corner and shape. - * The input and output is a 4D tensor. Cropping from the 2nd to - * the 4th dimenstion. + * \brief This layer crop input according to the specify conf. + * input_0: input to be cropped + * input_1: optional reference input + * axis: start dimension to be croped + * offset: offset of cropping in each dimension + * shape: if reference input layer was not setted, + * crop input as this shape conf */ class CropLayer : public Layer { public: @@ -38,9 +42,12 @@ protected: void setOutDims(const size_t batchSize); void setTensorDim(const size_t batchSize); + int32_t crop_axis_; + std::vector crop_offsets_; std::vector crop_corner_; std::vector crop_shape_; TensorShape inDims_; + TensorShape targetDims_; TensorShape outDims_; }; } // namespace paddle From 701827f59cb5727676818c2fffb2b07766528436 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 5 Jul 2017 00:53:32 +0800 Subject: [PATCH 03/37] Add grad test and python wrapper for crop layer --- paddle/function/CropOp.cpp | 2 +- paddle/function/CropOpTest.cpp | 2 +- paddle/gserver/layers/CropLayer.cpp | 23 ++++---- paddle/gserver/tests/CMakeLists.txt | 2 +- paddle/gserver/tests/test_LayerGrad.cpp | 28 ++++++++++ proto/ModelConfig.proto | 8 ++- python/paddle/trainer/config_parser.py | 45 ++++++++++++++++ .../paddle/trainer_config_helpers/layers.py | 54 +++++++++++++++++++ 8 files changed, 147 insertions(+), 17 deletions(-) diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp index 0d511ceef5..1bb194a9bc 100644 --- a/paddle/function/CropOp.cpp +++ b/paddle/function/CropOp.cpp @@ -148,7 +148,7 @@ public: void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1UL, inputs.size()); CHECK_EQ(1UL, outputs.size()); - CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); TensorShape outShape = outputs[0].shape(); diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp index c331a70d1f..71d9b05812 100644 --- a/paddle/function/CropOpTest.cpp +++ b/paddle/function/CropOpTest.cpp @@ -25,7 +25,7 @@ TEST(Crop, real) { VLOG(3) << " numSamples=" << numSamples << " channels=" << channels << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; for (bool test_grad : {false, true}) { - FunctionCompare compare( + CpuGpuFuncCompare compare( test_grad ? "CropGrad" : "Crop", FuncConfig() .set>("crop_corner", {0, 1, 1, 1}) diff --git a/paddle/gserver/layers/CropLayer.cpp b/paddle/gserver/layers/CropLayer.cpp index 198ceffb46..b2fa17b400 100644 --- a/paddle/gserver/layers/CropLayer.cpp +++ b/paddle/gserver/layers/CropLayer.cpp @@ -14,7 +14,6 @@ limitations under the License. */ #include "CropLayer.h" #include "paddle/utils/Stat.h" - namespace paddle { REGISTER_LAYER(crop, CropLayer); @@ -24,10 +23,9 @@ bool CropLayer::init(const LayerMap& layerMap, /* Initialize the basic parent class */ Layer::init(layerMap, parameterMap); - auto& crop_conf = config_.inputs(0).crop_conf(); - crop_axis_ = crop_conf.axis(); - for (int i = 0; i < crop_conf.offset_size(); i++) { - crop_offsets_[i] = crop_conf.offset(i); + crop_axis_ = config_.axis(); + for (int i = 0; i < config_.offset_size(); i++) { + crop_offsets_.push_back(config_.offset(i)); } // 1. get input_0 shape @@ -38,7 +36,6 @@ bool CropLayer::init(const LayerMap& layerMap, ? input0_img_conf.img_size_y() : input0_img_conf.img_size(), input0_img_conf.img_size()}); - // 2. get output shape from input_1 or crop shap conf if (config_.inputs_size() == 2) { auto& input1_img_conf = config_.inputs(1).image_conf(); @@ -49,19 +46,19 @@ bool CropLayer::init(const LayerMap& layerMap, : input1_img_conf.img_size(), input1_img_conf.img_size()}); } else { - targetDims_ = TensorShape({crop_conf.shape(0), - crop_conf.shape(1), - crop_conf.shape(2), - crop_conf.shape(3)}); + targetDims_ = TensorShape({config_.shape(0), + config_.shape(1), + config_.shape(2), + config_.shape(3)}); } // 3. get final crop shape int dimSize = 4; for (int i = 0; i < dimSize; i++) { if (i >= crop_axis_) { - crop_shape_[i] = targetDims_[i]; + crop_shape_.push_back(targetDims_[i]); } else { - crop_shape_[i] = inDims_[i]; + crop_shape_.push_back(inDims_[i]); } } @@ -99,7 +96,7 @@ void CropLayer::setOutDims(const size_t batchSize) { } void CropLayer::setTensorDim(const size_t batchSize) { - CHECK_EQ(static_cast(inputLayers_.size()), 1); + CHECK_EQ(static_cast(inputLayers_.size()), 2); inDims_.setDim(0, batchSize); int h = inputLayers_[0]->getOutput().getFrameHeight(); if (h != 0) inDims_.setDim(2, h); diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index 92f6cbcfe5..a43adc7ce7 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -56,7 +56,7 @@ add_test(NAME test_DetectionOutput add_unittest_without_exec(test_ConvUnify test_ConvUnify.cpp LayerGradUtil.cpp) - + add_test(NAME test_ConvUnify COMMAND test_ConvUnify) ################# test_BatchNorm ####################### diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 59d1e9273d..20a83d7aa1 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1792,6 +1792,34 @@ TEST(Layer, RowConvLayer) { } } +TEST(Layer, CropLayer) { + TestConfig config; + // config input_0 + config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ImageConfig* img = input->mutable_image_conf(); + img->set_channels(4); + img->set_img_size(16); + config.layerConfig.set_axis(2); + config.layerConfig.add_offset(0); + config.layerConfig.add_offset(0); + + // config input_1 + config.inputDefs.push_back({INPUT_DATA, "layer_1", 128, 0}); + input = config.layerConfig.add_inputs(); + img = input->mutable_image_conf(); + img->set_channels(2); + img->set_img_size(8); + + // config crop layer + config.layerConfig.set_type("crop"); + config.layerConfig.set_name("cropLayer"); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "crop", 100, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 37cd16c798..83f72c137b 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -472,10 +472,16 @@ message LayerConfig { // blank label used in ctc loss optional uint32 blank = 52 [default = 0]; - // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which + // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which // controls the scope of pooling operation. can be set > 0. // leave empty or set to -1 to disable this stride pooling. optional int32 seq_pool_stride = 53 [default = -1]; + + // for crop layer + optional int32 axis = 54 [default = 2]; + repeated uint32 offset = 55; + repeated uint32 shape = 56; + } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 370529ed97..8c529fdfd3 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1986,6 +1986,51 @@ class PadLayer(LayerBase): self.config.size = out_ch * out_h * out_w +@config_layer('crop') +class CropLayer(LayerBase): + def __init__(self, inputs, axis, offset, shape, name, **xargs): + super(CropLayer, self).__init__(name, 'crop', 0, inputs=inputs, **xargs) + self.conf.axis = axis + self.conf.axis = offset + self.conf.axis = shape + + crop = self.inputs[0].crop + self.config.inputs[0].crop_conf.axis = crop.axis + self.config.inputs[0].crop_conf.offset.extend(crop.offset) + self.config.inputs[0].crop_conf.shape.extend(crop.shape) + + # get channel, width and height from input_0 layer + input_layer = self.get_input_layer(0) + image_conf = self.config.inputs[0].image_conf + image_conf.img_size = input_layer.width + image_conf.img_size_y = input_layer.height + image_conf.channels = input_layer.size / (input_layer.width * + input_layer.height) + out_ch = image_conf.channels + out_h = image_conf.img_size + out_w = image_conf.img_size_y + if len(self.inputs) == 2: + # get channels, width and height from input_1 layer + input_layer = self.get_input_layer(1) + image_conf = self.config.inputs[1].image_conf + image_conf.img_size = input_layer.width + image_conf.img_size_y = input_layer.height + image_conf.channels = input_layer.size / (input_layer.width * + input_layer.height) + out_ch = image_conf.channels + out_h = image_conf.img_size_y + out_w = image_conf.img_size + else: + # set channels, width and heigth of current layer + if len(shape) > 2: + out_ch = shape[-3] + if len(shape) > 1: + out_h = shape[-2] + if len(shape) > 0: + out_w = shape[-1] + self.set_cnn_layer(name, out_h, out_w, out_ch) + + @config_layer('batch_norm') class BatchNormLayer(LayerBase): layer_type = 'batch_norm' diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 206de1f8e1..f9de086cba 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -217,6 +217,7 @@ class LayerType(object): SMOOTH_L1 = 'smooth_l1' PRELU = 'prelu' + CROP_LAYER = 'crop' @staticmethod def is_layer_type(type_name): @@ -5853,3 +5854,56 @@ def prelu_layer(input, layer_type=LayerType.PRELU, parents=input, size=l.config.size) + + +@wrap_name_default() +@layer_support() +def crop_layer(input, axis, offset, shape=None, name=None, layer_attr=None): + """ + The crop layer crop images by offset and shape. User can set crop shape by + args 'shape' explicitly or by reference input layer. + + + The example usage is: + + .. code-block:: python + + crop = crop_layer(input=[image_input, reference_input], axis=2, offset=[2, 3]) + + :param input: The input layer.If two inputs were setted, + the second input will be regarded as reference input + :type input: LayerOutput or Sequence + :param axis: start axis to be cropped. To image input layer: + - 0: batch size + - 1: channels + - 2: height + - 3: width + :type partial_sum: int + :param offset: The crop offset + :type offset: Sequence + :param shape: The shape to be cropped. Default is None. + :type shape: Sqquence | None + :param name: Name of this layer. + :type name: basestring + :return: LayerOutput object. + :rtype: LayerOutput + """ + if isinstance(input, LayerOutput): + input = [input] + elif isinstance(input, Projection): + input = [input] + else: + assert isinstance(input, collections.Sequence) + l = Layer( + inputs=[x.name for x in input], + axis=axis, + offset=offset, + shape=shape, + name=name, + type=LayerType.CROP_LAYER, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name=name, + layer_type=LayerType.CROP_LAYER, + parents=input, + size=l.config.size) From cbd61c7719b148043f4b8a4f3feacca57c17f1ab Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 5 Jul 2017 10:36:22 +0800 Subject: [PATCH 04/37] fix crop function test --- paddle/function/CropOpTest.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp index 71d9b05812..dcba972e10 100644 --- a/paddle/function/CropOpTest.cpp +++ b/paddle/function/CropOpTest.cpp @@ -34,8 +34,10 @@ TEST(Crop, real) { TensorShape outDims{numSamples, 2, 3, 3}; compare.addInputs( BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims)); - compare.addOutputs(BufferArg( - VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO)); + compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, + test_grad ? inDims : outDims, + tes_grad ? ADD_TO : ASSIGN_TO), + test_grad ? ADD_TO : ASSIGN_TO); compare.run(); } } From e10040ca8a9b4b9d9eb8275cab468edefd94caf9 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 16 Jun 2017 19:02:46 +0800 Subject: [PATCH 05/37] add crop layer --- paddle/function/CMakeLists.txt | 1 + paddle/function/CropOp.cpp | 177 ++++++++++++++++++++++++++++ paddle/function/CropOp.h | 56 +++++++++ paddle/function/CropOpGpu.cu | 109 +++++++++++++++++ paddle/function/CropOpTest.cpp | 47 ++++++++ paddle/gserver/layers/CropLayer.cpp | 101 ++++++++++++++++ paddle/gserver/layers/CropLayer.h | 46 ++++++++ 7 files changed, 537 insertions(+) create mode 100644 paddle/function/CropOp.cpp create mode 100644 paddle/function/CropOp.h create mode 100644 paddle/function/CropOpGpu.cu create mode 100644 paddle/function/CropOpTest.cpp create mode 100644 paddle/gserver/layers/CropLayer.cpp create mode 100644 paddle/gserver/layers/CropLayer.h diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 1518a8a654..f19a1eb777 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -37,6 +37,7 @@ if(WITH_GPU) add_simple_unittest(MulOpTest) add_simple_unittest(CosSimOpTest) add_simple_unittest(RowConvOpTest) + add_simple_unittest(CropOpTest) endif() add_simple_unittest(ConvOpTest) diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp new file mode 100644 index 0000000000..4d47d9c149 --- /dev/null +++ b/paddle/function/CropOp.cpp @@ -0,0 +1,177 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CropOp.h" +#include "paddle/math/Vector.h" +#include "paddle/function/TensorShape.h" +namespace paddle { + +static inline CropConf castToCropConf(const FuncConfig& conf) { + return {conf.get>("crop_corner"), + conf.get>("crop_shape")}; +} + +template <> +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const CropConf& crop) { + int cCrop = crop.corner[0]; + int hCrop = crop.corner[1]; + int wCrop = crop.corner[2]; + + int num = inShape[0]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + int outC = crop.shape[0]; + int outH = crop.shape[1]; + int outW = crop.shape[2]; + + for (int n = 0; n < num; n++) { + for (int c = 0; c < outC; c++) { + for (int h = 0; h < outH; h++) { + int outoff = ((n * outC + c) * outH + h) * outW; + int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop; + memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real)); + } + } + } +} + +template <> +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape outShape, + const CropConf& crop) { + int cCrop = crop.corner[0]; + int hCrop = crop.corner[1]; + int wCrop = crop.corner[2]; + + int num = outShape[0]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + int inC = crop.shape[0]; + int inH = crop.shape[1]; + int inW = crop.shape[2]; + + for (int n = 0; n < num; n++) { + for (int c = 0; c < inC; c++) { + for (int h = 0; h < inH; h++) { + int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop; + int inoff = ((n * inC + c) * inH + h) * inW; + CpuVector inG = CpuVector(inW, const_cast(inGrad + inoff)); + CpuVector outG = CpuVector(inW, outGrad + outoff); + outG += inG; + } + } + } +} + +/** + * \brief Crop input according to the specify corner and shape. + * The input and output is a 4D tensor. In CropFunc, we only + * crop the 2nd to 4th dimension. + * + * Argument in this Function: + * \param pad_ A struct object contains the cropping corner and shape. + * \param inputs A 4D tensor, only one input. + * \param outputs A 4D tensor, the output value after cropping. + * + * For example, + * Input(2,2,2,3) = [ + * [ [[1,2,3], [3,4,5]], + * [[2,3,5], [1,6,7]] ], + * [ [[4,3,1], [1,8,7]], + * [[3,8,9], [2,3,5]] ] + * ] # the input shape is (2,2,2,3) + * + * pad_: if corner = (0,1,1) and crop_shape = (2,1,2) + * Output(2,2,1,2) = [ + * [ [[4,5]], + * [[6,7]] ], + * [ [[8,7]], + * [[3,5]] ] + * ] # the input shape is (2,2,2,3) + */ +template +class CropFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + crop_ = castToCropConf(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(outputs[0].shape()[1], crop_.shape[0]); + CHECK_EQ(outputs[0].shape()[2], crop_.shape[1]); + CHECK_EQ(outputs[0].shape()[3], crop_.shape[2]); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + TensorShape inShape = inputs[0].shape(); + + Crop( + outputs[0].data(), inputs[0].data(), inShape, crop_); + } + +private: + CropConf crop_; +}; + +/** + * \brief The backward propagation of cropping Function. + * + * Argument in this Function: + * \param crop_ The same meaning as it in CropFunc. + * \param inputs The gradient with respect to the output value of CropFunc. + * \param outputs The gradient with respect to the input value of CropFunc. + */ + +template +class CropGradFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + crop_ = castToCropConf(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(inputs[0].shape()[1], crop_.shape[0]); + CHECK_EQ(inputs[0].shape()[2], crop_.shape[1]); + CHECK_EQ(inputs[0].shape()[3], crop_.shape[2]); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + TensorShape outShape = outputs[0].shape(); + + CropGrad( + inputs[0].data(), outputs[0].data(), outShape, crop_); + } + +private: + CropConf crop_; +}; + +REGISTER_TYPED_FUNC(Crop, CPU, CropFunc); +REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(Crop, GPU, CropFunc); +REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc); +#endif + +} // namespace paddle diff --git a/paddle/function/CropOp.h b/paddle/function/CropOp.h new file mode 100644 index 0000000000..78a55bd43e --- /dev/null +++ b/paddle/function/CropOp.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +struct CropConf { + /// The upper left corner of croped result + std::vector corner; + /// The shape of croped result + std::vector shape; +}; + +/** + * \brief This funtion crops inputs according to the specify start point and + *shape. + * + * \param[out] outputs save results. + * \param[in] inputs input data. + * \param[in] inShape the shape of input tensor. + * \param[in] crop the cropping config + */ +template +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const CropConf& crop); + +/** + * \brief Cropping operation backward. + * + * \param[out] inGrad gradients of previous layer + * \param[in] outGrad output gradient + * \param[in] inShape the shape of input tensor. + * \param[in] crop the cropping config + */ +template +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape inShape, + const CropConf& crop); +} // namespace paddle diff --git a/paddle/function/CropOpGpu.cu b/paddle/function/CropOpGpu.cu new file mode 100644 index 0000000000..f7d7d03abd --- /dev/null +++ b/paddle/function/CropOpGpu.cu @@ -0,0 +1,109 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "CropOp.h" + +namespace paddle { + +__global__ void KeCrop(real* outputs, const real* inputs, + int inC, int inH, int inW, + int cropC, int cropH, int cropW, + int outC, int outH, int outW, int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % outW; + const int h = (idx / outW) % outH; + const int c = (idx / outW / outH) % outC; + const int n = idx / outW / outH / outC; + + const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w; + outputs[idx] = inputs[off]; + } +} + +template <> +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const CropConf& crop) { + int cropC = crop.corner[0]; + int cropH = crop.corner[1]; + int cropW = crop.corner[2]; + + int num = inShape[0]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + int outC = crop.shape[0]; + int outH = crop.shape[1]; + int outW = crop.shape[2]; + + size_t nth = num * outC * outH * outW; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeCrop<<>> + (outputs, inputs, inC, inH, inW, cropC, cropH, cropW, + outC, outH, outW, nth); + CHECK_SYNC("Crop"); +} + +__global__ void KeCropDiff(const real* inGrad, real* outGrad, + int inC, int inH, int inW, + int cropC, int cropH, int cropW, + int outC, int outH, int outW, int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % inW; + const int h = (idx / inW) % inH; + const int c = (idx / inW / inH) % inC; + const int n = idx / inW / inH / inC; + + const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w; + + outGrad[off] += inGrad[idx]; + } +} + +template <> +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape outShape, + const CropConf& crop) { + int cropC = crop.corner[0]; + int cropH = crop.corner[1]; + int cropW = crop.corner[2]; + + int num = outShape[0]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + int inC = crop.shape[0]; + int inH = crop.shape[1]; + int inW = crop.shape[2]; + + size_t nth = num * inC * inH * inW; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeCropDiff <<>> + (inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW, + outC, outH, outW, nth); + CHECK_SYNC("CropGrad"); +} + +} // namespace paddle diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp new file mode 100644 index 0000000000..62b4bd9fde --- /dev/null +++ b/paddle/function/CropOpTest.cpp @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +namespace paddle { + +TEST(Crop, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {5, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; + for (bool test_grad : {false, true}) { + FunctionCompare compare( + test_grad ? "CropGrad" : "Crop", + FuncConfig() + .set>("crop_corner", {1, 1, 1}) + .set>("crop_shape", {2, 3, 3})); + TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; + TensorShape outDims{numSamples, 2, 3, 3}; + compare.addInputs( + BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims)); + compare.addOutputs(BufferArg( + VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO)); + compare.run(); + } + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/CropLayer.cpp b/paddle/gserver/layers/CropLayer.cpp new file mode 100644 index 0000000000..ab23d4617e --- /dev/null +++ b/paddle/gserver/layers/CropLayer.cpp @@ -0,0 +1,101 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CropLayer.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(crop, CropLayer); + +bool CropLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + + auto& crop_conf = config_.inputs(0).crop_conf(); + auto& img_conf = crop_conf.image_conf(); + CHECK_EQ(config_.inputs_size(), 1); + inDims_ = TensorShape( + {0, + img_conf.channels(), + img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(), + img_conf.img_size()}); + + crop_corner_ = {crop_conf.crop_corner(0), + crop_conf.crop_corner(1), + crop_conf.crop_corner(2)}; + crop_shape_ = {crop_conf.crop_shape(0), + crop_conf.crop_shape(1), + crop_conf.crop_shape(2)}; + + outDims_ = TensorShape(4); + setOutDims(0); + + createFunction(forward_, + "Crop", + FuncConfig() + .set("crop_corner", crop_corner_) + .set("crop_shape", crop_shape_)); + createFunction(backward_, + "CropGrad", + FuncConfig() + .set("crop_corner", crop_corner_) + .set("crop_shape", crop_shape_)); + + return true; +} + +void CropLayer::setOutDims(const size_t batchSize) { + outDims_.reshape({batchSize, crop_shape_[0], crop_shape_[1], crop_shape_[2]}); +} + +void CropLayer::setTensorDim(const size_t batchSize) { + CHECK_EQ(static_cast(inputLayers_.size()), 1); + inDims_.setDim(0, batchSize); + int h = inputLayers_[0]->getOutput().getFrameHeight(); + if (h != 0) inDims_.setDim(2, h); + int w = inputLayers_[0]->getOutput().getFrameWidth(); + if (w != 0) inDims_.setDim(3, w); + setOutDims(batchSize); +} + +void CropLayer::forward(PassType passType) { + Layer::forward(passType); + MatrixPtr input = inputLayers_[0]->getOutputValue(); + size_t batchSize = input->getHeight(); + setTensorDim(batchSize); + int size = outDims_[1] * outDims_[2] * outDims_[3]; + resetOutput(batchSize, size); + MatrixPtr outV = getOutputValue(); + REGISTER_TIMER_INFO("CropForward", getName().c_str()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), inDims_); + outputs.addArg(*getOutputValue(), outDims_, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); +} + +void CropLayer::backward(const UpdateCallback& callback) { + (void)callback; + REGISTER_TIMER_INFO("CropBackward", getName().c_str()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outDims_); + outputs.addArg(*getInputGrad(0), inDims_, ADD_TO); + backward_[0]->calc(inputs, outputs); +} +} // namespace paddle diff --git a/paddle/gserver/layers/CropLayer.h b/paddle/gserver/layers/CropLayer.h new file mode 100644 index 0000000000..3ce89707ca --- /dev/null +++ b/paddle/gserver/layers/CropLayer.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" + +namespace paddle { + +/** + * \brief This layer crop inputs according to the specify corner and shape. + * The input and output is a 4D tensor. Cropping from the 2nd to + * the 4th dimenstion. + */ +class CropLayer : public Layer { +public: + explicit CropLayer(const LayerConfig& config) : Layer(config) {} + + ~CropLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + +protected: + void setOutDims(const size_t batchSize); + void setTensorDim(const size_t batchSize); + + std::vector crop_corner_; + std::vector crop_shape_; + TensorShape inDims_; + TensorShape outDims_; +}; +} // namespace paddle From d1d70ec8319a55964231f2e925ef8cb881c94497 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 22 Jun 2017 16:54:07 +0800 Subject: [PATCH 06/37] Refine configure option of crop layer 1. change configure content to 'axis, offset, shape' 2. add an optional input to crop layer as cropping reference --- paddle/function/CropOp.cpp | 63 ++++++++++++--------------- paddle/function/CropOp.h | 15 ++----- paddle/function/CropOpGpu.cu | 32 ++++++++------ paddle/function/CropOpTest.cpp | 4 +- paddle/gserver/layers/CropLayer.cpp | 67 ++++++++++++++++++++++------- paddle/gserver/layers/CropLayer.h | 13 ++++-- 6 files changed, 114 insertions(+), 80 deletions(-) diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp index 4d47d9c149..0d511ceef5 100644 --- a/paddle/function/CropOp.cpp +++ b/paddle/function/CropOp.cpp @@ -17,28 +17,27 @@ limitations under the License. */ #include "paddle/function/TensorShape.h" namespace paddle { -static inline CropConf castToCropConf(const FuncConfig& conf) { - return {conf.get>("crop_corner"), - conf.get>("crop_shape")}; -} - template <> void Crop(real* outputs, const real* inputs, const TensorShape inShape, - const CropConf& crop) { - int cCrop = crop.corner[0]; - int hCrop = crop.corner[1]; - int wCrop = crop.corner[2]; + const FuncConfig& conf) { + std::vector crop_corner = + conf.get>("crop_corner"); + std::vector crop_shape = + conf.get>("crop_shape"); + int cCrop = crop_corner[1]; + int hCrop = crop_corner[2]; + int wCrop = crop_corner[3]; int num = inShape[0]; int inC = inShape[1]; int inH = inShape[2]; int inW = inShape[3]; - int outC = crop.shape[0]; - int outH = crop.shape[1]; - int outW = crop.shape[2]; + int outC = crop_shape[1]; + int outH = crop_shape[2]; + int outW = crop_shape[3]; for (int n = 0; n < num; n++) { for (int c = 0; c < outC; c++) { @@ -55,19 +54,23 @@ template <> void CropGrad(const real* inGrad, real* outGrad, const TensorShape outShape, - const CropConf& crop) { - int cCrop = crop.corner[0]; - int hCrop = crop.corner[1]; - int wCrop = crop.corner[2]; + const FuncConfig& conf) { + std::vector crop_corner = + conf.get>("crop_corner"); + std::vector crop_shape = + conf.get>("crop_shape"); + int cCrop = crop_corner[1]; + int hCrop = crop_corner[2]; + int wCrop = crop_corner[3]; int num = outShape[0]; int outC = outShape[1]; int outH = outShape[2]; int outW = outShape[3]; - int inC = crop.shape[0]; - int inH = crop.shape[1]; - int inW = crop.shape[2]; + int inC = crop_shape[1]; + int inH = crop_shape[2]; + int inW = crop_shape[3]; for (int n = 0; n < num; n++) { for (int c = 0; c < inC; c++) { @@ -111,26 +114,21 @@ void CropGrad(const real* inGrad, template class CropFunc : public FunctionBase { public: - void init(const FuncConfig& config) override { - crop_ = castToCropConf(config); - } + void init(const FuncConfig& config) override { conf_ = config; } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1UL, inputs.size()); CHECK_EQ(1UL, outputs.size()); - CHECK_EQ(outputs[0].shape()[1], crop_.shape[0]); - CHECK_EQ(outputs[0].shape()[2], crop_.shape[1]); - CHECK_EQ(outputs[0].shape()[3], crop_.shape[2]); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); TensorShape inShape = inputs[0].shape(); Crop( - outputs[0].data(), inputs[0].data(), inShape, crop_); + outputs[0].data(), inputs[0].data(), inShape, conf_); } private: - CropConf crop_; + FuncConfig conf_; }; /** @@ -145,26 +143,21 @@ private: template class CropGradFunc : public FunctionBase { public: - void init(const FuncConfig& config) override { - crop_ = castToCropConf(config); - } + void init(const FuncConfig& config) override { conf_ = config; } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1UL, inputs.size()); CHECK_EQ(1UL, outputs.size()); - CHECK_EQ(inputs[0].shape()[1], crop_.shape[0]); - CHECK_EQ(inputs[0].shape()[2], crop_.shape[1]); - CHECK_EQ(inputs[0].shape()[3], crop_.shape[2]); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); TensorShape outShape = outputs[0].shape(); CropGrad( - inputs[0].data(), outputs[0].data(), outShape, crop_); + inputs[0].data(), outputs[0].data(), outShape, conf_); } private: - CropConf crop_; + FuncConfig conf_; }; REGISTER_TYPED_FUNC(Crop, CPU, CropFunc); diff --git a/paddle/function/CropOp.h b/paddle/function/CropOp.h index 78a55bd43e..71e8c4c00e 100644 --- a/paddle/function/CropOp.h +++ b/paddle/function/CropOp.h @@ -18,13 +18,6 @@ limitations under the License. */ namespace paddle { -struct CropConf { - /// The upper left corner of croped result - std::vector corner; - /// The shape of croped result - std::vector shape; -}; - /** * \brief This funtion crops inputs according to the specify start point and *shape. @@ -32,13 +25,13 @@ struct CropConf { * \param[out] outputs save results. * \param[in] inputs input data. * \param[in] inShape the shape of input tensor. - * \param[in] crop the cropping config + * \param[in] conf the cropping config */ template void Crop(real* outputs, const real* inputs, const TensorShape inShape, - const CropConf& crop); + const FuncConfig& conf); /** * \brief Cropping operation backward. @@ -46,11 +39,11 @@ void Crop(real* outputs, * \param[out] inGrad gradients of previous layer * \param[in] outGrad output gradient * \param[in] inShape the shape of input tensor. - * \param[in] crop the cropping config + * \param[in] conf the cropping config */ template void CropGrad(const real* inGrad, real* outGrad, const TensorShape inShape, - const CropConf& crop); + const FuncConfig& conf); } // namespace paddle diff --git a/paddle/function/CropOpGpu.cu b/paddle/function/CropOpGpu.cu index f7d7d03abd..cadb58b6e9 100644 --- a/paddle/function/CropOpGpu.cu +++ b/paddle/function/CropOpGpu.cu @@ -37,19 +37,21 @@ template <> void Crop(real* outputs, const real* inputs, const TensorShape inShape, - const CropConf& crop) { - int cropC = crop.corner[0]; - int cropH = crop.corner[1]; - int cropW = crop.corner[2]; + const FuncConfig& conf) { + std::vector crop_corner = conf.get>("crop_corner"); + std::vector crop_shape = conf.get>("crop_shape"); + int cropC = crop_corner[1]; + int cropH = crop_corner[2]; + int cropW = crop_corner[3]; int num = inShape[0]; int inC = inShape[1]; int inH = inShape[2]; int inW = inShape[3]; - int outC = crop.shape[0]; - int outH = crop.shape[1]; - int outW = crop.shape[2]; + int outC = crop_shape[1]; + int outH = crop_shape[2]; + int outW = crop_shape[3]; size_t nth = num * outC * outH * outW; int blockSize = 1024; @@ -82,19 +84,21 @@ template <> void CropGrad(const real* inGrad, real* outGrad, const TensorShape outShape, - const CropConf& crop) { - int cropC = crop.corner[0]; - int cropH = crop.corner[1]; - int cropW = crop.corner[2]; + const FuncConfig& conf) { + std::vector crop_corner = conf.get>("crop_corner"); + std::vector crop_shape = conf.get>("crop_shape"); + int cropC = crop_corner[1]; + int cropH = crop_corner[2]; + int cropW = crop_corner[3]; int num = outShape[0]; int outC = outShape[1]; int outH = outShape[2]; int outW = outShape[3]; - int inC = crop.shape[0]; - int inH = crop.shape[1]; - int inW = crop.shape[2]; + int inC = crop_shape[1]; + int inH = crop_shape[2]; + int inW = crop_shape[3]; size_t nth = num * inC * inH * inW; int blockSize = 1024; diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp index 62b4bd9fde..c331a70d1f 100644 --- a/paddle/function/CropOpTest.cpp +++ b/paddle/function/CropOpTest.cpp @@ -28,8 +28,8 @@ TEST(Crop, real) { FunctionCompare compare( test_grad ? "CropGrad" : "Crop", FuncConfig() - .set>("crop_corner", {1, 1, 1}) - .set>("crop_shape", {2, 3, 3})); + .set>("crop_corner", {0, 1, 1, 1}) + .set>("crop_shape", {0, 2, 3, 3})); TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; TensorShape outDims{numSamples, 2, 3, 3}; compare.addInputs( diff --git a/paddle/gserver/layers/CropLayer.cpp b/paddle/gserver/layers/CropLayer.cpp index ab23d4617e..198ceffb46 100644 --- a/paddle/gserver/layers/CropLayer.cpp +++ b/paddle/gserver/layers/CropLayer.cpp @@ -25,20 +25,57 @@ bool CropLayer::init(const LayerMap& layerMap, Layer::init(layerMap, parameterMap); auto& crop_conf = config_.inputs(0).crop_conf(); - auto& img_conf = crop_conf.image_conf(); - CHECK_EQ(config_.inputs_size(), 1); - inDims_ = TensorShape( - {0, - img_conf.channels(), - img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(), - img_conf.img_size()}); - - crop_corner_ = {crop_conf.crop_corner(0), - crop_conf.crop_corner(1), - crop_conf.crop_corner(2)}; - crop_shape_ = {crop_conf.crop_shape(0), - crop_conf.crop_shape(1), - crop_conf.crop_shape(2)}; + crop_axis_ = crop_conf.axis(); + for (int i = 0; i < crop_conf.offset_size(); i++) { + crop_offsets_[i] = crop_conf.offset(i); + } + + // 1. get input_0 shape + auto& input0_img_conf = config_.inputs(0).image_conf(); + inDims_ = TensorShape({0, + input0_img_conf.channels(), + input0_img_conf.has_img_size_y() + ? input0_img_conf.img_size_y() + : input0_img_conf.img_size(), + input0_img_conf.img_size()}); + + // 2. get output shape from input_1 or crop shap conf + if (config_.inputs_size() == 2) { + auto& input1_img_conf = config_.inputs(1).image_conf(); + targetDims_ = TensorShape({0, + input1_img_conf.channels(), + input1_img_conf.has_img_size_y() + ? input1_img_conf.img_size_y() + : input1_img_conf.img_size(), + input1_img_conf.img_size()}); + } else { + targetDims_ = TensorShape({crop_conf.shape(0), + crop_conf.shape(1), + crop_conf.shape(2), + crop_conf.shape(3)}); + } + + // 3. get final crop shape + int dimSize = 4; + for (int i = 0; i < dimSize; i++) { + if (i >= crop_axis_) { + crop_shape_[i] = targetDims_[i]; + } else { + crop_shape_[i] = inDims_[i]; + } + } + + // 4. get final crop corner + crop_corner_ = {0, 0, 0, 0}; + for (int i = 0; i < dimSize; i++) { + if (i >= crop_axis_) { + if (crop_offsets_.size() > 1) { + crop_corner_[i] = crop_offsets_[i - crop_axis_]; + } else { + crop_corner_[i] = crop_offsets_[0]; + } + } + } outDims_ = TensorShape(4); setOutDims(0); @@ -58,7 +95,7 @@ bool CropLayer::init(const LayerMap& layerMap, } void CropLayer::setOutDims(const size_t batchSize) { - outDims_.reshape({batchSize, crop_shape_[0], crop_shape_[1], crop_shape_[2]}); + outDims_.reshape({batchSize, crop_shape_[1], crop_shape_[2], crop_shape_[3]}); } void CropLayer::setTensorDim(const size_t batchSize) { diff --git a/paddle/gserver/layers/CropLayer.h b/paddle/gserver/layers/CropLayer.h index 3ce89707ca..23cede1c3f 100644 --- a/paddle/gserver/layers/CropLayer.h +++ b/paddle/gserver/layers/CropLayer.h @@ -19,9 +19,13 @@ limitations under the License. */ namespace paddle { /** - * \brief This layer crop inputs according to the specify corner and shape. - * The input and output is a 4D tensor. Cropping from the 2nd to - * the 4th dimenstion. + * \brief This layer crop input according to the specify conf. + * input_0: input to be cropped + * input_1: optional reference input + * axis: start dimension to be croped + * offset: offset of cropping in each dimension + * shape: if reference input layer was not setted, + * crop input as this shape conf */ class CropLayer : public Layer { public: @@ -38,9 +42,12 @@ protected: void setOutDims(const size_t batchSize); void setTensorDim(const size_t batchSize); + int32_t crop_axis_; + std::vector crop_offsets_; std::vector crop_corner_; std::vector crop_shape_; TensorShape inDims_; + TensorShape targetDims_; TensorShape outDims_; }; } // namespace paddle From 5e6e1f636a356b6ae7d25ff8494354349b3b4f5f Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 5 Jul 2017 00:53:32 +0800 Subject: [PATCH 07/37] Add grad test and python wrapper for crop layer --- paddle/function/CropOp.cpp | 2 +- paddle/function/CropOpTest.cpp | 2 +- paddle/gserver/layers/CropLayer.cpp | 23 ++++---- paddle/gserver/tests/CMakeLists.txt | 2 +- paddle/gserver/tests/test_LayerGrad.cpp | 28 ++++++++++ proto/ModelConfig.proto | 8 ++- python/paddle/trainer/config_parser.py | 45 ++++++++++++++++ .../paddle/trainer_config_helpers/layers.py | 54 +++++++++++++++++++ 8 files changed, 147 insertions(+), 17 deletions(-) diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp index 0d511ceef5..1bb194a9bc 100644 --- a/paddle/function/CropOp.cpp +++ b/paddle/function/CropOp.cpp @@ -148,7 +148,7 @@ public: void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1UL, inputs.size()); CHECK_EQ(1UL, outputs.size()); - CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); TensorShape outShape = outputs[0].shape(); diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp index c331a70d1f..71d9b05812 100644 --- a/paddle/function/CropOpTest.cpp +++ b/paddle/function/CropOpTest.cpp @@ -25,7 +25,7 @@ TEST(Crop, real) { VLOG(3) << " numSamples=" << numSamples << " channels=" << channels << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; for (bool test_grad : {false, true}) { - FunctionCompare compare( + CpuGpuFuncCompare compare( test_grad ? "CropGrad" : "Crop", FuncConfig() .set>("crop_corner", {0, 1, 1, 1}) diff --git a/paddle/gserver/layers/CropLayer.cpp b/paddle/gserver/layers/CropLayer.cpp index 198ceffb46..b2fa17b400 100644 --- a/paddle/gserver/layers/CropLayer.cpp +++ b/paddle/gserver/layers/CropLayer.cpp @@ -14,7 +14,6 @@ limitations under the License. */ #include "CropLayer.h" #include "paddle/utils/Stat.h" - namespace paddle { REGISTER_LAYER(crop, CropLayer); @@ -24,10 +23,9 @@ bool CropLayer::init(const LayerMap& layerMap, /* Initialize the basic parent class */ Layer::init(layerMap, parameterMap); - auto& crop_conf = config_.inputs(0).crop_conf(); - crop_axis_ = crop_conf.axis(); - for (int i = 0; i < crop_conf.offset_size(); i++) { - crop_offsets_[i] = crop_conf.offset(i); + crop_axis_ = config_.axis(); + for (int i = 0; i < config_.offset_size(); i++) { + crop_offsets_.push_back(config_.offset(i)); } // 1. get input_0 shape @@ -38,7 +36,6 @@ bool CropLayer::init(const LayerMap& layerMap, ? input0_img_conf.img_size_y() : input0_img_conf.img_size(), input0_img_conf.img_size()}); - // 2. get output shape from input_1 or crop shap conf if (config_.inputs_size() == 2) { auto& input1_img_conf = config_.inputs(1).image_conf(); @@ -49,19 +46,19 @@ bool CropLayer::init(const LayerMap& layerMap, : input1_img_conf.img_size(), input1_img_conf.img_size()}); } else { - targetDims_ = TensorShape({crop_conf.shape(0), - crop_conf.shape(1), - crop_conf.shape(2), - crop_conf.shape(3)}); + targetDims_ = TensorShape({config_.shape(0), + config_.shape(1), + config_.shape(2), + config_.shape(3)}); } // 3. get final crop shape int dimSize = 4; for (int i = 0; i < dimSize; i++) { if (i >= crop_axis_) { - crop_shape_[i] = targetDims_[i]; + crop_shape_.push_back(targetDims_[i]); } else { - crop_shape_[i] = inDims_[i]; + crop_shape_.push_back(inDims_[i]); } } @@ -99,7 +96,7 @@ void CropLayer::setOutDims(const size_t batchSize) { } void CropLayer::setTensorDim(const size_t batchSize) { - CHECK_EQ(static_cast(inputLayers_.size()), 1); + CHECK_EQ(static_cast(inputLayers_.size()), 2); inDims_.setDim(0, batchSize); int h = inputLayers_[0]->getOutput().getFrameHeight(); if (h != 0) inDims_.setDim(2, h); diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index 92f6cbcfe5..a43adc7ce7 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -56,7 +56,7 @@ add_test(NAME test_DetectionOutput add_unittest_without_exec(test_ConvUnify test_ConvUnify.cpp LayerGradUtil.cpp) - + add_test(NAME test_ConvUnify COMMAND test_ConvUnify) ################# test_BatchNorm ####################### diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 59d1e9273d..20a83d7aa1 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1792,6 +1792,34 @@ TEST(Layer, RowConvLayer) { } } +TEST(Layer, CropLayer) { + TestConfig config; + // config input_0 + config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ImageConfig* img = input->mutable_image_conf(); + img->set_channels(4); + img->set_img_size(16); + config.layerConfig.set_axis(2); + config.layerConfig.add_offset(0); + config.layerConfig.add_offset(0); + + // config input_1 + config.inputDefs.push_back({INPUT_DATA, "layer_1", 128, 0}); + input = config.layerConfig.add_inputs(); + img = input->mutable_image_conf(); + img->set_channels(2); + img->set_img_size(8); + + // config crop layer + config.layerConfig.set_type("crop"); + config.layerConfig.set_name("cropLayer"); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "crop", 100, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 37cd16c798..83f72c137b 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -472,10 +472,16 @@ message LayerConfig { // blank label used in ctc loss optional uint32 blank = 52 [default = 0]; - // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which + // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which // controls the scope of pooling operation. can be set > 0. // leave empty or set to -1 to disable this stride pooling. optional int32 seq_pool_stride = 53 [default = -1]; + + // for crop layer + optional int32 axis = 54 [default = 2]; + repeated uint32 offset = 55; + repeated uint32 shape = 56; + } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 370529ed97..8c529fdfd3 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1986,6 +1986,51 @@ class PadLayer(LayerBase): self.config.size = out_ch * out_h * out_w +@config_layer('crop') +class CropLayer(LayerBase): + def __init__(self, inputs, axis, offset, shape, name, **xargs): + super(CropLayer, self).__init__(name, 'crop', 0, inputs=inputs, **xargs) + self.conf.axis = axis + self.conf.axis = offset + self.conf.axis = shape + + crop = self.inputs[0].crop + self.config.inputs[0].crop_conf.axis = crop.axis + self.config.inputs[0].crop_conf.offset.extend(crop.offset) + self.config.inputs[0].crop_conf.shape.extend(crop.shape) + + # get channel, width and height from input_0 layer + input_layer = self.get_input_layer(0) + image_conf = self.config.inputs[0].image_conf + image_conf.img_size = input_layer.width + image_conf.img_size_y = input_layer.height + image_conf.channels = input_layer.size / (input_layer.width * + input_layer.height) + out_ch = image_conf.channels + out_h = image_conf.img_size + out_w = image_conf.img_size_y + if len(self.inputs) == 2: + # get channels, width and height from input_1 layer + input_layer = self.get_input_layer(1) + image_conf = self.config.inputs[1].image_conf + image_conf.img_size = input_layer.width + image_conf.img_size_y = input_layer.height + image_conf.channels = input_layer.size / (input_layer.width * + input_layer.height) + out_ch = image_conf.channels + out_h = image_conf.img_size_y + out_w = image_conf.img_size + else: + # set channels, width and heigth of current layer + if len(shape) > 2: + out_ch = shape[-3] + if len(shape) > 1: + out_h = shape[-2] + if len(shape) > 0: + out_w = shape[-1] + self.set_cnn_layer(name, out_h, out_w, out_ch) + + @config_layer('batch_norm') class BatchNormLayer(LayerBase): layer_type = 'batch_norm' diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 206de1f8e1..f9de086cba 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -217,6 +217,7 @@ class LayerType(object): SMOOTH_L1 = 'smooth_l1' PRELU = 'prelu' + CROP_LAYER = 'crop' @staticmethod def is_layer_type(type_name): @@ -5853,3 +5854,56 @@ def prelu_layer(input, layer_type=LayerType.PRELU, parents=input, size=l.config.size) + + +@wrap_name_default() +@layer_support() +def crop_layer(input, axis, offset, shape=None, name=None, layer_attr=None): + """ + The crop layer crop images by offset and shape. User can set crop shape by + args 'shape' explicitly or by reference input layer. + + + The example usage is: + + .. code-block:: python + + crop = crop_layer(input=[image_input, reference_input], axis=2, offset=[2, 3]) + + :param input: The input layer.If two inputs were setted, + the second input will be regarded as reference input + :type input: LayerOutput or Sequence + :param axis: start axis to be cropped. To image input layer: + - 0: batch size + - 1: channels + - 2: height + - 3: width + :type partial_sum: int + :param offset: The crop offset + :type offset: Sequence + :param shape: The shape to be cropped. Default is None. + :type shape: Sqquence | None + :param name: Name of this layer. + :type name: basestring + :return: LayerOutput object. + :rtype: LayerOutput + """ + if isinstance(input, LayerOutput): + input = [input] + elif isinstance(input, Projection): + input = [input] + else: + assert isinstance(input, collections.Sequence) + l = Layer( + inputs=[x.name for x in input], + axis=axis, + offset=offset, + shape=shape, + name=name, + type=LayerType.CROP_LAYER, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name=name, + layer_type=LayerType.CROP_LAYER, + parents=input, + size=l.config.size) From 86bdb2f33fa9e9e806e8248b14a172ce4e0557c6 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 5 Jul 2017 10:36:22 +0800 Subject: [PATCH 08/37] fix crop function test --- paddle/function/CropOpTest.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp index 71d9b05812..dcba972e10 100644 --- a/paddle/function/CropOpTest.cpp +++ b/paddle/function/CropOpTest.cpp @@ -34,8 +34,10 @@ TEST(Crop, real) { TensorShape outDims{numSamples, 2, 3, 3}; compare.addInputs( BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims)); - compare.addOutputs(BufferArg( - VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO)); + compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, + test_grad ? inDims : outDims, + tes_grad ? ADD_TO : ASSIGN_TO), + test_grad ? ADD_TO : ASSIGN_TO); compare.run(); } } From cf868918012f29b94628cff7e80cfc6e65bf0ee6 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 5 Jul 2017 11:34:16 +0800 Subject: [PATCH 09/37] fix unittest of crop layer --- paddle/function/CropOpTest.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp index dcba972e10..6f11abfdf6 100644 --- a/paddle/function/CropOpTest.cpp +++ b/paddle/function/CropOpTest.cpp @@ -36,7 +36,7 @@ TEST(Crop, real) { BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims)); compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, - tes_grad ? ADD_TO : ASSIGN_TO), + test_grad ? ADD_TO : ASSIGN_TO), test_grad ? ADD_TO : ASSIGN_TO); compare.run(); } From acfd2fc6dfc1bf06bbfd6e25496ca1dfde881551 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 5 Jul 2017 11:54:47 +0800 Subject: [PATCH 10/37] fix cpp format --- paddle/function/CropOp.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp index 1bb194a9bc..39e06fc120 100644 --- a/paddle/function/CropOp.cpp +++ b/paddle/function/CropOp.cpp @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "CropOp.h" -#include "paddle/math/Vector.h" #include "paddle/function/TensorShape.h" +#include "paddle/math/Vector.h" + namespace paddle { template <> From 69b12225cc19919005d4cc1b4bb814a93ad205b3 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 11 Jul 2017 10:15:48 +0800 Subject: [PATCH 11/37] fix crop layer python wrapper bug --- python/paddle/trainer/config_parser.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 2d1b4a3b30..2f96d6fc0b 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1988,16 +1988,11 @@ class PadLayer(LayerBase): @config_layer('crop') class CropLayer(LayerBase): - def __init__(self, inputs, axis, offset, shape, name, **xargs): + def __init__(self, name, inputs, axis, offset, shape, **xargs): super(CropLayer, self).__init__(name, 'crop', 0, inputs=inputs, **xargs) - self.conf.axis = axis - self.conf.axis = offset - self.conf.axis = shape - - crop = self.inputs[0].crop - self.config.inputs[0].crop_conf.axis = crop.axis - self.config.inputs[0].crop_conf.offset.extend(crop.offset) - self.config.inputs[0].crop_conf.shape.extend(crop.shape) + self.config.axis = axis + self.config.offset.extend(offset) + self.config.shape.extend(shape) # get channel, width and height from input_0 layer input_layer = self.get_input_layer(0) From f812de2cce882dbfa84f0696e466aa8ef9de30a0 Mon Sep 17 00:00:00 2001 From: liaogang Date: Sat, 15 Jul 2017 01:36:27 +0800 Subject: [PATCH 12/37] ENH: unify PADDLE_ENFORCE --- paddle/framework/CMakeLists.txt | 1 - paddle/framework/attr_checker.h | 2 +- paddle/framework/enforce.h | 69 ------------------ paddle/framework/enforce_test.cc | 35 --------- paddle/framework/op_registry_test.cc | 6 +- paddle/framework/tensor.h | 2 +- paddle/framework/tensor_test.cc | 2 +- paddle/memory/detail/system_allocator.cc | 5 +- paddle/platform/CMakeLists.txt | 2 + paddle/platform/cpu_info.cc | 1 - paddle/platform/device_context.h | 53 ++++++-------- paddle/platform/dynload/dynamic_loader.cc | 2 +- paddle/platform/error.h | 87 ----------------------- paddle/platform/gpu_info.cc | 10 +-- 14 files changed, 39 insertions(+), 238 deletions(-) delete mode 100644 paddle/framework/enforce.h delete mode 100644 paddle/framework/enforce_test.cc delete mode 100644 paddle/platform/error.h diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 8415ce67e9..272649effc 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -5,7 +5,6 @@ nv_test(dim_test SRCS dim_test.cu DEPS ddim) cc_test(tensor_test SRCS tensor_test.cc DEPS ddim) cc_test(variable_test SRCS variable_test.cc) cc_test(scope_test SRCS scope_test.cc) -cc_test(enforce_test SRCS enforce_test.cc) proto_library(attr_type SRCS attr_type.proto) proto_library(op_proto SRCS op_proto.proto DEPS attr_type) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) diff --git a/paddle/framework/attr_checker.h b/paddle/framework/attr_checker.h index c0c33d8114..b527539d53 100644 --- a/paddle/framework/attr_checker.h +++ b/paddle/framework/attr_checker.h @@ -5,7 +5,7 @@ #include #include #include -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" namespace paddle { namespace framework { diff --git a/paddle/framework/enforce.h b/paddle/framework/enforce.h deleted file mode 100644 index 56cb7f9564..0000000000 --- a/paddle/framework/enforce.h +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include - -namespace paddle { -namespace framework { - -/** - * @brief Enforce exception. Inherits std::exception - * - * All enforce condition not met, will throw an EnforceNotMet exception. - */ -class EnforceNotMet : public std::exception { - public: - EnforceNotMet(const std::string& msg, const char* file, int fileline) { - std::ostringstream sout; - sout << msg << " at [" << file << ":" << fileline << "];"; - all_msg_ = sout.str(); - } - - const char* what() const noexcept override { return all_msg_.c_str(); } - - private: - std::string all_msg_; -}; - -// From https://stackoverflow.com/questions/30130930/ -// __buildin_expect is in C++ 11 standard. Since the condition which enforced -// should be true in most situation, it will make the compiler generate faster -// code by adding `UNLIKELY` macro. -#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) - -/** - * @brief Throw a EnforceNotMet exception, automatically filled __FILE__ & - * __LINE__ - * - * This macro take __VA_ARGS__, user can pass any type if that type can - * serialize to std::ostream - */ -#define PADDLE_THROW(...) \ - do { \ - throw ::paddle::framework::EnforceNotMet( \ - ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \ - } while (0) - -/** - * @brief Enforce a condition, otherwise throw an EnforceNotMet - */ -#define PADDLE_ENFORCE(condition, ...) \ - do { \ - if (UNLIKELY(!(condition))) { \ - PADDLE_THROW(__VA_ARGS__); \ - } \ - } while (0) - -} // namespace framework -} // namespace paddle diff --git a/paddle/framework/enforce_test.cc b/paddle/framework/enforce_test.cc deleted file mode 100644 index f8da1a192f..0000000000 --- a/paddle/framework/enforce_test.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -TEST(ENFORCE, OK) { - PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); - size_t val = 1; - const size_t limit = 10; - PADDLE_ENFORCE(val < limit, "Enforce is OK too"); -} - -TEST(ENFORCE, FAILED) { - bool in_catch = false; - try { - PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); - } catch (paddle::framework::EnforceNotMet err) { - in_catch = true; - std::string msg = "Enforce is not ok 123 at all"; - const char* what = err.what(); - for (size_t i = 0; i < msg.length(); ++i) { - ASSERT_EQ(what[i], msg[i]); - } - } - ASSERT_TRUE(in_catch); -} \ No newline at end of file diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 4791d4aaab..0a93655728 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -91,7 +91,7 @@ TEST(OpRegistry, IllegalAttr) { try { paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); - } catch (paddle::framework::EnforceNotMet err) { + } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = "larger_than check fail"; const char* err_msg = err.what(); @@ -138,7 +138,7 @@ TEST(OpRegistry, CustomChecker) { try { paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); - } catch (paddle::framework::EnforceNotMet err) { + } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = "Attribute 'test_attr' is required!"; const char* err_msg = err.what(); @@ -157,7 +157,7 @@ TEST(OpRegistry, CustomChecker) { try { paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); - } catch (paddle::framework::EnforceNotMet err) { + } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = "'test_attr' must be even!"; const char* err_msg = err.what(); diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 62e0710a82..5fdbb4f07a 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -18,7 +18,7 @@ limitations under the License. */ #include #include #include "paddle/framework/ddim.h" -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" #include "paddle/memory/memory.h" #include "paddle/platform/place.h" diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 255f69372f..34ea380b4e 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) { bool caught = false; try { src_tensor.data(); - } catch (paddle::framework::EnforceNotMet err) { + } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = "Tenosr holds no memory. Call Tensor::mutable_data first."; diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc index 1579174b1a..f61e67a329 100644 --- a/paddle/memory/detail/system_allocator.cc +++ b/paddle/memory/detail/system_allocator.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/memory/detail/system_allocator.h" #include "paddle/platform/assert.h" -#include "paddle/platform/error.h" +#include "paddle/platform/enforce.h" #include "paddle/platform/gpu_info.h" #include // for malloc and free @@ -128,8 +128,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) { // process is terminating, in which case we don't care if // cudaFree succeeds. if (err != cudaErrorCudartUnloading) { - platform::throw_on_error(err, - "cudaFree{Host} failed in GPUAllocator::Free."); + PADDLE_ENFORCE(err, "cudaFree{Host} failed in GPUAllocator::Free."); } } diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 6ac4035c0f..bd77bb7daa 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -8,6 +8,8 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) add_subdirectory(dynload) +cc_test(enforce_test SRCS enforce_test.cc) + IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader) ELSE() diff --git a/paddle/platform/cpu_info.cc b/paddle/platform/cpu_info.cc index 1905cfeee6..f2cbd863cf 100644 --- a/paddle/platform/cpu_info.cc +++ b/paddle/platform/cpu_info.cc @@ -22,7 +22,6 @@ limitations under the License. */ #endif #include "gflags/gflags.h" -#include "paddle/platform/error.h" DEFINE_double(fraction_of_cpu_memory_to_use, 1, "Default use 100% of CPU memory for PaddlePaddle," diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 51c8e13913..d2569fdc91 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -11,7 +11,7 @@ limitations under the License. */ #pragma once -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" @@ -74,8 +74,7 @@ class CUDADeviceContext : public DeviceContext { public: explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { GPUPlaceGuard guard(gpu_place_); - paddle::platform::throw_on_error(cudaStreamCreate(&stream_), - "cudaStreamCreate failed"); + PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_)); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); } @@ -86,8 +85,8 @@ class CUDADeviceContext : public DeviceContext { } void Wait() { - paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), - "cudaStreamSynchronize failed"); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_), + "cudaStreamSynchronize failed"); } cudaStream_t stream() { return stream_; } @@ -97,12 +96,11 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle() { if (!blas_handle_) { GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) == - CUBLAS_STATUS_SUCCESS, + PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_), "cublasCreate failed"); - PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream( - blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, - "cublasSetStream failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::cublasSetStream(blas_handle_, stream_), + "cublasSetStream failed"); } return blas_handle_; } @@ -110,12 +108,11 @@ class CUDADeviceContext : public DeviceContext { cudnnHandle_t cudnn_handle() { if (!dnn_handle_) { GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) == - CUDNN_STATUS_SUCCESS, + PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_), "cudnnCreate failed"); - PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream( - dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, - "cudnnSetStream failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_), + "cudnnSetStream failed"); } return dnn_handle_; } @@ -124,16 +121,15 @@ class CUDADeviceContext : public DeviceContext { if (!rand_generator_) { GPUPlaceGuard guard(gpu_place_); PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( - &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == - CURAND_STATUS_SUCCESS, + &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT), "curandCreateGenerator failed"); PADDLE_ENFORCE( paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( - rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS, + rand_generator_, random_seed_), "curandSetPseudoRandomGeneratorSeed failed"); - PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream( - rand_generator_, stream_) == CURAND_STATUS_SUCCESS, - "curandSetStream failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::curandSetStream(rand_generator_, stream_), + "curandSetStream failed"); } return rand_generator_; } @@ -141,26 +137,23 @@ class CUDADeviceContext : public DeviceContext { ~CUDADeviceContext() { Wait(); if (blas_handle_) { - PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) == - CUBLAS_STATUS_SUCCESS, + PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_), "cublasDestroy failed"); } if (dnn_handle_) { - PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) == - CUDNN_STATUS_SUCCESS, + PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_), "cudnnDestroy failed"); } if (rand_generator_) { - PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator( - rand_generator_) == CURAND_STATUS_SUCCESS, - "curandDestroyGenerator failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::curandDestroyGenerator(rand_generator_), + "curandDestroyGenerator failed"); } eigen_stream_.reset(); eigen_device_.reset(); - paddle::platform::throw_on_error(cudaStreamDestroy(stream_), - "cudaStreamDestroy failed"); + PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed"); } private: diff --git a/paddle/platform/dynload/dynamic_loader.cc b/paddle/platform/dynload/dynamic_loader.cc index dd914e006d..ae9a0a982c 100644 --- a/paddle/platform/dynload/dynamic_loader.cc +++ b/paddle/platform/dynload/dynamic_loader.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include #include "gflags/gflags.h" #include "glog/logging.h" -#include "paddle/framework/enforce.h" +#include "paddle/platform/enforce.h" DEFINE_string(cudnn_dir, "", "Specify path for loading libcudnn.so. For instance, " diff --git a/paddle/platform/error.h b/paddle/platform/error.h deleted file mode 100644 index 93424bb610..0000000000 --- a/paddle/platform/error.h +++ /dev/null @@ -1,87 +0,0 @@ -#pragma once - -#include -#include -#include - -#ifndef PADDLE_ONLY_CPU - -#include -#include -#include -#include -#include - -#endif // PADDLE_ONLY_CPU - -namespace paddle { -namespace platform { - -#ifndef PADDLE_ONLY_CPU - -inline void throw_on_error(cudaError_t e, const char* message) { - if (e) { - throw thrust::system_error(e, thrust::cuda_category(), message); - } -} - -inline void throw_on_error(curandStatus_t stat, const char* message) { - if (stat != CURAND_STATUS_SUCCESS) { - throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(), - message); - } -} - -inline void throw_on_error(cudnnStatus_t stat, const char* message) { - std::stringstream ss; - if (stat == CUDNN_STATUS_SUCCESS) { - return; - } else { - ss << cudnnGetErrorString(stat); - ss << ", " << message; - throw std::runtime_error(ss.str()); - } -} - -inline void throw_on_error(cublasStatus_t stat, const char* message) { - std::stringstream ss; - if (stat == CUBLAS_STATUS_SUCCESS) { - return; - } else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) { - ss << "CUBLAS: not initialized"; - } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) { - ss << "CUBLAS: alloc failed"; - } else if (stat == CUBLAS_STATUS_INVALID_VALUE) { - ss << "CUBLAS: invalid value"; - } else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) { - ss << "CUBLAS: arch mismatch"; - } else if (stat == CUBLAS_STATUS_MAPPING_ERROR) { - ss << "CUBLAS: mapping error"; - } else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) { - ss << "CUBLAS: execution failed"; - } else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) { - ss << "CUBLAS: internal error"; - } else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) { - ss << "CUBLAS: not supported"; - } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { - ss << "CUBLAS: license error"; - } - ss << ", " << message; - throw std::runtime_error(ss.str()); -} - -inline void throw_on_error(cublasStatus_t stat) { - const char* message = ""; - throw_on_error(stat, message); -} - -#endif // PADDLE_ONLY_CPU - -inline void throw_on_error(int stat, const char* message) { - if (stat) { - throw std::runtime_error(message + (", stat = " + std::to_string(stat))); - } -} - -} // namespace platform -} // namespace paddle diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index a1383d3524..cf9921e870 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/platform/gpu_info.h" #include "gflags/gflags.h" -#include "paddle/platform/error.h" +#include "paddle/platform/enforce.h" DEFINE_double(fraction_of_gpu_memory_to_use, 0.95, "Default use 95% of GPU memory for PaddlePaddle," @@ -25,7 +25,7 @@ namespace platform { int GetDeviceCount() { int count; - throw_on_error( + PADDLE_ENFORCE( cudaGetDeviceCount(&count), "cudaGetDeviceCount failed in paddle::platform::GetDeviceCount"); return count; @@ -33,19 +33,19 @@ int GetDeviceCount() { int GetCurrentDeviceId() { int device_id; - throw_on_error( + PADDLE_ENFORCE( cudaGetDevice(&device_id), "cudaGetDevice failed in paddle::platform::GetCurrentDeviceId"); return device_id; } void SetDeviceId(int id) { - throw_on_error(cudaSetDevice(id), + PADDLE_ENFORCE(cudaSetDevice(id), "cudaSetDevice failed in paddle::platform::SetDeviceId"); } void GpuMemoryUsage(size_t& available, size_t& total) { - throw_on_error(cudaMemGetInfo(&available, &total), + PADDLE_ENFORCE(cudaMemGetInfo(&available, &total), "cudaMemGetInfo failed in paddle::platform::GetMemoryUsage"); } From 2680dca9c8cfc4087bdfd0a402e9b9ec116ea824 Mon Sep 17 00:00:00 2001 From: liaogang Date: Sat, 15 Jul 2017 01:37:02 +0800 Subject: [PATCH 13/37] ENH: add cuda enforce to PADDLE_ENFORCE --- paddle/platform/enforce.h | 160 ++++++++++++++++++++++++++++++++ paddle/platform/enforce_test.cc | 35 +++++++ 2 files changed, 195 insertions(+) create mode 100644 paddle/platform/enforce.h create mode 100644 paddle/platform/enforce_test.cc diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h new file mode 100644 index 0000000000..0e40bd798c --- /dev/null +++ b/paddle/platform/enforce.h @@ -0,0 +1,160 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#ifndef PADDLE_ONLY_CPU + +#include "paddle/platform/dynload/cublas.h" +#include "paddle/platform/dynload/cudnn.h" +#include "paddle/platform/dynload/curand.h" + +#include +#include +#include +#include +#include + +#endif // PADDLE_ONLY_CPU + +namespace paddle { +namespace platform { + +/** + * @brief Enforce exception. Inherits std::exception + * + * All enforce condition not met, will throw an EnforceNotMet exception. + */ +class EnforceNotMet : public std::exception { + public: + EnforceNotMet(const std::string& msg, const char* file, int fileline) { + std::ostringstream sout; + sout << msg << " at [" << file << ":" << fileline << "];"; + all_msg_ = sout.str(); + } + + const char* what() const noexcept override { return all_msg_.c_str(); } + + private: + std::string all_msg_; +}; + +// From https://stackoverflow.com/questions/30130930/ +// __buildin_expect is in C++ 11 standard. Since the condition which enforced +// should be true in most situation, it will make the compiler generate faster +// code by adding `UNLIKELY` macro. +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) + +/** + * @brief Throw a EnforceNotMet exception, automatically filled __FILE__ & + * __LINE__ + * + * This macro take __VA_ARGS__, user can pass any type if that type can + * serialize to std::ostream + */ +#define PADDLE_THROW(...) \ + do { \ + throw ::paddle::platform::EnforceNotMet( \ + ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \ + } while (0) + +#ifndef PADDLE_ONLY_CPU + +template +inline void throw_on_error(cudaError_t e, const Args&... args) { + if (UNLIKELY(!(e))) { + std::stringstream ss; + ss << ::paddle::string::Sprintf(args...); + ss << ::paddle::string::Sprintf(" at [%s:%s];", __FILE__, __LINE__); + throw thrust::system_error(e, thrust::cuda_category(), ss.str()); + } +} + +template +inline void throw_on_error(curandStatus_t stat, const Args&... args) { + if (stat != CURAND_STATUS_SUCCESS) { + std::stringstream ss; + ss << ::paddle::string::Sprintf(args...); + ss << ::paddle::string::Sprintf(" at [%s:%s];", __FILE__, __LINE__); + throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(), + ss.str()); + } +} + +template +inline void throw_on_error(cudnnStatus_t stat, const Args&... args) { + if (stat == CUDNN_STATUS_SUCCESS) { + return; + } else { + std::stringstream ss; + ss << ::paddle::platform::dynload::cudnnGetErrorString(stat); + ss << ", " << ::paddle::string::Sprintf(args...); + ss << ::paddle::string::Sprintf(" at [%s:%s];", __FILE__, __LINE__); + throw std::runtime_error(ss.str()); + } +} + +template +inline void throw_on_error(cublasStatus_t stat, const Args&... args) { + std::stringstream ss; + if (stat == CUBLAS_STATUS_SUCCESS) { + return; + } else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) { + ss << "CUBLAS: not initialized"; + } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) { + ss << "CUBLAS: alloc failed"; + } else if (stat == CUBLAS_STATUS_INVALID_VALUE) { + ss << "CUBLAS: invalid value"; + } else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) { + ss << "CUBLAS: arch mismatch"; + } else if (stat == CUBLAS_STATUS_MAPPING_ERROR) { + ss << "CUBLAS: mapping error"; + } else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) { + ss << "CUBLAS: execution failed"; + } else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) { + ss << "CUBLAS: internal error"; + } else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) { + ss << "CUBLAS: not supported"; + } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { + ss << "CUBLAS: license error"; + } + ss << ", " << ::paddle::string::Sprintf(args...); + ss << ::paddle::string::Sprintf(" at [%s:%s];", __FILE__, __LINE__); + throw std::runtime_error(ss.str()); +} + +#endif // PADDLE_ONLY_CPU + +template +inline void throw_on_error(int stat, const Args&... args) { + if (UNLIKELY(!(stat))) { + PADDLE_THROW(args...); + } +} + +/** + * @brief Enforce a condition, otherwise throw an EnforceNotMet + */ +#define PADDLE_ENFORCE(condition, ...) \ + do { \ + ::paddle::platform::throw_on_error(condition, __VA_ARGS__); \ + } while (0) + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/enforce_test.cc b/paddle/platform/enforce_test.cc new file mode 100644 index 0000000000..2d96b51ab0 --- /dev/null +++ b/paddle/platform/enforce_test.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +TEST(ENFORCE, OK) { + PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); + size_t val = 1; + const size_t limit = 10; + PADDLE_ENFORCE(val < limit, "Enforce is OK too"); +} + +TEST(ENFORCE, FAILED) { + bool in_catch = false; + try { + PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); + } catch (paddle::platform::EnforceNotMet err) { + in_catch = true; + std::string msg = "Enforce is not ok 123 at all"; + const char* what = err.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + ASSERT_TRUE(in_catch); +} From 90c4cd8323ab7dc375e70ce9e84949854f58ec80 Mon Sep 17 00:00:00 2001 From: liaogang Date: Sat, 15 Jul 2017 08:29:55 +0800 Subject: [PATCH 14/37] FIX: header file --- paddle/framework/tensor.h | 2 +- paddle/platform/device_context.h | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 5fdbb4f07a..c6b9c00554 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -18,8 +18,8 @@ limitations under the License. */ #include #include #include "paddle/framework/ddim.h" -#include "paddle/platform/enforce.h" #include "paddle/memory/memory.h" +#include "paddle/platform/enforce.h" #include "paddle/platform/place.h" namespace paddle { diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index d2569fdc91..2dded7d79e 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -12,15 +12,16 @@ limitations under the License. */ #pragma once #include "paddle/platform/enforce.h" +#include "paddle/platform/place.h" + #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" -#include "paddle/platform/error.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif -#include + #include #include From 7010a5da1a0d91da41fddd4799eff157efa19014 Mon Sep 17 00:00:00 2001 From: liaogang Date: Sat, 15 Jul 2017 10:09:08 +0800 Subject: [PATCH 15/37] FIX: throw_on_error on cuda --- paddle/platform/enforce.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 0e40bd798c..9431204a68 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -78,7 +78,7 @@ class EnforceNotMet : public std::exception { template inline void throw_on_error(cudaError_t e, const Args&... args) { - if (UNLIKELY(!(e))) { + if (e) { std::stringstream ss; ss << ::paddle::string::Sprintf(args...); ss << ::paddle::string::Sprintf(" at [%s:%s];", __FILE__, __LINE__); From d3373c5b853d0570842fbadedb1d969b94cef1bc Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 18 Jul 2017 18:55:29 +0800 Subject: [PATCH 16/37] Skeleton Of fully connected operator * Fc operator is a grouped operator, which combined by may internal operators. * InferShape & Run a FC operator in Python. --- paddle/framework/attr_checker.h | 35 +++++++++ paddle/framework/net.cc | 6 +- paddle/framework/net.h | 5 +- paddle/operators/CMakeLists.txt | 6 +- paddle/operators/fc_op.cc | 76 +++++++++++++++++++ paddle/pybind/CMakeLists.txt | 2 +- paddle/pybind/pybind.cc | 46 ++++++----- .../paddle/v2/framework/tests/CMakeLists.txt | 2 +- .../paddle/v2/framework/tests/test_fc_op.py | 43 +++++++++++ 9 files changed, 195 insertions(+), 26 deletions(-) create mode 100644 paddle/operators/fc_op.cc create mode 100644 python/paddle/v2/framework/tests/test_fc_op.py diff --git a/paddle/framework/attr_checker.h b/paddle/framework/attr_checker.h index c0c33d8114..f2d88f3cb0 100644 --- a/paddle/framework/attr_checker.h +++ b/paddle/framework/attr_checker.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "paddle/framework/enforce.h" @@ -41,6 +42,35 @@ class DefaultValueSetter { T default_value_; }; +template +class EnumInContainer { + public: + explicit EnumInContainer(const std::unordered_set& c) : container_(c) {} + void operator()(T& val) const { + PADDLE_ENFORCE(container_.find(val) != container_.end(), + "Value %s is not in enum container %s", val, + ContainerDebugString()); + } + + private: + std::string ContainerDebugString() const { + std::ostringstream sout; + sout << "["; + size_t cnt = 0; + for (auto& v : container_) { + sout << v; + ++cnt; + if (cnt != container_.size()) { + sout << " ,"; + } + } + sout << "]"; + return sout.str(); + } + + std::unordered_set container_; +}; + // check whether a certain attribute fit its limits // an attribute can have more than one limits template @@ -50,6 +80,11 @@ class TypedAttrChecker { public: TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {} + TypedAttrChecker& InEnum(const std::unordered_set& range) { + value_checkers_.push_back(EnumInContainer(range)); + return *this; + } + TypedAttrChecker& LargerThan(const T& lower_bound) { value_checkers_.push_back(LargerThanChecker(lower_bound)); return *this; diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index b9cd732d40..501536657d 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -19,7 +19,10 @@ namespace paddle { namespace framework { -void PlainNet::CompleteAddOp() { +void PlainNet::CompleteAddOp(bool calc) { + add_op_done_ = true; + if (!calc) return; + std::unordered_set input_set; std::unordered_set output_set; std::unordered_set temp_output; @@ -52,7 +55,6 @@ void PlainNet::CompleteAddOp() { } attrs_["temporary_index"] = tmp_index; - add_op_done_ = true; } std::string PlainNet::DebugString() const { diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 33bb30ea07..19c5fa223b 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -16,7 +16,6 @@ limitations under the License. */ #include #include -#include "paddle/framework/net_proto.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/scope.h" @@ -41,7 +40,7 @@ namespace framework { class Net : public OperatorBase { public: virtual void AddOp(const OperatorPtr& op) = 0; - virtual void CompleteAddOp() = 0; + virtual void CompleteAddOp(bool calc) = 0; }; using NetPtr = std::shared_ptr; @@ -86,7 +85,7 @@ class PlainNet : public Net { ops_.push_back(op); } - void CompleteAddOp() override; + void CompleteAddOp(bool calculate = true) override; std::string DebugString() const override; diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f47c3a4208..bc64bfd7ec 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -27,7 +27,8 @@ function(op_library TARGET) endif() list(LENGTH cu_srcs cu_srcs_len) - if (${cu_srcs_len} EQUAL 0) + list(LENGTH op_library_DEPS dep_len) + if (${cu_srcs_len} EQUAL 0 AND ${dep_len} EQUAL 0) message(WARNING "The op library ${TARGET} not support GPU!") endif() @@ -47,3 +48,6 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) + +op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op + softmax_op net) diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc new file mode 100644 index 0000000000..01e96f4c48 --- /dev/null +++ b/paddle/operators/fc_op.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +class FullyConnectedOp : public framework::PlainNet { +public: + void Init() override { + AddOp(framework::OpRegistry::CreateOp("mul", + { + Input("X"), Input("W"), + }, + {Output("before_act")}, + {})); + auto b = Input("b"); + if (b != framework::OperatorBase::EMPTY_VAR_NAME()) { + AddOp(framework::OpRegistry::CreateOp("rowwise_add", + {Output("before_act"), Input("b")}, + {Output("before_act")}, + {})); + } + + auto activation = GetAttr("activation"); + AddOp(framework::OpRegistry::CreateOp( + activation, {Output("before_act")}, {Output("Y")}, {})); + CompleteAddOp(false); + } +}; + +class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker { +public: + FullyConnectedOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "the input of fc operator"); + AddInput("W", "the weight of fc operator"); + AddInput("b", "the bias of fc operator"); + + AddOutput("Y", "the output of fc operator"); + AddOutput( + "before_act", "the before activation output of fc operator", true); + AddAttr("activation", "The activation key for fc layer") + .SetDefault("sigmoid") + .InEnum({"sigmoid", "softmax"}); + + //! TODO(yuyang18): Complete comment; + AddComment("FullyConnected Operator"); + } +}; +} // namespace operators +} // namespace paddle + +USE_OP(mul); +USE_OP(rowwise_add); +USE_OP(sigmoid); +USE_OP(softmax); + +REGISTER_OP(fc, + paddle::operators::FullyConnectedOp, + paddle::operators::FullyConnectedOpMaker); diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 00b14a9432..29fb29c7c1 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,2 +1,2 @@ cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python - add_op mul_op rowwise_add_op sigmoid_op softmax_op) + add_op fc_op) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fc9c6544c3..e0f4c02459 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include #include #include #include @@ -26,10 +27,7 @@ namespace py = pybind11; namespace pd = paddle::framework; USE_OP(add_two); -USE_OP(softmax); -USE_OP(mul); -USE_OP(rowwise_add); -USE_OP(sigmoid); +USE_OP_WITHOUT_KERNEL(fc); PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of Paddle Paddle"); @@ -53,7 +51,9 @@ PYBIND11_PLUGIN(core) { self.mutable_data(paddle::platform::CPUPlace()); }) .def("set", paddle::pybind::PyTensorSetFromArray) - .def("set", paddle::pybind::PyTensorSetFromArray); + .def("set", paddle::pybind::PyTensorSetFromArray) + .def("shape", + [](pd::Tensor& self) { return pd::vectorize(self.dims()); }); py::class_(m, "Variable", R"DOC(Variable Class. @@ -83,15 +83,16 @@ All parameter, weight, gradient are variables in Paddle. //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. - m.def("get_all_op_protos", []() -> std::vector { + m.def("get_all_op_protos", []() -> std::vector { auto& protos = pd::OpRegistry::protos(); - std::vector ret_values; + std::vector ret_values; for (auto it = protos.begin(); it != protos.end(); ++it) { PADDLE_ENFORCE(it->second.IsInitialized(), "OpProto must all be initialized"); - ret_values.emplace_back(); - PADDLE_ENFORCE(it->second.SerializeToString(&ret_values.back()), + std::string str; + PADDLE_ENFORCE(it->second.SerializeToString(&str), "Serialize OpProto Error. This could be a bug of Paddle."); + ret_values.push_back(py::bytes(str)); } return ret_values; }); @@ -101,17 +102,26 @@ All parameter, weight, gradient are variables in Paddle. .def("empty", pd::OperatorBase::EMPTY_VAR_NAME) .def("temp", pd::OperatorBase::TMP_VAR_NAME); + py::class_(m, "DeviceContext") + .def_static("cpu_context", []() -> paddle::platform::DeviceContext* { + return new paddle::platform::CPUDeviceContext(); + }); + py::class_(m, "Operator") .def("__str__", &pd::OperatorBase::DebugString) - .def_static("create", [](const std::string& protobin) { - pd::OpDesc desc; - PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), - "Cannot parse user input to OpDesc"); - PADDLE_ENFORCE(desc.IsInitialized(), - "User OpDesc is not initialized, reason %s", - desc.InitializationErrorString()); - return pd::OpRegistry::CreateOp(desc); - }); + .def_static("create", + [](const std::string& protobin) { + pd::OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + return pd::OpRegistry::CreateOp(desc); + }) + .def("infer_shape", &pd::OperatorBase::InferShape) + .def("run", &pd::OperatorBase::Run) + .def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; }); return m.ptr(); } diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 4ce2bef6fc..b75b7442d1 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -1,3 +1,3 @@ add_python_test(test_framework test_protobuf.py test_scope.py test_default_scope_funcs.py test_op_creation_methods.py - test_tensor.py) + test_tensor.py test_fc_op.py) diff --git a/python/paddle/v2/framework/tests/test_fc_op.py b/python/paddle/v2/framework/tests/test_fc_op.py new file mode 100644 index 0000000000..59e7e61249 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_fc_op.py @@ -0,0 +1,43 @@ +import paddle.v2.framework.core as core +import unittest +import numpy +import paddle.v2.framework.create_op_creation_methods as creation + + +class TestFc(unittest.TestCase): + def test_fc(self): + scope = core.Scope(None) + x = scope.create_var("X") + x_tensor = x.get_tensor() + x_tensor.set_dims([1000, 784]) + x_tensor.alloc_float() + + w = scope.create_var("W") + w_tensor = w.get_tensor() + w_tensor.set_dims([784, 100]) + w_tensor.alloc_float() + + w_tensor.set(numpy.random.random((784, 100)).astype("float32")) + + # Set a real numpy array here. + # x_tensor.set(numpy.array([])) + + op = creation.op_creations.fc(X="X", Y="Y", W="W") + + for out in op.outputs(): + if scope.get_var(out) is None: + scope.create_var(out).get_tensor() + + tensor = scope.get_var("Y").get_tensor() + op.infer_shape(scope) + self.assertEqual([1000, 100], tensor.shape()) + + ctx = core.DeviceContext.cpu_context() + + op.run(scope, ctx) + + # After complete all ops, check Y is expect or not. + + +if __name__ == '__main__': + unittest.main() From 3402b6ad39c5ac8ba40a6981e206e554490217ff Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 18 Jul 2017 20:35:34 +0800 Subject: [PATCH 17/37] Add Unittest of add_two_op --- .../framework/create_op_creation_methods.py | 4 ++ .../paddle/v2/framework/tests/CMakeLists.txt | 2 +- .../paddle/v2/framework/tests/op_test_util.py | 50 +++++++++++++++++++ .../v2/framework/tests/test_add_two_op.py | 17 +++++++ 4 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 python/paddle/v2/framework/tests/op_test_util.py create mode 100644 python/paddle/v2/framework/tests/test_add_two_op.py diff --git a/python/paddle/v2/framework/create_op_creation_methods.py b/python/paddle/v2/framework/create_op_creation_methods.py index c2a7ae7692..7248c3f52a 100644 --- a/python/paddle/v2/framework/create_op_creation_methods.py +++ b/python/paddle/v2/framework/create_op_creation_methods.py @@ -217,6 +217,10 @@ def create_op_creation_method(op_proto): return core.Operator.create(opdesc.SerializeToString()) __impl__.__doc__ = get_docstring_from_op_proto(op_proto) + __impl__.all_input_args = [var.name for var in op_proto.inputs] + __impl__.all_output_args = [var.name for var in op_proto.outputs] + __impl__.all_attr_args = [attr.name for attr in op_proto.attrs] + return __impl__ diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index b75b7442d1..f71009aa85 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -1,3 +1,3 @@ add_python_test(test_framework test_protobuf.py test_scope.py test_default_scope_funcs.py test_op_creation_methods.py - test_tensor.py test_fc_op.py) + test_tensor.py test_fc_op.py test_add_two_op.py) diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py new file mode 100644 index 0000000000..237f9b7eb0 --- /dev/null +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -0,0 +1,50 @@ +import paddle.v2.framework.core as core +import unittest +import numpy +import paddle.v2.framework.create_op_creation_methods as creation + + +class OpTestMeta(type): + def __new__(cls, name, bases, attrs): + obj = super(OpTestMeta, cls).__new__(cls, name, bases, attrs) + + def test_all(self): + func = getattr(creation.op_creations, self.type, None) + self.assertIsNotNone(func) + + scope = core.Scope(None) + kwargs = dict() + + for in_name in func.all_input_args: + if hasattr(self, in_name): + kwargs[in_name] = in_name + var = scope.create_var(in_name).get_tensor() + arr = getattr(self, in_name) + var.set_dims(arr.shape) + var.set(arr) + else: + kwargs[in_name] = "@EMPTY@" + + for out_name in func.all_output_args: + if hasattr(self, out_name): + kwargs[out_name] = out_name + scope.create_var(out_name).get_tensor() + + for attr_name in func.all_attr_args: + if hasattr(self, attr_name): + kwargs[attr_name] = getattr(self, attr_name) + + op = func(**kwargs) + + op.infer_shape(scope) + + ctx = core.DeviceContext.cpu_context() + op.run(scope, ctx) + + for out_name in func.all_output_args: + actual = numpy.array(scope.get_var(out_name).get_tensor()) + expect = getattr(self, out_name) + numpy.testing.assert_almost_equal(actual, expect) + + obj.test_all = test_all + return obj diff --git a/python/paddle/v2/framework/tests/test_add_two_op.py b/python/paddle/v2/framework/tests/test_add_two_op.py new file mode 100644 index 0000000000..a06d7a78ec --- /dev/null +++ b/python/paddle/v2/framework/tests/test_add_two_op.py @@ -0,0 +1,17 @@ +import unittest +from op_test_util import OpTestMeta +import numpy + + +class TestAddOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "add_two" + self.X = numpy.random.random((342, 345)).astype("float32") + self.Y = numpy.random.random((342, 345)).astype("float32") + self.Out = self.X + self.Y + + +if __name__ == '__main__': + unittest.main() From 642d3c4687eb91c3a7fd026e3d8ae15957c8836d Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 18 Jul 2017 15:05:33 -0700 Subject: [PATCH 18/37] Refactorize Tensor to Eigen convesion --- paddle/framework/ddim.h | 11 ---- paddle/framework/eigen.h | 103 ++++++++++++++++++++++++++++++++ paddle/framework/tensor.h | 60 ------------------- paddle/framework/tensor_types.h | 67 --------------------- 4 files changed, 103 insertions(+), 138 deletions(-) create mode 100644 paddle/framework/eigen.h delete mode 100644 paddle/framework/tensor_types.h diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 070850375d..06c4c583b3 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -119,17 +119,6 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); -template -Eigen::DSizes ToEigenDSizes(const DDim& dims) { - int rank = arity(dims); - PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same"); - Eigen::DSizes dsizes; - for (int d = 0; d < rank; d++) { - dsizes[d] = dims[d]; - } - return dsizes; -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h new file mode 100644 index 0000000000..edbbc2694a --- /dev/null +++ b/paddle/framework/eigen.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/platform/tensor.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace framework { + +// EigenDim converts paddle::platform::DDim into Eigen::DSizes. +template +struct EigenDim { + typedef Eigen::DSizes Type; + + static Type From(const DDim& dims) { + PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); + Type ret; + for (int d = 0; d < rank; d++) { + ret[d] = dims[d]; + } + return ret; + } +}; + +// Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor. +template +struct EigenTensor { + using Type = Eigen::TensorMap, + Eigen::Aligned>; + + using ConstType = + Eigen::TensorMap, + Eigen::Aligned> + ConstTensor; + + static Type From(Tensor& tensor, DDim dims) { + return Type(tensor.data(), EigenDim::From(dims)); + } + + static Type From(Tensor& tensor) { return From(tensor, tensor.dims_); } + + static ConstType From(const Tensor& tensor, DDim dims) { + return ConstType(tensor.data(), EigenDim::From(dims)); + } + + static ConstType From(const Tensor& tensor) { + return From(tensor, tensor.dims_); + } +}; + +// Interpret paddle::platform::Tensor as EigenVecotr and EigenConstVector. +template +struct EigenVector { + using EigenVector = + Eigen::TensorMap, + Eigen::Aligned>; + + using EigenConstVector = + Eigen::TensorMap, + Eigen::Aligned>; + + static Type From(Tensor& tensor) { return EigenTensor::From(tensor); } + + static ConstType From(const Tensor& tensor) { + return EigenTensor::From(tensor); + } +}; + +// Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix. +template +struct EigenMatrix { + template + using EigenMatrix = + Eigen::TensorMap, + Eigen::Aligned>; + + template + using EigenConstMatrix = + Eigen::TensorMap, + Eigen::Aligned>; + + static Type From(Tensor& tensor) { return EigenTensor::From(tensor); } + + static ConstType From(const Tensor& tensor) { + return EigenTensor::From(tensor); + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 4f07350e59..1235b53227 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -86,66 +86,6 @@ class Tensor { offset_); } - template - typename TTypes::Tensor shaped(DDim new_dims) { - Eigen::array dims = - paddle::framework::ToEigenDSizes(new_dims); - return typename TTypes::Tensor(raw_data(), dims); - } - - template - typename TTypes::Tensor tensor() { - return typename TTypes::Tensor( - raw_data(), paddle::framework::ToEigenDSizes(dims_)); - } - - // flat to rank = 1 - template - typename TTypes::Flat flat() { - return shaped(make_ddim({static_cast(product(dims_))})); - } - - // to TensorType Vec - template - typename TTypes::Vec vec() { - return tensor(); - } - - // to TensorType Matrix - template - typename TTypes::Matrix matrix() { - return tensor(); - } - - // const versions of all the methods above. - template - typename TTypes::Tensor shaped(DDim new_dims) const { - Eigen::array dims = - paddle::framework::ToEigenDSizes(new_dims); - return typename TTypes::Tensor(data(), dims); - } - - template - typename TTypes::ConstantTensor tensor() const { - return typename TTypes::Tensor( - data(), paddle::framework::ToEigenDSizes(dims_)); - } - - template - typename TTypes::ConstFlat flat() const { - return shaped(make_ddim({static_cast(product(dims_))})); - } - - template - typename TTypes::ConstVec vec() const { - return tensor(); - } - - template - typename TTypes::ConstMatrix matrix() const { - return tensor(); - } - template void ShareDataFrom(const Tensor& src) { src.CheckDims(); diff --git a/paddle/framework/tensor_types.h b/paddle/framework/tensor_types.h deleted file mode 100644 index 4bf27a377e..0000000000 --- a/paddle/framework/tensor_types.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "unsupported/Eigen/CXX11/Tensor" - -namespace paddle { -namespace framework { - -// Helper to define Tensor types given that the scalar is of type T. -template -struct TTypes { - // Rank- tensor of scalar type T. - typedef Eigen::TensorMap, - Eigen::Aligned> - Tensor; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstTensor; - - // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. - typedef Eigen::TensorMap< - Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>, - Eigen::Aligned> - Scalar; - typedef Eigen::TensorMap, - Eigen::RowMajor, IndexType>, - Eigen::Aligned> - ConstScalar; - - // Rank-1 tensor (vector) of scalar type T. - typedef Eigen::TensorMap, - Eigen::Aligned> - Flat; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstFlat; - typedef Eigen::TensorMap, - Eigen::Aligned> - Vec; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstVec; - - // Rank-2 tensor (matrix) of scalar type T. - typedef Eigen::TensorMap, - Eigen::Aligned> - Matrix; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstMatrix; -}; - -} // namespace framework -} // namespace paddle From cb1d1f167c95b0c7ded6cb2c68d65de35765c6a5 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 18 Jul 2017 15:35:51 -0700 Subject: [PATCH 19/37] Add unit test --- paddle/framework/eigen_test.cc | 37 ++++++++++++++++++++++++++++++++++ paddle/framework/tensor.h | 15 +++++++++++--- 2 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 paddle/framework/eigen_test.cc diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc new file mode 100644 index 0000000000..c5f27a3298 --- /dev/null +++ b/paddle/framework/eigen_test.cc @@ -0,0 +1,37 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include "paddle/framework/eigen.h" + +#include + +#include "paddle/framework/tensor.h" + +TEST(Eigen, Tensor) { + using paddle::platform::Tensor; + using paddle::platform::EigenTensor; + using paddle::platform::make_ddim; + + Tensor t; + float* p = t.mutable_data(make_ddim({1, 2, 3}), CPUPlace()); + for (int i = 0; i < 1 * 2 * 3; i++) { + p[i] = static_cast(i); + } + + EigenTensor::Type et = EigenTensor::From(t); + // TODO: check the content of et. +} + +TEST(Eigen, Vector) {} + +TEST(Eigen, Matrix) {} diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 1235b53227..405393fb11 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -20,7 +20,6 @@ limitations under the License. */ #include #include "paddle/framework/ddim.h" #include "paddle/framework/enforce.h" -#include "paddle/framework/tensor_types.h" #include "paddle/memory/memory.h" #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" @@ -35,6 +34,18 @@ struct CastToPyBufferImpl; namespace framework { class Tensor { + template + friend struct paddle::pybind::details::CastToPyBufferImpl; + + template + friend struct EigenTensor; + + template + friend struct EigenVector; + + template + friend struct EigenMatrix; + public: Tensor() : offset_(0) {} @@ -191,8 +202,6 @@ class Tensor { std::shared_ptr holder_; // holds the memory block if allocated. DDim dims_; size_t offset_; // marks the begin of tensor data area. - template - friend struct paddle::pybind::details::CastToPyBufferImpl; }; } // namespace framework From 00ed56430782f953ab42e549fe94938271f9e194 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 18 Jul 2017 16:40:22 -0700 Subject: [PATCH 20/37] Update --- paddle/framework/CMakeLists.txt | 3 +++ paddle/operators/add_op.h | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index eb34164623..a00b9c8190 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -4,8 +4,11 @@ cc_test(enforce_test SRCS enforce_test.cc DEPS enforce) cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) + cc_library(tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) +cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) + cc_test(variable_test SRCS variable_test.cc) cc_test(scope_test SRCS scope_test.cc) proto_library(attr_type SRCS attr_type.proto) diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index e08b3fb187..e7c106a23f 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "glog/logging.h" +#include "paddle/framework/eigen.h" #include "paddle/framework/operator.h" namespace paddle { @@ -29,8 +30,9 @@ public: output->mutable_data(context.GetPlace()); - output->flat().device(*(context.GetEigenDevice())) = - input0.flat() + input1.flat(); + framework::EigenVector::From(*output).device( + *(context.GetEigenDevice())) = + framework::EigenVector(*input0) + framework::EigenVector(*input1); } }; From 2538e20787bf8e652a0acaf129fa73ce06abf20b Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 18 Jul 2017 16:42:59 -0700 Subject: [PATCH 21/37] Fix wrong inclusion path --- paddle/framework/eigen.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index edbbc2694a..28641a389f 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/platform/tensor.h" +#include "paddle/framework/tensor.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { From 3d4e808ce418dc95c3391eaabe24b2d9f4d0e33d Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 19 Jul 2017 01:05:59 +0000 Subject: [PATCH 22/37] cmake: fix problem that go_library is never rebuilt. `merge_static_libs` also have the similar logic of using ${dummyfile}, I am not sure if there needs a change or not. --- cmake/generic.cmake | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index e42e75c12a..534be0abe2 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -290,8 +290,22 @@ function(go_library TARGET_NAME) set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}") endif() - # Add dummy code to support `make target_name` under Terminal Command set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c) + + # This custom command will always run since it depends on a not + # existing file. + add_custom_command( + OUTPUT dummy_rebulid_${TARGET_NAME} + COMMAND cmake -E touch ${dummyfile} + ) + # Create a custom target that depends on the custom command output + # file, so the custom command can be referenced as a dependency by + # `add_dependencies`. + add_custom_target(rebuild_${TARGET_NAME} + DEPENDS dummy_rebulid_${TARGET_NAME} + ) + + # Add dummy code to support `make target_name` under Terminal Command file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") if (go_library_SHARED OR go_library_shared) add_library(${TARGET_NAME} SHARED ${dummyfile}) @@ -302,6 +316,12 @@ function(go_library TARGET_NAME) add_dependencies(${TARGET_NAME} ${go_library_DEPS}) endif(go_library_DEPS) + # The "source file" of the library is `${dummyfile}` which never + # change, so the target will never rebuild. Make the target depends + # on the custom command that touches the library "source file", so + # rebuild will always happen. + add_dependencies(${TARGET_NAME} rebuild_${TARGET_NAME}) + set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") From 1981eaf922f3636a9f49209757d52c527d2dbe96 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 18 Jul 2017 18:37:29 -0700 Subject: [PATCH 23/37] Fix Tensor::data interface --- paddle/framework/eigen.h | 21 ++++++++------------- paddle/framework/eigen_test.cc | 22 ++++++++++++++-------- paddle/framework/tensor.h | 8 ++++---- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 28641a389f..cd87b042df 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -28,7 +28,7 @@ struct EigenDim { static Type From(const DDim& dims) { PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); Type ret; - for (int d = 0; d < rank; d++) { + for (int d = 0; d < arity(dims); d++) { ret[d] = dims[d]; } return ret; @@ -43,8 +43,7 @@ struct EigenTensor { using ConstType = Eigen::TensorMap, - Eigen::Aligned> - ConstTensor; + Eigen::Aligned>; static Type From(Tensor& tensor, DDim dims) { return Type(tensor.data(), EigenDim::From(dims)); @@ -64,11 +63,10 @@ struct EigenTensor { // Interpret paddle::platform::Tensor as EigenVecotr and EigenConstVector. template struct EigenVector { - using EigenVector = - Eigen::TensorMap, - Eigen::Aligned>; + using Type = Eigen::TensorMap, + Eigen::Aligned>; - using EigenConstVector = + using ConstType = Eigen::TensorMap, Eigen::Aligned>; @@ -82,13 +80,10 @@ struct EigenVector { // Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix. template struct EigenMatrix { - template - using EigenMatrix = - Eigen::TensorMap, - Eigen::Aligned>; + using Type = Eigen::TensorMap, + Eigen::Aligned>; - template - using EigenConstMatrix = + using ConstType = Eigen::TensorMap, Eigen::Aligned>; diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index c5f27a3298..23eec7533f 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -12,26 +12,32 @@ */ #include "paddle/framework/eigen.h" - #include -#include "paddle/framework/tensor.h" +namespace paddle { +namespace framework { -TEST(Eigen, Tensor) { - using paddle::platform::Tensor; - using paddle::platform::EigenTensor; - using paddle::platform::make_ddim; +TEST(EigenDim, From) { + EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3})); + EXPECT_EQ(1, ed[0]); + EXPECT_EQ(2, ed[1]); + EXPECT_EQ(3, ed[2]); +} +TEST(Eigen, Tensor) { Tensor t; - float* p = t.mutable_data(make_ddim({1, 2, 3}), CPUPlace()); + float* p = t.mutable_data(make_ddim({1, 2, 3}), platform::CPUPlace()); for (int i = 0; i < 1 * 2 * 3; i++) { p[i] = static_cast(i); } - EigenTensor::Type et = EigenTensor::From(t); + EigenTensor::Type et = EigenTensor::From(t); // TODO: check the content of et. } TEST(Eigen, Vector) {} TEST(Eigen, Matrix) {} + +} // namespace platform +} // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 405393fb11..8fbf42e7f6 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -37,13 +37,13 @@ class Tensor { template friend struct paddle::pybind::details::CastToPyBufferImpl; - template + template friend struct EigenTensor; - template + template friend struct EigenVector; - template + template friend struct EigenMatrix; public: @@ -57,7 +57,7 @@ class Tensor { } template - T* raw_data() const { + T* data() { CheckDims(); return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); From 0a0b4caaa7cea5c2b205cc58cef08cdfb48de3c1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 19 Jul 2017 13:07:28 +0800 Subject: [PATCH 24/37] Change Operator::create use py::bytes not std::string --- paddle/pybind/pybind.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index e0f4c02459..7e84550f77 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -110,7 +110,7 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Operator") .def("__str__", &pd::OperatorBase::DebugString) .def_static("create", - [](const std::string& protobin) { + [](py::bytes protobin) { pd::OpDesc desc; PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), "Cannot parse user input to OpDesc"); From 3e7819c2762b5b9c93828844d4b4e201c996f5bf Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 19 Jul 2017 13:47:17 +0800 Subject: [PATCH 25/37] 1. Reading image shape from input data instead of image_config 2. Add crop layer unitest 3. Fix bugs --- CMakeLists.txt | 2 +- paddle/function/CropOp.cpp | 34 ++++--- paddle/function/CropOp.h | 2 + paddle/function/CropOpGpu.cu | 24 ++--- paddle/gserver/layers/CropLayer.cpp | 89 +++++++++++-------- paddle/gserver/layers/CropLayer.h | 5 +- python/paddle/trainer/config_parser.py | 23 ----- .../paddle/trainer_config_helpers/layers.py | 12 ++- .../tests/configs/test_crop.py | 21 +++++ 9 files changed, 113 insertions(+), 99 deletions(-) create mode 100644 python/paddle/trainer_config_helpers/tests/configs/test_crop.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 15a7c6b074..fdc62b3151 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ # limitations under the License cmake_minimum_required(VERSION 3.0) - +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ldl -lpthread") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp index 39e06fc120..f12ee43e3d 100644 --- a/paddle/function/CropOp.cpp +++ b/paddle/function/CropOp.cpp @@ -22,11 +22,10 @@ template <> void Crop(real* outputs, const real* inputs, const TensorShape inShape, + const TensorShape outShape, const FuncConfig& conf) { std::vector crop_corner = conf.get>("crop_corner"); - std::vector crop_shape = - conf.get>("crop_shape"); int cCrop = crop_corner[1]; int hCrop = crop_corner[2]; int wCrop = crop_corner[3]; @@ -36,9 +35,9 @@ void Crop(real* outputs, int inH = inShape[2]; int inW = inShape[3]; - int outC = crop_shape[1]; - int outH = crop_shape[2]; - int outW = crop_shape[3]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; for (int n = 0; n < num; n++) { for (int c = 0; c < outC; c++) { @@ -54,12 +53,11 @@ void Crop(real* outputs, template <> void CropGrad(const real* inGrad, real* outGrad, + const TensorShape inShape, const TensorShape outShape, const FuncConfig& conf) { std::vector crop_corner = conf.get>("crop_corner"); - std::vector crop_shape = - conf.get>("crop_shape"); int cCrop = crop_corner[1]; int hCrop = crop_corner[2]; int wCrop = crop_corner[3]; @@ -69,9 +67,9 @@ void CropGrad(const real* inGrad, int outH = outShape[2]; int outW = outShape[3]; - int inC = crop_shape[1]; - int inH = crop_shape[2]; - int inW = crop_shape[3]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; for (int n = 0; n < num; n++) { for (int c = 0; c < inC; c++) { @@ -123,9 +121,13 @@ public: CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); TensorShape inShape = inputs[0].shape(); + TensorShape outShape = outputs[0].shape(); - Crop( - outputs[0].data(), inputs[0].data(), inShape, conf_); + Crop(outputs[0].data(), + inputs[0].data(), + inShape, + outShape, + conf_); } private: @@ -152,9 +154,13 @@ public: CHECK_EQ(outputs[0].getArgType(), ADD_TO); TensorShape outShape = outputs[0].shape(); + TensorShape inShape = inputs[0].shape(); - CropGrad( - inputs[0].data(), outputs[0].data(), outShape, conf_); + CropGrad(inputs[0].data(), + outputs[0].data(), + inShape, + outShape, + conf_); } private: diff --git a/paddle/function/CropOp.h b/paddle/function/CropOp.h index 71e8c4c00e..87986fbdc7 100644 --- a/paddle/function/CropOp.h +++ b/paddle/function/CropOp.h @@ -31,6 +31,7 @@ template void Crop(real* outputs, const real* inputs, const TensorShape inShape, + const TensorShape outShape, const FuncConfig& conf); /** @@ -45,5 +46,6 @@ template void CropGrad(const real* inGrad, real* outGrad, const TensorShape inShape, + const TensorShape outShape, const FuncConfig& conf); } // namespace paddle diff --git a/paddle/function/CropOpGpu.cu b/paddle/function/CropOpGpu.cu index cadb58b6e9..37ce6de064 100644 --- a/paddle/function/CropOpGpu.cu +++ b/paddle/function/CropOpGpu.cu @@ -37,9 +37,9 @@ template <> void Crop(real* outputs, const real* inputs, const TensorShape inShape, + const TensorShape outShape, const FuncConfig& conf) { std::vector crop_corner = conf.get>("crop_corner"); - std::vector crop_shape = conf.get>("crop_shape"); int cropC = crop_corner[1]; int cropH = crop_corner[2]; int cropW = crop_corner[3]; @@ -49,14 +49,14 @@ void Crop(real* outputs, int inH = inShape[2]; int inW = inShape[3]; - int outC = crop_shape[1]; - int outH = crop_shape[2]; - int outW = crop_shape[3]; - + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + size_t nth = num * outC * outH * outW; int blockSize = 1024; int gridSize = (nth + blockSize - 1) / blockSize; - + KeCrop<<>> (outputs, inputs, inC, inH, inW, cropC, cropH, cropW, outC, outH, outW, nth); @@ -75,7 +75,7 @@ __global__ void KeCropDiff(const real* inGrad, real* outGrad, const int n = idx / inW / inH / inC; const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w; - + outGrad[off] += inGrad[idx]; } } @@ -83,10 +83,10 @@ __global__ void KeCropDiff(const real* inGrad, real* outGrad, template <> void CropGrad(const real* inGrad, real* outGrad, + const TensorShape inShape, const TensorShape outShape, const FuncConfig& conf) { std::vector crop_corner = conf.get>("crop_corner"); - std::vector crop_shape = conf.get>("crop_shape"); int cropC = crop_corner[1]; int cropH = crop_corner[2]; int cropW = crop_corner[3]; @@ -96,10 +96,10 @@ void CropGrad(const real* inGrad, int outH = outShape[2]; int outW = outShape[3]; - int inC = crop_shape[1]; - int inH = crop_shape[2]; - int inW = crop_shape[3]; - + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + size_t nth = num * inC * inH * inW; int blockSize = 1024; int gridSize = (nth + blockSize - 1) / blockSize; diff --git a/paddle/gserver/layers/CropLayer.cpp b/paddle/gserver/layers/CropLayer.cpp index b2fa17b400..69ad913420 100644 --- a/paddle/gserver/layers/CropLayer.cpp +++ b/paddle/gserver/layers/CropLayer.cpp @@ -22,7 +22,8 @@ bool CropLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { /* Initialize the basic parent class */ Layer::init(layerMap, parameterMap); - + CHECK_LE(static_cast(inputLayers_.size()), 2); + CHECK_GE(static_cast(inputLayers_.size()), 1); crop_axis_ = config_.axis(); for (int i = 0; i < config_.offset_size(); i++) { crop_offsets_.push_back(config_.offset(i)); @@ -36,8 +37,14 @@ bool CropLayer::init(const LayerMap& layerMap, ? input0_img_conf.img_size_y() : input0_img_conf.img_size(), input0_img_conf.img_size()}); - // 2. get output shape from input_1 or crop shap conf - if (config_.inputs_size() == 2) { + // 2. get target dims from config + if (config_.inputs_size() == 1) { + targetDims_ = TensorShape({config_.shape(0), + config_.shape(1), + config_.shape(2), + config_.shape(3)}); + } else { + // 2. get input_1 shape auto& input1_img_conf = config_.inputs(1).image_conf(); targetDims_ = TensorShape({0, input1_img_conf.channels(), @@ -45,24 +52,10 @@ bool CropLayer::init(const LayerMap& layerMap, ? input1_img_conf.img_size_y() : input1_img_conf.img_size(), input1_img_conf.img_size()}); - } else { - targetDims_ = TensorShape({config_.shape(0), - config_.shape(1), - config_.shape(2), - config_.shape(3)}); } - // 3. get final crop shape + // 3. get final crop corner int dimSize = 4; - for (int i = 0; i < dimSize; i++) { - if (i >= crop_axis_) { - crop_shape_.push_back(targetDims_[i]); - } else { - crop_shape_.push_back(inDims_[i]); - } - } - - // 4. get final crop corner crop_corner_ = {0, 0, 0, 0}; for (int i = 0; i < dimSize; i++) { if (i >= crop_axis_) { @@ -75,43 +68,61 @@ bool CropLayer::init(const LayerMap& layerMap, } outDims_ = TensorShape(4); - setOutDims(0); - - createFunction(forward_, - "Crop", - FuncConfig() - .set("crop_corner", crop_corner_) - .set("crop_shape", crop_shape_)); - createFunction(backward_, - "CropGrad", - FuncConfig() - .set("crop_corner", crop_corner_) - .set("crop_shape", crop_shape_)); + + createFunction( + forward_, "Crop", FuncConfig().set("crop_corner", crop_corner_)); + createFunction( + backward_, "CropGrad", FuncConfig().set("crop_corner", crop_corner_)); return true; } -void CropLayer::setOutDims(const size_t batchSize) { - outDims_.reshape({batchSize, crop_shape_[1], crop_shape_[2], crop_shape_[3]}); +void CropLayer::setOutDims() { + MatrixPtr input = inputLayers_[1]->getOutputValue(); + size_t batchSize = input->getHeight(); + // get target dims from input_1 + if (config_.inputs_size() == 2) { + targetDims_.setDim(0, batchSize); + int ch = config_.inputs(0).image_conf().channels(); + if (ch != 0) targetDims_.setDim(1, ch); + int h = inputLayers_[1]->getOutput().getFrameHeight(); + if (h != 0) targetDims_.setDim(2, h); + int w = inputLayers_[1]->getOutput().getFrameWidth(); + if (w != 0) targetDims_.setDim(3, w); + } + // get final crop shape from target dims and crop axis + std::vector crop_shape; + int dimSize = 4; + for (int i = 0; i < dimSize; i++) { + if (i >= crop_axis_) { + crop_shape.push_back(targetDims_[i]); + } else { + crop_shape.push_back(inDims_[i]); + } + } + + outDims_.reshape( + {crop_shape[0], crop_shape[1], crop_shape[2], crop_shape[3]}); + output_.setFrameHeight(crop_shape[2]); + output_.setFrameWidth(crop_shape[3]); } -void CropLayer::setTensorDim(const size_t batchSize) { - CHECK_EQ(static_cast(inputLayers_.size()), 2); +void CropLayer::setInDims() { + MatrixPtr input = inputLayers_[0]->getOutputValue(); + size_t batchSize = input->getHeight(); inDims_.setDim(0, batchSize); int h = inputLayers_[0]->getOutput().getFrameHeight(); if (h != 0) inDims_.setDim(2, h); int w = inputLayers_[0]->getOutput().getFrameWidth(); if (w != 0) inDims_.setDim(3, w); - setOutDims(batchSize); } void CropLayer::forward(PassType passType) { Layer::forward(passType); - MatrixPtr input = inputLayers_[0]->getOutputValue(); - size_t batchSize = input->getHeight(); - setTensorDim(batchSize); + setInDims(); + setOutDims(); int size = outDims_[1] * outDims_[2] * outDims_[3]; - resetOutput(batchSize, size); + resetOutput(outDims_[0], size); MatrixPtr outV = getOutputValue(); REGISTER_TIMER_INFO("CropForward", getName().c_str()); diff --git a/paddle/gserver/layers/CropLayer.h b/paddle/gserver/layers/CropLayer.h index 23cede1c3f..6b62026210 100644 --- a/paddle/gserver/layers/CropLayer.h +++ b/paddle/gserver/layers/CropLayer.h @@ -39,13 +39,12 @@ public: void backward(const UpdateCallback& callback = nullptr) override; protected: - void setOutDims(const size_t batchSize); - void setTensorDim(const size_t batchSize); + void setOutDims(); + void setInDims(); int32_t crop_axis_; std::vector crop_offsets_; std::vector crop_corner_; - std::vector crop_shape_; TensorShape inDims_; TensorShape targetDims_; TensorShape outDims_; diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index e599fa85ff..6b50d9cbf7 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2005,29 +2005,6 @@ class CropLayer(LayerBase): image_conf.img_size_y = input_layer.height image_conf.channels = input_layer.size / (input_layer.width * input_layer.height) - out_ch = image_conf.channels - out_h = image_conf.img_size - out_w = image_conf.img_size_y - if len(self.inputs) == 2: - # get channels, width and height from input_1 layer - input_layer = self.get_input_layer(1) - image_conf = self.config.inputs[1].image_conf - image_conf.img_size = input_layer.width - image_conf.img_size_y = input_layer.height - image_conf.channels = input_layer.size / (input_layer.width * - input_layer.height) - out_ch = image_conf.channels - out_h = image_conf.img_size_y - out_w = image_conf.img_size - else: - # set channels, width and heigth of current layer - if len(shape) > 2: - out_ch = shape[-3] - if len(shape) > 1: - out_h = shape[-2] - if len(shape) > 0: - out_w = shape[-1] - self.set_cnn_layer(name, out_h, out_w, out_ch) @config_layer('batch_norm') diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index b42cb02bff..5a7e91dd39 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -5881,9 +5881,9 @@ def prelu_layer(input, @wrap_name_default() @layer_support() -def crop_layer(input, axis, offset, shape=None, name=None, layer_attr=None): +def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None): """ - The crop layer crop images by offset and shape. User can set crop shape by + The crop layer crops images by offset and shape. User can set crop shape by args 'shape' explicitly or by reference input layer. @@ -5896,16 +5896,16 @@ def crop_layer(input, axis, offset, shape=None, name=None, layer_attr=None): :param input: The input layer.If two inputs were setted, the second input will be regarded as reference input :type input: LayerOutput or Sequence + :param offset: The crop offset + :type offset: Sequence :param axis: start axis to be cropped. To image input layer: - 0: batch size - 1: channels - 2: height - 3: width :type partial_sum: int - :param offset: The crop offset - :type offset: Sequence :param shape: The shape to be cropped. Default is None. - :type shape: Sqquence | None + :type shape: Sequence | None :param name: Name of this layer. :type name: basestring :return: LayerOutput object. @@ -5913,8 +5913,6 @@ def crop_layer(input, axis, offset, shape=None, name=None, layer_attr=None): """ if isinstance(input, LayerOutput): input = [input] - elif isinstance(input, Projection): - input = [input] else: assert isinstance(input, collections.Sequence) l = Layer( diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_crop.py b/python/paddle/trainer_config_helpers/tests/configs/test_crop.py new file mode 100644 index 0000000000..8314a7e9a5 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_crop.py @@ -0,0 +1,21 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +data = data_layer(name='data', size=2016, height=48, width=42) +refernce_data = data_layer(name='data', size=768, height=16, width=16) + +conv = img_conv_layer( + input=data, + filter_size=3, + num_channels=1, + num_filters=16, + padding=1, + act=LinearActivation(), + bias_attr=True) + +pool = img_pool_layer(input=conv, pool_size=2, stride=2, pool_type=MaxPooling()) + +crop = crop_layer(input=[pool, refernce_data], axis=2) + +outputs(pad) From d9fa6159b7b9109e76c8841388c7811eeac2eb6b Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 19 Jul 2017 14:06:58 +0800 Subject: [PATCH 26/37] add Flatten method to EigenVector --- paddle/framework/eigen.h | 15 +++++++++++++-- paddle/framework/eigen_test.cc | 6 +++++- paddle/operators/add_op.h | 5 +++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index cd87b042df..f5865635be 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -23,7 +23,7 @@ namespace framework { // EigenDim converts paddle::platform::DDim into Eigen::DSizes. template struct EigenDim { - typedef Eigen::DSizes Type; + using Type = Eigen::DSizes; static Type From(const DDim& dims) { PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); @@ -69,12 +69,23 @@ struct EigenVector { using ConstType = Eigen::TensorMap, Eigen::Aligned>; - + // From is to transfer a one dimension Tensor into a one dimension EigenVector static Type From(Tensor& tensor) { return EigenTensor::From(tensor); } + // Flatten is to reshape a Tensor into a one dimension EigenVector + static Type Flatten(Tensor& tensor) { + return EigenTensor::From( + tensor, make_ddim({static_cast(product(tensor.dims_))})); + } + static ConstType From(const Tensor& tensor) { return EigenTensor::From(tensor); } + + static ConstType Flatten(const Tensor& tensor) { + return EigenTensor::From( + tensor, make_ddim({static_cast(product(tensor.dims_))})); + } }; // Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix. diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index 23eec7533f..eca2dce60e 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -32,6 +32,10 @@ TEST(Eigen, Tensor) { } EigenTensor::Type et = EigenTensor::From(t); + + for (int i = 0; i < 1 * 2 * 3; i++) { + EXPECT_EQ(et(i), i); + } // TODO: check the content of et. } @@ -39,5 +43,5 @@ TEST(Eigen, Vector) {} TEST(Eigen, Matrix) {} -} // namespace platform +} // namespace framework } // namespace paddle diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index e7c106a23f..39d54a63bd 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -30,9 +30,10 @@ public: output->mutable_data(context.GetPlace()); - framework::EigenVector::From(*output).device( + framework::EigenVector::Flatten(*output).device( *(context.GetEigenDevice())) = - framework::EigenVector(*input0) + framework::EigenVector(*input1); + framework::EigenVector::Flatten(input0) + + framework::EigenVector::Flatten(input1); } }; From cff8762f2628e5abf562387bc076f23c6c029f7c Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 19 Jul 2017 14:33:07 +0800 Subject: [PATCH 27/37] add more uinttest for EigenTensor --- paddle/framework/eigen_test.cc | 65 +++++++++++++++++++++++++++++++--- 1 file changed, 61 insertions(+), 4 deletions(-) diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index eca2dce60e..b954c8d857 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -33,15 +33,72 @@ TEST(Eigen, Tensor) { EigenTensor::Type et = EigenTensor::From(t); + EXPECT_EQ(1, et.dimension(0)); + EXPECT_EQ(2, et.dimension(1)); + EXPECT_EQ(3, et.dimension(2)); + + for (int i = 0; i < 1; i++) { + for (int j = 0; j < 2; j++) { + for (int k = 0; k < 3; k++) { + EXPECT_EQ((i * 2 + j) * 3 + k, et(i, j, k)); + } + } + } + for (int i = 0; i < 1 * 2 * 3; i++) { + EXPECT_EQ(i, et(i)); + } +} + +TEST(Eigen, VectorFrom) { + Tensor t; + float* p = t.mutable_data(make_ddim({6}), platform::CPUPlace()); + for (int i = 0; i < 6; i++) { + p[i] = static_cast(i); + } + + EigenVector::Type ev = EigenVector::From(t); + + EXPECT_EQ(6, ev.dimension(0)); + + for (int i = 0; i < 6; i++) { + EXPECT_EQ(i, ev(i)); + } +} + +TEST(Eigen, VectorFlatten) { + Tensor t; + float* p = t.mutable_data(make_ddim({1, 2, 3}), platform::CPUPlace()); + for (int i = 0; i < 1 * 2 * 3; i++) { + p[i] = static_cast(i); + } + + EigenVector::Type ev = EigenVector::Flatten(t); + + EXPECT_EQ(1 * 2 * 3, ev.dimension(0)); + for (int i = 0; i < 1 * 2 * 3; i++) { - EXPECT_EQ(et(i), i); + EXPECT_EQ(i, ev(i)); } - // TODO: check the content of et. } -TEST(Eigen, Vector) {} +TEST(Eigen, Matrix) { + Tensor t; + float* p = t.mutable_data(make_ddim({2, 3}), platform::CPUPlace()); + for (int i = 0; i < 2 * 3; i++) { + p[i] = static_cast(i); + } + + EigenMatrix::Type em = EigenMatrix::From(t); -TEST(Eigen, Matrix) {} + EXPECT_EQ(2, em.dimension(0)); + EXPECT_EQ(3, em.dimension(1)); + + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 3; j++) { + EXPECT_EQ(i * 3 + j, em(i, j)); + } + } +} } // namespace framework } // namespace paddle From fab896c5a0219f2ffdc2ca034106407a98ddce65 Mon Sep 17 00:00:00 2001 From: liaogang Date: Wed, 19 Jul 2017 15:01:29 +0800 Subject: [PATCH 28/37] Remove using namespace --- paddle/platform/enforce_test.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/platform/enforce_test.cc b/paddle/platform/enforce_test.cc index 0a7ccd0819..d7152f8150 100644 --- a/paddle/platform/enforce_test.cc +++ b/paddle/platform/enforce_test.cc @@ -12,8 +12,6 @@ limitations under the License. */ #include "paddle/platform/enforce.h" #include "gtest/gtest.h" -using namespace paddle; - TEST(ENFORCE, OK) { PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); size_t val = 1; From 97c2a9a9698b8e8364ed99b66ea4232c527ed042 Mon Sep 17 00:00:00 2001 From: liaogang Date: Wed, 19 Jul 2017 16:43:56 +0800 Subject: [PATCH 29/37] Fix: compiler error under gpu --- paddle/platform/enforce.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index b6707659f2..6c1cd443c9 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -76,7 +76,7 @@ inline void throw_on_error(cudnnStatus_t stat, const Args&... args) { } else { // clang-format off throw std::runtime_error( - platform::dynload::cudnnGetErrorString(stat) + ", " + + platform::dynload::cudnnGetErrorString(stat) + string::Sprintf(args...) + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); // clang-format on @@ -107,7 +107,8 @@ inline void throw_on_error(cublasStatus_t stat, const Args&... args) { } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { ss << "CUBLAS: license error"; } - throw std::runtime_error(ss + ", " + string::Sprintf(args...) + + ss << ", "; + throw std::runtime_error(ss + string::Sprintf(args...) + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); } From 676b76da4a6600ede7a59078290743e5b8076ba8 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 19 Jul 2017 16:47:03 +0800 Subject: [PATCH 30/37] fix cmake --- CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9a85224843..2a6b0a20e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,6 @@ # limitations under the License cmake_minimum_required(VERSION 3.0) -SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ldl -lpthread") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR}) From 57c27b4e0013e4d3f51b41ae6950f70ae11be2e1 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 19 Jul 2017 17:11:49 +0800 Subject: [PATCH 31/37] make EigenTensor default unaligned and follow comments --- paddle/framework/eigen.h | 53 +++++++++++---------------------------- paddle/framework/tensor.h | 7 ++---- 2 files changed, 16 insertions(+), 44 deletions(-) diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index f5865635be..4ba4fd4d11 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -36,14 +36,15 @@ struct EigenDim { }; // Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor. -template +template struct EigenTensor { - using Type = Eigen::TensorMap, - Eigen::Aligned>; + // TODO(qijun) Now, default type in unaligned, and we will make a benchmark on + // the speed of aligned and unaligned version in future. + using Type = Eigen::TensorMap>; using ConstType = - Eigen::TensorMap, - Eigen::Aligned>; + Eigen::TensorMap>; static Type From(Tensor& tensor, DDim dims) { return Type(tensor.data(), EigenDim::From(dims)); @@ -60,50 +61,24 @@ struct EigenTensor { } }; -// Interpret paddle::platform::Tensor as EigenVecotr and EigenConstVector. -template -struct EigenVector { - using Type = Eigen::TensorMap, - Eigen::Aligned>; - - using ConstType = - Eigen::TensorMap, - Eigen::Aligned>; - // From is to transfer a one dimension Tensor into a one dimension EigenVector - static Type From(Tensor& tensor) { return EigenTensor::From(tensor); } - +template +struct EigenVector : public EigenTensor { // Flatten is to reshape a Tensor into a one dimension EigenVector - static Type Flatten(Tensor& tensor) { + static typename EigenTensor::Type Flatten(Tensor& tensor) { return EigenTensor::From( tensor, make_ddim({static_cast(product(tensor.dims_))})); } - static ConstType From(const Tensor& tensor) { - return EigenTensor::From(tensor); - } - - static ConstType Flatten(const Tensor& tensor) { + static typename EigenTensor::ConstType Flatten(const Tensor& tensor) { return EigenTensor::From( tensor, make_ddim({static_cast(product(tensor.dims_))})); } }; -// Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix. -template -struct EigenMatrix { - using Type = Eigen::TensorMap, - Eigen::Aligned>; - - using ConstType = - Eigen::TensorMap, - Eigen::Aligned>; - - static Type From(Tensor& tensor) { return EigenTensor::From(tensor); } - - static ConstType From(const Tensor& tensor) { - return EigenTensor::From(tensor); - } -}; +template +using EigenMatrix = EigenTensor; } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 8fbf42e7f6..8fd131cf89 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -37,15 +37,12 @@ class Tensor { template friend struct paddle::pybind::details::CastToPyBufferImpl; - template + template friend struct EigenTensor; - template + template friend struct EigenVector; - template - friend struct EigenMatrix; - public: Tensor() : offset_(0) {} From d6d057b4e8187df049f6f3ad7879fa045f2fc816 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 19 Jul 2017 20:21:49 +0800 Subject: [PATCH 32/37] change EQ to NEAR for float value --- paddle/framework/eigen_test.cc | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index b954c8d857..a9fa728e49 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -19,9 +19,9 @@ namespace framework { TEST(EigenDim, From) { EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3})); - EXPECT_EQ(1, ed[0]); - EXPECT_EQ(2, ed[1]); - EXPECT_EQ(3, ed[2]); + ASSERT_EQ(1, ed[0]); + ASSERT_EQ(2, ed[1]); + ASSERT_EQ(3, ed[2]); } TEST(Eigen, Tensor) { @@ -33,20 +33,17 @@ TEST(Eigen, Tensor) { EigenTensor::Type et = EigenTensor::From(t); - EXPECT_EQ(1, et.dimension(0)); - EXPECT_EQ(2, et.dimension(1)); - EXPECT_EQ(3, et.dimension(2)); + ASSERT_EQ(1, et.dimension(0)); + ASSERT_EQ(2, et.dimension(1)); + ASSERT_EQ(3, et.dimension(2)); for (int i = 0; i < 1; i++) { for (int j = 0; j < 2; j++) { for (int k = 0; k < 3; k++) { - EXPECT_EQ((i * 2 + j) * 3 + k, et(i, j, k)); + ASSERT_NEAR((i * 2 + j) * 3 + k, et(i, j, k), 1e-6f); } } } - for (int i = 0; i < 1 * 2 * 3; i++) { - EXPECT_EQ(i, et(i)); - } } TEST(Eigen, VectorFrom) { @@ -58,10 +55,10 @@ TEST(Eigen, VectorFrom) { EigenVector::Type ev = EigenVector::From(t); - EXPECT_EQ(6, ev.dimension(0)); + ASSERT_EQ(6, ev.dimension(0)); for (int i = 0; i < 6; i++) { - EXPECT_EQ(i, ev(i)); + ASSERT_NEAR(i, ev(i), 1e-6f); } } @@ -74,10 +71,10 @@ TEST(Eigen, VectorFlatten) { EigenVector::Type ev = EigenVector::Flatten(t); - EXPECT_EQ(1 * 2 * 3, ev.dimension(0)); + ASSERT_EQ(1 * 2 * 3, ev.dimension(0)); for (int i = 0; i < 1 * 2 * 3; i++) { - EXPECT_EQ(i, ev(i)); + ASSERT_NEAR(i, ev(i), 1e-6f); } } @@ -90,12 +87,12 @@ TEST(Eigen, Matrix) { EigenMatrix::Type em = EigenMatrix::From(t); - EXPECT_EQ(2, em.dimension(0)); - EXPECT_EQ(3, em.dimension(1)); + ASSERT_EQ(2, em.dimension(0)); + ASSERT_EQ(3, em.dimension(1)); for (int i = 0; i < 2; i++) { for (int j = 0; j < 3; j++) { - EXPECT_EQ(i * 3 + j, em(i, j)); + ASSERT_NEAR(i * 3 + j, em(i, j), 1e-6f); } } } From 2d2ee47bda7ad98956b914f2d81faf5e09b09eef Mon Sep 17 00:00:00 2001 From: liaogang Date: Wed, 19 Jul 2017 20:24:07 +0800 Subject: [PATCH 33/37] FIX: fix string --- paddle/platform/enforce.h | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 6c1cd443c9..5d440dec48 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -85,30 +85,29 @@ inline void throw_on_error(cudnnStatus_t stat, const Args&... args) { template inline void throw_on_error(cublasStatus_t stat, const Args&... args) { - std::stringstream ss; + std::string err; if (stat == CUBLAS_STATUS_SUCCESS) { return; } else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) { - ss << "CUBLAS: not initialized"; + err = "CUBLAS: not initialized, "; } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) { - ss << "CUBLAS: alloc failed"; + err = "CUBLAS: alloc failed, "; } else if (stat == CUBLAS_STATUS_INVALID_VALUE) { - ss << "CUBLAS: invalid value"; + err = "CUBLAS: invalid value, "; } else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) { - ss << "CUBLAS: arch mismatch"; + err = "CUBLAS: arch mismatch, "; } else if (stat == CUBLAS_STATUS_MAPPING_ERROR) { - ss << "CUBLAS: mapping error"; + err = "CUBLAS: mapping error, "; } else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) { - ss << "CUBLAS: execution failed"; + err = "CUBLAS: execution failed, "; } else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) { - ss << "CUBLAS: internal error"; + err = "CUBLAS: internal error, "; } else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) { - ss << "CUBLAS: not supported"; + err = "CUBLAS: not supported, "; } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { - ss << "CUBLAS: license error"; + err = "CUBLAS: license error, "; } - ss << ", "; - throw std::runtime_error(ss + string::Sprintf(args...) + + throw std::runtime_error(err + string::Sprintf(args...) + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); } From e3b27d19982b6eef33329ab0e9dcf718dd4c343e Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Wed, 19 Jul 2017 22:30:43 +0800 Subject: [PATCH 34/37] Add sgd op (#2950) * a simplest SGD op --- paddle/operators/CMakeLists.txt | 2 + paddle/operators/sgd_op.cc | 61 +++++++++++++++++++ paddle/operators/sgd_op.cu | 5 ++ paddle/operators/sgd_op.h | 39 ++++++++++++ paddle/operators/sgd_op_test.cc | 22 +++++++ paddle/pybind/CMakeLists.txt | 2 +- paddle/pybind/pybind.cc | 1 + .../paddle/v2/framework/tests/CMakeLists.txt | 2 +- .../paddle/v2/framework/tests/test_sgd_op.py | 18 ++++++ 9 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 paddle/operators/sgd_op.cc create mode 100644 paddle/operators/sgd_op.cu create mode 100644 paddle/operators/sgd_op.h create mode 100644 paddle/operators/sgd_op_test.cc create mode 100644 python/paddle/v2/framework/tests/test_sgd_op.py diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index bc64bfd7ec..a37720e509 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -51,3 +51,5 @@ op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) + +op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc new file mode 100644 index 0000000000..04df87a3ad --- /dev/null +++ b/paddle/operators/sgd_op.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sgd_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { + +class SGDOp : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override { + PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two"); + PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one"); + PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set"); + PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set"); + PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set"); + PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), + "Two input of SGD Op's dimension must be same."); + outputs[0]->set_dims(inputs[0]->dims()); + } +}; + +class SGDOpMaker : public framework::OpProtoAndCheckerMaker { +public: + SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("param", "input parameter"); + AddInput("grad", "input gradient"); + AddOutput("param_out", "output parameter"); + AddAttr("learning_rate", "learning rate of sgd"); + AddComment(R"DOC( + +Simplest sgd algorithm. + +param_out = param - learning_rate * grad; + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP(sgd, paddle::operators::SGDOp, paddle::operators::SGDOpMaker); +typedef paddle::operators::SGDOpKernel<::paddle::platform::CPUPlace, float> + SGDOpKernel_CPU_float; +REGISTER_OP_CPU_KERNEL(sgd, SGDOpKernel_CPU_float); diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu new file mode 100644 index 0000000000..400425db10 --- /dev/null +++ b/paddle/operators/sgd_op.cu @@ -0,0 +1,5 @@ +#include "paddle/operators/sgd_op.h" +#include "paddle/framework/op_registry.h" + +typedef paddle::operators::SGDOpKernel<::paddle::platform::GPUPlace, float> SGDOpKernel_GPU_float; +REGISTER_OP_GPU_KERNEL(sgd, SGDOpKernel_GPU_float); \ No newline at end of file diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h new file mode 100644 index 0000000000..2ee21ef8f9 --- /dev/null +++ b/paddle/operators/sgd_op.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +class SGDOpKernel : public framework::OpKernel { +public: + void Compute(const framework::KernelContext& ctx) const override { + auto param = ctx.Input("param")->Get(); + auto grad = ctx.Input("grad")->Get(); + auto* param_out = ctx.Output(0)->GetMutable(); + float lr = ctx.op_.GetAttr("learning_rate"); + + param_out->mutable_data(ctx.GetPlace()); + + param_out->flat().device(*(ctx.GetEigenDevice())) = + param.flat() - lr * grad.flat(); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sgd_op_test.cc b/paddle/operators/sgd_op_test.cc new file mode 100644 index 0000000000..75137259f5 --- /dev/null +++ b/paddle/operators/sgd_op_test.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +USE_OP(sgd); +TEST(SGDOp, GetOpProto) { + auto& protos = paddle::framework::OpRegistry::protos(); + auto it = protos.find("sgd"); + ASSERT_NE(it, protos.end()); +} diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 29fb29c7c1..6354dd211d 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,2 +1,2 @@ cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python - add_op fc_op) + add_op fc_op sgd_op) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 7e84550f77..54707a2859 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -28,6 +28,7 @@ namespace pd = paddle::framework; USE_OP(add_two); USE_OP_WITHOUT_KERNEL(fc); +USE_OP(sgd); PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of Paddle Paddle"); diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index f71009aa85..ec076e40c9 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -1,3 +1,3 @@ add_python_test(test_framework test_protobuf.py test_scope.py test_default_scope_funcs.py test_op_creation_methods.py - test_tensor.py test_fc_op.py test_add_two_op.py) + test_tensor.py test_fc_op.py test_add_two_op.py test_sgd_op.py) diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py new file mode 100644 index 0000000000..405d73b224 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_sgd_op.py @@ -0,0 +1,18 @@ +import unittest +import numpy +from op_test_util import OpTestMeta + + +class TestSGD(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "sgd" + self.param = numpy.random.random((342, 345)).astype("float32") + self.grad = numpy.random.random((342, 345)).astype("float32") + self.learning_rate = 0.1 + self.param_out = self.param - self.learning_rate * self.grad + + +if __name__ == "__main__": + unittest.main() From 5e8a4f16c77333f887656fff21ec2357f8f83790 Mon Sep 17 00:00:00 2001 From: liaogang Date: Wed, 19 Jul 2017 22:33:28 +0800 Subject: [PATCH 35/37] Fix conflcts --- paddle/framework/tensor.h | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 0f99fc89f8..93c6fad5d3 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include #include "paddle/framework/ddim.h" -#include "paddle/framework/tensor_types.h" #include "paddle/memory/memory.h" #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" From e4984f13e9ddaa035234f0672781b6e324591ed8 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 19 Jul 2017 23:02:27 +0800 Subject: [PATCH 36/37] fix tensor usage in sgd-op --- paddle/operators/sgd_op.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index 2ee21ef8f9..4b2d214618 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "glog/logging.h" +#include "paddle/framework/eigen.h" #include "paddle/framework/operator.h" namespace paddle { @@ -30,8 +31,10 @@ public: param_out->mutable_data(ctx.GetPlace()); - param_out->flat().device(*(ctx.GetEigenDevice())) = - param.flat() - lr * grad.flat(); + framework::EigenVector::Flatten(*param_out) + .device(*(ctx.GetEigenDevice())) = + framework::EigenVector::Flatten(param) - + lr * framework::EigenVector::Flatten(grad); } }; From a98346f4cd1a0468ac2d1d30574607698f7432bc Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 19 Jul 2017 21:06:07 -0500 Subject: [PATCH 37/37] Add comment to `OpTestMeta` (#2968) --- python/paddle/v2/framework/tests/op_test_util.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index 237f9b7eb0..b1fa12cc89 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -5,6 +5,18 @@ import paddle.v2.framework.create_op_creation_methods as creation class OpTestMeta(type): + """ + Operator Test ClassMeta. + + It injects `test_all` method into user's OperatorTest class, to make Python + unittest module run that method. + + The `test_all` read what value is stored in `self`. It use self's values to + create and run a operator, and check whether that op is OK or not. + + See `test_add_two_op` for example usage. + """ + def __new__(cls, name, bases, attrs): obj = super(OpTestMeta, cls).__new__(cls, name, bases, attrs)