Op (Save/LoadCombine) error message enhancement (#23647)

* op save/load_combine error msg polish, test=develop

* fix detail error, test=develop
revert-23830-2.0-beta
Chen Weihang 5 years ago committed by GitHub
parent b61aaa2c10
commit 0b6f09e74f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -35,21 +35,27 @@ class LoadCombineOpKernel : public framework::OpKernel<T> {
auto model_from_memory = ctx.Attr<bool>("model_from_memory"); auto model_from_memory = ctx.Attr<bool>("model_from_memory");
auto out_var_names = ctx.OutputNames("Out"); auto out_var_names = ctx.OutputNames("Out");
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(out_var_names.size(), 0UL,
static_cast<int>(out_var_names.size()), 0, platform::errors::InvalidArgument(
"The number of output variables should be greater than 0."); "The number of variables to be loaded is %d, expect "
"it to be greater than 0.",
out_var_names.size()));
if (!model_from_memory) { if (!model_from_memory) {
std::ifstream fin(filename, std::ios::binary); std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), PADDLE_ENFORCE_EQ(
"OP(LoadCombine) fail to open file %s, please check " static_cast<bool>(fin), true,
"whether the model file is complete or damaged.", platform::errors::Unavailable(
filename); "LoadCombine operator fails to open file %s, please check "
"whether the model file is complete or damaged.",
filename));
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names); LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
} else { } else {
PADDLE_ENFORCE(!filename.empty(), PADDLE_ENFORCE_NE(
"OP(LoadCombine) fail to open file %s, please check " filename.empty(), true,
"whether the model file is complete or damaged.", platform::errors::Unavailable(
filename); "LoadCombine operator fails to open file %s, please check "
"whether the model file is complete or damaged.",
filename));
std::stringstream fin(filename, std::ios::in | std::ios::binary); std::stringstream fin(filename, std::ios::in | std::ios::binary);
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names); LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
} }
@ -64,16 +70,19 @@ class LoadCombineOpKernel : public framework::OpKernel<T> {
auto out_vars = context.MultiOutputVar("Out"); auto out_vars = context.MultiOutputVar("Out");
for (size_t i = 0; i < out_var_names.size(); i++) { for (size_t i = 0; i < out_var_names.size(); i++) {
PADDLE_ENFORCE(out_vars[i] != nullptr, PADDLE_ENFORCE_NOT_NULL(
"Output variable %s cannot be found", out_var_names[i]); out_vars[i], platform::errors::InvalidArgument(
"The variable %s to be loaded cannot be found.",
out_var_names[i]));
auto *tensor = out_vars[i]->GetMutable<framework::LoDTensor>(); auto *tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
// Error checking // Error checking
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
static_cast<bool>(*buffer), static_cast<bool>(*buffer), true,
"There is a problem with loading model parameters. " platform::errors::Unavailable(
"Please check whether the model file is complete or damaged."); "An error occurred while loading model parameters. "
"Please check whether the model file is complete or damaged."));
// Get data from fin to tensor // Get data from fin to tensor
DeserializeFromStream(*buffer, tensor, dev_ctx); DeserializeFromStream(*buffer, tensor, dev_ctx);
@ -100,9 +109,10 @@ class LoadCombineOpKernel : public framework::OpKernel<T> {
} }
} }
buffer->peek(); buffer->peek();
PADDLE_ENFORCE(buffer->eof(), PADDLE_ENFORCE_EQ(buffer->eof(), true,
"You are not allowed to load partial data via " platform::errors::Unavailable(
"load_combine_op, use load_op instead."); "Not allowed to load partial data via "
"load_combine_op, please use load_op instead."));
} }
}; };

@ -41,31 +41,40 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
bool is_present = FileExists(filename); bool is_present = FileExists(filename);
if (is_present && !overwrite) { if (is_present && !overwrite) {
PADDLE_THROW("%s exists!, cannot save_combine to it when overwrite=false", PADDLE_THROW(platform::errors::PreconditionNotMet(
filename, overwrite); "%s exists! Cannot save_combine to it when overwrite is set to "
"false.",
filename, overwrite));
} }
MkDirRecursively(DirName(filename).c_str()); MkDirRecursively(DirName(filename).c_str());
std::ofstream fout(filename, std::ios::binary); std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write", PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
filename); platform::errors::Unavailable(
"Cannot open %s to save variables.", filename));
auto inp_var_names = ctx.InputNames("X"); auto inp_var_names = ctx.InputNames("X");
auto &inp_vars = ctx.MultiInputVar("X"); auto &inp_vars = ctx.MultiInputVar("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0, PADDLE_ENFORCE_GT(inp_var_names.size(), 0UL,
"The number of input variables should be greater than 0"); platform::errors::InvalidArgument(
"The number of variables to be saved is %d, expect "
"it to be greater than 0.",
inp_var_names.size()));
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
for (size_t i = 0; i < inp_var_names.size(); i++) { for (size_t i = 0; i < inp_var_names.size(); i++) {
PADDLE_ENFORCE(inp_vars[i] != nullptr, PADDLE_ENFORCE_NOT_NULL(
"Cannot find variable %s for save_combine_op", inp_vars[i],
inp_var_names[i]); platform::errors::InvalidArgument("Cannot find variable %s to save.",
PADDLE_ENFORCE(inp_vars[i]->IsType<framework::LoDTensor>(), inp_var_names[i]));
"SaveCombineOp only supports LoDTensor, %s has wrong type", PADDLE_ENFORCE_EQ(inp_vars[i]->IsType<framework::LoDTensor>(), true,
inp_var_names[i]); platform::errors::InvalidArgument(
"SaveCombine operator only supports saving "
"LoDTensor variable, %s has wrong type.",
inp_var_names[i]));
auto &tensor = inp_vars[i]->Get<framework::LoDTensor>(); auto &tensor = inp_vars[i]->Get<framework::LoDTensor>();
// Serialize tensors one by one // Serialize tensors one by one

Loading…
Cancel
Save