diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 9a514bd9a3..6ac3b82671 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -358,7 +358,7 @@ class DistributeTranspiler: type=v.type, dtype=v.dtype, shape=v.shape) - if self.trainer_num > 1: + if self.sync_mode and self.trainer_num > 1: for trainer_id in xrange(self.trainer_num): var = pserver_program.global_block().create_var( name="%s.trainer_%d" % (orig_var_name, trainer_id), @@ -688,17 +688,6 @@ class DistributeTranspiler: self.table_name)], persistable=False) - # create grad vars in pserver program - table_grad_var = self.table_param_grad[1] - table_grad_list = [ - pserver_program.global_block().create_var( - name="%s.trainer_%d.pserver_%d" % - (table_grad_var.name, index, pserver_index), - type=table_grad_var.type, - shape=table_grad_var.shape, - dtype=table_grad_var.dtype) for index in range(self.trainer_num) - ] - # create table optimize block in pserver program table_opt_op = [ op for op in self.optimize_ops @@ -708,11 +697,24 @@ class DistributeTranspiler: # only support sgd now assert table_opt_op.type == "sgd" - # append sum op for table_grad_list - table_opt_block.append_op( - type="sum", - inputs={"X": table_grad_list}, - outputs={"Out": [grad_var]}) + if self.sync_mode: + # create grad vars in pserver program + table_grad_var = self.table_param_grad[1] + table_grad_list = [ + pserver_program.global_block().create_var( + name="%s.trainer_%d.pserver_%d" % + (table_grad_var.name, index, pserver_index), + type=table_grad_var.type, + shape=table_grad_var.shape, + dtype=table_grad_var.dtype) + for index in range(self.trainer_num) + ] + + # append sum op for table_grad_list + table_opt_block.append_op( + type="sum", + inputs={"X": table_grad_list}, + outputs={"Out": [grad_var]}) lr_var = pserver_program.global_block().vars[table_opt_op.input( "LearningRate")[0]] @@ -751,7 +753,7 @@ class DistributeTranspiler: for varname, splited in block_map.iteritems(): orig_var = program.global_block().var(varname) if len(splited) == 1: - if add_trainer_suffix: + if self.sync_mode and add_trainer_suffix: new_var_name = "%s.trainer_%d" % \ (orig_var.name, self.trainer_id) program.global_block().rename_var(varname, new_var_name) @@ -775,7 +777,7 @@ class DistributeTranspiler: if len(orig_shape) >= 2: splited_shape.extend(orig_shape[1:]) new_var_name = "" - if add_trainer_suffix: + if self.sync_mode and add_trainer_suffix: new_var_name = "%s.block%d.trainer_%d" % \ (varname, i, self.trainer_id) else: @@ -907,7 +909,7 @@ class DistributeTranspiler: pserver_block.vars[self._orig_varname(grad_block.name)] grad_to_block_id.append(merged_var.name + ":" + str( optimize_block.idx)) - if self.trainer_num > 1: + if self.sync_mode and self.trainer_num > 1: vars2merge = [] for i in xrange(self.trainer_num): per_trainer_name = "%s.trainer_%d" % \ @@ -925,6 +927,7 @@ class DistributeTranspiler: inputs={"X": merged_var}, outputs={"Out": merged_var}, attrs={"scale": 1.0 / float(self.trainer_num)}) + new_inputs[key] = merged_var elif key == "Param": # param is already created on global program