diff --git a/paddle/cuda/src/hl_top_k.cu b/paddle/cuda/src/hl_top_k.cu index 4a737d5ba7..b17290557c 100644 --- a/paddle/cuda/src/hl_top_k.cu +++ b/paddle/cuda/src/hl_top_k.cu @@ -244,7 +244,7 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK, if (--beamSize == 0) break; __syncthreads(); - // temporary solution + // NOTE(zcd): temporary solution unsigned mask = 0u; CREATE_SHFL_MASK(mask, true); diff --git a/paddle/fluid/operators/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/batch_norm_mkldnn_op.cc new file mode 100644 index 0000000000..0e4a56d4a4 --- /dev/null +++ b/paddle/fluid/operators/batch_norm_mkldnn_op.cc @@ -0,0 +1,325 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/operators/batch_norm_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; +using paddle::platform::MKLDNNMemDesc; +using mkldnn::memory; + +template +using EigenArrayMap = + Eigen::Map>; +template +using ConstEigenArrayMap = + Eigen::Map>; +template +using EigenVectorArrayMap = Eigen::Map>; +template +using ConstEigenVectorArrayMap = + Eigen::Map>; + +namespace { +template +struct bn_type_traits { + using op_type = T; + using op_desc = typename op_type::desc; + using op_prim = typename op_type::primitive_desc; +}; + +template +void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end, + Container *c) { + auto it = std::begin(*c); + + std::copy(scale_begin, scale_end, std::inserter(*c, it)); + std::copy( + shift_begin, shift_end, + std::inserter(*c, std::next(it, std::distance(scale_begin, scale_end)))); +} + +template +void run_batch_norm_op(Args &&... args) { + Op batch_norm_op{args...}; + + std::vector pipeline; + pipeline.push_back(batch_norm_op); + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); +} + +template +inline void *cast_const_to_void(const T *t) { + return static_cast(const_cast(t)); +} +} // namespace + +template +class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto data_layout_str = ctx.Attr("data_layout"); + auto data_layout = framework::StringToDataLayout(data_layout_str); + PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW, + "MKLDNN batch normalization handles only NCHW data layout"); + + const float epsilon = ctx.Attr("epsilon"); + const float momentum = ctx.Attr("momentum"); + const bool is_test = ctx.Attr("is_test"); + + const auto *x = ctx.Input("X"); + const auto *mean = ctx.Input("Mean"); + const auto *variance = ctx.Input("Variance"); + + auto &dev_ctx = ctx.template device_context(); + auto mkldnn_engine = dev_ctx.GetEngine(); + + auto *y = ctx.Output("Y"); + auto *mean_out = ctx.Output("MeanOut"); + auto *variance_out = ctx.Output("VarianceOut"); + auto *batch_mean = ctx.Output("SavedMean"); + auto *batch_variance = ctx.Output("SavedVariance"); + + const auto *scale = ctx.Input("Scale"); + const auto *shift = ctx.Input("Bias"); + + y->mutable_data(ctx.GetPlace()); + mean_out->mutable_data(ctx.GetPlace()); + variance_out->mutable_data(ctx.GetPlace()); + + if (!is_test) { + batch_mean->mutable_data(ctx.GetPlace()); + batch_variance->mutable_data(ctx.GetPlace()); + } + + auto propagation = is_test == true ? mkldnn::prop_kind::forward_scoring + : mkldnn::prop_kind::forward_training; + + auto dims = paddle::framework::vectorize2int(x->dims()); + + auto src_md = + MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); + auto dst_md = + MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); + + auto src_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine}; + auto dst_pd = mkldnn::memory::primitive_desc{dst_md, mkldnn_engine}; + + auto src = mkldnn::memory{src_pd, cast_const_to_void(x->data())}; + auto dst = mkldnn::memory{dst_pd, y->data()}; + + unsigned flags = mkldnn::use_scale_shift; + if (is_test) flags |= mkldnn::use_global_stats; + + using bn_fwd_types = bn_type_traits; + auto batch_norm_fwd_desc = + bn_fwd_types::op_desc{propagation, src_md, epsilon, flags}; + auto batch_norm_fwd_pd = + bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine}; + + const unsigned int ic = dims[1]; + + // MKLDNN requires a single piece of memory for scale and shift/bias data + const size_t scaleshift_size = 2 * ic; + std::vector scaleshift_data; + scaleshift_data.reserve(scaleshift_size); + + copy_to_weights(scale->data(), scale->data() + ic, shift->data(), + shift->data() + ic, &scaleshift_data); + + auto scaleshift_memory = mkldnn::memory{ + batch_norm_fwd_pd.weights_primitive_desc(), scaleshift_data.data()}; + + if (is_test) { + auto mean_memory = mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(), + cast_const_to_void(mean->data())}; + + auto variance_memory = + mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(), + cast_const_to_void(variance->data())}; + + run_batch_norm_op( + batch_norm_fwd_pd, src, (const mkldnn::primitive::at &)mean_memory, + (const mkldnn::primitive::at &)variance_memory, scaleshift_memory, + dst); + } else { + auto mean_memory = + mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(), + cast_const_to_void(batch_mean->data())}; + + auto variance_memory = + mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(), + cast_const_to_void(batch_variance->data())}; + + run_batch_norm_op(batch_norm_fwd_pd, src, + scaleshift_memory, dst, + mean_memory, variance_memory); + } + + if (!is_test) { + const unsigned int in = dims[0]; + const unsigned int sample_size = x->numel() / in / ic; + + // saved_xx is use just in this batch of data + EigenVectorArrayMap saved_mean_e( + batch_mean->mutable_data(ctx.GetPlace()), ic); + EigenVectorArrayMap saved_variance_e( + batch_variance->mutable_data(ctx.GetPlace()), ic); + saved_mean_e.setZero(); + saved_variance_e.setZero(); + + const unsigned int x_arr_size = in * ic; + ConstEigenArrayMap x_arr(x->data(), sample_size, x_arr_size); + for (unsigned int nc = 0; nc < x_arr_size; ++nc) { + saved_mean_e(nc % ic) += x_arr.col(nc).sum(); + } + saved_mean_e /= in * sample_size; + for (unsigned int nc = 0; nc < x_arr_size; ++nc) { + saved_variance_e(nc % ic) += + (x_arr.col(nc) - saved_mean_e(nc % ic)).matrix().squaredNorm(); + } + saved_variance_e /= in * sample_size; + + ConstEigenVectorArrayMap mean_arr{mean->data(), ic}; + ConstEigenVectorArrayMap variance_arr{variance->data(), ic}; + + EigenVectorArrayMap running_mean_arr( + mean_out->mutable_data(ctx.GetPlace()), ic); + EigenVectorArrayMap running_var_arr( + variance_out->mutable_data(ctx.GetPlace()), ic); + + auto one_minus_momentum = 1. - momentum; + running_mean_arr = + mean_arr * momentum + saved_mean_e * one_minus_momentum; + running_var_arr = + variance_arr * momentum + saved_variance_e * one_minus_momentum; + } + } +}; + +template +class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext &ctx) const override { + auto data_layout_str = ctx.Attr("data_layout"); + auto data_layout = framework::StringToDataLayout(data_layout_str); + PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW, + "MKLDNN batch normalization handles only NCHW data layout"); + + auto &dev_ctx = ctx.template device_context(); + auto mkldnn_engine = dev_ctx.GetEngine(); + + const float epsilon = ctx.Attr("epsilon"); + + const auto *x = ctx.Input("X"); + const auto *scale = ctx.Input("Scale"); + const auto *shift = ctx.Input("Bias"); + const auto *batch_mean = ctx.Input("SavedMean"); + const auto *batch_variance = ctx.Input("SavedVariance"); + + const auto *diff_y = ctx.Input(framework::GradVarName("Y")); + auto *diff_x = ctx.Output(framework::GradVarName("X")); + auto *diff_scale = ctx.Output(framework::GradVarName("Scale")); + auto *diff_shift = ctx.Output(framework::GradVarName("Bias")); + + diff_x->mutable_data(ctx.GetPlace()); + diff_scale->mutable_data(ctx.GetPlace()); + diff_shift->mutable_data(ctx.GetPlace()); + + auto dims = paddle::framework::vectorize2int(x->dims()); + unsigned flags = mkldnn::use_scale_shift | !mkldnn::use_global_stats; + + auto src_md = + MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); + auto dst_md = + MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); + auto diff_src_md = + MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); + auto diff_dst_md = + MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); + + using bn_bwd_types = bn_type_traits; + using bn_fwd_types = bn_type_traits; + + auto batch_norm_fwd_desc = bn_fwd_types::op_desc{ + mkldnn::prop_kind::forward_training, src_md, epsilon, flags}; + auto batch_norm_fwd_pd = + bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine}; + + auto batch_norm_bwd_desc = bn_bwd_types::op_desc{ + mkldnn::prop_kind::backward, diff_dst_md, dst_md, epsilon, flags}; + auto batch_norm_bwd_pd = bn_bwd_types::op_prim{ + batch_norm_bwd_desc, mkldnn_engine, batch_norm_fwd_pd}; + + auto src = mkldnn::memory{{src_md, mkldnn_engine}, + cast_const_to_void(x->data())}; + + auto mean = mkldnn::memory{batch_norm_bwd_pd.mean_primitive_desc(), + cast_const_to_void(batch_mean->data())}; + + auto variance = + mkldnn::memory{batch_norm_bwd_pd.variance_primitive_desc(), + cast_const_to_void(batch_variance->data())}; + + auto diff_dst = mkldnn::memory{{diff_dst_md, mkldnn_engine}, + cast_const_to_void(diff_y->data())}; + + const unsigned int ic = dims[1]; + + const size_t scaleshift_size = 2 * ic; + + std::vector scaleshift_data; + scaleshift_data.reserve(scaleshift_size); + copy_to_weights(scale->data(), scale->data() + ic, shift->data(), + shift->data() + ic, &scaleshift_data); + + auto scaleshift_memory = mkldnn::memory{ + batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data()}; + + std::vector diff_scaleshift_data; + diff_scaleshift_data.reserve(scaleshift_size); + copy_to_weights(diff_scale->data(), diff_scale->data() + ic, + diff_shift->data(), diff_shift->data() + ic, + &diff_scaleshift_data); + + auto diff_scaleshift_memory = + mkldnn::memory{batch_norm_bwd_pd.diff_weights_primitive_desc(), + diff_scaleshift_data.data()}; + + auto diff_src = mkldnn::memory{{diff_src_md, mkldnn_engine}, + static_cast(diff_x->data())}; + + run_batch_norm_op( + batch_norm_bwd_pd, src, mean, variance, diff_dst, scaleshift_memory, + diff_src, diff_scaleshift_memory); + + auto it = std::begin(diff_scaleshift_data); + std::copy(it, std::next(it, ic), diff_scale->data()); + std::copy(std::next(it, ic), std::end(diff_scaleshift_data), + diff_shift->data()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(batch_norm, MKLDNN, paddle::platform::CPUPlace, + ops::BatchNormMKLDNNOpKernel); +REGISTER_OP_KERNEL(batch_norm_grad, MKLDNN, paddle::platform::CPUPlace, + ops::BatchNormMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index f8b2505ccf..b4bd40d031 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -15,6 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/batch_norm_op.h" #include #include "paddle/fluid/framework/data_layout.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -106,7 +109,18 @@ class BatchNormOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType( ctx.Input("Variance")->type()), "Variance input should be of float type"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + + framework::LibraryType library_{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + } +#endif + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library_); } }; @@ -151,6 +165,9 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { "Variance of the current mini batch, " "will apply to output when training") .AsIntermediate(); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( Batch Normalization. @@ -349,8 +366,19 @@ class BatchNormGradOp : public framework::OperatorWithKernel { if (t == nullptr) { PADDLE_THROW("can't find Y@GRAD"); } - return framework::OpKernelType(framework::ToDataType(t->type()), - ctx.GetPlace()); + + framework::LibraryType library_{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + } +#endif + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + layout, library_); } }; @@ -474,6 +502,7 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker { op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); op->SetInput("Scale", Input("Scale")); + op->SetInput("Bias", Input("Bias")); op->SetInput("SavedMean", Output("SavedMean")); op->SetInput("SavedVariance", Output("SavedVariance")); diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index 953aedc850..8b052611f8 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -22,6 +22,7 @@ limitations under the License. */ #ifdef __NVCC__ #include #include +#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; #endif @@ -336,43 +337,6 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, } #ifdef __NVCC__ - -template -__device__ T reduceSum(T val, int tid, int len) { - // NOTE(zcd): The warp size should be taken from the - // parameters of the GPU but not specified as 32 simply. - // To make the reduceSum more efficiently, - // I use Warp-Level Parallelism and assume the Warp size - // is 32 which may be different for different GPU, - // but most card's warp size is 32. - const int warpSize = 32; - __shared__ T shm[warpSize]; - unsigned mask = 0u; - CREATE_SHFL_MASK(mask, tid < len); - - for (int offset = warpSize / 2; offset > 0; offset /= 2) - val += platform::__shfl_down_sync(mask, val, offset); - - if (tid < warpSize) shm[tid] = 0; - - __syncthreads(); - - if (tid % warpSize == 0) { - shm[tid / warpSize] = val; - } - __syncthreads(); - - CREATE_SHFL_MASK(mask, tid < warpSize); - - if (tid < warpSize) { - val = shm[tid]; - for (int offset = warpSize / 2; offset > 0; offset /= 2) - val += platform::__shfl_down_sync(mask, val, offset); - } - - return val; -} - template static __global__ void ElemwiseGradBroadcast1CUDAKernel( const T* x, const T* y, const T* out, const T* dout, int h, int w, @@ -395,7 +359,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( if (dy) { h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = reduceSum(val, tid, h); + val = paddle::platform::reduceSum(val, tid, h); if (threadIdx.x == 0) { dy[j] = val; } @@ -472,7 +436,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( if (dy) { int h = pre * post; h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = reduceSum(val, tid, h); + val = paddle::platform::reduceSum(val, tid, h); if (threadIdx.x == 0) { dy[j] = val; } diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index 6d2ba2bd0d..0de58d5fdd 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/cross_entropy.h" +#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { @@ -30,66 +31,22 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, } } -template -__device__ __forceinline__ T sum_single_warp(T val) { - val += platform::__shfl_down_sync(0, val, 16); - val += platform::__shfl_down_sync(0, val, 8); - val += platform::__shfl_down_sync(0, val, 4); - val += platform::__shfl_down_sync(0, val, 2); - val += platform::__shfl_down_sync(0, val, 1); - return val; -} - -// CUDA do not support dynamic arrary in template -// https://stackoverflow.com/questions/20497209 -template -struct SharedMemory { - // Ensure that we won't compile any un-specialized types - __device__ T* GetPointer() { return NULL; } -}; - -template <> -struct SharedMemory { - __device__ float* GetPointer() { - extern __shared__ float s_float[]; - return s_float; - } -}; - -template <> -struct SharedMemory { - __device__ double* GetPointer() { - extern __shared__ double s_double[]; - return s_double; - } -}; - template __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, const int class_num) { int tid = threadIdx.x; - SharedMemory d_sum_shared; - T* d_sum = d_sum_shared.GetPointer(); - d_sum[tid] = 0; + T val = 0; - int cur_idx = tid; - int next_idx = blockIdx.x * class_num + tid; - while (cur_idx < class_num) { - d_sum[tid] += - math::TolerableValue()(std::log(X[next_idx])) * label[next_idx]; - next_idx += blockDim.x; - cur_idx += blockDim.x; + int idx = blockIdx.x * class_num + tid; + int end = blockIdx.x * class_num + class_num; + for (; idx < end; idx += blockDim.x) { + val += math::TolerableValue()(std::log(X[idx])) * label[idx]; } - __syncthreads(); - for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) { - if (tid < stride) d_sum[tid] += d_sum[tid + stride]; - __syncthreads(); + val = paddle::platform::reduceSum(val, tid, blockDim.x); + if (threadIdx.x == 0) { + Y[blockIdx.x] = -val; } - - T val = d_sum[tid]; - val = sum_single_warp(val); - if (tid == 0) Y[blockIdx.x] = -val; } } // namespace @@ -113,9 +70,7 @@ class CrossEntropyFunctor { ? 512 : pow(2, static_cast(std::log2(class_num))); - SoftCrossEntropyKernel<<< - batch_size, block, block * sizeof(T), - reinterpret_cast(ctx).stream()>>>( + SoftCrossEntropyKernel<<>>( loss_data, prob_data, label_data, class_num); } else { const int64_t* label_data = labels->data(); diff --git a/paddle/fluid/operators/row_conv_op.cu b/paddle/fluid/operators/row_conv_op.cu index dd8e62aca4..79d08cf3d1 100644 --- a/paddle/fluid/operators/row_conv_op.cu +++ b/paddle/fluid/operators/row_conv_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/row_conv_op.h" -#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/cuda_device_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index a2e3973fe8..e447e78b49 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/assert.h" -#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/cuda_device_function.h" namespace paddle { namespace operators { @@ -236,12 +236,13 @@ __device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid, sh_topk[tid] = topk[*beam]; } } - // temporary solution + // NOTE(zcd): temporary solution unsigned mask = 0u; CREATE_SHFL_MASK(mask, true); if (maxid[0] / 32 == warp) { - if (__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength) break; + if (platform::__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength) + break; } } } diff --git a/paddle/fluid/platform/cuda_device_function.h b/paddle/fluid/platform/cuda_device_function.h new file mode 100644 index 0000000000..7cfeaab35b --- /dev/null +++ b/paddle/fluid/platform/cuda_device_function.h @@ -0,0 +1,74 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +namespace paddle { +namespace platform { + +// __shfl_down and __shfl have been deprecated as of CUDA 9.0. +#if CUDA_VERSION < 9000 +template +__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { + return __shfl_down(val, delta); +} + +template +__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line, + int width) { + return __shfl(val, src_line, width); +} +#define CREATE_SHFL_MASK(mask, predicate) mask = 0u; +#else +#define FULL_WARP_MASK 0xFFFFFFFF +#define CREATE_SHFL_MASK(mask, predicate) \ + mask = __ballot_sync(FULL_WARP_MASK, (predicate)) +#endif + +template +__device__ T reduceSum(T val, int tid, int len) { + // NOTE(zcd): The warp size should be taken from the + // parameters of the GPU but not specified as 32 simply. + // To make the reduceSum more efficiently, + // I use Warp-Level Parallelism and assume the Warp size + // is 32 which may be different for different GPU, + // but most card's warp size is 32. + const int warpSize = 32; + __shared__ T shm[warpSize]; + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, tid < len); + + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += platform::__shfl_down_sync(mask, val, offset); + + if (tid < warpSize) shm[tid] = 0; + + if (tid % warpSize == 0) { + shm[tid / warpSize] = val; + } + __syncthreads(); + + CREATE_SHFL_MASK(mask, tid < warpSize); + + if (tid < warpSize) { + val = shm[tid]; + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += platform::__shfl_down_sync(mask, val, offset); + } + return val; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index 0f6e6159b6..d535ed2f89 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -65,26 +65,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) { return __longlong_as_double(old); } #endif - -// __shfl_down has been deprecated as of CUDA 9.0. -#if CUDA_VERSION < 9000 -template -__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { - return __shfl_down(val, delta); -} - -template -__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line, - int width) { - return __shfl(val, src_line, width); -} - -#define CREATE_SHFL_MASK(mask, predicate) mask = 0u; -#else -#define FULL_WARP_MASK 0xFFFFFFFF -#define CREATE_SHFL_MASK(mask, predicate) \ - mask = __ballot_sync(FULL_WARP_MASK, (predicate)) -#endif - } // namespace platform } // namespace paddle diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 53455fd860..1595cc9e8a 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -40,6 +40,7 @@ function print_usage() { ${BLUE}capi${NONE}: generate paddle CAPI package ${BLUE}fluid_inference_lib${NONE}: deploy fluid inference library ${BLUE}check_style${NONE}: run code style check + ${BLUE}cicheck${NONE}: run CI tasks " } @@ -453,6 +454,8 @@ function gen_capi_package() { } function gen_fluid_inference_lib() { + mkdir -p ${PADDLE_ROOT}/build + cd ${PADDLE_ROOT}/build if [ ${WITH_C_API:-OFF} == "OFF" ] ; then cat <void + reader: + parallel: True if use multi-CPUs or multi-GPUs + feed_order: Feeding order of reader. None will following the defining + order in program + + Returns: + + """ + if parallel: + raise NotImplementedError( + "Parallel Executor version of trainer is not implemented") + + self._train_by_executor(num_epochs, event_handler, reader, feed_order) def test(self, reader): pass + + def _get_scope_from_params(self, params): + """ + Get Scope from parameter object. + Args: + params(Parameter|None): The parameter object instance. Could be None. + + Returns: New scope if params is None. Or params.scope() + NOTE: This method is WIP. Not fully implemented. + """ + if params is None: + return core.Scope() # new scope when params is None + else: + raise NotImplementedError("Not implemented right now.") + + @staticmethod + def _check_and_get_place(place): + """ + Check the type of place or get the default place + Args: + place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on. + + Raises: + TypeError if the type mismatched. + + Returns: + the original place if it is not None. + if fluid is compiled with CUDA, returns CUDAPlace(0) by default. + Otherwise returns CPUPlace by default. + """ + if place is None: + if core.is_compiled_with_cuda(): + return core.CUDAPlace(0) + else: + return core.CPUPlace() + else: + if not isinstance(place, core.CUDAPlace) and not isinstance( + place, core.CPUPlace): + raise TypeError("Place should be either CUDAPlace or CPUPlace") + return place + + @contextlib.contextmanager + def _prog_and_scope_guard(self): + with framework.program_guard( + main_program=self.train_program, + startup_program=self.startup_program): + with executor.scope_guard(self.scope): + yield + + def _train_by_executor(self, num_epochs, event_handler, reader, feed_order): + """ + Train by Executor and single device. + + Args: + num_epochs: + event_handler: + reader: + feed_order: + + Returns: + + """ + with self._prog_and_scope_guard(): + exe = executor.Executor(self.place) + if feed_order is None: + feed_var_list = [ + var + for var in self.train_program.global_block( + ).vars.itervalues() + if hasattr(var, 'is_data') and var.is_data + ] + else: + feed_var_list = [ + self.train_program.global_block().var(var_name) + for var_name in feed_order + ] + + feeder = data_feeder.DataFeeder( + feed_list=feed_var_list, place=self.place) + for epoch_id in range(num_epochs): + event_handler(BeginEpochEvent(epoch_id)) + for step_id, data in enumerate(reader()): + event_handler(BeginStepEvent(epoch_id, step_id)) + exe.run(feed=feeder.feed(data), fetch_list=[]) + event_handler(EndStepEvent(epoch_id, step_id)) + event_handler(EndEpochEvent(epoch_id))