|
|
|
@ -291,7 +291,7 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c
|
|
|
|
|
return bn_training_update_outputs[0];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef FusedBatchNormMixPrecisionFusion::DefinePattern() const {
|
|
|
|
|
const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const {
|
|
|
|
|
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
|
|
|
|
|
VarPtr index0 = std::make_shared<CondVar>(IsC);
|
|
|
|
|
VarPtr index1 = std::make_shared<CondVar>(IsC);
|
|
|
|
@ -313,5 +313,28 @@ const BaseRef FusedBatchNormMixPrecisionFusion::DefinePattern() const {
|
|
|
|
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
|
|
|
|
|
return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const {
|
|
|
|
|
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
|
|
|
|
|
VarPtr index0 = std::make_shared<CondVar>(IsC);
|
|
|
|
|
VarPtr index1 = std::make_shared<CondVar>(IsC);
|
|
|
|
|
VarPtr index2 = std::make_shared<CondVar>(IsC);
|
|
|
|
|
VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
|
|
|
|
|
VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
|
|
|
|
|
VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
|
|
|
|
|
VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
|
|
|
|
|
VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_});
|
|
|
|
|
VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_});
|
|
|
|
|
VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1});
|
|
|
|
|
VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2});
|
|
|
|
|
VectorRef cast0 = VectorRef({prim::kPrimCast, sub0});
|
|
|
|
|
VectorRef cast1 = VectorRef({prim::kPrimCast, sub1});
|
|
|
|
|
VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_});
|
|
|
|
|
VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_});
|
|
|
|
|
VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0});
|
|
|
|
|
VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1});
|
|
|
|
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
|
|
|
|
|
return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|