|
|
|
@ -74,10 +74,9 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
dim3 threads(128, 8);
|
|
|
|
|
dim3 grids(8, 1);
|
|
|
|
|
LookupTable<T, 128, 8, 8><<<
|
|
|
|
|
grids, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
context.device_context())
|
|
|
|
|
.stream()>>>(output, table, ids, N, K, D);
|
|
|
|
|
LookupTable<T, 128, 8,
|
|
|
|
|
8><<<grids, threads, 0, context.device_context().stream()>>>(
|
|
|
|
|
output, table, ids, N, K, D);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -95,9 +94,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* ids_data = ids->data<int64_t>();
|
|
|
|
|
auto ids_dim = ids->dims();
|
|
|
|
|
|
|
|
|
|
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
context.device_context())
|
|
|
|
|
.stream();
|
|
|
|
|
auto stream = context.cuda_device_context().stream();
|
|
|
|
|
// copy GPU memory to CPU pinned memory
|
|
|
|
|
framework::Vector<int64_t> new_rows;
|
|
|
|
|
new_rows.resize(ids_dim[0]);
|
|
|
|
@ -136,11 +133,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
dim3 threads(128, 8);
|
|
|
|
|
dim3 grids(8, 1);
|
|
|
|
|
LookupTableGrad<T, 128, 8,
|
|
|
|
|
8><<<grids, threads, 0,
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
context.device_context())
|
|
|
|
|
.stream()>>>(d_table, d_output, ids, N, K, D);
|
|
|
|
|
LookupTableGrad<
|
|
|
|
|
T, 128, 8,
|
|
|
|
|
8><<<grids, threads, 0, context.device_context().stream()>>>(
|
|
|
|
|
d_table, d_output, ids, N, K, D);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|