|
|
@ -369,7 +369,7 @@ class DistributeTranspiler(object):
|
|
|
|
# FIXME(gongwb): delete not need ops.
|
|
|
|
# FIXME(gongwb): delete not need ops.
|
|
|
|
# note that: some parameter is not trainable and those ops can't be deleted.
|
|
|
|
# note that: some parameter is not trainable and those ops can't be deleted.
|
|
|
|
|
|
|
|
|
|
|
|
for varname, splited_var in self.param_var_mapping.iteritems():
|
|
|
|
for varname, splited_var in six.iteritems(self.param_var_mapping):
|
|
|
|
# Get the eplist of recv vars
|
|
|
|
# Get the eplist of recv vars
|
|
|
|
eps = []
|
|
|
|
eps = []
|
|
|
|
for var in splited_var:
|
|
|
|
for var in splited_var:
|
|
|
@ -406,7 +406,7 @@ class DistributeTranspiler(object):
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
for varname, splited_var in self.param_var_mapping.iteritems():
|
|
|
|
for varname, splited_var in six.iteritems(self.param_var_mapping):
|
|
|
|
#add concat ops to merge splited parameters received from parameter servers.
|
|
|
|
#add concat ops to merge splited parameters received from parameter servers.
|
|
|
|
if len(splited_var) <= 1:
|
|
|
|
if len(splited_var) <= 1:
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|