|
|
@ -46,8 +46,8 @@ class SequenceReshapeKernel : public framework::OpKernel<T> {
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
auto& out_lod = *out->mutable_lod();
|
|
|
|
auto& out_lod = *out->mutable_lod();
|
|
|
|
out_lod.resize(1);
|
|
|
|
out_lod.resize(1);
|
|
|
|
out_lod[0].clear();
|
|
|
|
out_lod[0].resize(seq_num + 1);
|
|
|
|
out_lod[0].push_back(0);
|
|
|
|
out_lod[0][0] = 0;
|
|
|
|
for (int i = 0; i < seq_num; ++i) {
|
|
|
|
for (int i = 0; i < seq_num; ++i) {
|
|
|
|
size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i];
|
|
|
|
size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i];
|
|
|
|
size_t offset = 0;
|
|
|
|
size_t offset = 0;
|
|
|
@ -57,11 +57,10 @@ class SequenceReshapeKernel : public framework::OpKernel<T> {
|
|
|
|
"be divided by new_dim with no remainder for each "
|
|
|
|
"be divided by new_dim with no remainder for each "
|
|
|
|
"sequence. The %dth sequence is invalid.",
|
|
|
|
"sequence. The %dth sequence is invalid.",
|
|
|
|
i + 1);
|
|
|
|
i + 1);
|
|
|
|
out_lod[0].push_back(out_lod[0].back() + offset);
|
|
|
|
out_lod[0][i + 1] = out_lod[0][i] + offset;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
framework::Copy(*in, context.GetPlace(), out);
|
|
|
|
framework::Copy(*in, context.GetPlace(), out);
|
|
|
|
out->Resize({static_cast<int64_t>(out->lod()[0].back()), out_width});
|
|
|
|
out->Resize({static_cast<int64_t>(out->lod()[0].back()), out_width});
|
|
|
|
}
|
|
|
|
}
|
|
|
|