|
|
|
@ -46,6 +46,19 @@ class LoadOp : public framework::OperatorBase {
|
|
|
|
|
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
DeserializeFromStream(fin, tensor, *dev_ctx);
|
|
|
|
|
|
|
|
|
|
auto load_as_fp16 = Attr<bool>("load_as_fp16");
|
|
|
|
|
auto in_dtype = framework::ToDataType(tensor->type());
|
|
|
|
|
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
|
|
|
|
|
|
|
|
|
|
if (in_dtype != out_dtype) {
|
|
|
|
|
// convert to float16 tensor
|
|
|
|
|
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
|
|
|
|
|
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
|
|
|
|
|
framework::LoDTensor fp16_tensor;
|
|
|
|
|
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
|
|
|
|
|
&fp16_tensor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -54,6 +67,13 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
LoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddOutput("Out", "(Tensor) The tensor need to be loaded");
|
|
|
|
|
AddAttr<bool>(
|
|
|
|
|
"load_as_fp16",
|
|
|
|
|
"(boolean, default false)"
|
|
|
|
|
"If true, the tensor will be first loaded and then "
|
|
|
|
|
"converted to float16 data type. Otherwise, the tensor will be "
|
|
|
|
|
"directly loaded without data type conversion.")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<std::string>("file_path",
|
|
|
|
|
"(string) "
|
|
|
|
|
"Variable will be loaded from \"file_path\".")
|
|
|
|
|