|
|
|
@ -1040,22 +1040,21 @@ void CommonGradBroadcastCUDA(
|
|
|
|
|
// fallback
|
|
|
|
|
// to old fast path.
|
|
|
|
|
// 2. if both x and y need broadcast, then do it one by one.
|
|
|
|
|
bool fast_broadcast = false;
|
|
|
|
|
if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
|
|
|
|
|
can_split_y = SplitDims(y_broadcast_pos, max_dim);
|
|
|
|
|
if (can_split_y) {
|
|
|
|
|
// only y need to do broadcast on h
|
|
|
|
|
if (y_broadcast_pos[0] == 0) {
|
|
|
|
|
FastBroadCastHeightCUDAF(y_broadcast_pos, true);
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "Error, broadcast should not into w broadcast";
|
|
|
|
|
fast_broadcast = true;
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
} else if (y_broadcast_pos.size() == 1 ||
|
|
|
|
|
CheckContiguousDims(y_broadcast_pos)) { // for only one dim and
|
|
|
|
|
// contiguous broadcast.
|
|
|
|
|
// If cannot split, which means input has 3 parts
|
|
|
|
|
FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true);
|
|
|
|
|
return;
|
|
|
|
|
fast_broadcast = true;
|
|
|
|
|
}
|
|
|
|
|
} else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) {
|
|
|
|
|
// only x need broadcast
|
|
|
|
@ -1063,49 +1062,53 @@ void CommonGradBroadcastCUDA(
|
|
|
|
|
if (can_split_x) {
|
|
|
|
|
if (x_broadcast_pos[0] == 0) {
|
|
|
|
|
FastBroadCastHeightCUDAF(x_broadcast_pos, false);
|
|
|
|
|
} else {
|
|
|
|
|
// x need to do broadcast on w
|
|
|
|
|
LOG(ERROR) << "Error, broadcast should not into w broadcast";
|
|
|
|
|
fast_broadcast = true;
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
} else if (x_broadcast_pos.size() == 1 ||
|
|
|
|
|
CheckContiguousDims(x_broadcast_pos)) {
|
|
|
|
|
FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false);
|
|
|
|
|
return;
|
|
|
|
|
fast_broadcast = true;
|
|
|
|
|
}
|
|
|
|
|
} else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
|
|
|
|
|
// do x and y broadcast each.
|
|
|
|
|
can_split_y = SplitDims(y_broadcast_pos, max_dim);
|
|
|
|
|
bool fast_broadcast_x = false;
|
|
|
|
|
bool fast_broadcast_y = false;
|
|
|
|
|
if (can_split_y) {
|
|
|
|
|
// begin at start.
|
|
|
|
|
if (y_broadcast_pos[0] == 0) {
|
|
|
|
|
FastCommonCUDAF(y_broadcast_pos, true);
|
|
|
|
|
} else {
|
|
|
|
|
// finish at end
|
|
|
|
|
LOG(ERROR) << "Error, broadcast should not into w broadcast";
|
|
|
|
|
fast_broadcast_y = true;
|
|
|
|
|
}
|
|
|
|
|
} else if (y_broadcast_pos.size() == 1) {
|
|
|
|
|
FastBroadCastOneCUDAF(y_broadcast_pos, max_dim, false);
|
|
|
|
|
can_split_y = true;
|
|
|
|
|
fast_broadcast_y = true;
|
|
|
|
|
}
|
|
|
|
|
can_split_x = SplitDims(x_broadcast_pos, max_dim);
|
|
|
|
|
if (can_split_x) {
|
|
|
|
|
if (x_broadcast_pos[0] == 0) {
|
|
|
|
|
FastCommonCUDAF(x_broadcast_pos, false);
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "Error, broadcast should not into w broadcast";
|
|
|
|
|
fast_broadcast_x = true;
|
|
|
|
|
}
|
|
|
|
|
} else if (x_broadcast_pos.size() == 1) {
|
|
|
|
|
FastBroadCastOneCUDAF(x_broadcast_pos, max_dim, true);
|
|
|
|
|
can_split_x = true;
|
|
|
|
|
fast_broadcast_x = true;
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y
|
|
|
|
|
<< " can_split_x:" << can_split_x;
|
|
|
|
|
// if both x and y into fast path then return
|
|
|
|
|
if (can_split_y && can_split_x) return;
|
|
|
|
|
if (fast_broadcast_x && fast_broadcast_y) {
|
|
|
|
|
fast_broadcast = true;
|
|
|
|
|
}
|
|
|
|
|
if (can_split_y && can_split_x && fast_broadcast) return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Should remove memory copy, use reg instead.
|
|
|
|
|
if (fast_broadcast) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
int x_blocks = 0;
|
|
|
|
|
int x_threads = 0;
|
|
|
|
|
ComputeBroadcastKernelSize(x_dims_array, out_dims_array, &x_blocks,
|
|
|
|
@ -1136,7 +1139,7 @@ void CommonGradBroadcastCUDA(
|
|
|
|
|
1, std::multiplies<int>());
|
|
|
|
|
int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads);
|
|
|
|
|
int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads);
|
|
|
|
|
if (dx && !can_split_x) {
|
|
|
|
|
if (dx) {
|
|
|
|
|
auto x_strides_order_tmp = memory::Alloc(ctx, bytes);
|
|
|
|
|
int *x_strides_order_gpu =
|
|
|
|
|
reinterpret_cast<int *>(x_strides_order_tmp->ptr());
|
|
|
|
@ -1153,7 +1156,7 @@ void CommonGradBroadcastCUDA(
|
|
|
|
|
x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data,
|
|
|
|
|
dout_data, dx_data, out_size, max_dim, x_threads, dx_op);
|
|
|
|
|
}
|
|
|
|
|
if (dy && !can_split_y) {
|
|
|
|
|
if (dy) {
|
|
|
|
|
auto y_strides_order_tmp = memory::Alloc(ctx, bytes);
|
|
|
|
|
int *y_strides_order_gpu =
|
|
|
|
|
reinterpret_cast<int *>(y_strides_order_tmp->ptr());
|
|
|
|
|