|
|
|
@ -21,6 +21,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/selected_rows.h"
|
|
|
|
|
#include "paddle/fluid/operators/jit/kernels.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -31,35 +32,6 @@ using LoDTensor = framework::LoDTensor;
|
|
|
|
|
using SelectedRows = framework::SelectedRows;
|
|
|
|
|
using DDim = framework::DDim;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void emb_seqpool(const framework::ExecutionContext &context, const T *table,
|
|
|
|
|
const int64_t *idx, T *out, int64_t table_height,
|
|
|
|
|
int64_t table_width, int64_t idx_height, int64_t idx_width,
|
|
|
|
|
int64_t out_width) { // pool type == sum
|
|
|
|
|
PADDLE_ENFORCE_EQ(table_width * idx_width, out_width);
|
|
|
|
|
|
|
|
|
|
auto check_idx_value_valid = [&](int i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(idx[i], table_height, "idx value: %d, i: %d", idx[i], i);
|
|
|
|
|
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
|
|
|
|
|
};
|
|
|
|
|
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
|
|
|
|
|
|
|
|
|
for (int w = 0; w != idx_width; ++w) {
|
|
|
|
|
check_idx_value_valid(w);
|
|
|
|
|
blas.VCOPY(table_width, table + idx[w] * table_width,
|
|
|
|
|
out + w * table_width);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int h = 1; h < idx_height; ++h) {
|
|
|
|
|
for (int w = 0; w < idx_width; ++w) {
|
|
|
|
|
int i = h * idx_width + w;
|
|
|
|
|
check_idx_value_valid(i);
|
|
|
|
|
blas.AXPY(table_width, static_cast<T>(1), table + idx[i] * table_width,
|
|
|
|
|
out + w * table_width);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct EmbeddingVSumFunctor {
|
|
|
|
|
void operator()(const framework::ExecutionContext &context,
|
|
|
|
@ -75,10 +47,15 @@ struct EmbeddingVSumFunctor {
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_LE(table_width * idx_width, out_width);
|
|
|
|
|
|
|
|
|
|
jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width,
|
|
|
|
|
out_width, jit::SeqPoolType::kSum);
|
|
|
|
|
for (int64_t i = 0; i != ids_lod.size() - 1; ++i) {
|
|
|
|
|
emb_seqpool(context, table, ids + ids_lod[i] * idx_width,
|
|
|
|
|
output + i * out_width, table_height, table_width,
|
|
|
|
|
ids_lod[i + 1] - ids_lod[i], idx_width, out_width);
|
|
|
|
|
attr.index_height = ids_lod[i + 1] - ids_lod[i];
|
|
|
|
|
auto emb_seqpool = jit::Get<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>,
|
|
|
|
|
platform::CPUPlace>(attr);
|
|
|
|
|
emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
|
|
|
|
|
&attr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|