|
|
|
|
@ -31,8 +31,8 @@ __global__ void LookupTable(T *output, const T *table, const int64_t *ids,
|
|
|
|
|
|
|
|
|
|
while (idy < K) {
|
|
|
|
|
int64_t id = ids[idy];
|
|
|
|
|
PADDLE_ASSERT(id >= 0);
|
|
|
|
|
PADDLE_ASSERT(id < N);
|
|
|
|
|
PADDLE_ASSERT_MSG_CODE(id >= 0, "received id:", id);
|
|
|
|
|
PADDLE_ASSERT_MSG_CODE(id < N, "received id:", id);
|
|
|
|
|
T *out = output + idy * D;
|
|
|
|
|
const T *tab = table + id * D;
|
|
|
|
|
for (int i = idx; i < D; i += BlockDimX) {
|
|
|
|
|
@ -57,9 +57,9 @@ __global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
|
|
|
|
|
int idy = blockIdx.x + threadIdx.y * GridDimX;
|
|
|
|
|
|
|
|
|
|
while (idy < K) {
|
|
|
|
|
int id = ids[idy];
|
|
|
|
|
PADDLE_ASSERT(id >= 0);
|
|
|
|
|
PADDLE_ASSERT(id < N);
|
|
|
|
|
int64_t id = ids[idy];
|
|
|
|
|
PADDLE_ASSERT_MSG_CODE(id >= 0, "received id:", id);
|
|
|
|
|
PADDLE_ASSERT_MSG_CODE(id < N, "received id:", id);
|
|
|
|
|
const T *out = output + idy * D;
|
|
|
|
|
T *tab = table + id * D;
|
|
|
|
|
for (int i = idx; i < D; i += BlockDimX) {
|
|
|
|
|
|