|
|
|
@ -22,7 +22,7 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
SplitSelectedRowsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The input SelectedRows.");
|
|
|
|
|
AddOutput("Out", "The outputs of input SelectedRows.").AsDuplicable();
|
|
|
|
|
AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
|
|
|
|
|
AddAttr<std::vector<int>>("height_sections",
|
|
|
|
|
"Height for each output SelectedRows.")
|
|
|
|
|
.SetDefault(std::vector<int>({}));
|
|
|
|
@ -56,27 +56,6 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "SplitSelectedRowsOp must has input X.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
|
|
|
|
|
"SplitSelectedRowsOp must has output Out.");
|
|
|
|
|
|
|
|
|
|
std::vector<int> height_sections =
|
|
|
|
|
ctx->Attrs().Get<std::vector<int>>("height_sections");
|
|
|
|
|
int64_t n = ctx->Outputs("Out").size();
|
|
|
|
|
|
|
|
|
|
std::vector<framework::DDim> outs_dims;
|
|
|
|
|
outs_dims.reserve(n);
|
|
|
|
|
|
|
|
|
|
// make output dims
|
|
|
|
|
for (int64_t i = 0; i < n; ++i) {
|
|
|
|
|
auto dims = ctx->GetInputDim("X");
|
|
|
|
|
if (height_sections.size()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
height_sections.size(), static_cast<size_t>(n),
|
|
|
|
|
"The size of height section should be the same with height"
|
|
|
|
|
" section size.");
|
|
|
|
|
dims[0] = height_sections[i];
|
|
|
|
|
}
|
|
|
|
|
outs_dims.push_back(dims);
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputsDim("Out", outs_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|