|
|
@ -708,10 +708,10 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel(
|
|
|
|
int x_offset = b_i * post + b_j;
|
|
|
|
int x_offset = b_i * post + b_j;
|
|
|
|
if (dy) {
|
|
|
|
if (dy) {
|
|
|
|
dy[y_offset] =
|
|
|
|
dy[y_offset] =
|
|
|
|
dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
|
|
|
|
dy_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (dx) {
|
|
|
|
if (dx) {
|
|
|
|
val += dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
|
|
|
|
val += dx_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (dx) {
|
|
|
|
if (dx) {
|
|
|
@ -1674,7 +1674,6 @@ void CommonElementwiseBroadcastBackward(
|
|
|
|
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
|
|
|
|
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
|
|
|
|
y_dims_array.data(), out_dims_array.data(), max_dim,
|
|
|
|
y_dims_array.data(), out_dims_array.data(), max_dim,
|
|
|
|
axis);
|
|
|
|
axis);
|
|
|
|
|
|
|
|
|
|
|
|
// for inplace strategy. memset will make dx and dout clear and get wrong
|
|
|
|
// for inplace strategy. memset will make dx and dout clear and get wrong
|
|
|
|
// result.
|
|
|
|
// result.
|
|
|
|
if (dx && dx->IsSharedBufferWith(dout)) {
|
|
|
|
if (dx && dx->IsSharedBufferWith(dout)) {
|
|
|
@ -1762,7 +1761,6 @@ void ElemwiseGradComputeWithBroadcast(
|
|
|
|
get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post,
|
|
|
|
get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post,
|
|
|
|
&is_run_common_broadcast);
|
|
|
|
&is_run_common_broadcast);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// special case for common backward implementation.
|
|
|
|
// special case for common backward implementation.
|
|
|
|
if (is_run_common_broadcast) {
|
|
|
|
if (is_run_common_broadcast) {
|
|
|
|
CommonElementwiseBroadcastBackward<DeviceContext, T, DX_OP, DY_OP>(
|
|
|
|
CommonElementwiseBroadcastBackward<DeviceContext, T, DX_OP, DY_OP>(
|
|
|
|