|
|
|
@ -58,6 +58,19 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename IndType>
|
|
|
|
|
static __global__ void FillGrad(const T* dO, const IndType* indices, T* dX,
|
|
|
|
|
IndType num_rows, IndType num_cols) {
|
|
|
|
|
int col_id = threadIdx.x;
|
|
|
|
|
int row_id = blockIdx.x;
|
|
|
|
|
|
|
|
|
|
for (IndType j = row_id; j < num_rows; j += gridDim.x) {
|
|
|
|
|
for (IndType i = col_id; i < num_cols; i += blockDim.x) {
|
|
|
|
|
dX[j * num_cols + indices[j * num_cols + i]] = dO[j * num_cols + i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Sort by flag descending, True: descending. False: Ascending.
|
|
|
|
|
// Default is false.
|
|
|
|
|
template <typename T, typename IndType>
|
|
|
|
@ -160,6 +173,35 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
|
|
|
|
|
temp_storage_bytes, cudaGetErrorString(err));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename IndType>
|
|
|
|
|
void ArgFullAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO,
|
|
|
|
|
const Tensor* indices, Tensor* dX, const IndType num_rows,
|
|
|
|
|
const IndType num_cols) {
|
|
|
|
|
auto cu_stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
auto ComputeBlockSize = [](IndType col) {
|
|
|
|
|
if (col > 512)
|
|
|
|
|
return 1024;
|
|
|
|
|
else if (col > 256 && col <= 512)
|
|
|
|
|
return 512;
|
|
|
|
|
else if (col > 128 && col <= 256)
|
|
|
|
|
return 256;
|
|
|
|
|
else if (col > 64 && col <= 128)
|
|
|
|
|
return 128;
|
|
|
|
|
else
|
|
|
|
|
return 64;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
int block_size = ComputeBlockSize(num_cols);
|
|
|
|
|
|
|
|
|
|
int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
|
|
|
|
|
// actually, int num_rows < max_grid_size
|
|
|
|
|
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
|
|
|
|
|
FillGrad<<<grid_size, block_size, 0, cu_stream>>>(
|
|
|
|
|
dO->data<T>(), indices->data<IndType>(), dX->data<T>(), num_rows,
|
|
|
|
|
num_cols);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -234,6 +276,81 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ArgsortGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* indices = ctx.Input<Tensor>("Indices");
|
|
|
|
|
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
|
|
|
|
|
dX->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto dxt = framework::EigenVector<T>::Flatten(*dX);
|
|
|
|
|
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
|
|
|
|
|
.eigen_device();
|
|
|
|
|
dxt.device(place) = dxt.constant(static_cast<T>(0));
|
|
|
|
|
if (dO->numel() == 0) return;
|
|
|
|
|
|
|
|
|
|
auto in_dims = indices->dims();
|
|
|
|
|
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
|
|
|
|
|
|
|
|
|
|
int64_t numel = indices->numel();
|
|
|
|
|
|
|
|
|
|
// Special case for full sort, speedup ~190x.
|
|
|
|
|
if (axis == -1 || axis + 1 == in_dims.size()) {
|
|
|
|
|
const int64_t input_height = framework::product(
|
|
|
|
|
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
|
|
|
|
|
const int64_t input_width = in_dims[in_dims.size() - 1];
|
|
|
|
|
const auto& dev_ctx = ctx.cuda_device_context();
|
|
|
|
|
ArgFullAssign<T, int64_t>(dev_ctx, dO, indices, dX, input_height,
|
|
|
|
|
input_width);
|
|
|
|
|
} else {
|
|
|
|
|
// if not full sort, do transpose first
|
|
|
|
|
std::vector<int> trans;
|
|
|
|
|
for (int i = 0; i < axis; i++) {
|
|
|
|
|
trans.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
trans.push_back(in_dims.size() - 1);
|
|
|
|
|
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
|
|
|
|
|
trans.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
trans.push_back(axis);
|
|
|
|
|
framework::DDim trans_dims(in_dims);
|
|
|
|
|
for (int i = 0; i < trans.size(); i++) {
|
|
|
|
|
trans_dims[i] = in_dims[trans[i]];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor trans_dO;
|
|
|
|
|
trans_dO.mutable_data<T>(trans_dims, ctx.GetPlace());
|
|
|
|
|
Tensor trans_ind;
|
|
|
|
|
trans_ind.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
|
|
|
|
|
int ndims = trans.size();
|
|
|
|
|
const auto& dev_ctx = ctx.cuda_device_context();
|
|
|
|
|
// Do transpose
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *dO,
|
|
|
|
|
&trans_dO, trans);
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, int64_t>(
|
|
|
|
|
ndims, dev_ctx, *indices, &trans_ind, trans);
|
|
|
|
|
|
|
|
|
|
const int64_t input_height = framework::product(
|
|
|
|
|
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
|
|
|
|
|
const int64_t input_width = trans_dims[trans_dims.size() - 1];
|
|
|
|
|
|
|
|
|
|
Tensor tmp_out;
|
|
|
|
|
tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
ArgFullAssign<T, int64_t>(dev_ctx, &trans_dO, &trans_ind, &tmp_out,
|
|
|
|
|
input_height, input_width);
|
|
|
|
|
|
|
|
|
|
// transpose back
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out, dX,
|
|
|
|
|
trans);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -243,3 +360,9 @@ REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
paddle::operators::ArgsortOpCUDAKernel<int>,
|
|
|
|
|
paddle::operators::ArgsortOpCUDAKernel<int64_t>,
|
|
|
|
|
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::float16>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
argsort_grad, paddle::operators::ArgsortGradOpCUDAKernel<float>,
|
|
|
|
|
paddle::operators::ArgsortGradOpCUDAKernel<double>,
|
|
|
|
|
paddle::operators::ArgsortGradOpCUDAKernel<int>,
|
|
|
|
|
paddle::operators::ArgsortGradOpCUDAKernel<int64_t>,
|
|
|
|
|
paddle::operators::ArgsortGradOpCUDAKernel<paddle::platform::float16>);
|
|
|
|
|