From dbf1d75f57c465696c82c618d593c4470e6d44ea Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 26 Dec 2017 15:27:42 +0800 Subject: [PATCH 01/33] Add a GemmConvMobileFunction. --- paddle/function/GemmConvOp.cpp | 152 +++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index de7b70e271..08eb6a5490 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -134,6 +134,154 @@ public: } }; +/* + * \brief Forward calculation of convolution, optimized for mobile. + */ +template +class GemmConvMobileFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); + // TODO(hedaoyuan): Need to define some index macros, + // to avoid useing 0 and 1. + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + + real beta; + if (outputs[0].getArgType() == ADD_TO) { + beta = 1.0; + } else { + beta = 0.0; + } + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + + real* inputData = inputs[0].data(); + real* filterData = inputs[1].data(); + real* outputData = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + + TensorShape imShape = + TensorShape({inputChannels / groups_, inputHeight, inputWidth}); + + TensorShape colShape; + real* colData = NULL; + + size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth; + size_t colWidth = outputHeight * outputWidth; + // Max col matrix height 256, Max col matrix width 1024 + size_t stepColHeight = std::min(colHeight, (size_t)256); + size_t stepColWidth = std::min(colWidth, (size_t)2048); + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + resizeBuffer(stepColHeight * stepColWidth * sizeof(real)); + colData = reinterpret_cast(memory_->getBuf()); + } + + Im2ColFunctor im2col; + GemmFunctor gemm; + size_t inputOffset = imShape.getElements(); + size_t outputOffset = + (outputChannels / groups_) * outputHeight * outputWidth; + size_t filterOffset = filter.getElements() / groups_; + + int nStride = colWidth; + int kStride = colHeight; + for (size_t i = 0; i < batchSize; i++) { + for (size_t g = 0; g < groups_; g++) { + if (needIm2col) { + real beta_ = beta; + for (size_t colHeightStart = 0; colHeightStart < colHeight; + colHeightStart += stepColHeight) { + for (size_t colWidthStart = 0; colWidthStart < colWidth; + colWidthStart += stepColWidth) { + int N = std::min(colWidth - colWidthStart, stepColWidth); + int K = std::min(colHeight - colHeightStart, stepColHeight); + // im2col + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW(), + colHeightStart, + K, + colWidthStart, + N); + + // gemm + int M = outputChannels / groups_; + gemm(CblasNoTrans, + CblasNoTrans, + M, + N, + K, + 1.0f, + filterData + g * filterOffset + colHeightStart, + kStride, + colData, + N, + beta_, + outputData + g * outputOffset + colWidthStart, + nStride); + } + beta_ = 1.0; + } + } else { + int M = outputChannels / groups_; + int N = outputHeight * outputWidth; + int K = inputChannels / groups_ * filterHeight * filterWidth; + gemm(CblasNoTrans, + CblasNoTrans, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + K, + inputData + g * inputOffset, + N, + beta, + outputData + g * outputOffset, + N); + } + } + inputData += inputChannels * inputHeight * inputWidth; + outputData += outputChannels * outputHeight * outputWidth; + } + } +}; + /* * \brief Backward input calculation of convolution. */ @@ -348,7 +496,11 @@ public: } }; +#ifdef PADDLE_MOBILE_INFERENCE +REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvMobileFunction); +#else REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction); +#endif REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction); REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction); #ifdef PADDLE_WITH_CUDA From d775895e939eb9e4ce4378e349a76d56bd4af72d Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 26 Dec 2017 15:43:30 +0800 Subject: [PATCH 02/33] Add Im2ColMobileFunctor. --- paddle/function/GemmConvOp.cpp | 56 +++++++++++++++++----------------- paddle/function/Im2Col.h | 48 +++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 28 deletions(-) diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 08eb6a5490..75a5b4fe84 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -206,8 +206,7 @@ public: colData = reinterpret_cast(memory_->getBuf()); } - Im2ColFunctor im2col; - GemmFunctor gemm; + Im2ColMobileFunctor im2col; size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -241,19 +240,20 @@ public: // gemm int M = outputChannels / groups_; - gemm(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - 1.0f, - filterData + g * filterOffset + colHeightStart, - kStride, - colData, - N, - beta_, - outputData + g * outputOffset + colWidthStart, - nStride); + BlasGemm::compute( + false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset + colHeightStart, + kStride, + colData, + N, + beta_, + outputData + g * outputOffset + colWidthStart, + nStride); } beta_ = 1.0; } @@ -261,19 +261,19 @@ public: int M = outputChannels / groups_; int N = outputHeight * outputWidth; int K = inputChannels / groups_ * filterHeight * filterWidth; - gemm(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - 1.0f, - filterData + g * filterOffset, - K, - inputData + g * inputOffset, - N, - beta, - outputData + g * outputOffset, - N); + BlasGemm::compute(false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + K, + inputData + g * inputOffset, + N, + beta, + outputData + g * outputOffset, + N); } } inputData += inputChannels * inputHeight * inputWidth; diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index 0c37fc9724..f43ca465a2 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -98,4 +98,52 @@ public: int dilationWidth = 1); }; +template +class Im2ColMobileFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int colHeightStart, + int colHeightSize, + int colWidthStart, + int colWidthSize) { + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputWidth = colShape[4]; + + for (int colh = 0; colh < colHeightSize; colh++) { + int wOffset = (colHeightStart + colh) % filterWidth; + int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight; + int c_im = (colHeightStart + colh) / filterWidth / filterHeight; + + for (int colw = 0; colw < colWidthSize; colw++) { + int h = (colWidthStart + colw) / outputWidth; + int w = (colWidthStart + colw) % outputWidth; + + int imRowIdx = h * strideHeight + hOffset; + int imColIdx = w * strideWidth + wOffset; + if ((imRowIdx - paddingHeight) < 0 || + (imRowIdx - paddingHeight) >= inputHeight || + (imColIdx - paddingWidth) < 0 || + (imColIdx - paddingWidth) >= inputWidth) { + colData[colh * colWidthSize + colw] = T(0); + } else { + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; + colData[colh * colWidthSize + colw] = + imData[imRowIdx * inputWidth + imColIdx]; + } + } + } + } +}; + } // namespace paddle From 19547943bac716d73354fcdb33c6d909b65308b3 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 26 Dec 2017 15:59:11 +0800 Subject: [PATCH 03/33] Add test for Im2ColMobileFunctor. --- paddle/function/Im2ColTest.cpp | 80 ++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp index 1f085538d8..0dc58696f7 100644 --- a/paddle/function/Im2ColTest.cpp +++ b/paddle/function/Im2ColTest.cpp @@ -138,4 +138,84 @@ TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor(); } #endif +template +void TestIm2ColMobileFunctor() { + for (size_t channels : {1, 5, 32}) { + for (size_t inputHeight : {5, 33, 100}) { + for (size_t inputWidth : {5, 32, 96}) { + for (size_t filterHeight : {1, 5}) { + for (size_t filterWidth : {3, 7}) { + for (size_t stride : {1, 2}) { + for (size_t padding : {0, 1}) { + for (size_t dilation : {1 /*, 3*/}) { + size_t filterSizeH = (filterHeight - 1) * dilation + 1; + size_t filterSizeW = (filterWidth - 1) * dilation + 1; + if (inputHeight + 2 * padding < filterSizeH || + inputWidth + 2 * padding < filterSizeW) + break; + if (padding >= filterSizeH || padding >= filterSizeW) break; + size_t outputHeight = + (inputHeight - filterSizeH + 2 * padding) / stride + 1; + size_t outputWidth = + (inputWidth - filterSizeW + 2 * padding) / stride + 1; + + TensorShape imShape = + TensorShape({channels, inputHeight, inputWidth}); + TensorShape colShape1 = TensorShape({channels, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + size_t height = channels * filterHeight * filterWidth; + size_t width = outputHeight * outputWidth; + VectorPtr input1 = + Vector::create(imShape.getElements(), false); + VectorPtr input2 = + Vector::create(imShape.getElements(), false); + MatrixPtr output1 = + Matrix::create(height, width, false, false); + MatrixPtr output2 = + Matrix::create(height, width, false, false); + input1->uniform(0.001, 1); + input2->copyFrom(*input1); + + Im2ColFunctor im2Col1; + Im2ColMobileFunctor im2Col2; + im2Col1(input1->getData(), + imShape, + output1->getData(), + colShape1, + stride, + stride, + padding, + padding, + dilation, + dilation); + im2Col2(input2->getData(), + imShape, + output2->getData(), + colShape1, + stride, + stride, + padding, + padding, + 0, + height, + 0, + width); + + autotest::TensorCheckEqual(*output1, *output2); + } + } + } + } + } + } + } + } +} + +TEST(Im2ColFunctor, Mobile) { TestIm2ColMobileFunctor(); } + } // namespace paddle From a850dec991d7d6d28f2669a959b3198a7a796ce9 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 26 Dec 2017 16:07:09 +0800 Subject: [PATCH 04/33] Add dilation. --- paddle/function/GemmConvOp.cpp | 2 ++ paddle/function/Im2Col.h | 6 ++++-- paddle/function/Im2ColTest.cpp | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 75a5b4fe84..acf1415ebf 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -233,6 +233,8 @@ public: strideW(), paddingH(), paddingW(), + dilationH(), + dilationW(), colHeightStart, K, colWidthStart, diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index f43ca465a2..1053e4fd23 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -109,6 +109,8 @@ public: int strideWidth, int paddingHeight, int paddingWidth, + int dilationHeight, + int dilationWidth, int colHeightStart, int colHeightSize, int colWidthStart, @@ -128,8 +130,8 @@ public: int h = (colWidthStart + colw) / outputWidth; int w = (colWidthStart + colw) % outputWidth; - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; + int imRowIdx = h * strideHeight + hOffset * dilationHeight; + int imColIdx = w * strideWidth + wOffset * dilationWidth; if ((imRowIdx - paddingHeight) < 0 || (imRowIdx - paddingHeight) >= inputHeight || (imColIdx - paddingWidth) < 0 || diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp index 0dc58696f7..c573469168 100644 --- a/paddle/function/Im2ColTest.cpp +++ b/paddle/function/Im2ColTest.cpp @@ -147,7 +147,7 @@ void TestIm2ColMobileFunctor() { for (size_t filterWidth : {3, 7}) { for (size_t stride : {1, 2}) { for (size_t padding : {0, 1}) { - for (size_t dilation : {1 /*, 3*/}) { + for (size_t dilation : {1, 3}) { size_t filterSizeH = (filterHeight - 1) * dilation + 1; size_t filterSizeW = (filterWidth - 1) * dilation + 1; if (inputHeight + 2 * padding < filterSizeH || @@ -200,6 +200,8 @@ void TestIm2ColMobileFunctor() { stride, padding, padding, + dilation, + dilation, 0, height, 0, From f453b7137f8ed5a10ff47901401a796338d6e504 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 26 Dec 2017 16:10:15 +0800 Subject: [PATCH 05/33] Refine code. --- paddle/function/GemmConvOp.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index acf1415ebf..25cc3df667 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -126,14 +126,11 @@ public: inputData += inputChannels * inputHeight * inputWidth; outputData += outputChannels * outputHeight * outputWidth; } -#ifdef PADDLE_MOBILE_INFERENCE - if (Device == DEVICE_TYPE_CPU) { - memory_.reset(); - } -#endif } }; +#ifdef PADDLE_MOBILE_INFERENCE + /* * \brief Forward calculation of convolution, optimized for mobile. */ @@ -284,6 +281,8 @@ public: } }; +#endif + /* * \brief Backward input calculation of convolution. */ From e5777d062bd916b44d12ee5b4e28c8cbef32524d Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 27 Dec 2017 10:19:41 +0800 Subject: [PATCH 06/33] fix build link rt --- paddle/pybind/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 6afed7eec7..ced75cbfd8 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -3,6 +3,7 @@ if(WITH_PYTHON) SRCS pybind.cc exception.cc protobuf.cc const_value.cc DEPS pybind python backward proto_desc paddle_memory executor prune init ${GLOB_OP_LIB}) + target_link_libraries(paddle_pybind rt) endif(WITH_PYTHON) if(WITH_DOC) From fd2bf55016e6de50bbc436476050f1c442cb654c Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 10:29:29 +0800 Subject: [PATCH 07/33] Rename API of DeviceContext Make them as usual names. --- paddle/framework/init.cc | 2 +- paddle/framework/operator.cc | 4 +-- paddle/operators/array_operator.h | 4 +-- paddle/operators/array_to_lod_tensor_op.cc | 5 ++-- paddle/operators/assign_op.cc | 4 +-- paddle/operators/cond_op.cc | 4 +-- paddle/operators/feed_op.cc | 4 +-- paddle/operators/fetch_op.cc | 4 +-- paddle/operators/fill_constant_op.cc | 4 +-- paddle/operators/fill_op.cc | 5 ++-- paddle/operators/load_op.cc | 4 +-- paddle/operators/lod_tensor_to_array_op.cc | 5 ++-- paddle/operators/merge_lod_tensor_op.cc | 4 +-- paddle/operators/recurrent_op.cc | 9 +++--- .../reorder_lod_tensor_by_rank_op.cc | 4 +-- paddle/operators/save_op.cc | 4 +-- paddle/operators/shrink_rnn_memory_op.cc | 4 +-- paddle/operators/split_lod_tensor_op.cc | 4 +-- .../operators/tensor_array_read_write_op.cc | 10 ++++--- paddle/platform/device_context.cc | 20 +------------ paddle/platform/device_context.h | 12 ++------ paddle/platform/device_context_test.cu | 29 +++++-------------- paddle/platform/nccl_test.cu | 2 +- 23 files changed, 59 insertions(+), 92 deletions(-) diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc index d6601090d5..682cff168d 100644 --- a/paddle/framework/init.cc +++ b/paddle/framework/init.cc @@ -71,7 +71,7 @@ bool InitDevices(const std::vector &devices) { places.emplace_back(platform::CPUPlace()); LOG(WARNING) << "Not specified CPU device, create CPU by Default."; } - platform::DeviceContextPool::Create(places); + platform::DeviceContextPool::Init(places); return true; } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 66840a2e03..307730de2e 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -387,8 +387,8 @@ void OperatorWithKernel::Run(const Scope& scope, const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); this->InferShape(&infer_shape_ctx); - platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); - auto dev_ctx = pool.Borrow(place); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto dev_ctx = pool.Get(place); // check if op[type] has kernel registered. auto& all_op_kernels = AllOpKernels(); diff --git a/paddle/operators/array_operator.h b/paddle/operators/array_operator.h index 060ffac827..e0eef5d9f9 100644 --- a/paddle/operators/array_operator.h +++ b/paddle/operators/array_operator.h @@ -35,8 +35,8 @@ class ArrayOp : public framework::OperatorBase { PADDLE_ENFORCE_EQ(i_tensor.numel(), 1); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); size_t offset; if (platform::is_gpu_place(i_tensor.place())) { diff --git a/paddle/operators/array_to_lod_tensor_op.cc b/paddle/operators/array_to_lod_tensor_op.cc index 0aa04c268b..49366fee8d 100644 --- a/paddle/operators/array_to_lod_tensor_op.cc +++ b/paddle/operators/array_to_lod_tensor_op.cc @@ -106,8 +106,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { } auto slice = out->Slice(out_offset, out_offset + len); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(x[x_idx].Slice(start_offset, end_offset), place, dev_ctx, &slice); diff --git a/paddle/operators/assign_op.cc b/paddle/operators/assign_op.cc index 0560040509..7d77be3be1 100644 --- a/paddle/operators/assign_op.cc +++ b/paddle/operators/assign_op.cc @@ -82,8 +82,8 @@ class AssignOp : public framework::OperatorBase { out != nullptr, "The Output(Out) should not be null if the Input(X) is set."); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); } diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index 455fbd8ca3..e333002bfd 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -195,8 +195,8 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, void CondOp::Run(const Scope& scope, const platform::Place& place) const { // get device context from pool - platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); - auto& dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& dev_ctx = *pool.Get(place); PrepareDataForSubnet(scope, dev_ctx); std::vector& sub_scopes = GetSubScopes(scope); diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index cecbb7226a..48da52c3b6 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -49,8 +49,8 @@ class FeedOp : public framework::OperatorBase { auto *out_item = out_var->GetMutable(); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(feed_item, place, dev_ctx, out_item); out_item->set_lod(feed_item.lod()); diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index fa20a06540..387d1e0a74 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -52,8 +52,8 @@ class FetchOp : public framework::OperatorBase { // FIXME(yuyang18): Should we assume the fetch operator always generate // CPU outputs? - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item); dev_ctx.Wait(); diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc index fe0706c4a9..dcd43a30c8 100644 --- a/paddle/operators/fill_constant_op.cc +++ b/paddle/operators/fill_constant_op.cc @@ -49,8 +49,8 @@ class FillConstantOp : public framework::OperatorBase { out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); } - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(dev_place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); math::set_constant(dev_ctx, &out, value); } }; diff --git a/paddle/operators/fill_op.cc b/paddle/operators/fill_op.cc index 57b4ec6938..084ba1db62 100644 --- a/paddle/operators/fill_op.cc +++ b/paddle/operators/fill_op.cc @@ -69,8 +69,9 @@ class FillOp : public framework::OperatorBase { if (!force_cpu && platform::is_gpu_place(place)) { // Copy tensor to out - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(tensor, place, dev_ctx, &out); } } diff --git a/paddle/operators/load_op.cc b/paddle/operators/load_op.cc index 5425375c1f..65f021d919 100644 --- a/paddle/operators/load_op.cc +++ b/paddle/operators/load_op.cc @@ -40,8 +40,8 @@ class LoadOp : public framework::OperatorBase { auto *tensor = out_var->GetMutable(); framework::DeserializeFromStream(fin, tensor); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); if (platform::is_gpu_place(place)) { // copy CPU to GPU diff --git a/paddle/operators/lod_tensor_to_array_op.cc b/paddle/operators/lod_tensor_to_array_op.cc index ed99915bb7..8d164b4abc 100644 --- a/paddle/operators/lod_tensor_to_array_op.cc +++ b/paddle/operators/lod_tensor_to_array_op.cc @@ -88,8 +88,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase { auto slice = out[i].Slice(static_cast(offset), static_cast(offset + len)); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(x.Slice(static_cast(each_range.begin), static_cast(each_range.end)), diff --git a/paddle/operators/merge_lod_tensor_op.cc b/paddle/operators/merge_lod_tensor_op.cc index 2287f34791..3f999e404f 100644 --- a/paddle/operators/merge_lod_tensor_op.cc +++ b/paddle/operators/merge_lod_tensor_op.cc @@ -30,8 +30,8 @@ class MergeLoDTensorOp : public framework::OperatorBase { void Run(const framework::Scope &scope, const platform::Place &dev_place) const override { // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(dev_place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); auto &x = scope.FindVar(Input("X"))->Get(); auto &mask = scope.FindVar(Input("Mask"))->Get(); diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 71769e67c7..056fa46949 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -272,8 +272,9 @@ class RecurrentOp : public RecurrentBase { false /*create_local_scope*/); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); // Copy inside::output -> outside::output // outside::output[seq_offset: seq_offset + 1] = inside::output @@ -326,8 +327,8 @@ class RecurrentGradOp : public RecurrentBase { auto *program = block->Program(); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); for (size_t step_id = 0; step_id < seq_len; ++step_id) { size_t seq_offset = reverse ? step_id : seq_len - step_id - 1; diff --git a/paddle/operators/reorder_lod_tensor_by_rank_op.cc b/paddle/operators/reorder_lod_tensor_by_rank_op.cc index 1063388e25..8d652ff806 100644 --- a/paddle/operators/reorder_lod_tensor_by_rank_op.cc +++ b/paddle/operators/reorder_lod_tensor_by_rank_op.cc @@ -131,8 +131,8 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { auto x_sliced = x.Slice(x_offset, x_offset + len); auto out_sliced = out->Slice(out_offset, out_offset + len); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(x_sliced, out_sliced.place(), dev_ctx, &out_sliced); out_offset += len; return out_offset; diff --git a/paddle/operators/save_op.cc b/paddle/operators/save_op.cc index d045a8b5b8..4b1cbe8883 100644 --- a/paddle/operators/save_op.cc +++ b/paddle/operators/save_op.cc @@ -91,8 +91,8 @@ class SaveOp : public framework::OperatorBase { auto &tensor = var->Get(); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::SerializeToStream(fout, tensor, dev_ctx); } diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index e8a4773547..e5ef0740b6 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -106,8 +106,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { dx_tensor.mutable_data(x_tensor.place(), x_tensor.type()); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); if (dout_var == nullptr) { // dx_tensor fill zero math::set_constant(dev_ctx, &dx_tensor, 0.0f); diff --git a/paddle/operators/split_lod_tensor_op.cc b/paddle/operators/split_lod_tensor_op.cc index 89826ca6ee..2d8787d740 100644 --- a/paddle/operators/split_lod_tensor_op.cc +++ b/paddle/operators/split_lod_tensor_op.cc @@ -45,8 +45,8 @@ class SplitLoDTensorOp : public framework::OperatorBase { auto &x_lod = x.lod(); auto &mask_dim = mask.dims(); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(dev_place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); std::unique_ptr cpu_mask{new framework::LoDTensor()}; if (platform::is_cpu_place(mask.place())) { diff --git a/paddle/operators/tensor_array_read_write_op.cc b/paddle/operators/tensor_array_read_write_op.cc index 9529aab573..53e38ec703 100644 --- a/paddle/operators/tensor_array_read_write_op.cc +++ b/paddle/operators/tensor_array_read_write_op.cc @@ -40,8 +40,9 @@ class WriteToArrayOp : public ArrayOp { if (x_tensor.memory_size() > 0) { auto *out_tensor = &out->at(offset); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); CopyFrom(x_tensor, place, dev_ctx, out_tensor); out_tensor->set_lod(x_tensor.lod()); @@ -132,8 +133,9 @@ class ReadFromArrayOp : public ArrayOp { auto *out_tensor = out->GetMutable(); size_t offset = GetOffset(scope, place); if (offset < x_array.size()) { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(x_array[offset], place, dev_ctx, out_tensor); out_tensor->set_lod(x_array[offset].lod()); } else { diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index e450ef32a4..ea07f2e002 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -17,7 +17,7 @@ namespace platform { DeviceContextPool* DeviceContextPool::pool = nullptr; -const platform::DeviceContext* DeviceContextPool::Borrow( +const platform::DeviceContext* DeviceContextPool::Get( const platform::Place& place) { auto it = device_contexts_.find(place); if (it == device_contexts_.end()) { @@ -28,24 +28,6 @@ const platform::DeviceContext* DeviceContextPool::Borrow( return it->second; } -std::vector DeviceContextPool::Borrow( - const std::vector& places) { - PADDLE_ENFORCE_GT(places.size(), 0); - PADDLE_ENFORCE_LE(places.size(), device_contexts_.size()); - std::vector borrowed_contexts; - for (auto& place : places) { - auto it = device_contexts_.find(place); - if (it != device_contexts_.end()) { - borrowed_contexts.emplace_back(it->second); - } else { - PADDLE_THROW( - "'Place' is not supported, Please re-compile with WITH_GPU " - "option"); - } - } - return borrowed_contexts; -} - DeviceContextPool::DeviceContextPool( const std::vector& places) { PADDLE_ENFORCE_GT(places.size(), 0); diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 8ba12e1657..dfef2c16d8 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -109,13 +109,13 @@ class DeviceContextPool { public: explicit DeviceContextPool(const std::vector& places); - static DeviceContextPool& Get() { + static DeviceContextPool& Instance() { PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!"); return *pool; } /*! \brief Create should only called by Init function */ - static DeviceContextPool& Create(const std::vector& places) { + static DeviceContextPool& Init(const std::vector& places) { if (pool == nullptr) { pool = new DeviceContextPool(places); } @@ -123,13 +123,7 @@ class DeviceContextPool { } /*! \brief Return handle of single device context. */ - const platform::DeviceContext* Borrow(const platform::Place& place); - - /*! \brief Return handle of multi-device context. */ - std::vector Borrow( - const std::vector& places); - - ~DeviceContextPool() {} + const platform::DeviceContext* Get(const platform::Place& place); private: static DeviceContextPool* pool; diff --git a/paddle/platform/device_context_test.cu b/paddle/platform/device_context_test.cu index 91011bf71c..ca10cf3463 100644 --- a/paddle/platform/device_context_test.cu +++ b/paddle/platform/device_context_test.cu @@ -71,35 +71,20 @@ TEST(Device, DeviceContextPool) { using paddle::platform::CPUPlace; using paddle::platform::CUDAPlace; - DeviceContextPool& pool = DeviceContextPool::Get(); - auto cpu_dev_ctx1 = pool.Borrow(CPUPlace()); - auto cpu_dev_ctx2 = pool.Borrow(CPUPlace()); - EXPECT_TRUE(cpu_dev_ctx2 == cpu_dev_ctx1); + DeviceContextPool& pool = DeviceContextPool::Instance(); + auto cpu_dev_ctx1 = pool.Get(CPUPlace()); + auto cpu_dev_ctx2 = pool.Get(CPUPlace()); + ASSERT_EQ(cpu_dev_ctx2, cpu_dev_ctx1); std::vector gpu_places; int count = paddle::platform::GetCUDADeviceCount(); for (int i = 0; i < count; ++i) { - gpu_places.emplace_back(CUDAPlace(i)); - } - auto dev_ctxs = pool.Borrow(gpu_places); - for (size_t i = 0; i < dev_ctxs.size(); ++i) { - auto* dev_ctx = static_cast(dev_ctxs[i]); - - // check same as CUDAPlace(i) - CUDAPlace place = boost::get(dev_ctx->GetPlace()); - EXPECT_EQ(place.GetDeviceId(), static_cast(i)); + auto dev_ctx = pool.Get(CUDAPlace(i)); + ASSERT_NE(dev_ctx, nullptr); } } int main(int argc, char** argv) { - int dev_count = paddle::platform::GetCUDADeviceCount(); - if (dev_count <= 1) { - LOG(WARNING) << "Cannot test multi-gpu DeviceContextPool, because the CUDA " - "device count is " - << dev_count; - return 0; - } - std::vector places; places.emplace_back(paddle::platform::CPUPlace()); @@ -109,7 +94,7 @@ int main(int argc, char** argv) { } VLOG(0) << " DeviceCount " << count; - paddle::platform::DeviceContextPool::Create(places); + paddle::platform::DeviceContextPool::Init(places); testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/paddle/platform/nccl_test.cu b/paddle/platform/nccl_test.cu index 8f815863a7..ef6d845874 100644 --- a/paddle/platform/nccl_test.cu +++ b/paddle/platform/nccl_test.cu @@ -144,7 +144,7 @@ int main(int argc, char** argv) { } VLOG(0) << " DeviceCount " << count; - paddle::platform::DeviceContextPool::Create(places); + paddle::platform::DeviceContextPool::Init(places); testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); From a5e1cf5a2eeec59740f5ff5c60dc104b2aa9b520 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 10:29:29 +0800 Subject: [PATCH 08/33] Rename API of DeviceContext Make them as usual names. --- paddle/framework/init.cc | 2 +- paddle/framework/operator.cc | 4 +-- paddle/operators/array_operator.h | 4 +-- paddle/operators/array_to_lod_tensor_op.cc | 5 ++-- paddle/operators/assign_op.cc | 4 +-- paddle/operators/cond_op.cc | 4 +-- paddle/operators/feed_op.cc | 4 +-- paddle/operators/fetch_op.cc | 4 +-- paddle/operators/fill_constant_op.cc | 4 +-- paddle/operators/fill_op.cc | 5 ++-- paddle/operators/load_op.cc | 4 +-- paddle/operators/lod_tensor_to_array_op.cc | 5 ++-- paddle/operators/merge_lod_tensor_op.cc | 4 +-- paddle/operators/recurrent_op.cc | 9 +++--- .../reorder_lod_tensor_by_rank_op.cc | 4 +-- paddle/operators/save_op.cc | 4 +-- paddle/operators/shrink_rnn_memory_op.cc | 4 +-- paddle/operators/split_lod_tensor_op.cc | 4 +-- .../operators/tensor_array_read_write_op.cc | 10 ++++--- paddle/platform/device_context.cc | 20 +------------ paddle/platform/device_context.h | 12 ++------ paddle/platform/device_context_test.cu | 29 +++++-------------- paddle/platform/nccl_test.cu | 2 +- 23 files changed, 59 insertions(+), 92 deletions(-) diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc index d6601090d5..682cff168d 100644 --- a/paddle/framework/init.cc +++ b/paddle/framework/init.cc @@ -71,7 +71,7 @@ bool InitDevices(const std::vector &devices) { places.emplace_back(platform::CPUPlace()); LOG(WARNING) << "Not specified CPU device, create CPU by Default."; } - platform::DeviceContextPool::Create(places); + platform::DeviceContextPool::Init(places); return true; } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 886f73e7b8..e8d4be8675 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -388,8 +388,8 @@ void OperatorWithKernel::Run(const Scope& scope, const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); this->InferShape(&infer_shape_ctx); - platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); - auto dev_ctx = pool.Borrow(place); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto dev_ctx = pool.Get(place); // check if op[type] has kernel registered. auto& all_op_kernels = AllOpKernels(); diff --git a/paddle/operators/array_operator.h b/paddle/operators/array_operator.h index 060ffac827..e0eef5d9f9 100644 --- a/paddle/operators/array_operator.h +++ b/paddle/operators/array_operator.h @@ -35,8 +35,8 @@ class ArrayOp : public framework::OperatorBase { PADDLE_ENFORCE_EQ(i_tensor.numel(), 1); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); size_t offset; if (platform::is_gpu_place(i_tensor.place())) { diff --git a/paddle/operators/array_to_lod_tensor_op.cc b/paddle/operators/array_to_lod_tensor_op.cc index 0aa04c268b..49366fee8d 100644 --- a/paddle/operators/array_to_lod_tensor_op.cc +++ b/paddle/operators/array_to_lod_tensor_op.cc @@ -106,8 +106,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { } auto slice = out->Slice(out_offset, out_offset + len); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(x[x_idx].Slice(start_offset, end_offset), place, dev_ctx, &slice); diff --git a/paddle/operators/assign_op.cc b/paddle/operators/assign_op.cc index 0560040509..7d77be3be1 100644 --- a/paddle/operators/assign_op.cc +++ b/paddle/operators/assign_op.cc @@ -82,8 +82,8 @@ class AssignOp : public framework::OperatorBase { out != nullptr, "The Output(Out) should not be null if the Input(X) is set."); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); } diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index 455fbd8ca3..e333002bfd 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -195,8 +195,8 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, void CondOp::Run(const Scope& scope, const platform::Place& place) const { // get device context from pool - platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); - auto& dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& dev_ctx = *pool.Get(place); PrepareDataForSubnet(scope, dev_ctx); std::vector& sub_scopes = GetSubScopes(scope); diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index cecbb7226a..48da52c3b6 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -49,8 +49,8 @@ class FeedOp : public framework::OperatorBase { auto *out_item = out_var->GetMutable(); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(feed_item, place, dev_ctx, out_item); out_item->set_lod(feed_item.lod()); diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index fa20a06540..387d1e0a74 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -52,8 +52,8 @@ class FetchOp : public framework::OperatorBase { // FIXME(yuyang18): Should we assume the fetch operator always generate // CPU outputs? - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item); dev_ctx.Wait(); diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc index fe0706c4a9..dcd43a30c8 100644 --- a/paddle/operators/fill_constant_op.cc +++ b/paddle/operators/fill_constant_op.cc @@ -49,8 +49,8 @@ class FillConstantOp : public framework::OperatorBase { out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); } - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(dev_place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); math::set_constant(dev_ctx, &out, value); } }; diff --git a/paddle/operators/fill_op.cc b/paddle/operators/fill_op.cc index 57b4ec6938..084ba1db62 100644 --- a/paddle/operators/fill_op.cc +++ b/paddle/operators/fill_op.cc @@ -69,8 +69,9 @@ class FillOp : public framework::OperatorBase { if (!force_cpu && platform::is_gpu_place(place)) { // Copy tensor to out - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(tensor, place, dev_ctx, &out); } } diff --git a/paddle/operators/load_op.cc b/paddle/operators/load_op.cc index 5425375c1f..65f021d919 100644 --- a/paddle/operators/load_op.cc +++ b/paddle/operators/load_op.cc @@ -40,8 +40,8 @@ class LoadOp : public framework::OperatorBase { auto *tensor = out_var->GetMutable(); framework::DeserializeFromStream(fin, tensor); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); if (platform::is_gpu_place(place)) { // copy CPU to GPU diff --git a/paddle/operators/lod_tensor_to_array_op.cc b/paddle/operators/lod_tensor_to_array_op.cc index ed99915bb7..8d164b4abc 100644 --- a/paddle/operators/lod_tensor_to_array_op.cc +++ b/paddle/operators/lod_tensor_to_array_op.cc @@ -88,8 +88,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase { auto slice = out[i].Slice(static_cast(offset), static_cast(offset + len)); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(x.Slice(static_cast(each_range.begin), static_cast(each_range.end)), diff --git a/paddle/operators/merge_lod_tensor_op.cc b/paddle/operators/merge_lod_tensor_op.cc index 2287f34791..3f999e404f 100644 --- a/paddle/operators/merge_lod_tensor_op.cc +++ b/paddle/operators/merge_lod_tensor_op.cc @@ -30,8 +30,8 @@ class MergeLoDTensorOp : public framework::OperatorBase { void Run(const framework::Scope &scope, const platform::Place &dev_place) const override { // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(dev_place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); auto &x = scope.FindVar(Input("X"))->Get(); auto &mask = scope.FindVar(Input("Mask"))->Get(); diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 71769e67c7..056fa46949 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -272,8 +272,9 @@ class RecurrentOp : public RecurrentBase { false /*create_local_scope*/); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); // Copy inside::output -> outside::output // outside::output[seq_offset: seq_offset + 1] = inside::output @@ -326,8 +327,8 @@ class RecurrentGradOp : public RecurrentBase { auto *program = block->Program(); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); for (size_t step_id = 0; step_id < seq_len; ++step_id) { size_t seq_offset = reverse ? step_id : seq_len - step_id - 1; diff --git a/paddle/operators/reorder_lod_tensor_by_rank_op.cc b/paddle/operators/reorder_lod_tensor_by_rank_op.cc index 1063388e25..8d652ff806 100644 --- a/paddle/operators/reorder_lod_tensor_by_rank_op.cc +++ b/paddle/operators/reorder_lod_tensor_by_rank_op.cc @@ -131,8 +131,8 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { auto x_sliced = x.Slice(x_offset, x_offset + len); auto out_sliced = out->Slice(out_offset, out_offset + len); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(x_sliced, out_sliced.place(), dev_ctx, &out_sliced); out_offset += len; return out_offset; diff --git a/paddle/operators/save_op.cc b/paddle/operators/save_op.cc index d045a8b5b8..4b1cbe8883 100644 --- a/paddle/operators/save_op.cc +++ b/paddle/operators/save_op.cc @@ -91,8 +91,8 @@ class SaveOp : public framework::OperatorBase { auto &tensor = var->Get(); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::SerializeToStream(fout, tensor, dev_ctx); } diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index e8a4773547..e5ef0740b6 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -106,8 +106,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { dx_tensor.mutable_data(x_tensor.place(), x_tensor.type()); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); if (dout_var == nullptr) { // dx_tensor fill zero math::set_constant(dev_ctx, &dx_tensor, 0.0f); diff --git a/paddle/operators/split_lod_tensor_op.cc b/paddle/operators/split_lod_tensor_op.cc index 89826ca6ee..2d8787d740 100644 --- a/paddle/operators/split_lod_tensor_op.cc +++ b/paddle/operators/split_lod_tensor_op.cc @@ -45,8 +45,8 @@ class SplitLoDTensorOp : public framework::OperatorBase { auto &x_lod = x.lod(); auto &mask_dim = mask.dims(); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(dev_place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); std::unique_ptr cpu_mask{new framework::LoDTensor()}; if (platform::is_cpu_place(mask.place())) { diff --git a/paddle/operators/tensor_array_read_write_op.cc b/paddle/operators/tensor_array_read_write_op.cc index 9529aab573..53e38ec703 100644 --- a/paddle/operators/tensor_array_read_write_op.cc +++ b/paddle/operators/tensor_array_read_write_op.cc @@ -40,8 +40,9 @@ class WriteToArrayOp : public ArrayOp { if (x_tensor.memory_size() > 0) { auto *out_tensor = &out->at(offset); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); CopyFrom(x_tensor, place, dev_ctx, out_tensor); out_tensor->set_lod(x_tensor.lod()); @@ -132,8 +133,9 @@ class ReadFromArrayOp : public ArrayOp { auto *out_tensor = out->GetMutable(); size_t offset = GetOffset(scope, place); if (offset < x_array.size()) { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(x_array[offset], place, dev_ctx, out_tensor); out_tensor->set_lod(x_array[offset].lod()); } else { diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index e450ef32a4..ea07f2e002 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -17,7 +17,7 @@ namespace platform { DeviceContextPool* DeviceContextPool::pool = nullptr; -const platform::DeviceContext* DeviceContextPool::Borrow( +const platform::DeviceContext* DeviceContextPool::Get( const platform::Place& place) { auto it = device_contexts_.find(place); if (it == device_contexts_.end()) { @@ -28,24 +28,6 @@ const platform::DeviceContext* DeviceContextPool::Borrow( return it->second; } -std::vector DeviceContextPool::Borrow( - const std::vector& places) { - PADDLE_ENFORCE_GT(places.size(), 0); - PADDLE_ENFORCE_LE(places.size(), device_contexts_.size()); - std::vector borrowed_contexts; - for (auto& place : places) { - auto it = device_contexts_.find(place); - if (it != device_contexts_.end()) { - borrowed_contexts.emplace_back(it->second); - } else { - PADDLE_THROW( - "'Place' is not supported, Please re-compile with WITH_GPU " - "option"); - } - } - return borrowed_contexts; -} - DeviceContextPool::DeviceContextPool( const std::vector& places) { PADDLE_ENFORCE_GT(places.size(), 0); diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 8ba12e1657..dfef2c16d8 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -109,13 +109,13 @@ class DeviceContextPool { public: explicit DeviceContextPool(const std::vector& places); - static DeviceContextPool& Get() { + static DeviceContextPool& Instance() { PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!"); return *pool; } /*! \brief Create should only called by Init function */ - static DeviceContextPool& Create(const std::vector& places) { + static DeviceContextPool& Init(const std::vector& places) { if (pool == nullptr) { pool = new DeviceContextPool(places); } @@ -123,13 +123,7 @@ class DeviceContextPool { } /*! \brief Return handle of single device context. */ - const platform::DeviceContext* Borrow(const platform::Place& place); - - /*! \brief Return handle of multi-device context. */ - std::vector Borrow( - const std::vector& places); - - ~DeviceContextPool() {} + const platform::DeviceContext* Get(const platform::Place& place); private: static DeviceContextPool* pool; diff --git a/paddle/platform/device_context_test.cu b/paddle/platform/device_context_test.cu index 91011bf71c..ca10cf3463 100644 --- a/paddle/platform/device_context_test.cu +++ b/paddle/platform/device_context_test.cu @@ -71,35 +71,20 @@ TEST(Device, DeviceContextPool) { using paddle::platform::CPUPlace; using paddle::platform::CUDAPlace; - DeviceContextPool& pool = DeviceContextPool::Get(); - auto cpu_dev_ctx1 = pool.Borrow(CPUPlace()); - auto cpu_dev_ctx2 = pool.Borrow(CPUPlace()); - EXPECT_TRUE(cpu_dev_ctx2 == cpu_dev_ctx1); + DeviceContextPool& pool = DeviceContextPool::Instance(); + auto cpu_dev_ctx1 = pool.Get(CPUPlace()); + auto cpu_dev_ctx2 = pool.Get(CPUPlace()); + ASSERT_EQ(cpu_dev_ctx2, cpu_dev_ctx1); std::vector gpu_places; int count = paddle::platform::GetCUDADeviceCount(); for (int i = 0; i < count; ++i) { - gpu_places.emplace_back(CUDAPlace(i)); - } - auto dev_ctxs = pool.Borrow(gpu_places); - for (size_t i = 0; i < dev_ctxs.size(); ++i) { - auto* dev_ctx = static_cast(dev_ctxs[i]); - - // check same as CUDAPlace(i) - CUDAPlace place = boost::get(dev_ctx->GetPlace()); - EXPECT_EQ(place.GetDeviceId(), static_cast(i)); + auto dev_ctx = pool.Get(CUDAPlace(i)); + ASSERT_NE(dev_ctx, nullptr); } } int main(int argc, char** argv) { - int dev_count = paddle::platform::GetCUDADeviceCount(); - if (dev_count <= 1) { - LOG(WARNING) << "Cannot test multi-gpu DeviceContextPool, because the CUDA " - "device count is " - << dev_count; - return 0; - } - std::vector places; places.emplace_back(paddle::platform::CPUPlace()); @@ -109,7 +94,7 @@ int main(int argc, char** argv) { } VLOG(0) << " DeviceCount " << count; - paddle::platform::DeviceContextPool::Create(places); + paddle::platform::DeviceContextPool::Init(places); testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/paddle/platform/nccl_test.cu b/paddle/platform/nccl_test.cu index 8f815863a7..ef6d845874 100644 --- a/paddle/platform/nccl_test.cu +++ b/paddle/platform/nccl_test.cu @@ -144,7 +144,7 @@ int main(int argc, char** argv) { } VLOG(0) << " DeviceCount " << count; - paddle::platform::DeviceContextPool::Create(places); + paddle::platform::DeviceContextPool::Init(places); testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); From 8b877dd7c8ef71520d27cd187f6767fe6be02262 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 10:29:29 +0800 Subject: [PATCH 09/33] Rename API of DeviceContext Make them as usual names. --- paddle/framework/init.cc | 2 +- paddle/framework/operator.cc | 4 +-- paddle/operators/array_operator.h | 4 +-- paddle/operators/array_to_lod_tensor_op.cc | 5 ++-- paddle/operators/assign_op.cc | 4 +-- paddle/operators/cond_op.cc | 4 +-- paddle/operators/feed_op.cc | 4 +-- paddle/operators/fetch_op.cc | 4 +-- paddle/operators/fill_constant_op.cc | 4 +-- paddle/operators/fill_op.cc | 5 ++-- paddle/operators/load_op.cc | 4 +-- paddle/operators/lod_tensor_to_array_op.cc | 5 ++-- paddle/operators/merge_lod_tensor_op.cc | 4 +-- paddle/operators/recurrent_op.cc | 9 +++--- .../reorder_lod_tensor_by_rank_op.cc | 4 +-- paddle/operators/save_op.cc | 4 +-- paddle/operators/shrink_rnn_memory_op.cc | 4 +-- paddle/operators/split_lod_tensor_op.cc | 4 +-- .../operators/tensor_array_read_write_op.cc | 10 ++++--- paddle/platform/device_context.cc | 20 +------------ paddle/platform/device_context.h | 12 ++------ paddle/platform/device_context_test.cu | 29 +++++-------------- paddle/platform/nccl_test.cu | 2 +- 23 files changed, 59 insertions(+), 92 deletions(-) diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc index d6601090d5..682cff168d 100644 --- a/paddle/framework/init.cc +++ b/paddle/framework/init.cc @@ -71,7 +71,7 @@ bool InitDevices(const std::vector &devices) { places.emplace_back(platform::CPUPlace()); LOG(WARNING) << "Not specified CPU device, create CPU by Default."; } - platform::DeviceContextPool::Create(places); + platform::DeviceContextPool::Init(places); return true; } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 886f73e7b8..e8d4be8675 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -388,8 +388,8 @@ void OperatorWithKernel::Run(const Scope& scope, const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); this->InferShape(&infer_shape_ctx); - platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); - auto dev_ctx = pool.Borrow(place); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto dev_ctx = pool.Get(place); // check if op[type] has kernel registered. auto& all_op_kernels = AllOpKernels(); diff --git a/paddle/operators/array_operator.h b/paddle/operators/array_operator.h index 060ffac827..e0eef5d9f9 100644 --- a/paddle/operators/array_operator.h +++ b/paddle/operators/array_operator.h @@ -35,8 +35,8 @@ class ArrayOp : public framework::OperatorBase { PADDLE_ENFORCE_EQ(i_tensor.numel(), 1); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); size_t offset; if (platform::is_gpu_place(i_tensor.place())) { diff --git a/paddle/operators/array_to_lod_tensor_op.cc b/paddle/operators/array_to_lod_tensor_op.cc index 0aa04c268b..49366fee8d 100644 --- a/paddle/operators/array_to_lod_tensor_op.cc +++ b/paddle/operators/array_to_lod_tensor_op.cc @@ -106,8 +106,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { } auto slice = out->Slice(out_offset, out_offset + len); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(x[x_idx].Slice(start_offset, end_offset), place, dev_ctx, &slice); diff --git a/paddle/operators/assign_op.cc b/paddle/operators/assign_op.cc index 0560040509..7d77be3be1 100644 --- a/paddle/operators/assign_op.cc +++ b/paddle/operators/assign_op.cc @@ -82,8 +82,8 @@ class AssignOp : public framework::OperatorBase { out != nullptr, "The Output(Out) should not be null if the Input(X) is set."); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); } diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index 455fbd8ca3..e333002bfd 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -195,8 +195,8 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, void CondOp::Run(const Scope& scope, const platform::Place& place) const { // get device context from pool - platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); - auto& dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& dev_ctx = *pool.Get(place); PrepareDataForSubnet(scope, dev_ctx); std::vector& sub_scopes = GetSubScopes(scope); diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index cecbb7226a..48da52c3b6 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -49,8 +49,8 @@ class FeedOp : public framework::OperatorBase { auto *out_item = out_var->GetMutable(); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(feed_item, place, dev_ctx, out_item); out_item->set_lod(feed_item.lod()); diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index fa20a06540..387d1e0a74 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -52,8 +52,8 @@ class FetchOp : public framework::OperatorBase { // FIXME(yuyang18): Should we assume the fetch operator always generate // CPU outputs? - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item); dev_ctx.Wait(); diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc index fe0706c4a9..dcd43a30c8 100644 --- a/paddle/operators/fill_constant_op.cc +++ b/paddle/operators/fill_constant_op.cc @@ -49,8 +49,8 @@ class FillConstantOp : public framework::OperatorBase { out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); } - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(dev_place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); math::set_constant(dev_ctx, &out, value); } }; diff --git a/paddle/operators/fill_op.cc b/paddle/operators/fill_op.cc index 57b4ec6938..084ba1db62 100644 --- a/paddle/operators/fill_op.cc +++ b/paddle/operators/fill_op.cc @@ -69,8 +69,9 @@ class FillOp : public framework::OperatorBase { if (!force_cpu && platform::is_gpu_place(place)) { // Copy tensor to out - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(tensor, place, dev_ctx, &out); } } diff --git a/paddle/operators/load_op.cc b/paddle/operators/load_op.cc index 5425375c1f..65f021d919 100644 --- a/paddle/operators/load_op.cc +++ b/paddle/operators/load_op.cc @@ -40,8 +40,8 @@ class LoadOp : public framework::OperatorBase { auto *tensor = out_var->GetMutable(); framework::DeserializeFromStream(fin, tensor); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); if (platform::is_gpu_place(place)) { // copy CPU to GPU diff --git a/paddle/operators/lod_tensor_to_array_op.cc b/paddle/operators/lod_tensor_to_array_op.cc index ed99915bb7..8d164b4abc 100644 --- a/paddle/operators/lod_tensor_to_array_op.cc +++ b/paddle/operators/lod_tensor_to_array_op.cc @@ -88,8 +88,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase { auto slice = out[i].Slice(static_cast(offset), static_cast(offset + len)); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(x.Slice(static_cast(each_range.begin), static_cast(each_range.end)), diff --git a/paddle/operators/merge_lod_tensor_op.cc b/paddle/operators/merge_lod_tensor_op.cc index 2287f34791..3f999e404f 100644 --- a/paddle/operators/merge_lod_tensor_op.cc +++ b/paddle/operators/merge_lod_tensor_op.cc @@ -30,8 +30,8 @@ class MergeLoDTensorOp : public framework::OperatorBase { void Run(const framework::Scope &scope, const platform::Place &dev_place) const override { // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(dev_place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); auto &x = scope.FindVar(Input("X"))->Get(); auto &mask = scope.FindVar(Input("Mask"))->Get(); diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 71769e67c7..056fa46949 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -272,8 +272,9 @@ class RecurrentOp : public RecurrentBase { false /*create_local_scope*/); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); // Copy inside::output -> outside::output // outside::output[seq_offset: seq_offset + 1] = inside::output @@ -326,8 +327,8 @@ class RecurrentGradOp : public RecurrentBase { auto *program = block->Program(); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); for (size_t step_id = 0; step_id < seq_len; ++step_id) { size_t seq_offset = reverse ? step_id : seq_len - step_id - 1; diff --git a/paddle/operators/reorder_lod_tensor_by_rank_op.cc b/paddle/operators/reorder_lod_tensor_by_rank_op.cc index 1063388e25..8d652ff806 100644 --- a/paddle/operators/reorder_lod_tensor_by_rank_op.cc +++ b/paddle/operators/reorder_lod_tensor_by_rank_op.cc @@ -131,8 +131,8 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { auto x_sliced = x.Slice(x_offset, x_offset + len); auto out_sliced = out->Slice(out_offset, out_offset + len); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(x_sliced, out_sliced.place(), dev_ctx, &out_sliced); out_offset += len; return out_offset; diff --git a/paddle/operators/save_op.cc b/paddle/operators/save_op.cc index d045a8b5b8..4b1cbe8883 100644 --- a/paddle/operators/save_op.cc +++ b/paddle/operators/save_op.cc @@ -91,8 +91,8 @@ class SaveOp : public framework::OperatorBase { auto &tensor = var->Get(); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::SerializeToStream(fout, tensor, dev_ctx); } diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index e8a4773547..e5ef0740b6 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -106,8 +106,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { dx_tensor.mutable_data(x_tensor.place(), x_tensor.type()); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); if (dout_var == nullptr) { // dx_tensor fill zero math::set_constant(dev_ctx, &dx_tensor, 0.0f); diff --git a/paddle/operators/split_lod_tensor_op.cc b/paddle/operators/split_lod_tensor_op.cc index 89826ca6ee..2d8787d740 100644 --- a/paddle/operators/split_lod_tensor_op.cc +++ b/paddle/operators/split_lod_tensor_op.cc @@ -45,8 +45,8 @@ class SplitLoDTensorOp : public framework::OperatorBase { auto &x_lod = x.lod(); auto &mask_dim = mask.dims(); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(dev_place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); std::unique_ptr cpu_mask{new framework::LoDTensor()}; if (platform::is_cpu_place(mask.place())) { diff --git a/paddle/operators/tensor_array_read_write_op.cc b/paddle/operators/tensor_array_read_write_op.cc index 9529aab573..53e38ec703 100644 --- a/paddle/operators/tensor_array_read_write_op.cc +++ b/paddle/operators/tensor_array_read_write_op.cc @@ -40,8 +40,9 @@ class WriteToArrayOp : public ArrayOp { if (x_tensor.memory_size() > 0) { auto *out_tensor = &out->at(offset); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); CopyFrom(x_tensor, place, dev_ctx, out_tensor); out_tensor->set_lod(x_tensor.lod()); @@ -132,8 +133,9 @@ class ReadFromArrayOp : public ArrayOp { auto *out_tensor = out->GetMutable(); size_t offset = GetOffset(scope, place); if (offset < x_array.size()) { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(place); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); framework::CopyFrom(x_array[offset], place, dev_ctx, out_tensor); out_tensor->set_lod(x_array[offset].lod()); } else { diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index e450ef32a4..ea07f2e002 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -17,7 +17,7 @@ namespace platform { DeviceContextPool* DeviceContextPool::pool = nullptr; -const platform::DeviceContext* DeviceContextPool::Borrow( +const platform::DeviceContext* DeviceContextPool::Get( const platform::Place& place) { auto it = device_contexts_.find(place); if (it == device_contexts_.end()) { @@ -28,24 +28,6 @@ const platform::DeviceContext* DeviceContextPool::Borrow( return it->second; } -std::vector DeviceContextPool::Borrow( - const std::vector& places) { - PADDLE_ENFORCE_GT(places.size(), 0); - PADDLE_ENFORCE_LE(places.size(), device_contexts_.size()); - std::vector borrowed_contexts; - for (auto& place : places) { - auto it = device_contexts_.find(place); - if (it != device_contexts_.end()) { - borrowed_contexts.emplace_back(it->second); - } else { - PADDLE_THROW( - "'Place' is not supported, Please re-compile with WITH_GPU " - "option"); - } - } - return borrowed_contexts; -} - DeviceContextPool::DeviceContextPool( const std::vector& places) { PADDLE_ENFORCE_GT(places.size(), 0); diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 8ba12e1657..dfef2c16d8 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -109,13 +109,13 @@ class DeviceContextPool { public: explicit DeviceContextPool(const std::vector& places); - static DeviceContextPool& Get() { + static DeviceContextPool& Instance() { PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!"); return *pool; } /*! \brief Create should only called by Init function */ - static DeviceContextPool& Create(const std::vector& places) { + static DeviceContextPool& Init(const std::vector& places) { if (pool == nullptr) { pool = new DeviceContextPool(places); } @@ -123,13 +123,7 @@ class DeviceContextPool { } /*! \brief Return handle of single device context. */ - const platform::DeviceContext* Borrow(const platform::Place& place); - - /*! \brief Return handle of multi-device context. */ - std::vector Borrow( - const std::vector& places); - - ~DeviceContextPool() {} + const platform::DeviceContext* Get(const platform::Place& place); private: static DeviceContextPool* pool; diff --git a/paddle/platform/device_context_test.cu b/paddle/platform/device_context_test.cu index 91011bf71c..ca10cf3463 100644 --- a/paddle/platform/device_context_test.cu +++ b/paddle/platform/device_context_test.cu @@ -71,35 +71,20 @@ TEST(Device, DeviceContextPool) { using paddle::platform::CPUPlace; using paddle::platform::CUDAPlace; - DeviceContextPool& pool = DeviceContextPool::Get(); - auto cpu_dev_ctx1 = pool.Borrow(CPUPlace()); - auto cpu_dev_ctx2 = pool.Borrow(CPUPlace()); - EXPECT_TRUE(cpu_dev_ctx2 == cpu_dev_ctx1); + DeviceContextPool& pool = DeviceContextPool::Instance(); + auto cpu_dev_ctx1 = pool.Get(CPUPlace()); + auto cpu_dev_ctx2 = pool.Get(CPUPlace()); + ASSERT_EQ(cpu_dev_ctx2, cpu_dev_ctx1); std::vector gpu_places; int count = paddle::platform::GetCUDADeviceCount(); for (int i = 0; i < count; ++i) { - gpu_places.emplace_back(CUDAPlace(i)); - } - auto dev_ctxs = pool.Borrow(gpu_places); - for (size_t i = 0; i < dev_ctxs.size(); ++i) { - auto* dev_ctx = static_cast(dev_ctxs[i]); - - // check same as CUDAPlace(i) - CUDAPlace place = boost::get(dev_ctx->GetPlace()); - EXPECT_EQ(place.GetDeviceId(), static_cast(i)); + auto dev_ctx = pool.Get(CUDAPlace(i)); + ASSERT_NE(dev_ctx, nullptr); } } int main(int argc, char** argv) { - int dev_count = paddle::platform::GetCUDADeviceCount(); - if (dev_count <= 1) { - LOG(WARNING) << "Cannot test multi-gpu DeviceContextPool, because the CUDA " - "device count is " - << dev_count; - return 0; - } - std::vector places; places.emplace_back(paddle::platform::CPUPlace()); @@ -109,7 +94,7 @@ int main(int argc, char** argv) { } VLOG(0) << " DeviceCount " << count; - paddle::platform::DeviceContextPool::Create(places); + paddle::platform::DeviceContextPool::Init(places); testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/paddle/platform/nccl_test.cu b/paddle/platform/nccl_test.cu index 8f815863a7..ef6d845874 100644 --- a/paddle/platform/nccl_test.cu +++ b/paddle/platform/nccl_test.cu @@ -144,7 +144,7 @@ int main(int argc, char** argv) { } VLOG(0) << " DeviceCount " << count; - paddle::platform::DeviceContextPool::Create(places); + paddle::platform::DeviceContextPool::Init(places); testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); From 42062c38b17f0a8ba3431bcb043e78b87440e6ad Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 11:12:18 +0800 Subject: [PATCH 10/33] Fix compile --- paddle/operators/beam_search_decode_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/operators/beam_search_decode_op.cc b/paddle/operators/beam_search_decode_op.cc index 52c28e7f53..72e05607b0 100644 --- a/paddle/operators/beam_search_decode_op.cc +++ b/paddle/operators/beam_search_decode_op.cc @@ -57,8 +57,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase { : OperatorBase(type, inputs, outputs, attrs) {} void Run(const framework::Scope& scope, const platform::Place& dev_place) const override { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); - auto& dev_ctx = *pool.Borrow(dev_place); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& dev_ctx = *pool.Get(dev_place); framework::ExecutionContext ctx(*this, scope, dev_ctx); From b711870c4ae2803374a5d5d86f011aa819055b7c Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 13:24:40 +0800 Subject: [PATCH 11/33] Fix compile --- paddle/gserver/layers/MKLDNNLRNLayer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/gserver/layers/MKLDNNLRNLayer.cpp b/paddle/gserver/layers/MKLDNNLRNLayer.cpp index 741984bb68..ac217f1363 100644 --- a/paddle/gserver/layers/MKLDNNLRNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLRNLayer.cpp @@ -29,7 +29,7 @@ bool MKLDNNLRNLayer::init(const LayerMap& layerMap, } /* the size of inputs for norm-layer is 1 */ - CHECK_EQ(config_.inputs_size(), 1UL); + CHECK_EQ(config_.inputs_size(), 1); const NormConfig& conf = config_.inputs(0).norm_conf(); localSize_ = conf.size(); alpha_ = conf.scale(); From 15309fde2c50a485fd120f749661ea16a6c75232 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 14:45:04 +0800 Subject: [PATCH 12/33] Add API for HasNAN HasInf --- paddle/framework/tensor_util.h | 96 ++++++++++++++++++++++++++++++++ paddle/platform/device_context.h | 20 +++++++ paddle/platform/place.h | 28 +++++++++- 3 files changed, 143 insertions(+), 1 deletion(-) diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index ea4e4f22ea..5c7822814c 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/data_type.h" +#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace framework { @@ -205,5 +208,98 @@ inline void CopyToVector(const Tensor& src, std::vector* dst) { src_ptr, size); } +template +struct AnyDTypeVisitor { + Predicate predicate_; + const Tensor& tensor_; + const DevCtx& ctx_; + Tensor* out_; + + AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx, + Tensor* out) + : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} + + template + void operator()() const { + auto t = EigenVector::Flatten(tensor_); + auto o = EigenScalar::From(*out_); + o.device(*ctx_.eigen_device()) = predicate_(t).any(); + } +}; + +template +inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, + const DevCtx& ctx, framework::Tensor* out) { + VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor( + predicate, tensor, ctx, out)); +} + +template +struct AnyVisitor : public boost::static_visitor { + const framework::Tensor& tensor_; + Predicate predicate_; + + AnyVisitor(const framework::Tensor& tensor, Predicate predicate) + : tensor_(tensor), predicate_(std::move(predicate)) {} + + template + bool operator()(const Place& place) const { + framework::Tensor out; + out.Resize({1}); + out.mutable_data(place); + auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + AnyImpl(predicate_, tensor_, *ctx, &out); + return this->GetResult(out, place); + } + + bool GetResult(const framework::Tensor& out, + const platform::CUDAPlace& gpu) const { + platform::CPUPlace cpu; + framework::Tensor tmp; + tmp.Resize({1}); + tmp.mutable_data(cpu); + platform::DeviceContextPool::Instance().Get(gpu)->Wait(); + CopyFrom(out, cpu, &tmp); + platform::DeviceContextPool::Instance().Get(gpu)->Wait(); + return GetResult(tmp, cpu); + } + + bool GetResult(const framework::Tensor& out, + const platform::CPUPlace& cpu) const { + return *out.data(); + } +}; + +template +inline bool Any(const framework::Tensor& tensor, Predicate predicate) { + AnyVisitor visitor(tensor, predicate); + auto place = tensor.place(); + return platform::VisitPlace(place, visitor); +} + +struct HasNanPredicate { + template + auto operator()(T eigen_vec) const -> decltype(std::declval().isnan()) { + return eigen_vec.isnan(); + } +}; + +inline bool HasNan(const framework::Tensor& tensor) { + HasNanPredicate predicate; + return Any(tensor, predicate); +} + +struct HasInfPredicate { + template + auto operator()(T eigen_vec) const -> decltype(std::declval().isinf()) { + return eigen_vec.isinf(); + } +}; + +inline bool HasInf(const framework::Tensor& tensor) { + HasInfPredicate predicate; + return Any(tensor, predicate); +} + } // namespace framework } // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index dfef2c16d8..fd441d27f9 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -52,6 +52,14 @@ class CPUDeviceContext : public DeviceContext { std::unique_ptr eigen_device_; }; +template +struct DefaultDeviceContextType; + +template <> +struct DefaultDeviceContextType { + using TYPE = CPUDeviceContext; +}; + #ifdef PADDLE_WITH_CUDA class EigenCudaStreamDevice; @@ -90,6 +98,11 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle_; }; +template <> +struct DefaultDeviceContextType { + using T = CUDADeviceContext; +}; + class CUDNNDeviceContext : public CUDADeviceContext { public: explicit CUDNNDeviceContext(CUDAPlace place); @@ -125,6 +138,13 @@ class DeviceContextPool { /*! \brief Return handle of single device context. */ const platform::DeviceContext* Get(const platform::Place& place); + template + const typename DefaultDeviceContextType::TYPE* GetByPlace( + const Place& place) { + return reinterpret_cast< + const typename DefaultDeviceContextType::TYPE*>(Get(place)); + } + private: static DeviceContextPool* pool; constexpr static int LEFT_SHIFT = 8; diff --git a/paddle/platform/place.h b/paddle/platform/place.h index d25eaa689f..76b5c502cc 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include - +#include "paddle/platform/enforce.h" #include "paddle/platform/variant.h" namespace paddle { @@ -64,5 +64,31 @@ bool places_are_same_class(const Place &, const Place &); std::ostream &operator<<(std::ostream &, const Place &); +template +struct PlaceVisitorWrapper + : public boost::static_visitor { + const Visitor &visitor_; + explicit PlaceVisitorWrapper(const Visitor &visitor) : visitor_(visitor) {} + + typename Visitor::result_type operator()(const CPUPlace &cpu) const { + return visitor_(cpu); + } + + typename Visitor::result_type operator()(const CUDAPlace &cuda) const { +#ifdef PADDLE_WITH_CUDA + return visitor_(cuda); +#else + PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda device"); + return typename Visitor::result_type(); +#endif + } +}; + +template +typename Visitor::result_type VisitPlace(const Place &place, + const Visitor &visitor) { + return boost::apply_visitor(PlaceVisitorWrapper(visitor), place); +} + } // namespace platform } // namespace paddle From 4518252e572c53ff0b1e8ac4149537bb400b80b5 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 14:47:08 +0800 Subject: [PATCH 13/33] Fix compile --- paddle/operators/nccl_op_test.cu.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/operators/nccl_op_test.cu.cc b/paddle/operators/nccl_op_test.cu.cc index 34a6e1a58d..6546096069 100644 --- a/paddle/operators/nccl_op_test.cu.cc +++ b/paddle/operators/nccl_op_test.cu.cc @@ -305,7 +305,7 @@ int main(int argc, char **argv) { } VLOG(0) << " DeviceCount " << count; - paddle::platform::DeviceContextPool::Create(places); + paddle::platform::DeviceContextPool::Init(places); testing::InitGoogleTest(&argc, argv); From 3d282ec407a518ece37adb1b9ee5da57429a9904 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 15:10:48 +0800 Subject: [PATCH 14/33] Add is_nan/is_inf --- paddle/framework/tensor_util.h | 12 +++++++----- paddle/framework/tensor_util_test.cc | 24 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index 5c7822814c..7d786ad614 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -277,21 +277,23 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) { return platform::VisitPlace(place, visitor); } -struct HasNanPredicate { +struct HasNANPredicate { template - auto operator()(T eigen_vec) const -> decltype(std::declval().isnan()) { + auto operator()(const T& eigen_vec) const + -> decltype(std::declval().isnan()) { return eigen_vec.isnan(); } }; -inline bool HasNan(const framework::Tensor& tensor) { - HasNanPredicate predicate; +inline bool HasNAN(const framework::Tensor& tensor) { + HasNANPredicate predicate; return Any(tensor, predicate); } struct HasInfPredicate { template - auto operator()(T eigen_vec) const -> decltype(std::declval().isinf()) { + auto operator()(const T& eigen_vec) const + -> decltype(std::declval().isinf()) { return eigen_vec.isinf(); } }; diff --git a/paddle/framework/tensor_util_test.cc b/paddle/framework/tensor_util_test.cc index f388c19f28..01dfd4deb9 100644 --- a/paddle/framework/tensor_util_test.cc +++ b/paddle/framework/tensor_util_test.cc @@ -13,6 +13,7 @@ #include "paddle/framework/tensor_util.h" #include +#include #include namespace paddle { @@ -230,5 +231,28 @@ TEST(CopyToVector, Tensor) { #endif } +TEST(IsNAN, CPU) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor src; + float* buf = src.mutable_data({3}, CPUPlace()); + buf[0] = 0.0; + buf[1] = NAN; + buf[2] = 0.0; + + ASSERT_TRUE(HasNAN(src)); +} + +TEST(IsInf, CPU) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor src; + double* buf = src.mutable_data({3}, CPUPlace()); + buf[0] = 1.0; + buf[1] = INFINITY; + buf[2] = 0.0; + ASSERT_TRUE(HasInf(src)); +} + } // namespace framework } // namespace paddle From a5291f9ce2466326588792a2e58a5f777c5fc51e Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 15:11:17 +0800 Subject: [PATCH 15/33] Fix compile --- paddle/pybind/tensor_py.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/pybind/tensor_py.h b/paddle/pybind/tensor_py.h index 67244d8260..64e981e4e8 100644 --- a/paddle/pybind/tensor_py.h +++ b/paddle/pybind/tensor_py.h @@ -63,9 +63,10 @@ struct CastToPyBufferImpl { auto *dst_ptr = static_cast(dst_tensor.mutable_data( tensor.dims(), platform::CPUPlace())); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance; auto dev_ctx = static_cast( - pool.Borrow(tensor.place())); + pool.Get(tensor.place())); paddle::platform::GpuMemcpyAsync( dst_ptr, src_ptr, sizeof(CUR_TYPE) * tensor.numel(), @@ -137,9 +138,9 @@ void PyCUDATensorSetFromArray( self.Resize(framework::make_ddim(dims)); auto *dst = self.mutable_data(place); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto dev_ctx = - static_cast(pool.Borrow(place)); + static_cast(pool.Get(place)); paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice, dev_ctx->stream()); } From 3ae781eb2bc139a946b7f195183e31304af49822 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 15:45:13 +0800 Subject: [PATCH 16/33] Executor check nan --- paddle/framework/executor.cc | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index 997773c168..9ee2ddb7c3 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -14,18 +14,17 @@ limitations under the License. */ #include "paddle/framework/executor.h" -#include -#include -#include #include -#include +#include "gflags/gflags.h" #include "paddle/framework/feed_fetch_type.h" #include "paddle/framework/lod_rank_table.h" -#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/op_registry.h" -#include "paddle/framework/scope.h" + +DEFINE_bool(check_nan_inf, false, + "Checking whether operator produce NAN/INF or not. It will be " + "extremely slow so please use this flag wisely."); namespace paddle { namespace framework { @@ -58,6 +57,19 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { } } +static void CheckTensorNANOrInf(const std::string& name, + const framework::Tensor& tensor) { + if (tensor.type().hash_code() != typeid(float).hash_code() && + tensor.type().hash_code() != typeid(double).hash_code()) { + return; + } + if (tensor.memory_size() == 0) { + return; + } + PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name); + PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name); +} + void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, bool create_local_scope, bool create_vars) { // TODO(tonyyang-svail): @@ -101,6 +113,15 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); VLOG(3) << op->DebugString(); op->Run(*local_scope, place_); + if (FLAGS_check_nan_inf) { + for (auto& vname : op->OutputVars(true)) { + auto* var = local_scope->FindVar(vname); + if (var == nullptr) continue; + if (var->IsType()) { + CheckTensorNANOrInf(vname, var->Get()); + } + } + } } if (create_local_scope) { scope->DeleteScope(local_scope); From 16a84328c6947f224534cbd5e3218714adfb9e9b Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 15:47:18 +0800 Subject: [PATCH 17/33] Fix compile --- paddle/pybind/tensor_py.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/pybind/tensor_py.h b/paddle/pybind/tensor_py.h index 64e981e4e8..4d5e73e2c2 100644 --- a/paddle/pybind/tensor_py.h +++ b/paddle/pybind/tensor_py.h @@ -64,7 +64,7 @@ struct CastToPyBufferImpl { tensor.dims(), platform::CPUPlace())); platform::DeviceContextPool &pool = - platform::DeviceContextPool::Instance; + platform::DeviceContextPool::Instance(); auto dev_ctx = static_cast( pool.Get(tensor.place())); From 5162c41a9209da9daf5c440396ac3fbd516f16e7 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 16:02:28 +0800 Subject: [PATCH 18/33] Add gflags --- python/paddle/v2/fluid/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index c72b573069..225b41c504 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -36,7 +36,7 @@ def __read_gflags_from_env__(): """ import sys import core - read_env_flags = ['use_pinned_memory'] + read_env_flags = ['use_pinned_memory', 'check_nan_inf'] if core.is_compile_gpu(): read_env_flags.append('fraction_of_gpu_memory_to_use') core.init_gflags([sys.argv[0]] + From 003917d881fec0192e97bae19abb41599c6b0083 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Thu, 28 Dec 2017 10:34:04 +0800 Subject: [PATCH 19/33] Fix compile --- paddle/platform/device_context.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index fd441d27f9..2b366e6383 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -100,7 +100,7 @@ class CUDADeviceContext : public DeviceContext { template <> struct DefaultDeviceContextType { - using T = CUDADeviceContext; + using TYPE = CUDADeviceContext; }; class CUDNNDeviceContext : public CUDADeviceContext { From 878d2e919c5c15fabc659ed544da3b867272f0d2 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Thu, 28 Dec 2017 10:52:41 +0800 Subject: [PATCH 20/33] Fix compile --- paddle/framework/CMakeLists.txt | 6 +- paddle/framework/tensor_util.cc | 115 ++++++++++++++++++++++++++++++++ paddle/framework/tensor_util.cu | 1 + paddle/framework/tensor_util.h | 96 +------------------------- 4 files changed, 123 insertions(+), 95 deletions(-) create mode 100644 paddle/framework/tensor_util.cc create mode 120000 paddle/framework/tensor_util.cu diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 738684795d..f72f49bc5e 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -5,7 +5,11 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) -cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context) +if (WITH_GPU) + nv_binary(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context) +else() + cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context) +endif () cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor) diff --git a/paddle/framework/tensor_util.cc b/paddle/framework/tensor_util.cc new file mode 100644 index 0000000000..293c65a065 --- /dev/null +++ b/paddle/framework/tensor_util.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/tensor_util.h" + +namespace paddle { +namespace framework { +template +struct AnyDTypeVisitor { + Predicate predicate_; + const Tensor& tensor_; + const DevCtx& ctx_; + Tensor* out_; + + AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx, + Tensor* out) + : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} + + template + void operator()() const { + auto t = EigenVector::Flatten(tensor_); + auto o = EigenScalar::From(*out_); + o.device(*ctx_.eigen_device()) = predicate_(t).any(); + } +}; + +template +inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, + const DevCtx& ctx, framework::Tensor* out) { + VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor( + predicate, tensor, ctx, out)); +} + +template +struct AnyVisitor : public boost::static_visitor { + const framework::Tensor& tensor_; + Predicate predicate_; + + AnyVisitor(const framework::Tensor& tensor, Predicate predicate) + : tensor_(tensor), predicate_(std::move(predicate)) {} + + template + bool operator()(const Place& place) const { + framework::Tensor out; + out.Resize({1}); + out.mutable_data(place); + auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + AnyImpl(predicate_, tensor_, *ctx, &out); + return this->GetResult(out, place); + } + + bool GetResult(const framework::Tensor& out, + const platform::CUDAPlace& gpu) const { + platform::CPUPlace cpu; + framework::Tensor tmp; + tmp.Resize({1}); + tmp.mutable_data(cpu); + platform::DeviceContextPool::Instance().Get(gpu)->Wait(); + CopyFrom(out, cpu, &tmp); + platform::DeviceContextPool::Instance().Get(gpu)->Wait(); + return GetResult(tmp, cpu); + } + + bool GetResult(const framework::Tensor& out, + const platform::CPUPlace& cpu) const { + return *out.data(); + } +}; + +template +inline bool Any(const framework::Tensor& tensor, Predicate predicate) { + AnyVisitor visitor(tensor, predicate); + auto place = tensor.place(); + return platform::VisitPlace(place, visitor); +} + +struct HasNANPredicate { + template + auto operator()(const T& eigen_vec) const + -> decltype(std::declval().isnan()) { + return eigen_vec.isnan(); + } +}; + +bool HasNAN(const framework::Tensor& tensor) { + HasNANPredicate predicate; + return Any(tensor, predicate); +} + +struct HasInfPredicate { + template + auto operator()(const T& eigen_vec) const + -> decltype(std::declval().isinf()) { + return eigen_vec.isinf(); + } +}; + +bool HasInf(const framework::Tensor& tensor) { + HasInfPredicate predicate; + return Any(tensor, predicate); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/tensor_util.cu b/paddle/framework/tensor_util.cu new file mode 120000 index 0000000000..b00e6e59d9 --- /dev/null +++ b/paddle/framework/tensor_util.cu @@ -0,0 +1 @@ +./tensor_util.cc \ No newline at end of file diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index 7d786ad614..e71d8e5672 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -208,100 +208,8 @@ inline void CopyToVector(const Tensor& src, std::vector* dst) { src_ptr, size); } -template -struct AnyDTypeVisitor { - Predicate predicate_; - const Tensor& tensor_; - const DevCtx& ctx_; - Tensor* out_; - - AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx, - Tensor* out) - : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} - - template - void operator()() const { - auto t = EigenVector::Flatten(tensor_); - auto o = EigenScalar::From(*out_); - o.device(*ctx_.eigen_device()) = predicate_(t).any(); - } -}; - -template -inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, - const DevCtx& ctx, framework::Tensor* out) { - VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor( - predicate, tensor, ctx, out)); -} - -template -struct AnyVisitor : public boost::static_visitor { - const framework::Tensor& tensor_; - Predicate predicate_; - - AnyVisitor(const framework::Tensor& tensor, Predicate predicate) - : tensor_(tensor), predicate_(std::move(predicate)) {} - - template - bool operator()(const Place& place) const { - framework::Tensor out; - out.Resize({1}); - out.mutable_data(place); - auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); - AnyImpl(predicate_, tensor_, *ctx, &out); - return this->GetResult(out, place); - } - - bool GetResult(const framework::Tensor& out, - const platform::CUDAPlace& gpu) const { - platform::CPUPlace cpu; - framework::Tensor tmp; - tmp.Resize({1}); - tmp.mutable_data(cpu); - platform::DeviceContextPool::Instance().Get(gpu)->Wait(); - CopyFrom(out, cpu, &tmp); - platform::DeviceContextPool::Instance().Get(gpu)->Wait(); - return GetResult(tmp, cpu); - } - - bool GetResult(const framework::Tensor& out, - const platform::CPUPlace& cpu) const { - return *out.data(); - } -}; - -template -inline bool Any(const framework::Tensor& tensor, Predicate predicate) { - AnyVisitor visitor(tensor, predicate); - auto place = tensor.place(); - return platform::VisitPlace(place, visitor); -} - -struct HasNANPredicate { - template - auto operator()(const T& eigen_vec) const - -> decltype(std::declval().isnan()) { - return eigen_vec.isnan(); - } -}; - -inline bool HasNAN(const framework::Tensor& tensor) { - HasNANPredicate predicate; - return Any(tensor, predicate); -} - -struct HasInfPredicate { - template - auto operator()(const T& eigen_vec) const - -> decltype(std::declval().isinf()) { - return eigen_vec.isinf(); - } -}; - -inline bool HasInf(const framework::Tensor& tensor) { - HasInfPredicate predicate; - return Any(tensor, predicate); -} +extern bool HasNAN(const framework::Tensor& tensor); +extern bool HasInf(const framework::Tensor& tensor); } // namespace framework } // namespace paddle From a9a44e017c4b38cd7105365dd1ee3916fe3889ce Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Thu, 28 Dec 2017 10:53:39 +0800 Subject: [PATCH 21/33] Fix compile --- paddle/framework/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index f72f49bc5e..2af10a996c 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -6,7 +6,7 @@ cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) if (WITH_GPU) - nv_binary(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context) + nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context) else() cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context) endif () From dd2bbf3a14fec5623609bf84377350e5812342f0 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 28 Dec 2017 13:51:17 +0800 Subject: [PATCH 22/33] update md5 of flowers dataset --- python/paddle/v2/dataset/flowers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/v2/dataset/flowers.py b/python/paddle/v2/dataset/flowers.py index 634388094c..7bdddeaabe 100644 --- a/python/paddle/v2/dataset/flowers.py +++ b/python/paddle/v2/dataset/flowers.py @@ -44,7 +44,7 @@ __all__ = ['train', 'test', 'valid'] DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat' SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' -DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' +DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' # In official 'readme', tstid is the flag of test data From b7c4b58d3d041d4afe4da3d7f8b7d7366e8dce8d Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 28 Dec 2017 14:51:32 +0800 Subject: [PATCH 23/33] Follow comments. --- paddle/function/GemmConvOp.cpp | 6 ++++-- paddle/function/Im2Col.h | 2 +- paddle/function/Im2ColTest.cpp | 14 +++++++------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 25cc3df667..cbdbf5335d 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -189,8 +189,8 @@ public: size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth; size_t colWidth = outputHeight * outputWidth; // Max col matrix height 256, Max col matrix width 1024 - size_t stepColHeight = std::min(colHeight, (size_t)256); - size_t stepColWidth = std::min(colWidth, (size_t)2048); + size_t stepColHeight = std::min(colHeight, static_cast(256)); + size_t stepColWidth = std::min(colWidth, static_cast(2048)); if (needIm2col) { colShape = TensorShape({inputChannels / groups_, @@ -278,6 +278,8 @@ public: inputData += inputChannels * inputHeight * inputWidth; outputData += outputChannels * outputHeight * outputWidth; } + + memory_.reset(); } }; diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index 1053e4fd23..36a9bcf84e 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -136,7 +136,7 @@ public: (imRowIdx - paddingHeight) >= inputHeight || (imColIdx - paddingWidth) < 0 || (imColIdx - paddingWidth) >= inputWidth) { - colData[colh * colWidthSize + colw] = T(0); + colData[colh * colWidthSize + colw] = static_cast(0); } else { imRowIdx += c_im * inputHeight - paddingHeight; imColIdx -= paddingWidth; diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp index c573469168..3ba866dcdd 100644 --- a/paddle/function/Im2ColTest.cpp +++ b/paddle/function/Im2ColTest.cpp @@ -140,13 +140,13 @@ TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor(); } template void TestIm2ColMobileFunctor() { - for (size_t channels : {1, 5, 32}) { - for (size_t inputHeight : {5, 33, 100}) { - for (size_t inputWidth : {5, 32, 96}) { - for (size_t filterHeight : {1, 5}) { - for (size_t filterWidth : {3, 7}) { - for (size_t stride : {1, 2}) { - for (size_t padding : {0, 1}) { + for (size_t channels : {32}) { + for (size_t inputHeight : {33, 100}) { + for (size_t inputWidth : {32, 96}) { + for (size_t filterHeight : {5}) { + for (size_t filterWidth : {7}) { + for (size_t stride : {2}) { + for (size_t padding : {1}) { for (size_t dilation : {1, 3}) { size_t filterSizeH = (filterHeight - 1) * dilation + 1; size_t filterSizeW = (filterWidth - 1) * dilation + 1; From 5022ee63597c0ac52a9b5344f81546f6c26b2dc7 Mon Sep 17 00:00:00 2001 From: Yancey Date: Thu, 28 Dec 2017 17:09:11 +0800 Subject: [PATCH 24/33] ThreadPool::Run interface return std::future (#7099) * Run interface return future * delete unused comments --- paddle/framework/threadpool.h | 19 +++++++++++++------ paddle/framework/threadpool_test.cc | 19 ++++++++++++------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/paddle/framework/threadpool.h b/paddle/framework/threadpool.h index 5f6b2d458f..bcd8190755 100644 --- a/paddle/framework/threadpool.h +++ b/paddle/framework/threadpool.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include #include @@ -25,10 +26,11 @@ limitations under the License. */ namespace paddle { namespace framework { -typedef std::function Task; - class ThreadPool { public: + typedef std::packaged_task Task; + typedef std::function Fun; + /** * @brief Get a instance of threadpool, the thread number will * be specified as the number of hardware thread contexts @@ -61,13 +63,18 @@ class ThreadPool { /** * @brief Push a function to the queue, and will be scheduled and * executed if a thread is available. - * @param[in] Task will be pushed to the task queue. + * @param[in] Task, will be pushed to the task queue. + * @return std::future, we could wait for the task finished by + * f.wait(). */ - void Run(const Task& fn) { + std::future Run(const Fun& fn) { std::unique_lock lock(mutex_); - tasks_.push(fn); + Task task(std::bind(fn)); + std::future f = task.get_future(); + tasks_.push(std::move(task)); lock.unlock(); scheduled_.notify_one(); + return f; } /** @@ -110,7 +117,7 @@ class ThreadPool { break; } // pop a task from the task queue - auto task = tasks_.front(); + auto task = std::move(tasks_.front()); tasks_.pop(); --available_; diff --git a/paddle/framework/threadpool_test.cc b/paddle/framework/threadpool_test.cc index 012d92a5ed..50b6238cd8 100644 --- a/paddle/framework/threadpool_test.cc +++ b/paddle/framework/threadpool_test.cc @@ -20,16 +20,21 @@ limitations under the License. */ namespace framework = paddle::framework; void do_sum(framework::ThreadPool* pool, std::atomic& sum, int cnt) { + std::vector> fs; for (int i = 0; i < cnt; ++i) { - pool->Run([&sum]() { sum.fetch_add(1); }); + auto f = pool->Run([&sum]() { sum.fetch_add(1); }); + fs.push_back(std::move(f)); + } + for (auto& f : fs) { + f.wait(); } } TEST(ThreadPool, ConcurrentInit) { framework::ThreadPool* pool; - int concurrent_cnt = 50; + int n = 50; std::vector threads; - for (int i = 0; i < concurrent_cnt; ++i) { + for (int i = 0; i < n; ++i) { std::thread t([&pool]() { pool = framework::ThreadPool::GetInstance(); }); threads.push_back(std::move(t)); } @@ -38,13 +43,13 @@ TEST(ThreadPool, ConcurrentInit) { } } -TEST(ThreadPool, ConcurrentStart) { +TEST(ThreadPool, ConcurrentRun) { framework::ThreadPool* pool = framework::ThreadPool::GetInstance(); std::atomic sum(0); std::vector threads; - int concurrent_cnt = 50; + int n = 50; // sum = (n * (n + 1)) / 2 - for (int i = 1; i <= concurrent_cnt; ++i) { + for (int i = 1; i <= n; ++i) { std::thread t(do_sum, pool, std::ref(sum), i); threads.push_back(std::move(t)); } @@ -52,5 +57,5 @@ TEST(ThreadPool, ConcurrentStart) { t.join(); } pool->Wait(); - EXPECT_EQ(sum, ((concurrent_cnt + 1) * concurrent_cnt) / 2); + EXPECT_EQ(sum, ((n + 1) * n) / 2); } From 3158b4b37a7743239030a331de56f9c227d14adf Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Thu, 28 Dec 2017 17:50:29 +0800 Subject: [PATCH 25/33] Update tensor_util --- paddle/framework/CMakeLists.txt | 6 ++- paddle/framework/tensor_util.cc | 10 +++-- paddle/framework/tensor_util.h | 3 ++ paddle/framework/tensor_util_test.cu | 57 ++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 paddle/framework/tensor_util_test.cu diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 2af10a996c..46dce7d1d2 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -12,7 +12,11 @@ else() endif () cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) -cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor) +if (WITH_GPU) + nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor) +else() + cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor) +endif() cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) diff --git a/paddle/framework/tensor_util.cc b/paddle/framework/tensor_util.cc index 293c65a065..7efc649d0b 100644 --- a/paddle/framework/tensor_util.cc +++ b/paddle/framework/tensor_util.cc @@ -31,6 +31,7 @@ struct AnyDTypeVisitor { void operator()() const { auto t = EigenVector::Flatten(tensor_); auto o = EigenScalar::From(*out_); + // return any of predicate_(t) is true. o.device(*ctx_.eigen_device()) = predicate_(t).any(); } }; @@ -66,9 +67,10 @@ struct AnyVisitor : public boost::static_visitor { framework::Tensor tmp; tmp.Resize({1}); tmp.mutable_data(cpu); - platform::DeviceContextPool::Instance().Get(gpu)->Wait(); - CopyFrom(out, cpu, &tmp); - platform::DeviceContextPool::Instance().Get(gpu)->Wait(); + auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu); + gpuctx->Wait(); + CopyFrom(out, cpu, *gpuctx, &tmp); + gpuctx->Wait(); return GetResult(tmp, cpu); } @@ -89,6 +91,7 @@ struct HasNANPredicate { template auto operator()(const T& eigen_vec) const -> decltype(std::declval().isnan()) { + // Cast eigen_vector to vector of bool. true if is inf. return eigen_vec.isnan(); } }; @@ -102,6 +105,7 @@ struct HasInfPredicate { template auto operator()(const T& eigen_vec) const -> decltype(std::declval().isinf()) { + // Cast eigen_vector to vector of bool. true if is inf. return eigen_vec.isinf(); } }; diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index e71d8e5672..784170dae3 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -208,7 +208,10 @@ inline void CopyToVector(const Tensor& src, std::vector* dst) { src_ptr, size); } +// Returns true if a tensor contains NAN, i.e., Not A Number. extern bool HasNAN(const framework::Tensor& tensor); + +// Returns true if a tensor contains Inf, i.e., Infinity. extern bool HasInf(const framework::Tensor& tensor); } // namespace framework diff --git a/paddle/framework/tensor_util_test.cu b/paddle/framework/tensor_util_test.cu new file mode 100644 index 0000000000..ebd35fdf6c --- /dev/null +++ b/paddle/framework/tensor_util_test.cu @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "gtest/gtest.h" +#include "paddle/framework/tensor_util.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace framework { + +static __global__ void FillNAN(float* buf) { + buf[0] = 0.0; + buf[1] = 0.1; + buf[2] = NAN; +} +static __global__ void FillInf(float* buf) { + buf[0] = 0.0; + buf[1] = INFINITY; + buf[2] = 0.5; +} + +TEST(HasNAN, GPU) { + Tensor tensor; + platform::CUDAPlace gpu(0); + auto& pool = platform::DeviceContextPool::Instance(); + auto* cuda_ctx = pool.GetByPlace(gpu); + float* buf = tensor.mutable_data({3}, gpu); + FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + ASSERT_TRUE(HasNAN(tensor)); +} + +TEST(HasInf, GPU) { + Tensor tensor; + platform::CUDAPlace gpu(0); + auto& pool = platform::DeviceContextPool::Instance(); + auto* cuda_ctx = pool.GetByPlace(gpu); + float* buf = tensor.mutable_data({3}, gpu); + FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + ASSERT_TRUE(HasInf(tensor)); +} + +} // namespace framework +} // namespace paddle From cf9e09b115bae0ad9cbb2ad3594f0f10f30a813b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 28 Dec 2017 21:51:47 +0800 Subject: [PATCH 26/33] set openblas env to avoid threads conflicts --- benchmark/paddle/image/run_openblas_train.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmark/paddle/image/run_openblas_train.sh b/benchmark/paddle/image/run_openblas_train.sh index e9df83fee2..d82c8384e0 100755 --- a/benchmark/paddle/image/run_openblas_train.sh +++ b/benchmark/paddle/image/run_openblas_train.sh @@ -2,6 +2,7 @@ set -e function train() { unset OMP_NUM_THREADS MKL_NUM_THREADS OMP_DYNAMIC KMP_AFFINITY + export OPENBLAS_NUM_THREADS=1 topology=$1 layer_num=$2 bs=$3 From 33b5382efc8f3e58eda8bae24559f22d6485824c Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 28 Dec 2017 22:15:59 +0800 Subject: [PATCH 27/33] auto set openblas env --- paddle/scripts/submit_local.sh.in | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/paddle/scripts/submit_local.sh.in b/paddle/scripts/submit_local.sh.in index a94bc01b35..8a352b0078 100755 --- a/paddle/scripts/submit_local.sh.in +++ b/paddle/scripts/submit_local.sh.in @@ -71,9 +71,7 @@ function threads_config() { # auto set OMP_NUM_THREADS and MKL_NUM_THREADS # according to trainer_count and total processors # only when MKL enabled - if [ "@WITH_MKL@" == "OFF" ]; then - return 0 - fi + # auto set OPENBLAS_NUM_THREADS when do not use MKL processors=`grep "processor" /proc/cpuinfo|sort -u|wc -l` trainers=`grep -Eo 'trainer_count.[0-9]+' <<< "$@" |grep -Eo '[0-9]+'|xargs` if [ -z $trainers ]; then @@ -83,12 +81,19 @@ function threads_config() { if [ $threads -eq 0 ]; then threads=1 fi - if [ -z "$OMP_NUM_THREADS" ]; then - export OMP_NUM_THREADS=$threads - fi - if [ -z "$MKL_NUM_THREADS" ]; then - export MKL_NUM_THREADS=$threads + if [ "@WITH_MKL@" == "ON" ]; then + if [ -z "$OMP_NUM_THREADS" ]; then + export OMP_NUM_THREADS=$threads + fi + if [ -z "$MKL_NUM_THREADS" ]; then + export MKL_NUM_THREADS=$threads + fi + else + if [ -z "$OPENBLAS_NUM_THREADS" ]; then + export OPENBLAS_NUM_THREADS=$threads + fi fi + } PADDLE_CONF_HOME="$HOME/.config/paddle" @@ -150,7 +155,7 @@ fi case "$1" in "train") threads_config $@ - # echo $OMP_NUM_THREADS $MKL_NUM_THREADS + # echo $OMP_NUM_THREADS $MKL_NUM_THREADS $OPENBLAS_NUM_THREADS ${DEBUGGER} $PADDLE_BIN_PATH/paddle_trainer ${@:2} ;; "merge_model") From d630d3921452b3f92dd358caaf03fa7d33942627 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 29 Dec 2017 12:18:23 +0800 Subject: [PATCH 28/33] auto set openblas env when inference and remove unused env for openblas --- benchmark/paddle/image/run_openblas_infer.sh | 16 ++++++++++------ benchmark/paddle/image/run_openblas_train.sh | 1 - 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/benchmark/paddle/image/run_openblas_infer.sh b/benchmark/paddle/image/run_openblas_infer.sh index da034f3b9d..71a49231a5 100755 --- a/benchmark/paddle/image/run_openblas_infer.sh +++ b/benchmark/paddle/image/run_openblas_infer.sh @@ -8,15 +8,19 @@ function clock_to_seconds() { } function infer() { - unset OMP_NUM_THREADS MKL_NUM_THREADS OMP_DYNAMIC KMP_AFFINITY topology=$1 layer_num=$2 bs=$3 - thread=`nproc` - if [ $thread -gt $bs ]; then - thread=$bs + trainers=`nproc` + if [ $trainers -gt $bs ]; then + trainers=$bs fi - log="logs/infer-${topology}-${layer_num}-${thread}openblas-${bs}.log" + log="logs/infer-${topology}-${layer_num}-${trainers}openblas-${bs}.log" + threads=$((`nproc` / trainers)) + if [ $threads -eq 0 ]; then + threads=1 + fi + export OPENBLAS_NUM_THREADS=$threads models_in="models/${topology}-${layer_num}/pass-00000/" if [ ! -d $models_in ]; then @@ -28,7 +32,7 @@ function infer() { --config="${topology}.py" \ --use_mkldnn=False \ --use_gpu=False \ - --trainer_count=$thread \ + --trainer_count=$trainers \ --log_period=$log_period \ --config_args="batch_size=${bs},layer_num=${layer_num},is_infer=True,num_samples=256" \ --init_model_path=$models_in \ diff --git a/benchmark/paddle/image/run_openblas_train.sh b/benchmark/paddle/image/run_openblas_train.sh index d82c8384e0..935cff6f2c 100755 --- a/benchmark/paddle/image/run_openblas_train.sh +++ b/benchmark/paddle/image/run_openblas_train.sh @@ -1,7 +1,6 @@ set -e function train() { - unset OMP_NUM_THREADS MKL_NUM_THREADS OMP_DYNAMIC KMP_AFFINITY export OPENBLAS_NUM_THREADS=1 topology=$1 layer_num=$2 From 5139e6c740f9829234de3cc4ed5a3fcd56e2331c Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Fri, 29 Dec 2017 12:57:57 +0800 Subject: [PATCH 29/33] Follow comments --- paddle/framework/executor.cc | 6 +++--- paddle/framework/tensor_util.h | 4 ++-- paddle/framework/tensor_util_test.cc | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index de4d3395eb..bf1f0471cc 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -59,11 +59,11 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { static void CheckTensorNANOrInf(const std::string& name, const framework::Tensor& tensor) { - if (tensor.type().hash_code() != typeid(float).hash_code() && - tensor.type().hash_code() != typeid(double).hash_code()) { + if (tensor.memory_size() == 0) { return; } - if (tensor.memory_size() == 0) { + if (tensor.type().hash_code() != typeid(float).hash_code() && + tensor.type().hash_code() != typeid(double).hash_code()) { return; } PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name); diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index a86fab2925..6a21f8db1e 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -210,10 +210,10 @@ inline void CopyToVector(const Tensor& src, std::vector* dst) { } // Returns true if a tensor contains NAN, i.e., Not A Number. -extern bool HasNAN(const framework::Tensor& tensor); +bool HasNAN(const framework::Tensor& tensor); // Returns true if a tensor contains Inf, i.e., Infinity. -extern bool HasInf(const framework::Tensor& tensor); +bool HasInf(const framework::Tensor& tensor); inline void SerializeToStream(std::ostream& os, const Tensor& tensor, const platform::DeviceContext& dev_ctx) { diff --git a/paddle/framework/tensor_util_test.cc b/paddle/framework/tensor_util_test.cc index f00ce79548..0dc5166fca 100644 --- a/paddle/framework/tensor_util_test.cc +++ b/paddle/framework/tensor_util_test.cc @@ -231,7 +231,7 @@ TEST(CopyToVector, Tensor) { #endif } -TEST(IsNAN, CPU) { +TEST(HasNAN, CPU) { using namespace paddle::framework; using namespace paddle::platform; Tensor src; @@ -243,7 +243,7 @@ TEST(IsNAN, CPU) { ASSERT_TRUE(HasNAN(src)); } -TEST(IsInf, CPU) { +TEST(HasInf, CPU) { using namespace paddle::framework; using namespace paddle::platform; Tensor src; From e188f0c16041f560ce7efe3a763a9dc164a06f28 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Fri, 29 Dec 2017 15:36:43 +0800 Subject: [PATCH 30/33] add paddle version of pip install --- doc/getstarted/build_and_install/pip_install_cn.rst | 4 ++-- doc/getstarted/build_and_install/pip_install_en.rst | 4 ++-- doc/getstarted/index_cn.rst | 4 ++-- doc/getstarted/index_en.rst | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/doc/getstarted/build_and_install/pip_install_cn.rst b/doc/getstarted/build_and_install/pip_install_cn.rst index a4587f82a9..0c741e936b 100644 --- a/doc/getstarted/build_and_install/pip_install_cn.rst +++ b/doc/getstarted/build_and_install/pip_install_cn.rst @@ -11,14 +11,14 @@ PaddlePaddle可以使用常用的Python包管理工具 ------------------------------ -执行下面的命令即可在当前机器上安装PaddlePaddle的运行时环境,并自动下载安装依赖软件。 +执行下面的命令即可在当前机器上安装PaddlePaddle的运行时环境,并自动下载安装依赖软件,版本为cpu_avx_openblas。 .. code-block:: bash pip install paddlepaddle -如果需要安装支持GPU的版本,需要执行: +如果需要安装支持GPU的版本(cuda7.5_cudnn5_avx_openblas),需要执行: .. code-block:: bash diff --git a/doc/getstarted/build_and_install/pip_install_en.rst b/doc/getstarted/build_and_install/pip_install_en.rst index 55e31560a0..285ed09805 100644 --- a/doc/getstarted/build_and_install/pip_install_en.rst +++ b/doc/getstarted/build_and_install/pip_install_en.rst @@ -12,14 +12,14 @@ Install Using pip ------------------------------ Run the following command to install PaddlePaddle on the current -machine, it will also download requirements. +machine, it will also download requirements, the version is cpu_avx_openblas. .. code-block:: bash pip install paddlepaddle -If you wish to install GPU version, just run: +If you wish to install GPU version (cuda7.5_cudnn5_avx_openblas), just run: .. code-block:: bash diff --git a/doc/getstarted/index_cn.rst b/doc/getstarted/index_cn.rst index a9087be6f3..9f6ee25987 100644 --- a/doc/getstarted/index_cn.rst +++ b/doc/getstarted/index_cn.rst @@ -7,13 +7,13 @@ ++++++++ PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.04以及MacOS 10.12,并安装有Python2.7。 -执行下面的命令完成快速安装: +执行下面的命令完成快速安装,版本为cpu_avx_openblas: .. code-block:: bash pip install paddlepaddle -如果需要安装支持GPU的版本,需要执行: +如果需要安装支持GPU的版本(cuda7.5_cudnn5_avx_openblas),需要执行: .. code-block:: bash diff --git a/doc/getstarted/index_en.rst b/doc/getstarted/index_en.rst index d14e3f5c0c..063d9d880c 100644 --- a/doc/getstarted/index_en.rst +++ b/doc/getstarted/index_en.rst @@ -8,13 +8,13 @@ Quick Install You can use pip to install PaddlePaddle with a single command, supports CentOS 6 above, Ubuntu 14.04 above or MacOS 10.12, with Python 2.7 installed. -Simply run the following command to install: +Simply run the following command to install, the version is cpu_avx_openblas: .. code-block:: bash pip install paddlepaddle -If you need to install GPU version, run: +If you need to install GPU version (cuda7.5_cudnn5_avx_openblas), run: .. code-block:: bash From c144261d40ab7c5d24e29c03155310a53d79909e Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Fri, 29 Dec 2017 16:34:08 +0800 Subject: [PATCH 31/33] add paddle version of docker --- doc/getstarted/build_and_install/docker_install_cn.rst | 8 ++++---- doc/getstarted/build_and_install/docker_install_en.rst | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/getstarted/build_and_install/docker_install_cn.rst b/doc/getstarted/build_and_install/docker_install_cn.rst index fa1b6a3727..bae42593dd 100644 --- a/doc/getstarted/build_and_install/docker_install_cn.rst +++ b/doc/getstarted/build_and_install/docker_install_cn.rst @@ -15,7 +15,7 @@ 获取PaddlePaddle的Docker镜像 ------------------------------ -执行下面的命令获取最新的PaddlePaddle Docker镜像 +执行下面的命令获取最新的PaddlePaddle Docker镜像,版本为cpu_avx_mkl: .. code-block:: bash @@ -27,7 +27,7 @@ docker pull docker.paddlepaddle.org/paddle -下载GPU版本的Docker镜像: +下载GPU版本(cuda8.0_cudnn5_avx_mkl)的Docker镜像: .. code-block:: bash @@ -54,7 +54,7 @@ .. _docker_run: 在Docker中执行PaddlePaddle训练程序 ------------------------------- +---------------------------------- 假设您已经在当前目录(比如在/home/work)编写了一个PaddlePaddle的程序 :code:`train.py` (可以参考 `PaddlePaddleBook `_ @@ -82,7 +82,7 @@ .. _docker_run_book: 使用Docker启动PaddlePaddle Book教程 ------------------------------- +----------------------------------- 使用Docker可以快速在本地启动一个包含了PaddlePaddle官方Book教程的Jupyter Notebook,可以通过网页浏览。 PaddlePaddle Book是为用户和开发者制作的一个交互式的Jupyter Notebook。 diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index 06012bf65e..56a7c68e4d 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -16,7 +16,7 @@ After you've read above tutorials you may proceed the following steps. Pull PaddlePaddle Docker Image ------------------------------ -Run the following command to download the latest Docker images: +Run the following command to download the latest Docker images, the version is cpu_avx_mkl: .. code-block:: bash @@ -28,7 +28,7 @@ For users in China, we provide a faster mirror: docker pull docker.paddlepaddle.org/paddle -Download GPU version images: +Download GPU version (cuda8.0_cudnn5_avx_mkl) images: .. code-block:: bash @@ -58,7 +58,7 @@ and run: .. _docker_run: Launch your training program in Docker ------------------------------- +-------------------------------------- Assume that you have already written a PaddlePaddle program named :code:`train.py` under directory :code:`/home/work` (refer to From 5036cf03872a1a1b68cd974e21193ae82f5da071 Mon Sep 17 00:00:00 2001 From: QI JUN Date: Fri, 29 Dec 2017 16:43:10 +0800 Subject: [PATCH 32/33] add helper function to get appropriate DeviceContext (#7066) * add helper function to get appropriate DeviceContext --- paddle/framework/data_transform.h | 5 ++-- paddle/framework/data_transform_test.cc | 15 ++++++------ paddle/framework/operator.cc | 32 ++++++++++++++++++------- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 2191dd3783..bd6d301c12 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -27,9 +27,8 @@ limitations under the License. */ namespace paddle { namespace framework { -using DataTransformFn = - std::function ctx, - const Variable& in, Variable* out)>; +using DataTransformFn = std::function; using KernelTypePair = std::pair; struct KernelTypePairHash { diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index 4e2141ecd2..5f05e881fa 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -54,18 +54,18 @@ auto kernel1 = GenFromBit({0, 0, 0, 1}); auto kernel2 = GenFromBit({0, 0, 1, 0}); auto kernel3 = GenFromBit({0, 0, 1, 1}); -void TransDataType_t(std::vector ctx, - const Variable& in, Variable* out) { +void TransDataType_t(const platform::DeviceContext* ctx, const Variable& in, + Variable* out) { test_value++; } -void TransDataLayout_t(std::vector ctx, - const Variable& in, Variable* out) { +void TransDataLayout_t(const platform::DeviceContext* ctx, const Variable& in, + Variable* out) { test_value--; } -void TransLibraryType_t(std::vector ctx, - const Variable& in, Variable* out) { +void TransLibraryType_t(const platform::DeviceContext* ctx, const Variable& in, + Variable* out) { test_value += 2; } @@ -83,7 +83,8 @@ TEST(DataTransform, Register) { using namespace paddle::platform; auto& instance = DataTransformFnMap::Instance(); - std::vector ctx; + ASSERT_EQ(instance.Map().size(), 3UL); + DeviceContext* ctx = nullptr; paddle::framework::Variable in; paddle::framework::Variable out; diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index c0be11294c..a3ce96c409 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -384,6 +384,24 @@ class RuntimeInferShapeContext : public InferShapeContext { const Scope& scope_; }; +const platform::DeviceContext* GetDeviceContext( + framework::KernelTypePair& kernel_pair) { + auto& actual_kernel_key = kernel_pair.first; + auto& expected_kernel_key = kernel_pair.second; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + + if (platform::is_gpu_place(actual_kernel_key.place_) && + platform::is_cpu_place(expected_kernel_key.place_)) { + return pool.Get(actual_kernel_key.place_); + } else if (platform::is_cpu_place(actual_kernel_key.place_) && + platform::is_gpu_place(expected_kernel_key.place_)) { + return pool.Get(expected_kernel_key.place_); + } else { + PADDLE_THROW( + "Currently, model parallelism is only supported between CPU and CUDA"); + } +} + void OperatorWithKernel::Run(const Scope& scope, const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); @@ -418,9 +436,9 @@ void OperatorWithKernel::Run(const Scope& scope, "CPU and other devices. For example, multi-GPU model " "parallelism will failed."); } else { + auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key); const DataTransformFn* trans_fun = - DataTransformFnMap::Instance().GetNullable( - std::make_pair(actual_kernel_key, expected_kernel_key)); + DataTransformFnMap::Instance().GetNullable(kernel_pair); if (trans_fun) { auto input_vars = this->InputVars(); // TODO(qijun) filter the input vars that do not need to be transformed @@ -437,22 +455,18 @@ void OperatorWithKernel::Run(const Scope& scope, } if (!need_trans.empty()) { - // TODO(qijun) get appropriate DeviceContext from DeviceContext pool - platform::DeviceContext* trans_dev_ctx = nullptr; - std::vector trans_dev_ctx_vec{trans_dev_ctx}; + auto trans_dev_ctx = GetDeviceContext(kernel_pair); // Wait for transform starting dev_ctx->Wait(); for (auto var_name : need_trans) { - (*trans_fun)(trans_dev_ctx_vec, *(scope.FindVar(var_name)), + (*trans_fun)(trans_dev_ctx, *(scope.FindVar(var_name)), scope.FindVar(var_name + framework::KernelTypeToString( expected_kernel_key))); } // Wait for data transform finishing - for (auto ctx : trans_dev_ctx_vec) { - ctx->Wait(); - } + trans_dev_ctx->Wait(); } } } From d14ca1c39f16b3744cd42e27d86a21a1f5020e37 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 29 Dec 2017 17:30:28 +0800 Subject: [PATCH 33/33] fix inference crash of alexnet benchmark --- benchmark/paddle/image/alexnet.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/benchmark/paddle/image/alexnet.py b/benchmark/paddle/image/alexnet.py index 77d130ae34..cad6051f14 100644 --- a/benchmark/paddle/image/alexnet.py +++ b/benchmark/paddle/image/alexnet.py @@ -19,7 +19,11 @@ args = { 'num_samples': num_samples } define_py_data_sources2( - "train.list", None, module="provider", obj="process", args=args) + "train.list" if not is_infer else None, + "test.list" if is_infer else None, + module="provider", + obj="process", + args=args) settings( batch_size=batch_size,