From 3e01a4048f28ad5cf4b33fb808b07965d9e7ff5d Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Fri, 28 Dec 2018 16:34:13 +0000 Subject: [PATCH 01/28] add refer seqpool jitkernel --- paddle/fluid/operators/jit/kernel_base.h | 20 +++++++++++++++++++ paddle/fluid/operators/jit/kernel_key.cc | 6 ++++++ .../fluid/operators/jit/refer/CMakeLists.txt | 1 + paddle/fluid/operators/jit/refer/refer.cc | 2 ++ paddle/fluid/operators/jit/refer/refer.h | 16 +++++++++++++++ 5 files changed, 45 insertions(+) diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index b4a2d5d473..8f13fbb16e 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -41,6 +41,7 @@ typedef enum { kCRFDecoding, kLayerNorm, kNCHW16CMulNC, + kSeqPool, } KernelType; template <typename T> @@ -112,6 +113,25 @@ struct GRUTuples { typedef void (*func_type)(gru_t*, const gru_attr_t*); }; +typedef enum { + non = 0, + sum, + avg, + sqrt, +} SeqPoolType; + +typedef struct { + int h, w; + SeqPoolType type; +} seq_pool_attr_t; + +template <typename T> +struct SeqPoolTuples { + typedef T data_type; + typedef seq_pool_attr_t attr_type; + typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); +}; + template <typename T> struct CRFDecodingTuples { typedef T data_type; diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index 4e6a19f04f..6b0025a75a 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -42,6 +42,12 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) { (static_cast<int>(attr.act_cand) << act_type_shift); } +template <> +size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) { + size_t key = static_cast<size_t>(attr.type); + return key + (attr.w << act_type_shift); +} + } // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index 07497b7320..0f626bb3bf 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -26,3 +26,4 @@ USE_JITKERNEL_REFER(kGRUHtPart2) USE_JITKERNEL_REFER(kCRFDecoding) USE_JITKERNEL_REFER(kLayerNorm) USE_JITKERNEL_REFER(kNCHW16CMulNC) +USE_JITKERNEL_REFER(kSeqPool) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index d196266326..85381daa47 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -47,4 +47,6 @@ REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm); REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); +REGISTER_REFER_KERNEL(kSeqPool, SeqPool); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 0fd1b89dfd..52fe2de02a 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -332,6 +332,20 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { } } +template <typename T> +void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { + PADDLE_ENFORCE(attr->type == SeqPoolType::sum, "Only support sum yet"); + for (int w = 0; w < attr->w; ++w) { + const T* src = x + w; + T* dst = y + w; + *dst = static_cast<T>(0); + for (int h = 0; h < attr->h; ++h) { + *dst = *dst + *src; + src += attr->w; + } + } +} + #define DECLARE_REFER_KERNEL(name, tuples) \ template <typename T> \ class name##Kernel : public ReferKernel<tuples<T>> { \ @@ -370,6 +384,8 @@ DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); +DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer From e58a569c6cdb8ab66c7dff69395518cee224fe67 Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Fri, 28 Dec 2018 16:35:00 +0000 Subject: [PATCH 02/28] use seqpool jitkernel --- paddle/fluid/operators/math/CMakeLists.txt | 2 +- .../fluid/operators/math/sequence_pooling.cc | 32 ++++++++++++------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index ea6aebd291..600ab14d37 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -51,7 +51,7 @@ math_library(pooling) math_library(selected_rows_functor DEPS selected_rows math_function blas) math_library(sequence2batch) math_library(sequence_padding) -math_library(sequence_pooling DEPS math_function) +math_library(sequence_pooling DEPS math_function jit_kernel_helper) math_library(sequence_scale) math_library(softmax DEPS math_function) diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 6d491dbf1e..23dc516933 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include <string> +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence_pooling.h" @@ -239,15 +240,33 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { last_pool(context, input, output); return; } - if (pooltype == "FIRST") { math::FirstSeqPoolFunctor<T> first_pool; first_pool(context, input, output); return; } + auto lod = input.lod()[0]; + if (pooltype == "SUM") { + auto place = context.GetPlace(); + PADDLE_ENFORCE(platform::is_cpu_place(place)); + const T* src = input.data<T>(); + T* dst = output->mutable_data<T>(place); + jit::seq_pool_attr_t attr; + attr.w = input.numel() / input.dims()[0]; + attr.type = jit::SeqPoolType::sum; + auto seqpool = + jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( + attr); + for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { + attr.h = static_cast<int>(lod[i + 1] - lod[i]); + seqpool(src, dst, &attr); + dst += attr.w; + src += attr.h * attr.w; + } + return; + } auto& place = *context.eigen_device(); - auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { Tensor in_t = input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1])); @@ -258,15 +277,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { auto out_e = EigenVector<T>::Flatten(out_t); if (pooltype == "AVERAGE") { out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}})); - } else if (pooltype == "SUM") { - if (h > 0) { - const T* in_data = in_t.data<T>(); - T* out_data = out_t.mutable_data<T>(context.GetPlace()); - blas.VCOPY(w, in_data, out_data); - for (int64_t r = 1; r != h; ++r) { - blas.AXPY(w, 1., in_data + r * w, out_data); - } - } } else if (pooltype == "SQRT") { out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) / std::sqrt(static_cast<T>(h)); From 142bb417483f9e0e71a26d24d30eb01c6d2f7754 Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Sat, 29 Dec 2018 05:13:08 +0000 Subject: [PATCH 03/28] add seqpool jitkernel test and benchmark --- paddle/fluid/operators/jit/benchmark.cc | 21 ++++++++ paddle/fluid/operators/jit/helper.cc | 15 ++++++ paddle/fluid/operators/jit/helper.h | 6 +++ paddle/fluid/operators/jit/kernel_base.h | 19 ++++---- paddle/fluid/operators/jit/refer/refer.h | 2 +- paddle/fluid/operators/jit/test.cc | 48 +++++++++++++++++++ .../fluid/operators/math/sequence_pooling.cc | 2 +- 7 files changed, 103 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 437005825d..f64e43389a 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -190,6 +190,24 @@ void BenchGRUKernel() { } } +template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> +void BenchSeqPoolKernel() { + std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; + for (auto type : pool_types) { + for (int h : TestSizes()) { + for (int w : TestSizes()) { + const jit::seq_pool_attr_t attr(h, w, type); + std::vector<T> x(h * w), y(w); + RandomVec<T>(h * w, x.data(), -2.f, 2.f); + const T* x_data = x.data(); + T* y_data = y.data(); + BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data, + y_data, &attr); + } + } + } +} + // Benchmark all jit kernels including jitcode, mkl and refer. // To use this tool, run command: ./benchmark [options...] // Options: @@ -228,4 +246,7 @@ int main(int argc, char* argv[]) { BenchGRUKernel<jit::kGRUH1, T, PlaceType>(); BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>(); BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>(); + + // seq pool function + BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>(); } diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index d00584baa0..7d02590f2e 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -26,6 +26,7 @@ namespace jit { const char* to_string(KernelType kt) { switch (kt) { + ONE_CASE(kNone); ONE_CASE(kVMul); ONE_CASE(kVAdd); ONE_CASE(kVAddRelu); @@ -45,12 +46,26 @@ const char* to_string(KernelType kt) { ONE_CASE(kCRFDecoding); ONE_CASE(kLayerNorm); ONE_CASE(kNCHW16CMulNC); + ONE_CASE(kSeqPool); default: PADDLE_THROW("Not support type: %d, or forget to add it.", kt); return "NOT JITKernel"; } return nullptr; } + +const char* to_string(SeqPoolType tp) { + switch (tp) { + ONE_CASE(kNonePoolType); + ONE_CASE(kSum); + ONE_CASE(kAvg); + ONE_CASE(kSqrt); + default: + PADDLE_THROW("Not support type: %d, or forget to add it.", tp); + return "NOT PoolType"; + } + return nullptr; +} #undef ONE_CASE KernelType to_kerneltype(const std::string& act) { diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 412df86aa1..fbf34fc4b3 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -119,6 +119,7 @@ typename KernelTuples::func_type Get( } const char* to_string(KernelType kt); +const char* to_string(SeqPoolType kt); KernelType to_kerneltype(const std::string& act); @@ -134,6 +135,11 @@ inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) { << "],act_cand[" << to_string(attr.act_cand) << "]"; return os; } +inline std::ostream& operator<<(std::ostream& os, const seq_pool_attr_t& attr) { + os << "height_size[" << attr.h << "],width_size[" << attr.w << "],pool_type[" + << to_string(attr.type) << "]"; + return os; +} } // namespace jit } // namespace operators diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 8f13fbb16e..2659374650 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -44,6 +44,13 @@ typedef enum { kSeqPool, } KernelType; +typedef enum { + kNonePoolType = 0, + kSum, + kAvg, + kSqrt, +} SeqPoolType; + template <typename T> struct XYZNTuples { typedef T data_type; @@ -113,16 +120,12 @@ struct GRUTuples { typedef void (*func_type)(gru_t*, const gru_attr_t*); }; -typedef enum { - non = 0, - sum, - avg, - sqrt, -} SeqPoolType; - -typedef struct { +typedef struct seq_pool_attr_s { int h, w; SeqPoolType type; + seq_pool_attr_s() = default; + explicit seq_pool_attr_s(int height, int width, SeqPoolType pool_type) + : h(height), w(width), type(pool_type) {} } seq_pool_attr_t; template <typename T> diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 52fe2de02a..c2aa922528 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -334,7 +334,7 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { template <typename T> void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { - PADDLE_ENFORCE(attr->type == SeqPoolType::sum, "Only support sum yet"); + PADDLE_ENFORCE(attr->type == SeqPoolType::kSum, "Only support sum yet"); for (int w = 0; w < attr->w; ++w) { const T* src = x + w; T* dst = y + w; diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index a73e2a60ae..0f1776507a 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -211,6 +211,24 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>, } }; +template <typename T> +struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>, + std::vector<T>> { + void operator()(const typename jit::SeqPoolTuples<T>::func_type tgt, + const std::vector<T>& x, const std::vector<T>& yref, + const typename jit::SeqPoolTuples<T>::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(x.size() % yref.size(), 0); + int w = yref.size(); + std::vector<T> y(w); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + T* y_data = y.data(); + tgt(x_data, y_data, &attr); + ExpectEQ<T>(y_data, yref_data, w); + } +}; + template <paddle::operators::jit::KernelType KT, typename KernelTuples, typename PlaceType, typename... Args> void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { @@ -415,6 +433,30 @@ void TestGRUKernel() { } } +template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> +void TestSeqPoolKernel() { + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + // TODO(TJ): support more + std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; + for (auto type : pool_types) { + for (int h : TestSizes()) { + for (int w : TestSizes()) { + const jit::seq_pool_attr_t attr(h, w, type); + auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>(); + EXPECT_TRUE(ref != nullptr); + std::vector<T> x(h * w), yref(w); + RandomVec<T>(h * w, x.data(), -2.f, 2.f); + const T* x_data = x.data(); + T* yref_data = yref.data(); + ref(x_data, yref_data, &attr); + VLOG(10) << attr; + TestAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType, std::vector<T>, + std::vector<T>>(attr, x, yref, attr); + } + } + } +} + template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> void TestNCHW16CMulNCKernel() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); @@ -569,6 +611,12 @@ TEST(JITKernel, kGRUHtPart2) { TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>(); } +TEST(JITKernel, kSeqPool) { + namespace jit = paddle::operators::jit; + TestSeqPoolKernel<jit::kSeqPool, float, paddle::platform::CPUPlace>(); + TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>(); +} + TEST(JITKernel, kNCHW16CMulNC) { namespace jit = paddle::operators::jit; TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 23dc516933..98707c936d 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -254,7 +254,7 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { T* dst = output->mutable_data<T>(place); jit::seq_pool_attr_t attr; attr.w = input.numel() / input.dims()[0]; - attr.type = jit::SeqPoolType::sum; + attr.type = jit::SeqPoolType::kSum; auto seqpool = jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( attr); From c50060bb264a3e70ef55abfdd8ab74416cb14121 Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Sat, 29 Dec 2018 06:26:02 +0000 Subject: [PATCH 04/28] add jitcode impl and use it --- paddle/fluid/operators/jit/gen/CMakeLists.txt | 1 + paddle/fluid/operators/jit/gen/seqpool.cc | 132 ++++++++++++++++++ paddle/fluid/operators/jit/gen/seqpool.h | 98 +++++++++++++ paddle/fluid/operators/jit/kernel_key.cc | 7 +- .../fluid/operators/math/sequence_pooling.cc | 6 +- 5 files changed, 239 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/jit/gen/seqpool.cc create mode 100644 paddle/fluid/operators/jit/gen/seqpool.h diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index 8a54010830..2b8c758a03 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -26,3 +26,4 @@ USE_JITKERNEL_GEN(kGRUH1) USE_JITKERNEL_GEN(kGRUHtPart1) USE_JITKERNEL_GEN(kGRUHtPart2) USE_JITKERNEL_GEN(kNCHW16CMulNC) +USE_JITKERNEL_GEN(kSeqPool) diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc new file mode 100644 index 0000000000..ce6801b030 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/seqpool.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/seqpool.h" +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +void SeqPoolJitCode::genCode() { + constexpr int block = YMM_FLOAT_BLOCK; + constexpr int max_num_regs = 8; + const int num_block = w_ / block; + const int num_groups = num_block / max_num_regs; + int rest_num_regs = num_block % max_num_regs; + if (type_ == SeqPoolType::kAvg) { + float scalar = 1.f / h_; + mov(reg32_scalar, scalar); + } else if (type_ == SeqPoolType::kSqrt) { + float scalar = 1.f / std::sqrt(static_cast<float>(h_)); + mov(reg32_scalar, scalar); + } + + // TODO(TJ): make height load from params + const int group_len = max_num_regs * block * sizeof(float); + for (int g = 0; g < num_groups; ++g) { + pool_height<ymm_t>(g * group_len, block, max_num_regs); + } + if (rest_num_regs > 0) { + pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs); + } + + // rest part + const int rest = w_ % block; + const bool has_block4 = rest / 4 > 0; + const bool has_block2 = (rest % 4) / 2 > 0; + const bool has_block1 = (rest % 2) == 1; + const int w_offset = num_block * YMM_FLOAT_BLOCK * sizeof(float); + for (int h = 0; h < h_; ++h) { + int offset = h * w_ * sizeof(float) + w_offset; + const int shift_regs = (h == 0) ? 0 : max_num_regs; + int reg_idx = 0; + if (has_block4) { + vmovups(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]); + offset += sizeof(float) * 4; + reg_idx++; + } + if (has_block2) { + vmovq(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]); + offset += sizeof(float) * 2; + reg_idx++; + } + if (has_block1) { + vmovss(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]); + reg_idx++; + } + rest_num_regs = reg_idx; + if (h > 0) { + for (int i = 0; i < reg_idx; ++i) { + vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); + } + } + } + // save right now + int offset = w_offset; + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar); + for (int i = 0; i < rest_num_regs; ++i) { + vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1)); + } + } + int reg_idx = 0; + if (has_block4) { + vmovups(ptr[param2 + offset], xmm_t(reg_idx)); + offset += sizeof(float) * 4; + reg_idx++; + } + if (has_block2) { + vmovq(ptr[param2 + offset], xmm_t(reg_idx)); + offset += sizeof(float) * 2; + reg_idx++; + } + if (has_block1) { + vmovss(ptr[param2 + offset], xmm_t(reg_idx)); + } + ret(); +} + +class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> { + public: + bool UseMe(const seq_pool_attr_t& attr) const override { + return platform::MayIUse(platform::avx); + } + size_t CodeSize(const seq_pool_attr_t& attr) const override { + // TODO(TJ): remove attr.h when enabled height + bool yes = + attr.type == SeqPoolType::kAvg || attr.type == SeqPoolType::kSqrt; + return 96 /* basic */ + + ((attr.w / YMM_FLOAT_BLOCK + 4 /* rest */) * 2 /* for sum */ + * (attr.h + (yes ? 3 : 1 /*for avg or sqrt*/))) * + 8; + } + std::unique_ptr<GenBase> CreateJitCode( + const seq_pool_attr_t& attr) const override { + PADDLE_ENFORCE_GT(attr.w, 0); + PADDLE_ENFORCE_GT(attr.h, 0); + return make_unique<SeqPoolJitCode>(attr, CodeSize(attr)); + } +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(kSeqPool, gen::SeqPoolCreator); diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h new file mode 100644 index 0000000000..eb2d191382 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/seqpool.h @@ -0,0 +1,98 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include <string> +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +class SeqPoolJitCode : public JitCode { + public: + explicit SeqPoolJitCode(const seq_pool_attr_t& attr, + size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), h_(attr.h), w_(attr.w), type_(attr.type) { + if (type_ != SeqPoolType::kSum) { + LOG(FATAL) << "Only support sum pool yet "; + } + this->genCode(); + } + + virtual const char* name() const { + std::string base = "SeqPoolJitCode"; + if (type_ == SeqPoolType::kSum) { + base += "_Sum"; + } else if (type_ == SeqPoolType::kAvg) { + base += "_Avg"; + } else if (type_ == SeqPoolType::kSqrt) { + base += "_Sqrt"; + } + base += ("_W" + std::to_string(w_)); + // TODO(TJ): make h load from params + base += ("_H" + std::to_string(h_)); + return base.c_str(); + } + void genCode() override; + + protected: + template <typename JMM> + void pool_height(int w_offset, int block, int max_num_regs) { + for (int h = 0; h < h_; ++h) { + int offset = h * w_ * sizeof(float) + w_offset; + const int shift_regs = (h == 0) ? 0 : max_num_regs; + for (int i = 0; i < max_num_regs; ++i) { + vmovups(JMM(i + shift_regs), ptr[param1 + offset]); + offset += sizeof(float) * block; + } + if (h > 0) { + // sum anyway + for (int i = 0; i < max_num_regs; ++i) { + vaddps(JMM(i), JMM(i), JMM(i + max_num_regs)); + } + } + } + // save right now + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + vbroadcastss(JMM(max_num_regs), reg32_scalar); + } + int offset = w_offset; + for (int i = 0; i < max_num_regs; ++i) { + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + vmulps(JMM(i), JMM(i), JMM(max_num_regs)); + } + vmovups(ptr[param2 + offset], JMM(i)); + offset += sizeof(float) * block; + } + } + + private: + int h_; + int w_; + SeqPoolType type_; + reg64_t param1{abi_param1}; + reg64_t param2{abi_param2}; + reg64_t param3{abi_param3}; + reg32_t reg32_scalar{r8d}; +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index 6b0025a75a..db78ed8ad8 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -44,8 +44,11 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) { template <> size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) { - size_t key = static_cast<size_t>(attr.type); - return key + (attr.w << act_type_shift); + size_t key = attr.w; + // TODO(TJ): support height, then removed it from key + constexpr int w_shift = 30; + return (key << act_type_shift) + static_cast<int>(attr.type) + + (static_cast<size_t>(attr.h) << (act_type_shift + w_shift)); } } // namespace jit diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 98707c936d..283e2e251a 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -255,11 +255,11 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { jit::seq_pool_attr_t attr; attr.w = input.numel() / input.dims()[0]; attr.type = jit::SeqPoolType::kSum; - auto seqpool = - jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( - attr); for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { attr.h = static_cast<int>(lod[i + 1] - lod[i]); + auto seqpool = + jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( + attr); seqpool(src, dst, &attr); dst += attr.w; src += attr.h * attr.w; From 92201d3956a4f64615baf5bc9e979bcfc6bd09bd Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Fri, 4 Jan 2019 06:41:40 +0000 Subject: [PATCH 05/28] support avg and sqrt pool and add mkl impl test=develop --- .../operators/jit/more/mkl/CMakeLists.txt | 1 + paddle/fluid/operators/jit/more/mkl/mkl.cc | 31 +++++++++++++++++++ paddle/fluid/operators/jit/more/mkl/mkl.h | 26 ++++++++++++++++ paddle/fluid/operators/jit/refer/refer.h | 9 ++++++ 4 files changed, 67 insertions(+) diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index 863cc720d6..f5ed2f0572 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -9,3 +9,4 @@ USE_JITKERNEL_MORE(kVScal, mkl) USE_JITKERNEL_MORE(kVExp, mkl) USE_JITKERNEL_MORE(kVSigmoid, mkl) USE_JITKERNEL_MORE(kVTanh, mkl) +USE_JITKERNEL_MORE(kSeqPool, mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index a5b088d481..5a499ac2c0 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -72,6 +72,26 @@ void VExp<double>(const double* x, double* y, int n) { platform::dynload::vdExp(n, x, y); } +template <> +void VCopy<float>(const float* x, float* y, int n) { + platform::dynload::cblas_scopy(n, x, 1, y, 1); +} + +template <> +void VCopy<double>(const double* x, double* y, int n) { + platform::dynload::cblas_dcopy(n, x, 1, y, 1); +} + +template <> +void VAXPY<float>(float a, const float* x, float* y, int n) { + platform::dynload::cblas_saxpy(n, a, x, 1, y, 1); +} + +template <> +void VAXPY<double>(double a, const double* x, double* y, int n) { + platform::dynload::cblas_daxpy(n, a, x, 1, y, 1); +} + // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 template <> bool VMulKernel<float>::UseMe(const int& d) const { @@ -103,6 +123,16 @@ bool VTanhKernel<float>::UseMe(const int& d) const { return d > 7; } +template <> +bool SeqPoolKernel<float>::UseMe(const seq_pool_attr_t& attr) const { + return true; +} + +template <> +bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const { + return true; +} + #define AWALYS_USE_ME_WITH_DOUBLE(func) \ template <> \ bool func##Kernel<double>::UseMe(const int& d) const { \ @@ -135,5 +165,6 @@ REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVTanh, VTanh); +REGISTER_MKL_KERNEL(kSeqPool, SeqPool); #undef REGISTER_MKL_KERNEL diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index ee1031c028..0a3816db24 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -14,6 +14,7 @@ #pragma once +#include <cmath> #include <type_traits> #include "paddle/fluid/operators/jit/kernel_base.h" @@ -35,6 +36,12 @@ void VScal(const T* a, const T* x, T* y, int n); template <typename T> void VExp(const T* x, T* y, int n); +template <typename T> +void VCopy(const T* x, T* y, int n); + +template <typename T> +void VAXPY(T a, const T* x, T* y, int n); + template <typename T> void VSigmoid(const T* x, T* y, int n) { const T min = SIGMOID_THRESHOLD_MIN; @@ -60,6 +67,23 @@ void VTanh(const T* x, T* y, int n) { } } +template <typename T> +void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { + VCopy<T>(x, y, attr->w); + for (int h = 1; h != attr->h; ++h) { + VAXPY<T>(static_cast<T>(1), x + h * attr->w, y, attr->w); + } + if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) { + T scalar = static_cast<T>(1); + if (attr->type == SeqPoolType::kAvg) { + scalar = scalar / static_cast<T>(attr->h); + } else { + scalar = scalar / std::sqrt(static_cast<T>(attr->h)); + } + VScal<T>(&scalar, y, y, attr->w); + } +} + #define DECLARE_MKL_KERNEL(name, tuples) \ template <typename T> \ class name##Kernel : public KernelMore<tuples<T>> { \ @@ -81,6 +105,8 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VTanh, XYNTuples); +DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); + #undef DECLARE_MKL_KERNEL } // namespace mkl diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index c2aa922528..4e19783c86 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -344,6 +344,15 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { src += attr->w; } } + if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) { + T scalar = static_cast<T>(1); + if (attr->type == SeqPoolType::kAvg) { + scalar = scalar / static_cast<T>(attr->h); + } else { + scalar = scalar / std::sqrt(static_cast<T>(attr->h)); + } + VScal<T>(&scalar, y, y, attr->w); + } } #define DECLARE_REFER_KERNEL(name, tuples) \ From f4c990e7b8493304b61249417aaaca45d95e5174 Mon Sep 17 00:00:00 2001 From: minqiyang <minqiyang@baidu.com> Date: Mon, 7 Jan 2019 12:54:37 +0800 Subject: [PATCH 06/28] Add fused embedding ops --- .../fused/fused_embedding_seq_pool_op.cc | 194 ++++++++++++++++++ .../fused/fused_embedding_seq_pool_op.h | 142 +++++++++++++ 2 files changed, 336 insertions(+) create mode 100644 paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc create mode 100644 paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc new file mode 100644 index 0000000000..fe4c73f472 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -0,0 +1,194 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h" +#include "paddle/fluid/framework/var_type_inference.h" + +namespace paddle { +namespace operators { + +class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("W"), + "Input W of FusedEmbeddingSeqPoolOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), + "Input Ids of FusedEmbeddingSeqPoolOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output of FusedEmbeddingSeqPoolOp should not be null."); + + auto table_dims = ctx->GetInputDim("W"); + auto ids_dims = ctx->GetInputDim("Ids"); + const std::string& combiner = ctx->Attrs().Get<std::string>("combiner"); + + PADDLE_ENFORCE_EQ(table_dims.size(), 2); + PADDLE_ENFORCE_GE(ids_dims.size(), 1, + "The dim size of the 'Ids' tensor must greater than 1."); + PADDLE_ENFORCE_EQ(ids_dims[ids_dims.size() - 1], 1, + "The last dimension of the 'Ids' tensor must be 1."); + // we only support sum now + PADDLE_ENFORCE_EQ(combiner, "sum"); + + int64_t last_dim = table_dims[1]; + for (int i = 1; i != ids_dims.size(); ++i) { + last_dim *= ids_dims[i]; + } + + if (ctx->IsRuntime()) { + framework::Variable* ids_var = + boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Ids")[0]); + const auto& ids_lod = ids_var->Get<LoDTensor>().lod(); + + // in run time, the LoD of ids must be 1 + PADDLE_ENFORCE(ids_lod.size(), 1u, + "The LoD level of Input(Ids) must be 1"); + PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty"); + + int64_t batch_size = ids_lod[0].size() - 1; + + // in run time, the shape from Ids -> output + // should be [seq_length, 1] -> [batch_size, embedding_size] + ctx->SetOutputDim("Out", framework::make_ddim({batch_size, last_dim})); + } else { + // in compile time, the lod level of ids must be 1 + framework::VarDesc* ids_desc = + boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]); + PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1); + + // in compile time, the shape from Ids -> output + // should be [-1, 1] -> [-1, embedding_size] + ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim})); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("W", + "(Tensor) The input represents embedding tensors, " + "which is a learnable parameter."); + AddInput("Ids", + "An input with type int32 or int64 " + "contains the ids to be looked up in W. " + "The last dimension size must be 1."); + AddOutput("Out", "The lookup results, which have the same type as W."); + AddAttr<std::string>("combiner", + "(string, default sum) " + "A string specifying the reduction op. Currently sum " + "are supported, sum computes the weighted sum of the " + "embedding results for each row.") + .SetDefault("sum"); + // NOTE(minqiyang): grad_inplace is an temporal attribute, + // please do NOT set this attribute in python layer. + AddAttr<bool>("grad_inplace", + "(boolean, default false) " + "If the grad op reuse the input's variable.") + .SetDefault(false); + AddAttr<bool>("is_sparse", + "(boolean, default false) " + "Sparse update.") + .SetDefault(false); + AddComment(R"DOC( +FusedEmbeddingSeqPool Operator. + +Computes embeddings for the given ids and weights. + +This operator is used to perform lookups on the parameter W, +then computes the weighted sum of the lookups results for each row +and concatenated into a dense tensor. + +The input Ids should carry the LoD (Level of Details) information. +And the output will change the LoD information with input Ids. + +)DOC"); + } +}; + +class FusedEmbeddingSeqPoolOpGradDescMaker + : public framework::DefaultGradOpDescMaker<true> { + using ::paddle::framework::DefaultGradOpDescMaker< + true>::DefaultGradOpDescMaker; + + protected: + virtual std::string GradOpType() const { + return "fused_embedding_seq_pool_grad"; + } +}; + +class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto table_dims = ctx->GetInputDim("W"); + ctx->SetOutputDim(framework::GradVarName("W"), table_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class FusedEmbeddingSeqPoolOpGradVarTypeInference + : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); + auto attr = op_desc.GetAttr("is_sparse"); + bool is_sparse = boost::get<bool>(attr); + if (is_sparse) { + VLOG(3) << "fused_embedding_seq_pool_grad op " + << framework::GradVarName("W") << " is set to SelectedRows"; + block->Var(out_var_name) + ->SetType(framework::proto::VarType::SELECTED_ROWS); + } else { + VLOG(3) << "fused_embedding_seq_pool_grad op " + << framework::GradVarName("W") << " is set to LoDTensor"; + block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); + } + block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_embedding_seq_pool, ops::FusedEmbeddingSeqPoolOp, + ops::FusedEmbeddingSeqPoolOpGradDescMaker, + ops::FusedEmbeddingSeqPoolOpMaker); +REGISTER_OPERATOR(fused_embedding_seq_pool_grad, + ops::FusedEmbeddingSeqPoolOpGrad, + ops::FusedEmbeddingSeqPoolOpGradVarTypeInference); + +REGISTER_OP_CPU_KERNEL(fused_embedding_seq_pool, + ops::FusedEmbeddingSeqPoolKernel<float>, + ops::FusedEmbeddingSeqPoolKernel<double>); +REGISTER_OP_CPU_KERNEL(fused_embedding_seq_pool_grad, + ops::FusedEmbeddingSeqPoolGradKernel<float>, + ops::FusedEmbeddingSeqPoolGradKernel<double>); diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h new file mode 100644 index 0000000000..38dfae8ad6 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -0,0 +1,142 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include <string> +#include <vector> + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; +using DDim = framework::DDim; + +template <typename T> +struct EmbeddingVSumFunctor { + void operator()(const framework::ExecutionContext &context, + const LoDTensor *table_t, const LoDTensor *ids_t, + LoDTensor *output_t) { + auto *table = table_t->data<T>(); + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + int64_t last_dim = output_t->dims()[1]; + int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>()); + auto ids_lod = ids_t->lod()[0]; + int64_t ids_count = ids_t->numel() / ids_lod.back(); + + auto *output = output_t->mutable_data<T>(context.GetPlace()); + + auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); + for (int64_t i = 0; i != ids_lod.size() - 1; ++i) { + size_t begin = ids_lod[i] * ids_count; + for (int64_t j = 0; j != ids_count; ++j) { + PADDLE_ENFORCE_LT(ids[begin], row_number); + PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i); + blas.VCOPY(row_width, table + ids[begin + j] * row_width, + output + i * last_dim + j * row_width); + } + + for (int64_t r = (ids_lod[i] + 1) * ids_count; + r < ids_lod[i + 1] * ids_count; ++r) { + PADDLE_ENFORCE_LT(ids[r], row_number); + PADDLE_ENFORCE_GE(ids[r], 0, "ids %d", i); + blas.AXPY(row_width, 1., table + ids[r] * row_width, + output + i * last_dim + (r % ids_count) * row_width); + } + } + } +}; + +template <typename T> +class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { + public: + void Compute(const framework::ExecutionContext &context) const override { + const LoDTensor *ids_t = context.Input<LoDTensor>("Ids"); // int tensor + LoDTensor *output_t = context.Output<LoDTensor>("Out"); // float tensor + const LoDTensor *table_var = context.Input<LoDTensor>("W"); + const std::string &combiner_type = context.Attr<std::string>("combiner"); + + if (combiner_type == "sum") { + EmbeddingVSumFunctor<T> functor; + functor(context, table_var, ids_t, output_t); + } + } +}; + +template <typename T> +class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *table_var = context.InputVar("W"); + DDim table_dim; + if (table_var->IsType<LoDTensor>()) { + table_dim = context.Input<LoDTensor>("W")->dims(); + } else if (table_var->IsType<SelectedRows>()) { + auto *table_t = context.Input<SelectedRows>("W"); + table_dim = table_t->value().dims(); + } else { + PADDLE_THROW( + "The parameter W of a LookupTable " + "must be either LoDTensor or SelectedRows"); + } + + bool is_sparse = context.Attr<bool>("is_sparse"); + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + if (is_sparse) { + auto *ids = context.Input<LoDTensor>("Ids"); + auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); + auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W")); + + auto *ids_data = ids->data<int64_t>(); + int64_t ids_num = ids->numel(); + auto lod = ids->lod()[0]; + int64_t row_width = d_output->dims()[1]; + + framework::Vector<int64_t> *new_rows = d_table->mutable_rows(); + new_rows->resize(ids_num); + std::memcpy(&(*new_rows)[0], ids_data, ids_num * sizeof(int64_t)); + + auto *d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table_dim[1]}); + T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace()); + const T *d_output_data = d_output->data<T>(); + + auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); + for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { + int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]); + int64_t in_offset = lod[i] * row_width; + const T *out_pos = d_output_data + i * row_width; + T *in_pos = d_table_data + in_offset; + for (int r = 0; r != h; ++r) { + blas.VCOPY(row_width, out_pos, in_pos + r * row_width); + } + } + } else { + LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now"; + } + } +}; + +} // namespace operators +} // namespace paddle From dc0ecffd6c4115019cfcbcc13b17a20511888c9b Mon Sep 17 00:00:00 2001 From: minqiyang <minqiyang@baidu.com> Date: Mon, 7 Jan 2019 12:55:03 +0800 Subject: [PATCH 07/28] Add ut for fused ops --- .../unittests/test_fused_emb_seq_pool_op.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py diff --git a/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py b/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py new file mode 100644 index 0000000000..584e309bef --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py @@ -0,0 +1,51 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid.op import Operator +import paddle.compat as cpt + + +class TestFusedEmbeddingSeqPoolOp(OpTest): + def setUp(self): + self.op_type = "fused_embedding_seq_pool" + self.emb_size = 2 + table = np.random.random((17, self.emb_size)).astype("float32") + ids = np.array([[[4], [3]], [[4], [3]], [[2], [1]], + [[16], [1]]]).astype("int64") + merged_ids = np.array([4, 2, 16]).astype("int64") + ids_expand = np.expand_dims(ids, axis=1) + self.lod = [[3, 1]] + self.attrs = {'is_sparse': True} + self.inputs = {'W': table, 'Ids': (ids_expand, self.lod)} + self.outputs = { + 'Out': np.reshape( + np.array([ + table[[4, 3]] + table[[4, 3]] + table[[2, 1]], + table[[16, 1]] + ]), [len(self.lod[0]), 2 * self.emb_size]) + } + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() From e0591deebc02202c4ae8bfc95f31be606b8192b8 Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Fri, 4 Jan 2019 14:40:43 +0000 Subject: [PATCH 08/28] enhance seqpool jitcode --- paddle/fluid/operators/jit/benchmark.cc | 4 +- paddle/fluid/operators/jit/gen/seqpool.cc | 55 +-------- paddle/fluid/operators/jit/gen/seqpool.h | 134 ++++++++++++++++++++-- 3 files changed, 126 insertions(+), 67 deletions(-) diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index f64e43389a..37a552fb6d 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -194,8 +194,8 @@ template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> void BenchSeqPoolKernel() { std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; for (auto type : pool_types) { - for (int h : TestSizes()) { - for (int w : TestSizes()) { + for (int w : TestSizes()) { + for (int h : TestSizes()) { const jit::seq_pool_attr_t attr(h, w, type); std::vector<T> x(h * w), y(w); RandomVec<T>(h * w, x.data(), -2.f, 2.f); diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc index ce6801b030..fd83f83436 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.cc +++ b/paddle/fluid/operators/jit/gen/seqpool.cc @@ -35,7 +35,6 @@ void SeqPoolJitCode::genCode() { mov(reg32_scalar, scalar); } - // TODO(TJ): make height load from params const int group_len = max_num_regs * block * sizeof(float); for (int g = 0; g < num_groups; ++g) { pool_height<ymm_t>(g * group_len, block, max_num_regs); @@ -44,59 +43,9 @@ void SeqPoolJitCode::genCode() { pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs); } - // rest part + // part of rest_w * height const int rest = w_ % block; - const bool has_block4 = rest / 4 > 0; - const bool has_block2 = (rest % 4) / 2 > 0; - const bool has_block1 = (rest % 2) == 1; - const int w_offset = num_block * YMM_FLOAT_BLOCK * sizeof(float); - for (int h = 0; h < h_; ++h) { - int offset = h * w_ * sizeof(float) + w_offset; - const int shift_regs = (h == 0) ? 0 : max_num_regs; - int reg_idx = 0; - if (has_block4) { - vmovups(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]); - offset += sizeof(float) * 4; - reg_idx++; - } - if (has_block2) { - vmovq(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]); - offset += sizeof(float) * 2; - reg_idx++; - } - if (has_block1) { - vmovss(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]); - reg_idx++; - } - rest_num_regs = reg_idx; - if (h > 0) { - for (int i = 0; i < reg_idx; ++i) { - vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); - } - } - } - // save right now - int offset = w_offset; - if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { - vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar); - for (int i = 0; i < rest_num_regs; ++i) { - vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1)); - } - } - int reg_idx = 0; - if (has_block4) { - vmovups(ptr[param2 + offset], xmm_t(reg_idx)); - offset += sizeof(float) * 4; - reg_idx++; - } - if (has_block2) { - vmovq(ptr[param2 + offset], xmm_t(reg_idx)); - offset += sizeof(float) * 2; - reg_idx++; - } - if (has_block1) { - vmovss(ptr[param2 + offset], xmm_t(reg_idx)); - } + pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs); ret(); } diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h index eb2d191382..48288d8c2a 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.h +++ b/paddle/fluid/operators/jit/gen/seqpool.h @@ -17,6 +17,7 @@ #include <string> #include "glog/logging.h" #include "paddle/fluid/operators/jit/gen/jitcode.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { @@ -45,8 +46,6 @@ class SeqPoolJitCode : public JitCode { base += "_Sqrt"; } base += ("_W" + std::to_string(w_)); - // TODO(TJ): make h load from params - base += ("_H" + std::to_string(h_)); return base.c_str(); } void genCode() override; @@ -54,25 +53,36 @@ class SeqPoolJitCode : public JitCode { protected: template <typename JMM> void pool_height(int w_offset, int block, int max_num_regs) { - for (int h = 0; h < h_; ++h) { - int offset = h * w_ * sizeof(float) + w_offset; - const int shift_regs = (h == 0) ? 0 : max_num_regs; - for (int i = 0; i < max_num_regs; ++i) { - vmovups(JMM(i + shift_regs), ptr[param1 + offset]); - offset += sizeof(float) * block; - } - if (h > 0) { - // sum anyway + int offset = w_offset; + for (int i = 0; i < max_num_regs; ++i) { + vmovups(JMM(i), ptr[param1 + offset]); + offset += sizeof(float) * block; + } + if (h_ > 1) { + Label l_next_h; + mov(reg_h, 1); + mov(reg_tmp, param1); + add(reg_tmp, w_ * sizeof(float) + w_offset); + L(l_next_h); + { + mov(reg_ptr_src_i, reg_tmp); for (int i = 0; i < max_num_regs; ++i) { + vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]); + // sum anyway vaddps(JMM(i), JMM(i), JMM(i + max_num_regs)); + add(reg_ptr_src_i, sizeof(float) * block); } + inc(reg_h); + add(reg_tmp, w_ * sizeof(float)); + cmp(reg_h, h_); + jl(l_next_h, T_NEAR); } } // save right now if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { vbroadcastss(JMM(max_num_regs), reg32_scalar); } - int offset = w_offset; + offset = w_offset; for (int i = 0; i < max_num_regs; ++i) { if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { vmulps(JMM(i), JMM(i), JMM(max_num_regs)); @@ -82,6 +92,102 @@ class SeqPoolJitCode : public JitCode { } } + void pool_height_of_rest_width(int rest, int w_offset, int max_num_regs) { + const int rest_used_num_regs = load_rest(rest, w_offset, 0); + const bool has_block4 = rest / 4 > 0; + const bool has_block2 = (rest % 4) / 2 > 0; + const bool has_block1 = (rest % 2) == 1; + if (h_ > 1) { + Label l_next_h; + mov(reg_h, 1); + mov(reg_tmp, param1); + add(reg_tmp, w_ * sizeof(float) + w_offset); + L(l_next_h); + { + // int used_regs =load_rest(rest, h * w_ * sizeof(float) + w_offset, + // max_num_regs); + int reg_idx = 0; + mov(reg_ptr_src_i, reg_tmp); + if (has_block4) { + vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + add(reg_ptr_src_i, sizeof(float) * 4); + reg_idx++; + } + if (has_block2) { + vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + add(reg_ptr_src_i, sizeof(float) * 2); + reg_idx++; + } + if (has_block1) { + vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + reg_idx++; + } + PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs, + "All heights should use same regs"); + for (int i = 0; i < reg_idx; ++i) { + vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); + } + inc(reg_h); + add(reg_tmp, w_ * sizeof(float)); + cmp(reg_h, h_); + jl(l_next_h, T_NEAR); + } + } + // save right now + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar); + for (int i = 0; i < rest_used_num_regs; ++i) { + vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1)); + } + } + save_rest(rest, w_offset); + } + + // return the number of used regs, use start from reg 0 + int load_rest(int rest, int w_offset, const int num_shift_regs, + const int reg_start = 0) { + const bool has_block4 = rest / 4 > 0; + const bool has_block2 = (rest % 4) / 2 > 0; + const bool has_block1 = (rest % 2) == 1; + int reg_idx = reg_start; + if (has_block4) { + vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); + w_offset += sizeof(float) * 4; + reg_idx++; + } + if (has_block2) { + vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); + w_offset += sizeof(float) * 2; + reg_idx++; + } + if (has_block1) { + vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); + reg_idx++; + } + return reg_idx; + } + + // use reg start from 0 + void save_rest(int rest, int w_offset, int reg_start = 0) { + const bool has_block4 = rest / 4 > 0; + const bool has_block2 = (rest % 4) / 2 > 0; + const bool has_block1 = (rest % 2) == 1; + int reg_idx = reg_start; + if (has_block4) { + vmovups(ptr[param2 + w_offset], xmm_t(reg_idx)); + w_offset += sizeof(float) * 4; + reg_idx++; + } + if (has_block2) { + vmovq(ptr[param2 + w_offset], xmm_t(reg_idx)); + w_offset += sizeof(float) * 2; + reg_idx++; + } + if (has_block1) { + vmovss(ptr[param2 + w_offset], xmm_t(reg_idx)); + } + } + private: int h_; int w_; @@ -90,6 +196,10 @@ class SeqPoolJitCode : public JitCode { reg64_t param2{abi_param2}; reg64_t param3{abi_param3}; reg32_t reg32_scalar{r8d}; + + reg64_t reg_h{r9}; + reg64_t reg_ptr_src_i{r10}; + reg64_t reg_tmp{r11}; }; } // namespace gen From 0145f40f4576fa035b92e3876ca9c4cfefbc5c52 Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Sat, 5 Jan 2019 11:34:15 +0000 Subject: [PATCH 09/28] use height from params of jitcode --- paddle/fluid/operators/jit/benchmark.cc | 3 +- paddle/fluid/operators/jit/gen/seqpool.cc | 17 +- paddle/fluid/operators/jit/gen/seqpool.h | 162 ++++++++++-------- paddle/fluid/operators/jit/kernel_base.h | 6 +- paddle/fluid/operators/jit/kernel_key.cc | 6 +- paddle/fluid/operators/jit/refer/refer.h | 1 - paddle/fluid/operators/jit/test.cc | 7 +- .../fluid/operators/math/sequence_pooling.cc | 12 +- 8 files changed, 117 insertions(+), 97 deletions(-) diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 37a552fb6d..4cbada4a5b 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -195,8 +195,9 @@ void BenchSeqPoolKernel() { std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; for (auto type : pool_types) { for (int w : TestSizes()) { + jit::seq_pool_attr_t attr(w, type); for (int h : TestSizes()) { - const jit::seq_pool_attr_t attr(h, w, type); + attr.h = h; std::vector<T> x(h * w), y(w); RandomVec<T>(h * w, x.data(), -2.f, 2.f); const T* x_data = x.data(); diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc index fd83f83436..d651f282bf 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.cc +++ b/paddle/fluid/operators/jit/gen/seqpool.cc @@ -13,6 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/gen/seqpool.h" +#include <stddef.h> // offsetof #include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/platform/cpu_info.h" @@ -21,20 +22,22 @@ namespace operators { namespace jit { namespace gen { +thread_local float ALIGN32_BEG float_h[1] ALIGN32_END = { + 1.f}; // TODO(TJ): try move to private + void SeqPoolJitCode::genCode() { constexpr int block = YMM_FLOAT_BLOCK; constexpr int max_num_regs = 8; const int num_block = w_ / block; const int num_groups = num_block / max_num_regs; int rest_num_regs = num_block % max_num_regs; - if (type_ == SeqPoolType::kAvg) { - float scalar = 1.f / h_; - mov(reg32_scalar, scalar); - } else if (type_ == SeqPoolType::kSqrt) { - float scalar = 1.f / std::sqrt(static_cast<float>(h_)); - mov(reg32_scalar, scalar); + mov(reg32_int_h, dword[param_attr]); + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + mov(reg_tmp, reinterpret_cast<size_t>(float_h)); + fild(dword[param_attr]); + fstp(dword[reg_tmp]); + mov(reg32_fp_h, dword[reg_tmp]); } - const int group_len = max_num_regs * block * sizeof(float); for (int g = 0; g < num_groups; ++g) { pool_height<ymm_t>(g * group_len, block, max_num_regs); diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h index 48288d8c2a..c61bf27cc1 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.h +++ b/paddle/fluid/operators/jit/gen/seqpool.h @@ -16,6 +16,7 @@ #include <string> #include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/act.h" // for ones #include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/fluid/platform/enforce.h" @@ -29,7 +30,7 @@ class SeqPoolJitCode : public JitCode { explicit SeqPoolJitCode(const seq_pool_attr_t& attr, size_t code_size = 256 * 1024, void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), h_(attr.h), w_(attr.w), type_(attr.type) { + : JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) { if (type_ != SeqPoolType::kSum) { LOG(FATAL) << "Only support sum pool yet "; } @@ -55,39 +56,48 @@ class SeqPoolJitCode : public JitCode { void pool_height(int w_offset, int block, int max_num_regs) { int offset = w_offset; for (int i = 0; i < max_num_regs; ++i) { - vmovups(JMM(i), ptr[param1 + offset]); + vmovups(JMM(i), ptr[param_src + offset]); offset += sizeof(float) * block; } - if (h_ > 1) { - Label l_next_h; - mov(reg_h, 1); - mov(reg_tmp, param1); - add(reg_tmp, w_ * sizeof(float) + w_offset); - L(l_next_h); - { - mov(reg_ptr_src_i, reg_tmp); - for (int i = 0; i < max_num_regs; ++i) { - vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]); - // sum anyway - vaddps(JMM(i), JMM(i), JMM(i + max_num_regs)); - add(reg_ptr_src_i, sizeof(float) * block); - } - inc(reg_h); - add(reg_tmp, w_ * sizeof(float)); - cmp(reg_h, h_); - jl(l_next_h, T_NEAR); + cmp(reg32_int_h, 1); + Label l_next_h, l_h_done; + jle(l_h_done, T_NEAR); + mov(reg_h_i, 1); + mov(reg_tmp, param_src); + add(reg_tmp, w_ * sizeof(float) + w_offset); + L(l_next_h); + { + mov(reg_ptr_src_i, reg_tmp); + for (int i = 0; i < max_num_regs; ++i) { + vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]); + // sum anyway + vaddps(JMM(i), JMM(i), JMM(i + max_num_regs)); + add(reg_ptr_src_i, sizeof(float) * block); } + inc(reg_h_i); + add(reg_tmp, w_ * sizeof(float)); + cmp(reg_h_i, reg32_int_h); + jl(l_next_h, T_NEAR); } + L(l_h_done); // save right now if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { - vbroadcastss(JMM(max_num_regs), reg32_scalar); + mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts)); + vmovups(JMM(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]); + movd(JMM(max_num_regs + 1), reg32_fp_h); + if (type_ == SeqPoolType::kSqrt) { + vsqrtps(JMM(max_num_regs + 1), JMM(max_num_regs + 1)); + } + vdivps(JMM(max_num_regs + 2), JMM(max_num_regs), JMM(max_num_regs + 1)); + vbroadcastss(JMM(max_num_regs), + JMM(max_num_regs + 2)); // TODO(TJ): fix me } offset = w_offset; for (int i = 0; i < max_num_regs; ++i) { if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { vmulps(JMM(i), JMM(i), JMM(max_num_regs)); } - vmovups(ptr[param2 + offset], JMM(i)); + vmovups(ptr[param_dst + offset], JMM(i)); offset += sizeof(float) * block; } } @@ -97,47 +107,54 @@ class SeqPoolJitCode : public JitCode { const bool has_block4 = rest / 4 > 0; const bool has_block2 = (rest % 4) / 2 > 0; const bool has_block1 = (rest % 2) == 1; - if (h_ > 1) { - Label l_next_h; - mov(reg_h, 1); - mov(reg_tmp, param1); - add(reg_tmp, w_ * sizeof(float) + w_offset); - L(l_next_h); - { - // int used_regs =load_rest(rest, h * w_ * sizeof(float) + w_offset, - // max_num_regs); - int reg_idx = 0; - mov(reg_ptr_src_i, reg_tmp); - if (has_block4) { - vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); - add(reg_ptr_src_i, sizeof(float) * 4); - reg_idx++; - } - if (has_block2) { - vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); - add(reg_ptr_src_i, sizeof(float) * 2); - reg_idx++; - } - if (has_block1) { - vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); - reg_idx++; - } - PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs, - "All heights should use same regs"); - for (int i = 0; i < reg_idx; ++i) { - vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); - } - inc(reg_h); - add(reg_tmp, w_ * sizeof(float)); - cmp(reg_h, h_); - jl(l_next_h, T_NEAR); + cmp(reg32_int_h, 1); + Label l_next_h, l_h_done; + jle(l_h_done, T_NEAR); + mov(reg_h_i, 1); + mov(reg_tmp, param_src); + add(reg_tmp, w_ * sizeof(float) + w_offset); + L(l_next_h); + { + int reg_idx = 0; + mov(reg_ptr_src_i, reg_tmp); + if (has_block4) { + vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + add(reg_ptr_src_i, sizeof(float) * 4); + reg_idx++; + } + if (has_block2) { + vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + add(reg_ptr_src_i, sizeof(float) * 2); + reg_idx++; + } + if (has_block1) { + vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + reg_idx++; } + PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs, + "All heights should use same regs"); + for (int i = 0; i < reg_idx; ++i) { + vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); + } + inc(reg_h_i); + add(reg_tmp, w_ * sizeof(float)); + cmp(reg_h_i, reg32_int_h); + jl(l_next_h, T_NEAR); } + L(l_h_done); // save right now if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { - vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar); + mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts)); + vmovups(xmm_t(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]); + movd(xmm_t(max_num_regs + 1), reg32_fp_h); + if (type_ == SeqPoolType::kSqrt) { + vsqrtps(xmm_t(max_num_regs + 1), xmm_t(max_num_regs + 1)); + } + vdivps(xmm_t(max_num_regs + 2), xmm_t(max_num_regs), + xmm_t(max_num_regs + 1)); + vbroadcastss(xmm_t(max_num_regs), xmm_t(max_num_regs + 2)); for (int i = 0; i < rest_used_num_regs; ++i) { - vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1)); + vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs)); } } save_rest(rest, w_offset); @@ -151,17 +168,17 @@ class SeqPoolJitCode : public JitCode { const bool has_block1 = (rest % 2) == 1; int reg_idx = reg_start; if (has_block4) { - vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); + vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]); w_offset += sizeof(float) * 4; reg_idx++; } if (has_block2) { - vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); + vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]); w_offset += sizeof(float) * 2; reg_idx++; } if (has_block1) { - vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]); + vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]); reg_idx++; } return reg_idx; @@ -174,32 +191,33 @@ class SeqPoolJitCode : public JitCode { const bool has_block1 = (rest % 2) == 1; int reg_idx = reg_start; if (has_block4) { - vmovups(ptr[param2 + w_offset], xmm_t(reg_idx)); + vmovups(ptr[param_dst + w_offset], xmm_t(reg_idx)); w_offset += sizeof(float) * 4; reg_idx++; } if (has_block2) { - vmovq(ptr[param2 + w_offset], xmm_t(reg_idx)); + vmovq(ptr[param_dst + w_offset], xmm_t(reg_idx)); w_offset += sizeof(float) * 2; reg_idx++; } if (has_block1) { - vmovss(ptr[param2 + w_offset], xmm_t(reg_idx)); + vmovss(ptr[param_dst + w_offset], xmm_t(reg_idx)); } } private: - int h_; int w_; SeqPoolType type_; - reg64_t param1{abi_param1}; - reg64_t param2{abi_param2}; - reg64_t param3{abi_param3}; - reg32_t reg32_scalar{r8d}; + reg64_t param_src{abi_param1}; + reg64_t param_dst{abi_param2}; + reg64_t param_attr{abi_param3}; + reg64_t reg_tmp{rax}; + + reg32_t reg32_int_h{r8d}; + reg32_t reg32_fp_h{r9d}; - reg64_t reg_h{r9}; - reg64_t reg_ptr_src_i{r10}; - reg64_t reg_tmp{r11}; + reg64_t reg_h_i{r10}; + reg64_t reg_ptr_src_i{r11}; }; } // namespace gen diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 2659374650..2a7697a6f2 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -46,7 +46,7 @@ typedef enum { typedef enum { kNonePoolType = 0, - kSum, + kSum = 1, kAvg, kSqrt, } SeqPoolType; @@ -121,10 +121,10 @@ struct GRUTuples { }; typedef struct seq_pool_attr_s { - int h, w; + int h, w; // h should always be the first one SeqPoolType type; seq_pool_attr_s() = default; - explicit seq_pool_attr_s(int height, int width, SeqPoolType pool_type) + explicit seq_pool_attr_s(int width, SeqPoolType pool_type, int height = 1) : h(height), w(width), type(pool_type) {} } seq_pool_attr_t; diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index db78ed8ad8..61de386886 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -45,10 +45,8 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) { template <> size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) { size_t key = attr.w; - // TODO(TJ): support height, then removed it from key - constexpr int w_shift = 30; - return (key << act_type_shift) + static_cast<int>(attr.type) + - (static_cast<size_t>(attr.h) << (act_type_shift + w_shift)); + constexpr int pool_type_shift = 3; + return (key << pool_type_shift) + static_cast<int>(attr.type); } } // namespace jit diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 4e19783c86..b4e9c8dd10 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -334,7 +334,6 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { template <typename T> void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { - PADDLE_ENFORCE(attr->type == SeqPoolType::kSum, "Only support sum yet"); for (int w = 0; w < attr->w; ++w) { const T* src = x + w; T* dst = y + w; diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 0f1776507a..5e05c71f40 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -439,9 +439,10 @@ void TestSeqPoolKernel() { // TODO(TJ): support more std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; for (auto type : pool_types) { - for (int h : TestSizes()) { - for (int w : TestSizes()) { - const jit::seq_pool_attr_t attr(h, w, type); + for (int w : TestSizes()) { + jit::seq_pool_attr_t attr(w, type); + for (int h : TestSizes()) { + attr.h = h; auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>(); EXPECT_TRUE(ref != nullptr); std::vector<T> x(h * w), yref(w); diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 283e2e251a..2a47502614 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -252,14 +252,14 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { PADDLE_ENFORCE(platform::is_cpu_place(place)); const T* src = input.data<T>(); T* dst = output->mutable_data<T>(place); - jit::seq_pool_attr_t attr; - attr.w = input.numel() / input.dims()[0]; - attr.type = jit::SeqPoolType::kSum; + jit::seq_pool_attr_t attr( + static_cast<int>(input.numel() / input.dims()[0]), + jit::SeqPoolType::kSum); + auto seqpool = + jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( + attr); for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { attr.h = static_cast<int>(lod[i + 1] - lod[i]); - auto seqpool = - jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( - attr); seqpool(src, dst, &attr); dst += attr.w; src += attr.h * attr.w; From 123b98f417d064e780412f316f4ca43988f4d0d2 Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Mon, 7 Jan 2019 06:07:23 +0000 Subject: [PATCH 10/28] refine heigth and codesize and support all pool test=develop --- paddle/fluid/operators/jit/benchmark.cc | 3 ++- paddle/fluid/operators/jit/gen/seqpool.cc | 27 +++++++++++----------- paddle/fluid/operators/jit/gen/seqpool.h | 28 +++++++---------------- paddle/fluid/operators/jit/test.cc | 4 ++-- 4 files changed, 26 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 4cbada4a5b..bde2791add 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -192,7 +192,8 @@ void BenchGRUKernel() { template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> void BenchSeqPoolKernel() { - std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; + std::vector<jit::SeqPoolType> pool_types = { + jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; for (auto type : pool_types) { for (int w : TestSizes()) { jit::seq_pool_attr_t attr(w, type); diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc index d651f282bf..530d24ee1f 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.cc +++ b/paddle/fluid/operators/jit/gen/seqpool.cc @@ -13,7 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/gen/seqpool.h" -#include <stddef.h> // offsetof +#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones #include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/platform/cpu_info.h" @@ -22,9 +22,6 @@ namespace operators { namespace jit { namespace gen { -thread_local float ALIGN32_BEG float_h[1] ALIGN32_END = { - 1.f}; // TODO(TJ): try move to private - void SeqPoolJitCode::genCode() { constexpr int block = YMM_FLOAT_BLOCK; constexpr int max_num_regs = 8; @@ -33,10 +30,17 @@ void SeqPoolJitCode::genCode() { int rest_num_regs = num_block % max_num_regs; mov(reg32_int_h, dword[param_attr]); if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { - mov(reg_tmp, reinterpret_cast<size_t>(float_h)); + mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts)); + vmovups(xmm_t(1), ptr[reg_tmp + OFFSET_EXP_ONE]); + mov(reg_tmp, reinterpret_cast<size_t>(fp_h_)); fild(dword[param_attr]); fstp(dword[reg_tmp]); - mov(reg32_fp_h, dword[reg_tmp]); + vmovss(xmm_t(0), ptr[reg_tmp]); + if (type_ == SeqPoolType::kSqrt) { + vsqrtps(xmm_t(0), xmm_t(0)); + } + vdivps(xmm_t(1), xmm_t(1), xmm_t(0)); + vmovss(ptr[reg_tmp], xmm_t(1)); } const int group_len = max_num_regs * block * sizeof(float); for (int g = 0; g < num_groups; ++g) { @@ -45,7 +49,6 @@ void SeqPoolJitCode::genCode() { if (rest_num_regs > 0) { pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs); } - // part of rest_w * height const int rest = w_ % block; pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs); @@ -58,12 +61,10 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> { return platform::MayIUse(platform::avx); } size_t CodeSize(const seq_pool_attr_t& attr) const override { - // TODO(TJ): remove attr.h when enabled height - bool yes = - attr.type == SeqPoolType::kAvg || attr.type == SeqPoolType::kSqrt; - return 96 /* basic */ + - ((attr.w / YMM_FLOAT_BLOCK + 4 /* rest */) * 2 /* for sum */ - * (attr.h + (yes ? 3 : 1 /*for avg or sqrt*/))) * + return 96 + + ((attr.w / YMM_FLOAT_BLOCK + 4 /* for rest */) * + 4 /* load, mul and save */ + + 256) * 8; } std::unique_ptr<GenBase> CreateJitCode( diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h index c61bf27cc1..fcbbb3c84c 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.h +++ b/paddle/fluid/operators/jit/gen/seqpool.h @@ -16,7 +16,6 @@ #include <string> #include "glog/logging.h" -#include "paddle/fluid/operators/jit/gen/act.h" // for ones #include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/fluid/platform/enforce.h" @@ -31,9 +30,11 @@ class SeqPoolJitCode : public JitCode { size_t code_size = 256 * 1024, void* code_ptr = nullptr) : JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) { - if (type_ != SeqPoolType::kSum) { + if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg || + type_ == SeqPoolType::kSqrt)) { LOG(FATAL) << "Only support sum pool yet "; } + fp_h_[0] = 1.f; this->genCode(); } @@ -82,15 +83,8 @@ class SeqPoolJitCode : public JitCode { L(l_h_done); // save right now if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { - mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts)); - vmovups(JMM(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]); - movd(JMM(max_num_regs + 1), reg32_fp_h); - if (type_ == SeqPoolType::kSqrt) { - vsqrtps(JMM(max_num_regs + 1), JMM(max_num_regs + 1)); - } - vdivps(JMM(max_num_regs + 2), JMM(max_num_regs), JMM(max_num_regs + 1)); - vbroadcastss(JMM(max_num_regs), - JMM(max_num_regs + 2)); // TODO(TJ): fix me + mov(reg_tmp, reinterpret_cast<size_t>(fp_h_)); + vbroadcastss(JMM(max_num_regs), ptr[reg_tmp]); } offset = w_offset; for (int i = 0; i < max_num_regs; ++i) { @@ -144,15 +138,8 @@ class SeqPoolJitCode : public JitCode { L(l_h_done); // save right now if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { - mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts)); - vmovups(xmm_t(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]); - movd(xmm_t(max_num_regs + 1), reg32_fp_h); - if (type_ == SeqPoolType::kSqrt) { - vsqrtps(xmm_t(max_num_regs + 1), xmm_t(max_num_regs + 1)); - } - vdivps(xmm_t(max_num_regs + 2), xmm_t(max_num_regs), - xmm_t(max_num_regs + 1)); - vbroadcastss(xmm_t(max_num_regs), xmm_t(max_num_regs + 2)); + mov(reg_tmp, reinterpret_cast<size_t>(fp_h_)); + vbroadcastss(xmm_t(max_num_regs), ptr[reg_tmp]); for (int i = 0; i < rest_used_num_regs; ++i) { vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs)); } @@ -206,6 +193,7 @@ class SeqPoolJitCode : public JitCode { } private: + float ALIGN32_BEG fp_h_[1] ALIGN32_END; int w_; SeqPoolType type_; reg64_t param_src{abi_param1}; diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 5e05c71f40..30291bfef3 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -436,8 +436,8 @@ void TestGRUKernel() { template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> void TestSeqPoolKernel() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); - // TODO(TJ): support more - std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; + std::vector<jit::SeqPoolType> pool_types = { + jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; for (auto type : pool_types) { for (int w : TestSizes()) { jit::seq_pool_attr_t attr(w, type); From 4bfa110fd893ee402ba1b052ddce7f26b257b442 Mon Sep 17 00:00:00 2001 From: minqiyang <minqiyang@baidu.com> Date: Mon, 7 Jan 2019 16:28:44 +0800 Subject: [PATCH 11/28] Add no lock optimize pass test=develop --- CMakeLists.txt | 2 + cmake/FindJeMalloc.cmake | 7 + cmake/generic.cmake | 2 +- paddle/fluid/framework/details/CMakeLists.txt | 2 +- .../fluid/framework/details/build_strategy.cc | 1 + paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/lock_free_optimize_pass.cc | 360 ++++++++++++++++++ .../framework/ir/lock_free_optimize_pass.h | 130 +++++++ 8 files changed, 503 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/framework/ir/lock_free_optimize_pass.cc create mode 100644 paddle/fluid/framework/ir/lock_free_optimize_pass.h diff --git a/CMakeLists.txt b/CMakeLists.txt index d6aa8f1b85..74d869307d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License +set(CMAKE_VERBOSE_MAKEFILE on) + cmake_minimum_required(VERSION 3.0) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/cmake/FindJeMalloc.cmake b/cmake/FindJeMalloc.cmake index 7911f77c4c..b95287160b 100644 --- a/cmake/FindJeMalloc.cmake +++ b/cmake/FindJeMalloc.cmake @@ -19,3 +19,10 @@ find_package_handle_standard_args(jemalloc DEFAULT_MSG JEMALLOC_LIBRARIES JEMALL mark_as_advanced( JEMALLOC_LIBRARIES JEMALLOC_INCLUDE_DIR) + +if (JEMALLOC_FOUND) + add_library(jemalloc::jemalloc UNKNOWN IMPORTED) + set_target_properties(jemalloc::jemalloc PROPERTIES + IMPORTED_LOCATION ${JEMALLOC_LIBRARIES} + INTERFACE_INCLUDE_DIRECTORIES "${JEMALLOC_INCLUDE_DIR}") +endif() diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 4e31392b98..05293b8b06 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -117,7 +117,7 @@ function(common_link TARGET_NAME) endif() if (WITH_JEMALLOC) - target_link_libraries(${TARGET_NAME} ${JEMALLOC_LIBRARIES}) + target_link_libraries(${TARGET_NAME} jemalloc::jemalloc) endif() endfunction() diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 179aa14528..c1ba6606f1 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -94,4 +94,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass fuse_elewise_add_act_pass multi_batch_merge_pass - memory_optimize_pass) + memory_optimize_pass lock_free_optimize_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 43c2eb7178..f65b3598b0 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -208,3 +208,4 @@ USE_PASS(analysis_var_pass); USE_PASS(sequential_execution_pass); USE_PASS(all_reduce_deps_pass); USE_PASS(modify_op_lock_and_record_event_pass); +USE_PASS(lock_free_optimize_pass); diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 6d795e1e2d..6e6db3d3ef 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -31,6 +31,7 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) pass_library(graph_to_program_pass base) pass_library(graph_viz_pass base) +pass_library(lock_free_optimize_pass base) pass_library(fc_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference) pass_library(infer_clean_graph_pass inference) diff --git a/paddle/fluid/framework/ir/lock_free_optimize_pass.cc b/paddle/fluid/framework/ir/lock_free_optimize_pass.cc new file mode 100644 index 0000000000..96e7060aac --- /dev/null +++ b/paddle/fluid/framework/ir/lock_free_optimize_pass.cc @@ -0,0 +1,360 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/lock_free_optimize_pass.h" + +#include <string> +#include <unordered_set> +#include <vector> + +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +const char kSumGradOpName[] = "sum"; +// TODO(minqiyang): only support sgd at current time, please add +// other optimizers later. +const char kOptimizerType[] = "sgd"; + +std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl( + std::unique_ptr<ir::Graph> graph) const { + PADDLE_ENFORCE(graph.get()); + + // We could collect all weights' name from SGD, where + // W1 <- SGD(W0, Grad0) + std::unordered_set<std::string> weight_var_set; + for (auto* node : graph->Nodes()) { + if (IsOpNamed(node, kOptimizerType)) { + auto& param_out_vars = node->Op()->Output("ParamOut"); + PADDLE_ENFORCE(param_out_vars.size() == 1u); + weight_var_set.insert(param_out_vars[0]); + } + } + + // find all grad's merge op via weight name, where + // Grad0 <- SUM(Grad1, Grad2, Grad3 ...) + std::unordered_set<ir::Node*> grad_sum_op_set; + for (ir::Node* node : graph->Nodes()) { + if (IsOpNamed(node, kSumGradOpName)) { + for (ir::Node* output : node->outputs) { + // strip the last grad suffix @GRAD + std::string var_name = output->Name(); + const std::string suffix(kGradVarSuffix); + if (var_name != suffix && var_name.size() > suffix.size() && + var_name.substr(var_name.size() - suffix.size()) == suffix) { + // if so then strip them off + var_name = var_name.substr(0, var_name.size() - suffix.size()); + if (weight_var_set.find(var_name) != weight_var_set.end()) { + grad_sum_op_set.insert(node); + break; + } + } + } + } + } + + // get the forward op and backward op pairs, where + // out <- forward(X, W) + // Grad1 <- backward(out, X') + // Grad0 <- SUM(Grad1, Grad2, Grad3 ...) + // W0 <- SGD(W1, Grad0) + for (ir::Node* node : grad_sum_op_set) { + for (ir::Node* merged_grad_var : node->outputs) { + // find the optimizers connected with sum op + if (IsVarNameEndsWith(merged_grad_var, kGradVarSuffix) && + merged_grad_var->outputs.size() == 1u) { + ir::Node* opt_node = merged_grad_var->outputs[0]; + LOG(ERROR) << "Found opt node " << opt_node->Name(); + + // find the backward op connected with sum op + for (ir::Node* unmerged_grad_var : node->inputs) { + if (IsVarNameContains(unmerged_grad_var, kGradVarSuffix) && + unmerged_grad_var->inputs.size() == 1u) { + ir::Node* backward_op = unmerged_grad_var->inputs[0]; + + LOG(ERROR) << "Found backward_op " << backward_op->Name(); + + // find the forward op related to the backward op + ir::Node* forward_op = + FindForwardOpViaBackwardOp(graph.get(), backward_op); + + LOG(ERROR) << "Found forward_op " << forward_op->Name(); + + PADDLE_ENFORCE(forward_op); + + Node* new_optimizer_node = CreateNewSGDNode( + graph.get(), forward_op, backward_op, node, opt_node); + + PADDLE_ENFORCE(new_optimizer_node); + } + } + } + } + } + + // Remove the sum_op and its' outputs and connected Optimizers + for (Node* sum_op : grad_sum_op_set) { + for (Node* sum_op_output : sum_op->outputs) { + for (Node* optimize_op : sum_op_output->outputs) { + if (optimize_op->NodeType() == Node::Type::kOperation && + optimize_op->Name() == kOptimizerType) { + LOG(ERROR) << "remove optimize_op: " << optimize_op->Name() << "_" + << optimize_op->id(); + graph->RemoveNode(optimize_op); + } + } + LOG(ERROR) << "remove sum_op_output: " << sum_op_output->Name() << "_" + << sum_op_output->id(); + graph->RemoveNode(sum_op_output); + } + LOG(ERROR) << "remove sum_op: " << sum_op->Name() << "_" << sum_op->id(); + graph->RemoveNode(sum_op); + } + + for (auto* node : graph->Nodes()) { + for (Node* output_node : node->outputs) { + if (output_node->Name() == "sgd") { + LOG(ERROR) << "Node link to SGD: " << node->Name() << "_" << node->id() + << " --> " << output_node->Name() << "_" + << output_node->id(); + for (Node* input_node : node->inputs) { + LOG(ERROR) << "SGD Input link: " << input_node->Name() << "_" + << input_node->id() << " --> " << node->Name() << "_" + << node->id(); + } + } + } + } + + return graph; +} + +ir::Node* LockFreeOptimizePass::CreateNewSGDNode( + ir::Graph* graph, ir::Node* forward_node, ir::Node* backward_node, + ir::Node* grad_sum_node, ir::Node* optimize_node) const { + PADDLE_ENFORCE(graph); + PADDLE_ENFORCE(forward_node); + PADDLE_ENFORCE(backward_node); + PADDLE_ENFORCE(grad_sum_node); + PADDLE_ENFORCE(optimize_node); + + // find the grad var node between the grad sum node and backward_node + std::vector<ir::Node*> grad_vars = + FindConnectedNode(backward_node, grad_sum_node); + ir::Node* grad_node = nullptr; + for (ir::Node* node : grad_vars) { + if (!ir::IsControlDepVar(*node)) { + grad_node = node; + } + } + PADDLE_ENFORCE(grad_node); + + // create a new SGD node + OpDesc* old_desc = optimize_node->Op(); + // keep with the same block between new optimizer and the old one + OpDesc new_desc(*old_desc, old_desc->Block()); + new_desc.SetInput("Param", old_desc->Input("Param")); + new_desc.SetInput("LearningRate", old_desc->Input("LearningRate")); + new_desc.SetInput("Grad", std::vector<std::string>({grad_node->Name()})); + new_desc.SetOutput("ParamOut", old_desc->Output("ParamOut")); + + std::vector<std::string> op_role_vars = boost::get<std::vector<std::string>>( + new_desc.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName())); + // replace the second op role var, because the grad name was + // changed in new optimizer + op_role_vars.pop_back(); + op_role_vars.push_back(grad_node->Name()); + new_desc.SetAttr(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(), + op_role_vars); + new_desc.SetType(kOptimizerType); + + // set backward op's op role var, this will be used to + // set device_id in multi_device_pass + backward_node->Op()->SetAttr( + framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(), op_role_vars); + // backward_node->Op()->SetAttr( + // framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(), {}); + + // keep with the same output nodes between new optimizer and the + // old one + Node* sgd_node = graph->CreateOpNode(&new_desc); + + // change all outputs of the optimize_node to the new one + ReplaceAllDownstreamNode(optimize_node, sgd_node); + + // find connected node between forward node and optimize node + // and replace the optimize node to new sgd node + std::vector<ir::Node*> forward_opt_connected_nodes = + FindConnectedNode(forward_node, optimize_node); + for (ir::Node* node : forward_opt_connected_nodes) { + ReplaceUpstreamNode(node, optimize_node, sgd_node); + } + + // find connected node between backward node and optimize node + // and replace the optimize node to new sgd node + std::vector<ir::Node*> backward_opt_connected_nodes = + FindConnectedNode(backward_node, optimize_node); + for (ir::Node* node : backward_opt_connected_nodes) { + ReplaceUpstreamNode(node, optimize_node, sgd_node); + } + + // SGD must have only one param and LR in + PADDLE_ENFORCE(old_desc->Input("LearningRate").size() == 1u); + PADDLE_ENFORCE(old_desc->Input("Param").size() == 1u); + + // LR and weight nodes should be copied + for (Node* upstream_node : optimize_node->inputs) { + if (upstream_node->Name() == old_desc->Input("LearningRate")[0] || + upstream_node->Name() == old_desc->Input("Param")[0]) { + ReplaceUpstreamNode(upstream_node, optimize_node, sgd_node); + } + } + + LOG(ERROR) << "Create new opt node" << sgd_node->Name() << "_" + << sgd_node->id(); + + return sgd_node; +} + +std::vector<ir::Node*> LockFreeOptimizePass::FindConnectedNode( + ir::Node* upstream_node, ir::Node* downstream_node) const { + std::vector<ir::Node*> result; + for (ir::Node* out_node : upstream_node->outputs) { + for (ir::Node* in_node : downstream_node->inputs) { + if (in_node == out_node) { + result.push_back(in_node); + } + } + } + + return result; +} + +void LockFreeOptimizePass::ReplaceUpstreamNode( + ir::Node* upstream_node, ir::Node* old_optimizer_node, + ir::Node* new_optimizer_node) const { + PADDLE_ENFORCE(upstream_node); + PADDLE_ENFORCE(old_optimizer_node); + PADDLE_ENFORCE(new_optimizer_node); + + // Remove the old_optimizer_node from upstream_node's outputs vector + auto& output_node_vec = upstream_node->outputs; + for (auto output_node_iter = output_node_vec.begin(); + output_node_iter != output_node_vec.end();) { + if (*output_node_iter == old_optimizer_node) { + output_node_vec.erase(output_node_iter); + break; + } else { + ++output_node_iter; + } + } + + // Add the new_optimizer_node to upstream_node's outputs vector + output_node_vec.emplace_back(new_optimizer_node); + new_optimizer_node->inputs.emplace_back(upstream_node); +} + +void LockFreeOptimizePass::ReplaceAllDownstreamNode( + ir::Node* old_optimizer_node, ir::Node* new_optimizer_node) const { + PADDLE_ENFORCE(old_optimizer_node); + PADDLE_ENFORCE(new_optimizer_node); + + for (ir::Node* downstream_node : old_optimizer_node->outputs) { + // Remove the old_optimizer_node from downstream_node's inputs vector + auto& input_node_vec = downstream_node->inputs; + for (auto input_node_iter = input_node_vec.begin(); + input_node_iter != input_node_vec.end();) { + if (*input_node_iter == old_optimizer_node) { + input_node_vec.erase(input_node_iter); + break; + } else { + ++input_node_iter; + } + } + + // Add the new_optimizer_node to downstream_node's inputs vector + input_node_vec.emplace_back(new_optimizer_node); + new_optimizer_node->outputs.emplace_back(downstream_node); + } +} + +ir::Node* LockFreeOptimizePass::FindForwardOpViaBackwardOp( + ir::Graph* graph, ir::Node* backward_node) const { + PADDLE_ENFORCE(graph); + PADDLE_ENFORCE(backward_node); + + // strip the suffix _grad of backward_node's name + std::string forward_op_name = backward_node->Name(); + const std::string suffix("_grad"); + if (forward_op_name != suffix && forward_op_name.size() > suffix.size() && + forward_op_name.substr(forward_op_name.size() - suffix.size()) == + suffix) { + // if so then strip them off + forward_op_name = + forward_op_name.substr(0, forward_op_name.size() - suffix.size()); + } else { + LOG(WARNING) << "Illegal backward node's name " << backward_node->Name() + << " id " << backward_node->id(); + + return nullptr; + } + + for (ir::Node* node : graph->Nodes()) { + if (node->Name() == forward_op_name) { + if (node->outputs.size() == 0u) { + // if forward_node has no output, then it has NO grad op + continue; + } + + // check whether all inputs of the backward_op that ends_with @GRAD + // comes from the output of forward_op is the input of the backward_op + bool is_related_forward_node = true; + for (ir::Node* backward_input : backward_node->inputs) { + if (IsVarNameEndsWith(backward_input, kGradVarSuffix)) { + bool meets_correct_output = false; + for (ir::Node* forward_output : node->outputs) { + if (forward_output->Name() + kGradVarSuffix == + backward_input->Name()) { + meets_correct_output = true; + break; + } + } + + if (!meets_correct_output) { + is_related_forward_node = false; + break; + } + } + } + + if (is_related_forward_node) { + return node; + } + } + } + + return nullptr; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(lock_free_optimize_pass, + paddle::framework::ir::LockFreeOptimizePass); diff --git a/paddle/fluid/framework/ir/lock_free_optimize_pass.h b/paddle/fluid/framework/ir/lock_free_optimize_pass.h new file mode 100644 index 0000000000..7310f596f8 --- /dev/null +++ b/paddle/fluid/framework/ir/lock_free_optimize_pass.h @@ -0,0 +1,130 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_ +#define PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_ + +#include <string> +#include <vector> + +#include <boost/algorithm/string/predicate.hpp> + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Node; + +/* +* Remove the sum op of all gradients of the backward op. +* And remove the dependecies of the optimizer related to the +* same backward op. +* +* Before this pass: +* +* forward_op1 forward_op2 +* | | +* grad_op1 grad_op2 +* \ / +* \ / +* sum_op +* | +* sgd_op +* +* After this pass: +* forward_op1 forward_op2 +* | | +* grad_op1 grad_op2 +* | | +* sgd_op1 sgd_op2 +* +* sgd_op1 and sgd_op2 will update the same weight which holds the same +* memory, so we could benefits from the acceleration +*/ +class LockFreeOptimizePass : public Pass { + public: + virtual ~LockFreeOptimizePass() {} + + protected: + std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const; + + private: + // Create a new sgd node via current optimizer node + ir::Node* CreateNewSGDNode(ir::Graph* graph, ir::Node* forward_node, + ir::Node* backward_node, ir::Node* grad_sum_node, + ir::Node* optimize_node) const; + + // Replace the input weight's optimizers + void ReplaceUpstreamNode(ir::Node* upstream_node, + ir::Node* old_optimizer_node, + ir::Node* new_optimizer_node) const; + + // Replace the output weight's optimizers + void ReplaceAllDownstreamNode(ir::Node* old_optimizer_node, + ir::Node* new_optimizer_node) const; + + // Find all weight variables in graph + bool FindAllWeightVars(ir::Graph* graph) const; + + // Find the forward_op node via the backward_op node + ir::Node* FindForwardOpViaBackwardOp(ir::Graph* graph, + ir::Node* backward_node) const; + + std::vector<ir::Node*> FindConnectedNode(ir::Node* upstream_node, + ir::Node* downstream_node) const; + + inline bool IsOpNamed(ir::Node* node, const std::string& name) const { + PADDLE_ENFORCE(node); + + return node->NodeType() == Node::Type::kOperation && node->Name() == name; + } + + inline bool IsVarNamed(ir::Node* node, const std::string& name) const { + PADDLE_ENFORCE(node); + + return node->NodeType() == Node::Type::kVariable && node->Name() == name; + } + + inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const { + PADDLE_ENFORCE(node); + + return node->NodeType() == Node::Type::kVariable && + boost::algorithm::ends_with(node->Name(), name); + } + + inline bool IsVarNameContains(ir::Node* node, const std::string& name) const { + PADDLE_ENFORCE(node); + + return node->NodeType() == Node::Type::kVariable && + node->Name().find(name) != std::string::npos; + } + + inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const { + PADDLE_ENFORCE(ctrl_dep_node); + PADDLE_ENFORCE(node); + + return IsControlDepVar(*ctrl_dep_node) && + ctrl_dep_node->inputs.size() >= 1u && + ctrl_dep_node->inputs[0] == node; + } +}; + +} // namespace ir +} // namespace framework +} // namespace paddle + +#endif // PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_ From 00e4de04bfa0ab0b90d153694fc7c597378bac16 Mon Sep 17 00:00:00 2001 From: minqiyang <minqiyang@baidu.com> Date: Mon, 7 Jan 2019 16:44:07 +0800 Subject: [PATCH 12/28] Polish code --- paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index 38dfae8ad6..758432fd9e 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -40,7 +40,7 @@ struct EmbeddingVSumFunctor { int64_t row_number = table_t->dims()[0]; int64_t row_width = table_t->dims()[1]; int64_t last_dim = output_t->dims()[1]; - int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>()); + const int64_t *ids = ids_t->data<int64_t>(); auto ids_lod = ids_t->lod()[0]; int64_t ids_count = ids_t->numel() / ids_lod.back(); From ee59e60f779749a3d431a54f68a32ebc5624df02 Mon Sep 17 00:00:00 2001 From: Tao Luo <luotao02@baidu.com> Date: Mon, 7 Jan 2019 16:59:48 +0800 Subject: [PATCH 13/28] update mklml version test=develop --- CMakeLists.txt | 5 ----- cmake/external/boost.cmake | 7 ++----- cmake/external/mklml.cmake | 24 +++++++++++------------- 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 66dcef0013..8ba8554456 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -126,11 +126,6 @@ if(ANDROID OR IOS) add_definitions(-DPADDLE_MOBILE_INFERENCE) endif() -if (APPLE) - set(WITH_MKL OFF CACHE STRING - "Disable MKL for building on mac" FORCE) -endif() - if (WIN32) set(WITH_DISTRIBUTE OFF CACHE STRING "Disable DISTRIBUTE when compiling for Windows" FORCE) diff --git a/cmake/external/boost.cmake b/cmake/external/boost.cmake index 5a78a1d1b7..12412a51a0 100644 --- a/cmake/external/boost.cmake +++ b/cmake/external/boost.cmake @@ -23,11 +23,8 @@ set(BOOST_PROJECT "extern_boost") # checked that the devtools package of CentOS 6 installs boost 1.41.0. # So we use 1.41.0 here. set(BOOST_VER "1.41.0") -if((NOT DEFINED BOOST_TAR) OR (NOT DEFINED BOOST_URL)) - message(STATUS "use pre defined download url") - set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE) - set(BOOST_URL "http://paddlepaddledeps.cdn.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE) -endif() +set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE) +set(BOOST_URL "http://paddlepaddledeps.cdn.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE) MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}") diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake index 96127e78d6..c94878b6c7 100644 --- a/cmake/external/mklml.cmake +++ b/cmake/external/mklml.cmake @@ -36,19 +36,17 @@ else() endif() SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib") -IF((NOT DEFINED MKLML_VER) OR (NOT DEFINED MKLML_URL)) - MESSAGE(STATUS "use pre defined download url") - if(WIN32) - SET(MKLML_VER "mklml_win_2019.0.1.20180928" CACHE STRING "" FORCE) - SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE) - elseif(APPLE) - SET(MKLML_VER "mklml_mac_2019.0.1.20180928" CACHE STRING "" FORCE) - SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) - else() - SET(MKLML_VER "mklml_lnx_2019.0.1.20180928" CACHE STRING "" FORCE) - SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) - ENDIF() -endif() +SET(TIME_VERSION "2019.0.1.20181227") +if(WIN32) + SET(MKLML_VER "mklml_win_${TIME_VERSION}" CACHE STRING "" FORCE) + SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE) +elseif(APPLE) + SET(MKLML_VER "mklml_mac_${TIME_VERSION}" CACHE STRING "" FORCE) + SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) +else() + SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE) + SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) +ENDIF() SET(MKLML_PROJECT "extern_mklml") MESSAGE(STATUS "MKLML_VER: ${MKLML_VER}, MKLML_URL: ${MKLML_URL}") From 7f45b9511aa1cf18f36709627a01a59bc1d3e661 Mon Sep 17 00:00:00 2001 From: minqiyang <minqiyang@baidu.com> Date: Mon, 7 Jan 2019 22:54:01 +0800 Subject: [PATCH 14/28] Polish code --- paddle/fluid/framework/operator.cc | 1 + paddle/fluid/operators/hash_op.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index f10da22aec..afece8e3d2 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -29,6 +29,7 @@ DECLARE_bool(benchmark); DEFINE_bool(check_nan_inf, false, "Checking whether operator produce NAN/INF or not. It will be " "extremely slow so please use this flag wisely."); +DEFINE_int32(inner_op_parallelism, 0, "number of threads for inner op"); namespace paddle { namespace framework { diff --git a/paddle/fluid/operators/hash_op.h b/paddle/fluid/operators/hash_op.h index 9781bb0f45..1ed3ffe9aa 100644 --- a/paddle/fluid/operators/hash_op.h +++ b/paddle/fluid/operators/hash_op.h @@ -45,7 +45,7 @@ class HashKerel : public framework::OpKernel<T> { for (int idx = 0; idx < seq_length; ++idx) { for (int ihash = 0; ihash != num_hash; ++ihash) { output[idx * num_hash + ihash] = - XXH64(input, sizeof(int) * last_dim, ihash) % mod_by; + XXH32(input, sizeof(int) * last_dim, ihash) % mod_by; } input += last_dim; } From 1bfbc0d963db26fcf72b9b53d568e0b102d50a5d Mon Sep 17 00:00:00 2001 From: minqiyang <minqiyang@baidu.com> Date: Mon, 7 Jan 2019 22:54:47 +0800 Subject: [PATCH 15/28] Polish code test=develop --- paddle/fluid/framework/operator.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index afece8e3d2..f10da22aec 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -29,7 +29,6 @@ DECLARE_bool(benchmark); DEFINE_bool(check_nan_inf, false, "Checking whether operator produce NAN/INF or not. It will be " "extremely slow so please use this flag wisely."); -DEFINE_int32(inner_op_parallelism, 0, "number of threads for inner op"); namespace paddle { namespace framework { From b76695418ad6cfe16f5fe54f9768fdf3b467a241 Mon Sep 17 00:00:00 2001 From: minqiyang <minqiyang@baidu.com> Date: Mon, 7 Jan 2019 22:55:59 +0800 Subject: [PATCH 16/28] Polish log test=develop --- .../framework/ir/lock_free_optimize_pass.cc | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/ir/lock_free_optimize_pass.cc b/paddle/fluid/framework/ir/lock_free_optimize_pass.cc index 96e7060aac..92e897ca9c 100644 --- a/paddle/fluid/framework/ir/lock_free_optimize_pass.cc +++ b/paddle/fluid/framework/ir/lock_free_optimize_pass.cc @@ -80,7 +80,7 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl( if (IsVarNameEndsWith(merged_grad_var, kGradVarSuffix) && merged_grad_var->outputs.size() == 1u) { ir::Node* opt_node = merged_grad_var->outputs[0]; - LOG(ERROR) << "Found opt node " << opt_node->Name(); + VLOG(3) << "Found opt node " << opt_node->Name(); // find the backward op connected with sum op for (ir::Node* unmerged_grad_var : node->inputs) { @@ -88,13 +88,13 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl( unmerged_grad_var->inputs.size() == 1u) { ir::Node* backward_op = unmerged_grad_var->inputs[0]; - LOG(ERROR) << "Found backward_op " << backward_op->Name(); + VLOG(3) << "Found backward_op " << backward_op->Name(); // find the forward op related to the backward op ir::Node* forward_op = FindForwardOpViaBackwardOp(graph.get(), backward_op); - LOG(ERROR) << "Found forward_op " << forward_op->Name(); + VLOG(3) << "Found forward_op " << forward_op->Name(); PADDLE_ENFORCE(forward_op); @@ -114,29 +114,28 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl( for (Node* optimize_op : sum_op_output->outputs) { if (optimize_op->NodeType() == Node::Type::kOperation && optimize_op->Name() == kOptimizerType) { - LOG(ERROR) << "remove optimize_op: " << optimize_op->Name() << "_" - << optimize_op->id(); + VLOG(3) << "remove optimize_op: " << optimize_op->Name() << "_" + << optimize_op->id(); graph->RemoveNode(optimize_op); } } - LOG(ERROR) << "remove sum_op_output: " << sum_op_output->Name() << "_" - << sum_op_output->id(); + VLOG(3) << "remove sum_op_output: " << sum_op_output->Name() << "_" + << sum_op_output->id(); graph->RemoveNode(sum_op_output); } - LOG(ERROR) << "remove sum_op: " << sum_op->Name() << "_" << sum_op->id(); + VLOG(3) << "remove sum_op: " << sum_op->Name() << "_" << sum_op->id(); graph->RemoveNode(sum_op); } for (auto* node : graph->Nodes()) { for (Node* output_node : node->outputs) { if (output_node->Name() == "sgd") { - LOG(ERROR) << "Node link to SGD: " << node->Name() << "_" << node->id() - << " --> " << output_node->Name() << "_" - << output_node->id(); + VLOG(3) << "Node link to SGD: " << node->Name() << "_" << node->id() + << " --> " << output_node->Name() << "_" << output_node->id(); for (Node* input_node : node->inputs) { - LOG(ERROR) << "SGD Input link: " << input_node->Name() << "_" - << input_node->id() << " --> " << node->Name() << "_" - << node->id(); + VLOG(3) << "SGD Input link: " << input_node->Name() << "_" + << input_node->id() << " --> " << node->Name() << "_" + << node->id(); } } } @@ -226,8 +225,7 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode( } } - LOG(ERROR) << "Create new opt node" << sgd_node->Name() << "_" - << sgd_node->id(); + VLOG(3) << "Create new opt node" << sgd_node->Name() << "_" << sgd_node->id(); return sgd_node; } From 5979953720ce35e5607f227d7b4c2400df0b8a35 Mon Sep 17 00:00:00 2001 From: minqiyang <minqiyang@baidu.com> Date: Mon, 7 Jan 2019 22:56:41 +0800 Subject: [PATCH 17/28] Remove debug info test=develop --- CMakeLists.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 74d869307d..d6aa8f1b85 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License -set(CMAKE_VERBOSE_MAKEFILE on) - cmake_minimum_required(VERSION 3.0) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) From 7c7342bf125ef2859d1dd7628ad5a494ffe315b9 Mon Sep 17 00:00:00 2001 From: sneaxiy <sneaxiy@126.com> Date: Tue, 8 Jan 2019 03:33:13 +0000 Subject: [PATCH 18/28] fix scope.var() test=develop --- paddle/fluid/framework/scope.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index a5742dbd3d..9536185609 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -87,11 +87,12 @@ Variable* Scope::Var(const std::string& name) { } Variable* Scope::Var(std::string* name) { - auto new_name = string::Sprintf("%p.%d", this, vars_.size()); + SCOPE_VARS_WRITER_LOCK + auto new_name = std::to_string(reinterpret_cast<uintptr_t>(this)) + "." + + std::to_string(vars_.size()); if (name != nullptr) { *name = new_name; } - SCOPE_VARS_WRITER_LOCK return VarInternal(new_name); } From ed409ac9f4fa57dbf8785f24dde4b55714555fc4 Mon Sep 17 00:00:00 2001 From: sneaxiy <sneaxiy@126.com> Date: Tue, 8 Jan 2019 03:37:59 +0000 Subject: [PATCH 19/28] Revert "Revert "Remove op handle lock"" test=develop --- paddle/fluid/operators/math/blas_impl.cu.h | 134 +++++++++---------- paddle/fluid/platform/cuda_helper.h | 58 ++++++++ paddle/fluid/platform/device_context.cc | 18 ++- paddle/fluid/platform/device_context.h | 76 ++++------- paddle/fluid/platform/device_context_test.cu | 3 - 5 files changed, 159 insertions(+), 130 deletions(-) create mode 100644 paddle/fluid/platform/cuda_helper.h diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index d35073029a..58f7be12ce 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -62,27 +62,19 @@ struct CUBlas<float> { cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Btype, int ldb, const float *beta, void *C, cudaDataType_t Ctype, int ldc) { - // Because the gcc 4.8 doesn't expand template parameter pack that - // appears in a lambda-expression, I can not use template parameter pack - // here. - auto cublas_call = [&]() { +// Because the gcc 4.8 doesn't expand template parameter pack that +// appears in a lambda-expression, I can not use template parameter pack +// here. #if CUDA_VERSION >= 8000 - VLOG(5) << "use_tensor_op_math: " - << (platform::TensorCoreAvailable() ? "True" : "False"); + VLOG(5) << "use_tensor_op_math: " + << (dev_ctx->tensor_core_available() ? "True" : "False"); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( - dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, - lda, B, Btype, ldb, beta, C, Ctype, ldc)); + handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, + beta, C, Ctype, ldc)); + }); #else - PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); -#endif - }; - -#if CUDA_VERSION >= 9000 - // NOTES: To use Tensor Core, we should change the cublas config, - // but the cublas may be hold by multi-thread. - dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); -#else - cublas_call(); + PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); #endif } }; @@ -170,32 +162,24 @@ struct CUBlas<platform::float16> { cudaDataType_t Btype, int ldb, const void *beta, void *C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType) { - auto cublas_call = [&]() { #if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 - bool use_tensor_op_math = platform::TensorCoreAvailable(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); + bool use_tensor_op_math = dev_ctx->tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); #endif // CUDA_VERSION >= 9000 + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE(platform::dynload::cublasGemmEx( - dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, - lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)); + handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, + beta, C, Ctype, ldc, computeType, algo)); + }); #else - PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); -#endif - }; - -#if CUDA_VERSION >= 9000 - // NOTES: To use Tensor Core, we should change the cublas config, - // but the cublas may be hold by multi-thread. - dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); -#else - cublas_call(); + PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); #endif } }; @@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA, CUDA_R_32F, N); } else { #endif // CUDA_VERSION >= 8000 - - CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, - &alpha, B, ldb, A, lda, &beta, C, N); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, + lda, &beta, C, N); + }); #if CUDA_VERSION >= 8000 } @@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F); #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, - N, M, K, &h_alpha, h_B, ldb, h_A, lda, - &h_beta, h_C, N); + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, + &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, + N); + }); #endif // CUDA_VERSION >= 8000 } @@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M, } else { #endif // CUDA_VERSION >= 8000 - CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, - &alpha, B, ldb, A, lda, &beta, C, ldc); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, + lda, &beta, C, ldc); + }); #if CUDA_VERSION >= 8000 } @@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, ldb, A, lda, &beta, C, - ldc); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, + B, ldb, A, lda, &beta, C, ldc); + }); } template <> template <typename T> void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x, T *y) const { - CUBlas<T>::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1); + }); } template <> @@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N, T beta, T *C) const { cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBlas<T>::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1, - &beta, C, 1); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); + }); } template <> @@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM( #if CUDA_VERSION >= 9010 if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) { - auto cublas_call = [&]() { - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = platform::TensorCoreAvailable(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); - + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( - context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, - CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, - CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo)); - }; - auto &dev_ctx = const_cast<platform::CUDADeviceContext &>(context_); - dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); + handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb, + strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc, + strideC, batchCount, CUDA_R_32F, algo)); + }); } else { #endif // CUDA_VERSION >= 9010 - CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, ldb, strideB, A, lda, - strideA, &beta, C, ldc, strideC, batchCount); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha, + B, ldb, strideB, A, lda, strideA, &beta, C, + ldc, strideC, batchCount); + }); #if CUDA_VERSION >= 9010 } diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h new file mode 100644 index 0000000000..122de72e15 --- /dev/null +++ b/paddle/fluid/platform/cuda_helper.h @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include <mutex> // NOLINT + +#include "paddle/fluid/platform/dynload/cublas.h" +#include "paddle/fluid/platform/macros.h" + +#if CUDA_VERSION < 9000 +enum cublasMath_t { CUBLAS_DEFAULT_MATH = 0 }; +#endif + +namespace paddle { +namespace platform { + +class CublasHandleHolder { + public: + CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) { + PADDLE_ENFORCE(dynload::cublasCreate(&handle_)); + PADDLE_ENFORCE(dynload::cublasSetStream(handle_, stream)); +#if CUDA_VERSION >= 9000 + if (math_type == CUBLAS_TENSOR_OP_MATH) { + PADDLE_ENFORCE( + dynload::cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH)); + } +#endif + } + + ~CublasHandleHolder() { PADDLE_ENFORCE(dynload::cublasDestroy(handle_)); } + + template <typename Callback> + inline void Call(Callback &&callback) const { + std::lock_guard<std::mutex> guard(mtx_); + callback(handle_); + } + + private: + DISABLE_COPY_AND_ASSIGN(CublasHandleHolder); + + cublasHandle_t handle_; + mutable std::mutex mtx_; +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 022afb686b..be7f4949d6 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); - PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); - PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); + cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH)); + + if (TensorCoreAvailable()) { +#if CUDA_VERSION >= 9000 + cublas_tensor_core_handle_.reset( + new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH)); +#endif + } + if (dynload::HasCUDNN()) { cudnn_holder_.reset(new CudnnHolder(&stream_, place)); } @@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); Wait(); WaitStreamCallback(); - PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); + cublas_handle_.reset(); + cublas_tensor_core_handle_.reset(); eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { return eigen_device_.get(); } -cublasHandle_t CUDADeviceContext::cublas_handle() const { - return cublas_handle_; +bool CUDADeviceContext::tensor_core_available() const { + return cublas_tensor_core_handle_ != nullptr; } cudnnHandle_t CUDADeviceContext::cudnn_handle() const { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 7e87580189..c81d17380c 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/temporary_allocator.h" #ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/gpu_info.h" @@ -209,39 +210,6 @@ class CudnnWorkspaceHandle { std::unique_ptr<std::lock_guard<std::mutex>> guard_; }; -#if CUDA_VERSION >= 9000 -class ScopedCublasMathMode { - public: - ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode) - : handle_(handle) { - need_reset = false; - PADDLE_ENFORCE( - platform::dynload::cublasGetMathMode(handle_, &old_math_mode_), - "Failed to get old cublas math mode"); - if (old_math_mode_ != new_math_mode) { - PADDLE_ENFORCE( - platform::dynload::cublasSetMathMode(handle_, new_math_mode), - "Failed to set old cublas math mode"); - need_reset = true; - } - } - - ~ScopedCublasMathMode() { - if (need_reset) { - PADDLE_ENFORCE( - platform::dynload::cublasSetMathMode(handle_, old_math_mode_), - "Failed to set old cublas math mode"); - } - } - - private: - cublasHandle_t handle_; - cublasMath_t old_math_mode_; - bool need_reset; -}; - -#endif - class CUDADeviceContext : public DeviceContext { public: explicit CUDADeviceContext(CUDAPlace place); @@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; - /*! \brief Return cublas handle in the device context. */ - cublasHandle_t cublas_handle() const; + /*! \brief Call cublas function safely. */ + template <typename Callback> + inline void CublasCall(Callback&& callback) const { + cublas_handle_->Call(std::forward<Callback>(callback)); + } + + /*! \brief Check whether tensor core is supported */ + bool tensor_core_available() const; + + /*! \brief Call cublas function with Tensor Core safely. If + Tensor Core is not available, use DEFAULT_MATH instead. */ + template <typename Callback> + inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const { + if (cublas_tensor_core_handle_) { + cublas_tensor_core_handle_->Call(std::forward<Callback>(callback)); + } else { + cublas_handle_->Call(std::forward<Callback>(callback)); + } + } /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; @@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext { template <typename Callback> void RecordEvent(cudaEvent_t ev, Callback callback) { - std::lock_guard<std::mutex> guard(mtx_); callback(); PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); } @@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext { void WaitStreamCallback() const { callback_manager_->Wait(); } -#if CUDA_VERSION >= 9000 - /*! \brief CublasCall may need to change cublas's config, - * but the cublas may be hold by multi-thread, so we should - * add lock here. */ - template <typename Callback> - void CublasCall(Callback callback, cublasMath_t new_math) { - std::lock_guard<std::mutex> guard(cublas_mtx_); - ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math); - callback(); - } -#endif - private: CUDAPlace place_; @@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<CudnnHolder> cudnn_holder_; cudaStream_t stream_; - cublasHandle_t cublas_handle_; + + std::unique_ptr<CublasHandleHolder> cublas_handle_; + std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_; int compute_capability_; int runtime_version_; @@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext { int multi_process_; int max_threads_per_mp_; - mutable std::mutex mtx_; - // StreamCallbackManager is thread-safe std::unique_ptr<StreamCallbackManager> callback_manager_; - mutable std::mutex cublas_mtx_; + DISABLE_COPY_AND_ASSIGN(CUDADeviceContext); }; template <> diff --git a/paddle/fluid/platform/device_context_test.cu b/paddle/fluid/platform/device_context_test.cu index 171d2979a0..5b3aa98efb 100644 --- a/paddle/fluid/platform/device_context_test.cu +++ b/paddle/fluid/platform/device_context_test.cu @@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) { ASSERT_NE(nullptr, gpu_device); cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); ASSERT_NE(nullptr, cudnn_handle); - cublasHandle_t cublas_handle = device_context->cublas_handle(); - ASSERT_NE(nullptr, cublas_handle); - ASSERT_NE(nullptr, device_context->stream()); delete device_context; } } From 49c31e5da409f9af01182ea74a91d605e3ca9747 Mon Sep 17 00:00:00 2001 From: Tao Luo <luotao02@baidu.com> Date: Mon, 7 Jan 2019 20:31:20 +0800 Subject: [PATCH 20/28] disable mkl for mac test=develop --- CMakeLists.txt | 5 +++++ cmake/external/mklml.cmake | 30 +++++++++++++++--------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8ba8554456..66dcef0013 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -126,6 +126,11 @@ if(ANDROID OR IOS) add_definitions(-DPADDLE_MOBILE_INFERENCE) endif() +if (APPLE) + set(WITH_MKL OFF CACHE STRING + "Disable MKL for building on mac" FORCE) +endif() + if (WIN32) set(WITH_DISTRIBUTE OFF CACHE STRING "Disable DISTRIBUTE when compiling for Windows" FORCE) diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake index c94878b6c7..43322a257a 100644 --- a/cmake/external/mklml.cmake +++ b/cmake/external/mklml.cmake @@ -16,6 +16,12 @@ IF(NOT ${WITH_MKLML}) return() ENDIF(NOT ${WITH_MKLML}) +IF(APPLE) + MESSAGE(WARNING "Mac is not supported with MKLML in Paddle yet. Force WITH_MKLML=OFF.") + SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in MacOS" FORCE) + return() +ENDIF() + INCLUDE(ExternalProject) SET(MKLML_DST_DIR "mklml") SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install") @@ -23,29 +29,23 @@ SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR}) SET(MKLML_ROOT ${MKLML_INSTALL_DIR}) SET(MKLML_INC_DIR ${MKLML_ROOT}/include) SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib) -if(WIN32) +SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib") + +SET(TIME_VERSION "2019.0.1.20181227") +IF(WIN32) + SET(MKLML_VER "mklml_win_${TIME_VERSION}" CACHE STRING "" FORCE) + SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE) SET(MKLML_LIB ${MKLML_LIB_DIR}/mklml.lib) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib) SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll) SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll) -else() +ELSE() + SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE) + SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) -endif() -SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib") - -SET(TIME_VERSION "2019.0.1.20181227") -if(WIN32) - SET(MKLML_VER "mklml_win_${TIME_VERSION}" CACHE STRING "" FORCE) - SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE) -elseif(APPLE) - SET(MKLML_VER "mklml_mac_${TIME_VERSION}" CACHE STRING "" FORCE) - SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) -else() - SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE) - SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) ENDIF() SET(MKLML_PROJECT "extern_mklml") From 7b7d0d0caf85fc2d104ac285cfa367ff46490fa1 Mon Sep 17 00:00:00 2001 From: minqiyang <minqiyang@baidu.com> Date: Tue, 8 Jan 2019 13:30:09 +0800 Subject: [PATCH 21/28] Change hash function back test=develop --- paddle/fluid/operators/hash_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/hash_op.h b/paddle/fluid/operators/hash_op.h index 1ed3ffe9aa..9781bb0f45 100644 --- a/paddle/fluid/operators/hash_op.h +++ b/paddle/fluid/operators/hash_op.h @@ -45,7 +45,7 @@ class HashKerel : public framework::OpKernel<T> { for (int idx = 0; idx < seq_length; ++idx) { for (int ihash = 0; ihash != num_hash; ++ihash) { output[idx * num_hash + ihash] = - XXH32(input, sizeof(int) * last_dim, ihash) % mod_by; + XXH64(input, sizeof(int) * last_dim, ihash) % mod_by; } input += last_dim; } From 23bdd0a223cc3e88c62fb8f48155c83455c9fede Mon Sep 17 00:00:00 2001 From: superjomn <yanchunwei@outlook.com> Date: Tue, 8 Jan 2019 15:11:48 +0800 Subject: [PATCH 22/28] fix analysis_tester bug test=develop --- paddle/fluid/inference/analysis/analyzer_tester.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index f84e1ab6b8..4c84d02d86 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -80,8 +80,8 @@ void TestWord2vecPrediction(const std::string& model_path) { i++) { LOG(INFO) << "data: " << static_cast<float*>(outputs.front().data.data())[i] << " result: " << result[i]; - PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i], - result[i]); + EXPECT_NEAR(static_cast<float*>(outputs.front().data.data())[i], result[i], + 1e-3); } } From 69fd3fdb5206045cfcee90d98b52cf070f1dcae1 Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Tue, 8 Jan 2019 09:11:39 +0000 Subject: [PATCH 23/28] fix debug build error test=develop --- paddle/fluid/inference/analysis/passes/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/inference/analysis/passes/CMakeLists.txt b/paddle/fluid/inference/analysis/passes/CMakeLists.txt index d3ea511d8f..add9b70f2c 100644 --- a/paddle/fluid/inference/analysis/passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/passes/CMakeLists.txt @@ -7,4 +7,5 @@ set(analysis_deps ${analysis_deps} ir_graph_build_pass ir_analysis_pass analysis_passes + subgraph_detector CACHE INTERNAL "") From bc205ef37453e0f7ab1f74abb123c3367ceee3c7 Mon Sep 17 00:00:00 2001 From: sneaxiy <sneaxiy@126.com> Date: Tue, 8 Jan 2019 10:28:01 +0000 Subject: [PATCH 24/28] fix same name func test=develop --- paddle/fluid/framework/var_type_traits.cc | 8 +++++--- paddle/fluid/framework/var_type_traits.h | 4 ++-- paddle/fluid/framework/var_type_traits_test.cc | 9 +++++---- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/var_type_traits.cc b/paddle/fluid/framework/var_type_traits.cc index c3c5bab23b..a37b1fbab8 100644 --- a/paddle/fluid/framework/var_type_traits.cc +++ b/paddle/fluid/framework/var_type_traits.cc @@ -105,13 +105,15 @@ struct VarIdToTypeIndexMapHolder { } // namespace detail -const std::type_index &ToTypeIndex(int var_id) { +const std::type_index &VarTraitIdToTypeIndex(int var_id) { return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id); } -const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); } +const char *ToTypeName(int var_id) { + return VarTraitIdToTypeIndex(var_id).name(); +} -int ToTypeId(const std::type_index &type) { +int TypeIndexToVarTraitId(const std::type_index &type) { return detail::VarIdToTypeIndexMapHolder::ToTypeId(type); } diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index cc68cf2ab8..733542e497 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -66,8 +66,8 @@ namespace paddle { namespace framework { const char *ToTypeName(int var_id); -const std::type_index &ToTypeIndex(int var_id); -int ToTypeId(const std::type_index &type); +const std::type_index &VarTraitIdToTypeIndex(int var_id); +int TypeIndexToVarTraitId(const std::type_index &type); namespace detail { diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/paddle/fluid/framework/var_type_traits_test.cc index 00840d634d..a47275e1ca 100644 --- a/paddle/fluid/framework/var_type_traits_test.cc +++ b/paddle/fluid/framework/var_type_traits_test.cc @@ -45,10 +45,11 @@ struct TypeIndexChecker { constexpr auto kId = VarTypeTrait<Type>::kId; std::type_index actual_type(typeid(Type)); EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name())); - EXPECT_EQ(ToTypeIndex(kId), actual_type); - EXPECT_EQ(ToTypeId(actual_type), kId); - EXPECT_EQ(ToTypeIndex(ToTypeId(actual_type)), actual_type); - EXPECT_EQ(ToTypeId(ToTypeIndex(kId)), kId); + EXPECT_EQ(VarTraitIdToTypeIndex(kId), actual_type); + EXPECT_EQ(TypeIndexToVarTraitId(actual_type), kId); + EXPECT_EQ(VarTraitIdToTypeIndex(TypeIndexToVarTraitId(actual_type)), + actual_type); + EXPECT_EQ(TypeIndexToVarTraitId(VarTraitIdToTypeIndex(kId)), kId); EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT From 55a0672378329764a1b1429d9cfc8def91317e63 Mon Sep 17 00:00:00 2001 From: chengduo <zhaochengduo@baidu.com> Date: Tue, 8 Jan 2019 05:20:48 -0600 Subject: [PATCH 25/28] fix compute_75 of cuda_cmake (#15209) test=develop --- cmake/cuda.cmake | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 10ecdf0ea8..16432ce2b8 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -2,9 +2,11 @@ if(NOT WITH_GPU) return() endif() -set(paddle_known_gpu_archs "30 35 50 52 60 61 70 75") +set(paddle_known_gpu_archs "30 35 50 52 60 61 70") set(paddle_known_gpu_archs7 "30 35 50 52") set(paddle_known_gpu_archs8 "30 35 50 52 60 61") +set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70") +set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75") ###################################################################################### # A function for automatic detection of GPUs installed (if autodetection is enabled) @@ -155,6 +157,16 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x # warning for now. list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets") add_definitions("-DPADDLE_CUDA_BINVER=\"80\"") +elseif (${CUDA_VERSION} LESS 10.0) # CUDA 9.x + set(paddle_known_gpu_archs ${paddle_known_gpu_archs9}) + list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") + list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") + add_definitions("-DPADDLE_CUDA_BINVER=\"90\"") +elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x + set(paddle_known_gpu_archs ${paddle_known_gpu_archs10}) + list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") + list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") + add_definitions("-DPADDLE_CUDA_BINVER=\"100\"") endif() include_directories(${CUDA_INCLUDE_DIRS}) From e4184008a4e4aa60fbd21d43209256ec1114186f Mon Sep 17 00:00:00 2001 From: mozga-intel <mateusz.ozga@intel.com> Date: Tue, 8 Jan 2019 16:37:03 +0100 Subject: [PATCH 26/28] PADDLE_WITH_NGRAPH was removed from the code test=develop --- paddle/fluid/operators/ngraph/ops/binary_unnary_op.h | 2 -- paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h | 2 -- paddle/fluid/operators/ngraph/ops/fill_constant_op.h | 2 -- paddle/fluid/operators/ngraph/ops/mean_op.h | 2 -- paddle/fluid/operators/ngraph/ops/mul_op.h | 2 -- paddle/fluid/operators/ngraph/ops/scale_op.h | 2 -- paddle/fluid/operators/ngraph/ops/top_k_op.h | 2 -- 7 files changed, 14 deletions(-) diff --git a/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h b/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h index 6610380fcf..0c0d25d0cd 100644 --- a/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h +++ b/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include <string> @@ -48,4 +47,3 @@ static void BuildUnaryNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h b/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h index 15fbd58b02..8f5092963c 100644 --- a/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h +++ b/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include <string> @@ -58,4 +57,3 @@ std::shared_ptr<ngraph::Node> ElementwiseScalar( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/fill_constant_op.h b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h index 5eff69e7b1..406a4314f8 100644 --- a/paddle/fluid/operators/ngraph/ops/fill_constant_op.h +++ b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include <string> @@ -58,4 +57,3 @@ void BuildFillConstantNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/mean_op.h b/paddle/fluid/operators/ngraph/ops/mean_op.h index 7fcf8f09cd..4c44bc4c11 100644 --- a/paddle/fluid/operators/ngraph/ops/mean_op.h +++ b/paddle/fluid/operators/ngraph/ops/mean_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include <functional> @@ -65,4 +64,3 @@ void BuildMeanGradNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/mul_op.h b/paddle/fluid/operators/ngraph/ops/mul_op.h index 9e12e5d7c3..4a6cbebe24 100644 --- a/paddle/fluid/operators/ngraph/ops/mul_op.h +++ b/paddle/fluid/operators/ngraph/ops/mul_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include <string> @@ -131,4 +130,3 @@ static void BuildMulGradNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/scale_op.h b/paddle/fluid/operators/ngraph/ops/scale_op.h index 24ab0702aa..91a57d0be6 100644 --- a/paddle/fluid/operators/ngraph/ops/scale_op.h +++ b/paddle/fluid/operators/ngraph/ops/scale_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include <string> @@ -38,4 +37,3 @@ void BuildScaleNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/top_k_op.h b/paddle/fluid/operators/ngraph/ops/top_k_op.h index 2b7254497c..ea66953a12 100644 --- a/paddle/fluid/operators/ngraph/ops/top_k_op.h +++ b/paddle/fluid/operators/ngraph/ops/top_k_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include <string> @@ -48,4 +47,3 @@ void BuildTopKNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif From a037378fdb96773f44e0c12c14d2119b7e76996a Mon Sep 17 00:00:00 2001 From: qingqing01 <dangqingqing@baidu.com> Date: Wed, 9 Jan 2019 10:16:40 +0800 Subject: [PATCH 27/28] Fix error with cuDNN version less than 7.1. (#15219) Since conv_fusion_op is not exposed into Python, remote the env flag in __init__.py test=develop --- python/paddle/fluid/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index f9f3807b15..2c17716500 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -155,7 +155,7 @@ def __bootstrap__(): 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', - 'cudnn_exhaustive_search_times', 'sync_nccl_allreduce' + 'sync_nccl_allreduce' ] core.init_gflags([sys.argv[0]] + From f23a257e905e61f513c2a68cdfd9fb39d8ff16db Mon Sep 17 00:00:00 2001 From: Tao Luo <luotao02@baidu.com> Date: Wed, 9 Jan 2019 11:26:14 +0800 Subject: [PATCH 28/28] use the new MKLDNN repo url test=develop --- cmake/external/mkldnn.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index a9b99e9ab8..03f0dee859 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -55,7 +55,7 @@ ExternalProject_Add( ${MKLDNN_PROJECT} ${EXTERNAL_PROJECT_LOG_ARGS} DEPENDS ${MKLDNN_DEPENDS} - GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" + GIT_REPOSITORY "https://github.com/intel/mkl-dnn.git" GIT_TAG "830a10059a018cd2634d94195140cf2d8790a75a" PREFIX ${MKLDNN_SOURCES_DIR} UPDATE_COMMAND ""