@ -1,6 +1,12 @@
package main
/ *
# include < stdlib . h >
# include < string . h >
# include < stdio . h >
# define PADDLE_MASTER_OK 0
# define PADDLE_MASTER_ERROR - 1
typedef int paddle_master_client ;
* /
@ -14,6 +20,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/master"
)
var nullPtr = unsafe . Pointer ( uintptr ( 0 ) )
var mu sync . Mutex
var handleMap = make ( map [ C . paddle_master_client ] * master . Client )
var curHandle C . paddle_master_client
@ -47,17 +54,16 @@ func (a addresser) Address() string {
return string ( a )
}
// paddle_new_master_client
func paddle_new_master_client ( addr * C . char , buf_size C . int ) C . paddle_master_client {
// export paddle_new_master_client
func paddle_new_master_client ( addr * C . char ) C . paddle_master_client {
a := C . GoString ( addr )
c := master . NewClient ( addresser ( a ) , int ( buf_size ) )
c := master . NewClient ( addresser ( a ) )
return add ( c )
}
//export paddle_new_etcd_master_client
func paddle_new_etcd_master_client ( etcd_addr * C . char ) C . paddle_master_client {
// TODO(helin): fault tolerant master client using etcd.
panic ( "not implemented." )
//export paddle_release_master_client
func paddle_release_master_client ( client C . paddle_master_client ) {
remove ( client )
}
//export paddle_set_dataset
@ -65,17 +71,40 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
c := get ( client )
var paths [ ] string
for i := 0 ; i < int ( size ) ; i ++ {
ptr := ( * * C . char ) ( unsafe . Pointer ( uintptr ( unsafe . Pointer ( path ) ) + uintptr ( size) ) )
ptr := ( * * C . char ) ( unsafe . Pointer ( uintptr ( unsafe . Pointer ( path ) ) + uintptr ( i) * un safe. S izeof( * path ) ) )
str := C . GoString ( * ptr )
paths = append ( paths , str )
}
err := c . SetDataset ( paths )
if err != nil {
log . Println ( err )
return - 1
return C . PADDLE_MASTER_ERROR
}
return C . PADDLE_MASTER_OK
}
//export paddle_next_record
func paddle_next_record ( client C . paddle_master_client , record * * C . uchar ) C . int {
c := get ( client )
r := c . NextRecord ( )
if len ( r ) == 0 {
* record = ( * C . uchar ) ( nullPtr )
return 0
}
return 0
size := C . size_t ( len ( r ) )
* record = ( * C . uchar ) ( C . malloc ( size ) )
C . memcpy ( unsafe . Pointer ( * record ) , unsafe . Pointer ( & r [ 0 ] ) , size )
return C . int ( size )
}
//export mem_free
func mem_free ( p unsafe . Pointer ) {
// "free" may be a better name for this function, but doing so
// will cause calling any function of this library from Python
// ctypes hanging.
C . free ( p )
}
func main ( ) { }