|
|
@ -122,7 +122,7 @@ class TestDistRunnerBase(object):
|
|
|
|
if args.batch_merge_repeat > 1:
|
|
|
|
if args.batch_merge_repeat > 1:
|
|
|
|
pass_builder = build_stra._finalize_strategy_and_create_passes()
|
|
|
|
pass_builder = build_stra._finalize_strategy_and_create_passes()
|
|
|
|
mypass = pass_builder.insert_pass(
|
|
|
|
mypass = pass_builder.insert_pass(
|
|
|
|
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
|
|
|
|
len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass")
|
|
|
|
mypass.set_int("num_repeats", args.batch_merge_repeat)
|
|
|
|
mypass.set_int("num_repeats", args.batch_merge_repeat)
|
|
|
|
|
|
|
|
|
|
|
|
if args.update_method == "nccl2":
|
|
|
|
if args.update_method == "nccl2":
|
|
|
|