|
|
|
@ -21,9 +21,8 @@ namespace reader {
|
|
|
|
|
template <bool ThreadSafe>
|
|
|
|
|
class RecordIOFileReader : public framework::FileReader {
|
|
|
|
|
public:
|
|
|
|
|
explicit RecordIOFileReader(const std::string& filename,
|
|
|
|
|
const std::vector<framework::DDim>& dims)
|
|
|
|
|
: FileReader(dims),
|
|
|
|
|
explicit RecordIOFileReader(const std::string& filename)
|
|
|
|
|
: FileReader(),
|
|
|
|
|
scanner_(filename),
|
|
|
|
|
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
|
|
|
|
|
platform::CPUPlace())) {
|
|
|
|
@ -58,20 +57,10 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const framework::Scope& scope,
|
|
|
|
|
const platform::Place& dev_place) const override {
|
|
|
|
|
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
|
|
|
|
|
const auto& ranks = Attr<std::vector<int>>("ranks");
|
|
|
|
|
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
|
|
|
|
|
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
|
|
|
|
|
static_cast<int>(shape_concat.size()),
|
|
|
|
|
"The accumulate of all ranks should be equal to the "
|
|
|
|
|
"shape concat's length.");
|
|
|
|
|
std::string filename = Attr<std::string>("filename");
|
|
|
|
|
|
|
|
|
|
auto* out = scope.FindVar(Output("Out"))
|
|
|
|
|
->template GetMutable<framework::ReaderHolder>();
|
|
|
|
|
|
|
|
|
|
out->Reset(new RecordIOFileReader<true>(
|
|
|
|
|
filename, RestoreShapes(shape_concat, ranks)));
|
|
|
|
|
out->Reset(new RecordIOFileReader<true>(filename));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|