|
|
|
@ -22,7 +22,6 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/selected_rows.h"
|
|
|
|
|
#include "paddle/fluid/operators/jit/kernels.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -47,7 +46,7 @@ struct EmbeddingVSumFunctor {
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_LE(table_width * idx_width, out_width);
|
|
|
|
|
PADDLE_ENFORCE_GT(ids_lod.size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_GT(ids_lod.size(), 1UL, "The LoD[0] could NOT be empty");
|
|
|
|
|
|
|
|
|
|
jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width,
|
|
|
|
|
out_width, jit::SeqPoolType::kSum);
|
|
|
|
@ -83,11 +82,11 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
|
|
|
|
|
FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims());
|
|
|
|
|
const auto &ids_lod = ids_t->lod();
|
|
|
|
|
// in run time, the LoD of ids must be 1
|
|
|
|
|
PADDLE_ENFORCE(ids_lod.size(), 1u, "The LoD level of Input(Ids) must be 1");
|
|
|
|
|
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty");
|
|
|
|
|
PADDLE_ENFORCE(ids_lod.size(), 1UL,
|
|
|
|
|
"The LoD level of Input(Ids) must be 1");
|
|
|
|
|
int64_t batch_size = ids_lod[0].size() - 1;
|
|
|
|
|
// in run time, the shape from Ids -> output
|
|
|
|
|
// should be [seq_length, 1] -> [batch_size, embedding_size]
|
|
|
|
|
// should be [seq_length, 1] -> [batch_size, last_dim]
|
|
|
|
|
output_t->Resize({batch_size, last_dim});
|
|
|
|
|
|
|
|
|
|
if (combiner_type == "sum") {
|
|
|
|
@ -125,7 +124,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto *ids_data = ids->data<int64_t>();
|
|
|
|
|
int64_t ids_num = ids->numel();
|
|
|
|
|
auto lod = ids->lod()[0];
|
|
|
|
|
int64_t row_width = d_output->dims()[1];
|
|
|
|
|
int64_t out_width = d_output->dims()[1];
|
|
|
|
|
|
|
|
|
|
framework::Vector<int64_t> *new_rows = d_table->mutable_rows();
|
|
|
|
|
new_rows->resize(ids_num);
|
|
|
|
@ -136,15 +135,13 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
|
|
|
|
|
const T *d_output_data = d_output->data<T>();
|
|
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
|
|
|
|
auto vbroadcast = jit::Get<jit::kVBroadcast, jit::VBroadcastTuples<T>,
|
|
|
|
|
platform::CPUPlace>(out_width);
|
|
|
|
|
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
|
|
|
|
|
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
|
|
|
|
|
int64_t in_offset = lod[i] * row_width;
|
|
|
|
|
const T *out_pos = d_output_data + i * row_width;
|
|
|
|
|
T *in_pos = d_table_data + in_offset;
|
|
|
|
|
for (int r = 0; r != h; ++r) {
|
|
|
|
|
blas.VCOPY(row_width, out_pos, in_pos + r * row_width);
|
|
|
|
|
}
|
|
|
|
|
const T *src = d_output_data + i * out_width;
|
|
|
|
|
T *dst = d_table_data + lod[i] * out_width;
|
|
|
|
|
vbroadcast(src, dst, h, out_width);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now";
|
|
|
|
|