Merge pull request #2468 from helinwang/master_dispatch
Implement master client for reading training tasksgangliao-patch-1
commit
1a12720bb2
@ -0,0 +1,21 @@
|
||||
cmake_minimum_required(VERSION 3.0)
|
||||
|
||||
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
|
||||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")
|
||||
|
||||
project(cxx_go C Go)
|
||||
|
||||
include(golang)
|
||||
include(flags)
|
||||
|
||||
set(MASTER_LIB_NAME "paddle_master")
|
||||
go_library(${MASTER_LIB_NAME} SHARED)
|
||||
|
||||
if(PROJ_ROOT)
|
||||
add_custom_command(OUTPUT ${PROJ_ROOT}/python/paddle/v2/master/lib${MASTER_LIB_NAME}.so
|
||||
COMMAND rm ${CMAKE_CURRENT_BINARY_DIR}/lib${MASTER_LIB_NAME}.h
|
||||
COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/lib${MASTER_LIB_NAME}.so ${PROJ_ROOT}/python/paddle/v2/master/
|
||||
DEPENDS ${MASTER_LIB_NAME})
|
||||
add_custom_target(paddle_master_shared ALL DEPENDS ${PROJ_ROOT}/python/paddle/v2/master/lib${MASTER_LIB_NAME}.so)
|
||||
endif(PROJ_ROOT)
|
@ -0,0 +1,110 @@
|
||||
package main
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#define PADDLE_MASTER_OK 0
|
||||
#define PADDLE_MASTER_ERROR -1
|
||||
|
||||
typedef int paddle_master_client;
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/master"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var nullPtr = unsafe.Pointer(uintptr(0))
|
||||
var mu sync.Mutex
|
||||
var handleMap = make(map[C.paddle_master_client]*master.Client)
|
||||
var curHandle C.paddle_master_client
|
||||
|
||||
func add(c *master.Client) C.paddle_master_client {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
client := curHandle
|
||||
curHandle++
|
||||
handleMap[client] = c
|
||||
return client
|
||||
}
|
||||
|
||||
func get(client C.paddle_master_client) *master.Client {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return handleMap[client]
|
||||
}
|
||||
|
||||
func remove(client C.paddle_master_client) *master.Client {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
h := handleMap[client]
|
||||
delete(handleMap, client)
|
||||
return h
|
||||
}
|
||||
|
||||
type addresser string
|
||||
|
||||
func (a addresser) Address() string {
|
||||
return string(a)
|
||||
}
|
||||
|
||||
//export paddle_new_master_client
|
||||
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
|
||||
a := C.GoString(addr)
|
||||
c := master.NewClient(addresser(a), bufSize)
|
||||
return add(c)
|
||||
}
|
||||
|
||||
//export paddle_release_master_client
|
||||
func paddle_release_master_client(client C.paddle_master_client) {
|
||||
remove(client)
|
||||
}
|
||||
|
||||
//export paddle_set_dataset
|
||||
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
|
||||
c := get(client)
|
||||
var paths []string
|
||||
for i := 0; i < int(size); i++ {
|
||||
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
|
||||
str := C.GoString(*ptr)
|
||||
paths = append(paths, str)
|
||||
}
|
||||
err := c.SetDataset(paths)
|
||||
if err != nil {
|
||||
log.Errorln(err)
|
||||
return C.PADDLE_MASTER_ERROR
|
||||
}
|
||||
|
||||
return C.PADDLE_MASTER_OK
|
||||
}
|
||||
|
||||
//export paddle_next_record
|
||||
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
|
||||
c := get(client)
|
||||
r := c.NextRecord()
|
||||
if len(r) == 0 {
|
||||
*record = (*C.uchar)(nullPtr)
|
||||
return 0
|
||||
}
|
||||
|
||||
size := C.size_t(len(r))
|
||||
*record = (*C.uchar)(C.malloc(size))
|
||||
C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size)
|
||||
return C.int(size)
|
||||
}
|
||||
|
||||
//export mem_free
|
||||
func mem_free(p unsafe.Pointer) {
|
||||
// "free" may be a better name for this function, but doing so
|
||||
// will cause calling any function of this library from Python
|
||||
// ctypes hanging.
|
||||
C.free(p)
|
||||
}
|
||||
|
||||
func main() {}
|
@ -0,0 +1,121 @@
|
||||
package master
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/connection"
|
||||
"github.com/PaddlePaddle/recordio"
|
||||
)
|
||||
|
||||
const (
|
||||
totalTask = 20
|
||||
chunkPerTask = 10
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetLevel(log.ErrorLevel)
|
||||
}
|
||||
|
||||
type TestAddresser string
|
||||
|
||||
func (a TestAddresser) Address() string {
|
||||
return string(a)
|
||||
}
|
||||
|
||||
func TestGetFinishTask(t *testing.T) {
|
||||
const path = "/tmp/master_client_test_0"
|
||||
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ss := strings.Split(l.Addr().String(), ":")
|
||||
p, err := strconv.Atoi(ss[len(ss)-1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go func(l net.Listener) {
|
||||
s := NewService(chunkPerTask, time.Second, 1)
|
||||
server := rpc.NewServer()
|
||||
err := server.Register(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(rpc.DefaultRPCPath, server)
|
||||
err = http.Serve(l, mux)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}(l)
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for i := 0; i < totalTask*chunkPerTask; i++ {
|
||||
w := recordio.NewWriter(f, -1, -1)
|
||||
w.Write(nil)
|
||||
// call Close to force RecordIO writing a chunk.
|
||||
w.Close()
|
||||
}
|
||||
f.Close()
|
||||
|
||||
// Manually intialize client to avoid calling c.getRecords()
|
||||
c := &Client{}
|
||||
c.conn = connection.New()
|
||||
go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p)))
|
||||
c.SetDataset([]string{path})
|
||||
|
||||
checkOnePass := func(i int) {
|
||||
var tasks []Task
|
||||
for idx := 0; idx < totalTask; idx++ {
|
||||
task, err := c.getTask()
|
||||
if err != nil {
|
||||
t.Fatalf("Error: %v, pass: %d\n", err, i)
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
_, err = c.getTask()
|
||||
if err == nil {
|
||||
t.Fatalf("Should get error, pass: %d\n", i)
|
||||
}
|
||||
|
||||
err = c.taskFinished(tasks[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error: %v, pass: %d\n", err, i)
|
||||
}
|
||||
tasks = tasks[1:]
|
||||
task, err := c.getTask()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
|
||||
for _, task := range tasks {
|
||||
err = c.taskFinished(task.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error: %v, pass: %d\n", err, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
checkOnePass(i)
|
||||
}
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
*.whl
|
||||
*.so
|
||||
*.pyc
|
@ -0,0 +1,3 @@
|
||||
from client import *
|
||||
|
||||
__all__ = ['client']
|
@ -0,0 +1,39 @@
|
||||
import ctypes
|
||||
import os
|
||||
|
||||
path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
|
||||
lib = ctypes.cdll.LoadLibrary(path)
|
||||
|
||||
|
||||
class client(object):
|
||||
"""
|
||||
client is a client to the master server.
|
||||
"""
|
||||
|
||||
def __init__(self, addr, buf_size):
|
||||
self.c = lib.paddle_new_master_client(addr, buf_size)
|
||||
|
||||
def close(self):
|
||||
lib.paddle_release_master_client(self.c)
|
||||
self.c = None
|
||||
|
||||
def set_dataset(self, paths):
|
||||
holder_type = ctypes.c_char_p * len(paths)
|
||||
holder = holder_type()
|
||||
print paths
|
||||
for idx, path in enumerate(paths):
|
||||
c_ptr = ctypes.c_char_p(path)
|
||||
holder[idx] = c_ptr
|
||||
lib.paddle_set_dataset(self.c, holder, len(paths))
|
||||
|
||||
def next_record(self):
|
||||
p = ctypes.c_char_p()
|
||||
ret = ctypes.pointer(p)
|
||||
size = lib.paddle_next_record(self.c, ret)
|
||||
if size == 0:
|
||||
# Empty record
|
||||
return ""
|
||||
record = ret.contents.value[:size]
|
||||
# Memory created from C should be freed.
|
||||
lib.mem_free(ret.contents)
|
||||
return record
|
Loading…
Reference in new issue