commit
						6597ccb01f
					
				@ -1,3 +0,0 @@
 | 
				
			||||
vendor/
 | 
				
			||||
.glide/
 | 
				
			||||
proto/*.go
 | 
				
			||||
@ -1,23 +0,0 @@
 | 
				
			||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
#
 | 
				
			||||
 | 
				
			||||
add_subdirectory(pserver/client/c)
 | 
				
			||||
add_subdirectory(cmd/pserver)
 | 
				
			||||
add_subdirectory(cmd/master)
 | 
				
			||||
add_subdirectory(master/c)
 | 
				
			||||
add_subdirectory(master)
 | 
				
			||||
add_subdirectory(pserver)
 | 
				
			||||
add_subdirectory(pserver/client)
 | 
				
			||||
add_subdirectory(utils/networkhelper)
 | 
				
			||||
@ -1,15 +0,0 @@
 | 
				
			||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
 | 
				
			||||
go_binary(master SRC master.go)
 | 
				
			||||
@ -1,120 +0,0 @@
 | 
				
			||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
// http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
package main
 | 
				
			||||
 | 
				
			||||
import (
 | 
				
			||||
	"fmt"
 | 
				
			||||
	"net"
 | 
				
			||||
	"net/http"
 | 
				
			||||
	"net/rpc"
 | 
				
			||||
	"os"
 | 
				
			||||
	"os/signal"
 | 
				
			||||
	"strconv"
 | 
				
			||||
	"strings"
 | 
				
			||||
	"time"
 | 
				
			||||
 | 
				
			||||
	log "github.com/inconshreveable/log15"
 | 
				
			||||
	"github.com/namsral/flag"
 | 
				
			||||
 | 
				
			||||
	"github.com/PaddlePaddle/Paddle/go/master"
 | 
				
			||||
	"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
func main() {
 | 
				
			||||
	port := flag.Int("port", 8080, "port of the master server.")
 | 
				
			||||
	ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.")
 | 
				
			||||
	endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.")
 | 
				
			||||
	taskTimeoutDur := flag.Duration("task-timout-dur", 20*time.Minute, "task timout duration.")
 | 
				
			||||
	taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.")
 | 
				
			||||
	chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.")
 | 
				
			||||
	logLevel := flag.String("log-level", "info",
 | 
				
			||||
		"log level, possible values: debug, info, warn, error, crit")
 | 
				
			||||
	flag.Parse()
 | 
				
			||||
 | 
				
			||||
	lvl, err := log.LvlFromString(*logLevel)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	log.Root().SetHandler(
 | 
				
			||||
		log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
 | 
				
			||||
	)
 | 
				
			||||
 | 
				
			||||
	if *endpoints == "" {
 | 
				
			||||
		log.Warn("-endpoints not set, fault tolerance not be enabled.")
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	var store master.Store
 | 
				
			||||
	if *endpoints != "" {
 | 
				
			||||
		eps := strings.Split(*endpoints, ",")
 | 
				
			||||
		ip, err := networkhelper.GetExternalIP()
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			log.Crit("get external ip error", log.Ctx{"error": err})
 | 
				
			||||
			panic(err)
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		addr := fmt.Sprintf("%s:%d", ip, *port)
 | 
				
			||||
		store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec)
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			log.Crit("error creating etcd client.", log.Ctx{"error": err})
 | 
				
			||||
			panic(err)
 | 
				
			||||
		}
 | 
				
			||||
	} else {
 | 
				
			||||
		store = &master.InMemStore{}
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	shutdown := func() {
 | 
				
			||||
		log.Info("shutting down gracefully")
 | 
				
			||||
		err := store.Shutdown()
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			log.Error("shutdown error", log.Ctx{"error": err})
 | 
				
			||||
		}
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	// Guaranteed to run even panic happens.
 | 
				
			||||
	defer shutdown()
 | 
				
			||||
 | 
				
			||||
	c := make(chan os.Signal, 1)
 | 
				
			||||
	signal.Notify(c, os.Interrupt)
 | 
				
			||||
 | 
				
			||||
	s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		log.Crit("error creating new service.", log.Ctx{"error": err})
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	err = rpc.Register(s)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		log.Crit("error registering to etcd.", log.Ctx{"error": err})
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	rpc.HandleHTTP()
 | 
				
			||||
	l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		log.Crit("error listing to port", log.Ctx{"error": err, "port": *port})
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	go func() {
 | 
				
			||||
		err = http.Serve(l, nil)
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			log.Crit("error serving HTTP", log.Ctx{"error": err})
 | 
				
			||||
			panic(err)
 | 
				
			||||
		}
 | 
				
			||||
	}()
 | 
				
			||||
 | 
				
			||||
	<-c
 | 
				
			||||
}
 | 
				
			||||
@ -1 +0,0 @@
 | 
				
			||||
pserver
 | 
				
			||||
@ -1,15 +0,0 @@
 | 
				
			||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
 | 
				
			||||
go_binary(pserver SRCS pserver.go DEPS paddle_go_optimizer)
 | 
				
			||||
@ -1,108 +0,0 @@
 | 
				
			||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
// http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
package main
 | 
				
			||||
 | 
				
			||||
import (
 | 
				
			||||
	"net"
 | 
				
			||||
	"net/http"
 | 
				
			||||
	"net/rpc"
 | 
				
			||||
	"os"
 | 
				
			||||
	"os/signal"
 | 
				
			||||
	"strconv"
 | 
				
			||||
	"time"
 | 
				
			||||
 | 
				
			||||
	"github.com/namsral/flag"
 | 
				
			||||
	"github.com/topicai/candy"
 | 
				
			||||
 | 
				
			||||
	"github.com/PaddlePaddle/Paddle/go/pserver"
 | 
				
			||||
	log "github.com/inconshreveable/log15"
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
func main() {
 | 
				
			||||
	port := flag.Int("port", 8001, "port of the pserver")
 | 
				
			||||
	index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
 | 
				
			||||
	etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
 | 
				
			||||
		"comma separated endpoint string for pserver to connect to etcd")
 | 
				
			||||
	dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout")
 | 
				
			||||
	etcdTTL := flag.Int("etcd-ttl", 5, "etcd time to live in seconds")
 | 
				
			||||
	numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
 | 
				
			||||
	checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
 | 
				
			||||
	checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
 | 
				
			||||
	logLevel := flag.String("log-level", "info",
 | 
				
			||||
		"log level, possible values: debug, info, warn, error, crit")
 | 
				
			||||
	flag.Parse()
 | 
				
			||||
 | 
				
			||||
	lvl, err := log.LvlFromString(*logLevel)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	log.Root().SetHandler(
 | 
				
			||||
		log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
 | 
				
			||||
	)
 | 
				
			||||
 | 
				
			||||
	var idx int
 | 
				
			||||
 | 
				
			||||
	var cp pserver.Checkpoint
 | 
				
			||||
	var e *pserver.EtcdClient
 | 
				
			||||
	if *index >= 0 {
 | 
				
			||||
		idx = *index
 | 
				
			||||
	} else {
 | 
				
			||||
		e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *dialTimeout, *etcdTTL)
 | 
				
			||||
		idx, err = e.Register(*port)
 | 
				
			||||
		candy.Must(err)
 | 
				
			||||
 | 
				
			||||
		cp, err = pserver.LoadCheckpoint(e, idx)
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			if err == pserver.ErrCheckpointNotFound {
 | 
				
			||||
				log.Info("load checkpoint error", "error", err)
 | 
				
			||||
			} else {
 | 
				
			||||
				panic(err)
 | 
				
			||||
			}
 | 
				
			||||
		}
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	shutdown := func() {
 | 
				
			||||
		log.Info("shutting down gracefully")
 | 
				
			||||
		sErr := e.Shutdown()
 | 
				
			||||
		if sErr != nil {
 | 
				
			||||
			log.Error("error shutting down", log.Ctx{"error": sErr})
 | 
				
			||||
		}
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	// Guaranteed to run even panic happens.
 | 
				
			||||
	defer shutdown()
 | 
				
			||||
 | 
				
			||||
	c := make(chan os.Signal, 1)
 | 
				
			||||
	signal.Notify(c, os.Interrupt)
 | 
				
			||||
 | 
				
			||||
	s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
 | 
				
			||||
	candy.Must(err)
 | 
				
			||||
 | 
				
			||||
	err = rpc.Register(s)
 | 
				
			||||
	candy.Must(err)
 | 
				
			||||
 | 
				
			||||
	rpc.HandleHTTP()
 | 
				
			||||
	l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
 | 
				
			||||
	candy.Must(err)
 | 
				
			||||
 | 
				
			||||
	go func() {
 | 
				
			||||
		log.Info("serving pserver", log.Ctx{"port": *port})
 | 
				
			||||
		err = http.Serve(l, nil)
 | 
				
			||||
		candy.Must(err)
 | 
				
			||||
	}()
 | 
				
			||||
 | 
				
			||||
	<-c
 | 
				
			||||
}
 | 
				
			||||
@ -1,120 +0,0 @@
 | 
				
			||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
// http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
package connection
 | 
				
			||||
 | 
				
			||||
import (
 | 
				
			||||
	"errors"
 | 
				
			||||
	"net/rpc"
 | 
				
			||||
	"sync"
 | 
				
			||||
 | 
				
			||||
	log "github.com/sirupsen/logrus"
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
// TODO(helin): add TCP re-connect logic
 | 
				
			||||
 | 
				
			||||
// Conn is a connection to a parameter server
 | 
				
			||||
type Conn struct {
 | 
				
			||||
	mu       sync.Mutex
 | 
				
			||||
	client   *rpc.Client
 | 
				
			||||
	waitConn chan struct{}
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// New creates a new connection.
 | 
				
			||||
func New() *Conn {
 | 
				
			||||
	c := &Conn{}
 | 
				
			||||
	return c
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// Close closes the connection.
 | 
				
			||||
func (c *Conn) Close() error {
 | 
				
			||||
	c.mu.Lock()
 | 
				
			||||
	defer c.mu.Unlock()
 | 
				
			||||
 | 
				
			||||
	if c.client == nil {
 | 
				
			||||
		return nil
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	return c.client.Close()
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// Connect connects the connection to a address.
 | 
				
			||||
func (c *Conn) Connect(addr string) error {
 | 
				
			||||
	c.mu.Lock()
 | 
				
			||||
	if c.client != nil {
 | 
				
			||||
		err := c.client.Close()
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			c.mu.Unlock()
 | 
				
			||||
			return err
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		c.client = nil
 | 
				
			||||
	}
 | 
				
			||||
	c.mu.Unlock()
 | 
				
			||||
 | 
				
			||||
	client, err := rpc.DialHTTP("tcp", addr)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		return err
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	c.mu.Lock()
 | 
				
			||||
	defer c.mu.Unlock()
 | 
				
			||||
 | 
				
			||||
	if c.client == nil {
 | 
				
			||||
		c.client = client
 | 
				
			||||
		if c.waitConn != nil {
 | 
				
			||||
			close(c.waitConn)
 | 
				
			||||
			c.waitConn = nil
 | 
				
			||||
		}
 | 
				
			||||
	} else {
 | 
				
			||||
		err := client.Close()
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			log.Errorln(err)
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		return errors.New("client already set from a concurrent goroutine")
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	return nil
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// TODO(helin): refactor Call to be able to perform given retry
 | 
				
			||||
// policy.
 | 
				
			||||
 | 
				
			||||
// Call make a RPC call.
 | 
				
			||||
//
 | 
				
			||||
// Call will be blocked until the connection to remote RPC service
 | 
				
			||||
// being established.
 | 
				
			||||
func (c *Conn) Call(serviceMethod string, args interface{}, reply interface{}) error {
 | 
				
			||||
	c.mu.Lock()
 | 
				
			||||
	client := c.client
 | 
				
			||||
	var waitCh chan struct{}
 | 
				
			||||
	if client == nil {
 | 
				
			||||
		if c.waitConn != nil {
 | 
				
			||||
			waitCh = c.waitConn
 | 
				
			||||
		} else {
 | 
				
			||||
			waitCh = make(chan struct{})
 | 
				
			||||
			c.waitConn = waitCh
 | 
				
			||||
		}
 | 
				
			||||
	}
 | 
				
			||||
	c.mu.Unlock()
 | 
				
			||||
 | 
				
			||||
	if waitCh != nil {
 | 
				
			||||
		// wait until new connection being established
 | 
				
			||||
		<-waitCh
 | 
				
			||||
		return c.Call(serviceMethod, args, reply)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	return client.Call(serviceMethod, args, reply)
 | 
				
			||||
}
 | 
				
			||||
@ -1,17 +0,0 @@
 | 
				
			||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
#
 | 
				
			||||
if(WITH_TESTING)
 | 
				
			||||
  go_test(master_test)
 | 
				
			||||
endif()
 | 
				
			||||
@ -1,15 +0,0 @@
 | 
				
			||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
#
 | 
				
			||||
go_library(paddle_master SHARED DEPS paddle_go_optimizer)
 | 
				
			||||
@ -1,196 +0,0 @@
 | 
				
			||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
// http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
package main
 | 
				
			||||
 | 
				
			||||
/*
 | 
				
			||||
#include <stdlib.h>
 | 
				
			||||
#include <string.h>
 | 
				
			||||
#include <stdio.h>
 | 
				
			||||
#define PADDLE_MASTER_OK    0
 | 
				
			||||
#define PADDLE_MASTER_ERROR -1
 | 
				
			||||
 | 
				
			||||
#define PADDLE_SAVE_MODEL_OK   1
 | 
				
			||||
#define PADDLE_SAVE_MODEL_SKIP 0
 | 
				
			||||
 | 
				
			||||
typedef int paddle_master_client;
 | 
				
			||||
*/
 | 
				
			||||
import "C"
 | 
				
			||||
 | 
				
			||||
import (
 | 
				
			||||
	"strings"
 | 
				
			||||
	"sync"
 | 
				
			||||
	"time"
 | 
				
			||||
	"unsafe"
 | 
				
			||||
 | 
				
			||||
	"github.com/PaddlePaddle/Paddle/go/master"
 | 
				
			||||
	log "github.com/inconshreveable/log15"
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
var mu sync.Mutex
 | 
				
			||||
var handleMap = make(map[C.paddle_master_client]*master.Client)
 | 
				
			||||
var curHandle C.paddle_master_client
 | 
				
			||||
 | 
				
			||||
func init() {
 | 
				
			||||
	log.Root().SetHandler(
 | 
				
			||||
		log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
 | 
				
			||||
	)
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
func add(c *master.Client) C.paddle_master_client {
 | 
				
			||||
	mu.Lock()
 | 
				
			||||
	defer mu.Unlock()
 | 
				
			||||
	client := curHandle
 | 
				
			||||
	curHandle++
 | 
				
			||||
	handleMap[client] = c
 | 
				
			||||
	return client
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
func get(client C.paddle_master_client) *master.Client {
 | 
				
			||||
	mu.Lock()
 | 
				
			||||
	defer mu.Unlock()
 | 
				
			||||
	return handleMap[client]
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
func remove(client C.paddle_master_client) *master.Client {
 | 
				
			||||
	mu.Lock()
 | 
				
			||||
	defer mu.Unlock()
 | 
				
			||||
	h := handleMap[client]
 | 
				
			||||
	delete(handleMap, client)
 | 
				
			||||
	return h
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
//export paddle_new_etcd_master_client
 | 
				
			||||
//
 | 
				
			||||
// bufSize is the record buffer size.
 | 
				
			||||
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
 | 
				
			||||
	p := C.GoString(etcdEndpoints)
 | 
				
			||||
	endpoints := strings.Split(p, ",")
 | 
				
			||||
	c, err := master.NewClient(
 | 
				
			||||
		master.WithEtcd(endpoints, time.Duration(timeout)*time.Second),
 | 
				
			||||
		master.WithBuffer(bufSize),
 | 
				
			||||
	)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	return add(c)
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
//export paddle_new_master_client
 | 
				
			||||
//
 | 
				
			||||
// bufSize is the record buffer size.
 | 
				
			||||
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
 | 
				
			||||
	a := C.GoString(addr)
 | 
				
			||||
	c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize))
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	return add(c)
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
//export paddle_release_master_client
 | 
				
			||||
func paddle_release_master_client(client C.paddle_master_client) {
 | 
				
			||||
	remove(client)
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
//export paddle_start_get_records
 | 
				
			||||
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
 | 
				
			||||
	c := get(client)
 | 
				
			||||
	c.StartGetRecords(int(pass))
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
//export paddle_set_dataset
 | 
				
			||||
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
 | 
				
			||||
	c := get(client)
 | 
				
			||||
	var paths []string
 | 
				
			||||
	for i := 0; i < int(size); i++ {
 | 
				
			||||
		ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
 | 
				
			||||
		str := C.GoString(*ptr)
 | 
				
			||||
		paths = append(paths, str)
 | 
				
			||||
	}
 | 
				
			||||
	err := c.SetDataset(paths)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		log.Error("error set dataset",
 | 
				
			||||
			log.Ctx{"error": err, "paths": paths})
 | 
				
			||||
		return C.PADDLE_MASTER_ERROR
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	return C.PADDLE_MASTER_OK
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// paddle_next_record gets the nexts training record.
 | 
				
			||||
//
 | 
				
			||||
// returns number of bytes of the records if success, -1 if failed, -2 if pass end.
 | 
				
			||||
//
 | 
				
			||||
//export paddle_next_record
 | 
				
			||||
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
 | 
				
			||||
	c := get(client)
 | 
				
			||||
	r, err := c.NextRecord()
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		// NOTE: use errors to indicate pass ends
 | 
				
			||||
		if err.Error() == master.ErrAllTaskFailed.Error() ||
 | 
				
			||||
			err.Error() == master.ErrNoMoreAvailable.Error() ||
 | 
				
			||||
			err.Error() == master.ErrPassBefore.Error() {
 | 
				
			||||
			return -2
 | 
				
			||||
		}
 | 
				
			||||
		*record = (*C.uchar)(nil)
 | 
				
			||||
		return -1
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	if len(r) == 0 {
 | 
				
			||||
		// Empty record
 | 
				
			||||
		*record = (*C.uchar)(nil)
 | 
				
			||||
		return 0
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	size := C.size_t(len(r))
 | 
				
			||||
	*record = (*C.uchar)(C.malloc(size))
 | 
				
			||||
	C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size)
 | 
				
			||||
	return C.int(size)
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// paddle_request_save_model requests the master server to approve the
 | 
				
			||||
// caller to save the model.
 | 
				
			||||
//
 | 
				
			||||
// returns 1 if the save the model request is approved, 0 if the
 | 
				
			||||
// request is rejected because other trainer is saving the model, -1
 | 
				
			||||
// if error happened.
 | 
				
			||||
//
 | 
				
			||||
//export paddle_request_save_model
 | 
				
			||||
func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int {
 | 
				
			||||
	c := get(client)
 | 
				
			||||
	need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		log.Error("error request save model", log.Ctx{"error": err})
 | 
				
			||||
		return C.PADDLE_MASTER_ERROR
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	if need {
 | 
				
			||||
		return C.PADDLE_SAVE_MODEL_OK
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	return C.PADDLE_SAVE_MODEL_SKIP
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
//export mem_free
 | 
				
			||||
func mem_free(p unsafe.Pointer) {
 | 
				
			||||
	// "free" may be a better name for this function, but doing so
 | 
				
			||||
	// will cause calling any function of this library from Python
 | 
				
			||||
	// ctypes hanging.
 | 
				
			||||
	C.free(p)
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
func main() {}
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -1,152 +0,0 @@
 | 
				
			||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
// http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
package master
 | 
				
			||||
 | 
				
			||||
import (
 | 
				
			||||
	"fmt"
 | 
				
			||||
	"net"
 | 
				
			||||
	"net/http"
 | 
				
			||||
	"net/rpc"
 | 
				
			||||
	"os"
 | 
				
			||||
	"strconv"
 | 
				
			||||
	"strings"
 | 
				
			||||
	"testing"
 | 
				
			||||
	"time"
 | 
				
			||||
 | 
				
			||||
	"github.com/PaddlePaddle/Paddle/go/connection"
 | 
				
			||||
	"github.com/PaddlePaddle/recordio"
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
const (
 | 
				
			||||
	totalTask    = 20
 | 
				
			||||
	chunkPerTask = 10
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
func TestGetFinishTask(t *testing.T) {
 | 
				
			||||
	const path = "/tmp/master_client_test_0"
 | 
				
			||||
 | 
				
			||||
	l, err := net.Listen("tcp", ":0")
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	ss := strings.Split(l.Addr().String(), ":")
 | 
				
			||||
	p, err := strconv.Atoi(ss[len(ss)-1])
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
	go func(l net.Listener) {
 | 
				
			||||
		s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
 | 
				
			||||
		if sErr != nil {
 | 
				
			||||
			panic(sErr)
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		server := rpc.NewServer()
 | 
				
			||||
		sErr = server.Register(s)
 | 
				
			||||
		if sErr != nil {
 | 
				
			||||
			panic(sErr)
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		mux := http.NewServeMux()
 | 
				
			||||
		mux.Handle(rpc.DefaultRPCPath, server)
 | 
				
			||||
		sErr = http.Serve(l, mux)
 | 
				
			||||
		if sErr != nil {
 | 
				
			||||
			panic(sErr)
 | 
				
			||||
		}
 | 
				
			||||
	}(l)
 | 
				
			||||
 | 
				
			||||
	f, err := os.Create(path)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	for i := 0; i < totalTask*chunkPerTask; i++ {
 | 
				
			||||
		w := recordio.NewWriter(f, -1, -1)
 | 
				
			||||
		_, err = w.Write(nil)
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			panic(err)
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		// call Close to force RecordIO writing a chunk.
 | 
				
			||||
		err = w.Close()
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			panic(err)
 | 
				
			||||
		}
 | 
				
			||||
	}
 | 
				
			||||
	err = f.Close()
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	// Manually intialize client to avoid calling c.getRecords()
 | 
				
			||||
	c := &Client{}
 | 
				
			||||
	c.conn = connection.New()
 | 
				
			||||
	addr := fmt.Sprintf(":%d", p)
 | 
				
			||||
	ch := make(chan string, 1)
 | 
				
			||||
	ch <- addr
 | 
				
			||||
	go c.monitorMaster(ch)
 | 
				
			||||
 | 
				
			||||
	err = c.SetDataset([]string{path})
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	checkOnePass := func(i int) {
 | 
				
			||||
		var tasks []Task
 | 
				
			||||
		for idx := 0; idx < totalTask; idx++ {
 | 
				
			||||
			task, cErr := c.getTask(i)
 | 
				
			||||
			if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
 | 
				
			||||
				t.Fatalf("error: %v, pass: %d\n", cErr, i)
 | 
				
			||||
			}
 | 
				
			||||
			tasks = append(tasks, task)
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		// getting task before task finishes should return error
 | 
				
			||||
		_, cErr := c.getTask(i)
 | 
				
			||||
		if cErr == nil {
 | 
				
			||||
			t.Fatalf("Should get error, pass: %d\n", i)
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		cErr = c.taskFinished(tasks[0].Meta.ID)
 | 
				
			||||
		if cErr != nil {
 | 
				
			||||
			t.Fatalf("Error: %v, pass: %d\n", cErr, i)
 | 
				
			||||
		}
 | 
				
			||||
		// call taskFailed once won't put the task to failed queue, just ensure
 | 
				
			||||
		// the call
 | 
				
			||||
		cErr = c.taskFailed(tasks[0].Meta)
 | 
				
			||||
		if cErr != nil {
 | 
				
			||||
			t.Fatalf("Error: %v, pass: %d\n", cErr, i)
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		tasks = tasks[1:]
 | 
				
			||||
		_, cErr = c.getTask(i)
 | 
				
			||||
		if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
 | 
				
			||||
			t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		for _, task := range tasks {
 | 
				
			||||
			cErr = c.taskFinished(task.Meta.ID)
 | 
				
			||||
			if cErr != nil {
 | 
				
			||||
				t.Fatal(cErr)
 | 
				
			||||
			}
 | 
				
			||||
		}
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	for i := 0; i < 10; i++ {
 | 
				
			||||
		// init pass data
 | 
				
			||||
		c.StartGetRecords(i)
 | 
				
			||||
		checkOnePass(i)
 | 
				
			||||
	}
 | 
				
			||||
}
 | 
				
			||||
@ -1,150 +0,0 @@
 | 
				
			||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
// http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
package master_test
 | 
				
			||||
 | 
				
			||||
import (
 | 
				
			||||
	"fmt"
 | 
				
			||||
	"net"
 | 
				
			||||
	"net/http"
 | 
				
			||||
	"net/rpc"
 | 
				
			||||
	"os"
 | 
				
			||||
	"runtime"
 | 
				
			||||
	"strconv"
 | 
				
			||||
	"strings"
 | 
				
			||||
	"sync"
 | 
				
			||||
	"testing"
 | 
				
			||||
	"time"
 | 
				
			||||
 | 
				
			||||
	"github.com/PaddlePaddle/Paddle/go/master"
 | 
				
			||||
	"github.com/PaddlePaddle/recordio"
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
// tool function for testing output goroutine ids
 | 
				
			||||
func goid() int {
 | 
				
			||||
	var buf [64]byte
 | 
				
			||||
	n := runtime.Stack(buf[:], false)
 | 
				
			||||
	idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
 | 
				
			||||
	id, err := strconv.Atoi(idField)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(fmt.Sprintf("cannot get goroutine id: %v", err))
 | 
				
			||||
	}
 | 
				
			||||
	return id
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
func TestNextRecord(t *testing.T) {
 | 
				
			||||
	const (
 | 
				
			||||
		path  = "/tmp/master_client_TestFull"
 | 
				
			||||
		total = 50
 | 
				
			||||
	)
 | 
				
			||||
	l, err := net.Listen("tcp", ":0")
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	ss := strings.Split(l.Addr().String(), ":")
 | 
				
			||||
	p, err := strconv.Atoi(ss[len(ss)-1])
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
	go func(l net.Listener) {
 | 
				
			||||
		s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			panic(err)
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		server := rpc.NewServer()
 | 
				
			||||
		err = server.Register(s)
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			panic(err)
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		mux := http.NewServeMux()
 | 
				
			||||
		mux.Handle(rpc.DefaultRPCPath, server)
 | 
				
			||||
		err = http.Serve(l, mux)
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			panic(err)
 | 
				
			||||
		}
 | 
				
			||||
	}(l)
 | 
				
			||||
 | 
				
			||||
	f, err := os.Create(path)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	w := recordio.NewWriter(f, 1, -1)
 | 
				
			||||
	for i := 0; i < total; i++ {
 | 
				
			||||
		_, err = w.Write([]byte{byte(i)})
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			panic(err)
 | 
				
			||||
		}
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	err = w.Close()
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	err = f.Close()
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		panic(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	// start several client to test task fetching
 | 
				
			||||
	var wg sync.WaitGroup
 | 
				
			||||
	for i := 0; i < 4; i++ {
 | 
				
			||||
		wg.Add(1)
 | 
				
			||||
		// test for multiple concurrent clients
 | 
				
			||||
		go func() {
 | 
				
			||||
			defer wg.Done()
 | 
				
			||||
			// each go-routine needs a single client connection instance
 | 
				
			||||
			c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
 | 
				
			||||
			if e != nil {
 | 
				
			||||
				t.Fatal(e)
 | 
				
			||||
			}
 | 
				
			||||
			e = c.SetDataset([]string{path})
 | 
				
			||||
			if e != nil {
 | 
				
			||||
				panic(e)
 | 
				
			||||
			}
 | 
				
			||||
 | 
				
			||||
			// test for n passes
 | 
				
			||||
			for pass := 0; pass < 10; pass++ {
 | 
				
			||||
				c.StartGetRecords(pass)
 | 
				
			||||
 | 
				
			||||
				received := make(map[byte]bool)
 | 
				
			||||
				taskid := 0
 | 
				
			||||
				for {
 | 
				
			||||
					r, e := c.NextRecord()
 | 
				
			||||
					if e != nil {
 | 
				
			||||
						// ErrorPassAfter will wait, else break for next pass
 | 
				
			||||
						if e.Error() == master.ErrPassBefore.Error() ||
 | 
				
			||||
							e.Error() == master.ErrNoMoreAvailable.Error() {
 | 
				
			||||
							break
 | 
				
			||||
						}
 | 
				
			||||
						t.Fatal(pass, taskid, "Read error:", e)
 | 
				
			||||
					}
 | 
				
			||||
					if len(r) != 1 {
 | 
				
			||||
						t.Fatal(pass, taskid, "Length should be 1.", r)
 | 
				
			||||
					}
 | 
				
			||||
					if received[r[0]] {
 | 
				
			||||
						t.Fatal(pass, taskid, "Received duplicate.", received, r)
 | 
				
			||||
					}
 | 
				
			||||
					taskid++
 | 
				
			||||
					received[r[0]] = true
 | 
				
			||||
				}
 | 
				
			||||
			}
 | 
				
			||||
		}()
 | 
				
			||||
	}
 | 
				
			||||
	wg.Wait()
 | 
				
			||||
}
 | 
				
			||||
@ -1,201 +0,0 @@
 | 
				
			||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
// http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
package master
 | 
				
			||||
 | 
				
			||||
import (
 | 
				
			||||
	"context"
 | 
				
			||||
	"time"
 | 
				
			||||
 | 
				
			||||
	"github.com/coreos/etcd/clientv3"
 | 
				
			||||
	"github.com/coreos/etcd/clientv3/concurrency"
 | 
				
			||||
	log "github.com/inconshreveable/log15"
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
const (
 | 
				
			||||
	// DefaultLockPath is the default etcd master lock path.
 | 
				
			||||
	DefaultLockPath = "/master/lock"
 | 
				
			||||
	// DefaultStatePath is the default etcd key for master state.
 | 
				
			||||
	DefaultStatePath = "/master/state"
 | 
				
			||||
	// DefaultAddrPath is the default etcd key for master address.
 | 
				
			||||
	DefaultAddrPath = "/master/addr"
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
// EtcdClient is the etcd client that the master uses for fault
 | 
				
			||||
// tolerance and service registry.
 | 
				
			||||
type EtcdClient struct {
 | 
				
			||||
	lockPath  string
 | 
				
			||||
	statePath string
 | 
				
			||||
	client    *clientv3.Client
 | 
				
			||||
	lock      *concurrency.Mutex
 | 
				
			||||
	sess      *concurrency.Session
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// NewEtcdClient creates a new EtcdClient.
 | 
				
			||||
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
 | 
				
			||||
	log.Debug("Connecting to etcd", log.Ctx{"endpoint": endpoints})
 | 
				
			||||
	cli, err := clientv3.New(clientv3.Config{
 | 
				
			||||
		Endpoints:   endpoints,
 | 
				
			||||
		DialTimeout: dialTimeout,
 | 
				
			||||
	})
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		return nil, err
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec))
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		return nil, err
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	lock := concurrency.NewMutex(sess, lockPath)
 | 
				
			||||
	// It's fine for the lock to get stuck, in this case we have
 | 
				
			||||
	// multiple master servers running (only configured to have
 | 
				
			||||
	// one master running, but split-brain problem may cause
 | 
				
			||||
	// multiple master servers running), and the cluster management
 | 
				
			||||
	// software will kill one of them.
 | 
				
			||||
	log.Info("Trying to acquire lock.", log.Ctx{"path": lockPath})
 | 
				
			||||
	err = lock.Lock(context.TODO())
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		return nil, err
 | 
				
			||||
	}
 | 
				
			||||
	log.Info("Successfully acquired lock at %s.", log.Ctx{"path": lockPath})
 | 
				
			||||
 | 
				
			||||
	put := clientv3.OpPut(addrPath, addr)
 | 
				
			||||
	resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		return nil, err
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	if !resp.Succeeded {
 | 
				
			||||
		log.Crit("No longer owns the master lock. Exiting.")
 | 
				
			||||
		panic("No longer owns the master lock. Exiting.")
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	e := &EtcdClient{
 | 
				
			||||
		lockPath:  lockPath,
 | 
				
			||||
		statePath: statePath,
 | 
				
			||||
		client:    cli,
 | 
				
			||||
		lock:      lock,
 | 
				
			||||
		sess:      sess,
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	return e, nil
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// Save saves the state into the etcd.
 | 
				
			||||
func (e *EtcdClient) Save(state []byte) error {
 | 
				
			||||
	ctx := context.TODO()
 | 
				
			||||
	put := clientv3.OpPut(e.statePath, string(state))
 | 
				
			||||
	resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		return err
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	if !resp.Succeeded {
 | 
				
			||||
		log.Error("No longer owns the lock, trying to lock again")
 | 
				
			||||
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 | 
				
			||||
		err := e.lock.Lock(ctx)
 | 
				
			||||
		cancel()
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			// We lost the master lock and can not acquire
 | 
				
			||||
			// it back, it means some other master is
 | 
				
			||||
			// already started. We don't want cluster
 | 
				
			||||
			// management system to kill the master server
 | 
				
			||||
			// who is holding the lock and running
 | 
				
			||||
			// correctly. So the most feasible solution is
 | 
				
			||||
			// to kill current master server. The current
 | 
				
			||||
			// state is not saved, but the trainer's RPC
 | 
				
			||||
			// call will fail, so the trainer will retry.
 | 
				
			||||
			log.Crit("Could not acquire the lock at %s: %v. Exiting.", log.Ctx{"path": e.lockPath, "error": err})
 | 
				
			||||
			panic("Could not acquire the lock at %s: %v. Exiting.")
 | 
				
			||||
		}
 | 
				
			||||
		log.Info("Successfully acquired lock at %s.", e.lockPath)
 | 
				
			||||
		return e.Save(state)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	return nil
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// Load loads the state from etcd.
 | 
				
			||||
func (e *EtcdClient) Load() ([]byte, error) {
 | 
				
			||||
	ctx := context.TODO()
 | 
				
			||||
	get := clientv3.OpGet(e.statePath)
 | 
				
			||||
 | 
				
			||||
	resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit()
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		return nil, err
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	if !resp.Succeeded {
 | 
				
			||||
		log.Error("No longer owns the lock, trying to lock and load again.")
 | 
				
			||||
		err = e.lock.Lock(context.Background())
 | 
				
			||||
		if err != nil {
 | 
				
			||||
			return nil, err
 | 
				
			||||
		}
 | 
				
			||||
 | 
				
			||||
		return e.Load()
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	kvs := resp.Responses[0].GetResponseRange().Kvs
 | 
				
			||||
	if len(kvs) == 0 {
 | 
				
			||||
		// No state exists
 | 
				
			||||
		return nil, nil
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	state := kvs[0].Value
 | 
				
			||||
	return state, nil
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// Shutdown shuts down the etcd client gracefully.
 | 
				
			||||
func (e *EtcdClient) Shutdown() error {
 | 
				
			||||
	err := e.sess.Close()
 | 
				
			||||
	newErr := e.client.Close()
 | 
				
			||||
	if newErr != nil {
 | 
				
			||||
		if err == nil {
 | 
				
			||||
			err = newErr
 | 
				
			||||
		} else {
 | 
				
			||||
			log.Error("shutdown error", log.Ctx{"error": newErr})
 | 
				
			||||
		}
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	return err
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// GetKey gets the value by the specify key.
 | 
				
			||||
func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
 | 
				
			||||
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
 | 
				
			||||
	resp, err := c.Get(ctx, key)
 | 
				
			||||
	cancel()
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		return "", err
 | 
				
			||||
	}
 | 
				
			||||
	kvs := resp.Kvs
 | 
				
			||||
	if len(kvs) == 0 {
 | 
				
			||||
		return "", nil
 | 
				
			||||
	}
 | 
				
			||||
	v := kvs[0].Value
 | 
				
			||||
	return string(v), nil
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// watchKey watches the specify key and send to valChan if there is some event.
 | 
				
			||||
func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
 | 
				
			||||
	rch := c.Watch(context.Background(), key)
 | 
				
			||||
	for wresp := range rch {
 | 
				
			||||
		for _, ev := range wresp.Events {
 | 
				
			||||
			// if received event is DELETE, the value will be an empty string
 | 
				
			||||
			log.Info("received event.", log.Ctx{"type": ev.Type, "key": ev.Kv.Key, "value": ev.Kv.Value})
 | 
				
			||||
			valChan <- string(ev.Kv.Value)
 | 
				
			||||
		}
 | 
				
			||||
	}
 | 
				
			||||
}
 | 
				
			||||
@ -1,47 +0,0 @@
 | 
				
			||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
// http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
package master
 | 
				
			||||
 | 
				
			||||
import "sync"
 | 
				
			||||
 | 
				
			||||
// InMemStore is an in memory implementation of Store interface.
 | 
				
			||||
//
 | 
				
			||||
// It does not tolerate the fault that causes the program to crash.
 | 
				
			||||
type InMemStore struct {
 | 
				
			||||
	mu  sync.Mutex
 | 
				
			||||
	buf []byte
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// Save saves the state into the in-memory store.
 | 
				
			||||
func (m *InMemStore) Save(state []byte) error {
 | 
				
			||||
	m.mu.Lock()
 | 
				
			||||
	defer m.mu.Unlock()
 | 
				
			||||
 | 
				
			||||
	m.buf = state
 | 
				
			||||
	return nil
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// Load loads the state from the in-memory store.
 | 
				
			||||
func (m *InMemStore) Load() ([]byte, error) {
 | 
				
			||||
	m.mu.Lock()
 | 
				
			||||
	defer m.mu.Unlock()
 | 
				
			||||
 | 
				
			||||
	return m.buf, nil
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// Shutdown shuts down the in mem store.
 | 
				
			||||
func (m *InMemStore) Shutdown() error {
 | 
				
			||||
	return nil
 | 
				
			||||
}
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -1,52 +0,0 @@
 | 
				
			||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
// http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
package master
 | 
				
			||||
 | 
				
			||||
import "testing"
 | 
				
			||||
 | 
				
			||||
func TestPartitionCount(t *testing.T) {
 | 
				
			||||
	cs := make([]Chunk, 100)
 | 
				
			||||
	ts := partition(cs, 5)
 | 
				
			||||
	if len(ts) != 20 {
 | 
				
			||||
		t.Error(len(ts))
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	cs = make([]Chunk, 101)
 | 
				
			||||
	ts = partition(cs, 5)
 | 
				
			||||
	if len(ts) != 21 {
 | 
				
			||||
		t.Error(len(ts))
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	ts = partition(cs, 1)
 | 
				
			||||
	if len(ts) != 101 {
 | 
				
			||||
		t.Error(len(ts))
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	ts = partition(cs, 0)
 | 
				
			||||
	if len(ts) != 101 {
 | 
				
			||||
		t.Error(len(ts))
 | 
				
			||||
	}
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
func TestPartionIndex(t *testing.T) {
 | 
				
			||||
	cs := make([]Chunk, 100)
 | 
				
			||||
	ts := partition(cs, 20)
 | 
				
			||||
	for i := range ts {
 | 
				
			||||
		// test auto increament ids
 | 
				
			||||
		if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 {
 | 
				
			||||
			t.Error(ts[i], i)
 | 
				
			||||
		}
 | 
				
			||||
	}
 | 
				
			||||
}
 | 
				
			||||
@ -1,72 +0,0 @@
 | 
				
			||||
package master_test
 | 
				
			||||
 | 
				
			||||
import (
 | 
				
			||||
	"io/ioutil"
 | 
				
			||||
	"net/url"
 | 
				
			||||
	"os"
 | 
				
			||||
	"strings"
 | 
				
			||||
	"testing"
 | 
				
			||||
	"time"
 | 
				
			||||
 | 
				
			||||
	"github.com/PaddlePaddle/Paddle/go/master"
 | 
				
			||||
	"github.com/coreos/etcd/clientv3"
 | 
				
			||||
	"github.com/coreos/etcd/embed"
 | 
				
			||||
	"github.com/stretchr/testify/assert"
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
func TestNewServiceWithEtcd(t *testing.T) {
 | 
				
			||||
	// setup an embed etcd server
 | 
				
			||||
	etcdDir, err := ioutil.TempDir("", "")
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		t.Fatal(err)
 | 
				
			||||
	}
 | 
				
			||||
	cfg := embed.NewConfig()
 | 
				
			||||
	lpurl, _ := url.Parse("http://localhost:0")
 | 
				
			||||
	lcurl, _ := url.Parse("http://localhost:0")
 | 
				
			||||
	cfg.LPUrls = []url.URL{*lpurl}
 | 
				
			||||
	cfg.LCUrls = []url.URL{*lcurl}
 | 
				
			||||
	cfg.Dir = etcdDir
 | 
				
			||||
	e, err := embed.StartEtcd(cfg)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		t.Fatal(err)
 | 
				
			||||
	}
 | 
				
			||||
	defer func() {
 | 
				
			||||
		e.Close()
 | 
				
			||||
		if err := os.RemoveAll(etcdDir); err != nil {
 | 
				
			||||
			t.Fatal(err)
 | 
				
			||||
		}
 | 
				
			||||
	}()
 | 
				
			||||
 | 
				
			||||
	<-e.Server.ReadyNotify()
 | 
				
			||||
 | 
				
			||||
	port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
 | 
				
			||||
	endpoint := "127.0.0.1:" + port
 | 
				
			||||
 | 
				
			||||
	ep := []string{endpoint}
 | 
				
			||||
	masterAddr := "127.0.0.1:3306"
 | 
				
			||||
	store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		t.Fatal(err)
 | 
				
			||||
	}
 | 
				
			||||
 | 
				
			||||
	_, err = master.NewService(store, 10, 10, 3)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		t.Fatal(err)
 | 
				
			||||
	}
 | 
				
			||||
	cli, err := clientv3.New(clientv3.Config{
 | 
				
			||||
		Endpoints:   ep,
 | 
				
			||||
		DialTimeout: 3 * time.Second,
 | 
				
			||||
	})
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		t.Fatal(err)
 | 
				
			||||
	}
 | 
				
			||||
	v, err := master.GetKey(cli, master.DefaultAddrPath, 3*time.Second)
 | 
				
			||||
	if err != nil {
 | 
				
			||||
		t.Fatal(err)
 | 
				
			||||
	}
 | 
				
			||||
	if err := cli.Close(); err != nil {
 | 
				
			||||
		t.Fatal(err)
 | 
				
			||||
	}
 | 
				
			||||
	// test master process registry itself into etcd server.
 | 
				
			||||
	assert.Equal(t, masterAddr, v, "master process should registry itself into etcd server.")
 | 
				
			||||
}
 | 
				
			||||
@ -1,4 +0,0 @@
 | 
				
			||||
# Ignore everything in this directory
 | 
				
			||||
*
 | 
				
			||||
# Except this file
 | 
				
			||||
!.gitignore
 | 
				
			||||
@ -1,17 +0,0 @@
 | 
				
			||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
#
 | 
				
			||||
if(WITH_TESTING)
 | 
				
			||||
  go_test(pserver_test DEPS paddle_go_optimizer gen_proto_go)
 | 
				
			||||
endif()
 | 
				
			||||
@ -1,17 +0,0 @@
 | 
				
			||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
#
 | 
				
			||||
if(WITH_TESTING)
 | 
				
			||||
  go_test(pserver_client_test DEPS paddle_go_optimizer)
 | 
				
			||||
endif()
 | 
				
			||||
@ -1 +0,0 @@
 | 
				
			||||
libpaddle_go_optimizer.a
 | 
				
			||||
Some files were not shown because too many files have changed in this diff Show More
					Loading…
					
					
				
		Reference in new issue