|
|
|
@ -335,6 +335,7 @@ __global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad,
|
|
|
|
|
for (size_t j = 0; j < cols; ++j) {
|
|
|
|
|
x_grad[i * cols + j] = 0;
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
for (size_t j = 0; j < k; ++j) {
|
|
|
|
|
size_t idx = indices[i * k + j];
|
|
|
|
|
x_grad[i * cols + idx] = out_grad[i * k + j];
|
|
|
|
@ -349,15 +350,16 @@ __global__ void AssignGradWithAxis(const T* grad_out, const int64_t* indices,
|
|
|
|
|
int raw_height, int k) {
|
|
|
|
|
// raw_height is the length of topk axis
|
|
|
|
|
for (int i = blockIdx.x; i < pre; i += gridDim.x) {
|
|
|
|
|
const int& base_index = i * post * k;
|
|
|
|
|
const int& base_grad = i * post * raw_height;
|
|
|
|
|
int base_index = i * post * k;
|
|
|
|
|
int base_grad = i * post * raw_height;
|
|
|
|
|
for (int j = threadIdx.x; j < raw_height * post; j += blockDim.x) {
|
|
|
|
|
grad_in[base_grad + j] = static_cast<T>(0);
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
for (int j = threadIdx.x; j < k * post; j += blockDim.x) {
|
|
|
|
|
const int64_t idx_ij = indices[base_index + j];
|
|
|
|
|
const int64_t in_ij = base_grad + (idx_ij * post) + (j % post);
|
|
|
|
|
grad_in[in_ij] = grad_out[idx_ij];
|
|
|
|
|
int64_t idx_ij = indices[base_index + j];
|
|
|
|
|
int64_t in_ij = base_grad + (idx_ij * post) + (j % post);
|
|
|
|
|
grad_in[in_ij] = grad_out[base_index + j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|