fix multi card

tonyyang-svail-patch-1
qiaolongfei 8 years ago
parent 5d305070ec
commit f8d0d84f7e

@ -220,7 +220,10 @@ def _callback_lookup_(op):
:return: callback function
"""
if op.type == 'parallel_do' and op.attr('use_nccl'):
all_vars = op.block.vars
param_names = set(op.input('parameters'))
param_names = filter(lambda name: all_vars[name].stop_gradient is False,
param_names)
param_grad_names = [n + "@GRAD" for n in param_names]
class ParallelDoCallBack(object):

@ -294,8 +294,7 @@ class ParallelDo(object):
params = list(set(params))
param_list = [parent_block.var(name) for name in params]
return filter(lambda param: param.stop_gradient is False, param_list)
return [parent_block.var(name) for name in params]
def complete_op(self):
main_program = self.helper.main_program

Loading…
Cancel
Save