|
|
|
@ -397,6 +397,24 @@ class ParallelDoGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ParallelDoGradOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::OpDesc &op_desc,
|
|
|
|
|
framework::BlockDesc *block) const override {
|
|
|
|
|
framework::BlockDesc *sub_block =
|
|
|
|
|
boost::get<framework::BlockDesc *>(op_desc.GetAttr(kParallelBlock));
|
|
|
|
|
for (auto &out_vars : op_desc.Outputs()) {
|
|
|
|
|
for (auto &out_var : out_vars.second) {
|
|
|
|
|
auto &var = block->FindRecursiveOrCreateVar(out_var);
|
|
|
|
|
auto sub_var = sub_block->FindRecursiveOrCreateVar(out_var);
|
|
|
|
|
if (sub_var.GetType() != var.GetType()) {
|
|
|
|
|
var.SetType(sub_var.GetType());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -404,4 +422,5 @@ REGISTER_OPERATOR(parallel_do, paddle::operators::ParallelDoOp,
|
|
|
|
|
paddle::operators::ParallelDoOpProtoMaker,
|
|
|
|
|
paddle::operators::ParallelDoGradOpDescMaker);
|
|
|
|
|
REGISTER_OPERATOR(parallel_do_grad, paddle::operators::ParallelDoGradOp,
|
|
|
|
|
paddle::operators::ParallelDoGradOpShapeInference);
|
|
|
|
|
paddle::operators::ParallelDoGradOpShapeInference,
|
|
|
|
|
paddle::operators::ParallelDoGradOpVarTypeInference);
|
|
|
|
|