|
|
@ -60,7 +60,9 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
} else if (ids_var->IsType<framework::SelectedRows>()) {
|
|
|
|
} else if (ids_var->IsType<framework::SelectedRows>()) {
|
|
|
|
const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids");
|
|
|
|
const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids");
|
|
|
|
auto &ids_dims = ids_selected_rows->value().dims();
|
|
|
|
auto &ids_dims = ids_selected_rows->value().dims();
|
|
|
|
PADDLE_ENFORCE_EQ(ids_dims[0], ids_selected_rows->rows().size(), "");
|
|
|
|
PADDLE_ENFORCE_EQ(ids_dims[0],
|
|
|
|
|
|
|
|
static_cast<int64_t>(ids_selected_rows->rows().size()),
|
|
|
|
|
|
|
|
"");
|
|
|
|
const T *ids = ids_selected_rows->value().data<T>();
|
|
|
|
const T *ids = ids_selected_rows->value().data<T>();
|
|
|
|
const auto &ids_rows = ids_selected_rows->rows();
|
|
|
|
const auto &ids_rows = ids_selected_rows->rows();
|
|
|
|
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
|
|
|
|
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
|
|
|
@ -77,7 +79,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
framework::DDim ddim = framework::make_ddim(
|
|
|
|
framework::DDim ddim = framework::make_ddim(
|
|
|
|
{static_cast<int64_t>(out->rows().size()), row_width});
|
|
|
|
{static_cast<int64_t>(out->rows().size()), row_width});
|
|
|
|
T *output = out->mutable_value()->mutable_data<T>(ddim, place);
|
|
|
|
T *output = out->mutable_value()->mutable_data<T>(ddim, place);
|
|
|
|
for (size_t i = 0; i < ddim[0]; ++i) {
|
|
|
|
for (int64_t i = 0; i < ddim[0]; ++i) {
|
|
|
|
memcpy(output + i * row_width, ids + out->rows()[i] * row_width,
|
|
|
|
memcpy(output + i * row_width, ids + out->rows()[i] * row_width,
|
|
|
|
row_width * sizeof(T));
|
|
|
|
row_width * sizeof(T));
|
|
|
|
}
|
|
|
|
}
|
|
|
|