parent
91f82aba5c
commit
fa5c3f1f73
@ -0,0 +1,81 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
typedef int paddle_master_client;
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/Paddle/go/master"
|
||||||
|
)
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var handleMap = make(map[C.paddle_master_client]*master.Client)
|
||||||
|
var curHandle C.paddle_master_client
|
||||||
|
|
||||||
|
func add(c *master.Client) C.paddle_master_client {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
client := curHandle
|
||||||
|
curHandle++
|
||||||
|
handleMap[client] = c
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
func get(client C.paddle_master_client) *master.Client {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
return handleMap[client]
|
||||||
|
}
|
||||||
|
|
||||||
|
func remove(client C.paddle_master_client) *master.Client {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
h := handleMap[client]
|
||||||
|
delete(handleMap, client)
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
type addresser string
|
||||||
|
|
||||||
|
func (a addresser) Address() string {
|
||||||
|
return string(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
//paddle_new_master_client
|
||||||
|
func paddle_new_master_client(addr *C.char, buf_size C.int) C.paddle_master_client {
|
||||||
|
a := C.GoString(addr)
|
||||||
|
c := master.NewClient(addresser(a), int(buf_size))
|
||||||
|
return add(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
//export paddle_new_etcd_master_client
|
||||||
|
func paddle_new_etcd_master_client(etcd_addr *C.char) C.paddle_master_client {
|
||||||
|
// TODO(helin): fault tolerant master client using etcd.
|
||||||
|
panic("not implemented.")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export paddle_set_dataset
|
||||||
|
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
|
||||||
|
c := get(client)
|
||||||
|
var paths []string
|
||||||
|
for i := 0; i < int(size); i++ {
|
||||||
|
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(size)))
|
||||||
|
str := C.GoString(*ptr)
|
||||||
|
paths = append(paths, str)
|
||||||
|
}
|
||||||
|
err := c.SetDataset(paths)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {}
|
@ -0,0 +1,120 @@
|
|||||||
|
package master
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/rpc"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/Paddle/go/connection"
|
||||||
|
"github.com/PaddlePaddle/recordio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
totalTask = 20
|
||||||
|
chunkPerTask = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
log.SetLevel(log.ErrorLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestAddresser string
|
||||||
|
|
||||||
|
func (a TestAddresser) Address() string {
|
||||||
|
return string(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFinishTask(t *testing.T) {
|
||||||
|
const path = "/tmp/master_client_test_0"
|
||||||
|
|
||||||
|
l, err := net.Listen("tcp", ":0")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ss := strings.Split(l.Addr().String(), ":")
|
||||||
|
p, err := strconv.Atoi(ss[len(ss)-1])
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func(l net.Listener) {
|
||||||
|
s := NewService(chunkPerTask, time.Second, 1)
|
||||||
|
server := rpc.NewServer()
|
||||||
|
err := server.Register(s)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.Handle(rpc.DefaultRPCPath, server)
|
||||||
|
err = http.Serve(l, mux)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}(l)
|
||||||
|
|
||||||
|
f, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < totalTask*chunkPerTask; i++ {
|
||||||
|
w := recordio.NewWriter(f, -1, -1)
|
||||||
|
w.Write(nil)
|
||||||
|
// call Close to force RecordIO writing a chunk.
|
||||||
|
w.Close()
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
|
||||||
|
c := &Client{}
|
||||||
|
c.conn = connection.New()
|
||||||
|
go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p)))
|
||||||
|
c.SetDataset([]string{path})
|
||||||
|
|
||||||
|
checkOnePass := func(i int) {
|
||||||
|
var tasks []Task
|
||||||
|
for idx := 0; idx < totalTask; idx++ {
|
||||||
|
task, err := c.getTask()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err, " pass:", i)
|
||||||
|
}
|
||||||
|
tasks = append(tasks, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.getTask()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Should get error. Pass:", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.taskFinished(tasks[0].ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err, "pass:", i)
|
||||||
|
}
|
||||||
|
tasks = tasks[1:]
|
||||||
|
task, err := c.getTask()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tasks = append(tasks, task)
|
||||||
|
|
||||||
|
for _, task := range tasks {
|
||||||
|
err = c.taskFinished(task.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err, " pass:", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
checkOnePass(i)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in new issue