From 03ccb9a461db7650fd1dc749f2f61a4df253bf31 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Thu, 15 Nov 2018 16:07:16 +0800 Subject: [PATCH 1/4] Optimize the stack operator --- paddle/fluid/operators/stack_op.h | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h index d236c5b943..f1692ae956 100644 --- a/paddle/fluid/operators/stack_op.h +++ b/paddle/fluid/operators/stack_op.h @@ -147,16 +147,23 @@ class StackKernel : public framework::OpKernel { auto &dim = x[0]->dims(); for (auto i = 0; i < axis; ++i) pre *= dim[i]; for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; - int total_num = pre * n * post; - auto &dev_ctx = ctx.template device_context(); #ifdef __NVCC__ thrust::device_vector device_x_vec(x_datas); auto x_data_arr = device_x_vec.data().get(); #else auto x_data_arr = x_datas.data(); #endif - StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); + size_t x_offset = 0; + size_t y_offset = 0; + for (int i = 0; i < pre; i++) { + for (int j = 0; j < n; j++) { + std::memcpy(y_data + y_offset, x_data_arr[j] + x_offset, + post * sizeof(T)); + y_offset += post; + } + x_offset += post; + } #ifdef __NVCC__ // Wait() must be called because device_x_vec may be destructed before // kernel ends From be50670348a23b35172e2420baeb058321ab3e13 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Tue, 20 Nov 2018 08:24:00 +0800 Subject: [PATCH 2/4] Remove the remnant code (test=develop) --- paddle/fluid/operators/stack_op.h | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h index f1692ae956..56a12852a9 100644 --- a/paddle/fluid/operators/stack_op.h +++ b/paddle/fluid/operators/stack_op.h @@ -72,25 +72,6 @@ class StackOpMaker : public framework::OpProtoAndCheckerMaker { } }; -template -struct StackFunctor { - HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post) - : x_(x), y_(y), n_(n), post_(post) {} - - HOSTDEVICE void operator()(int idx) { - int i = idx / (n_ * post_); - int which_x = idx / post_ - i * n_; - int x_index = i * post_ + idx % post_; - y_[idx] = x_[which_x][x_index]; - } - - private: - VecXType x_; - T *y_; - int n_; - int post_; -}; - template struct StackGradFunctor { HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post) @@ -110,14 +91,6 @@ struct StackGradFunctor { int post_; }; -template -static inline void StackFunctorForRange(const DeviceContext &ctx, - const VecXType &x, T *y, int total_num, - int n, int post) { - platform::ForRange for_range(ctx, total_num); - for_range(StackFunctor(x, y, n, post)); -} - template static inline void StackGradFunctorForRange(const DeviceContext &ctx, const VecDxType &dx, const T *dy, From d91740acb1e49e4baaad02aeda379f27f6ec0f69 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Tue, 20 Nov 2018 08:25:48 +0800 Subject: [PATCH 3/4] Revert "Remove the remnant code (test=develop)" This reverts commit be50670348a23b35172e2420baeb058321ab3e13. --- paddle/fluid/operators/stack_op.h | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h index 56a12852a9..f1692ae956 100644 --- a/paddle/fluid/operators/stack_op.h +++ b/paddle/fluid/operators/stack_op.h @@ -72,6 +72,25 @@ class StackOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +struct StackFunctor { + HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post) + : x_(x), y_(y), n_(n), post_(post) {} + + HOSTDEVICE void operator()(int idx) { + int i = idx / (n_ * post_); + int which_x = idx / post_ - i * n_; + int x_index = i * post_ + idx % post_; + y_[idx] = x_[which_x][x_index]; + } + + private: + VecXType x_; + T *y_; + int n_; + int post_; +}; + template struct StackGradFunctor { HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post) @@ -91,6 +110,14 @@ struct StackGradFunctor { int post_; }; +template +static inline void StackFunctorForRange(const DeviceContext &ctx, + const VecXType &x, T *y, int total_num, + int n, int post) { + platform::ForRange for_range(ctx, total_num); + for_range(StackFunctor(x, y, n, post)); +} + template static inline void StackGradFunctorForRange(const DeviceContext &ctx, const VecDxType &dx, const T *dy, From a906a361be831b9b425a9f197036fef506020857 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Tue, 20 Nov 2018 08:30:27 +0800 Subject: [PATCH 4/4] Add the macro for NVCC (test=develop) --- paddle/fluid/operators/stack_op.h | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h index f1692ae956..3d132e4397 100644 --- a/paddle/fluid/operators/stack_op.h +++ b/paddle/fluid/operators/stack_op.h @@ -149,11 +149,20 @@ class StackKernel : public framework::OpKernel { for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; #ifdef __NVCC__ + int total_num = pre * n * post; + auto &dev_ctx = ctx.template device_context(); + thrust::device_vector device_x_vec(x_datas); auto x_data_arr = device_x_vec.data().get(); + + StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); + + // Wait() must be called because device_x_vec may be destructed before + // kernel ends + dev_ctx.Wait(); #else auto x_data_arr = x_datas.data(); -#endif + size_t x_offset = 0; size_t y_offset = 0; for (int i = 0; i < pre; i++) { @@ -164,10 +173,6 @@ class StackKernel : public framework::OpKernel { } x_offset += post; } -#ifdef __NVCC__ - // Wait() must be called because device_x_vec may be destructed before - // kernel ends - dev_ctx.Wait(); #endif } };