@ -3,11 +3,13 @@ package client_test
import (
"context"
"io/ioutil"
"math/rand"
"net"
"net/http"
"net/rpc"
"strconv"
"strings"
"sync"
"testing"
"time"
@ -77,15 +79,33 @@ func initEtcdClient() {
log . Errorf ( "err %v" , err )
}
ctx , cancel := context . WithTimeout ( context . Background ( ) , timeout )
client . Delete ( ctx , pserver . PsDesired )
client . Delete ( ctx , pserver . PsPath )
client . Put ( ctx , pserver . PsDesired , strconv . Itoa ( numPserver ) )
_ , err = client . Delete ( ctx , pserver . PsDesired )
if err != nil {
panic ( err )
}
_ , err = client . Delete ( ctx , pserver . PsPath )
if err != nil {
panic ( err )
}
_ , err = client . Put ( ctx , pserver . PsDesired , strconv . Itoa ( numPserver ) )
if err != nil {
panic ( err )
}
ports := initClient ( )
for i := 0 ; i < numPserver ; i ++ {
client . Put ( ctx , pserver . PsPath + strconv . Itoa ( i ) , ":" + strconv . Itoa ( ports [ i ] ) )
_ , err = client . Put ( ctx , pserver . PsPath + strconv . Itoa ( i ) , ":" + strconv . Itoa ( ports [ i ] ) )
if err != nil {
panic ( err )
}
}
cancel ( )
client . Close ( )
err = client . Close ( )
if err != nil {
panic ( err )
}
}
type selector bool
@ -100,27 +120,34 @@ func (l lister) List() []client.Server {
return l
}
func ClientTes t( t * testing . T , c * client . Client ) {
func test Client( t * testing . T , c * client . Client ) {
selected := c . BeginInitParams ( )
if ! selected {
t . Fatal ( "should be selected." )
}
const numParameter = 100
const numParameter = 100 0
config , err := ioutil . ReadFile ( "./c/test/testdata/optimizer.pb" )
if err != nil {
t . Fatalf ( "read optimizer proto failed" )
}
var wg sync . WaitGroup
for i := 0 ; i < numParameter ; i ++ {
var p pserver . Parameter
p . Name = "p_" + strconv . Itoa ( i )
p . ElementType = pserver . Float32
p . Content = make ( [ ] byte , ( i + 1 ) * 100 )
err := c . InitParam ( pserver . ParameterWithConfig { Param : p , Config : config } )
if err != nil {
t . Fatal ( err )
}
wg . Add ( 1 )
go func ( i int ) {
var p pserver . Parameter
p . Name = "p_" + strconv . Itoa ( i )
p . ElementType = pserver . Float32
p . Content = make ( [ ] byte , ( i + 1 ) * 100 )
err := c . InitParam ( pserver . ParameterWithConfig { Param : p , Config : config } )
if err != nil {
t . Fatal ( err )
}
wg . Done ( )
} ( i )
}
wg . Wait ( )
err = c . FinishInitParams ( )
if err != nil {
@ -128,7 +155,7 @@ func ClientTest(t *testing.T, c *client.Client) {
}
var grads [ ] pserver . Gradient
for i := 0 ; i < numParameter / 2 ; i ++ {
for i := 0 ; i < numParameter ; i ++ {
var g pserver . Gradient
g . Name = "p_" + strconv . Itoa ( i )
g . ElementType = pserver . Float32
@ -136,9 +163,31 @@ func ClientTest(t *testing.T, c *client.Client) {
grads = append ( grads , g )
}
err = c . SendGrads ( grads )
if err != nil {
t . Fatal ( err )
const paramPerGroup = 10
const numGroups = numParameter / paramPerGroup
// shuffle send grads order
for i := range grads {
j := rand . Intn ( i + 1 )
grads [ i ] , grads [ j ] = grads [ j ] , grads [ i ]
}
for i := 0 ; i < numGroups ; i ++ {
var gs [ ] pserver . Gradient
if i == numGroups - 1 {
gs = grads [ i * paramPerGroup : ]
} else {
gs = grads [ i * paramPerGroup : ( i + 1 ) * paramPerGroup ]
}
wg . Add ( 1 )
go func ( gs [ ] pserver . Gradient ) {
err := c . SendGrads ( gs )
if err != nil {
t . Fatal ( err )
}
wg . Done ( )
} ( gs )
}
names := make ( [ ] string , numParameter )
@ -146,20 +195,35 @@ func ClientTest(t *testing.T, c *client.Client) {
names [ i ] = "p_" + strconv . Itoa ( i )
}
params , err := c . GetParams ( names )
if err != nil {
t . Fatal ( err )
}
for i := 0 ; i < numGroups ; i ++ {
var ns [ ] string
if i == numGroups - 1 {
ns = names [ i * paramPerGroup : ]
} else {
ns = names [ i * paramPerGroup : ( i + 1 ) * paramPerGroup ]
}
if len ( names ) != len ( params ) {
t . Fatalf ( "parameter size not match, need: %d, have: %d" , len ( names ) , len ( params ) )
}
wg . Add ( 1 )
go func ( ns [ ] string ) {
params , err := c . GetParams ( ns )
if err != nil {
t . Fatal ( err )
}
for i := range params {
if names [ i ] != params [ i ] . Name {
t . Fatalf ( "order of returned parameter does not required: parameter name: %s, required name: %s" , names [ i ] , params [ i ] . Name )
}
if len ( ns ) != len ( params ) {
t . Fatalf ( "parameter size not match, need: %d, have: %d" , len ( names ) , len ( params ) )
}
for i := range params {
if ns [ i ] != params [ i ] . Name {
t . Fatalf ( "order of returned parameter does not required: parameter name: %s, required name: %s" , ns [ i ] , params [ i ] . Name )
}
}
wg . Done ( )
} ( ns )
}
wg . Wait ( )
}
func TestNativeClient ( t * testing . T ) {
@ -169,13 +233,14 @@ func TestNativeClient(t *testing.T) {
servers [ i ] = client . Server { Index : i , Addr : ":" + strconv . Itoa ( pserverClientPorts [ i ] ) }
}
c1 := client . NewClient ( lister ( servers ) , len ( servers ) , selector ( true ) )
ClientTes t( t , c1 )
test Client( t , c1 )
}
// TODO: tmperary disable etcdClient test for dependency of etcd)
// EtcdClient is a disabled test, since we have not embedded etcd into
// our test.
func EtcdClient ( t * testing . T ) {
initEtcdClient ( )
etcdClient := client . NewEtcd ( etcdEndpoints )
c2 := client . NewClient ( etcdClient , etcdClient . Desired ( ) , selector ( true ) )
ClientTes t( t , c2 )
test Client( t , c2 )
}