|
|
|
@ -30,12 +30,17 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
|
if (!platform::is_cpu_place(place)) {
|
|
|
|
|
PADDLE_THROW("SplitIds do not support GPU kernel");
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
"SplitIds do not support GPU kernel"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto ids_vars = ctx.MultiInputVar("Ids");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GT(ids_vars.size(), 0, "The number of Ids should > 0");
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
ids_vars.size(), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
ids_vars.size(), 0, "The number of Ids expected > 0, but got %d",
|
|
|
|
|
ids_vars.size()));
|
|
|
|
|
auto *ids_var = ids_vars[0];
|
|
|
|
|
|
|
|
|
|
if (ids_var->IsType<framework::LoDTensor>()) {
|
|
|
|
@ -83,9 +88,6 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
} else if (ids_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids");
|
|
|
|
|
auto &ids_dims = ids_selected_rows->value().dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_dims[0],
|
|
|
|
|
static_cast<int64_t>(ids_selected_rows->rows().size()),
|
|
|
|
|
"");
|
|
|
|
|
const T *ids_data = ids_selected_rows->value().data<T>();
|
|
|
|
|
const auto &ids_rows = ids_selected_rows->rows();
|
|
|
|
|
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
|
|
|
|
@ -114,9 +116,9 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"% should be LoDTensor or SelectedRows, but the received type is %s",
|
|
|
|
|
ctx.InputNames("Ids")[0], framework::ToTypeName(ids_var->Type()));
|
|
|
|
|
ctx.InputNames("Ids")[0], framework::ToTypeName(ids_var->Type())));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|