Merge pull request #8659 from Yancey1989/fix_dist_bug

Registry var type infer in split_selected_rows op
tonyyang-svail-patch-1
Yancey 7 years ago committed by GitHub
commit 718642e93f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -59,6 +59,16 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
}
};
class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarType::SELECTED_ROWS);
}
}
};
class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
@ -80,7 +90,8 @@ class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(split_selected_rows, ops::SplitSelectedRowsOp,
ops::SplitSelectedRowsOpMaker,
ops::SplitSelectedRowsGradMaker);
ops::SplitSelectedRowsGradMaker,
ops::SplitSelectedRowsOpInferVarType);
REGISTER_OP_CPU_KERNEL(
split_selected_rows,
ops::SplitSelectedRowsOpKernel<paddle::platform::CPUPlace, float>);

@ -276,6 +276,7 @@ class DistributeTranspiler:
pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
type=v.type,
dtype=v.dtype,
shape=v.shape)
print("create origin var: ", orig_var_name)
@ -283,6 +284,7 @@ class DistributeTranspiler:
var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id),
persistable=False,
type=v.type,
dtype=v.dtype,
shape=v.shape)
recv_inputs.append(var)
@ -551,11 +553,12 @@ class DistributeTranspiler:
type="sum",
inputs={"X": vars2merge},
outputs={"Out": merged_var})
optimize_block.append_op(
type="scale",
inputs={"X": merged_var},
outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainers)})
if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS:
optimize_block.append_op(
type="scale",
inputs={"X": merged_var},
outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainers)})
new_inputs[key] = merged_var
elif key == "Param":
# param is already created on global program

Loading…
Cancel
Save