|
|
|
@ -37,6 +37,7 @@ type ParameterWithConfig struct {
|
|
|
|
|
// Gradient is the gradient of the parameter.
|
|
|
|
|
type Gradient Parameter
|
|
|
|
|
|
|
|
|
|
// Service is the RPC service for pserver.
|
|
|
|
|
type Service struct {
|
|
|
|
|
initialized chan struct{}
|
|
|
|
|
|
|
|
|
@ -45,6 +46,7 @@ type Service struct {
|
|
|
|
|
paramMap map[string]Parameter
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NewService creates a new service.
|
|
|
|
|
func NewService() *Service {
|
|
|
|
|
s := &Service{}
|
|
|
|
|
s.paramMap = make(map[string]Parameter)
|
|
|
|
@ -52,6 +54,8 @@ func NewService() *Service {
|
|
|
|
|
return s
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// BeginInitParams tells the parameter server that the parameter
|
|
|
|
|
// initialization has begun.
|
|
|
|
|
func (s *Service) BeginInitParams(config []byte, dummy *int) error {
|
|
|
|
|
select {
|
|
|
|
|
case <-s.initialized:
|
|
|
|
@ -71,6 +75,7 @@ func (s *Service) BeginInitParams(config []byte, dummy *int) error {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// InitParam initializes a parameter.
|
|
|
|
|
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
|
|
|
|
|
select {
|
|
|
|
|
case <-s.initialized:
|
|
|
|
@ -90,6 +95,8 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// FinishInitParams tells the parameter server that the parameter
|
|
|
|
|
// initialization has finished.
|
|
|
|
|
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
|
|
|
|
|
select {
|
|
|
|
|
case <-s.initialized:
|
|
|
|
@ -101,6 +108,8 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// SendGrads sends gradients to parameter servers for parameter
|
|
|
|
|
// optimization.
|
|
|
|
|
func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
|
|
|
|
|
<-s.initialized
|
|
|
|
|
|
|
|
|
@ -140,6 +149,7 @@ func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetParams gets parameters from the parameter server.
|
|
|
|
|
func (s *Service) GetParams(names []string, parameters *[]Parameter) error {
|
|
|
|
|
<-s.initialized
|
|
|
|
|
s.mu.Lock()
|
|
|
|
@ -166,7 +176,8 @@ func (s *Service) GetParams(names []string, parameters *[]Parameter) error {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *Service) SaveModel(path string, dummy *int) error {
|
|
|
|
|
// Save tells the parameter server to save parameters.
|
|
|
|
|
func (s *Service) Save(path string, dummy *int) error {
|
|
|
|
|
<-s.initialized
|
|
|
|
|
|
|
|
|
|
// TODO
|
|
|
|
|