@ -37,7 +37,7 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
const T* tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
if (PaddingFlag) {
if (idx == padding_idx)
if (id == padding_idx)
out[i] = static_cast<T>(0);
else
out[i] = tab[i];