|
|
|
@ -225,7 +225,7 @@ class DistributeTranspiler:
|
|
|
|
|
if len(splited_vars) <= 1:
|
|
|
|
|
continue
|
|
|
|
|
orig_var = program.global_block().vars[varname]
|
|
|
|
|
if orig_var == core.VarDesc.VarType.SELECTED_ROWS:
|
|
|
|
|
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
|
|
|
|
|
height_sections = []
|
|
|
|
|
for v in splited_vars:
|
|
|
|
|
height_sections.append(v.shape[0])
|
|
|
|
@ -234,7 +234,7 @@ class DistributeTranspiler:
|
|
|
|
|
inputs={"X": orig_var},
|
|
|
|
|
outputs={"Out": splited_vars},
|
|
|
|
|
attrs={"height_sections": height_sections})
|
|
|
|
|
elif orig_var == core.VarDesc.VarType.LOD_TENSOR:
|
|
|
|
|
elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
|
|
|
|
|
sections = []
|
|
|
|
|
for v in splited_vars:
|
|
|
|
|
sections.append(v.shape[0])
|
|
|
|
|