|
|
|
@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/operator.h"
|
|
|
|
|
#include "paddle/fluid/platform/transform.h"
|
|
|
|
|
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
#include <cuda.h>
|
|
|
|
|
#include <thrust/iterator/iterator_adaptor.h>
|
|
|
|
|
#include "paddle/fluid/platform/cuda_helper.h"
|
|
|
|
|
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
@ -43,35 +44,35 @@ namespace operators {
|
|
|
|
|
*/
|
|
|
|
|
inline void get_mid_dims(const framework::DDim& x_dims,
|
|
|
|
|
const framework::DDim& y_dims, const int axis,
|
|
|
|
|
int& pre, int& n, int& post) {
|
|
|
|
|
pre = 1;
|
|
|
|
|
n = 1;
|
|
|
|
|
post = 1;
|
|
|
|
|
int* pre, int* n, int* post) {
|
|
|
|
|
*pre = 1;
|
|
|
|
|
*n = 1;
|
|
|
|
|
*post = 1;
|
|
|
|
|
for (int i = 0; i < axis; ++i) {
|
|
|
|
|
pre *= x_dims[i];
|
|
|
|
|
(*pre) *= x_dims[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < y_dims.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
|
|
|
|
|
"Broadcast dimension mismatch.");
|
|
|
|
|
n *= y_dims[i];
|
|
|
|
|
(*n) *= y_dims[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
|
|
|
|
|
post *= x_dims[i];
|
|
|
|
|
(*post) *= x_dims[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void trim_trailing_singular_dims(framework::DDim& dims) {
|
|
|
|
|
inline void trim_trailing_singular_dims(framework::DDim* dims) {
|
|
|
|
|
// Remove trailing dimensions of size 1 for y
|
|
|
|
|
auto actual_dims_size = dims.size();
|
|
|
|
|
auto actual_dims_size = dims->size();
|
|
|
|
|
for (; actual_dims_size != 0; --actual_dims_size) {
|
|
|
|
|
if (dims[actual_dims_size - 1] != 1) break;
|
|
|
|
|
if ((*dims)[actual_dims_size - 1] != 1) break;
|
|
|
|
|
}
|
|
|
|
|
if (actual_dims_size != dims.size()) {
|
|
|
|
|
auto actual_dims = framework::vectorize(dims);
|
|
|
|
|
if (actual_dims_size != dims->size()) {
|
|
|
|
|
auto actual_dims = framework::vectorize(*dims);
|
|
|
|
|
actual_dims.resize(actual_dims_size);
|
|
|
|
|
dims = framework::make_ddim(actual_dims);
|
|
|
|
|
*dims = framework::make_ddim(actual_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -159,7 +160,7 @@ class RowwiseTransformIterator<T, platform::CUDADeviceContext>
|
|
|
|
|
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
|
|
|
|
|
super_t;
|
|
|
|
|
HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
|
|
|
|
|
: super_t(x), begin_(x), n_(n){};
|
|
|
|
|
: super_t(x), begin_(x), n_(n) {}
|
|
|
|
|
friend class thrust::iterator_core_access;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -179,7 +180,7 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
|
|
|
|
|
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
|
|
|
|
|
super_t;
|
|
|
|
|
HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
|
|
|
|
|
: super_t(x), begin_(x), n_(n), post_(post){};
|
|
|
|
|
: super_t(x), begin_(x), n_(n), post_(post) {}
|
|
|
|
|
friend class thrust::iterator_core_access;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -333,6 +334,55 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
|
|
|
|
|
// __shfl_down has been deprecated as of CUDA 9.0.
|
|
|
|
|
#if CUDA_VERSION < 9000
|
|
|
|
|
template <typename T>
|
|
|
|
|
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
|
|
|
|
|
return __shfl_down(val, delta);
|
|
|
|
|
}
|
|
|
|
|
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
|
|
|
|
|
#else
|
|
|
|
|
#define FULL_WARP_MASK 0xFFFFFFFF
|
|
|
|
|
#define CREATE_SHFL_MASK(mask, predicate) \
|
|
|
|
|
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__device__ T reduceSum(T val, int tid, int len) {
|
|
|
|
|
// TODO(zcd): The warp size should be taken from the
|
|
|
|
|
// parameters of the GPU but not specified as 32 simply.
|
|
|
|
|
// To make the reduceSum more efficiently,
|
|
|
|
|
// I use Warp-Level Parallelism and assume the Warp size
|
|
|
|
|
// is 32 which may be different for different GPU,
|
|
|
|
|
// but most card's warp size is 32.
|
|
|
|
|
__shared__ T shm[32];
|
|
|
|
|
const int warpSize = 32;
|
|
|
|
|
unsigned mask = 0u;
|
|
|
|
|
CREATE_SHFL_MASK(mask, tid < len);
|
|
|
|
|
|
|
|
|
|
for (int offset = warpSize / 2; offset > 0; offset /= 2)
|
|
|
|
|
val += __shfl_down_sync(mask, val, offset);
|
|
|
|
|
|
|
|
|
|
if (tid < warpSize) shm[tid] = 0;
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
if (tid % warpSize == 0) {
|
|
|
|
|
shm[tid / warpSize] = val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CREATE_SHFL_MASK(mask, tid < warpSize);
|
|
|
|
|
|
|
|
|
|
if (tid < warpSize) {
|
|
|
|
|
val = shm[tid];
|
|
|
|
|
for (int offset = warpSize / 2; offset > 0; offset /= 2)
|
|
|
|
|
val += __shfl_down_sync(mask, val, offset);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename DX_OP, typename DY_OP>
|
|
|
|
|
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
|
|
|
|
|
const T* x, const T* y, const T* out, const T* dout, int h, int w,
|
|
|
|
@ -355,7 +405,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
|
|
|
|
|
|
|
|
|
|
if (dy) {
|
|
|
|
|
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
|
|
|
|
|
val = platform::reduceSum(val, tid, h);
|
|
|
|
|
val = reduceSum(val, tid, h);
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
dy[j] = val;
|
|
|
|
|
}
|
|
|
|
@ -432,7 +482,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
|
|
|
|
|
if (dy) {
|
|
|
|
|
int h = pre * post;
|
|
|
|
|
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
|
|
|
|
|
val = platform::reduceSum(val, tid, h);
|
|
|
|
|
val = reduceSum(val, tid, h);
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
dy[j] = val;
|
|
|
|
|
}
|
|
|
|
@ -472,11 +522,11 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
|
|
|
|
|
auto y_dim = y.dims();
|
|
|
|
|
|
|
|
|
|
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
|
|
|
|
|
trim_trailing_singular_dims(y_dim);
|
|
|
|
|
trim_trailing_singular_dims(&y_dim);
|
|
|
|
|
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
|
|
|
|
|
|
|
|
|
|
int pre, n, post;
|
|
|
|
|
get_mid_dims(x_dim, y_dim, axis, pre, n, post);
|
|
|
|
|
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
|
|
|
|
|
if (post == 1) {
|
|
|
|
|
int h = pre;
|
|
|
|
|
int w = n;
|
|
|
|
@ -514,7 +564,7 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T, typename functor,
|
|
|
|
|
typename broadcastfunctor, typename broadcast2functor>
|
|
|
|
@ -543,11 +593,11 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
|
|
|
|
|
trim_trailing_singular_dims(y_dims);
|
|
|
|
|
trim_trailing_singular_dims(&y_dims);
|
|
|
|
|
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
|
|
|
|
|
|
|
|
|
|
int pre, n, post;
|
|
|
|
|
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
|
|
|
|
|
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
|
|
|
|
|
|
|
|
|
|
if (post == 1) {
|
|
|
|
|
broadcastfunctor f;
|
|
|
|
@ -582,11 +632,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
|
|
|
|
|
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
|
|
|
|
|
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
|
|
|
|
|
"Axis should be in range [0, x_dims)");
|
|
|
|
|
trim_trailing_singular_dims(y_dims);
|
|
|
|
|
trim_trailing_singular_dims(&y_dims);
|
|
|
|
|
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
|
|
|
|
|
|
|
|
|
|
int pre, n, post;
|
|
|
|
|
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
|
|
|
|
|
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
|
|
|
|
|
if (post == 1) {
|
|
|
|
|
functor.RunRowWise(n, pre);
|
|
|
|
|
return;
|
|
|
|
|