|
|
|
@ -548,6 +548,64 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename OP>
|
|
|
|
|
static __global__ void FastCommonGradBroadcastOneCUDAKernel(
|
|
|
|
|
const T *x, const T *y, const T *out, const T *dout, int pre, int n,
|
|
|
|
|
int post, int y_pre, int y_n, int y_post, bool is_xsize, OP op, T *dd) {
|
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
|
int bid = blockIdx.x;
|
|
|
|
|
|
|
|
|
|
T val(0);
|
|
|
|
|
if (is_xsize) {
|
|
|
|
|
// do reduce for x
|
|
|
|
|
for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) {
|
|
|
|
|
int b_i = bid / post;
|
|
|
|
|
int b_j = bid % post;
|
|
|
|
|
int x_offset = b_i * n * post + b_j;
|
|
|
|
|
int out_offset = b_i * n * post + i * post + b_j;
|
|
|
|
|
|
|
|
|
|
// Get y pre rows id with x post and y_pre.
|
|
|
|
|
int b_yi = bid / (post * y_pre);
|
|
|
|
|
int b_yj = bid % y_post;
|
|
|
|
|
int y_offset = b_yi * y_n + i * y_post + b_yj;
|
|
|
|
|
|
|
|
|
|
if (dd) {
|
|
|
|
|
val += op(x[x_offset], y[y_offset], out[out_offset], dout[out_offset]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (dd) {
|
|
|
|
|
int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n;
|
|
|
|
|
val = paddle::platform::reduceSum(val, tid, h);
|
|
|
|
|
if (tid == 0) {
|
|
|
|
|
dd[bid] = val;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// do reduce for y
|
|
|
|
|
for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) {
|
|
|
|
|
int b_i = bid / post;
|
|
|
|
|
int b_j = bid % post;
|
|
|
|
|
int y_offset = b_i * n * post + b_j;
|
|
|
|
|
int out_offset = b_i * n * post + i * post + b_j;
|
|
|
|
|
|
|
|
|
|
int b_yi = bid / (post * y_pre);
|
|
|
|
|
int b_yj = bid % y_post;
|
|
|
|
|
int x_offset = b_yi * y_n + i * y_post + b_yj;
|
|
|
|
|
|
|
|
|
|
if (dd) {
|
|
|
|
|
val += op(x[x_offset], y[y_offset], out[out_offset], dout[out_offset]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (dd) {
|
|
|
|
|
int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n;
|
|
|
|
|
val = paddle::platform::reduceSum(val, tid, h);
|
|
|
|
|
if (tid == 0) {
|
|
|
|
|
dd[bid] = val;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Check input can be split into 2 parts
|
|
|
|
|
static inline bool SplitDims(const std::vector<int> &y_broadcast_pos,
|
|
|
|
|
int max_dim) {
|
|
|
|
@ -568,6 +626,16 @@ static inline bool SplitDims(const std::vector<int> &y_broadcast_pos,
|
|
|
|
|
return can_split_dim2;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Suppose only has contiguous dims
|
|
|
|
|
static inline bool CheckContiguousDims(const std::vector<int> &broadcast_pos) {
|
|
|
|
|
for (int i = 1; i < broadcast_pos.size(); ++i) {
|
|
|
|
|
if (broadcast_pos[i] != broadcast_pos[i - 1] + 1) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename DX_OP, typename DY_OP>
|
|
|
|
|
void CommonGradBroadcastCUDA(
|
|
|
|
|
const framework::Tensor &x, const framework::Tensor &y,
|
|
|
|
@ -644,6 +712,7 @@ void CommonGradBroadcastCUDA(
|
|
|
|
|
y_broadcast_pos.emplace_back(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto stream = ctx.stream();
|
|
|
|
|
bool can_split_x = false;
|
|
|
|
|
bool can_split_y = false;
|
|
|
|
@ -751,10 +820,22 @@ void CommonGradBroadcastCUDA(
|
|
|
|
|
int axis = broadcast_pos[0];
|
|
|
|
|
int pre = std::accumulate(out_dims_array, out_dims_array + axis, 1,
|
|
|
|
|
std::multiplies<int>());
|
|
|
|
|
int mid = out_dims_array[axis];
|
|
|
|
|
int post =
|
|
|
|
|
std::accumulate(out_dims_array + axis + 1, out_dims_array + max_dim, 1,
|
|
|
|
|
std::multiplies<int>());
|
|
|
|
|
int mid = 1;
|
|
|
|
|
int post = 1;
|
|
|
|
|
|
|
|
|
|
if (broadcast_pos.size() == 1) {
|
|
|
|
|
mid = out_dims_array[axis];
|
|
|
|
|
post =
|
|
|
|
|
std::accumulate(out_dims_array + axis + 1, out_dims_array + max_dim,
|
|
|
|
|
1, std::multiplies<int>());
|
|
|
|
|
} else {
|
|
|
|
|
mid = std::accumulate(out_dims_array + axis,
|
|
|
|
|
out_dims_array + broadcast_pos.back() + 1, 1,
|
|
|
|
|
std::multiplies<int>());
|
|
|
|
|
post =
|
|
|
|
|
std::accumulate(out_dims_array + broadcast_pos.back() + 1,
|
|
|
|
|
out_dims_array + max_dim, 1, std::multiplies<int>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "FastBroadCastAllCUDAF pre:" << pre << " mid:" << mid
|
|
|
|
|
<< " post:" << post;
|
|
|
|
@ -767,6 +848,55 @@ void CommonGradBroadcastCUDA(
|
|
|
|
|
dy_op, dx_data, dy_data);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto FastBroadCastOneCUDAF = [&](const std::vector<int> &broadcast_pos,
|
|
|
|
|
int max_dim, bool is_x) {
|
|
|
|
|
int axis = broadcast_pos[0];
|
|
|
|
|
int pre = std::accumulate(out_dims_array, out_dims_array + axis, 1,
|
|
|
|
|
std::multiplies<int>());
|
|
|
|
|
int mid = out_dims_array[axis];
|
|
|
|
|
int post =
|
|
|
|
|
std::accumulate(out_dims_array + axis + 1, out_dims_array + max_dim, 1,
|
|
|
|
|
std::multiplies<int>());
|
|
|
|
|
|
|
|
|
|
int k_pre;
|
|
|
|
|
int k_mid;
|
|
|
|
|
int k_post;
|
|
|
|
|
|
|
|
|
|
if (is_x) {
|
|
|
|
|
k_pre = std::accumulate(y_dims_array, y_dims_array + axis, 1,
|
|
|
|
|
std::multiplies<int>());
|
|
|
|
|
k_mid = y_dims_array[axis];
|
|
|
|
|
k_post = std::accumulate(y_dims_array + axis + 1, y_dims_array + max_dim,
|
|
|
|
|
1, std::multiplies<int>());
|
|
|
|
|
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
|
|
|
|
|
int grid_size = pre * post;
|
|
|
|
|
// we need to calc y offset with blockid, so do x_pre/y_pre to get left
|
|
|
|
|
// size.
|
|
|
|
|
if (k_pre != pre) k_pre = pre / k_pre;
|
|
|
|
|
|
|
|
|
|
FastCommonGradBroadcastOneCUDAKernel<<<grid_size, block_size, 0,
|
|
|
|
|
stream>>>(
|
|
|
|
|
x_data, y_data, out_data, dout_data, pre, mid, post, k_pre, k_mid,
|
|
|
|
|
k_post, true, dx_op, dx_data);
|
|
|
|
|
} else {
|
|
|
|
|
k_pre = std::accumulate(x_dims_array, x_dims_array + axis, 1,
|
|
|
|
|
std::multiplies<int>());
|
|
|
|
|
k_mid = x_dims_array[axis];
|
|
|
|
|
k_post = std::accumulate(x_dims_array + axis + 1, x_dims_array + max_dim,
|
|
|
|
|
1, std::multiplies<int>());
|
|
|
|
|
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
|
|
|
|
|
int grid_size = pre * post;
|
|
|
|
|
if (k_pre != pre) k_pre = pre / k_pre;
|
|
|
|
|
|
|
|
|
|
FastCommonGradBroadcastOneCUDAKernel<<<grid_size, block_size, 0,
|
|
|
|
|
stream>>>(
|
|
|
|
|
x_data, y_data, out_data, dout_data, pre, mid, post, k_pre, k_mid,
|
|
|
|
|
k_post, false, dy_op, dy_data);
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "FastBroadCastOneCUDAF pre:" << pre << " mid:" << mid
|
|
|
|
|
<< " post:" << post;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// do fast elementwise if: 1. only one input need to do broadcast, we can
|
|
|
|
|
// fallback
|
|
|
|
|
// to old fast path.
|
|
|
|
@ -781,7 +911,9 @@ void CommonGradBroadcastCUDA(
|
|
|
|
|
LOG(ERROR) << "Error, broadcast should not into w broadcast";
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
} else if (y_broadcast_pos.size() == 1) { // for only one dim broadcast.
|
|
|
|
|
} else if (y_broadcast_pos.size() == 1 ||
|
|
|
|
|
CheckContiguousDims(y_broadcast_pos)) { // for only one dim and
|
|
|
|
|
// contiguous broadcast.
|
|
|
|
|
// If cannot split, which means input has 3 parts
|
|
|
|
|
FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true);
|
|
|
|
|
return;
|
|
|
|
@ -797,7 +929,8 @@ void CommonGradBroadcastCUDA(
|
|
|
|
|
LOG(ERROR) << "Error, broadcast should not into w broadcast";
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
} else if (x_broadcast_pos.size() == 1) {
|
|
|
|
|
} else if (x_broadcast_pos.size() == 1 ||
|
|
|
|
|
CheckContiguousDims(x_broadcast_pos)) {
|
|
|
|
|
FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -812,6 +945,9 @@ void CommonGradBroadcastCUDA(
|
|
|
|
|
// finish at end
|
|
|
|
|
LOG(ERROR) << "Error, broadcast should not into w broadcast";
|
|
|
|
|
}
|
|
|
|
|
} else if (y_broadcast_pos.size() == 1) {
|
|
|
|
|
FastBroadCastOneCUDAF(y_broadcast_pos, max_dim, false);
|
|
|
|
|
can_split_y = true;
|
|
|
|
|
}
|
|
|
|
|
can_split_x = SplitDims(x_broadcast_pos, max_dim);
|
|
|
|
|
if (can_split_x) {
|
|
|
|
@ -820,6 +956,9 @@ void CommonGradBroadcastCUDA(
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "Error, broadcast should not into w broadcast";
|
|
|
|
|
}
|
|
|
|
|
} else if (x_broadcast_pos.size() == 1) {
|
|
|
|
|
FastBroadCastOneCUDAF(x_broadcast_pos, max_dim, true);
|
|
|
|
|
can_split_x = true;
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y
|
|
|
|
|
<< " can_split_x:" << can_split_x;
|
|
|
|
@ -1492,6 +1631,10 @@ void CommonElementwiseBroadcastBackward(
|
|
|
|
|
dx->mutable_data<T>(x_dims, ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "CommonElementwiseBroadcastBackward xdims:"
|
|
|
|
|
<< framework::make_ddim(x_dims_array)
|
|
|
|
|
<< " ydim:" << framework::make_ddim(y_dims_array);
|
|
|
|
|
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
CommonGradBroadcastCUDA<T, DX_OP, DY_OP>(
|
|
|
|
|