@ -17,6 +17,7 @@ package pserver
import (
import (
"bufio"
"bufio"
"bytes"
"bytes"
"encoding/binary"
"encoding/gob"
"encoding/gob"
"encoding/json"
"encoding/json"
"errors"
"errors"
@ -26,11 +27,15 @@ import (
"os"
"os"
"path"
"path"
"strconv"
"strconv"
"strings"
"sync"
"sync"
"time"
"time"
"github.com/golang/protobuf/proto"
uuid "github.com/satori/go.uuid"
uuid "github.com/satori/go.uuid"
pb "github.com/PaddlePaddle/Paddle/go/proto"
log "github.com/inconshreveable/log15"
log "github.com/inconshreveable/log15"
)
)
@ -65,6 +70,46 @@ type Parameter struct {
Content [ ] byte
Content [ ] byte
}
}
func float32ToString ( b [ ] byte ) string {
f := make ( [ ] float32 , len ( b ) / 4 )
buf := bytes . NewReader ( b )
err := binary . Read ( buf , binary . LittleEndian , & f )
if err != nil {
return ""
}
return fmt . Sprintf ( "%v" , f )
}
func float32ByteToString ( c [ ] byte ) string {
var a [ ] byte
var b [ ] byte
if len ( c ) <= 80 {
a = c
} else {
a = c [ 0 : 40 ]
b = c [ len ( c ) - 40 : ]
}
var s string
s = float32ToString ( a )
if b == nil {
return s
}
s = strings . Replace ( s , "]" , "" , - 1 ) + "..." + strings . Replace ( float32ToString ( b ) , "[" , "" , - 1 )
return s
}
func ( p Parameter ) String ( ) string {
if p . ElementType != Float32 {
return fmt . Sprintf ( "name:%v ElementType:%v" ,
p . Name , p . ElementType )
}
return float32ByteToString ( p . Content )
}
// ParameterWithConfig contains the parameter and the configuration.
// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
type ParameterWithConfig struct {
Param Parameter
Param Parameter
@ -189,7 +234,9 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error
default :
default :
}
}
// TODO(helin): parse parameter config
c := & pb . OptimizerConfig { }
proto . Unmarshal ( paramWithConfigs . Config , c )
log . Debug ( fmt . Sprintf ( "OptimizerConfig:%v" , c ) )
s . mu . Lock ( )
s . mu . Lock ( )
defer s . mu . Unlock ( )
defer s . mu . Unlock ( )
@ -239,7 +286,8 @@ func (s *Service) SendGrad(g Gradient, _ *int) error {
select {
select {
case <- s . initialized :
case <- s . initialized :
default :
default :
log . Warn ( "received gradient before initialization." , "name" , g . Name , "size" , len ( g . Content ) , "type" , g . ElementType )
log . Warn ( "received gradient before initialization." ,
"name" , g . Name , "size" , len ( g . Content ) , "type" , g . ElementType )
return errors . New ( Uninitialized )
return errors . New ( Uninitialized )
}
}
@ -248,10 +296,14 @@ func (s *Service) SendGrad(g Gradient, _ *int) error {
o , ok := s . optMap [ g . Name ]
o , ok := s . optMap [ g . Name ]
if ! ok {
if ! ok {
log . Warn ( "received gradient but can't find name." ,
"name" , g . Name , "size" , len ( g . Content ) , "type" , g . ElementType )
return fmt . Errorf ( "parameter: %s does not exist" , g . Name )
return fmt . Errorf ( "parameter: %s does not exist" , g . Name )
}
}
log . Info ( "received gradient from trainer, updating gradient." , "name" , g . Name , "size" , len ( g . Content ) , "type" , g . ElementType )
log . Debug ( Parameter ( g ) . String ( ) )
log . Info ( "received gradient from trainer, updating gradient." ,
"name" , g . Name , "size" , len ( g . Content ) , "type" , g . ElementType )
return o . UpdateParameter ( g )
return o . UpdateParameter ( g )
}
}
@ -277,7 +329,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
parameter . Name = name
parameter . Name = name
parameter . ElementType = opt . elementType
parameter . ElementType = opt . elementType
parameter . Content = opt . GetWeights ( )
parameter . Content = opt . GetWeights ( )
log . Debug ( parameter . String ( ) )
log . Info ( "sending parameter to the trainer" , "name" , parameter . Name , "size" , len ( parameter . Content ) , "type" , parameter . ElementType )
log . Info ( "sending parameter to the trainer" , "name" , parameter . Name , "size" , len ( parameter . Content ) , "type" , parameter . ElementType )
return nil
return nil
}
}