|
|
|
@ -21,7 +21,7 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
static int FindOutIdx(int row, const std::vector<int>& abs_sections) {
|
|
|
|
|
static int FindOutIdx(int row, const std::vector<int64_t>& abs_sections) {
|
|
|
|
|
for (size_t i = 1; i < abs_sections.size(); ++i) {
|
|
|
|
|
if (row < abs_sections[i]) {
|
|
|
|
|
return i - 1;
|
|
|
|
@ -30,9 +30,9 @@ static int FindOutIdx(int row, const std::vector<int>& abs_sections) {
|
|
|
|
|
return abs_sections.size() - 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::vector<int> ToAbsoluteSection(
|
|
|
|
|
const std::vector<int>& height_sections) {
|
|
|
|
|
std::vector<int> abs_sections;
|
|
|
|
|
static std::vector<int64_t> ToAbsoluteSection(
|
|
|
|
|
const std::vector<int64_t>& height_sections) {
|
|
|
|
|
std::vector<int64_t> abs_sections;
|
|
|
|
|
abs_sections.resize(height_sections.size());
|
|
|
|
|
abs_sections[0] = 0;
|
|
|
|
|
for (size_t i = 1; i < height_sections.size(); ++i) {
|
|
|
|
@ -47,7 +47,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* x = ctx.Input<framework::SelectedRows>("X");
|
|
|
|
|
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
|
|
|
|
|
auto height_sections = ctx.Attr<std::vector<int>>("height_sections");
|
|
|
|
|
auto height_sections = ctx.Attr<std::vector<int64_t>>("height_sections");
|
|
|
|
|
|
|
|
|
|
auto abs_sections = ToAbsoluteSection(height_sections);
|
|
|
|
|
|
|
|
|
|