|
|
@ -3,11 +3,13 @@ package client_test
|
|
|
|
import (
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"context"
|
|
|
|
"io/ioutil"
|
|
|
|
"io/ioutil"
|
|
|
|
|
|
|
|
"math/rand"
|
|
|
|
"net"
|
|
|
|
"net"
|
|
|
|
"net/http"
|
|
|
|
"net/http"
|
|
|
|
"net/rpc"
|
|
|
|
"net/rpc"
|
|
|
|
"strconv"
|
|
|
|
"strconv"
|
|
|
|
"strings"
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
"sync"
|
|
|
|
"testing"
|
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
|
@ -100,18 +102,22 @@ func (l lister) List() []client.Server {
|
|
|
|
return l
|
|
|
|
return l
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func ClientTest(t *testing.T, c *client.Client) {
|
|
|
|
func testClient(t *testing.T, c *client.Client) {
|
|
|
|
selected := c.BeginInitParams()
|
|
|
|
selected := c.BeginInitParams()
|
|
|
|
if !selected {
|
|
|
|
if !selected {
|
|
|
|
t.Fatal("should be selected.")
|
|
|
|
t.Fatal("should be selected.")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const numParameter = 100
|
|
|
|
const numParameter = 1000
|
|
|
|
config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
|
|
|
|
config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("read optimizer proto failed")
|
|
|
|
t.Fatalf("read optimizer proto failed")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
for i := 0; i < numParameter; i++ {
|
|
|
|
for i := 0; i < numParameter; i++ {
|
|
|
|
|
|
|
|
wg.Add(1)
|
|
|
|
|
|
|
|
go func(i int) {
|
|
|
|
var p pserver.Parameter
|
|
|
|
var p pserver.Parameter
|
|
|
|
p.Name = "p_" + strconv.Itoa(i)
|
|
|
|
p.Name = "p_" + strconv.Itoa(i)
|
|
|
|
p.ElementType = pserver.Float32
|
|
|
|
p.ElementType = pserver.Float32
|
|
|
@ -120,7 +126,10 @@ func ClientTest(t *testing.T, c *client.Client) {
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
wg.Done()
|
|
|
|
|
|
|
|
}(i)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
wg.Wait()
|
|
|
|
|
|
|
|
|
|
|
|
err = c.FinishInitParams()
|
|
|
|
err = c.FinishInitParams()
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
@ -128,7 +137,7 @@ func ClientTest(t *testing.T, c *client.Client) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
var grads []pserver.Gradient
|
|
|
|
var grads []pserver.Gradient
|
|
|
|
for i := 0; i < numParameter/2; i++ {
|
|
|
|
for i := 0; i < numParameter; i++ {
|
|
|
|
var g pserver.Gradient
|
|
|
|
var g pserver.Gradient
|
|
|
|
g.Name = "p_" + strconv.Itoa(i)
|
|
|
|
g.Name = "p_" + strconv.Itoa(i)
|
|
|
|
g.ElementType = pserver.Float32
|
|
|
|
g.ElementType = pserver.Float32
|
|
|
@ -136,30 +145,67 @@ func ClientTest(t *testing.T, c *client.Client) {
|
|
|
|
grads = append(grads, g)
|
|
|
|
grads = append(grads, g)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
err = c.SendGrads(grads)
|
|
|
|
const paramPerGroup = 10
|
|
|
|
|
|
|
|
const numGroups = numParameter / paramPerGroup
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// shuffle send grads order
|
|
|
|
|
|
|
|
for i := range grads {
|
|
|
|
|
|
|
|
j := rand.Intn(i + 1)
|
|
|
|
|
|
|
|
grads[i], grads[j] = grads[j], grads[i]
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i := 0; i < numGroups; i++ {
|
|
|
|
|
|
|
|
var gs []pserver.Gradient
|
|
|
|
|
|
|
|
if i == numGroups-1 {
|
|
|
|
|
|
|
|
gs = grads[i*paramPerGroup:]
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
gs = grads[i*paramPerGroup : (i+1)*paramPerGroup]
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wg.Add(1)
|
|
|
|
|
|
|
|
go func(gs []pserver.Gradient) {
|
|
|
|
|
|
|
|
err = c.SendGrads(gs)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
wg.Done()
|
|
|
|
|
|
|
|
}(gs)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
names := make([]string, numParameter)
|
|
|
|
names := make([]string, numParameter)
|
|
|
|
for i := 0; i < numParameter; i++ {
|
|
|
|
for i := 0; i < numParameter; i++ {
|
|
|
|
names[i] = "p_" + strconv.Itoa(i)
|
|
|
|
names[i] = "p_" + strconv.Itoa(i)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
params, err := c.GetParams(names)
|
|
|
|
for i := 0; i < numGroups; i++ {
|
|
|
|
|
|
|
|
var ns []string
|
|
|
|
|
|
|
|
if i == numGroups-1 {
|
|
|
|
|
|
|
|
ns = names[i*paramPerGroup:]
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
ns = names[i*paramPerGroup : (i+1)*paramPerGroup]
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wg.Add(1)
|
|
|
|
|
|
|
|
go func(ns []string) {
|
|
|
|
|
|
|
|
params, err := c.GetParams(ns)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if len(names) != len(params) {
|
|
|
|
if len(ns) != len(params) {
|
|
|
|
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
|
|
|
|
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for i := range params {
|
|
|
|
for i := range params {
|
|
|
|
if names[i] != params[i].Name {
|
|
|
|
if ns[i] != params[i].Name {
|
|
|
|
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name)
|
|
|
|
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", ns[i], params[i].Name)
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
wg.Done()
|
|
|
|
|
|
|
|
}(ns)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wg.Wait()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestNativeClient(t *testing.T) {
|
|
|
|
func TestNativeClient(t *testing.T) {
|
|
|
@ -169,13 +215,14 @@ func TestNativeClient(t *testing.T) {
|
|
|
|
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
|
|
|
|
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
c1 := client.NewClient(lister(servers), len(servers), selector(true))
|
|
|
|
c1 := client.NewClient(lister(servers), len(servers), selector(true))
|
|
|
|
ClientTest(t, c1)
|
|
|
|
testClient(t, c1)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TODO: tmperary disable etcdClient test for dependency of etcd)
|
|
|
|
// EtcdClient is a disabled test, since we have not embedded etcd into
|
|
|
|
|
|
|
|
// our test.
|
|
|
|
func EtcdClient(t *testing.T) {
|
|
|
|
func EtcdClient(t *testing.T) {
|
|
|
|
initEtcdClient()
|
|
|
|
initEtcdClient()
|
|
|
|
etcdClient := client.NewEtcd(etcdEndpoints)
|
|
|
|
etcdClient := client.NewEtcd(etcdEndpoints)
|
|
|
|
c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true))
|
|
|
|
c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true))
|
|
|
|
ClientTest(t, c2)
|
|
|
|
testClient(t, c2)
|
|
|
|
}
|
|
|
|
}
|
|
|
|