|
|
|
@ -23,14 +23,21 @@ class SaveCombineOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {}
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(),
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return framework::OpKernelType(framework::proto::VarType::FP32,
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
// TODO(lujun): The override here is just to bypass transform
|
|
|
|
|
// in operator impl, which is not elegant enough.
|
|
|
|
|
framework::OpKernelType GetKernelTypeForVar(
|
|
|
|
|
const std::string& var_name, const Tensor& tensor,
|
|
|
|
|
const framework::OpKernelType& expected_kernel_type) const override {
|
|
|
|
|
return expected_kernel_type;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
@ -61,7 +68,7 @@ to a file on disk.
|
|
|
|
|
"(string)"
|
|
|
|
|
"The \"file_path\" where the LoDTensor variables will be saved.")
|
|
|
|
|
.AddCustomChecker(
|
|
|
|
|
[](const std::string &path) { return !path.empty(); });
|
|
|
|
|
[](const std::string& path) { return !path.empty(); });
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -77,6 +84,4 @@ REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
save_combine,
|
|
|
|
|
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
|
|
|
|
|
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
|
|
|
|
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>);
|
|
|
|
|