|
|
|
@ -31,38 +31,54 @@ using LoDTensor = framework::LoDTensor;
|
|
|
|
|
using SelectedRows = framework::SelectedRows;
|
|
|
|
|
using DDim = framework::DDim;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void emb_seqpool(const framework::ExecutionContext &context, const T *table,
|
|
|
|
|
const int64_t *idx, T *out, int64_t table_height,
|
|
|
|
|
int64_t table_width, int64_t idx_height, int64_t idx_width,
|
|
|
|
|
int64_t out_width) { // pool type == sum
|
|
|
|
|
PADDLE_ENFORCE_EQ(table_width * idx_width, out_width);
|
|
|
|
|
|
|
|
|
|
auto check_idx_value_valid = [&](int i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(idx[i], table_height, "idx value: %d, i: %d", idx[i], i);
|
|
|
|
|
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
|
|
|
|
|
};
|
|
|
|
|
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
|
|
|
|
|
|
|
|
|
for (int w = 0; w != idx_width; ++w) {
|
|
|
|
|
check_idx_value_valid(w);
|
|
|
|
|
blas.VCOPY(table_width, table + idx[w] * table_width,
|
|
|
|
|
out + w * table_width);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int h = 1; h < idx_height; ++h) {
|
|
|
|
|
for (int w = 0; w < idx_width; ++w) {
|
|
|
|
|
int i = h * idx_width + w;
|
|
|
|
|
check_idx_value_valid(i);
|
|
|
|
|
blas.AXPY(table_width, static_cast<T>(1), table + idx[i] * table_width,
|
|
|
|
|
out + w * table_width);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct EmbeddingVSumFunctor {
|
|
|
|
|
void operator()(const framework::ExecutionContext &context,
|
|
|
|
|
const LoDTensor *table_t, const LoDTensor *ids_t,
|
|
|
|
|
LoDTensor *output_t) {
|
|
|
|
|
auto *table = table_t->data<T>();
|
|
|
|
|
int64_t row_number = table_t->dims()[0];
|
|
|
|
|
int64_t row_width = table_t->dims()[1];
|
|
|
|
|
int64_t last_dim = output_t->dims()[1];
|
|
|
|
|
int64_t table_height = table_t->dims()[0];
|
|
|
|
|
int64_t table_width = table_t->dims()[1];
|
|
|
|
|
int64_t out_width = output_t->dims()[1];
|
|
|
|
|
const int64_t *ids = ids_t->data<int64_t>();
|
|
|
|
|
auto ids_lod = ids_t->lod()[0];
|
|
|
|
|
int64_t ids_count = ids_t->numel() / ids_lod.back();
|
|
|
|
|
|
|
|
|
|
int64_t idx_width = ids_t->numel() / ids_lod.back();
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
|
|
|
|
PADDLE_ENFORCE_LE(table_width * idx_width, out_width);
|
|
|
|
|
for (int64_t i = 0; i != ids_lod.size() - 1; ++i) {
|
|
|
|
|
size_t begin = ids_lod[i] * ids_count;
|
|
|
|
|
for (int64_t j = 0; j != ids_count; ++j) {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[begin], row_number);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i);
|
|
|
|
|
blas.VCOPY(row_width, table + ids[begin + j] * row_width,
|
|
|
|
|
output + i * last_dim + j * row_width);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int64_t r = (ids_lod[i] + 1) * ids_count;
|
|
|
|
|
r < ids_lod[i + 1] * ids_count; ++r) {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[r], row_number);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[r], 0, "ids %d", i);
|
|
|
|
|
blas.AXPY(row_width, 1., table + ids[r] * row_width,
|
|
|
|
|
output + i * last_dim + (r % ids_count) * row_width);
|
|
|
|
|
}
|
|
|
|
|
emb_seqpool(context, table, ids + ids_lod[i] * idx_width,
|
|
|
|
|
output + i * out_width, table_height, table_width,
|
|
|
|
|
ids_lod[i + 1] - ids_lod[i], idx_width, out_width);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|