|
|
|
@ -103,12 +103,22 @@ class SaveOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
const platform::Place &place,
|
|
|
|
|
const framework::Variable *var) const {
|
|
|
|
|
framework::Variable *out_put_var = ctx.OutputVar(LOOKUP_TABLE_PATH);
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
out_put_var != nullptr,
|
|
|
|
|
"Can not find variable kLookupTablePath for SaveSelectedRows");
|
|
|
|
|
auto *lt_var = out_put_var->GetMutable<std::string>();
|
|
|
|
|
|
|
|
|
|
std::string filename = lt_var->data();
|
|
|
|
|
auto file_path = ctx.Attr<std::string>("file_path");
|
|
|
|
|
auto overwrite = ctx.Attr<bool>("overwrite");
|
|
|
|
|
|
|
|
|
|
std::string filename = file_path;
|
|
|
|
|
|
|
|
|
|
if (out_put_var != nullptr) {
|
|
|
|
|
auto *lt_var = out_put_var->GetMutable<std::string>();
|
|
|
|
|
filename = *lt_var;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FileExists(filename) && !overwrite) {
|
|
|
|
|
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
|
|
|
|
|
filename, overwrite);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(4) << "SaveSelectedRows get File name: " << filename;
|
|
|
|
|
|
|
|
|
|
MkDirRecursively(DirName(filename).c_str());
|
|
|
|
|