dist train use split_by_ref

wangkuiyi-patch-2
typhoonzero 7 years ago
parent 0c6eef3e58
commit ed89b7b7e6

@ -824,7 +824,7 @@ class DistributeTranspiler:
for v in splited_vars:
sections.append(v.shape[0])
program.global_block().append_op(
type="split",
type="split_byref",
inputs={"X": orig_var},
outputs={"Out": splited_vars},
attrs={"sections": sections} # assume split evenly

Loading…
Cancel
Save