|
|
|
@ -37,7 +37,7 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
|
|
|
|
const T* tab = table + id * D;
|
|
|
|
const T* tab = table + id * D;
|
|
|
|
for (int i = idx; i < D; i += BlockDimX) {
|
|
|
|
for (int i = idx; i < D; i += BlockDimX) {
|
|
|
|
if (PaddingFlag) {
|
|
|
|
if (PaddingFlag) {
|
|
|
|
if (idx == padding_idx)
|
|
|
|
if (id == padding_idx)
|
|
|
|
out[i] = static_cast<T>(0);
|
|
|
|
out[i] = static_cast<T>(0);
|
|
|
|
else
|
|
|
|
else
|
|
|
|
out[i] = tab[i];
|
|
|
|
out[i] = tab[i];
|
|
|
|
|