|
|
|
@ -46,9 +46,9 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output,
|
|
|
|
|
* return: output tensor
|
|
|
|
|
*/
|
|
|
|
|
template <typename T>
|
|
|
|
|
void GPUGather(const Place& place, const Tensor* src, const Tensor* index,
|
|
|
|
|
Tensor* output) {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(place));
|
|
|
|
|
void GPUGather(const platform::DeviceContext& ctx, const Tensor* src,
|
|
|
|
|
const Tensor* index, Tensor* output) {
|
|
|
|
|
// PADDLE_ENFORCE(platform::is_gpu_place(place));
|
|
|
|
|
// check index of shape 1-D
|
|
|
|
|
PADDLE_ENFORCE(index->dims().size() == 1);
|
|
|
|
|
int index_size = index->dims()[0];
|
|
|
|
@ -68,8 +68,11 @@ void GPUGather(const Place& place, const Tensor* src, const Tensor* index,
|
|
|
|
|
int block = 512;
|
|
|
|
|
int n = slice_size * index_size;
|
|
|
|
|
int grid = (n + block - 1) / block;
|
|
|
|
|
GatherCUDAKernel<T><<<grid, block>>>(p_src, p_index, p_output, index_size,
|
|
|
|
|
slice_size);
|
|
|
|
|
|
|
|
|
|
GatherCUDAKernel<T><<<
|
|
|
|
|
grid, block, 0,
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
|
|
|
|
|
p_src, p_index, p_output, index_size, slice_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|