@ -104,9 +104,9 @@ class ParameterServerRuntime(RuntimeBase):
def _init_worker ( self ) :
def _init_worker ( self ) :
def sync_strategy_envs ( ) :
def sync_strategy_envs ( ) :
kwargs = { }
kwargs = { }
kwargs [ " pserver_endpoints " ] = self . role_maker . get_pserver_endpoints (
kwargs [
)
" pserver_endpoints " ] = self . role_maker . _get_pserver_endpoints ( )
kwargs [ " trainer_id " ] = self . role_maker . worker_index( )
kwargs [ " trainer_id " ] = self . role_maker . _ worker_index( )
return kwargs
return kwargs
def geo_strategy_envs ( ) :
def geo_strategy_envs ( ) :
@ -150,7 +150,7 @@ class ParameterServerRuntime(RuntimeBase):
return " # " . join ( init_attrs )
return " # " . join ( init_attrs )
kwargs = { }
kwargs = { }
kwargs [ " trainers " ] = self . role_maker . worker_num( )
kwargs [ " trainers " ] = self . role_maker . _ worker_num( )
kwargs [ " sparse_attrs " ] = get_sparse_attrs ( )
kwargs [ " sparse_attrs " ] = get_sparse_attrs ( )
return kwargs
return kwargs
@ -338,7 +338,7 @@ class ParameterServerRuntime(RuntimeBase):
block . append_op (
block . append_op (
type = ' recv_save ' ,
type = ' recv_save ' ,
attrs = {
attrs = {
" trainer_id " : self . role_maker . worker_index( ) ,
" trainer_id " : self . role_maker . _ worker_index( ) ,
" shape " : var . shape ,
" shape " : var . shape ,
" slice_shapes " :
" slice_shapes " :
[ " , " . join ( [ str ( i ) for i in var . shape ] ) ] ,
[ " , " . join ( [ str ( i ) for i in var . shape ] ) ] ,
@ -378,14 +378,15 @@ class ParameterServerRuntime(RuntimeBase):
block . append_op (
block . append_op (
type = ' recv_save ' ,
type = ' recv_save ' ,
attrs = {
attrs = {
" trainer_id " : self . role_maker . worker_index( ) ,
" trainer_id " : self . role_maker . _ worker_index( ) ,
" shape " : var . shape ,
" shape " : var . shape ,
" slice_shapes " : slice_shapes ,
" slice_shapes " : slice_shapes ,
" slice_varnames " : var_ctx . split_varnames ( ) ,
" slice_varnames " : var_ctx . split_varnames ( ) ,
" remote_varnames " : var_ctx . split_varnames ( ) ,
" remote_varnames " : var_ctx . split_varnames ( ) ,
" is_sparse " : True ,
" is_sparse " : True ,
" endpoints " : var_ctx . split_endpoints ( ) ,
" endpoints " : var_ctx . split_endpoints ( ) ,
" pserver_num " : len ( self . role_maker . get_pserver_endpoints ( ) ) ,
" pserver_num " :
len ( self . role_maker . _get_pserver_endpoints ( ) ) ,
" file_path " : os . path . join ( dirname , var . name )
" file_path " : os . path . join ( dirname , var . name )
} )
} )
@ -403,7 +404,7 @@ class ParameterServerRuntime(RuntimeBase):
block . append_op (
block . append_op (
type = ' recv_save ' ,
type = ' recv_save ' ,
attrs = {
attrs = {
" trainer_id " : self . role_maker . worker_index( ) ,
" trainer_id " : self . role_maker . _ worker_index( ) ,
" shape " : var . shape ,
" shape " : var . shape ,
" slice_shapes " : slice_shapes ,
" slice_shapes " : slice_shapes ,
" slice_varnames " : slice_varnames ,
" slice_varnames " : slice_varnames ,
@ -411,7 +412,7 @@ class ParameterServerRuntime(RuntimeBase):
" is_sparse " : True ,
" is_sparse " : True ,
" endpoints " : var_ctx . split_endpoints ( ) ,
" endpoints " : var_ctx . split_endpoints ( ) ,
" pserver_num " :
" pserver_num " :
len ( self . role_maker . get_pserver_endpoints( ) ) ,
len ( self . role_maker . _ get_pserver_endpoints( ) ) ,
" file_path " : os . path . join ( dirname , var . name )
" file_path " : os . path . join ( dirname , var . name )
} )
} )
@ -422,7 +423,7 @@ class ParameterServerRuntime(RuntimeBase):
block . append_op (
block . append_op (
type = ' recv_save ' ,
type = ' recv_save ' ,
attrs = {
attrs = {
" trainer_id " : self . role_maker . worker_index( ) ,
" trainer_id " : self . role_maker . _ worker_index( ) ,
" shape " : var . shape ,
" shape " : var . shape ,
" slice_shapes " :
" slice_shapes " :
[ " , " . join ( [ str ( i ) for i in var . shape ] ) ] ,
[ " , " . join ( [ str ( i ) for i in var . shape ] ) ] ,