From 8f2656ef5ca4ab16f06d94b8ca9392d3f0f760ae Mon Sep 17 00:00:00 2001 From: wawltor Date: Mon, 16 Nov 2020 20:21:46 +0800 Subject: [PATCH] fix the gradient bug for the topk v2 fix the gradient bug for the topk v2 --- paddle/fluid/operators/top_k_function_cuda.h | 12 ++++--- .../fluid/tests/unittests/test_top_k_v2_op.py | 32 +++++++++++-------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/top_k_function_cuda.h b/paddle/fluid/operators/top_k_function_cuda.h index 57891699fd..0fd5f2ac01 100644 --- a/paddle/fluid/operators/top_k_function_cuda.h +++ b/paddle/fluid/operators/top_k_function_cuda.h @@ -335,6 +335,7 @@ __global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad, for (size_t j = 0; j < cols; ++j) { x_grad[i * cols + j] = 0; } + __syncthreads(); for (size_t j = 0; j < k; ++j) { size_t idx = indices[i * k + j]; x_grad[i * cols + idx] = out_grad[i * k + j]; @@ -349,15 +350,16 @@ __global__ void AssignGradWithAxis(const T* grad_out, const int64_t* indices, int raw_height, int k) { // raw_height is the length of topk axis for (int i = blockIdx.x; i < pre; i += gridDim.x) { - const int& base_index = i * post * k; - const int& base_grad = i * post * raw_height; + int base_index = i * post * k; + int base_grad = i * post * raw_height; for (int j = threadIdx.x; j < raw_height * post; j += blockDim.x) { grad_in[base_grad + j] = static_cast(0); } + __syncthreads(); for (int j = threadIdx.x; j < k * post; j += blockDim.x) { - const int64_t idx_ij = indices[base_index + j]; - const int64_t in_ij = base_grad + (idx_ij * post) + (j % post); - grad_in[in_ij] = grad_out[idx_ij]; + int64_t idx_ij = indices[base_index + j]; + int64_t in_ij = base_grad + (idx_ij * post) + (j % post); + grad_in[in_ij] = grad_out[base_index + j]; } } } diff --git a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py index b9d96f329b..94dcf15115 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py @@ -64,34 +64,38 @@ class TestTopkOp(OpTest): class TestTopkOp1(TestTopkOp): - def init_args(self): - self.k = 3 - self.axis = 0 - self.largest = True - - -class TestTopkOp2(TestTopkOp): def init_args(self): self.k = 3 self.axis = 0 self.largest = False -class TestTopkOp3(TestTopkOp): +class TestTopkOp2(TestTopkOp): def init_args(self): self.k = 4 self.axis = 0 self.largest = False -class TestTopkOp4(TestTopkOp): +class TestTopkOp3(OpTest): def init_args(self): - self.k = 4 - self.axis = 0 - self.largest = False + self.k = 6 + self.axis = 1 + self.largest = True + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float64 + self.input_data = np.random.rand(16, 100) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} -class TestTopkOp5(TestTopkOp): + +class TestTopkOp4(TestTopkOp): def init_args(self): self.k = 3 self.axis = 1 @@ -109,7 +113,7 @@ class TestTopkOp5(TestTopkOp): self.outputs = {'Out': output, 'Indices': indices} -class TestTopkOp6(TestTopkOp): +class TestTopkOp5(TestTopkOp): def init_args(self): self.k = 3 self.axis = 1