parent
10343123e3
commit
bcb80756af
@ -0,0 +1,51 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/recordio/scanner.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace recordio {
|
||||
Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
|
||||
: stream_(std::move(stream)) {
|
||||
Reset();
|
||||
}
|
||||
|
||||
Scanner::Scanner(const std::string &filename) {
|
||||
stream_.reset(new std::ifstream(filename));
|
||||
Reset();
|
||||
}
|
||||
|
||||
void Scanner::Reset() {
|
||||
stream_->seekg(0, std::ios::beg);
|
||||
ParseNextChunk();
|
||||
}
|
||||
|
||||
const std::string &Scanner::Next() {
|
||||
PADDLE_ENFORCE(!eof_, "StopIteration");
|
||||
auto &rec = cur_chunk_.Record(offset_++);
|
||||
if (offset_ == cur_chunk_.NumRecords()) {
|
||||
ParseNextChunk();
|
||||
}
|
||||
return rec;
|
||||
}
|
||||
|
||||
void Scanner::ParseNextChunk() {
|
||||
eof_ = !cur_chunk_.Parse(*stream_);
|
||||
offset_ = 0;
|
||||
}
|
||||
|
||||
bool Scanner::HasNext() const { return !eof_; }
|
||||
} // namespace recordio
|
||||
} // namespace paddle
|
@ -0,0 +1,44 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include "paddle/fluid/recordio/chunk.h"
|
||||
namespace paddle {
|
||||
namespace recordio {
|
||||
|
||||
class Scanner {
|
||||
public:
|
||||
explicit Scanner(std::unique_ptr<std::istream>&& stream);
|
||||
|
||||
explicit Scanner(const std::string& filename);
|
||||
|
||||
void Reset();
|
||||
|
||||
const std::string& Next();
|
||||
|
||||
bool HasNext() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<std::istream> stream_;
|
||||
Chunk cur_chunk_;
|
||||
size_t offset_;
|
||||
bool eof_;
|
||||
|
||||
void ParseNextChunk();
|
||||
};
|
||||
} // namespace recordio
|
||||
} // namespace paddle
|
@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "paddle/fluid/recordio/writer.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace recordio {
|
||||
void Writer::Write(const std::string& record) {
|
||||
cur_chunk_.Add(record);
|
||||
if (cur_chunk_.NumRecords() >= max_num_records_in_chunk_) {
|
||||
Flush();
|
||||
}
|
||||
}
|
||||
|
||||
void Writer::Flush() {
|
||||
cur_chunk_.Write(stream_, compressor_);
|
||||
cur_chunk_.Clear();
|
||||
}
|
||||
|
||||
Writer::~Writer() {
|
||||
PADDLE_ENFORCE(cur_chunk_.Empty(), "Writer must be flushed when destroy.");
|
||||
}
|
||||
|
||||
} // namespace recordio
|
||||
} // namespace paddle
|
@ -0,0 +1,44 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/recordio/chunk.h"
|
||||
namespace paddle {
|
||||
namespace recordio {
|
||||
|
||||
class Writer {
|
||||
public:
|
||||
Writer(std::ostream* sout,
|
||||
Compressor compressor,
|
||||
size_t max_num_records_in_chunk = 1000)
|
||||
: stream_(*sout),
|
||||
max_num_records_in_chunk_(max_num_records_in_chunk),
|
||||
compressor_(compressor) {}
|
||||
|
||||
void Write(const std::string& record);
|
||||
|
||||
void Flush();
|
||||
|
||||
~Writer();
|
||||
|
||||
private:
|
||||
std::ostream& stream_;
|
||||
size_t max_num_records_in_chunk_;
|
||||
Chunk cur_chunk_;
|
||||
Compressor compressor_;
|
||||
};
|
||||
|
||||
} // namespace recordio
|
||||
} // namespace paddle
|
@ -0,0 +1,44 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <sstream>
|
||||
#include "paddle/fluid/recordio/scanner.h"
|
||||
#include "paddle/fluid/recordio/writer.h"
|
||||
|
||||
TEST(WriterScanner, Normal) {
|
||||
std::stringstream* stream = new std::stringstream();
|
||||
|
||||
{
|
||||
paddle::recordio::Writer writer(stream,
|
||||
paddle::recordio::Compressor::kSnappy);
|
||||
writer.Write("ABC");
|
||||
writer.Write("BCD");
|
||||
writer.Write("CDE");
|
||||
writer.Flush();
|
||||
}
|
||||
|
||||
{
|
||||
stream->seekg(0, std::ios::beg);
|
||||
std::unique_ptr<std::istream> stream_ptr(stream);
|
||||
paddle::recordio::Scanner scanner(std::move(stream_ptr));
|
||||
ASSERT_TRUE(scanner.HasNext());
|
||||
ASSERT_EQ(scanner.Next(), "ABC");
|
||||
ASSERT_EQ("BCD", scanner.Next());
|
||||
ASSERT_TRUE(scanner.HasNext());
|
||||
ASSERT_EQ("CDE", scanner.Next());
|
||||
ASSERT_FALSE(scanner.HasNext());
|
||||
}
|
||||
}
|
Loading…
Reference in new issue