From 43cee33a23cb1dce5501f0642a38f97dad8cea45 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 2 Aug 2018 15:38:44 +0800 Subject: [PATCH 01/29] add mkl packed gemm --- paddle/fluid/operators/math/blas.h | 37 +++++++++++++ paddle/fluid/operators/math/blas_impl.h | 73 +++++++++++++++++++++++++ paddle/fluid/platform/dynload/mklml.h | 8 +++ 3 files changed, 118 insertions(+) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 70f88f24f6..2470df9d78 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -90,6 +90,23 @@ class Blas { void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, int ldc) const; + template + T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, + const int K) const; + + template + void GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M, + int N, int K, const T alpha, const T* src, const int ld, + T* dst) const; + + template + void GEMM_COMPUTE(int transA, int transB, int M, int N, int K, const T* A, + const int lda, const T* B, const int ldb, T beta, T* C, + const int ldc) const; + + template + void GEMM_FREE(T* data) const; + template void MatMul(const framework::Tensor& mat_a, bool trans_a, const framework::Tensor& mat_b, bool trans_b, T alpha, @@ -146,6 +163,26 @@ class BlasT : private Blas { Base()->template GEMM(args...); } + template + T* GEMM_ALLOC(ARGS... args) const { + Base()->template GEMM_ALLOC(args...); + } + + template + void GEMM_PACK(ARGS... args) const { + Base()->template GEMM_PACK(args...); + } + + template + void GEMM_COMPUTE(ARGS... args) const { + Base()->template GEMM_COMPUTE(args...); + } + + template + void GEMM_FREE(ARGS... args) const { + Base()->template GEMM_FREE(args...); + } + template void MatMul(ARGS... args) const { Base()->template MatMul(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index a0802ef90c..4164fe6229 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -31,6 +31,26 @@ struct CBlas { platform::dynload::cblas_sgemm(args...); } + template + static float *GEMM_ALLOC(ARGS... args) { + return platform::dynload::cblas_sgemm_alloc(args...); + } + + template + static void GEMM_PACK(ARGS... args) { + platform::dynload::cblas_sgemm_pack(args...); + } + + template + static void GEMM_COMPUTE(ARGS... args) { + platform::dynload::cblas_sgemm_compute(args...); + } + + template + static void GEMM_FREE(ARGS... args) { + platform::dynload::cblas_sgemm_free(args...); + } + #ifdef PADDLE_WITH_LIBXSMM template static void SMM_GEMM(ARGS... args) { @@ -71,6 +91,26 @@ struct CBlas { platform::dynload::cblas_dgemm(args...); } + template + static double *GEMM_ALLOC(ARGS... args) { + return platform::dynload::cblas_dgemm_alloc(args...); + } + + template + static void GEMM_PACK(ARGS... args) { + platform::dynload::cblas_dgemm_pack(args...); + } + + template + static void GEMM_COMPUTE(ARGS... args) { + platform::dynload::cblas_dgemm_compute(args...); + } + + template + static void GEMM_FREE(ARGS... args) { + platform::dynload::cblas_dgemm_free(args...); + } + #ifdef PADDLE_WITH_LIBXSMM template static void SMM_GEMM(ARGS... args) { @@ -224,6 +264,39 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, beta, C, ldc); } +template <> +template +T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, + const int M, const int N, + const int K) const { + return CBlas::GEMM_ALLOC(id, M, N, K); +} + +template <> +template +void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, + const CBLAS_TRANSPOSE trans, + int M, int N, int K, + const T alpha, const T *src, + const int ld, T *dst) const { + CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); +} + +template <> +template +void Blas::GEMM_COMPUTE( + int transA, int transB, int M, int N, int K, const T *A, const int lda, + const T *B, const int ldb, T beta, T *C, const int ldc) const { + CBlas::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, + beta, C, ldc); +} + +template <> +template +void Blas::GEMM_FREE(T *data) const { + CBlas::GEMM_FREE(data); +} + template <> template void Blas::GEMM(CBLAS_TRANSPOSE transA, diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 17acefe8cd..9e7a616094 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -60,6 +60,14 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_batch); \ __macro(vsAdd); \ __macro(vdAdd); \ + __macro(cblas_sgemm_alloc); \ + __macro(cblas_sgemm_pack); \ + __macro(cblas_sgemm_compute); \ + __macro(cblas_sgemm_free); \ + __macro(cblas_dgemm_alloc); \ + __macro(cblas_dgemm_pack); \ + __macro(cblas_dgemm_compute); \ + __macro(cblas_dgemm_free); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); From d9cc6b18662295383f925e12b6a5e0cf5dabd14a Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 3 Aug 2018 13:31:53 +0800 Subject: [PATCH 02/29] replace gru compute with details --- paddle/fluid/operators/gru_op.h | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index 3b0d93e54b..4e534789ce 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -16,7 +16,10 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" +#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" +#include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence2batch.h" @@ -94,6 +97,7 @@ class GRUKernel : public framework::OpKernel { context.Attr("activation")); auto active_gate = math::detail::GetActivationType( context.Attr("gate_activation")); + auto blas = math::GetBlas(dev_ctx); for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -105,9 +109,27 @@ class GRUKernel : public framework::OpKernel { gru_value.output_value = hidden_t.data(); gru_value.gate_value = gate_t.data(); gru_value.reset_output_value = reset_hidden_prev_t.data(); - math::GRUUnitFunctor::compute( - dev_ctx, gru_value, frame_size, cur_batch_size, active_node, - active_gate); + if (gru_value.prev_out_value) { + blas.GEMM(false, false, cur_batch_size, frame_size * 2, frame_size, 1, + gru_value.prev_out_value, frame_size, gru_value.gate_weight, + frame_size * 2, 1, gru_value.gate_value, frame_size * 3); + } + + math::detail::forward_reset_output( + math::detail::forward::gru_resetOutput(), gru_value, frame_size, + cur_batch_size, active_gate); + + if (gru_value.prev_out_value) { + blas.GEMM(false, false, cur_batch_size, frame_size, frame_size, 1, + gru_value.reset_output_value, frame_size, + gru_value.state_weight, frame_size, 1, + gru_value.gate_value + frame_size * 2, frame_size * 3); + } + + math::detail::forward_final_output( + math::detail::forward::gru_finalOutput(), gru_value, frame_size, + cur_batch_size, active_node); + gru_value.prev_out_value = gru_value.output_value; } From 8c23f7c4f029ba3b22481ae27b721b7a4ac18e8b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 3 Aug 2018 18:44:36 +0800 Subject: [PATCH 03/29] fix blas and use packed weight --- paddle/fluid/operators/gru_op.h | 34 ++++++++++++++++++++++++------ paddle/fluid/operators/math/blas.h | 2 +- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index 4e534789ce..a9450337e7 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -98,6 +98,23 @@ class GRUKernel : public framework::OpKernel { auto active_gate = math::detail::GetActivationType( context.Attr("gate_activation")); auto blas = math::GetBlas(dev_ctx); + + // TODO(TJ): make a class, make one pack + T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_gate); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, + frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, + packed_gate); + T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_state); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, + frame_size, T(1.0), gru_value.state_weight, frame_size, + packed_state); + for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -110,9 +127,10 @@ class GRUKernel : public framework::OpKernel { gru_value.gate_value = gate_t.data(); gru_value.reset_output_value = reset_hidden_prev_t.data(); if (gru_value.prev_out_value) { - blas.GEMM(false, false, cur_batch_size, frame_size * 2, frame_size, 1, - gru_value.prev_out_value, frame_size, gru_value.gate_weight, - frame_size * 2, 1, gru_value.gate_value, frame_size * 3); + blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size, + frame_size * 2, frame_size, gru_value.prev_out_value, + frame_size, packed_gate, frame_size * 2, T(1), + gru_value.gate_value, frame_size * 3); } math::detail::forward_reset_output( @@ -120,10 +138,10 @@ class GRUKernel : public framework::OpKernel { cur_batch_size, active_gate); if (gru_value.prev_out_value) { - blas.GEMM(false, false, cur_batch_size, frame_size, frame_size, 1, - gru_value.reset_output_value, frame_size, - gru_value.state_weight, frame_size, 1, - gru_value.gate_value + frame_size * 2, frame_size * 3); + blas.GEMM_COMPUTE( + CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, + gru_value.reset_output_value, frame_size, packed_state, frame_size, + T(1), gru_value.gate_value + frame_size * 2, frame_size * 3); } math::detail::forward_final_output( @@ -132,6 +150,8 @@ class GRUKernel : public framework::OpKernel { gru_value.prev_out_value = gru_value.output_value; } + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); math::Batch2LoDTensorFunctor to_seq; batch_hidden->set_lod(batch_gate->lod()); diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 2470df9d78..485e96227e 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -165,7 +165,7 @@ class BlasT : private Blas { template T* GEMM_ALLOC(ARGS... args) const { - Base()->template GEMM_ALLOC(args...); + return Base()->template GEMM_ALLOC(args...); } template From 54c95e49f09e70233adb363b5b612cb8d427c116 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 6 Aug 2018 11:34:11 +0800 Subject: [PATCH 04/29] fix blas --- paddle/fluid/operators/math/blas.h | 4 ++++ paddle/fluid/operators/math/blas_impl.h | 2 ++ 2 files changed, 6 insertions(+) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 485e96227e..2558154e0b 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -90,6 +90,7 @@ class Blas { void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, int ldc) const; +#ifdef PADDLE_WITH_MKLML template T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, const int K) const; @@ -106,6 +107,7 @@ class Blas { template void GEMM_FREE(T* data) const; +#endif template void MatMul(const framework::Tensor& mat_a, bool trans_a, @@ -163,6 +165,7 @@ class BlasT : private Blas { Base()->template GEMM(args...); } +#ifdef PADDLE_WITH_MKLML template T* GEMM_ALLOC(ARGS... args) const { return Base()->template GEMM_ALLOC(args...); @@ -182,6 +185,7 @@ class BlasT : private Blas { void GEMM_FREE(ARGS... args) const { Base()->template GEMM_FREE(args...); } +#endif template void MatMul(ARGS... args) const { diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 4164fe6229..bf33821079 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -264,6 +264,7 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, beta, C, ldc); } +#ifdef PADDLE_WITH_MKLML template <> template T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, @@ -296,6 +297,7 @@ template void Blas::GEMM_FREE(T *data) const { CBlas::GEMM_FREE(data); } +#endif template <> template From 18c322c2a1133bcc6350aea1b148bb6d767e6933 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 6 Aug 2018 11:38:59 +0800 Subject: [PATCH 05/29] seperate cpu and gpu implementations for gru kernel compute --- paddle/fluid/operators/gru_op.cc | 138 +++++++++++++++++++++++++++- paddle/fluid/operators/gru_op.cu.cc | 90 ++++++++++++++++++ paddle/fluid/operators/gru_op.h | 123 ------------------------- 3 files changed, 225 insertions(+), 126 deletions(-) diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 5c74687882..4847eb3626 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -211,6 +211,139 @@ class GRUGradOp : public framework::OperatorWithKernel { } }; +template +class GRUCPUKernel : public framework::OpKernel { + public: + void BatchCompute(const framework::ExecutionContext& context) const { + using DeviceContext = paddle::platform::CPUDeviceContext; + auto* input = context.Input("Input"); + auto* h0 = context.Input("H0"); + auto* weight = context.Input("Weight"); + const T* weight_data = weight->data(); + auto* bias = context.Input("Bias"); + auto* batch_gate = context.Output("BatchGate"); + batch_gate->mutable_data(context.GetPlace()); + auto* batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + auto* batch_hidden = context.Output("BatchHidden"); + batch_hidden->mutable_data(context.GetPlace()); + auto* hidden = context.Output("Hidden"); + hidden->mutable_data(context.GetPlace()); + + auto hidden_dims = hidden->dims(); + + bool is_reverse = context.Attr("is_reverse"); + math::LoDTensor2BatchFunctor to_batch; + auto& dev_ctx = context.template device_context(); + to_batch(dev_ctx, *input, batch_gate, true, is_reverse); + + if (bias) { + math::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, *bias, batch_gate); + } + + int frame_size = hidden_dims[1]; + math::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + Tensor ordered_h0; + + framework::Vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState( + context.template device_context(), *h0, order, + &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto active_node = math::detail::GetActivationType( + context.Attr("activation")); + auto active_gate = math::detail::GetActivationType( + context.Attr("gate_activation")); + +#ifdef PADDLE_WITH_MKLML + auto blas = math::GetBlas(dev_ctx); + // TODO(TJ): make a class + T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_gate); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, + frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, + packed_gate); + T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_state); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, + frame_size, T(1.0), gru_value.state_weight, frame_size, + packed_state); +#endif + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + +#ifdef PADDLE_WITH_MKLML + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size, + frame_size * 2, frame_size, gru_value.prev_out_value, + frame_size, packed_gate, frame_size * 2, T(1), + gru_value.gate_value, frame_size * 3); + } + + math::detail::forward_reset_output( + math::detail::forward::gru_resetOutput(), gru_value, frame_size, + cur_batch_size, active_gate); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE( + CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, + gru_value.reset_output_value, frame_size, packed_state, frame_size, + T(1), gru_value.gate_value + frame_size * 2, frame_size * 3); + } + + math::detail::forward_final_output( + math::detail::forward::gru_finalOutput(), gru_value, frame_size, + cur_batch_size, active_node); +#else + math::GRUUnitFunctor::compute( + dev_ctx, gru_value, frame_size, cur_batch_size, active_node, + active_gate); +#endif + gru_value.prev_out_value = gru_value.output_value; + } +#ifdef PADDLE_WITH_MKLML + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); +#endif + + math::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(dev_ctx, *batch_hidden, hidden); + } + + void Compute(const framework::ExecutionContext& context) const override { + BatchCompute(context); + } +}; + } // namespace operators } // namespace paddle @@ -218,9 +351,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(gru_grad, ops::GRUGradOp); -REGISTER_OP_CPU_KERNEL( - gru, ops::GRUKernel, - ops::GRUKernel); +REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel, + ops::GRUCPUKernel); REGISTER_OP_CPU_KERNEL( gru_grad, ops::GRUGradKernel, ops::GRUGradKernel); diff --git a/paddle/fluid/operators/gru_op.cu.cc b/paddle/fluid/operators/gru_op.cu.cc index baf455a840..55721c283d 100644 --- a/paddle/fluid/operators/gru_op.cu.cc +++ b/paddle/fluid/operators/gru_op.cu.cc @@ -14,6 +14,96 @@ limitations under the License. */ #include "paddle/fluid/operators/gru_op.h" +namespace paddle { +namespace operators { + +template +class GRUKernel : public framework::OpKernel { + public: + void BatchCompute(const framework::ExecutionContext& context) const { + auto* input = context.Input("Input"); + auto* h0 = context.Input("H0"); + auto* weight = context.Input("Weight"); + const T* weight_data = weight->data(); + auto* bias = context.Input("Bias"); + auto* batch_gate = context.Output("BatchGate"); + batch_gate->mutable_data(context.GetPlace()); + auto* batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + auto* batch_hidden = context.Output("BatchHidden"); + batch_hidden->mutable_data(context.GetPlace()); + auto* hidden = context.Output("Hidden"); + hidden->mutable_data(context.GetPlace()); + + auto hidden_dims = hidden->dims(); + + bool is_reverse = context.Attr("is_reverse"); + math::LoDTensor2BatchFunctor to_batch; + auto& dev_ctx = context.template device_context(); + to_batch(dev_ctx, *input, batch_gate, true, is_reverse); + + if (bias) { + math::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, *bias, batch_gate); + } + + int frame_size = hidden_dims[1]; + math::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + Tensor ordered_h0; + + framework::Vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState( + context.template device_context(), *h0, order, + &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto active_node = math::detail::GetActivationType( + context.Attr("activation")); + auto active_gate = math::detail::GetActivationType( + context.Attr("gate_activation")); + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + math::GRUUnitFunctor::compute( + dev_ctx, gru_value, frame_size, cur_batch_size, active_node, + active_gate); + gru_value.prev_out_value = gru_value.output_value; + } + + math::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(dev_ctx, *batch_hidden, hidden); + } + + void Compute(const framework::ExecutionContext& context) const override { + BatchCompute(context); + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( gru, ops::GRUKernel, diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index a9450337e7..0bf4e6bc44 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -40,129 +40,6 @@ inline void ReorderInitState(const DeviceContext& ctx, row_shuffle(ctx, src, index_lod, dst, indexed_src); } -template -class GRUKernel : public framework::OpKernel { - public: - void BatchCompute(const framework::ExecutionContext& context) const { - auto* input = context.Input("Input"); - auto* h0 = context.Input("H0"); - auto* weight = context.Input("Weight"); - const T* weight_data = weight->data(); - auto* bias = context.Input("Bias"); - auto* batch_gate = context.Output("BatchGate"); - batch_gate->mutable_data(context.GetPlace()); - auto* batch_reset_hidden_prev = - context.Output("BatchResetHiddenPrev"); - batch_reset_hidden_prev->mutable_data(context.GetPlace()); - auto* batch_hidden = context.Output("BatchHidden"); - batch_hidden->mutable_data(context.GetPlace()); - auto* hidden = context.Output("Hidden"); - hidden->mutable_data(context.GetPlace()); - - auto hidden_dims = hidden->dims(); - - bool is_reverse = context.Attr("is_reverse"); - math::LoDTensor2BatchFunctor to_batch; - auto& dev_ctx = context.template device_context(); - to_batch(dev_ctx, *input, batch_gate, true, is_reverse); - - if (bias) { - math::RowwiseAdd add_bias; - add_bias(dev_ctx, *batch_gate, *bias, batch_gate); - } - - int frame_size = hidden_dims[1]; - math::GRUMetaValue gru_value; - gru_value.gate_weight = const_cast(weight_data); - gru_value.state_weight = - const_cast(weight_data + 2 * frame_size * frame_size); - Tensor ordered_h0; - - framework::Vector order(batch_gate->lod()[2]); - - if (h0) { - // Since the batch computing for GRU reorders the input sequences - // according to their length. The initialized cell state also needs - // to reorder. - ReorderInitState( - context.template device_context(), *h0, order, - &ordered_h0, true); - gru_value.prev_out_value = ordered_h0.data(); - } else { - gru_value.prev_out_value = nullptr; - } - auto batch_starts = batch_gate->lod()[0]; - size_t num_batch = batch_starts.size() - 1; - auto active_node = math::detail::GetActivationType( - context.Attr("activation")); - auto active_gate = math::detail::GetActivationType( - context.Attr("gate_activation")); - auto blas = math::GetBlas(dev_ctx); - - // TODO(TJ): make a class, make one pack - T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, - frame_size * 2 /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE(packed_gate); - blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, - frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, - packed_gate); - T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, - frame_size /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE(packed_state); - blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, - frame_size, T(1.0), gru_value.state_weight, frame_size, - packed_state); - - for (size_t n = 0; n < num_batch; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - Tensor gate_t = batch_gate->Slice(bstart, bend); - Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); - Tensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size, - frame_size * 2, frame_size, gru_value.prev_out_value, - frame_size, packed_gate, frame_size * 2, T(1), - gru_value.gate_value, frame_size * 3); - } - - math::detail::forward_reset_output( - math::detail::forward::gru_resetOutput(), gru_value, frame_size, - cur_batch_size, active_gate); - - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE( - CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, - gru_value.reset_output_value, frame_size, packed_state, frame_size, - T(1), gru_value.gate_value + frame_size * 2, frame_size * 3); - } - - math::detail::forward_final_output( - math::detail::forward::gru_finalOutput(), gru_value, frame_size, - cur_batch_size, active_node); - - gru_value.prev_out_value = gru_value.output_value; - } - blas.GEMM_FREE(packed_gate); - blas.GEMM_FREE(packed_state); - - math::Batch2LoDTensorFunctor to_seq; - batch_hidden->set_lod(batch_gate->lod()); - to_seq(dev_ctx, *batch_hidden, hidden); - } - - void Compute(const framework::ExecutionContext& context) const override { - BatchCompute(context); - } -}; - template class GRUGradKernel : public framework::OpKernel { public: From 61052cdbc6cd048410aebb0df514fba6f8931347 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Wed, 8 Aug 2018 10:22:36 +0000 Subject: [PATCH 06/29] polish high frequency enforce error message --- paddle/fluid/platform/enforce.h | 10 ++++++---- paddle/fluid/platform/gpu_info.cc | 20 ++++++++++---------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 566485cd3c..cad60275a2 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -263,7 +263,8 @@ inline void throw_on_error(T e) { * PADDLE_ENFORCE_EQ(a, b); * * will raise an expression described as follows: - * "enforce a == b failed, 1 != 2" with detailed stack information. + * "Data check failed. Expected input a == b, but received a(1) != b(2)." + * with detailed stack information. * * extra messages is also supported, for example: * PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2) @@ -292,9 +293,10 @@ inline void throw_on_error(T e) { #define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ do { \ if (UNLIKELY(!((__VAL0)__CMP(__VAL1)))) { \ - PADDLE_THROW("enforce %s " #__CMP " %s failed, %s " #__INV_CMP \ - " %s\n%s", \ - #__VAL0, #__VAL1, paddle::string::to_string(__VAL0), \ + PADDLE_THROW("Data check failed. Expected %s " #__CMP \ + " %s, but received %s:%s " #__INV_CMP " %s:%s.\n%s", \ + #__VAL0, #__VAL1, #__VAL0, \ + paddle::string::to_string(__VAL0), #__VAL1, \ paddle::string::to_string(__VAL1), \ paddle::string::Sprintf("" __VA_ARGS__)); \ } \ diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 4cee93f3a4..f9e2e8c69d 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -100,25 +100,25 @@ size_t GpuMinChunkSize() { size_t GpuMaxChunkSize() { size_t total = 0; - size_t available = 0; + size_t available_memory = 0; - GpuMemoryUsage(&available, &total); - VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/" + GpuMemoryUsage(&available_memory, &total); + VLOG(10) << "GPU Usage " << available_memory / 1024 / 1024 << "M/" << total / 1024 / 1024 << "M"; size_t reserving = static_cast(0.05 * total); // If available less than minimum chunk size, no usable memory exists. - available = - std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(), - total - reserving); + available_memory = std::min( + std::max(available_memory, GpuMinChunkSize()) - GpuMinChunkSize(), + total - reserving); // Reserving the rest memory for page tables, etc. - size_t allocating = static_cast(FLAGS_fraction_of_gpu_memory_to_use * - (total - reserving)); + size_t allocating_memory = static_cast( + FLAGS_fraction_of_gpu_memory_to_use * (total - reserving)); - PADDLE_ENFORCE_LE(allocating, available); + PADDLE_ENFORCE_LE(allocating_memory, available_memory); - return allocating; + return allocating_memory; } void GpuMemcpyAsync(void *dst, const void *src, size_t count, From b1dd4149b90dde40640de2baf0190d611cb24486 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Thu, 9 Aug 2018 03:02:25 +0000 Subject: [PATCH 07/29] adjust enforce test cases --- paddle/fluid/platform/enforce_test.cc | 30 +++++++++++++++++---------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index 0e8684581a..8dcf39fdaa 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -54,7 +54,9 @@ TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) { PADDLE_ENFORCE_EQ(a, 1 + 3); } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; - HasPrefix(StringPiece(error.what()), "enforce a == 1 + 3 failed, 2 != 4"); + HasPrefix( + StringPiece(error.what()), + "Data check failed. Expected a == 1 + 3, but received a:2 != 1 + 3:4."); } EXPECT_TRUE(caught_exception); } @@ -67,7 +69,8 @@ TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) { } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; HasPrefix(StringPiece(error.what()), - "enforce a == 1 + 3 failed, 2 != 4\ntheir size not match"); + "Data check failed. Expected a == 1 + 3, but received a:2 != 1 + " + "3:4.\ntheir size not match"); } EXPECT_TRUE(caught_exception); } @@ -84,8 +87,9 @@ TEST(ENFORCE_NE, FAIL) { PADDLE_ENFORCE_NE(1.0, 1UL); } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; - EXPECT_TRUE(HasPrefix(StringPiece(error.what()), - "enforce 1.0 != 1UL failed, 1 == 1")) + EXPECT_TRUE(HasPrefix( + StringPiece(error.what()), + "Data check failed. Expected 1.0 != 1UL, but received 1.0:1 == 1UL:1.")) << error.what() << " does not have expected prefix"; } EXPECT_TRUE(caught_exception); @@ -98,8 +102,9 @@ TEST(ENFORCE_GT, FAIL) { PADDLE_ENFORCE_GT(1, 2UL); } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; - EXPECT_TRUE( - HasPrefix(StringPiece(error.what()), "enforce 1 > 2UL failed, 1 <= 2")); + EXPECT_TRUE(HasPrefix( + StringPiece(error.what()), + "Data check failed. Expected 1 > 2UL, but received 1:1 <= 2UL:2.")); } EXPECT_TRUE(caught_exception); } @@ -116,8 +121,9 @@ TEST(ENFORCE_GE, FAIL) { PADDLE_ENFORCE_GE(1, 2UL); } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; - EXPECT_TRUE( - HasPrefix(StringPiece(error.what()), "enforce 1 >= 2UL failed, 1 < 2")); + EXPECT_TRUE(HasPrefix( + StringPiece(error.what()), + "Data check failed. Expected 1 >= 2UL, but received 1:1 < 2UL:2.")); } EXPECT_TRUE(caught_exception); } @@ -135,8 +141,9 @@ TEST(ENFORCE_LE, FAIL) { PADDLE_ENFORCE_GT(1, 2UL); } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; - EXPECT_TRUE( - HasPrefix(StringPiece(error.what()), "enforce 1 > 2UL failed, 1 <= 2")); + EXPECT_TRUE(HasPrefix( + StringPiece(error.what()), + "Data check failed. Expected 1 > 2UL, but received 1:1 <= 2UL:2.")); } EXPECT_TRUE(caught_exception); } @@ -153,7 +160,8 @@ TEST(ENFORCE_LT, FAIL) { } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; EXPECT_TRUE(HasPrefix(StringPiece(error.what()), - "enforce 1UL < 0.12 failed, 1 >= 0.12")); + "Data check failed. Expected 1UL < 0.12, but " + "received 1UL:1 >= 0.12:0.12.")); } EXPECT_TRUE(caught_exception); } From 5377edd282bf4998d675d5551bb5b4e420fe4122 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 13 Aug 2018 11:35:11 +0800 Subject: [PATCH 08/29] refine packed condition --- paddle/fluid/operators/gru_op.cc | 135 ++++++++++++++++++------------- paddle/fluid/operators/gru_op.h | 3 - 2 files changed, 79 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 4847eb3626..2b5094925c 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -14,6 +14,11 @@ limitations under the License. */ #include "paddle/fluid/operators/gru_op.h" #include +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" +#include "paddle/fluid/operators/math/detail/gru_kernel.h" + +DECLARE_int32(paddle_num_threads); namespace paddle { namespace operators { @@ -264,76 +269,94 @@ class GRUCPUKernel : public framework::OpKernel { gru_value.prev_out_value = nullptr; } auto batch_starts = batch_gate->lod()[0]; - size_t num_batch = batch_starts.size() - 1; + size_t seq_len = batch_starts.size() - 1; auto active_node = math::detail::GetActivationType( context.Attr("activation")); auto active_gate = math::detail::GetActivationType( context.Attr("gate_activation")); #ifdef PADDLE_WITH_MKLML - auto blas = math::GetBlas(dev_ctx); - // TODO(TJ): make a class - T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, - frame_size * 2 /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE(packed_gate); - blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, - frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, - packed_gate); - T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, - frame_size /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE(packed_state); - blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, - frame_size, T(1.0), gru_value.state_weight, frame_size, - packed_state); -#endif - for (size_t n = 0; n < num_batch; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - Tensor gate_t = batch_gate->Slice(bstart, bend); - Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); - Tensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); + if (FLAGS_paddle_num_threads >= 4) { + auto blas = math::GetBlas(dev_ctx); + T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_gate); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, + frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, + packed_gate); + T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_state); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, + frame_size, T(1.0), gru_value.state_weight, frame_size, + packed_state); + for (size_t n = 0; n < seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; -#ifdef PADDLE_WITH_MKLML - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size, - frame_size * 2, frame_size, gru_value.prev_out_value, - frame_size, packed_gate, frame_size * 2, T(1), - gru_value.gate_value, frame_size * 3); - } + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); - math::detail::forward_reset_output( - math::detail::forward::gru_resetOutput(), gru_value, frame_size, - cur_batch_size, active_gate); + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE( + CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2, + frame_size, gru_value.prev_out_value, frame_size, packed_gate, + frame_size * 2, T(1), gru_value.gate_value, frame_size * 3); + } - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE( - CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, - gru_value.reset_output_value, frame_size, packed_state, frame_size, - T(1), gru_value.gate_value + frame_size * 2, frame_size * 3); + math::detail::forward_reset_output( + math::detail::forward::gru_resetOutput(), gru_value, frame_size, + cur_batch_size, active_gate); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE( + CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, + gru_value.reset_output_value, frame_size, packed_state, + frame_size, T(1), gru_value.gate_value + frame_size * 2, + frame_size * 3); + } + + math::detail::forward_final_output( + math::detail::forward::gru_finalOutput(), gru_value, frame_size, + cur_batch_size, active_node); + + gru_value.prev_out_value = gru_value.output_value; } - math::detail::forward_final_output( - math::detail::forward::gru_finalOutput(), gru_value, frame_size, - cur_batch_size, active_node); -#else - math::GRUUnitFunctor::compute( - dev_ctx, gru_value, frame_size, cur_batch_size, active_node, - active_gate); + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); + } else { #endif - gru_value.prev_out_value = gru_value.output_value; - } + for (size_t n = 0; n < seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + + math::GRUUnitFunctor::compute( + dev_ctx, gru_value, frame_size, cur_batch_size, active_node, + active_gate); + + gru_value.prev_out_value = gru_value.output_value; + } #ifdef PADDLE_WITH_MKLML - blas.GEMM_FREE(packed_gate); - blas.GEMM_FREE(packed_state); + } #endif - math::Batch2LoDTensorFunctor to_seq; batch_hidden->set_lod(batch_gate->lod()); to_seq(dev_ctx, *batch_hidden, hidden); diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index 0bf4e6bc44..0b551e8046 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -16,10 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" -#include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence2batch.h" From 171a0e2b42e1dea669056bbc6093e572e1c88e0a Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 13 Aug 2018 18:01:43 +0800 Subject: [PATCH 09/29] add some comment --- paddle/fluid/operators/gru_op.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 2b5094925c..087f903a8b 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -276,6 +276,7 @@ class GRUCPUKernel : public framework::OpKernel { context.Attr("gate_activation")); #ifdef PADDLE_WITH_MKLML + // use MKL packed to speedup GEMM if (FLAGS_paddle_num_threads >= 4) { auto blas = math::GetBlas(dev_ctx); T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, From 038cbf799d290e3e7cc129b59a2bea7b7e40055a Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 13 Aug 2018 22:49:58 +0800 Subject: [PATCH 10/29] add bias for fc op --- paddle/fluid/operators/fc_op.cc | 72 ++++++++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index a9ae1396db..5fee30e146 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -30,21 +30,34 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { auto w_dims = ctx->GetInputDim("W"); std::vector output_shape({in_dims[0], w_dims[1]}); + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim]."); + PADDLE_ENFORCE_EQ(bias_dims[1], framework::product(w_dims) / w_dims[0], + "The shape of Bias must be [1, dim]."); + } PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, "Fully Connected input should be 2-D or 4-D tensor."); PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4, "Fully Connected input should be 2-D or 4-D tensor."); + PADDLE_ENFORCE_EQ(framework::product(w_dims) / w_dims[0], + framework::product(in_dims) / in_dims[0], + "Fully Connected input and weigth size do not match."); + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->ShareLoD("Input", "Out"); } framework::OpKernelType FCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library{framework::LibraryType::kMKLDNN}; - framework::DataLayout layout{framework::DataLayout::kMKLDNN}; - + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + if (ctx.Attr("use_mkldnn");) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), layout, library); @@ -60,13 +73,22 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { if (ctx->HasOutput(framework::GradVarName("W"))) { ctx->SetOutputDim(framework::GradVarName("W"), w_dims); } + + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")); + ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims); + } } framework::OpKernelType FCOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library{framework::LibraryType::kMKLDNN}; - framework::DataLayout layout{framework::DataLayout::kMKLDNN}; - + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + if (ctx.Attr("use_mkldnn");) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), layout, library); @@ -75,12 +97,12 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( void FCOpMaker::Make() { AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); AddInput("W", "(Tensor), The second input tensor of fc op."); + AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x D") + .AsDispensable(); AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); - AddAttr("bias_attr", "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); AddComment(R"DOC( Fully Connected Operator. @@ -94,9 +116,39 @@ void FCOpMaker::Make() { )DOC"); } +template +class FCOpKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + auto input = ctx.Input("Input"); + auto w = ctx.Input("W"); + auto b = ctx.Input("Bias"); + + const T* input_data = input->data(); + const T* w_data = w->data(); + auto output = ctx.Output("Out"); + T* output_data = output->mutable_data(ctx.GetPlace()); + + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + std::vector output_shape({in_dims[0], w_dims[1]}); + + if (bias) { + const T* bias_data = bias->data(); + } + } +}; + } // namespace operators } // namespace paddle -REGISTER_OPERATOR(fc, paddle::operators::FCOp, paddle::operators::FCOpMaker, +namespace ops = paddle::operators; +REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(fc_grad, paddle::operators::FCOpGrad); +REGISTER_OPERATOR(fc_grad, ops::FCOpGrad); +REGISTER_OP_CPU_KERNEL(fc, ops::FCMKLDNNOpKernel, + ops::FCMKLDNNOpKernel); From e133df60373b92d1e35b2f34144e7067dbb9752b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 13 Aug 2018 23:40:58 +0800 Subject: [PATCH 11/29] enable native fc forward --- paddle/fluid/operators/fc_mkldnn_op.cc | 1 + paddle/fluid/operators/fc_op.cc | 55 +++++++++++++++----------- 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/fc_mkldnn_op.cc b/paddle/fluid/operators/fc_mkldnn_op.cc index 99fa659a35..68a47dd6ad 100644 --- a/paddle/fluid/operators/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/fc_mkldnn_op.cc @@ -128,6 +128,7 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel { PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4, "Input must be with 2 or 4 dimensions, i.e. NCHW"); + // TODO(intel): the src weight is io and mkldnn weight need be transposed ! PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4, "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW"); diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index 5fee30e146..e71f63c134 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -15,6 +15,8 @@ limitations under the License. */ #include "paddle/fluid/operators/fc_op.h" #include +DECLARE_int32(paddle_num_threads); + namespace paddle { namespace operators { @@ -25,25 +27,23 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { "Out(Output) of Fully Connected should not be null."); PADDLE_ENFORCE(ctx->HasInput("W"), "W(Input) of Fully Connected should not be null."); - + // NCHW auto in_dims = ctx->GetInputDim("Input"); + // IO, I=C*H*W auto w_dims = ctx->GetInputDim("W"); std::vector output_shape({in_dims[0], w_dims[1]}); if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim]."); - PADDLE_ENFORCE_EQ(bias_dims[1], framework::product(w_dims) / w_dims[0], + PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1], "The shape of Bias must be [1, dim]."); } PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, "Fully Connected input should be 2-D or 4-D tensor."); - - PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4, - "Fully Connected input should be 2-D or 4-D tensor."); - - PADDLE_ENFORCE_EQ(framework::product(w_dims) / w_dims[0], - framework::product(in_dims) / in_dims[0], + PADDLE_ENFORCE_EQ(w_dims.size(), 2UL, + "Fully Connected input should be 2-D tensor."); + PADDLE_ENFORCE_EQ(framework::product(in_dims) / in_dims[0], w_dims[0], "Fully Connected input and weigth size do not match."); ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); @@ -54,7 +54,7 @@ framework::OpKernelType FCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; - if (ctx.Attr("use_mkldnn");) { + if (ctx.Attr("use_mkldnn")) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -75,8 +75,9 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { } if (ctx->HasInput("Bias")) { + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), + "Should have bias grad"); auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")); ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims); } } @@ -85,7 +86,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; - if (ctx.Attr("use_mkldnn");) { + if (ctx.Attr("use_mkldnn")) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -95,9 +96,11 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( } void FCOpMaker::Make() { - AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); - AddInput("W", "(Tensor), The second input tensor of fc op."); - AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x D") + AddInput("Input", + "(Tensor), The input tensor of fully connected operator with format " + "(NCHW). "); + AddInput("W", "(Tensor), The weight fc op with shape (I, O)."); + AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O") .AsDispensable(); AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddAttr("use_mkldnn", @@ -120,25 +123,32 @@ template class FCOpKernel : public framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); - auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); auto input = ctx.Input("Input"); auto w = ctx.Input("W"); auto b = ctx.Input("Bias"); + auto output = ctx.Output("Out"); + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); const T* input_data = input->data(); const T* w_data = w->data(); - auto output = ctx.Output("Out"); T* output_data = output->mutable_data(ctx.GetPlace()); - auto in_dims = ctx->GetInputDim("Input"); - auto w_dims = ctx->GetInputDim("W"); - std::vector output_shape({in_dims[0], w_dims[1]}); + blas.GEMM(CblasNoTrans, CblasNoTrans, in_dims[0], w_dims[1], w_dims[0], + static_cast(1), input_data, w_data, static_cast(0), + output_data); if (bias) { const T* bias_data = bias->data(); +#pragma omp parallel for if (FLAGS_paddle_num_threads > 1) + for (int bs = 0; bs < in_dims[0]; bs++) { + blas.AXPY(w_dims[1], static_cast(1), bias_data, + output_data + bs * w_dimws[1]); + } } } }; @@ -150,5 +160,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(fc_grad, ops::FCOpGrad); -REGISTER_OP_CPU_KERNEL(fc, ops::FCMKLDNNOpKernel, - ops::FCMKLDNNOpKernel); +REGISTER_OP_CPU_KERNEL(fc, ops::FCOpKernel, ops::FCOpKernel); From 4b5986bb77b06432f44bcd7f1e9352f8ca5dae2f Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 14 Aug 2018 13:36:03 +0800 Subject: [PATCH 12/29] enable fc op in normal case --- paddle/fluid/operators/CMakeLists.txt | 6 ------ paddle/fluid/operators/fc_op.cc | 13 +++++++------ 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 4c3b8ec781..8cd80ca6be 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -295,12 +295,6 @@ op_library(channel_recv_op DEPS concurrency) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) -# The fully connected layer is deleted when the WITH_MKLDNN flag is OFF -# Because the fully connected layer has only one MKLDNN's operator -if(NOT WITH_MKLDNN) - list(REMOVE_ITEM GENERAL_OPS fc_op) -endif(NOT WITH_MKLDNN) - foreach(src ${GENERAL_OPS}) op_library(${src}) endforeach() diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index e71f63c134..ec8dfb659c 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fc_op.h" #include +#include "paddle/fluid/operators/math/blas.h" DECLARE_int32(paddle_num_threads); @@ -127,13 +128,13 @@ class FCOpKernel : public framework::OpKernel { "It must use CPUPlace."); auto input = ctx.Input("Input"); auto w = ctx.Input("W"); - auto b = ctx.Input("Bias"); + auto bias = ctx.Input("Bias"); auto output = ctx.Output("Out"); - auto in_dims = ctx->GetInputDim("Input"); - auto w_dims = ctx->GetInputDim("W"); + auto in_dims = input->dims(); + auto w_dims = w->dims(); - auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); const T* input_data = input->data(); const T* w_data = w->data(); T* output_data = output->mutable_data(ctx.GetPlace()); @@ -147,7 +148,7 @@ class FCOpKernel : public framework::OpKernel { #pragma omp parallel for if (FLAGS_paddle_num_threads > 1) for (int bs = 0; bs < in_dims[0]; bs++) { blas.AXPY(w_dims[1], static_cast(1), bias_data, - output_data + bs * w_dimws[1]); + output_data + bs * w_dims[1]); } } } From 45d0259a6746d67f23f538b0b20f51a4af2f6d3f Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 14 Aug 2018 15:05:04 +0800 Subject: [PATCH 13/29] add fc forward test --- .../fluid/tests/unittests/test_fc_op.py | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_fc_op.py diff --git a/python/paddle/fluid/tests/unittests/test_fc_op.py b/python/paddle/fluid/tests/unittests/test_fc_op.py new file mode 100644 index 0000000000..2bb920710a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fc_op.py @@ -0,0 +1,90 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +def fc_refer(matrix, with_bias): + in_n, in_c, in_h, in_w = matrix.input.shape + w_i, w_o = matrix.weights.shape + + x_data = np.reshape(matrix.input, [in_n, in_c * in_h * in_w]) + w_data = np.reshape(matrix.weights, [w_i, w_o]) + b_data = np.reshape(matrix.bias, [1, w_o]) + result = None + + if with_bias: + result = np.dot(x_data, w_data) + b_data + else: + result = np.dot(x_data, w_data) + + return result + + +class MatrixGenerate: + def __init__(self, mb, ic, oc, h, w): + self.input = np.random.random((mb, ic, h, w)).astype("float32") + self.weights = np.random.random((ic * h * w, oc)).astype("float32") + self.bias = np.random.random((1, oc)).astype("float32") + + +class TestFCOp(OpTest): + def setUp(self): + self.op_type = "fc" + self.matrix = MatrixGenerate(1, 10, 15, 3, 3) + + self.with_bias = True + if self.with_bias: + self.inputs = { + 'Input': self.matrix.input, + 'W': self.matrix.weights, + 'Bias': self.matrix.bias + } + else: + self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights} + + self.attrs = {'use_mkldnn': False} + + self.outputs = {'Out': fc_refer(self.matrix, self.with_bias)} + + def test_check_output(self): + self.check_output() + + +class TestFCOpBiasBoth(TestFCOp): + def init_shapes(self, mb, ic, oc, h, w): + for with_bias in {True, False}: + self.with_bias = with_bias + self.matrix = MatrixGenerate(mb, ic, oc, h, w) + + +class TestFCOp1(TestFCOpBiasBoth): + def init_op_type(self): + self.init_shapes(2, 8, 10, 1, 1) + + +class TestFCOp2(TestFCOpBiasBoth): + def init_op_type(self): + self.init_shapes(4, 5, 6, 2, 2) + + +class TestFCOp4(TestFCOpBiasBoth): + def init_op_type(self): + self.init_shapes(1, 32, 64, 3, 3) + + +if __name__ == "__main__": + unittest.main() From b9dbb7c5cbad8e25cb16af07af6b58764c27ae6e Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 14 Aug 2018 15:47:15 +0800 Subject: [PATCH 14/29] fix bias attri in mkldnn fc --- paddle/fluid/operators/fc_mkldnn_op.cc | 10 +++++++--- .../paddle/fluid/tests/unittests/test_fc_mkldnn_op.py | 9 ++------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/fc_mkldnn_op.cc b/paddle/fluid/operators/fc_mkldnn_op.cc index 68a47dd6ad..e595f1a627 100644 --- a/paddle/fluid/operators/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/fc_mkldnn_op.cc @@ -125,14 +125,16 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel { auto input = ctx.Input("Input"); auto w = ctx.Input("W"); + auto bias = ctx.Input("Bias"); PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4, "Input must be with 2 or 4 dimensions, i.e. NCHW"); - // TODO(intel): the src weight is io and mkldnn weight need be transposed ! + // TODO(intel friends): the native weight format is io, + // but the mkldnn weight format is oihw, which may need be transposed. PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4, "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW"); - bool with_bias = ctx.Attr("bias_attr"); + bool with_bias = bias != nullptr; MKLDNNMD md(input, w, with_bias); std::shared_ptr pd = @@ -155,6 +157,7 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_memory = mem.dst(output_data); auto src_memory = mem.src(input_data); auto weights_memory = mem.weights(w_data); + // TODO(intel friends): bias memory should also be obtain from bias->data() auto bias_memory = mem.bias(); auto forward = with_bias ? mkldnn::inner_product_forward( @@ -217,7 +220,8 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel { const Tensor* out_grad = ctx.Input(framework::GradVarName("Out")); const T* out_grad_data = out_grad->data(); - bool with_bias = ctx.Attr("bias_attr"); + auto bias = ctx.Input("Bias"); + bool with_bias = bias != nullptr; MKLDNNMD md(input, w, with_bias); MKLDNNMemory mem(&md, mkldnn_engine); diff --git a/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py index 3f547f3c48..099e6e6064 100644 --- a/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py @@ -22,6 +22,7 @@ def fully_connected_naive(input, weights, bias_data=None): w_h, w_c = weights.shape x_data = np.reshape(input, [in_n, in_c * in_h * in_w]) + # this transpose should be implemented at C code w_data = np.transpose(np.reshape(weights, (w_c, in_c * in_h * in_w))) result = None @@ -43,15 +44,11 @@ class TestFCMKLDNNOp(OpTest): def setUp(self): self.op_type = "fc" self.use_mkldnn = True - self.with_bias = True self.matrix = MatrixGenerate(1, 10, 15, 3, 3) self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights} - self.attrs = { - 'use_mkldnn': self.use_mkldnn, - 'with_bias': self.with_bias - } + self.attrs = {'use_mkldnn': self.use_mkldnn, } self.outputs = { 'Out': fully_connected_naive(self.matrix.input, self.matrix.weights) @@ -85,13 +82,11 @@ class TestFCMKLDNNOp3(TestFCMKLDNNOp): class TestFCMKLDNNOp4(TestFCMKLDNNOp): def init_op_type(self): - self.with_bias = False self.matrix = MatrixGenerate(2, 32, 48, 2, 2) class TestFCMKLDNNOp4(TestFCMKLDNNOp): def init_op_type(self): - self.with_bias = False self.matrix = MatrixGenerate(2, 32, 1000, 6, 6) From 742300baa8a24cea467a45cc55f63ca894b0625f Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 14 Aug 2018 15:52:31 +0800 Subject: [PATCH 15/29] fix unkown omp pragmas --- paddle/fluid/operators/fc_op.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index ec8dfb659c..099ca52c8e 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -145,7 +145,9 @@ class FCOpKernel : public framework::OpKernel { if (bias) { const T* bias_data = bias->data(); +#ifdef PADDLE_WITH_MKLML #pragma omp parallel for if (FLAGS_paddle_num_threads > 1) +#endif for (int bs = 0; bs < in_dims[0]; bs++) { blas.AXPY(w_dims[1], static_cast(1), bias_data, output_data + bs * w_dims[1]); From 21d5b942282ae32bba4613b31f5429b65afc1532 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Tue, 14 Aug 2018 08:24:21 +0000 Subject: [PATCH 16/29] error message refine: add demangle api to attribute type --- paddle/fluid/framework/CMakeLists.txt | 2 + paddle/fluid/framework/attribute.h | 8 +- paddle/fluid/framework/attribute_type.h | 97 +++++++++++++++++++ paddle/fluid/framework/attribute_type_test.cc | 46 +++++++++ 4 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/framework/attribute_type.h create mode 100644 paddle/fluid/framework/attribute_type_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 6440607dbe..b3fe2d97a8 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -115,6 +115,8 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) # cc_test(channel_test SRCS channel_test.cc) cc_test(tuple_test SRCS tuple_test.cc ) +cc_test(attribute_type_test SRCS attribute_type_test.cc) + # disable test temporarily. # TODO https://github.com/PaddlePaddle/Paddle/issues/11971 # cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 8428bf8e33..2b05528257 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/attribute_type.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/platform/enforce.h" @@ -128,7 +129,8 @@ struct ExtractAttribute { attr_value = &boost::get(attr); } catch (boost::bad_get& bad_get) { PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s", - attr_name_, typeid(T).name(), attr.type().name()); + attr_name_, paddle::framework::demangle(typeid(T).name()), + paddle::framework::demangle(attr.type().name())); } return attr_value; } @@ -160,7 +162,7 @@ struct ExtractAttribute { attr_value = &boost::get(attr); } catch (boost::bad_get& bad_get) { PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s", - attr_name_, attr.type().name()); + attr_name_, paddle::framework::demangle(attr.type().name())); } return attr_value; } @@ -186,7 +188,7 @@ struct ExtractAttribute { attr_value = &boost::get(attr); } catch (boost::bad_get& bad_get) { PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", - attr_name_, attr.type().name()); + attr_name_, paddle::framework::demangle(attr.type().name())); } return attr_value; } diff --git a/paddle/fluid/framework/attribute_type.h b/paddle/fluid/framework/attribute_type.h new file mode 100644 index 0000000000..337dcde775 --- /dev/null +++ b/paddle/fluid/framework/attribute_type.h @@ -0,0 +1,97 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +// __has_include is currently supported by GCC and Clang. However GCC 4.9 may +// have issues and +// returns 1 for 'defined( __has_include )', while '__has_include' is actually +// not supported: +#if defined(__has_include) && (!defined(BOOST_GCC) || (__GNUC__ + 0) >= 5) +#if __has_include() +#define PADDLE_FRAMEWORK_HAS_CXXABI_H +#endif +#elif defined(__GLIBCXX__) || defined(__GLIBCPP__) +#define PADDLE_FRAMEWORK_HAS_CXXABI_H +#endif + +#if defined(PADDLE_FRAMEWORK_HAS_CXXABI_H) +#include +// For some archtectures (mips, mips64, x86, x86_64) cxxabi.h in Android NDK is +// implemented by gabi++ library +// which does not implement abi::__cxa_demangle(). We detect this implementation +// by checking the include guard here. +#if defined(__GABIXX_CXXABI_H__) +#undef PADDLE_FRAMEWORK_HAS_CXXABI_H +#else +#include +#include +#endif +#endif + +namespace paddle { +namespace framework { + +inline char const* demangle_alloc(char const* name); +inline void demangle_free(char const* name); + +class scoped_demangled_name { + private: + char const* m_p; + + public: + explicit scoped_demangled_name(char const* name) + : m_p(demangle_alloc(name)) {} + + ~scoped_demangled_name() { demangle_free(m_p); } + + char const* get() const { return m_p; } + + scoped_demangled_name(scoped_demangled_name const&) = delete; + scoped_demangled_name& operator=(scoped_demangled_name const&) = delete; +}; + +#if defined(PADDLE_FRAMEWORK_HAS_CXXABI_H) + +inline char const* demangle_alloc(char const* name) { + int status = 0; + std::size_t size = 0; + return abi::__cxa_demangle(name, NULL, &size, &status); +} + +inline void demangle_free(char const* name) { + std::free(const_cast(name)); +} + +inline std::string demangle(char const* name) { + scoped_demangled_name demangled_name(name); + char const* p = demangled_name.get(); + if (!p) p = name; + return p; +} + +#else + +inline char const* demangle_alloc(char const* name) { return name; } + +inline void demangle_free(char const*) {} + +inline std::string demangle(char const* name) { return name; } + +#endif + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/attribute_type_test.cc b/paddle/fluid/framework/attribute_type_test.cc new file mode 100644 index 0000000000..82537b8a0f --- /dev/null +++ b/paddle/fluid/framework/attribute_type_test.cc @@ -0,0 +1,46 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/attribute_type.h" + +TEST(Attribute, TypeName) { + bool boolean; + int integer; + float ft; + std::string str; + std::vector booleans; + std::vector integers; + std::vector strings; + + EXPECT_EQ("bool", paddle::framework::demangle(typeid(boolean).name())); + EXPECT_EQ("int", paddle::framework::demangle(typeid(integer).name())); + EXPECT_EQ("float", paddle::framework::demangle(typeid(ft).name())); + EXPECT_EQ( + "std::__cxx11::basic_string, " + "std::allocator >", + paddle::framework::demangle(typeid(str).name())); + EXPECT_EQ("std::vector >", + paddle::framework::demangle(typeid(booleans).name())); + EXPECT_EQ("std::vector >", + paddle::framework::demangle(typeid(integers).name())); + EXPECT_EQ( + "std::vector, " + "std::allocator >, std::allocator, std::allocator > > >", + paddle::framework::demangle(typeid(strings).name())); +} From da39d84a48d1445d6bb9fb10e8d7d17d9053c7b7 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Tue, 14 Aug 2018 09:55:47 +0000 Subject: [PATCH 17/29] refine by reviewer's advice --- paddle/fluid/platform/enforce.h | 4 ++-- paddle/fluid/platform/enforce_test.cc | 14 +++++++------- paddle/fluid/platform/gpu_info.cc | 21 +++++++++++---------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index cad60275a2..81b5359b40 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -263,7 +263,7 @@ inline void throw_on_error(T e) { * PADDLE_ENFORCE_EQ(a, b); * * will raise an expression described as follows: - * "Data check failed. Expected input a == b, but received a(1) != b(2)." + * "Enforce failed. Expected input a == b, but received a(1) != b(2)." * with detailed stack information. * * extra messages is also supported, for example: @@ -293,7 +293,7 @@ inline void throw_on_error(T e) { #define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ do { \ if (UNLIKELY(!((__VAL0)__CMP(__VAL1)))) { \ - PADDLE_THROW("Data check failed. Expected %s " #__CMP \ + PADDLE_THROW("Enforce failed. Expected %s " #__CMP \ " %s, but received %s:%s " #__INV_CMP " %s:%s.\n%s", \ #__VAL0, #__VAL1, #__VAL0, \ paddle::string::to_string(__VAL0), #__VAL1, \ diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index 8dcf39fdaa..d521829655 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -56,7 +56,7 @@ TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) { caught_exception = true; HasPrefix( StringPiece(error.what()), - "Data check failed. Expected a == 1 + 3, but received a:2 != 1 + 3:4."); + "Enforce failed. Expected a == 1 + 3, but received a:2 != 1 + 3:4."); } EXPECT_TRUE(caught_exception); } @@ -69,7 +69,7 @@ TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) { } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; HasPrefix(StringPiece(error.what()), - "Data check failed. Expected a == 1 + 3, but received a:2 != 1 + " + "Enforce failed. Expected a == 1 + 3, but received a:2 != 1 + " "3:4.\ntheir size not match"); } EXPECT_TRUE(caught_exception); @@ -89,7 +89,7 @@ TEST(ENFORCE_NE, FAIL) { caught_exception = true; EXPECT_TRUE(HasPrefix( StringPiece(error.what()), - "Data check failed. Expected 1.0 != 1UL, but received 1.0:1 == 1UL:1.")) + "Enforce failed. Expected 1.0 != 1UL, but received 1.0:1 == 1UL:1.")) << error.what() << " does not have expected prefix"; } EXPECT_TRUE(caught_exception); @@ -104,7 +104,7 @@ TEST(ENFORCE_GT, FAIL) { caught_exception = true; EXPECT_TRUE(HasPrefix( StringPiece(error.what()), - "Data check failed. Expected 1 > 2UL, but received 1:1 <= 2UL:2.")); + "Enforce failed. Expected 1 > 2UL, but received 1:1 <= 2UL:2.")); } EXPECT_TRUE(caught_exception); } @@ -123,7 +123,7 @@ TEST(ENFORCE_GE, FAIL) { caught_exception = true; EXPECT_TRUE(HasPrefix( StringPiece(error.what()), - "Data check failed. Expected 1 >= 2UL, but received 1:1 < 2UL:2.")); + "Enforce failed. Expected 1 >= 2UL, but received 1:1 < 2UL:2.")); } EXPECT_TRUE(caught_exception); } @@ -143,7 +143,7 @@ TEST(ENFORCE_LE, FAIL) { caught_exception = true; EXPECT_TRUE(HasPrefix( StringPiece(error.what()), - "Data check failed. Expected 1 > 2UL, but received 1:1 <= 2UL:2.")); + "Enforce failed. Expected 1 > 2UL, but received 1:1 <= 2UL:2.")); } EXPECT_TRUE(caught_exception); } @@ -160,7 +160,7 @@ TEST(ENFORCE_LT, FAIL) { } catch (paddle::platform::EnforceNotMet error) { caught_exception = true; EXPECT_TRUE(HasPrefix(StringPiece(error.what()), - "Data check failed. Expected 1UL < 0.12, but " + "Enforce failed. Expected 1UL < 0.12, but " "received 1UL:1 >= 0.12:0.12.")); } EXPECT_TRUE(caught_exception); diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index f9e2e8c69d..126636d879 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -100,25 +100,26 @@ size_t GpuMinChunkSize() { size_t GpuMaxChunkSize() { size_t total = 0; - size_t available_memory = 0; + size_t available = 0; - GpuMemoryUsage(&available_memory, &total); - VLOG(10) << "GPU Usage " << available_memory / 1024 / 1024 << "M/" + GpuMemoryUsage(&available, &total); + VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/" << total / 1024 / 1024 << "M"; size_t reserving = static_cast(0.05 * total); // If available less than minimum chunk size, no usable memory exists. - available_memory = std::min( - std::max(available_memory, GpuMinChunkSize()) - GpuMinChunkSize(), - total - reserving); + available = + std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(), + total - reserving); // Reserving the rest memory for page tables, etc. - size_t allocating_memory = static_cast( - FLAGS_fraction_of_gpu_memory_to_use * (total - reserving)); + size_t allocating = static_cast(FLAGS_fraction_of_gpu_memory_to_use * + (total - reserving)); - PADDLE_ENFORCE_LE(allocating_memory, available_memory); + PADDLE_ENFORCE_LE(allocating, available, + "Insufficient GPU memory to allocation."); - return allocating_memory; + return allocating; } void GpuMemcpyAsync(void *dst, const void *src, size_t count, From 7797e55f425d22a4fa812c9ac14eb906828828f2 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Tue, 14 Aug 2018 11:28:07 +0000 Subject: [PATCH 18/29] use paddle::platform::demangle --- paddle/fluid/framework/CMakeLists.txt | 2 - paddle/fluid/framework/attribute.h | 9 +- paddle/fluid/framework/attribute_type.h | 97 ------------------- paddle/fluid/framework/attribute_type_test.cc | 46 --------- 4 files changed, 4 insertions(+), 150 deletions(-) delete mode 100644 paddle/fluid/framework/attribute_type.h delete mode 100644 paddle/fluid/framework/attribute_type_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index b3fe2d97a8..6440607dbe 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -115,8 +115,6 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) # cc_test(channel_test SRCS channel_test.cc) cc_test(tuple_test SRCS tuple_test.cc ) -cc_test(attribute_type_test SRCS attribute_type_test.cc) - # disable test temporarily. # TODO https://github.com/PaddlePaddle/Paddle/issues/11971 # cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 2b05528257..14ca3e9620 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -20,7 +20,6 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/attribute_type.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/platform/enforce.h" @@ -129,8 +128,8 @@ struct ExtractAttribute { attr_value = &boost::get(attr); } catch (boost::bad_get& bad_get) { PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s", - attr_name_, paddle::framework::demangle(typeid(T).name()), - paddle::framework::demangle(attr.type().name())); + attr_name_, paddle::platform::demangle(typeid(T).name()), + paddle::platform::demangle(attr.type().name())); } return attr_value; } @@ -162,7 +161,7 @@ struct ExtractAttribute { attr_value = &boost::get(attr); } catch (boost::bad_get& bad_get) { PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s", - attr_name_, paddle::framework::demangle(attr.type().name())); + attr_name_, paddle::platform::demangle(attr.type().name())); } return attr_value; } @@ -188,7 +187,7 @@ struct ExtractAttribute { attr_value = &boost::get(attr); } catch (boost::bad_get& bad_get) { PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", - attr_name_, paddle::framework::demangle(attr.type().name())); + attr_name_, paddle::platform::demangle(attr.type().name())); } return attr_value; } diff --git a/paddle/fluid/framework/attribute_type.h b/paddle/fluid/framework/attribute_type.h deleted file mode 100644 index 337dcde775..0000000000 --- a/paddle/fluid/framework/attribute_type.h +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -// __has_include is currently supported by GCC and Clang. However GCC 4.9 may -// have issues and -// returns 1 for 'defined( __has_include )', while '__has_include' is actually -// not supported: -#if defined(__has_include) && (!defined(BOOST_GCC) || (__GNUC__ + 0) >= 5) -#if __has_include() -#define PADDLE_FRAMEWORK_HAS_CXXABI_H -#endif -#elif defined(__GLIBCXX__) || defined(__GLIBCPP__) -#define PADDLE_FRAMEWORK_HAS_CXXABI_H -#endif - -#if defined(PADDLE_FRAMEWORK_HAS_CXXABI_H) -#include -// For some archtectures (mips, mips64, x86, x86_64) cxxabi.h in Android NDK is -// implemented by gabi++ library -// which does not implement abi::__cxa_demangle(). We detect this implementation -// by checking the include guard here. -#if defined(__GABIXX_CXXABI_H__) -#undef PADDLE_FRAMEWORK_HAS_CXXABI_H -#else -#include -#include -#endif -#endif - -namespace paddle { -namespace framework { - -inline char const* demangle_alloc(char const* name); -inline void demangle_free(char const* name); - -class scoped_demangled_name { - private: - char const* m_p; - - public: - explicit scoped_demangled_name(char const* name) - : m_p(demangle_alloc(name)) {} - - ~scoped_demangled_name() { demangle_free(m_p); } - - char const* get() const { return m_p; } - - scoped_demangled_name(scoped_demangled_name const&) = delete; - scoped_demangled_name& operator=(scoped_demangled_name const&) = delete; -}; - -#if defined(PADDLE_FRAMEWORK_HAS_CXXABI_H) - -inline char const* demangle_alloc(char const* name) { - int status = 0; - std::size_t size = 0; - return abi::__cxa_demangle(name, NULL, &size, &status); -} - -inline void demangle_free(char const* name) { - std::free(const_cast(name)); -} - -inline std::string demangle(char const* name) { - scoped_demangled_name demangled_name(name); - char const* p = demangled_name.get(); - if (!p) p = name; - return p; -} - -#else - -inline char const* demangle_alloc(char const* name) { return name; } - -inline void demangle_free(char const*) {} - -inline std::string demangle(char const* name) { return name; } - -#endif - -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/attribute_type_test.cc b/paddle/fluid/framework/attribute_type_test.cc deleted file mode 100644 index 82537b8a0f..0000000000 --- a/paddle/fluid/framework/attribute_type_test.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#include -#include - -#include "gtest/gtest.h" -#include "paddle/fluid/framework/attribute_type.h" - -TEST(Attribute, TypeName) { - bool boolean; - int integer; - float ft; - std::string str; - std::vector booleans; - std::vector integers; - std::vector strings; - - EXPECT_EQ("bool", paddle::framework::demangle(typeid(boolean).name())); - EXPECT_EQ("int", paddle::framework::demangle(typeid(integer).name())); - EXPECT_EQ("float", paddle::framework::demangle(typeid(ft).name())); - EXPECT_EQ( - "std::__cxx11::basic_string, " - "std::allocator >", - paddle::framework::demangle(typeid(str).name())); - EXPECT_EQ("std::vector >", - paddle::framework::demangle(typeid(booleans).name())); - EXPECT_EQ("std::vector >", - paddle::framework::demangle(typeid(integers).name())); - EXPECT_EQ( - "std::vector, " - "std::allocator >, std::allocator, std::allocator > > >", - paddle::framework::demangle(typeid(strings).name())); -} From e38eca26e2d307ee4f4f0303970d1072ea5f56b9 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 14 Aug 2018 19:33:04 +0800 Subject: [PATCH 19/29] Add libpng dependencies to yum Correct libnccl dir --- tools/manylinux1/Dockerfile.x64 | 2 +- tools/manylinux1/build_scripts/build.sh | 2 +- tools/manylinux1/build_scripts/install_nccl2.sh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/manylinux1/Dockerfile.x64 b/tools/manylinux1/Dockerfile.x64 index bca0b77ad7..34c54303bd 100644 --- a/tools/manylinux1/Dockerfile.x64 +++ b/tools/manylinux1/Dockerfile.x64 @@ -13,7 +13,7 @@ ENV PATH /opt/rh/devtoolset-2/root/usr/bin:$PATH ENV LD_LIBRARY_PATH /opt/rh/devtoolset-2/root/usr/lib64:/opt/rh/devtoolset-2/root/usr/lib:/usr/local/lib64:/usr/local/lib:${LD_LIBRARY_PATH} ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig -RUN yum install -y sqlite-devel zlib-devel openssl-devel pcre-devel vim tk-devel tkinter libtool xz +RUN yum install -y sqlite-devel zlib-devel openssl-devel pcre-devel vim tk-devel tkinter libtool xz freetype-devel libpng-devel COPY build_scripts /build_scripts RUN bash build_scripts/build.sh && \ bash build_scripts/install_nccl2.sh && rm -r build_scripts diff --git a/tools/manylinux1/build_scripts/build.sh b/tools/manylinux1/build_scripts/build.sh index 93591fa9dd..d99d3db2ed 100644 --- a/tools/manylinux1/build_scripts/build.sh +++ b/tools/manylinux1/build_scripts/build.sh @@ -105,7 +105,7 @@ curl-config --features rm -rf /usr/local/ssl # Install patchelf (latest with unreleased bug fixes) -curl -sLO https://nipy.bic.berkeley.edu/manylinux/patchelf-0.9njs2.tar.gz +curl -sLO http://nipy.bic.berkeley.edu/manylinux/patchelf-0.9njs2.tar.gz check_sha256sum patchelf-0.9njs2.tar.gz $PATCHELF_HASH tar -xzf patchelf-0.9njs2.tar.gz (cd patchelf-0.9njs2 && ./configure && make && make install) diff --git a/tools/manylinux1/build_scripts/install_nccl2.sh b/tools/manylinux1/build_scripts/install_nccl2.sh index 282c5c290d..43a99d8287 100644 --- a/tools/manylinux1/build_scripts/install_nccl2.sh +++ b/tools/manylinux1/build_scripts/install_nccl2.sh @@ -21,5 +21,5 @@ for sub_deb in $DEBS; do ar x $sub_deb && tar xf data.tar.xz done mv -f usr/include/nccl.h /usr/local/include/ -mv -f usr/lib/libnccl* /usr/local/lib/ +mv -f usr/lib/x86_64-linux-gnu/libnccl* /usr/local/lib/ rm -rf $DIR From 3373535b213c7ad5c24121e9a4e56534bc40e05b Mon Sep 17 00:00:00 2001 From: luotao1 Date: Tue, 14 Aug 2018 16:08:36 +0800 Subject: [PATCH 20/29] fix specific cudnn include and library path --- cmake/configure.cmake | 4 ++++ cmake/external/anakin.cmake | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index c35096e09b..ae90a529b1 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -104,6 +104,10 @@ if(WITH_GPU) if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7) message(FATAL_ERROR "Anakin needs CUDNN >= 7.0 to compile") endif() + set(ENV{CUDNN_INCLUDE_DIR} ${CUDNN_INCLUDE_DIR}) + set(ENV{CUDNN_LIBRARY} ${CUDNN_LIBRARY}) + message(STATUS "cudnn include header is ${CUDNN_INCLUDE_DIR}/cudnn.h") + message(STATUS "cudnn library is ${CUDNN_LIBRARY}") endif() elseif(WITH_AMD_GPU) add_definitions(-DPADDLE_WITH_HIP) diff --git a/cmake/external/anakin.cmake b/cmake/external/anakin.cmake index 403873a510..5de7ca8f46 100644 --- a/cmake/external/anakin.cmake +++ b/cmake/external/anakin.cmake @@ -37,7 +37,7 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} # TODO(luotao): use PaddlePaddle/Anakin later GIT_REPOSITORY "https://github.com/luotao1/Anakin" - GIT_TAG "3957ae9263eaa0b1986758dac60a88852afb09be" + GIT_TAG "842a89ae3747ede25d8acbc29030d2eb602ced1f" PREFIX ${ANAKIN_SOURCE_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DUSE_GPU_PLACE=YES From d84a1a0010fc038a7da2ee7cf3ebb4f93353f1a4 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 15 Aug 2018 12:24:03 +0800 Subject: [PATCH 21/29] fc op use cpu only --- paddle/fluid/operators/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ae37c70929..c3f7c42a82 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -158,6 +158,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") else() file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") + # HACK: fc only have cpu kernel + if (${MKLDNN_FILE} STREQUAL "fc_mkldnn_op") + file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") + set(pybind_flag 1) + endif() endif() endif() From eee38464dc5477480fd36e57305f36c9519c9c00 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 15 Aug 2018 13:36:32 +0800 Subject: [PATCH 22/29] refine fc op use cpu only --- paddle/fluid/operators/CMakeLists.txt | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c3f7c42a82..e8b5dec9d4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -158,11 +158,6 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") else() file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") - # HACK: fc only have cpu kernel - if (${MKLDNN_FILE} STREQUAL "fc_mkldnn_op") - file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") - set(pybind_flag 1) - endif() endif() endif() @@ -175,6 +170,9 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") elseif(${TARGET} STREQUAL "tensorrt_engine_op") message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference") + elseif(${TARGET} STREQUAL "fc") + # HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition + file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") else() file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") endif() From d06849305a67d6645699384ae87ec1870e5756e3 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 15 Aug 2018 21:17:14 +0800 Subject: [PATCH 23/29] parameter dispather. (#12666) --- paddle/fluid/framework/threadpool.cc | 7 ++ .../distributed/variable_response.cc | 7 +- paddle/fluid/operators/listen_and_serv_op.cc | 5 +- python/paddle/fluid/__init__.py | 2 +- python/paddle/fluid/initializer.py | 1 - .../fluid/tests/unittests/CMakeLists.txt | 4 +- .../fluid/tests/unittests/test_dist_train.py | 17 +++ .../tests/unittests/test_dist_transpiler.py | 50 ++++++--- .../fluid/transpiler/distribute_transpiler.py | 100 +++++++++++++++--- 9 files changed, 162 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index f26f212d4d..18cdca3a65 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -20,6 +20,9 @@ DEFINE_int32(io_threadpool_size, 100, "number of threads used for doing IO, default 100"); +DEFINE_int32(dist_threadpool_size, 0, + "number of threads used for distributed executed."); + namespace paddle { namespace framework { @@ -35,6 +38,10 @@ void ThreadPool::Init() { if (threadpool_.get() == nullptr) { // TODO(Yancey1989): specify the max threads number int num_threads = std::thread::hardware_concurrency(); + if (FLAGS_dist_threadpool_size > 0) { + num_threads = FLAGS_dist_threadpool_size; + VLOG(1) << "set dist_threadpool_size to " << num_threads; + } PADDLE_ENFORCE_GT(num_threads, 0); threadpool_.reset(new ThreadPool(num_threads)); } diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc index 466bce18af..8e38b3713f 100644 --- a/paddle/fluid/operators/distributed/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -190,12 +190,15 @@ bool VariableResponse::ProcSerializedField( #endif } + VLOG(7) << "ProcSerializedField:" << meta_.varname() + << ", type:" << meta_.type() << std::endl; framework::DDim dims = GetDims(meta_.dims()); if (meta_.type() == sendrecv::LOD_TENSOR) { PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!"); if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) { return false; } + return true; } @@ -206,7 +209,9 @@ bool VariableResponse::ProcSerializedField( return true; } - return true; + PADDLE_ENFORCE("not supported var types:", meta_.varname(), meta_.type()); + + return false; } }; // namespace distributed diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index b194807696..f196e18fe1 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -123,8 +123,11 @@ void ListenAndServOp::RunSyncLoop( optimize_prepared.begin(), std::shared_ptr(nullptr)); + // Trainers will get all parameters from pserver in the + // startup program, so we will wait RequestGet first + rpc_service_->SetCond(distributed::kRequestGet); + rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->ResetBarrierCounter(); - while (true) { rpc_service_->Profiler().OneStep(); // Get from multiple trainers, we don't care about the order in which diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 1ae05dec8d..9aac3c7fc1 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -122,7 +122,7 @@ def __bootstrap__(): 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', - 'cpu_deterministic' + "dist_threadpool_size", 'cpu_deterministic' ] if core.is_compiled_with_dist(): read_env_flags.append('rpc_deadline') diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 3f740dd7c5..6dedbae7a6 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -15,7 +15,6 @@ from . import framework import numpy as np import contextlib -from .framework import convert_np_dtype_to_dtype_ from .core import VarDesc __all__ = [ diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index a6a911721d..e7dd85ef5c 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -59,8 +59,8 @@ py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=$ if(WITH_DISTRIBUTE) py_test_modules(test_dist_train MODULES test_dist_train SERIAL) set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20) - set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 180) - set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 180) + set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 200) + set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200) endif() py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) diff --git a/python/paddle/fluid/tests/unittests/test_dist_train.py b/python/paddle/fluid/tests/unittests/test_dist_train.py index aab8969a96..55aa923f5a 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_train.py +++ b/python/paddle/fluid/tests/unittests/test_dist_train.py @@ -26,6 +26,12 @@ from paddle.fluid.layers.io import ListenAndServ from paddle.fluid.layers.io import Recv from paddle.fluid.layers.io import Send +from paddle.fluid import core + +RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( +) +RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC + class TestSendOp(unittest.TestCase): def test_send(self): @@ -89,18 +95,29 @@ class TestSendOp(unittest.TestCase): def init_client(self, place, port): main = fluid.Program() with fluid.program_guard(main): + main.global_block().append_op( + type="fetch_barrier", + inputs={}, + outputs={}, + attrs={ + "endpoints": ["127.0.0.1:{0}".format(port)], + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + x = layers.data( shape=[32, 32], dtype='float32', name='X', append_batch_size=False) fluid.initializer.Constant(value=2.3)(x, main.global_block()) + get_var = main.global_block().create_var( name="scale_0.tmp_0", # server side var dtype="float32", persistable=False, shape=[32, 32]) fluid.initializer.Constant(value=2.3)(get_var, main.global_block()) + Send("127.0.0.1:%d" % port, [x]) o = Recv("127.0.0.1:%d" % port, [get_var]) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 55f8b3eff8..124abf4ccd 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -18,6 +18,7 @@ import unittest import paddle.fluid as fluid from paddle.fluid.transpiler.distribute_transpiler import delete_ops import traceback +import collections class TranspilerTest(unittest.TestCase): @@ -53,9 +54,18 @@ class TranspilerTest(unittest.TestCase): self.origin_prog = main.clone() return main - def get_trainer(self, config=None, sync_mode=True): - t = self._transpiler_instance(config, sync_mode) - return t.get_trainer_program() + def get_trainer(self, config=None): + src = fluid.default_startup_program().clone() + + t = self._transpiler_instance(config) + + trainer_main = t.get_trainer_program() + trainer_startup = fluid.default_startup_program() + + assert (src.num_blocks == 1) + assert (trainer_startup.num_blocks == src.num_blocks) + + return trainer_main, trainer_startup def get_pserver(self, ep, config=None, sync_mode=True): t = self._transpiler_instance(config, sync_mode) @@ -91,7 +101,21 @@ class TestBasicModel(TranspilerTest): pserver, startup = self.get_pserver(self.pserver1_ep) pserver2, startup2 = self.get_pserver(self.pserver2_ep) - trainer = self.get_trainer() + trainer, trainer_startup = self.get_trainer() + + # splited var blocks should be in startup program + self.assertTrue("fc_w.block0" in trainer_startup.global_block().vars) + self.assertTrue("fc_w.block1" in trainer_startup.global_block().vars) + self.assertTrue("fc_w" in trainer_startup.global_block().vars) + self.assertTrue("fc_b" in trainer_startup.global_block().vars) + self.assertTrue("fc_w@GRAD" not in trainer_startup.global_block().vars) + self.assertTrue("fc_b@GRAD" not in trainer_startup.global_block().vars) + + src = [op.type for op in trainer_startup.global_block().ops] + dst = ['fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv', \ + 'fetch_barrier', 'concat'] + + self.assertEqual(src, dst) self.assertEqual([op.type for op in trainer.global_block().ops], [ 'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean', @@ -142,7 +166,7 @@ class TestBasicModelWithLargeBlockSize(TranspilerTest): pserver, startup = self.get_pserver(self.pserver1_ep, config) pserver2, startup2 = self.get_pserver(self.pserver2_ep, config) - trainer = self.get_trainer(config) + trainer, _ = self.get_trainer(config) self.assertEqual([op.type for op in trainer.global_block().ops], [ 'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean', @@ -226,7 +250,7 @@ class TestLRDecay(TranspilerTest): def transpiler_test_impl(self): pserver, startup = self.get_pserver(self.pserver1_ep) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(pserver.blocks), 4) lr_decay_ops = [op.type for op in pserver.blocks[1].ops] @@ -256,7 +280,7 @@ class TestLRDecayConditional(TranspilerTest): def transpiler_test_impl(self): pserver, startup = self.get_pserver(self.pserver1_ep) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() serv_op = pserver.blocks[0].ops[0] sub_blocks = [] @@ -305,7 +329,7 @@ class TestL2Decay(TranspilerTest): def transpiler_test_impl(self): pserver, startup = self.get_pserver(self.pserver1_ep) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(pserver.blocks), 3) self.assertEqual([op.type for op in pserver.blocks[1].ops], @@ -340,7 +364,7 @@ class TestL2DecayWithPiecewise(TranspilerTest): def transpiler_test_impl(self): pserver, startup = self.get_pserver(self.pserver1_ep) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(pserver.blocks), 9) self.assertEqual([op.type for op in pserver.blocks[1].ops], [ @@ -415,7 +439,7 @@ class TestLocalLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sum", "adam", "scale", "scale"]) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(trainer.blocks), 1) ops = [ 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', @@ -453,7 +477,7 @@ class TestDistLookupTable(TestDistLookupTableBase): # 5 save table self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(trainer.blocks), 1) ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', @@ -486,7 +510,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["adam", "scale", "scale"]) - trainer = self.get_trainer(config) + trainer, _ = self.get_trainer(config) self.assertEqual(len(trainer.blocks), 1) ops = [ 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', @@ -525,7 +549,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): # 5 save table self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) - trainer = self.get_trainer(config) + trainer, _ = self.get_trainer(config) self.assertEqual(len(trainer.blocks), 1) ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index c97beea1b3..ce4709f23b 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -195,6 +195,9 @@ class DistributeTranspiler(object): if program is None: program = default_main_program() self.origin_program = program + self.origin_startup_program = default_startup_program().clone() + + self.startup_program = default_startup_program() self.trainer_num = trainers self.sync_mode = sync_mode self.trainer_id = trainer_id @@ -205,10 +208,10 @@ class DistributeTranspiler(object): ps_dispatcher = self.config.split_method(self.pserver_endpoints) self.has_distributed_lookup_table = self._has_distributed_lookup_table() - # split and create vars, then put splited vars in dicts for later use. + # step 1: split and create vars, then put splited vars in dicts for later use. self._init_splited_vars() - # step 3.1: insert send op to send gradient vars to parameter servers + # step 2: insert send op to send gradient vars to parameter servers ps_dispatcher.reset() send_vars = [] @@ -265,7 +268,7 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) - # step 3.2: insert recv op to receive parameters from parameter server + # step 3: insert recv op to receive parameters from parameter server recv_vars = [] for _, var in enumerate(send_vars): recv_vars.append(self.grad_param_mapping[var]) @@ -312,6 +315,8 @@ class DistributeTranspiler(object): outputs={"Out": [orig_param]}, attrs={"axis": 0}) + self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) + if self.has_distributed_lookup_table: self._replace_lookup_table_op_with_prefetch(program, pserver_endpoints) @@ -328,8 +333,78 @@ class DistributeTranspiler(object): # FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay? delete_ops(self.origin_program.global_block(), self.optimize_ops) self.origin_program.__str__() + return self.origin_program + def _get_trainer_startup_program(self, + recv_vars, + eplist, + startup_program=None): + """ + Get transpiled trainer side startup program. + + Args: + startup_program(Program): Startup program. + + Returns: + Program: trainer side startup program. + """ + if startup_program is None: + startup_program = self.startup_program + + # FIXME(gongwb): delete not need ops. + # note that: some parameter is not trainable and those ops can't be deleted. + + for varname, splited_var in self.param_var_mapping.iteritems(): + # Get the eplist of recv vars + eps = [] + for var in splited_var: + index = [v.name for v in recv_vars].index(var.name) + eps.append(eplist[index]) + + for var in splited_var: + if startup_program.global_block().has_var(var.name): + continue + + startup_program.global_block().create_var( + name=var.name, + persistable=False, + type=var.type, + dtype=var.dtype, + shape=var.shape, + lod_level=var.lod_level) + + op = startup_program.global_block().append_op( + type="recv", + inputs={}, + outputs={"Out": splited_var}, + attrs={ + "epmap": eps, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + + startup_program.global_block().append_op( + type="fetch_barrier", + inputs={}, + outputs={}, + attrs={ + "endpoints": self.pserver_endpoints, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + + for varname, splited_var in self.param_var_mapping.iteritems(): + #add concat ops to merge splited parameters received from parameter servers. + if len(splited_var) <= 1: + continue + orig_param = startup_program.global_block().vars[varname] + startup_program.global_block().append_op( + type="concat", + inputs={"X": splited_var}, + outputs={"Out": [orig_param]}, + attrs={"axis": 0}) + + return startup_program + def get_pserver_program(self, endpoint): """ Get parameter server side program. @@ -576,14 +651,16 @@ class DistributeTranspiler(object): new_outputs = dict() # do not append startup op if var is not on this pserver op_on_pserver = False - for key in op.output_names: - newname, _ = _get_splited_name_and_shape(op.output(key)[0]) - if newname: - op_on_pserver = True - new_outputs[key] = created_var_map[newname] - elif op.output(key)[0] in pserver_vars: - op_on_pserver = True - new_outputs[key] = pserver_vars[op.output(key)[0]] + # TODO(gongwb): remove this line. + if op.type not in ["recv", "fetch_barrier", "concat"]: + for key in op.output_names: + newname, _ = _get_splited_name_and_shape(op.output(key)[0]) + if newname: + op_on_pserver = True + new_outputs[key] = created_var_map[newname] + elif op.output(key)[0] in pserver_vars: + op_on_pserver = True + new_outputs[key] = pserver_vars[op.output(key)[0]] if op_on_pserver: # most startup program ops have no inputs @@ -1022,7 +1099,6 @@ class DistributeTranspiler(object): var_mapping[varname] = \ [program.global_block().var(orig_var.name)] continue - var_mapping[varname] = [] orig_shape = orig_var.shape orig_dim1_flatten = 1 From c108376506faa8c51f489a4c1e658a446424453a Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Wed, 15 Aug 2018 22:38:25 +0800 Subject: [PATCH 24/29] Add three modes for prelu_op (#12630) * Add three modes for prelu_op. --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/prelu_op.cc | 65 +++++++-- paddle/fluid/operators/prelu_op.cu | 22 --- paddle/fluid/operators/prelu_op.h | 125 ++++++++++-------- python/paddle/fluid/layers/nn.py | 54 ++++++++ .../fluid/tests/unittests/test_layers.py | 15 +++ .../fluid/tests/unittests/test_prelu_op.py | 56 ++++++-- 7 files changed, 237 insertions(+), 101 deletions(-) delete mode 100644 paddle/fluid/operators/prelu_op.cu diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index c020ff45ad..ea9105d79c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -159,6 +159,7 @@ paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaul paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index db040509bc..23d9ea88f6 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,14 +23,40 @@ class PReluOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { + std::string mode = ctx->Attrs().Get("mode"); + + auto x_dim = ctx->GetInputDim("X"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null"); - PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, - "Size of weight Alpha must be one."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + if (mode == "all") { + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, + "For mode 'all', size of weight Alpha must be one."); + } else if (mode == "channel") { + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == x_dim[1], + "For channel-wise mode, size of weight Alpha must be " + "equal to the number of channels, should be %d", + x_dim[1]); + } else if (mode == "element") { + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == product(x_dim), + "For element-wise mode, size of weight Alpha must be " + "equal to the number of input, should be %d", + product(x_dim)); + } else { + PADDLE_THROW("Unkown mode %s", mode); + } + ctx->SetOutputDim("Out", x_dim); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } }; class PReluOpMaker : public framework::OpProtoAndCheckerMaker { @@ -44,9 +67,7 @@ class PReluOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output tensor of prelu operator."); AddComment(R"DOC( PRelu Operator. - The equation is: - $$ f(x) = \begin{cases} @@ -54,11 +75,15 @@ f(x) = x, \qquad \text{if} \ x >= 0 \end{cases} $$ - The input `X` can carry the LoD (Level of Details) information, or not. And the output shares the LoD information with input `X`. - +There are modes: + all: all elements share same weight + channel: elements in a channel share same weight + element: each element has a weight )DOC"); + AddAttr("mode", "The mode for inputs to share weights.") + .SetDefault("all"); } }; @@ -71,9 +96,23 @@ class PReluGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->SetOutputDim(framework::GradVarName("Alpha"), - ctx->GetInputDim("Alpha")); + auto x_grad_name = framework::GradVarName("X"); + auto alpha_grad_name = framework::GradVarName("Alpha"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + } + if (ctx->HasOutput(alpha_grad_name)) { + ctx->SetOutputDim(alpha_grad_name, ctx->GetInputDim("Alpha")); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu deleted file mode 100644 index 37d934a290..0000000000 --- a/paddle/fluid/operators/prelu_op.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/prelu_op.h" - -REGISTER_OP_CUDA_KERNEL( - prelu, - paddle::operators::PReluKernel); -REGISTER_OP_CUDA_KERNEL(prelu_grad, - paddle::operators::PReluGradKernel< - paddle::platform::CUDADeviceContext, float>); diff --git a/paddle/fluid/operators/prelu_op.h b/paddle/fluid/operators/prelu_op.h index a6197d3548..f9076cbc67 100644 --- a/paddle/fluid/operators/prelu_op.h +++ b/paddle/fluid/operators/prelu_op.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,32 +10,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/transform.h" - namespace paddle { namespace operators { using Tensor = framework::Tensor; using platform::Transform; -template -class PReluFunctor { - public: - explicit PReluFunctor(const T* alpha) : alpha_(alpha) {} - - HOSTDEVICE T operator()(const T& x) const { - if (x > 0) - return x; - else - return x * (*alpha_); - } - - private: - const T* alpha_; -}; - template class PReluKernel : public framework::OpKernel { public: @@ -50,53 +31,93 @@ class PReluKernel : public framework::OpKernel { const T* x_ptr = x->data(); T* o_ptr = out->mutable_data(context.GetPlace()); - auto* alpha_ptr = alpha->data(); + const T* alpha_ptr = alpha->data(); + std::string mode = context.Attr("mode"); int numel = x->numel(); - - Transform trans; - trans(context.template device_context(), x_ptr, - x_ptr + numel, o_ptr, PReluFunctor(alpha_ptr)); - } -}; - -template -class PReluGradFunctor { - public: - explicit PReluGradFunctor(const T* alpha) : alpha_(alpha) {} - - HOSTDEVICE T operator()(const T& out, const T& dout) const { - if (out > 0) - return dout; - else - return dout * (*alpha_); + auto dim = x->dims(); + int index = 0; + int i = 0; + int temp = 0; + if (mode == "channel") { + for (i = 0; i < numel; i++) { + temp = numel / (dim[0] * dim[1]); + index = (i / temp) % dim[1]; + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; + } + } else if (mode == "element") { + for (i = 0; i < numel; i++) { + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[i] * x_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[0] * x_ptr[i]; + } + } } - - private: - const T* alpha_; }; template class PReluGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); auto* dx = context.Output(framework::GradVarName("X")); auto* dout = context.Input(framework::GradVarName("Out")); - + auto* dalpha = context.Output(framework::GradVarName("Alpha")); auto* out = context.Input("Out"); auto* alpha = context.Input("Alpha"); - auto* alpha_ptr = alpha->data(); - - T* dx_ptr = dx->mutable_data(context.GetPlace()); + const T* alpha_ptr = alpha->data(); + const T* x_ptr = x->data(); const T* dout_ptr = dout->data(); const T* out_ptr = out->data(); - int numel = dx->numel(); - - Transform trans; - trans(context.template device_context(), out_ptr, - out_ptr + numel, dout_ptr, dx_ptr, PReluGradFunctor(alpha_ptr)); - - // TODO(Zhuoyuan): add dalpha upgrade when GPU kernels ready + std::string mode = context.Attr("mode"); + int numel = x->numel(); + auto dim = x->dims(); + int index = 0; + int i = 0; + int temp = 0; + if (dx) { + T* dx_ptr = dx->mutable_data(context.GetPlace()); + if (mode == "channel") { + for (i = 0; i < numel; i++) { + temp = numel / (dim[0] * dim[1]); + index = (i / temp) % dim[1]; + dx_ptr[i] = + out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; + } + } else if (mode == "element") { + for (i = 0; i < numel; i++) { + dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[i] * dout_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[0] * dout_ptr[i]; + } + } + } + + index = 0; + if (dalpha) { + T* dalpha_ptr = dalpha->mutable_data(context.GetPlace()); + if (mode == "channel") { + for (i = 0; i < numel; i++) { + temp = numel / (dim[0] * dim[1]); + index = (i / temp) % dim[1]; + dalpha_ptr[index] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } else if (mode == "element") { + for (i = 0; i < numel; i++) { + dalpha_ptr[i] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + dalpha_ptr[0] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } + } + + // TODO(Guanzhong): add GPU kernels } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c75e7eeb43..3e50fc91d9 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -112,6 +112,7 @@ __all__ = [ 'log', 'crop', 'rank_loss', + 'prelu', 'flatten', ] @@ -5364,6 +5365,59 @@ def rank_loss(label, left, right, name=None): return out +def prelu(x, mode, param_attr=None, name=None): + """ + Equation: + + y = \max(0, x) + alpha \min(0, x) + + Args: + x (Variable): The input tensor. + param_attr(ParamAttr|None): The parameter attribute for the learnable + weight (alpha). + mode (string): The mode for weight sharing + all: all elements share same weight + channel:elements in a channel share same weight + element:each element has a weight + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The output tensor with the same shape as input. + + Examples: + + .. code-block:: python + + x = fluid.layers.data(name="x", shape=[10,10], dtype="float32") + mode = 'channel' + output = fluid.layers.prelu(x,mode) + """ + helper = LayerHelper('prelu', **locals()) + if mode not in ['all', 'channel', 'element']: + raise ValueError('mode should be one of all, channel, element.') + alpha_shape = [1] + if mode == 'channel': + alpha_shape = [1, x.shape[1], 1, 1] + elif mode == 'element': + alpha_shape = x.shape + dtype = helper.input_dtype(input_param_name='x') + alpha = helper.create_parameter( + attr=param_attr, + shape=alpha_shape, + dtype='float32', + is_bias=False, + default_initializer=Constant(1.0)) + out = helper.create_tmp_variable(dtype) + helper.append_op( + type="prelu", + inputs={"X": x, + 'Alpha': alpha}, + attrs={"mode": mode}, + outputs={"Out": out}) + return out + + def flatten(x, axis=1, name=None): """ **Flatten layer** diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 38a138a8fa..07fd0575d3 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -21,6 +21,7 @@ import paddle.fluid.nets as nets from paddle.fluid.framework import Program, program_guard, default_main_program from paddle.fluid.param_attr import ParamAttr import decorators +from paddle.fluid.initializer import Constant class TestBook(unittest.TestCase): @@ -485,6 +486,20 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_prelu(self): + program = Program() + with program_guard(program): + input = layers.data( + name="input", shape=[5, 200, 100, 100], dtype="float32") + mode = 'channel' + out = layers.prelu( + input, + mode, + param_attr=ParamAttr(initializer=Constant(1.0)), + name='prelu') + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index ae19a553bb..cb7de3fc93 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -20,30 +20,58 @@ from op_test import OpTest class PReluTest(OpTest): def setUp(self): self.op_type = "prelu" - x_np = np.random.normal(size=(10, 10)).astype("float32") - - for pos, val in np.ndenumerate(x_np): - # Since zero point in prelu is not differentiable, avoid randomize - # zero. - while abs(val) < 1e-3: - x_np[pos] = np.random.normal() - val = x_np[pos] - - x_np_sign = np.sign(x_np) - x_np = x_np_sign * np.maximum(x_np, .005) - alpha_np = np.array([.1], dtype="float32") - self.inputs = {'X': x_np, 'Alpha': alpha_np} + self.initTestCase() + x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32") + + # Since zero point in prelu is not differentiable, avoid randomize + # zero. + x_np[np.abs(x_np) < 0.005] = 0.02 + + if self.attrs == {'mode': "all"}: + alpha_np = np.random.rand(1).astype("float32") + self.inputs = {'X': x_np, 'Alpha': alpha_np} + elif self.attrs == {'mode': "channel"}: + alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32") + self.inputs = {'X': x_np, 'Alpha': alpha_np} + else: + alpha_np = np.random.rand(*x_np.shape).astype("float32") + self.inputs = {'X': x_np, 'Alpha': alpha_np} + out_np = np.maximum(self.inputs['X'], 0.) out_np = out_np + np.minimum(self.inputs['X'], 0.) * self.inputs['Alpha'] assert out_np is not self.inputs['X'] self.outputs = {'Out': out_np} + def initTestCase(self): + self.attrs = {'mode': "channel"} + def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X', 'Alpha'], 'Out') + + def test_check_grad_ignore_x(self): + self.check_grad(['Alpha'], 'Out', no_grad_set=set('X')) + + def test_check_grad_ignore_alpha(self): + self.check_grad(['X'], 'Out', no_grad_set=set('Alpha')) + + +class TestCase1(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "all"} + + +class TestCase2(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "channel"} + + +class TestCase3(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "element"} if __name__ == "__main__": From bf3c34960f2a59a2616957f8fb4107b2ac7aa02b Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 16 Aug 2018 11:00:55 +0800 Subject: [PATCH 25/29] "cherry picked operators changes" (#12184) * "cherry picked operators changes" * "remove duplicated code" * "add constant setter" * "add get expected kernel" * "fix ci" * "add fill constant" --- paddle/fluid/operators/activation_op.cu | 4 +- paddle/fluid/operators/activation_op.h | 12 ++-- paddle/fluid/operators/assign_value_op.cu.cc | 5 +- paddle/fluid/operators/conv_cudnn_op.cu.cc | 56 +++++++++++------- paddle/fluid/operators/cross_entropy_op.cu | 12 ++-- paddle/fluid/operators/elementwise_add_op.cu | 3 +- paddle/fluid/operators/elementwise_div_op.cu | 9 ++- paddle/fluid/operators/elementwise_mul_op.cu | 8 ++- .../fluid/operators/elementwise_op_function.h | 4 +- paddle/fluid/operators/elementwise_sub_op.cu | 8 ++- paddle/fluid/operators/fill_constant_op.cc | 53 ++++++----------- paddle/fluid/operators/fill_constant_op.cu.cc | 26 ++++++++ paddle/fluid/operators/fill_constant_op.h | 48 +++++++++++++++ paddle/fluid/operators/fill_op.cc | 2 +- paddle/fluid/operators/gaussian_random_op.cu | 2 + paddle/fluid/operators/math/cross_entropy.cu | 20 ++++++- paddle/fluid/operators/math/cross_entropy.h | 17 ++++++ .../operators/math/selected_rows_functor.cu | 13 +++- paddle/fluid/operators/math/softmax.cu | 3 + paddle/fluid/operators/mean_op.cu | 10 ++-- paddle/fluid/operators/mean_op.h | 2 +- paddle/fluid/operators/mul_op.cu.cc | 7 ++- paddle/fluid/operators/pool_cudnn_op.cu.cc | 6 +- paddle/fluid/operators/scale_op.cu | 6 +- paddle/fluid/operators/softmax_cudnn_op.cu.cc | 3 +- paddle/fluid/operators/softmax_op.cu.cc | 3 +- paddle/fluid/operators/sum_op.cu | 5 +- paddle/fluid/operators/sum_op.h | 2 +- paddle/fluid/operators/top_k_op.cu | 28 +++++++-- paddle/fluid/operators/uniform_random_op.cu | 59 ++++++++++++++++--- 30 files changed, 328 insertions(+), 108 deletions(-) create mode 100644 paddle/fluid/operators/fill_constant_op.cu.cc create mode 100644 paddle/fluid/operators/fill_constant_op.h diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 27487b396c..d3a7ceed46 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -26,6 +26,8 @@ namespace plat = paddle::platform; act_type##_grad, ops::ActivationGradKernel>, \ ops::ActivationGradKernel>); + ops::grad_functor>, \ + ops::ActivationGradKernel>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 9124151926..48f3b5a5bc 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -333,8 +333,7 @@ struct SqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - const Out out_conj = Eigen::numext::conj(out); - dx.device(d) = static_cast(0.5) * dout / out_conj; + dx.device(d) = static_cast(0.5) * dout / out; } }; @@ -740,7 +739,7 @@ struct PowGradFunctor : public BaseActivationFunctor { typename dX> void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(factor) * - x.pow(static_cast(factor - static_cast(1))); + x.pow(static_cast(factor) - static_cast(1)); } }; @@ -863,10 +862,11 @@ struct SwishGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + T b = static_cast(beta); auto temp1 = static_cast(1) / - (static_cast(1) + (static_cast(-beta) * x).exp()); - auto temp2 = temp1 * (static_cast(1) - (beta * out)); - dx.device(d) = dout * ((beta * out) + temp2); + (static_cast(1) + (static_cast(-b) * x).exp()); + auto temp2 = temp1 * (static_cast(1) - (b * out)); + dx.device(d) = dout * ((b * out) + temp2); } }; diff --git a/paddle/fluid/operators/assign_value_op.cu.cc b/paddle/fluid/operators/assign_value_op.cu.cc index 08bfde5dc9..0ff174b388 100644 --- a/paddle/fluid/operators/assign_value_op.cu.cc +++ b/paddle/fluid/operators/assign_value_op.cu.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/assign_value_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel, - ops::AssignValueKernel); + ops::AssignValueKernel, + ops::AssignValueKernel); diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 22cbf680c0..59bfe8f61d 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -39,6 +39,27 @@ using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = static_cast(1024) * 1024 * 1024; +template +// bool EnableFp16(const T& dummy, const DeviceContext& dev_ctx, +bool EnableFp16(const DeviceContext& dev_ctx, + cudnnConvolutionDescriptor_t cudnn_conv_desc) { +#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) + // Tensor core is supported since the volta GPU and + // is only enabled when input and filter data are float16 + if (dev_ctx.GetComputeCapability() >= 70 && + std::type_index(typeid(T)) == + std::type_index(typeid(platform::float16))) { + PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); + return true; + } else { + PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_DEFAULT_MATH)); + } +#endif + return false; +} + template class CUDNNConvOpKernel : public framework::OpKernel { public: @@ -128,27 +149,14 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnnConvolutionFwdAlgo_t algo; auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); - - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( - handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); - -#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) - // Tensor core is supported since the volta GPU and - // is only enabled when input and filter data are float16 - if (dev_ctx.GetComputeCapability() >= 70 && - std::type_index(typeid(T)) == - std::type_index(typeid(platform::float16))) { - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); - // Currently tensor core is only enabled using this algo + if (EnableFp16(dev_ctx, cudnn_conv_desc)) { algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; } else { - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_DEFAULT_MATH)); + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); } -#endif // get workspace size able to allocate CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( @@ -288,6 +296,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } else { data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } + if (EnableFp16(dev_ctx, cudnn_conv_desc)) { + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + } CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( @@ -307,6 +318,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } else { filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } + if (EnableFp16(dev_ctx, cudnn_conv_desc)) { + filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + } CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( @@ -362,7 +376,8 @@ REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel); + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel); REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel, @@ -370,4 +385,5 @@ REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel); + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel) diff --git a/paddle/fluid/operators/cross_entropy_op.cu b/paddle/fluid/operators/cross_entropy_op.cu index 30dbd5bd3d..65fd3a5dbc 100644 --- a/paddle/fluid/operators/cross_entropy_op.cu +++ b/paddle/fluid/operators/cross_entropy_op.cu @@ -13,12 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cross_entropy_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; using CUDACtx = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, - ops::CrossEntropyOpKernel); -REGISTER_OP_CUDA_KERNEL(cross_entropy_grad, - ops::CrossEntropyGradientOpKernel, - ops::CrossEntropyGradientOpKernel); + ops::CrossEntropyOpKernel, + ops::CrossEntropyOpKernel); +REGISTER_OP_CUDA_KERNEL( + cross_entropy_grad, ops::CrossEntropyGradientOpKernel, + ops::CrossEntropyGradientOpKernel, + ops::CrossEntropyGradientOpKernel); diff --git a/paddle/fluid/operators/elementwise_add_op.cu b/paddle/fluid/operators/elementwise_add_op.cu index dfff518f17..f9f5c66d34 100644 --- a/paddle/fluid/operators/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise_add_op.cu @@ -30,4 +30,5 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, - ops::ElementwiseAddGradKernel); + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel); diff --git a/paddle/fluid/operators/elementwise_div_op.cu b/paddle/fluid/operators/elementwise_div_op.cu index 588d1f7420..4cc7ba0f43 100644 --- a/paddle/fluid/operators/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise_div_op.cu @@ -14,19 +14,24 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_div_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_div, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel); + ops::ElementwiseDivKernel, + ops::ElementwiseDivKernel); REGISTER_OP_CUDA_KERNEL( elementwise_div_grad, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, + ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel); + plat::float16>); diff --git a/paddle/fluid/operators/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise_mul_op.cu index 2fb1b4bee6..350d43168d 100644 --- a/paddle/fluid/operators/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise_mul_op.cu @@ -14,19 +14,25 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_mul_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_mul, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel); + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel); REGISTER_OP_CUDA_KERNEL( elementwise_mul_grad, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel); diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index bc3e95e904..7223a972d2 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -350,7 +350,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( int j = blockIdx.x; int i = threadIdx.x; int tid = threadIdx.x; - T val = 0; + T val(0); do { int x_offset = i * w + j; @@ -418,7 +418,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( int tid = threadIdx.x; int j = blockIdx.x; - T val = 0; + T val(0); int ttid = tid; while (true) { diff --git a/paddle/fluid/operators/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise_sub_op.cu index 8709f686f9..ff3f6f8a2c 100644 --- a/paddle/fluid/operators/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise_sub_op.cu @@ -14,19 +14,25 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_sub_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_sub, ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, - ops::ElementwiseSubKernel); + ops::ElementwiseSubKernel, + ops::ElementwiseSubKernel); REGISTER_OP_CUDA_KERNEL( elementwise_sub_grad, ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, + ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel); diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 130f18dde4..862249269e 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -12,48 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/operators/fill_constant_op.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { -class FillConstantInferShape : public framework::InferShapeBase { +class FillConstantOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *ctx) const override { + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FillConstantOp should not be null."); - auto &shape = ctx->Attrs().Get>("shape"); + auto& shape = ctx->Attrs().Get>("shape"); ctx->SetOutputDim("Out", framework::make_ddim(shape)); } -}; - -class FillConstantOp : public framework::OperatorBase { - public: - using framework::OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - auto data_type = - static_cast(Attr("dtype")); - auto value = Attr("value"); - auto force_cpu = Attr("force_cpu"); - auto &out = - *scope.FindVar(Output("Out"))->GetMutable(); - out.Resize(framework::make_ddim(Attr>("shape"))); - if (force_cpu) { - auto cpu = platform::CPUPlace(); - out.mutable_data(cpu, framework::ToTypeIndex(data_type)); - } else { - out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); - } - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); - math::set_constant(dev_ctx, &out, value); + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + static_cast(ctx.Attr("dtype")), + ctx.device_context()); } }; @@ -87,6 +67,11 @@ Fill up a variable with specified constant value. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, - ops::FillConstantInferShape, ops::FillConstantOpMaker, +REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, ops::FillConstantOpMaker, paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + fill_constant, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel) diff --git a/paddle/fluid/operators/fill_constant_op.cu.cc b/paddle/fluid/operators/fill_constant_op.cu.cc new file mode 100644 index 0000000000..51ccaefa43 --- /dev/null +++ b/paddle/fluid/operators/fill_constant_op.cu.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/fill_constant_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + fill_constant, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel) diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h new file mode 100644 index 0000000000..b2a2a7b2fa --- /dev/null +++ b/paddle/fluid/operators/fill_constant_op.h @@ -0,0 +1,48 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class FillConstantOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto data_type = + static_cast(ctx.Attr("dtype")); + auto value = ctx.Attr("value"); + auto force_cpu = ctx.Attr("force_cpu"); + auto* out = ctx.Output("Out"); + out->Resize(framework::make_ddim(ctx.Attr>("shape"))); + if (force_cpu) { + auto cpu = platform::CPUPlace(); + out->mutable_data(cpu, framework::ToTypeIndex(data_type)); + } else { + out->mutable_data(ctx.GetPlace(), framework::ToTypeIndex(data_type)); + } + + math::set_constant(ctx.template device_context(), out, + value); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index 925dc19061..352a17c927 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -69,7 +70,6 @@ class FillOp : public framework::OperatorBase { framework::VisitDataType( dtype, FillOpVisitor(&tensor, Attr>("value"))); - if (!force_cpu && platform::is_gpu_place(place)) { // Copy tensor to out platform::DeviceContextPool &pool = diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index 7784856417..b490723795 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -60,6 +61,7 @@ class GPUGaussianRandomKernel : public framework::OpKernel { } // namespace operators } // namespace paddle +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(gaussian_random, paddle::operators::GPUGaussianRandomKernel, paddle::operators::GPUGaussianRandomKernel); diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index 0de58d5fdd..58b85abf82 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -15,11 +15,25 @@ limitations under the License. */ #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { namespace math { +template +HOSTDEVICE T log(const T& val) { + return std::log(val); +} + +template <> +HOSTDEVICE platform::float16 log(const platform::float16& val) { + // strage bug, hlog is not exists. + return static_cast(0); + // half tmp = static_cast(val); + // return static_cast(hlog(tmp)); +} + namespace { template __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, @@ -35,12 +49,12 @@ template __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, const int class_num) { int tid = threadIdx.x; - T val = 0; + T val(0); int idx = blockIdx.x * class_num + tid; int end = blockIdx.x * class_num + class_num; for (; idx < end; idx += blockDim.x) { - val += math::TolerableValue()(std::log(X[idx])) * label[idx]; + val += math::TolerableValue()(log(X[idx])) * label[idx]; } val = paddle::platform::reduceSum(val, tid, blockDim.x); @@ -84,6 +98,8 @@ class CrossEntropyFunctor { template class CrossEntropyFunctor; template class CrossEntropyFunctor; +template class CrossEntropyFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/cross_entropy.h b/paddle/fluid/operators/math/cross_entropy.h index adc5b3fe47..2e4e4781c2 100644 --- a/paddle/fluid/operators/math/cross_entropy.h +++ b/paddle/fluid/operators/math/cross_entropy.h @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/hostdevice.h" namespace paddle { @@ -33,6 +35,21 @@ struct TolerableValue { } }; +// float16 value clip behave different. +using paddle::platform::float16; +using paddle::platform::isfinite; +template <> +struct TolerableValue { + HOSTDEVICE float16 operator()(const float16& x) const { + if (isfinite(x)) + return x; + else if (x > static_cast(0)) + return std::numeric_limits::max(); + else + return std::numeric_limits::min(); + } +}; + template class CrossEntropyFunctor { public: diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index a92762c7fe..00dbfc11a2 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -76,6 +77,7 @@ struct SelectedRowsAdd { template struct SelectedRowsAdd; template struct SelectedRowsAdd; +template struct SelectedRowsAdd; namespace { template @@ -120,7 +122,7 @@ struct SelectedRowsAddTensor { auto* out_data = output->data(); SetConstant functor; - functor(context, output, 0.0); + functor(context, output, static_cast(0)); const int block_size = 256; dim3 threads(block_size, 1); @@ -138,6 +140,8 @@ struct SelectedRowsAddTensor { template struct SelectedRowsAddTensor; template struct SelectedRowsAddTensor; +template struct SelectedRowsAddTensor; template struct SelectedRowsAddTo { @@ -177,6 +181,8 @@ template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; +template struct SelectedRowsAddTo; namespace { template @@ -229,6 +235,8 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; namespace scatter { @@ -276,7 +284,7 @@ struct MergeAdd { context.GetPlace()); math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); + constant_functor(context, out.mutable_value(), static_cast(0)); auto* out_data = out.mutable_value()->data(); auto* input_data = input.value().data(); @@ -300,6 +308,7 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; +template struct MergeAdd; template __global__ void UpdateToTensorKernel(const T* selected_rows, diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 3effe77625..785c4baecb 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -94,12 +94,15 @@ void SoftmaxGradCUDNNFunctor::operator()( template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; +template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index 91e0ab28ef..07aa23754f 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -12,14 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/mean_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( mean, ops::MeanKernel, - ops::MeanKernel); + ops::MeanKernel, + ops::MeanKernel); REGISTER_OP_CUDA_KERNEL( mean_grad, ops::MeanGradKernel, - ops::MeanGradKernel); + ops::MeanGradKernel, + ops::MeanGradKernel); diff --git a/paddle/fluid/operators/mean_op.h b/paddle/fluid/operators/mean_op.h index 362e9f9ae8..a41d50ae0b 100644 --- a/paddle/fluid/operators/mean_op.h +++ b/paddle/fluid/operators/mean_op.h @@ -55,7 +55,7 @@ class MeanGradKernel : public framework::OpKernel { IG->mutable_data(context.GetPlace()); T ig_size = static_cast(IG->numel()); - Eigen::DSizes bcast(ig_size); + Eigen::DSizes bcast(static_cast(ig_size)); EigenVector::Flatten(*IG).device( *context.template device_context().eigen_device()) = diff --git a/paddle/fluid/operators/mul_op.cu.cc b/paddle/fluid/operators/mul_op.cu.cc index 81f3e42bf4..6c5a83c6a5 100644 --- a/paddle/fluid/operators/mul_op.cu.cc +++ b/paddle/fluid/operators/mul_op.cu.cc @@ -20,6 +20,7 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel, ops::MulKernel, ops::MulKernel); -REGISTER_OP_CUDA_KERNEL(mul_grad, - ops::MulGradKernel, - ops::MulGradKernel); +REGISTER_OP_CUDA_KERNEL( + mul_grad, ops::MulGradKernel, + ops::MulGradKernel, + ops::MulGradKernel); diff --git a/paddle/fluid/operators/pool_cudnn_op.cu.cc b/paddle/fluid/operators/pool_cudnn_op.cu.cc index 31f083565f..9fdbee818a 100644 --- a/paddle/fluid/operators/pool_cudnn_op.cu.cc +++ b/paddle/fluid/operators/pool_cudnn_op.cu.cc @@ -174,7 +174,8 @@ REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel); REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, - ops::PoolCUDNNGradOpKernel); + ops::PoolCUDNNGradOpKernel, + ops::PoolCUDNNGradOpKernel); REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel, @@ -182,4 +183,5 @@ REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel); REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, - ops::PoolCUDNNGradOpKernel); + ops::PoolCUDNNGradOpKernel, + ops::PoolCUDNNGradOpKernel); diff --git a/paddle/fluid/operators/scale_op.cu b/paddle/fluid/operators/scale_op.cu index 04c802da12..d266867046 100644 --- a/paddle/fluid/operators/scale_op.cu +++ b/paddle/fluid/operators/scale_op.cu @@ -13,11 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/scale_op.h" +#include "paddle/fluid/platform/float16.h" +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( scale, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel); + int64_t>, + paddle::operators::ScaleKernel); diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index 2bdb23e999..c2d45c3d2e 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -78,4 +78,5 @@ REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, ops::SoftmaxCUDNNKernel, ops::SoftmaxCUDNNKernel); REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, - ops::SoftmaxGradCUDNNKernel); + ops::SoftmaxGradCUDNNKernel, + ops::SoftmaxGradCUDNNKernel); diff --git a/paddle/fluid/operators/softmax_op.cu.cc b/paddle/fluid/operators/softmax_op.cu.cc index 5fb4f011d9..19359b7eef 100644 --- a/paddle/fluid/operators/softmax_op.cu.cc +++ b/paddle/fluid/operators/softmax_op.cu.cc @@ -23,4 +23,5 @@ REGISTER_OP_CUDA_KERNEL( ops::SoftmaxKernel); REGISTER_OP_CUDA_KERNEL( softmax_grad, ops::SoftmaxGradKernel, - ops::SoftmaxGradKernel); + ops::SoftmaxGradKernel, + ops::SoftmaxGradKernel); diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index 89bcd1bbc8..db4c2d6c11 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -11,10 +11,13 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/sum_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( sum, ops::SumKernel, ops::SumKernel, ops::SumKernel, - ops::SumKernel); + ops::SumKernel, + ops::SumKernel); diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 49a4afb3a8..dda6772796 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -46,7 +46,7 @@ class SumKernel : public framework::OpKernel { if (!in_place) { math::SetConstant constant_functor; constant_functor(context.template device_context(), out, - 0.0); + static_cast(0)); } math::SelectedRowsAddToTensor functor; diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 9da8551eb2..5fc0784f66 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -11,16 +11,19 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using paddle::platform::float16; template struct Pair { @@ -32,6 +35,11 @@ struct Pair { id = id; } + __device__ __forceinline__ void clear() { + v = -INFINITY; + id = -1; + } + __device__ __forceinline__ void operator=(const Pair& in) { v = in.v; id = in.id; @@ -53,6 +61,12 @@ struct Pair { int64_t id; }; +template <> +__device__ __forceinline__ void Pair::clear() { + v = platform::raw_uint16_to_float16(0x400); + id = -1; +} + template __device__ __forceinline__ void AddTo(Pair topk[], const Pair& p, int beam_size) { @@ -150,7 +164,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, if (k < MaxLength - (*beam)) { topk[k] = topk[k + *beam]; } else { - topk[k].set(-INFINITY, -1); + topk[k].clear(); } } if (!(*is_empty)) { @@ -160,7 +174,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, } *max = topk[MaxLength - 1]; - if ((*max).v == -1) *is_empty = true; + if ((*max).v == static_cast(-1)) *is_empty = true; *beam = 0; } } @@ -181,7 +195,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, if (k < MaxLength - *beam) { topk[k] = topk[k + *beam]; } else { - topk[k].set(-INFINITY, -1); + topk[k].set(std::numeric_limits::min(), -1); } } if (!(*is_empty)) { @@ -273,7 +287,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, bool firststep = true; for (int k = 0; k < MaxLength; k++) { - topk[k].set(-INFINITY, -1); + topk[k].clear(); } while (k) { ThreadGetTopK(topk, &beam, k, @@ -325,5 +339,7 @@ class TopkOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel, - paddle::operators::TopkOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + top_k, paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel); diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index e1c7323a30..2b8039a0c1 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -11,10 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include +#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { @@ -36,6 +40,11 @@ struct UniformGenerator { } }; +template +struct CastFunctor { + HOSTDEVICE V operator()(const T& a) { return static_cast(a); } +}; + // It seems that Eigen::Tensor::random in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. @@ -66,18 +75,50 @@ class GPUUniformRandomKernel : public framework::OpKernel { T max = static_cast(context.Attr("max")); thrust::counting_iterator index_sequence_begin(0); int64_t size = tensor->numel(); - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - UniformGenerator(min, max, seed)); + if (out_var->IsType() && + std::type_index(typeid(T)) == + std::type_index(typeid(platform::float16))) { + framework::Tensor master_copy_tensor; + master_copy_tensor.Resize(tensor->dims()); + float* master_copy_tensor_data = + master_copy_tensor.mutable_data(context.GetPlace()); + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(master_copy_tensor_data), + UniformGenerator(static_cast(min), + static_cast(max), seed)); + platform::Transform trans; + auto* in_begin = master_copy_tensor.data(); + auto* in_end = in_begin + master_copy_tensor.numel(); + auto* out_begin = tensor->mutable_data(context.GetPlace()); + trans(context.template device_context(), + in_begin, in_end, out_begin, CastFunctor()); + } else { + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + UniformGenerator(min, max, seed)); + } + if (VLOG_IS_ON(5)) { + framework::Tensor cpu_tensor; + framework::TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor); + auto& dev_ctx = + *platform::DeviceContextPool::Instance().Get(context.GetPlace()); + dev_ctx.Wait(); + auto x = framework::EigenVector::Flatten(cpu_tensor); + VLOG(5) << "The Uniform output " << x; + } } }; } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); -REGISTER_OP_CUDA_KERNEL(uniform_random_batch_size_like, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + uniform_random, paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); +REGISTER_OP_CUDA_KERNEL( + uniform_random_batch_size_like, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); From 9f3789944c2c98605f26ffd224fbe1df02fa2e68 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Thu, 16 Aug 2018 11:34:21 +0800 Subject: [PATCH 26/29] use latest anakin commit --- CMakeLists.txt | 3 ++- cmake/external/anakin.cmake | 5 ++--- paddle/fluid/inference/api/CMakeLists.txt | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 920c20d6f8..6844772711 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -204,11 +204,12 @@ include(external/snappy) # download snappy include(external/snappystream) include(external/threadpool) -set(WITH_ANAKIN OFF CACHE STRING "Disable Anakin first, will add it later." FORCE) if(WITH_GPU) include(cuda) include(tensorrt) include(external/anakin) +elseif() + set(WITH_ANAKIN OFF CACHE STRING "Anakin is used in GPU only now." FORCE) endif() include(cudnn) # set cudnn libraries, must before configure diff --git a/cmake/external/anakin.cmake b/cmake/external/anakin.cmake index 5de7ca8f46..455ef91ac5 100644 --- a/cmake/external/anakin.cmake +++ b/cmake/external/anakin.cmake @@ -35,9 +35,8 @@ set(ANAKIN_COMPILE_EXTRA_FLAGS ExternalProject_Add( extern_anakin ${EXTERNAL_PROJECT_LOG_ARGS} - # TODO(luotao): use PaddlePaddle/Anakin later - GIT_REPOSITORY "https://github.com/luotao1/Anakin" - GIT_TAG "842a89ae3747ede25d8acbc29030d2eb602ced1f" + GIT_REPOSITORY "https://github.com/PaddlePaddle/Anakin" + GIT_TAG "04256ba78fa3da0beb74e8036c8efd68c12824d6" PREFIX ${ANAKIN_SOURCE_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DUSE_GPU_PLACE=YES diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 83867e0a2c..a72e27d651 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -60,7 +60,7 @@ cc_library(paddle_inference_tensorrt_subgraph_engine inference_api_test(test_api_tensorrt_subgraph_engine SRC api_tensorrt_subgraph_engine_tester.cc ARGS test_word2vec) endif() -if (WITH_ANAKIN) # only needed in CI +if (WITH_ANAKIN AND WITH_GPU) # only needed in CI # compile the libinference_anakin_api.a and anakin.so. nv_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber) #nv_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS anakin) From c44fb003715aab90d14f0d0fce020d0b65ec6fbf Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Thu, 16 Aug 2018 12:01:22 +0800 Subject: [PATCH 27/29] Add name in relu and log API. (#12438) --- paddle/fluid/API.spec | 4 ++-- python/paddle/fluid/layers/nn.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index ea9105d79c..e963902a50 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -155,8 +155,8 @@ paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.random_crop ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) +paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3e50fc91d9..be852b6711 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5090,7 +5090,7 @@ def random_crop(x, shape, seed=None): return out -def log(x): +def log(x, name=None): """ Calculates the natural log of the given input tensor, element-wise. @@ -5100,6 +5100,8 @@ def log(x): Args: x (Variable): Input tensor. + name (str|None, default None): A name for this layer If set None, + the layer will be named automatically. Returns: Variable: The natural log of the input tensor computed element-wise. @@ -5117,7 +5119,7 @@ def log(x): return out -def relu(x): +def relu(x, name=None): """ Relu takes one input data (Tensor) and produces one output data (Tensor) where the rectified linear function, y = max(0, x), is applied to @@ -5129,6 +5131,8 @@ def relu(x): Args: x (Variable): The input tensor. + name (str|None, default None): A name for this layer If set None, + the layer will be named automatically. Returns: Variable: The output tensor with the same shape as input. From 317e18abd2aa69390dcc6a0d6760ba954597863e Mon Sep 17 00:00:00 2001 From: Qingsheng Li Date: Thu, 16 Aug 2018 13:00:55 +0800 Subject: [PATCH 28/29] Remove Data Sharing between input and output in scatter_op (#12672) * Remove Data Sharing between input and output in scatter_op * Removed data sharing in backward op --- paddle/fluid/operators/scatter_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index d29947b55e..181bb1af5c 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -35,7 +35,7 @@ class ScatterOpKernel : public framework::OpKernel { auto *Out = ctx.Output("Out"); // In place output: Out = X, Out[Ids] += Updates - Out->ShareDataWith(*X); + framework::TensorCopySync(*X, ctx.GetPlace(), Out); // Apply ScatterUpdate: Out[index] += Updates[:] ScatterAssign(ctx.device_context(), *Updates, *Ids, Out); } @@ -53,7 +53,7 @@ class ScatterGradientOpKernel : public framework::OpKernel { auto *dOut = ctx.Input(framework::GradVarName("Out")); // In place gradient: dX = dO - dX->ShareDataWith(*dOut); + framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates += dO[Ids] CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); From d7873e14124a157980049f3dc6a281638ce437ee Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 16 Aug 2018 13:48:46 +0800 Subject: [PATCH 29/29] remove patchelf in windows (#12710) * remove patchelf in windowls * "follow comment" --- .gitignore | 2 ++ cmake/configure.cmake | 4 ++++ python/CMakeLists.txt | 5 +++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 9e3a0b499f..b92bb9cc12 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ python/paddle/v2/fluid/tests/book/image_classification_resnet.inference.model/ python/paddle/v2/fluid/tests/book/image_classification_vgg.inference.model/ python/paddle/v2/fluid/tests/book/label_semantic_roles.inference.model/ *.DS_Store +*.vs build/ build_doc/ *.user @@ -15,6 +16,7 @@ build_doc/ .cproject .pydevproject .settings/ +CMakeSettings.json Makefile .test_env/ third_party/ diff --git a/cmake/configure.cmake b/cmake/configure.cmake index ae90a529b1..d14162e0a6 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -56,6 +56,10 @@ if(NOT CMAKE_CROSSCOMPILING) set(SIMD_FLAG ${SSE3_FLAG}) endif() endif() +if(UNIX AND NOT APPLE) + # except apple from nix*Os family + set(LINUX TRUE) +endif(UNIX AND NOT APPLE) if(NOT WITH_GOLANG) add_definitions(-DPADDLE_WITHOUT_GOLANG) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 2590081150..9cdcb87df5 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -97,10 +97,11 @@ if(APPLE) if(NOT INSTALL_NAME_TOOL_EXECUTABLE) message(FATAL_ERROR "install_name_tool not found, please check.\n") endif() -else(APPLE) +endif() +if(LINUX) find_program(PATCHELF_EXECUTABLE patchelf) if(NOT PATCHELF_EXECUTABLE) message(FATAL_ERROR "patchelf not found, please install it.\n" "For Ubuntu, the command is: apt-get install -y patchelf.") endif() -endif(APPLE) +endif(LINUX)