|
|
|
@ -49,10 +49,16 @@ __global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
|
|
|
|
|
template <typename T, typename IndexT = int>
|
|
|
|
|
void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
|
|
|
|
|
const Tensor& index, Tensor* output) {
|
|
|
|
|
// PADDLE_ENFORCE(platform::is_gpu_place(place));
|
|
|
|
|
// check index of shape 1-D
|
|
|
|
|
PADDLE_ENFORCE(index.dims().size() == 1 ||
|
|
|
|
|
(index.dims().size() == 2 && index.dims()[1] == 1));
|
|
|
|
|
if (index.dims().size() == 1) {
|
|
|
|
|
PADDLE_ENFORCE_GT(index.dims()[0], 0,
|
|
|
|
|
"The index of gather_op should not be empty when the "
|
|
|
|
|
"index's rank is 1.");
|
|
|
|
|
} else if (index.dims().size() == 2) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(index.dims()[1], 1,
|
|
|
|
|
" If the index's rank of gather_op is 2, the second "
|
|
|
|
|
"dimension should be 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int index_size = index.dims()[0];
|
|
|
|
|
|
|
|
|
|