|
|
|
@ -30,10 +30,10 @@ class LoDResetOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
if (!ctx->HasInput("Y")) {
|
|
|
|
|
auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod");
|
|
|
|
|
PADDLE_ENFORCE_GT(level0.size(), 1,
|
|
|
|
|
PADDLE_ENFORCE_GT(level0.size(), 0,
|
|
|
|
|
"If Input(Y) not provided, the target lod should be "
|
|
|
|
|
"specified by attribute `target_lod`.");
|
|
|
|
|
} else {
|
|
|
|
|
} else if (ctx->IsRuntime()) {
|
|
|
|
|
ctx->ShareLoD("Y", "Out");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -48,6 +48,23 @@ class LoDResetOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class LoDResetOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferVarTypeContext *ctx) const override {
|
|
|
|
|
auto x_var_name = ctx->Input("X").front();
|
|
|
|
|
auto out_var_name = ctx->Output("Out").front();
|
|
|
|
|
if (ctx->HasInput("Y")) {
|
|
|
|
|
auto y_var_name = ctx->Input("Y").front();
|
|
|
|
|
auto y_lod_level = std::max(ctx->GetLoDLevel(y_var_name), 1);
|
|
|
|
|
ctx->SetLoDLevel(out_var_name, y_lod_level);
|
|
|
|
|
} else {
|
|
|
|
|
ctx->SetLoDLevel(out_var_name, 1);
|
|
|
|
|
}
|
|
|
|
|
ctx->SetDataType(out_var_name, ctx->GetDataType(x_var_name));
|
|
|
|
|
ctx->SetType(out_var_name, paddle::framework::proto::VarType::LOD_TENSOR);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
@ -177,9 +194,10 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(LoDResetGradNoNeedBufferVarInference,
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker,
|
|
|
|
|
ops::LoDResetGradDescMaker);
|
|
|
|
|
ops::LoDResetGradDescMaker, ops::LoDResetOpVarTypeInference);
|
|
|
|
|
REGISTER_OPERATOR(lod_reset_grad, ops::LoDResetGradOp,
|
|
|
|
|
ops::LoDResetGradNoNeedBufferVarInference);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
lod_reset, ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
|
|
|
|
|
ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
|
|
|
|
|