|
|
@ -64,7 +64,7 @@ class LoadCombineOp : public framework::OperatorBase {
|
|
|
|
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
|
|
// Error checking
|
|
|
|
// Error checking
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(buffer), "Cannot read more");
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(*buffer), "Cannot read more");
|
|
|
|
|
|
|
|
|
|
|
|
// Get data from fin to tensor
|
|
|
|
// Get data from fin to tensor
|
|
|
|
DeserializeFromStream(*buffer, tensor, dev_ctx);
|
|
|
|
DeserializeFromStream(*buffer, tensor, dev_ctx);
|
|
|
@ -90,6 +90,10 @@ class LoadCombineOp : public framework::OperatorBase {
|
|
|
|
tensor->ShareDataWith(fp16_tensor);
|
|
|
|
tensor->ShareDataWith(fp16_tensor);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
buffer->peek();
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(buffer->eof(),
|
|
|
|
|
|
|
|
"You are not allowed to load partial data via "
|
|
|
|
|
|
|
|
"load_combine_op, use load_op instead.");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|