@ -1119,6 +1119,7 @@ to transpile() call.")
def _split_table_grad_and_add_send_vars ( self , program , pserver_endpoints ) :
# 2. add split_ids_op and send_op to send gradient to pservers
# there should only be one table_name
all_ops = program . global_block ( ) . ops
table_grad_name = grad_var_name ( self . table_name )
@ -1143,7 +1144,7 @@ to transpile() call.")
if self . sync_mode else [ ]
} ,
attrs = {
" sync_mode " : self . sync_mode ,
" sync_mode " : not self . sync_mode ,
" epmap " : pserver_endpoints ,
RPC_OP_ROLE_ATTR_NAME : RPC_OP_ROLE_ATTR_VALUE ,
OP_ROLE_VAR_ATTR_NAME : [
@ -1189,7 +1190,15 @@ to transpile() call.")
def _create_table_optimize_block ( self , pserver_index , pserver_program ,
pre_block_idx , grad_to_block_id ) :
# STEP: create table optimize block
table_opt_block = pserver_program . _create_block ( pre_block_idx )
# create table param and grad var in pserver program
# create table optimize block in pserver program
table_opt_op = [
op for op in self . optimize_ops
if ' Param ' in op . input_names and op . input ( " Param " ) [ 0 ] ==
self . table_name
] [ 0 ]
origin_param_var = self . origin_program . global_block ( ) . vars [
self . table_name ]
@ -1205,19 +1214,16 @@ to transpile() call.")
dtype = origin_param_var . dtype ,
type = core . VarDesc . VarType . SELECTED_ROWS ,
persistable = True )
# parameter must be selected rows
param_var . desc . set_type ( core . VarDesc . VarType . SELECTED_ROWS )
grad_var = pserver_program . global_block ( ) . _clone_variable (
self . origin_program . global_block ( ) . vars [ grad_var_name (
self . table_name ) ] )
# create table optimize block in pserver program
table_opt_op = [
op for op in self . optimize_ops
if ' Param ' in op . input_names and op . input ( " Param " ) [ 0 ] ==
self . table_name
] [ 0 ]
table_opt_block = pserver_program . _create_block ( pre_block_idx )
lr_var = pserver_program . global_block ( ) . _clone_variable (
self . origin_program . global_block ( ) . vars [ table_opt_op . input (
" LearningRate " ) [ 0 ] ] )
if self . sync_mode :
# create grad vars in pserver program
@ -1249,8 +1255,6 @@ to transpile() call.")
grad_var = pserver_program . global_block ( ) . _rename_var (
origin_grad_name , splited_grad_name )
lr_var = pserver_program . global_block ( ) . vars [ table_opt_op . input (
" LearningRate " ) [ 0 ] ]
inputs = {
" Param " : [ param_var ] ,
" Grad " : [ grad_var ] ,