|
|
|
@ -62,7 +62,7 @@ class CheckpointLoadOp : public framework::OperatorBase {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto inp_var_names = Output("Out");
|
|
|
|
|
auto inp_var_names = Inputs("X");
|
|
|
|
|
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
|
|
|
|
|
"The number of input variables should be greater than 0");
|
|
|
|
|
// get device context from pool
|
|
|
|
@ -102,7 +102,10 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
CheckpointLoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddOutput("Out", "(Tensor) The tensor need to be loaded");
|
|
|
|
|
AddInput(
|
|
|
|
|
"X",
|
|
|
|
|
"(vector) Input LoDTensors that need to be saved together in a file.")
|
|
|
|
|
.AsDuplicable();
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
CheckpointLoad operator
|
|
|
|
|
|
|
|
|
|