|
|
|
@ -69,7 +69,6 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
|
|
|
|
|
auto iname = Input("X");
|
|
|
|
|
auto *var = scope.FindVar(iname);
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
|
|
|
|
@ -132,8 +131,11 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
void SaveSelectedRows(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &place,
|
|
|
|
|
framework::Variable *var) const {
|
|
|
|
|
auto *lt_var = scope.FindVar("loopup_table_path")->GetMutable<std::string>();
|
|
|
|
|
PADDLE_ENFORCE(lt_var != nullptr, "Cannot find variable loopup_table_path for SaveSelectedRows");
|
|
|
|
|
auto *lt_var =
|
|
|
|
|
scope.FindVar("loopup_table_path")->GetMutable<std::string>();
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
lt_var != nullptr,
|
|
|
|
|
"Can not find variable loopup_table_path for SaveSelectedRows");
|
|
|
|
|
std::string filename = lt_var->data();
|
|
|
|
|
VLOG(4) << "SaveSelectedRows get File name: " << filename;
|
|
|
|
|
|
|
|
|
@ -195,17 +197,11 @@ class SaveOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// namespace operators
|
|
|
|
|
// namespace paddle
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(save, ops::SaveOp,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker,
|
|
|
|
|
ops::SaveOpProtoMaker,
|
|
|
|
|
ops::SaveOpVarTypeInference,
|
|
|
|
|
REGISTER_OPERATOR(save, ops::SaveOp, paddle::framework::EmptyGradOpMaker,
|
|
|
|
|
ops::SaveOpProtoMaker, ops::SaveOpVarTypeInference,
|
|
|
|
|
ops::SaveOpShapeInference);
|
|
|
|
|
|
|
|
|
|