From a868d010658642c3612ca37b488123d63623a967 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Thu, 25 May 2017 20:53:13 -0400 Subject: [PATCH 1/8] add cgo wrapper for recordio, make go_cmake automatically download go dependency --- paddle/go/CMakeLists.txt | 11 - paddle/go/adder.go | 10 - paddle/go/cclient/CMakeLists.txt | 31 +-- paddle/go/cclient/cclient.go | 7 +- paddle/go/cgo_test.cc | 5 - .../cmake/CMakeDetermineGoCompiler.cmake | 2 +- .../cmake/CMakeGoCompiler.cmake.in | 0 .../cmake/CMakeGoInformation.cmake | 0 .../cmake/CMakeTestGoCompiler.cmake | 0 paddle/go/{cclient => }/cmake/flags.cmake | 4 +- paddle/go/{cclient => }/cmake/golang.cmake | 40 ++-- paddle/go/crecordio/CMakeLists.txt | 12 + paddle/go/crecordio/crecordio.go | 208 ++++++++++++++++++ paddle/go/crecordio/register.go | 61 +++++ paddle/go/crecordio/test/CMakeLists.txt | 8 + paddle/go/crecordio/test/test.c | 31 +++ paddle/go/recordio/README.md | 5 +- 17 files changed, 361 insertions(+), 74 deletions(-) delete mode 100644 paddle/go/CMakeLists.txt delete mode 100644 paddle/go/adder.go delete mode 100644 paddle/go/cgo_test.cc rename paddle/go/{cclient => }/cmake/CMakeDetermineGoCompiler.cmake (94%) rename paddle/go/{cclient => }/cmake/CMakeGoCompiler.cmake.in (100%) rename paddle/go/{cclient => }/cmake/CMakeGoInformation.cmake (100%) rename paddle/go/{cclient => }/cmake/CMakeTestGoCompiler.cmake (100%) rename paddle/go/{cclient => }/cmake/flags.cmake (95%) rename paddle/go/{cclient => }/cmake/golang.cmake (50%) create mode 100644 paddle/go/crecordio/CMakeLists.txt create mode 100644 paddle/go/crecordio/crecordio.go create mode 100644 paddle/go/crecordio/register.go create mode 100644 paddle/go/crecordio/test/CMakeLists.txt create mode 100644 paddle/go/crecordio/test/test.c diff --git a/paddle/go/CMakeLists.txt b/paddle/go/CMakeLists.txt deleted file mode 100644 index 51c5252d66..0000000000 --- a/paddle/go/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -include_directories(${CMAKE_CURRENT_BINARY_DIR}) - -go_library(adder SRCS adder.go) - -if (WITH_TESTING) - cc_test(cgo_test - SRCS - cgo_test.cc - DEPS - adder) -endif() diff --git a/paddle/go/adder.go b/paddle/go/adder.go deleted file mode 100644 index e14f40fd9f..0000000000 --- a/paddle/go/adder.go +++ /dev/null @@ -1,10 +0,0 @@ -package main - -import "C" - -//export GoAdder -func GoAdder(x, y int) int { - return x + y -} - -func main() {} // Required but ignored diff --git a/paddle/go/cclient/CMakeLists.txt b/paddle/go/cclient/CMakeLists.txt index c85ff3db09..e3e9fa9f1a 100644 --- a/paddle/go/cclient/CMakeLists.txt +++ b/paddle/go/cclient/CMakeLists.txt @@ -1,31 +1,12 @@ cmake_minimum_required(VERSION 3.0) -if(GTEST_INCLUDE_DIR AND GTEST_LIBRARIES) - message("-- Found gtest (include: ${GTEST_INCLUDE_DIR}, library: ${GTEST_LIBRARIES})") -else() - # find cmake directory modules - get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) - get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) - get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) +get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake") - set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake") +project(cxx_go C Go) - # enable c++11 - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +include(golang) +include(flags) - # enable gtest - set(THIRD_PARTY_PATH ./third_party) - set(WITH_TESTING ON) - include(external/gtest) -endif() - -set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") - -project(cxx_go CXX C Go) - -include(cmake/golang.cmake) -include(cmake/flags.cmake) - -ExternalGoProject_Add(pserver github.com/PaddlePaddle/Paddle/paddle/go/pserver) -add_go_library(client STATIC pserver) +add_go_library(client STATIC) add_subdirectory(test) diff --git a/paddle/go/cclient/cclient.go b/paddle/go/cclient/cclient.go index dc86d47e8d..654b6f68a4 100644 --- a/paddle/go/cclient/cclient.go +++ b/paddle/go/cclient/cclient.go @@ -78,8 +78,11 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { return nil } - // create a Go clice backed by a C array, - // reference: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices + // create a Go clice backed by a C array, reference: + // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices + // + // Go garbage collector will not interact with this data, need + // to be freed from C side. return (*[1 << 30]byte)(p)[:len:len] } diff --git a/paddle/go/cgo_test.cc b/paddle/go/cgo_test.cc deleted file mode 100644 index 64efa606ff..0000000000 --- a/paddle/go/cgo_test.cc +++ /dev/null @@ -1,5 +0,0 @@ -#include -#include "gtest/gtest.h" -#include "libadder.h" - -TEST(Cgo, Invoke) { EXPECT_EQ(GoAdder(30, 12), 42); } diff --git a/paddle/go/cclient/cmake/CMakeDetermineGoCompiler.cmake b/paddle/go/cmake/CMakeDetermineGoCompiler.cmake similarity index 94% rename from paddle/go/cclient/cmake/CMakeDetermineGoCompiler.cmake rename to paddle/go/cmake/CMakeDetermineGoCompiler.cmake index b3f8fbe271..a9bb6906c7 100644 --- a/paddle/go/cclient/cmake/CMakeDetermineGoCompiler.cmake +++ b/paddle/go/cmake/CMakeDetermineGoCompiler.cmake @@ -38,7 +38,7 @@ endif() mark_as_advanced(CMAKE_Go_COMPILER) -configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/CMakeGoCompiler.cmake.in +configure_file(${CMAKE_MODULE_PATH}/CMakeGoCompiler.cmake.in ${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY) set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER") diff --git a/paddle/go/cclient/cmake/CMakeGoCompiler.cmake.in b/paddle/go/cmake/CMakeGoCompiler.cmake.in similarity index 100% rename from paddle/go/cclient/cmake/CMakeGoCompiler.cmake.in rename to paddle/go/cmake/CMakeGoCompiler.cmake.in diff --git a/paddle/go/cclient/cmake/CMakeGoInformation.cmake b/paddle/go/cmake/CMakeGoInformation.cmake similarity index 100% rename from paddle/go/cclient/cmake/CMakeGoInformation.cmake rename to paddle/go/cmake/CMakeGoInformation.cmake diff --git a/paddle/go/cclient/cmake/CMakeTestGoCompiler.cmake b/paddle/go/cmake/CMakeTestGoCompiler.cmake similarity index 100% rename from paddle/go/cclient/cmake/CMakeTestGoCompiler.cmake rename to paddle/go/cmake/CMakeTestGoCompiler.cmake diff --git a/paddle/go/cclient/cmake/flags.cmake b/paddle/go/cmake/flags.cmake similarity index 95% rename from paddle/go/cclient/cmake/flags.cmake rename to paddle/go/cmake/flags.cmake index 062d5ab660..a167c432a9 100644 --- a/paddle/go/cclient/cmake/flags.cmake +++ b/paddle/go/cmake/flags.cmake @@ -21,7 +21,7 @@ function(CheckCompilerCXX11Flag) if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3) message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.") endif() - endif() + endif() endif() endfunction() @@ -42,4 +42,4 @@ if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0") list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60") endif() -set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS}) \ No newline at end of file +set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS}) diff --git a/paddle/go/cclient/cmake/golang.cmake b/paddle/go/cmake/golang.cmake similarity index 50% rename from paddle/go/cclient/cmake/golang.cmake rename to paddle/go/cmake/golang.cmake index 5d39868bfd..caddaae1bf 100644 --- a/paddle/go/cclient/cmake/golang.cmake +++ b/paddle/go/cmake/golang.cmake @@ -1,22 +1,7 @@ set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go") file(MAKE_DIRECTORY ${GOPATH}) - -function(ExternalGoProject_Add TARG) - add_custom_target(${TARG} env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get ${ARGN}) -endfunction(ExternalGoProject_Add) - -function(add_go_executable NAME) - file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") - add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp - COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build - -o "${CMAKE_CURRENT_BINARY_DIR}/${NAME}" - ${CMAKE_GO_FLAGS} ${GO_SOURCE} - WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) - - add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN}) - install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${NAME} DESTINATION bin) -endfunction(add_go_executable) - +set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle") +file(MAKE_DIRECTORY ${PADDLE_IN_GOPATH}) function(ADD_GO_LIBRARY NAME BUILD_TYPE) if(BUILD_TYPE STREQUAL "STATIC") @@ -32,6 +17,26 @@ function(ADD_GO_LIBRARY NAME BUILD_TYPE) endif() file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") + file(RELATIVE_PATH rel ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) + + # find Paddle directory. + get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) + get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) + get_filename_component(PADDLE_DIR ${PARENT_DIR} DIRECTORY) + + # automatically get all dependencies specified in the source code + # for given target. + add_custom_target(goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...) + + # make a symlink that references Paddle inside $GOPATH, so go get + # will use the local changes in Paddle rather than checkout Paddle + # in github. + if(NOT EXISTS ${PADDLE_IN_GOPATH}) + add_custom_target(copyPaddle + COMMAND ln -s ${PADDLE_DIR} ${PADDLE_IN_GOPATH}) + add_dependencies(goGet copyPaddle) + endif() + add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} -o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" @@ -39,6 +44,7 @@ function(ADD_GO_LIBRARY NAME BUILD_TYPE) WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN}) + add_dependencies(${NAME} goGet) if(NOT BUILD_TYPE STREQUAL "STATIC") install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME} DESTINATION bin) diff --git a/paddle/go/crecordio/CMakeLists.txt b/paddle/go/crecordio/CMakeLists.txt new file mode 100644 index 0000000000..db8f556e50 --- /dev/null +++ b/paddle/go/crecordio/CMakeLists.txt @@ -0,0 +1,12 @@ +cmake_minimum_required(VERSION 3.0) + +get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake") + +project(cxx_go C Go) + +include(golang) +include(flags) + +add_go_library(recordio STATIC) +add_subdirectory(test) diff --git a/paddle/go/crecordio/crecordio.go b/paddle/go/crecordio/crecordio.go new file mode 100644 index 0000000000..3335d0795f --- /dev/null +++ b/paddle/go/crecordio/crecordio.go @@ -0,0 +1,208 @@ +package main + +/* +#include + +typedef int reader; +typedef int writer; +*/ +import "C" + +import ( + "io" + "log" + "os" + "path/filepath" + "strings" + "unsafe" + + "github.com/PaddlePaddle/Paddle/paddle/go/recordio" +) + +var nullPtr = unsafe.Pointer(uintptr(0)) + +type writer struct { + w *recordio.Writer + f *os.File +} + +type reader struct { + buffer chan []byte + cancel chan struct{} +} + +func read(paths []string, buffer chan<- []byte, cancel chan struct{}) { + var curFile *os.File + var curScanner *recordio.Scanner + var pathIdx int + + var nextFile func() bool + nextFile = func() bool { + if pathIdx >= len(paths) { + return false + } + + path := paths[pathIdx] + pathIdx++ + f, err := os.Open(path) + if err != nil { + return nextFile() + } + + idx, err := recordio.LoadIndex(f) + if err != nil { + log.Println(err) + err = f.Close() + if err != nil { + log.Println(err) + } + + return nextFile() + } + + curFile = f + curScanner = recordio.NewScanner(f, idx, 0, -1) + return true + } + + more := nextFile() + if !more { + close(buffer) + return + } + + closeFile := func() { + err := curFile.Close() + if err != nil { + log.Println(err) + } + curFile = nil + } + + for { + for curScanner.Scan() { + select { + case buffer <- curScanner.Record(): + case <-cancel: + close(buffer) + closeFile() + return + } + } + + if err := curScanner.Error(); err != nil && err != io.EOF { + log.Println(err) + } + + closeFile() + more := nextFile() + if !more { + close(buffer) + return + } + } +} + +//export paddle_new_writer +func paddle_new_writer(path *C.char) C.writer { + p := C.GoString(path) + f, err := os.Create(p) + if err != nil { + log.Println(err) + return -1 + } + + w := recordio.NewWriter(f, -1, -1) + writer := &writer{f: f, w: w} + return addWriter(writer) +} + +func cArrayToSlice(p unsafe.Pointer, len int) []byte { + if p == nullPtr { + return nil + } + + // create a Go clice backed by a C array, reference: + // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices + // + // Go garbage collector will not interact with this data, need + // to be freed from C side. + return (*[1 << 30]byte)(p)[:len:len] +} + +//export paddle_writer_write +func paddle_writer_write(writer C.writer, buf *C.uchar, size C.int) int { + w := getWriter(writer) + b := cArrayToSlice(unsafe.Pointer(buf), int(size)) + _, err := w.w.Write(b) + if err != nil { + log.Println(err) + return -1 + } + + return 0 +} + +//export paddle_writer_release +func paddle_writer_release(writer C.writer) { + w := removeWriter(writer) + w.w.Close() + w.f.Close() +} + +//export paddle_new_reader +func paddle_new_reader(path *C.char, bufferSize C.int) C.reader { + p := C.GoString(path) + ss := strings.Split(p, ",") + var paths []string + for _, s := range ss { + match, err := filepath.Glob(s) + if err != nil { + log.Printf("error applying glob to %s: %v\n", s, err) + return -1 + } + + paths = append(paths, match...) + } + + if len(paths) == 0 { + log.Println("no valid path provided.", p) + return -1 + } + + buffer := make(chan []byte, int(bufferSize)) + cancel := make(chan struct{}) + r := &reader{buffer: buffer, cancel: cancel} + go read(paths, buffer, cancel) + return addReader(r) +} + +//export paddle_reader_next_item +func paddle_reader_next_item(reader C.reader, size *C.int) *C.uchar { + r := getReader(reader) + buf, ok := <-r.buffer + if !ok { + // channel closed and empty, reached EOF. + *size = -1 + return (*C.uchar)(nullPtr) + } + + if len(buf) == 0 { + // empty item + *size = 0 + return (*C.uchar)(nullPtr) + } + + ptr := C.malloc(C.size_t(len(buf))) + C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf))) + *size = C.int(len(buf)) + return (*C.uchar)(ptr) +} + +//export paddle_reader_release +func paddle_reader_release(reader C.reader) { + r := removeReader(reader) + close(r.cancel) +} + +func main() {} // Required but ignored diff --git a/paddle/go/crecordio/register.go b/paddle/go/crecordio/register.go new file mode 100644 index 0000000000..61dfdbd4ab --- /dev/null +++ b/paddle/go/crecordio/register.go @@ -0,0 +1,61 @@ +package main + +/* +typedef int reader; +typedef int writer; +*/ +import "C" + +import "sync" + +var mu sync.Mutex +var handleMap = make(map[C.reader]*reader) +var curHandle C.reader +var writerMap = make(map[C.writer]*writer) +var curWriterHandle C.writer + +func addReader(r *reader) C.reader { + mu.Lock() + defer mu.Unlock() + reader := curHandle + curHandle++ + handleMap[reader] = r + return reader +} + +func getReader(reader C.reader) *reader { + mu.Lock() + defer mu.Unlock() + return handleMap[reader] +} + +func removeReader(reader C.reader) *reader { + mu.Lock() + defer mu.Unlock() + r := handleMap[reader] + delete(handleMap, reader) + return r +} + +func addWriter(w *writer) C.writer { + mu.Lock() + defer mu.Unlock() + writer := curWriterHandle + curWriterHandle++ + writerMap[writer] = w + return writer +} + +func getWriter(writer C.writer) *writer { + mu.Lock() + defer mu.Unlock() + return writerMap[writer] +} + +func removeWriter(writer C.writer) *writer { + mu.Lock() + defer mu.Unlock() + w := writerMap[writer] + delete(writerMap, writer) + return w +} diff --git a/paddle/go/crecordio/test/CMakeLists.txt b/paddle/go/crecordio/test/CMakeLists.txt new file mode 100644 index 0000000000..bac1006ae1 --- /dev/null +++ b/paddle/go/crecordio/test/CMakeLists.txt @@ -0,0 +1,8 @@ +cmake_minimum_required(VERSION 3.0) + +include_directories(${CMAKE_BINARY_DIR}) + +add_executable(recordio_test test.c) +add_dependencies(recordio_test recordio) +set (CMAKE_EXE_LINKER_FLAGS "-pthread") +target_link_libraries(recordio_test ${CMAKE_BINARY_DIR}/librecordio.a) diff --git a/paddle/go/crecordio/test/test.c b/paddle/go/crecordio/test/test.c new file mode 100644 index 0000000000..bbf5964fd3 --- /dev/null +++ b/paddle/go/crecordio/test/test.c @@ -0,0 +1,31 @@ +#include +#include + +#include "librecordio.h" + +void panic() { + // TODO(helin): fix: gtest using cmake is not working, using this + // hacky way for now. + *(void*)0; +} + +int main() { + writer w = paddle_new_writer("/tmp/test"); + paddle_writer_write(w, "hello", 6); + paddle_writer_write(w, "hi", 3); + paddle_writer_release(w); + + reader r = paddle_new_reader("/tmp/test", 10); + int size; + unsigned char* item = paddle_reader_next_item(r, &size); + if (!strcmp(item, "hello") || size != 6) { + panic(); + } + free(item); + + item = paddle_reader_next_item(r, &size); + if (!strcmp(item, "hi") || size != 2) { + panic(); + } + free(item); +} diff --git a/paddle/go/recordio/README.md b/paddle/go/recordio/README.md index 8b0b9308b1..fbf568ceba 100644 --- a/paddle/go/recordio/README.md +++ b/paddle/go/recordio/README.md @@ -8,6 +8,7 @@ w := recordio.NewWriter(f) w.Write([]byte("Hello")) w.Write([]byte("World!")) w.Close() +f.Close() ``` ## Read @@ -18,6 +19,7 @@ w.Close() f, e := os.Open("a_file.recordio") idx, e := recordio.LoadIndex(f) fmt.Println("Total records: ", idx.Len()) + f.Close() ``` 2. Create one or more scanner to read a range of records. The @@ -30,7 +32,8 @@ w.Close() for s.Scan() { fmt.Println(string(s.Record())) } - if s.Err() != nil && s.Err() != io.EOF { + if s.Error() != nil && s.Error() != io.EOF { log.Fatalf("Something wrong with scanning: %v", e) } + f.Close() ``` From 9e8503b64cf49e3cc0eb531d227e3681597720d2 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Fri, 26 May 2017 23:54:48 +0000 Subject: [PATCH 2/8] fix comment --- paddle/go/cclient/cclient.go | 2 +- paddle/go/crecordio/crecordio.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/go/cclient/cclient.go b/paddle/go/cclient/cclient.go index 654b6f68a4..ee2d9d24fd 100644 --- a/paddle/go/cclient/cclient.go +++ b/paddle/go/cclient/cclient.go @@ -82,7 +82,7 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices // // Go garbage collector will not interact with this data, need - // to be freed from C side. + // to be freed properly. return (*[1 << 30]byte)(p)[:len:len] } diff --git a/paddle/go/crecordio/crecordio.go b/paddle/go/crecordio/crecordio.go index 3335d0795f..cfc15d29a6 100644 --- a/paddle/go/crecordio/crecordio.go +++ b/paddle/go/crecordio/crecordio.go @@ -126,7 +126,7 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices // // Go garbage collector will not interact with this data, need - // to be freed from C side. + // to be freed properly. return (*[1 << 30]byte)(p)[:len:len] } From f074198e2795658030d5aeca1e7038373d050b74 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Sat, 27 May 2017 00:31:41 +0000 Subject: [PATCH 3/8] clang format --- paddle/go/crecordio/test/test.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/go/crecordio/test/test.c b/paddle/go/crecordio/test/test.c index bbf5964fd3..598b0965d8 100644 --- a/paddle/go/crecordio/test/test.c +++ b/paddle/go/crecordio/test/test.c @@ -1,5 +1,5 @@ -#include #include +#include #include "librecordio.h" @@ -22,7 +22,7 @@ int main() { panic(); } free(item); - + item = paddle_reader_next_item(r, &size); if (!strcmp(item, "hi") || size != 2) { panic(); From cab5076860aa7ecbd8595aeb47e0e2536d401c7c Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Sat, 27 May 2017 00:42:38 +0000 Subject: [PATCH 4/8] do not include paddle/go into cmake yet. --- paddle/CMakeLists.txt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index cf31b4a342..9898dc083e 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -9,9 +9,10 @@ add_subdirectory(pserver) add_subdirectory(trainer) add_subdirectory(scripts) -if(CMAKE_Go_COMPILER) - add_subdirectory(go) -endif() +# Do not build go directory until go cmake is working smoothly. +# if(CMAKE_Go_COMPILER) +# add_subdirectory(go) +# endif() find_package(Boost QUIET) From 0e80dadf37e6ba532fcccf17d52ebd6f746ec1e6 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Sat, 27 May 2017 00:54:10 +0000 Subject: [PATCH 5/8] release reader in c example --- paddle/go/crecordio/test/test.c | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/go/crecordio/test/test.c b/paddle/go/crecordio/test/test.c index 598b0965d8..5461a0911f 100644 --- a/paddle/go/crecordio/test/test.c +++ b/paddle/go/crecordio/test/test.c @@ -28,4 +28,5 @@ int main() { panic(); } free(item); + paddle_reader_release(r); } From 633171c2d3a1f6b8e245844fa4fb254895565da7 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Sat, 27 May 2017 14:49:05 +0000 Subject: [PATCH 6/8] fix according to comments --- paddle/go/cclient/test/main.c | 19 ++-- paddle/go/cmake/golang.cmake | 8 +- paddle/go/crecordio/crecordio.go | 169 +++++++---------------------- paddle/go/crecordio/test/test.c | 53 ++++++--- paddle/go/recordio/README.md | 2 +- paddle/go/recordio/multi_reader.go | 140 ++++++++++++++++++++++++ paddle/go/recordio/reader.go | 9 +- 7 files changed, 239 insertions(+), 161 deletions(-) create mode 100644 paddle/go/recordio/multi_reader.go diff --git a/paddle/go/cclient/test/main.c b/paddle/go/cclient/test/main.c index 28e3d03b7a..abfb32e560 100644 --- a/paddle/go/cclient/test/main.c +++ b/paddle/go/cclient/test/main.c @@ -1,11 +1,12 @@ -#include "libclient.h" +#include -//#include "gtest/gtest.h" +#include "libclient.h" -void panic() { +void fail() { // TODO(helin): fix: gtest using cmake is not working, using this // hacky way for now. - *(void*)0; + printf("test failed.\n"); + exit(-1); } int main() { @@ -35,7 +36,7 @@ retry: goto retry; } } else { - panic(); + fail(); } char content[] = {0x00, 0x11, 0x22}; @@ -44,25 +45,25 @@ retry: {"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}}; if (!paddle_send_grads(c, grads, 2)) { - panic(); + fail(); } paddle_parameter* params[2] = {NULL, NULL}; char* names[] = {"param_a", "param_b"}; if (!paddle_get_params(c, names, params, 2)) { - panic(); + fail(); } // get parameters again by reusing the allocated parameter buffers. if (!paddle_get_params(c, names, params, 2)) { - panic(); + fail(); } paddle_release_param(params[0]); paddle_release_param(params[1]); if (!paddle_save_model(c, "/tmp/")) { - panic(); + fail(); } return 0; diff --git a/paddle/go/cmake/golang.cmake b/paddle/go/cmake/golang.cmake index caddaae1bf..0ac17a967b 100644 --- a/paddle/go/cmake/golang.cmake +++ b/paddle/go/cmake/golang.cmake @@ -31,11 +31,9 @@ function(ADD_GO_LIBRARY NAME BUILD_TYPE) # make a symlink that references Paddle inside $GOPATH, so go get # will use the local changes in Paddle rather than checkout Paddle # in github. - if(NOT EXISTS ${PADDLE_IN_GOPATH}) - add_custom_target(copyPaddle - COMMAND ln -s ${PADDLE_DIR} ${PADDLE_IN_GOPATH}) - add_dependencies(goGet copyPaddle) - endif() + add_custom_target(copyPaddle + COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}) + add_dependencies(goGet copyPaddle) add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} diff --git a/paddle/go/crecordio/crecordio.go b/paddle/go/crecordio/crecordio.go index cfc15d29a6..e96bb49017 100644 --- a/paddle/go/crecordio/crecordio.go +++ b/paddle/go/crecordio/crecordio.go @@ -9,10 +9,8 @@ typedef int writer; import "C" import ( - "io" "log" "os" - "path/filepath" "strings" "unsafe" @@ -27,84 +25,24 @@ type writer struct { } type reader struct { - buffer chan []byte - cancel chan struct{} + scanner *recordio.MultiScanner } -func read(paths []string, buffer chan<- []byte, cancel chan struct{}) { - var curFile *os.File - var curScanner *recordio.Scanner - var pathIdx int - - var nextFile func() bool - nextFile = func() bool { - if pathIdx >= len(paths) { - return false - } - - path := paths[pathIdx] - pathIdx++ - f, err := os.Open(path) - if err != nil { - return nextFile() - } - - idx, err := recordio.LoadIndex(f) - if err != nil { - log.Println(err) - err = f.Close() - if err != nil { - log.Println(err) - } - - return nextFile() - } - - curFile = f - curScanner = recordio.NewScanner(f, idx, 0, -1) - return true - } - - more := nextFile() - if !more { - close(buffer) - return - } - - closeFile := func() { - err := curFile.Close() - if err != nil { - log.Println(err) - } - curFile = nil +func cArrayToSlice(p unsafe.Pointer, len int) []byte { + if p == nullPtr { + return nil } - for { - for curScanner.Scan() { - select { - case buffer <- curScanner.Record(): - case <-cancel: - close(buffer) - closeFile() - return - } - } - - if err := curScanner.Error(); err != nil && err != io.EOF { - log.Println(err) - } - - closeFile() - more := nextFile() - if !more { - close(buffer) - return - } - } + // create a Go clice backed by a C array, reference: + // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices + // + // Go garbage collector will not interact with this data, need + // to be freed properly. + return (*[1 << 30]byte)(p)[:len:len] } -//export paddle_new_writer -func paddle_new_writer(path *C.char) C.writer { +//export create_recordio_writer +func create_recordio_writer(path *C.char) C.writer { p := C.GoString(path) f, err := os.Create(p) if err != nil { @@ -117,21 +55,8 @@ func paddle_new_writer(path *C.char) C.writer { return addWriter(writer) } -func cArrayToSlice(p unsafe.Pointer, len int) []byte { - if p == nullPtr { - return nil - } - - // create a Go clice backed by a C array, reference: - // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices - // - // Go garbage collector will not interact with this data, need - // to be freed properly. - return (*[1 << 30]byte)(p)[:len:len] -} - -//export paddle_writer_write -func paddle_writer_write(writer C.writer, buf *C.uchar, size C.int) int { +//export write_recordio +func write_recordio(writer C.writer, buf *C.uchar, size C.int) int { w := getWriter(writer) b := cArrayToSlice(unsafe.Pointer(buf), int(size)) _, err := w.w.Write(b) @@ -143,66 +68,50 @@ func paddle_writer_write(writer C.writer, buf *C.uchar, size C.int) int { return 0 } -//export paddle_writer_release -func paddle_writer_release(writer C.writer) { +//export release_recordio +func release_recordio(writer C.writer) { w := removeWriter(writer) w.w.Close() w.f.Close() } -//export paddle_new_reader -func paddle_new_reader(path *C.char, bufferSize C.int) C.reader { +//export create_recordio_reader +func create_recordio_reader(path *C.char) C.reader { p := C.GoString(path) - ss := strings.Split(p, ",") - var paths []string - for _, s := range ss { - match, err := filepath.Glob(s) - if err != nil { - log.Printf("error applying glob to %s: %v\n", s, err) - return -1 - } - - paths = append(paths, match...) - } - - if len(paths) == 0 { - log.Println("no valid path provided.", p) + s, err := recordio.NewMultiScanner(strings.Split(p, ",")) + if err != nil { + log.Println(err) return -1 } - buffer := make(chan []byte, int(bufferSize)) - cancel := make(chan struct{}) - r := &reader{buffer: buffer, cancel: cancel} - go read(paths, buffer, cancel) + r := &reader{scanner: s} return addReader(r) } -//export paddle_reader_next_item -func paddle_reader_next_item(reader C.reader, size *C.int) *C.uchar { +//export read_next_item +func read_next_item(reader C.reader, size *C.int) *C.uchar { r := getReader(reader) - buf, ok := <-r.buffer - if !ok { - // channel closed and empty, reached EOF. - *size = -1 - return (*C.uchar)(nullPtr) - } + if r.scanner.Scan() { + buf := r.scanner.Record() + *size = C.int(len(buf)) + + if len(buf) == 0 { + return (*C.uchar)(nullPtr) + } - if len(buf) == 0 { - // empty item - *size = 0 - return (*C.uchar)(nullPtr) + ptr := C.malloc(C.size_t(len(buf))) + C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf))) + return (*C.uchar)(ptr) } - ptr := C.malloc(C.size_t(len(buf))) - C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf))) - *size = C.int(len(buf)) - return (*C.uchar)(ptr) + *size = -1 + return (*C.uchar)(nullPtr) } -//export paddle_reader_release -func paddle_reader_release(reader C.reader) { +//export release_recordio_reader +func release_recordio_reader(reader C.reader) { r := removeReader(reader) - close(r.cancel) + r.scanner.Close() } func main() {} // Required but ignored diff --git a/paddle/go/crecordio/test/test.c b/paddle/go/crecordio/test/test.c index 5461a0911f..54c3773ee9 100644 --- a/paddle/go/crecordio/test/test.c +++ b/paddle/go/crecordio/test/test.c @@ -3,30 +3,55 @@ #include "librecordio.h" -void panic() { +void fail() { // TODO(helin): fix: gtest using cmake is not working, using this // hacky way for now. - *(void*)0; + printf("test failed.\n"); + exit(-1); } int main() { - writer w = paddle_new_writer("/tmp/test"); - paddle_writer_write(w, "hello", 6); - paddle_writer_write(w, "hi", 3); - paddle_writer_release(w); + writer w = create_recordio_writer("/tmp/test_recordio_0"); + write_recordio(w, "hello", 6); + write_recordio(w, "hi", 3); + release_recordio(w); - reader r = paddle_new_reader("/tmp/test", 10); + w = create_recordio_writer("/tmp/test_recordio_1"); + write_recordio(w, "dog", 4); + write_recordio(w, "cat", 4); + release_recordio(w); + + reader r = create_recordio_reader("/tmp/test_recordio_*"); int size; - unsigned char* item = paddle_reader_next_item(r, &size); - if (!strcmp(item, "hello") || size != 6) { - panic(); + unsigned char* item = read_next_item(r, &size); + if (strcmp(item, "hello") || size != 6) { + fail(); + } + + free(item); + + item = read_next_item(r, &size); + if (strcmp(item, "hi") || size != 3) { + fail(); } free(item); - item = paddle_reader_next_item(r, &size); - if (!strcmp(item, "hi") || size != 2) { - panic(); + item = read_next_item(r, &size); + if (strcmp(item, "dog") || size != 4) { + fail(); } free(item); - paddle_reader_release(r); + + item = read_next_item(r, &size); + if (strcmp(item, "cat") || size != 4) { + fail(); + } + free(item); + + item = read_next_item(r, &size); + if (item != NULL || size != -1) { + fail(); + } + + release_recordio_reader(r); } diff --git a/paddle/go/recordio/README.md b/paddle/go/recordio/README.md index fbf568ceba..50e7e95476 100644 --- a/paddle/go/recordio/README.md +++ b/paddle/go/recordio/README.md @@ -32,7 +32,7 @@ f.Close() for s.Scan() { fmt.Println(string(s.Record())) } - if s.Error() != nil && s.Error() != io.EOF { + if s.Err() != nil { log.Fatalf("Something wrong with scanning: %v", e) } f.Close() diff --git a/paddle/go/recordio/multi_reader.go b/paddle/go/recordio/multi_reader.go new file mode 100644 index 0000000000..07e2834211 --- /dev/null +++ b/paddle/go/recordio/multi_reader.go @@ -0,0 +1,140 @@ +package recordio + +import ( + "fmt" + "os" + "path/filepath" +) + +// MultiScanner is a scanner for multiple recordio files. +type MultiScanner struct { + paths []string + curFile *os.File + curScanner *Scanner + pathIdx int + end bool + err error +} + +// NewMultiScanner creates a new MultiScanner. +func NewMultiScanner(paths []string) (*MultiScanner, error) { + var ps []string + for _, s := range paths { + match, err := filepath.Glob(s) + if err != nil { + return nil, err + } + + ps = append(ps, match...) + } + + if len(ps) == 0 { + return nil, fmt.Errorf("no valid path provided: %v", paths) + } + + return &MultiScanner{paths: ps}, nil +} + +// Scan moves the cursor forward for one record and loads the chunk +// containing the record if not yet. +func (s *MultiScanner) Scan() bool { + if s.err != nil { + return false + } + + if s.end { + return false + } + + if s.curScanner == nil { + more, err := s.nextFile() + if err != nil { + s.err = err + return false + } + + if !more { + s.end = true + return false + } + } + + curMore := s.curScanner.Scan() + s.err = s.curScanner.Err() + + if s.err != nil { + return curMore + } + + if !curMore { + err := s.curFile.Close() + if err != nil { + s.err = err + return false + } + s.curFile = nil + + more, err := s.nextFile() + if err != nil { + s.err = err + return false + } + + if !more { + s.end = true + return false + } + + return s.Scan() + } + return true +} + +// Err returns the first non-EOF error that was encountered by the +// Scanner. +func (s *MultiScanner) Err() error { + return s.err +} + +// Record returns the record under the current cursor. +func (s *MultiScanner) Record() []byte { + if s.curScanner == nil { + return nil + } + + return s.curScanner.Record() +} + +// Close release the resources. +func (s *MultiScanner) Close() error { + s.curScanner = nil + if s.curFile != nil { + err := s.curFile.Close() + s.curFile = nil + return err + } + return nil +} + +func (s *MultiScanner) nextFile() (bool, error) { + if s.pathIdx >= len(s.paths) { + return false, nil + } + + path := s.paths[s.pathIdx] + s.pathIdx++ + f, err := os.Open(path) + if err != nil { + return false, err + } + + idx, err := LoadIndex(f) + if err != nil { + f.Close() + return false, err + } + + s.curFile = f + s.curScanner = NewScanner(f, idx, 0, -1) + return true, nil +} diff --git a/paddle/go/recordio/reader.go b/paddle/go/recordio/reader.go index a12c604f7b..d00aef7ca9 100644 --- a/paddle/go/recordio/reader.go +++ b/paddle/go/recordio/reader.go @@ -129,7 +129,12 @@ func (s *Scanner) Record() []byte { return s.chunk.records[ri] } -// Error returns the error that stopped Scan. -func (s *Scanner) Error() error { +// Err returns the first non-EOF error that was encountered by the +// Scanner. +func (s *Scanner) Err() error { + if s.err == io.EOF { + return nil + } + return s.err } From ec5db3801c20c4ac2f93b404eadfc9c2ed48c079 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Sat, 27 May 2017 14:51:40 +0000 Subject: [PATCH 7/8] fix according to comments --- paddle/go/cclient/CMakeLists.txt | 2 +- paddle/go/cmake/golang.cmake | 4 ++-- paddle/go/crecordio/CMakeLists.txt | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/go/cclient/CMakeLists.txt b/paddle/go/cclient/CMakeLists.txt index e3e9fa9f1a..dfd104fb58 100644 --- a/paddle/go/cclient/CMakeLists.txt +++ b/paddle/go/cclient/CMakeLists.txt @@ -8,5 +8,5 @@ project(cxx_go C Go) include(golang) include(flags) -add_go_library(client STATIC) +go_library(client STATIC) add_subdirectory(test) diff --git a/paddle/go/cmake/golang.cmake b/paddle/go/cmake/golang.cmake index 0ac17a967b..e73b0c865b 100644 --- a/paddle/go/cmake/golang.cmake +++ b/paddle/go/cmake/golang.cmake @@ -3,7 +3,7 @@ file(MAKE_DIRECTORY ${GOPATH}) set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle") file(MAKE_DIRECTORY ${PADDLE_IN_GOPATH}) -function(ADD_GO_LIBRARY NAME BUILD_TYPE) +function(GO_LIBRARY NAME BUILD_TYPE) if(BUILD_TYPE STREQUAL "STATIC") set(BUILD_MODE -buildmode=c-archive) set(LIB_NAME "lib${NAME}.a") @@ -47,4 +47,4 @@ function(ADD_GO_LIBRARY NAME BUILD_TYPE) if(NOT BUILD_TYPE STREQUAL "STATIC") install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME} DESTINATION bin) endif() -endfunction(ADD_GO_LIBRARY) +endfunction(GO_LIBRARY) diff --git a/paddle/go/crecordio/CMakeLists.txt b/paddle/go/crecordio/CMakeLists.txt index db8f556e50..c395fe0b4a 100644 --- a/paddle/go/crecordio/CMakeLists.txt +++ b/paddle/go/crecordio/CMakeLists.txt @@ -8,5 +8,5 @@ project(cxx_go C Go) include(golang) include(flags) -add_go_library(recordio STATIC) +go_library(recordio STATIC) add_subdirectory(test) From 2fa274cf3672df25a0926b100cf2585596d24fe0 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Tue, 30 May 2017 20:34:27 +0000 Subject: [PATCH 8/8] fix according to comments --- paddle/go/crecordio/crecordio.go | 35 +++++++++---------- paddle/go/crecordio/test/test.c | 27 +++++++------- .../recordio/{reader.go => range_scanner.go} | 16 ++++----- paddle/go/recordio/recordio_internal_test.go | 2 +- paddle/go/recordio/recordio_test.go | 4 +-- .../recordio/{multi_reader.go => scanner.go} | 24 ++++++------- 6 files changed, 53 insertions(+), 55 deletions(-) rename paddle/go/recordio/{reader.go => range_scanner.go} (88%) rename paddle/go/recordio/{multi_reader.go => scanner.go} (77%) diff --git a/paddle/go/crecordio/crecordio.go b/paddle/go/crecordio/crecordio.go index e96bb49017..33f97de8cf 100644 --- a/paddle/go/crecordio/crecordio.go +++ b/paddle/go/crecordio/crecordio.go @@ -25,7 +25,7 @@ type writer struct { } type reader struct { - scanner *recordio.MultiScanner + scanner *recordio.Scanner } func cArrayToSlice(p unsafe.Pointer, len int) []byte { @@ -55,21 +55,21 @@ func create_recordio_writer(path *C.char) C.writer { return addWriter(writer) } -//export write_recordio -func write_recordio(writer C.writer, buf *C.uchar, size C.int) int { +//export recordio_write +func recordio_write(writer C.writer, buf *C.uchar, size C.int) C.int { w := getWriter(writer) b := cArrayToSlice(unsafe.Pointer(buf), int(size)) - _, err := w.w.Write(b) + c, err := w.w.Write(b) if err != nil { log.Println(err) return -1 } - return 0 + return C.int(c) } -//export release_recordio -func release_recordio(writer C.writer) { +//export release_recordio_writer +func release_recordio_writer(writer C.writer) { w := removeWriter(writer) w.w.Close() w.f.Close() @@ -78,7 +78,7 @@ func release_recordio(writer C.writer) { //export create_recordio_reader func create_recordio_reader(path *C.char) C.reader { p := C.GoString(path) - s, err := recordio.NewMultiScanner(strings.Split(p, ",")) + s, err := recordio.NewScanner(strings.Split(p, ",")...) if err != nil { log.Println(err) return -1 @@ -88,24 +88,23 @@ func create_recordio_reader(path *C.char) C.reader { return addReader(r) } -//export read_next_item -func read_next_item(reader C.reader, size *C.int) *C.uchar { +//export recordio_read +func recordio_read(reader C.reader, record **C.uchar) C.int { r := getReader(reader) if r.scanner.Scan() { buf := r.scanner.Record() - *size = C.int(len(buf)) - if len(buf) == 0 { - return (*C.uchar)(nullPtr) + *record = (*C.uchar)(nullPtr) + return 0 } - ptr := C.malloc(C.size_t(len(buf))) - C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf))) - return (*C.uchar)(ptr) + size := C.int(len(buf)) + *record = (*C.uchar)(C.malloc(C.size_t(len(buf)))) + C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&buf[0]), C.size_t(len(buf))) + return size } - *size = -1 - return (*C.uchar)(nullPtr) + return -1 } //export release_recordio_reader diff --git a/paddle/go/crecordio/test/test.c b/paddle/go/crecordio/test/test.c index 54c3773ee9..b25536a9d7 100644 --- a/paddle/go/crecordio/test/test.c +++ b/paddle/go/crecordio/test/test.c @@ -12,44 +12,43 @@ void fail() { int main() { writer w = create_recordio_writer("/tmp/test_recordio_0"); - write_recordio(w, "hello", 6); - write_recordio(w, "hi", 3); - release_recordio(w); + recordio_write(w, "hello", 6); + recordio_write(w, "hi", 3); + release_recordio_writer(w); w = create_recordio_writer("/tmp/test_recordio_1"); - write_recordio(w, "dog", 4); - write_recordio(w, "cat", 4); - release_recordio(w); + recordio_write(w, "dog", 4); + recordio_write(w, "cat", 4); + release_recordio_writer(w); reader r = create_recordio_reader("/tmp/test_recordio_*"); - int size; - unsigned char* item = read_next_item(r, &size); + unsigned char* item = NULL; + int size = recordio_read(r, &item); if (strcmp(item, "hello") || size != 6) { fail(); } - free(item); - item = read_next_item(r, &size); + size = recordio_read(r, &item); if (strcmp(item, "hi") || size != 3) { fail(); } free(item); - item = read_next_item(r, &size); + size = recordio_read(r, &item); if (strcmp(item, "dog") || size != 4) { fail(); } free(item); - item = read_next_item(r, &size); + size = recordio_read(r, &item); if (strcmp(item, "cat") || size != 4) { fail(); } free(item); - item = read_next_item(r, &size); - if (item != NULL || size != -1) { + size = recordio_read(r, &item); + if (size != -1) { fail(); } diff --git a/paddle/go/recordio/reader.go b/paddle/go/recordio/range_scanner.go similarity index 88% rename from paddle/go/recordio/reader.go rename to paddle/go/recordio/range_scanner.go index d00aef7ca9..46e2eee68c 100644 --- a/paddle/go/recordio/reader.go +++ b/paddle/go/recordio/range_scanner.go @@ -74,8 +74,8 @@ func (r *Index) Locate(recordIndex int) (int, int) { return -1, -1 } -// Scanner scans records in a specified range within [0, numRecords). -type Scanner struct { +// RangeScanner scans records in a specified range within [0, numRecords). +type RangeScanner struct { reader io.ReadSeeker index *Index start, end, cur int @@ -84,10 +84,10 @@ type Scanner struct { err error } -// NewScanner creates a scanner that sequencially reads records in the +// NewRangeScanner creates a scanner that sequencially reads records in the // range [start, start+len). If start < 0, it scans from the // beginning. If len < 0, it scans till the end of file. -func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner { +func NewRangeScanner(r io.ReadSeeker, index *Index, start, len int) *RangeScanner { if start < 0 { start = 0 } @@ -95,7 +95,7 @@ func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner { len = index.NumRecords() - start } - return &Scanner{ + return &RangeScanner{ reader: r, index: index, start: start, @@ -108,7 +108,7 @@ func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner { // Scan moves the cursor forward for one record and loads the chunk // containing the record if not yet. -func (s *Scanner) Scan() bool { +func (s *RangeScanner) Scan() bool { s.cur++ if s.cur >= s.end { @@ -124,14 +124,14 @@ func (s *Scanner) Scan() bool { } // Record returns the record under the current cursor. -func (s *Scanner) Record() []byte { +func (s *RangeScanner) Record() []byte { _, ri := s.index.Locate(s.cur) return s.chunk.records[ri] } // Err returns the first non-EOF error that was encountered by the // Scanner. -func (s *Scanner) Err() error { +func (s *RangeScanner) Err() error { if s.err == io.EOF { return nil } diff --git a/paddle/go/recordio/recordio_internal_test.go b/paddle/go/recordio/recordio_internal_test.go index e0f7dd0407..30e317925d 100644 --- a/paddle/go/recordio/recordio_internal_test.go +++ b/paddle/go/recordio/recordio_internal_test.go @@ -68,7 +68,7 @@ func TestWriteAndRead(t *testing.T) { 2*4)}, // two record legnths idx.chunkOffsets) - s := NewScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) + s := NewRangeScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) i := 0 for s.Scan() { assert.Equal(data[i], string(s.Record())) diff --git a/paddle/go/recordio/recordio_test.go b/paddle/go/recordio/recordio_test.go index 8bf1b020ab..ab117d2050 100644 --- a/paddle/go/recordio/recordio_test.go +++ b/paddle/go/recordio/recordio_test.go @@ -29,7 +29,7 @@ func TestWriteRead(t *testing.T) { t.Fatal("num record does not match:", idx.NumRecords(), total) } - s := recordio.NewScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) + s := recordio.NewRangeScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) i := 0 for s.Scan() { if !reflect.DeepEqual(s.Record(), make([]byte, i)) { @@ -66,7 +66,7 @@ func TestChunkIndex(t *testing.T) { for i := 0; i < total; i++ { newIdx := idx.ChunkIndex(i) - s := recordio.NewScanner(bytes.NewReader(buf.Bytes()), newIdx, -1, -1) + s := recordio.NewRangeScanner(bytes.NewReader(buf.Bytes()), newIdx, -1, -1) j := 0 for s.Scan() { if !reflect.DeepEqual(s.Record(), make([]byte, i)) { diff --git a/paddle/go/recordio/multi_reader.go b/paddle/go/recordio/scanner.go similarity index 77% rename from paddle/go/recordio/multi_reader.go rename to paddle/go/recordio/scanner.go index 07e2834211..865228ff65 100644 --- a/paddle/go/recordio/multi_reader.go +++ b/paddle/go/recordio/scanner.go @@ -6,18 +6,18 @@ import ( "path/filepath" ) -// MultiScanner is a scanner for multiple recordio files. -type MultiScanner struct { +// Scanner is a scanner for multiple recordio files. +type Scanner struct { paths []string curFile *os.File - curScanner *Scanner + curScanner *RangeScanner pathIdx int end bool err error } -// NewMultiScanner creates a new MultiScanner. -func NewMultiScanner(paths []string) (*MultiScanner, error) { +// NewScanner creates a new Scanner. +func NewScanner(paths ...string) (*Scanner, error) { var ps []string for _, s := range paths { match, err := filepath.Glob(s) @@ -32,12 +32,12 @@ func NewMultiScanner(paths []string) (*MultiScanner, error) { return nil, fmt.Errorf("no valid path provided: %v", paths) } - return &MultiScanner{paths: ps}, nil + return &Scanner{paths: ps}, nil } // Scan moves the cursor forward for one record and loads the chunk // containing the record if not yet. -func (s *MultiScanner) Scan() bool { +func (s *Scanner) Scan() bool { if s.err != nil { return false } @@ -92,12 +92,12 @@ func (s *MultiScanner) Scan() bool { // Err returns the first non-EOF error that was encountered by the // Scanner. -func (s *MultiScanner) Err() error { +func (s *Scanner) Err() error { return s.err } // Record returns the record under the current cursor. -func (s *MultiScanner) Record() []byte { +func (s *Scanner) Record() []byte { if s.curScanner == nil { return nil } @@ -106,7 +106,7 @@ func (s *MultiScanner) Record() []byte { } // Close release the resources. -func (s *MultiScanner) Close() error { +func (s *Scanner) Close() error { s.curScanner = nil if s.curFile != nil { err := s.curFile.Close() @@ -116,7 +116,7 @@ func (s *MultiScanner) Close() error { return nil } -func (s *MultiScanner) nextFile() (bool, error) { +func (s *Scanner) nextFile() (bool, error) { if s.pathIdx >= len(s.paths) { return false, nil } @@ -135,6 +135,6 @@ func (s *MultiScanner) nextFile() (bool, error) { } s.curFile = f - s.curScanner = NewScanner(f, idx, 0, -1) + s.curScanner = NewRangeScanner(f, idx, 0, -1) return true, nil }