|
|
@ -35,21 +35,27 @@ class LoadCombineOpKernel : public framework::OpKernel<T> {
|
|
|
|
auto model_from_memory = ctx.Attr<bool>("model_from_memory");
|
|
|
|
auto model_from_memory = ctx.Attr<bool>("model_from_memory");
|
|
|
|
auto out_var_names = ctx.OutputNames("Out");
|
|
|
|
auto out_var_names = ctx.OutputNames("Out");
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
PADDLE_ENFORCE_GT(out_var_names.size(), 0UL,
|
|
|
|
static_cast<int>(out_var_names.size()), 0,
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"The number of output variables should be greater than 0.");
|
|
|
|
"The number of variables to be loaded is %d, expect "
|
|
|
|
|
|
|
|
"it to be greater than 0.",
|
|
|
|
|
|
|
|
out_var_names.size()));
|
|
|
|
if (!model_from_memory) {
|
|
|
|
if (!model_from_memory) {
|
|
|
|
std::ifstream fin(filename, std::ios::binary);
|
|
|
|
std::ifstream fin(filename, std::ios::binary);
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(fin),
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
"OP(LoadCombine) fail to open file %s, please check "
|
|
|
|
static_cast<bool>(fin), true,
|
|
|
|
"whether the model file is complete or damaged.",
|
|
|
|
platform::errors::Unavailable(
|
|
|
|
filename);
|
|
|
|
"LoadCombine operator fails to open file %s, please check "
|
|
|
|
|
|
|
|
"whether the model file is complete or damaged.",
|
|
|
|
|
|
|
|
filename));
|
|
|
|
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
|
|
|
|
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
PADDLE_ENFORCE(!filename.empty(),
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
"OP(LoadCombine) fail to open file %s, please check "
|
|
|
|
filename.empty(), true,
|
|
|
|
"whether the model file is complete or damaged.",
|
|
|
|
platform::errors::Unavailable(
|
|
|
|
filename);
|
|
|
|
"LoadCombine operator fails to open file %s, please check "
|
|
|
|
|
|
|
|
"whether the model file is complete or damaged.",
|
|
|
|
|
|
|
|
filename));
|
|
|
|
std::stringstream fin(filename, std::ios::in | std::ios::binary);
|
|
|
|
std::stringstream fin(filename, std::ios::in | std::ios::binary);
|
|
|
|
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
|
|
|
|
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -64,16 +70,19 @@ class LoadCombineOpKernel : public framework::OpKernel<T> {
|
|
|
|
auto out_vars = context.MultiOutputVar("Out");
|
|
|
|
auto out_vars = context.MultiOutputVar("Out");
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < out_var_names.size(); i++) {
|
|
|
|
for (size_t i = 0; i < out_var_names.size(); i++) {
|
|
|
|
PADDLE_ENFORCE(out_vars[i] != nullptr,
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
"Output variable %s cannot be found", out_var_names[i]);
|
|
|
|
out_vars[i], platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The variable %s to be loaded cannot be found.",
|
|
|
|
|
|
|
|
out_var_names[i]));
|
|
|
|
|
|
|
|
|
|
|
|
auto *tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
|
|
|
|
auto *tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
|
|
// Error checking
|
|
|
|
// Error checking
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
static_cast<bool>(*buffer),
|
|
|
|
static_cast<bool>(*buffer), true,
|
|
|
|
"There is a problem with loading model parameters. "
|
|
|
|
platform::errors::Unavailable(
|
|
|
|
"Please check whether the model file is complete or damaged.");
|
|
|
|
"An error occurred while loading model parameters. "
|
|
|
|
|
|
|
|
"Please check whether the model file is complete or damaged."));
|
|
|
|
|
|
|
|
|
|
|
|
// Get data from fin to tensor
|
|
|
|
// Get data from fin to tensor
|
|
|
|
DeserializeFromStream(*buffer, tensor, dev_ctx);
|
|
|
|
DeserializeFromStream(*buffer, tensor, dev_ctx);
|
|
|
@ -100,9 +109,10 @@ class LoadCombineOpKernel : public framework::OpKernel<T> {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
buffer->peek();
|
|
|
|
buffer->peek();
|
|
|
|
PADDLE_ENFORCE(buffer->eof(),
|
|
|
|
PADDLE_ENFORCE_EQ(buffer->eof(), true,
|
|
|
|
"You are not allowed to load partial data via "
|
|
|
|
platform::errors::Unavailable(
|
|
|
|
"load_combine_op, use load_op instead.");
|
|
|
|
"Not allowed to load partial data via "
|
|
|
|
|
|
|
|
"load_combine_op, please use load_op instead."));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|