|
|
|
@ -69,15 +69,6 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
auto filename = Attr<std::string>("file_path");
|
|
|
|
|
auto overwrite = Attr<bool>("overwrite");
|
|
|
|
|
|
|
|
|
|
if (FileExists(filename) && !overwrite) {
|
|
|
|
|
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
|
|
|
|
|
filename, overwrite);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MkDirRecursively(DirName(filename).c_str());
|
|
|
|
|
|
|
|
|
|
auto iname = Input("X");
|
|
|
|
|
auto *var = scope.FindVar(iname);
|
|
|
|
@ -85,7 +76,7 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
iname);
|
|
|
|
|
|
|
|
|
|
if (var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
SaveLodTensor(filename, place, var);
|
|
|
|
|
SaveLodTensor(place, var);
|
|
|
|
|
} else if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
SaveSelectedRows(scope, place, var);
|
|
|
|
|
} else {
|
|
|
|
@ -96,8 +87,18 @@ class SaveOp : public framework::OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SaveLodTensor(const std::string &filename, const platform::Place &place,
|
|
|
|
|
void SaveLodTensor( const platform::Place &place,
|
|
|
|
|
framework::Variable *var) const {
|
|
|
|
|
auto filename = Attr<std::string>("file_path");
|
|
|
|
|
auto overwrite = Attr<bool>("overwrite");
|
|
|
|
|
|
|
|
|
|
if (FileExists(filename) && !overwrite) {
|
|
|
|
|
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
|
|
|
|
|
filename, overwrite);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MkDirRecursively(DirName(filename).c_str());
|
|
|
|
|
|
|
|
|
|
auto &tensor = var->Get<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
|