|
|
@ -47,7 +47,7 @@ void SetAttrsForFusionNode(const AnfNodePtr &sub_anf, const AnfNodePtr &fusion_n
|
|
|
|
|
|
|
|
|
|
|
|
const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const {
|
|
|
|
const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const {
|
|
|
|
return VectorRef(
|
|
|
|
return VectorRef(
|
|
|
|
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})});
|
|
|
|
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input1_, input0_})})});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
|
|
|
const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
|
|
|