You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
166 lines
3.0 KiB
166 lines
3.0 KiB
8 years ago
|
package pserver_test
|
||
|
|
||
|
import (
|
||
|
"reflect"
|
||
|
"sync"
|
||
|
"testing"
|
||
|
|
||
8 years ago
|
"github.com/PaddlePaddle/Paddle/go/pserver"
|
||
8 years ago
|
)
|
||
|
|
||
|
func TestFull(t *testing.T) {
|
||
|
s := pserver.NewService()
|
||
|
var dummy int
|
||
|
err := s.BeginInitParams(nil, &dummy)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
var p pserver.Parameter
|
||
|
p.Name = "param_a"
|
||
|
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
|
||
|
p.ElementType = pserver.Int32
|
||
|
err = s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
var p1 pserver.Parameter
|
||
|
p1.Name = "param_b"
|
||
|
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||
|
p1.ElementType = pserver.Float32
|
||
|
err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, &dummy)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
err = s.FinishInitParams(0, &dummy)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
var params []pserver.Parameter
|
||
|
err = s.GetParams([]string{"param_b", "param_a"}, ¶ms)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
if len(params) != 2 || !reflect.DeepEqual(params[0], p1) || !reflect.DeepEqual(params[0], p1) {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
grads := []pserver.Gradient{pserver.Gradient(p1), pserver.Gradient(p)}
|
||
|
err = s.SendGrads(grads, &dummy)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
var params1 []pserver.Parameter
|
||
|
err = s.GetParams([]string{"param_b", "param_a"}, ¶ms1)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
if len(params) != 2 {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
8 years ago
|
// don't compare content, since it's already changed by
|
||
|
// gradient update.
|
||
8 years ago
|
params1[0].Content = nil
|
||
|
params1[0].Content = nil
|
||
|
p.Content = nil
|
||
|
p1.Content = nil
|
||
|
|
||
|
if !reflect.DeepEqual(params1[0], p1) || !reflect.DeepEqual(params1[0], p1) {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMultipleInit(t *testing.T) {
|
||
|
s := pserver.NewService()
|
||
|
var dummy int
|
||
|
err := s.BeginInitParams(nil, &dummy)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
// this is fine, it's possible for client to call init
|
||
|
// multiple times.
|
||
|
err = s.BeginInitParams(nil, &dummy)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
err = s.FinishInitParams(0, &dummy)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
err = s.FinishInitParams(0, &dummy)
|
||
8 years ago
|
if err != pserver.ErrAlreadyInitialized {
|
||
8 years ago
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
err = s.BeginInitParams(nil, &dummy)
|
||
8 years ago
|
if err != pserver.ErrAlreadyInitialized {
|
||
8 years ago
|
t.FailNow()
|
||
|
}
|
||
|
}
|
||
|
|
||
8 years ago
|
func TestUninitialized(t *testing.T) {
|
||
|
s := pserver.NewService()
|
||
|
var dummy int
|
||
|
err := s.SendGrads(nil, &dummy)
|
||
|
if err != pserver.ErrUninitialized {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
}
|
||
|
|
||
8 years ago
|
func TestBlockUntilInitialized(t *testing.T) {
|
||
|
s := pserver.NewService()
|
||
8 years ago
|
ch := make(chan struct{}, 2)
|
||
8 years ago
|
var wg sync.WaitGroup
|
||
|
wg.Add(1)
|
||
|
go func() {
|
||
|
var params []pserver.Parameter
|
||
|
err := s.GetParams(nil, ¶ms)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
wg.Done()
|
||
8 years ago
|
ch <- struct{}{}
|
||
8 years ago
|
}()
|
||
|
|
||
|
wg.Add(1)
|
||
|
go func() {
|
||
|
var dummy int
|
||
8 years ago
|
err := s.Save("", &dummy)
|
||
8 years ago
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
wg.Done()
|
||
8 years ago
|
ch <- struct{}{}
|
||
8 years ago
|
}()
|
||
|
|
||
|
var dummy int
|
||
|
err := s.BeginInitParams(nil, &dummy)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
8 years ago
|
select {
|
||
|
case <-ch:
|
||
|
// some function returned before initialization is completed.
|
||
|
t.FailNow()
|
||
|
default:
|
||
|
}
|
||
|
|
||
8 years ago
|
err = s.FinishInitParams(0, &dummy)
|
||
|
if err != nil {
|
||
|
t.FailNow()
|
||
|
}
|
||
|
|
||
|
wg.Wait()
|
||
|
}
|