|
|
|
@ -105,7 +105,7 @@ class TestDistRunnerBase(object):
|
|
|
|
|
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
|
|
|
|
|
|
|
|
|
|
if args.batch_merge_repeat > 1:
|
|
|
|
|
pass_builder = build_stra._create_passes_from_strategy()
|
|
|
|
|
pass_builder = build_stra._finalize_strategy_and_create_passes()
|
|
|
|
|
mypass = pass_builder.insert_pass(
|
|
|
|
|
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
|
|
|
|
|
mypass.set_int("num_repeats", args.batch_merge_repeat)
|
|
|
|
|