fix the broadcast for the large second input (#30818)

fix the broadcast for the large second input
revert-31068-fix_conv3d_windows
wawltor 4 years ago committed by GitHub
parent 6e1e036a75
commit b7560a59ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -708,10 +708,10 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel(
int x_offset = b_i * post + b_j; int x_offset = b_i * post + b_j;
if (dy) { if (dy) {
dy[y_offset] = dy[y_offset] =
dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); dy_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]);
} }
if (dx) { if (dx) {
val += dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); val += dx_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]);
} }
} }
if (dx) { if (dx) {
@ -1674,7 +1674,6 @@ void CommonElementwiseBroadcastBackward(
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), max_dim, y_dims_array.data(), out_dims_array.data(), max_dim,
axis); axis);
// for inplace strategy. memset will make dx and dout clear and get wrong // for inplace strategy. memset will make dx and dout clear and get wrong
// result. // result.
if (dx && dx->IsSharedBufferWith(dout)) { if (dx && dx->IsSharedBufferWith(dout)) {
@ -1762,7 +1761,6 @@ void ElemwiseGradComputeWithBroadcast(
get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post, get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post,
&is_run_common_broadcast); &is_run_common_broadcast);
} }
// special case for common backward implementation. // special case for common backward implementation.
if (is_run_common_broadcast) { if (is_run_common_broadcast) {
CommonElementwiseBroadcastBackward<DeviceContext, T, DX_OP, DY_OP>( CommonElementwiseBroadcastBackward<DeviceContext, T, DX_OP, DY_OP>(

@ -381,6 +381,16 @@ class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp):
self.axis = 2 self.axis = 2
class TestElementwiseAddOp_same_shape_ysize_large(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(10, 1, 12).astype(self.dtype)
self.y = np.random.rand(10, 3, 12).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = 0
class TestElementwiseAddOpError(unittest.TestCase): class TestElementwiseAddOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):

Loading…
Cancel
Save