add_receive_eliminate_pass

pull/8900/head
lichenever 4 years ago
parent 63fcdb44b5
commit ee34ae9259

@ -155,6 +155,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
"virtual_dataset_eliminate", prim::kPrimVirtualDataset);
// Receive
receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive);
// Convert
print_tuple_wrapper_ =
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);

@ -99,6 +99,9 @@ class OptimizeIRPassLib {
// virtual dataset
SubstitutionPtr virtual_dataset_eliminate_;
// Receive
SubstitutionPtr receive_eliminate_;
// Convert
SubstitutionPtr print_tuple_wrapper_;

@ -102,6 +102,25 @@ class VirtualDatasetEliminater : public AnfVisitor {
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimReceive, X} -> prim::kPrimReceive
class ReceiveEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimReceive) || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() == 1) {
return nullptr;
}
std::vector<AnfNodePtr> args = {cnode->input(0)};
return node->func_graph()->NewCNode(args);
}
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimSameTypeShape, X, Y} -> X
class SameEliminater : public AnfVisitor {
public:

@ -201,7 +201,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
{irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_});
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_});
opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_,

@ -185,6 +185,7 @@ inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD");
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("_Receive");
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap");
inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast");

Loading…
Cancel
Save