|
|
|
@ -59,6 +59,16 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::OpDesc &op_desc,
|
|
|
|
|
framework::BlockDesc *block) const override {
|
|
|
|
|
for (auto &out_var : op_desc.Output("Out")) {
|
|
|
|
|
block->Var(out_var)->SetType(framework::proto::VarType::SELECTED_ROWS);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
@ -80,7 +90,8 @@ class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(split_selected_rows, ops::SplitSelectedRowsOp,
|
|
|
|
|
ops::SplitSelectedRowsOpMaker,
|
|
|
|
|
ops::SplitSelectedRowsGradMaker);
|
|
|
|
|
ops::SplitSelectedRowsGradMaker,
|
|
|
|
|
ops::SplitSelectedRowsOpInferVarType);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
split_selected_rows,
|
|
|
|
|
ops::SplitSelectedRowsOpKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|