mobile_baidu
typhoonzero 8 years ago
parent 579c92abc3
commit 00360e7eb5

@ -74,8 +74,9 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTable<T, 128, 8,
8><<<grids, threads, 0, context.device_context().stream()>>>(
LookupTable<
T, 128, 8,
8><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D);
}
};
@ -135,7 +136,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
dim3 grids(8, 1);
LookupTableGrad<
T, 128, 8,
8><<<grids, threads, 0, context.device_context().stream()>>>(
8><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
d_table, d_output, ids, N, K, D);
}
}

Loading…
Cancel
Save