|
|
@ -53,25 +53,27 @@ struct SequenceExpandFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
|
|
|
|
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
|
|
|
|
LoDTensor* out) {
|
|
|
|
LoDTensor* out) {
|
|
|
|
int out_offset = 0;
|
|
|
|
int out_offset = 0;
|
|
|
|
auto& eigen_place = *context.eigen_device();
|
|
|
|
int x_item_length = x.numel() / x.dims()[0];
|
|
|
|
|
|
|
|
auto out_data = out->data<T>();
|
|
|
|
|
|
|
|
auto x_data = x.data<T>();
|
|
|
|
for (size_t i = 1; i < ref_lod.size(); ++i) {
|
|
|
|
for (size_t i = 1; i < ref_lod.size(); ++i) {
|
|
|
|
int repeat_num = ref_lod[i] - ref_lod[i - 1];
|
|
|
|
int repeat_num = ref_lod[i] - ref_lod[i - 1];
|
|
|
|
int x_start = x_lod[i - 1];
|
|
|
|
int x_start = x_lod[i - 1];
|
|
|
|
int x_end = x_lod[i];
|
|
|
|
int x_end = x_lod[i];
|
|
|
|
int x_seq_len = x_end - x_start;
|
|
|
|
int x_seq_len = x_end - x_start;
|
|
|
|
if (repeat_num > 0) {
|
|
|
|
if (repeat_num > 0) {
|
|
|
|
auto x_sub_tensor = x.Slice(x_start, x_end);
|
|
|
|
|
|
|
|
x_sub_tensor.Resize({1, x_sub_tensor.numel()});
|
|
|
|
|
|
|
|
int out_start = out_offset;
|
|
|
|
int out_start = out_offset;
|
|
|
|
if (out->lod().size() == 1) {
|
|
|
|
if (out->lod().size() == 1) {
|
|
|
|
out_start = out->lod()[0][out_offset];
|
|
|
|
out_start = out->lod()[0][out_offset];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto out_sub_tensor =
|
|
|
|
for (int j = 0; j < repeat_num; j++) {
|
|
|
|
out->Slice(out_start, out_start + x_seq_len * repeat_num);
|
|
|
|
for (int k = 0; k < x_seq_len; k++) {
|
|
|
|
out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]});
|
|
|
|
for (int l = 0; l < x_item_length; l++) {
|
|
|
|
EigenMatrix<T>::From(out_sub_tensor).device(eigen_place) =
|
|
|
|
out_data[(out_start + j * x_seq_len + k) * x_item_length + l] =
|
|
|
|
EigenMatrix<T>::From(x_sub_tensor)
|
|
|
|
x_data[(x_start + k) * x_item_length + l];
|
|
|
|
.broadcast(Eigen::array<int, 2>({{repeat_num, 1}}));
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
out_offset += repeat_num;
|
|
|
|
out_offset += repeat_num;
|
|
|
|
}
|
|
|
|
}
|
|
|
|