Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into variable_input
commit
7430d30598
@ -0,0 +1,38 @@
|
||||
FROM ubuntu:16.04
|
||||
MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
|
||||
|
||||
ARG UBUNTU_MIRROR
|
||||
RUN /bin/bash -c 'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i 's#http://archive.ubuntu.com/ubuntu#${UBUNTU_MIRROR}#g' /etc/apt/sources.list; fi'
|
||||
|
||||
ENV HOME=/root \
|
||||
ANDROID_NDK_HOME=/opt/android-ndk-linux \
|
||||
ANDROID_STANDALONE_TOOLCHAIN=/opt/android-toolchain-gcc
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
git python-dev python-pip python-numpy \
|
||||
wget curl tar unzip gcc g++ locales clang-format-3.8 swig cmake && \
|
||||
apt-get clean -y
|
||||
|
||||
# git credential to skip password typing
|
||||
RUN git config --global credential.helper store
|
||||
|
||||
# Fix locales to en_US.UTF-8
|
||||
RUN localedef -i en_US -f UTF-8 en_US.UTF-8
|
||||
|
||||
RUN pip install --upgrade pip && \
|
||||
pip install -U 'protobuf==3.1.0' && \
|
||||
pip install -U wheel sphinx && \
|
||||
pip install pre-commit
|
||||
|
||||
# Android NDK
|
||||
RUN mkdir /opt/android-ndk-tmp && \
|
||||
cd /opt/android-ndk-tmp && \
|
||||
wget -q https://dl.google.com/android/repository/android-ndk-r14b-linux-x86_64.zip && \
|
||||
unzip -q android-ndk-r14b-linux-x86_64.zip && \
|
||||
mv android-ndk-r14b ${ANDROID_NDK_HOME} && \
|
||||
${ANDROID_NDK_HOME}/build/tools/make-standalone-toolchain.sh --arch=arm --platform=android-21 --install-dir=${ANDROID_STANDALONE_TOOLCHAIN} && \
|
||||
rm -rf /opt/android-ndk-tmp && \
|
||||
rm -rf ${ANDROID_NDK_HOME}
|
||||
|
||||
CMD ["bash", "/paddle/paddle/scripts/docker/build_android.sh"]
|
@ -0,0 +1,46 @@
|
||||
if(NOT CMAKE_Go_COMPILER)
|
||||
if(NOT $ENV{GO_COMPILER} STREQUAL "")
|
||||
get_filename_component(CMAKE_Go_COMPILER_INIT $ENV{GO_COMPILER} PROGRAM PROGRAM_ARGS CMAKE_Go_FLAGS_ENV_INIT)
|
||||
|
||||
if(CMAKE_Go_FLAGS_ENV_INIT)
|
||||
set(CMAKE_Go_COMPILER_ARG1 "${CMAKE_Go_FLAGS_ENV_INIT}" CACHE STRING "First argument to Go compiler")
|
||||
endif()
|
||||
|
||||
if(NOT EXISTS ${CMAKE_Go_COMPILER_INIT})
|
||||
message(SEND_ERROR "Could not find compiler set in environment variable GO_COMPILER:\n$ENV{GO_COMPILER}.")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
set(Go_BIN_PATH
|
||||
$ENV{GOPATH}
|
||||
$ENV{GOROOT}
|
||||
$ENV{GOROOT}/bin
|
||||
$ENV{GO_COMPILER}
|
||||
/usr/bin
|
||||
/usr/local/bin
|
||||
)
|
||||
|
||||
if(CMAKE_Go_COMPILER_INIT)
|
||||
set(CMAKE_Go_COMPILER ${CMAKE_Go_COMPILER_INIT} CACHE PATH "Go Compiler")
|
||||
else()
|
||||
find_program(CMAKE_Go_COMPILER
|
||||
NAMES go
|
||||
PATHS ${Go_BIN_PATH}
|
||||
)
|
||||
if(CMAKE_Go_COMPILER)
|
||||
EXEC_PROGRAM(${CMAKE_Go_COMPILER} ARGS version OUTPUT_VARIABLE GOLANG_VERSION)
|
||||
STRING(REGEX MATCH "go[0-9]+[.0-9]*[ /A-Za-z0-9]*" VERSION "${GOLANG_VERSION}")
|
||||
message("-- The Golang compiler identification is ${VERSION}")
|
||||
message("-- Check for working Golang compiler: ${CMAKE_Go_COMPILER}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
mark_as_advanced(CMAKE_Go_COMPILER)
|
||||
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/CMakeGoCompiler.cmake.in
|
||||
${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY)
|
||||
|
||||
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
|
@ -0,0 +1,8 @@
|
||||
set(CMAKE_Go_COMPILER "@CMAKE_Go_COMPILER@")
|
||||
set(CMAKE_Go_COMPILER_LOADED 1)
|
||||
|
||||
set(CMAKE_Go_SOURCE_FILE_EXTENSIONS go)
|
||||
set(CMAKE_Go_LINKER_PREFERENCE 40)
|
||||
set(CMAKE_Go_OUTPUT_EXTENSION .o)
|
||||
set(CMAKE_Go_OUTPUT_EXTENSION_REPLACE 1)
|
||||
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
|
@ -0,0 +1,7 @@
|
||||
if(NOT CMAKE_Go_COMPILE_OBJECT)
|
||||
set(CMAKE_Go_COMPILE_OBJECT "go tool compile -l -N -o <OBJECT> <SOURCE> ")
|
||||
endif()
|
||||
|
||||
if(NOT CMAKE_Go_LINK_EXECUTABLE)
|
||||
set(CMAKE_Go_LINK_EXECUTABLE "go tool link -o <TARGET> <OBJECTS> ")
|
||||
endif()
|
@ -0,0 +1 @@
|
||||
set(CMAKE_Go_COMPILER_WORKS 1 CACHE INTERNAL "")
|
@ -0,0 +1,9 @@
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
go_library(adder SRCS adder.go)
|
||||
|
||||
cc_test(cgo_test
|
||||
SRCS
|
||||
cgo_test.cc
|
||||
DEPS
|
||||
adder)
|
@ -0,0 +1,10 @@
|
||||
package main
|
||||
|
||||
import "C"
|
||||
|
||||
//export GoAdder
|
||||
func GoAdder(x, y int) int {
|
||||
return x + y
|
||||
}
|
||||
|
||||
func main() {} // Required but ignored
|
@ -1,8 +1,8 @@
|
||||
cmake_minimum_required(VERSION 3.0)
|
||||
|
||||
include_directories(/env/gopath/src/github.com/PaddlePaddle/Paddle/paddle/go/cclient/build/)
|
||||
include_directories(${CMAKE_BINARY_DIR})
|
||||
|
||||
add_executable(main main.c)
|
||||
add_dependencies(main client)
|
||||
set (CMAKE_EXE_LINKER_FLAGS "-pthread")
|
||||
target_link_libraries(main /env/gopath/src/github.com/PaddlePaddle/Paddle/paddle/go/cclient/build/libclient.a) # ${GTEST_LIBRARIES})
|
||||
target_link_libraries(main ${CMAKE_BINARY_DIR}/libclient.a)
|
||||
|
@ -0,0 +1,5 @@
|
||||
#include <iostream>
|
||||
#include "gtest/gtest.h"
|
||||
#include "libadder.h"
|
||||
|
||||
TEST(Cgo, Invoke) { EXPECT_EQ(GoAdder(30, 12), 42); }
|
@ -0,0 +1 @@
|
||||
pserver
|
@ -0,0 +1,33 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"strconv"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := flag.Int("p", 0, "port of the pserver")
|
||||
flag.Parse()
|
||||
|
||||
s := pserver.NewService()
|
||||
err := rpc.Register(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rpc.HandleHTTP()
|
||||
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = http.Serve(l, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "optimizer.h"
|
||||
|
||||
typedef int (*update_func)(void*, void*, paddle_element_type, const void*, int);
|
||||
typedef void (*release_func)(void*);
|
||||
|
||||
typedef struct paddle_optimizer {
|
||||
update_func update;
|
||||
release_func release;
|
||||
void* optimizer;
|
||||
} paddle_optimizer;
|
||||
|
||||
void paddle_release_optimizer(paddle_optimizer* o) {
|
||||
o->release(o->optimizer);
|
||||
free(o);
|
||||
}
|
||||
|
||||
int paddle_update_parameter(paddle_optimizer* o,
|
||||
void* buffer,
|
||||
paddle_element_type element_type,
|
||||
const void* gradient,
|
||||
int num_bytes) {
|
||||
return o->update(o->optimizer, buffer, element_type, gradient, num_bytes);
|
||||
}
|
||||
|
||||
typedef struct { double learning_rate; } SGD_optimizer;
|
||||
|
||||
int update_SGD(void* optimizer,
|
||||
void* buffer,
|
||||
paddle_element_type element_type,
|
||||
const void* gradient,
|
||||
int num_bytes) {
|
||||
SGD_optimizer* o = (SGD_optimizer*)optimizer;
|
||||
// TODO
|
||||
return 0;
|
||||
}
|
||||
|
||||
void release_SGD(void* optimizer) {
|
||||
SGD_optimizer* o = (SGD_optimizer*)optimizer;
|
||||
// nothing allocated on heap
|
||||
}
|
||||
|
||||
paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) {
|
||||
SGD_optimizer* impl = (SGD_optimizer*)malloc(sizeof(SGD_optimizer));
|
||||
impl->learning_rate = learning_rate;
|
||||
paddle_optimizer* opt = (paddle_optimizer*)malloc(sizeof(paddle_optimizer));
|
||||
opt->update = update_SGD;
|
||||
opt->release = release_SGD;
|
||||
opt->optimizer = impl;
|
||||
return opt;
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
package pserver
|
||||
|
||||
/*
|
||||
#include "optimizer.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type optimizerType int
|
||||
|
||||
const (
|
||||
sgd optimizerType = iota
|
||||
)
|
||||
|
||||
var nullPtr = unsafe.Pointer(uintptr(0))
|
||||
|
||||
type optimizer struct {
|
||||
opt *C.struct_paddle_optimizer
|
||||
}
|
||||
|
||||
func newOptimizer(t optimizerType, learning_rate float64) *optimizer {
|
||||
o := &optimizer{}
|
||||
o.opt = C.paddle_create_SGD_optimizer(C.double(learning_rate))
|
||||
return o
|
||||
}
|
||||
|
||||
func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error {
|
||||
if len(p.Content) != len(g.Content) {
|
||||
return fmt.Errorf("parameter and gradient length not match, parameter: %d, gradient: %d", len(p.Content), len(g.Content))
|
||||
}
|
||||
|
||||
if p.ElementType != g.ElementType {
|
||||
return fmt.Errorf("parameter and gradient element type not match, parameter: %v, gradient: %v", p.ElementType, g.ElementType)
|
||||
}
|
||||
|
||||
r := C.paddle_update_parameter(o.opt, unsafe.Pointer(&p.Content[0]), C.paddle_element_type(p.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content)))
|
||||
if r != 0 {
|
||||
return fmt.Errorf("optimizer update returned error code: %d", r)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *optimizer) Cleanup() {
|
||||
if unsafe.Pointer(o.opt) != nullPtr {
|
||||
C.paddle_release_optimizer(o.opt)
|
||||
o.opt = (*C.struct_paddle_optimizer)(nullPtr)
|
||||
}
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
#ifndef PADDLE_PSERVER_OPTIMIZER_H
|
||||
#define PADDLE_PSERVER_OPTIMIZER_H
|
||||
|
||||
typedef enum {
|
||||
PADDLE_ELEMENT_TYPE_INT32 = 0,
|
||||
PADDLE_ELEMENT_TYPE_UINT32 = 1,
|
||||
PADDLE_ELEMENT_TYPE_INT64 = 2,
|
||||
PADDLE_ELEMENT_TYPE_UINT64 = 3,
|
||||
PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
|
||||
PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
|
||||
} paddle_element_type;
|
||||
|
||||
struct paddle_optimizer;
|
||||
struct paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate);
|
||||
void paddle_release_optimizer(struct paddle_optimizer* o);
|
||||
int paddle_update_parameter(struct paddle_optimizer* o,
|
||||
void* buffer,
|
||||
paddle_element_type element_type,
|
||||
const void* gradient,
|
||||
int num_bytes);
|
||||
|
||||
#endif /* PADDLE_PSERVER_OPTIMIZER_H */
|
@ -0,0 +1,8 @@
|
||||
package pserver
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSGDCreateRelease(t *testing.T) {
|
||||
o := newOptimizer(sgd, 1)
|
||||
o.Cleanup()
|
||||
}
|
@ -0,0 +1,190 @@
|
||||
package pserver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ElementType is the type of elements of a Parameter.
|
||||
type ElementType int
|
||||
|
||||
var ErrAlreadyInitialized = errors.New("pserver already initialized")
|
||||
var ErrUninitialized = errors.New("pserver not fully initialized")
|
||||
|
||||
// Supported element types
|
||||
const (
|
||||
Int32 ElementType = iota
|
||||
UInt32
|
||||
Int64
|
||||
UInt64
|
||||
Float32
|
||||
Float64
|
||||
)
|
||||
|
||||
// Parameter is a piece of data to sync with the parameter server.
|
||||
type Parameter struct {
|
||||
Name string
|
||||
ElementType ElementType
|
||||
Content []byte
|
||||
}
|
||||
|
||||
// ParameterWithConfig contains the parameter and the configuration.
|
||||
type ParameterWithConfig struct {
|
||||
Param Parameter
|
||||
Config []byte // parameter configuration in Proto Buffer format
|
||||
}
|
||||
|
||||
// Gradient is the gradient of the parameter.
|
||||
type Gradient Parameter
|
||||
|
||||
// Service is the RPC service for pserver.
|
||||
type Service struct {
|
||||
initialized chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
opt *optimizer
|
||||
paramMap map[string]Parameter
|
||||
}
|
||||
|
||||
// NewService creates a new service.
|
||||
func NewService() *Service {
|
||||
s := &Service{}
|
||||
s.paramMap = make(map[string]Parameter)
|
||||
s.initialized = make(chan struct{})
|
||||
return s
|
||||
}
|
||||
|
||||
// BeginInitParams tells the parameter server that the parameter
|
||||
// initialization has begun.
|
||||
func (s *Service) BeginInitParams(config []byte, dummy *int) error {
|
||||
select {
|
||||
case <-s.initialized:
|
||||
return ErrAlreadyInitialized
|
||||
default:
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.opt != nil {
|
||||
s.opt.Cleanup()
|
||||
}
|
||||
|
||||
// TODO(helin): parse learning rate from config
|
||||
s.opt = newOptimizer(sgd, 0.01)
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitParam initializes a parameter.
|
||||
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
|
||||
select {
|
||||
case <-s.initialized:
|
||||
return ErrAlreadyInitialized
|
||||
default:
|
||||
}
|
||||
|
||||
// TODO(helin): parse parameter config
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// TODO(helin): check if paramWithConfigs.Param.Content is
|
||||
// properly memory aligned, if not, make copy to a memory
|
||||
// aligned region.
|
||||
s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param
|
||||
return nil
|
||||
}
|
||||
|
||||
// FinishInitParams tells the parameter server that the parameter
|
||||
// initialization has finished.
|
||||
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
|
||||
select {
|
||||
case <-s.initialized:
|
||||
return ErrAlreadyInitialized
|
||||
default:
|
||||
}
|
||||
|
||||
close(s.initialized)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendGrads sends gradients to parameter servers for parameter
|
||||
// optimization.
|
||||
func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
|
||||
select {
|
||||
case <-s.initialized:
|
||||
default:
|
||||
return ErrUninitialized
|
||||
}
|
||||
|
||||
count := len(grads)
|
||||
if count == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, g := range grads {
|
||||
if _, ok := s.paramMap[g.Name]; !ok {
|
||||
return fmt.Errorf("parameter: %s does not exist", g.Name)
|
||||
}
|
||||
}
|
||||
|
||||
errCh := make(chan error, count)
|
||||
for _, g := range grads {
|
||||
go func(p Parameter, g Gradient) {
|
||||
err := s.opt.UpdateParameter(p, g)
|
||||
errCh <- err
|
||||
}(s.paramMap[g.Name], g)
|
||||
}
|
||||
|
||||
recv := 0
|
||||
for err := range errCh {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recv++
|
||||
if recv == count {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetParams gets parameters from the parameter server.
|
||||
func (s *Service) GetParams(names []string, parameters *[]Parameter) error {
|
||||
<-s.initialized
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, n := range names {
|
||||
if _, ok := s.paramMap[n]; !ok {
|
||||
return fmt.Errorf("parameter: %s does not exist", n)
|
||||
}
|
||||
}
|
||||
|
||||
*parameters = make([]Parameter, len(names))
|
||||
for i, n := range names {
|
||||
// The parameter content (a byte slice) may change
|
||||
// during RPC serialization due to write from other
|
||||
// goroutine, we allow it since mini-batch based deep
|
||||
// learning optimization methods are stochastic in
|
||||
// nature. This race condition is allowed deliberately
|
||||
// to save the program from making a copy of the
|
||||
// paramter content.
|
||||
(*parameters)[i] = s.paramMap[n]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save tells the parameter server to save parameters.
|
||||
func (s *Service) Save(path string, dummy *int) error {
|
||||
<-s.initialized
|
||||
|
||||
// TODO
|
||||
return nil
|
||||
}
|
@ -0,0 +1,165 @@
|
||||
package pserver_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
|
||||
)
|
||||
|
||||
func TestFull(t *testing.T) {
|
||||
s := pserver.NewService()
|
||||
var dummy int
|
||||
err := s.BeginInitParams(nil, &dummy)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
var p pserver.Parameter
|
||||
p.Name = "param_a"
|
||||
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
|
||||
p.ElementType = pserver.Int32
|
||||
err = s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
var p1 pserver.Parameter
|
||||
p1.Name = "param_b"
|
||||
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
p1.ElementType = pserver.Float32
|
||||
err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, &dummy)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
err = s.FinishInitParams(0, &dummy)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
var params []pserver.Parameter
|
||||
err = s.GetParams([]string{"param_b", "param_a"}, ¶ms)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if len(params) != 2 || !reflect.DeepEqual(params[0], p1) || !reflect.DeepEqual(params[0], p1) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
grads := []pserver.Gradient{pserver.Gradient(p1), pserver.Gradient(p)}
|
||||
err = s.SendGrads(grads, &dummy)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
var params1 []pserver.Parameter
|
||||
err = s.GetParams([]string{"param_b", "param_a"}, ¶ms1)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if len(params) != 2 {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// don't compare content, since it's already changed by
|
||||
// gradient update.
|
||||
params1[0].Content = nil
|
||||
params1[0].Content = nil
|
||||
p.Content = nil
|
||||
p1.Content = nil
|
||||
|
||||
if !reflect.DeepEqual(params1[0], p1) || !reflect.DeepEqual(params1[0], p1) {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleInit(t *testing.T) {
|
||||
s := pserver.NewService()
|
||||
var dummy int
|
||||
err := s.BeginInitParams(nil, &dummy)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// this is fine, it's possible for client to call init
|
||||
// multiple times.
|
||||
err = s.BeginInitParams(nil, &dummy)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
err = s.FinishInitParams(0, &dummy)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
err = s.FinishInitParams(0, &dummy)
|
||||
if err != pserver.ErrAlreadyInitialized {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
err = s.BeginInitParams(nil, &dummy)
|
||||
if err != pserver.ErrAlreadyInitialized {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestUninitialized(t *testing.T) {
|
||||
s := pserver.NewService()
|
||||
var dummy int
|
||||
err := s.SendGrads(nil, &dummy)
|
||||
if err != pserver.ErrUninitialized {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlockUntilInitialized(t *testing.T) {
|
||||
s := pserver.NewService()
|
||||
ch := make(chan struct{}, 2)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
var params []pserver.Parameter
|
||||
err := s.GetParams(nil, ¶ms)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
wg.Done()
|
||||
ch <- struct{}{}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
var dummy int
|
||||
err := s.Save("", &dummy)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
wg.Done()
|
||||
ch <- struct{}{}
|
||||
}()
|
||||
|
||||
var dummy int
|
||||
err := s.BeginInitParams(nil, &dummy)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
// some function returned before initialization is completed.
|
||||
t.FailNow()
|
||||
default:
|
||||
}
|
||||
|
||||
err = s.FinishInitParams(0, &dummy)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue