commit
28e9807247
@ -1,8 +1,8 @@
|
|||||||
cmake_minimum_required(VERSION 3.0)
|
cmake_minimum_required(VERSION 3.0)
|
||||||
|
|
||||||
include_directories(/env/gopath/src/github.com/PaddlePaddle/Paddle/paddle/go/cclient/build/)
|
include_directories(${CMAKE_BINARY_DIR})
|
||||||
|
|
||||||
add_executable(main main.c)
|
add_executable(main main.c)
|
||||||
add_dependencies(main client)
|
add_dependencies(main client)
|
||||||
set (CMAKE_EXE_LINKER_FLAGS "-pthread")
|
set (CMAKE_EXE_LINKER_FLAGS "-pthread")
|
||||||
target_link_libraries(main /env/gopath/src/github.com/PaddlePaddle/Paddle/paddle/go/cclient/build/libclient.a) # ${GTEST_LIBRARIES})
|
target_link_libraries(main ${CMAKE_BINARY_DIR}/libclient.a)
|
||||||
|
@ -0,0 +1,36 @@
|
|||||||
|
# RecordIO
|
||||||
|
|
||||||
|
## Write
|
||||||
|
|
||||||
|
```go
|
||||||
|
f, e := os.Create("a_file.recordio")
|
||||||
|
w := recordio.NewWriter(f)
|
||||||
|
w.Write([]byte("Hello"))
|
||||||
|
w.Write([]byte("World!"))
|
||||||
|
w.Close()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Read
|
||||||
|
|
||||||
|
1. Load chunk index:
|
||||||
|
|
||||||
|
```go
|
||||||
|
f, e := os.Open("a_file.recordio")
|
||||||
|
idx, e := recordio.LoadIndex(f)
|
||||||
|
fmt.Println("Total records: ", idx.Len())
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create one or more scanner to read a range of records. The
|
||||||
|
following example reads the range
|
||||||
|
[1, 3), i.e., the second and the third records:
|
||||||
|
|
||||||
|
```go
|
||||||
|
f, e := os.Open("a_file.recordio")
|
||||||
|
s := recrodio.NewScanner(f, idx, 1, 3)
|
||||||
|
for s.Scan() {
|
||||||
|
fmt.Println(string(s.Record()))
|
||||||
|
}
|
||||||
|
if s.Err() != nil && s.Err() != io.EOF {
|
||||||
|
log.Fatalf("Something wrong with scanning: %v", e)
|
||||||
|
}
|
||||||
|
```
|
@ -0,0 +1,181 @@
|
|||||||
|
package recordio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"hash/crc32"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/golang/snappy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Chunk contains the Header and optionally compressed records. To
|
||||||
|
// create a chunk, just use ch := &Chunk{}.
|
||||||
|
type Chunk struct {
|
||||||
|
records [][]byte
|
||||||
|
numBytes int // sum of record lengths.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *Chunk) add(record []byte) {
|
||||||
|
ch.records = append(ch.records, record)
|
||||||
|
ch.numBytes += len(record)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dump the chunk into w, and clears the chunk and makes it ready for
|
||||||
|
// the next add invocation.
|
||||||
|
func (ch *Chunk) dump(w io.Writer, compressorIndex int) error {
|
||||||
|
// NOTE: don't check ch.numBytes instead, because empty
|
||||||
|
// records are allowed.
|
||||||
|
if len(ch.records) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write raw records and their lengths into data buffer.
|
||||||
|
var data bytes.Buffer
|
||||||
|
|
||||||
|
for _, r := range ch.records {
|
||||||
|
var rs [4]byte
|
||||||
|
binary.LittleEndian.PutUint32(rs[:], uint32(len(r)))
|
||||||
|
|
||||||
|
if _, e := data.Write(rs[:]); e != nil {
|
||||||
|
return fmt.Errorf("Failed to write record length: %v", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, e := data.Write(r); e != nil {
|
||||||
|
return fmt.Errorf("Failed to write record: %v", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
compressed, e := compressData(&data, compressorIndex)
|
||||||
|
if e != nil {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write chunk header and compressed data.
|
||||||
|
hdr := &Header{
|
||||||
|
checkSum: crc32.ChecksumIEEE(compressed.Bytes()),
|
||||||
|
compressor: uint32(compressorIndex),
|
||||||
|
compressedSize: uint32(compressed.Len()),
|
||||||
|
numRecords: uint32(len(ch.records)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, e := hdr.write(w); e != nil {
|
||||||
|
return fmt.Errorf("Failed to write chunk header: %v", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, e := w.Write(compressed.Bytes()); e != nil {
|
||||||
|
return fmt.Errorf("Failed to write chunk data: %v", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the current chunk.
|
||||||
|
ch.records = nil
|
||||||
|
ch.numBytes = 0
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type noopCompressor struct {
|
||||||
|
*bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *noopCompressor) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func compressData(src io.Reader, compressorIndex int) (*bytes.Buffer, error) {
|
||||||
|
compressed := new(bytes.Buffer)
|
||||||
|
var compressor io.WriteCloser
|
||||||
|
|
||||||
|
switch compressorIndex {
|
||||||
|
case NoCompression:
|
||||||
|
compressor = &noopCompressor{compressed}
|
||||||
|
case Snappy:
|
||||||
|
compressor = snappy.NewBufferedWriter(compressed)
|
||||||
|
case Gzip:
|
||||||
|
compressor = gzip.NewWriter(compressed)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("Unknown compression algorithm: %d", compressorIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, e := io.Copy(compressor, src); e != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to compress chunk data: %v", e)
|
||||||
|
}
|
||||||
|
compressor.Close()
|
||||||
|
|
||||||
|
return compressed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse the specified chunk from r.
|
||||||
|
func parseChunk(r io.ReadSeeker, chunkOffset int64) (*Chunk, error) {
|
||||||
|
var e error
|
||||||
|
var hdr *Header
|
||||||
|
|
||||||
|
if _, e = r.Seek(chunkOffset, io.SeekStart); e != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to seek chunk: %v", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
hdr, e = parseHeader(r)
|
||||||
|
if e != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to parse chunk header: %v", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if _, e = io.CopyN(&buf, r, int64(hdr.compressedSize)); e != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to read chunk data: %v", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.checkSum != crc32.ChecksumIEEE(buf.Bytes()) {
|
||||||
|
return nil, fmt.Errorf("Checksum checking failed.")
|
||||||
|
}
|
||||||
|
|
||||||
|
deflated, e := deflateData(&buf, int(hdr.compressor))
|
||||||
|
if e != nil {
|
||||||
|
return nil, e
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := &Chunk{}
|
||||||
|
for i := 0; i < int(hdr.numRecords); i++ {
|
||||||
|
var rs [4]byte
|
||||||
|
if _, e = deflated.Read(rs[:]); e != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to read record length: %v", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := make([]byte, binary.LittleEndian.Uint32(rs[:]))
|
||||||
|
if _, e = deflated.Read(r); e != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to read a record: %v", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
ch.records = append(ch.records, r)
|
||||||
|
ch.numBytes += len(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func deflateData(src io.Reader, compressorIndex int) (*bytes.Buffer, error) {
|
||||||
|
var e error
|
||||||
|
var deflator io.Reader
|
||||||
|
|
||||||
|
switch compressorIndex {
|
||||||
|
case NoCompression:
|
||||||
|
deflator = src
|
||||||
|
case Snappy:
|
||||||
|
deflator = snappy.NewReader(src)
|
||||||
|
case Gzip:
|
||||||
|
deflator, e = gzip.NewReader(src)
|
||||||
|
if e != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to create gzip reader: %v", e)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("Unknown compression algorithm: %d", compressorIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
deflated := new(bytes.Buffer)
|
||||||
|
if _, e = io.Copy(deflated, deflator); e != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to deflate chunk data: %v", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
return deflated, nil
|
||||||
|
}
|
@ -0,0 +1,59 @@
|
|||||||
|
package recordio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// NoCompression means writing raw chunk data into files.
|
||||||
|
// With other choices, chunks are compressed before written.
|
||||||
|
NoCompression = iota
|
||||||
|
// Snappy had been the default compressing algorithm widely
|
||||||
|
// used in Google. It compromises between speech and
|
||||||
|
// compression ratio.
|
||||||
|
Snappy
|
||||||
|
// Gzip is a well-known compression algorithm. It is
|
||||||
|
// recommmended only you are looking for compression ratio.
|
||||||
|
Gzip
|
||||||
|
|
||||||
|
magicNumber uint32 = 0x01020304
|
||||||
|
defaultCompressor = Snappy
|
||||||
|
)
|
||||||
|
|
||||||
|
// Header is the metadata of Chunk.
|
||||||
|
type Header struct {
|
||||||
|
checkSum uint32
|
||||||
|
compressor uint32
|
||||||
|
compressedSize uint32
|
||||||
|
numRecords uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Header) write(w io.Writer) (int, error) {
|
||||||
|
var buf [20]byte
|
||||||
|
binary.LittleEndian.PutUint32(buf[0:4], magicNumber)
|
||||||
|
binary.LittleEndian.PutUint32(buf[4:8], c.checkSum)
|
||||||
|
binary.LittleEndian.PutUint32(buf[8:12], c.compressor)
|
||||||
|
binary.LittleEndian.PutUint32(buf[12:16], c.compressedSize)
|
||||||
|
binary.LittleEndian.PutUint32(buf[16:20], c.numRecords)
|
||||||
|
return w.Write(buf[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHeader(r io.Reader) (*Header, error) {
|
||||||
|
var buf [20]byte
|
||||||
|
if _, e := r.Read(buf[:]); e != nil {
|
||||||
|
return nil, e
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := binary.LittleEndian.Uint32(buf[0:4]); v != magicNumber {
|
||||||
|
return nil, fmt.Errorf("Failed to parse magic number")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Header{
|
||||||
|
checkSum: binary.LittleEndian.Uint32(buf[4:8]),
|
||||||
|
compressor: binary.LittleEndian.Uint32(buf[8:12]),
|
||||||
|
compressedSize: binary.LittleEndian.Uint32(buf[12:16]),
|
||||||
|
numRecords: binary.LittleEndian.Uint32(buf[16:20]),
|
||||||
|
}, nil
|
||||||
|
}
|
@ -0,0 +1,135 @@
|
|||||||
|
package recordio
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
// Index consists offsets and sizes of the consequetive chunks in a RecordIO file.
|
||||||
|
type Index struct {
|
||||||
|
chunkOffsets []int64
|
||||||
|
chunkLens []uint32
|
||||||
|
numRecords int // the number of all records in a file.
|
||||||
|
chunkRecords []int // the number of records in chunks.
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadIndex scans the file and parse chunkOffsets, chunkLens, and len.
|
||||||
|
func LoadIndex(r io.ReadSeeker) (*Index, error) {
|
||||||
|
f := &Index{}
|
||||||
|
offset := int64(0)
|
||||||
|
var e error
|
||||||
|
var hdr *Header
|
||||||
|
|
||||||
|
for {
|
||||||
|
hdr, e = parseHeader(r)
|
||||||
|
if e != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
f.chunkOffsets = append(f.chunkOffsets, offset)
|
||||||
|
f.chunkLens = append(f.chunkLens, hdr.numRecords)
|
||||||
|
f.chunkRecords = append(f.chunkRecords, int(hdr.numRecords))
|
||||||
|
f.numRecords += int(hdr.numRecords)
|
||||||
|
|
||||||
|
offset, e = r.Seek(int64(hdr.compressedSize), io.SeekCurrent)
|
||||||
|
if e != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if e == io.EOF {
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
return nil, e
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumRecords returns the total number of records in a RecordIO file.
|
||||||
|
func (r *Index) NumRecords() int {
|
||||||
|
return r.numRecords
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumChunks returns the total number of chunks in a RecordIO file.
|
||||||
|
func (r *Index) NumChunks() int {
|
||||||
|
return len(r.chunkLens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChunkIndex return the Index of i-th Chunk.
|
||||||
|
func (r *Index) ChunkIndex(i int) *Index {
|
||||||
|
idx := &Index{}
|
||||||
|
idx.chunkOffsets = []int64{r.chunkOffsets[i]}
|
||||||
|
idx.chunkLens = []uint32{r.chunkLens[i]}
|
||||||
|
idx.chunkRecords = []int{r.chunkRecords[i]}
|
||||||
|
idx.numRecords = idx.chunkRecords[0]
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
|
||||||
|
// Locate returns the index of chunk that contains the given record,
|
||||||
|
// and the record index within the chunk. It returns (-1, -1) if the
|
||||||
|
// record is out of range.
|
||||||
|
func (r *Index) Locate(recordIndex int) (int, int) {
|
||||||
|
sum := 0
|
||||||
|
for i, l := range r.chunkLens {
|
||||||
|
sum += int(l)
|
||||||
|
if recordIndex < sum {
|
||||||
|
return i, recordIndex - sum + int(l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1, -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scanner scans records in a specified range within [0, numRecords).
|
||||||
|
type Scanner struct {
|
||||||
|
reader io.ReadSeeker
|
||||||
|
index *Index
|
||||||
|
start, end, cur int
|
||||||
|
chunkIndex int
|
||||||
|
chunk *Chunk
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewScanner creates a scanner that sequencially reads records in the
|
||||||
|
// range [start, start+len). If start < 0, it scans from the
|
||||||
|
// beginning. If len < 0, it scans till the end of file.
|
||||||
|
func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner {
|
||||||
|
if start < 0 {
|
||||||
|
start = 0
|
||||||
|
}
|
||||||
|
if len < 0 || start+len >= index.NumRecords() {
|
||||||
|
len = index.NumRecords() - start
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Scanner{
|
||||||
|
reader: r,
|
||||||
|
index: index,
|
||||||
|
start: start,
|
||||||
|
end: start + len,
|
||||||
|
cur: start - 1, // The intial status required by Scan.
|
||||||
|
chunkIndex: -1,
|
||||||
|
chunk: &Chunk{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan moves the cursor forward for one record and loads the chunk
|
||||||
|
// containing the record if not yet.
|
||||||
|
func (s *Scanner) Scan() bool {
|
||||||
|
s.cur++
|
||||||
|
|
||||||
|
if s.cur >= s.end {
|
||||||
|
s.err = io.EOF
|
||||||
|
} else {
|
||||||
|
if ci, _ := s.index.Locate(s.cur); s.chunkIndex != ci {
|
||||||
|
s.chunkIndex = ci
|
||||||
|
s.chunk, s.err = parseChunk(s.reader, s.index.chunkOffsets[ci])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record returns the record under the current cursor.
|
||||||
|
func (s *Scanner) Record() []byte {
|
||||||
|
_, ri := s.index.Locate(s.cur)
|
||||||
|
return s.chunk.records[ri]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns the error that stopped Scan.
|
||||||
|
func (s *Scanner) Error() error {
|
||||||
|
return s.err
|
||||||
|
}
|
@ -0,0 +1,90 @@
|
|||||||
|
package recordio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChunkHead(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
c := &Header{
|
||||||
|
checkSum: 123,
|
||||||
|
compressor: 456,
|
||||||
|
compressedSize: 789,
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
_, e := c.write(&buf)
|
||||||
|
assert.Nil(e)
|
||||||
|
|
||||||
|
cc, e := parseHeader(&buf)
|
||||||
|
assert.Nil(e)
|
||||||
|
assert.Equal(c, cc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteAndRead(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
data := []string{
|
||||||
|
"12345",
|
||||||
|
"1234",
|
||||||
|
"12"}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
w := NewWriter(&buf, 10, NoCompression) // use a small maxChunkSize.
|
||||||
|
|
||||||
|
n, e := w.Write([]byte(data[0])) // not exceed chunk size.
|
||||||
|
assert.Nil(e)
|
||||||
|
assert.Equal(5, n)
|
||||||
|
|
||||||
|
n, e = w.Write([]byte(data[1])) // not exceed chunk size.
|
||||||
|
assert.Nil(e)
|
||||||
|
assert.Equal(4, n)
|
||||||
|
|
||||||
|
n, e = w.Write([]byte(data[2])) // exeeds chunk size, dump and create a new chunk.
|
||||||
|
assert.Nil(e)
|
||||||
|
assert.Equal(n, 2)
|
||||||
|
|
||||||
|
assert.Nil(w.Close()) // flush the second chunk.
|
||||||
|
assert.Nil(w.Writer)
|
||||||
|
|
||||||
|
n, e = w.Write([]byte("anything")) // not effective after close.
|
||||||
|
assert.NotNil(e)
|
||||||
|
assert.Equal(n, 0)
|
||||||
|
|
||||||
|
idx, e := LoadIndex(bytes.NewReader(buf.Bytes()))
|
||||||
|
assert.Nil(e)
|
||||||
|
assert.Equal([]uint32{2, 1}, idx.chunkLens)
|
||||||
|
assert.Equal(
|
||||||
|
[]int64{0,
|
||||||
|
int64(4 + // magic number
|
||||||
|
unsafe.Sizeof(Header{}) +
|
||||||
|
5 + // first record
|
||||||
|
4 + // second record
|
||||||
|
2*4)}, // two record legnths
|
||||||
|
idx.chunkOffsets)
|
||||||
|
|
||||||
|
s := NewScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1)
|
||||||
|
i := 0
|
||||||
|
for s.Scan() {
|
||||||
|
assert.Equal(data[i], string(s.Record()))
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteEmptyFile(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
w := NewWriter(&buf, 10, NoCompression) // use a small maxChunkSize.
|
||||||
|
assert.Nil(w.Close())
|
||||||
|
assert.Equal(0, buf.Len())
|
||||||
|
|
||||||
|
idx, e := LoadIndex(bytes.NewReader(buf.Bytes()))
|
||||||
|
assert.Nil(e)
|
||||||
|
assert.Equal(0, idx.NumRecords())
|
||||||
|
}
|
@ -0,0 +1,81 @@
|
|||||||
|
package recordio_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/Paddle/paddle/go/recordio"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWriteRead(t *testing.T) {
|
||||||
|
const total = 1000
|
||||||
|
var buf bytes.Buffer
|
||||||
|
w := recordio.NewWriter(&buf, 0, -1)
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
_, err := w.Write(make([]byte, i))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.Close()
|
||||||
|
|
||||||
|
idx, err := recordio.LoadIndex(bytes.NewReader(buf.Bytes()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if idx.NumRecords() != total {
|
||||||
|
t.Fatal("num record does not match:", idx.NumRecords(), total)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := recordio.NewScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1)
|
||||||
|
i := 0
|
||||||
|
for s.Scan() {
|
||||||
|
if !reflect.DeepEqual(s.Record(), make([]byte, i)) {
|
||||||
|
t.Fatal("not equal:", len(s.Record()), len(make([]byte, i)))
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
if i != total {
|
||||||
|
t.Fatal("total count not match:", i, total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChunkIndex(t *testing.T) {
|
||||||
|
const total = 1000
|
||||||
|
var buf bytes.Buffer
|
||||||
|
w := recordio.NewWriter(&buf, 0, -1)
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
_, err := w.Write(make([]byte, i))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.Close()
|
||||||
|
|
||||||
|
idx, err := recordio.LoadIndex(bytes.NewReader(buf.Bytes()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if idx.NumChunks() != total {
|
||||||
|
t.Fatal("unexpected chunk num:", idx.NumChunks(), total)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
newIdx := idx.ChunkIndex(i)
|
||||||
|
s := recordio.NewScanner(bytes.NewReader(buf.Bytes()), newIdx, -1, -1)
|
||||||
|
j := 0
|
||||||
|
for s.Scan() {
|
||||||
|
if !reflect.DeepEqual(s.Record(), make([]byte, i)) {
|
||||||
|
t.Fatal("not equal:", len(s.Record()), len(make([]byte, i)))
|
||||||
|
}
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
if j != 1 {
|
||||||
|
t.Fatal("unexpected record per chunk:", j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,60 @@
|
|||||||
|
package recordio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultMaxChunkSize = 32 * 1024 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
// Writer creates a RecordIO file.
|
||||||
|
type Writer struct {
|
||||||
|
io.Writer // Set to nil to mark a closed writer.
|
||||||
|
chunk *Chunk
|
||||||
|
maxChunkSize int // total records size, excluding metadata, before compression.
|
||||||
|
compressor int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWriter creates a RecordIO file writer. Each chunk is compressed
|
||||||
|
// using the deflate algorithm given compression level. Note that
|
||||||
|
// level 0 means no compression and -1 means default compression.
|
||||||
|
func NewWriter(w io.Writer, maxChunkSize, compressor int) *Writer {
|
||||||
|
if maxChunkSize < 0 {
|
||||||
|
maxChunkSize = defaultMaxChunkSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if compressor < 0 {
|
||||||
|
compressor = defaultCompressor
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Writer{
|
||||||
|
Writer: w,
|
||||||
|
chunk: &Chunk{},
|
||||||
|
maxChunkSize: maxChunkSize,
|
||||||
|
compressor: compressor}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writes a record. It returns an error if Close has been called.
|
||||||
|
func (w *Writer) Write(record []byte) (int, error) {
|
||||||
|
if w.Writer == nil {
|
||||||
|
return 0, fmt.Errorf("Cannot write since writer had been closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.chunk.numBytes+len(record) > w.maxChunkSize {
|
||||||
|
if e := w.chunk.dump(w.Writer, w.compressor); e != nil {
|
||||||
|
return 0, e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w.chunk.add(record)
|
||||||
|
return len(record), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close flushes the current chunk and makes the writer invalid.
|
||||||
|
func (w *Writer) Close() error {
|
||||||
|
e := w.chunk.dump(w.Writer, w.compressor)
|
||||||
|
w.Writer = nil
|
||||||
|
return e
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle.v2.dataset.mq2007
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestMQ2007(unittest.TestCase):
|
||||||
|
def test_pairwise(self):
|
||||||
|
for label, query_left, query_right in paddle.v2.dataset.mq2007.test(
|
||||||
|
format="pairwise"):
|
||||||
|
self.assertEqual(query_left.shape(), (46, ))
|
||||||
|
self.assertEqual(query_right.shape(), (46, ))
|
||||||
|
|
||||||
|
def test_listwise(self):
|
||||||
|
for label_array, query_array in paddle.v2.dataset.mq2007.test(
|
||||||
|
format="listwise"):
|
||||||
|
self.assertEqual(len(label_array), len(query_array))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue