|
|
|
@ -278,7 +278,7 @@ class DistSeResneXt2x2:
|
|
|
|
|
|
|
|
|
|
def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True):
|
|
|
|
|
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = get_model(
|
|
|
|
|
batch_size=20)
|
|
|
|
|
batch_size=2)
|
|
|
|
|
if is_dist:
|
|
|
|
|
t = get_transpiler(trainer_id,
|
|
|
|
|
fluid.default_main_program(), endpoints,
|
|
|
|
@ -294,11 +294,7 @@ class DistSeResneXt2x2:
|
|
|
|
|
strategy.num_threads = 1
|
|
|
|
|
strategy.allow_op_delay = False
|
|
|
|
|
exe = fluid.ParallelExecutor(
|
|
|
|
|
True,
|
|
|
|
|
loss_name=avg_cost.name,
|
|
|
|
|
exec_strategy=strategy,
|
|
|
|
|
num_trainers=trainers,
|
|
|
|
|
trainer_id=trainer_id)
|
|
|
|
|
True, loss_name=avg_cost.name, exec_strategy=strategy)
|
|
|
|
|
|
|
|
|
|
feed_var_list = [
|
|
|
|
|
var for var in trainer_prog.global_block().vars.itervalues()
|
|
|
|
|