From 43cee33a23cb1dce5501f0642a38f97dad8cea45 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 2 Aug 2018 15:38:44 +0800 Subject: [PATCH 01/94] add mkl packed gemm --- paddle/fluid/operators/math/blas.h | 37 +++++++++++++ paddle/fluid/operators/math/blas_impl.h | 73 +++++++++++++++++++++++++ paddle/fluid/platform/dynload/mklml.h | 8 +++ 3 files changed, 118 insertions(+) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 70f88f24f6..2470df9d78 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -90,6 +90,23 @@ class Blas { void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, int ldc) const; + template + T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, + const int K) const; + + template + void GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M, + int N, int K, const T alpha, const T* src, const int ld, + T* dst) const; + + template + void GEMM_COMPUTE(int transA, int transB, int M, int N, int K, const T* A, + const int lda, const T* B, const int ldb, T beta, T* C, + const int ldc) const; + + template + void GEMM_FREE(T* data) const; + template void MatMul(const framework::Tensor& mat_a, bool trans_a, const framework::Tensor& mat_b, bool trans_b, T alpha, @@ -146,6 +163,26 @@ class BlasT : private Blas { Base()->template GEMM(args...); } + template + T* GEMM_ALLOC(ARGS... args) const { + Base()->template GEMM_ALLOC(args...); + } + + template + void GEMM_PACK(ARGS... args) const { + Base()->template GEMM_PACK(args...); + } + + template + void GEMM_COMPUTE(ARGS... args) const { + Base()->template GEMM_COMPUTE(args...); + } + + template + void GEMM_FREE(ARGS... args) const { + Base()->template GEMM_FREE(args...); + } + template void MatMul(ARGS... args) const { Base()->template MatMul(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index a0802ef90c..4164fe6229 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -31,6 +31,26 @@ struct CBlas { platform::dynload::cblas_sgemm(args...); } + template + static float *GEMM_ALLOC(ARGS... args) { + return platform::dynload::cblas_sgemm_alloc(args...); + } + + template + static void GEMM_PACK(ARGS... args) { + platform::dynload::cblas_sgemm_pack(args...); + } + + template + static void GEMM_COMPUTE(ARGS... args) { + platform::dynload::cblas_sgemm_compute(args...); + } + + template + static void GEMM_FREE(ARGS... args) { + platform::dynload::cblas_sgemm_free(args...); + } + #ifdef PADDLE_WITH_LIBXSMM template static void SMM_GEMM(ARGS... args) { @@ -71,6 +91,26 @@ struct CBlas { platform::dynload::cblas_dgemm(args...); } + template + static double *GEMM_ALLOC(ARGS... args) { + return platform::dynload::cblas_dgemm_alloc(args...); + } + + template + static void GEMM_PACK(ARGS... args) { + platform::dynload::cblas_dgemm_pack(args...); + } + + template + static void GEMM_COMPUTE(ARGS... args) { + platform::dynload::cblas_dgemm_compute(args...); + } + + template + static void GEMM_FREE(ARGS... args) { + platform::dynload::cblas_dgemm_free(args...); + } + #ifdef PADDLE_WITH_LIBXSMM template static void SMM_GEMM(ARGS... args) { @@ -224,6 +264,39 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, beta, C, ldc); } +template <> +template +T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, + const int M, const int N, + const int K) const { + return CBlas::GEMM_ALLOC(id, M, N, K); +} + +template <> +template +void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, + const CBLAS_TRANSPOSE trans, + int M, int N, int K, + const T alpha, const T *src, + const int ld, T *dst) const { + CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); +} + +template <> +template +void Blas::GEMM_COMPUTE( + int transA, int transB, int M, int N, int K, const T *A, const int lda, + const T *B, const int ldb, T beta, T *C, const int ldc) const { + CBlas::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, + beta, C, ldc); +} + +template <> +template +void Blas::GEMM_FREE(T *data) const { + CBlas::GEMM_FREE(data); +} + template <> template void Blas::GEMM(CBLAS_TRANSPOSE transA, diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 17acefe8cd..9e7a616094 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -60,6 +60,14 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_batch); \ __macro(vsAdd); \ __macro(vdAdd); \ + __macro(cblas_sgemm_alloc); \ + __macro(cblas_sgemm_pack); \ + __macro(cblas_sgemm_compute); \ + __macro(cblas_sgemm_free); \ + __macro(cblas_dgemm_alloc); \ + __macro(cblas_dgemm_pack); \ + __macro(cblas_dgemm_compute); \ + __macro(cblas_dgemm_free); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); From d9cc6b18662295383f925e12b6a5e0cf5dabd14a Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 3 Aug 2018 13:31:53 +0800 Subject: [PATCH 02/94] replace gru compute with details --- paddle/fluid/operators/gru_op.h | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index 3b0d93e54b..4e534789ce 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -16,7 +16,10 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" +#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" +#include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence2batch.h" @@ -94,6 +97,7 @@ class GRUKernel : public framework::OpKernel { context.Attr("activation")); auto active_gate = math::detail::GetActivationType( context.Attr("gate_activation")); + auto blas = math::GetBlas(dev_ctx); for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -105,9 +109,27 @@ class GRUKernel : public framework::OpKernel { gru_value.output_value = hidden_t.data(); gru_value.gate_value = gate_t.data(); gru_value.reset_output_value = reset_hidden_prev_t.data(); - math::GRUUnitFunctor::compute( - dev_ctx, gru_value, frame_size, cur_batch_size, active_node, - active_gate); + if (gru_value.prev_out_value) { + blas.GEMM(false, false, cur_batch_size, frame_size * 2, frame_size, 1, + gru_value.prev_out_value, frame_size, gru_value.gate_weight, + frame_size * 2, 1, gru_value.gate_value, frame_size * 3); + } + + math::detail::forward_reset_output( + math::detail::forward::gru_resetOutput(), gru_value, frame_size, + cur_batch_size, active_gate); + + if (gru_value.prev_out_value) { + blas.GEMM(false, false, cur_batch_size, frame_size, frame_size, 1, + gru_value.reset_output_value, frame_size, + gru_value.state_weight, frame_size, 1, + gru_value.gate_value + frame_size * 2, frame_size * 3); + } + + math::detail::forward_final_output( + math::detail::forward::gru_finalOutput(), gru_value, frame_size, + cur_batch_size, active_node); + gru_value.prev_out_value = gru_value.output_value; } From 8c23f7c4f029ba3b22481ae27b721b7a4ac18e8b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 3 Aug 2018 18:44:36 +0800 Subject: [PATCH 03/94] fix blas and use packed weight --- paddle/fluid/operators/gru_op.h | 34 ++++++++++++++++++++++++------ paddle/fluid/operators/math/blas.h | 2 +- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index 4e534789ce..a9450337e7 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -98,6 +98,23 @@ class GRUKernel : public framework::OpKernel { auto active_gate = math::detail::GetActivationType( context.Attr("gate_activation")); auto blas = math::GetBlas(dev_ctx); + + // TODO(TJ): make a class, make one pack + T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_gate); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, + frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, + packed_gate); + T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_state); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, + frame_size, T(1.0), gru_value.state_weight, frame_size, + packed_state); + for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -110,9 +127,10 @@ class GRUKernel : public framework::OpKernel { gru_value.gate_value = gate_t.data(); gru_value.reset_output_value = reset_hidden_prev_t.data(); if (gru_value.prev_out_value) { - blas.GEMM(false, false, cur_batch_size, frame_size * 2, frame_size, 1, - gru_value.prev_out_value, frame_size, gru_value.gate_weight, - frame_size * 2, 1, gru_value.gate_value, frame_size * 3); + blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size, + frame_size * 2, frame_size, gru_value.prev_out_value, + frame_size, packed_gate, frame_size * 2, T(1), + gru_value.gate_value, frame_size * 3); } math::detail::forward_reset_output( @@ -120,10 +138,10 @@ class GRUKernel : public framework::OpKernel { cur_batch_size, active_gate); if (gru_value.prev_out_value) { - blas.GEMM(false, false, cur_batch_size, frame_size, frame_size, 1, - gru_value.reset_output_value, frame_size, - gru_value.state_weight, frame_size, 1, - gru_value.gate_value + frame_size * 2, frame_size * 3); + blas.GEMM_COMPUTE( + CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, + gru_value.reset_output_value, frame_size, packed_state, frame_size, + T(1), gru_value.gate_value + frame_size * 2, frame_size * 3); } math::detail::forward_final_output( @@ -132,6 +150,8 @@ class GRUKernel : public framework::OpKernel { gru_value.prev_out_value = gru_value.output_value; } + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); math::Batch2LoDTensorFunctor to_seq; batch_hidden->set_lod(batch_gate->lod()); diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 2470df9d78..485e96227e 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -165,7 +165,7 @@ class BlasT : private Blas { template T* GEMM_ALLOC(ARGS... args) const { - Base()->template GEMM_ALLOC(args...); + return Base()->template GEMM_ALLOC(args...); } template From e0ab2f71589a71e918a94dd307d18f9a54864199 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Sun, 5 Aug 2018 15:21:32 +0800 Subject: [PATCH 04/94] new sampling op --- paddle/fluid/operators/sampling_id_op.cc | 64 ++++++++++++++++++++++ paddle/fluid/operators/sampling_id_op.cu | 40 ++++++++++++++ paddle/fluid/operators/sampling_id_op.h | 68 ++++++++++++++++++++++++ 3 files changed, 172 insertions(+) create mode 100644 paddle/fluid/operators/sampling_id_op.cc create mode 100644 paddle/fluid/operators/sampling_id_op.cu create mode 100644 paddle/fluid/operators/sampling_id_op.h diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc new file mode 100644 index 0000000000..20e3d43217 --- /dev/null +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -0,0 +1,64 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sampling_id_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class SamplingIdOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of RowConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of RowConvOp should not be null."); + + auto input_dims = ctx->GetInputDim("X"); + + framework::DDim dims = input_dims; + ctx->SetOutputDim("Out", dims); + ctx->ShareLoD("X", "Out"); + } +}; + +class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of softmax. " + "2-D with shape [batch_size, input_feature_dimensions]."); + AddOutput("Out", "Sliced data tensor."); + + AddComment(R"DOC( +SamplingId Operator. + @brief A layer for sampling id from multinomial distribution from the + input layer. Sampling one id for one sample. The result is stored in + output_.ids. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + slice, ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel); diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu new file mode 100644 index 0000000000..4fa10de2cd --- /dev/null +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -0,0 +1,40 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/operators/sampling_id_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class SamplingIdOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override {} +} +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(samplingid, ops::SamplingIdOp, ops::SamplingIdOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + slice, ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel); diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h new file mode 100644 index 0000000000..eeb72d8f7d --- /dev/null +++ b/paddle/fluid/operators/sampling_id_op.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class SamplingIdKernel : public framework::OpKernel { + /// Produces random floating-point values, uniformly distributed on [0, 1). + std::uniform_real_distribution rand1_; + + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("X"); + const int batch_size = static_cast(input->dims()[0]); + const int width = static_cast(input->dims()[1]); + + std::vector ids(batchSize); + auto& reng = get(); + + for (size_t i = 0; i < batchSize; ++i) { + double r = rand1_(reng); + int id = dim - 1; + for (int j = 0; j < dim; ++j) { + if ((r -= buf[i * dim + j]) < 0) { + id = j; + break; + } + } + ids[i] = id; + } + + std::vector out_dim; + out_dim.push_back(static_cast(batch_size)); + + Tensor* output = context.Output("Output"); + output->Resize(framework::make_ddim(in_dim)); + output->mutable_data(context.GetPlace()); + framework::TensorFromVector(ids, context.device_context(), output); + } + + std::default_random_engine& get() { + auto engine = new std::default_random_engine; + engine->seed(defaultSeed); + return *engine; + } + + private: + unsigned int defaultSeed = 0; +} +} // namespace operators +} // namespace paddle From 3206970b770fb3a45d7a7c85566cab6b16db28d7 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 6 Aug 2018 10:24:59 +0800 Subject: [PATCH 05/94] sampling op rename --- paddle/fluid/operators/sampling_id_op.cc | 3 ++- paddle/fluid/operators/sampling_id_op.cu | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index 20e3d43217..b9e3b0372d 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -58,7 +58,8 @@ SamplingId Operator. namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - slice, ops::SamplingIdKernel, + sampling_id, + ops::SamplingIdKernel, ops::SamplingIdKernel, ops::SamplingIdKernel, ops::SamplingIdKernel); diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index 4fa10de2cd..f82ba68ce4 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -30,11 +30,11 @@ class SamplingIdOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(samplingid, ops::SamplingIdOp, ops::SamplingIdOpMaker, +REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - slice, ops::SamplingIdKernel, + sampling_id, ops::SamplingIdKernel, ops::SamplingIdKernel, ops::SamplingIdKernel, ops::SamplingIdKernel); From 54c95e49f09e70233adb363b5b612cb8d427c116 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 6 Aug 2018 11:34:11 +0800 Subject: [PATCH 06/94] fix blas --- paddle/fluid/operators/math/blas.h | 4 ++++ paddle/fluid/operators/math/blas_impl.h | 2 ++ 2 files changed, 6 insertions(+) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 485e96227e..2558154e0b 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -90,6 +90,7 @@ class Blas { void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, int ldc) const; +#ifdef PADDLE_WITH_MKLML template T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, const int K) const; @@ -106,6 +107,7 @@ class Blas { template void GEMM_FREE(T* data) const; +#endif template void MatMul(const framework::Tensor& mat_a, bool trans_a, @@ -163,6 +165,7 @@ class BlasT : private Blas { Base()->template GEMM(args...); } +#ifdef PADDLE_WITH_MKLML template T* GEMM_ALLOC(ARGS... args) const { return Base()->template GEMM_ALLOC(args...); @@ -182,6 +185,7 @@ class BlasT : private Blas { void GEMM_FREE(ARGS... args) const { Base()->template GEMM_FREE(args...); } +#endif template void MatMul(ARGS... args) const { diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 4164fe6229..bf33821079 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -264,6 +264,7 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, beta, C, ldc); } +#ifdef PADDLE_WITH_MKLML template <> template T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, @@ -296,6 +297,7 @@ template void Blas::GEMM_FREE(T *data) const { CBlas::GEMM_FREE(data); } +#endif template <> template From 18c322c2a1133bcc6350aea1b148bb6d767e6933 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 6 Aug 2018 11:38:59 +0800 Subject: [PATCH 07/94] seperate cpu and gpu implementations for gru kernel compute --- paddle/fluid/operators/gru_op.cc | 138 +++++++++++++++++++++++++++- paddle/fluid/operators/gru_op.cu.cc | 90 ++++++++++++++++++ paddle/fluid/operators/gru_op.h | 123 ------------------------- 3 files changed, 225 insertions(+), 126 deletions(-) diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 5c74687882..4847eb3626 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -211,6 +211,139 @@ class GRUGradOp : public framework::OperatorWithKernel { } }; +template +class GRUCPUKernel : public framework::OpKernel { + public: + void BatchCompute(const framework::ExecutionContext& context) const { + using DeviceContext = paddle::platform::CPUDeviceContext; + auto* input = context.Input("Input"); + auto* h0 = context.Input("H0"); + auto* weight = context.Input("Weight"); + const T* weight_data = weight->data(); + auto* bias = context.Input("Bias"); + auto* batch_gate = context.Output("BatchGate"); + batch_gate->mutable_data(context.GetPlace()); + auto* batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + auto* batch_hidden = context.Output("BatchHidden"); + batch_hidden->mutable_data(context.GetPlace()); + auto* hidden = context.Output("Hidden"); + hidden->mutable_data(context.GetPlace()); + + auto hidden_dims = hidden->dims(); + + bool is_reverse = context.Attr("is_reverse"); + math::LoDTensor2BatchFunctor to_batch; + auto& dev_ctx = context.template device_context(); + to_batch(dev_ctx, *input, batch_gate, true, is_reverse); + + if (bias) { + math::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, *bias, batch_gate); + } + + int frame_size = hidden_dims[1]; + math::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + Tensor ordered_h0; + + framework::Vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState( + context.template device_context(), *h0, order, + &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto active_node = math::detail::GetActivationType( + context.Attr("activation")); + auto active_gate = math::detail::GetActivationType( + context.Attr("gate_activation")); + +#ifdef PADDLE_WITH_MKLML + auto blas = math::GetBlas(dev_ctx); + // TODO(TJ): make a class + T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_gate); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, + frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, + packed_gate); + T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_state); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, + frame_size, T(1.0), gru_value.state_weight, frame_size, + packed_state); +#endif + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + +#ifdef PADDLE_WITH_MKLML + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size, + frame_size * 2, frame_size, gru_value.prev_out_value, + frame_size, packed_gate, frame_size * 2, T(1), + gru_value.gate_value, frame_size * 3); + } + + math::detail::forward_reset_output( + math::detail::forward::gru_resetOutput(), gru_value, frame_size, + cur_batch_size, active_gate); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE( + CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, + gru_value.reset_output_value, frame_size, packed_state, frame_size, + T(1), gru_value.gate_value + frame_size * 2, frame_size * 3); + } + + math::detail::forward_final_output( + math::detail::forward::gru_finalOutput(), gru_value, frame_size, + cur_batch_size, active_node); +#else + math::GRUUnitFunctor::compute( + dev_ctx, gru_value, frame_size, cur_batch_size, active_node, + active_gate); +#endif + gru_value.prev_out_value = gru_value.output_value; + } +#ifdef PADDLE_WITH_MKLML + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); +#endif + + math::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(dev_ctx, *batch_hidden, hidden); + } + + void Compute(const framework::ExecutionContext& context) const override { + BatchCompute(context); + } +}; + } // namespace operators } // namespace paddle @@ -218,9 +351,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(gru_grad, ops::GRUGradOp); -REGISTER_OP_CPU_KERNEL( - gru, ops::GRUKernel, - ops::GRUKernel); +REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel, + ops::GRUCPUKernel); REGISTER_OP_CPU_KERNEL( gru_grad, ops::GRUGradKernel, ops::GRUGradKernel); diff --git a/paddle/fluid/operators/gru_op.cu.cc b/paddle/fluid/operators/gru_op.cu.cc index baf455a840..55721c283d 100644 --- a/paddle/fluid/operators/gru_op.cu.cc +++ b/paddle/fluid/operators/gru_op.cu.cc @@ -14,6 +14,96 @@ limitations under the License. */ #include "paddle/fluid/operators/gru_op.h" +namespace paddle { +namespace operators { + +template +class GRUKernel : public framework::OpKernel { + public: + void BatchCompute(const framework::ExecutionContext& context) const { + auto* input = context.Input("Input"); + auto* h0 = context.Input("H0"); + auto* weight = context.Input("Weight"); + const T* weight_data = weight->data(); + auto* bias = context.Input("Bias"); + auto* batch_gate = context.Output("BatchGate"); + batch_gate->mutable_data(context.GetPlace()); + auto* batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + auto* batch_hidden = context.Output("BatchHidden"); + batch_hidden->mutable_data(context.GetPlace()); + auto* hidden = context.Output("Hidden"); + hidden->mutable_data(context.GetPlace()); + + auto hidden_dims = hidden->dims(); + + bool is_reverse = context.Attr("is_reverse"); + math::LoDTensor2BatchFunctor to_batch; + auto& dev_ctx = context.template device_context(); + to_batch(dev_ctx, *input, batch_gate, true, is_reverse); + + if (bias) { + math::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, *bias, batch_gate); + } + + int frame_size = hidden_dims[1]; + math::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + Tensor ordered_h0; + + framework::Vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState( + context.template device_context(), *h0, order, + &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto active_node = math::detail::GetActivationType( + context.Attr("activation")); + auto active_gate = math::detail::GetActivationType( + context.Attr("gate_activation")); + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + math::GRUUnitFunctor::compute( + dev_ctx, gru_value, frame_size, cur_batch_size, active_node, + active_gate); + gru_value.prev_out_value = gru_value.output_value; + } + + math::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(dev_ctx, *batch_hidden, hidden); + } + + void Compute(const framework::ExecutionContext& context) const override { + BatchCompute(context); + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( gru, ops::GRUKernel, diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index a9450337e7..0bf4e6bc44 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -40,129 +40,6 @@ inline void ReorderInitState(const DeviceContext& ctx, row_shuffle(ctx, src, index_lod, dst, indexed_src); } -template -class GRUKernel : public framework::OpKernel { - public: - void BatchCompute(const framework::ExecutionContext& context) const { - auto* input = context.Input("Input"); - auto* h0 = context.Input("H0"); - auto* weight = context.Input("Weight"); - const T* weight_data = weight->data(); - auto* bias = context.Input("Bias"); - auto* batch_gate = context.Output("BatchGate"); - batch_gate->mutable_data(context.GetPlace()); - auto* batch_reset_hidden_prev = - context.Output("BatchResetHiddenPrev"); - batch_reset_hidden_prev->mutable_data(context.GetPlace()); - auto* batch_hidden = context.Output("BatchHidden"); - batch_hidden->mutable_data(context.GetPlace()); - auto* hidden = context.Output("Hidden"); - hidden->mutable_data(context.GetPlace()); - - auto hidden_dims = hidden->dims(); - - bool is_reverse = context.Attr("is_reverse"); - math::LoDTensor2BatchFunctor to_batch; - auto& dev_ctx = context.template device_context(); - to_batch(dev_ctx, *input, batch_gate, true, is_reverse); - - if (bias) { - math::RowwiseAdd add_bias; - add_bias(dev_ctx, *batch_gate, *bias, batch_gate); - } - - int frame_size = hidden_dims[1]; - math::GRUMetaValue gru_value; - gru_value.gate_weight = const_cast(weight_data); - gru_value.state_weight = - const_cast(weight_data + 2 * frame_size * frame_size); - Tensor ordered_h0; - - framework::Vector order(batch_gate->lod()[2]); - - if (h0) { - // Since the batch computing for GRU reorders the input sequences - // according to their length. The initialized cell state also needs - // to reorder. - ReorderInitState( - context.template device_context(), *h0, order, - &ordered_h0, true); - gru_value.prev_out_value = ordered_h0.data(); - } else { - gru_value.prev_out_value = nullptr; - } - auto batch_starts = batch_gate->lod()[0]; - size_t num_batch = batch_starts.size() - 1; - auto active_node = math::detail::GetActivationType( - context.Attr("activation")); - auto active_gate = math::detail::GetActivationType( - context.Attr("gate_activation")); - auto blas = math::GetBlas(dev_ctx); - - // TODO(TJ): make a class, make one pack - T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, - frame_size * 2 /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE(packed_gate); - blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, - frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, - packed_gate); - T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, - frame_size /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE(packed_state); - blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, - frame_size, T(1.0), gru_value.state_weight, frame_size, - packed_state); - - for (size_t n = 0; n < num_batch; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - Tensor gate_t = batch_gate->Slice(bstart, bend); - Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); - Tensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size, - frame_size * 2, frame_size, gru_value.prev_out_value, - frame_size, packed_gate, frame_size * 2, T(1), - gru_value.gate_value, frame_size * 3); - } - - math::detail::forward_reset_output( - math::detail::forward::gru_resetOutput(), gru_value, frame_size, - cur_batch_size, active_gate); - - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE( - CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, - gru_value.reset_output_value, frame_size, packed_state, frame_size, - T(1), gru_value.gate_value + frame_size * 2, frame_size * 3); - } - - math::detail::forward_final_output( - math::detail::forward::gru_finalOutput(), gru_value, frame_size, - cur_batch_size, active_node); - - gru_value.prev_out_value = gru_value.output_value; - } - blas.GEMM_FREE(packed_gate); - blas.GEMM_FREE(packed_state); - - math::Batch2LoDTensorFunctor to_seq; - batch_hidden->set_lod(batch_gate->lod()); - to_seq(dev_ctx, *batch_hidden, hidden); - } - - void Compute(const framework::ExecutionContext& context) const override { - BatchCompute(context); - } -}; - template class GRUGradKernel : public framework::OpKernel { public: From 4973e07be3fab37b7559b9a8abce12260a3233ea Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 6 Aug 2018 14:01:21 +0800 Subject: [PATCH 08/94] sampling op optimize --- paddle/fluid/operators/sampling_id_op.cc | 14 +++++---- paddle/fluid/operators/sampling_id_op.cu | 14 ++++----- paddle/fluid/operators/sampling_id_op.h | 36 +++++++++++++----------- 3 files changed, 34 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index b9e3b0372d..9729537d1e 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -57,9 +57,11 @@ SamplingId Operator. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - sampling_id, - ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel); +REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + sampling_id, ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel); diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index f82ba68ce4..e467165b6d 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -30,11 +30,9 @@ class SamplingIdOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker, - paddle::framework::EmptyGradOpMaker); - -REGISTER_OP_CPU_KERNEL( - sampling_id, ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel); +REGISTER_OP_CUDA_KERNEL( + sampling_id, + ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel); diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index eeb72d8f7d..5bb1991fc5 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -15,30 +15,31 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class SamplingIdKernel : public framework::OpKernel { - /// Produces random floating-point values, uniformly distributed on [0, 1). - std::uniform_real_distribution rand1_; - public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("X"); const int batch_size = static_cast(input->dims()[0]); const int width = static_cast(input->dims()[1]); - std::vector ids(batchSize); - auto& reng = get(); + std::vector ins_vector; + framework::TensorToVector(*input, context.device_context(), &ins_vector); - for (size_t i = 0; i < batchSize; ++i) { - double r = rand1_(reng); - int id = dim - 1; - for (int j = 0; j < dim; ++j) { - if ((r -= buf[i * dim + j]) < 0) { + std::vector ids(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + double r = this->get_rand(); + int id = width - 1; + for (int j = 0; j < width; ++j) { + if ((r -= ins_vector[i * width + j]) < 0) { id = j; break; } @@ -50,19 +51,22 @@ class SamplingIdKernel : public framework::OpKernel { out_dim.push_back(static_cast(batch_size)); Tensor* output = context.Output("Output"); - output->Resize(framework::make_ddim(in_dim)); + output->Resize(framework::make_ddim(out_dim)); output->mutable_data(context.GetPlace()); framework::TensorFromVector(ids, context.device_context(), output); } - std::default_random_engine& get() { - auto engine = new std::default_random_engine; - engine->seed(defaultSeed); - return *engine; + double get_rand() const { + // Will be used to obtain a seed for the random number engine + std::random_device rd; + // Standard mersenne_twister_engine seeded with rd() + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(0, 1); + return dis(gen); } private: unsigned int defaultSeed = 0; -} +}; } // namespace operators } // namespace paddle From 1f618c4ff9622259489546535c85309e4b619ebb Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 6 Aug 2018 16:43:24 +0800 Subject: [PATCH 09/94] Fix the overfix of 2to3 in xrange --- python/paddle/dataset/conll05.py | 2 +- python/paddle/dataset/imdb.py | 3 ++- python/paddle/dataset/imikolov.py | 5 ++-- python/paddle/dataset/mnist.py | 1 + python/paddle/dataset/tests/common_test.py | 1 + python/paddle/dataset/uci_housing.py | 5 ++-- python/paddle/fluid/executor.py | 2 +- python/paddle/fluid/framework.py | 18 ++++++++----- python/paddle/fluid/layer_helper.py | 2 +- python/paddle/fluid/layers/detection.py | 27 ++++++++++--------- python/paddle/fluid/layers/io.py | 7 ++--- python/paddle/fluid/nets.py | 3 ++- python/paddle/fluid/parallel_executor.py | 5 ++-- python/paddle/fluid/tests/demo/pyreader.py | 5 ++-- .../fluid/tests/unittests/dist_transformer.py | 3 ++- .../tests/unittests/test_split_ids_op.py | 3 ++- 16 files changed, 55 insertions(+), 37 deletions(-) diff --git a/python/paddle/dataset/conll05.py b/python/paddle/dataset/conll05.py index 25623feabb..724202b956 100644 --- a/python/paddle/dataset/conll05.py +++ b/python/paddle/dataset/conll05.py @@ -24,7 +24,7 @@ import tarfile import gzip import itertools import paddle.dataset.common -from six.moves import zip +from six.moves import zip, range __all__ = ['test, get_dict', 'get_embedding', 'convert'] diff --git a/python/paddle/dataset/imdb.py b/python/paddle/dataset/imdb.py index e7fe4e0b7e..39fc29fdac 100644 --- a/python/paddle/dataset/imdb.py +++ b/python/paddle/dataset/imdb.py @@ -25,6 +25,7 @@ import collections import tarfile import re import string +from six.moves import range __all__ = ['build_dict', 'train', 'test', 'convert'] @@ -66,7 +67,7 @@ def build_dict(pattern, cutoff): dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) words, _ = list(zip(*dictionary)) - word_idx = dict(list(zip(words, list(range(len(words)))))) + word_idx = dict(list(zip(words, range(len(words))))) word_idx[''] = len(words) return word_idx diff --git a/python/paddle/dataset/imikolov.py b/python/paddle/dataset/imikolov.py index bc007c9d3c..bfb087ff38 100644 --- a/python/paddle/dataset/imikolov.py +++ b/python/paddle/dataset/imikolov.py @@ -14,13 +14,14 @@ """ imikolov's simple dataset. -This module will download dataset from +This module will download dataset from http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set into paddle reader creators. """ import paddle.dataset.common import collections import tarfile +from six.moves import range __all__ = ['train', 'test', 'build_dict', 'convert'] @@ -68,7 +69,7 @@ def build_dict(min_word_freq=50): word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) words, _ = list(zip(*word_freq_sorted)) - word_idx = dict(list(zip(words, list(range(len(words)))))) + word_idx = dict(list(zip(words, range(len(words))))) word_idx[''] = len(words) return word_idx diff --git a/python/paddle/dataset/mnist.py b/python/paddle/dataset/mnist.py index ffa9008c80..55e82fa755 100644 --- a/python/paddle/dataset/mnist.py +++ b/python/paddle/dataset/mnist.py @@ -21,6 +21,7 @@ import paddle.dataset.common import subprocess import numpy import platform +from six.moves import range __all__ = ['train', 'test', 'convert'] URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/' diff --git a/python/paddle/dataset/tests/common_test.py b/python/paddle/dataset/tests/common_test.py index 777cd06a19..ede3d593eb 100644 --- a/python/paddle/dataset/tests/common_test.py +++ b/python/paddle/dataset/tests/common_test.py @@ -16,6 +16,7 @@ import paddle.dataset.common import unittest import tempfile import glob +from six.moves import range class TestCommon(unittest.TestCase): diff --git a/python/paddle/dataset/uci_housing.py b/python/paddle/dataset/uci_housing.py index 410ca7af0d..cc946762da 100644 --- a/python/paddle/dataset/uci_housing.py +++ b/python/paddle/dataset/uci_housing.py @@ -22,6 +22,7 @@ parse training set and test set into paddle reader creators. import os import numpy as np +import six import tempfile import tarfile import os @@ -74,7 +75,7 @@ def load_data(filename, feature_num=14, ratio=0.8): maximums, minimums, avgs = data.max(axis=0), data.min(axis=0), data.sum( axis=0) / data.shape[0] feature_range(maximums[:-1], minimums[:-1]) - for i in range(feature_num - 1): + for i in six.moves.range(feature_num - 1): data[:, i] = (data[:, i] - avgs[i]) / (maximums[i] - minimums[i]) offset = int(data.shape[0] * ratio) UCI_TRAIN_DATA = data[:offset] @@ -137,7 +138,7 @@ def predict_reader(): It returns just one tuple data to do inference. :return: one tuple data - :rtype: tuple + :rtype: tuple """ global UCI_TEST_DATA load_data(paddle.dataset.common.download(URL, 'uci_housing', MD5)) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 35da1d06a2..8437a9f20f 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -346,7 +346,7 @@ class Executor(object): def _fetch_data(self, fetch_list, fetch_var_name, scope): outs = [ core.get_fetch_variable(scope, fetch_var_name, i) - for i in range(len(fetch_list)) + for i in six.moves.range(len(fetch_list)) ] return outs diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index f0653a43ce..9a2c8adc03 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1497,7 +1497,9 @@ class Program(object): else: p = Program() p.desc = core.ProgramDesc(self.desc) - p.blocks = [Block(p, i) for i in range(self.desc.num_blocks())] + p.blocks = [ + Block(p, i) for i in six.moves.range(self.desc.num_blocks()) + ] p._sync_with_cpp() p._copy_param_info_from(self) @@ -1549,7 +1551,9 @@ class Program(object): targets_idx.append([t.block.idx, t.idx]) res = Program() res.desc = core.prune(self.desc, targets_idx) - res.blocks = [Block(res, i) for i in range(res.desc.num_blocks())] + res.blocks = [ + Block(res, i) for i in six.moves.range(res.desc.num_blocks()) + ] res._sync_with_cpp() return res @@ -1590,13 +1594,15 @@ class Program(object): root_block._remove_var(var.name()) # change all `is_test` attributes to True - for i in range(res.desc.num_blocks()): + for i in six.moves.range(res.desc.num_blocks()): block = res.desc.block(i) - for j in range(block.op_size()): + for j in six.moves.range(block.op_size()): op = block.op(j) if op.has_attr('is_test'): op.set_attr('is_test', True) - res.blocks = [Block(res, i) for i in range(res.desc.num_blocks())] + res.blocks = [ + Block(res, i) for i in six.moves.range(res.desc.num_blocks()) + ] res._sync_with_cpp() return res @@ -1616,7 +1622,7 @@ class Program(object): """ p = Program() p.desc = core.ProgramDesc(binary_str) - p.blocks = [Block(p, i) for i in range(p.desc.num_blocks())] + p.blocks = [Block(p, i) for i in six.moves.range(p.desc.num_blocks())] p._sync_with_cpp() return p diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index 64337465ed..5f66f54cf7 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -85,7 +85,7 @@ class LayerHelper(object): raise ValueError("parameter number mismatch") elif len(param_attr) == 1 and length != 1: tmp = [None] * length - for i in range(length): + for i in six.moves.range(length): tmp[i] = copy.deepcopy(param_attr[0]) param_attr = tmp return param_attr diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 9fae96d9bc..c11455b7a6 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -21,6 +21,7 @@ from ..layer_helper import LayerHelper from . import tensor from . import nn import math +import six from functools import reduce __all__ = [ @@ -102,7 +103,7 @@ def rpn_target_assign(loc, examples. Returns: - tuple: + tuple: A tuple(predicted_scores, predicted_location, target_label, target_bbox) is returned. The predicted_scores and predicted_location is the predicted result of the RPN. @@ -113,7 +114,7 @@ def rpn_target_assign(loc, anchors. The predicted_scores is a 2D Tensor with shape [F + B, 1], and the shape of target_label is same as the shape of the predicted_scores, B is the number of the background - anchors, the F and B is depends on the input of this operator. + anchors, the F and B is depends on the input of this operator. Examples: .. code-block:: python @@ -230,8 +231,8 @@ def detection_output(loc, nms_eta(float): The parameter for adaptive NMS. Returns: - Variable: - + Variable: + The detection outputs is a LoDTensor with shape [No, 6]. Each row has six values: [label, confidence, xmin, ymin, xmax, ymax]. `No` is the total number of detections in this mini-batch. For each @@ -501,7 +502,7 @@ def target_assign(input, Assumed that the row offset for each instance in `neg_indices` is called neg_lod, for i-th instance and each `id` of neg_indices in this instance: - + .. code-block:: text out[i][id][0 : K] = {mismatch_value, mismatch_value, ...} @@ -519,11 +520,11 @@ def target_assign(input, mismatch_value (float32): Fill this value to the mismatched location. Returns: - tuple: - A tuple(out, out_weight) is returned. out is a 3D Tensor with - shape [N, P, K], N and P is the same as they are in - `neg_indices`, K is the same as it in input of X. If - `match_indices[i][j]`. out_weight is the weight for output with + tuple: + A tuple(out, out_weight) is returned. out is a 3D Tensor with + shape [N, P, K], N and P is the same as they are in + `neg_indices`, K is the same as it in input of X. If + `match_indices[i][j]`. out_weight is the weight for output with the shape of [N, P, 1]. Examples: @@ -822,7 +823,7 @@ def prior_box(input, offset(float): Prior boxes center offset. Default: 0.5 name(str): Name of the prior box op. Default: None. min_max_aspect_ratios_order(bool): If set True, the output prior box is - in order of [min, max, aspect_ratios], which is consistent with + in order of [min, max, aspect_ratios], which is consistent with Caffe. Please note, this order affects the weights order of convolution layer followed by and does not affect the final detection results. Default: False. @@ -965,7 +966,7 @@ def multi_box_head(inputs, stride(int|list|tuple): The stride of conv2d. Default:1, name(str): Name of the prior box layer. Default: None. min_max_aspect_ratios_order(bool): If set True, the output prior box is - in order of [min, max, aspect_ratios], which is consistent with + in order of [min, max, aspect_ratios], which is consistent with Caffe. Please note, this order affects the weights order of convolution layer followed by and does not affect the fininal detection results. Default: False. @@ -1033,7 +1034,7 @@ def multi_box_head(inputs, min_sizes = [] max_sizes = [] step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2))) - for ratio in range(min_ratio, max_ratio + 1, step): + for ratio in six.moves.range(min_ratio, max_ratio + 1, step): min_sizes.append(base_size * ratio / 100.) max_sizes.append(base_size * (ratio + step) / 100.) min_sizes = [base_size * .10] + min_sizes diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 327ae30981..f9b01203e2 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib import multiprocessing +import six import threading from ..data_feeder import DataFeeder @@ -69,7 +70,7 @@ def data(name, """ helper = LayerHelper('data', **locals()) shape = list(shape) - for i in range(len(shape)): + for i in six.moves.range(len(shape)): if shape[i] is None: shape[i] = -1 append_batch_size = False @@ -674,7 +675,7 @@ def py_reader(capacity, def __tensor_provider__(): for slots in paddle_reader(): - yield [slots[str(idx)] for idx in xrange(counter)] + yield [slots[str(idx)] for idx in six.moves.xrange(counter)] __set_tensor_provider__(__tensor_provider__) @@ -1005,7 +1006,7 @@ class Preprocessor(object): source_lod_levels = self.underlying_reader.desc.lod_levels() self.source_var_names = [ unique_name("preprocessor_source") - for _ in range(len(source_shapes)) + for _ in six.moves.range(len(source_shapes)) ] source_vars = [] for var_name, shape, dtype, lod_level in zip( diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index 08480671d8..46e4c70195 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import six from . import layers __all__ = [ @@ -210,7 +211,7 @@ def img_conv_group(input, conv_with_batchnorm = __extend_list__(conv_with_batchnorm) conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate) - for i in range(len(conv_num_filter)): + for i in six.moves.range(len(conv_num_filter)): local_conv_act = conv_act if conv_with_batchnorm[i]: local_conv_act = None diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index eabe6bb901..97849672b2 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -19,6 +19,7 @@ from . import framework from . import executor import warnings import sys +import six import os __all__ = ['ParallelExecutor', 'ExecutionStrategy', 'BuildStrategy'] @@ -95,7 +96,7 @@ class ParallelExecutor(object): self._places = [] self._act_places = [] if use_cuda: - for i in range(core.get_cuda_device_count()): + for i in six.moves.range(core.get_cuda_device_count()): p = core.Place() self._act_places.append(core.CUDAPlace(i)) p.set_place(self._act_places[-1]) @@ -103,7 +104,7 @@ class ParallelExecutor(object): else: cpu_num = int( os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - for i in range(cpu_num): + for i in six.moves.range(cpu_num): p = core.Place() self._act_places.append(core.CPUPlace()) p.set_place(self._act_places[-1]) diff --git a/python/paddle/fluid/tests/demo/pyreader.py b/python/paddle/fluid/tests/demo/pyreader.py index 8206540193..737644a25f 100644 --- a/python/paddle/fluid/tests/demo/pyreader.py +++ b/python/paddle/fluid/tests/demo/pyreader.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy +import six import paddle import paddle.dataset.mnist as mnist @@ -31,7 +32,7 @@ def network(is_train): hidden = img - for i in xrange(2): + for i in six.moves.xrange(2): hidden = fluid.layers.fc(input=hidden, size=100, act='tanh') hidden = fluid.layers.dropout( hidden, dropout_prob=0.5, is_test=not is_train) @@ -74,7 +75,7 @@ def main(): test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 512)) - for epoch_id in xrange(10): + for epoch_id in six.moves.xrange(10): train_reader.start() try: while True: diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py index ee8020a735..6bd4ecbbe1 100644 --- a/python/paddle/fluid/tests/unittests/dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/dist_transformer.py @@ -22,6 +22,7 @@ import paddle.fluid as fluid from paddle.fluid import core import os import sys +import six import transformer_model import paddle.dataset.wmt16 as wmt16 @@ -222,7 +223,7 @@ class DistTransformer2x2(object): first_loss, = exe.run(fetch_list=[avg_cost.name]) print(first_loss) - for i in xrange(5): + for i in six.moves.xrange(5): _ = exe.run(fetch_list=[avg_cost.name]) last_loss, = exe.run(fetch_list=[avg_cost.name]) print(last_loss) diff --git a/python/paddle/fluid/tests/unittests/test_split_ids_op.py b/python/paddle/fluid/tests/unittests/test_split_ids_op.py index ca78613098..20bba3ac33 100644 --- a/python/paddle/fluid/tests/unittests/test_split_ids_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_ids_op.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import six from op_test import OpTest import paddle.fluid.core as core from paddle.fluid.op import Operator @@ -59,7 +60,7 @@ class TestSpliteIds(unittest.TestCase): x_tensor = x.get_tensor() x_tensor.set(np_array, place) - outs_name = ["out%d" % i for i in xrange(3)] + outs_name = ["out%d" % i for i in six.moves.xrange(3)] outs = [ scope.var(var_name).get_selected_rows() for var_name in outs_name ] From da2cc99f67480858f444765773fa9ff6be6835a2 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 6 Aug 2018 17:46:43 +0800 Subject: [PATCH 10/94] sampling op optimize --- paddle/fluid/operators/sampling_id_op.cc | 7 ++- paddle/fluid/operators/sampling_id_op.cu | 13 ------ paddle/fluid/operators/sampling_id_op.h | 2 +- .../tests/unittests/test_sampling_id_op.py | 45 +++++++++++++++++++ 4 files changed, 49 insertions(+), 18 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_sampling_id_op.py diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index 9729537d1e..17f6461fcb 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -25,9 +25,9 @@ class SamplingIdOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of RowConvOp should not be null."); + "Input(X) of SamplingIdOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of RowConvOp should not be null."); + "Output(Out) of SamplingIdOp should not be null."); auto input_dims = ctx->GetInputDim("X"); @@ -43,8 +43,7 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The input tensor of softmax. " "2-D with shape [batch_size, input_feature_dimensions]."); - AddOutput("Out", "Sliced data tensor."); - + AddOutput("Out", "SamplingId data tensor."); AddComment(R"DOC( SamplingId Operator. @brief A layer for sampling id from multinomial distribution from the diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index e467165b6d..c0bb9c916c 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -16,19 +16,6 @@ limitations under the License. */ #include #include "paddle/fluid/operators/sampling_id_op.h" -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -class SamplingIdOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override {} -} -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( sampling_id, diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 5bb1991fc5..4d962b4809 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -50,7 +50,7 @@ class SamplingIdKernel : public framework::OpKernel { std::vector out_dim; out_dim.push_back(static_cast(batch_size)); - Tensor* output = context.Output("Output"); + Tensor* output = context.Output("Out"); output->Resize(framework::make_ddim(out_dim)); output->mutable_data(context.GetPlace()); framework::TensorFromVector(ids, context.device_context(), output); diff --git a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py new file mode 100644 index 0000000000..86d86acfb5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py @@ -0,0 +1,45 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + +import paddle.fluid.core as core +from paddle.fluid.op import Operator + + +class TestSamplingIdOp(OpTest): + def setUp(self): + self.op_type = "sampling_id" + self.use_mkldnn = False + self.init_kernel_type() + X = np.random.random((3, 4)).astype('float32') + self.inputs = {"X": X} + Y = np.random.random(3).astype('float32') + self.outputs = {'Out': Y} + self.attrs = {'use_mkldnn': self.use_mkldnn} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_kernel_type(self): + pass + + +if __name__ == "__main__": + unittest.main() From 4cd504d3b4fbab768ea8720830cd5048612e510d Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 6 Aug 2018 21:07:31 +0800 Subject: [PATCH 11/94] bug fix --- paddle/fluid/operators/sampling_id_op.h | 12 ++++++++---- .../fluid/tests/unittests/test_sampling_id_op.py | 15 ++++++++------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 4d962b4809..3d724e3ae7 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -13,7 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include +#include #include +#include #include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" @@ -34,17 +38,17 @@ class SamplingIdKernel : public framework::OpKernel { std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); - std::vector ids(batch_size); + std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { double r = this->get_rand(); - int id = width - 1; + int idx = width - 1; for (int j = 0; j < width; ++j) { if ((r -= ins_vector[i * width + j]) < 0) { - id = j; + idx = j; break; } } - ids[i] = id; + ids[i] = ins_vector[i * width + idx]; } std::vector out_dim; diff --git a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py index 86d86acfb5..e3e7153049 100644 --- a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py +++ b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py @@ -25,17 +25,18 @@ class TestSamplingIdOp(OpTest): self.op_type = "sampling_id" self.use_mkldnn = False self.init_kernel_type() - X = np.random.random((3, 4)).astype('float32') - self.inputs = {"X": X} - Y = np.random.random(3).astype('float32') - self.outputs = {'Out': Y} + self.X = np.random.random((8, 4)).astype('float32') + self.inputs = {"X": self.X} + self.Y = np.random.random(8).astype('float32') + self.outputs = {'Out': self.Y} self.attrs = {'use_mkldnn': self.use_mkldnn} def test_check_output(self): - self.check_output() + self.check_output_customized(self.verify_output) - def test_check_grad(self): - self.check_grad(['X'], 'Out') + def verify_output(self, outs): + out = np.array(outs[0]) + self.assertEqual(len(out), len(self.Y)) def init_kernel_type(self): pass From 5b9716d1f6f8653a4785e50582548f473df21a0a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 6 Aug 2018 21:22:35 +0800 Subject: [PATCH 12/94] add dims check --- paddle/fluid/operators/sampling_id_op.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index 17f6461fcb..d13eeabcb9 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -30,6 +30,8 @@ class SamplingIdOp : public framework::OperatorWithKernel { "Output(Out) of SamplingIdOp should not be null."); auto input_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(input_dims.size() == 2, + "Input(X, Filter) should be 2-D tensor."); framework::DDim dims = input_dims; ctx->SetOutputDim("Out", dims); @@ -46,10 +48,8 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "SamplingId data tensor."); AddComment(R"DOC( SamplingId Operator. - @brief A layer for sampling id from multinomial distribution from the - input layer. Sampling one id for one sample. The result is stored in - output_.ids. -)DOC"); +A layer for sampling id from multinomial distribution from the + input layer. Sampling one id for one sample.)DOC"); } }; } // namespace operators From 9c63fef63ca721a8e69c723314040fb9e9a5ad3d Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 6 Aug 2018 22:01:46 +0800 Subject: [PATCH 13/94] random optimize --- paddle/fluid/operators/sampling_id_op.h | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 3d724e3ae7..7ad25fa13a 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -40,7 +40,7 @@ class SamplingIdKernel : public framework::OpKernel { std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { - double r = this->get_rand(); + double r = this->getRandReal(); int idx = width - 1; for (int j = 0; j < width; ++j) { if ((r -= ins_vector[i * width + j]) < 0) { @@ -60,17 +60,23 @@ class SamplingIdKernel : public framework::OpKernel { framework::TensorFromVector(ids, context.device_context(), output); } - double get_rand() const { + private: + double getRandReal() const { + std::call_once(init_flag_, &SamplingIdKernel::getRndInstance); + return rnd(); + } + + static void getRndInstance() { // Will be used to obtain a seed for the random number engine std::random_device rd; // Standard mersenne_twister_engine seeded with rd() std::mt19937 gen(rd()); std::uniform_real_distribution<> dis(0, 1); - return dis(gen); + rnd = std::bind(dis, gen); } - private: - unsigned int defaultSeed = 0; + static std::once_flag init_flag_; + static std::function<> rnd; }; } // namespace operators } // namespace paddle From b30bdde15a841b2918f7ef8125f1afd3672a322d Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 6 Aug 2018 22:15:10 +0800 Subject: [PATCH 14/94] random optimize --- paddle/fluid/operators/sampling_id_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 7ad25fa13a..7f3ca8e761 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -72,11 +72,11 @@ class SamplingIdKernel : public framework::OpKernel { // Standard mersenne_twister_engine seeded with rd() std::mt19937 gen(rd()); std::uniform_real_distribution<> dis(0, 1); - rnd = std::bind(dis, gen); + rnd = std::bind(dis, std::ref(gen)); } static std::once_flag init_flag_; - static std::function<> rnd; + static std::function rnd; }; } // namespace operators } // namespace paddle From 6abe819f0758a9cdb55aded43a119cdad491617c Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 7 Aug 2018 23:15:32 +0800 Subject: [PATCH 15/94] Fix pybind11 problem Fix str and bytes problem Fix sorted problem Fix math problem Fix CI problem --- paddle/fluid/framework/attribute.h | 10 ++- paddle/fluid/framework/op_desc.cc | 51 +++++++++++++ paddle/fluid/pybind/protobuf.cc | 6 +- paddle/fluid/pybind/pybind.cc | 2 + python/paddle/dataset/cifar.py | 2 +- python/paddle/dataset/common.py | 3 +- python/paddle/dataset/image.py | 44 +++++------ python/paddle/dataset/mnist.py | 21 ++++-- python/paddle/dataset/uci_housing.py | 2 +- python/paddle/dataset/wmt16.py | 3 +- python/paddle/fluid/backward.py | 37 +++++----- python/paddle/fluid/compat.py | 74 +++++++++++++++++++ python/paddle/fluid/framework.py | 34 ++++----- python/paddle/fluid/graphviz.py | 3 +- python/paddle/fluid/io.py | 16 +--- python/paddle/fluid/layers/io.py | 2 +- python/paddle/fluid/layers/nn.py | 30 ++++---- python/paddle/fluid/parallel_executor.py | 13 +++- .../cifar10_small_test_set.py | 4 +- .../tests/book/test_image_classification.py | 2 +- .../paddle/fluid/tests/unittests/op_test.py | 17 ++--- .../tests/unittests/test_data_balance.py | 2 +- .../fluid/tests/unittests/test_dist_base.py | 7 +- .../fluid/tests/unittests/test_pool2d_op.py | 16 ++-- .../tests/unittests/test_reader_reset.py | 2 +- .../unittests/test_reorder_lod_tensor.py | 4 +- .../fluid/tests/unittests/test_roi_pool_op.py | 23 +++--- .../fluid/tests/unittests/test_unpool_op.py | 8 +- .../fluid/tests/unittests/test_warpctc_op.py | 2 +- .../fluid/transpiler/details/program_utils.py | 4 +- .../fluid/transpiler/distribute_transpiler.py | 2 +- tools/test_runner.py | 1 + 32 files changed, 292 insertions(+), 155 deletions(-) create mode 100644 python/paddle/fluid/compat.py diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 8428bf8e33..ea91ac2bb0 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -82,7 +82,10 @@ class DefaultValueSetter { public: explicit DefaultValueSetter(T default_value) : default_value_(default_value) {} - void operator()(T& value) const { value = default_value_; } + void operator()(T* value) const { + PADDLE_ENFORCE(value != nullptr, "Can not set default value to nullptr"); + *value = default_value_; + } private: T default_value_; @@ -199,6 +202,7 @@ struct ExtractAttribute { template class TypedAttrChecker { typedef std::function ValueChecker; + typedef std::function ValueSetter; public: explicit TypedAttrChecker(const std::string& attr_name) @@ -241,7 +245,7 @@ class TypedAttrChecker { "Attribute '%s' is required!", attr_name_); // default_value_setter_ has no more than one element T val; - (default_value_setter_[0])(val); + (default_value_setter_[0])(&val); attr_map[attr_name_] = val; } Attribute& attr = attr_map.at(attr_name_); @@ -255,7 +259,7 @@ class TypedAttrChecker { private: std::string attr_name_; std::vector value_checkers_; - std::vector default_value_setter_; + std::vector default_value_setter_; }; // check whether op's all attributes fit their own limits diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index a190199f1c..984ea3a3dd 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -202,6 +202,57 @@ std::vector OpDesc::AttrNames() const { } void OpDesc::SetAttr(const std::string &name, const Attribute &v) { + // NOTICE(minqiyang): pybind11 will take the empty list in python as + // the std::vector type in C++; so we have to change the attr's type + // here if we meet this issue + proto::AttrType attr_type = static_cast(v.which() - 1); + if (attr_type == proto::AttrType::INTS && + boost::get>(v).size() == 0u) { + proto::OpProto proto = OpInfoMap::Instance().Get(Type()).Proto(); + // Find current attr via attr name and set the correct attribute value + for (int i = 0; i != proto.attrs_size(); ++i) { + const proto::OpProto::Attr &attr = proto.attrs(i); + if (attr.name() == name) { + switch (attr.type()) { + case proto::AttrType::BOOLEANS: { + VLOG(11) << "SetAttr: " << Type() << ", " << name + << " from INTS to BOOLEANS"; + this->attrs_[name] = std::vector(); + break; + } + case proto::AttrType::INTS: { + VLOG(11) << "SetAttr: " << Type() << ", " << name + << " from INTS to INTS"; + this->attrs_[name] = std::vector(); + break; + } + case proto::AttrType::FLOATS: { + VLOG(11) << "SetAttr: " << Type() << ", " << name + << " from INTS to FLOATS"; + this->attrs_[name] = std::vector(); + break; + } + case proto::AttrType::STRINGS: { + VLOG(11) << "SetAttr: " << Type() << ", " << name + << " from INTS to STRINGS"; + this->attrs_[name] = std::vector(); + break; + } + case proto::AttrType::BLOCKS: { + VLOG(11) << "SetAttr: " << Type() << ", " << name + << " from INTS to BLOCKS"; + this->SetBlocksAttr(name, std::vector()); + return; + } + default: + PADDLE_THROW("Wrong attr type %d", attr.type()); + } + need_update_ = true; + return; + } + } + } + this->attrs_[name] = v; need_update_ = true; } diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 2199f5311f..2372db9715 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -205,11 +205,7 @@ void BindBlockDesc(pybind11::module *m) { void BindVarDsec(pybind11::module *m) { pybind11::class_ var_desc(*m, "VarDesc", ""); var_desc - .def("name", - [](pd::VarDesc &self) { - pybind11::bytes name = self.Name(); - return name; - }, + .def("name", [](pd::VarDesc &self) { return self.Name(); }, pybind11::return_value_policy::reference) .def("set_name", &pd::VarDesc::SetName) .def("set_shape", &pd::VarDesc::SetShape) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 2320f3e4db..8e6412fc86 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -54,6 +54,8 @@ limitations under the License. */ #include "paddle/fluid/platform/gpu_info.h" #endif +#include "pybind11/stl.h" + // disable auto conversion to list in Python PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray); diff --git a/python/paddle/dataset/cifar.py b/python/paddle/dataset/cifar.py index f6b4ff8fbd..e399b5215f 100644 --- a/python/paddle/dataset/cifar.py +++ b/python/paddle/dataset/cifar.py @@ -53,7 +53,7 @@ def reader_creator(filename, sub_name, cycle=False): yield (sample / 255.0).astype(numpy.float32), int(label) def reader(): - with tarfile.open(filename, mode='r') as f: + with tarfile.open(filename, mode='rb') as f: names = (each_item.name for each_item in f if sub_name in each_item.name) diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index 6195cc50df..1161a57059 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -20,6 +20,7 @@ import shutil import sys import importlib import paddle.dataset +import paddle.fluid.compat as cpt import six.moves.cPickle as pickle import glob @@ -93,7 +94,7 @@ def download(url, module_name, md5sum, save_name=None): total_length = int(total_length) for data in r.iter_content(chunk_size=4096): dl += len(data) - f.write(data) + f.write(cpt.to_literal_str(data)) done = int(50 * dl / total_length) sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50 - done))) diff --git a/python/paddle/dataset/image.py b/python/paddle/dataset/image.py index 3b3d89c93c..f7e7c854fe 100644 --- a/python/paddle/dataset/image.py +++ b/python/paddle/dataset/image.py @@ -56,7 +56,7 @@ def batch_images_from_tar(data_file, :type data_file: string :param dataset_name: 'train','test' or 'valid' :type dataset_name: string - :param img2label: a dic with image file name as key + :param img2label: a dic with image file name as key and image's label as value :type img2label: dic :param num_per_batch: image number per batch file @@ -88,7 +88,7 @@ def batch_images_from_tar(data_file, output['data'] = data pickle.dump( output, - open('%s/batch_%d' % (out_path, file_id), 'w'), + open('%s/batch_%d' % (out_path, file_id), 'wb'), protocol=pickle.HIGHEST_PROTOCOL) file_id += 1 data = [] @@ -99,7 +99,7 @@ def batch_images_from_tar(data_file, output['data'] = data pickle.dump( output, - open('%s/batch_%d' % (out_path, file_id), 'w'), + open('%s/batch_%d' % (out_path, file_id), 'wb'), protocol=pickle.HIGHEST_PROTOCOL) with open(meta_file, 'a') as meta: @@ -113,7 +113,7 @@ def load_image_bytes(bytes, is_color=True): Load an color or gray image from bytes array. Example usage: - + .. code-block:: python with open('cat.jpg') as f: @@ -137,7 +137,7 @@ def load_image(file, is_color=True): Load an color or gray image from the file path. Example usage: - + .. code-block:: python im = load_image('cat.jpg') @@ -161,16 +161,16 @@ def load_image(file, is_color=True): def resize_short(im, size): - """ + """ Resize an image so that the length of shorter edge is size. Example usage: - + .. code-block:: python im = load_image('cat.jpg') im = resize_short(im, 256) - + :param im: the input image with HWC layout. :type im: ndarray :param size: the shorter edge size of image after resizing. @@ -193,17 +193,17 @@ def to_chw(im, order=(2, 0, 1)): according the order (2,0,1). Example usage: - + .. code-block:: python im = load_image('cat.jpg') im = resize_short(im, 256) im = to_chw(im) - + :param im: the input image with HWC layout. :type im: ndarray :param order: the transposed order. - :type order: tuple|list + :type order: tuple|list """ assert len(im.shape) == len(order) im = im.transpose(order) @@ -215,11 +215,11 @@ def center_crop(im, size, is_color=True): Crop the center of image with size. Example usage: - + .. code-block:: python im = center_crop(im, 224) - + :param im: the input image with HWC layout. :type im: ndarray :param size: the cropping size. @@ -243,11 +243,11 @@ def random_crop(im, size, is_color=True): Randomly crop input image with size. Example usage: - + .. code-block:: python im = random_crop(im, 224) - + :param im: the input image with HWC layout. :type im: ndarray :param size: the cropping size. @@ -272,11 +272,11 @@ def left_right_flip(im, is_color=True): Return the flipped image. Example usage: - + .. code-block:: python im = left_right_flip(im) - + :param im: input image with HWC layout or HW layout for gray image :type im: ndarray :param is_color: whether input image is color or not @@ -299,7 +299,7 @@ def simple_transform(im, resizing, croping and flipping. Example usage: - + .. code-block:: python im = simple_transform(im, 256, 224, True) @@ -314,7 +314,7 @@ def simple_transform(im, :type is_train: bool :param is_color: whether the image is color or not. :type is_color: bool - :param mean: the mean values, which can be element-wise mean values or + :param mean: the mean values, which can be element-wise mean values or mean values per channel. :type mean: numpy array | list """ @@ -332,7 +332,7 @@ def simple_transform(im, im = im.astype('float32') if mean is not None: mean = np.array(mean, dtype=np.float32) - # mean value, may be one value per channel + # mean value, may be one value per channel if mean.ndim == 1 and is_color: mean = mean[:, np.newaxis, np.newaxis] elif mean.ndim == 1: @@ -357,7 +357,7 @@ def load_and_transform(filename, for the transform operations. Example usage: - + .. code-block:: python im = load_and_transform('cat.jpg', 256, 224, True) @@ -372,7 +372,7 @@ def load_and_transform(filename, :type is_train: bool :param is_color: whether the image is color or not. :type is_color: bool - :param mean: the mean values, which can be element-wise mean values or + :param mean: the mean values, which can be element-wise mean values or mean values per channel. :type mean: numpy array | list """ diff --git a/python/paddle/dataset/mnist.py b/python/paddle/dataset/mnist.py index 55e82fa755..28e6a04795 100644 --- a/python/paddle/dataset/mnist.py +++ b/python/paddle/dataset/mnist.py @@ -21,6 +21,8 @@ import paddle.dataset.common import subprocess import numpy import platform +import six +import tempfile from six.moves import range __all__ = ['train', 'test', 'convert'] @@ -46,23 +48,28 @@ def reader_creator(image_filename, label_filename, buffer_size): # According to http://stackoverflow.com/a/38061619/724872, we # cannot use standard package gzip here. - m = subprocess.Popen([zcat_cmd, image_filename], stdout=subprocess.PIPE) - m.stdout.read(16) # skip some magic bytes + tmp_image_file = tempfile.TemporaryFile(prefix='paddle_dataset') + m = subprocess.Popen( + [zcat_cmd, image_filename], stdout=tmp_image_file).communicate() + tmp_image_file.seek(16) # skip some magic bytes - l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE) - l.stdout.read(8) # skip some magic bytes + # Python3 will not take stdout as file + tmp_label_file = tempfile.TemporaryFile(prefix='paddle_dataset') + l = subprocess.Popen( + [zcat_cmd, label_filename], stdout=tmp_label_file).communicate() + tmp_label_file.seek(8) # skip some magic bytes try: # reader could be break. while True: labels = numpy.fromfile( - l.stdout, 'ubyte', count=buffer_size).astype("int") + tmp_label_file, 'ubyte', count=buffer_size).astype("int") if labels.size != buffer_size: break # numpy.fromfile returns empty slice after EOF. images = numpy.fromfile( - m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape( - (buffer_size, 28 * 28)).astype('float32') + tmp_image_file, 'ubyte', count=buffer_size * 28 * + 28).reshape((buffer_size, 28 * 28)).astype('float32') images = images / 255.0 * 2.0 - 1.0 diff --git a/python/paddle/dataset/uci_housing.py b/python/paddle/dataset/uci_housing.py index cc946762da..2ba8ddcc1f 100644 --- a/python/paddle/dataset/uci_housing.py +++ b/python/paddle/dataset/uci_housing.py @@ -71,7 +71,7 @@ def load_data(filename, feature_num=14, ratio=0.8): return data = np.fromfile(filename, sep=' ') - data = data.reshape(data.shape[0] / feature_num, feature_num) + data = data.reshape(data.shape[0] // feature_num, feature_num) maximums, minimums, avgs = data.max(axis=0), data.min(axis=0), data.sum( axis=0) / data.shape[0] feature_range(maximums[:-1], minimums[:-1]) diff --git a/python/paddle/dataset/wmt16.py b/python/paddle/dataset/wmt16.py index 4e3c466c38..186f9476d8 100644 --- a/python/paddle/dataset/wmt16.py +++ b/python/paddle/dataset/wmt16.py @@ -29,6 +29,7 @@ Multi30K: Multilingual English-German Image Descriptions. """ import os +import six import tarfile import gzip from collections import defaultdict @@ -120,7 +121,7 @@ def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang): with tarfile.open(tar_file, mode="r") as f: for line in f.extractfile(file_name): - line_split = line.strip().split("\t") + line_split = line.strip().split(six.b("\t")) if len(line_split) != 2: continue src_words = line_split[src_col].split() diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index f33fa7218b..6430d3a264 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -17,6 +17,7 @@ from . import core import collections import copy import six +from . import compat as cpt from . import unique_name __all__ = ['append_backward'] @@ -75,10 +76,10 @@ def _infer_var_data_type_(grad_var_name, block): """ Infer the data type of given grad variable """ - grad_var = block.desc.find_var(grad_var_name.encode("ascii")) - fwd_name = _strip_grad_suffix_(grad_var_name.encode("ascii")) - if block.desc.has_var_recursive(fwd_name): - fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii")) + grad_var = block.desc.find_var(cpt.to_bytes(grad_var_name)) + fwd_name = _strip_grad_suffix_(grad_var_name) + if block.desc.has_var_recursive(cpt.to_bytes(fwd_name)): + fwd_var = block.desc.find_var_recursive(cpt.to_bytes(fwd_name)) grad_var.set_dtype(fwd_var.dtype()) else: grad_var.set_dtype(core.VarDesc.VarType.FP32) @@ -102,8 +103,10 @@ def _some_in_set_(cands, s): """ if len(cands) == 0: return False - for c in cands: - if c in s: + literal_set = cpt.to_literal_str(s) + literal_cands = cpt.to_literal_str(cands) + for c in literal_cands: + if c in literal_set: return True return False @@ -114,9 +117,8 @@ def _strip_grad_suffix_(name): e.g. x@GRAD ==> x y@GRAD@RENAME@1 ==> y """ - if isinstance(name, six.text_type): - name = name.encode() - pos = name.find(six.b(core.grad_var_suffix())) + name = cpt.to_literal_str(name) + pos = name.find(core.grad_var_suffix()) return name[:pos] if pos != -1 else name @@ -125,9 +127,7 @@ def _append_grad_suffix_(name): Append grad suffix to the given variable name e.g. x ==> x@GRAD """ - if isinstance(name, six.text_type): - name = name.encode() - return name + six.b(core.grad_var_suffix()) + return cpt.to_literal_str(name) + core.grad_var_suffix() def _addup_repetitive_outputs_(op_descs): @@ -364,7 +364,8 @@ def _append_backward_ops_(block, # Getting op's corresponding grad_op grad_op_desc, op_grad_to_var = core.get_grad_op_desc( - op.desc, no_grad_dict[block.idx], grad_sub_block_list) + op.desc, + cpt.to_literal_str(no_grad_dict[block.idx]), grad_sub_block_list) grad_op_descs.extend(grad_op_desc) grad_to_var.update(op_grad_to_var) @@ -411,11 +412,10 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): new_vars = set() # create new gradient variables for grad_var_name in op_desc.output_arg_names(): - grad_var_name = grad_var_name.encode("ascii") - if block.desc.has_var_recursive( - grad_var_name) or grad_var_name == core.empty_var_name(): + if block.desc.has_var_recursive(cpt.to_bytes( + grad_var_name)) or grad_var_name == core.empty_var_name(): continue - block.desc.var(grad_var_name) + block.desc.var(cpt.to_bytes(grad_var_name)) new_vars.add(grad_var_name) if grad_var_name not in grad_to_var: continue @@ -597,11 +597,12 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, parameters = parameter_list else: params = program.global_block().all_parameters() + program.global_block().iter_parameters() parameters = [param.name for param in params] params_and_grads = [] for param in parameters: - if param not in grad_info_map: + if cpt.to_literal_str(param) not in grad_info_map: continue grad_info = grad_info_map[param] grad_block = grad_info[1] diff --git a/python/paddle/fluid/compat.py b/python/paddle/fluid/compat.py new file mode 100644 index 0000000000..05633583cc --- /dev/null +++ b/python/paddle/fluid/compat.py @@ -0,0 +1,74 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six + + +# str and bytes related functions +def to_literal_str(obj): + if isinstance(obj, list): + return [_to_literal_str(item) for item in obj] + elif isinstance(obj, set): + return set([_to_literal_str(item) for item in obj]) + else: + return _to_literal_str(obj) + + +def _to_literal_str(obj): + if isinstance(obj, six.binary_type): + return obj.decode('latin-1') + elif isinstance(obj, six.text_type): + return obj + else: + return six.u(obj) + + +def to_bytes(obj): + if isinstance(obj, list): + return [_to_bytes(item) for item in obj] + elif isinstance(obj, set): + return set([_to_bytes(item) for item in obj]) + else: + return _to_bytes(obj) + + +def _to_bytes(obj): + if isinstance(obj, six.text_type): + return obj.encode('latin-1') + elif isinstance(obj, six.binary_type): + return obj + else: + return six.b(obj) + + +# math related functions +import math + + +def round(x, d=0): + """ + Compatible round which act the same behaviour in Python3. + + Args: + x(float) : The number to round halfway. + + Returns: + round result of x + """ + p = 10**d + return float(math.floor((x * p) + math.copysign(0.5, x))) / p + + +def floor_division(x, y): + return x // y diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 9a2c8adc03..4e08836471 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -19,6 +19,7 @@ import six import numpy as np +from . import compat as cpt from .proto import framework_pb2 try: from . import core @@ -87,7 +88,7 @@ def convert_np_dtype_to_dtype_(np_dtype): elif dtype == np.uint8: return core.VarDesc.VarType.UINT8 else: - raise ValueError("Not supported numpy dtype " + six.binary_type(dtype)) + raise ValueError("Not supported numpy dtype %s" % dtype) def dtype_is_floating(dtype): @@ -198,11 +199,11 @@ class Variable(object): if name is None: name = unique_name.generate('_generated_var') is_new_var = False - name = name if isinstance(name, six.binary_type) else name.encode() - self.desc = self.block.desc.find_var(name) + name = cpt.to_literal_str(name) + self.desc = self.block.desc.find_var(cpt.to_bytes(name)) if self.desc is None: - self.desc = self.block.desc.var(name) + self.desc = self.block.desc.var(cpt.to_bytes(name)) is_new_var = True if is_new_var: @@ -325,7 +326,7 @@ class Variable(object): @property def name(self): - return self.desc.name() + return cpt.to_literal_str(self.desc.name()) @name.setter def name(self, new_name): @@ -529,10 +530,7 @@ class Operator(object): elif isinstance(arg, six.binary_type): in_arg_names.append(arg.decode()) else: - if isinstance(arg.name, six.string_types): - in_arg_names.append(arg.name) - elif isinstance(arg.name, six.binary_type): - in_arg_names.append(arg.name.decode()) + in_arg_names.append(cpt.to_literal_str(arg.name)) self.desc.set_input(in_proto.name, in_arg_names) else: self.desc.set_input(in_proto.name, []) @@ -561,12 +559,7 @@ class Operator(object): (out_proto.name, len(out_args))) out_arg_names = [] for arg in out_args: - if isinstance(arg.name, six.string_types): - out_arg_names.append(arg.name) - elif isinstance(arg.name, six.binary_type): - out_arg_names.append(arg.name.decode()) - else: - out_arg_names.append(six.u(arg.name)) + out_arg_names.append(cpt.to_literal_str(arg.name)) arg.op = self self.desc.set_output(out_proto.name, out_arg_names) @@ -994,6 +987,9 @@ class Block(object): Returns: Variable: the Variable with the giving name. """ + name = cpt.to_literal_str(name) + new_name = cpt.to_literal_str(new_name) + if not self.has_var(name): raise ValueError("var %s is not in current block" % name) v = self.var(name) @@ -1012,9 +1008,9 @@ class Block(object): else: raise ValueError("unsupported var type: %s", type(v)) orig_var_type = v.type - self.desc._rename_var(name, new_name) + self.desc._rename_var(cpt.to_bytes(name), cpt.to_bytes(new_name)) # NOTE: v is destroyed by C++ after calling _rename_var. - d = self.desc.find_var(new_name) + d = self.desc.find_var(cpt.to_bytes(new_name)) if var_type == "Parameter": var = Parameter( self, @@ -1045,7 +1041,7 @@ class Block(object): def _remove_var(self, name): self._sync_with_cpp() - self.desc._remove_var(name) + self.desc._remove_var(cpt.to_bytes(name)) del self.vars[name] def create_parameter(self, *args, **kwargs): @@ -1128,7 +1124,7 @@ class Block(object): # sync variables removed from c++ end for var in list(self.vars.keys()): - if not self.desc.find_var(var): + if not self.desc.find_var(cpt.to_bytes(var)): self.vars.pop(var) # sync operators from cpp diff --git a/python/paddle/fluid/graphviz.py b/python/paddle/fluid/graphviz.py index ba67bf5ae6..0557d7fd8a 100644 --- a/python/paddle/fluid/graphviz.py +++ b/python/paddle/fluid/graphviz.py @@ -106,7 +106,8 @@ class Graph(object): def _rank_repr(self): ranks = sorted( list(self.rank_groups.items()), - cmp=lambda a, b: a[1].priority > b[1].priority) + key=functools.cmp_to_key( + lambda a, b: a[1].priority > b[1].priority)) repr = [] for x in ranks: repr.append(str(x[1])) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 55e517f1f4..78e5ef30cc 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -600,25 +600,15 @@ def save_inference_model(dirname, # "./infer_model". """ - if isinstance(feeded_var_names, six.binary_type): + if isinstance(feeded_var_names, six.string_types): feeded_var_names = [feeded_var_names] - elif isinstance(feeded_var_names, six.text_type): - feeded_var_names = [feeded_var_names.encode()] else: if len(feeded_var_names) > 0: # TODO(paddle-dev): polish these code blocks if not (bool(feeded_var_names) and all( - isinstance(name, six.binary_type) + isinstance(name, six.string_types) for name in feeded_var_names)): - if not (all( - isinstance(name, six.text_type) - for name in feeded_var_names)): - raise ValueError( - "'feed_var_names' should be a list of str.") - else: - feeded_var_names = [ - name.encode() for name in feeded_var_names - ] + raise ValueError("'feed_var_names' should be a list of str.") if isinstance(target_vars, Variable): target_vars = [target_vars] diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index f9b01203e2..bac641327d 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -751,7 +751,7 @@ def open_files(filenames, else: buffer_size = int(buffer_size) - if isinstance(filenames, basestring): + if isinstance(filenames, six.string_types): filenames = [filenames] dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] shape_concat = [] diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a82fdf41a6..d1ae284d54 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -360,7 +360,7 @@ def dynamic_lstm(input, """ helper = LayerHelper('lstm', **locals()) - size = size / 4 + size = size // 4 weight = helper.create_parameter( attr=helper.param_attr, shape=[size, 4 * size], dtype=dtype) bias_size = [1, 7 * size] @@ -1498,7 +1498,7 @@ def conv2d(input, raise ValueError("use_cudnn should be True or False") input_shape = input.shape - filter_shape = [num_filters, num_filter_channels] + filter_size + filter_shape = [num_filters, int(num_filter_channels)] + filter_size def _get_default_param_initializer(): std = (2.0 / (filter_size[0]**2 * num_channels))**0.5 @@ -2669,15 +2669,15 @@ def beam_search(pre_ids, Refer to `Beam search `_ for more details. - - This layer does the search in beams for one time step. Specifically, it + + This layer does the search in beams for one time step. Specifically, it selects the top-K candidate word ids of current step from :attr:`ids` according to their :attr:`scores` for all source sentences, where K is :attr:`beam_size` and :attr:`ids, scores` are predicted results from the computation cell. Additionally, :attr:`pre_ids` and :attr:`pre_scores` are the output of beam_search at previous step, they are needed for special use to handle ended candidate translations. - + Note that the :attr:`scores` passed in should be accumulated scores, and length penalty should be done with extra operators before calculating the accumulated scores if needed, also suggest finding top-K before it and @@ -3878,7 +3878,7 @@ def nce(input, def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None): """ The hierarchical sigmoid operator is used to accelerate the training - process of language model. This operator organizes the classes into a + process of language model. This operator organizes the classes into a complete binary tree, each leaf node represents a class(a word) and each internal node acts as a binary classifier. For each word there's a unique path from root to it's leaf node, hsigmoid calculate the cost for each @@ -3888,9 +3888,9 @@ def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None): Refer to `Hierarchical Probabilistic Neural Network Language Model `_ - + Args: - input (Variable): The input tensor variable with shape + input (Variable): The input tensor variable with shape :math:`[N \\times D]`, where :math:`N` is the size of mini-batch, and :math:`D` is the feature size. label (Variable): The tensor variable contains labels of training data. @@ -3898,7 +3898,7 @@ def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None): num_classes: (int), The number of classes, must not be less than 2. param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable parameters/weights of this layer. - bias_attr (ParamAttr|list of ParamAttr, default None): The parameter + bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias of this layer. If it is set to False, no bias will be applied. @@ -5293,23 +5293,23 @@ def rank_loss(label, left, right, name=None): is a pairwise ranking model with a training sample consisting of a pair of documents, A and B. Label P indicates whether A is ranked higher than B or not: - + P = {0, 1} or {0, 0.5, 1}, where 0.5 means that there is no information about the rank of the input pair. - + Rank loss layer takes three inputs: left (o_i), right (o_j) and label (P_{i,j}). The inputs respectively represent RankNet's output scores for documents A and B and the value of label P. The following equation computes rank loss C_{i,j} from the inputs: - + $$ C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + \log(1 + e^{o_{i,j}}) \\ o_{i,j} = o_i - o_j \\ \tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \} $$ - - Rank loss layer takes batch inputs with size batch_size (batch_size >= 1). - + + Rank loss layer takes batch inputs with size batch_size (batch_size >= 1). + Args: label (Variable): Indicats whether A ranked higher than B or not. left (Variable): RankNet's output score for doc A. diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 97849672b2..7c723ba264 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -17,6 +17,7 @@ import multiprocessing from . import core from . import framework from . import executor +from . import compat as cpt import warnings import sys import six @@ -154,11 +155,14 @@ class ParallelExecutor(object): self.executor = core.ParallelExecutor( self._places, set([ - p.name for p in main.global_block().iter_parameters() + cpt.to_literal_str(p.name) + for p in main.global_block().iter_parameters() if not p.stop_gradient ]), - set(self.persistable_vars), main.desc, loss_name - if loss_name else '', scope, local_scopes, exec_strategy, + set(cpt.to_literal_str(var) + for var in self.persistable_vars), main.desc, + cpt.to_literal_str(loss_name) + if loss_name else six.u(''), scope, local_scopes, exec_strategy, build_strategy, num_trainers, trainer_id) self.scope = scope @@ -270,7 +274,8 @@ class ParallelExecutor(object): self.executor.feed_tensors_into_local_scopes(res) fetch_var_name = '@FETCHED_VAR_NAME@' - self.executor.run(fetch_list, fetch_var_name) + self.executor.run( + cpt.to_literal_str(fetch_list), cpt.to_literal_str(fetch_var_name)) arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() if self.is_dist: diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py index 9e4c384d92..e7b709f31b 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py @@ -30,7 +30,7 @@ images per class. import itertools import numpy -import paddle.v2.dataset.common +import paddle.dataset.common import tarfile from six.moves import cPickle as pickle from six.moves import zip @@ -78,6 +78,6 @@ def train10(batch_size=None): :rtype: callable """ return reader_creator( - paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), + paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'data_batch', batch_size=batch_size) diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index de6fe5f140..b6685fe2c2 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -60,7 +60,7 @@ def resnet_cifar10(input, depth=32): return tmp assert (depth - 2) % 6 == 0 - n = (depth - 2) / 6 + n = (depth - 2) // 6 conv1 = conv_bn_layer( input=input, ch_out=16, filter_size=3, stride=1, padding=1) res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index b27d773f09..1ed14e35b1 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -15,6 +15,7 @@ import unittest import numpy as np import random +import six import time import itertools import collections @@ -26,15 +27,13 @@ from paddle.fluid.op import Operator from paddle.fluid.executor import Executor from paddle.fluid.framework import Program, OpProtoHolder, Variable from testsuite import create_op, set_input, append_input_output, append_loss_ops -from functools import reduce -from six.moves import zip def randomize_probability(batch_size, class_num, dtype='float32'): prob = np.random.uniform( 0.1, 1.0, size=(batch_size, class_num)).astype(dtype) prob_sum = prob.sum(axis=1) - for i in range(len(prob)): + for i in six.moves.xrange(len(prob)): prob[i] /= prob_sum[i] return prob @@ -51,7 +50,7 @@ def get_numeric_gradient(place, set_input(scope, op, inputs, place) def product(dim): - return reduce(lambda a, b: a * b, dim, 1) + return six.moves.reduce(lambda a, b: a * b, dim, 1) def get_output(): sum = [] @@ -103,7 +102,7 @@ def get_numeric_gradient(place, # we only compute gradient of one element each time. # we use a for loop to compute the gradient of every element. - for i in range(tensor_size): + for i in six.moves.xrange(tensor_size): if in_place: set_input(scope, op, inputs, place) @@ -161,7 +160,7 @@ class OpTest(unittest.TestCase): assert isinstance( numpy_dict, dict), "self.inputs, self.outputs must be numpy_dict" - for var_name, var_value in numpy_dict.items(): + for var_name, var_value in six.iteritems(numpy_dict): if isinstance(var_value, (np.ndarray, np.generic)): self.try_call_once(var_value.dtype) elif isinstance(var_value, (list, tuple)): @@ -225,7 +224,7 @@ class OpTest(unittest.TestCase): def _get_io_vars(self, block, numpy_inputs): inputs = {} - for name, value in numpy_inputs.items(): + for name, value in six.iteritems(numpy_inputs): if isinstance(value, list): var_list = [ block.var(sub_name) for sub_name, sub_value in value @@ -268,7 +267,7 @@ class OpTest(unittest.TestCase): # if the fetch_list is customized by user, we use it directly. # if not, fill the fetch_list by the user configured outputs in test. if len(fetch_list) == 0: - for var_name, var in outputs.items(): + for var_name, var in six.iteritems(outputs): if isinstance(var, list): for v in var: fetch_list.append(v) @@ -371,7 +370,7 @@ class OpTest(unittest.TestCase): def __assert_is_close(self, numeric_grads, analytic_grads, names, max_relative_error, msg_prefix): - for a, b, name in zip(numeric_grads, analytic_grads, names): + for a, b, name in six.moves.zip(numeric_grads, analytic_grads, names): abs_a = np.abs(a) abs_a[abs_a < 1e-3] = 1 diff --git a/python/paddle/fluid/tests/unittests/test_data_balance.py b/python/paddle/fluid/tests/unittests/test_data_balance.py index 951282e8ba..d3c7b6e714 100644 --- a/python/paddle/fluid/tests/unittests/test_data_balance.py +++ b/python/paddle/fluid/tests/unittests/test_data_balance.py @@ -14,7 +14,7 @@ import unittest import paddle.fluid as fluid -import paddle.v2 as paddle +import paddle as paddle import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 1aaab6f906..f543a39d83 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -16,9 +16,12 @@ import time import unittest import os import sys +import six import signal import subprocess +import paddle.fluid.compat as cpt + class TestDistBase(unittest.TestCase): def setUp(self): @@ -78,7 +81,7 @@ class TestDistBase(unittest.TestCase): env=env_local) local_proc.wait() out, err = local_proc.communicate() - local_ret = out + local_ret = cpt.to_literal_str(out) sys.stderr.write('local_loss: %s\n' % local_ret) sys.stderr.write('local_stderr: %s\n' % err) @@ -116,7 +119,7 @@ class TestDistBase(unittest.TestCase): tr1_proc.wait() out, err = tr0_proc.communicate() sys.stderr.write('dist_stderr: %s\n' % err) - loss_data0 = out + loss_data0 = cpt.to_literal_str(out) sys.stderr.write('dist_loss: %s\n' % loss_data0) lines = loss_data0.split("\n") dist_first_loss = eval(lines[0].replace(" ", ","))[0] diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_op.py index 1cf70311b4..a75194f34a 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_op.py @@ -29,11 +29,11 @@ def max_pool2D_forward_naive(x, if global_pool == 1: ksize = [H, W] H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1 - ) / strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 * - paddings[0]) / strides[0] + 1 + ) // strides[0] + 1 if ceil_mode else ( + H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1 - ) / strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 * - paddings[1]) / strides[1] + 1 + ) // strides[1] + 1 if ceil_mode else ( + W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 out = np.zeros((N, C, H_out, W_out)) for i in range(H_out): for j in range(W_out): @@ -57,11 +57,11 @@ def avg_pool2D_forward_naive(x, if global_pool == 1: ksize = [H, W] H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1 - ) / strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 * - paddings[0]) / strides[0] + 1 + ) // strides[0] + 1 if ceil_mode else ( + H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1 - ) / strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 * - paddings[1]) / strides[1] + 1 + ) // strides[1] + 1 if ceil_mode else ( + W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 out = np.zeros((N, C, H_out, W_out)) for i in range(H_out): for j in range(W_out): diff --git a/python/paddle/fluid/tests/unittests/test_reader_reset.py b/python/paddle/fluid/tests/unittests/test_reader_reset.py index 3ad85d5748..d3ab991c84 100644 --- a/python/paddle/fluid/tests/unittests/test_reader_reset.py +++ b/python/paddle/fluid/tests/unittests/test_reader_reset.py @@ -13,7 +13,7 @@ # limitations under the License. import paddle.fluid as fluid -import paddle.v2 as paddle +import paddle as paddle import numpy as np import unittest diff --git a/python/paddle/fluid/tests/unittests/test_reorder_lod_tensor.py b/python/paddle/fluid/tests/unittests/test_reorder_lod_tensor.py index 6e1cd56b3e..e51408944c 100644 --- a/python/paddle/fluid/tests/unittests/test_reorder_lod_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_reorder_lod_tensor.py @@ -17,6 +17,7 @@ import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.layers.control_flow import lod_rank_table import numpy +import functools class TestReorderLoDTensor(unittest.TestCase): @@ -101,7 +102,8 @@ class TestReorderLoDTensor(unittest.TestCase): rank_table = [] # list of (index, length) for i in range(len(ref_lod)): rank_table.append((i, ref_lod[i])) - rank_table = sorted(rank_table, lambda x, y: y[1] - x[1]) + rank_table = sorted( + rank_table, key=functools.cmp_to_key(lambda x, y: y[1] - x[1])) # compute the input sequence info according to input_lod input_value, input_lod = self.data[self.data_desc[0][0]] diff --git a/python/paddle/fluid/tests/unittests/test_roi_pool_op.py b/python/paddle/fluid/tests/unittests/test_roi_pool_op.py index df5684ab17..0f38b742d9 100644 --- a/python/paddle/fluid/tests/unittests/test_roi_pool_op.py +++ b/python/paddle/fluid/tests/unittests/test_roi_pool_op.py @@ -16,6 +16,7 @@ import unittest import numpy as np import math import sys +import paddle.fluid.compat as cpt from op_test import OpTest @@ -59,10 +60,10 @@ class TestROIPoolOp(OpTest): for i in range(self.rois_num): roi = self.rois[i] roi_batch_id = roi[0] - roi_start_w = int(round(roi[1] * self.spatial_scale)) - roi_start_h = int(round(roi[2] * self.spatial_scale)) - roi_end_w = int(round(roi[3] * self.spatial_scale)) - roi_end_h = int(round(roi[4] * self.spatial_scale)) + roi_start_w = int(cpt.round(roi[1] * self.spatial_scale)) + roi_start_h = int(cpt.round(roi[2] * self.spatial_scale)) + roi_end_w = int(cpt.round(roi[3] * self.spatial_scale)) + roi_end_h = int(cpt.round(roi[4] * self.spatial_scale)) roi_height = int(max(roi_end_h - roi_start_h + 1, 1)) roi_width = int(max(roi_end_w - roi_start_w + 1, 1)) @@ -97,8 +98,8 @@ class TestROIPoolOp(OpTest): for w in range(wstart, wend): if x_i[c, h, w] > out_data[i, c, ph, pw]: out_data[i, c, ph, pw] = x_i[c, h, w] - argmax_data[i, c, ph, pw] = h * \ - self.width + w + argmax_data[i, c, ph, + pw] = h * self.width + w self.outs = out_data.astype('float32') self.argmaxes = argmax_data.astype('int64') @@ -110,14 +111,14 @@ class TestROIPoolOp(OpTest): self.rois_lod[0].append(bno + 1) for i in range(bno + 1): x1 = np.random.random_integers( - 0, self.width / self.spatial_scale - self.pooled_width) + 0, self.width // self.spatial_scale - self.pooled_width) y1 = np.random.random_integers( - 0, self.height / self.spatial_scale - self.pooled_height) + 0, self.height // self.spatial_scale - self.pooled_height) x2 = np.random.random_integers(x1 + self.pooled_width, - self.width / self.spatial_scale) - y2 = np.random.random_integers(y1 + self.pooled_height, - self.height / self.spatial_scale) + self.width // self.spatial_scale) + y2 = np.random.random_integers( + y1 + self.pooled_height, self.height // self.spatial_scale) roi = [bno, x1, y1, x2, y2] rois.append(roi) diff --git a/python/paddle/fluid/tests/unittests/test_unpool_op.py b/python/paddle/fluid/tests/unittests/test_unpool_op.py index ecce4cdde2..49dc559ed7 100644 --- a/python/paddle/fluid/tests/unittests/test_unpool_op.py +++ b/python/paddle/fluid/tests/unittests/test_unpool_op.py @@ -27,7 +27,7 @@ def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings): for h in range(s2): for w in range(s3): index = indices[nidx, cidx, h, w] - hidx = (index - index % out_wsize) / out_wsize + hidx = (index - index % out_wsize) // out_wsize widx = index % out_wsize out[nidx, cidx, int(hidx), int(widx)] = \ input[nidx, cidx, h, w] @@ -41,9 +41,9 @@ class TestUnpoolOp(OpTest): self.init_test_case() pre_input = np.random.random(self.shape).astype("float32") nsize, csize, hsize, wsize = pre_input.shape - hsize_out = (hsize - self.ksize[0] + 2 * self.paddings[0]) / \ + hsize_out = (hsize - self.ksize[0] + 2 * self.paddings[0]) // \ self.strides[0] + 1 - wsize_out = (wsize - self.ksize[1] + 2 * self.paddings[1]) / \ + wsize_out = (wsize - self.ksize[1] + 2 * self.paddings[1]) // \ self.strides[1] + 1 input = np.zeros((nsize, csize, hsize_out, wsize_out)) indices = np.zeros((nsize, csize, hsize_out, wsize_out)) @@ -62,7 +62,7 @@ class TestUnpoolOp(OpTest): input[nidx, cidx, i, j] = x_masked.max() arg = x_masked.argmax() indices[nidx, cidx, i, j] = \ - (r_start + arg / self.ksize[1]) * wsize + \ + (r_start + arg // self.ksize[1]) * wsize + \ c_start + arg % self.ksize[1] output = self.unpool2d_forward_naive(input, indices, self.ksize, \ self.strides, self.paddings).astype("float32") diff --git a/python/paddle/fluid/tests/unittests/test_warpctc_op.py b/python/paddle/fluid/tests/unittests/test_warpctc_op.py index 9f1aaee472..d647a17692 100644 --- a/python/paddle/fluid/tests/unittests/test_warpctc_op.py +++ b/python/paddle/fluid/tests/unittests/test_warpctc_op.py @@ -132,7 +132,7 @@ class CTCForward(object): for k in range(end - start): j = k + start if j & 1 == 1: - label_idx = j / 2 + label_idx = j // 2 label_val = labels_a_sequence[label_idx, 0] fv = self.log_add(forward_vars[i - 1, j], forward_vars[i - 1, j - 1]) diff --git a/python/paddle/fluid/transpiler/details/program_utils.py b/python/paddle/fluid/transpiler/details/program_utils.py index 76d10777f5..291c8fb27b 100644 --- a/python/paddle/fluid/transpiler/details/program_utils.py +++ b/python/paddle/fluid/transpiler/details/program_utils.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import six + def delete_ops(block, ops): try: start = list(block.ops).index(ops[0]) end = list(block.ops).index(ops[-1]) - [block._remove_op(start) for _ in range(end - start + 1)] + [block._remove_op(start) for _ in six.moves.range(end - start + 1)] except Exception as e: raise e block.program._sync_with_cpp() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 4d6761436e..aca9aafd52 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -1017,7 +1017,7 @@ class DistributeTranspiler(object): for i, block in enumerate(splited): size = block[1] - rows = size / orig_dim1_flatten + rows = size // orig_dim1_flatten splited_shape = [rows] if len(orig_shape) >= 2: splited_shape.extend(orig_shape[1:]) diff --git a/tools/test_runner.py b/tools/test_runner.py index 2d6a3cf8a9..9b9f165e73 100644 --- a/tools/test_runner.py +++ b/tools/test_runner.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import print_function + import unittest import os import sys From e6ae1e4ffce3a39f21a2ca7d0a7d2e9883f83528 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 8 Aug 2018 15:52:43 +0800 Subject: [PATCH 16/94] Replace the dependency of paddle.v2 dataset --- python/paddle/dataset/cifar.py | 2 +- python/paddle/dataset/wmt14.py | 10 ++++++---- python/paddle/dataset/wmt16.py | 9 +++++---- .../tests/unittests/test_data_balance.py | 2 +- .../tests/unittests/test_preprocessor.py | 4 ++-- .../fluid/tests/unittests/test_profiler.py | 2 +- .../tests/unittests/test_protobuf_descs.py | 19 ++++++++++--------- .../tests/unittests/test_reader_reset.py | 2 +- .../tests/unittests/test_recordio_reader.py | 4 ++-- 9 files changed, 29 insertions(+), 25 deletions(-) diff --git a/python/paddle/dataset/cifar.py b/python/paddle/dataset/cifar.py index e399b5215f..f6b4ff8fbd 100644 --- a/python/paddle/dataset/cifar.py +++ b/python/paddle/dataset/cifar.py @@ -53,7 +53,7 @@ def reader_creator(filename, sub_name, cycle=False): yield (sample / 255.0).astype(numpy.float32), int(label) def reader(): - with tarfile.open(filename, mode='rb') as f: + with tarfile.open(filename, mode='r') as f: names = (each_item.name for each_item in f if sub_name in each_item.name) diff --git a/python/paddle/dataset/wmt14.py b/python/paddle/dataset/wmt14.py index 250fd03ffb..7488e21f1f 100644 --- a/python/paddle/dataset/wmt14.py +++ b/python/paddle/dataset/wmt14.py @@ -19,10 +19,12 @@ http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz and parse training set and test set into paddle reader creators. """ +import six import tarfile import gzip import paddle.dataset.common +import paddle.fluid.compat as cpt __all__ = [ 'train', @@ -40,8 +42,8 @@ URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/' 'wmt_shrinked_data/wmt14.tgz') MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c' # BLEU of this trained model is 26.92 -URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' -MD5_MODEL = '0cb4a5366189b6acba876491c8724fa3' +URL_MODEL = 'http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz' +MD5_MODEL = '0791583d57d5beb693b9414c5b36798c' START = "" END = "" @@ -54,7 +56,7 @@ def __read_to_dict(tar_file, dict_size): out_dict = dict() for line_count, line in enumerate(fd): if line_count < size: - out_dict[line.strip()] = line_count + out_dict[cpt.to_literal_str(line.strip())] = line_count else: break return out_dict @@ -85,7 +87,7 @@ def reader_creator(tar_file, file_name, dict_size): ] for name in names: for line in f.extractfile(name): - line_split = line.strip().split('\t') + line_split = line.strip().split(six.b('\t')) if len(line_split) != 2: continue src_seq = line_split[0] # one source sequence diff --git a/python/paddle/dataset/wmt16.py b/python/paddle/dataset/wmt16.py index 186f9476d8..cd34b523eb 100644 --- a/python/paddle/dataset/wmt16.py +++ b/python/paddle/dataset/wmt16.py @@ -35,6 +35,7 @@ import gzip from collections import defaultdict import paddle.dataset.common +import paddle.fluid.compat as cpt __all__ = [ "train", @@ -82,16 +83,16 @@ def __load_dict(tar_file, dict_size, lang, reverse=False): dict_path = os.path.join(paddle.dataset.common.DATA_HOME, "wmt16/%s_%d.dict" % (lang, dict_size)) if not os.path.exists(dict_path) or ( - len(open(dict_path, "r").readlines()) != dict_size): + len(open(dict_path, "rb").readlines()) != dict_size): __build_dict(tar_file, dict_size, dict_path, lang) word_dict = {} - with open(dict_path, "r") as fdict: + with open(dict_path, "rb") as fdict: for idx, line in enumerate(fdict): if reverse: - word_dict[idx] = line.strip() + word_dict[idx] = cpt.to_literal_str(line.strip()) else: - word_dict[line.strip()] = idx + word_dict[cpt.to_literal_str(line.strip())] = idx return word_dict diff --git a/python/paddle/fluid/tests/unittests/test_data_balance.py b/python/paddle/fluid/tests/unittests/test_data_balance.py index d3c7b6e714..09edf05fd7 100644 --- a/python/paddle/fluid/tests/unittests/test_data_balance.py +++ b/python/paddle/fluid/tests/unittests/test_data_balance.py @@ -14,7 +14,7 @@ import unittest import paddle.fluid as fluid -import paddle as paddle +import paddle import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_preprocessor.py b/python/paddle/fluid/tests/unittests/test_preprocessor.py index cbf1a7e0c5..6a82746c61 100644 --- a/python/paddle/fluid/tests/unittests/test_preprocessor.py +++ b/python/paddle/fluid/tests/unittests/test_preprocessor.py @@ -15,9 +15,9 @@ import unittest import numpy as np +import paddle import paddle.fluid as fluid -import paddle.v2 as paddle -import paddle.v2.dataset.mnist as mnist +import paddle.dataset.mnist as mnist class TestPreprocessor(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_profiler.py b/python/paddle/fluid/tests/unittests/test_profiler.py index 9f8d33f9bb..705d01165a 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler.py +++ b/python/paddle/fluid/tests/unittests/test_profiler.py @@ -93,7 +93,7 @@ class TestProfiler(unittest.TestCase): "profiler is enabled only with GPU") def test_all_profiler(self): self.net_profiler('All', '/tmp/profile_out') - with open('/tmp/profile_out', 'r') as f: + with open('/tmp/profile_out', 'rb') as f: self.assertGreater(len(f.read()), 0) diff --git a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py index 621dd68134..2176db71b9 100644 --- a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py +++ b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py @@ -14,6 +14,7 @@ import unittest import paddle.fluid.core as core +import paddle.fluid.compat as cpt from paddle.fluid.framework import Program @@ -108,7 +109,7 @@ class TestVarDesc(unittest.TestCase): def test_shape(self): program_desc = core.ProgramDesc() block = program_desc.block(0) - var = block.var('my_var') + var = block.var(cpt.to_bytes('my_var')) var.set_type(core.VarDesc.VarType.SELECTED_ROWS) src_shape = [3, 2, 10, 8] var.set_shape(src_shape) @@ -119,7 +120,7 @@ class TestVarDesc(unittest.TestCase): def test_multiple_shape(self): program_desc = core.ProgramDesc() block = program_desc.block(0) - var = block.var('my_reader') + var = block.var(cpt.to_bytes('my_reader')) var.set_type(core.VarDesc.VarType.READER) src_shapes = [[2, 3, 3], [4, 5], [6, 7, 8, 9]] var.set_shapes(src_shapes) @@ -130,7 +131,7 @@ class TestVarDesc(unittest.TestCase): def test_dtype(self): program_desc = core.ProgramDesc() block = program_desc.block(0) - var = block.var('my_var') + var = block.var(cpt.to_bytes('my_var')) var.set_type(core.VarDesc.VarType.LOD_TENSOR) var.set_dtype(core.VarDesc.VarType.INT32) self.assertEqual(core.VarDesc.VarType.INT32, var.dtype()) @@ -139,7 +140,7 @@ class TestVarDesc(unittest.TestCase): def test_multiple_dtype(self): program_desc = core.ProgramDesc() block = program_desc.block(0) - var = block.var('my_reader') + var = block.var(cpt.to_bytes('my_reader')) var.set_type(core.VarDesc.VarType.READER) src_types = [ core.VarDesc.VarType.INT32, core.VarDesc.VarType.FP64, @@ -152,7 +153,7 @@ class TestVarDesc(unittest.TestCase): def test_multiple_lod_level(self): program_desc = core.ProgramDesc() block = program_desc.block(0) - var = block.var('my_reader') + var = block.var(cpt.to_bytes('my_reader')) var.set_type(core.VarDesc.VarType.READER) src_types = [3, 1, 2] var.set_lod_levels(src_types) @@ -166,12 +167,12 @@ class TestBlockDesc(unittest.TestCase): self.assertIsNotNone(program_desc) block = program_desc.block(0) self.assertIsNotNone(block) - var1 = block.var("var1") - var2 = block.var("var2") - var3 = block.var("var3") + var1 = block.var(cpt.to_bytes("var1")) + var2 = block.var(cpt.to_bytes("var2")) + var3 = block.var(cpt.to_bytes("var3")) all_vars = block.all_vars() self.assertEqual(set(all_vars), {var1, var2, var3}) - var2_re = block.find_var("var2") + var2_re = block.find_var(cpt.to_bytes("var2")) self.assertEqual(var2_re, var2) def test_add_op(self): diff --git a/python/paddle/fluid/tests/unittests/test_reader_reset.py b/python/paddle/fluid/tests/unittests/test_reader_reset.py index d3ab991c84..698612acf4 100644 --- a/python/paddle/fluid/tests/unittests/test_reader_reset.py +++ b/python/paddle/fluid/tests/unittests/test_reader_reset.py @@ -13,7 +13,7 @@ # limitations under the License. import paddle.fluid as fluid -import paddle as paddle +import paddle import numpy as np import unittest diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py index 69a522e273..09c3167152 100644 --- a/python/paddle/fluid/tests/unittests/test_recordio_reader.py +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -15,8 +15,8 @@ import unittest import paddle.fluid as fluid -import paddle.v2 as paddle -import paddle.v2.dataset.mnist as mnist +import paddle +import paddle.dataset.mnist as mnist class TestRecordIO(unittest.TestCase): From 3ec6d60c3491c033a3a67bf95f337ef0f70751b4 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 8 Aug 2018 18:04:32 +0800 Subject: [PATCH 17/94] Fix write bytes in dataset download --- python/paddle/dataset/cifar.py | 5 +++++ python/paddle/dataset/common.py | 14 ++++++++++---- python/paddle/dataset/conll05.py | 13 +++++++------ python/paddle/fluid/executor.py | 2 +- .../test_image_classification_resnet.py | 2 +- 5 files changed, 24 insertions(+), 12 deletions(-) diff --git a/python/paddle/dataset/cifar.py b/python/paddle/dataset/cifar.py index f6b4ff8fbd..0e5bbfc45a 100644 --- a/python/paddle/dataset/cifar.py +++ b/python/paddle/dataset/cifar.py @@ -59,6 +59,11 @@ def reader_creator(filename, sub_name, cycle=False): while True: for name in names: + import sys + print(name) + sys.stdout.flush() + print(f.extractfile(name)) + sys.stdout.flush() batch = pickle.load(f.extractfile(name)) for item in read_batch(batch): yield item diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index 1161a57059..8abb4d2790 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -86,15 +86,21 @@ def download(url, module_name, md5sum, save_name=None): total_length = r.headers.get('content-length') if total_length is None: - with open(filename, 'w') as f: - shutil.copyfileobj(r.raw, f) + with open(filename, 'wb') as f: + import sys + print("write follow block") + sys.stdout.flush() + shutil.copyfileobj(cpt.to_bytes(r.raw), f) else: - with open(filename, 'w') as f: + with open(filename, 'wb') as f: + import sys + print("write follow length") + sys.stdout.flush() dl = 0 total_length = int(total_length) for data in r.iter_content(chunk_size=4096): dl += len(data) - f.write(cpt.to_literal_str(data)) + f.write(cpt.to_bytes(data)) done = int(50 * dl / total_length) sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50 - done))) diff --git a/python/paddle/dataset/conll05.py b/python/paddle/dataset/conll05.py index 724202b956..190688ba2c 100644 --- a/python/paddle/dataset/conll05.py +++ b/python/paddle/dataset/conll05.py @@ -24,19 +24,20 @@ import tarfile import gzip import itertools import paddle.dataset.common +import paddle.fluid.compat as cpt from six.moves import zip, range __all__ = ['test, get_dict', 'get_embedding', 'convert'] DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz' DATA_MD5 = '387719152ae52d60422c016e92a742fc' -WORDDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/wordDict.txt' +WORDDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/wordDict.txt' WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa' -VERBDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/verbDict.txt' +VERBDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/verbDict.txt' VERBDICT_MD5 = '0d2977293bbb6cbefab5b0f97db1e77c' -TRGDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/targetDict.txt' +TRGDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/targetDict.txt' TRGDICT_MD5 = 'd8c7f03ceb5fc2e5a0fa7503a4353751' -EMB_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/emb' +EMB_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/emb' EMB_MD5 = 'bf436eb0faa1f6f9103017f8be57cdb7' UNK_IDX = 0 @@ -89,8 +90,8 @@ def corpus_reader(data_path, words_name, props_name): labels = [] one_seg = [] for word, label in zip(words_file, props_file): - word = word.strip() - label = label.strip().split() + word = cpt.to_literal_str(word.strip()) + label = cpt.to_literal_str(label.strip().split()) if len(label) == 0: # end of sentence for i in range(len(one_seg[0])): diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 8437a9f20f..a0cc7fac34 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -320,7 +320,7 @@ class Executor(object): # append fetch_operators if not has_fetch_operators(global_block, fetch_list, fetch_var_name): for i, var in enumerate(fetch_list): - assert isinstance(var, Variable) or isinstance(var, str), ( + assert isinstance(var, Variable) or isinstance(var, six.text_type), ( "Wrong type for fetch_list[%s]: %s" % (i, type(var))) global_block.append_op( type='fetch', diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py index a1f62db093..54c59ac075 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py @@ -55,7 +55,7 @@ def resnet_cifar10(input, depth=32): return tmp assert (depth - 2) % 6 == 0 - n = (depth - 2) / 6 + n = (depth - 2) // 6 conv1 = conv_bn_layer( input=input, ch_out=16, filter_size=3, stride=1, padding=1) res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) From 61052cdbc6cd048410aebb0df514fba6f8931347 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Wed, 8 Aug 2018 10:22:36 +0000 Subject: [PATCH 18/94] polish high frequency enforce error message --- paddle/fluid/platform/enforce.h | 10 ++++++---- paddle/fluid/platform/gpu_info.cc | 20 ++++++++++---------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 566485cd3c..cad60275a2 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -263,7 +263,8 @@ inline void throw_on_error(T e) { * PADDLE_ENFORCE_EQ(a, b); * * will raise an expression described as follows: - * "enforce a == b failed, 1 != 2" with detailed stack information. + * "Data check failed. Expected input a == b, but received a(1) != b(2)." + * with detailed stack information. * * extra messages is also supported, for example: * PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2) @@ -292,9 +293,10 @@ inline void throw_on_error(T e) { #define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ do { \ if (UNLIKELY(!((__VAL0)__CMP(__VAL1)))) { \ - PADDLE_THROW("enforce %s " #__CMP " %s failed, %s " #__INV_CMP \ - " %s\n%s", \ - #__VAL0, #__VAL1, paddle::string::to_string(__VAL0), \ + PADDLE_THROW("Data check failed. Expected %s " #__CMP \ + " %s, but received %s:%s " #__INV_CMP " %s:%s.\n%s", \ + #__VAL0, #__VAL1, #__VAL0, \ + paddle::string::to_string(__VAL0), #__VAL1, \ paddle::string::to_string(__VAL1), \ paddle::string::Sprintf("" __VA_ARGS__)); \ } \ diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 4cee93f3a4..f9e2e8c69d 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -100,25 +100,25 @@ size_t GpuMinChunkSize() { size_t GpuMaxChunkSize() { size_t total = 0; - size_t available = 0; + size_t available_memory = 0; - GpuMemoryUsage(&available, &total); - VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/" + GpuMemoryUsage(&available_memory, &total); + VLOG(10) << "GPU Usage " << available_memory / 1024 / 1024 << "M/" << total / 1024 / 1024 << "M"; size_t reserving = static_cast(0.05 * total); // If available less than minimum chunk size, no usable memory exists. - available = - std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(), - total - reserving); + available_memory = std::min( + std::max(available_memory, GpuMinChunkSize()) - GpuMinChunkSize(), + total - reserving); // Reserving the rest memory for page tables, etc. - size_t allocating = static_cast(FLAGS_fraction_of_gpu_memory_to_use * - (total - reserving)); + size_t allocating_memory = static_cast( + FLAGS_fraction_of_gpu_memory_to_use * (total - reserving)); - PADDLE_ENFORCE_LE(allocating, available); + PADDLE_ENFORCE_LE(allocating_memory, available_memory); - return allocating; + return allocating_memory; } void GpuMemcpyAsync(void *dst, const void *src, size_t count, From e102e5dd834314b413aee1f979782ee34d5ad4c4 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 8 Aug 2018 21:40:48 +0800 Subject: [PATCH 19/94] Fix cifar10 decompress problem --- python/paddle/dataset/cifar.py | 21 ++++++++----------- python/paddle/dataset/common.py | 10 ++------- python/paddle/fluid/compat.py | 4 ++-- .../cifar10_small_test_set.py | 17 ++++++++------- 4 files changed, 23 insertions(+), 29 deletions(-) diff --git a/python/paddle/dataset/cifar.py b/python/paddle/dataset/cifar.py index 0e5bbfc45a..0d07462e68 100644 --- a/python/paddle/dataset/cifar.py +++ b/python/paddle/dataset/cifar.py @@ -32,7 +32,7 @@ import itertools import numpy import paddle.dataset.common import tarfile -from six.moves import zip +import six from six.moves import cPickle as pickle __all__ = ['train100', 'test100', 'train10', 'test10', 'convert'] @@ -46,25 +46,22 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' def reader_creator(filename, sub_name, cycle=False): def read_batch(batch): - data = batch['data'] - labels = batch.get('labels', batch.get('fine_labels', None)) + data = batch[six.b('data')] + labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None)) assert labels is not None - for sample, label in zip(data, labels): + for sample, label in six.moves.zip(data, labels): yield (sample / 255.0).astype(numpy.float32), int(label) def reader(): with tarfile.open(filename, mode='r') as f: - names = (each_item.name for each_item in f - if sub_name in each_item.name) + names = [each_item.name for each_item in f if sub_name in each_item.name] while True: for name in names: - import sys - print(name) - sys.stdout.flush() - print(f.extractfile(name)) - sys.stdout.flush() - batch = pickle.load(f.extractfile(name)) + if six.PY2: + batch = pickle.load(f.extractfile(name)) + else: + batch = pickle.load(f.extractfile(name), encoding='bytes') for item in read_batch(batch): yield item if not cycle: diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index 8abb4d2790..07e6b199c0 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -87,20 +87,14 @@ def download(url, module_name, md5sum, save_name=None): if total_length is None: with open(filename, 'wb') as f: - import sys - print("write follow block") - sys.stdout.flush() - shutil.copyfileobj(cpt.to_bytes(r.raw), f) + shutil.copyfileobj(r.raw, f) else: with open(filename, 'wb') as f: - import sys - print("write follow length") - sys.stdout.flush() dl = 0 total_length = int(total_length) for data in r.iter_content(chunk_size=4096): dl += len(data) - f.write(cpt.to_bytes(data)) + f.write(data) done = int(50 * dl / total_length) sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50 - done))) diff --git a/python/paddle/fluid/compat.py b/python/paddle/fluid/compat.py index 05633583cc..fe23a5929a 100644 --- a/python/paddle/fluid/compat.py +++ b/python/paddle/fluid/compat.py @@ -27,7 +27,7 @@ def to_literal_str(obj): def _to_literal_str(obj): if isinstance(obj, six.binary_type): - return obj.decode('latin-1') + return obj.decode('utf-8') elif isinstance(obj, six.text_type): return obj else: @@ -45,7 +45,7 @@ def to_bytes(obj): def _to_bytes(obj): if isinstance(obj, six.text_type): - return obj.encode('latin-1') + return obj.encode('utf-8') elif isinstance(obj, six.binary_type): return obj else: diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py index e7b709f31b..9afac4143e 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py @@ -32,8 +32,8 @@ import itertools import numpy import paddle.dataset.common import tarfile +import six from six.moves import cPickle as pickle -from six.moves import zip __all__ = ['train10'] @@ -44,20 +44,23 @@ CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' def reader_creator(filename, sub_name, batch_size=None): def read_batch(batch): - data = batch['data'] - labels = batch.get('labels', batch.get('fine_labels', None)) + data = batch[six.b('data')] + labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None)) assert labels is not None - for sample, label in zip(data, labels): + for sample, label in six.moves.zip(data, labels): yield (sample / 255.0).astype(numpy.float32), int(label) def reader(): with tarfile.open(filename, mode='r') as f: - names = (each_item.name for each_item in f - if sub_name in each_item.name) + names = [each_item.name for each_item in f + if sub_name in each_item.name] batch_count = 0 for name in names: - batch = pickle.load(f.extractfile(name)) + if six.PY2: + batch = pickle.load(f.extractfile(name)) + else: + batch = pickle.load(f.extractfile(name), encoding='bytes') for item in read_batch(batch): if isinstance(batch_size, int) and batch_count > batch_size: break From 2a799e3bf24a869fec6a6c24968aa552a18f4ca0 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 8 Aug 2018 22:47:36 +0800 Subject: [PATCH 20/94] Fix dist_se_resnet problem --- python/paddle/fluid/tests/unittests/dist_se_resnext.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dist_se_resnext.py b/python/paddle/fluid/tests/unittests/dist_se_resnext.py index 7199ef1020..fc7422f12b 100644 --- a/python/paddle/fluid/tests/unittests/dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/dist_se_resnext.py @@ -171,7 +171,7 @@ class SE_ResNeXt(): num_filters=num_filters, filter_size=filter_size, stride=stride, - padding=(filter_size - 1) / 2, + padding=(filter_size - 1) // 2, groups=groups, act=None, bias_attr=False) @@ -182,7 +182,7 @@ class SE_ResNeXt(): input=input, pool_size=0, pool_type='avg', global_pooling=True) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) squeeze = fluid.layers.fc(input=pool, - size=num_channels / reduction_ratio, + size=num_channels // reduction_ratio, act='relu') stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) excitation = fluid.layers.fc(input=squeeze, From c2fce7dd24ebb97e635ecf9ca04a3a7a8e316343 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 8 Aug 2018 22:52:54 +0800 Subject: [PATCH 21/94] Fix bos problem --- python/paddle/dataset/conll05.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/dataset/conll05.py b/python/paddle/dataset/conll05.py index 190688ba2c..b5a87d624d 100644 --- a/python/paddle/dataset/conll05.py +++ b/python/paddle/dataset/conll05.py @@ -29,7 +29,7 @@ from six.moves import zip, range __all__ = ['test, get_dict', 'get_embedding', 'convert'] -DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz' +DATA_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/conll05st-tests.tar.gz' DATA_MD5 = '387719152ae52d60422c016e92a742fc' WORDDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/wordDict.txt' WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa' From 46a2694633a5052236db953d640225046fe5d34b Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 8 Aug 2018 23:46:13 +0800 Subject: [PATCH 22/94] Fix exception problem --- python/paddle/fluid/compat.py | 8 ++++++++ python/paddle/fluid/tests/unittests/test_exception.py | 8 +++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/compat.py b/python/paddle/fluid/compat.py index fe23a5929a..6cb59d50a7 100644 --- a/python/paddle/fluid/compat.py +++ b/python/paddle/fluid/compat.py @@ -72,3 +72,11 @@ def round(x, d=0): def floor_division(x, y): return x // y + +# exception related functions +def get_exception_message(exc): + if six.PY2: + return exc.message + else: + return str(exc) + diff --git a/python/paddle/fluid/tests/unittests/test_exception.py b/python/paddle/fluid/tests/unittests/test_exception.py index bb7c0f88f6..6e4ea273a9 100644 --- a/python/paddle/fluid/tests/unittests/test_exception.py +++ b/python/paddle/fluid/tests/unittests/test_exception.py @@ -12,19 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle.fluid.compat as cpt import paddle.fluid.core as core import unittest class TestException(unittest.TestCase): def test_exception(self): - ex = None + exception = None try: core.__unittest_throw_exception__() except core.EnforceNotMet as ex: - self.assertIn("test exception", ex.message) + self.assertIn("test exception", cpt.get_exception_message(ex)) + exception = ex - self.assertIsNotNone(ex) + self.assertIsNotNone(exception) if __name__ == "__main__": From ee1d08abba7a2565c2ec1f4791254abdc32b95f8 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 9 Aug 2018 00:52:25 +0800 Subject: [PATCH 23/94] Fix CI issues --- python/paddle/dataset/movielens.py | 11 ++++++-- python/paddle/dataset/wmt16.py | 2 +- python/paddle/fluid/compat.py | 28 +++++++++---------- python/paddle/fluid/op.py | 3 ++ .../unittests/test_conv3d_transpose_op.py | 2 +- .../fluid/tests/unittests/test_maxout_op.py | 2 +- 6 files changed, 27 insertions(+), 21 deletions(-) diff --git a/python/paddle/dataset/movielens.py b/python/paddle/dataset/movielens.py index 056ec21786..354b7d4aee 100644 --- a/python/paddle/dataset/movielens.py +++ b/python/paddle/dataset/movielens.py @@ -27,6 +27,8 @@ import paddle.dataset.common import re import random import functools +import six +import paddle.fluid.compat as cpt __all__ = [ 'train', 'test', 'get_movie_title_dict', 'max_movie_id', 'max_user_id', @@ -112,6 +114,7 @@ def __initialize_meta_info__(): categories_set = set() with package.open('ml-1m/movies.dat') as movie_file: for i, line in enumerate(movie_file): + line = cpt.to_literal_str(line, encoding='latin') movie_id, title, categories = line.strip().split('::') categories = categories.split('|') for c in categories: @@ -136,6 +139,7 @@ def __initialize_meta_info__(): USER_INFO = dict() with package.open('ml-1m/users.dat') as user_file: for line in user_file: + line = cpt.to_literal_str(line, encoding='latin') uid, gender, age, job, _ = line.strip().split("::") USER_INFO[int(uid)] = UserInfo( index=uid, gender=gender, age=age, job_id=job) @@ -148,6 +152,7 @@ def __reader__(rand_seed=0, test_ratio=0.1, is_test=False): with zipfile.ZipFile(file=fn) as package: with package.open('ml-1m/ratings.dat') as rating: for line in rating: + line = cpt.to_literal_str(line, encoding='latin') if (rand.random() < test_ratio) == is_test: uid, mov_id, rating, _ = line.strip().split("::") uid = int(uid) @@ -187,7 +192,7 @@ def max_movie_id(): Get the maximum value of movie id. """ __initialize_meta_info__() - return reduce(__max_index_info__, list(MOVIE_INFO.values())).index + return six.moves.reduce(__max_index_info__, list(MOVIE_INFO.values())).index def max_user_id(): @@ -195,7 +200,7 @@ def max_user_id(): Get the maximum value of user id. """ __initialize_meta_info__() - return reduce(__max_index_info__, list(USER_INFO.values())).index + return six.moves.reduce(__max_index_info__, list(USER_INFO.values())).index def __max_job_id_impl__(a, b): @@ -210,7 +215,7 @@ def max_job_id(): Get the maximum value of job id. """ __initialize_meta_info__() - return reduce(__max_job_id_impl__, list(USER_INFO.values())).job_id + return six.moves.reduce(__max_job_id_impl__, list(USER_INFO.values())).job_id def movie_categories(): diff --git a/python/paddle/dataset/wmt16.py b/python/paddle/dataset/wmt16.py index cd34b523eb..3e453a6479 100644 --- a/python/paddle/dataset/wmt16.py +++ b/python/paddle/dataset/wmt16.py @@ -62,7 +62,7 @@ def __build_dict(tar_file, dict_size, save_path, lang): word_dict = defaultdict(int) with tarfile.open(tar_file, mode="r") as f: for line in f.extractfile("wmt16/train"): - line_split = line.strip().split("\t") + line_split = line.strip().split(six.b("\t")) if len(line_split) != 2: continue sen = line_split[0] if lang == "en" else line_split[1] for w in sen.split(): diff --git a/python/paddle/fluid/compat.py b/python/paddle/fluid/compat.py index 6cb59d50a7..32f567253e 100644 --- a/python/paddle/fluid/compat.py +++ b/python/paddle/fluid/compat.py @@ -13,39 +13,40 @@ # limitations under the License. import six +import math # str and bytes related functions -def to_literal_str(obj): +def to_literal_str(obj, encoding='utf-8'): if isinstance(obj, list): - return [_to_literal_str(item) for item in obj] + return [_to_literal_str(item, encoding) for item in obj] elif isinstance(obj, set): - return set([_to_literal_str(item) for item in obj]) + return set([_to_literal_str(item, encoding) for item in obj]) else: - return _to_literal_str(obj) + return _to_literal_str(obj, encoding) -def _to_literal_str(obj): +def _to_literal_str(obj, encoding): if isinstance(obj, six.binary_type): - return obj.decode('utf-8') + return obj.decode(encoding) elif isinstance(obj, six.text_type): return obj else: return six.u(obj) -def to_bytes(obj): +def to_bytes(obj, encoding='utf-8'): if isinstance(obj, list): - return [_to_bytes(item) for item in obj] + return [_to_bytes(item, encoding) for item in obj] elif isinstance(obj, set): - return set([_to_bytes(item) for item in obj]) + return set([_to_bytes(item, encoding) for item in obj]) else: - return _to_bytes(obj) + return _to_bytes(obj, encoding) -def _to_bytes(obj): +def _to_bytes(obj, encoding): if isinstance(obj, six.text_type): - return obj.encode('utf-8') + return obj.encode(encoding) elif isinstance(obj, six.binary_type): return obj else: @@ -53,9 +54,6 @@ def _to_bytes(obj): # math related functions -import math - - def round(x, d=0): """ Compatible round which act the same behaviour in Python3. diff --git a/python/paddle/fluid/op.py b/python/paddle/fluid/op.py index 93f021a360..a2db5bad51 100644 --- a/python/paddle/fluid/op.py +++ b/python/paddle/fluid/op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import six import paddle.fluid.core as core @@ -99,6 +100,8 @@ class OpDescCreationMethod(object): new_attr = op_desc.attrs.add() new_attr.name = attr.name new_attr.type = attr.type + if isinstance(user_defined_attr, np.ndarray): + user_defined_attr = user_defined_attr.tolist() if attr.type == framework_pb2.INT: new_attr.i = user_defined_attr elif attr.type == framework_pb2.FLOAT: diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py index 300fa5e8bd..2e55b89392 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py @@ -25,7 +25,7 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs): groups = attrs['groups'] assert in_c == f_c out_c = f_out_c * groups - sub_in_c = in_c / groups + sub_in_c = in_c // groups stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[ 'dilations'] diff --git a/python/paddle/fluid/tests/unittests/test_maxout_op.py b/python/paddle/fluid/tests/unittests/test_maxout_op.py index f5ddf72516..2151853ae1 100644 --- a/python/paddle/fluid/tests/unittests/test_maxout_op.py +++ b/python/paddle/fluid/tests/unittests/test_maxout_op.py @@ -19,7 +19,7 @@ from op_test import OpTest def maxout_forward_naive(input, groups): s0, s1, s2, s3 = input.shape - return np.ndarray([s0, s1 / groups, groups, s2, s3], \ + return np.ndarray([s0, s1 // groups, groups, s2, s3], \ buffer = input, dtype=input.dtype).max(axis=(2)) From db7d8136a321c98eaf0b1a7b152c07736d1496fe Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 9 Aug 2018 01:15:28 +0800 Subject: [PATCH 24/94] Fix CI issue --- python/paddle/dataset/imdb.py | 12 ++++++------ python/paddle/fluid/framework.py | 7 +++---- .../test_memopt_image_classification_train.py | 2 +- .../transpiler/memory_optimization_transpiler.py | 11 ++++++----- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/paddle/dataset/imdb.py b/python/paddle/dataset/imdb.py index 39fc29fdac..60a9062c46 100644 --- a/python/paddle/dataset/imdb.py +++ b/python/paddle/dataset/imdb.py @@ -25,7 +25,7 @@ import collections import tarfile import re import string -from six.moves import range +import six __all__ = ['build_dict', 'train', 'test', 'convert'] @@ -43,13 +43,13 @@ def tokenize(pattern): # sequential access of member files, other than # tarfile.extractfile, which does random access and might # destroy hard disks. - tf = next(tarf) + tf = tarf.next() while tf != None: if bool(pattern.match(tf.name)): # newline and punctuations removal and ad-hoc tokenization. - yield tarf.extractfile(tf).read().rstrip("\n\r").translate( - None, string.punctuation).lower().split() - tf = next(tarf) + yield tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate( + None, six.b(string.punctuation)).lower().split() + tf = tarf.next() def build_dict(pattern, cutoff): @@ -67,7 +67,7 @@ def build_dict(pattern, cutoff): dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) words, _ = list(zip(*dictionary)) - word_idx = dict(list(zip(words, range(len(words))))) + word_idx = dict(list(zip(words, six.moves.range(len(words))))) word_idx[''] = len(words) return word_idx diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 4e08836471..c3d2b7a4b2 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -905,10 +905,9 @@ class Block(object): Variable: the Variable with the giving name. """ if not isinstance(name, six.string_types): - if not isinstance(name, six.binary_type): - raise TypeError( - "var require string as parameter, but get %s instead." % - (type(name))) + raise TypeError( + "var require string as parameter, but get %s instead." % + (type(name))) v = self.vars.get(name, None) if v is None: raise ValueError("var %s not in this block" % name) diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py index b2a59d27da..8831dac336 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py @@ -56,7 +56,7 @@ def resnet_cifar10(input, depth=32): return tmp assert (depth - 2) % 6 == 0 - n = (depth - 2) / 6 + n = (depth - 2) // 6 conv1 = conv_bn_layer( input=input, ch_out=16, filter_size=3, stride=1, padding=1) res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 20ba7ed2b0..c072ef0822 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -14,6 +14,7 @@ from collections import defaultdict from .. import core +from .. import compat from ..framework import Program, default_main_program, Parameter from ..backward import _rename_arg_ from functools import reduce @@ -125,15 +126,15 @@ class ControlFlowGraph(object): def _has_var(self, block_desc, var_name, is_forward): if is_forward: - return block_desc.has_var(str(var_name)) + return block_desc.has_var(cpt.to_bytes(var_name)) else: - return block_desc.has_var_recursive(str(var_name)) + return block_desc.has_var_recursive(cpt.to_bytes(var_name)) def _find_var(self, block_desc, var_name, is_forward): if is_forward: - return block_desc.find_var(str(var_name)) + return block_desc.find_var(cpt.to_bytes(var_name)) else: - return block_desc.find_var_recursive(str(var_name)) + return block_desc.find_var_recursive(cpt.to_bytes(var_name)) def _check_var_validity(self, block_desc, x, is_forward): if str(x) == "@EMPTY@": @@ -258,7 +259,7 @@ class ControlFlowGraph(object): # Rename the var to the cache var already with # memory allocated in order to reuse the memory. _rename_arg_(self._ops, x, cache_var, begin_idx=i) - self._program.block(block_desc.id).var(str( + self._program.block(block_desc.id).var(cpt.to_literal_str( x)).desc = self._find_var(block_desc, cache_var, is_forward) self._update_graph(x, cache_var, begin_idx=i) From b1dd4149b90dde40640de2baf0190d611cb24486 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Thu, 9 Aug 2018 03:02:25 +0000 Subject: [PATCH 25/94] adjust enforce test cases --- paddle/fluid/platform/enforce_test.cc | 30 +++++++++++++++++---------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index 0e8684581a..8dcf39fdaa 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -54,7 +54,9 @@ TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) { PADDLE_ENFORCE_EQ(a, 1 + 3); } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; - HasPrefix(StringPiece(error.what()), "enforce a == 1 + 3 failed, 2 != 4"); + HasPrefix( + StringPiece(error.what()), + "Data check failed. Expected a == 1 + 3, but received a:2 != 1 + 3:4."); } EXPECT_TRUE(caught_exception); } @@ -67,7 +69,8 @@ TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) { } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; HasPrefix(StringPiece(error.what()), - "enforce a == 1 + 3 failed, 2 != 4\ntheir size not match"); + "Data check failed. Expected a == 1 + 3, but received a:2 != 1 + " + "3:4.\ntheir size not match"); } EXPECT_TRUE(caught_exception); } @@ -84,8 +87,9 @@ TEST(ENFORCE_NE, FAIL) { PADDLE_ENFORCE_NE(1.0, 1UL); } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; - EXPECT_TRUE(HasPrefix(StringPiece(error.what()), - "enforce 1.0 != 1UL failed, 1 == 1")) + EXPECT_TRUE(HasPrefix( + StringPiece(error.what()), + "Data check failed. Expected 1.0 != 1UL, but received 1.0:1 == 1UL:1.")) << error.what() << " does not have expected prefix"; } EXPECT_TRUE(caught_exception); @@ -98,8 +102,9 @@ TEST(ENFORCE_GT, FAIL) { PADDLE_ENFORCE_GT(1, 2UL); } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; - EXPECT_TRUE( - HasPrefix(StringPiece(error.what()), "enforce 1 > 2UL failed, 1 <= 2")); + EXPECT_TRUE(HasPrefix( + StringPiece(error.what()), + "Data check failed. Expected 1 > 2UL, but received 1:1 <= 2UL:2.")); } EXPECT_TRUE(caught_exception); } @@ -116,8 +121,9 @@ TEST(ENFORCE_GE, FAIL) { PADDLE_ENFORCE_GE(1, 2UL); } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; - EXPECT_TRUE( - HasPrefix(StringPiece(error.what()), "enforce 1 >= 2UL failed, 1 < 2")); + EXPECT_TRUE(HasPrefix( + StringPiece(error.what()), + "Data check failed. Expected 1 >= 2UL, but received 1:1 < 2UL:2.")); } EXPECT_TRUE(caught_exception); } @@ -135,8 +141,9 @@ TEST(ENFORCE_LE, FAIL) { PADDLE_ENFORCE_GT(1, 2UL); } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; - EXPECT_TRUE( - HasPrefix(StringPiece(error.what()), "enforce 1 > 2UL failed, 1 <= 2")); + EXPECT_TRUE(HasPrefix( + StringPiece(error.what()), + "Data check failed. Expected 1 > 2UL, but received 1:1 <= 2UL:2.")); } EXPECT_TRUE(caught_exception); } @@ -153,7 +160,8 @@ TEST(ENFORCE_LT, FAIL) { } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; EXPECT_TRUE(HasPrefix(StringPiece(error.what()), - "enforce 1UL < 0.12 failed, 1 >= 0.12")); + "Data check failed. Expected 1UL < 0.12, but " + "received 1UL:1 >= 0.12:0.12.")); } EXPECT_TRUE(caught_exception); } From c3fdf3aee4ff397982fa0433db4af7da9794047b Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 9 Aug 2018 15:19:24 +0800 Subject: [PATCH 26/94] Fix divide problem in CI Fix pb_protobuf2 FromString problem --- paddle/fluid/framework/op_desc.cc | 1 + python/paddle/fluid/backward.py | 2 +- python/paddle/fluid/debugger.py | 2 +- python/paddle/fluid/framework.py | 5 ++++- .../fluid/tests/unittests/test_conv2d_op.py | 22 +++++++++---------- .../unittests/test_conv2d_transpose_op.py | 4 ++-- .../fluid/tests/unittests/test_lrn_op.py | 2 +- .../tests/unittests/test_operator_desc.py | 5 +++-- .../test_parallel_executor_seresnext.py | 4 ++-- .../memory_optimization_transpiler.py | 2 +- 10 files changed, 27 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 984ea3a3dd..c473b11292 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -202,6 +202,7 @@ std::vector OpDesc::AttrNames() const { } void OpDesc::SetAttr(const std::string &name, const Attribute &v) { + VLOG(11) << "SetAttr: " << Type() << ", " << name << ", " << v.which(); // NOTICE(minqiyang): pybind11 will take the empty list in python as // the std::vector type in C++; so we have to change the attr's type // here if we meet this issue diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 6430d3a264..804608827b 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -243,7 +243,7 @@ from .proto import framework_pb2 def serialize_op_decs(op_desc): protostr = op_desc.serialize_to_string() - proto = framework_pb2.OpDesc.FromString(str(protostr)) + proto = framework_pb2.OpDesc.FromString(six.binary_type(protostr)) return proto.__str__() diff --git a/python/paddle/fluid/debugger.py b/python/paddle/fluid/debugger.py index b7a92cf044..dd8523f95b 100644 --- a/python/paddle/fluid/debugger.py +++ b/python/paddle/fluid/debugger.py @@ -225,7 +225,7 @@ def draw_block_graphviz(block, highlights=None, path="./temp.dot"): graph = GraphPreviewGenerator("some graph") # collect parameters and args protostr = block.desc.serialize_to_string() - desc = framework_pb2.BlockDesc.FromString(str(protostr)) + desc = framework_pb2.BlockDesc.FromString(six.binary_type(protostr)) def need_highlight(name): if highlights is None: return False diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index c3d2b7a4b2..af80cd9ca1 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -28,7 +28,7 @@ except ImportError as e: """NOTE: You may need to run \"export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH\" if you encounters \"libmkldnn.so not found\" errors. If you have python installed in other directory, replace \"/usr/local/lib\" with your own - directory. The original error is: \n""" + e.message) + directory. The original error is: \n""" + cpt.get_exception_message(e)) except Exception as e: raise e from . import unique_name @@ -574,6 +574,9 @@ class Operator(object): attr_val = self.attrs[attr_name] self._update_desc_attr(attr_name, attr_val) + import sys + print('self.attrs', self.attrs) + sys.stdout.flush() self.desc.check_attrs() if self.has_kernel(type): self.desc.infer_var_type(self.block.desc) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index a478649541..bdfa17ebc9 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -24,12 +24,12 @@ def conv2d_forward_naive(input, filter, group, conv_param): out_c, f_c, f_h, f_w = filter.shape assert f_c * group == in_c assert np.mod(out_c, group) == 0 - sub_out_c = out_c / group + sub_out_c = out_c // group stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[ 'dilation'] - out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) / stride[0] - out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) / stride[1] + out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0] + out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1] out = np.zeros((in_n, out_c, out_h, out_w)) d_bolck_h = (dilation[0] * (f_h - 1) + 1) @@ -160,7 +160,7 @@ class TestConv2dOp(OpTest): self.stride = [1, 1] self.input_size = [2, 3, 5, 5] # NCHW assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] def init_dilation(self): @@ -179,7 +179,7 @@ class TestWithPad(TestConv2dOp): self.stride = [1, 1] self.input_size = [2, 3, 5, 5] # NCHW assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] @@ -189,7 +189,7 @@ class TestWithStride(TestConv2dOp): self.stride = [2, 2] self.input_size = [2, 3, 6, 6] # NCHW assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] @@ -204,7 +204,7 @@ class TestWith1x1(TestConv2dOp): self.stride = [1, 1] self.input_size = [2, 3, 5, 5] # NCHW assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 1, 1] def init_group(self): @@ -217,7 +217,7 @@ class TestWithDilation(TestConv2dOp): self.stride = [1, 1] self.input_size = [2, 3, 10, 10] # NCHW assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] def init_dilation(self): @@ -233,7 +233,7 @@ class TestWithInput1x1Filter1x1(TestConv2dOp): self.stride = [1, 1] self.input_size = [2, 3, 1, 1] # NCHW assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 1, 1] def init_group(self): @@ -350,7 +350,7 @@ class TestDepthwiseConv(TestConv2dOp): self.input_size = [2, 3, 5, 5] # NCHW self.groups = 3 assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] self.op_type = "depthwise_conv2d" @@ -362,7 +362,7 @@ class TestDepthwiseConv2(TestConv2dOp): self.input_size = [2, 3, 5, 5] # NCHW self.groups = 3 assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] self.op_type = "depthwise_conv2d" diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py index af6cd99b0d..1cb50afca5 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py @@ -25,7 +25,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs): groups = attrs['groups'] assert in_c == f_c out_c = f_out_c * groups - sub_in_c = in_c / groups + sub_in_c = in_c // groups stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[ 'dilations'] @@ -258,7 +258,7 @@ class TestDepthwiseConvTranspose(TestConv2dTransposeOp): self.input_size = [2, 8, 16, 16] # NCHW self.groups = 8 assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [self.input_size[1], f_c, 4, 4] self.op_type = "depthwise_conv2d_transpose" diff --git a/python/paddle/fluid/tests/unittests/test_lrn_op.py b/python/paddle/fluid/tests/unittests/test_lrn_op.py index eaff45cbb2..b0930440f2 100644 --- a/python/paddle/fluid/tests/unittests/test_lrn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lrn_op.py @@ -34,7 +34,7 @@ class TestLRNOp(OpTest): return x + 1 def get_out(self): - start = -(self.n - 1) / 2 + start = -(self.n - 1) // 2 end = start + self.n mid = np.empty((self.N, self.C, self.H, self.W)).astype("float32") diff --git a/python/paddle/fluid/tests/unittests/test_operator_desc.py b/python/paddle/fluid/tests/unittests/test_operator_desc.py index c098a5a0cb..cbdbd957a5 100644 --- a/python/paddle/fluid/tests/unittests/test_operator_desc.py +++ b/python/paddle/fluid/tests/unittests/test_operator_desc.py @@ -15,6 +15,7 @@ import unittest import paddle.fluid.core as core +import paddle.fluid.compat as cpt from paddle.fluid.framework import Program, default_startup_program @@ -29,13 +30,13 @@ class TestOperator(unittest.TestCase): self.assertFail() except ValueError as v_err: self.assertEqual( - v_err.message, + cpt.get_exception_message(v_err), "`type` to initilized an Operator can not be None.") try: block.append_op(type="no_such_op") self.assertFail() except ValueError as a_err: - self.assertEqual(a_err.message, + self.assertEqual(cpt.get_exception_message(a_err), "Operator \"no_such_op\" has not been registered.") def test_op_desc_creation(self): diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py index b04c24b9bd..4b4e5e6898 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py @@ -33,7 +33,7 @@ def squeeze_excitation(input, num_channels, reduction_ratio): pool = fluid.layers.reduce_mean(input=reshape, dim=2) squeeze = fluid.layers.fc(input=pool, - size=num_channels / reduction_ratio, + size=num_channels // reduction_ratio, act='relu') excitation = fluid.layers.fc(input=squeeze, size=num_channels, @@ -49,7 +49,7 @@ def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, num_filters=num_filters, filter_size=filter_size, stride=stride, - padding=(filter_size - 1) / 2, + padding=(filter_size - 1) // 2, groups=groups, act=None, bias_attr=False) diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index c072ef0822..2732030b95 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -14,7 +14,7 @@ from collections import defaultdict from .. import core -from .. import compat +from .. import compat as cpt from ..framework import Program, default_main_program, Parameter from ..backward import _rename_arg_ from functools import reduce From 09103084d332c2fc794e07179b02d39d7027c17b Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 9 Aug 2018 21:42:41 +0800 Subject: [PATCH 27/94] Polish compat.py and add unittest for it --- paddle/fluid/framework/op_desc.cc | 1 - python/paddle/fluid/compat.py | 162 +++++- python/paddle/fluid/layers/detection.py | 9 +- .../fluid/tests/unittests/test_compat.py | 490 ++++++++++++++++++ 4 files changed, 650 insertions(+), 12 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_compat.py diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index c473b11292..984ea3a3dd 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -202,7 +202,6 @@ std::vector OpDesc::AttrNames() const { } void OpDesc::SetAttr(const std::string &name, const Attribute &v) { - VLOG(11) << "SetAttr: " << Type() << ", " << name << ", " << v.which(); // NOTICE(minqiyang): pybind11 will take the empty list in python as // the std::vector type in C++; so we have to change the attr's type // here if we meet this issue diff --git a/python/paddle/fluid/compat.py b/python/paddle/fluid/compat.py index 32f567253e..16932a2a0c 100644 --- a/python/paddle/fluid/compat.py +++ b/python/paddle/fluid/compat.py @@ -15,18 +15,77 @@ import six import math +__all__ = [ + 'to_literal_str', + 'to_bytes', + 'round', + 'floor_division', + 'get_exception_message', +] # str and bytes related functions -def to_literal_str(obj, encoding='utf-8'): +def to_literal_str(obj, encoding='utf-8', inplace=False): + """ + All string in PaddlePaddle should be represented as a literal string. + This function will convert object to a literal string without any encoding. + Especially, if the object type is a list or set container, we will iterate + all items in the object and convert them to literal string. + + In Python3: + Decode the bytes type object to str type with specific encoding + + In Python2: + Decode the str type object to unicode type with specific encoding + + Args: + obj(unicode|str|bytes|list|set) : The object to be decoded. + encoding(str) : The encoding format to decode a string + inplace(bool) : If we change the original object or we create a new one + + Returns: + Decoded result of obj + """ + if obj is None: + return obj + if isinstance(obj, list): - return [_to_literal_str(item, encoding) for item in obj] + if inplace: + for i in six.moves.xrange(len(obj)): + obj[i] = _to_literal_str(obj[i], encoding) + return obj + else: + return [_to_literal_str(item, encoding) for item in obj] elif isinstance(obj, set): - return set([_to_literal_str(item, encoding) for item in obj]) + if inplace: + for item in obj: + obj.remove(item) + obj.add(_to_literal_str(item, encoding)) + return obj + else: + return set([_to_literal_str(item, encoding) for item in obj]) else: return _to_literal_str(obj, encoding) def _to_literal_str(obj, encoding): + """ + In Python3: + Decode the bytes type object to str type with specific encoding + + In Python2: + Decode the str type object to unicode type with specific encoding, + or we just return the unicode string of object + + Args: + obj(unicode|str|bytes) : The object to be decoded. + encoding(str) : The encoding format + + Returns: + decoded result of obj + """ + if obj is None: + return obj + if isinstance(obj, six.binary_type): return obj.decode(encoding) elif isinstance(obj, six.text_type): @@ -35,16 +94,70 @@ def _to_literal_str(obj, encoding): return six.u(obj) -def to_bytes(obj, encoding='utf-8'): +def to_bytes(obj, encoding='utf-8', inplace=False): + """ + All string in PaddlePaddle should be represented as a literal string. + This function will convert object to a bytes with specific encoding. + Especially, if the object type is a list or set container, we will iterate + all items in the object and convert them to bytes. + + In Python3: + Encode the str type object to bytes type with specific encoding + + In Python2: + Encode the unicode type object to str type with specific encoding, + or we just return the 8-bit string of object + + Args: + obj(unicode|str|bytes|list|set) : The object to be encoded. + encoding(str) : The encoding format to encode a string + inplace(bool) : If we change the original object or we create a new one + + Returns: + Decoded result of obj + """ + if obj is None: + return obj + if isinstance(obj, list): - return [_to_bytes(item, encoding) for item in obj] + if inplace: + for i in six.moves.xrange(len(obj)): + obj[i] = _to_bytes(obj[i], encoding) + return obj + else: + return [_to_bytes(item, encoding) for item in obj] elif isinstance(obj, set): - return set([_to_bytes(item, encoding) for item in obj]) + if inplace: + for item in obj: + obj.remove(item) + obj.add(_to_bytes(item, encoding)) + return obj + else: + return set([_to_bytes(item, encoding) for item in obj]) else: return _to_bytes(obj, encoding) def _to_bytes(obj, encoding): + """ + In Python3: + Encode the str type object to bytes type with specific encoding + + In Python2: + Encode the unicode type object to str type with specific encoding, + or we just return the 8-bit string of object + + Args: + obj(unicode|str|bytes) : The object to be encoded. + encoding(str) : The encoding format + + Returns: + encoded result of obj + """ + if obj is None: + return obj + + assert encoding is not None if isinstance(obj, six.text_type): return obj.encode(encoding) elif isinstance(obj, six.binary_type): @@ -64,15 +177,48 @@ def round(x, d=0): Returns: round result of x """ - p = 10**d - return float(math.floor((x * p) + math.copysign(0.5, x))) / p + if six.PY3: + # The official walkaround of round in Python3 is incorrect + # we implement accroding this answer: https://www.techforgeek.info/round_python.html + if x > 0.0: + p = 10 ** d + return float(math.floor((x * p) + math.copysign(0.5, x))) / p + else: + p = 10 ** d + return float(math.ceil((x * p) + math.copysign(0.5, x))) / p + else: + import __builtin__ + return __builtin__.round(x, d) def floor_division(x, y): + """ + Compatible division which act the same behaviour in Python3 and Python2, + whose result will be a int value of floor(x / y) in Python3 and value of + (x / y) in Python2. + + Args: + x(int|float) : The number to divide. + y(int|float) : The number to be divided + + Returns: + division result of x // y + """ return x // y # exception related functions def get_exception_message(exc): + """ + Get the error message of a specific exception + + Args: + exec(Exception) : The exception to get error message. + + Returns: + the error message of exec + """ + assert exc is not None + if six.PY2: return exc.message else: diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index c11455b7a6..e21a7a3ddd 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -20,6 +20,7 @@ from .layer_function_generator import autodoc, templatedoc from ..layer_helper import LayerHelper from . import tensor from . import nn +from .. import compat as cpt import math import six from functools import reduce @@ -1104,7 +1105,8 @@ def multi_box_head(inputs, mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1]) new_shape = [ mbox_loc.shape[0], - mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4 + mbox_loc.shape[1] * mbox_loc.shape[2] * cpt.floor_division(mbox_loc.shape[3], 4), + 4 ] mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape) mbox_locs.append(mbox_loc_flatten) @@ -1119,8 +1121,9 @@ def multi_box_head(inputs, stride=stride) conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1]) new_shape = [ - conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] * - conf_loc.shape[3] / num_classes, num_classes + conf_loc.shape[0], + conf_loc.shape[1] * conf_loc.shape[2] * cpt.floor_division(conf_loc.shape[3], num_classes), + num_classes ] conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape) mbox_confs.append(conf_loc_flatten) diff --git a/python/paddle/fluid/tests/unittests/test_compat.py b/python/paddle/fluid/tests/unittests/test_compat.py new file mode 100644 index 0000000000..0725d2c49a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_compat.py @@ -0,0 +1,490 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid.compat as cpt +import six + + +class TestCompatible(unittest.TestCase): + def test_to_literal_str(self): + # Only support python2.x and python3.x now + self.assertTrue(six.PY2 | six.PY3) + + if six.PY2: + # check None + self.assertIsNone(cpt.to_literal_str(None)) + + # check all string related types + self.assertTrue(isinstance(cpt.to_literal_str(str("")), unicode)) + self.assertTrue(isinstance(cpt.to_literal_str(str("123")), unicode)) + self.assertTrue(isinstance(cpt.to_literal_str(b""), unicode)) + self.assertTrue(isinstance(cpt.to_literal_str(b""), unicode)) + self.assertTrue(isinstance(cpt.to_literal_str(u""), unicode)) + self.assertTrue(isinstance(cpt.to_literal_str(u""), unicode)) + + self.assertEqual(u"", cpt.to_literal_str(str(""))) + self.assertEqual(u"123", cpt.to_literal_str(str("123"))) + self.assertEqual(u"", cpt.to_literal_str(b"")) + self.assertEqual(u"123", cpt.to_literal_str(b"123")) + self.assertEqual(u"", cpt.to_literal_str(u"")) + self.assertEqual(u"123", cpt.to_literal_str(u"123")) + + # check list types, not inplace + l = [""] + l2 = cpt.to_literal_str(l) + self.assertTrue(isinstance(l2, list)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual([u""], l2) + l = ["", "123"] + l2 = cpt.to_literal_str(l) + self.assertTrue(isinstance(l2, list)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual([u"", u"123"], l2) + l = ["", b'123', u"321"] + l2 = cpt.to_literal_str(l) + self.assertTrue(isinstance(l2, list)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual([u"", u"123", u"321"], l2) + for i in l2: + self.assertTrue(isinstance(i, unicode)) + + + # check list types, inplace + l = [""] + l2 = cpt.to_literal_str(l, inplace=True) + self.assertTrue(isinstance(l2, list)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual([u""], l2) + l = ["", "123"] + l2 = cpt.to_literal_str(l, inplace=True) + self.assertTrue(isinstance(l2, list)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual([u"", u"123"], l2) + l = ["", b"123", u"321"] + l2 = cpt.to_literal_str(l, inplace=True) + self.assertTrue(isinstance(l2, list)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual([u"", u"123", u"321"], l2) + + # check set types, not inplace + l = set("") + l2 = cpt.to_literal_str(l, inplace=False) + self.assertTrue(isinstance(l2, set)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set(u""), l2) + l = set([b"", b"123"]) + l2 = cpt.to_literal_str(l, inplace=False) + self.assertTrue(isinstance(l2, set)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set([u"", u"123"]), l2) + l = set(["", b"123", u"321"]) + l2 = cpt.to_literal_str(l, inplace=False) + self.assertTrue(isinstance(l2, set)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set([u"", u"123", u"321"]), l2) + for i in l2: + self.assertTrue(isinstance(i, unicode)) + + # check set types, inplace + l = set("") + l2 = cpt.to_literal_str(l, inplace=True) + self.assertTrue(isinstance(l2, set)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set(u""), l2) + l = set([b"", b"123"]) + l2 = cpt.to_literal_str(l, inplace=True) + self.assertTrue(isinstance(l2, set)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set([u"", u"123"]), l2) + l = set(["", b"123", u"321"]) + l2 = cpt.to_literal_str(l, inplace=True) + self.assertTrue(isinstance(l2, set)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set([u"", u"123", u"321"]), l2) + + elif six.PY3: + self.assertIsNone(cpt.to_literal_str(None)) + + self.assertTrue(isinstance(cpt.to_literal_str(str("")), str)) + self.assertTrue(isinstance(cpt.to_literal_str(str("123")), str)) + self.assertTrue(isinstance(cpt.to_literal_str(b""), str)) + self.assertTrue(isinstance(cpt.to_literal_str(b""), str)) + self.assertTrue(isinstance(cpt.to_literal_str(u""), str)) + self.assertTrue(isinstance(cpt.to_literal_str(u""), str)) + + self.assertEqual("", cpt.to_literal_str(str(""))) + self.assertEqual("123", cpt.to_literal_str(str("123"))) + self.assertEqual("", cpt.to_literal_str(b"")) + self.assertEqual("123", cpt.to_literal_str(b"123")) + self.assertEqual("", cpt.to_literal_str(u"")) + self.assertEqual("123", cpt.to_literal_str(u"123")) + + # check list types, not inplace + l = [""] + l2 = cpt.to_literal_str(l) + self.assertTrue(isinstance(l2, list)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual([""], l2) + l = ["", "123"] + l2 = cpt.to_literal_str(l) + self.assertTrue(isinstance(l2, list)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual(["", "123"], l2) + l = ["", b"123", u"321"] + l2 = cpt.to_literal_str(l) + self.assertTrue(isinstance(l2, list)) + self.assertFalse(l is l2) + self.assertNotEqual(l, l2) + self.assertEqual(["", "123", "321"], l2) + + # check list types, inplace + l = [""] + l2 = cpt.to_literal_str(l, inplace=True) + self.assertTrue(isinstance(l2, list)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual([""], l2) + l = ["", b"123"] + l2 = cpt.to_literal_str(l, inplace=True) + self.assertTrue(isinstance(l2, list)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(["", "123"], l2) + l = ["", b"123", u"321"] + l2 = cpt.to_literal_str(l, inplace=True) + self.assertTrue(isinstance(l2, list)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(["", "123", "321"], l2) + for i in l2: + self.assertTrue(isinstance(i, str)) + + # check set types, not inplace + l = set("") + l2 = cpt.to_literal_str(l, inplace=False) + self.assertTrue(isinstance(l2, set)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set(""), l2) + l = set([b"", b"123"]) + l2 = cpt.to_literal_str(l, inplace=False) + self.assertTrue(isinstance(l2, set)) + self.assertFalse(l is l2) + self.assertNotEqual(l, l2) + self.assertEqual(set(["", "123"]), l2) + l = set(["", b"123", u"321"]) + l2 = cpt.to_literal_str(l, inplace=False) + self.assertTrue(isinstance(l2, set)) + self.assertFalse(l is l2) + self.assertNotEqual(l, l2) + self.assertEqual(set(["", "123", "321"]), l2) + + # check set types, inplace + l = set("") + l2 = cpt.to_literal_str(l, inplace=True) + self.assertTrue(isinstance(l2, set)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set(""), l2) + l = set([b"", b"123"]) + l2 = cpt.to_literal_str(l, inplace=True) + self.assertTrue(isinstance(l2, set)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set(["", "123"]), l2) + l = set(["", b"123", u"321"]) + l2 = cpt.to_literal_str(l, inplace=True) + self.assertTrue(isinstance(l2, set)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set(["", "123", "321"]), l2) + for i in l2: + self.assertTrue(isinstance(i, str)) + + def test_to_bytes(self): + # Only support python2.x and python3.x now + self.assertTrue(six.PY2 | six.PY3) + + if six.PY2: + # check None + self.assertIsNone(cpt.to_bytes(None)) + + # check all string related types + self.assertTrue(isinstance(cpt.to_bytes(str("")), bytes)) + self.assertTrue(isinstance(cpt.to_bytes(str("123")), bytes)) + self.assertTrue(isinstance(cpt.to_bytes(b""), bytes)) + self.assertTrue(isinstance(cpt.to_bytes(b""), bytes)) + self.assertTrue(isinstance(cpt.to_bytes(u""), bytes)) + self.assertTrue(isinstance(cpt.to_bytes(u""), bytes)) + + self.assertEqual(b"", cpt.to_bytes(str(""))) + self.assertEqual(b"123", cpt.to_bytes(str("123"))) + self.assertEqual(b"", cpt.to_bytes(b"")) + self.assertEqual(b"123", cpt.to_bytes(b"123")) + self.assertEqual(b"", cpt.to_bytes(u"")) + self.assertEqual(b"123", cpt.to_bytes(u"123")) + + # check list types, not inplace + l = [""] + l2 = cpt.to_bytes(l) + self.assertTrue(isinstance(l2, list)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual([b""], l2) + l = ["", "123"] + l2 = cpt.to_bytes(l) + self.assertTrue(isinstance(l2, list)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual([b"", b"123"], l2) + l = ["", b'123', u"321"] + l2 = cpt.to_bytes(l) + self.assertTrue(isinstance(l2, list)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual([b"", b"123", b"321"], l2) + for i in l2: + self.assertTrue(isinstance(i, bytes)) + + + # check list types, inplace + l = [""] + l2 = cpt.to_bytes(l, inplace=True) + self.assertTrue(isinstance(l2, list)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual([b""], l2) + l = ["", "123"] + l2 = cpt.to_bytes(l, inplace=True) + self.assertTrue(isinstance(l2, list)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual([b"", b"123"], l2) + l = ["", b"123", u"321"] + l2 = cpt.to_bytes(l, inplace=True) + self.assertTrue(isinstance(l2, list)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual([b"", b"123", b"321"], l2) + + # check set types, not inplace + l = set("") + l2 = cpt.to_bytes(l, inplace=False) + self.assertTrue(isinstance(l2, set)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set(b""), l2) + l = set([b"", b"123"]) + l2 = cpt.to_bytes(l, inplace=False) + self.assertTrue(isinstance(l2, set)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set([b"", b"123"]), l2) + l = set(["", b"123", u"321"]) + l2 = cpt.to_bytes(l, inplace=False) + self.assertTrue(isinstance(l2, set)) + self.assertFalse(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set([b"", b"123", b"321"]), l2) + for i in l2: + self.assertTrue(isinstance(i, bytes)) + + # check set types, inplace + l = set("") + l2 = cpt.to_bytes(l, inplace=True) + self.assertTrue(isinstance(l2, set)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set(b""), l2) + l = set([b"", b"123"]) + l2 = cpt.to_bytes(l, inplace=True) + self.assertTrue(isinstance(l2, set)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set([b"", b"123"]), l2) + l = set(["", b"123", u"321"]) + l2 = cpt.to_bytes(l, inplace=True) + self.assertTrue(isinstance(l2, set)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set([b"", b"123", b"321"]), l2) + + elif six.PY3: + self.assertIsNone(cpt.to_bytes(None)) + + self.assertTrue(isinstance(cpt.to_bytes(str("")), bytes)) + self.assertTrue(isinstance(cpt.to_bytes(str("123")), bytes)) + self.assertTrue(isinstance(cpt.to_bytes(b""), bytes)) + self.assertTrue(isinstance(cpt.to_bytes(b""), bytes)) + self.assertTrue(isinstance(cpt.to_bytes(u""), bytes)) + self.assertTrue(isinstance(cpt.to_bytes(u""), bytes)) + + self.assertEqual(b"", cpt.to_bytes(str(""))) + self.assertEqual(b"123", cpt.to_bytes(str("123"))) + self.assertEqual(b"", cpt.to_bytes(b"")) + self.assertEqual(b"123", cpt.to_bytes(b"123")) + self.assertEqual(b"", cpt.to_bytes(u"")) + self.assertEqual(b"123", cpt.to_bytes(u"123")) + + # check list types, not inplace + l = [""] + l2 = cpt.to_bytes(l) + self.assertTrue(isinstance(l2, list)) + self.assertFalse(l is l2) + self.assertNotEqual(l, l2) + self.assertEqual([b""], l2) + l = ["", "123"] + l2 = cpt.to_bytes(l) + self.assertTrue(isinstance(l2, list)) + self.assertFalse(l is l2) + self.assertNotEqual(l, l2) + self.assertEqual([b"", b"123"], l2) + l = ["", b"123", u"321"] + l2 = cpt.to_bytes(l) + self.assertTrue(isinstance(l2, list)) + self.assertFalse(l is l2) + self.assertNotEqual(l, l2) + self.assertEqual([b"", b"123", b"321"], l2) + + # check list types, inplace + l = [""] + l2 = cpt.to_bytes(l, inplace=True) + self.assertTrue(isinstance(l2, list)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual([b""], l2) + l = ["", b"123"] + l2 = cpt.to_bytes(l, inplace=True) + self.assertTrue(isinstance(l2, list)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual([b"", b"123"], l2) + l = ["", b"123", u"321"] + l2 = cpt.to_bytes(l, inplace=True) + self.assertTrue(isinstance(l2, list)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual([b"", b"123", b"321"], l2) + for i in l2: + self.assertTrue(isinstance(i, bytes)) + + # check set types, not inplace + l = set([""]) + l2 = cpt.to_bytes(l, inplace=False) + self.assertTrue(isinstance(l2, set)) + self.assertFalse(l is l2) + self.assertNotEqual(l, l2) + self.assertEqual(set([b""]), l2) + l = set([u"", u"123"]) + l2 = cpt.to_bytes(l, inplace=False) + self.assertTrue(isinstance(l2, set)) + self.assertFalse(l is l2) + self.assertNotEqual(l, l2) + self.assertEqual(set([b"", b"123"]), l2) + l = set(["", b"123", u"321"]) + l2 = cpt.to_bytes(l, inplace=False) + self.assertTrue(isinstance(l2, set)) + self.assertFalse(l is l2) + self.assertNotEqual(l, l2) + self.assertEqual(set([b"", b"123", b"321"]), l2) + + # check set types, inplace + l = set("") + l2 = cpt.to_bytes(l, inplace=True) + self.assertTrue(isinstance(l2, set)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set(b""), l2) + l = set([u"", u"123"]) + l2 = cpt.to_bytes(l, inplace=True) + self.assertTrue(isinstance(l2, set)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set([b"", b"123"]), l2) + l = set(["", b"123", u"321"]) + l2 = cpt.to_bytes(l, inplace=True) + self.assertTrue(isinstance(l2, set)) + self.assertTrue(l is l2) + self.assertEqual(l, l2) + self.assertEqual(set([b"", b"123", b"321"]), l2) + for i in l2: + self.assertTrue(isinstance(i, bytes)) + + def test_round(self): + self.assertEqual(3.0, cpt.round(3.4)) + self.assertEqual(4.0, cpt.round(3.5)) + self.assertEqual(0.0, cpt.round(0.1)) + self.assertEqual(-0.0, cpt.round(-0.1)) + self.assertEqual(-3.0, cpt.round(-3.4)) + self.assertEqual(-4.0, cpt.round(-3.5)) + self.assertEqual(5.0, cpt.round(5)) + self.assertRaises(TypeError, cpt.round, None) + + def test_floor_division(self): + self.assertEqual(0.0, cpt.floor_division(3, 4)) + self.assertEqual(1.0, cpt.floor_division(4, 3)) + self.assertEqual(2.0, cpt.floor_division(6, 3)) + self.assertEqual(-2.0, cpt.floor_division(-4, 3)) + self.assertEqual(-2.0, cpt.floor_division(-6, 3)) + self.assertRaises(ZeroDivisionError, cpt.floor_division, 3, 0) + self.assertRaises(TypeError, cpt.floor_division, None, None) + + def test_get_exception_message(self): + exception_message = "test_message" + self.assertRaises(AssertionError, cpt.get_exception_message, None) + if six.PY2: + self.assertRaises(AttributeError, cpt.get_exception_message, exception_message) + try: + raise RuntimeError(exception_message) + except Exception as e: + self.assertEqual(exception_message, cpt.get_exception_message(e)) + self.assertIsNotNone(e) + + try: + raise Exception(exception_message) + except Exception as e: + self.assertEqual(exception_message, cpt.get_exception_message(e)) + self.assertIsNotNone(e) + + if six.PY3: + try: + raise RuntimeError(exception_message) + except Exception as e: + self.assertEqual(exception_message, cpt.get_exception_message(e)) + self.assertIsNotNone(e) + + try: + raise Exception(exception_message) + except Exception as e: + self.assertEqual(exception_message, cpt.get_exception_message(e)) + self.assertIsNotNone(e) + + +if __name__ == "__main__": + unittest.main() From 938a9e4faa8deb6a9a7a7af04b7fb761b6983823 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 9 Aug 2018 21:57:17 +0800 Subject: [PATCH 28/94] Polish code --- python/paddle/fluid/framework.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index af80cd9ca1..aec06febcc 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -574,9 +574,6 @@ class Operator(object): attr_val = self.attrs[attr_name] self._update_desc_attr(attr_name, attr_val) - import sys - print('self.attrs', self.attrs) - sys.stdout.flush() self.desc.check_attrs() if self.has_kernel(type): self.desc.infer_var_type(self.block.desc) From 59adf7ced11acbc659953055b9746d3065699058 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 9 Aug 2018 23:48:36 +0800 Subject: [PATCH 29/94] Fix round(0.0) special issue --- python/paddle/fluid/compat.py | 4 +++- python/paddle/fluid/debugger.py | 1 + python/paddle/fluid/graphviz.py | 1 + python/paddle/fluid/profiler.py | 3 ++- python/paddle/fluid/tests/unittests/test_compat.py | 2 ++ .../paddle/fluid/tests/unittests/test_conv_shift_op.py | 2 +- python/paddle/fluid/tests/unittests/test_gru_op.py | 3 ++- .../fluid/tests/unittests/test_inference_model_io.py | 5 +++-- .../paddle/fluid/tests/unittests/test_pool_max_op.py | 10 +++++----- 9 files changed, 20 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/compat.py b/python/paddle/fluid/compat.py index 16932a2a0c..f0ea9d4aac 100644 --- a/python/paddle/fluid/compat.py +++ b/python/paddle/fluid/compat.py @@ -183,9 +183,11 @@ def round(x, d=0): if x > 0.0: p = 10 ** d return float(math.floor((x * p) + math.copysign(0.5, x))) / p - else: + elif x < 0.0: p = 10 ** d return float(math.ceil((x * p) + math.copysign(0.5, x))) / p + else: + return math.copysign(0.0, x) else: import __builtin__ return __builtin__.round(x, d) diff --git a/python/paddle/fluid/debugger.py b/python/paddle/fluid/debugger.py index dd8523f95b..ea6c14df72 100644 --- a/python/paddle/fluid/debugger.py +++ b/python/paddle/fluid/debugger.py @@ -13,6 +13,7 @@ # limitations under the License. import sys +import six import re from .graphviz import GraphPreviewGenerator from .proto import framework_pb2 diff --git a/python/paddle/fluid/graphviz.py b/python/paddle/fluid/graphviz.py index 0557d7fd8a..5e823418bd 100644 --- a/python/paddle/fluid/graphviz.py +++ b/python/paddle/fluid/graphviz.py @@ -15,6 +15,7 @@ import os import random import six +import functools import subprocess import logging diff --git a/python/paddle/fluid/profiler.py b/python/paddle/fluid/profiler.py index 60e9215457..5fbb35abdd 100644 --- a/python/paddle/fluid/profiler.py +++ b/python/paddle/fluid/profiler.py @@ -15,6 +15,7 @@ from . import core from contextlib import contextmanager import os +import six __all__ = [ 'cuda_profiler', 'reset_profiler', 'profiler', 'start_profiler', @@ -88,7 +89,7 @@ def cuda_profiler(output_file, output_mode=None, config=None): config = NVPROF_CONFIG if config is None else config config_file = 'nvprof_config_file' with open(config_file, 'wb') as fp: - fp.writelines(["%s\n" % item for item in config]) + fp.writelines([six.b("%s\n" % item) for item in config]) core.nvprof_init(output_file, output_mode, config_file) # Enables profiler collection by the active CUDA profiling tool. core.nvprof_start() diff --git a/python/paddle/fluid/tests/unittests/test_compat.py b/python/paddle/fluid/tests/unittests/test_compat.py index 0725d2c49a..20e93515de 100644 --- a/python/paddle/fluid/tests/unittests/test_compat.py +++ b/python/paddle/fluid/tests/unittests/test_compat.py @@ -440,6 +440,8 @@ class TestCompatible(unittest.TestCase): self.assertEqual(3.0, cpt.round(3.4)) self.assertEqual(4.0, cpt.round(3.5)) self.assertEqual(0.0, cpt.round(0.1)) + self.assertEqual(0.0, cpt.round(0.0)) + self.assertEqual(-0.0, cpt.round(-0.0)) self.assertEqual(-0.0, cpt.round(-0.1)) self.assertEqual(-3.0, cpt.round(-3.4)) self.assertEqual(-4.0, cpt.round(-3.5)) diff --git a/python/paddle/fluid/tests/unittests/test_conv_shift_op.py b/python/paddle/fluid/tests/unittests/test_conv_shift_op.py index 9fdb7baa90..d524832058 100644 --- a/python/paddle/fluid/tests/unittests/test_conv_shift_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv_shift_op.py @@ -21,7 +21,7 @@ def conv_shift_forward(x, y): out = np.zeros_like(x) M = x.shape[1] N = y.shape[1] - y_half_width = (N - 1) / 2 + y_half_width = (N - 1) // 2 for i in range(M): for j in range(N): out[:, i] += x[:, (i + j + M - y_half_width) % M] * y[:, j] diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py index 86a2c674d0..4bbec06a91 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np import math +import functools from op_test import OpTest from test_lstm_op import identity, sigmoid, tanh, relu @@ -38,7 +39,7 @@ class TestGRUOp(OpTest): for i in range(len(seq_lens)): seq_starts.append(seq_starts[-1] + seq_lens[i]) sorted_seqs = sorted( - list(range(len(seq_lens))), lambda x, y: seq_lens[y] - seq_lens[x]) + list(range(len(seq_lens))), key=functools.cmp_to_key(lambda x, y: seq_lens[y] - seq_lens[x])) num_batch = seq_lens[sorted_seqs[0]] for batch_idx in range(num_batch): idx_in_seq = [] diff --git a/python/paddle/fluid/tests/unittests/test_inference_model_io.py b/python/paddle/fluid/tests/unittests/test_inference_model_io.py index 4cd203155f..66cc78e4d4 100644 --- a/python/paddle/fluid/tests/unittests/test_inference_model_io.py +++ b/python/paddle/fluid/tests/unittests/test_inference_model_io.py @@ -14,6 +14,7 @@ import unittest +import six import numpy as np import paddle.fluid.core as core @@ -48,7 +49,7 @@ class TestBook(unittest.TestCase): exe.run(init_program, feed={}, fetch_list=[]) - for i in range(100): + for i in six.moves.xrange(100): tensor_x = np.array( [[1, 1], [1, 2], [3, 4], [5, 2]]).astype("float32") tensor_y = np.array([[-2], [-3], [-7], [-7]]).astype("float32") @@ -64,7 +65,7 @@ class TestBook(unittest.TestCase): 'y': tensor_y}, fetch_list=[avg_cost])[0] - reload(executor) # reload to build a new scope + six.moves.reload_module(executor) # reload to build a new scope exe = executor.Executor(place) [infer_prog, feed_var_names, fetch_vars] = load_inference_model( diff --git a/python/paddle/fluid/tests/unittests/test_pool_max_op.py b/python/paddle/fluid/tests/unittests/test_pool_max_op.py index e6a9f6f08c..9a23fde340 100644 --- a/python/paddle/fluid/tests/unittests/test_pool_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool_max_op.py @@ -24,9 +24,9 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=False): ksize = [D, H, W] paddings = [0, 0, 0] - D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1 - H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1 - W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1 + D_out = (D - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + H_out = (H - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + W_out = (W - ksize[2] + 2 * paddings[2]) // strides[2] + 1 out = np.zeros((N, C, D_out, H_out, W_out)) mask = np.zeros((N, C, D_out, H_out, W_out)) for k in range(D_out): @@ -63,8 +63,8 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=False): ksize = [H, W] paddings = [0, 0] - H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1 - W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + H_out = (H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 out = np.zeros((N, C, H_out, W_out)) mask = np.zeros((N, C, H_out, W_out)) for i in range(H_out): From be6ecec46f7b2b1d661a101744ee838dffc519ea Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 10 Aug 2018 00:08:24 +0800 Subject: [PATCH 30/94] Fix unittests' division issues --- python/paddle/fluid/layers/nn.py | 24 +++++++++---------- .../fluid/tests/unittests/test_conv3d_op.py | 18 +++++++------- .../fluid/tests/unittests/test_infer_shape.py | 13 +++++----- .../fluid/tests/unittests/test_layers.py | 4 ++-- .../fluid/tests/unittests/test_pool3d_op.py | 24 +++++++++---------- 5 files changed, 42 insertions(+), 41 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d1ae284d54..37e860b08c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -550,7 +550,7 @@ def dynamic_lstmp(input, """ helper = LayerHelper('lstmp', **locals()) - size = size / 4 + size = size // 4 weight = helper.create_parameter( attr=helper.param_attr, shape=[proj_size, 4 * size], dtype=dtype) proj_weight = helper.create_parameter( @@ -778,7 +778,7 @@ def gru_unit(input, helper = LayerHelper('gru_unit', **locals()) dtype = helper.input_dtype() - size = size / 3 + size = size // 3 # create weight weight = helper.create_parameter( @@ -1258,7 +1258,7 @@ def sequence_conv(input, outputs={"Out": pre_bias}, attrs={ 'contextStride': filter_stride, - 'contextStart': -int(filter_size / 2), + 'contextStart': -int(filter_size // 2), 'contextLength': filter_size }) pre_act = helper.append_bias_op(pre_bias) @@ -1487,7 +1487,7 @@ def conv2d(input, else: if num_channels % groups != 0: raise ValueError("num_channels must be divisible by groups.") - num_filter_channels = num_channels / groups + num_filter_channels = num_channels // groups filter_size = utils.convert_to_list(filter_size, 2, 'filter_size') stride = utils.convert_to_list(stride, 2, 'stride') @@ -1649,7 +1649,7 @@ def conv3d(input, else: if num_channels % groups != 0: raise ValueError("num_channels must be divisible by groups.") - num_filter_channels = num_channels / groups + num_filter_channels = num_channels // groups filter_size = utils.convert_to_list(filter_size, 3, 'filter_size') stride = utils.convert_to_list(stride, 3, 'stride') @@ -2384,16 +2384,16 @@ def conv2d_transpose(input, w_in = input.shape[3] filter_size_h = (output_size[0] - (h_in - 1) * stride[0] + 2 * - padding[0] - 1) / dilation[0] + 1 + padding[0] - 1) // dilation[0] + 1 filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 * - padding[1] - 1) / dilation[1] + 1 + padding[1] - 1) // dilation[1] + 1 filter_size = [filter_size_h, filter_size_w] else: filter_size = utils.convert_to_list(filter_size, 2, 'conv2d_transpose.filter_size') groups = 1 if groups is None else groups - filter_shape = [input_channel, num_filters / groups] + filter_size + filter_shape = [input_channel, num_filters // groups] + filter_size img_filter = helper.create_parameter( dtype=input.dtype, shape=filter_shape, attr=helper.param_attr) @@ -2551,18 +2551,18 @@ def conv3d_transpose(input, w_in = input.shape[4] filter_size_d = (output_size[0] - (d_in - 1) * stride[0] + 2 * - padding[0] - 1) / dilation[0] + 1 + padding[0] - 1) // dilation[0] + 1 filter_size_h = (output_size[1] - (h_in - 1) * stride[1] + 2 * - padding[1] - 1) / dilation[1] + 1 + padding[1] - 1) // dilation[1] + 1 filter_size_w = (output_size[2] - (w_in - 1) * stride[2] + 2 * - padding[2] - 1) / dilation[2] + 1 + padding[2] - 1) // dilation[2] + 1 filter_size = [filter_size_d, filter_size_h, filter_size_w] else: filter_size = utils.convert_to_list(filter_size, 3, 'conv3d_transpose.filter_size') groups = 1 if groups is None else groups - filter_shape = [input_channel, num_filters / groups] + filter_size + filter_shape = [input_channel, num_filters // groups] + filter_size img_filter = helper.create_parameter( dtype=input.dtype, shape=filter_shape, attr=helper.param_attr) diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_op.py index dd4ef7cc94..e473ebacea 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_op.py @@ -24,14 +24,14 @@ def conv3d_forward_naive(input, filter, group, conv_param): out_c, f_c, f_d, f_h, f_w = filter.shape assert f_c * group == in_c assert np.mod(out_c, group) == 0 - sub_out_c = out_c / group + sub_out_c = out_c // group stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[ 'dilations'] - out_d = 1 + (in_d + 2 * pad[0] - (dilation[0] * (f_d - 1) + 1)) / stride[0] - out_h = 1 + (in_h + 2 * pad[1] - (dilation[1] * (f_h - 1) + 1)) / stride[1] - out_w = 1 + (in_w + 2 * pad[2] - (dilation[2] * (f_w - 1) + 1)) / stride[2] + out_d = 1 + (in_d + 2 * pad[0] - (dilation[0] * (f_d - 1) + 1)) // stride[0] + out_h = 1 + (in_h + 2 * pad[1] - (dilation[1] * (f_h - 1) + 1)) // stride[1] + out_w = 1 + (in_w + 2 * pad[2] - (dilation[2] * (f_w - 1) + 1)) // stride[2] out = np.zeros((in_n, out_c, out_d, out_h, out_w)) @@ -166,7 +166,7 @@ class TestConv3dOp(OpTest): self.stride = [1, 1, 1] self.input_size = [2, 3, 4, 4, 4] # NCDHW assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3, 3] def init_dilation(self): @@ -185,7 +185,7 @@ class TestCase1(TestConv3dOp): self.stride = [1, 1, 1] self.input_size = [2, 3, 4, 4, 4] # NCDHW assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3, 3] @@ -205,7 +205,7 @@ class TestWith1x1(TestConv3dOp): self.stride = [1, 1, 1] self.input_size = [2, 3, 4, 4, 4] # NCHW assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 1, 1, 1] def init_dilation(self): @@ -221,7 +221,7 @@ class TestWithInput1x1Filter1x1(TestConv3dOp): self.stride = [1, 1, 1] self.input_size = [2, 3, 1, 1, 1] # NCHW assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 1, 1, 1] def init_dilation(self): @@ -237,7 +237,7 @@ class TestWithDilation(TestConv3dOp): self.stride = [1, 1, 1] self.input_size = [2, 3, 6, 6, 6] # NCDHW assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] / self.groups + f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 2, 2, 2] def init_dilation(self): diff --git a/python/paddle/fluid/tests/unittests/test_infer_shape.py b/python/paddle/fluid/tests/unittests/test_infer_shape.py index 699a2d4246..ede51f6550 100644 --- a/python/paddle/fluid/tests/unittests/test_infer_shape.py +++ b/python/paddle/fluid/tests/unittests/test_infer_shape.py @@ -14,6 +14,7 @@ import unittest +import six import paddle.fluid.core as core @@ -27,14 +28,14 @@ class TestInferShape(unittest.TestCase): shape = [10, 20] # prepare input/output - x1 = block.var("x1") + x1 = block.var(six.b("x1")) x1.set_type(core.VarDesc.VarType.LOD_TENSOR) x1.set_shape(shape) - x2 = block.var("x2") + x2 = block.var(six.b("x2")) x2.set_type(core.VarDesc.VarType.LOD_TENSOR) x2.set_shape(shape) - out = block.var("out") + out = block.var(six.b("out")) out.set_type(core.VarDesc.VarType.LOD_TENSOR) # prepare the operator @@ -57,14 +58,14 @@ class TestInferShape(unittest.TestCase): y_shape = [20, 30] # prepare input/output - x1 = block.var("x") + x1 = block.var(six.b("x")) x1.set_type(core.VarDesc.VarType.LOD_TENSOR) x1.set_shape(x_shape) - x2 = block.var("y") + x2 = block.var(six.b("y")) x2.set_type(core.VarDesc.VarType.LOD_TENSOR) x2.set_shape(y_shape) - out = block.var("out") + out = block.var(six.b("out")) out.set_type(core.VarDesc.VarType.LOD_TENSOR) # prepare the operator diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 8f2dac786d..aae5a24f6c 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -158,7 +158,7 @@ class TestBook(unittest.TestCase): input=crf_decode, label=label, chunk_scheme="IOB", - num_chunk_types=(label_dict_len - 1) / 2) + num_chunk_types=(label_dict_len - 1) // 2) self.assertFalse(crf is None) self.assertFalse(crf_decode is None) @@ -285,7 +285,7 @@ class TestBook(unittest.TestCase): name='word_{0}'.format(i), shape=[1], dtype='int64')) dict_size = 10000 - label_word = int(window_size / 2) + 1 + label_word = int(window_size // 2) + 1 embs = [] for i in range(window_size): diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_op.py b/python/paddle/fluid/tests/unittests/test_pool3d_op.py index 92c64b3792..a358c84991 100644 --- a/python/paddle/fluid/tests/unittests/test_pool3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool3d_op.py @@ -29,14 +29,14 @@ def max_pool3D_forward_naive(x, if global_pool == 1: ksize = [D, H, W] D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1 - ) / strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 * - paddings[0]) / strides[0] + 1 + ) // strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 * + paddings[0]) // strides[0] + 1 H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1 - ) / strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 * - paddings[1]) / strides[1] + 1 + ) // strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 * + paddings[1]) // strides[1] + 1 W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1 - ) / strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 * - paddings[2]) / strides[2] + 1 + ) // strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 * + paddings[2]) // strides[2] + 1 out = np.zeros((N, C, D_out, H_out, W_out)) for k in range(D_out): d_start = np.max((k * strides[0] - paddings[0], 0)) @@ -63,14 +63,14 @@ def avg_pool3D_forward_naive(x, if global_pool == 1: ksize = [D, H, W] D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1 - ) / strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 * - paddings[0]) / strides[0] + 1 + ) // strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 * + paddings[0]) // strides[0] + 1 H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1 - ) / strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 * - paddings[1]) / strides[1] + 1 + ) // strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 * + paddings[1]) // strides[1] + 1 W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1 - ) / strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 * - paddings[2]) / strides[2] + 1 + ) // strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 * + paddings[2]) // strides[2] + 1 out = np.zeros((N, C, D_out, H_out, W_out)) for k in range(D_out): d_start = np.max((k * strides[0] - paddings[0], 0)) From 9cd5999032b41cdd324ef011a3e6a0ea3a185461 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 10 Aug 2018 12:00:15 +0800 Subject: [PATCH 31/94] Fix dist transpiler unordered dict issue --- .../tests/unittests/test_dist_transpiler.py | 1 + .../fluid/transpiler/distribute_transpiler.py | 24 ++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 4d1a89da3d..a42a9718fd 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -46,6 +46,7 @@ class TranspilerTest(unittest.TestCase): def get_main_program(self): main = fluid.Program() + main.random_seed = 1 with fluid.program_guard(main): self.net_conf() self.origin_prog = main.clone() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index aca9aafd52..252afc058b 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -31,6 +31,7 @@ Steps to transpile pserver: import math import random import numpy as np +import collections from .ps_dispatcher import RoundRobin, HashName, PSDispatcher from .. import core, framework @@ -218,8 +219,9 @@ class DistributeTranspiler(object): # fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2 # shuffle the map will avoid the uneven distribution above grad_var_mapping_items = list(self.grad_var_mapping.items()) + if not self.config.slice_var_up: - random.seed(self.trainer_num) + random.seed(self.origin_program.random_seed) random.shuffle(grad_var_mapping_items) for orig_varname, splited_vars in grad_var_mapping_items: @@ -557,14 +559,14 @@ class DistributeTranspiler(object): # 1. create vars in pserver program to startup program pserver_vars = pserver_program.global_block().vars - created_var_map = dict() + created_var_map = collections.OrderedDict() for _, var in list(pserver_vars.items()): tmpvar = s_prog.global_block()._clone_variable(var) created_var_map[var.name] = tmpvar # 2. rename op outputs for op in orig_s_prog.global_block().ops: - new_outputs = dict() + new_outputs = collections.OrderedDict() # do not append startup op if var is not on this pserver op_on_pserver = False for key in op.output_names: @@ -703,7 +705,7 @@ class DistributeTranspiler(object): self.origin_program, grad_blocks, add_trainer_suffix=self.trainer_num > 1) - self.grad_param_mapping = dict() + self.grad_param_mapping = collections.OrderedDict() for g, p in zip(grad_blocks, param_blocks): g_name, g_bid, _ = g.split(":") p_name, p_bid, _ = p.split(":") @@ -711,7 +713,7 @@ class DistributeTranspiler(object): self.param_var_mapping[p_name][int(p_bid)] # create mapping of endpoint -> split var to create pserver side program - self.param_grad_ep_mapping = dict() + self.param_grad_ep_mapping = collections.OrderedDict() [ self.param_grad_ep_mapping.update({ ep: { @@ -981,14 +983,14 @@ class DistributeTranspiler(object): block_list (list[(varname, block_id, block_size)]): List of gradient blocks. add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True. Returns: - var_mapping (dict(varname->[new_varname_variable])):A dict mapping + var_mapping (collections.OrderedDict(varname->[new_varname_variable])):A dict mapping from original var name to each var split. """ # varname->[(block_id, current_block_size)] - block_map = dict() + block_map = collections.OrderedDict() - var_mapping = dict() + var_mapping = collections.OrderedDict() for block_str in block_list: varname, offset, size = block_str.split(":") if varname not in block_map: @@ -1181,7 +1183,7 @@ class DistributeTranspiler(object): grad_to_block_id, origin_program, merged_var): program = optimize_block.program pserver_block = program.global_block() - new_inputs = dict() + new_inputs = collections.OrderedDict() # update param/grad shape first, then other inputs like # moment can use the updated shape for key in opt_op.input_names: @@ -1359,7 +1361,7 @@ class DistributeTranspiler(object): def _get_input_map_from_op(self, varmap, op): """Returns a dict from op input name to the vars in varmap.""" - iomap = dict() + iomap = collections.OrderedDict() for key in op.input_names: vars = [] for varname in op.input(key): @@ -1372,7 +1374,7 @@ class DistributeTranspiler(object): def _get_output_map_from_op(self, varmap, op): """Returns a dict from op output name to the vars in varmap.""" - iomap = dict() + iomap = collections.OrderedDict() for key in op.output_names: vars = [] for varname in op.output(key): From 6dc07e7f95ec44d348a1a597203edb397c1460ce Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 10 Aug 2018 13:50:20 +0800 Subject: [PATCH 32/94] Replace items() with six.moves.iteritems() to improve memory usage --- python/paddle/dataset/imdb.py | 7 +++--- python/paddle/dataset/imikolov.py | 10 ++++---- python/paddle/dataset/sentiment.py | 3 ++- python/paddle/dataset/wmt14.py | 4 ++-- python/paddle/dataset/wmt16.py | 2 +- python/paddle/fluid/backward.py | 10 ++++---- python/paddle/fluid/framework.py | 2 +- python/paddle/fluid/graphviz.py | 9 ++++---- python/paddle/fluid/layers/control_flow.py | 3 ++- python/paddle/fluid/metrics.py | 13 ++++++----- .../paddle/fluid/tests/unittests/benchmark.py | 4 ++-- .../tests/unittests/test_detection_map_op.py | 3 ++- .../tests/unittests/test_lod_rank_table.py | 3 ++- .../test_positive_negative_pair_op.py | 3 ++- python/paddle/fluid/trainer.py | 5 ++-- .../fluid/transpiler/distribute_transpiler.py | 23 +++++++++---------- 16 files changed, 57 insertions(+), 47 deletions(-) diff --git a/python/paddle/dataset/imdb.py b/python/paddle/dataset/imdb.py index 60a9062c46..7c915062c3 100644 --- a/python/paddle/dataset/imdb.py +++ b/python/paddle/dataset/imdb.py @@ -47,8 +47,9 @@ def tokenize(pattern): while tf != None: if bool(pattern.match(tf.name)): # newline and punctuations removal and ad-hoc tokenization. - yield tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate( - None, six.b(string.punctuation)).lower().split() + yield tarf.extractfile(tf).read().rstrip(six.b( + "\n\r")).translate( + None, six.b(string.punctuation)).lower().split() tf = tarf.next() @@ -63,7 +64,7 @@ def build_dict(pattern, cutoff): word_freq[word] += 1 # Not sure if we should prune less-frequent words here. - word_freq = [x for x in list(word_freq.items()) if x[1] > cutoff] + word_freq = [x for x in six.moves.iteritems(word_freq) if x[1] > cutoff] dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) words, _ = list(zip(*dictionary)) diff --git a/python/paddle/dataset/imikolov.py b/python/paddle/dataset/imikolov.py index bfb087ff38..bfa287ce80 100644 --- a/python/paddle/dataset/imikolov.py +++ b/python/paddle/dataset/imikolov.py @@ -21,7 +21,7 @@ into paddle reader creators. import paddle.dataset.common import collections import tarfile -from six.moves import range +import six __all__ = ['train', 'test', 'build_dict', 'convert'] @@ -65,11 +65,13 @@ def build_dict(min_word_freq=50): # remove for now, since we will set it as last index del word_freq[''] - word_freq = [x for x in list(word_freq.items()) if x[1] > min_word_freq] + word_freq = [ + x for x in six.moves.iteritems(word_freq) if x[1] > min_word_freq + ] word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) words, _ = list(zip(*word_freq_sorted)) - word_idx = dict(list(zip(words, range(len(words))))) + word_idx = dict(list(zip(words, six.moves.range(len(words))))) word_idx[''] = len(words) return word_idx @@ -90,7 +92,7 @@ def reader_creator(filename, word_idx, n, data_type): l = [''] + l.strip().split() + [''] if len(l) >= n: l = [word_idx.get(w, UNK) for w in l] - for i in range(n, len(l) + 1): + for i in six.moves.range(n, len(l) + 1): yield tuple(l[i - n:i]) elif DataType.SEQ == data_type: l = l.strip().split() diff --git a/python/paddle/dataset/sentiment.py b/python/paddle/dataset/sentiment.py index 953ada057b..078ba74bef 100644 --- a/python/paddle/dataset/sentiment.py +++ b/python/paddle/dataset/sentiment.py @@ -20,6 +20,7 @@ The script fetch and preprocess movie_reviews data set that provided by NLTK TODO(yuyang18): Complete dataset. """ +import six import collections from itertools import chain @@ -64,7 +65,7 @@ def get_word_dict(): for field in movie_reviews.fileids(category): for words in movie_reviews.words(field): word_freq_dict[words] += 1 - words_sort_list = list(word_freq_dict.items()) + words_sort_list = six.moves.iteritems(word_freq_dict) words_sort_list.sort(cmp=lambda a, b: b[1] - a[1]) for index, word in enumerate(words_sort_list): words_freq_sorted.append((word[0], index)) diff --git a/python/paddle/dataset/wmt14.py b/python/paddle/dataset/wmt14.py index 7488e21f1f..3c413c71c6 100644 --- a/python/paddle/dataset/wmt14.py +++ b/python/paddle/dataset/wmt14.py @@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True): tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN) src_dict, trg_dict = __read_to_dict(tar_file, dict_size) if reverse: - src_dict = {v: k for k, v in list(src_dict.items())} - trg_dict = {v: k for k, v in list(trg_dict.items())} + src_dict = {v: k for k, v in six.moves.iteritems(src_dict)} + trg_dict = {v: k for k, v in six.moves.iteritems(trg_dict)} return src_dict, trg_dict diff --git a/python/paddle/dataset/wmt16.py b/python/paddle/dataset/wmt16.py index 3e453a6479..e59fa531d1 100644 --- a/python/paddle/dataset/wmt16.py +++ b/python/paddle/dataset/wmt16.py @@ -72,7 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang): fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)) for idx, word in enumerate( sorted( - iter(list(word_dict.items())), + six.moves.iteritems(word_dict), key=lambda x: x[1], reverse=True)): if idx + 3 == dict_size: break diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 804608827b..07aea2d8e6 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): """ op_desc = core.OpDesc() op_desc.set_type(op_type) - for para, args in list(inputs.items()): + for para, args in six.moves.iteritems(inputs): op_desc.set_input( para, list( map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg, args))) - for para, args in list(outputs.items()): + for para, args in six.moves.iteritems(outputs): op_desc.set_output( para, list( @@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): if op_role_attr_name not in attrs: attrs[ op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward - for name, val in list(attrs.items()): + for name, val in six.moves.iteritems(attrs): if isinstance(val, framework.Block): op_desc.set_block_attr(name, val.desc) else: @@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs): op_desc.set_output(param_name, arg_names) renamed_vars[var_name].append(new_name) - for var_name, inputs in list(renamed_vars.items()): + for var_name, inputs in six.moves.iteritems(renamed_vars): if len(inputs) > 1: pending_sum_ops.append( (_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]}, @@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): op_desc.rename_output(name, new_name) var_map[name] = new_name - for g, ng in list(var_map.items()): + for g, ng in six.moves.iteritems(var_map): if g in grad_to_var: grad_to_var[ng] = grad_to_var[g] grad_to_var.pop(g) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index aec06febcc..0156be9045 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -958,7 +958,7 @@ class Block(object): return list(self.iter_parameters()) def iter_parameters(self): - return (item[1] for item in list(self.vars.items()) + return (item[1] for item in six.moves.iteritems(self.vars) if isinstance(item[1], Parameter)) def create_var(self, *args, **kwargs): diff --git a/python/paddle/fluid/graphviz.py b/python/paddle/fluid/graphviz.py index 5e823418bd..966f58a977 100644 --- a/python/paddle/fluid/graphviz.py +++ b/python/paddle/fluid/graphviz.py @@ -106,7 +106,7 @@ class Graph(object): def _rank_repr(self): ranks = sorted( - list(self.rank_groups.items()), + six.moves.iteritems(self.rank_groups), key=functools.cmp_to_key( lambda a, b: a[1].priority > b[1].priority)) repr = [] @@ -150,8 +150,9 @@ class Node(object): reprs = '{name} [label={label} {extra} ];'.format( name=self.name, label=self.label, - extra=',' + ','.join("%s=%s" % (key, crepr(value)) - for key, value in list(self.attrs.items())) + extra=',' + ','.join( + "%s=%s" % (key, crepr(value)) + for key, value in six.moves.iteritems(self.attrs)) if self.attrs else "") return reprs @@ -175,7 +176,7 @@ class Edge(object): target=self.target.name, extra="" if not self.attrs else "[" + ','.join("{}={}".format(attr[0], crepr(attr[1])) - for attr in list(self.attrs.items())) + "]") + for attr in six.moves.iteritems(self.attrs)) + "]") return repr diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 9fb7b4d0ca..2fc0961699 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -22,6 +22,7 @@ from ..initializer import force_init_on_cpu from .ops import logical_and, logical_not, logical_or import numpy import warnings +import six from functools import reduce __all__ = [ @@ -602,7 +603,7 @@ class StaticRNN(object): boot_memories = [] pre_memories = [] memories = [] - for _, mem in list(self.memories.items()): + for _, mem in six.moves.iteritems(self.memories): boot_memories.append(mem.init) pre_memories.append(mem.pre_mem.name) mem_var = rnn_block.var(mem.mem.name) diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index cd89345227..3bfcfd8585 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -14,11 +14,12 @@ """ Fluid Metrics -The metrics are accomplished via Python natively. +The metrics are accomplished via Python natively. """ import numpy as np import copy import warnings +import six __all__ = [ 'MetricBase', @@ -79,10 +80,10 @@ class MetricBase(object): """ states = { attr: value - for attr, value in list(self.__dict__.items()) + for attr, value in six.moves.iteritems(self.__dict__) if not attr.startswith("_") } - for attr, value in list(states.items()): + for attr, value in six.moves.iteritems(states): if isinstance(value, int): setattr(self, attr, 0) elif isinstance(value, float): @@ -105,7 +106,7 @@ class MetricBase(object): """ states = { attr: value - for attr, value in list(self.__dict__.items()) + for attr, value in six.moves.iteritems(self.__dict__) if not attr.startswith("_") } config = {} @@ -141,10 +142,10 @@ class CompositeMetric(MetricBase): """ Composite multiple metrics in one instance. for example, merge F1, accuracy, recall into one Metric. - + Examples: .. code-block:: python - + labels = fluid.layers.data(name="data", shape=[1], dtype="int32") data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32") pred = fluid.layers.fc(input=data, size=1000, act="tanh") diff --git a/python/paddle/fluid/tests/unittests/benchmark.py b/python/paddle/fluid/tests/unittests/benchmark.py index b98a92dcbe..0dbde89fcd 100644 --- a/python/paddle/fluid/tests/unittests/benchmark.py +++ b/python/paddle/fluid/tests/unittests/benchmark.py @@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest): def _get_input_names(self): inputs = [] - for name, value in list(self.inputs.items()): + for name, value in six.moves.iteritems(self.inputs): if isinstance(value, list): inputs.extend([sub_name for sub_name, _ in value]) inputs.append(name) @@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest): def _get_output_names(self): outputs = [] - for var_name, var in list(self.outputs.items()): + for var_name, var in six.moves.iteritems(self.outputs): if isinstance(var, list): for sub_var_name, sub_var in var: outputs.append(sub_var_name) diff --git a/python/paddle/fluid/tests/unittests/test_detection_map_op.py b/python/paddle/fluid/tests/unittests/test_detection_map_op.py index 8b66d1b270..07c77c03cb 100644 --- a/python/paddle/fluid/tests/unittests/test_detection_map_op.py +++ b/python/paddle/fluid/tests/unittests/test_detection_map_op.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import six import sys import collections import math @@ -176,7 +177,7 @@ class TestDetectionMAPOp(OpTest): true_pos[label].append([score, tp]) false_pos[label].append([score, fp]) - for (label, label_pos_num) in list(label_count.items()): + for (label, label_pos_num) in six.moves.iteritems(label_count): if label_pos_num == 0 or label not in true_pos: continue label_true_pos = true_pos[label] label_false_pos = false_pos[label] diff --git a/python/paddle/fluid/tests/unittests/test_lod_rank_table.py b/python/paddle/fluid/tests/unittests/test_lod_rank_table.py index d53ead381d..180e9c4cb3 100644 --- a/python/paddle/fluid/tests/unittests/test_lod_rank_table.py +++ b/python/paddle/fluid/tests/unittests/test_lod_rank_table.py @@ -18,6 +18,7 @@ from paddle.fluid.executor import Executor import paddle.fluid.core as core import numpy import unittest +import six class TestLoDRankTable(unittest.TestCase): @@ -36,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase): exe.run(scope=scope, feed={'x': tensor}) var = scope.find_var(rank_table.name) table = var.get_lod_rank_table() - self.assertEqual([(0, 5), (1, 1), (2, 1)], list(table.items())) + self.assertEqual([(0, 5), (1, 1), (2, 1)], six.moves.iteritems(table)) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py b/python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py index 8c76393bda..0e525cbff8 100644 --- a/python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py +++ b/python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py @@ -15,6 +15,7 @@ import unittest import itertools import numpy as np +import six from op_test import OpTest @@ -32,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None): # accumulate statistics pos, neg, neu = 0, 0, 0 - for _, ranks in list(predictions.items()): + for _, ranks in six.moves.iteritems(predictions): for e1, e2 in itertools.combinations(ranks, 2): s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2] w = (w1 + w2) * 0.5 diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index eed9b49ef4..ac4db36de4 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -16,6 +16,7 @@ import contextlib import os import errno import shutil +import six import time from . import core @@ -618,7 +619,7 @@ def build_feed_var_list(program, feed_order): "The values of 'feed_order' should be a permutation of [0, len(feed_order))" ) sorted_pair_list = sorted( - list(feed_order.items()), key=lambda item: item[1]) + six.moves.iteritems(feed_order), key=lambda item: item[1]) feed_var_list = [ program.global_block().var(pair[0]) for pair in sorted_pair_list ] @@ -1036,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args): cur_dir = _get_trainer_dir(dirname, trainer_id) - for name, value in list(trainer_args.items()): + for name, value in six.moves.iteritems(trainer_args): args_file = os.path.join(cur_dir, name) with open(args_file, 'w') as f: f.write(str(value)) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 252afc058b..8d863a7856 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -218,7 +218,8 @@ class DistributeTranspiler(object): # fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1 # fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2 # shuffle the map will avoid the uneven distribution above - grad_var_mapping_items = list(self.grad_var_mapping.items()) + grad_var_mapping_items = list( + six.moves.iteritems(self.grad_var_mapping)) if not self.config.slice_var_up: random.seed(self.origin_program.random_seed) @@ -279,7 +280,7 @@ class DistributeTranspiler(object): self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) # step4: Concat the parameters splits together after recv. - for varname, splited_var in list(self.param_var_mapping.items()): + for varname, splited_var in six.moves.iteritems(self.param_var_mapping): eps = [] for var in splited_var: index = [v.name for v in recv_vars].index(var.name) @@ -303,7 +304,7 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) - for varname, splited_var in list(self.param_var_mapping.items()): + for varname, splited_var in six.moves.iteritems(self.param_var_mapping): if len(splited_var) <= 1: continue orig_param = program.global_block().vars[varname] @@ -560,7 +561,7 @@ class DistributeTranspiler(object): # 1. create vars in pserver program to startup program pserver_vars = pserver_program.global_block().vars created_var_map = collections.OrderedDict() - for _, var in list(pserver_vars.items()): + for _, var in six.moves.iteritems(pserver_vars): tmpvar = s_prog.global_block()._clone_variable(var) created_var_map[var.name] = tmpvar @@ -997,7 +998,7 @@ class DistributeTranspiler(object): block_map[varname] = [] block_map[varname].append((int(offset), int(size))) - for varname, splited in list(block_map.items()): + for varname, splited in six.moves.iteritems(block_map): orig_var = program.global_block().var(varname) if len(splited) == 1: if self.sync_mode and add_trainer_suffix: @@ -1248,9 +1249,7 @@ class DistributeTranspiler(object): def _is_splited_grad_var(self, var, var_dict): grad_block = None - # TODO(minqiyang): replace these items() with six.iteritems() to - # improve memory - for _, g in list(var_dict.items()): + for _, g in six.moves.iteritems(var_dict): if self._orig_varname(g.name) == self._orig_varname(var.name): if g.name.find(".trainer_") == -1: grad_block = g @@ -1260,7 +1259,7 @@ class DistributeTranspiler(object): def _clone_lr_op(self, program, block, op): inputs = self._get_input_map_from_op( self.origin_program.global_block().vars, op) - for key, varlist in list(inputs.items()): + for key, varlist in six.moves.iteritems(inputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: @@ -1269,7 +1268,7 @@ class DistributeTranspiler(object): outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, op) - for key, varlist in list(outputs.items()): + for key, varlist in six.moves.iteritems(outputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: @@ -1284,7 +1283,7 @@ class DistributeTranspiler(object): # Append the ops for parameters that do not need to be optimized/updated inputs = self._get_input_map_from_op( self.origin_program.global_block().vars, opt_op) - for key, varlist in list(inputs.items()): + for key, varlist in six.moves.iteritems(inputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: @@ -1303,7 +1302,7 @@ class DistributeTranspiler(object): outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, opt_op) - for key, varlist in list(outputs.items()): + for key, varlist in six.moves.iteritems(outputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: From 76ee482e18ef8b6a319a5203c45e4db8bc5bb0c5 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 10 Aug 2018 17:00:21 +0800 Subject: [PATCH 33/94] Fix cv2 issues --- python/paddle/dataset/image.py | 19 +++++++++++++++---- .../paddle/fluid/tests/unittests/op_test.py | 5 +++++ python/paddle/reader/decorator.py | 5 +++-- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/python/paddle/dataset/image.py b/python/paddle/dataset/image.py index f7e7c854fe..99d2c5f899 100644 --- a/python/paddle/dataset/image.py +++ b/python/paddle/dataset/image.py @@ -33,6 +33,11 @@ import numpy as np try: import cv2 except ImportError: + import sys + sys.stderr.write( + '''Warning with paddle image module: opencv-python should be imported, + or paddle image module could NOT work; please install opencv-python first.''' + ) cv2 = None import os import tarfile @@ -126,6 +131,8 @@ def load_image_bytes(bytes, is_color=True): load and return a gray image. :type is_color: bool """ + assert cv2 is not None + flag = 1 if is_color else 0 file_bytes = np.asarray(bytearray(bytes), dtype=np.uint8) img = cv2.imdecode(file_bytes, flag) @@ -149,6 +156,8 @@ def load_image(file, is_color=True): load and return a gray image. :type is_color: bool """ + assert cv2 is not None + # cv2.IMAGE_COLOR for OpenCV3 # cv2.CV_LOAD_IMAGE_COLOR for older OpenCV Version # cv2.IMAGE_GRAYSCALE for OpenCV3 @@ -176,12 +185,14 @@ def resize_short(im, size): :param size: the shorter edge size of image after resizing. :type size: int """ + assert cv2 is not None + h, w = im.shape[:2] h_new, w_new = size, size if h > w: - h_new = size * h / w + h_new = size * h // w else: - w_new = size * w / h + w_new = size * w // h im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC) return im @@ -228,8 +239,8 @@ def center_crop(im, size, is_color=True): :type is_color: bool """ h, w = im.shape[:2] - h_start = (h - size) / 2 - w_start = (w - size) / 2 + h_start = (h - size) // 2 + w_start = (w - size) // 2 h_end, w_end = h_start + size, w_start + size if is_color: im = im[h_start:h_end, w_start:w_end, :] diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 1ed14e35b1..ada4ad70f0 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -362,9 +362,14 @@ class OpTest(unittest.TestCase): def check_output_customized(self, checker): places = self._get_places() + import sys + print('places', places) for place in places: outs = self.calc_output(place) outs = [np.array(out) for out in outs] + import sys + print('outs', outs) + sys.stdout.flush() checker(outs) def __assert_is_close(self, numeric_grads, analytic_grads, names, diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index ce410e61b9..d53694959b 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -27,6 +27,7 @@ from six.moves import zip import itertools import random import zlib +import paddle.fluid.compat as cpt def map_readers(func, *readers): @@ -390,9 +391,9 @@ class PipeReader: buff = self.process.stdout.read(self.bufsize) if buff: if self.file_type == "gzip": - decomp_buff = self.dec.decompress(buff) + decomp_buff = cpt.to_literal_str(self.dec.decompress(buff)) elif self.file_type == "plain": - decomp_buff = buff + decomp_buff = cpt.to_literal_str(buff) else: raise TypeError("file_type %s is not allowed" % self.file_type) From e4e9450e88a5851cb5bbff12860def544127f850 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 10 Aug 2018 21:15:10 +0800 Subject: [PATCH 34/94] Fix random crop op problem --- python/paddle/fluid/tests/unittests/test_random_crop_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_random_crop_op.py b/python/paddle/fluid/tests/unittests/test_random_crop_op.py index 1c708d0386..1acd377b1f 100644 --- a/python/paddle/fluid/tests/unittests/test_random_crop_op.py +++ b/python/paddle/fluid/tests/unittests/test_random_crop_op.py @@ -21,7 +21,7 @@ from op_test import OpTest class TestRandomCropOp(OpTest): def setUp(self): to_crop = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]] * - 5).astype("float32") + 5).astype(np.int32) self.possible_res = [ np.array([[1, 2, 3], [5, 6, 7]]), np.array([[2, 3, 4], [6, 7, 8]]), np.array([[5, 6, 7], [9, 10, 11]]), From 5d4238cdccbe914b102d336c5427e865cbf7b7d7 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 10 Aug 2018 21:22:36 +0800 Subject: [PATCH 35/94] Fix six.iteritems problem --- python/paddle/dataset/imdb.py | 2 +- python/paddle/dataset/imikolov.py | 2 +- python/paddle/dataset/sentiment.py | 2 +- python/paddle/dataset/wmt14.py | 4 ++-- python/paddle/dataset/wmt16.py | 3 +-- python/paddle/fluid/backward.py | 10 ++++----- python/paddle/fluid/framework.py | 2 +- python/paddle/fluid/graphviz.py | 9 ++++---- python/paddle/fluid/layers/control_flow.py | 2 +- python/paddle/fluid/metrics.py | 6 +++--- .../paddle/fluid/tests/unittests/benchmark.py | 4 ++-- .../tests/unittests/test_detection_map_op.py | 2 +- .../tests/unittests/test_lod_rank_table.py | 2 +- .../test_positive_negative_pair_op.py | 2 +- python/paddle/fluid/trainer.py | 4 ++-- .../fluid/transpiler/distribute_transpiler.py | 21 +++++++++---------- 16 files changed, 37 insertions(+), 40 deletions(-) diff --git a/python/paddle/dataset/imdb.py b/python/paddle/dataset/imdb.py index 7c915062c3..903e93d34f 100644 --- a/python/paddle/dataset/imdb.py +++ b/python/paddle/dataset/imdb.py @@ -64,7 +64,7 @@ def build_dict(pattern, cutoff): word_freq[word] += 1 # Not sure if we should prune less-frequent words here. - word_freq = [x for x in six.moves.iteritems(word_freq) if x[1] > cutoff] + word_freq = [x for x in six.iteritems(word_freq) if x[1] > cutoff] dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) words, _ = list(zip(*dictionary)) diff --git a/python/paddle/dataset/imikolov.py b/python/paddle/dataset/imikolov.py index bfa287ce80..422eaef644 100644 --- a/python/paddle/dataset/imikolov.py +++ b/python/paddle/dataset/imikolov.py @@ -66,7 +66,7 @@ def build_dict(min_word_freq=50): del word_freq[''] word_freq = [ - x for x in six.moves.iteritems(word_freq) if x[1] > min_word_freq + x for x in six.iteritems(word_freq) if x[1] > min_word_freq ] word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) diff --git a/python/paddle/dataset/sentiment.py b/python/paddle/dataset/sentiment.py index 078ba74bef..25cd59df92 100644 --- a/python/paddle/dataset/sentiment.py +++ b/python/paddle/dataset/sentiment.py @@ -65,7 +65,7 @@ def get_word_dict(): for field in movie_reviews.fileids(category): for words in movie_reviews.words(field): word_freq_dict[words] += 1 - words_sort_list = six.moves.iteritems(word_freq_dict) + words_sort_list = six.iteritems(word_freq_dict) words_sort_list.sort(cmp=lambda a, b: b[1] - a[1]) for index, word in enumerate(words_sort_list): words_freq_sorted.append((word[0], index)) diff --git a/python/paddle/dataset/wmt14.py b/python/paddle/dataset/wmt14.py index 3c413c71c6..75363a30f3 100644 --- a/python/paddle/dataset/wmt14.py +++ b/python/paddle/dataset/wmt14.py @@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True): tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN) src_dict, trg_dict = __read_to_dict(tar_file, dict_size) if reverse: - src_dict = {v: k for k, v in six.moves.iteritems(src_dict)} - trg_dict = {v: k for k, v in six.moves.iteritems(trg_dict)} + src_dict = {v: k for k, v in six.iteritems(src_dict)} + trg_dict = {v: k for k, v in six.iteritems(trg_dict)} return src_dict, trg_dict diff --git a/python/paddle/dataset/wmt16.py b/python/paddle/dataset/wmt16.py index e59fa531d1..c5772e1f19 100644 --- a/python/paddle/dataset/wmt16.py +++ b/python/paddle/dataset/wmt16.py @@ -72,8 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang): fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)) for idx, word in enumerate( sorted( - six.moves.iteritems(word_dict), - key=lambda x: x[1], + six.iteritems(word_dict), key=lambda x: x[1], reverse=True)): if idx + 3 == dict_size: break fout.write("%s\n" % (word[0])) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 07aea2d8e6..f51acdac6e 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): """ op_desc = core.OpDesc() op_desc.set_type(op_type) - for para, args in six.moves.iteritems(inputs): + for para, args in six.iteritems(inputs): op_desc.set_input( para, list( map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg, args))) - for para, args in six.moves.iteritems(outputs): + for para, args in six.iteritems(outputs): op_desc.set_output( para, list( @@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): if op_role_attr_name not in attrs: attrs[ op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward - for name, val in six.moves.iteritems(attrs): + for name, val in six.iteritems(attrs): if isinstance(val, framework.Block): op_desc.set_block_attr(name, val.desc) else: @@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs): op_desc.set_output(param_name, arg_names) renamed_vars[var_name].append(new_name) - for var_name, inputs in six.moves.iteritems(renamed_vars): + for var_name, inputs in six.iteritems(renamed_vars): if len(inputs) > 1: pending_sum_ops.append( (_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]}, @@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): op_desc.rename_output(name, new_name) var_map[name] = new_name - for g, ng in six.moves.iteritems(var_map): + for g, ng in six.iteritems(var_map): if g in grad_to_var: grad_to_var[ng] = grad_to_var[g] grad_to_var.pop(g) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 0156be9045..304df414f6 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -958,7 +958,7 @@ class Block(object): return list(self.iter_parameters()) def iter_parameters(self): - return (item[1] for item in six.moves.iteritems(self.vars) + return (item[1] for item in six.iteritems(self.vars) if isinstance(item[1], Parameter)) def create_var(self, *args, **kwargs): diff --git a/python/paddle/fluid/graphviz.py b/python/paddle/fluid/graphviz.py index 966f58a977..27d4a7d8dc 100644 --- a/python/paddle/fluid/graphviz.py +++ b/python/paddle/fluid/graphviz.py @@ -106,7 +106,7 @@ class Graph(object): def _rank_repr(self): ranks = sorted( - six.moves.iteritems(self.rank_groups), + six.iteritems(self.rank_groups), key=functools.cmp_to_key( lambda a, b: a[1].priority > b[1].priority)) repr = [] @@ -150,9 +150,8 @@ class Node(object): reprs = '{name} [label={label} {extra} ];'.format( name=self.name, label=self.label, - extra=',' + ','.join( - "%s=%s" % (key, crepr(value)) - for key, value in six.moves.iteritems(self.attrs)) + extra=',' + ','.join("%s=%s" % (key, crepr(value)) + for key, value in six.iteritems(self.attrs)) if self.attrs else "") return reprs @@ -176,7 +175,7 @@ class Edge(object): target=self.target.name, extra="" if not self.attrs else "[" + ','.join("{}={}".format(attr[0], crepr(attr[1])) - for attr in six.moves.iteritems(self.attrs)) + "]") + for attr in six.iteritems(self.attrs)) + "]") return repr diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 2fc0961699..730075a1ec 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -603,7 +603,7 @@ class StaticRNN(object): boot_memories = [] pre_memories = [] memories = [] - for _, mem in six.moves.iteritems(self.memories): + for _, mem in six.iteritems(self.memories): boot_memories.append(mem.init) pre_memories.append(mem.pre_mem.name) mem_var = rnn_block.var(mem.mem.name) diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index 3bfcfd8585..19df1e1dcb 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -80,10 +80,10 @@ class MetricBase(object): """ states = { attr: value - for attr, value in six.moves.iteritems(self.__dict__) + for attr, value in six.iteritems(self.__dict__) if not attr.startswith("_") } - for attr, value in six.moves.iteritems(states): + for attr, value in six.iteritems(states): if isinstance(value, int): setattr(self, attr, 0) elif isinstance(value, float): @@ -106,7 +106,7 @@ class MetricBase(object): """ states = { attr: value - for attr, value in six.moves.iteritems(self.__dict__) + for attr, value in six.iteritems(self.__dict__) if not attr.startswith("_") } config = {} diff --git a/python/paddle/fluid/tests/unittests/benchmark.py b/python/paddle/fluid/tests/unittests/benchmark.py index 0dbde89fcd..d334d8b60c 100644 --- a/python/paddle/fluid/tests/unittests/benchmark.py +++ b/python/paddle/fluid/tests/unittests/benchmark.py @@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest): def _get_input_names(self): inputs = [] - for name, value in six.moves.iteritems(self.inputs): + for name, value in six.iteritems(self.inputs): if isinstance(value, list): inputs.extend([sub_name for sub_name, _ in value]) inputs.append(name) @@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest): def _get_output_names(self): outputs = [] - for var_name, var in six.moves.iteritems(self.outputs): + for var_name, var in six.iteritems(self.outputs): if isinstance(var, list): for sub_var_name, sub_var in var: outputs.append(sub_var_name) diff --git a/python/paddle/fluid/tests/unittests/test_detection_map_op.py b/python/paddle/fluid/tests/unittests/test_detection_map_op.py index 07c77c03cb..a471f62852 100644 --- a/python/paddle/fluid/tests/unittests/test_detection_map_op.py +++ b/python/paddle/fluid/tests/unittests/test_detection_map_op.py @@ -177,7 +177,7 @@ class TestDetectionMAPOp(OpTest): true_pos[label].append([score, tp]) false_pos[label].append([score, fp]) - for (label, label_pos_num) in six.moves.iteritems(label_count): + for (label, label_pos_num) in six.iteritems(label_count): if label_pos_num == 0 or label not in true_pos: continue label_true_pos = true_pos[label] label_false_pos = false_pos[label] diff --git a/python/paddle/fluid/tests/unittests/test_lod_rank_table.py b/python/paddle/fluid/tests/unittests/test_lod_rank_table.py index 180e9c4cb3..ea57412660 100644 --- a/python/paddle/fluid/tests/unittests/test_lod_rank_table.py +++ b/python/paddle/fluid/tests/unittests/test_lod_rank_table.py @@ -37,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase): exe.run(scope=scope, feed={'x': tensor}) var = scope.find_var(rank_table.name) table = var.get_lod_rank_table() - self.assertEqual([(0, 5), (1, 1), (2, 1)], six.moves.iteritems(table)) + self.assertEqual([(0, 5), (1, 1), (2, 1)], six.iteritems(table)) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py b/python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py index 0e525cbff8..fcb308ae2c 100644 --- a/python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py +++ b/python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py @@ -33,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None): # accumulate statistics pos, neg, neu = 0, 0, 0 - for _, ranks in six.moves.iteritems(predictions): + for _, ranks in six.iteritems(predictions): for e1, e2 in itertools.combinations(ranks, 2): s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2] w = (w1 + w2) * 0.5 diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index ac4db36de4..5d549e68d1 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -619,7 +619,7 @@ def build_feed_var_list(program, feed_order): "The values of 'feed_order' should be a permutation of [0, len(feed_order))" ) sorted_pair_list = sorted( - six.moves.iteritems(feed_order), key=lambda item: item[1]) + six.iteritems(feed_order), key=lambda item: item[1]) feed_var_list = [ program.global_block().var(pair[0]) for pair in sorted_pair_list ] @@ -1037,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args): cur_dir = _get_trainer_dir(dirname, trainer_id) - for name, value in six.moves.iteritems(trainer_args): + for name, value in six.iteritems(trainer_args): args_file = os.path.join(cur_dir, name) with open(args_file, 'w') as f: f.write(str(value)) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 8d863a7856..4effc7dfda 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -218,8 +218,7 @@ class DistributeTranspiler(object): # fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1 # fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2 # shuffle the map will avoid the uneven distribution above - grad_var_mapping_items = list( - six.moves.iteritems(self.grad_var_mapping)) + grad_var_mapping_items = list(six.iteritems(self.grad_var_mapping)) if not self.config.slice_var_up: random.seed(self.origin_program.random_seed) @@ -280,7 +279,7 @@ class DistributeTranspiler(object): self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) # step4: Concat the parameters splits together after recv. - for varname, splited_var in six.moves.iteritems(self.param_var_mapping): + for varname, splited_var in six.iteritems(self.param_var_mapping): eps = [] for var in splited_var: index = [v.name for v in recv_vars].index(var.name) @@ -304,7 +303,7 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) - for varname, splited_var in six.moves.iteritems(self.param_var_mapping): + for varname, splited_var in six.iteritems(self.param_var_mapping): if len(splited_var) <= 1: continue orig_param = program.global_block().vars[varname] @@ -561,7 +560,7 @@ class DistributeTranspiler(object): # 1. create vars in pserver program to startup program pserver_vars = pserver_program.global_block().vars created_var_map = collections.OrderedDict() - for _, var in six.moves.iteritems(pserver_vars): + for _, var in six.iteritems(pserver_vars): tmpvar = s_prog.global_block()._clone_variable(var) created_var_map[var.name] = tmpvar @@ -998,7 +997,7 @@ class DistributeTranspiler(object): block_map[varname] = [] block_map[varname].append((int(offset), int(size))) - for varname, splited in six.moves.iteritems(block_map): + for varname, splited in six.iteritems(block_map): orig_var = program.global_block().var(varname) if len(splited) == 1: if self.sync_mode and add_trainer_suffix: @@ -1249,7 +1248,7 @@ class DistributeTranspiler(object): def _is_splited_grad_var(self, var, var_dict): grad_block = None - for _, g in six.moves.iteritems(var_dict): + for _, g in six.iteritems(var_dict): if self._orig_varname(g.name) == self._orig_varname(var.name): if g.name.find(".trainer_") == -1: grad_block = g @@ -1259,7 +1258,7 @@ class DistributeTranspiler(object): def _clone_lr_op(self, program, block, op): inputs = self._get_input_map_from_op( self.origin_program.global_block().vars, op) - for key, varlist in six.moves.iteritems(inputs): + for key, varlist in six.iteritems(inputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: @@ -1268,7 +1267,7 @@ class DistributeTranspiler(object): outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, op) - for key, varlist in six.moves.iteritems(outputs): + for key, varlist in six.iteritems(outputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: @@ -1283,7 +1282,7 @@ class DistributeTranspiler(object): # Append the ops for parameters that do not need to be optimized/updated inputs = self._get_input_map_from_op( self.origin_program.global_block().vars, opt_op) - for key, varlist in six.moves.iteritems(inputs): + for key, varlist in six.iteritems(inputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: @@ -1302,7 +1301,7 @@ class DistributeTranspiler(object): outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, opt_op) - for key, varlist in six.moves.iteritems(outputs): + for key, varlist in six.iteritems(outputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: From 23447cd5a045f21884e596a2b2adb8c75ee224c3 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 10 Aug 2018 22:57:06 +0800 Subject: [PATCH 36/94] Fix parallel_executor_fetch_feed issue --- python/paddle/dataset/flowers.py | 4 ++-- python/paddle/reader/tests/creator_test.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index 914dae348b..7d14cc5dc8 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -116,8 +116,8 @@ def reader_creator(data_file, for file in open(file_list): file = file.strip() batch = None - with open(file, 'r') as f: - batch = pickle.load(f) + with open(file, 'rb') as f: + batch = pickle.loads(f.read()) data = batch['data'] labels = batch['label'] for sample, label in zip(data, batch['label']): diff --git a/python/paddle/reader/tests/creator_test.py b/python/paddle/reader/tests/creator_test.py index c4238c12a7..567f38c96e 100644 --- a/python/paddle/reader/tests/creator_test.py +++ b/python/paddle/reader/tests/creator_test.py @@ -29,6 +29,7 @@ import os import unittest import numpy as np import paddle.reader.creator +import six class TestNumpyArray(unittest.TestCase): @@ -37,7 +38,7 @@ class TestNumpyArray(unittest.TestCase): x = np.array(l, np.int32) reader = paddle.reader.creator.np_array(x) for idx, e in enumerate(reader()): - self.assertItemsEqual(e, l[idx]) + six.assertCountEqual(e, l[idx]) class TestTextFile(unittest.TestCase): From 0bfd62be3d0bb194fc23500828cc1327f8b9dae6 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Sun, 12 Aug 2018 17:02:40 +0800 Subject: [PATCH 37/94] remove gpu supported, will add it later --- paddle/fluid/operators/sampling_id_op.cu | 25 ------------------------ 1 file changed, 25 deletions(-) delete mode 100644 paddle/fluid/operators/sampling_id_op.cu diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu deleted file mode 100644 index c0bb9c916c..0000000000 --- a/paddle/fluid/operators/sampling_id_op.cu +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include "paddle/fluid/operators/sampling_id_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - sampling_id, - ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel); From 5377edd282bf4998d675d5551bb5b4e420fe4122 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 13 Aug 2018 11:35:11 +0800 Subject: [PATCH 38/94] refine packed condition --- paddle/fluid/operators/gru_op.cc | 135 ++++++++++++++++++------------- paddle/fluid/operators/gru_op.h | 3 - 2 files changed, 79 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 4847eb3626..2b5094925c 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -14,6 +14,11 @@ limitations under the License. */ #include "paddle/fluid/operators/gru_op.h" #include +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" +#include "paddle/fluid/operators/math/detail/gru_kernel.h" + +DECLARE_int32(paddle_num_threads); namespace paddle { namespace operators { @@ -264,76 +269,94 @@ class GRUCPUKernel : public framework::OpKernel { gru_value.prev_out_value = nullptr; } auto batch_starts = batch_gate->lod()[0]; - size_t num_batch = batch_starts.size() - 1; + size_t seq_len = batch_starts.size() - 1; auto active_node = math::detail::GetActivationType( context.Attr("activation")); auto active_gate = math::detail::GetActivationType( context.Attr("gate_activation")); #ifdef PADDLE_WITH_MKLML - auto blas = math::GetBlas(dev_ctx); - // TODO(TJ): make a class - T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, - frame_size * 2 /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE(packed_gate); - blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, - frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, - packed_gate); - T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, - frame_size /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE(packed_state); - blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, - frame_size, T(1.0), gru_value.state_weight, frame_size, - packed_state); -#endif - for (size_t n = 0; n < num_batch; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - Tensor gate_t = batch_gate->Slice(bstart, bend); - Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); - Tensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); + if (FLAGS_paddle_num_threads >= 4) { + auto blas = math::GetBlas(dev_ctx); + T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_gate); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, + frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, + packed_gate); + T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_state); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, + frame_size, T(1.0), gru_value.state_weight, frame_size, + packed_state); + for (size_t n = 0; n < seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; -#ifdef PADDLE_WITH_MKLML - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size, - frame_size * 2, frame_size, gru_value.prev_out_value, - frame_size, packed_gate, frame_size * 2, T(1), - gru_value.gate_value, frame_size * 3); - } + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); - math::detail::forward_reset_output( - math::detail::forward::gru_resetOutput(), gru_value, frame_size, - cur_batch_size, active_gate); + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE( + CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2, + frame_size, gru_value.prev_out_value, frame_size, packed_gate, + frame_size * 2, T(1), gru_value.gate_value, frame_size * 3); + } - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE( - CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, - gru_value.reset_output_value, frame_size, packed_state, frame_size, - T(1), gru_value.gate_value + frame_size * 2, frame_size * 3); + math::detail::forward_reset_output( + math::detail::forward::gru_resetOutput(), gru_value, frame_size, + cur_batch_size, active_gate); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE( + CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, + gru_value.reset_output_value, frame_size, packed_state, + frame_size, T(1), gru_value.gate_value + frame_size * 2, + frame_size * 3); + } + + math::detail::forward_final_output( + math::detail::forward::gru_finalOutput(), gru_value, frame_size, + cur_batch_size, active_node); + + gru_value.prev_out_value = gru_value.output_value; } - math::detail::forward_final_output( - math::detail::forward::gru_finalOutput(), gru_value, frame_size, - cur_batch_size, active_node); -#else - math::GRUUnitFunctor::compute( - dev_ctx, gru_value, frame_size, cur_batch_size, active_node, - active_gate); + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); + } else { #endif - gru_value.prev_out_value = gru_value.output_value; - } + for (size_t n = 0; n < seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + + math::GRUUnitFunctor::compute( + dev_ctx, gru_value, frame_size, cur_batch_size, active_node, + active_gate); + + gru_value.prev_out_value = gru_value.output_value; + } #ifdef PADDLE_WITH_MKLML - blas.GEMM_FREE(packed_gate); - blas.GEMM_FREE(packed_state); + } #endif - math::Batch2LoDTensorFunctor to_seq; batch_hidden->set_lod(batch_gate->lod()); to_seq(dev_ctx, *batch_hidden, hidden); diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index 0bf4e6bc44..0b551e8046 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -16,10 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" -#include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence2batch.h" From e7c7cbaa42e01aa09df156cdfa05c906f00ca20e Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 13 Aug 2018 14:39:19 +0800 Subject: [PATCH 39/94] Port new added files to Python3 --- .../paddle/fluid/contrib/memory_usage_calc.py | 20 ++++++++++--------- .../fluid/tests/unittests/dist_transformer.py | 1 + 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/contrib/memory_usage_calc.py b/python/paddle/fluid/contrib/memory_usage_calc.py index 5da846edb6..f0316a70ec 100644 --- a/python/paddle/fluid/contrib/memory_usage_calc.py +++ b/python/paddle/fluid/contrib/memory_usage_calc.py @@ -14,12 +14,14 @@ """ This module privides a memory usage calculate function for user. The purpose of this API is to allow users to estimate memory usage of -a program under a special batch size, then user can set appropriate -batch size to fully utilize a GPU. +a program under a special batch size, then user can set appropriate +batch size to fully utilize a GPU. This API is still under active development and may change drastically. """ +import six + from .. import core from ..framework import Program, Variable @@ -45,15 +47,15 @@ def memory_usage(program, batch_size): Args: program(Program): The current Program. - batch_size(int): The current input data batch_size. - + batch_size(int): The current input data batch_size. + Returns: min_total_memory(float): the estimate memory usage lower bound. max_total_memory(float): the estimate memory usage upper bound. unit_str(string): the unit of estimate usage result. - + Examples: - + >>> import paddle.fluid as fluid >>> lower_usage, upper_usage, unit = fluid.contrib.memory_usage( fluid.default_main_program(), batch_size=10) @@ -72,7 +74,7 @@ def memory_usage(program, batch_size): # Get the var_name list of first block and calculate total_memory = 0.0 - for var in program.global_block().vars.itervalues(): + for var in six.itervalues(program.global_block().vars): data_count = 1 for x in var.shape: if x == -1: @@ -81,10 +83,10 @@ def memory_usage(program, batch_size): data_count *= x var_memory = data_count * dtype_to_size[var.dtype] if DEBUG: - print "%s memory usage: %d" % (var.name, var_memory) + print("%s memory usage: %d" % (var.name, var_memory)) total_memory += var_memory if DEBUG: - print "total memory usage: %.2f" % (total_memory) + print("total memory usage: %.2f" % (total_memory)) # Convert appropriate unit unit_str = "B" diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py index 6bd4ecbbe1..41125d38bd 100644 --- a/python/paddle/fluid/tests/unittests/dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/dist_transformer.py @@ -160,6 +160,7 @@ def get_model(): avg_cost = transformer(use_feed=False) optimizer = fluid.optimizer.Adam() optimizer.minimize(avg_cost) + fluid.memory_optimize(fluid.default_main_program()) return avg_cost From a3539845f2e778e4c989b16b05c60dfd1439dfc2 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 13 Aug 2018 15:46:12 +0800 Subject: [PATCH 40/94] Polish python code style --- python/paddle/dataset/cifar.py | 10 +++++--- python/paddle/dataset/movielens.py | 3 ++- python/paddle/fluid/compat.py | 7 +++--- python/paddle/fluid/executor.py | 5 ++-- python/paddle/fluid/layers/detection.py | 10 ++++---- .../cifar10_small_test_set.py | 8 ++++--- .../fluid/tests/unittests/test_compat.py | 17 +++++++------ .../fluid/tests/unittests/test_gru_op.py | 3 ++- .../tests/unittests/test_operator_desc.py | 5 ++-- .../fluid/tests/unittests/test_pool3d_op.py | 24 +++++++++---------- .../memory_optimization_transpiler.py | 6 ++--- 11 files changed, 55 insertions(+), 43 deletions(-) diff --git a/python/paddle/dataset/cifar.py b/python/paddle/dataset/cifar.py index 0d07462e68..cfe9deeab0 100644 --- a/python/paddle/dataset/cifar.py +++ b/python/paddle/dataset/cifar.py @@ -47,21 +47,25 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' def reader_creator(filename, sub_name, cycle=False): def read_batch(batch): data = batch[six.b('data')] - labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None)) + labels = batch.get( + six.b('labels'), batch.get(six.b('fine_labels'), None)) assert labels is not None for sample, label in six.moves.zip(data, labels): yield (sample / 255.0).astype(numpy.float32), int(label) def reader(): with tarfile.open(filename, mode='r') as f: - names = [each_item.name for each_item in f if sub_name in each_item.name] + names = [ + each_item.name for each_item in f if sub_name in each_item.name + ] while True: for name in names: if six.PY2: batch = pickle.load(f.extractfile(name)) else: - batch = pickle.load(f.extractfile(name), encoding='bytes') + batch = pickle.load( + f.extractfile(name), encoding='bytes') for item in read_batch(batch): yield item if not cycle: diff --git a/python/paddle/dataset/movielens.py b/python/paddle/dataset/movielens.py index 354b7d4aee..137d6ca8d0 100644 --- a/python/paddle/dataset/movielens.py +++ b/python/paddle/dataset/movielens.py @@ -215,7 +215,8 @@ def max_job_id(): Get the maximum value of job id. """ __initialize_meta_info__() - return six.moves.reduce(__max_job_id_impl__, list(USER_INFO.values())).job_id + return six.moves.reduce(__max_job_id_impl__, + list(USER_INFO.values())).job_id def movie_categories(): diff --git a/python/paddle/fluid/compat.py b/python/paddle/fluid/compat.py index f0ea9d4aac..6d3e7794de 100644 --- a/python/paddle/fluid/compat.py +++ b/python/paddle/fluid/compat.py @@ -23,6 +23,7 @@ __all__ = [ 'get_exception_message', ] + # str and bytes related functions def to_literal_str(obj, encoding='utf-8', inplace=False): """ @@ -181,10 +182,10 @@ def round(x, d=0): # The official walkaround of round in Python3 is incorrect # we implement accroding this answer: https://www.techforgeek.info/round_python.html if x > 0.0: - p = 10 ** d + p = 10**d return float(math.floor((x * p) + math.copysign(0.5, x))) / p elif x < 0.0: - p = 10 ** d + p = 10**d return float(math.ceil((x * p) + math.copysign(0.5, x))) / p else: return math.copysign(0.0, x) @@ -208,6 +209,7 @@ def floor_division(x, y): """ return x // y + # exception related functions def get_exception_message(exc): """ @@ -225,4 +227,3 @@ def get_exception_message(exc): return exc.message else: return str(exc) - diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index a0cc7fac34..d840e41476 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -320,8 +320,9 @@ class Executor(object): # append fetch_operators if not has_fetch_operators(global_block, fetch_list, fetch_var_name): for i, var in enumerate(fetch_list): - assert isinstance(var, Variable) or isinstance(var, six.text_type), ( - "Wrong type for fetch_list[%s]: %s" % (i, type(var))) + assert isinstance(var, Variable) or isinstance( + var, six.text_type), ("Wrong type for fetch_list[%s]: %s" % + (i, type(var))) global_block.append_op( type='fetch', inputs={'X': [var]}, diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index e21a7a3ddd..0de66d1144 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -1104,9 +1104,8 @@ def multi_box_head(inputs, mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1]) new_shape = [ - mbox_loc.shape[0], - mbox_loc.shape[1] * mbox_loc.shape[2] * cpt.floor_division(mbox_loc.shape[3], 4), - 4 + mbox_loc.shape[0], mbox_loc.shape[1] * mbox_loc.shape[2] * + cpt.floor_division(mbox_loc.shape[3], 4), 4 ] mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape) mbox_locs.append(mbox_loc_flatten) @@ -1121,9 +1120,8 @@ def multi_box_head(inputs, stride=stride) conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1]) new_shape = [ - conf_loc.shape[0], - conf_loc.shape[1] * conf_loc.shape[2] * cpt.floor_division(conf_loc.shape[3], num_classes), - num_classes + conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] * + cpt.floor_division(conf_loc.shape[3], num_classes), num_classes ] conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape) mbox_confs.append(conf_loc_flatten) diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py index 9afac4143e..c03e73542a 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py @@ -45,15 +45,17 @@ CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' def reader_creator(filename, sub_name, batch_size=None): def read_batch(batch): data = batch[six.b('data')] - labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None)) + labels = batch.get( + six.b('labels'), batch.get(six.b('fine_labels'), None)) assert labels is not None for sample, label in six.moves.zip(data, labels): yield (sample / 255.0).astype(numpy.float32), int(label) def reader(): with tarfile.open(filename, mode='r') as f: - names = [each_item.name for each_item in f - if sub_name in each_item.name] + names = [ + each_item.name for each_item in f if sub_name in each_item.name + ] batch_count = 0 for name in names: diff --git a/python/paddle/fluid/tests/unittests/test_compat.py b/python/paddle/fluid/tests/unittests/test_compat.py index 20e93515de..00216d33e1 100644 --- a/python/paddle/fluid/tests/unittests/test_compat.py +++ b/python/paddle/fluid/tests/unittests/test_compat.py @@ -63,7 +63,6 @@ class TestCompatible(unittest.TestCase): for i in l2: self.assertTrue(isinstance(i, unicode)) - # check list types, inplace l = [""] l2 = cpt.to_literal_str(l, inplace=True) @@ -272,7 +271,6 @@ class TestCompatible(unittest.TestCase): for i in l2: self.assertTrue(isinstance(i, bytes)) - # check list types, inplace l = [""] l2 = cpt.to_bytes(l, inplace=True) @@ -461,30 +459,35 @@ class TestCompatible(unittest.TestCase): exception_message = "test_message" self.assertRaises(AssertionError, cpt.get_exception_message, None) if six.PY2: - self.assertRaises(AttributeError, cpt.get_exception_message, exception_message) + self.assertRaises(AttributeError, cpt.get_exception_message, + exception_message) try: raise RuntimeError(exception_message) except Exception as e: - self.assertEqual(exception_message, cpt.get_exception_message(e)) + self.assertEqual(exception_message, + cpt.get_exception_message(e)) self.assertIsNotNone(e) try: raise Exception(exception_message) except Exception as e: - self.assertEqual(exception_message, cpt.get_exception_message(e)) + self.assertEqual(exception_message, + cpt.get_exception_message(e)) self.assertIsNotNone(e) if six.PY3: try: raise RuntimeError(exception_message) except Exception as e: - self.assertEqual(exception_message, cpt.get_exception_message(e)) + self.assertEqual(exception_message, + cpt.get_exception_message(e)) self.assertIsNotNone(e) try: raise Exception(exception_message) except Exception as e: - self.assertEqual(exception_message, cpt.get_exception_message(e)) + self.assertEqual(exception_message, + cpt.get_exception_message(e)) self.assertIsNotNone(e) diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py index 4bbec06a91..1d8db37fe7 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_op.py @@ -39,7 +39,8 @@ class TestGRUOp(OpTest): for i in range(len(seq_lens)): seq_starts.append(seq_starts[-1] + seq_lens[i]) sorted_seqs = sorted( - list(range(len(seq_lens))), key=functools.cmp_to_key(lambda x, y: seq_lens[y] - seq_lens[x])) + list(range(len(seq_lens))), + key=functools.cmp_to_key(lambda x, y: seq_lens[y] - seq_lens[x])) num_batch = seq_lens[sorted_seqs[0]] for batch_idx in range(num_batch): idx_in_seq = [] diff --git a/python/paddle/fluid/tests/unittests/test_operator_desc.py b/python/paddle/fluid/tests/unittests/test_operator_desc.py index cbdbd957a5..5634e29d01 100644 --- a/python/paddle/fluid/tests/unittests/test_operator_desc.py +++ b/python/paddle/fluid/tests/unittests/test_operator_desc.py @@ -36,8 +36,9 @@ class TestOperator(unittest.TestCase): block.append_op(type="no_such_op") self.assertFail() except ValueError as a_err: - self.assertEqual(cpt.get_exception_message(a_err), - "Operator \"no_such_op\" has not been registered.") + self.assertEqual( + cpt.get_exception_message(a_err), + "Operator \"no_such_op\" has not been registered.") def test_op_desc_creation(self): program = Program() diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_op.py b/python/paddle/fluid/tests/unittests/test_pool3d_op.py index a358c84991..8b96a0e22a 100644 --- a/python/paddle/fluid/tests/unittests/test_pool3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool3d_op.py @@ -29,14 +29,14 @@ def max_pool3D_forward_naive(x, if global_pool == 1: ksize = [D, H, W] D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1 - ) // strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 * - paddings[0]) // strides[0] + 1 + ) // strides[0] + 1 if ceil_mode else ( + H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1 - ) // strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 * - paddings[1]) // strides[1] + 1 + ) // strides[1] + 1 if ceil_mode else ( + W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1 - ) // strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 * - paddings[2]) // strides[2] + 1 + ) // strides[2] + 1 if ceil_mode else ( + W - ksize[2] + 2 * paddings[2]) // strides[2] + 1 out = np.zeros((N, C, D_out, H_out, W_out)) for k in range(D_out): d_start = np.max((k * strides[0] - paddings[0], 0)) @@ -63,14 +63,14 @@ def avg_pool3D_forward_naive(x, if global_pool == 1: ksize = [D, H, W] D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1 - ) // strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 * - paddings[0]) // strides[0] + 1 + ) // strides[0] + 1 if ceil_mode else ( + H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1 - ) // strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 * - paddings[1]) // strides[1] + 1 + ) // strides[1] + 1 if ceil_mode else ( + W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1 - ) // strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 * - paddings[2]) // strides[2] + 1 + ) // strides[2] + 1 if ceil_mode else ( + W - ksize[2] + 2 * paddings[2]) // strides[2] + 1 out = np.zeros((N, C, D_out, H_out, W_out)) for k in range(D_out): d_start = np.max((k * strides[0] - paddings[0], 0)) diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 2732030b95..57c72f465b 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -259,9 +259,9 @@ class ControlFlowGraph(object): # Rename the var to the cache var already with # memory allocated in order to reuse the memory. _rename_arg_(self._ops, x, cache_var, begin_idx=i) - self._program.block(block_desc.id).var(cpt.to_literal_str( - x)).desc = self._find_var(block_desc, cache_var, - is_forward) + self._program.block(block_desc.id).var( + cpt.to_literal_str(x)).desc = self._find_var( + block_desc, cache_var, is_forward) self._update_graph(x, cache_var, begin_idx=i) break From 47561c34b0fdd27df852dbf2c09230bf3bee412b Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 13 Aug 2018 15:59:33 +0800 Subject: [PATCH 41/94] Fix python2 CI issues --- python/paddle/fluid/executor.py | 5 ++--- python/paddle/fluid/tests/unittests/test_lod_rank_table.py | 2 +- python/paddle/reader/tests/creator_test.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index d840e41476..8437a9f20f 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -320,9 +320,8 @@ class Executor(object): # append fetch_operators if not has_fetch_operators(global_block, fetch_list, fetch_var_name): for i, var in enumerate(fetch_list): - assert isinstance(var, Variable) or isinstance( - var, six.text_type), ("Wrong type for fetch_list[%s]: %s" % - (i, type(var))) + assert isinstance(var, Variable) or isinstance(var, str), ( + "Wrong type for fetch_list[%s]: %s" % (i, type(var))) global_block.append_op( type='fetch', inputs={'X': [var]}, diff --git a/python/paddle/fluid/tests/unittests/test_lod_rank_table.py b/python/paddle/fluid/tests/unittests/test_lod_rank_table.py index ea57412660..cae8f3fb81 100644 --- a/python/paddle/fluid/tests/unittests/test_lod_rank_table.py +++ b/python/paddle/fluid/tests/unittests/test_lod_rank_table.py @@ -37,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase): exe.run(scope=scope, feed={'x': tensor}) var = scope.find_var(rank_table.name) table = var.get_lod_rank_table() - self.assertEqual([(0, 5), (1, 1), (2, 1)], six.iteritems(table)) + self.assertEqual([(0, 5), (1, 1), (2, 1)], list(six.iteritems(table))) if __name__ == '__main__': diff --git a/python/paddle/reader/tests/creator_test.py b/python/paddle/reader/tests/creator_test.py index 567f38c96e..d7107610a5 100644 --- a/python/paddle/reader/tests/creator_test.py +++ b/python/paddle/reader/tests/creator_test.py @@ -38,7 +38,7 @@ class TestNumpyArray(unittest.TestCase): x = np.array(l, np.int32) reader = paddle.reader.creator.np_array(x) for idx, e in enumerate(reader()): - six.assertCountEqual(e, l[idx]) + six.assertCountEqual(self, e, l[idx]) class TestTextFile(unittest.TestCase): From 5b452bfd59461a6f9282b17ca9762f39a2b3eb17 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 13 Aug 2018 16:14:46 +0800 Subject: [PATCH 42/94] Remove the overfix of lod_rank_table --- python/paddle/fluid/tests/unittests/test_lod_rank_table.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lod_rank_table.py b/python/paddle/fluid/tests/unittests/test_lod_rank_table.py index cae8f3fb81..d53ead381d 100644 --- a/python/paddle/fluid/tests/unittests/test_lod_rank_table.py +++ b/python/paddle/fluid/tests/unittests/test_lod_rank_table.py @@ -18,7 +18,6 @@ from paddle.fluid.executor import Executor import paddle.fluid.core as core import numpy import unittest -import six class TestLoDRankTable(unittest.TestCase): @@ -37,7 +36,7 @@ class TestLoDRankTable(unittest.TestCase): exe.run(scope=scope, feed={'x': tensor}) var = scope.find_var(rank_table.name) table = var.get_lod_rank_table() - self.assertEqual([(0, 5), (1, 1), (2, 1)], list(six.iteritems(table))) + self.assertEqual([(0, 5), (1, 1), (2, 1)], list(table.items())) if __name__ == '__main__': From 171a0e2b42e1dea669056bbc6093e572e1c88e0a Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 13 Aug 2018 18:01:43 +0800 Subject: [PATCH 43/94] add some comment --- paddle/fluid/operators/gru_op.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 2b5094925c..087f903a8b 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -276,6 +276,7 @@ class GRUCPUKernel : public framework::OpKernel { context.Attr("gate_activation")); #ifdef PADDLE_WITH_MKLML + // use MKL packed to speedup GEMM if (FLAGS_paddle_num_threads >= 4) { auto blas = math::GetBlas(dev_ctx); T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, From fae5c1f5140c44227b4f2e030c6f924fe902b2fc Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 13 Aug 2018 18:52:18 +0800 Subject: [PATCH 44/94] Fix the input of executor --- python/paddle/fluid/executor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 8437a9f20f..a6b9440150 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -320,8 +320,9 @@ class Executor(object): # append fetch_operators if not has_fetch_operators(global_block, fetch_list, fetch_var_name): for i, var in enumerate(fetch_list): - assert isinstance(var, Variable) or isinstance(var, str), ( - "Wrong type for fetch_list[%s]: %s" % (i, type(var))) + assert isinstance(var, Variable) or isinstance( + var, six.string_types), ( + "Wrong type for fetch_list[%s]: %s" % (i, type(var))) global_block.append_op( type='fetch', inputs={'X': [var]}, From 6988aea989259829320769c610df5573358976a8 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 13 Aug 2018 19:09:35 +0800 Subject: [PATCH 45/94] Fix long type in Python3 --- python/paddle/fluid/compat.py | 8 ++++++++ python/paddle/fluid/tests/unittests/test_compat.py | 8 ++++++++ .../paddle/fluid/tests/unittests/test_lookup_table_op.py | 3 ++- 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/compat.py b/python/paddle/fluid/compat.py index 6d3e7794de..62826c7ce9 100644 --- a/python/paddle/fluid/compat.py +++ b/python/paddle/fluid/compat.py @@ -16,6 +16,7 @@ import six import math __all__ = [ + 'long_type', 'to_literal_str', 'to_bytes', 'round', @@ -23,6 +24,13 @@ __all__ = [ 'get_exception_message', ] +if six.PY2: + int_type = int + long_type = long +else: + int_type = int + long_type = int + # str and bytes related functions def to_literal_str(obj, encoding='utf-8', inplace=False): diff --git a/python/paddle/fluid/tests/unittests/test_compat.py b/python/paddle/fluid/tests/unittests/test_compat.py index 00216d33e1..525789ddb6 100644 --- a/python/paddle/fluid/tests/unittests/test_compat.py +++ b/python/paddle/fluid/tests/unittests/test_compat.py @@ -18,6 +18,14 @@ import six class TestCompatible(unittest.TestCase): + def test_type(self): + if six.PY2: + self.assertEqual(cpt.int_type, int) + self.assertEqual(cpt.long_type, long) + else: + self.assertEqual(cpt.int_type, int) + self.assertEqual(cpt.long_type, int) + def test_to_literal_str(self): # Only support python2.x and python3.x now self.assertTrue(six.PY2 | six.PY3) diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index ac25f432df..a325422c31 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -17,6 +17,7 @@ import numpy as np from op_test import OpTest import paddle.fluid.core as core from paddle.fluid.op import Operator +import paddle.fluid.compat as cpt class TestLookupTableOp(OpTest): @@ -71,7 +72,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): flatten_idx = ids.flatten() padding_idx = np.random.choice(flatten_idx, 1)[0] self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) - self.attrs = {'padding_idx': long(padding_idx)} + self.attrs = {'padding_idx': cpt.long_type(padding_idx)} self.check_output() def test_check_grad(self): From 038cbf799d290e3e7cc129b59a2bea7b7e40055a Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 13 Aug 2018 22:49:58 +0800 Subject: [PATCH 46/94] add bias for fc op --- paddle/fluid/operators/fc_op.cc | 72 ++++++++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index a9ae1396db..5fee30e146 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -30,21 +30,34 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { auto w_dims = ctx->GetInputDim("W"); std::vector output_shape({in_dims[0], w_dims[1]}); + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim]."); + PADDLE_ENFORCE_EQ(bias_dims[1], framework::product(w_dims) / w_dims[0], + "The shape of Bias must be [1, dim]."); + } PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, "Fully Connected input should be 2-D or 4-D tensor."); PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4, "Fully Connected input should be 2-D or 4-D tensor."); + PADDLE_ENFORCE_EQ(framework::product(w_dims) / w_dims[0], + framework::product(in_dims) / in_dims[0], + "Fully Connected input and weigth size do not match."); + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->ShareLoD("Input", "Out"); } framework::OpKernelType FCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library{framework::LibraryType::kMKLDNN}; - framework::DataLayout layout{framework::DataLayout::kMKLDNN}; - + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + if (ctx.Attr("use_mkldnn");) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), layout, library); @@ -60,13 +73,22 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { if (ctx->HasOutput(framework::GradVarName("W"))) { ctx->SetOutputDim(framework::GradVarName("W"), w_dims); } + + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")); + ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims); + } } framework::OpKernelType FCOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library{framework::LibraryType::kMKLDNN}; - framework::DataLayout layout{framework::DataLayout::kMKLDNN}; - + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + if (ctx.Attr("use_mkldnn");) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), layout, library); @@ -75,12 +97,12 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( void FCOpMaker::Make() { AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); AddInput("W", "(Tensor), The second input tensor of fc op."); + AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x D") + .AsDispensable(); AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); - AddAttr("bias_attr", "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); AddComment(R"DOC( Fully Connected Operator. @@ -94,9 +116,39 @@ void FCOpMaker::Make() { )DOC"); } +template +class FCOpKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + auto input = ctx.Input("Input"); + auto w = ctx.Input("W"); + auto b = ctx.Input("Bias"); + + const T* input_data = input->data(); + const T* w_data = w->data(); + auto output = ctx.Output("Out"); + T* output_data = output->mutable_data(ctx.GetPlace()); + + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + std::vector output_shape({in_dims[0], w_dims[1]}); + + if (bias) { + const T* bias_data = bias->data(); + } + } +}; + } // namespace operators } // namespace paddle -REGISTER_OPERATOR(fc, paddle::operators::FCOp, paddle::operators::FCOpMaker, +namespace ops = paddle::operators; +REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(fc_grad, paddle::operators::FCOpGrad); +REGISTER_OPERATOR(fc_grad, ops::FCOpGrad); +REGISTER_OP_CPU_KERNEL(fc, ops::FCMKLDNNOpKernel, + ops::FCMKLDNNOpKernel); From 812de6e8066320cfec32af246e5575adb653cfdc Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 14 Aug 2018 10:26:03 +0800 Subject: [PATCH 47/94] Port utils to Python3 --- python/paddle/utils/dump_config.py | 4 +- python/paddle/utils/image_multiproc.py | 16 ++++---- python/paddle/utils/image_util.py | 8 ++-- python/paddle/utils/make_model_diagram.py | 45 +++++++++++++---------- python/paddle/utils/merge_model.py | 2 +- python/paddle/utils/plotcurve.py | 3 +- python/paddle/utils/predefined_net.py | 5 ++- python/paddle/utils/preprocess_img.py | 14 +++---- python/paddle/utils/preprocess_util.py | 10 ++--- python/paddle/utils/show_pb.py | 8 ++-- python/paddle/utils/torch2paddle.py | 4 +- 11 files changed, 66 insertions(+), 53 deletions(-) diff --git a/python/paddle/utils/dump_config.py b/python/paddle/utils/dump_config.py index d27af7f762..6a96a0a78f 100644 --- a/python/paddle/utils/dump_config.py +++ b/python/paddle/utils/dump_config.py @@ -37,9 +37,9 @@ if __name__ == '__main__': assert isinstance(conf, TrainerConfig_pb2.TrainerConfig) if whole_conf: - print conf + print(conf) else: if binary: sys.stdout.write(conf.model_config.SerializeToString()) else: - print conf.model_config + print(conf.model_config) diff --git a/python/paddle/utils/image_multiproc.py b/python/paddle/utils/image_multiproc.py index 3e3e519f76..d1bbda3fd3 100644 --- a/python/paddle/utils/image_multiproc.py +++ b/python/paddle/utils/image_multiproc.py @@ -15,7 +15,8 @@ import os, sys import numpy as np from PIL import Image -from cStringIO import StringIO +import six +from six.moves import cStringIO as StringIO import multiprocessing import functools import itertools @@ -187,7 +188,8 @@ class PILTransformer(ImageTransformer): return self.transform(im) -def job(is_img_string, transformer, (data, label)): +def job(is_img_string, transformer, data_label_pack): + (data, label) = data_label_pack if is_img_string: return transformer.transform_from_string(data), label else: @@ -208,7 +210,7 @@ class MultiProcessImageTransformer(object): """ Processing image with multi-process. If it is used in PyDataProvider, the simple usage for CNN is as follows: - + .. code-block:: python def hool(settings, is_train, **kwargs): @@ -229,7 +231,7 @@ class MultiProcessImageTransformer(object): @provider(init_hook=hook, pool_size=20480) def process(settings, file_list): with open(file_list, 'r') as fdata: - for line in fdata: + for line in fdata: data_dic = np.load(line.strip()) # load the data batch pickled by Pickle. data = data_dic['data'] labels = data_dic['label'] @@ -249,10 +251,10 @@ class MultiProcessImageTransformer(object): :type channel_swap: tuple or list :param mean: the mean values of image, per-channel mean or element-wise mean. :type mean: array, The dimension is 1 for per-channel mean. - The dimension is 3 for element-wise mean. + The dimension is 3 for element-wise mean. :param is_train: training peroid or testing peroid. :type is_train: bool. - :param is_color: the image is color or gray. + :param is_color: the image is color or gray. :type is_color: bool. :param is_img_string: The input can be the file name of image or image string. :type is_img_string: bool. @@ -273,4 +275,4 @@ class MultiProcessImageTransformer(object): def run(self, data, label): fun = functools.partial(job, self.is_img_string, self.transformer) return self.pool.imap_unordered( - fun, itertools.izip(data, label), chunksize=100 * self.procnum) + fun, six.moves.zip(data, label), chunksize=100 * self.procnum) diff --git a/python/paddle/utils/image_util.py b/python/paddle/utils/image_util.py index d3d79b1440..a8092349cd 100644 --- a/python/paddle/utils/image_util.py +++ b/python/paddle/utils/image_util.py @@ -14,7 +14,7 @@ import numpy as np from PIL import Image -from cStringIO import StringIO +from six.moves import cStringIO as StringIO def resize_image(img, target_size): @@ -34,7 +34,7 @@ def flip(im): """ Return the flipped image. Flip an image along the horizontal direction. - im: input image, (H x W x K) ndarrays + im: input image, (H x W x K) ndarrays """ if len(im.shape) == 3: return im[:, :, ::-1] @@ -132,7 +132,7 @@ def load_meta(meta_path, mean_img_size, crop_size, color=True): def load_image(img_path, is_color=True): """ - Load image and return. + Load image and return. img_path: image path. is_color: is color image or not. """ @@ -205,7 +205,7 @@ class ImageTransformer: def set_mean(self, mean): if mean is not None: - # mean value, may be one value per channel + # mean value, may be one value per channel if mean.ndim == 1: mean = mean[:, np.newaxis, np.newaxis] else: diff --git a/python/paddle/utils/make_model_diagram.py b/python/paddle/utils/make_model_diagram.py index 40f99075de..52759d3ad2 100644 --- a/python/paddle/utils/make_model_diagram.py +++ b/python/paddle/utils/make_model_diagram.py @@ -15,6 +15,9 @@ # Generate dot diagram file for the given paddle model config # The generated file can be viewed using Graphviz (http://graphviz.org) +from __future__ import print_function + +import six import sys import traceback @@ -61,9 +64,9 @@ def make_diagram_from_proto(model_config, dot_file): name2id[mem.link_name]) return s - print >> f, 'digraph graphname {' - print >> f, 'node [width=0.375,height=0.25];' - for i in xrange(len(model_config.layers)): + print('digraph graphname {', file=f) + print('node [width=0.375,height=0.25];', file=f) + for i in six.moves.xrange(len(model_config.layers)): l = model_config.layers[i] name2id[l.name] = i @@ -71,12 +74,12 @@ def make_diagram_from_proto(model_config, dot_file): for sub_model in model_config.sub_models: if sub_model.name == 'root': continue - print >> f, 'subgraph cluster_%s {' % i - print >> f, 'style=dashed;' + print('subgraph cluster_%s {' % i, file=f) + print('style=dashed;', file=f) label = '%s ' % sub_model.name if sub_model.reversed: label += '<==' - print >> f, 'label = "%s";' % label + print('label = "%s";' % label, file=f) i += 1 submodel_layers.add(sub_model.name) for layer_name in sub_model.layer_names: @@ -84,37 +87,41 @@ def make_diagram_from_proto(model_config, dot_file): lid = name2id[layer_name] layer_config = model_config.layers[lid] label = make_layer_label(layer_config) - print >> f, 'l%s [label="%s", shape=box];' % (lid, label) - print >> f, '}' + print('l%s [label="%s", shape=box];' % (lid, label), file=f) + print('}', file=f) - for i in xrange(len(model_config.layers)): + for i in six.moves.xrange(len(model_config.layers)): l = model_config.layers[i] if l.name not in submodel_layers: label = make_layer_label(l) - print >> f, 'l%s [label="%s", shape=box];' % (i, label) + print('l%s [label="%s", shape=box];' % (i, label), file=f) for sub_model in model_config.sub_models: if sub_model.name == 'root': continue for link in sub_model.in_links: - print >> f, make_link(link) + print(make_link(link), file=f) for link in sub_model.out_links: - print >> f, make_link(link) + print(make_link(link), file=f) for mem in sub_model.memories: - print >> f, make_mem(mem) + print(make_mem(mem), file=f) - for i in xrange(len(model_config.layers)): + for i in six.moves.xrange(len(model_config.layers)): for l in model_config.layers[i].inputs: - print >> f, 'l%s -> l%s [label="%s"];' % ( - name2id[l.input_layer_name], i, l.input_parameter_name) + print( + 'l%s -> l%s [label="%s"];' % (name2id[l.input_layer_name], i, + l.input_parameter_name), + file=f) - print >> f, '}' + print('}', file=f) f.close() def usage(): - print >> sys.stderr, ("Usage: python show_model_diagram.py" + - " CONFIG_FILE DOT_FILE [config_str]") + print( + ("Usage: python show_model_diagram.py" + + " CONFIG_FILE DOT_FILE [config_str]"), + file=sys.stderr) exit(1) diff --git a/python/paddle/utils/merge_model.py b/python/paddle/utils/merge_model.py index 2b10020772..b74649e936 100644 --- a/python/paddle/utils/merge_model.py +++ b/python/paddle/utils/merge_model.py @@ -70,4 +70,4 @@ def merge_v2_model(net, param_file, output_file): for pname in param_names: params.serialize(pname, f) - print 'Generate %s success!' % (output_file) + print('Generate %s success!' % (output_file)) diff --git a/python/paddle/utils/plotcurve.py b/python/paddle/utils/plotcurve.py index 27bd8157d3..a95e5497e2 100644 --- a/python/paddle/utils/plotcurve.py +++ b/python/paddle/utils/plotcurve.py @@ -44,6 +44,7 @@ To use this script to generate plot for AvgCost, error: python plotcurve.py -i paddle.INFO -o figure.png AvgCost error """ +import six import sys import matplotlib # the following line is added immediately after import matplotlib @@ -91,7 +92,7 @@ def plot_paddle_curve(keys, inputfile, outputfile, format='png', sys.stderr.write("No data to plot. Exiting!\n") return m = len(keys) + 1 - for i in xrange(1, m): + for i in six.moves.xrange(1, m): pyplot.plot( x[:, 0], x[:, i], diff --git a/python/paddle/utils/predefined_net.py b/python/paddle/utils/predefined_net.py index fa05f981f2..2801f4877c 100644 --- a/python/paddle/utils/predefined_net.py +++ b/python/paddle/utils/predefined_net.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import six import os from paddle.trainer.config_parser import * from paddle.utils.preprocess_img import \ @@ -112,7 +113,7 @@ def simple_conv_net(data_conf, is_color=False): num_classes: num of classes. is_color: whether the input images are color. """ - for k, v in data_conf.iteritems(): + for k, v in six.iteritems(data_conf): globals()[k] = v data_input, label_input, num_image_channels = \ image_data_layers(image_size, num_classes, is_color, is_predict) @@ -340,7 +341,7 @@ def small_vgg(data_conf, is_predict=False): num_classes: num of classes. is_color: whether the input images are color. """ - for k, v in data_conf.iteritems(): + for k, v in six.iteritems(data_conf): globals()[k] = v vgg_conv_net(image_size, num_classes, num_layers=[2, 2, 3, 3], diff --git a/python/paddle/utils/preprocess_img.py b/python/paddle/utils/preprocess_img.py index 975f1e9ede..a322f7b769 100644 --- a/python/paddle/utils/preprocess_img.py +++ b/python/paddle/utils/preprocess_img.py @@ -17,9 +17,9 @@ import os import random import numpy as np import PIL.Image as Image -import StringIO -import preprocess_util -from image_util import crop_img +from six.moves import cStringIO as StringIO +from . import preprocess_util +from .image_util import crop_img def resize_image(img, target_size): @@ -52,7 +52,7 @@ class DiskImage: def read_image(self): if self.img is None: - print "reading: " + self.path + print("reading: " + self.path) image = resize_image(Image.open(self.path), self.target_size) self.img = image @@ -69,7 +69,7 @@ class DiskImage: convert the image into the paddle batch format. """ self.read_image() - output = StringIO.StringIO() + output = StringIO() self.img.save(output, "jpeg") contents = output.getvalue() return contents @@ -127,7 +127,7 @@ class ImageClassificationDatasetCreater(preprocess_util.DatasetCreater): image_path = items[0] label_name = items[1] if not label_name in label_set: - label_set[label_name] = len(label_set.keys()) + label_set[label_name] = len(list(label_set.keys())) img = DiskImage(path=image_path, target_size=self.target_size) label = preprocess_util.Lablel( label=label_set[label_name], name=label_name) @@ -144,7 +144,7 @@ class ImageClassificationDatasetCreater(preprocess_util.DatasetCreater): return create_dataset_from_list(path) label_set = preprocess_util.get_label_set_from_dir(path) data = [] - for l_name in label_set.keys(): + for l_name in list(label_set.keys()): image_paths = preprocess_util.list_images( os.path.join(path, l_name)) for p in image_paths: diff --git a/python/paddle/utils/preprocess_util.py b/python/paddle/utils/preprocess_util.py index 1d17a48824..05b2067d01 100644 --- a/python/paddle/utils/preprocess_util.py +++ b/python/paddle/utils/preprocess_util.py @@ -14,7 +14,7 @@ import os import math -import cPickle as pickle +import six.moves.cPickle as pickle import random import collections @@ -169,7 +169,7 @@ class Dataset: random.shuffle(keyvalue_indices[k]) num_data_per_key_batch = \ - math.ceil(num_per_batch / float(len(keyvalue_indices.keys()))) + math.ceil(num_per_batch / float(len(list(keyvalue_indices.keys())))) if num_data_per_key_batch < 2: raise Exception("The number of data in a batch is too small") @@ -182,8 +182,8 @@ class Dataset: end_idx = int( min(begin_idx + num_data_per_key_batch, len(keyvalue_indices[k]))) - print "begin_idx, end_idx" - print begin_idx, end_idx + print("begin_idx, end_idx") + print(begin_idx, end_idx) for idx in range(begin_idx, end_idx): permuted_data.append(self.data[keyvalue_indices[k][idx]]) keyvalue_readpointer[k] = end_idx @@ -357,6 +357,6 @@ class DatasetCreater(object): data_batcher.create_batches_and_list( self.output_path, self.train_list_name, self.test_list_name, self.label_set_name) - self.num_classes = len(train_label_set.keys()) + self.num_classes = len(list(train_label_set.keys())) self.create_meta_file(train_data) return out_path diff --git a/python/paddle/utils/show_pb.py b/python/paddle/utils/show_pb.py index 20614826d1..da7a71a665 100644 --- a/python/paddle/utils/show_pb.py +++ b/python/paddle/utils/show_pb.py @@ -15,6 +15,8 @@ Show the content of proto buffer data file of PADDLE """ +from __future__ import print_function + import os import sys from google.protobuf.internal.decoder import _DecodeVarint @@ -39,7 +41,7 @@ def read_proto(file, message): def usage(): - print >> sys.stderr, "Usage: python show_pb.py PROTO_DATA_FILE" + print("Usage: python show_pb.py PROTO_DATA_FILE", file=sys.stderr) exit(1) @@ -50,8 +52,8 @@ if __name__ == '__main__': f = open(sys.argv[1]) header = DataFormat.DataHeader() read_proto(f, header) - print header + print(header) sample = DataFormat.DataSample() while read_proto(f, sample): - print sample + print(sample) diff --git a/python/paddle/utils/torch2paddle.py b/python/paddle/utils/torch2paddle.py index 91490111a1..398d3aa4e0 100644 --- a/python/paddle/utils/torch2paddle.py +++ b/python/paddle/utils/torch2paddle.py @@ -24,7 +24,7 @@ import sys import struct import numpy as np import torchfile -import cPickle as pickle +import six.moves.cPickle as pickle import argparse @@ -48,7 +48,7 @@ def save_net_parameters(layers, params, output_path): biases = params[i * 2 + 1] weight_file = os.path.join(output_path, '_%s.w0' % layers[i]) biases_file = os.path.join(output_path, '_%s.wbias' % layers[i]) - print "Saving for layer %s." % layers[i] + print("Saving for layer %s." % layers[i]) save_layer_parameters(weight_file, [weight]) save_layer_parameters(biases_file, biases) From 92aa20616dd7b053de209a11a9fec33dc8579f4f Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 14 Aug 2018 11:31:45 +0800 Subject: [PATCH 48/94] Polish the code style --- python/paddle/fluid/tests/unittests/test_dist_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 618c910b71..78ed29be56 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -123,6 +123,7 @@ def runtime_main(test_class): ) else fluid.CPUPlace() model.run_trainer(p, endpoints, trainer_id, trainers, is_dist) + import paddle.fluid.compat as cpt From e133df60373b92d1e35b2f34144e7067dbb9752b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 13 Aug 2018 23:40:58 +0800 Subject: [PATCH 49/94] enable native fc forward --- paddle/fluid/operators/fc_mkldnn_op.cc | 1 + paddle/fluid/operators/fc_op.cc | 55 +++++++++++++++----------- 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/fc_mkldnn_op.cc b/paddle/fluid/operators/fc_mkldnn_op.cc index 99fa659a35..68a47dd6ad 100644 --- a/paddle/fluid/operators/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/fc_mkldnn_op.cc @@ -128,6 +128,7 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel { PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4, "Input must be with 2 or 4 dimensions, i.e. NCHW"); + // TODO(intel): the src weight is io and mkldnn weight need be transposed ! PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4, "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW"); diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index 5fee30e146..e71f63c134 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -15,6 +15,8 @@ limitations under the License. */ #include "paddle/fluid/operators/fc_op.h" #include +DECLARE_int32(paddle_num_threads); + namespace paddle { namespace operators { @@ -25,25 +27,23 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { "Out(Output) of Fully Connected should not be null."); PADDLE_ENFORCE(ctx->HasInput("W"), "W(Input) of Fully Connected should not be null."); - + // NCHW auto in_dims = ctx->GetInputDim("Input"); + // IO, I=C*H*W auto w_dims = ctx->GetInputDim("W"); std::vector output_shape({in_dims[0], w_dims[1]}); if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim]."); - PADDLE_ENFORCE_EQ(bias_dims[1], framework::product(w_dims) / w_dims[0], + PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1], "The shape of Bias must be [1, dim]."); } PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, "Fully Connected input should be 2-D or 4-D tensor."); - - PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4, - "Fully Connected input should be 2-D or 4-D tensor."); - - PADDLE_ENFORCE_EQ(framework::product(w_dims) / w_dims[0], - framework::product(in_dims) / in_dims[0], + PADDLE_ENFORCE_EQ(w_dims.size(), 2UL, + "Fully Connected input should be 2-D tensor."); + PADDLE_ENFORCE_EQ(framework::product(in_dims) / in_dims[0], w_dims[0], "Fully Connected input and weigth size do not match."); ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); @@ -54,7 +54,7 @@ framework::OpKernelType FCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; - if (ctx.Attr("use_mkldnn");) { + if (ctx.Attr("use_mkldnn")) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -75,8 +75,9 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { } if (ctx->HasInput("Bias")) { + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), + "Should have bias grad"); auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")); ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims); } } @@ -85,7 +86,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; - if (ctx.Attr("use_mkldnn");) { + if (ctx.Attr("use_mkldnn")) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -95,9 +96,11 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( } void FCOpMaker::Make() { - AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); - AddInput("W", "(Tensor), The second input tensor of fc op."); - AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x D") + AddInput("Input", + "(Tensor), The input tensor of fully connected operator with format " + "(NCHW). "); + AddInput("W", "(Tensor), The weight fc op with shape (I, O)."); + AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O") .AsDispensable(); AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddAttr("use_mkldnn", @@ -120,25 +123,32 @@ template class FCOpKernel : public framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); - auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); auto input = ctx.Input("Input"); auto w = ctx.Input("W"); auto b = ctx.Input("Bias"); + auto output = ctx.Output("Out"); + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); const T* input_data = input->data(); const T* w_data = w->data(); - auto output = ctx.Output("Out"); T* output_data = output->mutable_data(ctx.GetPlace()); - auto in_dims = ctx->GetInputDim("Input"); - auto w_dims = ctx->GetInputDim("W"); - std::vector output_shape({in_dims[0], w_dims[1]}); + blas.GEMM(CblasNoTrans, CblasNoTrans, in_dims[0], w_dims[1], w_dims[0], + static_cast(1), input_data, w_data, static_cast(0), + output_data); if (bias) { const T* bias_data = bias->data(); +#pragma omp parallel for if (FLAGS_paddle_num_threads > 1) + for (int bs = 0; bs < in_dims[0]; bs++) { + blas.AXPY(w_dims[1], static_cast(1), bias_data, + output_data + bs * w_dimws[1]); + } } } }; @@ -150,5 +160,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(fc_grad, ops::FCOpGrad); -REGISTER_OP_CPU_KERNEL(fc, ops::FCMKLDNNOpKernel, - ops::FCMKLDNNOpKernel); +REGISTER_OP_CPU_KERNEL(fc, ops::FCOpKernel, ops::FCOpKernel); From 4b5986bb77b06432f44bcd7f1e9352f8ca5dae2f Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 14 Aug 2018 13:36:03 +0800 Subject: [PATCH 50/94] enable fc op in normal case --- paddle/fluid/operators/CMakeLists.txt | 6 ------ paddle/fluid/operators/fc_op.cc | 13 +++++++------ 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 4c3b8ec781..8cd80ca6be 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -295,12 +295,6 @@ op_library(channel_recv_op DEPS concurrency) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) -# The fully connected layer is deleted when the WITH_MKLDNN flag is OFF -# Because the fully connected layer has only one MKLDNN's operator -if(NOT WITH_MKLDNN) - list(REMOVE_ITEM GENERAL_OPS fc_op) -endif(NOT WITH_MKLDNN) - foreach(src ${GENERAL_OPS}) op_library(${src}) endforeach() diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index e71f63c134..ec8dfb659c 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fc_op.h" #include +#include "paddle/fluid/operators/math/blas.h" DECLARE_int32(paddle_num_threads); @@ -127,13 +128,13 @@ class FCOpKernel : public framework::OpKernel { "It must use CPUPlace."); auto input = ctx.Input("Input"); auto w = ctx.Input("W"); - auto b = ctx.Input("Bias"); + auto bias = ctx.Input("Bias"); auto output = ctx.Output("Out"); - auto in_dims = ctx->GetInputDim("Input"); - auto w_dims = ctx->GetInputDim("W"); + auto in_dims = input->dims(); + auto w_dims = w->dims(); - auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); const T* input_data = input->data(); const T* w_data = w->data(); T* output_data = output->mutable_data(ctx.GetPlace()); @@ -147,7 +148,7 @@ class FCOpKernel : public framework::OpKernel { #pragma omp parallel for if (FLAGS_paddle_num_threads > 1) for (int bs = 0; bs < in_dims[0]; bs++) { blas.AXPY(w_dims[1], static_cast(1), bias_data, - output_data + bs * w_dimws[1]); + output_data + bs * w_dims[1]); } } } From 45d0259a6746d67f23f538b0b20f51a4af2f6d3f Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 14 Aug 2018 15:05:04 +0800 Subject: [PATCH 51/94] add fc forward test --- .../fluid/tests/unittests/test_fc_op.py | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_fc_op.py diff --git a/python/paddle/fluid/tests/unittests/test_fc_op.py b/python/paddle/fluid/tests/unittests/test_fc_op.py new file mode 100644 index 0000000000..2bb920710a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fc_op.py @@ -0,0 +1,90 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +def fc_refer(matrix, with_bias): + in_n, in_c, in_h, in_w = matrix.input.shape + w_i, w_o = matrix.weights.shape + + x_data = np.reshape(matrix.input, [in_n, in_c * in_h * in_w]) + w_data = np.reshape(matrix.weights, [w_i, w_o]) + b_data = np.reshape(matrix.bias, [1, w_o]) + result = None + + if with_bias: + result = np.dot(x_data, w_data) + b_data + else: + result = np.dot(x_data, w_data) + + return result + + +class MatrixGenerate: + def __init__(self, mb, ic, oc, h, w): + self.input = np.random.random((mb, ic, h, w)).astype("float32") + self.weights = np.random.random((ic * h * w, oc)).astype("float32") + self.bias = np.random.random((1, oc)).astype("float32") + + +class TestFCOp(OpTest): + def setUp(self): + self.op_type = "fc" + self.matrix = MatrixGenerate(1, 10, 15, 3, 3) + + self.with_bias = True + if self.with_bias: + self.inputs = { + 'Input': self.matrix.input, + 'W': self.matrix.weights, + 'Bias': self.matrix.bias + } + else: + self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights} + + self.attrs = {'use_mkldnn': False} + + self.outputs = {'Out': fc_refer(self.matrix, self.with_bias)} + + def test_check_output(self): + self.check_output() + + +class TestFCOpBiasBoth(TestFCOp): + def init_shapes(self, mb, ic, oc, h, w): + for with_bias in {True, False}: + self.with_bias = with_bias + self.matrix = MatrixGenerate(mb, ic, oc, h, w) + + +class TestFCOp1(TestFCOpBiasBoth): + def init_op_type(self): + self.init_shapes(2, 8, 10, 1, 1) + + +class TestFCOp2(TestFCOpBiasBoth): + def init_op_type(self): + self.init_shapes(4, 5, 6, 2, 2) + + +class TestFCOp4(TestFCOpBiasBoth): + def init_op_type(self): + self.init_shapes(1, 32, 64, 3, 3) + + +if __name__ == "__main__": + unittest.main() From b9dbb7c5cbad8e25cb16af07af6b58764c27ae6e Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 14 Aug 2018 15:47:15 +0800 Subject: [PATCH 52/94] fix bias attri in mkldnn fc --- paddle/fluid/operators/fc_mkldnn_op.cc | 10 +++++++--- .../paddle/fluid/tests/unittests/test_fc_mkldnn_op.py | 9 ++------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/fc_mkldnn_op.cc b/paddle/fluid/operators/fc_mkldnn_op.cc index 68a47dd6ad..e595f1a627 100644 --- a/paddle/fluid/operators/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/fc_mkldnn_op.cc @@ -125,14 +125,16 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel { auto input = ctx.Input("Input"); auto w = ctx.Input("W"); + auto bias = ctx.Input("Bias"); PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4, "Input must be with 2 or 4 dimensions, i.e. NCHW"); - // TODO(intel): the src weight is io and mkldnn weight need be transposed ! + // TODO(intel friends): the native weight format is io, + // but the mkldnn weight format is oihw, which may need be transposed. PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4, "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW"); - bool with_bias = ctx.Attr("bias_attr"); + bool with_bias = bias != nullptr; MKLDNNMD md(input, w, with_bias); std::shared_ptr pd = @@ -155,6 +157,7 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_memory = mem.dst(output_data); auto src_memory = mem.src(input_data); auto weights_memory = mem.weights(w_data); + // TODO(intel friends): bias memory should also be obtain from bias->data() auto bias_memory = mem.bias(); auto forward = with_bias ? mkldnn::inner_product_forward( @@ -217,7 +220,8 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel { const Tensor* out_grad = ctx.Input(framework::GradVarName("Out")); const T* out_grad_data = out_grad->data(); - bool with_bias = ctx.Attr("bias_attr"); + auto bias = ctx.Input("Bias"); + bool with_bias = bias != nullptr; MKLDNNMD md(input, w, with_bias); MKLDNNMemory mem(&md, mkldnn_engine); diff --git a/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py index 3f547f3c48..099e6e6064 100644 --- a/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py @@ -22,6 +22,7 @@ def fully_connected_naive(input, weights, bias_data=None): w_h, w_c = weights.shape x_data = np.reshape(input, [in_n, in_c * in_h * in_w]) + # this transpose should be implemented at C code w_data = np.transpose(np.reshape(weights, (w_c, in_c * in_h * in_w))) result = None @@ -43,15 +44,11 @@ class TestFCMKLDNNOp(OpTest): def setUp(self): self.op_type = "fc" self.use_mkldnn = True - self.with_bias = True self.matrix = MatrixGenerate(1, 10, 15, 3, 3) self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights} - self.attrs = { - 'use_mkldnn': self.use_mkldnn, - 'with_bias': self.with_bias - } + self.attrs = {'use_mkldnn': self.use_mkldnn, } self.outputs = { 'Out': fully_connected_naive(self.matrix.input, self.matrix.weights) @@ -85,13 +82,11 @@ class TestFCMKLDNNOp3(TestFCMKLDNNOp): class TestFCMKLDNNOp4(TestFCMKLDNNOp): def init_op_type(self): - self.with_bias = False self.matrix = MatrixGenerate(2, 32, 48, 2, 2) class TestFCMKLDNNOp4(TestFCMKLDNNOp): def init_op_type(self): - self.with_bias = False self.matrix = MatrixGenerate(2, 32, 1000, 6, 6) From 742300baa8a24cea467a45cc55f63ca894b0625f Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 14 Aug 2018 15:52:31 +0800 Subject: [PATCH 53/94] fix unkown omp pragmas --- paddle/fluid/operators/fc_op.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index ec8dfb659c..099ca52c8e 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -145,7 +145,9 @@ class FCOpKernel : public framework::OpKernel { if (bias) { const T* bias_data = bias->data(); +#ifdef PADDLE_WITH_MKLML #pragma omp parallel for if (FLAGS_paddle_num_threads > 1) +#endif for (int bs = 0; bs < in_dims[0]; bs++) { blas.AXPY(w_dims[1], static_cast(1), bias_data, output_data + bs * w_dims[1]); From 21d5b942282ae32bba4613b31f5429b65afc1532 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Tue, 14 Aug 2018 08:24:21 +0000 Subject: [PATCH 54/94] error message refine: add demangle api to attribute type --- paddle/fluid/framework/CMakeLists.txt | 2 + paddle/fluid/framework/attribute.h | 8 +- paddle/fluid/framework/attribute_type.h | 97 +++++++++++++++++++ paddle/fluid/framework/attribute_type_test.cc | 46 +++++++++ 4 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/framework/attribute_type.h create mode 100644 paddle/fluid/framework/attribute_type_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 6440607dbe..b3fe2d97a8 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -115,6 +115,8 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) # cc_test(channel_test SRCS channel_test.cc) cc_test(tuple_test SRCS tuple_test.cc ) +cc_test(attribute_type_test SRCS attribute_type_test.cc) + # disable test temporarily. # TODO https://github.com/PaddlePaddle/Paddle/issues/11971 # cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 8428bf8e33..2b05528257 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/attribute_type.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/platform/enforce.h" @@ -128,7 +129,8 @@ struct ExtractAttribute { attr_value = &boost::get(attr); } catch (boost::bad_get& bad_get) { PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s", - attr_name_, typeid(T).name(), attr.type().name()); + attr_name_, paddle::framework::demangle(typeid(T).name()), + paddle::framework::demangle(attr.type().name())); } return attr_value; } @@ -160,7 +162,7 @@ struct ExtractAttribute { attr_value = &boost::get(attr); } catch (boost::bad_get& bad_get) { PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s", - attr_name_, attr.type().name()); + attr_name_, paddle::framework::demangle(attr.type().name())); } return attr_value; } @@ -186,7 +188,7 @@ struct ExtractAttribute { attr_value = &boost::get(attr); } catch (boost::bad_get& bad_get) { PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", - attr_name_, attr.type().name()); + attr_name_, paddle::framework::demangle(attr.type().name())); } return attr_value; } diff --git a/paddle/fluid/framework/attribute_type.h b/paddle/fluid/framework/attribute_type.h new file mode 100644 index 0000000000..337dcde775 --- /dev/null +++ b/paddle/fluid/framework/attribute_type.h @@ -0,0 +1,97 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +// __has_include is currently supported by GCC and Clang. However GCC 4.9 may +// have issues and +// returns 1 for 'defined( __has_include )', while '__has_include' is actually +// not supported: +#if defined(__has_include) && (!defined(BOOST_GCC) || (__GNUC__ + 0) >= 5) +#if __has_include() +#define PADDLE_FRAMEWORK_HAS_CXXABI_H +#endif +#elif defined(__GLIBCXX__) || defined(__GLIBCPP__) +#define PADDLE_FRAMEWORK_HAS_CXXABI_H +#endif + +#if defined(PADDLE_FRAMEWORK_HAS_CXXABI_H) +#include +// For some archtectures (mips, mips64, x86, x86_64) cxxabi.h in Android NDK is +// implemented by gabi++ library +// which does not implement abi::__cxa_demangle(). We detect this implementation +// by checking the include guard here. +#if defined(__GABIXX_CXXABI_H__) +#undef PADDLE_FRAMEWORK_HAS_CXXABI_H +#else +#include +#include +#endif +#endif + +namespace paddle { +namespace framework { + +inline char const* demangle_alloc(char const* name); +inline void demangle_free(char const* name); + +class scoped_demangled_name { + private: + char const* m_p; + + public: + explicit scoped_demangled_name(char const* name) + : m_p(demangle_alloc(name)) {} + + ~scoped_demangled_name() { demangle_free(m_p); } + + char const* get() const { return m_p; } + + scoped_demangled_name(scoped_demangled_name const&) = delete; + scoped_demangled_name& operator=(scoped_demangled_name const&) = delete; +}; + +#if defined(PADDLE_FRAMEWORK_HAS_CXXABI_H) + +inline char const* demangle_alloc(char const* name) { + int status = 0; + std::size_t size = 0; + return abi::__cxa_demangle(name, NULL, &size, &status); +} + +inline void demangle_free(char const* name) { + std::free(const_cast(name)); +} + +inline std::string demangle(char const* name) { + scoped_demangled_name demangled_name(name); + char const* p = demangled_name.get(); + if (!p) p = name; + return p; +} + +#else + +inline char const* demangle_alloc(char const* name) { return name; } + +inline void demangle_free(char const*) {} + +inline std::string demangle(char const* name) { return name; } + +#endif + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/attribute_type_test.cc b/paddle/fluid/framework/attribute_type_test.cc new file mode 100644 index 0000000000..82537b8a0f --- /dev/null +++ b/paddle/fluid/framework/attribute_type_test.cc @@ -0,0 +1,46 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/attribute_type.h" + +TEST(Attribute, TypeName) { + bool boolean; + int integer; + float ft; + std::string str; + std::vector booleans; + std::vector integers; + std::vector strings; + + EXPECT_EQ("bool", paddle::framework::demangle(typeid(boolean).name())); + EXPECT_EQ("int", paddle::framework::demangle(typeid(integer).name())); + EXPECT_EQ("float", paddle::framework::demangle(typeid(ft).name())); + EXPECT_EQ( + "std::__cxx11::basic_string, " + "std::allocator >", + paddle::framework::demangle(typeid(str).name())); + EXPECT_EQ("std::vector >", + paddle::framework::demangle(typeid(booleans).name())); + EXPECT_EQ("std::vector >", + paddle::framework::demangle(typeid(integers).name())); + EXPECT_EQ( + "std::vector, " + "std::allocator >, std::allocator, std::allocator > > >", + paddle::framework::demangle(typeid(strings).name())); +} From 55d7f55c63a33aa2fc8f8a4d1c5c2024ec8a137d Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 14 Aug 2018 16:34:04 +0800 Subject: [PATCH 55/94] Revert the changes to attribute.h --- paddle/fluid/framework/attribute.h | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index ea91ac2bb0..8428bf8e33 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -82,10 +82,7 @@ class DefaultValueSetter { public: explicit DefaultValueSetter(T default_value) : default_value_(default_value) {} - void operator()(T* value) const { - PADDLE_ENFORCE(value != nullptr, "Can not set default value to nullptr"); - *value = default_value_; - } + void operator()(T& value) const { value = default_value_; } private: T default_value_; @@ -202,7 +199,6 @@ struct ExtractAttribute { template class TypedAttrChecker { typedef std::function ValueChecker; - typedef std::function ValueSetter; public: explicit TypedAttrChecker(const std::string& attr_name) @@ -245,7 +241,7 @@ class TypedAttrChecker { "Attribute '%s' is required!", attr_name_); // default_value_setter_ has no more than one element T val; - (default_value_setter_[0])(&val); + (default_value_setter_[0])(val); attr_map[attr_name_] = val; } Attribute& attr = attr_map.at(attr_name_); @@ -259,7 +255,7 @@ class TypedAttrChecker { private: std::string attr_name_; std::vector value_checkers_; - std::vector default_value_setter_; + std::vector default_value_setter_; }; // check whether op's all attributes fit their own limits From ae39709e5930e10d7555ce75454053e8ca6f5880 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 14 Aug 2018 17:06:53 +0800 Subject: [PATCH 56/94] Polish code --- paddle/fluid/framework/op_desc.cc | 85 +++++++------- paddle/fluid/framework/op_desc.h | 2 + python/paddle/dataset/cifar.py | 5 +- python/paddle/dataset/common.py | 1 - python/paddle/dataset/conll05.py | 4 +- python/paddle/dataset/movielens.py | 6 +- python/paddle/dataset/wmt14.py | 2 +- python/paddle/dataset/wmt16.py | 4 +- python/paddle/fluid/backward.py | 12 +- python/paddle/fluid/compat.py | 16 +-- python/paddle/fluid/framework.py | 12 +- python/paddle/fluid/parallel_executor.py | 8 +- .../fluid/tests/unittests/test_compat.py | 108 +++++++++--------- .../fluid/tests/unittests/test_dist_base.py | 4 +- .../memory_optimization_transpiler.py | 2 +- python/paddle/reader/decorator.py | 4 +- 16 files changed, 141 insertions(+), 134 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 984ea3a3dd..b0a07ccad6 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -208,49 +208,44 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { proto::AttrType attr_type = static_cast(v.which() - 1); if (attr_type == proto::AttrType::INTS && boost::get>(v).size() == 0u) { - proto::OpProto proto = OpInfoMap::Instance().Get(Type()).Proto(); // Find current attr via attr name and set the correct attribute value - for (int i = 0; i != proto.attrs_size(); ++i) { - const proto::OpProto::Attr &attr = proto.attrs(i); - if (attr.name() == name) { - switch (attr.type()) { - case proto::AttrType::BOOLEANS: { - VLOG(11) << "SetAttr: " << Type() << ", " << name - << " from INTS to BOOLEANS"; - this->attrs_[name] = std::vector(); - break; - } - case proto::AttrType::INTS: { - VLOG(11) << "SetAttr: " << Type() << ", " << name - << " from INTS to INTS"; - this->attrs_[name] = std::vector(); - break; - } - case proto::AttrType::FLOATS: { - VLOG(11) << "SetAttr: " << Type() << ", " << name - << " from INTS to FLOATS"; - this->attrs_[name] = std::vector(); - break; - } - case proto::AttrType::STRINGS: { - VLOG(11) << "SetAttr: " << Type() << ", " << name - << " from INTS to STRINGS"; - this->attrs_[name] = std::vector(); - break; - } - case proto::AttrType::BLOCKS: { - VLOG(11) << "SetAttr: " << Type() << ", " << name - << " from INTS to BLOCKS"; - this->SetBlocksAttr(name, std::vector()); - return; - } - default: - PADDLE_THROW("Wrong attr type %d", attr.type()); - } - need_update_ = true; + const proto::OpProto::Attr& attr = GetProtoAttr(name); + switch (attr.type()) { + case proto::AttrType::BOOLEANS: { + VLOG(11) << "SetAttr: " << Type() << ", " << name + << " from INTS to BOOLEANS"; + this->attrs_[name] = std::vector(); + break; + } + case proto::AttrType::INTS: { + VLOG(11) << "SetAttr: " << Type() << ", " << name + << " from INTS to INTS"; + this->attrs_[name] = std::vector(); + break; + } + case proto::AttrType::FLOATS: { + VLOG(11) << "SetAttr: " << Type() << ", " << name + << " from INTS to FLOATS"; + this->attrs_[name] = std::vector(); + break; + } + case proto::AttrType::STRINGS: { + VLOG(11) << "SetAttr: " << Type() << ", " << name + << " from INTS to STRINGS"; + this->attrs_[name] = std::vector(); + break; + } + case proto::AttrType::BLOCKS: { + VLOG(11) << "SetAttr: " << Type() << ", " << name + << " from INTS to BLOCKS"; + this->SetBlocksAttr(name, std::vector()); return; } + default: + PADDLE_THROW("Wrong attr type %d", attr.type()); } + need_update_ = true; + return; } this->attrs_[name] = v; @@ -280,6 +275,18 @@ Attribute OpDesc::GetAttr(const std::string &name) const { return it->second; } +const proto::OpProto::Attr& OpDesc::GetProtoAttr(const std::string &name) { + proto::OpProto& proto = OpInfoMap::Instance().Get(Type()).Proto(); + for (int i = 0; i != proto.attrs_size(); ++i) { + const proto::OpProto::Attr &attr = proto.attrs(i); + if (attr.name() == name) { + return attr; + } + } + + PADDLE_THROW("Attribute %s is not found in proto %s", name, proto.type()); +} + Attribute OpDesc::GetNullableAttr(const std::string &name) const { auto it = attrs_.find(name); if (it != attrs_.end()) { diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 74dd8ec002..0421f36a35 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -81,6 +81,8 @@ class OpDesc { Attribute GetAttr(const std::string &name) const; + const proto::OpProto::Attr& GetProtoAttr(const std::string &name) const; + Attribute GetNullableAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const; diff --git a/python/paddle/dataset/cifar.py b/python/paddle/dataset/cifar.py index cfe9deeab0..b42bc192b2 100644 --- a/python/paddle/dataset/cifar.py +++ b/python/paddle/dataset/cifar.py @@ -55,9 +55,8 @@ def reader_creator(filename, sub_name, cycle=False): def reader(): with tarfile.open(filename, mode='r') as f: - names = [ - each_item.name for each_item in f if sub_name in each_item.name - ] + names = (each_item.name for each_item in f + if sub_name in each_item.name) while True: for name in names: diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index 07e6b199c0..a75cabd676 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -20,7 +20,6 @@ import shutil import sys import importlib import paddle.dataset -import paddle.fluid.compat as cpt import six.moves.cPickle as pickle import glob diff --git a/python/paddle/dataset/conll05.py b/python/paddle/dataset/conll05.py index e5a5b3a343..c716b5af67 100644 --- a/python/paddle/dataset/conll05.py +++ b/python/paddle/dataset/conll05.py @@ -90,8 +90,8 @@ def corpus_reader(data_path, words_name, props_name): labels = [] one_seg = [] for word, label in zip(words_file, props_file): - word = cpt.to_literal_str(word.strip()) - label = cpt.to_literal_str(label.strip().split()) + word = cpt.to_text(word.strip()) + label = cpt.to_text(label.strip().split()) if len(label) == 0: # end of sentence for i in range(len(one_seg[0])): diff --git a/python/paddle/dataset/movielens.py b/python/paddle/dataset/movielens.py index 137d6ca8d0..15bbd00c0b 100644 --- a/python/paddle/dataset/movielens.py +++ b/python/paddle/dataset/movielens.py @@ -114,7 +114,7 @@ def __initialize_meta_info__(): categories_set = set() with package.open('ml-1m/movies.dat') as movie_file: for i, line in enumerate(movie_file): - line = cpt.to_literal_str(line, encoding='latin') + line = cpt.to_text(line, encoding='latin') movie_id, title, categories = line.strip().split('::') categories = categories.split('|') for c in categories: @@ -139,7 +139,7 @@ def __initialize_meta_info__(): USER_INFO = dict() with package.open('ml-1m/users.dat') as user_file: for line in user_file: - line = cpt.to_literal_str(line, encoding='latin') + line = cpt.to_text(line, encoding='latin') uid, gender, age, job, _ = line.strip().split("::") USER_INFO[int(uid)] = UserInfo( index=uid, gender=gender, age=age, job_id=job) @@ -152,7 +152,7 @@ def __reader__(rand_seed=0, test_ratio=0.1, is_test=False): with zipfile.ZipFile(file=fn) as package: with package.open('ml-1m/ratings.dat') as rating: for line in rating: - line = cpt.to_literal_str(line, encoding='latin') + line = cpt.to_text(line, encoding='latin') if (rand.random() < test_ratio) == is_test: uid, mov_id, rating, _ = line.strip().split("::") uid = int(uid) diff --git a/python/paddle/dataset/wmt14.py b/python/paddle/dataset/wmt14.py index ecd39a79f1..d35e706131 100644 --- a/python/paddle/dataset/wmt14.py +++ b/python/paddle/dataset/wmt14.py @@ -55,7 +55,7 @@ def __read_to_dict(tar_file, dict_size): out_dict = dict() for line_count, line in enumerate(fd): if line_count < size: - out_dict[cpt.to_literal_str(line.strip())] = line_count + out_dict[cpt.to_text(line.strip())] = line_count else: break return out_dict diff --git a/python/paddle/dataset/wmt16.py b/python/paddle/dataset/wmt16.py index c5772e1f19..f0051c7736 100644 --- a/python/paddle/dataset/wmt16.py +++ b/python/paddle/dataset/wmt16.py @@ -89,9 +89,9 @@ def __load_dict(tar_file, dict_size, lang, reverse=False): with open(dict_path, "rb") as fdict: for idx, line in enumerate(fdict): if reverse: - word_dict[idx] = cpt.to_literal_str(line.strip()) + word_dict[idx] = cpt.to_text(line.strip()) else: - word_dict[cpt.to_literal_str(line.strip())] = idx + word_dict[cpt.to_text(line.strip())] = idx return word_dict diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 7ff77162aa..1c10d06c51 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -103,8 +103,8 @@ def _some_in_set_(cands, s): """ if len(cands) == 0: return False - literal_set = cpt.to_literal_str(s) - literal_cands = cpt.to_literal_str(cands) + literal_set = cpt.to_text(s) + literal_cands = cpt.to_text(cands) for c in literal_cands: if c in literal_set: return True @@ -117,7 +117,7 @@ def _strip_grad_suffix_(name): e.g. x@GRAD ==> x y@GRAD@RENAME@1 ==> y """ - name = cpt.to_literal_str(name) + name = cpt.to_text(name) pos = name.find(core.grad_var_suffix()) return name[:pos] if pos != -1 else name @@ -127,7 +127,7 @@ def _append_grad_suffix_(name): Append grad suffix to the given variable name e.g. x ==> x@GRAD """ - return cpt.to_literal_str(name) + core.grad_var_suffix() + return cpt.to_text(name) + core.grad_var_suffix() def _addup_repetitive_outputs_(op_descs): @@ -365,7 +365,7 @@ def _append_backward_ops_(block, # Getting op's corresponding grad_op grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, - cpt.to_literal_str(no_grad_dict[block.idx]), grad_sub_block_list) + cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) grad_op_descs.extend(grad_op_desc) grad_to_var.update(op_grad_to_var) @@ -600,7 +600,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, params_and_grads = [] for param in parameters: - if cpt.to_literal_str(param) not in grad_info_map: + if cpt.to_text(param) not in grad_info_map: continue grad_info = grad_info_map[param] grad_block = grad_info[1] diff --git a/python/paddle/fluid/compat.py b/python/paddle/fluid/compat.py index 62826c7ce9..50726b6fa1 100644 --- a/python/paddle/fluid/compat.py +++ b/python/paddle/fluid/compat.py @@ -17,7 +17,7 @@ import math __all__ = [ 'long_type', - 'to_literal_str', + 'to_text', 'to_bytes', 'round', 'floor_division', @@ -33,7 +33,7 @@ else: # str and bytes related functions -def to_literal_str(obj, encoding='utf-8', inplace=False): +def to_text(obj, encoding='utf-8', inplace=False): """ All string in PaddlePaddle should be represented as a literal string. This function will convert object to a literal string without any encoding. @@ -60,23 +60,23 @@ def to_literal_str(obj, encoding='utf-8', inplace=False): if isinstance(obj, list): if inplace: for i in six.moves.xrange(len(obj)): - obj[i] = _to_literal_str(obj[i], encoding) + obj[i] = _to_text(obj[i], encoding) return obj else: - return [_to_literal_str(item, encoding) for item in obj] + return [_to_text(item, encoding) for item in obj] elif isinstance(obj, set): if inplace: for item in obj: obj.remove(item) - obj.add(_to_literal_str(item, encoding)) + obj.add(_to_text(item, encoding)) return obj else: - return set([_to_literal_str(item, encoding) for item in obj]) + return set([_to_text(item, encoding) for item in obj]) else: - return _to_literal_str(obj, encoding) + return _to_text(obj, encoding) -def _to_literal_str(obj, encoding): +def _to_text(obj, encoding): """ In Python3: Decode the bytes type object to str type with specific encoding diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 073172ddba..5203aa160c 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -199,7 +199,7 @@ class Variable(object): if name is None: name = unique_name.generate('_generated_var') is_new_var = False - name = cpt.to_literal_str(name) + name = cpt.to_text(name) self.desc = self.block.desc.find_var(cpt.to_bytes(name)) if self.desc is None: @@ -326,7 +326,7 @@ class Variable(object): @property def name(self): - return cpt.to_literal_str(self.desc.name()) + return cpt.to_text(self.desc.name()) @name.setter def name(self, new_name): @@ -530,7 +530,7 @@ class Operator(object): elif isinstance(arg, six.binary_type): in_arg_names.append(arg.decode()) else: - in_arg_names.append(cpt.to_literal_str(arg.name)) + in_arg_names.append(cpt.to_text(arg.name)) self.desc.set_input(in_proto.name, in_arg_names) else: self.desc.set_input(in_proto.name, []) @@ -559,7 +559,7 @@ class Operator(object): (out_proto.name, len(out_args))) out_arg_names = [] for arg in out_args: - out_arg_names.append(cpt.to_literal_str(arg.name)) + out_arg_names.append(cpt.to_text(arg.name)) arg.op = self self.desc.set_output(out_proto.name, out_arg_names) @@ -986,8 +986,8 @@ class Block(object): Returns: Variable: the Variable with the giving name. """ - name = cpt.to_literal_str(name) - new_name = cpt.to_literal_str(new_name) + name = cpt.to_text(name) + new_name = cpt.to_text(new_name) if not self.has_var(name): raise ValueError("var %s is not in current block" % name) diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 2618872cc6..6654e33847 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -155,13 +155,13 @@ class ParallelExecutor(object): self.executor = core.ParallelExecutor( self._places, set([ - cpt.to_literal_str(p.name) + cpt.to_text(p.name) for p in main.global_block().iter_parameters() if not p.stop_gradient ]), - set(cpt.to_literal_str(var) + set(cpt.to_text(var) for var in self.persistable_vars), main.desc, - cpt.to_literal_str(loss_name) + cpt.to_text(loss_name) if loss_name else six.u(''), scope, local_scopes, exec_strategy, build_strategy, num_trainers, trainer_id) self.scope = scope @@ -275,7 +275,7 @@ class ParallelExecutor(object): fetch_var_name = '@FETCHED_VAR_NAME@' self.executor.run( - cpt.to_literal_str(fetch_list), cpt.to_literal_str(fetch_var_name)) + cpt.to_text(fetch_list), cpt.to_text(fetch_var_name)) arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() if self.is_dist: diff --git a/python/paddle/fluid/tests/unittests/test_compat.py b/python/paddle/fluid/tests/unittests/test_compat.py index 525789ddb6..7f45ee7052 100644 --- a/python/paddle/fluid/tests/unittests/test_compat.py +++ b/python/paddle/fluid/tests/unittests/test_compat.py @@ -26,44 +26,44 @@ class TestCompatible(unittest.TestCase): self.assertEqual(cpt.int_type, int) self.assertEqual(cpt.long_type, int) - def test_to_literal_str(self): + def test_to_text(self): # Only support python2.x and python3.x now self.assertTrue(six.PY2 | six.PY3) if six.PY2: # check None - self.assertIsNone(cpt.to_literal_str(None)) + self.assertIsNone(cpt.to_text(None)) # check all string related types - self.assertTrue(isinstance(cpt.to_literal_str(str("")), unicode)) - self.assertTrue(isinstance(cpt.to_literal_str(str("123")), unicode)) - self.assertTrue(isinstance(cpt.to_literal_str(b""), unicode)) - self.assertTrue(isinstance(cpt.to_literal_str(b""), unicode)) - self.assertTrue(isinstance(cpt.to_literal_str(u""), unicode)) - self.assertTrue(isinstance(cpt.to_literal_str(u""), unicode)) - - self.assertEqual(u"", cpt.to_literal_str(str(""))) - self.assertEqual(u"123", cpt.to_literal_str(str("123"))) - self.assertEqual(u"", cpt.to_literal_str(b"")) - self.assertEqual(u"123", cpt.to_literal_str(b"123")) - self.assertEqual(u"", cpt.to_literal_str(u"")) - self.assertEqual(u"123", cpt.to_literal_str(u"123")) + self.assertTrue(isinstance(cpt.to_text(str("")), unicode)) + self.assertTrue(isinstance(cpt.to_text(str("123")), unicode)) + self.assertTrue(isinstance(cpt.to_text(b""), unicode)) + self.assertTrue(isinstance(cpt.to_text(b""), unicode)) + self.assertTrue(isinstance(cpt.to_text(u""), unicode)) + self.assertTrue(isinstance(cpt.to_text(u""), unicode)) + + self.assertEqual(u"", cpt.to_text(str(""))) + self.assertEqual(u"123", cpt.to_text(str("123"))) + self.assertEqual(u"", cpt.to_text(b"")) + self.assertEqual(u"123", cpt.to_text(b"123")) + self.assertEqual(u"", cpt.to_text(u"")) + self.assertEqual(u"123", cpt.to_text(u"123")) # check list types, not inplace l = [""] - l2 = cpt.to_literal_str(l) + l2 = cpt.to_text(l) self.assertTrue(isinstance(l2, list)) self.assertFalse(l is l2) self.assertEqual(l, l2) self.assertEqual([u""], l2) l = ["", "123"] - l2 = cpt.to_literal_str(l) + l2 = cpt.to_text(l) self.assertTrue(isinstance(l2, list)) self.assertFalse(l is l2) self.assertEqual(l, l2) self.assertEqual([u"", u"123"], l2) l = ["", b'123', u"321"] - l2 = cpt.to_literal_str(l) + l2 = cpt.to_text(l) self.assertTrue(isinstance(l2, list)) self.assertFalse(l is l2) self.assertEqual(l, l2) @@ -73,19 +73,19 @@ class TestCompatible(unittest.TestCase): # check list types, inplace l = [""] - l2 = cpt.to_literal_str(l, inplace=True) + l2 = cpt.to_text(l, inplace=True) self.assertTrue(isinstance(l2, list)) self.assertTrue(l is l2) self.assertEqual(l, l2) self.assertEqual([u""], l2) l = ["", "123"] - l2 = cpt.to_literal_str(l, inplace=True) + l2 = cpt.to_text(l, inplace=True) self.assertTrue(isinstance(l2, list)) self.assertTrue(l is l2) self.assertEqual(l, l2) self.assertEqual([u"", u"123"], l2) l = ["", b"123", u"321"] - l2 = cpt.to_literal_str(l, inplace=True) + l2 = cpt.to_text(l, inplace=True) self.assertTrue(isinstance(l2, list)) self.assertTrue(l is l2) self.assertEqual(l, l2) @@ -93,19 +93,19 @@ class TestCompatible(unittest.TestCase): # check set types, not inplace l = set("") - l2 = cpt.to_literal_str(l, inplace=False) + l2 = cpt.to_text(l, inplace=False) self.assertTrue(isinstance(l2, set)) self.assertFalse(l is l2) self.assertEqual(l, l2) self.assertEqual(set(u""), l2) l = set([b"", b"123"]) - l2 = cpt.to_literal_str(l, inplace=False) + l2 = cpt.to_text(l, inplace=False) self.assertTrue(isinstance(l2, set)) self.assertFalse(l is l2) self.assertEqual(l, l2) self.assertEqual(set([u"", u"123"]), l2) l = set(["", b"123", u"321"]) - l2 = cpt.to_literal_str(l, inplace=False) + l2 = cpt.to_text(l, inplace=False) self.assertTrue(isinstance(l2, set)) self.assertFalse(l is l2) self.assertEqual(l, l2) @@ -115,56 +115,56 @@ class TestCompatible(unittest.TestCase): # check set types, inplace l = set("") - l2 = cpt.to_literal_str(l, inplace=True) + l2 = cpt.to_text(l, inplace=True) self.assertTrue(isinstance(l2, set)) self.assertTrue(l is l2) self.assertEqual(l, l2) self.assertEqual(set(u""), l2) l = set([b"", b"123"]) - l2 = cpt.to_literal_str(l, inplace=True) + l2 = cpt.to_text(l, inplace=True) self.assertTrue(isinstance(l2, set)) self.assertTrue(l is l2) self.assertEqual(l, l2) self.assertEqual(set([u"", u"123"]), l2) l = set(["", b"123", u"321"]) - l2 = cpt.to_literal_str(l, inplace=True) + l2 = cpt.to_text(l, inplace=True) self.assertTrue(isinstance(l2, set)) self.assertTrue(l is l2) self.assertEqual(l, l2) self.assertEqual(set([u"", u"123", u"321"]), l2) elif six.PY3: - self.assertIsNone(cpt.to_literal_str(None)) - - self.assertTrue(isinstance(cpt.to_literal_str(str("")), str)) - self.assertTrue(isinstance(cpt.to_literal_str(str("123")), str)) - self.assertTrue(isinstance(cpt.to_literal_str(b""), str)) - self.assertTrue(isinstance(cpt.to_literal_str(b""), str)) - self.assertTrue(isinstance(cpt.to_literal_str(u""), str)) - self.assertTrue(isinstance(cpt.to_literal_str(u""), str)) - - self.assertEqual("", cpt.to_literal_str(str(""))) - self.assertEqual("123", cpt.to_literal_str(str("123"))) - self.assertEqual("", cpt.to_literal_str(b"")) - self.assertEqual("123", cpt.to_literal_str(b"123")) - self.assertEqual("", cpt.to_literal_str(u"")) - self.assertEqual("123", cpt.to_literal_str(u"123")) + self.assertIsNone(cpt.to_text(None)) + + self.assertTrue(isinstance(cpt.to_text(str("")), str)) + self.assertTrue(isinstance(cpt.to_text(str("123")), str)) + self.assertTrue(isinstance(cpt.to_text(b""), str)) + self.assertTrue(isinstance(cpt.to_text(b""), str)) + self.assertTrue(isinstance(cpt.to_text(u""), str)) + self.assertTrue(isinstance(cpt.to_text(u""), str)) + + self.assertEqual("", cpt.to_text(str(""))) + self.assertEqual("123", cpt.to_text(str("123"))) + self.assertEqual("", cpt.to_text(b"")) + self.assertEqual("123", cpt.to_text(b"123")) + self.assertEqual("", cpt.to_text(u"")) + self.assertEqual("123", cpt.to_text(u"123")) # check list types, not inplace l = [""] - l2 = cpt.to_literal_str(l) + l2 = cpt.to_text(l) self.assertTrue(isinstance(l2, list)) self.assertFalse(l is l2) self.assertEqual(l, l2) self.assertEqual([""], l2) l = ["", "123"] - l2 = cpt.to_literal_str(l) + l2 = cpt.to_text(l) self.assertTrue(isinstance(l2, list)) self.assertFalse(l is l2) self.assertEqual(l, l2) self.assertEqual(["", "123"], l2) l = ["", b"123", u"321"] - l2 = cpt.to_literal_str(l) + l2 = cpt.to_text(l) self.assertTrue(isinstance(l2, list)) self.assertFalse(l is l2) self.assertNotEqual(l, l2) @@ -172,19 +172,19 @@ class TestCompatible(unittest.TestCase): # check list types, inplace l = [""] - l2 = cpt.to_literal_str(l, inplace=True) + l2 = cpt.to_text(l, inplace=True) self.assertTrue(isinstance(l2, list)) self.assertTrue(l is l2) self.assertEqual(l, l2) self.assertEqual([""], l2) l = ["", b"123"] - l2 = cpt.to_literal_str(l, inplace=True) + l2 = cpt.to_text(l, inplace=True) self.assertTrue(isinstance(l2, list)) self.assertTrue(l is l2) self.assertEqual(l, l2) self.assertEqual(["", "123"], l2) l = ["", b"123", u"321"] - l2 = cpt.to_literal_str(l, inplace=True) + l2 = cpt.to_text(l, inplace=True) self.assertTrue(isinstance(l2, list)) self.assertTrue(l is l2) self.assertEqual(l, l2) @@ -194,19 +194,19 @@ class TestCompatible(unittest.TestCase): # check set types, not inplace l = set("") - l2 = cpt.to_literal_str(l, inplace=False) + l2 = cpt.to_text(l, inplace=False) self.assertTrue(isinstance(l2, set)) self.assertFalse(l is l2) self.assertEqual(l, l2) self.assertEqual(set(""), l2) l = set([b"", b"123"]) - l2 = cpt.to_literal_str(l, inplace=False) + l2 = cpt.to_text(l, inplace=False) self.assertTrue(isinstance(l2, set)) self.assertFalse(l is l2) self.assertNotEqual(l, l2) self.assertEqual(set(["", "123"]), l2) l = set(["", b"123", u"321"]) - l2 = cpt.to_literal_str(l, inplace=False) + l2 = cpt.to_text(l, inplace=False) self.assertTrue(isinstance(l2, set)) self.assertFalse(l is l2) self.assertNotEqual(l, l2) @@ -214,19 +214,19 @@ class TestCompatible(unittest.TestCase): # check set types, inplace l = set("") - l2 = cpt.to_literal_str(l, inplace=True) + l2 = cpt.to_text(l, inplace=True) self.assertTrue(isinstance(l2, set)) self.assertTrue(l is l2) self.assertEqual(l, l2) self.assertEqual(set(""), l2) l = set([b"", b"123"]) - l2 = cpt.to_literal_str(l, inplace=True) + l2 = cpt.to_text(l, inplace=True) self.assertTrue(isinstance(l2, set)) self.assertTrue(l is l2) self.assertEqual(l, l2) self.assertEqual(set(["", "123"]), l2) l = set(["", b"123", u"321"]) - l2 = cpt.to_literal_str(l, inplace=True) + l2 = cpt.to_text(l, inplace=True) self.assertTrue(isinstance(l2, set)) self.assertTrue(l is l2) self.assertEqual(l, l2) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 78ed29be56..e0b545183c 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -186,7 +186,7 @@ class TestDistBase(unittest.TestCase): env=env_local) local_proc.wait() out, err = local_proc.communicate() - local_ret = cpt.to_literal_str(out) + local_ret = cpt.to_text(out) sys.stderr.write('local_loss: %s\n' % local_ret) sys.stderr.write('local_stderr: %s\n' % err) @@ -224,7 +224,7 @@ class TestDistBase(unittest.TestCase): tr1_proc.wait() out, err = tr0_proc.communicate() sys.stderr.write('dist_stderr: %s\n' % err) - loss_data0 = cpt.to_literal_str(out) + loss_data0 = cpt.to_text(out) sys.stderr.write('dist_loss: %s\n' % loss_data0) lines = loss_data0.split("\n") dist_first_loss = eval(lines[0].replace(" ", ","))[0] diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 57c72f465b..06cb16db6f 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -260,7 +260,7 @@ class ControlFlowGraph(object): # memory allocated in order to reuse the memory. _rename_arg_(self._ops, x, cache_var, begin_idx=i) self._program.block(block_desc.id).var( - cpt.to_literal_str(x)).desc = self._find_var( + cpt.to_text(x)).desc = self._find_var( block_desc, cache_var, is_forward) self._update_graph(x, cache_var, begin_idx=i) break diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index d53694959b..6ec36b2bb9 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -391,9 +391,9 @@ class PipeReader: buff = self.process.stdout.read(self.bufsize) if buff: if self.file_type == "gzip": - decomp_buff = cpt.to_literal_str(self.dec.decompress(buff)) + decomp_buff = cpt.to_text(self.dec.decompress(buff)) elif self.file_type == "plain": - decomp_buff = cpt.to_literal_str(buff) + decomp_buff = cpt.to_text(buff) else: raise TypeError("file_type %s is not allowed" % self.file_type) From 5338417b4754d1ed9dd46d5c1e951aeea46d4888 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 14 Aug 2018 17:12:24 +0800 Subject: [PATCH 57/94] Polish code style --- paddle/fluid/framework/op_desc.cc | 6 +++--- paddle/fluid/framework/op_desc.h | 2 +- python/paddle/fluid/backward.py | 3 +-- python/paddle/fluid/parallel_executor.py | 6 ++---- .../fluid/transpiler/memory_optimization_transpiler.py | 6 +++--- 5 files changed, 10 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index b0a07ccad6..9399b8675e 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -209,7 +209,7 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { if (attr_type == proto::AttrType::INTS && boost::get>(v).size() == 0u) { // Find current attr via attr name and set the correct attribute value - const proto::OpProto::Attr& attr = GetProtoAttr(name); + const proto::OpProto::Attr &attr = GetProtoAttr(name); switch (attr.type()) { case proto::AttrType::BOOLEANS: { VLOG(11) << "SetAttr: " << Type() << ", " << name @@ -275,8 +275,8 @@ Attribute OpDesc::GetAttr(const std::string &name) const { return it->second; } -const proto::OpProto::Attr& OpDesc::GetProtoAttr(const std::string &name) { - proto::OpProto& proto = OpInfoMap::Instance().Get(Type()).Proto(); +const proto::OpProto::Attr &OpDesc::GetProtoAttr(const std::string &name) { + proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto(); for (int i = 0; i != proto.attrs_size(); ++i) { const proto::OpProto::Attr &attr = proto.attrs(i); if (attr.name() == name) { diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 0421f36a35..6805d25934 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -81,7 +81,7 @@ class OpDesc { Attribute GetAttr(const std::string &name) const; - const proto::OpProto::Attr& GetProtoAttr(const std::string &name) const; + const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const; Attribute GetNullableAttr(const std::string &name) const; diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 1c10d06c51..e552b79219 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -364,8 +364,7 @@ def _append_backward_ops_(block, # Getting op's corresponding grad_op grad_op_desc, op_grad_to_var = core.get_grad_op_desc( - op.desc, - cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) + op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) grad_op_descs.extend(grad_op_desc) grad_to_var.update(op_grad_to_var) diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 6654e33847..35c3ab59c2 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -159,8 +159,7 @@ class ParallelExecutor(object): for p in main.global_block().iter_parameters() if not p.stop_gradient ]), - set(cpt.to_text(var) - for var in self.persistable_vars), main.desc, + set(cpt.to_text(var) for var in self.persistable_vars), main.desc, cpt.to_text(loss_name) if loss_name else six.u(''), scope, local_scopes, exec_strategy, build_strategy, num_trainers, trainer_id) @@ -274,8 +273,7 @@ class ParallelExecutor(object): self.executor.feed_tensors_into_local_scopes(res) fetch_var_name = '@FETCHED_VAR_NAME@' - self.executor.run( - cpt.to_text(fetch_list), cpt.to_text(fetch_var_name)) + self.executor.run(cpt.to_text(fetch_list), cpt.to_text(fetch_var_name)) arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() if self.is_dist: diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 06cb16db6f..293c7841ec 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -259,9 +259,9 @@ class ControlFlowGraph(object): # Rename the var to the cache var already with # memory allocated in order to reuse the memory. _rename_arg_(self._ops, x, cache_var, begin_idx=i) - self._program.block(block_desc.id).var( - cpt.to_text(x)).desc = self._find_var( - block_desc, cache_var, is_forward) + self._program.block(block_desc.id).var(cpt.to_text( + x)).desc = self._find_var(block_desc, cache_var, + is_forward) self._update_graph(x, cache_var, begin_idx=i) break From 0707abb51b0bdbd4cddd3c0c62ce5288515217b1 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 14 Aug 2018 17:16:46 +0800 Subject: [PATCH 58/94] lookup table fix --- python/paddle/fluid/transpiler/distribute_transpiler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 1bb86acdf8..0328c172cd 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -877,9 +877,15 @@ class DistributeTranspiler(object): # create table param and grad var in pserver program origin_param_var = self.origin_program.global_block().vars[ self.table_name] + + zero_dim = long( + math.ceil(origin_param_var.shape[0] / len(self.pserver_endpoints))) + table_shape = list(origin_param_var.shape) + table_shape[0] = zero_dim + param_var = pserver_program.global_block().create_var( name=origin_param_var.name, - shape=origin_param_var.shape, + shape=table_shape, dtype=origin_param_var.dtype, type=core.VarDesc.VarType.SELECTED_ROWS, persistable=True) From 7e0f66e99ae47401c2339e7f57f9b4dd499b4501 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 14 Aug 2018 17:36:46 +0800 Subject: [PATCH 59/94] Polish code --- paddle/fluid/framework/op_desc.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 9399b8675e..af26cf2872 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -275,8 +275,9 @@ Attribute OpDesc::GetAttr(const std::string &name) const { return it->second; } -const proto::OpProto::Attr &OpDesc::GetProtoAttr(const std::string &name) { - proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto(); +const proto::OpProto::Attr &OpDesc::GetProtoAttr( + const std::string &name) const { + const proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto(); for (int i = 0; i != proto.attrs_size(); ++i) { const proto::OpProto::Attr &attr = proto.attrs(i); if (attr.name() == name) { From da39d84a48d1445d6bb9fb10e8d7d17d9053c7b7 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Tue, 14 Aug 2018 09:55:47 +0000 Subject: [PATCH 60/94] refine by reviewer's advice --- paddle/fluid/platform/enforce.h | 4 ++-- paddle/fluid/platform/enforce_test.cc | 14 +++++++------- paddle/fluid/platform/gpu_info.cc | 21 +++++++++++---------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index cad60275a2..81b5359b40 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -263,7 +263,7 @@ inline void throw_on_error(T e) { * PADDLE_ENFORCE_EQ(a, b); * * will raise an expression described as follows: - * "Data check failed. Expected input a == b, but received a(1) != b(2)." + * "Enforce failed. Expected input a == b, but received a(1) != b(2)." * with detailed stack information. * * extra messages is also supported, for example: @@ -293,7 +293,7 @@ inline void throw_on_error(T e) { #define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ do { \ if (UNLIKELY(!((__VAL0)__CMP(__VAL1)))) { \ - PADDLE_THROW("Data check failed. Expected %s " #__CMP \ + PADDLE_THROW("Enforce failed. Expected %s " #__CMP \ " %s, but received %s:%s " #__INV_CMP " %s:%s.\n%s", \ #__VAL0, #__VAL1, #__VAL0, \ paddle::string::to_string(__VAL0), #__VAL1, \ diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index 8dcf39fdaa..d521829655 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -56,7 +56,7 @@ TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) { caught_exception = true; HasPrefix( StringPiece(error.what()), - "Data check failed. Expected a == 1 + 3, but received a:2 != 1 + 3:4."); + "Enforce failed. Expected a == 1 + 3, but received a:2 != 1 + 3:4."); } EXPECT_TRUE(caught_exception); } @@ -69,7 +69,7 @@ TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) { } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; HasPrefix(StringPiece(error.what()), - "Data check failed. Expected a == 1 + 3, but received a:2 != 1 + " + "Enforce failed. Expected a == 1 + 3, but received a:2 != 1 + " "3:4.\ntheir size not match"); } EXPECT_TRUE(caught_exception); @@ -89,7 +89,7 @@ TEST(ENFORCE_NE, FAIL) { caught_exception = true; EXPECT_TRUE(HasPrefix( StringPiece(error.what()), - "Data check failed. Expected 1.0 != 1UL, but received 1.0:1 == 1UL:1.")) + "Enforce failed. Expected 1.0 != 1UL, but received 1.0:1 == 1UL:1.")) << error.what() << " does not have expected prefix"; } EXPECT_TRUE(caught_exception); @@ -104,7 +104,7 @@ TEST(ENFORCE_GT, FAIL) { caught_exception = true; EXPECT_TRUE(HasPrefix( StringPiece(error.what()), - "Data check failed. Expected 1 > 2UL, but received 1:1 <= 2UL:2.")); + "Enforce failed. Expected 1 > 2UL, but received 1:1 <= 2UL:2.")); } EXPECT_TRUE(caught_exception); } @@ -123,7 +123,7 @@ TEST(ENFORCE_GE, FAIL) { caught_exception = true; EXPECT_TRUE(HasPrefix( StringPiece(error.what()), - "Data check failed. Expected 1 >= 2UL, but received 1:1 < 2UL:2.")); + "Enforce failed. Expected 1 >= 2UL, but received 1:1 < 2UL:2.")); } EXPECT_TRUE(caught_exception); } @@ -143,7 +143,7 @@ TEST(ENFORCE_LE, FAIL) { caught_exception = true; EXPECT_TRUE(HasPrefix( StringPiece(error.what()), - "Data check failed. Expected 1 > 2UL, but received 1:1 <= 2UL:2.")); + "Enforce failed. Expected 1 > 2UL, but received 1:1 <= 2UL:2.")); } EXPECT_TRUE(caught_exception); } @@ -160,7 +160,7 @@ TEST(ENFORCE_LT, FAIL) { } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; EXPECT_TRUE(HasPrefix(StringPiece(error.what()), - "Data check failed. Expected 1UL < 0.12, but " + "Enforce failed. Expected 1UL < 0.12, but " "received 1UL:1 >= 0.12:0.12.")); } EXPECT_TRUE(caught_exception); diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index f9e2e8c69d..126636d879 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -100,25 +100,26 @@ size_t GpuMinChunkSize() { size_t GpuMaxChunkSize() { size_t total = 0; - size_t available_memory = 0; + size_t available = 0; - GpuMemoryUsage(&available_memory, &total); - VLOG(10) << "GPU Usage " << available_memory / 1024 / 1024 << "M/" + GpuMemoryUsage(&available, &total); + VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/" << total / 1024 / 1024 << "M"; size_t reserving = static_cast(0.05 * total); // If available less than minimum chunk size, no usable memory exists. - available_memory = std::min( - std::max(available_memory, GpuMinChunkSize()) - GpuMinChunkSize(), - total - reserving); + available = + std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(), + total - reserving); // Reserving the rest memory for page tables, etc. - size_t allocating_memory = static_cast( - FLAGS_fraction_of_gpu_memory_to_use * (total - reserving)); + size_t allocating = static_cast(FLAGS_fraction_of_gpu_memory_to_use * + (total - reserving)); - PADDLE_ENFORCE_LE(allocating_memory, available_memory); + PADDLE_ENFORCE_LE(allocating, available, + "Insufficient GPU memory to allocation."); - return allocating_memory; + return allocating; } void GpuMemcpyAsync(void *dst, const void *src, size_t count, From e0d5f8a8207b5366582888a7d4bd2086f39c02ce Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 14 Aug 2018 18:16:59 +0800 Subject: [PATCH 61/94] Move compat module to python/paddle --- paddle/fluid/pybind/protobuf.cc | 3 +-- python/paddle/__init__.py | 1 + python/paddle/{fluid => }/compat.py | 0 python/paddle/dataset/conll05.py | 2 +- python/paddle/dataset/movielens.py | 2 +- python/paddle/dataset/wmt14.py | 2 +- python/paddle/dataset/wmt16.py | 2 +- python/paddle/fluid/backward.py | 2 +- python/paddle/fluid/framework.py | 2 +- python/paddle/fluid/layers/detection.py | 2 +- python/paddle/fluid/parallel_executor.py | 2 +- python/paddle/fluid/tests/unittests/test_compat.py | 2 +- python/paddle/fluid/tests/unittests/test_dist_base.py | 2 +- python/paddle/fluid/tests/unittests/test_exception.py | 2 +- python/paddle/fluid/tests/unittests/test_lookup_table_op.py | 2 +- python/paddle/fluid/tests/unittests/test_operator_desc.py | 2 +- python/paddle/fluid/tests/unittests/test_protobuf_descs.py | 2 +- python/paddle/fluid/tests/unittests/test_roi_pool_op.py | 2 +- .../paddle/fluid/transpiler/memory_optimization_transpiler.py | 2 +- python/paddle/reader/decorator.py | 2 +- 20 files changed, 19 insertions(+), 19 deletions(-) rename python/paddle/{fluid => }/compat.py (100%) diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 2372db9715..7f06f7a9d7 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -205,8 +205,7 @@ void BindBlockDesc(pybind11::module *m) { void BindVarDsec(pybind11::module *m) { pybind11::class_ var_desc(*m, "VarDesc", ""); var_desc - .def("name", [](pd::VarDesc &self) { return self.Name(); }, - pybind11::return_value_policy::reference) + .def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference) .def("set_name", &pd::VarDesc::SetName) .def("set_shape", &pd::VarDesc::SetShape) .def("set_shapes", &pd::VarDesc::SetShapes) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 241a07a352..53746afdb2 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -24,4 +24,5 @@ except ImportError: import paddle.reader import paddle.dataset import paddle.batch +import paddle.compat batch = batch.batch diff --git a/python/paddle/fluid/compat.py b/python/paddle/compat.py similarity index 100% rename from python/paddle/fluid/compat.py rename to python/paddle/compat.py diff --git a/python/paddle/dataset/conll05.py b/python/paddle/dataset/conll05.py index c716b5af67..b23d127eeb 100644 --- a/python/paddle/dataset/conll05.py +++ b/python/paddle/dataset/conll05.py @@ -24,7 +24,7 @@ import tarfile import gzip import itertools import paddle.dataset.common -import paddle.fluid.compat as cpt +import paddle.compat as cpt from six.moves import zip, range __all__ = ['test, get_dict', 'get_embedding', 'convert'] diff --git a/python/paddle/dataset/movielens.py b/python/paddle/dataset/movielens.py index 15bbd00c0b..fe07daf5c3 100644 --- a/python/paddle/dataset/movielens.py +++ b/python/paddle/dataset/movielens.py @@ -28,7 +28,7 @@ import re import random import functools import six -import paddle.fluid.compat as cpt +import paddle.compat as cpt __all__ = [ 'train', 'test', 'get_movie_title_dict', 'max_movie_id', 'max_user_id', diff --git a/python/paddle/dataset/wmt14.py b/python/paddle/dataset/wmt14.py index d35e706131..cf366309c0 100644 --- a/python/paddle/dataset/wmt14.py +++ b/python/paddle/dataset/wmt14.py @@ -24,7 +24,7 @@ import tarfile import gzip import paddle.dataset.common -import paddle.fluid.compat as cpt +import paddle.compat as cpt __all__ = [ 'train', diff --git a/python/paddle/dataset/wmt16.py b/python/paddle/dataset/wmt16.py index f0051c7736..d68a6e8be7 100644 --- a/python/paddle/dataset/wmt16.py +++ b/python/paddle/dataset/wmt16.py @@ -35,7 +35,7 @@ import gzip from collections import defaultdict import paddle.dataset.common -import paddle.fluid.compat as cpt +import paddle.compat as cpt __all__ = [ "train", diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index e552b79219..aa62d34491 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -17,7 +17,7 @@ from . import core import collections import copy import six -from . import compat as cpt +from .. import compat as cpt from . import unique_name __all__ = ['append_backward'] diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5203aa160c..fbeb0e5940 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -19,7 +19,7 @@ import six import numpy as np -from . import compat as cpt +from .. import compat as cpt from .proto import framework_pb2 try: from . import core diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 4f317a4030..9f8cb6ef0b 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -20,7 +20,7 @@ from .layer_function_generator import autodoc, templatedoc from ..layer_helper import LayerHelper from . import tensor from . import nn -from .. import compat as cpt +from ... import compat as cpt import math import six from functools import reduce diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 35c3ab59c2..ac87b12a1c 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -17,7 +17,7 @@ import multiprocessing from . import core from . import framework from . import executor -from . import compat as cpt +from .. import compat as cpt import warnings import sys import six diff --git a/python/paddle/fluid/tests/unittests/test_compat.py b/python/paddle/fluid/tests/unittests/test_compat.py index 7f45ee7052..eabcced5d1 100644 --- a/python/paddle/fluid/tests/unittests/test_compat.py +++ b/python/paddle/fluid/tests/unittests/test_compat.py @@ -13,7 +13,7 @@ # limitations under the License. import unittest -import paddle.fluid.compat as cpt +import paddle.compat as cpt import six diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index e0b545183c..e059f2cd2a 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -124,7 +124,7 @@ def runtime_main(test_class): model.run_trainer(p, endpoints, trainer_id, trainers, is_dist) -import paddle.fluid.compat as cpt +import paddle.compat as cpt class TestDistBase(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_exception.py b/python/paddle/fluid/tests/unittests/test_exception.py index 6e4ea273a9..a43df91342 100644 --- a/python/paddle/fluid/tests/unittests/test_exception.py +++ b/python/paddle/fluid/tests/unittests/test_exception.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle.fluid.compat as cpt +import paddle.compat as cpt import paddle.fluid.core as core import unittest diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index a325422c31..77fb8154f0 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -17,7 +17,7 @@ import numpy as np from op_test import OpTest import paddle.fluid.core as core from paddle.fluid.op import Operator -import paddle.fluid.compat as cpt +import paddle.compat as cpt class TestLookupTableOp(OpTest): diff --git a/python/paddle/fluid/tests/unittests/test_operator_desc.py b/python/paddle/fluid/tests/unittests/test_operator_desc.py index 5634e29d01..21a113f509 100644 --- a/python/paddle/fluid/tests/unittests/test_operator_desc.py +++ b/python/paddle/fluid/tests/unittests/test_operator_desc.py @@ -15,7 +15,7 @@ import unittest import paddle.fluid.core as core -import paddle.fluid.compat as cpt +import paddle.compat as cpt from paddle.fluid.framework import Program, default_startup_program diff --git a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py index 2176db71b9..37de792114 100644 --- a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py +++ b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py @@ -14,7 +14,7 @@ import unittest import paddle.fluid.core as core -import paddle.fluid.compat as cpt +import paddle.compat as cpt from paddle.fluid.framework import Program diff --git a/python/paddle/fluid/tests/unittests/test_roi_pool_op.py b/python/paddle/fluid/tests/unittests/test_roi_pool_op.py index 0f38b742d9..9b0a3f26b7 100644 --- a/python/paddle/fluid/tests/unittests/test_roi_pool_op.py +++ b/python/paddle/fluid/tests/unittests/test_roi_pool_op.py @@ -16,7 +16,7 @@ import unittest import numpy as np import math import sys -import paddle.fluid.compat as cpt +import paddle.compat as cpt from op_test import OpTest diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 293c7841ec..0de994dda3 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -14,7 +14,7 @@ from collections import defaultdict from .. import core -from .. import compat as cpt +from ... import compat as cpt from ..framework import Program, default_main_program, Parameter from ..backward import _rename_arg_ from functools import reduce diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index 6ec36b2bb9..6d7ac876fd 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -27,7 +27,7 @@ from six.moves import zip import itertools import random import zlib -import paddle.fluid.compat as cpt +import paddle.compat as cpt def map_readers(func, *readers): From f42247e55f9978054dd1057d46e89adc6f928135 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 14 Aug 2018 18:33:26 +0800 Subject: [PATCH 62/94] change to int to compatible with py3 --- python/paddle/fluid/transpiler/distribute_transpiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 0328c172cd..0faf9d0969 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -878,7 +878,7 @@ class DistributeTranspiler(object): origin_param_var = self.origin_program.global_block().vars[ self.table_name] - zero_dim = long( + zero_dim = int( math.ceil(origin_param_var.shape[0] / len(self.pserver_endpoints))) table_shape = list(origin_param_var.shape) table_shape[0] = zero_dim From c01caba1205c7ec3f0dd907a90e3530c0927a23c Mon Sep 17 00:00:00 2001 From: luotao1 Date: Tue, 14 Aug 2018 18:59:14 +0800 Subject: [PATCH 63/94] fix inference doc url error --- doc/fluid/new_docs/advanced_usage/deploy/native_infer.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/fluid/new_docs/advanced_usage/deploy/native_infer.rst b/doc/fluid/new_docs/advanced_usage/deploy/native_infer.rst index e1eee3f818..3571f81326 100644 --- a/doc/fluid/new_docs/advanced_usage/deploy/native_infer.rst +++ b/doc/fluid/new_docs/advanced_usage/deploy/native_infer.rst @@ -4,7 +4,7 @@ Paddle 预测 API 为了更简单方便的预测部署,Fluid 提供了一套高层 API 用来隐藏底层不同的优化实现。 -`预测库相关代码 `__ +`预测库相关代码 `__ 包括 - 头文件 ``paddle_inference_api.h`` 定义了所有的接口 @@ -104,5 +104,5 @@ engine ------------ - `inference - demos `__ -- `复杂单线程/多线程例子 `__ + demos `__ +- `复杂单线程/多线程例子 `__ From 7797e55f425d22a4fa812c9ac14eb906828828f2 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Tue, 14 Aug 2018 11:28:07 +0000 Subject: [PATCH 64/94] use paddle::platform::demangle --- paddle/fluid/framework/CMakeLists.txt | 2 - paddle/fluid/framework/attribute.h | 9 +- paddle/fluid/framework/attribute_type.h | 97 ------------------- paddle/fluid/framework/attribute_type_test.cc | 46 --------- 4 files changed, 4 insertions(+), 150 deletions(-) delete mode 100644 paddle/fluid/framework/attribute_type.h delete mode 100644 paddle/fluid/framework/attribute_type_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index b3fe2d97a8..6440607dbe 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -115,8 +115,6 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) # cc_test(channel_test SRCS channel_test.cc) cc_test(tuple_test SRCS tuple_test.cc ) -cc_test(attribute_type_test SRCS attribute_type_test.cc) - # disable test temporarily. # TODO https://github.com/PaddlePaddle/Paddle/issues/11971 # cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 2b05528257..14ca3e9620 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -20,7 +20,6 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/attribute_type.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/platform/enforce.h" @@ -129,8 +128,8 @@ struct ExtractAttribute { attr_value = &boost::get(attr); } catch (boost::bad_get& bad_get) { PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s", - attr_name_, paddle::framework::demangle(typeid(T).name()), - paddle::framework::demangle(attr.type().name())); + attr_name_, paddle::platform::demangle(typeid(T).name()), + paddle::platform::demangle(attr.type().name())); } return attr_value; } @@ -162,7 +161,7 @@ struct ExtractAttribute { attr_value = &boost::get(attr); } catch (boost::bad_get& bad_get) { PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s", - attr_name_, paddle::framework::demangle(attr.type().name())); + attr_name_, paddle::platform::demangle(attr.type().name())); } return attr_value; } @@ -188,7 +187,7 @@ struct ExtractAttribute { attr_value = &boost::get(attr); } catch (boost::bad_get& bad_get) { PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", - attr_name_, paddle::framework::demangle(attr.type().name())); + attr_name_, paddle::platform::demangle(attr.type().name())); } return attr_value; } diff --git a/paddle/fluid/framework/attribute_type.h b/paddle/fluid/framework/attribute_type.h deleted file mode 100644 index 337dcde775..0000000000 --- a/paddle/fluid/framework/attribute_type.h +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -// __has_include is currently supported by GCC and Clang. However GCC 4.9 may -// have issues and -// returns 1 for 'defined( __has_include )', while '__has_include' is actually -// not supported: -#if defined(__has_include) && (!defined(BOOST_GCC) || (__GNUC__ + 0) >= 5) -#if __has_include() -#define PADDLE_FRAMEWORK_HAS_CXXABI_H -#endif -#elif defined(__GLIBCXX__) || defined(__GLIBCPP__) -#define PADDLE_FRAMEWORK_HAS_CXXABI_H -#endif - -#if defined(PADDLE_FRAMEWORK_HAS_CXXABI_H) -#include -// For some archtectures (mips, mips64, x86, x86_64) cxxabi.h in Android NDK is -// implemented by gabi++ library -// which does not implement abi::__cxa_demangle(). We detect this implementation -// by checking the include guard here. -#if defined(__GABIXX_CXXABI_H__) -#undef PADDLE_FRAMEWORK_HAS_CXXABI_H -#else -#include -#include -#endif -#endif - -namespace paddle { -namespace framework { - -inline char const* demangle_alloc(char const* name); -inline void demangle_free(char const* name); - -class scoped_demangled_name { - private: - char const* m_p; - - public: - explicit scoped_demangled_name(char const* name) - : m_p(demangle_alloc(name)) {} - - ~scoped_demangled_name() { demangle_free(m_p); } - - char const* get() const { return m_p; } - - scoped_demangled_name(scoped_demangled_name const&) = delete; - scoped_demangled_name& operator=(scoped_demangled_name const&) = delete; -}; - -#if defined(PADDLE_FRAMEWORK_HAS_CXXABI_H) - -inline char const* demangle_alloc(char const* name) { - int status = 0; - std::size_t size = 0; - return abi::__cxa_demangle(name, NULL, &size, &status); -} - -inline void demangle_free(char const* name) { - std::free(const_cast(name)); -} - -inline std::string demangle(char const* name) { - scoped_demangled_name demangled_name(name); - char const* p = demangled_name.get(); - if (!p) p = name; - return p; -} - -#else - -inline char const* demangle_alloc(char const* name) { return name; } - -inline void demangle_free(char const*) {} - -inline std::string demangle(char const* name) { return name; } - -#endif - -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/attribute_type_test.cc b/paddle/fluid/framework/attribute_type_test.cc deleted file mode 100644 index 82537b8a0f..0000000000 --- a/paddle/fluid/framework/attribute_type_test.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#include -#include - -#include "gtest/gtest.h" -#include "paddle/fluid/framework/attribute_type.h" - -TEST(Attribute, TypeName) { - bool boolean; - int integer; - float ft; - std::string str; - std::vector booleans; - std::vector integers; - std::vector strings; - - EXPECT_EQ("bool", paddle::framework::demangle(typeid(boolean).name())); - EXPECT_EQ("int", paddle::framework::demangle(typeid(integer).name())); - EXPECT_EQ("float", paddle::framework::demangle(typeid(ft).name())); - EXPECT_EQ( - "std::__cxx11::basic_string, " - "std::allocator >", - paddle::framework::demangle(typeid(str).name())); - EXPECT_EQ("std::vector >", - paddle::framework::demangle(typeid(booleans).name())); - EXPECT_EQ("std::vector >", - paddle::framework::demangle(typeid(integers).name())); - EXPECT_EQ( - "std::vector, " - "std::allocator >, std::allocator, std::allocator > > >", - paddle::framework::demangle(typeid(strings).name())); -} From e38eca26e2d307ee4f4f0303970d1072ea5f56b9 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 14 Aug 2018 19:33:04 +0800 Subject: [PATCH 65/94] Add libpng dependencies to yum Correct libnccl dir --- tools/manylinux1/Dockerfile.x64 | 2 +- tools/manylinux1/build_scripts/build.sh | 2 +- tools/manylinux1/build_scripts/install_nccl2.sh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/manylinux1/Dockerfile.x64 b/tools/manylinux1/Dockerfile.x64 index bca0b77ad7..34c54303bd 100644 --- a/tools/manylinux1/Dockerfile.x64 +++ b/tools/manylinux1/Dockerfile.x64 @@ -13,7 +13,7 @@ ENV PATH /opt/rh/devtoolset-2/root/usr/bin:$PATH ENV LD_LIBRARY_PATH /opt/rh/devtoolset-2/root/usr/lib64:/opt/rh/devtoolset-2/root/usr/lib:/usr/local/lib64:/usr/local/lib:${LD_LIBRARY_PATH} ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig -RUN yum install -y sqlite-devel zlib-devel openssl-devel pcre-devel vim tk-devel tkinter libtool xz +RUN yum install -y sqlite-devel zlib-devel openssl-devel pcre-devel vim tk-devel tkinter libtool xz freetype-devel libpng-devel COPY build_scripts /build_scripts RUN bash build_scripts/build.sh && \ bash build_scripts/install_nccl2.sh && rm -r build_scripts diff --git a/tools/manylinux1/build_scripts/build.sh b/tools/manylinux1/build_scripts/build.sh index 93591fa9dd..d99d3db2ed 100644 --- a/tools/manylinux1/build_scripts/build.sh +++ b/tools/manylinux1/build_scripts/build.sh @@ -105,7 +105,7 @@ curl-config --features rm -rf /usr/local/ssl # Install patchelf (latest with unreleased bug fixes) -curl -sLO https://nipy.bic.berkeley.edu/manylinux/patchelf-0.9njs2.tar.gz +curl -sLO http://nipy.bic.berkeley.edu/manylinux/patchelf-0.9njs2.tar.gz check_sha256sum patchelf-0.9njs2.tar.gz $PATCHELF_HASH tar -xzf patchelf-0.9njs2.tar.gz (cd patchelf-0.9njs2 && ./configure && make && make install) diff --git a/tools/manylinux1/build_scripts/install_nccl2.sh b/tools/manylinux1/build_scripts/install_nccl2.sh index 282c5c290d..43a99d8287 100644 --- a/tools/manylinux1/build_scripts/install_nccl2.sh +++ b/tools/manylinux1/build_scripts/install_nccl2.sh @@ -21,5 +21,5 @@ for sub_deb in $DEBS; do ar x $sub_deb && tar xf data.tar.xz done mv -f usr/include/nccl.h /usr/local/include/ -mv -f usr/lib/libnccl* /usr/local/lib/ +mv -f usr/lib/x86_64-linux-gnu/libnccl* /usr/local/lib/ rm -rf $DIR From 3373535b213c7ad5c24121e9a4e56534bc40e05b Mon Sep 17 00:00:00 2001 From: luotao1 Date: Tue, 14 Aug 2018 16:08:36 +0800 Subject: [PATCH 66/94] fix specific cudnn include and library path --- cmake/configure.cmake | 4 ++++ cmake/external/anakin.cmake | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index c35096e09b..ae90a529b1 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -104,6 +104,10 @@ if(WITH_GPU) if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7) message(FATAL_ERROR "Anakin needs CUDNN >= 7.0 to compile") endif() + set(ENV{CUDNN_INCLUDE_DIR} ${CUDNN_INCLUDE_DIR}) + set(ENV{CUDNN_LIBRARY} ${CUDNN_LIBRARY}) + message(STATUS "cudnn include header is ${CUDNN_INCLUDE_DIR}/cudnn.h") + message(STATUS "cudnn library is ${CUDNN_LIBRARY}") endif() elseif(WITH_AMD_GPU) add_definitions(-DPADDLE_WITH_HIP) diff --git a/cmake/external/anakin.cmake b/cmake/external/anakin.cmake index 403873a510..5de7ca8f46 100644 --- a/cmake/external/anakin.cmake +++ b/cmake/external/anakin.cmake @@ -37,7 +37,7 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} # TODO(luotao): use PaddlePaddle/Anakin later GIT_REPOSITORY "https://github.com/luotao1/Anakin" - GIT_TAG "3957ae9263eaa0b1986758dac60a88852afb09be" + GIT_TAG "842a89ae3747ede25d8acbc29030d2eb602ced1f" PREFIX ${ANAKIN_SOURCE_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DUSE_GPU_PLACE=YES From 1f86c88f4afaa3c2ea2bd9baabce521cc943e451 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 14 Aug 2018 21:51:06 +0800 Subject: [PATCH 67/94] Remove random order of fetch_list in test_random_crop_op --- CMakeLists.txt | 3 +-- python/paddle/fluid/tests/unittests/op_test.py | 6 +----- python/paddle/fluid/tests/unittests/test_random_crop_op.py | 7 ++++--- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f56c5d382a..920c20d6f8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -204,12 +204,11 @@ include(external/snappy) # download snappy include(external/snappystream) include(external/threadpool) +set(WITH_ANAKIN OFF CACHE STRING "Disable Anakin first, will add it later." FORCE) if(WITH_GPU) include(cuda) include(tensorrt) include(external/anakin) -else() - set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when GPU is set." FORCE) endif() include(cudnn) # set cudnn libraries, must before configure diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index ada4ad70f0..75373ae2e1 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -362,14 +362,10 @@ class OpTest(unittest.TestCase): def check_output_customized(self, checker): places = self._get_places() - import sys - print('places', places) for place in places: outs = self.calc_output(place) outs = [np.array(out) for out in outs] - import sys - print('outs', outs) - sys.stdout.flush() + outs.sort(key=len) checker(outs) def __assert_is_close(self, numeric_grads, analytic_grads, names, diff --git a/python/paddle/fluid/tests/unittests/test_random_crop_op.py b/python/paddle/fluid/tests/unittests/test_random_crop_op.py index 1acd377b1f..27e5db4991 100644 --- a/python/paddle/fluid/tests/unittests/test_random_crop_op.py +++ b/python/paddle/fluid/tests/unittests/test_random_crop_op.py @@ -23,9 +23,10 @@ class TestRandomCropOp(OpTest): to_crop = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]] * 5).astype(np.int32) self.possible_res = [ - np.array([[1, 2, 3], [5, 6, 7]]), np.array([[2, 3, 4], [6, 7, 8]]), - np.array([[5, 6, 7], [9, 10, 11]]), - np.array([[6, 7, 8], [10, 11, 12]]) + np.array([[1, 2, 3], [5, 6, 7]]).astype(np.int32), + np.array([[2, 3, 4], [6, 7, 8]]).astype(np.int32), + np.array([[5, 6, 7], [9, 10, 11]]).astype(np.int32), + np.array([[6, 7, 8], [10, 11, 12]]).astype(np.int32) ] self.op_type = "random_crop" self.inputs = {'X': to_crop, 'Seed': np.array([10])} From 478f73c18871957fa70baa0536105a8fabe6dbce Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 15 Aug 2018 10:45:39 +0800 Subject: [PATCH 68/94] merge header in cc --- paddle/fluid/operators/sampling_id_op.cc | 54 +++++++++++++++- paddle/fluid/operators/sampling_id_op.h | 82 ------------------------ 2 files changed, 52 insertions(+), 84 deletions(-) delete mode 100644 paddle/fluid/operators/sampling_id_op.h diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index d13eeabcb9..4929a7edc2 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -12,18 +12,68 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sampling_id_op.h" +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +template +class SamplingIdKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("X"); + const int batch_size = static_cast(input->dims()[0]); + const int width = static_cast(input->dims()[1]); + + std::vector ins_vector; + framework::TensorToVector(*input, context.device_context(), &ins_vector); + + std::vector ids(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + double r = getRandReal(); + int idx = width - 1; + for (int j = 0; j < width; ++j) { + if ((r -= ins_vector[i * width + j]) < 0) { + idx = j; + break; + } + } + ids[i] = ins_vector[i * width + idx]; + } + + std::vector out_dim; + out_dim.push_back(static_cast(batch_size)); + + Tensor* output = context.Output("Out"); + output->Resize(framework::make_ddim(out_dim)); + output->mutable_data(context.GetPlace()); + framework::TensorFromVector(ids, context.device_context(), output); + } + + private: + double getRandReal() const { + std::random_device + rd; // Will be used to obtain a seed for the random number engine + std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with + // rd() + std::uniform_real_distribution<> dis(1.0, 2.0); + return dis(gen); + } +}; + class SamplingIdOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SamplingIdOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h deleted file mode 100644 index 7f3ca8e761..0000000000 --- a/paddle/fluid/operators/sampling_id_op.h +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class SamplingIdKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("X"); - const int batch_size = static_cast(input->dims()[0]); - const int width = static_cast(input->dims()[1]); - - std::vector ins_vector; - framework::TensorToVector(*input, context.device_context(), &ins_vector); - - std::vector ids(batch_size); - for (size_t i = 0; i < batch_size; ++i) { - double r = this->getRandReal(); - int idx = width - 1; - for (int j = 0; j < width; ++j) { - if ((r -= ins_vector[i * width + j]) < 0) { - idx = j; - break; - } - } - ids[i] = ins_vector[i * width + idx]; - } - - std::vector out_dim; - out_dim.push_back(static_cast(batch_size)); - - Tensor* output = context.Output("Out"); - output->Resize(framework::make_ddim(out_dim)); - output->mutable_data(context.GetPlace()); - framework::TensorFromVector(ids, context.device_context(), output); - } - - private: - double getRandReal() const { - std::call_once(init_flag_, &SamplingIdKernel::getRndInstance); - return rnd(); - } - - static void getRndInstance() { - // Will be used to obtain a seed for the random number engine - std::random_device rd; - // Standard mersenne_twister_engine seeded with rd() - std::mt19937 gen(rd()); - std::uniform_real_distribution<> dis(0, 1); - rnd = std::bind(dis, std::ref(gen)); - } - - static std::once_flag init_flag_; - static std::function rnd; -}; -} // namespace operators -} // namespace paddle From 4d4491ef6a3c45c688b904941a85f9951b130c3d Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 15 Aug 2018 12:15:10 +0800 Subject: [PATCH 69/94] Fix new added code --- python/paddle/fluid/framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index d413f96e95..621e46b0f9 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1668,7 +1668,7 @@ class Program(object): root_block._remove_op(0, read_op_idx + 1) for var in root_block.all_vars(): if var.type() == core.VarDesc.VarType.READER: - root_block._remove_var(var.name()) + root_block._remove_var(cpt.to_bytes(var.name())) # change all `is_test` attributes to True for i in six.moves.range(res.desc.num_blocks()): From d84a1a0010fc038a7da2ee7cf3ebb4f93353f1a4 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 15 Aug 2018 12:24:03 +0800 Subject: [PATCH 70/94] fc op use cpu only --- paddle/fluid/operators/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ae37c70929..c3f7c42a82 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -158,6 +158,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") else() file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") + # HACK: fc only have cpu kernel + if (${MKLDNN_FILE} STREQUAL "fc_mkldnn_op") + file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") + set(pybind_flag 1) + endif() endif() endif() From eee38464dc5477480fd36e57305f36c9519c9c00 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 15 Aug 2018 13:36:32 +0800 Subject: [PATCH 71/94] refine fc op use cpu only --- paddle/fluid/operators/CMakeLists.txt | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c3f7c42a82..e8b5dec9d4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -158,11 +158,6 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") else() file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") - # HACK: fc only have cpu kernel - if (${MKLDNN_FILE} STREQUAL "fc_mkldnn_op") - file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") - set(pybind_flag 1) - endif() endif() endif() @@ -175,6 +170,9 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") elseif(${TARGET} STREQUAL "tensorrt_engine_op") message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference") + elseif(${TARGET} STREQUAL "fc") + # HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition + file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") else() file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") endif() From 13e99cf92fe7e64caa94ddeeb84d8e2a168ca3ec Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 15 Aug 2018 15:17:58 +0800 Subject: [PATCH 72/94] add unit test --- .../tests/unittests/test_dist_transpiler.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index b6f4f0726f..a536c21071 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + import unittest import paddle.fluid as fluid from paddle.fluid.transpiler.distribute_transpiler import delete_ops @@ -362,12 +364,13 @@ class TestL2DecayWithPiecewise(TranspilerTest): class TestDistLookupTableBase(TranspilerTest): def network_with_table(self, is_sparse, is_distributed): + self.table_size = 1000 + self.emb_size = 64 + def emb_pool(ids): - table_size = 1000 - emb_size = 64 emb = fluid.layers.embedding( input=ids, - size=[table_size, emb_size], + size=[self.table_size, self.emb_size], dtype='float32', param_attr='shared_w', # share parameter is_sparse=is_sparse, @@ -536,6 +539,22 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) +class TestDistLookupTableSliceSize(TestDistLookupTableBase): + def net_conf(self): + self.network_with_table(is_sparse=True, is_distributed=True) + + def transpiler_test_impl(self): + config = fluid.DistributeTranspilerConfig() + pserver1, startup1 = self.get_pserver(self.pserver1_ep, config) + + self.assertTrue(self.transpiler.has_distributed_lookup_table) + lookup_table_var = pserver1.global_block().vars[ + self.transpiler.table_name] + row_size = lookup_table_var.shape[0] + calc_row_size = int(math.ceil(self.table_size / self.pservers)) + self.assertEqual(row_size, calc_row_size) + + class TestRMSPropOptimizer(TranspilerTest): def net_conf(self): x = fluid.layers.data(name='x', shape=[1000], dtype='float32') From 4661f5589dc95a3bd3736848b820990c4c6e32d3 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 15 Aug 2018 16:01:53 +0800 Subject: [PATCH 73/94] random optimize --- paddle/fluid/operators/sampling_id_op.cc | 44 +++++++++++++++--------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index 4929a7edc2..f8f94553be 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -36,9 +36,19 @@ class SamplingIdKernel : public framework::OpKernel { std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); + unsigned int seed = static_cast(ctx.Attr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::uniform_real_distribution dist( + static_cast(ctx.Attr("min")), + static_cast(ctx.Attr("max"))); + std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { - double r = getRandReal(); + double r = dist(engine); int idx = width - 1; for (int j = 0; j < width; ++j) { if ((r -= ins_vector[i * width + j]) < 0) { @@ -57,16 +67,6 @@ class SamplingIdKernel : public framework::OpKernel { output->mutable_data(context.GetPlace()); framework::TensorFromVector(ids, context.device_context(), output); } - - private: - double getRandReal() const { - std::random_device - rd; // Will be used to obtain a seed for the random number engine - std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with - // rd() - std::uniform_real_distribution<> dis(1.0, 2.0); - return dis(gen); - } }; class SamplingIdOp : public framework::OperatorWithKernel { @@ -78,6 +78,9 @@ class SamplingIdOp : public framework::OperatorWithKernel { "Input(X) of SamplingIdOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SamplingIdOp should not be null."); + PADDLE_ENFORCE( + ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), + "min must less then max"); auto input_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE(input_dims.size() == 2, @@ -99,7 +102,17 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( SamplingId Operator. A layer for sampling id from multinomial distribution from the - input layer. Sampling one id for one sample.)DOC"); + input. Sampling one id for one sample.)DOC"); + AddAttr("min", "Minimum value of random. [default 0.0].") + .SetDefault(0.0f); + AddAttr("max", "Maximun value of random. [default 1.0].") + .SetDefault(1.0f); + AddAttr("seed", + "Random seed used for the random number engine. " + "0 means use a seed generated by the system." + "Note that if seed is not 0, this operator will always " + "generate the same random numbers every time. [default 0].") + .SetDefault(0); } }; } // namespace operators @@ -109,8 +122,5 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - sampling_id, ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel); +REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel, + paddle::operators::SamplingIdKernel); From 60dda7bf9f86a832b81d0e05f40a3fc1f462c456 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 15 Aug 2018 16:27:22 +0800 Subject: [PATCH 74/94] add gpu Implementation --- paddle/fluid/operators/sampling_id_op.cc | 4 +- paddle/fluid/operators/sampling_id_op.cu | 87 ++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/sampling_id_op.cu diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index f8f94553be..2549758a8e 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -25,7 +25,7 @@ namespace operators { using Tensor = framework::Tensor; -template +template class SamplingIdKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -48,7 +48,7 @@ class SamplingIdKernel : public framework::OpKernel { std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { - double r = dist(engine); + T r = dist(engine); int idx = width - 1; for (int j = 0; j < width; ++j) { if ((r -= ins_vector[i * width + j]) < 0) { diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu new file mode 100644 index 0000000000..791675b73b --- /dev/null +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -0,0 +1,87 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +template +struct UniformGenerator { + T min_, max_; + unsigned int seed_; + + __host__ __device__ UniformGenerator(T min, T max, int seed) + : min_(min), max_(max), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n); + return dist(rng); + } +}; + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class SamplingIdKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("X"); + const int batch_size = static_cast(input->dims()[0]); + const int width = static_cast(input->dims()[1]); + + std::vector ins_vector; + framework::TensorToVector(*input, context.device_context(), &ins_vector); + + unsigned int seed = static_cast(context.Attr("seed")); + if (seed == 0) { + std::random_device rd; + seed = rd(); + } + T min = static_cast(context.Attr("min")); + T max = static_cast(context.Attr("max")); + + std::vector ids(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + T r = UniformGenerator(min, max, seed); + int idx = width - 1; + for (int j = 0; j < width; ++j) { + if ((r -= ins_vector[i * width + j]) < 0) { + idx = j; + break; + } + } + ids[i] = ins_vector[i * width + idx]; + } + + std::vector out_dim; + out_dim.push_back(static_cast(batch_size)); + + Tensor* output = context.Output("Out"); + output->Resize(framework::make_ddim(out_dim)); + output->mutable_data(context.GetPlace()); + framework::TensorFromVector(ids, context.device_context(), output); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel, + paddle::operators::SamplingIdKernel); From 470fb7c5c39ad0f84baf15de67f618e8826b6d79 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 15 Aug 2018 18:15:16 +0800 Subject: [PATCH 75/94] bug fix --- paddle/fluid/operators/sampling_id_op.cc | 6 +++--- paddle/fluid/operators/sampling_id_op.cu | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index 2549758a8e..e88310745f 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -36,15 +36,15 @@ class SamplingIdKernel : public framework::OpKernel { std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); - unsigned int seed = static_cast(ctx.Attr("seed")); + unsigned int seed = static_cast(context.Attr("seed")); std::minstd_rand engine; if (seed == 0) { seed = std::random_device()(); } engine.seed(seed); std::uniform_real_distribution dist( - static_cast(ctx.Attr("min")), - static_cast(ctx.Attr("max"))); + static_cast(context.Attr("min")), + static_cast(context.Attr("max"))); std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index 791675b73b..b104710374 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -39,7 +39,7 @@ namespace operators { using Tensor = framework::Tensor; template -class SamplingIdKernel : public framework::OpKernel { +class SamplingIdGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("X"); @@ -83,5 +83,6 @@ class SamplingIdKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel, - paddle::operators::SamplingIdKernel); +REGISTER_OP_CUDA_KERNEL(sampling_id, + paddle::operators::SamplingIdGPUKernel, + paddle::operators::SamplingIdGPUKernel); From 507f47973270156095726506e9a1cf68f9fb4b05 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 15 Aug 2018 19:06:43 +0800 Subject: [PATCH 76/94] Polish code --- python/paddle/fluid/layers/detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index abe479693c..9baf5f84fd 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -1111,7 +1111,7 @@ def multi_box_head(inputs, mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1]) compile_shape = [ mbox_loc.shape[0], cpt.floor_division( - box_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3], 4), 4 + mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3], 4), 4 ] run_shape = tensor.assign(numpy.array([0, -1, 4]).astype("int32")) mbox_loc_flatten = nn.reshape( From baa6273c541a6ab4af588fd9fb211c53323f3640 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 15 Aug 2018 19:21:03 +0800 Subject: [PATCH 77/94] unit test optimize --- .../paddle/fluid/tests/unittests/test_sampling_id_op.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py index e3e7153049..708265b457 100644 --- a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py +++ b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py @@ -29,14 +29,19 @@ class TestSamplingIdOp(OpTest): self.inputs = {"X": self.X} self.Y = np.random.random(8).astype('float32') self.outputs = {'Out': self.Y} - self.attrs = {'use_mkldnn': self.use_mkldnn} + self.attrs = {'max': 1.0, 'min': 0.0, 'seed': 1} def test_check_output(self): self.check_output_customized(self.verify_output) + y1 = self.out + self.check_output_customized(self.verify_output) + y2 = self.out + self.assertTrue(np.array_equal(y1, y2)) + self.assertEqual(len(y1), len(self.Y)) def verify_output(self, outs): out = np.array(outs[0]) - self.assertEqual(len(out), len(self.Y)) + self.out = out def init_kernel_type(self): pass From 5d2834fcf7d0d797d5050c5d3481388363cf0817 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 15 Aug 2018 19:44:16 +0800 Subject: [PATCH 78/94] fea/ir support fuse, based on graph pattern detection helper (#12636) --- paddle/fluid/framework/ir/CMakeLists.txt | 3 + .../framework/ir/graph_pattern_detecter.cc | 186 ++++++++++++++++++ .../framework/ir/graph_pattern_detecter.h | 181 +++++++++++++++++ .../ir/graph_pattern_detecter_tester.cc | 172 ++++++++++++++++ paddle/fluid/framework/ir/graph_traits.cc | 69 +++++++ paddle/fluid/framework/ir/graph_traits.h | 90 +++++++++ paddle/fluid/framework/ir/node.h | 3 + 7 files changed, 704 insertions(+) create mode 100644 paddle/fluid/framework/ir/graph_pattern_detecter.cc create mode 100644 paddle/fluid/framework/ir/graph_pattern_detecter.h create mode 100644 paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc create mode 100644 paddle/fluid/framework/ir/graph_traits.cc create mode 100644 paddle/fluid/framework/ir/graph_traits.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index bf7d76a8a6..923a7083d4 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -3,7 +3,10 @@ cc_library(graph SRCS graph.cc DEPS node) cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) +cc_library(graph_traits SRCS graph_traits.cc DEPS graph) +cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) +cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter) diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter.cc b/paddle/fluid/framework/ir/graph_pattern_detecter.cc new file mode 100644 index 0000000000..f27d9b0509 --- /dev/null +++ b/paddle/fluid/framework/ir/graph_pattern_detecter.cc @@ -0,0 +1,186 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" +#include "paddle/fluid/framework/ir/graph_traits.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { + nodes_.emplace_back(new PDNode(std::move(teller), name)); + auto* cur = nodes_.back().get(); + return cur; +} + +void PDPattern::AddEdge(PDNode* a, PDNode* b) { + PADDLE_ENFORCE(a); + PADDLE_ENFORCE(b); + PADDLE_ENFORCE(a != b, "can't connect to the same nodes."); + edges_.emplace_back(a, b); +} + +void GraphPatternDetecter::operator()(Graph* graph, + GraphPatternDetecter::handle_t handler) { + if (!MarkPDNodesInGraph(*graph)) return; + auto subgraphs = DetectPatterns(); + UniquePatterns(&subgraphs); + RemoveOverlappedMatch(&subgraphs); + + for (auto& g : subgraphs) { + handler(g, graph); + } +} + +bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) { + if (graph.Nodes().empty()) return false; + + for (auto& node : GraphTraits::DFS(graph)) { + for (const auto& pdnode : pattern_.nodes()) { + if (pdnode->Tell(&node)) { + pdnodes2nodes_[pdnode.get()].insert(&node); + } + } + } + return !pdnodes2nodes_.empty(); +} + +struct HitGroup { + std::unordered_map roles; + + bool Match(Node* node, PDNode* pat) { + return !roles.count(pat) || roles.at(pat) == node; + } + + void Register(Node* node, PDNode* pat) { roles[pat] = node; } +}; + +// Tell whether Node a links to b. +bool IsNodesLink(Node* a, Node* b) { + for (auto* node : a->outputs) { + if (b == node) { + return true; + } + } + return false; +} + +std::vector +GraphPatternDetecter::DetectPatterns() { + // Init empty subgraphs. + std::vector result; + std::vector init_groups; + PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed"); + auto* first_pnode = pattern_.edges().front().first; + if (!pdnodes2nodes_.count(first_pnode)) return result; + for (auto* node : pdnodes2nodes_[first_pnode]) { + HitGroup group; + group.roles[first_pnode] = node; + init_groups.emplace_back(group); + } + + int step = 0; + std::array, 2> bi_records; + bi_records[0] = std::move(init_groups); + + // Extend a PDNode to subgraphs by deducing the connection relations defined + // in edges of PDNodes. + for (const auto& edge : pattern_.edges()) { + // Each role has two PDNodes, which indicates two roles. + // Detect two Nodes that can match these two roles and they are connected. + auto& pre_groups = bi_records[step % 2]; + auto& cur_groups = bi_records[1 - (step++ % 2)]; + cur_groups.clear(); + // source -> target + for (Node* source : pdnodes2nodes_[edge.first]) { + for (Node* target : pdnodes2nodes_[edge.second]) { + // TODO(Superjomn) add some prune strategies. + for (const auto& group : pre_groups) { + HitGroup new_group = group; + if (IsNodesLink(source, target) && + new_group.Match(source, edge.first)) { + new_group.Register(source, edge.first); + if (new_group.Match(target, edge.second)) { + new_group.Register(target, edge.second); + cur_groups.push_back(new_group); + // TODO(Superjomn) need to unique + } + } + } + } + } + } + + for (auto& group : bi_records[step % 2]) { + GraphPatternDetecter::subgraph_t subgraph; + for (auto& role : group.roles) { + subgraph.emplace(role.first, role.second); + } + result.emplace_back(subgraph); + } + return result; +} + +void GraphPatternDetecter::UniquePatterns( + std::vector* subgraphs) { + if (subgraphs->empty()) return; + std::vector result; + + std::unordered_set set; + for (auto& g : *subgraphs) { + size_t key = 0; + for (auto& item : g) { + key ^= std::hash{}(item.first); + key ^= std::hash{}(item.second); + } + if (!set.count(key)) { + result.emplace_back(g); + set.insert(key); + } + } + *subgraphs = result; +} + +void GraphPatternDetecter::RemoveOverlappedMatch( + std::vector* subgraphs) { + std::vector result; + std::unordered_set node_set; + + for (const auto& subgraph : *subgraphs) { + bool valid = true; + for (auto& item : subgraph) { + if (node_set.count(item.second)) { + valid = false; + break; + } + } + if (valid) { + for (auto& item : subgraph) { + node_set.insert(item.second); + } + result.push_back(subgraph); + } + } + *subgraphs = result; +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter.h b/paddle/fluid/framework/ir/graph_pattern_detecter.h new file mode 100644 index 0000000000..1778bf0000 --- /dev/null +++ b/paddle/fluid/framework/ir/graph_pattern_detecter.h @@ -0,0 +1,181 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_TESTING +#include +#endif + +#include +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" + +namespace paddle { +namespace framework { +namespace ir { + +// Some basic torminolygies: +// - PDPattern: a pattern defined as a data flow graph. +// - PDNode: the node in the pattern, each PDNode represents an `ir::Node` +// that meets some conditions defined in `PDNode.teller`. +// - A pattern is defined with PDNodes with edges. + +// Pattern detector node. This node helps to build a pattern. +struct PDNode { + // tell whether an ir::Node* is a candidation for a PDNode. + using teller_t = std::function; + + PDNode(teller_t&& teller, const std::string& name = "") + : teller_(teller), name_(name) { + PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set."); + } + + PDNode(PDNode&& other) = default; + + std::vector inlinks; + std::vector outlinks; + + bool Tell(Node* node) const { + PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode"); + return teller_(node); + } + + const std::string& name() const { return name_; } + + PDNode(const PDNode&) = delete; + PDNode& operator=(const PDNode&) = delete; + + private: + teller_t teller_; + std::string name_; +}; + +/* + * A pattern in a graph, which defined with PDNode and edges. Most graph + * patterns can be divided into PDNodes and link relations between them. + * + * For example, the FC fusion need to filter the MUL and ELEMENTWISE_ADD + * operators from the computation graph, the MUL's output should have only one + * consumer which is the ELEMENTWISE_ADD. + * This pattern can be defined as with the following pseudo codes + * + * // Create two operator PDNodes. + * MUL = PDPattern.NewNode() + * ELE = PDPattern.NewNode() + * // Create the variable PDNodes. + * MUL_out = PDPattern.NewNode() + * // Add teller to define some rules that help to filter the target Nodes. + * MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul"; + * ELE.teller = lambda(node): \ + * node->IsOp() && node->Op()->Type == "elementwise_add"; + * MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs) + * && (ELE in node->outputs) + * + * One can add more specific tellers for PDNodes or edges, both the Operator + * and Variable Nodes can be ruled in PDNode.teller. + * + * PDPattern can record the general patterns, such as the pattern represents + * - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place. + * - Ops whose inputs and outputs share the same variables + */ +class PDPattern { + public: + using edge_t = std::pair; + + void AddEdge(PDNode* a, PDNode* b); + + PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = ""); + + const std::vector>& nodes() const { return nodes_; } + const std::vector& edges() const { return edges_; } + + private: +#ifdef PADDLE_WITH_TESTING + FRIEND_TEST(PDPattern, AddEdge); + FRIEND_TEST(PDPattern, NewNode); +#endif + + std::vector> nodes_; + std::vector edges_; +}; + +/* + * GraphPatternDetecter helps to detect the specific patterns in the graph. + * Input a pattern, output a list of the matched subgraphs/nodes. + * This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.). + * + * The algorithm has three phases: + * 1. Mark the nodes that match the defined PDNodes in a PDPattern, + * 2. Extend a PDNode to subgraphs by deducing the connection relation defined + * in PAPattern(the edges), + * 3. Get the filtered subgraphs and treat them with a pre-defined handler. + * + * Usage: + * // Create a detector + * GraphPatternDetecter detector; + * // Define the detector's pattern, by adding PDNode and define the edges. + * auto* node0 = detector.mutable_pattern().AddNode(...) + * auto* node1 = detector.mutable_pattern().AddNode(...) + * node0->teller = some lambda. + * node1->teller = some lambda. + * detector.mutable_pattern().AddEdge(node0, node1); + * // Create an handler, to define the behavior of treating the filtered + * // subgraphs that comply with the patterns. + * GraphPatternDetecter::handle_t handler = some labmda + * // Execute the detector. + * detector(&graph, handler); + */ +class GraphPatternDetecter { + public: + using subgraph_t = std::unordered_map; + + // Operate on the detected pattern. + using handle_t = + std::function; + + void operator()(Graph* graph, handle_t handler); + + const PDPattern& pattern() const { return pattern_; } + PDPattern* mutable_pattern() { return &pattern_; } + + private: + // Mark the nodes that fits the pattern. + bool MarkPDNodesInGraph(const ir::Graph& graph); + + // Detect all the pattern and output the hit records. + std::vector DetectPatterns(); + + // Remove duplicate patterns. + void UniquePatterns(std::vector* subgraphs); + + // Remove overlapped match subgraphs, when overlapped, keep the previous one. + void RemoveOverlappedMatch(std::vector* subgraphs); + +#ifdef PADDLE_WITH_TESTING + FRIEND_TEST(GraphPatternDetecter, MarkPDNodesInGraph); + FRIEND_TEST(GraphPatternDetecter, DetectPatterns); +#endif + + private: + using hit_rcd_t = + std::pair; + PDPattern pattern_; + std::vector marked_records_; + std::unordered_map> pdnodes2nodes_; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc b/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc new file mode 100644 index 0000000000..993c885a81 --- /dev/null +++ b/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" + +#include + +namespace paddle { +namespace framework { +namespace ir { + +void BuildGraph(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation); + ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation); + ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable); + ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable); + + // o1->v1->o2 + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + // o2->v2->o3 + // o2->v2->o4 + o2->outputs.push_back(v2); + o3->inputs.push_back(v2); + o4->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o3); + v2->outputs.push_back(o4); + // o2->v3->o5 + o2->outputs.push_back(v3); + o5->inputs.push_back(v3); + v3->inputs.push_back(o2); + v3->outputs.push_back(o5); + // o3-v4->o5 + o3->outputs.push_back(v4); + o5->inputs.push_back(v4); + v4->inputs.push_back(o3); + v4->outputs.push_back(o5); +} + +TEST(PDPattern, NewNode) { + PDPattern x; + auto* n = x.NewNode([](Node* x) { return true; }); + ASSERT_TRUE(n); + ASSERT_EQ(x.nodes_.size(), 1UL); +} + +TEST(PDPattern, AddEdge) { + PDPattern x; + auto* a = x.NewNode([](Node* x) { return true; }); + auto* b = x.NewNode([](Node* x) { return true; }); + ASSERT_TRUE(a); + ASSERT_TRUE(b); + x.AddEdge(a, b); + ASSERT_EQ(x.nodes_.size(), 2UL); + ASSERT_EQ(x.edges_.size(), 1UL); + ASSERT_EQ(x.edges_.front().first, a); + ASSERT_EQ(x.edges_.front().second, b); + + ASSERT_EQ(x.nodes().size(), 2UL); + ASSERT_EQ(x.edges().size(), 1UL); + ASSERT_EQ(x.edges().front().first, a); + ASSERT_EQ(x.edges().front().second, b); +} + +TEST(GraphPatternDetecter, MarkPDNodesInGraph) { + GraphPatternDetecter x; + // mark o2, o3, v2 + + // The pattern is a graph: + // o2(a node named o2) -> v2(a node named v2) + // v2 -> o3(a node named o3) + auto* o2 = x.pattern_.NewNode([](Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->Name() == "op2" && node->IsOp(); + }); + auto* o3 = x.pattern_.NewNode([](Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->Name() == "op3" && node->IsOp(); + }); + auto* v2 = x.pattern_.NewNode([](Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->Name() == "var2" && node->IsVar(); + }); + + ASSERT_FALSE(o2->Tell(nullptr)); + ASSERT_FALSE(o3->Tell(nullptr)); + ASSERT_FALSE(v2->Tell(nullptr)); + + x.pattern_.AddEdge(o2, v2); + x.pattern_.AddEdge(v2, o3); + + ASSERT_EQ(x.pattern_.edges().size(), 2UL); + ASSERT_EQ(x.pattern_.edges()[0].first, o2); + ASSERT_EQ(x.pattern_.edges()[0].second, v2); + ASSERT_EQ(x.pattern_.edges()[1].first, v2); + ASSERT_EQ(x.pattern_.edges()[1].second, o3); + + ProgramDesc program; + Graph graph(program); + BuildGraph(&graph); + + x.MarkPDNodesInGraph(graph); + + ASSERT_EQ(x.pdnodes2nodes_.size(), 3UL); + + auto subgraphs = x.DetectPatterns(); + ASSERT_EQ(subgraphs.size(), 1UL); +} + +TEST(GraphPatternDetecter, MultiSubgraph) { + ProgramDesc program; + Graph graph(program); + BuildGraph(&graph); + + GraphPatternDetecter x; + + // The pattern is a graph: + // op -> var + auto* any_op = x.mutable_pattern()->NewNode( + [](Node* node) { + return node->IsOp() && (node->Name() == "op2" || node->Name() == "op3"); + }, + "OP0"); + auto* any_var = x.mutable_pattern()->NewNode( + [](Node* node) { return node->IsVar(); }, "VAR"); + auto* any_op1 = x.mutable_pattern()->NewNode( + [](Node* node) { return node->IsOp(); }, "OP1"); + + x.mutable_pattern()->AddEdge(any_op, any_var); + x.mutable_pattern()->AddEdge(any_var, any_op1); + + int count = 0; + GraphPatternDetecter::handle_t handle = [&]( + const GraphPatternDetecter::subgraph_t& s, Graph* g) { + LOG(INFO) << "Detect " << s.at(any_op)->Name() << " -> " + << s.at(any_var)->Name() << " -> " << s.at(any_op1)->Name(); + count++; + }; + + x(&graph, handle); + + // 1. Detect op3 -> var4 -> op5 + // 2. Detect op2 -> var2 -> op3 + // 3. Detect op2 -> var2 -> op4 + // 4. Detect op2 -> var3 -> op5 + // But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2 + ASSERT_GE(count, 1UL); + ASSERT_LE(count, 2UL); +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_traits.cc b/paddle/fluid/framework/ir/graph_traits.cc new file mode 100644 index 0000000000..8f548913e4 --- /dev/null +++ b/paddle/fluid/framework/ir/graph_traits.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/graph_traits.h" + +namespace paddle { +namespace framework { +namespace ir { + +// +// NodesDFSIterator +// +NodesDFSIterator::NodesDFSIterator(const std::vector &source) { + for (auto *x : source) stack_.push(x); +} + +NodesDFSIterator::NodesDFSIterator(NodesDFSIterator &&other) noexcept + : stack_(std::move(other.stack_)), + visited_(std::move(other.visited_)) {} + +NodesDFSIterator::NodesDFSIterator(const NodesDFSIterator &other) + : stack_(other.stack_), visited_(other.visited_) {} + +Node &NodesDFSIterator::operator*() { + PADDLE_ENFORCE(!stack_.empty()); + return *stack_.top(); +} + +NodesDFSIterator &NodesDFSIterator::operator++() { + PADDLE_ENFORCE(!stack_.empty(), "the iterator exceeds range"); + visited_.insert(stack_.top()); + auto *cur = stack_.top(); + stack_.pop(); + for (auto *x : cur->outputs) { + if (!visited_.count(x)) { + stack_.push(x); + } + } + return *this; +} +bool NodesDFSIterator::operator==(const NodesDFSIterator &other) { + if (stack_.empty()) return other.stack_.empty(); + if ((!stack_.empty()) && (!other.stack_.empty())) { + return stack_.top() == other.stack_.top(); + } + return false; +} + +NodesDFSIterator &NodesDFSIterator::operator=(const NodesDFSIterator &other) { + stack_ = other.stack_; + visited_ = other.visited_; + return *this; +} +Node *NodesDFSIterator::operator->() { return stack_.top(); } + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_traits.h b/paddle/fluid/framework/ir/graph_traits.h new file mode 100644 index 0000000000..edbe45acb9 --- /dev/null +++ b/paddle/fluid/framework/ir/graph_traits.h @@ -0,0 +1,90 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" + +namespace paddle { +namespace framework { +namespace ir { + +template +class iterator_range { + IteratorT begin_, end_; + + public: + template + explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {} + + iterator_range(const IteratorT &begin, const IteratorT &end) + : begin_(begin), end_(end) {} + + const IteratorT &begin() const { return begin_; } + const IteratorT &end() const { return end_; } +}; + +// DFS iterator on nodes. +struct NodesDFSIterator + : public std::iterator { + NodesDFSIterator() = default; + explicit NodesDFSIterator(const std::vector &source); + NodesDFSIterator(NodesDFSIterator &&other) noexcept; + NodesDFSIterator(const NodesDFSIterator &other); + + Node &operator*(); + NodesDFSIterator &operator++(); + // TODO(Superjomn) current implementation just compare the first + // element, need to compare the graph and all the elements in the queue and + // set. + NodesDFSIterator &operator=(const NodesDFSIterator &other); + bool operator==(const NodesDFSIterator &other); + bool operator!=(const NodesDFSIterator &other) { return !(*this == other); } + Node *operator->(); + + private: + std::stack stack_; + std::unordered_set visited_; +}; + +/* + * GraphTraits contains some graph traversal algorithms. + * + * Usage: + * + */ +struct GraphTraits { + static iterator_range DFS(const Graph &g) { + auto start_points = ExtractStartPoints(g); + NodesDFSIterator x(start_points); + return iterator_range(NodesDFSIterator(start_points), + NodesDFSIterator()); + } + + private: + // The nodes those have no input will be treated as start points. + static std::vector ExtractStartPoints(const Graph &g) { + std::vector result; + for (auto *node : g.Nodes()) { + if (node->inputs.empty()) { + result.push_back(node); + } + } + return result; + } +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index b3138fccee..9c0765ab8c 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -58,6 +58,9 @@ class Node { return op_desc_; } + bool IsOp() const { return type_ == Type::kOperation; } + bool IsVar() const { return type_ == Type::kVariable; } + std::vector inputs; std::vector outputs; From d06849305a67d6645699384ae87ec1870e5756e3 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 15 Aug 2018 21:17:14 +0800 Subject: [PATCH 79/94] parameter dispather. (#12666) --- paddle/fluid/framework/threadpool.cc | 7 ++ .../distributed/variable_response.cc | 7 +- paddle/fluid/operators/listen_and_serv_op.cc | 5 +- python/paddle/fluid/__init__.py | 2 +- python/paddle/fluid/initializer.py | 1 - .../fluid/tests/unittests/CMakeLists.txt | 4 +- .../fluid/tests/unittests/test_dist_train.py | 17 +++ .../tests/unittests/test_dist_transpiler.py | 50 ++++++--- .../fluid/transpiler/distribute_transpiler.py | 100 +++++++++++++++--- 9 files changed, 162 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index f26f212d4d..18cdca3a65 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -20,6 +20,9 @@ DEFINE_int32(io_threadpool_size, 100, "number of threads used for doing IO, default 100"); +DEFINE_int32(dist_threadpool_size, 0, + "number of threads used for distributed executed."); + namespace paddle { namespace framework { @@ -35,6 +38,10 @@ void ThreadPool::Init() { if (threadpool_.get() == nullptr) { // TODO(Yancey1989): specify the max threads number int num_threads = std::thread::hardware_concurrency(); + if (FLAGS_dist_threadpool_size > 0) { + num_threads = FLAGS_dist_threadpool_size; + VLOG(1) << "set dist_threadpool_size to " << num_threads; + } PADDLE_ENFORCE_GT(num_threads, 0); threadpool_.reset(new ThreadPool(num_threads)); } diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc index 466bce18af..8e38b3713f 100644 --- a/paddle/fluid/operators/distributed/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -190,12 +190,15 @@ bool VariableResponse::ProcSerializedField( #endif } + VLOG(7) << "ProcSerializedField:" << meta_.varname() + << ", type:" << meta_.type() << std::endl; framework::DDim dims = GetDims(meta_.dims()); if (meta_.type() == sendrecv::LOD_TENSOR) { PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!"); if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) { return false; } + return true; } @@ -206,7 +209,9 @@ bool VariableResponse::ProcSerializedField( return true; } - return true; + PADDLE_ENFORCE("not supported var types:", meta_.varname(), meta_.type()); + + return false; } }; // namespace distributed diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index b194807696..f196e18fe1 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -123,8 +123,11 @@ void ListenAndServOp::RunSyncLoop( optimize_prepared.begin(), std::shared_ptr(nullptr)); + // Trainers will get all parameters from pserver in the + // startup program, so we will wait RequestGet first + rpc_service_->SetCond(distributed::kRequestGet); + rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->ResetBarrierCounter(); - while (true) { rpc_service_->Profiler().OneStep(); // Get from multiple trainers, we don't care about the order in which diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 1ae05dec8d..9aac3c7fc1 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -122,7 +122,7 @@ def __bootstrap__(): 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', - 'cpu_deterministic' + "dist_threadpool_size", 'cpu_deterministic' ] if core.is_compiled_with_dist(): read_env_flags.append('rpc_deadline') diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 3f740dd7c5..6dedbae7a6 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -15,7 +15,6 @@ from . import framework import numpy as np import contextlib -from .framework import convert_np_dtype_to_dtype_ from .core import VarDesc __all__ = [ diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index a6a911721d..e7dd85ef5c 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -59,8 +59,8 @@ py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=$ if(WITH_DISTRIBUTE) py_test_modules(test_dist_train MODULES test_dist_train SERIAL) set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20) - set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 180) - set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 180) + set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 200) + set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200) endif() py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) diff --git a/python/paddle/fluid/tests/unittests/test_dist_train.py b/python/paddle/fluid/tests/unittests/test_dist_train.py index aab8969a96..55aa923f5a 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_train.py +++ b/python/paddle/fluid/tests/unittests/test_dist_train.py @@ -26,6 +26,12 @@ from paddle.fluid.layers.io import ListenAndServ from paddle.fluid.layers.io import Recv from paddle.fluid.layers.io import Send +from paddle.fluid import core + +RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( +) +RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC + class TestSendOp(unittest.TestCase): def test_send(self): @@ -89,18 +95,29 @@ class TestSendOp(unittest.TestCase): def init_client(self, place, port): main = fluid.Program() with fluid.program_guard(main): + main.global_block().append_op( + type="fetch_barrier", + inputs={}, + outputs={}, + attrs={ + "endpoints": ["127.0.0.1:{0}".format(port)], + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + x = layers.data( shape=[32, 32], dtype='float32', name='X', append_batch_size=False) fluid.initializer.Constant(value=2.3)(x, main.global_block()) + get_var = main.global_block().create_var( name="scale_0.tmp_0", # server side var dtype="float32", persistable=False, shape=[32, 32]) fluid.initializer.Constant(value=2.3)(get_var, main.global_block()) + Send("127.0.0.1:%d" % port, [x]) o = Recv("127.0.0.1:%d" % port, [get_var]) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 55f8b3eff8..124abf4ccd 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -18,6 +18,7 @@ import unittest import paddle.fluid as fluid from paddle.fluid.transpiler.distribute_transpiler import delete_ops import traceback +import collections class TranspilerTest(unittest.TestCase): @@ -53,9 +54,18 @@ class TranspilerTest(unittest.TestCase): self.origin_prog = main.clone() return main - def get_trainer(self, config=None, sync_mode=True): - t = self._transpiler_instance(config, sync_mode) - return t.get_trainer_program() + def get_trainer(self, config=None): + src = fluid.default_startup_program().clone() + + t = self._transpiler_instance(config) + + trainer_main = t.get_trainer_program() + trainer_startup = fluid.default_startup_program() + + assert (src.num_blocks == 1) + assert (trainer_startup.num_blocks == src.num_blocks) + + return trainer_main, trainer_startup def get_pserver(self, ep, config=None, sync_mode=True): t = self._transpiler_instance(config, sync_mode) @@ -91,7 +101,21 @@ class TestBasicModel(TranspilerTest): pserver, startup = self.get_pserver(self.pserver1_ep) pserver2, startup2 = self.get_pserver(self.pserver2_ep) - trainer = self.get_trainer() + trainer, trainer_startup = self.get_trainer() + + # splited var blocks should be in startup program + self.assertTrue("fc_w.block0" in trainer_startup.global_block().vars) + self.assertTrue("fc_w.block1" in trainer_startup.global_block().vars) + self.assertTrue("fc_w" in trainer_startup.global_block().vars) + self.assertTrue("fc_b" in trainer_startup.global_block().vars) + self.assertTrue("fc_w@GRAD" not in trainer_startup.global_block().vars) + self.assertTrue("fc_b@GRAD" not in trainer_startup.global_block().vars) + + src = [op.type for op in trainer_startup.global_block().ops] + dst = ['fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv', \ + 'fetch_barrier', 'concat'] + + self.assertEqual(src, dst) self.assertEqual([op.type for op in trainer.global_block().ops], [ 'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean', @@ -142,7 +166,7 @@ class TestBasicModelWithLargeBlockSize(TranspilerTest): pserver, startup = self.get_pserver(self.pserver1_ep, config) pserver2, startup2 = self.get_pserver(self.pserver2_ep, config) - trainer = self.get_trainer(config) + trainer, _ = self.get_trainer(config) self.assertEqual([op.type for op in trainer.global_block().ops], [ 'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean', @@ -226,7 +250,7 @@ class TestLRDecay(TranspilerTest): def transpiler_test_impl(self): pserver, startup = self.get_pserver(self.pserver1_ep) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(pserver.blocks), 4) lr_decay_ops = [op.type for op in pserver.blocks[1].ops] @@ -256,7 +280,7 @@ class TestLRDecayConditional(TranspilerTest): def transpiler_test_impl(self): pserver, startup = self.get_pserver(self.pserver1_ep) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() serv_op = pserver.blocks[0].ops[0] sub_blocks = [] @@ -305,7 +329,7 @@ class TestL2Decay(TranspilerTest): def transpiler_test_impl(self): pserver, startup = self.get_pserver(self.pserver1_ep) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(pserver.blocks), 3) self.assertEqual([op.type for op in pserver.blocks[1].ops], @@ -340,7 +364,7 @@ class TestL2DecayWithPiecewise(TranspilerTest): def transpiler_test_impl(self): pserver, startup = self.get_pserver(self.pserver1_ep) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(pserver.blocks), 9) self.assertEqual([op.type for op in pserver.blocks[1].ops], [ @@ -415,7 +439,7 @@ class TestLocalLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sum", "adam", "scale", "scale"]) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(trainer.blocks), 1) ops = [ 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', @@ -453,7 +477,7 @@ class TestDistLookupTable(TestDistLookupTableBase): # 5 save table self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(trainer.blocks), 1) ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', @@ -486,7 +510,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["adam", "scale", "scale"]) - trainer = self.get_trainer(config) + trainer, _ = self.get_trainer(config) self.assertEqual(len(trainer.blocks), 1) ops = [ 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', @@ -525,7 +549,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): # 5 save table self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) - trainer = self.get_trainer(config) + trainer, _ = self.get_trainer(config) self.assertEqual(len(trainer.blocks), 1) ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index c97beea1b3..ce4709f23b 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -195,6 +195,9 @@ class DistributeTranspiler(object): if program is None: program = default_main_program() self.origin_program = program + self.origin_startup_program = default_startup_program().clone() + + self.startup_program = default_startup_program() self.trainer_num = trainers self.sync_mode = sync_mode self.trainer_id = trainer_id @@ -205,10 +208,10 @@ class DistributeTranspiler(object): ps_dispatcher = self.config.split_method(self.pserver_endpoints) self.has_distributed_lookup_table = self._has_distributed_lookup_table() - # split and create vars, then put splited vars in dicts for later use. + # step 1: split and create vars, then put splited vars in dicts for later use. self._init_splited_vars() - # step 3.1: insert send op to send gradient vars to parameter servers + # step 2: insert send op to send gradient vars to parameter servers ps_dispatcher.reset() send_vars = [] @@ -265,7 +268,7 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) - # step 3.2: insert recv op to receive parameters from parameter server + # step 3: insert recv op to receive parameters from parameter server recv_vars = [] for _, var in enumerate(send_vars): recv_vars.append(self.grad_param_mapping[var]) @@ -312,6 +315,8 @@ class DistributeTranspiler(object): outputs={"Out": [orig_param]}, attrs={"axis": 0}) + self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) + if self.has_distributed_lookup_table: self._replace_lookup_table_op_with_prefetch(program, pserver_endpoints) @@ -328,8 +333,78 @@ class DistributeTranspiler(object): # FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay? delete_ops(self.origin_program.global_block(), self.optimize_ops) self.origin_program.__str__() + return self.origin_program + def _get_trainer_startup_program(self, + recv_vars, + eplist, + startup_program=None): + """ + Get transpiled trainer side startup program. + + Args: + startup_program(Program): Startup program. + + Returns: + Program: trainer side startup program. + """ + if startup_program is None: + startup_program = self.startup_program + + # FIXME(gongwb): delete not need ops. + # note that: some parameter is not trainable and those ops can't be deleted. + + for varname, splited_var in self.param_var_mapping.iteritems(): + # Get the eplist of recv vars + eps = [] + for var in splited_var: + index = [v.name for v in recv_vars].index(var.name) + eps.append(eplist[index]) + + for var in splited_var: + if startup_program.global_block().has_var(var.name): + continue + + startup_program.global_block().create_var( + name=var.name, + persistable=False, + type=var.type, + dtype=var.dtype, + shape=var.shape, + lod_level=var.lod_level) + + op = startup_program.global_block().append_op( + type="recv", + inputs={}, + outputs={"Out": splited_var}, + attrs={ + "epmap": eps, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + + startup_program.global_block().append_op( + type="fetch_barrier", + inputs={}, + outputs={}, + attrs={ + "endpoints": self.pserver_endpoints, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + + for varname, splited_var in self.param_var_mapping.iteritems(): + #add concat ops to merge splited parameters received from parameter servers. + if len(splited_var) <= 1: + continue + orig_param = startup_program.global_block().vars[varname] + startup_program.global_block().append_op( + type="concat", + inputs={"X": splited_var}, + outputs={"Out": [orig_param]}, + attrs={"axis": 0}) + + return startup_program + def get_pserver_program(self, endpoint): """ Get parameter server side program. @@ -576,14 +651,16 @@ class DistributeTranspiler(object): new_outputs = dict() # do not append startup op if var is not on this pserver op_on_pserver = False - for key in op.output_names: - newname, _ = _get_splited_name_and_shape(op.output(key)[0]) - if newname: - op_on_pserver = True - new_outputs[key] = created_var_map[newname] - elif op.output(key)[0] in pserver_vars: - op_on_pserver = True - new_outputs[key] = pserver_vars[op.output(key)[0]] + # TODO(gongwb): remove this line. + if op.type not in ["recv", "fetch_barrier", "concat"]: + for key in op.output_names: + newname, _ = _get_splited_name_and_shape(op.output(key)[0]) + if newname: + op_on_pserver = True + new_outputs[key] = created_var_map[newname] + elif op.output(key)[0] in pserver_vars: + op_on_pserver = True + new_outputs[key] = pserver_vars[op.output(key)[0]] if op_on_pserver: # most startup program ops have no inputs @@ -1022,7 +1099,6 @@ class DistributeTranspiler(object): var_mapping[varname] = \ [program.global_block().var(orig_var.name)] continue - var_mapping[varname] = [] orig_shape = orig_var.shape orig_dim1_flatten = 1 From 9f09d68678c66e4759ce0bffc338cae87d5ec9d5 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 15 Aug 2018 21:22:14 +0800 Subject: [PATCH 80/94] add enforce --- paddle/fluid/operators/sampling_id_op.cc | 4 ++++ paddle/fluid/operators/sampling_id_op.cu | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index e88310745f..ca7b246901 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -33,6 +33,10 @@ class SamplingIdKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); const int width = static_cast(input->dims()[1]); + PADDLE_ENFORCE_GE(batch_size, 0, + "batch_size(dims[0]) must be nonnegative."); + PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); + std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index b104710374..114df044af 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -46,6 +46,10 @@ class SamplingIdGPUKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); const int width = static_cast(input->dims()[1]); + PADDLE_ENFORCE_GE(batch_size, 0, + "batch_size(dims[0]) must be nonnegative."); + PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); + std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); @@ -56,10 +60,11 @@ class SamplingIdGPUKernel : public framework::OpKernel { } T min = static_cast(context.Attr("min")); T max = static_cast(context.Attr("max")); + UniformGenerator gen = UniformGenerator(min, max, seed); std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { - T r = UniformGenerator(min, max, seed); + T r = gen(0); int idx = width - 1; for (int j = 0; j < width; ++j) { if ((r -= ins_vector[i * width + j]) < 0) { From c108376506faa8c51f489a4c1e658a446424453a Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Wed, 15 Aug 2018 22:38:25 +0800 Subject: [PATCH 81/94] Add three modes for prelu_op (#12630) * Add three modes for prelu_op. --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/prelu_op.cc | 65 +++++++-- paddle/fluid/operators/prelu_op.cu | 22 --- paddle/fluid/operators/prelu_op.h | 125 ++++++++++-------- python/paddle/fluid/layers/nn.py | 54 ++++++++ .../fluid/tests/unittests/test_layers.py | 15 +++ .../fluid/tests/unittests/test_prelu_op.py | 56 ++++++-- 7 files changed, 237 insertions(+), 101 deletions(-) delete mode 100644 paddle/fluid/operators/prelu_op.cu diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index c020ff45ad..ea9105d79c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -159,6 +159,7 @@ paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaul paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index db040509bc..23d9ea88f6 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,14 +23,40 @@ class PReluOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { + std::string mode = ctx->Attrs().Get("mode"); + + auto x_dim = ctx->GetInputDim("X"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null"); - PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, - "Size of weight Alpha must be one."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + if (mode == "all") { + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, + "For mode 'all', size of weight Alpha must be one."); + } else if (mode == "channel") { + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == x_dim[1], + "For channel-wise mode, size of weight Alpha must be " + "equal to the number of channels, should be %d", + x_dim[1]); + } else if (mode == "element") { + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == product(x_dim), + "For element-wise mode, size of weight Alpha must be " + "equal to the number of input, should be %d", + product(x_dim)); + } else { + PADDLE_THROW("Unkown mode %s", mode); + } + ctx->SetOutputDim("Out", x_dim); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } }; class PReluOpMaker : public framework::OpProtoAndCheckerMaker { @@ -44,9 +67,7 @@ class PReluOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output tensor of prelu operator."); AddComment(R"DOC( PRelu Operator. - The equation is: - $$ f(x) = \begin{cases} @@ -54,11 +75,15 @@ f(x) = x, \qquad \text{if} \ x >= 0 \end{cases} $$ - The input `X` can carry the LoD (Level of Details) information, or not. And the output shares the LoD information with input `X`. - +There are modes: + all: all elements share same weight + channel: elements in a channel share same weight + element: each element has a weight )DOC"); + AddAttr("mode", "The mode for inputs to share weights.") + .SetDefault("all"); } }; @@ -71,9 +96,23 @@ class PReluGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->SetOutputDim(framework::GradVarName("Alpha"), - ctx->GetInputDim("Alpha")); + auto x_grad_name = framework::GradVarName("X"); + auto alpha_grad_name = framework::GradVarName("Alpha"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + } + if (ctx->HasOutput(alpha_grad_name)) { + ctx->SetOutputDim(alpha_grad_name, ctx->GetInputDim("Alpha")); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu deleted file mode 100644 index 37d934a290..0000000000 --- a/paddle/fluid/operators/prelu_op.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/prelu_op.h" - -REGISTER_OP_CUDA_KERNEL( - prelu, - paddle::operators::PReluKernel); -REGISTER_OP_CUDA_KERNEL(prelu_grad, - paddle::operators::PReluGradKernel< - paddle::platform::CUDADeviceContext, float>); diff --git a/paddle/fluid/operators/prelu_op.h b/paddle/fluid/operators/prelu_op.h index a6197d3548..f9076cbc67 100644 --- a/paddle/fluid/operators/prelu_op.h +++ b/paddle/fluid/operators/prelu_op.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,32 +10,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/transform.h" - namespace paddle { namespace operators { using Tensor = framework::Tensor; using platform::Transform; -template -class PReluFunctor { - public: - explicit PReluFunctor(const T* alpha) : alpha_(alpha) {} - - HOSTDEVICE T operator()(const T& x) const { - if (x > 0) - return x; - else - return x * (*alpha_); - } - - private: - const T* alpha_; -}; - template class PReluKernel : public framework::OpKernel { public: @@ -50,53 +31,93 @@ class PReluKernel : public framework::OpKernel { const T* x_ptr = x->data(); T* o_ptr = out->mutable_data(context.GetPlace()); - auto* alpha_ptr = alpha->data(); + const T* alpha_ptr = alpha->data(); + std::string mode = context.Attr("mode"); int numel = x->numel(); - - Transform trans; - trans(context.template device_context(), x_ptr, - x_ptr + numel, o_ptr, PReluFunctor(alpha_ptr)); - } -}; - -template -class PReluGradFunctor { - public: - explicit PReluGradFunctor(const T* alpha) : alpha_(alpha) {} - - HOSTDEVICE T operator()(const T& out, const T& dout) const { - if (out > 0) - return dout; - else - return dout * (*alpha_); + auto dim = x->dims(); + int index = 0; + int i = 0; + int temp = 0; + if (mode == "channel") { + for (i = 0; i < numel; i++) { + temp = numel / (dim[0] * dim[1]); + index = (i / temp) % dim[1]; + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; + } + } else if (mode == "element") { + for (i = 0; i < numel; i++) { + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[i] * x_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[0] * x_ptr[i]; + } + } } - - private: - const T* alpha_; }; template class PReluGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); auto* dx = context.Output(framework::GradVarName("X")); auto* dout = context.Input(framework::GradVarName("Out")); - + auto* dalpha = context.Output(framework::GradVarName("Alpha")); auto* out = context.Input("Out"); auto* alpha = context.Input("Alpha"); - auto* alpha_ptr = alpha->data(); - - T* dx_ptr = dx->mutable_data(context.GetPlace()); + const T* alpha_ptr = alpha->data(); + const T* x_ptr = x->data(); const T* dout_ptr = dout->data(); const T* out_ptr = out->data(); - int numel = dx->numel(); - - Transform trans; - trans(context.template device_context(), out_ptr, - out_ptr + numel, dout_ptr, dx_ptr, PReluGradFunctor(alpha_ptr)); - - // TODO(Zhuoyuan): add dalpha upgrade when GPU kernels ready + std::string mode = context.Attr("mode"); + int numel = x->numel(); + auto dim = x->dims(); + int index = 0; + int i = 0; + int temp = 0; + if (dx) { + T* dx_ptr = dx->mutable_data(context.GetPlace()); + if (mode == "channel") { + for (i = 0; i < numel; i++) { + temp = numel / (dim[0] * dim[1]); + index = (i / temp) % dim[1]; + dx_ptr[i] = + out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; + } + } else if (mode == "element") { + for (i = 0; i < numel; i++) { + dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[i] * dout_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[0] * dout_ptr[i]; + } + } + } + + index = 0; + if (dalpha) { + T* dalpha_ptr = dalpha->mutable_data(context.GetPlace()); + if (mode == "channel") { + for (i = 0; i < numel; i++) { + temp = numel / (dim[0] * dim[1]); + index = (i / temp) % dim[1]; + dalpha_ptr[index] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } else if (mode == "element") { + for (i = 0; i < numel; i++) { + dalpha_ptr[i] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + dalpha_ptr[0] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } + } + + // TODO(Guanzhong): add GPU kernels } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c75e7eeb43..3e50fc91d9 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -112,6 +112,7 @@ __all__ = [ 'log', 'crop', 'rank_loss', + 'prelu', 'flatten', ] @@ -5364,6 +5365,59 @@ def rank_loss(label, left, right, name=None): return out +def prelu(x, mode, param_attr=None, name=None): + """ + Equation: + + y = \max(0, x) + alpha \min(0, x) + + Args: + x (Variable): The input tensor. + param_attr(ParamAttr|None): The parameter attribute for the learnable + weight (alpha). + mode (string): The mode for weight sharing + all: all elements share same weight + channel:elements in a channel share same weight + element:each element has a weight + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The output tensor with the same shape as input. + + Examples: + + .. code-block:: python + + x = fluid.layers.data(name="x", shape=[10,10], dtype="float32") + mode = 'channel' + output = fluid.layers.prelu(x,mode) + """ + helper = LayerHelper('prelu', **locals()) + if mode not in ['all', 'channel', 'element']: + raise ValueError('mode should be one of all, channel, element.') + alpha_shape = [1] + if mode == 'channel': + alpha_shape = [1, x.shape[1], 1, 1] + elif mode == 'element': + alpha_shape = x.shape + dtype = helper.input_dtype(input_param_name='x') + alpha = helper.create_parameter( + attr=param_attr, + shape=alpha_shape, + dtype='float32', + is_bias=False, + default_initializer=Constant(1.0)) + out = helper.create_tmp_variable(dtype) + helper.append_op( + type="prelu", + inputs={"X": x, + 'Alpha': alpha}, + attrs={"mode": mode}, + outputs={"Out": out}) + return out + + def flatten(x, axis=1, name=None): """ **Flatten layer** diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 38a138a8fa..07fd0575d3 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -21,6 +21,7 @@ import paddle.fluid.nets as nets from paddle.fluid.framework import Program, program_guard, default_main_program from paddle.fluid.param_attr import ParamAttr import decorators +from paddle.fluid.initializer import Constant class TestBook(unittest.TestCase): @@ -485,6 +486,20 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_prelu(self): + program = Program() + with program_guard(program): + input = layers.data( + name="input", shape=[5, 200, 100, 100], dtype="float32") + mode = 'channel' + out = layers.prelu( + input, + mode, + param_attr=ParamAttr(initializer=Constant(1.0)), + name='prelu') + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index ae19a553bb..cb7de3fc93 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -20,30 +20,58 @@ from op_test import OpTest class PReluTest(OpTest): def setUp(self): self.op_type = "prelu" - x_np = np.random.normal(size=(10, 10)).astype("float32") - - for pos, val in np.ndenumerate(x_np): - # Since zero point in prelu is not differentiable, avoid randomize - # zero. - while abs(val) < 1e-3: - x_np[pos] = np.random.normal() - val = x_np[pos] - - x_np_sign = np.sign(x_np) - x_np = x_np_sign * np.maximum(x_np, .005) - alpha_np = np.array([.1], dtype="float32") - self.inputs = {'X': x_np, 'Alpha': alpha_np} + self.initTestCase() + x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32") + + # Since zero point in prelu is not differentiable, avoid randomize + # zero. + x_np[np.abs(x_np) < 0.005] = 0.02 + + if self.attrs == {'mode': "all"}: + alpha_np = np.random.rand(1).astype("float32") + self.inputs = {'X': x_np, 'Alpha': alpha_np} + elif self.attrs == {'mode': "channel"}: + alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32") + self.inputs = {'X': x_np, 'Alpha': alpha_np} + else: + alpha_np = np.random.rand(*x_np.shape).astype("float32") + self.inputs = {'X': x_np, 'Alpha': alpha_np} + out_np = np.maximum(self.inputs['X'], 0.) out_np = out_np + np.minimum(self.inputs['X'], 0.) * self.inputs['Alpha'] assert out_np is not self.inputs['X'] self.outputs = {'Out': out_np} + def initTestCase(self): + self.attrs = {'mode': "channel"} + def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X', 'Alpha'], 'Out') + + def test_check_grad_ignore_x(self): + self.check_grad(['Alpha'], 'Out', no_grad_set=set('X')) + + def test_check_grad_ignore_alpha(self): + self.check_grad(['X'], 'Out', no_grad_set=set('Alpha')) + + +class TestCase1(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "all"} + + +class TestCase2(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "channel"} + + +class TestCase3(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "element"} if __name__ == "__main__": From 99d3f089201f6967378d2d97b9f0b57ab3bc5a45 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Wed, 15 Aug 2018 22:58:15 +0800 Subject: [PATCH 82/94] Add print_function for all python files --- python/paddle/dataset/cifar.py | 2 + python/paddle/dataset/common.py | 2 + python/paddle/dataset/conll05.py | 2 + python/paddle/dataset/flowers.py | 3 ++ python/paddle/dataset/image.py | 3 ++ python/paddle/dataset/imdb.py | 2 + python/paddle/dataset/imikolov.py | 3 ++ python/paddle/dataset/mnist.py | 3 ++ python/paddle/dataset/movielens.py | 2 + python/paddle/dataset/mq2007.py | 2 + python/paddle/dataset/sentiment.py | 2 + python/paddle/dataset/tests/cifar_test.py | 2 + python/paddle/dataset/tests/common_test.py | 2 + python/paddle/dataset/tests/flowers_test.py | 2 + python/paddle/dataset/tests/imdb_test.py | 2 + python/paddle/dataset/tests/imikolov_test.py | 2 + python/paddle/dataset/tests/mnist_test.py | 2 + python/paddle/dataset/tests/mq2007_test.py | 2 + python/paddle/dataset/tests/test_image.py | 2 + python/paddle/dataset/tests/test_sentiment.py | 2 + python/paddle/dataset/tests/voc2012_test.py | 2 + python/paddle/dataset/tests/wmt16_test.py | 2 + python/paddle/dataset/uci_housing.py | 2 +- python/paddle/dataset/voc2012.py | 2 + python/paddle/dataset/wmt14.py | 3 ++ python/paddle/dataset/wmt16.py | 2 + python/paddle/fluid/average.py | 2 + python/paddle/fluid/backward.py | 2 + python/paddle/fluid/clip.py | 2 + python/paddle/fluid/concurrency.py | 2 + python/paddle/fluid/contrib/__init__.py | 2 + .../paddle/fluid/contrib/decoder/__init__.py | 2 + .../contrib/decoder/beam_search_decoder.py | 2 + .../paddle/fluid/contrib/memory_usage_calc.py | 2 + python/paddle/fluid/data_feeder.py | 2 + python/paddle/fluid/debugger.py | 2 + python/paddle/fluid/default_scope_funcs.py | 2 + python/paddle/fluid/evaluator.py | 2 + python/paddle/fluid/executor.py | 2 + python/paddle/fluid/framework.py | 2 + python/paddle/fluid/graphviz.py | 2 + python/paddle/fluid/inferencer.py | 2 + python/paddle/fluid/initializer.py | 2 + python/paddle/fluid/io.py | 2 + python/paddle/fluid/layer_helper.py | 2 + python/paddle/fluid/layers/__init__.py | 2 + python/paddle/fluid/layers/control_flow.py | 2 + python/paddle/fluid/layers/detection.py | 2 + python/paddle/fluid/layers/device.py | 2 + python/paddle/fluid/layers/io.py | 2 + .../fluid/layers/layer_function_generator.py | 2 + .../fluid/layers/learning_rate_scheduler.py | 12 +++--- python/paddle/fluid/layers/math_op_patch.py | 2 + python/paddle/fluid/layers/metric_op.py | 18 +++++---- python/paddle/fluid/layers/nn.py | 38 +++++++------------ python/paddle/fluid/layers/ops.py | 2 + python/paddle/fluid/layers/tensor.py | 2 + python/paddle/fluid/layers/utils.py | 2 + python/paddle/fluid/lod_tensor.py | 2 + python/paddle/fluid/metrics.py | 3 ++ python/paddle/fluid/net_drawer.py | 2 + python/paddle/fluid/nets.py | 2 + python/paddle/fluid/op.py | 2 + python/paddle/fluid/optimizer.py | 2 + python/paddle/fluid/param_attr.py | 2 + python/paddle/fluid/profiler.py | 2 + python/paddle/fluid/recordio_writer.py | 2 + python/paddle/fluid/regularizer.py | 2 + .../fit_a_line/test_fit_a_line.py | 2 + .../cifar10_small_test_set.py | 2 + .../test_image_classification_resnet.py | 2 + .../test_image_classification_vgg.py | 2 + .../test_label_semantic_roles_newapi.py | 2 + .../test_machine_translation.py | 2 + .../test_recognize_digits_conv.py | 2 + .../test_recognize_digits_mlp.py | 2 + .../test_recommender_system_newapi.py | 2 + .../test_understand_sentiment_conv.py | 2 + .../test_understand_sentiment_dynamic_rnn.py | 2 + .../test_understand_sentiment_stacked_lstm.py | 2 + .../word2vec/test_word2vec_new_api.py | 2 + .../tests/book/notest_understand_sentiment.py | 2 + .../fluid/tests/book/test_fit_a_line.py | 2 + .../tests/book/test_image_classification.py | 2 + .../tests/book/test_label_semantic_roles.py | 2 + .../tests/book/test_machine_translation.py | 2 + .../fluid/tests/book/test_recognize_digits.py | 2 + .../tests/book/test_recommender_system.py | 2 + .../tests/book/test_rnn_encoder_decoder.py | 2 + .../paddle/fluid/tests/book/test_word2vec.py | 2 + .../test_memopt_fit_a_line.py | 2 + .../test_memopt_image_classification_train.py | 2 + .../test_memopt_machine_translation.py | 2 + python/paddle/fluid/tests/demo/fc_gan.py | 2 + .../file_reader/convert_data_to_recordio.py | 2 + .../fluid/tests/demo/file_reader/train.py | 2 + python/paddle/fluid/tests/demo/pyreader.py | 2 + .../paddle/fluid/tests/no_test_concurrency.py | 2 + .../paddle/fluid/tests/notest_concurrency.py | 2 + .../fluid/tests/test_beam_search_decoder.py | 2 + python/paddle/fluid/tests/test_cpp_reader.py | 2 + python/paddle/fluid/tests/test_data_feeder.py | 2 + python/paddle/fluid/tests/test_detection.py | 2 + python/paddle/fluid/tests/test_error_clip.py | 2 + .../paddle/fluid/tests/test_gradient_clip.py | 2 + python/paddle/fluid/tests/test_if_else_op.py | 2 + python/paddle/fluid/tests/test_lod_tensor.py | 2 + .../tests/test_python_operator_overriding.py | 2 + .../paddle/fluid/tests/unittests/benchmark.py | 2 + .../fluid/tests/unittests/benchmark_sum_op.py | 2 + .../fluid/tests/unittests/decorators.py | 2 + .../fluid/tests/unittests/dist_mnist.py | 2 + .../fluid/tests/unittests/dist_se_resnext.py | 2 + .../fluid/tests/unittests/dist_transformer.py | 2 + .../fluid/tests/unittests/dist_word2vec.py | 2 + .../paddle/fluid/tests/unittests/op_test.py | 2 + .../unittests/parallel_executor_test_base.py | 2 + .../fluid/tests/unittests/test_accuracy_op.py | 2 + .../unittests/test_activation_mkldnn_op.py | 2 + .../tests/unittests/test_activation_op.py | 2 + .../fluid/tests/unittests/test_adadelta_op.py | 2 + .../fluid/tests/unittests/test_adagrad_op.py | 2 + .../fluid/tests/unittests/test_adam_op.py | 2 + .../fluid/tests/unittests/test_adamax_op.py | 2 + .../unittests/test_anchor_generator_op.py | 2 + .../tests/unittests/test_arg_min_max_op.py | 2 + .../fluid/tests/unittests/test_argsort_op.py | 2 + .../unittests/test_array_read_write_op.py | 2 + .../fluid/tests/unittests/test_assign_op.py | 2 + .../tests/unittests/test_assign_value_op.py | 2 + .../fluid/tests/unittests/test_auc_op.py | 2 + .../unittests/test_batch_norm_mkldnn_op.py | 2 + .../tests/unittests/test_batch_norm_op.py | 2 + .../unittests/test_beam_search_decode_op.py | 2 + .../tests/unittests/test_beam_search_op.py | 2 + .../unittests/test_bilinear_interp_op.py | 2 + .../test_bilinear_tensor_product_op.py | 2 + .../unittests/test_bipartite_match_op.py | 2 + .../tests/unittests/test_box_coder_op.py | 2 + .../tests/unittests/test_calc_gradient.py | 2 + .../fluid/tests/unittests/test_cast_op.py | 2 + .../tests/unittests/test_chunk_eval_op.py | 2 + .../tests/unittests/test_clip_by_norm_op.py | 2 + .../fluid/tests/unittests/test_clip_op.py | 2 + .../fluid/tests/unittests/test_compare_op.py | 2 + .../fluid/tests/unittests/test_compat.py | 2 + .../fluid/tests/unittests/test_concat_op.py | 2 + .../tests/unittests/test_conditional_block.py | 2 + .../fluid/tests/unittests/test_const_value.py | 2 + .../tests/unittests/test_conv2d_mkldnn_op.py | 2 + .../fluid/tests/unittests/test_conv2d_op.py | 2 + .../unittests/test_conv2d_transpose_op.py | 2 + .../fluid/tests/unittests/test_conv3d_op.py | 2 + .../unittests/test_conv3d_transpose_op.py | 2 + .../tests/unittests/test_conv_shift_op.py | 2 + .../fluid/tests/unittests/test_cos_sim_op.py | 2 + .../unittests/test_create_op_doc_string.py | 2 + .../tests/unittests/test_crf_decoding_op.py | 2 + .../fluid/tests/unittests/test_crop_op.py | 2 + .../tests/unittests/test_cross_entropy_op.py | 2 + .../fluid/tests/unittests/test_ctc_align.py | 2 + .../fluid/tests/unittests/test_cumsum_op.py | 2 + .../tests/unittests/test_data_balance.py | 2 + .../fluid/tests/unittests/test_debugger.py | 2 + .../unittests/test_decayed_adagrad_op.py | 2 + .../unittests/test_default_scope_funcs.py | 2 + .../fluid/tests/unittests/test_desc_clone.py | 2 + .../tests/unittests/test_detection_map_op.py | 2 + .../fluid/tests/unittests/test_dist_base.py | 2 + .../fluid/tests/unittests/test_dist_mnist.py | 2 + .../tests/unittests/test_dist_se_resnext.py | 2 + .../fluid/tests/unittests/test_dist_train.py | 2 + .../tests/unittests/test_dist_transformer.py | 2 + .../tests/unittests/test_dist_transpiler.py | 2 + .../tests/unittests/test_dist_word2vec.py | 2 + .../fluid/tests/unittests/test_dropout_op.py | 2 + .../fluid/tests/unittests/test_dyn_rnn.py | 2 + .../unittests/test_dynrnn_gradient_check.py | 2 + .../unittests/test_dynrnn_static_input.py | 2 + .../tests/unittests/test_edit_distance_op.py | 2 + .../test_elementwise_add_mkldnn_op.py | 2 + .../unittests/test_elementwise_add_op.py | 2 + .../unittests/test_elementwise_div_op.py | 2 + .../unittests/test_elementwise_gradient_op.py | 2 + .../unittests/test_elementwise_max_op.py | 2 + .../unittests/test_elementwise_min_op.py | 2 + .../unittests/test_elementwise_mul_op.py | 2 + .../unittests/test_elementwise_pow_op.py | 2 + .../unittests/test_elementwise_sub_op.py | 2 + .../fluid/tests/unittests/test_exception.py | 2 + .../tests/unittests/test_executor_and_mul.py | 2 + .../fluid/tests/unittests/test_expand_op.py | 2 + .../tests/unittests/test_extract_rows_op.py | 2 + .../unittests/test_fake_dequantize_op.py | 2 + .../tests/unittests/test_fake_quantize_op.py | 2 + .../tests/unittests/test_fc_mkldnn_op.py | 2 + .../tests/unittests/test_feed_fetch_method.py | 2 + .../fluid/tests/unittests/test_fetch_var.py | 2 + .../test_fill_constant_batch_size_like_op.py | 2 + .../tests/unittests/test_fill_constant_op.py | 2 + .../fluid/tests/unittests/test_fill_op.py | 2 + .../unittests/test_fill_zeros_like_op.py | 2 + .../fluid/tests/unittests/test_flatten_op.py | 2 + .../unittests/test_framework_debug_str.py | 2 + .../fluid/tests/unittests/test_ftrl_op.py | 2 + .../test_fused_elemwise_activation_op.py | 2 + .../fluid/tests/unittests/test_gather_op.py | 2 + ...test_gaussian_random_batch_size_like_op.py | 2 + .../test_gaussian_random_mkldnn_op.py | 2 + .../unittests/test_gaussian_random_op.py | 2 + .../tests/unittests/test_get_places_op.py | 2 + .../fluid/tests/unittests/test_gru_op.py | 2 + .../fluid/tests/unittests/test_gru_unit_op.py | 2 + .../tests/unittests/test_hinge_loss_op.py | 2 + .../fluid/tests/unittests/test_hsigmoid_op.py | 2 + .../tests/unittests/test_huber_loss_op.py | 2 + .../tests/unittests/test_im2sequence_op.py | 2 + .../test_image_classification_layer.py | 2 + .../fluid/tests/unittests/test_infer_shape.py | 2 + .../unittests/test_inference_model_io.py | 2 + .../fluid/tests/unittests/test_initializer.py | 2 + .../tests/unittests/test_iou_similarity_op.py | 2 + .../fluid/tests/unittests/test_is_empty_op.py | 2 + .../fluid/tests/unittests/test_l1_norm_op.py | 2 + .../tests/unittests/test_label_smooth_op.py | 2 + .../tests/unittests/test_layer_norm_op.py | 2 + .../unittests/test_learning_rate_scheduler.py | 2 + .../unittests/test_linear_chain_crf_op.py | 2 + .../unittests/test_listen_and_serv_op.py | 2 + .../unittests/test_lod_array_length_op.py | 2 + .../tests/unittests/test_lod_rank_table.py | 2 + .../tests/unittests/test_lod_reset_op.py | 2 + .../tests/unittests/test_lod_tensor_array.py | 2 + .../unittests/test_lod_tensor_array_ops.py | 2 + .../fluid/tests/unittests/test_log_loss_op.py | 2 + .../fluid/tests/unittests/test_logical_op.py | 2 + .../unittests/test_lookup_sparse_table_op.py | 2 + .../tests/unittests/test_lookup_table_op.py | 2 + .../tests/unittests/test_lrn_mkldnn_op.py | 2 + .../fluid/tests/unittests/test_lrn_op.py | 2 + .../fluid/tests/unittests/test_lstm_op.py | 2 + .../tests/unittests/test_lstm_unit_op.py | 2 + .../fluid/tests/unittests/test_lstmp_op.py | 2 + .../unittests/test_margin_rank_loss_op.py | 2 + .../tests/unittests/test_math_op_patch.py | 2 + .../fluid/tests/unittests/test_matmul_op.py | 2 + .../fluid/tests/unittests/test_maxout_op.py | 2 + .../fluid/tests/unittests/test_mean_iou.py | 2 + .../fluid/tests/unittests/test_mean_op.py | 2 + .../tests/unittests/test_memory_usage.py | 2 +- .../tests/unittests/test_merge_ids_op.py | 2 + .../unittests/test_mine_hard_examples_op.py | 2 + .../fluid/tests/unittests/test_minus_op.py | 2 + .../unittests/test_modified_huber_loss_op.py | 2 + .../fluid/tests/unittests/test_momentum_op.py | 2 + .../fluid/tests/unittests/test_mul_op.py | 2 + .../tests/unittests/test_multi_file_reader.py | 2 + .../tests/unittests/test_multi_pass_reader.py | 2 + .../tests/unittests/test_multiclass_nms_op.py | 2 + .../unittests/test_multihead_attention.py | 2 + .../tests/unittests/test_multiplex_op.py | 2 + .../paddle/fluid/tests/unittests/test_nce.py | 2 + .../unittests/test_network_with_dtype.py | 2 + .../fluid/tests/unittests/test_norm_op.py | 2 + .../unittests/test_normalization_wrapper.py | 2 + .../fluid/tests/unittests/test_nvprof.py | 2 + .../fluid/tests/unittests/test_one_hot_op.py | 2 + .../tests/unittests/test_op_support_gpu.py | 2 + .../fluid/tests/unittests/test_operator.py | 2 + .../tests/unittests/test_operator_desc.py | 2 + .../fluid/tests/unittests/test_optimizer.py | 2 + .../fluid/tests/unittests/test_pad_op.py | 2 + .../unittests/test_parallel_executor_crf.py | 2 + .../test_parallel_executor_fetch_feed.py | 2 + .../unittests/test_parallel_executor_mnist.py | 2 + .../test_parallel_executor_seresnext.py | 2 + ...test_parallel_executor_test_while_train.py | 2 + .../test_parallel_executor_transformer.py | 2 + .../fluid/tests/unittests/test_parallel_op.py | 2 + .../fluid/tests/unittests/test_parameter.py | 2 + .../unittests/test_polygon_box_transform.py | 2 + .../tests/unittests/test_pool2d_mkldnn_op.py | 2 + .../fluid/tests/unittests/test_pool2d_op.py | 2 + .../fluid/tests/unittests/test_pool3d_op.py | 2 + .../fluid/tests/unittests/test_pool_max_op.py | 2 + .../test_positive_negative_pair_op.py | 2 + .../unittests/test_precision_recall_op.py | 2 + .../fluid/tests/unittests/test_prelu_op.py | 2 + .../tests/unittests/test_preprocessor.py | 2 + .../fluid/tests/unittests/test_print_op.py | 2 + .../tests/unittests/test_prior_box_op.py | 2 + .../fluid/tests/unittests/test_profiler.py | 2 + .../fluid/tests/unittests/test_protobuf.py | 2 + .../tests/unittests/test_protobuf_descs.py | 2 + .../unittests/test_proximal_adagrad_op.py | 2 + .../tests/unittests/test_proximal_gd_op.py | 2 + .../unittests/test_py_reader_push_pop.py | 2 + .../test_py_reader_using_executor.py | 2 + .../tests/unittests/test_random_crop_op.py | 2 + .../tests/unittests/test_rank_loss_op.py | 2 + .../tests/unittests/test_reader_reset.py | 2 + .../tests/unittests/test_recordio_reader.py | 2 + .../tests/unittests/test_recurrent_op.py | 2 + .../fluid/tests/unittests/test_reduce_op.py | 2 + .../fluid/tests/unittests/test_registry.py | 2 + .../fluid/tests/unittests/test_regularizer.py | 2 + .../unittests/test_reorder_lod_tensor.py | 2 + .../fluid/tests/unittests/test_reshape_op.py | 2 + .../fluid/tests/unittests/test_reverse_op.py | 2 + .../fluid/tests/unittests/test_rmsprop_op.py | 2 + .../unittests/test_rnn_memory_helper_op.py | 2 + .../fluid/tests/unittests/test_roi_pool_op.py | 2 + .../fluid/tests/unittests/test_row_conv_op.py | 2 + .../unittests/test_rpn_target_assign_op.py | 2 + .../fluid/tests/unittests/test_scale_op.py | 2 + .../fluid/tests/unittests/test_scatter_op.py | 2 + .../fluid/tests/unittests/test_scope.py | 2 + .../tests/unittests/test_selected_rows.py | 2 + .../tests/unittests/test_seq_concat_op.py | 2 + .../fluid/tests/unittests/test_seq_conv.py | 2 + .../fluid/tests/unittests/test_seq_pool.py | 2 + .../tests/unittests/test_sequence_erase_op.py | 2 + .../tests/unittests/test_sequence_expand.py | 2 + .../tests/unittests/test_sequence_reshape.py | 2 + .../tests/unittests/test_sequence_slice_op.py | 2 + .../unittests/test_sequence_softmax_op.py | 2 + .../fluid/tests/unittests/test_sgd_op.py | 2 + .../fluid/tests/unittests/test_shape_op.py | 2 + .../tests/unittests/test_shrink_rnn_memory.py | 2 + ...st_sigmoid_cross_entropy_with_logits_op.py | 2 + .../fluid/tests/unittests/test_sign_op.py | 2 + .../fluid/tests/unittests/test_slice_op.py | 2 + .../fluid/tests/unittests/test_slice_var.py | 2 + .../tests/unittests/test_smooth_l1_loss_op.py | 2 + .../fluid/tests/unittests/test_softmax_op.py | 2 + .../test_softmax_with_cross_entropy_op.py | 2 + .../test_split_and_merge_lod_tensor_op.py | 2 + .../tests/unittests/test_split_ids_op.py | 2 + .../fluid/tests/unittests/test_split_op.py | 2 + .../unittests/test_split_selected_rows_op.py | 2 + .../fluid/tests/unittests/test_spp_op.py | 2 + .../unittests/test_squared_l2_distance_op.py | 2 + .../unittests/test_squared_l2_norm_op.py | 2 + .../fluid/tests/unittests/test_squeeze_op.py | 2 + .../tests/unittests/test_sum_mkldnn_op.py | 2 + .../fluid/tests/unittests/test_sum_op.py | 2 + .../fluid/tests/unittests/test_switch.py | 2 + .../tests/unittests/test_target_assign_op.py | 2 + .../fluid/tests/unittests/test_tensor.py | 2 + .../fluid/tests/unittests/test_top_k_op.py | 2 + .../tests/unittests/test_transpose_op.py | 2 + .../test_uniform_random_batch_size_like_op.py | 2 + .../tests/unittests/test_uniform_random_op.py | 2 + .../fluid/tests/unittests/test_unique_name.py | 2 + .../fluid/tests/unittests/test_unpool_op.py | 2 + .../tests/unittests/test_unsqueeze_op.py | 2 + .../fluid/tests/unittests/test_variable.py | 2 + .../fluid/tests/unittests/test_version.py | 2 + .../fluid/tests/unittests/test_warpctc_op.py | 2 + .../unittests/test_weight_normalization.py | 2 + .../fluid/tests/unittests/test_while_op.py | 2 + .../paddle/fluid/tests/unittests/testsuite.py | 2 + .../tests/unittests/transformer_model.py | 2 + python/paddle/fluid/trainer.py | 2 + python/paddle/fluid/transpiler/__init__.py | 2 + .../fluid/transpiler/details/__init__.py | 2 + .../fluid/transpiler/details/program_utils.py | 2 + .../paddle/fluid/transpiler/details/ufind.py | 2 + .../fluid/transpiler/distribute_transpiler.py | 2 + .../fluid/transpiler/inference_transpiler.py | 2 + .../memory_optimization_transpiler.py | 2 + .../paddle/fluid/transpiler/ps_dispatcher.py | 2 + python/paddle/fluid/unique_name.py | 2 + 373 files changed, 774 insertions(+), 40 deletions(-) diff --git a/python/paddle/dataset/cifar.py b/python/paddle/dataset/cifar.py index b42bc192b2..b83fa78c4c 100644 --- a/python/paddle/dataset/cifar.py +++ b/python/paddle/dataset/cifar.py @@ -28,6 +28,8 @@ images per class. """ +from __future__ import print_function + import itertools import numpy import paddle.dataset.common diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index a75cabd676..1d7ff582c8 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import requests import hashlib import os diff --git a/python/paddle/dataset/conll05.py b/python/paddle/dataset/conll05.py index b23d127eeb..55cfd92721 100644 --- a/python/paddle/dataset/conll05.py +++ b/python/paddle/dataset/conll05.py @@ -20,6 +20,8 @@ dataset. And a pre-trained word vector model based on Wikipedia corpus is used to initialize SRL model. """ +from __future__ import print_function + import tarfile import gzip import itertools diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index 7d14cc5dc8..17c768424f 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -28,6 +28,9 @@ Graphics and Image Processing (2008) http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}. """ + +from __future__ import print_function + import itertools import functools from .common import download diff --git a/python/paddle/dataset/image.py b/python/paddle/dataset/image.py index 99d2c5f899..1cd50bd180 100644 --- a/python/paddle/dataset/image.py +++ b/python/paddle/dataset/image.py @@ -29,6 +29,9 @@ the image layout as follows. formats can be used for training. Noted that, the format should be keep consistent between the training and inference peroid. """ + +from __future__ import print_function + import numpy as np try: import cv2 diff --git a/python/paddle/dataset/imdb.py b/python/paddle/dataset/imdb.py index 903e93d34f..fd92523a94 100644 --- a/python/paddle/dataset/imdb.py +++ b/python/paddle/dataset/imdb.py @@ -20,6 +20,8 @@ of 25,000 highly polar movie reviews for training, and 25,000 for testing. Besides, this module also provides API for building dictionary. """ +from __future__ import print_function + import paddle.dataset.common import collections import tarfile diff --git a/python/paddle/dataset/imikolov.py b/python/paddle/dataset/imikolov.py index 422eaef644..8eecb75231 100644 --- a/python/paddle/dataset/imikolov.py +++ b/python/paddle/dataset/imikolov.py @@ -18,6 +18,9 @@ This module will download dataset from http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set into paddle reader creators. """ + +from __future__ import print_function + import paddle.dataset.common import collections import tarfile diff --git a/python/paddle/dataset/mnist.py b/python/paddle/dataset/mnist.py index 28e6a04795..3038747bf8 100644 --- a/python/paddle/dataset/mnist.py +++ b/python/paddle/dataset/mnist.py @@ -17,6 +17,9 @@ MNIST dataset. This module will download dataset from http://yann.lecun.com/exdb/mnist/ and parse training set and test set into paddle reader creators. """ + +from __future__ import print_function + import paddle.dataset.common import subprocess import numpy diff --git a/python/paddle/dataset/movielens.py b/python/paddle/dataset/movielens.py index fe07daf5c3..c98e0019f7 100644 --- a/python/paddle/dataset/movielens.py +++ b/python/paddle/dataset/movielens.py @@ -22,6 +22,8 @@ set and test set into paddle reader creators. """ +from __future__ import print_function + import zipfile import paddle.dataset.common import re diff --git a/python/paddle/dataset/mq2007.py b/python/paddle/dataset/mq2007.py index cc4d088316..d5740f30c8 100644 --- a/python/paddle/dataset/mq2007.py +++ b/python/paddle/dataset/mq2007.py @@ -23,6 +23,8 @@ http://research.microsoft.com/en-us/um/beijing/projects/letor/LETOR4.0/Data/MQ20 """ +from __future__ import print_function + import os import functools import rarfile diff --git a/python/paddle/dataset/sentiment.py b/python/paddle/dataset/sentiment.py index 25cd59df92..22d867beea 100644 --- a/python/paddle/dataset/sentiment.py +++ b/python/paddle/dataset/sentiment.py @@ -20,6 +20,8 @@ The script fetch and preprocess movie_reviews data set that provided by NLTK TODO(yuyang18): Complete dataset. """ +from __future__ import print_function + import six import collections from itertools import chain diff --git a/python/paddle/dataset/tests/cifar_test.py b/python/paddle/dataset/tests/cifar_test.py index 839125b09d..8e514f0fd9 100644 --- a/python/paddle/dataset/tests/cifar_test.py +++ b/python/paddle/dataset/tests/cifar_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.dataset.cifar import unittest diff --git a/python/paddle/dataset/tests/common_test.py b/python/paddle/dataset/tests/common_test.py index ede3d593eb..0ce7d83f37 100644 --- a/python/paddle/dataset/tests/common_test.py +++ b/python/paddle/dataset/tests/common_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.dataset.common import unittest import tempfile diff --git a/python/paddle/dataset/tests/flowers_test.py b/python/paddle/dataset/tests/flowers_test.py index 06260fd796..06a0a7761c 100644 --- a/python/paddle/dataset/tests/flowers_test.py +++ b/python/paddle/dataset/tests/flowers_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.dataset.flowers import unittest diff --git a/python/paddle/dataset/tests/imdb_test.py b/python/paddle/dataset/tests/imdb_test.py index 539da04944..415947e347 100644 --- a/python/paddle/dataset/tests/imdb_test.py +++ b/python/paddle/dataset/tests/imdb_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.dataset.imdb import unittest import re diff --git a/python/paddle/dataset/tests/imikolov_test.py b/python/paddle/dataset/tests/imikolov_test.py index 50f50d947d..1f78a5dd4d 100644 --- a/python/paddle/dataset/tests/imikolov_test.py +++ b/python/paddle/dataset/tests/imikolov_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.dataset.imikolov import unittest diff --git a/python/paddle/dataset/tests/mnist_test.py b/python/paddle/dataset/tests/mnist_test.py index 8ada19d3f2..fbb5d92649 100644 --- a/python/paddle/dataset/tests/mnist_test.py +++ b/python/paddle/dataset/tests/mnist_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.dataset.mnist import unittest diff --git a/python/paddle/dataset/tests/mq2007_test.py b/python/paddle/dataset/tests/mq2007_test.py index fba388724a..ee0897e88f 100644 --- a/python/paddle/dataset/tests/mq2007_test.py +++ b/python/paddle/dataset/tests/mq2007_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.dataset.mq2007 import unittest diff --git a/python/paddle/dataset/tests/test_image.py b/python/paddle/dataset/tests/test_image.py index 8bd56607ae..32d2eb17ae 100644 --- a/python/paddle/dataset/tests/test_image.py +++ b/python/paddle/dataset/tests/test_image.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/dataset/tests/test_sentiment.py b/python/paddle/dataset/tests/test_sentiment.py index 37326517f7..bb9830132e 100644 --- a/python/paddle/dataset/tests/test_sentiment.py +++ b/python/paddle/dataset/tests/test_sentiment.py @@ -15,6 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import nltk import paddle.dataset.sentiment as st diff --git a/python/paddle/dataset/tests/voc2012_test.py b/python/paddle/dataset/tests/voc2012_test.py index 0d285461a8..cddeb91cab 100644 --- a/python/paddle/dataset/tests/voc2012_test.py +++ b/python/paddle/dataset/tests/voc2012_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.dataset.voc2012 import unittest diff --git a/python/paddle/dataset/tests/wmt16_test.py b/python/paddle/dataset/tests/wmt16_test.py index 8b949d8bf5..be121bb101 100644 --- a/python/paddle/dataset/tests/wmt16_test.py +++ b/python/paddle/dataset/tests/wmt16_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.dataset.wmt16 import unittest diff --git a/python/paddle/dataset/uci_housing.py b/python/paddle/dataset/uci_housing.py index 2ba8ddcc1f..f87fdcc4f0 100644 --- a/python/paddle/dataset/uci_housing.py +++ b/python/paddle/dataset/uci_housing.py @@ -19,7 +19,7 @@ https://archive.ics.uci.edu/ml/machine-learning-databases/housing/ and parse training set and test set into paddle reader creators. """ -import os +from __future__ import print_function import numpy as np import six diff --git a/python/paddle/dataset/voc2012.py b/python/paddle/dataset/voc2012.py index 9c945574db..5068893765 100644 --- a/python/paddle/dataset/voc2012.py +++ b/python/paddle/dataset/voc2012.py @@ -19,6 +19,8 @@ to training/test sets has been maintained. The total number of images with segmentation has been increased from 7,062 to 9,993. """ +from __future__ import print_function + import tarfile import io import numpy as np diff --git a/python/paddle/dataset/wmt14.py b/python/paddle/dataset/wmt14.py index cf366309c0..f8c1a33574 100644 --- a/python/paddle/dataset/wmt14.py +++ b/python/paddle/dataset/wmt14.py @@ -19,6 +19,9 @@ http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz and parse training set and test set into paddle reader creators. """ + +from __future__ import print_function + import six import tarfile import gzip diff --git a/python/paddle/dataset/wmt16.py b/python/paddle/dataset/wmt16.py index d68a6e8be7..f30dcd518e 100644 --- a/python/paddle/dataset/wmt16.py +++ b/python/paddle/dataset/wmt16.py @@ -28,6 +28,8 @@ Multi30K: Multilingual English-German Image Descriptions. } """ +from __future__ import print_function + import os import six import tarfile diff --git a/python/paddle/fluid/average.py b/python/paddle/fluid/average.py index 358e24df31..42cd3b3642 100644 --- a/python/paddle/fluid/average.py +++ b/python/paddle/fluid/average.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import warnings """ diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 3824b21ec2..a415cdbeaa 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from paddle.fluid import framework as framework from . import core import collections diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 4b0a792f78..ba7ba3b5e9 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import copy import six diff --git a/python/paddle/fluid/concurrency.py b/python/paddle/fluid/concurrency.py index 676a52a917..b4a06f23a6 100644 --- a/python/paddle/fluid/concurrency.py +++ b/python/paddle/fluid/concurrency.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from .layers.control_flow import BlockGuard, equal from .framework import Operator from .layer_helper import LayerHelper, unique_name diff --git a/python/paddle/fluid/contrib/__init__.py b/python/paddle/fluid/contrib/__init__.py index 58f2da1c3b..5607f11932 100644 --- a/python/paddle/fluid/contrib/__init__.py +++ b/python/paddle/fluid/contrib/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from . import decoder from .decoder import * from . import memory_usage_calc diff --git a/python/paddle/fluid/contrib/decoder/__init__.py b/python/paddle/fluid/contrib/decoder/__init__.py index 6343c1543d..9f973fd3c9 100644 --- a/python/paddle/fluid/contrib/decoder/__init__.py +++ b/python/paddle/fluid/contrib/decoder/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from . import beam_search_decoder from .beam_search_decoder import * diff --git a/python/paddle/fluid/contrib/decoder/beam_search_decoder.py b/python/paddle/fluid/contrib/decoder/beam_search_decoder.py index d268a948f7..f2b7ac8375 100644 --- a/python/paddle/fluid/contrib/decoder/beam_search_decoder.py +++ b/python/paddle/fluid/contrib/decoder/beam_search_decoder.py @@ -20,6 +20,8 @@ without using the low level API such as while ops. This API is still under active development and may change drastically. """ +from __future__ import print_function + import contextlib import numpy as np import six diff --git a/python/paddle/fluid/contrib/memory_usage_calc.py b/python/paddle/fluid/contrib/memory_usage_calc.py index f0316a70ec..09721e430b 100644 --- a/python/paddle/fluid/contrib/memory_usage_calc.py +++ b/python/paddle/fluid/contrib/memory_usage_calc.py @@ -20,6 +20,8 @@ batch size to fully utilize a GPU. This API is still under active development and may change drastically. """ +from __future__ import print_function + import six from .. import core diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 9452cf0e2a..631bbfe1fe 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from . import core import numpy import os diff --git a/python/paddle/fluid/debugger.py b/python/paddle/fluid/debugger.py index ea6c14df72..63060a77d1 100644 --- a/python/paddle/fluid/debugger.py +++ b/python/paddle/fluid/debugger.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import sys import six import re diff --git a/python/paddle/fluid/default_scope_funcs.py b/python/paddle/fluid/default_scope_funcs.py index f8faf69425..a5b2c84dfe 100644 --- a/python/paddle/fluid/default_scope_funcs.py +++ b/python/paddle/fluid/default_scope_funcs.py @@ -26,6 +26,8 @@ A `scoped_function` will take a `function` as input. That function will be invoked in a new local scope. """ +from __future__ import print_function + import paddle.fluid.core import threading diff --git a/python/paddle/fluid/evaluator.py b/python/paddle/fluid/evaluator.py index c0671cce9a..7a82038ff7 100644 --- a/python/paddle/fluid/evaluator.py +++ b/python/paddle/fluid/evaluator.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import warnings import numpy as np diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index b4d9989851..288951cd7c 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import contextlib import six diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 621e46b0f9..2377ac5f92 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import collections import contextlib import re diff --git a/python/paddle/fluid/graphviz.py b/python/paddle/fluid/graphviz.py index 27d4a7d8dc..2b18d854d1 100644 --- a/python/paddle/fluid/graphviz.py +++ b/python/paddle/fluid/graphviz.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import os import random import six diff --git a/python/paddle/fluid/inferencer.py b/python/paddle/fluid/inferencer.py index ff382d8b83..3d2ef56617 100644 --- a/python/paddle/fluid/inferencer.py +++ b/python/paddle/fluid/inferencer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import contextlib from . import core diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 3f740dd7c5..4680fa700e 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from . import framework import numpy as np import contextlib diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index e277c85021..6b67128fbf 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import os import errno import time diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index 82df4c6c54..bd9727b6ac 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import copy import itertools import six diff --git a/python/paddle/fluid/layers/__init__.py b/python/paddle/fluid/layers/__init__.py index a48e360463..a2a808777d 100644 --- a/python/paddle/fluid/layers/__init__.py +++ b/python/paddle/fluid/layers/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from . import ops from .ops import * from . import nn diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 730075a1ec..173567a0a3 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import contextlib from .layer_function_generator import autodoc, templatedoc diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 9baf5f84fd..7207147884 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -15,6 +15,8 @@ All layers just related to the detection neural network. """ +from __future__ import print_function + from .layer_function_generator import generate_layer_fn from .layer_function_generator import autodoc, templatedoc from ..layer_helper import LayerHelper diff --git a/python/paddle/fluid/layers/device.py b/python/paddle/fluid/layers/device.py index bb1fb7fd57..43ebd160de 100644 --- a/python/paddle/fluid/layers/device.py +++ b/python/paddle/fluid/layers/device.py @@ -15,6 +15,8 @@ All util layers. """ +from __future__ import print_function + from .layer_function_generator import autodoc from ..framework import unique_name from ..layer_helper import LayerHelper diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index bac641327d..21a295a098 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import contextlib import multiprocessing import six diff --git a/python/paddle/fluid/layers/layer_function_generator.py b/python/paddle/fluid/layers/layer_function_generator.py index c0d72620b1..8963d74de0 100644 --- a/python/paddle/fluid/layers/layer_function_generator.py +++ b/python/paddle/fluid/layers/layer_function_generator.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import re import functools import warnings diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index daf91a40f7..be368007dd 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -20,6 +20,8 @@ User can also implement their own learning_rate_decay strategy according to this module. """ +from __future__ import print_function + from . import control_flow from . import nn from . import ops @@ -72,10 +74,10 @@ def noam_decay(d_model, warmup_steps): def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): """ - Applies exponential decay to the learning rate. + Applies exponential decay to the learning rate. - When training a model, it is often recommended to lower the learning rate as the - training progresses. By using this function, the learning rate will be decayed by + When training a model, it is often recommended to lower the learning rate as the + training progresses. By using this function, the learning rate will be decayed by 'decay_rate' every 'decay_steps' steps. >>> if staircase == True: @@ -148,8 +150,8 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): """ Applies inverse time decay to the initial learning rate. - When training a model, it is often recommended to lower the learning rate as the - training progresses. By using this function, an inverse decay function will be + When training a model, it is often recommended to lower the learning rate as the + training progresses. By using this function, an inverse decay function will be applied to the initial learning rate. >>> if staircase == True: diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 0e10a91d25..a458cebfb1 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from ..framework import Variable, unique_name from .layer_function_generator import OpProtoHolder from ..initializer import force_init_on_cpu diff --git a/python/paddle/fluid/layers/metric_op.py b/python/paddle/fluid/layers/metric_op.py index 49bae1e8af..2c3bdd77e1 100644 --- a/python/paddle/fluid/layers/metric_op.py +++ b/python/paddle/fluid/layers/metric_op.py @@ -15,6 +15,8 @@ All layers just related to metric. """ +from __future__ import print_function + import warnings from ..layer_helper import LayerHelper from ..initializer import Normal, Constant @@ -81,9 +83,9 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): **Area Under the Curve (AUC) Layer** This implementation computes the AUC according to forward output and label. - It is used very widely in binary classification evaluation. + It is used very widely in binary classification evaluation. - Note: If input label contains values other than 0 and 1, it will be cast + Note: If input label contains values other than 0 and 1, it will be cast to `bool`. Find the relevant definitions `here `_. @@ -93,14 +95,14 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): 2. PR: Precision Recall Args: - input(Variable): A floating-point 2D Variable, values are in the range - [0, 1]. Each row is sorted in descending order. This - input should be the output of topk. Typically, this + input(Variable): A floating-point 2D Variable, values are in the range + [0, 1]. Each row is sorted in descending order. This + input should be the output of topk. Typically, this Variable indicates the probability of each label. - label(Variable): A 2D int Variable indicating the label of the training + label(Variable): A 2D int Variable indicating the label of the training data. The height is batch size and width is always 1. curve(str): Curve type, can be 'ROC' or 'PR'. Default 'ROC'. - num_thresholds(int): The number of thresholds to use when discretizing + num_thresholds(int): The number of thresholds to use when discretizing the roc curve. Default 200. topk(int): only topk number of prediction output will be used for auc. @@ -109,7 +111,7 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): Examples: .. code-block:: python - + # network is a binary classification model and label the ground truth prediction = network(image, is_infer=True) auc_out=fluid.layers.auc(input=prediction, label=label) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index aed09914bb..0c1f78e435 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11,24 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Copyright (c ) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """ All layers just related to the neural network. """ +from __future__ import print_function + from ..layer_helper import LayerHelper from ..initializer import Normal, Constant from ..framework import Variable @@ -1319,15 +1307,15 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True): def softmax(input, param_attr=None, bias_attr=None, use_cudnn=True, name=None): """ - The input of the softmax operator is a tensor of any rank. The output tensor + The input of the softmax operator is a tensor of any rank. The output tensor has the same shape as the input. - The input tensor will first be logically flattened to a 2-D matrix. The matrix's - second dimension(row length) is as same as the last dimension of the input - tensor, and the first dimension(column length) is the product of all other - dimensions of the input tensor. For each row of the matrix, the softmax operator - squashes the K-dimensional(K is the width of the matrix, which is also the size - of the input tensor's last dimension) vector of arbitrary real values to a + The input tensor will first be logically flattened to a 2-D matrix. The matrix's + second dimension(row length) is as same as the last dimension of the input + tensor, and the first dimension(column length) is the product of all other + dimensions of the input tensor. For each row of the matrix, the softmax operator + squashes the K-dimensional(K is the width of the matrix, which is also the size + of the input tensor's last dimension) vector of arbitrary real values to a K-dimensional vector of real values in the range [0, 1] that add up to 1. It computes the exponential of the given dimension and the sum of exponential @@ -5377,7 +5365,7 @@ def flatten(x, axis=1, name=None): axis = 2 We get: Out.shape = (3 * 100, 4 * 100) - + Case 2: Given X.shape = (3, 100, 100, 4) @@ -5388,8 +5376,8 @@ def flatten(x, axis=1, name=None): Args: x (Variable): A tensor of rank >= axis. - axis (int): Indicate up to which input dimensions (exclusive) should - be flattened to the outer dimension of the output. + axis (int): Indicate up to which input dimensions (exclusive) should + be flattened to the outer dimension of the output. The value for axis must be in the range [0, R], where R is the rank of the input tensor. When axis = 0, the shape of the output tensor is (1, (d_0 X d_1 ... d_n), where the @@ -5405,7 +5393,7 @@ def flatten(x, axis=1, name=None): Raises: ValueError: If x is not a variable. - ValueError: If axis is not in range [0, rank(x)]. + ValueError: If axis is not in range [0, rank(x)]. Examples: diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index f70c7f2258..cc4a7de163 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function from .layer_function_generator import generate_layer_fn __activations__ = [ diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index b93d721c12..04e71497aa 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from ..layer_helper import LayerHelper from ..param_attr import ParamAttr from ..framework import convert_np_dtype_to_dtype_ diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py index 49ec308883..5688f04ab2 100644 --- a/python/paddle/fluid/layers/utils.py +++ b/python/paddle/fluid/layers/utils.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import numpy as np diff --git a/python/paddle/fluid/lod_tensor.py b/python/paddle/fluid/lod_tensor.py index 53c33616f5..a9de09f31f 100644 --- a/python/paddle/fluid/lod_tensor.py +++ b/python/paddle/fluid/lod_tensor.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from . import core import numpy as np diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index 19df1e1dcb..592cb23eb9 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -16,6 +16,9 @@ Fluid Metrics The metrics are accomplished via Python natively. """ + +from __future__ import print_function + import numpy as np import copy import warnings diff --git a/python/paddle/fluid/net_drawer.py b/python/paddle/fluid/net_drawer.py index 623a7d3fd0..0b61c23d07 100644 --- a/python/paddle/fluid/net_drawer.py +++ b/python/paddle/fluid/net_drawer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import argparse import json import logging diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index 46e4c70195..051fe84364 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import six from . import layers diff --git a/python/paddle/fluid/op.py b/python/paddle/fluid/op.py index a2db5bad51..667db10d3e 100644 --- a/python/paddle/fluid/op.py +++ b/python/paddle/fluid/op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import six diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index a07325f46a..031ddd09a0 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import re from collections import defaultdict from paddle.fluid.framework import Program, Variable diff --git a/python/paddle/fluid/param_attr.py b/python/paddle/fluid/param_attr.py index afae577656..f0be794327 100644 --- a/python/paddle/fluid/param_attr.py +++ b/python/paddle/fluid/param_attr.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import six from .initializer import Initializer, Xavier, Constant diff --git a/python/paddle/fluid/profiler.py b/python/paddle/fluid/profiler.py index 5bbbdf7fe7..e05885f5f5 100644 --- a/python/paddle/fluid/profiler.py +++ b/python/paddle/fluid/profiler.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from . import core from contextlib import contextmanager import os diff --git a/python/paddle/fluid/recordio_writer.py b/python/paddle/fluid/recordio_writer.py index 93b38ad3fa..a69c0c29d4 100644 --- a/python/paddle/fluid/recordio_writer.py +++ b/python/paddle/fluid/recordio_writer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import os import contextlib from . import core diff --git a/python/paddle/fluid/regularizer.py b/python/paddle/fluid/regularizer.py index 6eaac4432d..da38626111 100644 --- a/python/paddle/fluid/regularizer.py +++ b/python/paddle/fluid/regularizer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from . import framework from . import core diff --git a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py index 36a1a223cf..f6017a455d 100644 --- a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid import contextlib diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py index c03e73542a..48c0f3d361 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py @@ -28,6 +28,8 @@ images per class. """ +from __future__ import print_function + import itertools import numpy import paddle.dataset.common diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py index 54c59ac075..be494a0d34 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid import numpy diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py index 8429551765..dbc7bc06c9 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid import numpy diff --git a/python/paddle/fluid/tests/book/high-level-api/label_semantic_roles/test_label_semantic_roles_newapi.py b/python/paddle/fluid/tests/book/high-level-api/label_semantic_roles/test_label_semantic_roles_newapi.py index e3602e2d56..ec4e1c768c 100755 --- a/python/paddle/fluid/tests/book/high-level-api/label_semantic_roles/test_label_semantic_roles_newapi.py +++ b/python/paddle/fluid/tests/book/high-level-api/label_semantic_roles/test_label_semantic_roles_newapi.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid import numpy as np diff --git a/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py b/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py index 6fb0c85a8b..560f118958 100644 --- a/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py +++ b/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import contextlib import numpy as np diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py index 898807db6f..187bef1b0c 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py +++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import argparse import paddle.fluid as fluid import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py index 6dd64be315..b95e7db122 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py +++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import argparse import paddle.fluid as fluid import paddle diff --git a/python/paddle/fluid/tests/book/high-level-api/recommender_system/test_recommender_system_newapi.py b/python/paddle/fluid/tests/book/high-level-api/recommender_system/test_recommender_system_newapi.py index 60f3d8e105..9e2767783b 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recommender_system/test_recommender_system_newapi.py +++ b/python/paddle/fluid/tests/book/high-level-api/recommender_system/test_recommender_system_newapi.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import math import sys import numpy as np diff --git a/python/paddle/fluid/tests/book/high-level-api/understand_sentiment/test_understand_sentiment_conv.py b/python/paddle/fluid/tests/book/high-level-api/understand_sentiment/test_understand_sentiment_conv.py index 24e65d1bd5..097c2a468f 100644 --- a/python/paddle/fluid/tests/book/high-level-api/understand_sentiment/test_understand_sentiment_conv.py +++ b/python/paddle/fluid/tests/book/high-level-api/understand_sentiment/test_understand_sentiment_conv.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid from functools import partial diff --git a/python/paddle/fluid/tests/book/high-level-api/understand_sentiment/test_understand_sentiment_dynamic_rnn.py b/python/paddle/fluid/tests/book/high-level-api/understand_sentiment/test_understand_sentiment_dynamic_rnn.py index b3b1505a0f..5f74cd1425 100644 --- a/python/paddle/fluid/tests/book/high-level-api/understand_sentiment/test_understand_sentiment_dynamic_rnn.py +++ b/python/paddle/fluid/tests/book/high-level-api/understand_sentiment/test_understand_sentiment_dynamic_rnn.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid from functools import partial diff --git a/python/paddle/fluid/tests/book/high-level-api/understand_sentiment/test_understand_sentiment_stacked_lstm.py b/python/paddle/fluid/tests/book/high-level-api/understand_sentiment/test_understand_sentiment_stacked_lstm.py index 25f99ff0fd..284a6ca168 100644 --- a/python/paddle/fluid/tests/book/high-level-api/understand_sentiment/test_understand_sentiment_stacked_lstm.py +++ b/python/paddle/fluid/tests/book/high-level-api/understand_sentiment/test_understand_sentiment_stacked_lstm.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid from functools import partial diff --git a/python/paddle/fluid/tests/book/high-level-api/word2vec/test_word2vec_new_api.py b/python/paddle/fluid/tests/book/high-level-api/word2vec/test_word2vec_new_api.py index 02e65cf56c..1c7cf3199a 100644 --- a/python/paddle/fluid/tests/book/high-level-api/word2vec/test_word2vec_new_api.py +++ b/python/paddle/fluid/tests/book/high-level-api/word2vec/test_word2vec_new_api.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid import numpy as np diff --git a/python/paddle/fluid/tests/book/notest_understand_sentiment.py b/python/paddle/fluid/tests/book/notest_understand_sentiment.py index ce6342c2da..82f1c6615f 100644 --- a/python/paddle/fluid/tests/book/notest_understand_sentiment.py +++ b/python/paddle/fluid/tests/book/notest_understand_sentiment.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from paddle.fluid.layers.device import get_places import unittest import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/book/test_fit_a_line.py b/python/paddle/fluid/tests/book/test_fit_a_line.py index 37b64fa94a..334294ab48 100644 --- a/python/paddle/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/test_fit_a_line.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid import contextlib diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index b6685fe2c2..9fe361425c 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid import contextlib diff --git a/python/paddle/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/fluid/tests/book/test_label_semantic_roles.py index b7ac911caf..f63387a906 100644 --- a/python/paddle/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/fluid/tests/book/test_label_semantic_roles.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import contextlib import math import numpy as np diff --git a/python/paddle/fluid/tests/book/test_machine_translation.py b/python/paddle/fluid/tests/book/test_machine_translation.py index 462faad3e1..5e241aaa32 100644 --- a/python/paddle/fluid/tests/book/test_machine_translation.py +++ b/python/paddle/fluid/tests/book/test_machine_translation.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import contextlib import numpy as np diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index 3e5f76d12d..da216d0cc4 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid.core as core import math import os diff --git a/python/paddle/fluid/tests/book/test_recommender_system.py b/python/paddle/fluid/tests/book/test_recommender_system.py index b30c8771fc..cf8c48f346 100644 --- a/python/paddle/fluid/tests/book/test_recommender_system.py +++ b/python/paddle/fluid/tests/book/test_recommender_system.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import math import sys import os diff --git a/python/paddle/fluid/tests/book/test_rnn_encoder_decoder.py b/python/paddle/fluid/tests/book/test_rnn_encoder_decoder.py index 2e79be2bd0..91c8705aa4 100644 --- a/python/paddle/fluid/tests/book/test_rnn_encoder_decoder.py +++ b/python/paddle/fluid/tests/book/test_rnn_encoder_decoder.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import paddle import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/book/test_word2vec.py b/python/paddle/fluid/tests/book/test_word2vec.py index e761e05795..fe063eb462 100644 --- a/python/paddle/fluid/tests/book/test_word2vec.py +++ b/python/paddle/fluid/tests/book/test_word2vec.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid from paddle.fluid.layers.device import get_places diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py index ccc62b442f..f530f8f488 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import math import sys diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py index 8831dac336..3951e7b8ca 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import sys import paddle diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py index 323ddfb691..1ad51936b5 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import paddle import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/demo/fc_gan.py b/python/paddle/fluid/tests/demo/fc_gan.py index 3d92f50f0a..bd77779ce6 100644 --- a/python/paddle/fluid/tests/demo/fc_gan.py +++ b/python/paddle/fluid/tests/demo/fc_gan.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import errno import math import os diff --git a/python/paddle/fluid/tests/demo/file_reader/convert_data_to_recordio.py b/python/paddle/fluid/tests/demo/file_reader/convert_data_to_recordio.py index a00325d79b..45a104ec96 100644 --- a/python/paddle/fluid/tests/demo/file_reader/convert_data_to_recordio.py +++ b/python/paddle/fluid/tests/demo/file_reader/convert_data_to_recordio.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import sys import paddle.fluid as fluid import paddle.v2 as paddle diff --git a/python/paddle/fluid/tests/demo/file_reader/train.py b/python/paddle/fluid/tests/demo/file_reader/train.py index bc3a6dc81d..5f5d2848da 100644 --- a/python/paddle/fluid/tests/demo/file_reader/train.py +++ b/python/paddle/fluid/tests/demo/file_reader/train.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid import numpy import sys diff --git a/python/paddle/fluid/tests/demo/pyreader.py b/python/paddle/fluid/tests/demo/pyreader.py index 737644a25f..ec61e0ebae 100644 --- a/python/paddle/fluid/tests/demo/pyreader.py +++ b/python/paddle/fluid/tests/demo/pyreader.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy import six diff --git a/python/paddle/fluid/tests/no_test_concurrency.py b/python/paddle/fluid/tests/no_test_concurrency.py index 3bc0c9808e..b5d7676f4a 100644 --- a/python/paddle/fluid/tests/no_test_concurrency.py +++ b/python/paddle/fluid/tests/no_test_concurrency.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/notest_concurrency.py b/python/paddle/fluid/tests/notest_concurrency.py index 77107f8b36..fd9da4cce0 100644 --- a/python/paddle/fluid/tests/notest_concurrency.py +++ b/python/paddle/fluid/tests/notest_concurrency.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/test_beam_search_decoder.py b/python/paddle/fluid/tests/test_beam_search_decoder.py index 8bf750940d..fe8a9daa3b 100644 --- a/python/paddle/fluid/tests/test_beam_search_decoder.py +++ b/python/paddle/fluid/tests/test_beam_search_decoder.py @@ -15,6 +15,8 @@ A simple machine translation demo using beam search decoder. """ +from __future__ import print_function + import contextlib import numpy as np import paddle diff --git a/python/paddle/fluid/tests/test_cpp_reader.py b/python/paddle/fluid/tests/test_cpp_reader.py index 6cc291dfcf..b2a5253b95 100644 --- a/python/paddle/fluid/tests/test_cpp_reader.py +++ b/python/paddle/fluid/tests/test_cpp_reader.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid import numpy as np diff --git a/python/paddle/fluid/tests/test_data_feeder.py b/python/paddle/fluid/tests/test_data_feeder.py index 30b7a634a2..01de564aa4 100644 --- a/python/paddle/fluid/tests/test_data_feeder.py +++ b/python/paddle/fluid/tests/test_data_feeder.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid import unittest diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index fd45abd0a7..1467e72caa 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid import paddle.fluid.layers as layers from paddle.fluid.framework import Program, program_guard diff --git a/python/paddle/fluid/tests/test_error_clip.py b/python/paddle/fluid/tests/test_error_clip.py index e8edd7fbbb..3c977afc7c 100644 --- a/python/paddle/fluid/tests/test_error_clip.py +++ b/python/paddle/fluid/tests/test_error_clip.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import paddle import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/test_gradient_clip.py b/python/paddle/fluid/tests/test_gradient_clip.py index d530601f13..266687fcd0 100644 --- a/python/paddle/fluid/tests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/test_gradient_clip.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import paddle import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/test_if_else_op.py b/python/paddle/fluid/tests/test_if_else_op.py index 082f64c146..10918a985f 100644 --- a/python/paddle/fluid/tests/test_if_else_op.py +++ b/python/paddle/fluid/tests/test_if_else_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid.layers as layers from paddle.fluid.framework import Program, program_guard diff --git a/python/paddle/fluid/tests/test_lod_tensor.py b/python/paddle/fluid/tests/test_lod_tensor.py index f7a9dd4129..722b5f07b0 100644 --- a/python/paddle/fluid/tests/test_lod_tensor.py +++ b/python/paddle/fluid/tests/test_lod_tensor.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid from paddle.fluid.lod_tensor import create_lod_tensor, create_random_int_lodtensor import numpy as np diff --git a/python/paddle/fluid/tests/test_python_operator_overriding.py b/python/paddle/fluid/tests/test_python_operator_overriding.py index b5ac97eac5..5f92c437ec 100644 --- a/python/paddle/fluid/tests/test_python_operator_overriding.py +++ b/python/paddle/fluid/tests/test_python_operator_overriding.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/benchmark.py b/python/paddle/fluid/tests/unittests/benchmark.py index d334d8b60c..9ea95f3e87 100644 --- a/python/paddle/fluid/tests/unittests/benchmark.py +++ b/python/paddle/fluid/tests/unittests/benchmark.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import unittest import time diff --git a/python/paddle/fluid/tests/unittests/benchmark_sum_op.py b/python/paddle/fluid/tests/unittests/benchmark_sum_op.py index 91a5f1bca4..0e7338b839 100644 --- a/python/paddle/fluid/tests/unittests/benchmark_sum_op.py +++ b/python/paddle/fluid/tests/unittests/benchmark_sum_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/decorators.py b/python/paddle/fluid/tests/unittests/decorators.py index d1165e2a91..1a5f4540cf 100644 --- a/python/paddle/fluid/tests/unittests/decorators.py +++ b/python/paddle/fluid/tests/unittests/decorators.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid __all__ = ['many_times', 'prog_scope'] diff --git a/python/paddle/fluid/tests/unittests/dist_mnist.py b/python/paddle/fluid/tests/unittests/dist_mnist.py index 8f5ba33f7c..722b3e159a 100644 --- a/python/paddle/fluid/tests/unittests/dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/dist_mnist.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import argparse import time diff --git a/python/paddle/fluid/tests/unittests/dist_se_resnext.py b/python/paddle/fluid/tests/unittests/dist_se_resnext.py index f9ac6612df..b0ee6ff9f5 100644 --- a/python/paddle/fluid/tests/unittests/dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/dist_se_resnext.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import argparse import six diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py index 41125d38bd..ab4c5c3f36 100644 --- a/python/paddle/fluid/tests/unittests/dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/dist_transformer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import argparse import time diff --git a/python/paddle/fluid/tests/unittests/dist_word2vec.py b/python/paddle/fluid/tests/unittests/dist_word2vec.py index 54a70f4adb..0ad994a258 100644 --- a/python/paddle/fluid/tests/unittests/dist_word2vec.py +++ b/python/paddle/fluid/tests/unittests/dist_word2vec.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import argparse import time diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 75373ae2e1..972e44c952 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import random diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index 67c35e9de7..9be53c4609 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import multiprocessing import os import unittest diff --git a/python/paddle/fluid/tests/unittests/test_accuracy_op.py b/python/paddle/fluid/tests/unittests/test_accuracy_op.py index db1861fd10..1b2b53f2d4 100644 --- a/python/paddle/fluid/tests/unittests/test_accuracy_op.py +++ b/python/paddle/fluid/tests/unittests/test_accuracy_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_activation_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_activation_mkldnn_op.py index 7d554c2276..611d0dd076 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_mkldnn_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 34f9cf0620..30651c1326 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_adadelta_op.py b/python/paddle/fluid/tests/unittests/test_adadelta_op.py index 1b892e64c7..969a7da3b7 100644 --- a/python/paddle/fluid/tests/unittests/test_adadelta_op.py +++ b/python/paddle/fluid/tests/unittests/test_adadelta_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_adagrad_op.py b/python/paddle/fluid/tests/unittests/test_adagrad_op.py index 2f0ea79f4d..fc3b7ce2fd 100644 --- a/python/paddle/fluid/tests/unittests/test_adagrad_op.py +++ b/python/paddle/fluid/tests/unittests/test_adagrad_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index fa4b39879c..5318d2f976 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_adamax_op.py b/python/paddle/fluid/tests/unittests/test_adamax_op.py index 8099beefa5..a6d1be7616 100644 --- a/python/paddle/fluid/tests/unittests/test_adamax_op.py +++ b/python/paddle/fluid/tests/unittests/test_adamax_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_anchor_generator_op.py b/python/paddle/fluid/tests/unittests/test_anchor_generator_op.py index 9c7d5d41f0..d31eaa0114 100644 --- a/python/paddle/fluid/tests/unittests/test_anchor_generator_op.py +++ b/python/paddle/fluid/tests/unittests/test_anchor_generator_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import sys diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py index e04412f809..0712e102b3 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py index b29a102a38..7bc6f2599d 100644 --- a/python/paddle/fluid/tests/unittests/test_argsort_op.py +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_array_read_write_op.py b/python/paddle/fluid/tests/unittests/test_array_read_write_op.py index 0000fb0958..b86d0bc43a 100644 --- a/python/paddle/fluid/tests/unittests/test_array_read_write_op.py +++ b/python/paddle/fluid/tests/unittests/test_array_read_write_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.core as core import paddle.fluid.layers as layers diff --git a/python/paddle/fluid/tests/unittests/test_assign_op.py b/python/paddle/fluid/tests/unittests/test_assign_op.py index e93c02bd3e..ba2eecfaf1 100644 --- a/python/paddle/fluid/tests/unittests/test_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_assign_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import op_test import numpy import unittest diff --git a/python/paddle/fluid/tests/unittests/test_assign_value_op.py b/python/paddle/fluid/tests/unittests/test_assign_value_op.py index 02f2e6eddc..5a9d8efef1 100644 --- a/python/paddle/fluid/tests/unittests/test_assign_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_assign_value_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid import paddle.fluid.layers as layers import op_test diff --git a/python/paddle/fluid/tests/unittests/test_auc_op.py b/python/paddle/fluid/tests/unittests/test_auc_op.py index 6580c70ca6..5393a17e67 100644 --- a/python/paddle/fluid/tests/unittests/test_auc_op.py +++ b/python/paddle/fluid/tests/unittests/test_auc_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py index 18fa546159..1286cee8dc 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index f805fdc35f..80261eff4e 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_beam_search_decode_op.py b/python/paddle/fluid/tests/unittests/test_beam_search_decode_op.py index 4a3ac2a31e..51eee41ab2 100644 --- a/python/paddle/fluid/tests/unittests/test_beam_search_decode_op.py +++ b/python/paddle/fluid/tests/unittests/test_beam_search_decode_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_beam_search_op.py b/python/paddle/fluid/tests/unittests/test_beam_search_op.py index e8283fc942..c28dda4b53 100644 --- a/python/paddle/fluid/tests/unittests/test_beam_search_op.py +++ b/python/paddle/fluid/tests/unittests/test_beam_search_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import logging from paddle.fluid.op import Operator, DynamicRecurrentOp import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py index b04f25ef87..bed847c3c1 100644 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_tensor_product_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_tensor_product_op.py index d20a11e27e..46831119c5 100644 --- a/python/paddle/fluid/tests/unittests/test_bilinear_tensor_product_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_tensor_product_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py b/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py index ceeca25b74..5cc8e2ba15 100644 --- a/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py +++ b/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py @@ -11,6 +11,8 @@ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. + +from __future__ import print_function import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index 4ce9a4783e..2511c5c22e 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import sys diff --git a/python/paddle/fluid/tests/unittests/test_calc_gradient.py b/python/paddle/fluid/tests/unittests/test_calc_gradient.py index 7f2a9e6971..4120a18b72 100644 --- a/python/paddle/fluid/tests/unittests/test_calc_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_calc_gradient.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/test_cast_op.py b/python/paddle/fluid/tests/unittests/test_cast_op.py index b8d3ed3aa3..71a2ccb6da 100644 --- a/python/paddle/fluid/tests/unittests/test_cast_op.py +++ b/python/paddle/fluid/tests/unittests/test_cast_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import op_test import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_chunk_eval_op.py b/python/paddle/fluid/tests/unittests/test_chunk_eval_op.py index 354110f1f9..48eb8e9f75 100644 --- a/python/paddle/fluid/tests/unittests/test_chunk_eval_op.py +++ b/python/paddle/fluid/tests/unittests/test_chunk_eval_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py b/python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py index 129958fa28..6103c3aafc 100644 --- a/python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_clip_op.py b/python/paddle/fluid/tests/unittests/test_clip_op.py index 3df80c8ec8..32677bdb4c 100644 --- a/python/paddle/fluid/tests/unittests/test_clip_op.py +++ b/python/paddle/fluid/tests/unittests/test_clip_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index 405afebae8..437ad35538 100644 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import op_test import unittest import numpy diff --git a/python/paddle/fluid/tests/unittests/test_compat.py b/python/paddle/fluid/tests/unittests/test_compat.py index eabcced5d1..1c2c46f99a 100644 --- a/python/paddle/fluid/tests/unittests/test_compat.py +++ b/python/paddle/fluid/tests/unittests/test_compat.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.compat as cpt import six diff --git a/python/paddle/fluid/tests/unittests/test_concat_op.py b/python/paddle/fluid/tests/unittests/test_concat_op.py index e9f3c45dc4..436ab7d49f 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_op.py +++ b/python/paddle/fluid/tests/unittests/test_concat_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_conditional_block.py b/python/paddle/fluid/tests/unittests/test_conditional_block.py index 77869a1242..5b2b71d050 100644 --- a/python/paddle/fluid/tests/unittests/test_conditional_block.py +++ b/python/paddle/fluid/tests/unittests/test_conditional_block.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.layers as layers import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_const_value.py b/python/paddle/fluid/tests/unittests/test_const_value.py index 58ac6fa0a9..0b2431d772 100644 --- a/python/paddle/fluid/tests/unittests/test_const_value.py +++ b/python/paddle/fluid/tests/unittests/test_const_value.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.framework as framework diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py index d0de7ad52c..1902a98698 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 6467e302a5..6a2732e939 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py index 1cb50afca5..2a320e735b 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_op.py index e473ebacea..ddaf99fe06 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py index 2e55b89392..8d9075961c 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_conv_shift_op.py b/python/paddle/fluid/tests/unittests/test_conv_shift_op.py index d524832058..b7364e869e 100644 --- a/python/paddle/fluid/tests/unittests/test_conv_shift_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv_shift_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_cos_sim_op.py b/python/paddle/fluid/tests/unittests/test_cos_sim_op.py index 1b27cd5767..3c3fd6d4d7 100644 --- a/python/paddle/fluid/tests/unittests/test_cos_sim_op.py +++ b/python/paddle/fluid/tests/unittests/test_cos_sim_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_create_op_doc_string.py b/python/paddle/fluid/tests/unittests/test_create_op_doc_string.py index 07c89eefc3..fd34c8fc93 100644 --- a/python/paddle/fluid/tests/unittests/test_create_op_doc_string.py +++ b/python/paddle/fluid/tests/unittests/test_create_op_doc_string.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.layers as layers diff --git a/python/paddle/fluid/tests/unittests/test_crf_decoding_op.py b/python/paddle/fluid/tests/unittests/test_crf_decoding_op.py index 122b076c2d..51bd1300e6 100644 --- a/python/paddle/fluid/tests/unittests/test_crf_decoding_op.py +++ b/python/paddle/fluid/tests/unittests/test_crf_decoding_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import random import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_crop_op.py b/python/paddle/fluid/tests/unittests/test_crop_op.py index 4016089c01..d7bcfba8de 100644 --- a/python/paddle/fluid/tests/unittests/test_crop_op.py +++ b/python/paddle/fluid/tests/unittests/test_crop_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py index 86ac159323..fa367f95fc 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest, randomize_probability diff --git a/python/paddle/fluid/tests/unittests/test_ctc_align.py b/python/paddle/fluid/tests/unittests/test_ctc_align.py index 131b4076f4..5f17d2d407 100644 --- a/python/paddle/fluid/tests/unittests/test_ctc_align.py +++ b/python/paddle/fluid/tests/unittests/test_ctc_align.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import sys import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_cumsum_op.py b/python/paddle/fluid/tests/unittests/test_cumsum_op.py index 04e7f0b945..13a4eacece 100644 --- a/python/paddle/fluid/tests/unittests/test_cumsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_cumsum_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_data_balance.py b/python/paddle/fluid/tests/unittests/test_data_balance.py index 09edf05fd7..e39eedd282 100644 --- a/python/paddle/fluid/tests/unittests/test_data_balance.py +++ b/python/paddle/fluid/tests/unittests/test_data_balance.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid import paddle diff --git a/python/paddle/fluid/tests/unittests/test_debugger.py b/python/paddle/fluid/tests/unittests/test_debugger.py index 870952f2f9..f4c9466d63 100644 --- a/python/paddle/fluid/tests/unittests/test_debugger.py +++ b/python/paddle/fluid/tests/unittests/test_debugger.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_decayed_adagrad_op.py b/python/paddle/fluid/tests/unittests/test_decayed_adagrad_op.py index 84c44d4817..a664a1529f 100644 --- a/python/paddle/fluid/tests/unittests/test_decayed_adagrad_op.py +++ b/python/paddle/fluid/tests/unittests/test_decayed_adagrad_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_default_scope_funcs.py b/python/paddle/fluid/tests/unittests/test_default_scope_funcs.py index 868bcca881..01a7b68248 100644 --- a/python/paddle/fluid/tests/unittests/test_default_scope_funcs.py +++ b/python/paddle/fluid/tests/unittests/test_default_scope_funcs.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from paddle.fluid.default_scope_funcs import * import unittest diff --git a/python/paddle/fluid/tests/unittests/test_desc_clone.py b/python/paddle/fluid/tests/unittests/test_desc_clone.py index 8603d3a5b3..88d44e453c 100644 --- a/python/paddle/fluid/tests/unittests/test_desc_clone.py +++ b/python/paddle/fluid/tests/unittests/test_desc_clone.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import argparse import time diff --git a/python/paddle/fluid/tests/unittests/test_detection_map_op.py b/python/paddle/fluid/tests/unittests/test_detection_map_op.py index a471f62852..f6eb8f2c6d 100644 --- a/python/paddle/fluid/tests/unittests/test_detection_map_op.py +++ b/python/paddle/fluid/tests/unittests/test_detection_map_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import six diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index ab028dd36f..4c71181d0d 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import time import unittest diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index b3ccec9a7d..4ec68d411b 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest from test_dist_base import TestDistBase diff --git a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py index a33a338fc1..16525f6fdb 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest from test_dist_base import TestDistBase diff --git a/python/paddle/fluid/tests/unittests/test_dist_train.py b/python/paddle/fluid/tests/unittests/test_dist_train.py index aab8969a96..52a655635a 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_train.py +++ b/python/paddle/fluid/tests/unittests/test_dist_train.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import os import time import unittest diff --git a/python/paddle/fluid/tests/unittests/test_dist_transformer.py b/python/paddle/fluid/tests/unittests/test_dist_transformer.py index 68cd35d751..313207ff9c 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transformer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest from test_dist_base import TestDistBase diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index c5ada4f743..1531e43799 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import math import unittest diff --git a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py index 543d0f9dc2..e43992c488 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py +++ b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest from test_dist_base import TestDistBase diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index eaa3435a86..0296bc2af4 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_dyn_rnn.py b/python/paddle/fluid/tests/unittests/test_dyn_rnn.py index fdc6adc93b..d84dab1499 100644 --- a/python/paddle/fluid/tests/unittests/test_dyn_rnn.py +++ b/python/paddle/fluid/tests/unittests/test_dyn_rnn.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid import paddle import unittest diff --git a/python/paddle/fluid/tests/unittests/test_dynrnn_gradient_check.py b/python/paddle/fluid/tests/unittests/test_dynrnn_gradient_check.py index 7756885166..9d635f36fe 100644 --- a/python/paddle/fluid/tests/unittests/test_dynrnn_gradient_check.py +++ b/python/paddle/fluid/tests/unittests/test_dynrnn_gradient_check.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy import random import collections diff --git a/python/paddle/fluid/tests/unittests/test_dynrnn_static_input.py b/python/paddle/fluid/tests/unittests/test_dynrnn_static_input.py index d182889a97..b4359fc69a 100644 --- a/python/paddle/fluid/tests/unittests/test_dynrnn_static_input.py +++ b/python/paddle/fluid/tests/unittests/test_dynrnn_static_input.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_edit_distance_op.py b/python/paddle/fluid/tests/unittests/test_edit_distance_op.py index 816562621b..4d03523025 100644 --- a/python/paddle/fluid/tests/unittests/test_edit_distance_op.py +++ b/python/paddle/fluid/tests/unittests/test_edit_distance_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_mkldnn_op.py index bcdbfc8e52..d85cc1f856 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_mkldnn_op.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index fb9a496126..5aec5d8e38 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index bfe022af6d..cadaf1df53 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_gradient_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_gradient_op.py index 6f35004489..9f452ffde7 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_gradient_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_gradient_op.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py index b6cd18a579..43c58710ba 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_min_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_min_op.py index 92099724fe..45c861e2c3 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_min_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_min_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index 2742bb21d9..775c2253ab 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py index a3fd18669c..7bf642f03f 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py index 1854232194..6cb88a8bb1 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_exception.py b/python/paddle/fluid/tests/unittests/test_exception.py index a43df91342..798ed53cdd 100644 --- a/python/paddle/fluid/tests/unittests/test_exception.py +++ b/python/paddle/fluid/tests/unittests/test_exception.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.compat as cpt import paddle.fluid.core as core import unittest diff --git a/python/paddle/fluid/tests/unittests/test_executor_and_mul.py b/python/paddle/fluid/tests/unittests/test_executor_and_mul.py index e1272c1d6d..b1f89eca6e 100644 --- a/python/paddle/fluid/tests/unittests/test_executor_and_mul.py +++ b/python/paddle/fluid/tests/unittests/test_executor_and_mul.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy diff --git a/python/paddle/fluid/tests/unittests/test_expand_op.py b/python/paddle/fluid/tests/unittests/test_expand_op.py index a91e3aef5a..67a8d8f072 100644 --- a/python/paddle/fluid/tests/unittests/test_expand_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_extract_rows_op.py b/python/paddle/fluid/tests/unittests/test_extract_rows_op.py index 6a41c44fe6..8629bcf0f2 100644 --- a/python/paddle/fluid/tests/unittests/test_extract_rows_op.py +++ b/python/paddle/fluid/tests/unittests/test_extract_rows_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index 026ac2112b..d84ebed3fa 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import math diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 6c6aa9d3bb..cc0494774a 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py index 3f547f3c48..8795fa8c1c 100644 --- a/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_feed_fetch_method.py b/python/paddle/fluid/tests/unittests/test_feed_fetch_method.py index 8b9da84311..b823d397e9 100644 --- a/python/paddle/fluid/tests/unittests/test_feed_fetch_method.py +++ b/python/paddle/fluid/tests/unittests/test_feed_fetch_method.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid.core as core import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_fetch_var.py b/python/paddle/fluid/tests/unittests/test_fetch_var.py index e6f37f0b4c..de339d821b 100644 --- a/python/paddle/fluid/tests/unittests/test_fetch_var.py +++ b/python/paddle/fluid/tests/unittests/test_fetch_var.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid import paddle.fluid.layers as layers import op_test diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_batch_size_like_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_batch_size_like_op.py index 0c75cf33f5..fdc8a118e5 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_batch_size_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_batch_size_like_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index 5e2ddb218a..44fb1d047d 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_fill_op.py b/python/paddle/fluid/tests/unittests/test_fill_op.py index 762d29199e..b734ee05b3 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_fill_zeros_like_op.py b/python/paddle/fluid/tests/unittests/test_fill_zeros_like_op.py index c9b3e4ba13..eec73d0beb 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_zeros_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_zeros_like_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_flatten_op.py b/python/paddle/fluid/tests/unittests/test_flatten_op.py index f8692ce2ea..17b01e0312 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_framework_debug_str.py b/python/paddle/fluid/tests/unittests/test_framework_debug_str.py index c906c74afe..72f43e56cc 100644 --- a/python/paddle/fluid/tests/unittests/test_framework_debug_str.py +++ b/python/paddle/fluid/tests/unittests/test_framework_debug_str.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest from paddle.fluid.framework import Program diff --git a/python/paddle/fluid/tests/unittests/test_ftrl_op.py b/python/paddle/fluid/tests/unittests/test_ftrl_op.py index 5f7581391a..a6390b054f 100644 --- a/python/paddle/fluid/tests/unittests/test_ftrl_op.py +++ b/python/paddle/fluid/tests/unittests/test_ftrl_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py index ec0a939e9e..97e1b9061a 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 4ae9086480..bd5785aa55 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_batch_size_like_op.py b/python/paddle/fluid/tests/unittests/test_gaussian_random_batch_size_like_op.py index 1398166a74..9a0631fa26 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_batch_size_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_random_batch_size_like_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_gaussian_random_mkldnn_op.py index 3ae877a608..9777ec3906 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_random_mkldnn_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest from test_gaussian_random_op import TestGaussianRandomOp diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py index 8481500fd7..496aa41110 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy diff --git a/python/paddle/fluid/tests/unittests/test_get_places_op.py b/python/paddle/fluid/tests/unittests/test_get_places_op.py index 964423e2d2..441666a97b 100644 --- a/python/paddle/fluid/tests/unittests/test_get_places_op.py +++ b/python/paddle/fluid/tests/unittests/test_get_places_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid from paddle.fluid.layers.device import get_places import decorators diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py index 1d8db37fe7..001fd7efb1 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import math diff --git a/python/paddle/fluid/tests/unittests/test_gru_unit_op.py b/python/paddle/fluid/tests/unittests/test_gru_unit_op.py index 87a9eba4d9..b5a66fdf08 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_unit_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_unit_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import math import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_hinge_loss_op.py b/python/paddle/fluid/tests/unittests/test_hinge_loss_op.py index 70586c6be3..1eb441e2c5 100644 --- a/python/paddle/fluid/tests/unittests/test_hinge_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_hinge_loss_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index daa5da8d95..6948ae3002 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import math diff --git a/python/paddle/fluid/tests/unittests/test_huber_loss_op.py b/python/paddle/fluid/tests/unittests/test_huber_loss_op.py index a8d0a77625..0055ef0052 100644 --- a/python/paddle/fluid/tests/unittests/test_huber_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_huber_loss_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_im2sequence_op.py b/python/paddle/fluid/tests/unittests/test_im2sequence_op.py index 13bc576874..833e46483c 100644 --- a/python/paddle/fluid/tests/unittests/test_im2sequence_op.py +++ b/python/paddle/fluid/tests/unittests/test_im2sequence_op.py @@ -11,6 +11,8 @@ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. + +from __future__ import print_function import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_image_classification_layer.py b/python/paddle/fluid/tests/unittests/test_image_classification_layer.py index 23b1ed957a..405637969a 100644 --- a/python/paddle/fluid/tests/unittests/test_image_classification_layer.py +++ b/python/paddle/fluid/tests/unittests/test_image_classification_layer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/test_infer_shape.py b/python/paddle/fluid/tests/unittests/test_infer_shape.py index ede51f6550..a3d700aad8 100644 --- a/python/paddle/fluid/tests/unittests/test_infer_shape.py +++ b/python/paddle/fluid/tests/unittests/test_infer_shape.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import six diff --git a/python/paddle/fluid/tests/unittests/test_inference_model_io.py b/python/paddle/fluid/tests/unittests/test_inference_model_io.py index 66cc78e4d4..9962702f69 100644 --- a/python/paddle/fluid/tests/unittests/test_inference_model_io.py +++ b/python/paddle/fluid/tests/unittests/test_inference_model_io.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import six diff --git a/python/paddle/fluid/tests/unittests/test_initializer.py b/python/paddle/fluid/tests/unittests/test_initializer.py index b215e37986..ab7183f88d 100644 --- a/python/paddle/fluid/tests/unittests/test_initializer.py +++ b/python/paddle/fluid/tests/unittests/test_initializer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import unittest diff --git a/python/paddle/fluid/tests/unittests/test_iou_similarity_op.py b/python/paddle/fluid/tests/unittests/test_iou_similarity_op.py index eff4212d91..7c1808cf99 100644 --- a/python/paddle/fluid/tests/unittests/test_iou_similarity_op.py +++ b/python/paddle/fluid/tests/unittests/test_iou_similarity_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import numpy.random as random diff --git a/python/paddle/fluid/tests/unittests/test_is_empty_op.py b/python/paddle/fluid/tests/unittests/test_is_empty_op.py index 11121d9b65..26d607718a 100644 --- a/python/paddle/fluid/tests/unittests/test_is_empty_op.py +++ b/python/paddle/fluid/tests/unittests/test_is_empty_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_l1_norm_op.py b/python/paddle/fluid/tests/unittests/test_l1_norm_op.py index fa5b18a16f..4e24a78ee5 100644 --- a/python/paddle/fluid/tests/unittests/test_l1_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_l1_norm_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import unittest from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_label_smooth_op.py b/python/paddle/fluid/tests/unittests/test_label_smooth_op.py index ca21289a0d..62d385bc52 100644 --- a/python/paddle/fluid/tests/unittests/test_label_smooth_op.py +++ b/python/paddle/fluid/tests/unittests/test_label_smooth_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index 295887ccd1..fb6c43136f 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py index e628195e72..0d3e6d73e0 100644 --- a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import copy import math import unittest diff --git a/python/paddle/fluid/tests/unittests/test_linear_chain_crf_op.py b/python/paddle/fluid/tests/unittests/test_linear_chain_crf_op.py index 696d0ab4fa..6e31e9204e 100644 --- a/python/paddle/fluid/tests/unittests/test_linear_chain_crf_op.py +++ b/python/paddle/fluid/tests/unittests/test_linear_chain_crf_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import random import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py b/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py index 1cdc695010..48b52a5412 100644 --- a/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py +++ b/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle import paddle.fluid as fluid import os diff --git a/python/paddle/fluid/tests/unittests/test_lod_array_length_op.py b/python/paddle/fluid/tests/unittests/test_lod_array_length_op.py index d8b4e40662..15485df5ac 100644 --- a/python/paddle/fluid/tests/unittests/test_lod_array_length_op.py +++ b/python/paddle/fluid/tests/unittests/test_lod_array_length_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.layers as layers from paddle.fluid.executor import Executor diff --git a/python/paddle/fluid/tests/unittests/test_lod_rank_table.py b/python/paddle/fluid/tests/unittests/test_lod_rank_table.py index d53ead381d..865ca118d5 100644 --- a/python/paddle/fluid/tests/unittests/test_lod_rank_table.py +++ b/python/paddle/fluid/tests/unittests/test_lod_rank_table.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from paddle.fluid.layers import data from paddle.fluid.layers.control_flow import lod_rank_table from paddle.fluid.executor import Executor diff --git a/python/paddle/fluid/tests/unittests/test_lod_reset_op.py b/python/paddle/fluid/tests/unittests/test_lod_reset_op.py index 77905c4b96..31f364a42f 100644 --- a/python/paddle/fluid/tests/unittests/test_lod_reset_op.py +++ b/python/paddle/fluid/tests/unittests/test_lod_reset_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_lod_tensor_array.py b/python/paddle/fluid/tests/unittests/test_lod_tensor_array.py index 0ac6d9b81d..6ad27de9a0 100644 --- a/python/paddle/fluid/tests/unittests/test_lod_tensor_array.py +++ b/python/paddle/fluid/tests/unittests/test_lod_tensor_array.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.core as core import numpy diff --git a/python/paddle/fluid/tests/unittests/test_lod_tensor_array_ops.py b/python/paddle/fluid/tests/unittests/test_lod_tensor_array_ops.py index 9789ff4af6..6a78ef5078 100644 --- a/python/paddle/fluid/tests/unittests/test_lod_tensor_array_ops.py +++ b/python/paddle/fluid/tests/unittests/test_lod_tensor_array_ops.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.core as core import numpy diff --git a/python/paddle/fluid/tests/unittests/test_log_loss_op.py b/python/paddle/fluid/tests/unittests/test_log_loss_op.py index d3980b8db9..784f4f648d 100644 --- a/python/paddle/fluid/tests/unittests/test_log_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_log_loss_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_logical_op.py b/python/paddle/fluid/tests/unittests/test_logical_op.py index 1d7dfe60f2..521851a3d5 100644 --- a/python/paddle/fluid/tests/unittests/test_logical_op.py +++ b/python/paddle/fluid/tests/unittests/test_logical_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import op_test import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py index aa9eae1e88..7f75d0e6e9 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index 77fb8154f0..4990ee898d 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_lrn_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_lrn_mkldnn_op.py index 966a16dc87..f6bb2ab7a6 100644 --- a/python/paddle/fluid/tests/unittests/test_lrn_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lrn_mkldnn_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest from test_lrn_op import TestLRNOp diff --git a/python/paddle/fluid/tests/unittests/test_lrn_op.py b/python/paddle/fluid/tests/unittests/test_lrn_op.py index b0930440f2..bb91f26bbb 100644 --- a/python/paddle/fluid/tests/unittests/test_lrn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lrn_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_lstm_op.py b/python/paddle/fluid/tests/unittests/test_lstm_op.py index 705a24bd8f..76a24123fc 100644 --- a/python/paddle/fluid/tests/unittests/test_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstm_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_lstm_unit_op.py b/python/paddle/fluid/tests/unittests/test_lstm_unit_op.py index e343265874..eaa6b774c4 100644 --- a/python/paddle/fluid/tests/unittests/test_lstm_unit_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstm_unit_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_lstmp_op.py b/python/paddle/fluid/tests/unittests/test_lstmp_op.py index ed2262da4b..9c3ec45515 100644 --- a/python/paddle/fluid/tests/unittests/test_lstmp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstmp_op.py @@ -11,6 +11,8 @@ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. + +from __future__ import print_function import unittest import numpy as np import test_lstm_op as LstmTest diff --git a/python/paddle/fluid/tests/unittests/test_margin_rank_loss_op.py b/python/paddle/fluid/tests/unittests/test_margin_rank_loss_op.py index 97c112487f..4a7e952436 100644 --- a/python/paddle/fluid/tests/unittests/test_margin_rank_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_margin_rank_loss_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch.py b/python/paddle/fluid/tests/unittests/test_math_op_patch.py index 852a80261e..b25d40a3a1 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import decorators import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/test_matmul_op.py b/python/paddle/fluid/tests/unittests/test_matmul_op.py index cae2c8fa87..abf10437d8 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_maxout_op.py b/python/paddle/fluid/tests/unittests/test_maxout_op.py index 2151853ae1..d588b22fe2 100644 --- a/python/paddle/fluid/tests/unittests/test_maxout_op.py +++ b/python/paddle/fluid/tests/unittests/test_maxout_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_mean_iou.py b/python/paddle/fluid/tests/unittests/test_mean_iou.py index 32b4ee1847..03e9448317 100644 --- a/python/paddle/fluid/tests/unittests/test_mean_iou.py +++ b/python/paddle/fluid/tests/unittests/test_mean_iou.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from __future__ import division import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_mean_op.py b/python/paddle/fluid/tests/unittests/test_mean_op.py index 15472a8fc4..ff338f0e00 100644 --- a/python/paddle/fluid/tests/unittests/test_mean_op.py +++ b/python/paddle/fluid/tests/unittests/test_mean_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_memory_usage.py b/python/paddle/fluid/tests/unittests/test_memory_usage.py index f9daf83652..4cdb5b5d9f 100644 --- a/python/paddle/fluid/tests/unittests/test_memory_usage.py +++ b/python/paddle/fluid/tests/unittests/test_memory_usage.py @@ -34,7 +34,7 @@ def train_simulator(test_batch_size=10): sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) sgd_optimizer.minimize(avg_cost) - # Calculate memory usage in current network config + # Calculate memory usage in current network config lower_usage, upper_usage, unit = fluid.contrib.memory_usage( fluid.default_main_program(), batch_size=test_batch_size) diff --git a/python/paddle/fluid/tests/unittests/test_merge_ids_op.py b/python/paddle/fluid/tests/unittests/test_merge_ids_op.py index f209bdf30f..26ce702411 100644 --- a/python/paddle/fluid/tests/unittests/test_merge_ids_op.py +++ b/python/paddle/fluid/tests/unittests/test_merge_ids_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_mine_hard_examples_op.py b/python/paddle/fluid/tests/unittests/test_mine_hard_examples_op.py index 54ee85c1a7..4e5cc91268 100644 --- a/python/paddle/fluid/tests/unittests/test_mine_hard_examples_op.py +++ b/python/paddle/fluid/tests/unittests/test_mine_hard_examples_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import sys diff --git a/python/paddle/fluid/tests/unittests/test_minus_op.py b/python/paddle/fluid/tests/unittests/test_minus_op.py index ee32bd4992..54253b17b9 100644 --- a/python/paddle/fluid/tests/unittests/test_minus_op.py +++ b/python/paddle/fluid/tests/unittests/test_minus_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_modified_huber_loss_op.py b/python/paddle/fluid/tests/unittests/test_modified_huber_loss_op.py index 62035efe8e..02fecfe47e 100644 --- a/python/paddle/fluid/tests/unittests/test_modified_huber_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_modified_huber_loss_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index c75d3bd276..7137fd0fdb 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_mul_op.py b/python/paddle/fluid/tests/unittests/test_mul_op.py index bbc782c1bc..fca4ffa88b 100644 --- a/python/paddle/fluid/tests/unittests/test_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_mul_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_multi_file_reader.py b/python/paddle/fluid/tests/unittests/test_multi_file_reader.py index cb0ea96ff6..09788868cc 100644 --- a/python/paddle/fluid/tests/unittests/test_multi_file_reader.py +++ b/python/paddle/fluid/tests/unittests/test_multi_file_reader.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py index 7fc9f55044..4fae11e928 100644 --- a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py +++ b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py b/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py index 10cb78a08d..df0562dcc7 100644 --- a/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py +++ b/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py @@ -11,6 +11,8 @@ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. + +from __future__ import print_function import unittest import numpy as np import copy diff --git a/python/paddle/fluid/tests/unittests/test_multihead_attention.py b/python/paddle/fluid/tests/unittests/test_multihead_attention.py index 80c3c67967..f60da862ac 100644 --- a/python/paddle/fluid/tests/unittests/test_multihead_attention.py +++ b/python/paddle/fluid/tests/unittests/test_multihead_attention.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_multiplex_op.py b/python/paddle/fluid/tests/unittests/test_multiplex_op.py index 03cad8b43b..1567a74808 100644 --- a/python/paddle/fluid/tests/unittests/test_multiplex_op.py +++ b/python/paddle/fluid/tests/unittests/test_multiplex_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_nce.py b/python/paddle/fluid/tests/unittests/test_nce.py index 7431a142c5..0745bd274f 100644 --- a/python/paddle/fluid/tests/unittests/test_nce.py +++ b/python/paddle/fluid/tests/unittests/test_nce.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_network_with_dtype.py b/python/paddle/fluid/tests/unittests/test_network_with_dtype.py index d4835dd184..60dcf195da 100644 --- a/python/paddle/fluid/tests/unittests/test_network_with_dtype.py +++ b/python/paddle/fluid/tests/unittests/test_network_with_dtype.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_norm_op.py b/python/paddle/fluid/tests/unittests/test_norm_op.py index 108a665f37..22bc45ff1e 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_norm_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_normalization_wrapper.py b/python/paddle/fluid/tests/unittests/test_normalization_wrapper.py index 198c68866d..24fdcf8c88 100644 --- a/python/paddle/fluid/tests/unittests/test_normalization_wrapper.py +++ b/python/paddle/fluid/tests/unittests/test_normalization_wrapper.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_nvprof.py b/python/paddle/fluid/tests/unittests/test_nvprof.py index 226e5e5d11..da943d64da 100644 --- a/python/paddle/fluid/tests/unittests/test_nvprof.py +++ b/python/paddle/fluid/tests/unittests/test_nvprof.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import os import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_one_hot_op.py b/python/paddle/fluid/tests/unittests/test_one_hot_op.py index 06fccd39ac..7afdae804a 100644 --- a/python/paddle/fluid/tests/unittests/test_one_hot_op.py +++ b/python/paddle/fluid/tests/unittests/test_one_hot_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import math diff --git a/python/paddle/fluid/tests/unittests/test_op_support_gpu.py b/python/paddle/fluid/tests/unittests/test_op_support_gpu.py index 5fafb8280e..e203fccd03 100644 --- a/python/paddle/fluid/tests/unittests/test_op_support_gpu.py +++ b/python/paddle/fluid/tests/unittests/test_op_support_gpu.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_operator.py b/python/paddle/fluid/tests/unittests/test_operator.py index 5e418fe6ac..544fca8cec 100644 --- a/python/paddle/fluid/tests/unittests/test_operator.py +++ b/python/paddle/fluid/tests/unittests/test_operator.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.op as op diff --git a/python/paddle/fluid/tests/unittests/test_operator_desc.py b/python/paddle/fluid/tests/unittests/test_operator_desc.py index 21a113f509..6d01955993 100644 --- a/python/paddle/fluid/tests/unittests/test_operator_desc.py +++ b/python/paddle/fluid/tests/unittests/test_operator_desc.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index 18921d727f..4374d198f2 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.framework as framework diff --git a/python/paddle/fluid/tests/unittests/test_pad_op.py b/python/paddle/fluid/tests/unittests/test_pad_op.py index 300f3ffcb8..58e56ca1a4 100644 --- a/python/paddle/fluid/tests/unittests/test_pad_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py index d17e493c36..6d6917300c 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.dataset.conll05 as conll05 import paddle.fluid as fluid import unittest diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py index a43f2e7c49..372ef748b2 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.dataset.flowers as flowers import math import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py index 9448d89cd5..893acd763f 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from parallel_executor_test_base import TestParallelExecutorBase import paddle.fluid as fluid import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py index 672e94480a..cc2d692e18 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid import paddle.fluid.layers.ops as ops from paddle.fluid.initializer import init_on_cpu diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py index fcb5947ff0..f5a0ba6246 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid import paddle.fluid.core as core import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py index 8203d5d1fc..5ad922725a 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid import transformer_model import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_parallel_op.py b/python/paddle/fluid/tests/unittests/test_parallel_op.py index c9617e3677..d7b9af8bac 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_op.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/test_parameter.py b/python/paddle/fluid/tests/unittests/test_parameter.py index e09865074e..df42e6cb9a 100644 --- a/python/paddle/fluid/tests/unittests/test_parameter.py +++ b/python/paddle/fluid/tests/unittests/test_parameter.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest from paddle.fluid.framework import default_main_program import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_polygon_box_transform.py b/python/paddle/fluid/tests/unittests/test_polygon_box_transform.py index 8aff4e87f6..dfedf8190f 100644 --- a/python/paddle/fluid/tests/unittests/test_polygon_box_transform.py +++ b/python/paddle/fluid/tests/unittests/test_polygon_box_transform.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py index 003ebba18b..14d7ed9057 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest from test_pool2d_op import TestPool2d_Op, TestCase1, TestCase2, TestCase3, TestCase4, TestCase5 diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_op.py index a75194f34a..26969bd523 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_op.py b/python/paddle/fluid/tests/unittests/test_pool3d_op.py index 8b96a0e22a..77045c1307 100644 --- a/python/paddle/fluid/tests/unittests/test_pool3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool3d_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_pool_max_op.py b/python/paddle/fluid/tests/unittests/test_pool_max_op.py index 9a23fde340..488ff431d4 100644 --- a/python/paddle/fluid/tests/unittests/test_pool_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool_max_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py b/python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py index fcb308ae2c..afe8d212d6 100644 --- a/python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py +++ b/python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import itertools import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_precision_recall_op.py b/python/paddle/fluid/tests/unittests/test_precision_recall_op.py index 5ae425fee1..6456376259 100644 --- a/python/paddle/fluid/tests/unittests/test_precision_recall_op.py +++ b/python/paddle/fluid/tests/unittests/test_precision_recall_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index ae19a553bb..e0ea74b6ad 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_preprocessor.py b/python/paddle/fluid/tests/unittests/test_preprocessor.py index 6a82746c61..98e609b769 100644 --- a/python/paddle/fluid/tests/unittests/test_preprocessor.py +++ b/python/paddle/fluid/tests/unittests/test_preprocessor.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_print_op.py b/python/paddle/fluid/tests/unittests/test_print_op.py index b461c5c940..ac682d6181 100644 --- a/python/paddle/fluid/tests/unittests/test_print_op.py +++ b/python/paddle/fluid/tests/unittests/test_print_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.core as core from paddle.fluid.executor import Executor diff --git a/python/paddle/fluid/tests/unittests/test_prior_box_op.py b/python/paddle/fluid/tests/unittests/test_prior_box_op.py index e15554737b..7381b74af7 100644 --- a/python/paddle/fluid/tests/unittests/test_prior_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_prior_box_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import sys diff --git a/python/paddle/fluid/tests/unittests/test_profiler.py b/python/paddle/fluid/tests/unittests/test_profiler.py index 705d01165a..38a7c913bf 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler.py +++ b/python/paddle/fluid/tests/unittests/test_profiler.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import os import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_protobuf.py b/python/paddle/fluid/tests/unittests/test_protobuf.py index c3f1fa8018..7b80927c48 100644 --- a/python/paddle/fluid/tests/unittests/test_protobuf.py +++ b/python/paddle/fluid/tests/unittests/test_protobuf.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid.proto.framework_pb2 as framework_pb2 import unittest diff --git a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py index f7087299cf..d24b5cbd06 100644 --- a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py +++ b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.core as core import paddle.compat as cpt diff --git a/python/paddle/fluid/tests/unittests/test_proximal_adagrad_op.py b/python/paddle/fluid/tests/unittests/test_proximal_adagrad_op.py index 3c26895850..57e96f1fa3 100644 --- a/python/paddle/fluid/tests/unittests/test_proximal_adagrad_op.py +++ b/python/paddle/fluid/tests/unittests/test_proximal_adagrad_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_proximal_gd_op.py b/python/paddle/fluid/tests/unittests/test_proximal_gd_op.py index 137594b9a0..067502baec 100644 --- a/python/paddle/fluid/tests/unittests/test_proximal_gd_op.py +++ b/python/paddle/fluid/tests/unittests/test_proximal_gd_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_py_reader_push_pop.py b/python/paddle/fluid/tests/unittests/test_py_reader_push_pop.py index f9bda5e470..3efe5aac88 100644 --- a/python/paddle/fluid/tests/unittests/test_py_reader_push_pop.py +++ b/python/paddle/fluid/tests/unittests/test_py_reader_push_pop.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py b/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py index 9a379bdbaa..931cac409f 100644 --- a/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py +++ b/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_random_crop_op.py b/python/paddle/fluid/tests/unittests/test_random_crop_op.py index 27e5db4991..f29dddff7a 100644 --- a/python/paddle/fluid/tests/unittests/test_random_crop_op.py +++ b/python/paddle/fluid/tests/unittests/test_random_crop_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_rank_loss_op.py b/python/paddle/fluid/tests/unittests/test_rank_loss_op.py index 7eba1e2077..c9fa24b103 100644 --- a/python/paddle/fluid/tests/unittests/test_rank_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_rank_loss_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_reader_reset.py b/python/paddle/fluid/tests/unittests/test_reader_reset.py index 698612acf4..8ad11d76f6 100644 --- a/python/paddle/fluid/tests/unittests/test_reader_reset.py +++ b/python/paddle/fluid/tests/unittests/test_reader_reset.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid as fluid import paddle import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py index 09c3167152..c5210bb208 100644 --- a/python/paddle/fluid/tests/unittests/test_recordio_reader.py +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/test_recurrent_op.py b/python/paddle/fluid/tests/unittests/test_recurrent_op.py index 2e22df2beb..6dfc85e301 100644 --- a/python/paddle/fluid/tests/unittests/test_recurrent_op.py +++ b/python/paddle/fluid/tests/unittests/test_recurrent_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.layers as layers diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 06d116601b..328f0f0011 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_registry.py b/python/paddle/fluid/tests/unittests/test_registry.py index a361c4624e..7381bb61eb 100644 --- a/python/paddle/fluid/tests/unittests/test_registry.py +++ b/python/paddle/fluid/tests/unittests/test_registry.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function import unittest import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/test_regularizer.py b/python/paddle/fluid/tests/unittests/test_regularizer.py index 9b1c4ceada..6727335c60 100644 --- a/python/paddle/fluid/tests/unittests/test_regularizer.py +++ b/python/paddle/fluid/tests/unittests/test_regularizer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.framework as framework diff --git a/python/paddle/fluid/tests/unittests/test_reorder_lod_tensor.py b/python/paddle/fluid/tests/unittests/test_reorder_lod_tensor.py index e51408944c..28c8c4699a 100644 --- a/python/paddle/fluid/tests/unittests/test_reorder_lod_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_reorder_lod_tensor.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index 2f5558578a..1de35dc35b 100644 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_reverse_op.py b/python/paddle/fluid/tests/unittests/test_reverse_op.py index f845575a02..e83f548c22 100644 --- a/python/paddle/fluid/tests/unittests/test_reverse_op.py +++ b/python/paddle/fluid/tests/unittests/test_reverse_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py index 0d84a5853e..3d4623c74d 100644 --- a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py +++ b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_rnn_memory_helper_op.py b/python/paddle/fluid/tests/unittests/test_rnn_memory_helper_op.py index 178606f059..9bfec8e9bd 100644 --- a/python/paddle/fluid/tests/unittests/test_rnn_memory_helper_op.py +++ b/python/paddle/fluid/tests/unittests/test_rnn_memory_helper_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest from paddle.fluid.framework import Program diff --git a/python/paddle/fluid/tests/unittests/test_roi_pool_op.py b/python/paddle/fluid/tests/unittests/test_roi_pool_op.py index 9b0a3f26b7..ed7f467835 100644 --- a/python/paddle/fluid/tests/unittests/test_roi_pool_op.py +++ b/python/paddle/fluid/tests/unittests/test_roi_pool_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import math diff --git a/python/paddle/fluid/tests/unittests/test_row_conv_op.py b/python/paddle/fluid/tests/unittests/test_row_conv_op.py index 07dcd10868..2f13f067ef 100644 --- a/python/paddle/fluid/tests/unittests/test_row_conv_op.py +++ b/python/paddle/fluid/tests/unittests/test_row_conv_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py index df6e0faaca..08c462d903 100644 --- a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_scale_op.py b/python/paddle/fluid/tests/unittests/test_scale_op.py index 53f59c3990..0a8a43253d 100644 --- a/python/paddle/fluid/tests/unittests/test_scale_op.py +++ b/python/paddle/fluid/tests/unittests/test_scale_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index fb17287436..088996f9d7 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_scope.py b/python/paddle/fluid/tests/unittests/test_scope.py index d249a989a9..45fcbfba6e 100644 --- a/python/paddle/fluid/tests/unittests/test_scope.py +++ b/python/paddle/fluid/tests/unittests/test_scope.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid.core import unittest diff --git a/python/paddle/fluid/tests/unittests/test_selected_rows.py b/python/paddle/fluid/tests/unittests/test_selected_rows.py index f504a06fff..2f34f79b8e 100644 --- a/python/paddle/fluid/tests/unittests/test_selected_rows.py +++ b/python/paddle/fluid/tests/unittests/test_selected_rows.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid.core as core import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_seq_concat_op.py b/python/paddle/fluid/tests/unittests/test_seq_concat_op.py index 11ffa761a6..9d1d139721 100644 --- a/python/paddle/fluid/tests/unittests/test_seq_concat_op.py +++ b/python/paddle/fluid/tests/unittests/test_seq_concat_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import sys diff --git a/python/paddle/fluid/tests/unittests/test_seq_conv.py b/python/paddle/fluid/tests/unittests/test_seq_conv.py index 1a6e1aad79..dcc86382e5 100644 --- a/python/paddle/fluid/tests/unittests/test_seq_conv.py +++ b/python/paddle/fluid/tests/unittests/test_seq_conv.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import random diff --git a/python/paddle/fluid/tests/unittests/test_seq_pool.py b/python/paddle/fluid/tests/unittests/test_seq_pool.py index 0b3659d7a6..66e77714c5 100644 --- a/python/paddle/fluid/tests/unittests/test_seq_pool.py +++ b/python/paddle/fluid/tests/unittests/test_seq_pool.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_sequence_erase_op.py b/python/paddle/fluid/tests/unittests/test_sequence_erase_op.py index 8f0765277a..92cd5b0cbc 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_erase_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_erase_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_sequence_expand.py b/python/paddle/fluid/tests/unittests/test_sequence_expand.py index 5ff0dab23e..ffd4026dba 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_expand.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_expand.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_sequence_reshape.py b/python/paddle/fluid/tests/unittests/test_sequence_reshape.py index 39b02ecf6d..f11fa6c39c 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_reshape.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_reshape.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import math diff --git a/python/paddle/fluid/tests/unittests/test_sequence_slice_op.py b/python/paddle/fluid/tests/unittests/test_sequence_slice_op.py index 313e485d1e..1561490087 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_slice_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import sys diff --git a/python/paddle/fluid/tests/unittests/test_sequence_softmax_op.py b/python/paddle/fluid/tests/unittests/test_sequence_softmax_op.py index c4fc8b74cf..3e00e7d95f 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_softmax_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_sgd_op.py b/python/paddle/fluid/tests/unittests/test_sgd_op.py index 3126293f9d..c14a83b4bb 100644 --- a/python/paddle/fluid/tests/unittests/test_sgd_op.py +++ b/python/paddle/fluid/tests/unittests/test_sgd_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_shape_op.py b/python/paddle/fluid/tests/unittests/test_shape_op.py index a62ee05007..02231ea943 100644 --- a/python/paddle/fluid/tests/unittests/test_shape_op.py +++ b/python/paddle/fluid/tests/unittests/test_shape_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_shrink_rnn_memory.py b/python/paddle/fluid/tests/unittests/test_shrink_rnn_memory.py index a994bf181a..97f79f9421 100644 --- a/python/paddle/fluid/tests/unittests/test_shrink_rnn_memory.py +++ b/python/paddle/fluid/tests/unittests/test_shrink_rnn_memory.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.core as core from paddle.fluid.executor import Executor diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py b/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py index c435796569..97ff203499 100644 --- a/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np from op_test import OpTest from scipy.special import logit diff --git a/python/paddle/fluid/tests/unittests/test_sign_op.py b/python/paddle/fluid/tests/unittests/test_sign_op.py index 087a0c575b..85a9d9cae4 100644 --- a/python/paddle/fluid/tests/unittests/test_sign_op.py +++ b/python/paddle/fluid/tests/unittests/test_sign_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index 1a48bce3bb..134df38eea 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_slice_var.py b/python/paddle/fluid/tests/unittests/test_slice_var.py index 82305b23a1..fab63b7d56 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_var.py +++ b/python/paddle/fluid/tests/unittests/test_slice_var.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import math import unittest from paddle.fluid.transpiler.distribute_transpiler import slice_variable diff --git a/python/paddle/fluid/tests/unittests/test_smooth_l1_loss_op.py b/python/paddle/fluid/tests/unittests/test_smooth_l1_loss_op.py index e74664dac4..8ab6833821 100644 --- a/python/paddle/fluid/tests/unittests/test_smooth_l1_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_smooth_l1_loss_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 70ad05597c..d88aa1ae1c 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index c0d9fc8f22..b7e5ff6d52 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_split_and_merge_lod_tensor_op.py b/python/paddle/fluid/tests/unittests/test_split_and_merge_lod_tensor_op.py index ea1146166d..5397d5c521 100644 --- a/python/paddle/fluid/tests/unittests/test_split_and_merge_lod_tensor_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_and_merge_lod_tensor_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.core as core import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_split_ids_op.py b/python/paddle/fluid/tests/unittests/test_split_ids_op.py index 20bba3ac33..4c3d025898 100644 --- a/python/paddle/fluid/tests/unittests/test_split_ids_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_ids_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import six diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index 6b67a52e81..3c5dd782f8 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py b/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py index 2b261820e0..41a5ee59ea 100644 --- a/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.core as core import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_spp_op.py b/python/paddle/fluid/tests/unittests/test_spp_op.py index 3cbfc2a703..a6c2cccd39 100644 --- a/python/paddle/fluid/tests/unittests/test_spp_op.py +++ b/python/paddle/fluid/tests/unittests/test_spp_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_squared_l2_distance_op.py b/python/paddle/fluid/tests/unittests/test_squared_l2_distance_op.py index 78bc300ebe..a8bc1004d9 100644 --- a/python/paddle/fluid/tests/unittests/test_squared_l2_distance_op.py +++ b/python/paddle/fluid/tests/unittests/test_squared_l2_distance_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py b/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py index 609445d522..439bae9510 100644 --- a/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import unittest from numpy import linalg as LA diff --git a/python/paddle/fluid/tests/unittests/test_squeeze_op.py b/python/paddle/fluid/tests/unittests/test_squeeze_op.py index bca6af2fd5..a2a5584459 100644 --- a/python/paddle/fluid/tests/unittests/test_squeeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_sum_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_sum_mkldnn_op.py index 7956897d68..55820f31b8 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_mkldnn_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest from test_sum_op import TestSumOp diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index 1d90414e13..9dc93048e6 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_switch.py b/python/paddle/fluid/tests/unittests/test_switch.py index 528c5cce4b..2a9c07a889 100644 --- a/python/paddle/fluid/tests/unittests/test_switch.py +++ b/python/paddle/fluid/tests/unittests/test_switch.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_target_assign_op.py b/python/paddle/fluid/tests/unittests/test_target_assign_op.py index bd20889752..aec219f806 100644 --- a/python/paddle/fluid/tests/unittests/test_target_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_target_assign_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np import random diff --git a/python/paddle/fluid/tests/unittests/test_tensor.py b/python/paddle/fluid/tests/unittests/test_tensor.py index 5ccc876ae8..e9d0f8a019 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_tensor.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import paddle.fluid.core as core import unittest import numpy diff --git a/python/paddle/fluid/tests/unittests/test_top_k_op.py b/python/paddle/fluid/tests/unittests/test_top_k_op.py index cbc3da5503..e54e170f7f 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index ebd63fbd49..0853f80b82 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_batch_size_like_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_batch_size_like_op.py index e033e86114..7b8be24d9d 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_batch_size_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_batch_size_like_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py index 346a949b6e..d6a5d68765 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_unique_name.py b/python/paddle/fluid/tests/unittests/test_unique_name.py index 49ef335618..b8c751b2e9 100644 --- a/python/paddle/fluid/tests/unittests/test_unique_name.py +++ b/python/paddle/fluid/tests/unittests/test_unique_name.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/test_unpool_op.py b/python/paddle/fluid/tests/unittests/test_unpool_op.py index 49dc559ed7..b0c7c3c866 100644 --- a/python/paddle/fluid/tests/unittests/test_unpool_op.py +++ b/python/paddle/fluid/tests/unittests/test_unpool_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index 7a4aa0a40b..5fcabe4c83 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 49784e21c4..b0830e130d 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_ import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_version.py b/python/paddle/fluid/tests/unittests/test_version.py index a09c8a759b..42a0e5c802 100644 --- a/python/paddle/fluid/tests/unittests/test_version.py +++ b/python/paddle/fluid/tests/unittests/test_version.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import re diff --git a/python/paddle/fluid/tests/unittests/test_warpctc_op.py b/python/paddle/fluid/tests/unittests/test_warpctc_op.py index d647a17692..5e3aa13546 100644 --- a/python/paddle/fluid/tests/unittests/test_warpctc_op.py +++ b/python/paddle/fluid/tests/unittests/test_warpctc_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import sys import unittest import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_weight_normalization.py b/python/paddle/fluid/tests/unittests/test_weight_normalization.py index 436f9b9f86..e990d8b249 100644 --- a/python/paddle/fluid/tests/unittests/test_weight_normalization.py +++ b/python/paddle/fluid/tests/unittests/test_weight_normalization.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy import collections diff --git a/python/paddle/fluid/tests/unittests/test_while_op.py b/python/paddle/fluid/tests/unittests/test_while_op.py index 790e6afe5f..b75373cf24 100644 --- a/python/paddle/fluid/tests/unittests/test_while_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import paddle.fluid.layers as layers from paddle.fluid.executor import Executor diff --git a/python/paddle/fluid/tests/unittests/testsuite.py b/python/paddle/fluid/tests/unittests/testsuite.py index c6e176ca31..31ae25f02c 100644 --- a/python/paddle/fluid/tests/unittests/testsuite.py +++ b/python/paddle/fluid/tests/unittests/testsuite.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import numpy as np import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/transformer_model.py b/python/paddle/fluid/tests/unittests/transformer_model.py index 868a0248be..f0e74aff6b 100644 --- a/python/paddle/fluid/tests/unittests/transformer_model.py +++ b/python/paddle/fluid/tests/unittests/transformer_model.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from functools import partial import numpy as np diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 5d549e68d1..294308f187 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import contextlib import os import errno diff --git a/python/paddle/fluid/transpiler/__init__.py b/python/paddle/fluid/transpiler/__init__.py index a8622ad544..8429e2fd7c 100644 --- a/python/paddle/fluid/transpiler/__init__.py +++ b/python/paddle/fluid/transpiler/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from .distribute_transpiler import DistributeTranspiler, DistributeTranspilerConfig from .inference_transpiler import InferenceTranspiler from .memory_optimization_transpiler import memory_optimize, release_memory diff --git a/python/paddle/fluid/transpiler/details/__init__.py b/python/paddle/fluid/transpiler/details/__init__.py index 1bfab1f219..5e98266a76 100644 --- a/python/paddle/fluid/transpiler/details/__init__.py +++ b/python/paddle/fluid/transpiler/details/__init__.py @@ -12,5 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from .program_utils import * from .ufind import * diff --git a/python/paddle/fluid/transpiler/details/program_utils.py b/python/paddle/fluid/transpiler/details/program_utils.py index 291c8fb27b..640dbf4bbe 100644 --- a/python/paddle/fluid/transpiler/details/program_utils.py +++ b/python/paddle/fluid/transpiler/details/program_utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import six diff --git a/python/paddle/fluid/transpiler/details/ufind.py b/python/paddle/fluid/transpiler/details/ufind.py index 0e30d0e3f9..aa63af7dcf 100644 --- a/python/paddle/fluid/transpiler/details/ufind.py +++ b/python/paddle/fluid/transpiler/details/ufind.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + class UnionFind(object): """ Union-find data structure. diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 4c6df361fe..836477a9e6 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function """ Steps to transpile trainer: 1. split variable to multiple blocks, aligned by product(dim[1:]) (width). diff --git a/python/paddle/fluid/transpiler/inference_transpiler.py b/python/paddle/fluid/transpiler/inference_transpiler.py index 87f20bbccf..42005839c4 100644 --- a/python/paddle/fluid/transpiler/inference_transpiler.py +++ b/python/paddle/fluid/transpiler/inference_transpiler.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import os import numpy as np from .. import core diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 0de994dda3..3e58e125de 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + from collections import defaultdict from .. import core from ... import compat as cpt diff --git a/python/paddle/fluid/transpiler/ps_dispatcher.py b/python/paddle/fluid/transpiler/ps_dispatcher.py index dcffadd531..6a6d14a69b 100644 --- a/python/paddle/fluid/transpiler/ps_dispatcher.py +++ b/python/paddle/fluid/transpiler/ps_dispatcher.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + class PSDispatcher(object): """ diff --git a/python/paddle/fluid/unique_name.py b/python/paddle/fluid/unique_name.py index b125eba4f8..b9957a699e 100644 --- a/python/paddle/fluid/unique_name.py +++ b/python/paddle/fluid/unique_name.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import collections import contextlib import six From bf3c34960f2a59a2616957f8fb4107b2ac7aa02b Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 16 Aug 2018 11:00:55 +0800 Subject: [PATCH 83/94] "cherry picked operators changes" (#12184) * "cherry picked operators changes" * "remove duplicated code" * "add constant setter" * "add get expected kernel" * "fix ci" * "add fill constant" --- paddle/fluid/operators/activation_op.cu | 4 +- paddle/fluid/operators/activation_op.h | 12 ++-- paddle/fluid/operators/assign_value_op.cu.cc | 5 +- paddle/fluid/operators/conv_cudnn_op.cu.cc | 56 +++++++++++------- paddle/fluid/operators/cross_entropy_op.cu | 12 ++-- paddle/fluid/operators/elementwise_add_op.cu | 3 +- paddle/fluid/operators/elementwise_div_op.cu | 9 ++- paddle/fluid/operators/elementwise_mul_op.cu | 8 ++- .../fluid/operators/elementwise_op_function.h | 4 +- paddle/fluid/operators/elementwise_sub_op.cu | 8 ++- paddle/fluid/operators/fill_constant_op.cc | 53 ++++++----------- paddle/fluid/operators/fill_constant_op.cu.cc | 26 ++++++++ paddle/fluid/operators/fill_constant_op.h | 48 +++++++++++++++ paddle/fluid/operators/fill_op.cc | 2 +- paddle/fluid/operators/gaussian_random_op.cu | 2 + paddle/fluid/operators/math/cross_entropy.cu | 20 ++++++- paddle/fluid/operators/math/cross_entropy.h | 17 ++++++ .../operators/math/selected_rows_functor.cu | 13 +++- paddle/fluid/operators/math/softmax.cu | 3 + paddle/fluid/operators/mean_op.cu | 10 ++-- paddle/fluid/operators/mean_op.h | 2 +- paddle/fluid/operators/mul_op.cu.cc | 7 ++- paddle/fluid/operators/pool_cudnn_op.cu.cc | 6 +- paddle/fluid/operators/scale_op.cu | 6 +- paddle/fluid/operators/softmax_cudnn_op.cu.cc | 3 +- paddle/fluid/operators/softmax_op.cu.cc | 3 +- paddle/fluid/operators/sum_op.cu | 5 +- paddle/fluid/operators/sum_op.h | 2 +- paddle/fluid/operators/top_k_op.cu | 28 +++++++-- paddle/fluid/operators/uniform_random_op.cu | 59 ++++++++++++++++--- 30 files changed, 328 insertions(+), 108 deletions(-) create mode 100644 paddle/fluid/operators/fill_constant_op.cu.cc create mode 100644 paddle/fluid/operators/fill_constant_op.h diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 27487b396c..d3a7ceed46 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -26,6 +26,8 @@ namespace plat = paddle::platform; act_type##_grad, ops::ActivationGradKernel>, \ ops::ActivationGradKernel>); + ops::grad_functor>, \ + ops::ActivationGradKernel>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 9124151926..48f3b5a5bc 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -333,8 +333,7 @@ struct SqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - const Out out_conj = Eigen::numext::conj(out); - dx.device(d) = static_cast(0.5) * dout / out_conj; + dx.device(d) = static_cast(0.5) * dout / out; } }; @@ -740,7 +739,7 @@ struct PowGradFunctor : public BaseActivationFunctor { typename dX> void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(factor) * - x.pow(static_cast(factor - static_cast(1))); + x.pow(static_cast(factor) - static_cast(1)); } }; @@ -863,10 +862,11 @@ struct SwishGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + T b = static_cast(beta); auto temp1 = static_cast(1) / - (static_cast(1) + (static_cast(-beta) * x).exp()); - auto temp2 = temp1 * (static_cast(1) - (beta * out)); - dx.device(d) = dout * ((beta * out) + temp2); + (static_cast(1) + (static_cast(-b) * x).exp()); + auto temp2 = temp1 * (static_cast(1) - (b * out)); + dx.device(d) = dout * ((b * out) + temp2); } }; diff --git a/paddle/fluid/operators/assign_value_op.cu.cc b/paddle/fluid/operators/assign_value_op.cu.cc index 08bfde5dc9..0ff174b388 100644 --- a/paddle/fluid/operators/assign_value_op.cu.cc +++ b/paddle/fluid/operators/assign_value_op.cu.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/assign_value_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel, - ops::AssignValueKernel); + ops::AssignValueKernel, + ops::AssignValueKernel); diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 22cbf680c0..59bfe8f61d 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -39,6 +39,27 @@ using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = static_cast(1024) * 1024 * 1024; +template +// bool EnableFp16(const T& dummy, const DeviceContext& dev_ctx, +bool EnableFp16(const DeviceContext& dev_ctx, + cudnnConvolutionDescriptor_t cudnn_conv_desc) { +#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) + // Tensor core is supported since the volta GPU and + // is only enabled when input and filter data are float16 + if (dev_ctx.GetComputeCapability() >= 70 && + std::type_index(typeid(T)) == + std::type_index(typeid(platform::float16))) { + PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); + return true; + } else { + PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_DEFAULT_MATH)); + } +#endif + return false; +} + template class CUDNNConvOpKernel : public framework::OpKernel { public: @@ -128,27 +149,14 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnnConvolutionFwdAlgo_t algo; auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); - - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( - handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); - -#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) - // Tensor core is supported since the volta GPU and - // is only enabled when input and filter data are float16 - if (dev_ctx.GetComputeCapability() >= 70 && - std::type_index(typeid(T)) == - std::type_index(typeid(platform::float16))) { - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); - // Currently tensor core is only enabled using this algo + if (EnableFp16(dev_ctx, cudnn_conv_desc)) { algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; } else { - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_DEFAULT_MATH)); + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); } -#endif // get workspace size able to allocate CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( @@ -288,6 +296,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } else { data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } + if (EnableFp16(dev_ctx, cudnn_conv_desc)) { + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + } CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( @@ -307,6 +318,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } else { filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } + if (EnableFp16(dev_ctx, cudnn_conv_desc)) { + filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + } CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( @@ -362,7 +376,8 @@ REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel); + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel); REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel, @@ -370,4 +385,5 @@ REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel); + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel) diff --git a/paddle/fluid/operators/cross_entropy_op.cu b/paddle/fluid/operators/cross_entropy_op.cu index 30dbd5bd3d..65fd3a5dbc 100644 --- a/paddle/fluid/operators/cross_entropy_op.cu +++ b/paddle/fluid/operators/cross_entropy_op.cu @@ -13,12 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cross_entropy_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; using CUDACtx = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, - ops::CrossEntropyOpKernel); -REGISTER_OP_CUDA_KERNEL(cross_entropy_grad, - ops::CrossEntropyGradientOpKernel, - ops::CrossEntropyGradientOpKernel); + ops::CrossEntropyOpKernel, + ops::CrossEntropyOpKernel); +REGISTER_OP_CUDA_KERNEL( + cross_entropy_grad, ops::CrossEntropyGradientOpKernel, + ops::CrossEntropyGradientOpKernel, + ops::CrossEntropyGradientOpKernel); diff --git a/paddle/fluid/operators/elementwise_add_op.cu b/paddle/fluid/operators/elementwise_add_op.cu index dfff518f17..f9f5c66d34 100644 --- a/paddle/fluid/operators/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise_add_op.cu @@ -30,4 +30,5 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, - ops::ElementwiseAddGradKernel); + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel); diff --git a/paddle/fluid/operators/elementwise_div_op.cu b/paddle/fluid/operators/elementwise_div_op.cu index 588d1f7420..4cc7ba0f43 100644 --- a/paddle/fluid/operators/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise_div_op.cu @@ -14,19 +14,24 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_div_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_div, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel); + ops::ElementwiseDivKernel, + ops::ElementwiseDivKernel); REGISTER_OP_CUDA_KERNEL( elementwise_div_grad, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, + ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel); + plat::float16>); diff --git a/paddle/fluid/operators/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise_mul_op.cu index 2fb1b4bee6..350d43168d 100644 --- a/paddle/fluid/operators/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise_mul_op.cu @@ -14,19 +14,25 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_mul_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_mul, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel); + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel); REGISTER_OP_CUDA_KERNEL( elementwise_mul_grad, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel); diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index bc3e95e904..7223a972d2 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -350,7 +350,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( int j = blockIdx.x; int i = threadIdx.x; int tid = threadIdx.x; - T val = 0; + T val(0); do { int x_offset = i * w + j; @@ -418,7 +418,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( int tid = threadIdx.x; int j = blockIdx.x; - T val = 0; + T val(0); int ttid = tid; while (true) { diff --git a/paddle/fluid/operators/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise_sub_op.cu index 8709f686f9..ff3f6f8a2c 100644 --- a/paddle/fluid/operators/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise_sub_op.cu @@ -14,19 +14,25 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_sub_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_sub, ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, - ops::ElementwiseSubKernel); + ops::ElementwiseSubKernel, + ops::ElementwiseSubKernel); REGISTER_OP_CUDA_KERNEL( elementwise_sub_grad, ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, + ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel); diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 130f18dde4..862249269e 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -12,48 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/operators/fill_constant_op.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { -class FillConstantInferShape : public framework::InferShapeBase { +class FillConstantOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *ctx) const override { + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FillConstantOp should not be null."); - auto &shape = ctx->Attrs().Get>("shape"); + auto& shape = ctx->Attrs().Get>("shape"); ctx->SetOutputDim("Out", framework::make_ddim(shape)); } -}; - -class FillConstantOp : public framework::OperatorBase { - public: - using framework::OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - auto data_type = - static_cast(Attr("dtype")); - auto value = Attr("value"); - auto force_cpu = Attr("force_cpu"); - auto &out = - *scope.FindVar(Output("Out"))->GetMutable(); - out.Resize(framework::make_ddim(Attr>("shape"))); - if (force_cpu) { - auto cpu = platform::CPUPlace(); - out.mutable_data(cpu, framework::ToTypeIndex(data_type)); - } else { - out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); - } - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); - math::set_constant(dev_ctx, &out, value); + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + static_cast(ctx.Attr("dtype")), + ctx.device_context()); } }; @@ -87,6 +67,11 @@ Fill up a variable with specified constant value. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, - ops::FillConstantInferShape, ops::FillConstantOpMaker, +REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, ops::FillConstantOpMaker, paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + fill_constant, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel) diff --git a/paddle/fluid/operators/fill_constant_op.cu.cc b/paddle/fluid/operators/fill_constant_op.cu.cc new file mode 100644 index 0000000000..51ccaefa43 --- /dev/null +++ b/paddle/fluid/operators/fill_constant_op.cu.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/fill_constant_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + fill_constant, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel) diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h new file mode 100644 index 0000000000..b2a2a7b2fa --- /dev/null +++ b/paddle/fluid/operators/fill_constant_op.h @@ -0,0 +1,48 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class FillConstantOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto data_type = + static_cast(ctx.Attr("dtype")); + auto value = ctx.Attr("value"); + auto force_cpu = ctx.Attr("force_cpu"); + auto* out = ctx.Output("Out"); + out->Resize(framework::make_ddim(ctx.Attr>("shape"))); + if (force_cpu) { + auto cpu = platform::CPUPlace(); + out->mutable_data(cpu, framework::ToTypeIndex(data_type)); + } else { + out->mutable_data(ctx.GetPlace(), framework::ToTypeIndex(data_type)); + } + + math::set_constant(ctx.template device_context(), out, + value); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index 925dc19061..352a17c927 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -69,7 +70,6 @@ class FillOp : public framework::OperatorBase { framework::VisitDataType( dtype, FillOpVisitor(&tensor, Attr>("value"))); - if (!force_cpu && platform::is_gpu_place(place)) { // Copy tensor to out platform::DeviceContextPool &pool = diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index 7784856417..b490723795 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -60,6 +61,7 @@ class GPUGaussianRandomKernel : public framework::OpKernel { } // namespace operators } // namespace paddle +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(gaussian_random, paddle::operators::GPUGaussianRandomKernel, paddle::operators::GPUGaussianRandomKernel); diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index 0de58d5fdd..58b85abf82 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -15,11 +15,25 @@ limitations under the License. */ #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { namespace math { +template +HOSTDEVICE T log(const T& val) { + return std::log(val); +} + +template <> +HOSTDEVICE platform::float16 log(const platform::float16& val) { + // strage bug, hlog is not exists. + return static_cast(0); + // half tmp = static_cast(val); + // return static_cast(hlog(tmp)); +} + namespace { template __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, @@ -35,12 +49,12 @@ template __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, const int class_num) { int tid = threadIdx.x; - T val = 0; + T val(0); int idx = blockIdx.x * class_num + tid; int end = blockIdx.x * class_num + class_num; for (; idx < end; idx += blockDim.x) { - val += math::TolerableValue()(std::log(X[idx])) * label[idx]; + val += math::TolerableValue()(log(X[idx])) * label[idx]; } val = paddle::platform::reduceSum(val, tid, blockDim.x); @@ -84,6 +98,8 @@ class CrossEntropyFunctor { template class CrossEntropyFunctor; template class CrossEntropyFunctor; +template class CrossEntropyFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/cross_entropy.h b/paddle/fluid/operators/math/cross_entropy.h index adc5b3fe47..2e4e4781c2 100644 --- a/paddle/fluid/operators/math/cross_entropy.h +++ b/paddle/fluid/operators/math/cross_entropy.h @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/hostdevice.h" namespace paddle { @@ -33,6 +35,21 @@ struct TolerableValue { } }; +// float16 value clip behave different. +using paddle::platform::float16; +using paddle::platform::isfinite; +template <> +struct TolerableValue { + HOSTDEVICE float16 operator()(const float16& x) const { + if (isfinite(x)) + return x; + else if (x > static_cast(0)) + return std::numeric_limits::max(); + else + return std::numeric_limits::min(); + } +}; + template class CrossEntropyFunctor { public: diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index a92762c7fe..00dbfc11a2 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -76,6 +77,7 @@ struct SelectedRowsAdd { template struct SelectedRowsAdd; template struct SelectedRowsAdd; +template struct SelectedRowsAdd; namespace { template @@ -120,7 +122,7 @@ struct SelectedRowsAddTensor { auto* out_data = output->data(); SetConstant functor; - functor(context, output, 0.0); + functor(context, output, static_cast(0)); const int block_size = 256; dim3 threads(block_size, 1); @@ -138,6 +140,8 @@ struct SelectedRowsAddTensor { template struct SelectedRowsAddTensor; template struct SelectedRowsAddTensor; +template struct SelectedRowsAddTensor; template struct SelectedRowsAddTo { @@ -177,6 +181,8 @@ template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; +template struct SelectedRowsAddTo; namespace { template @@ -229,6 +235,8 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; namespace scatter { @@ -276,7 +284,7 @@ struct MergeAdd { context.GetPlace()); math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); + constant_functor(context, out.mutable_value(), static_cast(0)); auto* out_data = out.mutable_value()->data(); auto* input_data = input.value().data(); @@ -300,6 +308,7 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; +template struct MergeAdd; template __global__ void UpdateToTensorKernel(const T* selected_rows, diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 3effe77625..785c4baecb 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -94,12 +94,15 @@ void SoftmaxGradCUDNNFunctor::operator()( template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; +template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index 91e0ab28ef..07aa23754f 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -12,14 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/mean_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( mean, ops::MeanKernel, - ops::MeanKernel); + ops::MeanKernel, + ops::MeanKernel); REGISTER_OP_CUDA_KERNEL( mean_grad, ops::MeanGradKernel, - ops::MeanGradKernel); + ops::MeanGradKernel, + ops::MeanGradKernel); diff --git a/paddle/fluid/operators/mean_op.h b/paddle/fluid/operators/mean_op.h index 362e9f9ae8..a41d50ae0b 100644 --- a/paddle/fluid/operators/mean_op.h +++ b/paddle/fluid/operators/mean_op.h @@ -55,7 +55,7 @@ class MeanGradKernel : public framework::OpKernel { IG->mutable_data(context.GetPlace()); T ig_size = static_cast(IG->numel()); - Eigen::DSizes bcast(ig_size); + Eigen::DSizes bcast(static_cast(ig_size)); EigenVector::Flatten(*IG).device( *context.template device_context().eigen_device()) = diff --git a/paddle/fluid/operators/mul_op.cu.cc b/paddle/fluid/operators/mul_op.cu.cc index 81f3e42bf4..6c5a83c6a5 100644 --- a/paddle/fluid/operators/mul_op.cu.cc +++ b/paddle/fluid/operators/mul_op.cu.cc @@ -20,6 +20,7 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel, ops::MulKernel, ops::MulKernel); -REGISTER_OP_CUDA_KERNEL(mul_grad, - ops::MulGradKernel, - ops::MulGradKernel); +REGISTER_OP_CUDA_KERNEL( + mul_grad, ops::MulGradKernel, + ops::MulGradKernel, + ops::MulGradKernel); diff --git a/paddle/fluid/operators/pool_cudnn_op.cu.cc b/paddle/fluid/operators/pool_cudnn_op.cu.cc index 31f083565f..9fdbee818a 100644 --- a/paddle/fluid/operators/pool_cudnn_op.cu.cc +++ b/paddle/fluid/operators/pool_cudnn_op.cu.cc @@ -174,7 +174,8 @@ REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel); REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, - ops::PoolCUDNNGradOpKernel); + ops::PoolCUDNNGradOpKernel, + ops::PoolCUDNNGradOpKernel); REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel, @@ -182,4 +183,5 @@ REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel); REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, - ops::PoolCUDNNGradOpKernel); + ops::PoolCUDNNGradOpKernel, + ops::PoolCUDNNGradOpKernel); diff --git a/paddle/fluid/operators/scale_op.cu b/paddle/fluid/operators/scale_op.cu index 04c802da12..d266867046 100644 --- a/paddle/fluid/operators/scale_op.cu +++ b/paddle/fluid/operators/scale_op.cu @@ -13,11 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/scale_op.h" +#include "paddle/fluid/platform/float16.h" +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( scale, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel); + int64_t>, + paddle::operators::ScaleKernel); diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index 2bdb23e999..c2d45c3d2e 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -78,4 +78,5 @@ REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, ops::SoftmaxCUDNNKernel, ops::SoftmaxCUDNNKernel); REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, - ops::SoftmaxGradCUDNNKernel); + ops::SoftmaxGradCUDNNKernel, + ops::SoftmaxGradCUDNNKernel); diff --git a/paddle/fluid/operators/softmax_op.cu.cc b/paddle/fluid/operators/softmax_op.cu.cc index 5fb4f011d9..19359b7eef 100644 --- a/paddle/fluid/operators/softmax_op.cu.cc +++ b/paddle/fluid/operators/softmax_op.cu.cc @@ -23,4 +23,5 @@ REGISTER_OP_CUDA_KERNEL( ops::SoftmaxKernel); REGISTER_OP_CUDA_KERNEL( softmax_grad, ops::SoftmaxGradKernel, - ops::SoftmaxGradKernel); + ops::SoftmaxGradKernel, + ops::SoftmaxGradKernel); diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index 89bcd1bbc8..db4c2d6c11 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -11,10 +11,13 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/sum_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( sum, ops::SumKernel, ops::SumKernel, ops::SumKernel, - ops::SumKernel); + ops::SumKernel, + ops::SumKernel); diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 49a4afb3a8..dda6772796 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -46,7 +46,7 @@ class SumKernel : public framework::OpKernel { if (!in_place) { math::SetConstant constant_functor; constant_functor(context.template device_context(), out, - 0.0); + static_cast(0)); } math::SelectedRowsAddToTensor functor; diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 9da8551eb2..5fc0784f66 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -11,16 +11,19 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using paddle::platform::float16; template struct Pair { @@ -32,6 +35,11 @@ struct Pair { id = id; } + __device__ __forceinline__ void clear() { + v = -INFINITY; + id = -1; + } + __device__ __forceinline__ void operator=(const Pair& in) { v = in.v; id = in.id; @@ -53,6 +61,12 @@ struct Pair { int64_t id; }; +template <> +__device__ __forceinline__ void Pair::clear() { + v = platform::raw_uint16_to_float16(0x400); + id = -1; +} + template __device__ __forceinline__ void AddTo(Pair topk[], const Pair& p, int beam_size) { @@ -150,7 +164,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, if (k < MaxLength - (*beam)) { topk[k] = topk[k + *beam]; } else { - topk[k].set(-INFINITY, -1); + topk[k].clear(); } } if (!(*is_empty)) { @@ -160,7 +174,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, } *max = topk[MaxLength - 1]; - if ((*max).v == -1) *is_empty = true; + if ((*max).v == static_cast(-1)) *is_empty = true; *beam = 0; } } @@ -181,7 +195,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, if (k < MaxLength - *beam) { topk[k] = topk[k + *beam]; } else { - topk[k].set(-INFINITY, -1); + topk[k].set(std::numeric_limits::min(), -1); } } if (!(*is_empty)) { @@ -273,7 +287,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, bool firststep = true; for (int k = 0; k < MaxLength; k++) { - topk[k].set(-INFINITY, -1); + topk[k].clear(); } while (k) { ThreadGetTopK(topk, &beam, k, @@ -325,5 +339,7 @@ class TopkOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel, - paddle::operators::TopkOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + top_k, paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel); diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index e1c7323a30..2b8039a0c1 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -11,10 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include +#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { @@ -36,6 +40,11 @@ struct UniformGenerator { } }; +template +struct CastFunctor { + HOSTDEVICE V operator()(const T& a) { return static_cast(a); } +}; + // It seems that Eigen::Tensor::random in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. @@ -66,18 +75,50 @@ class GPUUniformRandomKernel : public framework::OpKernel { T max = static_cast(context.Attr("max")); thrust::counting_iterator index_sequence_begin(0); int64_t size = tensor->numel(); - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - UniformGenerator(min, max, seed)); + if (out_var->IsType() && + std::type_index(typeid(T)) == + std::type_index(typeid(platform::float16))) { + framework::Tensor master_copy_tensor; + master_copy_tensor.Resize(tensor->dims()); + float* master_copy_tensor_data = + master_copy_tensor.mutable_data(context.GetPlace()); + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(master_copy_tensor_data), + UniformGenerator(static_cast(min), + static_cast(max), seed)); + platform::Transform trans; + auto* in_begin = master_copy_tensor.data(); + auto* in_end = in_begin + master_copy_tensor.numel(); + auto* out_begin = tensor->mutable_data(context.GetPlace()); + trans(context.template device_context(), + in_begin, in_end, out_begin, CastFunctor()); + } else { + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + UniformGenerator(min, max, seed)); + } + if (VLOG_IS_ON(5)) { + framework::Tensor cpu_tensor; + framework::TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor); + auto& dev_ctx = + *platform::DeviceContextPool::Instance().Get(context.GetPlace()); + dev_ctx.Wait(); + auto x = framework::EigenVector::Flatten(cpu_tensor); + VLOG(5) << "The Uniform output " << x; + } } }; } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); -REGISTER_OP_CUDA_KERNEL(uniform_random_batch_size_like, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + uniform_random, paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); +REGISTER_OP_CUDA_KERNEL( + uniform_random_batch_size_like, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); From 822496f6268c70fc53e50a1a15fb7a3f0a3fe49e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 16 Aug 2018 11:26:49 +0800 Subject: [PATCH 84/94] merge cpu and gpu --- paddle/fluid/operators/sampling_id_op.cc | 56 +--------------- paddle/fluid/operators/sampling_id_op.cu | 82 ++---------------------- paddle/fluid/operators/sampling_id_op.h | 80 +++++++++++++++++++++++ 3 files changed, 85 insertions(+), 133 deletions(-) create mode 100644 paddle/fluid/operators/sampling_id_op.h diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index ca7b246901..724463c95c 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -12,67 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/sampling_id_op.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -class SamplingIdKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("X"); - const int batch_size = static_cast(input->dims()[0]); - const int width = static_cast(input->dims()[1]); - - PADDLE_ENFORCE_GE(batch_size, 0, - "batch_size(dims[0]) must be nonnegative."); - PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); - - std::vector ins_vector; - framework::TensorToVector(*input, context.device_context(), &ins_vector); - - unsigned int seed = static_cast(context.Attr("seed")); - std::minstd_rand engine; - if (seed == 0) { - seed = std::random_device()(); - } - engine.seed(seed); - std::uniform_real_distribution dist( - static_cast(context.Attr("min")), - static_cast(context.Attr("max"))); - - std::vector ids(batch_size); - for (size_t i = 0; i < batch_size; ++i) { - T r = dist(engine); - int idx = width - 1; - for (int j = 0; j < width; ++j) { - if ((r -= ins_vector[i * width + j]) < 0) { - idx = j; - break; - } - } - ids[i] = ins_vector[i * width + idx]; - } - - std::vector out_dim; - out_dim.push_back(static_cast(batch_size)); - - Tensor* output = context.Output("Out"); - output->Resize(framework::make_ddim(out_dim)); - output->mutable_data(context.GetPlace()); - framework::TensorFromVector(ids, context.device_context(), output); - } -}; - class SamplingIdOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index 114df044af..a4f0470314 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -11,83 +11,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -template -struct UniformGenerator { - T min_, max_; - unsigned int seed_; +#include "paddle/fluid/operators/sampling_id_op.h" - __host__ __device__ UniformGenerator(T min, T max, int seed) - : min_(min), max_(max), seed_(seed) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n); - return dist(rng); - } -}; - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class SamplingIdGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("X"); - const int batch_size = static_cast(input->dims()[0]); - const int width = static_cast(input->dims()[1]); - - PADDLE_ENFORCE_GE(batch_size, 0, - "batch_size(dims[0]) must be nonnegative."); - PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); - - std::vector ins_vector; - framework::TensorToVector(*input, context.device_context(), &ins_vector); - - unsigned int seed = static_cast(context.Attr("seed")); - if (seed == 0) { - std::random_device rd; - seed = rd(); - } - T min = static_cast(context.Attr("min")); - T max = static_cast(context.Attr("max")); - UniformGenerator gen = UniformGenerator(min, max, seed); - - std::vector ids(batch_size); - for (size_t i = 0; i < batch_size; ++i) { - T r = gen(0); - int idx = width - 1; - for (int j = 0; j < width; ++j) { - if ((r -= ins_vector[i * width + j]) < 0) { - idx = j; - break; - } - } - ids[i] = ins_vector[i * width + idx]; - } - - std::vector out_dim; - out_dim.push_back(static_cast(batch_size)); - - Tensor* output = context.Output("Out"); - output->Resize(framework::make_ddim(out_dim)); - output->mutable_data(context.GetPlace()); - framework::TensorFromVector(ids, context.device_context(), output); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OP_CUDA_KERNEL(sampling_id, - paddle::operators::SamplingIdGPUKernel, - paddle::operators::SamplingIdGPUKernel); +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(sampling_id, paddle::operators::SamplingIdKernel, + paddle::operators::SamplingIdKernel); diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h new file mode 100644 index 0000000000..f730a9746d --- /dev/null +++ b/paddle/fluid/operators/sampling_id_op.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class SamplingIdKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("X"); + const int batch_size = static_cast(input->dims()[0]); + const int width = static_cast(input->dims()[1]); + + PADDLE_ENFORCE_GE(batch_size, 0, + "batch_size(dims[0]) must be nonnegative."); + PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); + + std::vector ins_vector; + framework::TensorToVector(*input, context.device_context(), &ins_vector); + + unsigned int seed = static_cast(context.Attr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::uniform_real_distribution dist( + static_cast(context.Attr("min")), + static_cast(context.Attr("max"))); + + std::vector ids(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + T r = dist(engine); + int idx = width - 1; + for (int j = 0; j < width; ++j) { + if ((r -= ins_vector[i * width + j]) < 0) { + idx = j; + break; + } + } + ids[i] = ins_vector[i * width + idx]; + } + + std::vector out_dim; + out_dim.push_back(static_cast(batch_size)); + + Tensor* output = context.Output("Out"); + output->Resize(framework::make_ddim(out_dim)); + output->mutable_data(context.GetPlace()); + framework::TensorFromVector(ids, context.device_context(), output); + } +}; + +} // namespace operators +} // namespace paddle From 9f3789944c2c98605f26ffd224fbe1df02fa2e68 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Thu, 16 Aug 2018 11:34:21 +0800 Subject: [PATCH 85/94] use latest anakin commit --- CMakeLists.txt | 3 ++- cmake/external/anakin.cmake | 5 ++--- paddle/fluid/inference/api/CMakeLists.txt | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 920c20d6f8..6844772711 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -204,11 +204,12 @@ include(external/snappy) # download snappy include(external/snappystream) include(external/threadpool) -set(WITH_ANAKIN OFF CACHE STRING "Disable Anakin first, will add it later." FORCE) if(WITH_GPU) include(cuda) include(tensorrt) include(external/anakin) +elseif() + set(WITH_ANAKIN OFF CACHE STRING "Anakin is used in GPU only now." FORCE) endif() include(cudnn) # set cudnn libraries, must before configure diff --git a/cmake/external/anakin.cmake b/cmake/external/anakin.cmake index 5de7ca8f46..455ef91ac5 100644 --- a/cmake/external/anakin.cmake +++ b/cmake/external/anakin.cmake @@ -35,9 +35,8 @@ set(ANAKIN_COMPILE_EXTRA_FLAGS ExternalProject_Add( extern_anakin ${EXTERNAL_PROJECT_LOG_ARGS} - # TODO(luotao): use PaddlePaddle/Anakin later - GIT_REPOSITORY "https://github.com/luotao1/Anakin" - GIT_TAG "842a89ae3747ede25d8acbc29030d2eb602ced1f" + GIT_REPOSITORY "https://github.com/PaddlePaddle/Anakin" + GIT_TAG "04256ba78fa3da0beb74e8036c8efd68c12824d6" PREFIX ${ANAKIN_SOURCE_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DUSE_GPU_PLACE=YES diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 83867e0a2c..a72e27d651 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -60,7 +60,7 @@ cc_library(paddle_inference_tensorrt_subgraph_engine inference_api_test(test_api_tensorrt_subgraph_engine SRC api_tensorrt_subgraph_engine_tester.cc ARGS test_word2vec) endif() -if (WITH_ANAKIN) # only needed in CI +if (WITH_ANAKIN AND WITH_GPU) # only needed in CI # compile the libinference_anakin_api.a and anakin.so. nv_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber) #nv_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS anakin) From c44fb003715aab90d14f0d0fce020d0b65ec6fbf Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Thu, 16 Aug 2018 12:01:22 +0800 Subject: [PATCH 86/94] Add name in relu and log API. (#12438) --- paddle/fluid/API.spec | 4 ++-- python/paddle/fluid/layers/nn.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index ea9105d79c..e963902a50 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -155,8 +155,8 @@ paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.random_crop ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) +paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3e50fc91d9..be852b6711 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5090,7 +5090,7 @@ def random_crop(x, shape, seed=None): return out -def log(x): +def log(x, name=None): """ Calculates the natural log of the given input tensor, element-wise. @@ -5100,6 +5100,8 @@ def log(x): Args: x (Variable): Input tensor. + name (str|None, default None): A name for this layer If set None, + the layer will be named automatically. Returns: Variable: The natural log of the input tensor computed element-wise. @@ -5117,7 +5119,7 @@ def log(x): return out -def relu(x): +def relu(x, name=None): """ Relu takes one input data (Tensor) and produces one output data (Tensor) where the rectified linear function, y = max(0, x), is applied to @@ -5129,6 +5131,8 @@ def relu(x): Args: x (Variable): The input tensor. + name (str|None, default None): A name for this layer If set None, + the layer will be named automatically. Returns: Variable: The output tensor with the same shape as input. From 317e18abd2aa69390dcc6a0d6760ba954597863e Mon Sep 17 00:00:00 2001 From: Qingsheng Li Date: Thu, 16 Aug 2018 13:00:55 +0800 Subject: [PATCH 87/94] Remove Data Sharing between input and output in scatter_op (#12672) * Remove Data Sharing between input and output in scatter_op * Removed data sharing in backward op --- paddle/fluid/operators/scatter_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index d29947b55e..181bb1af5c 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -35,7 +35,7 @@ class ScatterOpKernel : public framework::OpKernel { auto *Out = ctx.Output("Out"); // In place output: Out = X, Out[Ids] += Updates - Out->ShareDataWith(*X); + framework::TensorCopySync(*X, ctx.GetPlace(), Out); // Apply ScatterUpdate: Out[index] += Updates[:] ScatterAssign(ctx.device_context(), *Updates, *Ids, Out); } @@ -53,7 +53,7 @@ class ScatterGradientOpKernel : public framework::OpKernel { auto *dOut = ctx.Input(framework::GradVarName("Out")); // In place gradient: dX = dO - dX->ShareDataWith(*dOut); + framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates += dO[Ids] CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); From d7873e14124a157980049f3dc6a281638ce437ee Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 16 Aug 2018 13:48:46 +0800 Subject: [PATCH 88/94] remove patchelf in windows (#12710) * remove patchelf in windowls * "follow comment" --- .gitignore | 2 ++ cmake/configure.cmake | 4 ++++ python/CMakeLists.txt | 5 +++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 9e3a0b499f..b92bb9cc12 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ python/paddle/v2/fluid/tests/book/image_classification_resnet.inference.model/ python/paddle/v2/fluid/tests/book/image_classification_vgg.inference.model/ python/paddle/v2/fluid/tests/book/label_semantic_roles.inference.model/ *.DS_Store +*.vs build/ build_doc/ *.user @@ -15,6 +16,7 @@ build_doc/ .cproject .pydevproject .settings/ +CMakeSettings.json Makefile .test_env/ third_party/ diff --git a/cmake/configure.cmake b/cmake/configure.cmake index ae90a529b1..d14162e0a6 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -56,6 +56,10 @@ if(NOT CMAKE_CROSSCOMPILING) set(SIMD_FLAG ${SSE3_FLAG}) endif() endif() +if(UNIX AND NOT APPLE) + # except apple from nix*Os family + set(LINUX TRUE) +endif(UNIX AND NOT APPLE) if(NOT WITH_GOLANG) add_definitions(-DPADDLE_WITHOUT_GOLANG) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 2590081150..9cdcb87df5 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -97,10 +97,11 @@ if(APPLE) if(NOT INSTALL_NAME_TOOL_EXECUTABLE) message(FATAL_ERROR "install_name_tool not found, please check.\n") endif() -else(APPLE) +endif() +if(LINUX) find_program(PATCHELF_EXECUTABLE patchelf) if(NOT PATCHELF_EXECUTABLE) message(FATAL_ERROR "patchelf not found, please install it.\n" "For Ubuntu, the command is: apt-get install -y patchelf.") endif() -endif(APPLE) +endif(LINUX) From 1ef5f2c3e834a26137907d9150307ba257fa2568 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 16 Aug 2018 18:12:28 +0800 Subject: [PATCH 89/94] Make flowers reader and parallel_executor more efficient --- python/paddle/dataset/flowers.py | 2 +- python/paddle/fluid/parallel_executor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index 17c768424f..aa73bbaf70 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -120,7 +120,7 @@ def reader_creator(data_file, file = file.strip() batch = None with open(file, 'rb') as f: - batch = pickle.loads(f.read()) + batch = pickle.load(f) data = batch['data'] labels = batch['label'] for sample, label in zip(data, batch['label']): diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index ac87b12a1c..a7765c9591 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -273,7 +273,7 @@ class ParallelExecutor(object): self.executor.feed_tensors_into_local_scopes(res) fetch_var_name = '@FETCHED_VAR_NAME@' - self.executor.run(cpt.to_text(fetch_list), cpt.to_text(fetch_var_name)) + self.executor.run(fetch_list, fetch_var_name) arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() if self.is_dist: From 546a26f08178e04264b0b2842afabd70412c53ca Mon Sep 17 00:00:00 2001 From: luotao1 Date: Thu, 16 Aug 2018 20:30:29 +0800 Subject: [PATCH 90/94] add mklml depends for anakin --- cmake/external/anakin.cmake | 1 + 1 file changed, 1 insertion(+) diff --git a/cmake/external/anakin.cmake b/cmake/external/anakin.cmake index 455ef91ac5..75ed529cd6 100644 --- a/cmake/external/anakin.cmake +++ b/cmake/external/anakin.cmake @@ -35,6 +35,7 @@ set(ANAKIN_COMPILE_EXTRA_FLAGS ExternalProject_Add( extern_anakin ${EXTERNAL_PROJECT_LOG_ARGS} + DEPENDS ${MKLML_PROJECT} GIT_REPOSITORY "https://github.com/PaddlePaddle/Anakin" GIT_TAG "04256ba78fa3da0beb74e8036c8efd68c12824d6" PREFIX ${ANAKIN_SOURCE_DIR} From 447936551ec188387fa9dd1539c3e2d2190d7993 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Thu, 16 Aug 2018 21:40:11 +0800 Subject: [PATCH 91/94] quick fix anakin on 5117 cpu --- cmake/external/anakin.cmake | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cmake/external/anakin.cmake b/cmake/external/anakin.cmake index 75ed529cd6..855897394a 100644 --- a/cmake/external/anakin.cmake +++ b/cmake/external/anakin.cmake @@ -36,8 +36,9 @@ ExternalProject_Add( extern_anakin ${EXTERNAL_PROJECT_LOG_ARGS} DEPENDS ${MKLML_PROJECT} - GIT_REPOSITORY "https://github.com/PaddlePaddle/Anakin" - GIT_TAG "04256ba78fa3da0beb74e8036c8efd68c12824d6" + # Anakin codes error on Intel(R) Xeon(R) Gold 5117 CPU, temporary do not compile avx512 related code. + GIT_REPOSITORY "https://github.com/luotao1/Anakin" + GIT_TAG "bcf17aabe7921ceb7bce591244b4f9dce7dba5c8" PREFIX ${ANAKIN_SOURCE_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DUSE_GPU_PLACE=YES From 64d48f4d6af515e7d3da00f37ee9938f7dd6eb96 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 17 Aug 2018 09:15:38 +0800 Subject: [PATCH 92/94] fix mac compile (#12751) --- paddle/fluid/framework/ir/graph_pattern_detecter.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter.cc b/paddle/fluid/framework/ir/graph_pattern_detecter.cc index f27d9b0509..dcc4382792 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detecter.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detecter.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include From 653fad08f8f7a717c20756eebfc1b4ab860d4618 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 17 Aug 2018 10:52:09 +0800 Subject: [PATCH 93/94] Optimize selected rows for dist lookup table with pthread rwlock (#12635) Optimize selected rows for dist lookup table with rwlock --- paddle/fluid/framework/rw_lock.h | 46 ++++++ paddle/fluid/framework/selected_rows.cc | 110 ++++++++------ paddle/fluid/framework/selected_rows.h | 63 ++++---- paddle/fluid/framework/selected_rows_test.cc | 143 +++++++++++++++--- .../operators/distributed/rpc_server_test.cc | 3 +- .../fluid/operators/lookup_sparse_table_op.cc | 53 +------ paddle/fluid/operators/sgd_op.h | 2 +- paddle/fluid/operators/uniform_random_op.cc | 4 +- paddle/fluid/pybind/pybind.cc | 1 + .../unittests/test_lookup_sparse_table_op.py | 57 +++---- .../fluid/tests/unittests/test_sgd_op.py | 1 + 11 files changed, 298 insertions(+), 185 deletions(-) create mode 100644 paddle/fluid/framework/rw_lock.h diff --git a/paddle/fluid/framework/rw_lock.h b/paddle/fluid/framework/rw_lock.h new file mode 100644 index 0000000000..2a4009b765 --- /dev/null +++ b/paddle/fluid/framework/rw_lock.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +namespace paddle { +namespace framework { + +struct RWLock { + RWLock() { pthread_rwlock_init(&lock_, nullptr); } + + ~RWLock() { pthread_rwlock_destroy(&lock_); } + + void RDLock() { + PADDLE_ENFORCE_EQ(pthread_rwlock_rdlock(&lock_), 0, + "acquire read lock failed"); + } + + void WRLock() { + PADDLE_ENFORCE_EQ(pthread_rwlock_wrlock(&lock_), 0, + "acquire write lock failed"); + } + + void UNLock() { + PADDLE_ENFORCE_EQ(pthread_rwlock_unlock(&lock_), 0, "unlock failed"); + } + + private: + pthread_rwlock_t lock_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index 06ed87e7e8..c202b0a5be 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -120,66 +120,76 @@ bool SelectedRows::HasKey(int64_t key) const { : true; } -std::vector> SelectedRows::Get( - const std::vector& keys, framework::Tensor* value) const { +int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown) { + rwlock_->RDLock(); + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + rwlock_->UNLock(); + if (!auto_grown) { + PADDLE_THROW("key %d not found", key); + } + rwlock_->WRLock(); + auto map_size = id_to_index_.size(); + auto vector_size = rows_.size(); + if (map_size != vector_size) { + rwlock_->UNLock(); + PADDLE_THROW( + "id_to_index_ size %d should have the same size with rows_ %d", + map_size, vector_size); + } + auto write_iter = id_to_index_.find(key); + if (write_iter == id_to_index_.end()) { + size_t row_num = rows_.size(); + if (row_num == value_->dims()[0]) { + rwlock_->UNLock(); + PADDLE_THROW("selected rows is full, then length exceed %d", row_num); + } + // key logic to put a key into id_to_index_ + rows_.push_back(key); + auto index = static_cast(rows_.size() - 1); + id_to_index_[key] = index; + rwlock_->UNLock(); + return index; + } else { + auto index = write_iter->second; + rwlock_->UNLock(); + return index; + } + } else { + auto index = iter->second; + rwlock_->UNLock(); + return index; + } +} + +void SelectedRows::SyncIndex() { + rwlock_->WRLock(); + id_to_index_.clear(); + for (size_t i = 0; i < rows_.size(); ++i) { + id_to_index_[rows_[i]] = i; + } + rwlock_->UNLock(); +} + +void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, + bool auto_grown) { PADDLE_ENFORCE(value->IsInitialized(), "The value tensor should be initialized."); - std::vector> non_keys_pair; - if (keys.empty()) { + if (ids.numel() == 0) { VLOG(3) << "keys is empty, please check data!"; } else { int64_t value_width = value_->numel() / value_->dims()[0]; PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], "output tensor should have the same shape with table " "except the dims[0]."); - - for (size_t i = 0; i < keys.size(); ++i) { - int64_t index = Index(keys[i]); - if (index == -1) { - non_keys_pair.push_back( - std::make_pair(keys[i], static_cast(i))); - } else { - framework::VisitDataType( - framework::ToDataType(value_->type()), - TensorCopyVisitor(value, i * value_width, *value_.get(), - index * value_width, value_width)); - } + for (size_t i = 0; i < ids.numel(); ++i) { + int64_t index = AutoGrownIndex(ids.data()[i], auto_grown); + framework::VisitDataType( + framework::ToDataType(value_->type()), + TensorCopyVisitor(value, i * value_width, *value_.get(), + index * value_width, value_width)); } } - return non_keys_pair; -} - -bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { - PADDLE_ENFORCE(value.IsInitialized(), "The value should be initialized."); - if (value_->IsInitialized()) { - PADDLE_ENFORCE_EQ( - value.type(), value_->type(), - "The type of the value should be same with the original value"); - } - PADDLE_ENFORCE_EQ(value.dims()[0], static_cast(1), - "The first dim of value should be 1."); - std::lock_guard lock(*auto_grown_mutex_.get()); - auto index = Index(key); - bool is_new_key = false; - if (index == -1) { - rows_.push_back(key); - index = rows_.size() - 1; - is_new_key = true; - // whether need to resize the table - if (static_cast(rows_.size()) > value_->dims()[0]) { - auto dims = value_->dims(); - dims[0] = (dims[0] + 1) << 1; - framework::VisitDataType(framework::ToDataType(value.type()), - ReAllocateVisitor(dims, value_.get())); - } - } - - framework::VisitDataType( - framework::ToDataType(value.type()), - TensorCopyVisitor(value_.get(), - index * value_->numel() / value_->dims()[0], value, - static_cast(0), value.numel())); - return is_new_key; } } // namespace framework diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index 7160670ddd..daf5e95304 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -17,10 +17,12 @@ limitations under the License. */ #include #include #include // NOLINT +#include #include #include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/rw_lock.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/memory/memcpy.h" @@ -48,13 +50,13 @@ class SelectedRows { SelectedRows(const std::vector& rows, const int64_t& height) : rows_(rows), height_(height) { value_.reset(new Tensor()); - auto_grown_mutex_.reset(new std::mutex); + rwlock_.reset(new RWLock); } SelectedRows() { height_ = 0; value_.reset(new Tensor()); - auto_grown_mutex_.reset(new std::mutex); + rwlock_.reset(new RWLock); } platform::Place place() const { return value_->place(); } @@ -74,47 +76,51 @@ class SelectedRows { void set_rows(const Vector& rows) { rows_ = rows; } /* - * @brief wheter has the specified key in the table. + * @brief Get the index of key in rows + * + * @return -1 if the key does not exists. + */ + int64_t Index(int64_t key) const { + auto it = std::find(rows_.begin(), rows_.end(), key); + if (it == rows_.end()) { + PADDLE_THROW("id %s not in table", key); + } + return static_cast(std::distance(rows_.begin(), it)); + } + + /* + * @brief whether has the specified key in the table. * * @return true if the key is exists. */ bool HasKey(int64_t key) const; /* - * @brief Get value by the key list, if the + * @brief Get value by the key list. + * Note!!! this interface is only used when selected_rows is used as + * parameters + * for distribute lookup table. * * @return a list of pair which contains the non-exists key and the index in * the value */ - std::vector> Get(const std::vector& keys, - framework::Tensor* value) const; + void Get(const framework::Tensor& ids, framework::Tensor* value, + bool auto_grown = false); /* - * @brief Set a key-value pair into the table. - * This function will double the value memory if it's not engouth. + * @brief Get the index of the key from id_to_index_ map. If the key not + * exist, + * add the key into id_to_index_. * - * @note: - * 1. The first dim of the value should be 1 - * 2. The value should be initialized and the data type - * should be the same with the table. - * - * @return true if the key is a new one, otherwise false + * Note!!! this interface is only used when selected_rows is used as + * parameters + * for distribute lookup table. * + * @return index of the key. */ - bool Set(int64_t key, const Tensor& value); + int64_t AutoGrownIndex(int64_t key, bool auto_grown); - /* - * @brief Get the index of key in rows - * - * @return -1 if the key does not exists. - */ - int64_t Index(int64_t key) const { - auto it = std::find(rows_.begin(), rows_.end(), key); - if (it == rows_.end()) { - return static_cast(-1); - } - return static_cast(std::distance(rows_.begin(), it)); - } + void SyncIndex(); DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); @@ -127,9 +133,10 @@ class SelectedRows { // SelectedRows are simply concated when adding together. Until a // SelectedRows add a Tensor, will the duplicate rows be handled. Vector rows_; + std::unordered_map id_to_index_; std::unique_ptr value_{nullptr}; int64_t height_; - std::unique_ptr auto_grown_mutex_{nullptr}; + std::unique_ptr rwlock_{nullptr}; }; /* diff --git a/paddle/fluid/framework/selected_rows_test.cc b/paddle/fluid/framework/selected_rows_test.cc index eefcaa5672..5ca864cfdf 100644 --- a/paddle/fluid/framework/selected_rows_test.cc +++ b/paddle/fluid/framework/selected_rows_test.cc @@ -9,8 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/selected_rows.h" +#include +#include // NOLINT + #include "gtest/gtest.h" +#include "paddle/fluid/framework/selected_rows.h" namespace paddle { namespace framework { @@ -59,39 +62,129 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) { ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims()); } -TEST_F(SelectedRowsTester, SparseTable) { +TEST(SelectedRows, SparseTable) { platform::CPUPlace cpu; SelectedRows table; + + int64_t table_size = 100; + int64_t embedding_width = 8; // initialize a sparse table - table.mutable_value()->Resize(framework::make_ddim({1, 100})); - table.mutable_value()->mutable_data(cpu); - table.mutable_rows()->push_back(1); + table.mutable_value()->Resize( + framework::make_ddim({table_size, embedding_width})); + auto* data = table.mutable_value()->mutable_data(cpu); + for (int64_t i = 0; i < table_size; ++i) { + for (int64_t j = 0; j < embedding_width; ++j) { + data[i * embedding_width + j] = static_cast(i); + } + } + ASSERT_EQ(table.AutoGrownIndex(10, true), 0); + ASSERT_EQ(table.AutoGrownIndex(8, true), 1); + ASSERT_EQ(table.AutoGrownIndex(8, true), 1); + ASSERT_EQ(table.AutoGrownIndex(6, true), 2); + ASSERT_TRUE(table.HasKey(10)); + ASSERT_TRUE(table.HasKey(8)); + ASSERT_TRUE(table.HasKey(6)); + ASSERT_EQ(table.rows().size(), 3); + + framework::Tensor ids; + ids.Resize(framework::make_ddim({4})); + auto* ids_data = ids.mutable_data(cpu); + ids_data[0] = static_cast(6); + ids_data[1] = static_cast(6); + ids_data[2] = static_cast(8); + ids_data[3] = static_cast(10); - int64_t key = 10000; - int64_t non_key = 999; - framework::Tensor value; - value.Resize(framework::make_ddim({1, 100})); - auto ptr = value.mutable_data(cpu); - ptr[0] = static_cast(10); + framework::Tensor get_value; + auto* value_data = get_value.mutable_data( + framework::make_ddim({4, embedding_width}), cpu); + table.Get(ids, &get_value); - ASSERT_EQ(table.rows().size(), static_cast(1)); - ASSERT_EQ(table.HasKey(key), false); + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[0 * embedding_width + j], 2); + } + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[1 * embedding_width + j], 2); + } + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[2 * embedding_width + j], 1); + } + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[3 * embedding_width + j], 0); + } +} - table.Set(key, value); +void f1(SelectedRows* table, int table_size) { + for (int i = 1000000; i > 0; --i) { + auto id = i % table_size; + int64_t index1 = table->AutoGrownIndex(id, true); + int64_t index2 = table->AutoGrownIndex(id, false); + int64_t index3 = table->AutoGrownIndex(id, true); + ASSERT_EQ(index1, index2); + ASSERT_EQ(index2, index3); + } +} - ASSERT_EQ(table.rows().size(), static_cast(2)); - ASSERT_EQ(table.HasKey(key), true); - // check re-allocate - ASSERT_EQ(table.value().dims()[0], static_cast(4)); +void f2(SelectedRows* table, int table_size) { + for (int i = 0; i < 1000000; ++i) { + auto id = i % table_size; + int64_t index1 = table->AutoGrownIndex(id, true); + int64_t index2 = table->AutoGrownIndex(id, false); + int64_t index3 = table->AutoGrownIndex(id, true); + ASSERT_EQ(index1, index2); + ASSERT_EQ(index2, index3); + } +} - framework::Tensor get_value; - get_value.mutable_data(framework::make_ddim({2, 100}), cpu); - std::vector keys({non_key, key}); - auto non_key_pairs = table.Get(keys, &get_value); +void f3(SelectedRows* table, int table_size) { + clock_t t1 = clock(); + for (int i = 100000; i > 0; --i) { + auto id1 = table->AutoGrownIndex(i % table_size, true); + auto id2 = table->Index(i % table_size); + ASSERT_EQ(id1, id2); + } + clock_t t2 = clock(); + std::cout << "f3 run time:" << t2 - t1 << std::endl; +} + +void f4(SelectedRows* table, int table_size) { + clock_t t1 = clock(); + for (int i = 0; i < 100000; ++i) { + auto id1 = table->AutoGrownIndex(i % table_size, true); + auto id2 = table->Index(i % table_size); + ASSERT_EQ(id1, id2); + } + clock_t t2 = clock(); + std::cout << "f4 run time:" << t2 - t1 << std::endl; +} + +TEST(SelectedRows, MultiThreadAutoIndex) { + platform::CPUPlace cpu; + SelectedRows table; + + int64_t table_size = 100000; + int64_t embedding_width = 8; + // initialize a sparse table + table.mutable_value()->Resize( + framework::make_ddim({table_size, embedding_width})); + auto* data = table.mutable_value()->mutable_data(cpu); + for (int64_t i = 0; i < table_size; ++i) { + for (int64_t j = 0; j < embedding_width; ++j) { + data[i * embedding_width + j] = static_cast(i); + } + } - ASSERT_EQ(get_value.data()[100], static_cast(10)); - ASSERT_EQ(non_key_pairs.size(), static_cast(1)); - ASSERT_EQ(non_key_pairs[0].first, non_key); + std::thread t1(f1, &table, table_size); + std::thread t11(f1, &table, table_size); + std::thread t2(f2, &table, table_size); + std::thread t22(f2, &table, table_size); + t1.join(); + t11.join(); + t2.join(); + t22.join(); + std::thread t3(f3, &table, table_size); + std::thread t4(f4, &table, table_size); + t3.join(); + t4.join(); } } // namespace framework diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc index b50830c362..d6176e1443 100644 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -78,10 +78,9 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, int64_t rows_numel) { CreateVarsOnScope(scope, place); auto w = scope->Var("w")->GetMutable(); - auto rows = w->mutable_rows(); - for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i); auto w_value = w->mutable_value(); w_value->Resize({rows_numel, 10}); + for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true); auto ptr = w_value->mutable_data(*place); diff --git a/paddle/fluid/operators/lookup_sparse_table_op.cc b/paddle/fluid/operators/lookup_sparse_table_op.cc index 2ce11e712f..de3f0990e1 100644 --- a/paddle/fluid/operators/lookup_sparse_table_op.cc +++ b/paddle/fluid/operators/lookup_sparse_table_op.cc @@ -17,7 +17,6 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { @@ -46,10 +45,6 @@ class LookupSparseTableOp : public framework::OperatorBase { auto out_var = scope.FindVar(Output("Out")); auto w_var = scope.FindVar(Input("W")); auto ids_var = scope.FindVar(Input("Ids")); - unsigned int seed = static_cast(Attr("seed")); - float min = Attr("min"); - float max = Attr("max"); - bool auto_grown_table = Attr("auto_grown_table"); PADDLE_ENFORCE(out_var->IsType(), "The type of Out var should be LodTensor."); @@ -60,46 +55,17 @@ class LookupSparseTableOp : public framework::OperatorBase { auto &ids_t = ids_var->Get(); auto out_t = out_var->GetMutable(); auto w_t = w_var->GetMutable(); - std::vector keys; - keys.resize(ids_t.numel()); - for (int64_t i = 0; i < ids_t.numel(); ++i) { - keys[i] = ids_t.data()[i]; - } // TODO(Yancey1989): support CUDA Place for the sparse table platform::CPUPlace cpu; auto out_shape = w_t->value().dims(); - out_shape[0] = keys.size(); + out_shape[0] = ids_t.numel(); out_t->Resize(out_shape); out_t->mutable_data(cpu, w_t->value().type()); PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()), framework::proto::VarType::FP32, "The sparse table only support FP32"); - auto non_keys_pair = w_t->Get(keys, out_t); - if (!auto_grown_table) { - PADDLE_ENFORCE_EQ(non_keys_pair.size(), static_cast(0), - "there is some keys does exists in the sparse table."); - } - auto value_shape = w_t->value().dims(); - value_shape[0] = 1; - for (const auto &it : non_keys_pair) { - const auto key = it.first; - const auto index = it.second; - framework::Tensor value; - value.Resize(value_shape); - auto data = value.mutable_data(cpu); - - std::minstd_rand engine; - engine.seed(seed); - std::uniform_real_distribution dist(min, max); - int64_t size = value.numel(); - for (int64_t i = 0; i < size; ++i) { - data[i] = dist(engine); - } - w_t->Set(key, value); - memory::Copy(cpu, out_t->mutable_data(cpu) + index * value.numel(), - cpu, value.data(), value.numel() * sizeof(float)); - } + w_t->Get(ids_t, out_t, true); } }; @@ -121,21 +87,6 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker { "Otherwise the given value indicates padding the output " "with zeros whenever lookup encounters it in Ids.") .SetDefault(kNoPadding); - AddAttr("min", - "(float, default -1.0) " - "Minimum value of uniform random") - .SetDefault(-1.0f); - AddAttr("max", - "(float, default 1.0) " - "Maximum value of uniform random") - .SetDefault(1.0f); - AddAttr("seed", - "(int, default 0) " - "Random seed used for generating samples. " - "0 means use a seed generated by the system." - "Note that if seed is not 0, this operator will always " - "generate the same random numbers every time.") - .SetDefault(0); AddAttr("auto_grown_table", "(bool default false)" "Whether create new value if for nonexistent key.") diff --git a/paddle/fluid/operators/sgd_op.h b/paddle/fluid/operators/sgd_op.h index 2685ce217e..d8b0165b2a 100644 --- a/paddle/fluid/operators/sgd_op.h +++ b/paddle/fluid/operators/sgd_op.h @@ -111,7 +111,7 @@ class SGDOpKernel : public framework::OpKernel { for (size_t i = 0; i < grad.rows().size(); i++) { PADDLE_ENFORCE(grad.rows()[i] < grad.height(), "Input rows index should less than height"); - int64_t id_index = param.Index(grad.rows()[i]); + int64_t id_index = param_out->AutoGrownIndex(grad.rows()[i], false); PADDLE_ENFORCE_GE(id_index, static_cast(0), "id should be in the table"); for (int64_t j = 0; j < grad_row_width; j++) { diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index edd1baa4ac..5248767c2e 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -30,8 +30,10 @@ class CPUUniformRandomKernel : public framework::OpKernel { tensor = out_var->GetMutable(); } else if (out_var->IsType()) { auto shape = ctx.Attr>("shape"); - tensor = out_var->GetMutable()->mutable_value(); + auto* selected_rows = out_var->GetMutable(); + tensor = selected_rows->mutable_value(); tensor->Resize(framework::make_ddim(shape)); + selected_rows->mutable_rows()->reserve(shape[0]); } else { PADDLE_THROW( "uniform_random_op's output only" diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 40ced8e1c7..6c58478b0d 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -249,6 +249,7 @@ PYBIND11_PLUGIN(core) { self.set_rows(new_rows); #endif }) + .def("sync_index", [](SelectedRows &instance) { instance.SyncIndex(); }) .def("rows", [](SelectedRows &self) { auto rows = self.rows(); std::vector new_rows; diff --git a/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py index 7f75d0e6e9..11e5d8b536 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py @@ -21,36 +21,27 @@ import paddle.fluid.core as core from paddle.fluid.op import Operator -def output_hist(out): - hist, _ = np.histogram(out, range=(-5, 10)) - hist = hist.astype("float32") - hist /= float(out.size) - prob = 0.1 * np.ones((10)) - return hist, prob - - class TestLookupSpraseTable(OpTest): def check_with_place(self, place): scope = core.Scope() - # create and initialize Id Variable - ids = scope.var("Ids").get_tensor() - ids_array = np.array([0, 2, 3, 5, 100]).astype("int64") - ids.set(ids_array, place) - # create and initialize W Variable - rows = [0, 1, 2, 3, 4, 5, 6] - row_numel = 10000 + table_size = 10000 + row_numel = 8 w_selected_rows = scope.var('W').get_selected_rows() - w_selected_rows.set_height(len(rows)) - w_selected_rows.set_rows(rows) - w_array = np.ones((len(rows), row_numel)).astype("float32") - for i in range(len(rows)): + w_selected_rows.set_height(table_size) + w_array = np.ones((table_size, row_numel)).astype("float32") + for i in range(table_size): w_array[i] *= i w_tensor = w_selected_rows.get_tensor() w_tensor.set(w_array, place) + # create and initialize Id Variable + ids = scope.var("Ids").get_tensor() + ids_array1 = np.array([0, 2, 3, 2, 5, 0, 100]).astype("int64") + ids.set(ids_array1, place) + # create Out Variable out_tensor = scope.var('Out').get_tensor() @@ -66,16 +57,28 @@ class TestLookupSpraseTable(OpTest): lookup_table.run(scope, place) # get result from Out - result_array = np.array(out_tensor) + result_array1 = np.array(out_tensor) # all(): return True if all elements of the iterable are true (or if the iterable is empty) - for idx, row in enumerate(ids_array[:-2]): - assert (row == result_array[idx]).all() + assert (result_array1[0] == w_array[0]).all() + assert (result_array1[1] == w_array[1]).all() + assert (result_array1[2] == w_array[2]).all() + assert (result_array1[3] == w_array[1]).all() + assert (result_array1[4] == w_array[3]).all() + assert (result_array1[5] == w_array[0]).all() + assert (result_array1[6] == w_array[4]).all() + + # create and initialize Id Variable + ids = scope.var("Ids").get_tensor() + ids_array2 = np.array([4, 2, 3, 7, 100000]).astype("int64") + ids.set(ids_array2, place) + lookup_table.run(scope, place) - # check the random value - hist, prob = output_hist(result_array[-1]) - self.assertTrue( - np.allclose( - hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) + result_array2 = np.array(out_tensor) + assert (result_array2[0] == w_array[5]).all() + assert (result_array2[1] == w_array[1]).all() + assert (result_array2[2] == w_array[2]).all() + assert (result_array2[3] == w_array[6]).all() + assert (result_array2[4] == w_array[7]).all() def test_w_is_selected_rows(self): places = [core.CPUPlace()] diff --git a/python/paddle/fluid/tests/unittests/test_sgd_op.py b/python/paddle/fluid/tests/unittests/test_sgd_op.py index c14a83b4bb..b46e4bfb86 100644 --- a/python/paddle/fluid/tests/unittests/test_sgd_op.py +++ b/python/paddle/fluid/tests/unittests/test_sgd_op.py @@ -126,6 +126,7 @@ class TestSGDOpOptimizeSelectedRows(unittest.TestCase): w_selected_rows = scope.var('Param').get_selected_rows() w_selected_rows.set_height(len(param_rows)) w_selected_rows.set_rows(param_rows) + w_selected_rows.sync_index() w_array = np.ones((len(param_rows), row_width)).astype("float32") for i in range(len(param_rows)): w_array[i] *= i From 4069262f0e06da6f240ac4c9e90ba0403a94bc4d Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Fri, 17 Aug 2018 10:58:01 +0800 Subject: [PATCH 94/94] Revert ""cherry picked operators changes" (#12184)" (#12747) This reverts commit bf3c34960f2a59a2616957f8fb4107b2ac7aa02b. --- paddle/fluid/operators/activation_op.cu | 4 +- paddle/fluid/operators/activation_op.h | 12 ++-- paddle/fluid/operators/assign_value_op.cu.cc | 5 +- paddle/fluid/operators/conv_cudnn_op.cu.cc | 56 +++++++----------- paddle/fluid/operators/cross_entropy_op.cu | 12 ++-- paddle/fluid/operators/elementwise_add_op.cu | 3 +- paddle/fluid/operators/elementwise_div_op.cu | 9 +-- paddle/fluid/operators/elementwise_mul_op.cu | 8 +-- .../fluid/operators/elementwise_op_function.h | 4 +- paddle/fluid/operators/elementwise_sub_op.cu | 8 +-- paddle/fluid/operators/fill_constant_op.cc | 53 +++++++++++------ paddle/fluid/operators/fill_constant_op.cu.cc | 26 -------- paddle/fluid/operators/fill_constant_op.h | 48 --------------- paddle/fluid/operators/fill_op.cc | 2 +- paddle/fluid/operators/gaussian_random_op.cu | 2 - paddle/fluid/operators/math/cross_entropy.cu | 20 +------ paddle/fluid/operators/math/cross_entropy.h | 17 ------ .../operators/math/selected_rows_functor.cu | 13 +--- paddle/fluid/operators/math/softmax.cu | 3 - paddle/fluid/operators/mean_op.cu | 10 ++-- paddle/fluid/operators/mean_op.h | 2 +- paddle/fluid/operators/mul_op.cu.cc | 7 +-- paddle/fluid/operators/pool_cudnn_op.cu.cc | 6 +- paddle/fluid/operators/scale_op.cu | 6 +- paddle/fluid/operators/softmax_cudnn_op.cu.cc | 3 +- paddle/fluid/operators/softmax_op.cu.cc | 3 +- paddle/fluid/operators/sum_op.cu | 5 +- paddle/fluid/operators/sum_op.h | 2 +- paddle/fluid/operators/top_k_op.cu | 28 ++------- paddle/fluid/operators/uniform_random_op.cu | 59 +++---------------- 30 files changed, 108 insertions(+), 328 deletions(-) delete mode 100644 paddle/fluid/operators/fill_constant_op.cu.cc delete mode 100644 paddle/fluid/operators/fill_constant_op.h diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index d3a7ceed46..27487b396c 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -26,8 +26,6 @@ namespace plat = paddle::platform; act_type##_grad, ops::ActivationGradKernel>, \ ops::ActivationGradKernel>, \ - ops::ActivationGradKernel>); + ops::grad_functor>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 48f3b5a5bc..9124151926 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -333,7 +333,8 @@ struct SqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = static_cast(0.5) * dout / out; + const Out out_conj = Eigen::numext::conj(out); + dx.device(d) = static_cast(0.5) * dout / out_conj; } }; @@ -739,7 +740,7 @@ struct PowGradFunctor : public BaseActivationFunctor { typename dX> void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(factor) * - x.pow(static_cast(factor) - static_cast(1)); + x.pow(static_cast(factor - static_cast(1))); } }; @@ -862,11 +863,10 @@ struct SwishGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - T b = static_cast(beta); auto temp1 = static_cast(1) / - (static_cast(1) + (static_cast(-b) * x).exp()); - auto temp2 = temp1 * (static_cast(1) - (b * out)); - dx.device(d) = dout * ((b * out) + temp2); + (static_cast(1) + (static_cast(-beta) * x).exp()); + auto temp2 = temp1 * (static_cast(1) - (beta * out)); + dx.device(d) = dout * ((beta * out) + temp2); } }; diff --git a/paddle/fluid/operators/assign_value_op.cu.cc b/paddle/fluid/operators/assign_value_op.cu.cc index 0ff174b388..08bfde5dc9 100644 --- a/paddle/fluid/operators/assign_value_op.cu.cc +++ b/paddle/fluid/operators/assign_value_op.cu.cc @@ -13,10 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/assign_value_op.h" -#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; -namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel, - ops::AssignValueKernel, - ops::AssignValueKernel); + ops::AssignValueKernel); diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 59bfe8f61d..22cbf680c0 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -39,27 +39,6 @@ using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = static_cast(1024) * 1024 * 1024; -template -// bool EnableFp16(const T& dummy, const DeviceContext& dev_ctx, -bool EnableFp16(const DeviceContext& dev_ctx, - cudnnConvolutionDescriptor_t cudnn_conv_desc) { -#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) - // Tensor core is supported since the volta GPU and - // is only enabled when input and filter data are float16 - if (dev_ctx.GetComputeCapability() >= 70 && - std::type_index(typeid(T)) == - std::type_index(typeid(platform::float16))) { - PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); - return true; - } else { - PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_DEFAULT_MATH)); - } -#endif - return false; -} - template class CUDNNConvOpKernel : public framework::OpKernel { public: @@ -149,14 +128,27 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnnConvolutionFwdAlgo_t algo; auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); - if (EnableFp16(dev_ctx, cudnn_conv_desc)) { + + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + +#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) + // Tensor core is supported since the volta GPU and + // is only enabled when input and filter data are float16 + if (dev_ctx.GetComputeCapability() >= 70 && + std::type_index(typeid(T)) == + std::type_index(typeid(platform::float16))) { + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); + // Currently tensor core is only enabled using this algo algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; } else { - PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( - handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_DEFAULT_MATH)); } +#endif // get workspace size able to allocate CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( @@ -296,9 +288,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } else { data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } - if (EnableFp16(dev_ctx, cudnn_conv_desc)) { - data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - } CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( @@ -318,9 +307,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } else { filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } - if (EnableFp16(dev_ctx, cudnn_conv_desc)) { - filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; - } CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( @@ -376,8 +362,7 @@ REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel); + paddle::operators::CUDNNConvGradOpKernel); REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel, @@ -385,5 +370,4 @@ REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel) + paddle::operators::CUDNNConvGradOpKernel); diff --git a/paddle/fluid/operators/cross_entropy_op.cu b/paddle/fluid/operators/cross_entropy_op.cu index 65fd3a5dbc..30dbd5bd3d 100644 --- a/paddle/fluid/operators/cross_entropy_op.cu +++ b/paddle/fluid/operators/cross_entropy_op.cu @@ -13,16 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cross_entropy_op.h" -#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; -namespace plat = paddle::platform; using CUDACtx = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, - ops::CrossEntropyOpKernel, - ops::CrossEntropyOpKernel); -REGISTER_OP_CUDA_KERNEL( - cross_entropy_grad, ops::CrossEntropyGradientOpKernel, - ops::CrossEntropyGradientOpKernel, - ops::CrossEntropyGradientOpKernel); + ops::CrossEntropyOpKernel); +REGISTER_OP_CUDA_KERNEL(cross_entropy_grad, + ops::CrossEntropyGradientOpKernel, + ops::CrossEntropyGradientOpKernel); diff --git a/paddle/fluid/operators/elementwise_add_op.cu b/paddle/fluid/operators/elementwise_add_op.cu index f9f5c66d34..dfff518f17 100644 --- a/paddle/fluid/operators/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise_add_op.cu @@ -30,5 +30,4 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, - ops::ElementwiseAddGradKernel, - ops::ElementwiseAddGradKernel); + ops::ElementwiseAddGradKernel); diff --git a/paddle/fluid/operators/elementwise_div_op.cu b/paddle/fluid/operators/elementwise_div_op.cu index 4cc7ba0f43..588d1f7420 100644 --- a/paddle/fluid/operators/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise_div_op.cu @@ -14,24 +14,19 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_div_op.h" -#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; -namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_div, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel); + ops::ElementwiseDivKernel); REGISTER_OP_CUDA_KERNEL( elementwise_div_grad, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, - ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel); + int64_t>); diff --git a/paddle/fluid/operators/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise_mul_op.cu index 350d43168d..2fb1b4bee6 100644 --- a/paddle/fluid/operators/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise_mul_op.cu @@ -14,25 +14,19 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_mul_op.h" -#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; -namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_mul, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel); + ops::ElementwiseMulKernel); REGISTER_OP_CUDA_KERNEL( elementwise_mul_grad, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel); diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index 7223a972d2..bc3e95e904 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -350,7 +350,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( int j = blockIdx.x; int i = threadIdx.x; int tid = threadIdx.x; - T val(0); + T val = 0; do { int x_offset = i * w + j; @@ -418,7 +418,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( int tid = threadIdx.x; int j = blockIdx.x; - T val(0); + T val = 0; int ttid = tid; while (true) { diff --git a/paddle/fluid/operators/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise_sub_op.cu index ff3f6f8a2c..8709f686f9 100644 --- a/paddle/fluid/operators/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise_sub_op.cu @@ -14,25 +14,19 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_sub_op.h" -#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; -namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_sub, ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, - ops::ElementwiseSubKernel, - ops::ElementwiseSubKernel); + ops::ElementwiseSubKernel); REGISTER_OP_CUDA_KERNEL( elementwise_sub_grad, ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, - ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel); diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 862249269e..130f18dde4 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -12,28 +12,48 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fill_constant_op.h" -#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { -class FillConstantOp : public framework::OperatorWithKernel { +class FillConstantInferShape : public framework::InferShapeBase { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { + void operator()(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FillConstantOp should not be null."); - auto& shape = ctx->Attrs().Get>("shape"); + auto &shape = ctx->Attrs().Get>("shape"); ctx->SetOutputDim("Out", framework::make_ddim(shape)); } +}; + +class FillConstantOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + auto data_type = + static_cast(Attr("dtype")); + auto value = Attr("value"); + auto force_cpu = Attr("force_cpu"); + auto &out = + *scope.FindVar(Output("Out"))->GetMutable(); + out.Resize(framework::make_ddim(Attr>("shape"))); + if (force_cpu) { + auto cpu = platform::CPUPlace(); + out.mutable_data(cpu, framework::ToTypeIndex(data_type)); + } else { + out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); + } - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - static_cast(ctx.Attr("dtype")), - ctx.device_context()); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + math::set_constant(dev_ctx, &out, value); } }; @@ -67,11 +87,6 @@ Fill up a variable with specified constant value. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, ops::FillConstantOpMaker, +REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, + ops::FillConstantInferShape, ops::FillConstantOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - fill_constant, - ops::FillConstantOpKernel, - ops::FillConstantOpKernel, - ops::FillConstantOpKernel, - ops::FillConstantOpKernel) diff --git a/paddle/fluid/operators/fill_constant_op.cu.cc b/paddle/fluid/operators/fill_constant_op.cu.cc deleted file mode 100644 index 51ccaefa43..0000000000 --- a/paddle/fluid/operators/fill_constant_op.cu.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/fill_constant_op.h" -#include "paddle/fluid/platform/float16.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - fill_constant, - ops::FillConstantOpKernel, - ops::FillConstantOpKernel, - ops::FillConstantOpKernel, - ops::FillConstantOpKernel, - ops::FillConstantOpKernel) diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h deleted file mode 100644 index b2a2a7b2fa..0000000000 --- a/paddle/fluid/operators/fill_constant_op.h +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/math_function.h" - -namespace paddle { -namespace operators { - -template -class FillConstantOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto data_type = - static_cast(ctx.Attr("dtype")); - auto value = ctx.Attr("value"); - auto force_cpu = ctx.Attr("force_cpu"); - auto* out = ctx.Output("Out"); - out->Resize(framework::make_ddim(ctx.Attr>("shape"))); - if (force_cpu) { - auto cpu = platform::CPUPlace(); - out->mutable_data(cpu, framework::ToTypeIndex(data_type)); - } else { - out->mutable_data(ctx.GetPlace(), framework::ToTypeIndex(data_type)); - } - - math::set_constant(ctx.template device_context(), out, - value); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index 352a17c927..925dc19061 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -70,6 +69,7 @@ class FillOp : public framework::OperatorBase { framework::VisitDataType( dtype, FillOpVisitor(&tensor, Attr>("value"))); + if (!force_cpu && platform::is_gpu_place(place)) { // Copy tensor to out platform::DeviceContextPool &pool = diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index b490723795..7784856417 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -15,7 +15,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -61,7 +60,6 @@ class GPUGaussianRandomKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(gaussian_random, paddle::operators::GPUGaussianRandomKernel, paddle::operators::GPUGaussianRandomKernel); diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index 58b85abf82..0de58d5fdd 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -15,25 +15,11 @@ limitations under the License. */ #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" -#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { namespace math { -template -HOSTDEVICE T log(const T& val) { - return std::log(val); -} - -template <> -HOSTDEVICE platform::float16 log(const platform::float16& val) { - // strage bug, hlog is not exists. - return static_cast(0); - // half tmp = static_cast(val); - // return static_cast(hlog(tmp)); -} - namespace { template __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, @@ -49,12 +35,12 @@ template __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, const int class_num) { int tid = threadIdx.x; - T val(0); + T val = 0; int idx = blockIdx.x * class_num + tid; int end = blockIdx.x * class_num + class_num; for (; idx < end; idx += blockDim.x) { - val += math::TolerableValue()(log(X[idx])) * label[idx]; + val += math::TolerableValue()(std::log(X[idx])) * label[idx]; } val = paddle::platform::reduceSum(val, tid, blockDim.x); @@ -98,8 +84,6 @@ class CrossEntropyFunctor { template class CrossEntropyFunctor; template class CrossEntropyFunctor; -template class CrossEntropyFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/cross_entropy.h b/paddle/fluid/operators/math/cross_entropy.h index 2e4e4781c2..adc5b3fe47 100644 --- a/paddle/fluid/operators/math/cross_entropy.h +++ b/paddle/fluid/operators/math/cross_entropy.h @@ -13,10 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/hostdevice.h" namespace paddle { @@ -35,21 +33,6 @@ struct TolerableValue { } }; -// float16 value clip behave different. -using paddle::platform::float16; -using paddle::platform::isfinite; -template <> -struct TolerableValue { - HOSTDEVICE float16 operator()(const float16& x) const { - if (isfinite(x)) - return x; - else if (x > static_cast(0)) - return std::numeric_limits::max(); - else - return std::numeric_limits::min(); - } -}; - template class CrossEntropyFunctor { public: diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index 00dbfc11a2..a92762c7fe 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -18,7 +18,6 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/cuda_primitives.h" -#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -77,7 +76,6 @@ struct SelectedRowsAdd { template struct SelectedRowsAdd; template struct SelectedRowsAdd; -template struct SelectedRowsAdd; namespace { template @@ -122,7 +120,7 @@ struct SelectedRowsAddTensor { auto* out_data = output->data(); SetConstant functor; - functor(context, output, static_cast(0)); + functor(context, output, 0.0); const int block_size = 256; dim3 threads(block_size, 1); @@ -140,8 +138,6 @@ struct SelectedRowsAddTensor { template struct SelectedRowsAddTensor; template struct SelectedRowsAddTensor; -template struct SelectedRowsAddTensor; template struct SelectedRowsAddTo { @@ -181,8 +177,6 @@ template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; -template struct SelectedRowsAddTo; namespace { template @@ -235,8 +229,6 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; -template struct SelectedRowsAddToTensor; namespace scatter { @@ -284,7 +276,7 @@ struct MergeAdd { context.GetPlace()); math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), static_cast(0)); + constant_functor(context, out.mutable_value(), 0.0); auto* out_data = out.mutable_value()->data(); auto* input_data = input.value().data(); @@ -308,7 +300,6 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; -template struct MergeAdd; template __global__ void UpdateToTensorKernel(const T* selected_rows, diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 785c4baecb..3effe77625 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -94,15 +94,12 @@ void SoftmaxGradCUDNNFunctor::operator()( template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; -template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; -template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index 07aa23754f..91e0ab28ef 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -12,16 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#define EIGEN_USE_GPU + #include "paddle/fluid/operators/mean_op.h" -#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; -namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( mean, ops::MeanKernel, - ops::MeanKernel, - ops::MeanKernel); + ops::MeanKernel); REGISTER_OP_CUDA_KERNEL( mean_grad, ops::MeanGradKernel, - ops::MeanGradKernel, - ops::MeanGradKernel); + ops::MeanGradKernel); diff --git a/paddle/fluid/operators/mean_op.h b/paddle/fluid/operators/mean_op.h index a41d50ae0b..362e9f9ae8 100644 --- a/paddle/fluid/operators/mean_op.h +++ b/paddle/fluid/operators/mean_op.h @@ -55,7 +55,7 @@ class MeanGradKernel : public framework::OpKernel { IG->mutable_data(context.GetPlace()); T ig_size = static_cast(IG->numel()); - Eigen::DSizes bcast(static_cast(ig_size)); + Eigen::DSizes bcast(ig_size); EigenVector::Flatten(*IG).device( *context.template device_context().eigen_device()) = diff --git a/paddle/fluid/operators/mul_op.cu.cc b/paddle/fluid/operators/mul_op.cu.cc index 6c5a83c6a5..81f3e42bf4 100644 --- a/paddle/fluid/operators/mul_op.cu.cc +++ b/paddle/fluid/operators/mul_op.cu.cc @@ -20,7 +20,6 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel, ops::MulKernel, ops::MulKernel); -REGISTER_OP_CUDA_KERNEL( - mul_grad, ops::MulGradKernel, - ops::MulGradKernel, - ops::MulGradKernel); +REGISTER_OP_CUDA_KERNEL(mul_grad, + ops::MulGradKernel, + ops::MulGradKernel); diff --git a/paddle/fluid/operators/pool_cudnn_op.cu.cc b/paddle/fluid/operators/pool_cudnn_op.cu.cc index 9fdbee818a..31f083565f 100644 --- a/paddle/fluid/operators/pool_cudnn_op.cu.cc +++ b/paddle/fluid/operators/pool_cudnn_op.cu.cc @@ -174,8 +174,7 @@ REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel); REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, - ops::PoolCUDNNGradOpKernel, - ops::PoolCUDNNGradOpKernel); + ops::PoolCUDNNGradOpKernel); REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel, @@ -183,5 +182,4 @@ REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel); REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, - ops::PoolCUDNNGradOpKernel, - ops::PoolCUDNNGradOpKernel); + ops::PoolCUDNNGradOpKernel); diff --git a/paddle/fluid/operators/scale_op.cu b/paddle/fluid/operators/scale_op.cu index d266867046..04c802da12 100644 --- a/paddle/fluid/operators/scale_op.cu +++ b/paddle/fluid/operators/scale_op.cu @@ -13,15 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/scale_op.h" -#include "paddle/fluid/platform/float16.h" -namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( scale, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel, - paddle::operators::ScaleKernel); + int64_t>); diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index c2d45c3d2e..2bdb23e999 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -78,5 +78,4 @@ REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, ops::SoftmaxCUDNNKernel, ops::SoftmaxCUDNNKernel); REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, - ops::SoftmaxGradCUDNNKernel, - ops::SoftmaxGradCUDNNKernel); + ops::SoftmaxGradCUDNNKernel); diff --git a/paddle/fluid/operators/softmax_op.cu.cc b/paddle/fluid/operators/softmax_op.cu.cc index 19359b7eef..5fb4f011d9 100644 --- a/paddle/fluid/operators/softmax_op.cu.cc +++ b/paddle/fluid/operators/softmax_op.cu.cc @@ -23,5 +23,4 @@ REGISTER_OP_CUDA_KERNEL( ops::SoftmaxKernel); REGISTER_OP_CUDA_KERNEL( softmax_grad, ops::SoftmaxGradKernel, - ops::SoftmaxGradKernel, - ops::SoftmaxGradKernel); + ops::SoftmaxGradKernel); diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index db4c2d6c11..89bcd1bbc8 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -11,13 +11,10 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/sum_op.h" -#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; -namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( sum, ops::SumKernel, ops::SumKernel, ops::SumKernel, - ops::SumKernel, - ops::SumKernel); + ops::SumKernel); diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index dda6772796..49a4afb3a8 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -46,7 +46,7 @@ class SumKernel : public framework::OpKernel { if (!in_place) { math::SetConstant constant_functor; constant_functor(context.template device_context(), out, - static_cast(0)); + 0.0); } math::SelectedRowsAddToTensor functor; diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 5fc0784f66..9da8551eb2 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -11,19 +11,16 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/cuda_device_function.h" -#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -using paddle::platform::float16; template struct Pair { @@ -35,11 +32,6 @@ struct Pair { id = id; } - __device__ __forceinline__ void clear() { - v = -INFINITY; - id = -1; - } - __device__ __forceinline__ void operator=(const Pair& in) { v = in.v; id = in.id; @@ -61,12 +53,6 @@ struct Pair { int64_t id; }; -template <> -__device__ __forceinline__ void Pair::clear() { - v = platform::raw_uint16_to_float16(0x400); - id = -1; -} - template __device__ __forceinline__ void AddTo(Pair topk[], const Pair& p, int beam_size) { @@ -164,7 +150,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, if (k < MaxLength - (*beam)) { topk[k] = topk[k + *beam]; } else { - topk[k].clear(); + topk[k].set(-INFINITY, -1); } } if (!(*is_empty)) { @@ -174,7 +160,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, } *max = topk[MaxLength - 1]; - if ((*max).v == static_cast(-1)) *is_empty = true; + if ((*max).v == -1) *is_empty = true; *beam = 0; } } @@ -195,7 +181,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, if (k < MaxLength - *beam) { topk[k] = topk[k + *beam]; } else { - topk[k].set(std::numeric_limits::min(), -1); + topk[k].set(-INFINITY, -1); } } if (!(*is_empty)) { @@ -287,7 +273,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, bool firststep = true; for (int k = 0; k < MaxLength; k++) { - topk[k].clear(); + topk[k].set(-INFINITY, -1); } while (k) { ThreadGetTopK(topk, &beam, k, @@ -339,7 +325,5 @@ class TopkOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL( - top_k, paddle::operators::TopkOpCUDAKernel, - paddle::operators::TopkOpCUDAKernel, - paddle::operators::TopkOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel); diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index 2b8039a0c1..e1c7323a30 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -11,14 +11,10 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include #include -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { @@ -40,11 +36,6 @@ struct UniformGenerator { } }; -template -struct CastFunctor { - HOSTDEVICE V operator()(const T& a) { return static_cast(a); } -}; - // It seems that Eigen::Tensor::random in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. @@ -75,50 +66,18 @@ class GPUUniformRandomKernel : public framework::OpKernel { T max = static_cast(context.Attr("max")); thrust::counting_iterator index_sequence_begin(0); int64_t size = tensor->numel(); - if (out_var->IsType() && - std::type_index(typeid(T)) == - std::type_index(typeid(platform::float16))) { - framework::Tensor master_copy_tensor; - master_copy_tensor.Resize(tensor->dims()); - float* master_copy_tensor_data = - master_copy_tensor.mutable_data(context.GetPlace()); - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(master_copy_tensor_data), - UniformGenerator(static_cast(min), - static_cast(max), seed)); - platform::Transform trans; - auto* in_begin = master_copy_tensor.data(); - auto* in_end = in_begin + master_copy_tensor.numel(); - auto* out_begin = tensor->mutable_data(context.GetPlace()); - trans(context.template device_context(), - in_begin, in_end, out_begin, CastFunctor()); - } else { - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - UniformGenerator(min, max, seed)); - } - if (VLOG_IS_ON(5)) { - framework::Tensor cpu_tensor; - framework::TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor); - auto& dev_ctx = - *platform::DeviceContextPool::Instance().Get(context.GetPlace()); - dev_ctx.Wait(); - auto x = framework::EigenVector::Flatten(cpu_tensor); - VLOG(5) << "The Uniform output " << x; - } + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + UniformGenerator(min, max, seed)); } }; } // namespace operators } // namespace paddle -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - uniform_random, paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); -REGISTER_OP_CUDA_KERNEL( - uniform_random_batch_size_like, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); +REGISTER_OP_CUDA_KERNEL(uniform_random, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); +REGISTER_OP_CUDA_KERNEL(uniform_random_batch_size_like, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel);