|
|
@ -27,30 +27,47 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
|
|
|
|
auto* in = context.Input<LoDTensor>("X");
|
|
|
|
auto* in = context.Input<LoDTensor>("X");
|
|
|
|
auto* out = context.Output<LoDTensor>("Out");
|
|
|
|
auto* out = context.Output<LoDTensor>("Out");
|
|
|
|
int win_size = context.Attr<int>("win_size");
|
|
|
|
int win_size = context.Attr<int>("win_size");
|
|
|
|
int pad_value = context.Attr<int>("pad_value");
|
|
|
|
auto pad_value = static_cast<T>(context.Attr<int>("pad_value"));
|
|
|
|
|
|
|
|
|
|
|
|
auto in_dims = in->dims();
|
|
|
|
auto in_dims = in->dims();
|
|
|
|
auto in_lod = in->lod();
|
|
|
|
auto lod0 = in->lod()[0];
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
|
|
|
|
static_cast<uint64_t>(in_dims[0]), lod0.back(),
|
|
|
|
"The actual input data's size mismatched with LoD information.");
|
|
|
|
"The actual input data's size mismatched with LoD information.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
in_dims.size(), 2UL,
|
|
|
|
|
|
|
|
"Input(X) of SequenceEnumerate operator's rank should be 2.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dims[1], 1,
|
|
|
|
|
|
|
|
"Input(X) of SequenceEnumerate operator's 2nd "
|
|
|
|
|
|
|
|
"dimension should be 1.");
|
|
|
|
|
|
|
|
|
|
|
|
// Generate enumerate sequence set
|
|
|
|
// Generate enumerate sequence set
|
|
|
|
auto lod0 = in_lod[0];
|
|
|
|
|
|
|
|
auto in_data = in->data<T>();
|
|
|
|
auto in_data = in->data<T>();
|
|
|
|
out->Resize({in_dims[0], win_size});
|
|
|
|
out->Resize({in_dims[0], win_size});
|
|
|
|
|
|
|
|
out->set_lod(in->lod());
|
|
|
|
auto out_data = out->mutable_data<T>(context.GetPlace());
|
|
|
|
auto out_data = out->mutable_data<T>(context.GetPlace());
|
|
|
|
for (size_t i = 0; i < lod0.size() - 1; ++i) {
|
|
|
|
for (size_t i = 0; i < lod0.size() - 1; ++i) {
|
|
|
|
for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) {
|
|
|
|
int start = lod0[i];
|
|
|
|
for (int word_idx = 0; word_idx < win_size; ++word_idx) {
|
|
|
|
int end = lod0[i + 1];
|
|
|
|
size_t word_pos = idx + word_idx;
|
|
|
|
int copy_size = win_size < end - start + 1 ? win_size : end - start + 1;
|
|
|
|
out_data[win_size * idx + word_idx] =
|
|
|
|
int mid = end + 1 - copy_size;
|
|
|
|
word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value;
|
|
|
|
int pad_num = win_size - copy_size;
|
|
|
|
|
|
|
|
copy_size *= sizeof(T);
|
|
|
|
|
|
|
|
for (int idx = start; idx < mid; ++idx) {
|
|
|
|
|
|
|
|
std::memcpy(out_data, in_data + idx, copy_size);
|
|
|
|
|
|
|
|
out_data += win_size;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int idx = mid; idx < end; ++idx) {
|
|
|
|
|
|
|
|
copy_size -= sizeof(T);
|
|
|
|
|
|
|
|
pad_num++;
|
|
|
|
|
|
|
|
std::memcpy(out_data, in_data + idx, copy_size);
|
|
|
|
|
|
|
|
T* pdata = out_data + copy_size / sizeof(T);
|
|
|
|
|
|
|
|
for (int i = 0; i < pad_num; ++i) {
|
|
|
|
|
|
|
|
pdata[i] = pad_value;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
out_data += win_size;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
out->set_lod(in->lod());
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|