|
|
|
@ -182,9 +182,32 @@ This operator will serialize and write a tensor/selected rows variable to file o
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
class SaveOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::OpDesc &op_desc,
|
|
|
|
|
framework::BlockDesc *block) const override {
|
|
|
|
|
auto out_var_name = op_desc.Output("loopup_table_path").front();
|
|
|
|
|
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
|
|
|
|
|
auto var_type = framework::proto::VarType::RAW;
|
|
|
|
|
out_var.SetType(var_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SaveOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// namespace operators
|
|
|
|
|
// namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker);
|
|
|
|
|
REGISTER_OPERATOR(save, ops::SaveOp,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker,
|
|
|
|
|
ops::SaveOpProtoMaker,
|
|
|
|
|
ops::SaveOpVarTypeInference,
|
|
|
|
|
ops::SaveOpShapeInference);
|
|
|
|
|
|
|
|
|
|