|
|
|
|
@ -220,7 +220,10 @@ def _callback_lookup_(op):
|
|
|
|
|
:return: callback function
|
|
|
|
|
"""
|
|
|
|
|
if op.type == 'parallel_do' and op.attr('use_nccl'):
|
|
|
|
|
all_vars = op.block.vars
|
|
|
|
|
param_names = set(op.input('parameters'))
|
|
|
|
|
param_names = filter(lambda name: all_vars[name].stop_gradient is False,
|
|
|
|
|
param_names)
|
|
|
|
|
param_grad_names = [n + "@GRAD" for n in param_names]
|
|
|
|
|
|
|
|
|
|
class ParallelDoCallBack(object):
|
|
|
|
|
|