@ -19,21 +19,7 @@ typedef struct {
int content_len ;
} paddle_parameter , paddle_gradient ;
static inline void paddle_release_param ( paddle_parameter * param ) {
if ( param != NULL ) {
if ( param - > name != NULL ) {
free ( param - > name ) ;
}
if ( param - > content != NULL ) {
free ( param - > content ) ;
}
free ( param ) ;
}
}
typedef int client ;
typedef int paddle_pserver_client ;
* /
import "C"
@ -48,10 +34,10 @@ import (
var nullPtr = unsafe . Pointer ( uintptr ( 0 ) )
var mu sync . Mutex
var handleMap = make ( map [ C . client] * pserver . Client )
var curHandle C . client
var handleMap = make ( map [ C . paddle_pserver_ client] * pserver . Client )
var curHandle C . paddle_pserver_ client
func add ( c * pserver . Client ) C . client {
func add ( c * pserver . Client ) C . paddle_pserver_ client {
mu . Lock ( )
defer mu . Unlock ( )
client := curHandle
@ -60,13 +46,13 @@ func add(c *pserver.Client) C.client {
return client
}
func get ( client C . client) * pserver . Client {
func get ( client C . paddle_pserver_ client) * pserver . Client {
mu . Lock ( )
defer mu . Unlock ( )
return handleMap [ client ]
}
func remove ( client C . client) * pserver . Client {
func remove ( client C . paddle_pserver_ client) * pserver . Client {
mu . Lock ( )
defer mu . Unlock ( )
h := handleMap [ client ]
@ -100,7 +86,7 @@ func (l lister) List() []pserver.Server {
}
//export paddle_new_pserver_client
func paddle_new_pserver_client ( addrs * C . char , selected int ) C . client {
func paddle_new_pserver_client ( addrs * C . char , selected int ) C . paddle_pserver_ client {
a := C . GoString ( addrs )
as := strings . Split ( a , "," )
servers := make ( [ ] pserver . Server , len ( as ) )
@ -113,18 +99,18 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.client {
}
//export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client ( etcd_addr * C . char ) C . client {
func paddle_new_etcd_pserver_client ( etcd_addr * C . char ) C . paddle_pserver_ client {
// TODO(helin): fault tolerant pserver client using etcd.
panic ( "not implemented." )
}
//export paddle_pserver_client_release
func paddle_pserver_client_release ( client C . client) {
func paddle_pserver_client_release ( client C . paddle_pserver_ client) {
remove ( client )
}
//export paddle_begin_init_params
func paddle_begin_init_params ( client C . client) C . int {
func paddle_begin_init_params ( client C . paddle_pserver_ client) C . int {
c := get ( client )
if selected := c . BeginInitParams ( ) ; selected {
return 1
@ -133,7 +119,7 @@ func paddle_begin_init_params(client C.client) C.int {
}
//export paddle_init_param
func paddle_init_param ( client C . client, param C . paddle_parameter , param_config unsafe . Pointer , config_len C . int ) C . int {
func paddle_init_param ( client C . paddle_pserver_ client, param C . paddle_parameter , param_config unsafe . Pointer , config_len C . int ) C . int {
et := pserver . ElementType ( param . element_type )
name := C . GoString ( param . name )
content := cArrayToSlice ( unsafe . Pointer ( param . content ) , int ( param . content_len ) )
@ -143,7 +129,12 @@ func paddle_init_param(client C.client, param C.paddle_parameter, param_config u
}
c := get ( client )
err := c . InitParam ( pc )
if err != nil {
if err . Error ( ) == pserver . AlreadyInitialized {
log . Println ( "parameter" , name , "already initialized, treat paddle_init_param as sucessful." )
return 0
}
log . Println ( err )
return - 1
}
@ -152,10 +143,15 @@ func paddle_init_param(client C.client, param C.paddle_parameter, param_config u
}
//export paddle_finish_init_params
func paddle_finish_init_params ( client C . client) C . int {
func paddle_finish_init_params ( client C . paddle_pserver_ client) C . int {
c := get ( client )
err := c . FinishInitParams ( )
if err != nil {
if err . Error ( ) == pserver . AlreadyInitialized {
log . Println ( "parameters already initialized, treat paddle_finish_init_params as sucessful." )
return 0
}
log . Println ( err )
return - 1
}
@ -164,7 +160,7 @@ func paddle_finish_init_params(client C.client) C.int {
}
//export paddle_send_grads
func paddle_send_grads ( client C . client, grads * C . paddle_gradient , total C . int ) C . int {
func paddle_send_grads ( client C . paddle_pserver_ client, grads * C . paddle_gradient , total C . int ) C . int {
var gs [ ] pserver . Gradient
for i := 0 ; i < int ( total ) ; i ++ {
grad := ( * C . paddle_gradient ) ( unsafe . Pointer ( ( uintptr ( unsafe . Pointer ( grads ) ) + uintptr ( i ) * unsafe . Sizeof ( * grads ) ) ) )
@ -185,11 +181,11 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C
}
//export paddle_get_params
func paddle_get_params ( client C . client, names * * C . char , dst * * C . paddle_parameter , total C . int ) C . int {
func paddle_get_params ( client C . paddle_pserver_ client, dst * * C . paddle_parameter , total C . int ) C . int {
var ns [ ] string
for i := 0 ; i < int ( total ) ; i ++ {
name := * ( * * C . cha r) ( unsafe . Pointer ( ( uintptr ( unsafe . Pointer ( names ) ) + uintptr ( i ) * unsafe . Sizeof ( * names ) ) ) )
ns = append ( ns , C . GoString ( name) )
param := * ( * * C . paddle_paramete r) ( unsafe . Pointer ( ( uintptr ( unsafe . Pointer ( dst ) ) + uintptr ( i ) * unsafe . Sizeof ( * dst ) ) ) )
ns = append ( ns , C . GoString ( param. name) )
}
c := get ( client )
ps , err := c . GetParams ( ns )
@ -198,44 +194,32 @@ func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter
return - 1
}
for i := 0 ; i < int ( total ) ; i ++ {
if i >= len ( ps ) {
break
if len ( ps ) != len ( ns ) {
return - 1
}
for i := range ps {
if ns [ i ] != ps [ i ] . Name {
return - 1
}
}
for i := 0 ; i < int ( total ) ; i ++ {
p := ps [ i ]
param := * ( * * C . paddle_parameter ) ( unsafe . Pointer ( ( uintptr ( unsafe . Pointer ( dst ) ) + uintptr ( i ) * unsafe . Sizeof ( * dst ) ) ) )
nameReady := false
contentAllocated := false
if unsafe . Pointer ( param ) == nullPtr {
param = ( * C . paddle_parameter ) ( C . calloc ( 1 , C . size_t ( unsafe . Sizeof ( * param ) ) ) )
log . Println ( "Error: must pre-allocate parameter." )
return - 1
} else {
if unsafe . Pointer ( param . name ) != nullPtr {
if n := C . GoString ( param . name ) ; n != p . Name {
log . Println ( "Warning: the pre-allocated parameter name does not match the parameter name, it will be freed." , n , p . Name )
C . free ( unsafe . Pointer ( param . name ) )
} else {
nameReady = true
}
}
if unsafe . Pointer ( param . content ) != nullPtr {
if int ( param . content_len ) == len ( p . Content ) {
contentAllocated = true
} else {
log . Println ( "Warning: the pre-allocated content len does not match parameter content len, the pre-allocated content will be freed." , param . content_len , len ( p . Content ) )
C . free ( unsafe . Pointer ( param . content ) )
if int ( param . content_len ) != len ( p . Content ) {
log . Println ( "Error: the pre-allocated content len does not match parameter content len." , param . content_len , len ( p . Content ) )
return - 1
}
}
}
if ! nameReady {
param . name = C . CString ( p . Name )
}
if ! contentAllocated {
param . content = ( * C . uchar ) ( C . malloc ( C . size_t ( len ( p . Content ) ) ) )
}
C . memcpy ( unsafe . Pointer ( param . content ) , unsafe . Pointer ( & p . Content [ 0 ] ) , C . size_t ( len ( p . Content ) ) )
param . content_len = C . int ( len ( p . Content ) )
param . element_type = C . paddle_element_type ( p . ElementType )
@ -245,7 +229,7 @@ func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter
}
//export paddle_save_model
func paddle_save_model ( client C . client, path * C . char ) C . int {
func paddle_save_model ( client C . paddle_pserver_ client, path * C . char ) C . int {
p := C . GoString ( path )
c := get ( client )
err := c . Save ( p )