|
|
|
@ -72,10 +72,11 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (size_t i = 0; i < outs_rows_idx.size(); ++i) {
|
|
|
|
|
auto rows_idx = outs_rows_idx[i];
|
|
|
|
|
outs[i]->set_height(height_sections[i]);
|
|
|
|
|
auto dims = x->GetCompleteDims();
|
|
|
|
|
dims[0] = rows_idx.size();
|
|
|
|
|
outs[i]->mutable_value()->mutable_data<T>(dims, x->place());
|
|
|
|
|
outs[i]->mutable_rows()->clear();
|
|
|
|
|
if (rows_idx.size() > 0) {
|
|
|
|
|
auto dims = x->GetCompleteDims();
|
|
|
|
|
dims[0] = rows_idx.size();
|
|
|
|
|
outs[i]->mutable_value()->mutable_data<T>(dims, x->place());
|
|
|
|
|
for (auto idx : rows_idx) {
|
|
|
|
|
outs[i]->mutable_rows()->push_back(idx - abs_sections[i]);
|
|
|
|
|
}
|
|
|
|
@ -98,6 +99,8 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(rows_idx.size(), outs[i]->rows().size(),
|
|
|
|
|
"rows should has the same size with tensor dim 0");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|