mobile_baidu
typhoonzero 8 years ago
parent 579c92abc3
commit 00360e7eb5

@ -74,8 +74,9 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(8, 1); dim3 grids(8, 1);
LookupTable<T, 128, 8, LookupTable<
8><<<grids, threads, 0, context.device_context().stream()>>>( T, 128, 8,
8><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D); output, table, ids, N, K, D);
} }
}; };
@ -135,7 +136,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
dim3 grids(8, 1); dim3 grids(8, 1);
LookupTableGrad< LookupTableGrad<
T, 128, 8, T, 128, 8,
8><<<grids, threads, 0, context.device_context().stream()>>>( 8><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
d_table, d_output, ids, N, K, D); d_table, d_output, ids, N, K, D);
} }
} }

Loading…
Cancel
Save