|
|
|
@ -32,7 +32,8 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto abs_sections = ToAbsoluteSection(height_sections);
|
|
|
|
|
|
|
|
|
|
auto x_rows = x->rows();
|
|
|
|
|
auto& x_rows = x->rows();
|
|
|
|
|
auto height = x->height();
|
|
|
|
|
std::vector<std::vector<int>> outs_rows_idx;
|
|
|
|
|
std::vector<std::vector<int>> outs_dense_idx;
|
|
|
|
|
|
|
|
|
@ -44,8 +45,10 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// split rows index into output sparse vars
|
|
|
|
|
for (size_t i = 0; i < x_rows.size(); ++i) {
|
|
|
|
|
int out_idx = FindOutIdx(x_rows[i], abs_sections);
|
|
|
|
|
outs_rows_idx[out_idx].push_back(x_rows[i]);
|
|
|
|
|
auto& id = x_rows[i];
|
|
|
|
|
PADDLE_ENFORCE_LT(id, height);
|
|
|
|
|
int out_idx = GetSectionIndex(id, abs_sections);
|
|
|
|
|
outs_rows_idx[out_idx].push_back(id);
|
|
|
|
|
outs_dense_idx[out_idx].push_back(i);
|
|
|
|
|
}
|
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
@ -59,7 +62,9 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
outs[i]->mutable_rows()->clear();
|
|
|
|
|
if (rows_idx.size() > 0) {
|
|
|
|
|
for (auto idx : rows_idx) {
|
|
|
|
|
outs[i]->mutable_rows()->push_back(idx - abs_sections[i]);
|
|
|
|
|
auto id_offset = idx - abs_sections[i];
|
|
|
|
|
PADDLE_ENFORCE_LT(id_offset, height_sections[i]);
|
|
|
|
|
outs[i]->mutable_rows()->push_back(id_offset);
|
|
|
|
|
}
|
|
|
|
|
auto dst = outs[i]->mutable_value()->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
for (size_t j = 0; j < rows_idx.size(); j++) {
|
|
|
|
|