|
|
|
@ -3,11 +3,13 @@ package client_test
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"io/ioutil"
|
|
|
|
|
"math/rand"
|
|
|
|
|
"net"
|
|
|
|
|
"net/http"
|
|
|
|
|
"net/rpc"
|
|
|
|
|
"strconv"
|
|
|
|
|
"strings"
|
|
|
|
|
"sync"
|
|
|
|
|
"testing"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
@ -111,16 +113,23 @@ func testClient(t *testing.T, c *client.Client) {
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("read optimizer proto failed")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
|
for i := 0; i < numParameter; i++ {
|
|
|
|
|
var p pserver.Parameter
|
|
|
|
|
p.Name = "p_" + strconv.Itoa(i)
|
|
|
|
|
p.ElementType = pserver.Float32
|
|
|
|
|
p.Content = make([]byte, (i+1)*100)
|
|
|
|
|
err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
wg.Add(1)
|
|
|
|
|
go func(i int) {
|
|
|
|
|
var p pserver.Parameter
|
|
|
|
|
p.Name = "p_" + strconv.Itoa(i)
|
|
|
|
|
p.ElementType = pserver.Float32
|
|
|
|
|
p.Content = make([]byte, (i+1)*100)
|
|
|
|
|
err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
wg.Done()
|
|
|
|
|
}(i)
|
|
|
|
|
}
|
|
|
|
|
wg.Wait()
|
|
|
|
|
|
|
|
|
|
err = c.FinishInitParams()
|
|
|
|
|
if err != nil {
|
|
|
|
@ -136,9 +145,31 @@ func testClient(t *testing.T, c *client.Client) {
|
|
|
|
|
grads = append(grads, g)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err = c.SendGrads(grads)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
const paramPerGroup = 10
|
|
|
|
|
const numGroups = numParameter / paramPerGroup
|
|
|
|
|
|
|
|
|
|
// shuffle send grads order
|
|
|
|
|
for i := range grads {
|
|
|
|
|
j := rand.Intn(i + 1)
|
|
|
|
|
grads[i], grads[j] = grads[j], grads[i]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for i := 0; i < numGroups; i++ {
|
|
|
|
|
var gs []pserver.Gradient
|
|
|
|
|
if i == numGroups-1 {
|
|
|
|
|
gs = grads[i*paramPerGroup:]
|
|
|
|
|
} else {
|
|
|
|
|
gs = grads[i*paramPerGroup : (i+1)*paramPerGroup]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
wg.Add(1)
|
|
|
|
|
go func(gs []pserver.Gradient) {
|
|
|
|
|
err = c.SendGrads(gs)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
wg.Done()
|
|
|
|
|
}(gs)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
names := make([]string, numParameter)
|
|
|
|
@ -146,20 +177,35 @@ func testClient(t *testing.T, c *client.Client) {
|
|
|
|
|
names[i] = "p_" + strconv.Itoa(i)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
params, err := c.GetParams(names)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
for i := 0; i < numGroups; i++ {
|
|
|
|
|
var ns []string
|
|
|
|
|
if i == numGroups-1 {
|
|
|
|
|
ns = names[i*paramPerGroup:]
|
|
|
|
|
} else {
|
|
|
|
|
ns = names[i*paramPerGroup : (i+1)*paramPerGroup]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if len(names) != len(params) {
|
|
|
|
|
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
|
|
|
|
|
}
|
|
|
|
|
wg.Add(1)
|
|
|
|
|
go func(ns []string) {
|
|
|
|
|
params, err := c.GetParams(ns)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for i := range params {
|
|
|
|
|
if names[i] != params[i].Name {
|
|
|
|
|
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name)
|
|
|
|
|
}
|
|
|
|
|
if len(ns) != len(params) {
|
|
|
|
|
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for i := range params {
|
|
|
|
|
if ns[i] != params[i].Name {
|
|
|
|
|
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", ns[i], params[i].Name)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
wg.Done()
|
|
|
|
|
}(ns)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
wg.Wait()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestNativeClient(t *testing.T) {
|
|
|
|
|